"""Arithmetic expressions accepted by Biogeme: variablesMichel BierlaireFri Jun 27 2025, 14:43:42"""from__future__importannotationsimportloggingimportjax.numpyasjnpimportpandasaspdimportpymcaspmfrompytensor.tensorimportTensorVariablefrombiogeme.bayesian_estimationimportcheck_shapefrombiogeme.exceptionsimportBiogemeErrorfrom.bayesianimportDimension,PymcModelBuilderTypefrom.elementary_expressionsimportElementaryfrom.elementary_typesimportTypeOfElementaryExpressionfrom.jax_utilsimportJaxFunctionTypelogger=logging.getLogger(__name__)
[docs]classVariable(Elementary):"""Explanatory variable This represents the explanatory variables of the choice model. Typically, they come from the data set. """expression_type=TypeOfElementaryExpression.VARIABLEdef__init__(self,name:str):"""Constructor :param name: name of the variable. :type name: string """super().__init__(name)
[docs]defdeep_flat_copy(self)->Variable:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """returntype(self)(name=self.name)
@propertydefsafe_variable_id(self)->int:"""Check the presence of the ID before using it"""ifself.specific_idisNone:raiseBiogemeError(f"No id defined for variable {self.name}")returnself.specific_id
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.array:returnjnp.take(one_row,self.safe_variable_id,axis=-1)# return one_row[self.variableId]returnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:""" Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """@check_shapedefbuilder(dataframe:pd.DataFrame)->TensorVariable:model=pm.modelcontext(None)# active modeltry:values=dataframe[self.name].to_numpy()exceptKeyErrorase:raiseBiogemeError(f"Column '{self.name}' not found in the dataframe.")fromeexisting=model.named_vars.get(self.name)ifexistingisnotNone:# Ensure the pm.Data is refreshed with the current valuespm.set_data({self.name:values},model=model)returnexistingreturnpm.Data(self.name,values,dims=Dimension.OBS,)returnbuilder