"""Arithmetic expressions accepted by Biogeme: Derive
Michel Bierlaire
Fri May 02 2025, 13:24:27
"""
from __future__ import annotations
import logging
import jax
import jax.numpy as jnp
from .base_expressions import ExpressionOrNumeric
from .collectors import list_of_all_betas_in_expression, list_of_variables_in_expression
from .jax_utils import JaxFunctionType
from .unary_expressions import UnaryOperator
from .variable import Variable
logger = logging.getLogger(__name__)
[docs]
class Derive(UnaryOperator):
"""
Derivative with respect to a variable
"""
def __init__(self, child: ExpressionOrNumeric, name: str):
"""Constructor
:param child: first arithmetic expression
:type child: biogeme.expressions.Expression
"""
super().__init__(child)
# Name of the elementary expression by which the derivative is taken
self.name = name
# Check if it is a variable or a parameter
list_of_variables = list_of_variables_in_expression(child)
self.variable: Variable | None = next(
(variable for variable in list_of_variables if variable.name == name), None
)
if self.variable is not None:
self.children.append(self.variable)
list_of_betas = list_of_all_betas_in_expression(child)
self.beta: Variable | None = next(
(beta for beta in list_of_betas if beta.name == name), None
)
if self.beta is not None:
self.children.append(self.beta)
if self.beta is None and self.variable is None:
logger.warning(
f'Variable {name} does not appear in expression {child}. Derivative is trivially zero.'
)
[docs]
def deep_flat_copy(self) -> Derive:
"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children.
It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression.
The flat part is irrelevant for this expression.
"""
copy_child = self.child.deep_flat_copy()
return type(self)(child=copy_child, name=self.name)
def __str__(self) -> str:
return f'Derive({self.child}, "{self.name}")'
def __repr__(self) -> str:
return f'Derive({self.child}, "{self.name}")'
[docs]
def recursive_construct_jax_function_variable(
self, numerically_safe: bool
) -> JaxFunctionType:
"""
Generates a function to be used by biogeme_jax. Must be overloaded by each expression
:return: the function takes two parameters: the parameters, and one row of the database.
"""
child_jax = self.child.recursive_construct_jax_function(
numerically_safe=numerically_safe
)
def the_jax_function(
parameters: jnp.ndarray,
one_row: jnp.ndarray,
the_draws: jnp.ndarray,
the_random_variables: jnp.ndarray,
) -> jnp.ndarray:
# Compute gradient with respect to the data row (i.e., Variable values)
grad_wrt_row = jax.grad(
lambda p, row, d, rv: child_jax(p, row, d, rv), argnums=1
)
# Get derivative w.r.t. Variable 'X' (assuming it’s index i in row)
index = self.variable.safe_variable_id
value = grad_wrt_row(parameters, one_row, the_draws, the_random_variables)[
index
]
return value
return the_jax_function
[docs]
def recursive_construct_jax_function_beta(
self, numerically_safe: bool
) -> JaxFunctionType:
"""
Generates a function to be used by biogeme_jax. Must be overloaded by each expression
:return: the function takes two parameters: the parameters, and one row of the database.
"""
child_jax = self.child.recursive_construct_jax_function(
numerically_safe=numerically_safe
)
def the_jax_function(
parameters: jnp.ndarray,
one_row: jnp.ndarray,
the_draws: jnp.ndarray,
the_random_variables: jnp.ndarray,
) -> jnp.ndarray:
# Compute gradient with respect to beta
grad_wrt_beta = jax.grad(
lambda p, row, d, rv: child_jax(p, row, d, rv), argnums=0
)
# Get derivative w.r.t. Variable 'X' (assuming it’s index i in row)
index = self.beta.safe_beta_id
value = grad_wrt_beta(parameters, one_row, the_draws, the_random_variables)[
index
]
return value
return the_jax_function
[docs]
def recursive_construct_jax_function(
self, numerically_safe: bool
) -> JaxFunctionType:
"""
Generates a function to be used by biogeme_jax. Must be overloaded by each expression
:return: the function takes two parameters: the parameters, and one row of the database.
"""
if self.beta is not None:
return self.recursive_construct_jax_function_beta(
numerically_safe=numerically_safe
)
if self.variable is not None:
return self.recursive_construct_jax_function_variable(
numerically_safe=numerically_safe
)
# Return zero function if neither beta nor variable is matched
return lambda p, row, d, rv: jnp.array(0.0)