Source code for biogeme.expressions.collectors

"""Function collecting recursively information about expressions

Michel Bierlaire
Thu May 01 2025, 18:50:24
"""

from __future__ import annotations

from typing import Any, TYPE_CHECKING

from .beta_parameters import Beta
from .draws import Draws
from .random_variable import RandomVariable
from .variable import Variable

if TYPE_CHECKING:
    from .base_expressions import Expression


[docs] class ExpressionCollector: """Walks the tree and collects handler return values.""" def __init__(self): self._registry = {}
[docs] def register(self, expr_type: type[Expression]): """ Register a handler function for a specific expression type. The handler must return a list of results. :param expr_type: the type of Expression for which the handler should be used :return: decorator that registers the handler """ def decorator(func): self._registry[expr_type] = func return func return decorator
[docs] def walk(self, expr: Expression, context: Any = None) -> list[Any]: """ Traverse the expression tree and apply handlers to matching types. :param expr: the root expression to walk :param context: optional context object passed to handlers :return: list of collected results from the handlers """ return self._visit(expr, context)
def _visit(self, expr: Expression, context: dict[str, Any]) -> list[Any]: """ Recursively visit expressions and collect handler results. Each handler function must return a list. If multiple handlers are triggered during the traversal, their outputs are flattened into a single list. :param expr: current expression node :param context: context passed down to handler functions :return: list of results from handler invocations """ if context is None: context = {} if 'ancestors' not in context: context['ancestors'] = [] context['ancestors'].append(expr) results = [] handler = self._registry.get(type(expr)) if handler: result = handler(expr, context) if not isinstance(result, list): raise TypeError( f"Handler for {type(expr).__name__} must return a list, got {type(result).__name__}" ) results.extend(result) for child in expr.get_children(): child_result = self._visit(child, context) results.extend(child_result) context['ancestors'].pop() return results
[docs] def collect_init_values(expression: Expression) -> dict[str, float]: collector = ExpressionCollector() @collector.register(Beta) def collect_beta(expr: Beta, context) -> list[tuple[str, float]]: if expr.is_free: return [(expr.name, expr.init_value)] return [] collected = collector.walk(expression) return dict(collected)
[docs] def list_of_variables_in_expression( the_expression: Expression, ) -> list[Variable]: # Create walker walker = ExpressionCollector() @walker.register(Variable) def retrieve_controllers( expr: Variable, context: Any | None = None ) -> list[Variable]: return [expr] # Now use it return walker.walk(the_expression)
[docs] def list_of_all_betas_in_expression( the_expression: Expression, ) -> list[Beta]: # Create walker walker = ExpressionCollector() @walker.register(Beta) def retrieve_controllers(expr: Beta, context: Any | None = None) -> list[Beta]: return [expr] # Now use it return walker.walk(the_expression)
[docs] def list_of_free_betas_in_expression( the_expression: Expression, ) -> list[Beta]: # Create walker walker = ExpressionCollector() @walker.register(Beta) def retrieve_controllers(expr: Beta, context: Any | None = None) -> list[Beta]: return [expr] if expr.is_free else [] # Now use it return walker.walk(the_expression)
[docs] def list_of_fixed_betas_in_expression( the_expression: Expression, ) -> list[Beta]: # Create walker walker = ExpressionCollector() @walker.register(Beta) def retrieve_controllers(expr: Beta, context: Any | None = None) -> list[Beta]: return [] if expr.is_free else [expr] # Now use it return walker.walk(the_expression)
[docs] def list_of_random_variables_in_expression( the_expression: Expression, ) -> list[RandomVariable]: # Create walker walker = ExpressionCollector() @walker.register(RandomVariable) def retrieve_controllers( expr: RandomVariable, context: Any | None = None ) -> list[RandomVariable]: return [expr] # Now use it return walker.walk(the_expression)
[docs] def list_of_draws_in_expression( the_expression: Expression, ) -> list[Draws]: # Create walker walker = ExpressionCollector() @walker.register(Draws) def retrieve_controllers(expr: Draws, context: Any | None = None) -> list[Draws]: return [expr] # Now use it return walker.walk(the_expression)