"""Arithmetic expressions accepted by Biogeme: expm1Michel BierlaireMon Nov 03 2025, 16:44:14"""from__future__importannotationsimportloggingimportnumpyasnpimportpandasaspdimportpytensor.tensorasptfrombiogeme.floating_pointimportMAX_EXP_ARG,MIN_EXP_ARGfromjaximportnumpyasjnpfrom.base_expressionsimportExpressionOrNumericfrom.bayesianimportPymcModelBuilderTypefrom.jax_utilsimportJaxFunctionTypefrom.unary_expressionsimportUnaryOperatorlogger=logging.getLogger(__name__)
[docs]classexpm1(UnaryOperator):""" exponential minus one expression, i.e. eˣ - 1, implemented in a numerically stable way. """def__init__(self,child:ExpressionOrNumeric)->None:"""Constructor :param child: arithmetic expression :type child: biogeme.expressions.Expression """super().__init__(child)
[docs]defdeep_flat_copy(self)->expm1:"""Provides a deep copy of the expression."""copy_child=self.child.deep_flat_copy()returntype(self)(child=copy_child)
[docs]defget_value(self)->float:"""Evaluates the value of the expression :return: eˣ - 1 :rtype: float """returnnp.expm1(self.child.get_value())
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a JAX-compatible function for Biogeme-JAX. :return: callable(parameters, one_row, the_draws, the_random_variables) """child_jax=self.child.recursive_construct_jax_function(numerically_safe=numerically_safe)ifnumerically_safe:defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)safe_value=jnp.clip(child_value,min=MIN_EXP_ARG,max=MAX_EXP_ARG)result=jnp.expm1(safe_value)returnresultreturnthe_jax_functiondefthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)result=jnp.expm1(child_value)returnresultreturnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:""" Generates recursively a function for PyMC representation. :return: the expression in TensorVariable format (PyTensor) """child_pymc=self.child.recursive_construct_pymc_model_builder()defbuilder(dataframe:pd.DataFrame)->pt.TensorVariable:child_value=child_pymc(dataframe=dataframe)returnpt.expm1(child_value)returnbuilder