"""This module defines the Draws class, which manages the generation andconversion of random draws for use in simulation-based models.Michel BierlaireThu Mar 27 08:42:16 2025"""from__future__importannotationsimportloggingfromdatetimeimportdatetime,timedeltaimportjax.numpyasjnpimportnumpyasnpimportpandasaspdfrombiogeme.floating_pointimportJAX_FLOATfrom.factoryimportDrawFactoryfrom.native_drawsimportRandomNumberGeneratorTupleLOW_NUMBER_OF_DRAWS=1000logger=logging.getLogger(__name__)
[docs]classDrawsManagement:""" Manages generation of simulation draws and conversion to JAX-compatible format. """def__init__(self,sample_size:int,number_of_draws:int,user_generators:dict[str,RandomNumberGeneratorTuple]|None=None,):""" Constructor for the Draws class. :param sample_size: The number of observations (rows) in the sample. :param number_of_draws: The number of draws to generate per observation. :param user_generators: Optional dictionary of user-defined random number generators. """ifsample_size<=0:raiseValueError(f'Incorrect sample size: {sample_size}')self.user_generators=user_generatorsself.sample_size:int=sample_sizeself.number_of_draws:int=number_of_drawsself.factory:DrawFactory=DrawFactory(user_generators)self.draws:np.ndarray|None=Noneself.draw_types:dict[str,str]|None=Noneself.processing_time:timedelta=timedelta(0)
[docs]defgenerate_draws(self,draw_types:dict[str,str],variable_names:list[str],)->None:""" Generates random draws using the configured factory. :param draw_types: Mapping of variable names to draw types. :param variable_names: List of variable names requiring draws. :return: The generated draws as a NumPy array. """self.draw_types=draw_typesifself.number_of_draws<=0:raiseValueError(f'Incorrect number of draws: {self.number_of_draws}')ifself.number_of_draws<=LOW_NUMBER_OF_DRAWS:warning_msg=f'The number of draws ({self.number_of_draws}) is low. The results may not be meaningful.'logger.warning(warning_msg)start_time=datetime.now()self.draws=self.factory.generate_draws(draw_types=draw_types,variable_names=variable_names,sample_size=self.sample_size,number_of_draws=self.number_of_draws,)self.processing_time=datetime.now()-start_time
@propertydefdraws_jax(self)->jnp.ndarray:""" Returns the generated draws as a JAX array. If no draws have been generated, returns an empty JAX array of shape (sample_size, 1, 1). :return: JAX-compatible array of draws. """ifself.drawsisnotNone:returnjnp.asarray(self.draws,dtype=JAX_FLOAT)returnjnp.zeros((self.sample_size,1,1),dtype=JAX_FLOAT)
[docs]defextract_slice(self,indices:pd.Index)->DrawsManagement:""" Create a new DrawsManagement instance containing only a subset of draws. This is useful to maintain consistency across estimation and validation datasets by slicing the original draws array according to the provided indices. :param indices: The indices used to extract the subset of draws. :return: A new DrawsManagement instance containing the sliced draws. """sliced_draw_management:DrawsManagement=DrawsManagement(sample_size=len(indices),number_of_draws=self.number_of_draws,user_generators=self.user_generators,)sliced_draw_management.draws=(self.draws[indices]ifself.drawsisnotNoneelseNone)returnsliced_draw_management
[docs]defremove_rows(self,indices:pd.Index)->None:"""Remove rows. Typically called when the database is modified."""ifself.drawsisNone:returnself.draws=self.draws[indices]