Source code for biogeme.expressions.expm1

"""Arithmetic expressions accepted by Biogeme: expm1

Michel Bierlaire
Mon Nov 03 2025, 16:44:14
"""

from __future__ import annotations

import logging

import numpy as np
import pandas as pd
import pytensor.tensor as pt
from biogeme.floating_point import MAX_EXP_ARG, MIN_EXP_ARG
from jax import numpy as jnp

from .base_expressions import ExpressionOrNumeric
from .bayesian import PymcModelBuilderType
from .jax_utils import JaxFunctionType
from .unary_expressions import UnaryOperator

logger = logging.getLogger(__name__)


[docs] class expm1(UnaryOperator): """ exponential minus one expression, i.e. eˣ - 1, implemented in a numerically stable way. """ def __init__(self, child: ExpressionOrNumeric) -> None: """Constructor :param child: arithmetic expression :type child: biogeme.expressions.Expression """ super().__init__(child)
[docs] def deep_flat_copy(self) -> expm1: """Provides a deep copy of the expression.""" copy_child = self.child.deep_flat_copy() return type(self)(child=copy_child)
def __str__(self) -> str: return f'expm1({self.child})' def __repr__(self) -> str: return f'expm1({repr(self.child)})'
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: eˣ - 1 :rtype: float """ return np.expm1(self.child.get_value())
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: """ Generates a JAX-compatible function for Biogeme-JAX. :return: callable(parameters, one_row, the_draws, the_random_variables) """ child_jax = self.child.recursive_construct_jax_function( numerically_safe=numerically_safe ) if numerically_safe: def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.ndarray: child_value = child_jax( parameters, one_row, the_draws, the_random_variables ) safe_value = jnp.clip(child_value, min=MIN_EXP_ARG, max=MAX_EXP_ARG) result = jnp.expm1(safe_value) return result return the_jax_function def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> jnp.ndarray: child_value = child_jax( parameters, one_row, the_draws, the_random_variables ) result = jnp.expm1(child_value) return result return the_jax_function
[docs] def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType: """ Generates recursively a function for PyMC representation. :return: the expression in TensorVariable format (PyTensor) """ child_pymc = self.child.recursive_construct_pymc_model_builder() def builder(dataframe: pd.DataFrame) -> pt.TensorVariable: child_value = child_pymc(dataframe=dataframe) return pt.expm1(child_value) return builder