Source code for biogeme.expressions.beta_parameters
"""Representation of unknown parameters:author: Michel Bierlaire:date: Sat Apr 20 14:54:16 2024"""from__future__importannotationsimportloggingimportjax.numpyasjnpfrombiogeme.exceptionsimportBiogemeErrorfrom.elementary_expressionsimportElementaryfrom.elementary_typesimportTypeOfElementaryExpressionfrom.jax_utilsimportJaxFunctionTypelogger=logging.getLogger(__name__)
[docs]classBeta(Elementary):""" Unknown parameters to be estimated from data. """def__init__(self,name:str,value:float,lowerbound:float|None,upperbound:float|None,status:int,):"""Constructor :param name: name of the parameter. :param value: default value. :param lowerbound: if different from None, imposes a lower bound on the value of the parameter during the optimization. :param upperbound: if different from None, imposes an upper bound on the value of the parameter during the optimization. :param status: if different from 0, the parameter is fixed to its default value, and not modified by the optimization algorithm. :raise BiogemeError: if the first parameter is not a str. :raise BiogemeError: if the second parameter is not an int or a float. """ifnotisinstance(value,(int,float)):error_msg=(f"The second parameter for {name} must be "f"a float and not a {type(value)}: {value}")raiseBiogemeError(error_msg)ifnotisinstance(name,str):error_msg=(f"The first parameter must be a string and "f"not a {type(name)}: {name}")raiseBiogemeError(error_msg)super().__init__(name)self.init_value=valueself.lower_bound=lowerboundself.upper_bound=upperboundself.status=status
[docs]defdeep_flat_copy(self)->Beta:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """returntype(self)(name=self.name,value=self.init_value,lowerbound=self.lower_bound,upperbound=self.upper_bound,status=self.status,)
[docs]defget_value(self)->float:"""Calculates the value of the expression if it is simple"""returnself.init_value
@propertydefsafe_beta_id(self)->int:"""Check the presence of the ID before using it"""ifself.specific_idisNone:raiseBiogemeError(f"No id defined for parameter {self.name}")returnself.specific_id
[docs]deffix_betas(self,beta_values:dict[str,float],prefix:str|None=None,suffix:str|None=None,):"""Fix all the values of the Beta parameters appearing in the dictionary :param beta_values: dictionary containing the betas to be fixed (as key) and their value. :type beta_values: dict(str: float) :param prefix: if not None, the parameter is renamed, with a prefix defined by this argument. :type prefix: str :param suffix: if not None, the parameter is renamed, with a suffix defined by this argument. :type suffix: str """ifself.nameinbeta_values:self.init_value=beta_values[self.name]self.status=1ifprefixisnotNone:self.name=f"{prefix}{self.name}"ifsuffixisnotNone:self.name=f"{self.name}{suffix}"
[docs]defchange_init_values(self,betas:dict[str,float]):"""Modifies the initial values of the Beta parameters. The fact that the parameters are fixed or free is irrelevant here. :param betas: dictionary where the keys are the names of the parameters, and the values are the new value for the parameters. :type betas: dict(string:float) """value=betas.get(self.name)ifvalueisnotNoneandvalue!=self.init_value:ifself.is_fixed:warning_msg=(f'Parameter {self.name} is fixed, but its value 'f'is changed from {self.init_value} to {value}.')logger.warning(warning_msg)self.init_value=value
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Returns a compiled JAX-compatible function that extracts the beta value from the parameter vector using its unique index. """ifself.is_free:defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.array:returnjnp.asarray(parameters[self.safe_beta_id])returnthe_jax_functiondefthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.array:returnjnp.asarray(self.init_value)returnthe_jax_function