"""Arithmetic expressions accepted by Biogeme: logical andMichel BierlaireSat Jun 14 2025, 10:14:27"""from__future__importannotationsimportloggingimportjaximportjax.numpyasjnpimportpandasaspdfrombiogeme.expressionsimportPymcModelBuilderTypefrombiogeme.floating_pointimportJAX_FLOATfrompytensor.tensorimportTensorVariable,neq,switchfrom.base_expressionsimportExpressionOrNumericfrom.binary_expressionsimportBinaryOperatorfrom.jax_utilsimportJaxFunctionTypelogger=logging.getLogger(__name__)
[docs]classAnd(BinaryOperator):""" Logical and """def__init__(self,left:ExpressionOrNumeric,right:ExpressionOrNumeric):"""Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """super().__init__(left,right)
[docs]defdeep_flat_copy(self)->And:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """copy_left=self.left.deep_flat_copy()copy_right=self.right.deep_flat_copy()returntype(self)(left=copy_left,right=copy_right)
def__str__(self)->str:returnf'({self.left} and {self.right})'def__repr__(self)->str:returnf'({repr(self.left)} and {repr(self.right)})'
[docs]defget_value(self)->float:"""Evaluates the value of the expression :return: value of the expression :rtype: float """ifself.left.get_value()==0.0:return0.0ifself.right.get_value()==0.0:return0.0return1.0
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression. :return: the function takes three parameters: the parameters, one row of the database, and the draws. """left_jax:JaxFunctionType=self.left.recursive_construct_jax_function(numerically_safe=numerically_safe)right_jax:JaxFunctionType=self.right.recursive_construct_jax_function(numerically_safe=numerically_safe)defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:left_value=left_jax(parameters,one_row,the_draws,the_random_variables)defif_true(_):right_value=right_jax(parameters,one_row,the_draws,the_random_variables)returnjnp.where(right_value!=0.0,1.0,0.0)defif_false(_):returnjnp.array(0.0,dtype=JAX_FLOAT)returnjax.lax.cond(left_value!=0.0,if_true,if_false,operand=None)returnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:""" Generates recursively a function to be used by PyMC. Implements logical AND using numeric convention: - 0 → False - ≠0 → True Returns 1.0 if both sides are nonzero, else 0.0. """left_pymc=self.left.recursive_construct_pymc_model_builder()right_pymc=self.right.recursive_construct_pymc_model_builder()defbuilder(dataframe:pd.DataFrame)->TensorVariable:left_value=left_pymc(dataframe=dataframe)right_value=right_pymc(dataframe=dataframe)# Convert to boolean using nonzero testleft_bool=neq(left_value,0.0)right_bool=neq(right_value,0.0)# Logical and, then convert back to float (0.0 or 1.0)returnswitch(left_bool&right_bool,1.0,0.0)returnbuilder