Source code for biogeme.expressions.distributed_parameter
"""Arithmetic expressions accepted by Biogeme:exp
Michel Bierlaire
10.04.2025 11:48
"""
from __future__ import annotations
import logging
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
from jax import numpy as jnp
from . import (
ExpressionOrNumeric,
)
from .bayesian import Dimension, PymcModelBuilderType
from .jax_utils import JaxFunctionType
from .unary_expressions import UnaryOperator
logger = logging.getLogger(__name__)
[docs]
class DistributedParameter(UnaryOperator):
"""
Distributed parameter in Bayesian estimation
"""
def __init__(
self,
name: str,
child: ExpressionOrNumeric,
) -> None:
"""Constructor
:param child: expression of the parameter. Most of the time, mu + sigma * xi.
"""
super().__init__(child)
self.name = name
self.panel_column = None
[docs]
def deep_flat_copy(self) -> DistributedParameter:
"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children.
It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression.
The flat part is irrelevant for this expression.
"""
copy_child = self.child.deep_flat_copy()
return type(self)(name=self.name, child=copy_child)
def __str__(self) -> str:
return f'DistributedParameter({self.name}, {self.child})'
def __repr__(self) -> str:
return f'DistributedParameter({self.name}, {self.child}, {repr(self.child)})'
[docs]
def recursive_construct_jax_function(
self, numerically_safe: bool
) -> JaxFunctionType:
"""
Generates a function to be used by biogeme_jax. Must be overloaded by each expression
:return: the function takes two parameters: the parameters, and one row of the database.
"""
child_jax = self.child.recursive_construct_jax_function(
numerically_safe=numerically_safe
)
def the_jax_function(
parameters: jnp.ndarray,
one_row: jnp.ndarray,
the_draws: jnp.ndarray,
the_random_variables: jnp.ndarray,
) -> jnp.ndarray:
child_value = child_jax(
parameters, one_row, the_draws, the_random_variables
)
return child_value
return the_jax_function
[docs]
def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType:
"""
Generates recursively a function to be used by PyMc. Must be overloaded by each expression
:return: the expression in TensorVariable format, suitable for PyMc
"""
child_pymc = self.child.recursive_construct_pymc_model_builder()
def builder(dataframe: pd.DataFrame) -> pt.TensorVariable:
model = pm.modelcontext(None) # Get current active model context
if self.name in model.named_vars:
return model.named_vars[self.name]
child_value = child_pymc(dataframe=dataframe)
# Panel case: map individuals -> observations using the panel column.
if self.panel_column is not None and self.panel_column in dataframe.columns:
# Panel ids for each observation; we map them to integer indices.
# Assumption: the order of individuals used to build the draws is
# consistent with the codes produced here (e.g., via factorize).
panel_ids = dataframe[self.panel_column].to_numpy()
# Map arbitrary ids to 0..(n_individuals-1).
codes, uniques = pd.factorize(panel_ids, sort=True)
# We rely on child_value having shape (n_individuals, ...) with the same
# ordering as `uniques`. If not, the calling code must ensure consistency.
idx = pt.as_tensor_variable(codes, dtype="int64")
indiv_name = f"{self.name}_per_individual"
if indiv_name not in model.named_vars:
pm.Deterministic(
indiv_name,
child_value,
dims=(Dimension.INDIVIDUALS.value,),
)
# Broadcast from individual-level to observation-level.
child_value_obs = child_value[idx]
return pm.Deterministic(
self.name,
child_value_obs,
dims=(Dimension.OBS.value,),
)
return pm.Deterministic(self.name, child_value, dims=(Dimension.OBS.value,))
return builder