"""Arithmetic expressions accepted by Biogeme: nary operators
Michel Bierlaire
Sat Sep 9 15:29:36 2023
"""
from __future__ import annotations
import logging
from typing import NamedTuple, TYPE_CHECKING
import jax.numpy as jnp
import pandas as pd
import pytensor.tensor as pt
from biogeme.exceptions import BiogemeError
from biogeme.expressions.bayesian import PymcModelBuilderType
from .base_expressions import Expression
from .beta_parameters import Beta
from .jax_utils import JaxFunctionType
from .variable import Variable
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
[docs]
class LinearTermTuple(NamedTuple):
beta: Beta
x: Variable
[docs]
class LinearUtility(Expression):
"""When the utility function is linear, it is expressed as a list of
terms, where a parameter multiplies a variable.
"""
def __init__(self, list_of_terms: list[LinearTermTuple]):
"""Constructor
:param list_of_terms: a list of tuple. Each tuple contains first
a Beta parameter, second the name of a variable.
:type list_of_terms: list(biogeme.expressions.Expression,
biogeme.expressions.Expression)
:raises biogeme.exceptions.BiogemeError: if the object is not
a list of tuples (parameter, variable)
"""
super().__init__()
the_error = ''
first = True
for b, v in list_of_terms:
if not isinstance(b, Beta) or not isinstance(v, Variable):
raise BiogemeError(
f'Each term must be a (Beta, Variable) pair. Got: ({b}, {v})'
)
if not first or not list_of_terms:
raise BiogemeError(the_error)
self.betas, self.variables = zip(*list_of_terms)
self.betas = list(self.betas) #: list of parameters
self.variables = list(self.variables) #: list of variables
self.list_of_terms = list_of_terms
""" List of terms """
self.children += self.betas + self.variables
[docs]
def deep_flat_copy(self) -> LinearUtility:
"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children.
It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression.
The flat part is irrelevant for this expression.
"""
copy_list_of_terms = [
LinearTermTuple(beta=term[0].deep_flat_copy(), x=term[1].deep_flat_copy())
for term in self.list_of_terms
]
return type(self)(list_of_terms=copy_list_of_terms)
def __str__(self) -> str:
return ' + '.join([f'{b} * {x}' for b, x in self.list_of_terms])
def __repr__(self) -> str:
return f"LinearUtility({repr(self.list_of_terms)})"
[docs]
def recursive_construct_jax_function(
self, numerically_safe: bool
) -> JaxFunctionType:
"""
Generates a function to be used by biogeme_jax. Must be overloaded by each expression
:return: the function takes two parameters: the parameters, and one row of the database.
"""
beta_fns = [
b.recursive_construct_jax_function(numerically_safe=numerically_safe)
for b in self.betas
]
variable_fns = [
v.recursive_construct_jax_function(numerically_safe=numerically_safe)
for v in self.variables
]
def the_jax_function(
parameters: jnp.ndarray,
one_row: jnp.ndarray,
the_draws: jnp.ndarray,
the_random_variables: jnp.ndarray,
) -> jnp.ndarray:
beta_values = jnp.array(
[
fn(parameters, one_row, the_draws, the_random_variables)
for fn in beta_fns
]
)
variable_values = jnp.array(
[
fn(parameters, one_row, the_draws, the_random_variables)
for fn in variable_fns
]
)
return jnp.dot(beta_values, variable_values)
return the_jax_function
[docs]
def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType:
"""
PyMC builder for LinearUtility:
- evaluate Beta (scalar) and Variable (per-observation) children
- form elementwise products beta_k * x_k
- stack along a new axis and sum to get the linear utility per observation
"""
# Builders for each Beta and Variable term (preserve pairing order)
beta_builders = [b.recursive_construct_pymc_model_builder() for b in self.betas]
var_builders = [
v.recursive_construct_pymc_model_builder() for v in self.variables
]
def builder(dataframe: pd.DataFrame) -> pt.TensorVariable:
# Evaluate all terms on the dataframe
betas = [bb(dataframe=dataframe) for bb in beta_builders]
vars_ = [vb(dataframe=dataframe) for vb in var_builders]
if len(betas) != len(vars_):
raise BiogemeError(
f"LinearUtility mismatch: {len(betas)} betas for {len(vars_)} variables."
)
# Form beta*x for each pair; broadcasting handles scalar beta with vector x
try:
products = [b * x for b, x in zip(betas, vars_)]
if len(products) == 1:
return products[0]
return pt.sum(pt.stack(products, axis=0), axis=0)
except (TypeError, ValueError) as e:
shape_pairs = [
(
getattr(getattr(b, "type", None), "shape", None),
getattr(getattr(x, "type", None), "shape", None),
)
for b, x in zip(betas, vars_)
]
raise BiogemeError(
"LinearUtility terms are not shape-compatible. "
f"Got (beta_shape, var_shape) pairs: {shape_pairs}. "
"Each product beta_k * x_k must broadcast to a common per-observation shape."
) from e
return builder