"""Arithmetic expressions accepted by Biogeme: drawsMichel BierlaireFri Jun 27 2025, 14:41:17"""from__future__importannotationsimportloggingimportpandasaspdimportpymcaspmfromjaximportnumpyasjnpfrompytensor.tensorimportTensorVariablefrombiogeme.drawsimportPyMcDistributionFactory,get_distribution,pymc_distributionsfrombiogeme.exceptionsimportBiogemeErrorfrom.bayesianimportDimension,PymcModelBuilderTypefrom.elementary_expressionsimportElementaryfrom.elementary_typesimportTypeOfElementaryExpressionfrom.jax_utilsimportJaxFunctionTypelogger=logging.getLogger(__name__)
[docs]classDraws(Elementary):""" Draws for Monte-Carlo integration """expression_type=TypeOfElementaryExpression.DRAWSdef__init__(self,name:str,draw_type:str='NORMAL',dict_of_distributions:dict[str,PyMcDistributionFactory]|None=None,):"""Constructor :param name: name of the random variable with a series of draws. :type name: string :param draw_type: type of draws. :type draw_type: string """super().__init__(name)self.draw_type=draw_typeself._draw_dimension=Dimension.OBSself._is_complex=Trueself.dict_of_distributions:dict[str,PyMcDistributionFactory]=(dict_of_distributionsifdict_of_distributionsisnotNoneelsepymc_distributions)@propertydefdraw_dimension(self)->Dimension:returnself._draw_dimension
[docs]defdeep_flat_copy(self)->Draws:"""Provides a copy of the expression. It is deep in the sense that it generates copies of the children. It is flat in the sense that any `MultipleExpression` is transformed into the currently selected expression. The flat part is irrelevant for this expression. """returntype(self)(name=self.name,draw_type=self.draw_type)
def__str__(self)->str:returnf'Draws("{self.name}", "{self.draw_type}")'@propertydefsafe_draw_id(self)->int:"""Check the presence of the draw ID before its usage"""ifself.specific_idisNone:raiseBiogemeError(f"No id defined for draw {self.name}")returnself.specific_id
[docs]defrecursive_construct_jax_function(self,numerically_safe:bool)->JaxFunctionType:""" Generates a function to be used by biogeme_jax. Must be overloaded by each expression :return: the function takes two parameters: the parameters, and one row of the database. """defthe_jax_function(parameters:jnp.ndarray,one_row:jnp.ndarray,the_draws:jnp.ndarray,the_random_variables:jnp.ndarray,)->jnp.ndarray:returnjnp.take(the_draws,self.safe_draw_id,axis=-1)returnthe_jax_function
[docs]defrecursive_construct_pymc_model_builder(self)->PymcModelBuilderType:""" Generates recursively a function to be used by PyMc. Must be overloaded by each expression :return: the expression in TensorVariable format, suitable for PyMc """try:selected_distribution=get_distribution(name=self.draw_type,the_dict=self.dict_of_distributions)exceptValueErrorase:error_msg=f'Problem generating draws for {self.name}. {e}'raiseBiogemeError(error_msg)fromedefbuilder(dataframe:pd.DataFrame)->TensorVariable:model=pm.modelcontext(None)# Get current active model contextifself.nameinmodel.named_vars:returnmodel.named_vars[self.name]# return pm.Normal(name=self.name, mu=0, sigma=1)the_distribution=selected_distribution(name=self.name,dims=self.draw_dimension)# ic(self.name, the_distribution)returnthe_distributionreturnbuilder