"""Arithmetic expressions accepted by Biogeme: belongs toMichel BierlaireSat May 03 2025, 11:56:33"""from__future__importannotationsimportloggingimportpandasaspdimportpytensor.tensorasptfromjaximportnumpyasjnpfrom.base_expressionsimportExpressionOrNumericfrom.bayesianimportPymcModelBuilderTypefrom.jax_utilsimportJaxFunctionTypefrom.unary_expressionsimportUnaryOperatorlogger=logging.getLogger(__name__)
[docs]classBelongsTo(UnaryOperator):""" Check if a value belongs to a set """def__init__(self,child:ExpressionOrNumeric,the_set:set[float]):"""Constructor :param child: arithmetic expression :type child: biogeme.expressions.Expression :param the_set: set of values :type the_set: set(float) """super().__init__(child)self.the_set:set[float]=the_set
[docs]defdeep_flat_copy(self)->BelongsTo:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """child_copy=self.child.deep_flat_copy()returntype(self)(child=child_copy,the_set=self.the_set)
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """child_jax=self.child.recursive_construct_jax_function(numerically_safe=numerically_safe)defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)returnjnp.where(jnp.isin(child_value,jnp.array(list(self.the_set))),1.0,0.0)returnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:""" Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """child_builder=self.child.recursive_construct_pymc_model_builder()defbuilder(dataframe:pd.DataFrame)->pt.TensorVariable:child_val:pt.TensorVariable=child_builder(dataframe)# Make a constant tensor of the set values, matching dtype for safetyset_vals=pt.constant(set_values_np,dtype=child_val.dtype)# Membership test:membership=pt.any(pt.eq(child_val[...,None],set_vals),axis=-1)# Return 1.0 where in set, else 0.0, with correct dtype/shapereturnpt.where(membership,pt.ones_like(child_val),pt.zeros_like(child_val))returnbuilder