Source code for biogeme.calculator.single_formula

"""
Module in charge of the actual calculation of the formula on the database.

Michel Bierlaire
Wed Mar 26 19:30:57 2025
"""

from __future__ import annotations

import logging

import jax
import jax.numpy as jnp
import numpy as np

from biogeme.database import Database
from biogeme.exceptions import BiogemeError
from biogeme.expressions import (
    Expression,
    build_vectorized_function,
    collect_init_values,
)
from biogeme.floating_point import JAX_FLOAT, NUMPY_FLOAT
from biogeme.function_output import FunctionOutput, NamedFunctionOutput
from biogeme.model_elements import ModelElements
from biogeme.second_derivatives import SecondDerivativesMode

logger = logging.getLogger(__name__)


[docs] class CompiledFormulaEvaluator: """ Compiles and evaluates a Biogeme expression using JAX for efficient repeated computation. """ def __init__( self, model_elements: ModelElements, second_derivatives_mode: SecondDerivativesMode, numerically_safe: bool, ): """ Prepares and compiles the JAX function for evaluating a Biogeme expression. :param model_elements: All elements needed to calculate the expression. :param second_derivatives_mode: specifies how second derivatives are calculated. :param numerically_safe: improves the numerical stability of the calculations. """ from biogeme.expressions import build_vectorized_function self.model_elements = model_elements self.second_derivatives_mode = second_derivatives_mode self.numerically_safe = numerically_safe self.use_jit = model_elements.use_jit self.free_betas_names = ( self.model_elements.expressions_registry.free_betas_names ) self.data_jax = ( self.model_elements.database.data_jax if self.model_elements.database is not None else None ) self.draws_jax = ( self.model_elements.draws_management.draws_jax if self.model_elements.draws_management is not None else None ) n_rv = self.model_elements.expressions_registry.number_of_random_variables self.random_variables_jax = jnp.zeros((n_rv,), dtype=JAX_FLOAT) log_likelihood = self.model_elements.loglikelihood if log_likelihood is None: error_message = ( f'No expression found for log likelihood. ' f'Available expressions: {self.model_elements.formula_names}' ) raise BiogemeError(error_message) the_function = log_likelihood.recursive_construct_jax_function( numerically_safe=self.numerically_safe ) vectorized_function = build_vectorized_function( the_function, use_jit=self.use_jit ) if self.model_elements.weight is not None: weight_function = ( self.model_elements.weight.recursive_construct_jax_function( numerically_safe=numerically_safe ) ) vectorized_weight_function = build_vectorized_function( weight_function, use_jit=self.use_jit ) else: vectorized_weight_function = None def sum_function( params: list[float], data: jnp.ndarray, draws: jnp.ndarray, random_variables: jnp.ndarray, ) -> tuple[jnp.ndarray, jnp.ndarray]: values = vectorized_function(params, data, draws, random_variables) if vectorized_weight_function is not None: weights = vectorized_weight_function( params, data, draws, random_variables ) values *= weights return jnp.asarray(jnp.sum(values), dtype=JAX_FLOAT), values self.sum_function = jax.jit(sum_function) if self.use_jit else sum_function
[docs] def evaluate( self, the_betas: dict[str, float], gradient: bool, hessian: bool, bhhh: bool, ) -> FunctionOutput: free_betas_values = ( self.model_elements.expressions_registry.get_complete_betas_array( betas_dict=the_betas ) ) if not gradient: return self._evaluate_function_only(free_betas_values) if bhhh: if hessian: return self._evaluate_autodiff_hessian_bhhh(free_betas_values) else: return self._evaluate_bhhh_only(free_betas_values, use_jit=self.use_jit) if hessian: return self._evaluate_autodiff_hessian(free_betas_values) return self._evaluate_function_and_gradient(free_betas_values)
def _evaluate_function_only(self, free_betas_values): value_jax, _ = self.sum_function( free_betas_values, self.data_jax, self.draws_jax, self.random_variables_jax, ) return FunctionOutput( function=float(value_jax), gradient=None, hessian=None, bhhh=None, ) def _evaluate_function_and_gradient(self, free_betas_values): value_and_grad_fn = jax.value_and_grad( lambda p, d, r, rv: self.sum_function(p, d, r, rv)[0], argnums=0 ) value, the_gradient = value_and_grad_fn( free_betas_values, self.data_jax, self.draws_jax, self.random_variables_jax, ) return FunctionOutput( function=float(value), gradient=np.asarray(the_gradient, dtype=NUMPY_FLOAT), hessian=None, bhhh=None, ) def _evaluate_autodiff_hessian(self, free_betas_values): if self.second_derivatives_mode == SecondDerivativesMode.NEVER: error_msg = 'The second derivatives are not supposed to be evaluated' raise BiogemeError(error_msg) value_and_grad_fn = jax.value_and_grad( lambda p, d, r, rv: self.sum_function(p, d, r, rv)[0], argnums=0, ) value, the_gradient = value_and_grad_fn( free_betas_values, self.data_jax, self.draws_jax, self.random_variables_jax, ) if jnp.all(the_gradient == 0.0): the_hessian = np.zeros( (len(free_betas_values), len(free_betas_values)), dtype=NUMPY_FLOAT ) elif self.second_derivatives_mode == SecondDerivativesMode.FINITE_DIFFERENCES: the_hessian = self._evaluate_finite_difference_hessian(free_betas_values) else: hessian_fn = jax.jacfwd( jax.grad( lambda p, d, r, rv: self.sum_function(p, d, r, rv)[0], argnums=0, ), argnums=0, ) hess_autodiff = hessian_fn( free_betas_values, self.data_jax, self.draws_jax, self.random_variables_jax, ) if jnp.any(jnp.isnan(hess_autodiff)): logger.warning( 'The calculation of second derivatives generated numerical errors.' ) # raise BiogemeError( # 'The calculation of second derivatives generated numerical errors.' # ) the_hessian = np.asarray(hess_autodiff, dtype=NUMPY_FLOAT) return FunctionOutput( function=float(value), gradient=np.asarray(the_gradient, dtype=NUMPY_FLOAT), hessian=the_hessian, bhhh=None, ) def _evaluate_bhhh_only(self, free_betas_values, use_jit: bool): _, individual_values = self.sum_function( free_betas_values, self.data_jax, self.draws_jax, self.random_variables_jax, ) def one_gradient(p, d, r, rv): loglik_fn = ( self.model_elements.loglikelihood.recursive_construct_jax_function( numerically_safe=self.numerically_safe ) ) vectorized_function = build_vectorized_function(loglik_fn, use_jit=use_jit) draw_values = vectorized_function(p, d[None, :], r[None, :, :], rv) return jnp.mean(draw_values) per_obs_grad_fn = ( jax.jit( jax.vmap(jax.grad(one_gradient, argnums=0), in_axes=(None, 0, 0, 0)) ) if self.use_jit else jax.vmap(jax.grad(one_gradient, argnums=0), in_axes=(None, 0, 0, 0)) ) free_betas_values_jnp = jnp.asarray(free_betas_values, dtype=JAX_FLOAT) random_variables_broadcast = jnp.tile( self.random_variables_jax[None, :], (self.data_jax.shape[0], 1) ) individual_gradients = per_obs_grad_fn( free_betas_values_jnp, self.data_jax, self.draws_jax, random_variables_broadcast, ) expected_shape = (self.data_jax.shape[0], len(free_betas_values)) if individual_gradients.shape != expected_shape: individual_gradients = jnp.tile( individual_gradients, (self.data_jax.shape[0], 1) ) the_gradient = jnp.sum(individual_gradients, axis=0) bhhh_matrix = jnp.sum( jnp.stack([jnp.outer(g, g) for g in individual_gradients]), axis=0 ) return FunctionOutput( function=float(jnp.sum(individual_values)), gradient=np.asarray(the_gradient, dtype=NUMPY_FLOAT), hessian=None, bhhh=np.asarray(bhhh_matrix, dtype=NUMPY_FLOAT), ) def _evaluate_autodiff_hessian_bhhh(self, free_betas_values): bhhh_result = self._evaluate_bhhh_only(free_betas_values, use_jit=self.use_jit) hessian = self._evaluate_autodiff_hessian(free_betas_values).hessian return FunctionOutput( function=bhhh_result.function, gradient=bhhh_result.gradient, hessian=hessian, bhhh=bhhh_result.bhhh, ) def _evaluate_finite_difference_hessian(self, free_betas_values): import scipy.optimize as so def func_for_fd(betas_array): return float( self.sum_function( betas_array, self.data_jax, self.draws_jax, self.random_variables_jax, )[0] ) eps = np.sqrt(np.finfo(float).eps) n = len(free_betas_values) the_hessian = np.zeros((n, n), dtype=NUMPY_FLOAT) for i in range(n): x0 = np.array(free_betas_values) ei = np.zeros_like(x0) ei[i] = eps g_plus = so.approx_fprime(x0 + ei, func_for_fd, eps) g_minus = so.approx_fprime(x0 - ei, func_for_fd, eps) the_hessian[i, :] = (g_plus - g_minus) / (2 * eps) return the_hessian
[docs] def evaluate_individual( self, the_betas: dict[str, float], ) -> np.ndarray: """ Evaluates the compiled expression using provided beta values and returns the value of the expression for each observation. :param the_betas: Dictionary of parameter names to values. :return: A numpy array with one value per observation. """ free_betas_values = ( self.model_elements.expressions_registry.get_complete_betas_array( betas_dict=the_betas ) ) _, individual_values = self.sum_function( free_betas_values, self.data_jax, self.draws_jax, self.random_variables_jax ) return np.asarray(individual_values, dtype=NUMPY_FLOAT)
[docs] def calculate_single_formula( model_elements: ModelElements, the_betas: dict[str, float], gradient: bool, hessian: bool, bhhh: bool, second_derivatives_mode: SecondDerivativesMode, numerically_safe: bool, ) -> FunctionOutput: """ Evaluates a single Biogeme expression using JAX, optionally computing the gradient and Hessian. :param model_elements: All elements needed to calculate the expression. :param the_betas: Dictionary of parameter names to values. :param gradient: If True, compute the gradient. :param hessian: If True, compute the Hessian (requires gradient=True). :param bhhh: Unused here, included for compatibility. :param second_derivatives_mode: specifies how second derivatives are calculated. :param numerically_safe: improves the numerical stability of the calculations. :return: A BiogemeFunctionOutput with the value, gradient, and optionally the Hessian. """ the_compiled_formula = CompiledFormulaEvaluator( model_elements=model_elements, second_derivatives_mode=second_derivatives_mode, numerically_safe=numerically_safe, ) return the_compiled_formula.evaluate( the_betas=the_betas, gradient=gradient, hessian=hessian, bhhh=bhhh )
[docs] def calculate_single_formula_from_expression( expression: Expression, database: Database, number_of_draws: int, the_betas: dict[str, float], second_derivatives_mode: SecondDerivativesMode, numerically_safe: bool, use_jit: bool, ) -> float: model_elements = ModelElements.from_expression_and_weight( log_like=expression, weight=None, database=database, number_of_draws=number_of_draws, use_jit=use_jit, ) result = calculate_single_formula( model_elements=model_elements, second_derivatives_mode=second_derivatives_mode, numerically_safe=numerically_safe, the_betas=the_betas, gradient=False, hessian=False, bhhh=False, ) return result.function
[docs] def evaluate_formula( model_elements: ModelElements, the_betas: dict[str, float], second_derivatives_mode: SecondDerivativesMode, numerically_safe: bool, ) -> float: """ Evaluates a single Biogeme expression using JAX. :param model_elements: All elements needed to calculate the expression. :param the_betas: Dictionary of parameter names to values. :param second_derivatives_mode: specifies how second derivatives are calculated. :param numerically_safe: improves the numerical stability of the calculations. :return: the value of the expression. """ result = calculate_single_formula( model_elements=model_elements, the_betas=the_betas, gradient=False, hessian=False, bhhh=False, second_derivatives_mode=second_derivatives_mode, numerically_safe=numerically_safe, ) return result.function
[docs] def evaluate_model_per_row( model_elements: ModelElements, the_betas: dict[str, float], second_derivatives_mode: SecondDerivativesMode, numerically_safe: bool, ) -> np.ndarray: """ Evaluates a Biogeme expression for each entry in the database and returns individual results. This function compiles the expression using JAX, applies it to all observations in the database, and returns a NumPy array containing the evaluated values per observation. The result is not aggregated or summed. :param model_elements: All elements needed to calculate the expression. :param the_betas: Dictionary mapping parameter names to their values. :param second_derivatives_mode: specifies how second derivatives are calculated. :param numerically_safe: improves the numerical stability of the calculations. :return: A NumPy array of values, one for each observation in the database. """ the_compiled_formula = CompiledFormulaEvaluator( model_elements=model_elements, second_derivatives_mode=second_derivatives_mode, numerically_safe=numerically_safe, ) return the_compiled_formula.evaluate_individual(the_betas=the_betas)
[docs] def evaluate_expression( expression: Expression, numerically_safe: bool, use_jit: bool, database: Database | None = None, betas: dict[str, float] | None = None, number_of_draws: int = 1000, aggregation: bool = False, ) -> np.ndarray | float: """Evaluate an arithmetic expression :param expression: the expression to be evaluated :param numerically_safe: if True, the numerical stability of the evaluation is improved, possibly at the expense of calculation speed. Set it to False except if necessary. :param use_jit: if True, performs just-in-time compilation. :param database: database, needed if the expression involves `Variable` :param betas: values of the parameters, if the expression involves `Beta` :param number_of_draws: number of draws for Monte Carlo integration, if the expression involves it. :param aggregation: if True, the sum over all rows is calculated. If False, the value for each row is returned. """ if database is None: database = Database.dummy_database() model_elements = ModelElements.from_expression_and_weight( log_like=expression, weight=None, database=database, number_of_draws=number_of_draws, use_jit=use_jit, ) if betas is None: betas = collect_init_values(expression=expression) if aggregation: return evaluate_formula( model_elements=model_elements, the_betas=betas, second_derivatives_mode=SecondDerivativesMode.NEVER, numerically_safe=numerically_safe, ) return evaluate_model_per_row( model_elements=model_elements, the_betas=betas, second_derivatives_mode=SecondDerivativesMode.NEVER, numerically_safe=numerically_safe, )
[docs] def get_value_and_derivatives( expression: Expression, numerically_safe: bool, use_jit: bool, betas: dict[str, float] | None = None, database: Database | None = None, number_of_draws: int = 1000, gradient: bool = True, hessian: bool = True, bhhh: bool = True, named_results: bool = False, ) -> FunctionOutput | NamedFunctionOutput: if database is None: from biogeme.database import Database database = Database.dummy_database() model_elements = ModelElements.from_expression_and_weight( log_like=expression, weight=None, database=database, number_of_draws=number_of_draws, use_jit=use_jit, ) the_compiled_formula = CompiledFormulaEvaluator( model_elements=model_elements, second_derivatives_mode=SecondDerivativesMode.ANALYTICAL, numerically_safe=numerically_safe, ) if betas is None: betas = collect_init_values(expression=expression) result: FunctionOutput = the_compiled_formula.evaluate( the_betas=betas, gradient=gradient, hessian=hessian, bhhh=bhhh ) if not named_results: return result named_results = NamedFunctionOutput( function_output=result, mapping=model_elements.expressions_registry.free_betas_indices, ) return named_results
[docs] def get_value_c( expression: Expression, numerically_safe: bool, use_jit: bool, database: Database | None = None, betas: dict[str, float] | None = None, number_of_draws: int = 1000, aggregation: bool = False, ) -> np.ndarray | float: """For backward compatibility. This function used to be a member of the Expression class.""" return evaluate_expression( expression=expression, numerically_safe=numerically_safe, database=database, betas=betas, number_of_draws=number_of_draws, aggregation=aggregation, use_jit=use_jit, )