Source code for biogeme.expressions.distributed_parameter
"""Arithmetic expressions accepted by Biogeme:expMichel Bierlaire10.04.2025 11:48"""from__future__importannotationsimportloggingimportpandasaspdimportpymcaspmimportpytensor.tensorasptfromjaximportnumpyasjnpfrom.import(ExpressionOrNumeric,)from.bayesianimportDimension,PymcModelBuilderTypefrom.jax_utilsimportJaxFunctionTypefrom.unary_expressionsimportUnaryOperatorlogger=logging.getLogger(__name__)
[docs]classDistributedParameter(UnaryOperator):""" Distributed parameter in Bayesian estimation """def__init__(self,name:str,child:ExpressionOrNumeric,)->None:"""Constructor :param child: expression of the parameter. Most of the time, mu + sigma * xi. """super().__init__(child)self.name=nameself.panel_column=None
[docs]defdeep_flat_copy(self)->DistributedParameter:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """copy_child=self.child.deep_flat_copy()returntype(self)(name=self.name,child=copy_child)
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """child_jax=self.child.recursive_construct_jax_function(numerically_safe=numerically_safe)defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:child_value=child_jax(parameters,one_row,the_draws,the_random_variables)returnchild_valuereturnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:""" Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """child_pymc=self.child.recursive_construct_pymc_model_builder()defbuilder(dataframe:pd.DataFrame)->pt.TensorVariable:model=pm.modelcontext(None)# Get current active model contextifself.nameinmodel.named_vars:returnmodel.named_vars[self.name]child_value=child_pymc(dataframe=dataframe)# Panel case: map individuals -> observations using the panel column.ifself.panel_columnisnotNoneandself.panel_columnindataframe.columns:# Panel ids for each observation; we map them to integer indices.# Assumption: the order of individuals used to build the draws is# consistent with the codes produced here (e.g., via factorize).panel_ids=dataframe[self.panel_column].to_numpy()# Map arbitrary ids to 0..(n_individuals-1).codes,uniques=pd.factorize(panel_ids,sort=True)# We rely on child_value having shape (n_individuals, ...) with the same# ordering as `uniques`. If not, the calling code must ensure consistency.idx=pt.as_tensor_variable(codes,dtype="int64")indiv_name=f"{self.name}_per_individual"ifindiv_namenotinmodel.named_vars:pm.Deterministic(indiv_name,child_value,dims=(Dimension.INDIVIDUALS.value,),)# Broadcast from individual-level to observation-level.child_value_obs=child_value[idx]returnpm.Deterministic(self.name,child_value_obs,dims=(Dimension.OBS.value,),)returnpm.Deterministic(self.name,child_value,dims=(Dimension.OBS.value,))returnbuilder