Source code for biogeme.expressions.comparison_expressions

"""Arithmetic expressions accepted by Biogeme: comparison operators

Michel Bierlaire
Wed Mar 26 13:17:40 2025
"""

from __future__ import annotations

import logging

from jax import Array, lax

from .base_expressions import ExpressionOrNumeric
from .binary_expressions import BinaryOperator
from .jax_utils import JaxFunctionType

logger = logging.getLogger(__name__)


[docs] class ComparisonOperator(BinaryOperator): """Base class for comparison expressions.""" def __init__(self, left: ExpressionOrNumeric, right: ExpressionOrNumeric): """Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """ super().__init__(left, right)
[docs] class Equal(ComparisonOperator): """ Logical equal """ def __init__(self, left: ExpressionOrNumeric, right: ExpressionOrNumeric): """Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """ super().__init__(left, right)
[docs] def deep_flat_copy(self) -> Equal: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ left_copy = self.left.deep_flat_copy() right_copy = self.right.deep_flat_copy() return type(self)(left=left_copy, right=right_copy)
def __str__(self) -> str: return f'({self.left} == {self.right})' def __repr__(self) -> str: return f'Equal({repr(self.left)}, {repr(self.right)})'
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float """ r = 1 if self.left.get_value() == self.right.get_value() else 0 return r
[docs] def recursive_construct_jax_function(self, numerically_safe: bool): import jax.numpy as jnp left_fn = self.left.recursive_construct_jax_function( numerically_safe=numerically_safe ) right_fn = self.right.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ): is_equal = left_fn( parameters, one_row, the_draws, the_random_variables ) == right_fn(parameters, one_row, the_draws, the_random_variables) return lax.cond(is_equal, lambda _: 1.0, lambda _: 0.0, operand=None) return the_jax_function
[docs] class NotEqual(ComparisonOperator): """ Logical not equal """ def __init__(self, left: ExpressionOrNumeric, right: ExpressionOrNumeric): """Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """ super().__init__(left, right)
[docs] def deep_flat_copy(self) -> NotEqual: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ left_copy = self.left.deep_flat_copy() right_copy = self.right.deep_flat_copy() return type(self)(left=left_copy, right=right_copy)
def __str__(self) -> str: return f'({self.left} != {self.right})' def __repr__(self) -> str: return f'NotEqual({repr(self.left)}, {repr(self.right)})'
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float """ r = 1 if self.left.get_value() != self.right.get_value() else 0 return r
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: import jax.numpy as jnp left_fn: JaxFunctionType = self.left.recursive_construct_jax_function( numerically_safe=numerically_safe ) right_fn: JaxFunctionType = self.right.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ) -> Array: is_not_equal = left_fn( parameters, one_row, the_draws, the_random_variables ) != right_fn(parameters, one_row, the_draws, the_random_variables) return lax.cond(is_not_equal, lambda _: 1.0, lambda _: 0.0, operand=None) return the_jax_function
[docs] class LessOrEqual(ComparisonOperator): """ Logical less or equal """ def __init__(self, left: ExpressionOrNumeric, right: ExpressionOrNumeric): """Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """ super().__init__(left, right)
[docs] def deep_flat_copy(self) -> LessOrEqual: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ left_copy = self.left.deep_flat_copy() right_copy = self.right.deep_flat_copy() return type(self)(left=left_copy, right=right_copy)
def __str__(self) -> str: return f'({self.left} <= {self.right})' def __repr__(self) -> str: return f'LessOrEqual({repr(self.left)}, {repr(self.right)})'
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float """ r = 1 if self.left.get_value() <= self.right.get_value() else 0 return r
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: import jax.numpy as jnp left_fn: JaxFunctionType = self.left.recursive_construct_jax_function( numerically_safe=numerically_safe ) right_fn: JaxFunctionType = self.right.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ): is_less_or_equal = left_fn( parameters, one_row, the_draws, the_random_variables ) <= right_fn(parameters, one_row, the_draws, the_random_variables) return lax.cond( is_less_or_equal, lambda _: 1.0, lambda _: 0.0, operand=None ) return the_jax_function
[docs] class GreaterOrEqual(ComparisonOperator): """ Logical greater or equal """ def __init__(self, left: ExpressionOrNumeric, right: ExpressionOrNumeric): """Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """ super().__init__(left, right)
[docs] def deep_flat_copy(self) -> GreaterOrEqual: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ left_copy = self.left.deep_flat_copy() right_copy = self.right.deep_flat_copy() return type(self)(left=left_copy, right=right_copy)
def __str__(self) -> str: return f'({self.left} >= {self.right})' def __repr__(self) -> str: return f'GreaterOrEqual({repr(self.left)}, {repr(self.right)})'
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float """ r = 1 if self.left.get_value() >= self.right.get_value() else 0 return r
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: import jax.numpy as jnp left_fn: JaxFunctionType = self.left.recursive_construct_jax_function( numerically_safe=numerically_safe ) right_fn: JaxFunctionType = self.right.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ): is_greater_or_equal = left_fn( parameters, one_row, the_draws, the_random_variables ) >= right_fn(parameters, one_row, the_draws, the_random_variables) return lax.cond( is_greater_or_equal, lambda _: 1.0, lambda _: 0.0, operand=None ) return the_jax_function
[docs] class Less(ComparisonOperator): """ Logical less """ def __init__(self, left: ExpressionOrNumeric, right: ExpressionOrNumeric): """Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """ super().__init__(left, right)
[docs] def deep_flat_copy(self) -> Less: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ left_copy = self.left.deep_flat_copy() right_copy = self.right.deep_flat_copy() return type(self)(left=left_copy, right=right_copy)
def __str__(self) -> str: return f'({self.left} < {self.right})' def __repr__(self) -> str: return f'Less({repr(self.left)}, {repr(self.right)})'
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float """ r = 1 if self.left.get_value() < self.right.get_value() else 0 return r
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: import jax.numpy as jnp left_fn: JaxFunctionType = self.left.recursive_construct_jax_function( numerically_safe=numerically_safe ) right_fn: JaxFunctionType = self.right.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ): is_less = left_fn( parameters, one_row, the_draws, the_random_variables ) < right_fn(parameters, one_row, the_draws, the_random_variables) return lax.cond(is_less, lambda _: 1.0, lambda _: 0.0, operand=None) return the_jax_function
[docs] class Greater(ComparisonOperator): """ Logical greater """ def __init__(self, left: ExpressionOrNumeric, right: ExpressionOrNumeric): """Constructor :param left: first arithmetic expression :type left: biogeme.expressions.Expression :param right: second arithmetic expression :type right: biogeme.expressions.Expression """ super().__init__(left, right)
[docs] def deep_flat_copy(self) -> Greater: """Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """ left_copy = self.left.deep_flat_copy() right_copy = self.right.deep_flat_copy() return type(self)(left=left_copy, right=right_copy)
def __str__(self) -> str: return f'({self.left} > {self.right})' def __repr__(self) -> str: return f'Greater({repr(self.left)}, {repr(self.right)})'
[docs] def get_value(self) -> float: """Evaluates the value of the expression :return: value of the expression :rtype: float """ r = 1 if self.left.get_value() > self.right.get_value() else 0 return r
[docs] def recursive_construct_jax_function( self, numerically_safe: bool ) -> JaxFunctionType: import jax.numpy as jnp left_fn: JaxFunctionType = self.left.recursive_construct_jax_function( numerically_safe=numerically_safe ) right_fn: JaxFunctionType = self.right.recursive_construct_jax_function( numerically_safe=numerically_safe ) def the_jax_function( parameters: jnp.ndarray, one_row: jnp.ndarray, the_draws: jnp.ndarray, the_random_variables: jnp.ndarray, ): is_greater = left_fn( parameters, one_row, the_draws, the_random_variables ) > right_fn(parameters, one_row, the_draws, the_random_variables) return lax.cond(is_greater, lambda _: 1.0, lambda _: 0.0, operand=None) return the_jax_function