Source code for biogeme.expressions.logit_expressions

"""Arithmetic expressions accepted by Biogeme: logit

:author: Michel Bierlaire
:date: Sat Sep  9 15:28:39 2023
"""

from __future__ import annotations

import logging
from itertools import chain
from typing import TYPE_CHECKING

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import pytensor.tensor as pt
from biogeme.floating_point import JAX_FLOAT, MIN_EXP_ARG
from jax.scipy.special import logsumexp

from .base_expressions import Expression, LogitTuple
from .bayesian import PymcModelBuilderType
from .convert import validate_and_convert
from .jax_utils import JaxFunctionType
from ..deprecated import deprecated
from ..exceptions import BiogemeError

if TYPE_CHECKING:
    from . import ExpressionOrNumeric
logger: logging.Logger = logging.getLogger(__name__)


MASKED_UTILITY_PT = pt.as_tensor_variable(float(MIN_EXP_ARG))


def _describe_tensor(t: pt.TensorVariable) -> str:
    """Compact description of a PyTensor variable for error messages."""
    t_type = getattr(t, "type", None)
    shape = getattr(t_type, "shape", None)
    return f"ndim={getattr(t, 'ndim', '?')}, shape={shape}"


def _prepare_pymc_logit_inputs(
    dataframe: pd.DataFrame,
    choice_builder: PymcModelBuilderType,
    util_builders: list[PymcModelBuilderType],
    av_builders: list[PymcModelBuilderType] | None,
    util_keys: list[int],
) -> tuple[pt.TensorVariable, pt.TensorVariable, pt.TensorVariable | None]:
    """
    Validate and construct PyTensor inputs for the LogLogit PyMC builder.

    This function takes care of:
    - evaluating choice / utilities / availabilities,
    - checking their shapes,
    - stacking utilities / availabilities along axis 1,
    - sanitizing NaN/Inf to safe values.

    :param dataframe: DataFrame used as input to builders.
    :param choice_builder: Builder returning the chosen alternative (per obs).
    :param util_builders: Builders returning utilities (one per alternative).
    :param av_builders: Builders returning availabilities (one per alt), or None.
    :param util_keys: List of alternative identifiers (for error messages).
    :return: (choice_i32, utilities, availabilities_or_None)
    :raises BiogemeError: If shapes are inconsistent or not 1-D where expected.
    """
    # ---- Choice vector ----
    choice = choice_builder(dataframe)
    if choice.ndim != 1:
        raise BiogemeError(
            "LogLogit PyMC builder: the choice expression must return a 1-D "
            f"tensor of per-observation choices.\n"
            f"Got: {_describe_tensor(choice)}"
        )
    choice_i32 = pt.cast(choice, "int32")  # (N,)

    # ---- Utilities ----
    util_tensors = [ub(dataframe) for ub in util_builders]
    bad_utils = [
        (alt, _describe_tensor(t))
        for alt, t in zip(util_keys, util_tensors)
        if getattr(t, "ndim", None) != 1
    ]
    if bad_utils:
        details = ", ".join(f"alt={alt}: {desc}" for alt, desc in bad_utils)
        raise BiogemeError(
            "LogLogit PyMC builder: each utility must be a 1-D tensor "
            "with one value per observation.\n"
            f"Offending utilities: {details}"
        )

    try:
        utilities = pt.stack(util_tensors, axis=1)  # (N, K)
    except (TypeError, ValueError, IndexError) as e:
        shapes = [_describe_tensor(t) for t in util_tensors]
        raise BiogemeError(
            "LogLogit PyMC builder: utilities for some alternatives are not "
            "shape-compatible for stacking along axis=1.\n"
            "Utilities (alt_id → tensor): "
            + ", ".join(f"{alt}: {desc}" for alt, desc in zip(util_keys, shapes))
            + "\nExpected all utilities to be 1-D tensors with the same "
            "length N (number of rows in the dataframe)."
        ) from e

    # Sanitize upstream NaN/Inf in utilities to a very small finite value
    utilities = pt.where(_isfinite(utilities), utilities, MASKED_UTILITY_PT)

    # ---- Availabilities ----
    if av_builders is None:
        availabilities = None
    else:
        av_tensors = [ab(dataframe) for ab in av_builders]
        bad_av = [
            (alt, _describe_tensor(t))
            for alt, t in zip(util_keys, av_tensors)
            if getattr(t, "ndim", None) != 1
        ]
        if bad_av:
            details = ", ".join(f"alt={alt}: {desc}" for alt, desc in bad_av)
            raise BiogemeError(
                "LogLogit PyMC builder: each availability must be a 1-D tensor "
                "with one value per observation.\n"
                f"Offending availabilities: {details}"
            )

        try:
            availabilities = pt.stack(av_tensors, axis=1)  # (N, K)
        except (TypeError, ValueError, IndexError) as e:
            shapes = [_describe_tensor(t) for t in av_tensors]
            raise BiogemeError(
                "LogLogit PyMC builder: availabilities for some alternatives "
                "are not shape-compatible for stacking along axis=1.\n"
                "Availabilities (alt_id → tensor): "
                + ", ".join(f"{alt}: {desc}" for alt, desc in zip(util_keys, shapes))
                + "\nExpected all availabilities to be 1-D tensors with the same "
                "length N (number of rows in the dataframe)."
            ) from e

        # Treat non-finite availability as unavailable (0.0)
        availabilities = pt.where(_isfinite(availabilities), availabilities, 0.0)

    return choice_i32, utilities, availabilities


