Source code for biogeme.calculator.multiple_formula
"""This module defines the MultiRowEvaluator class, which evaluates multiple expressionson a given database using JAX for efficient batched computation. It returns resultsas a pandas DataFrame with one column per expression and one row per observation.Michel BierlaireWed Apr 2 13:10:17 2025"""importjaximportjax.numpyasjnpimportnumpyasnpimportpandasaspdfrombiogeme.exceptionsimportBiogemeErrorfrombiogeme.expressionsimportbuild_vectorized_functionfrombiogeme.floating_pointimportJAX_FLOAT,NUMPY_FLOATfrombiogeme.model_elementsimportModelElements
[docs]classMultiRowEvaluator:""" Evaluates multiple expressions on a common dataset using JAX and returns results as a pandas DataFrame. This class compiles all expressions into JAX functions and evaluates them efficiently in a single batched operation. :param model_elements: Object containing the expressions and all elements needed to calculate them. """def__init__(self,model_elements:ModelElements,numerically_safe:bool,use_jit:bool,):ifmodel_elementsisNone:raiseBiogemeError('A model must be provided.')self.multiple_model_elements=model_elementsself.free_betas_names=model_elements.expressions_registry.free_betas_namesself.data_jax=model_elements.database.data_jaxself.draws_jax=model_elements.draws_management.draws_jaxself.names=list(model_elements.expressions.keys())n_rv=(self.multiple_model_elements.expressions_registry.number_of_random_variables)self.random_variables_jax=jnp.zeros((n_rv,),dtype=JAX_FLOAT)self.vectorized_functions=[build_vectorized_function(expr.recursive_construct_jax_function(numerically_safe=numerically_safe),use_jit=use_jit,)forexprinself.multiple_model_elements.expressions.values()]defevaluate_all_impl(params,data,draws,rv):returnjnp.stack([vf(params,data,draws,rv)forvfinself.vectorized_functions],axis=1,)ifuse_jit:self._evaluate_all=jax.jit(evaluate_all_impl)else:self._evaluate_all=evaluate_all_impl
[docs]defevaluate(self,the_betas:dict[str,float])->pd.DataFrame:""" Evaluates all expressions using the provided beta values. :param the_betas: A dictionary mapping beta names to their numerical values. :return: A pandas DataFrame with one column per expression and one row per observation. """param_vector=(self.multiple_model_elements.expressions_registry.get_betas_array(the_betas))values=self._evaluate_all(param_vector,self.data_jax,self.draws_jax,self.random_variables_jax)returnpd.DataFrame(np.asarray(values,dtype=NUMPY_FLOAT),columns=self.names)