Source code for biogeme.expressions.numeric_expressions
"""Arithmetic expressions accepted by Biogeme: numeric expressionsMichel BierlaireTue Mar 25 18:41:06 2025"""from__future__importannotationsimportloggingimportjax.numpyasjnpimportpandasaspdimportpytensor.tensorasptfrombiogeme.bayesian_estimationimportcheck_shapefrompytensorimportconfigaspt_configfrompytensor.tensorimportTensorVariablefrom.base_expressionsimportExpressionfrom.bayesianimportPymcModelBuilderTypefrom.jax_utilsimportJaxFunctionTypefrom..floating_pointimportJAX_FLOATlogger=logging.getLogger(__name__)
[docs]classNumeric(Expression):""" Numerical expression for a simple number """def__init__(self,value:float|int|bool):"""Constructor :param value: numerical value :type value: float """super().__init__()self.value=float(value)#: numeric value
[docs]defdeep_flat_copy(self)->Numeric:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """returntype(self)(value=self.value)
[docs]defget_value(self)->float:"""Evaluates the value of the expression :return: value of the expression :rtype: float """returnself.value
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:returnjnp.array(self.value,dtype=JAX_FLOAT)returnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:""" Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """@check_shapedefbuilder(dataframe:pd.DataFrame)->TensorVariable:# Produce a constant vector of length len(dataframe) with the numeric valuen=len(dataframe)returnpt.full((n,),self.value,dtype=pt_config.floatX)returnbuilder