Source code for biogeme.expressions.variable

"""Arithmetic expressions accepted by Biogeme: variables

Michel Bierlaire
Fri Jun 27 2025, 14:43:42
"""

from __future__ import annotations

import logging

import jax.numpy as jnp
import pandas as pd
import pymc as pm
from pytensor.tensor import TensorVariable

from biogeme.bayesian_estimation import check_shape
from biogeme.exceptions import BiogemeError
from .bayesian import Dimension, PymcModelBuilderType
from .elementary_expressions import Elementary
from .elementary_types import TypeOfElementaryExpression
from .jax_utils import JaxFunctionType

logger = logging.getLogger(__name__)


[docs] class Variable(Elementary): """Explanatory variable This represents the explanatory variables of the choice model. Typically, they come from the data set. """ expression_type = TypeOfElementaryExpression.VARIABLE def __init__(self, name: str): """Constructor :param name: name of the variable. :type name: string """ super().__init__(name)
[docs] def deep_flat_copy(self) -> Variable: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ return type(self)(name=self.name)
@property def safe_variable_id(self) -> int: """Check the presence of the ID before using it""" if self.specific_id is None: raise BiogemeError(f"No id defined for variable {self.name}") return self.specific_id
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: """ Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """ def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.array: return jnp.take(one_row, self.safe_variable_id, axis=-1) # return one_row[self.variableId] return the_jax_function
[docs] def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType: """ Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """ @check_shape def builder(dataframe: pd.DataFrame) -> TensorVariable: model = pm.modelcontext(None) # active model try: values = dataframe[self.name].to_numpy() except KeyError as e: raise BiogemeError( f"Column '{self.name}' not found in the dataframe." ) from e existing = model.named_vars.get(self.name) if existing is not None: # Ensure the pm.Data is refreshed with the current values pm.set_data({self.name: values}, model=model) return existing return pm.Data( self.name, values, dims=Dimension.OBS, ) return builder