Source code for biogeme.expressions.power_constant
"""Arithmetic expressions accepted by Biogeme: power constantMichel BierlaireSat Jun 28 2025, 12:20:48"""from__future__importannotationsimportloggingimportjaximportjax.numpyasjnpimportnumpyasnpfrombiogeme.exceptionsimportBiogemeErrorfrombiogeme.floating_pointimportJAX_FLOATfrom.base_expressionsimportExpressionOrNumericfrom.jax_utilsimportJaxFunctionTypefrom.unary_expressionsimportUnaryOperatorlogger=logging.getLogger(__name__)
[docs]classPowerConstant(UnaryOperator):""" Raise the argument to a constant power. """def__init__(self,child:ExpressionOrNumeric,exponent:float):"""Constructor :param child: first arithmetic expression :type child: biogeme.expressions.Expression """super().__init__(child)self.exponent:float=exponentepsilon=np.finfo(float).epsifabs(exponent-round(exponent))<epsilon:self.integer_exponent=int(round(exponent))else:self.integer_exponent=None
[docs]defdeep_flat_copy(self)->PowerConstant:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """copy_child=self.child.deep_flat_copy()returntype(self)(child=copy_child,exponent=self.exponent)
[docs]defget_value(self)->float:"""Evaluates the value of the expression :return: value of the expression :rtype: float """v=self.child.get_value()ifv==0:return0.0ifv>0:returnv**self.exponentifself.integer_exponentisnotNone:returnv**self.integer_exponentifv<0:error_msg=f'Cannot calculate {v}**{self.exponent}'raiseBiogemeError(error_msg)return0ifv==0elsenp.log(v)
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """child_jax=self.child.recursive_construct_jax_function(numerically_safe=numerically_safe)ifnumerically_safe:defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)epsilon=jnp.finfo(JAX_FLOAT).epsifself.integer_exponentisnotNone:abs_exponent=jnp.abs(self.integer_exponent)safe_value=jnp.sqrt(child_value**2+epsilon)powered=safe_value**abs_exponentsigned=jnp.where(child_value<0,(-1)**self.integer_exponent,1.0)result=jnp.where(self.exponent<0,1.0/powered,powered)near_zero=jnp.logical_and(child_value>=-epsilon,child_value<=epsilon)defzero_case(_):returnjnp.array(0.0,dtype=JAX_FLOAT)defnonzero_case(_):returnresult*signedreturnjax.lax.cond(near_zeroifself.integer_exponent>0elsechild_value==0.0,zero_case,nonzero_case,operand=None,)else:defnan_branch(_):returnjnp.nandefsafe_branch(_):returnjnp.exp(self.exponent*jnp.log(jnp.clip(child_value,a_min=epsilon)))returnjax.lax.cond(child_value==0.0,lambda_:jnp.array(0.0,dtype=JAX_FLOAT),lambda_:jax.lax.cond(child_value<0.0,nan_branch,safe_branch,operand=None),operand=None,)returnthe_jax_functiondefthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)returnjnp.power(child_value,self.exponent)returnthe_jax_function