# PyTensor does not have isfinite, so define our own
def _isfinite(x: pt.TensorVariable) -> pt.TensorVariable:
    """Elementwise finite check for PyTensor: True iff not NaN and not Inf."""
    return ~(pt.isnan(x) | pt.isinf(x))


[docs] def index_of(key: float, keys: list[int]): """Function returning the index of a kex for biogeme_jax""" return jnp.argmax(keys == key)
[docs] class LogLogit(Expression): """Expression capturing the logit formula. It contains one formula for the target alternative, a dict of formula for the availabilities and a dict of formulas for the utilities """ def __init__( self, util: dict[int, ExpressionOrNumeric], av: dict[int, ExpressionOrNumeric] | None, choice: ExpressionOrNumeric, ): """Constructor :param util: dictionary where the keys are the identifiers of the alternatives, and the elements are objects defining the utility functions. :type util: dict(int:biogeme.expressions.Expression) :param av: dictionary where the keys are the identifiers of the alternatives, and the elements are object of type biogeme.expressions.Expression defining the availability conditions. If av is None, all the alternatives are assumed to be always available :type av: dict(int:biogeme.expressions.Expression) :param choice: formula to obtain the alternative for which the logit probability must be calculated. :type choice: biogeme.expressions.Expression :raise BiogemeError: if one of the expressions is invalid, that is neither a numeric value nor a biogeme.expressions.Expression object. """ Expression.__init__(self) self.util: dict[int, Expression] = { alt_id: validate_and_convert(util_expression) for (alt_id, util_expression) in util.items() } #: dict of availability formulas self.av: dict[int, Expression] | None = None if av is not None: self.av = { alt_id: validate_and_convert(avail_expression) for (alt_id, avail_expression) in av.items() } for i, e in self.av.items(): self.children.append(e) self.av_keys = jnp.array(list(self.av.keys()), dtype=JAX_FLOAT) self.av_values = tuple(self.av[k] for k in self.av.keys()) self.choice: Expression = validate_and_convert(choice) """expression for the chosen alternative""" self.children.append(self.choice) for i, e in self.util.items(): self.children.append(e) # Convert the dict into list for biogeme_jax self.util_keys = jnp.array(list(self.util.keys()), dtype=JAX_FLOAT) self.util_values = tuple(self.util[k] for k in self.util.keys())
[docs] def deep_flat_copy(self) -> LogLogit: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ copy_util = {key: util.deep_flat_copy() for key, util in self.util.items()} copy_av = ( { key: av.deep_flat_copy() if av is not None else None for key, av in self.av.items() } if self.av is not None else None ) copy_choice = self.choice.deep_flat_copy() return type(self)(util=copy_util, av=copy_av, choice=copy_choice)
[docs] def logit_choice_avail(self) -> list[LogitTuple]: result: list[LogitTuple] = list( chain.from_iterable(e.logit_choice_avail() for e in self.children) ) if self.av is not None: this_tuple: LogitTuple = LogitTuple( choice=self.choice, availabilities=self.av ) result.append(this_tuple) return result
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float :raise BiogemeError: if the chosen alternative does not correspond to any of the utility functions :raise BiogemeError: if the chosen alternative does not correspond to any of entry in the availability condition """ choice = int(self.choice.get_value()) if choice not in self.util: error_msg = ( f'Alternative {choice} does not appear in the list ' f'of utility functions: {self.util.keys()}' ) raise BiogemeError(error_msg) if choice not in self.av: error_msg = ( f'Alternative {choice} does not appear in the list ' f'of availabilities: {self.av.keys()}' ) raise BiogemeError(error_msg) if self.av[choice].get_value() == 0.0: return -np.log(0) v_chosen = self.util[choice].get_value() denom = 0.0 for i, V in self.util.items(): if self.av[i].get_value() != 0.0: denom += np.exp(V.get_value() - v_chosen) return -np.log(denom)
[docs] @deprecated(get_value) def getValue(self) -> float: pass
def __str__(self) -> str: s = f'{self.get_class_name()}[choice={self.choice}]' util_str = ', '.join(f'{int(i)}:{e}' for i, e in self.util.items()) s += f'U=({util_str})' if self.av is None: s += '[always available]' else: av_str = ', '.join(f'{int(i)}:{e}' for i, e in self.av.items()) s += f'av=({av_str})' return s
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: """ Generates a JAX-compatible function. This function computes the logit-based probability calculation based on availability and utility values. :return: A function that takes parameters, a row of the database, and random draws. """ def get_value( expression: Expression, parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.ndarray: """Retrieve the JAX function of an object and evaluate it.""" jax_fn = expression.recursive_construct_jax_function( numerically_safe=numerically_safe ) return jax_fn(parameters, one_row, the_draws, the_random_variables) if self.av is None: def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.ndarray: """JAX-compatible function for logit probability calculation with availability.""" choice_id = get_value( self.choice, parameters, one_row, the_draws, the_random_variables ) choice_index = index_of(choice_id, self.util_keys) # Compute v_chosen branches = tuple( lambda _, V=V_expr: jnp.asarray( get_value( V, parameters, one_row, the_draws, the_random_variables ), dtype=JAX_FLOAT, ) for V_expr in self.util_values ) v_chosen = jax.lax.switch(choice_index, branches, operand=None) # Vectorized computation of utilities and availabilities all_utils = jnp.array( [ get_value( V, parameters, one_row, the_draws, the_random_variables ) - v_chosen for V in self.util_values ] ) # Compute the log-sum-exp safely result = -logsumexp(all_utils) return result return the_jax_function else: def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.ndarray: """JAX-compatible function for logit probability calculation.""" choice_id = get_value( self.choice, parameters, one_row, the_draws, the_random_variables ) choice_index = index_of(choice_id, self.util_keys) # Get availability of chosen alternative av_branches = tuple( lambda _, av=av_expr: get_value( av, parameters, one_row, the_draws, the_random_variables ) for av_expr in self.av_values ) chosen_avail = jax.lax.switch(choice_index, av_branches, operand=None) def unavailable_branch(_): # If the chosen alternative is unavailable return -jnp.finfo(JAX_FLOAT).max def available_branch(_): # Compute v_chosen branches = tuple( lambda _, V=V_expr: jnp.asarray( get_value( V, parameters, one_row, the_draws, the_random_variables ), dtype=JAX_FLOAT, ) for V_expr in self.util_values ) v_chosen = jax.lax.switch(choice_index, branches, operand=None) # Vectorized computation of utilities and availabilities all_utils = jnp.array( [ get_value( V, parameters, one_row, the_draws, the_random_variables ) - v_chosen for V in self.util_values ] ) all_avail = jnp.array( [ get_value( A, parameters, one_row, the_draws, the_random_variables ) for A in self.av_values ] ) masked_utils = jnp.where(all_avail != 0.0, all_utils, -jnp.inf) return -logsumexp(masked_utils) # Conditionally compute result result = jax.lax.cond( chosen_avail == 0.0, unavailable_branch, available_branch, operand=None, ) return result return the_jax_function
[docs] def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType: """Return a *vectorized* PyTensor builder computing per-observation log-likelihoods for the multinomial logit, compatible with PyMC. Shapes (vectorized over observations): - util_keys_vec: (K,) - choice: (N,) arbitrary codes (mapped to column indices) - U / AV: (N, K) - output ll: (N,) """ util_keys = list(self.util.keys()) util_builders = [ self.util[k].recursive_construct_pymc_model_builder() for k in util_keys ] choice_builder = self.choice.recursive_construct_pymc_model_builder() av_builders = ( [self.av[k].recursive_construct_pymc_model_builder() for k in util_keys] if self.av is not None else None ) util_keys_vec = pt.constant(np.asarray(util_keys, dtype=np.int32)) def builder(dataframe: pd.DataFrame) -> pt.TensorVariable: # Delegate all validation + stacking to the helper choice_i32, utilities, availabilities = _prepare_pymc_logit_inputs( dataframe=dataframe, choice_builder=choice_builder, util_builders=util_builders, av_builders=av_builders, util_keys=util_keys, ) if availabilities is None: U_masked = utilities else: U_masked = pt.where( pt.neq(availabilities, 0.0), utilities, MASKED_UTILITY_PT ) # Map choice codes to utility columns matches = pt.eq(choice_i32[:, None], util_keys_vec[None, :]) # (N, K) idx = pt.argmax(matches, axis=1) # (N,) any_match = pt.any(matches, axis=1) # (N,) n_obs = choice_i32.shape[0] row_idx = pt.arange(n_obs) # Safe index: use column 0 when there is no match; penalize later idx_safe = pt.where(any_match, idx, 0) chosen_logits = U_masked[row_idx, idx_safe] # (N,) ll = chosen_logits - pt.logsumexp(U_masked, axis=1) # (N,) neg_large = pt.cast(pt.as_tensor_variable(-1.0e30), utilities.dtype) if availabilities is not None: chosen_av = availabilities[row_idx, idx_safe] ll = pt.where(pt.eq(chosen_av, 0.0), neg_large, ll) ll = pt.where(any_match, ll, neg_large) return ll return builder
[docs] def recursive_construct_pymc_model_builder_old(self) -> PymcModelBuilderType: """Return a *vectorized* PyTensor builder computing per-observation log-likelihoods for the multinomial logit, compatible with PyMC. Shapes (vectorized over observations): - util_keys_vec: (K,) - choice: (N,) arbitrary codes (mapped to column indices) - U / AV: (N, K) - output ll: (N,) """ util_keys = list(self.util.keys()) util_builders = [ self.util[k].recursive_construct_pymc_model_builder() for k in util_keys ] choice_builder = self.choice.recursive_construct_pymc_model_builder() av_builders = ( [self.av[k].recursive_construct_pymc_model_builder() for k in util_keys] if self.av is not None else None ) util_keys_vec = pt.constant(np.asarray(util_keys, dtype=np.int32)) def builder(dataframe: pd.DataFrame) -> pt.TensorVariable: choice_i32 = pt.cast(choice_builder(dataframe), "int32") # (N,) utilities = pt.stack( [ub(dataframe) for ub in util_builders], axis=1 ) # (N, K) # Sanitize upstream NaN/Inf in utilities to a very small finite value utilities = pt.where(_isfinite(utilities), utilities, MASKED_UTILITY_PT) if av_builders is None: U_masked = utilities availabilities = None else: availabilities = pt.stack( [ab(dataframe) for ab in av_builders], axis=1 ) # (N, K) # Treat non-finite availability as unavailable (0.0) availabilities = pt.where( _isfinite(availabilities), availabilities, 0.0 ) U_masked = pt.where( pt.neq(availabilities, 0.0), utilities, MASKED_UTILITY_PT ) matches = pt.eq(choice_i32[:, None], util_keys_vec[None, :]) # (N, K) idx = pt.argmax(matches, axis=1) # (N,) any_match = pt.any(matches, axis=1) # (N,) n_obs = choice_i32.shape[0] row_idx = pt.arange(n_obs) # Safe index: use column 0 when there is no match; penalize later idx_safe = pt.where(any_match, idx, 0) chosen_logits = U_masked[row_idx, idx_safe] # (N,) ll = chosen_logits - pt.logsumexp(U_masked, axis=1) # (N,) neg_large = pt.cast(pt.as_tensor_variable(-1.0e30), utilities.dtype) if availabilities is not None: chosen_av = availabilities[row_idx, idx_safe] ll = pt.where(pt.eq(chosen_av, 0.0), neg_large, ll) ll = pt.where(any_match, ll, neg_large) return ll return builder