"""Arithmetic expressions accepted by Biogeme: logzeroMichel BierlaireSat Jun 28 2025, 12:17:05"""from__future__importannotationsimportloggingimportjaximportjax.numpyasjnpimportnumpyasnpimportpandasaspdimportpytensor.tensorasptfrombiogeme.expressions.bayesianimportPymcModelBuilderTypefrombiogeme.floating_pointimportEPSILON,JAX_FLOATfrom.base_expressionsimportExpressionOrNumericfrom.jax_utilsimportJaxFunctionTypefrom.unary_expressionsimportUnaryOperatorlogger=logging.getLogger(__name__)
[docs]classlogzero(UnaryOperator):""" logarithm expression. Returns zero if the argument is zero. """def__init__(self,child:ExpressionOrNumeric):"""Constructor :param child: first arithmetic expression :type child: biogeme.expressions.Expression """super().__init__(child)
[docs]defdeep_flat_copy(self)->logzero:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """copy_child=self.child.deep_flat_copy()returntype(self)(child=copy_child)
[docs]defget_value(self)->float:"""Evaluates the value of the expression :return: value of the expression :rtype: float """v=self.child.get_value()return0ifv==0elsenp.log(v)
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """child_jax=self.child.recursive_construct_jax_function(numerically_safe=numerically_safe)ifnumerically_safe:defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)is_zero=child_value==0.0slope=1.0/EPSILONintercept=jnp.log(EPSILON)-slope*EPSILONapprox_log=slope*child_value+interceptreturnjax.lax.cond(is_zero,lambda_:jnp.array(0.0,dtype=JAX_FLOAT),lambda_:jax.lax.cond(child_value<EPSILON,lambda_:approx_log,lambda_:jnp.log(child_value),operand=None,),operand=None,)returnthe_jax_functiondefthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)is_zero=child_value==0.0returnjax.lax.cond(is_zero,lambda_:jnp.array(0.0,dtype=JAX_FLOAT),lambda_:jnp.log(child_value),operand=None,)returnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:child_pymc=self.child.recursive_construct_pymc_model_builder()defbuilder(dataframe:pd.DataFrame)->pt.TensorVariable:x=child_pymc(dataframe=dataframe)# Boolean mask for zerosis_zero=pt.eq(x,0.0)m=pt.cast(is_zero,x.dtype)# 1 where x==0, else 0one_minus_m=1.0-m# Safe log: for x==0, we evaluate log(EPSILON), otherwise log(x)# (No branching: just add EPSILON exactly where x==0)safe_log=pt.log(x+m*EPSILON)# Enforce exact 0 when x==0, and log(x) otherwise (still branch-free)returnone_minus_m*safe_logreturnbuilder