Source code for biogeme.expressions.belongs_to

"""Arithmetic expressions accepted by Biogeme: belongs to

Michel Bierlaire
Sat May 03 2025, 11:56:33
"""

from __future__ import annotations

import logging

import pandas as pd
import pytensor.tensor as pt
from jax import numpy as jnp

from .base_expressions import ExpressionOrNumeric
from .bayesian import PymcModelBuilderType
from .jax_utils import JaxFunctionType
from .unary_expressions import UnaryOperator

logger = logging.getLogger(__name__)


[docs] class BelongsTo(UnaryOperator): """ Check if a value belongs to a set """ def __init__(self, child: ExpressionOrNumeric, the_set: set[float]): """Constructor :param child: arithmetic expression :type child: biogeme.expressions.Expression :param the_set: set of values :type the_set: set(float) """ super().__init__(child) self.the_set: set[float] = the_set
[docs] def deep_flat_copy(self) -> BelongsTo: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ child_copy = self.child.deep_flat_copy() return type(self)(child=child_copy, the_set=self.the_set)
def __str__(self) -> str: return f'BelongsTo({self.child}, "{self.the_set}")' def __repr__(self) -> str: return f'BelongsTo({self.child}, "{self.the_set}")'
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: """ Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """ child_jax = self.child.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.ndarray: child_value = child_jax( parameters, one_row, the_draws, the_random_variables ) return jnp.where( jnp.isin(child_value, jnp.array(list(self.the_set))), 1.0, 0.0 ) return the_jax_function
[docs] def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType: """ Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """ child_builder = self.child.recursive_construct_pymc_model_builder() def builder(dataframe: pd.DataFrame) -> pt.TensorVariable: child_val: pt.TensorVariable = child_builder(dataframe) # Make a constant tensor of the set values, matching dtype for safety set_vals = pt.constant(set_values_np, dtype=child_val.dtype) # Membership test: membership = pt.any(pt.eq(child_val[..., None], set_vals), axis=-1) # Return 1.0 where in set, else 0.0, with correct dtype/shape return pt.where( membership, pt.ones_like(child_val), pt.zeros_like(child_val) ) return builder