Source code for biogeme.expressions.draws

"""Arithmetic expressions accepted by Biogeme: draws

Michel Bierlaire
Fri Jun 27 2025, 14:41:17
"""

from __future__ import annotations

import logging

import pandas as pd
import pymc as pm
from jax import numpy as jnp
from pytensor.tensor import TensorVariable

from biogeme.draws import PyMcDistributionFactory, get_distribution, pymc_distributions
from biogeme.exceptions import BiogemeError
from .bayesian import Dimension, PymcModelBuilderType
from .elementary_expressions import Elementary
from .elementary_types import TypeOfElementaryExpression
from .jax_utils import JaxFunctionType

logger = logging.getLogger(__name__)


[docs] class Draws(Elementary): """ Draws for Monte-Carlo integration """ expression_type = TypeOfElementaryExpression.DRAWS def __init__( self, name: str, draw_type: str = 'NORMAL', dict_of_distributions: dict[str, PyMcDistributionFactory] | None = None, ): """Constructor :param name: name of the random variable with a series of draws. :type name: string :param draw_type: type of draws. :type draw_type: string """ super().__init__(name) self.draw_type = draw_type self._draw_dimension = Dimension.OBS self._is_complex = True self.dict_of_distributions: dict[str, PyMcDistributionFactory] = ( dict_of_distributions if dict_of_distributions is not None else pymc_distributions ) @property def draw_dimension(self) -> Dimension: return self._draw_dimension
[docs] def set_draw_per_observation(self) -> None: self._draw_dimension = Dimension.OBS
[docs] def set_draw_per_individual(self) -> None: self._draw_dimension = Dimension.INDIVIDUALS
[docs] def deep_flat_copy(self) -> Draws: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ return type(self)(name=self.name, draw_type=self.draw_type)
def __str__(self) -> str: return f'Draws("{self.name}", "{self.draw_type}")' @property def safe_draw_id(self) -> int: """Check the presence of the draw ID before its usage""" if self.specific_id is None: raise BiogemeError(f"No id defined for draw {self.name}") return self.specific_id
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: """ Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """ def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.ndarray: return jnp.take(the_draws, self.safe_draw_id, axis=-1) return the_jax_function
[docs] def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType: """ Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """ try: selected_distribution = get_distribution( name=self.draw_type, the_dict=self.dict_of_distributions ) except ValueError as e: error_msg = f'Problem generating draws for {self.name}. {e}' raise BiogemeError(error_msg) from e def builder(dataframe: pd.DataFrame) -> TensorVariable: model = pm.modelcontext(None) # Get current active model context if self.name in model.named_vars: return model.named_vars[self.name] # return pm.Normal(name=self.name, mu=0, sigma=1) the_distribution = selected_distribution( name=self.name, dims=self.draw_dimension ) # ic(self.name, the_distribution) return the_distribution return builder