Source code for biogeme.expressions.boxcox

"""Arithmetic expressions accepted by Biogeme: BoxCox

Michel Bierlaire
Mon Nov 03 2025, 17:16:46
"""

from __future__ import annotations

import logging
import math

import jax
import jax.numpy as jnp
import pandas as pd
import pytensor.tensor as pt
from biogeme.expressions import Beta, Expression

from .bayesian import PymcModelBuilderType
from .binary_expressions import BinaryOperator
from .convert import validate_and_convert
from .jax_utils import JaxFunctionType

logger = logging.getLogger(__name__)


[docs] class BoxCox(BinaryOperator): """ Box–Cox transform with McLaurin expansion near :math:`\\ell = 0`. .. math:: B(x, \\ell) = \\frac{x^{\\ell} - 1}{\\ell} with the limit .. math:: \\lim_{\\ell \\to 0} B(x, \\ell) = \\log(x). To avoid numerical issues, we use a McLaurin expansion for small :math:`\\ell`: .. math:: \\log(x) + \\ell \\log(x)^2 + \\frac{1}{6} \\ell^2 \\log(x)^3 + \\frac{1}{24} \\ell^3 \\log(x)^4. and a special case :math:`B(0, \\ell) = 0`. This class reproduces the behaviour of ``boxcox_old`` but implements the piecewise logic with JAX / PyTensor control flow instead of :class:`Elem`, so it is compatible with JAX and PyMC backends. """ def __init__(self, x: Expression, ell: Expression): # Always store the validated/converted children and pass them to the # parent constructor. This avoids keeping both raw and converted # references and prevents duplicating children in the expression tree. x_c = validate_and_convert(x) ell_c = validate_and_convert(ell) super().__init__(left=x_c, right=ell_c) self.x = x_c self.ell = ell_c def __str__(self) -> str: return f'BoxCox({self.left}, {self.right})' def __repr__(self) -> str: return f'BoxCox({repr(self.left)}, {repr(self.right)})'
[docs] def deep_flat_copy(self) -> BoxCox: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ left_copy = self.left.deep_flat_copy() right_copy = self.right.deep_flat_copy() return type(self)(x=left_copy, ell=right_copy)
[docs] def get_value(self) -> float: """ Evaluate the Box–Cox transform for scalar values. - If ``x == 0``, returns 0.0. - If ``|ell| < 1e-5``, uses the McLaurin expansion around ``ell = 0``. - Otherwise, uses the standard Box–Cox formula. :return: Scalar value of the Box–Cox transform. """ # Retrieve scalar values; will raise if not possible x_value = float(self.x.get_value()) ell_value = float(self.ell.get_value()) # Convention: B(0, ell) = 0 for any ell if x_value == 0.0: return 0.0 # McLaurin expansion around ell = 0 for numerical stability if abs(ell_value) < 1.0e-5: lx = math.log(x_value) return ( lx + ell_value * lx**2 + (ell_value**2) * lx**3 / 6.0 + (ell_value**3) * lx**4 / 24.0 ) # Regular Box–Cox formula return (x_value**ell_value - 1.0) / ell_value
# ------------------------------------------------------------------ # JAX builder # ------------------------------------------------------------------
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: """ JAX implementation of the Box–Cox transform. Uses: - regular formula when ``|ell| >= 1e-5``, - McLaurin expansion when ``|ell| < 1e-5``, - value 0.0 when ``x == 0``. """ get_x = self.x.recursive_construct_jax_function( numerically_safe=numerically_safe ) get_ell = self.ell.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax(theta, one_row, draws, rvars): x = get_x(theta, one_row, draws, rvars) ell = get_ell(theta, one_row, draws, rvars) # regular branch def regular_branch(_): return (jnp.power(x, ell) - 1.0) / ell # McLaurin expansion branch def mclaurin_branch(_): lx = jnp.log(x) return ( lx + ell * lx**2 + (ell**2) * lx**3 / 6.0 + (ell**3) * lx**4 / 24.0 ) # choose between regular and McLaurin based on |ell| def inner(_): return jax.lax.cond( jnp.abs(ell) < 1.0e-5, mclaurin_branch, regular_branch, operand=None, ) # top-level: x == 0 -> 0.0, else inner val = jax.lax.cond(x == 0.0, lambda _: 0.0, inner, operand=None) return val return the_jax
# ------------------------------------------------------------------ # PyMC / PyTensor builder # ------------------------------------------------------------------
[docs] def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType: """ PyTensor implementation of the Box–Cox transform, mirroring the original ``boxcox_old`` piecewise logic with ``pt.switch``. """ x_b = self.x.recursive_construct_pymc_model_builder() ell_b = self.ell.recursive_construct_pymc_model_builder() def builder(dataframe: pd.DataFrame) -> pt.TensorVariable: x = x_b(dataframe) ell = ell_b(dataframe) # Warn if ell is a Beta without bounds, as in boxcox_old # (we inspect the expression tree statically, so do this outside # builder if you prefer; left here for simplicity) # NOTE: if ell is not a Beta, this does nothing. # You can drop this if you already warn elsewhere. if isinstance(self.ell, Beta) and ( self.ell.upper_bound is None or self.ell.lower_bound is None ): warning_msg = ( f'It is advised to set the bounds on parameter {self.ell.name}. ' f'A value of -10 and 10 should be appropriate: ' f'Beta("{self.ell.name}", {self.ell.init_value}, -10, 10, ' f'{self.ell.status})' ) logger.warning(warning_msg) lx = pt.log(x) # your Expression log, which maps to pt.log regular = (x**ell - 1.0) / ell mclaurin = lx + ell * lx**2 + ell**2 * lx**3 / 6.0 + ell**3 * lx**4 / 24.0 close_to_zero = pt.lt(ell, 1.0e-5) & pt.gt(ell, -1.0e-5) smooth = pt.switch(close_to_zero, mclaurin, regular) result = pt.switch(pt.eq(x, 0.0), 0.0, smooth) return result return builder