"""Function collecting recursively information about expressionsMichel BierlaireThu May 01 2025, 18:50:24"""from__future__importannotationsfromtypingimportAny,TYPE_CHECKINGfrom.beta_parametersimportBetafrom.drawsimportDrawsfrom.random_variableimportRandomVariablefrom.variableimportVariableifTYPE_CHECKING:from.base_expressionsimportExpression
[docs]classExpressionCollector:"""Walks the tree and collects handler return values."""def__init__(self):self._registry={}
[docs]defregister(self,expr_type:type[Expression]):""" Register a handler function for a specific expression type. The handler must return a list of results. :param expr_type: the type of Expression for which the handler should be used :return: decorator that registers the handler """defdecorator(func):self._registry[expr_type]=funcreturnfuncreturndecorator
[docs]defwalk(self,expr:Expression,context:Any=None)->list[Any]:""" Traverse the expression tree and apply handlers to matching types. :param expr: the root expression to walk :param context: optional context object passed to handlers :return: list of collected results from the handlers """returnself._visit(expr,context)
def_visit(self,expr:Expression,context:dict[str,Any])->list[Any]:""" Recursively visit expressions and collect handler results. Each handler function must return a list. If multiple handlers are triggered during the traversal, their outputs are flattened into a single list. :param expr: current expression node :param context: context passed down to handler functions :return: list of results from handler invocations """ifcontextisNone:context={}if'ancestors'notincontext:context['ancestors']=[]context['ancestors'].append(expr)results=[]handler=self._registry.get(type(expr))ifhandler:result=handler(expr,context)ifnotisinstance(result,list):raiseTypeError(f"Handler for {type(expr).__name__} must return a list, got {type(result).__name__}")results.extend(result)forchildinexpr.get_children():child_result=self._visit(child,context)results.extend(child_result)context['ancestors'].pop()returnresults
[docs]deflist_of_variables_in_expression(the_expression:Expression,)->list[Variable]:# Create walkerwalker=ExpressionCollector()@walker.register(Variable)defretrieve_controllers(expr:Variable,context:Any|None=None)->list[Variable]:return[expr]# Now use itreturnwalker.walk(the_expression)
[docs]deflist_of_all_betas_in_expression(the_expression:Expression,)->list[Beta]:# Create walkerwalker=ExpressionCollector()@walker.register(Beta)defretrieve_controllers(expr:Beta,context:Any|None=None)->list[Beta]:return[expr]# Now use itreturnwalker.walk(the_expression)
[docs]deflist_of_free_betas_in_expression(the_expression:Expression,)->list[Beta]:# Create walkerwalker=ExpressionCollector()@walker.register(Beta)defretrieve_controllers(expr:Beta,context:Any|None=None)->list[Beta]:return[expr]ifexpr.is_freeelse[]# Now use itreturnwalker.walk(the_expression)
[docs]deflist_of_fixed_betas_in_expression(the_expression:Expression,)->list[Beta]:# Create walkerwalker=ExpressionCollector()@walker.register(Beta)defretrieve_controllers(expr:Beta,context:Any|None=None)->list[Beta]:return[]ifexpr.is_freeelse[expr]# Now use itreturnwalker.walk(the_expression)
[docs]deflist_of_random_variables_in_expression(the_expression:Expression,)->list[RandomVariable]:# Create walkerwalker=ExpressionCollector()@walker.register(RandomVariable)defretrieve_controllers(expr:RandomVariable,context:Any|None=None)->list[RandomVariable]:return[expr]# Now use itreturnwalker.walk(the_expression)
[docs]deflist_of_draws_in_expression(the_expression:Expression,)->list[Draws]:# Create walkerwalker=ExpressionCollector()@walker.register(Draws)defretrieve_controllers(expr:Draws,context:Any|None=None)->list[Draws]:return[expr]# Now use itreturnwalker.walk(the_expression)