[docs]defdeep_flat_copy(self)->log:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """copy_child=self.child.deep_flat_copy()returntype(self)(child=copy_child)
[docs]defget_value(self)->float:"""Evaluates the value of the expression :return: value of the expression :rtype: float """returnnp.log(self.child.get_value())
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """child_jax=self.child.recursive_construct_jax_function(numerically_safe=numerically_safe)ifnumerically_safe:defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)epsilon=EPSILONslope=1.0/epsilonintercept=jnp.log(epsilon)-slope*epsilonapprox_log=slope*child_value+interceptreturnjax.lax.cond(child_value<epsilon,lambda_:approx_log,lambda_:jnp.log(child_value),operand=None,)# result = jnp.log(# (child_value + jnp.sqrt(child_value**2 + EPSILON**2)) / 2# )# return resultreturnthe_jax_functiondefthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)returnjnp.log(child_value)returnthe_jax_function