Source code for biogeme.expressions.conditional_sum

"""Arithmetic expressions accepted by Biogeme: ConditionalSum

Michel Bierlaire
Sat Sep  9 15:29:36 2023
"""

from __future__ import annotations

import logging
from typing import Iterable, NamedTuple

import jax
import pandas as pd
import pytensor.tensor as pt

from biogeme.exceptions import BiogemeError
from .base_expressions import Expression, ExpressionOrNumeric
from .bayesian import PymcModelBuilderType
from .convert import validate_and_convert
from .jax_utils import JaxFunctionType

logger = logging.getLogger(__name__)


[docs] class ConditionalTermTuple(NamedTuple): condition: ExpressionOrNumeric term: ExpressionOrNumeric
[docs] class ConditionalSum(Expression): """This expression returns the sum of a selected list of expressions. An expression is considered in the sum only if the corresponding key is True (that is, return a non-zero value). """ def __init__(self, list_of_terms: Iterable[ConditionalTermTuple]): """Constructor :param list_of_terms: list containing the terms and the associated conditions :raise BiogemeError: if one of the expressions is invalid, that is neither a numeric value nor a biogeme.expressions.Expression object. :raise BiogemeError: if the dict of expressions is empty :raise BiogemeError: if the dict of expressions is not a dict """ if not list_of_terms: raise BiogemeError('The argument of ConditionalSum cannot be empty') super().__init__() self.list_of_terms = [ the_term._replace( condition=validate_and_convert(the_term.condition), term=validate_and_convert(the_term.term), ) for the_term in list_of_terms ] for the_term in self.list_of_terms: self.children.append(the_term.condition) self.children.append(the_term.term)
[docs] def deep_flat_copy(self) -> ConditionalSum: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ copy_list_of_terms: list[ConditionalTermTuple] = [ ConditionalTermTuple( condition=the_term.condition.deep_flat_copy(), term=the_term.term.deep_flat_copy(), ) for the_term in self.list_of_terms ] return type(self)(list_of_terms=copy_list_of_terms)
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float """ result = 0.0 for the_term in self.list_of_terms: condition = the_term.condition.get_value() if condition != 0: result += the_term.term.get_value() return result
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: compiled_terms = [ ( cond.recursive_construct_jax_function( numerically_safe=numerically_safe ), term.recursive_construct_jax_function( numerically_safe=numerically_safe ), ) for cond, term in self.list_of_terms ] def the_jax_function(parameters, one_row, the_draws, the_random_variables): result = 0.0 for cond_fn, term_fn in compiled_terms: cond_val = cond_fn(parameters, one_row, the_draws, the_random_variables) def include_branch(_): return term_fn(parameters, one_row, the_draws, the_random_variables) def skip_branch(_): return 0.0 result += jax.lax.cond( cond_val != 0.0, include_branch, skip_branch, operand=None ) return result return the_jax_function
[docs] def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType: """ Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """ pymc_terms = [ ( cond.recursive_construct_pymc_model_builder(), term.recursive_construct_pymc_model_builder(), ) for cond, term in self.list_of_terms ] def builder(dataframe: pd.DataFrame) -> pt.TensorVariable: # Build all (condition, term) once built = [(c_fn(dataframe), t_fn(dataframe)) for c_fn, t_fn in pymc_terms] # Stack conditions as booleans and terms with a single zeros_like conds = pt.stack([pt.neq(c, 0) for c, _ in built], axis=0) terms = pt.stack([t for _, t in built], axis=0) masked = pt.where(conds, terms, pt.zeros_like(terms)) return masked.sum(axis=0) return builder
def __str__(self) -> str: s = ( 'ConditionalSum(' + ', '.join([f'{k}: {v}' for k, v in self.list_of_terms]) + ')' ) return s def __repr__(self) -> str: return f"ConditionalSum({repr(self.list_of_terms)})"