"""Arithmetic expressions accepted by Biogeme: powerMichel Bierlaire10.04.2025 15:56"""from__future__importannotationsimportloggingimportjaximportjax.numpyasjnpfrombiogeme.floating_pointimportJAX_FLOATfrom.base_expressionsimportExpressionOrNumericfrom.binary_expressionsimportBinaryOperatorfrom.jax_utilsimportJaxFunctionTypefrom.numeric_expressionsimportNumericlogger=logging.getLogger(__name__)
[docs]classPower(BinaryOperator):""" Power expression """def__init__(self,left:ExpressionOrNumeric,right:ExpressionOrNumeric):"""Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """super().__init__(left,right)self.simplified=Noneifisinstance(left,Numeric):ifleft.value==0:self.simplified=Numeric(0)elifleft.value==1:self.simplified=Numeric(1)
[docs]defdeep_flat_copy(self)->Power:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """copy_left=self.left.deep_flat_copy()copy_right=self.right.deep_flat_copy()returntype(self)(left=copy_left,right=copy_right)
[docs]defget_value(self)->float:"""Evaluates the value of the expression :return: value of the expression :rtype: float """ifself.simplifiedisnotNone:returnself.simplified.get_value()returnself.left.get_value()**self.right.get_value()
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """ifself.simplifiedisnotNone:returnself.simplified.recursive_construct_jax_function(numerically_safe=numerically_safe)left_jax:JaxFunctionType=self.left.recursive_construct_jax_function(numerically_safe=numerically_safe)right_jax:JaxFunctionType=self.right.recursive_construct_jax_function(numerically_safe=numerically_safe)ifnumerically_safe:defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:base=left_jax(parameters,one_row,the_draws,the_random_variables)exponent=right_jax(parameters,one_row,the_draws,the_random_variables)epsilon=jnp.finfo(JAX_FLOAT).epsdefsafe_power(_):safe_base=jnp.clip(base,a_min=epsilon)returnjnp.exp(exponent*jnp.log(safe_base))defreturn_nan(_):returnjnp.nandefreturn_zero(_):returnjnp.array(0.0,dtype=JAX_FLOAT)returnjax.lax.cond(base==0.0,lambda_:return_zero(None),lambda_:jax.lax.cond(base<0.0,lambda_:return_nan(None),safe_power,operand=None,),operand=None,)returnthe_jax_functiondefthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:base=left_jax(parameters,one_row,the_draws,the_random_variables)exponent=right_jax(parameters,one_row,the_draws,the_random_variables)returnjnp.power(base,exponent)returnthe_jax_function