[docs]classLinearUtility(Expression):"""When the utility function is linear, it is expressed as a list of terms, where a parameter multiplies a variable. """def__init__(self,list_of_terms:list[LinearTermTuple]):"""Constructor :param list_of_terms: a list of tuple. Each tuple contains first a Beta parameter, second the name of a variable. :type list_of_terms: list(biogeme.expressions.Expression, biogeme.expressions.Expression) :raises biogeme.exceptions.BiogemeError: if the object is not a list of tuples (parameter, variable) """super().__init__()the_error=''first=Trueforb,vinlist_of_terms:ifnotisinstance(b,Beta)ornotisinstance(v,Variable):raiseBiogemeError(f'Each term must be a (Beta, Variable) pair. Got: ({b}, {v})')ifnotfirst:raiseBiogemeError(the_error)self.betas,self.variables=zip(*list_of_terms)self.betas=list(self.betas)#: list of parametersself.variables=list(self.variables)#: list of variablesself.list_of_terms=list(zip(self.betas,self.variables))""" List of terms """self.children+=self.betas+self.variables
[docs]defdeep_flat_copy(self)->LinearUtility:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """copy_list_of_terms=[LinearTermTuple(beta=term[0].deep_flat_copy(),x=term[1].deep_flat_copy())forterminself.list_of_terms]returntype(self)(list_of_terms=copy_list_of_terms)
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """beta_fns=[b.recursive_construct_jax_function(numerically_safe=numerically_safe)forbinself.betas]variable_fns=[v.recursive_construct_jax_function(numerically_safe=numerically_safe)forvinself.variables]defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:beta_values=jnp.array([fn(parameters,one_row,the_draws,the_random_variables)forfninbeta_fns])variable_values=jnp.array([fn(parameters,one_row,the_draws,the_random_variables)forfninvariable_fns])returnjnp.dot(beta_values,variable_values)returnthe_jax_function