Source code for biogeme.draws.management
"""
This module defines the Draws class, which manages the generation and
conversion of random draws for use in simulation-based models.
Michel Bierlaire
Thu Mar 27 08:42:16 2025
"""
from __future__ import annotations
import logging
from datetime import datetime, timedelta
import jax.numpy as jnp
import numpy as np
import pandas as pd
from biogeme.floating_point import JAX_FLOAT
from .factory import DrawFactory
from .native_draws import RandomNumberGeneratorTuple
LOW_NUMBER_OF_DRAWS = 1000
logger = logging.getLogger(__name__)
[docs]
class DrawsManagement:
"""
Manages generation of simulation draws and conversion to JAX-compatible format.
"""
def __init__(
self,
sample_size: int,
number_of_draws: int,
user_generators: dict[str, RandomNumberGeneratorTuple] | None = None,
):
"""
Constructor for the Draws class.
:param sample_size: The number of observations (rows) in the sample.
:param number_of_draws: The number of draws to generate per observation.
:param user_generators: Optional dictionary of user-defined random number generators.
"""
if sample_size <= 0:
raise ValueError(f'Incorrect sample size: {sample_size}')
self.user_generators = user_generators
self.sample_size: int = sample_size
self.number_of_draws: int = number_of_draws
self.factory: DrawFactory = DrawFactory(user_generators)
self.draws: np.ndarray | None = None
self.draw_types: dict[str, str] | None = None
self.processing_time: timedelta = timedelta(0)
[docs]
def generate_draws(
self,
draw_types: dict[str, str],
variable_names: list[str],
) -> None:
"""
Generates random draws using the configured factory.
:param draw_types: Mapping of variable names to draw types.
:param variable_names: List of variable names requiring draws.
:return: The generated draws as a NumPy array.
"""
self.draw_types = draw_types
if self.number_of_draws <= 0:
raise ValueError(f'Incorrect number of draws: {self.number_of_draws}')
if self.number_of_draws <= LOW_NUMBER_OF_DRAWS:
warning_msg = f'The number of draws ({self.number_of_draws}) is low. The results may not be meaningful.'
logger.warning(warning_msg)
start_time = datetime.now()
self.draws = self.factory.generate_draws(
draw_types=draw_types,
variable_names=variable_names,
sample_size=self.sample_size,
number_of_draws=self.number_of_draws,
)
self.processing_time = datetime.now() - start_time
@property
def draws_jax(self) -> jnp.ndarray:
"""
Returns the generated draws as a JAX array.
If no draws have been generated, returns an empty JAX array of shape (sample_size, 1, 1).
:return: JAX-compatible array of draws.
"""
if self.draws is not None:
return jnp.asarray(self.draws, dtype=JAX_FLOAT)
return jnp.zeros((self.sample_size, 1, 1), dtype=JAX_FLOAT)
[docs]
def remove_rows(self, indices: pd.Index) -> None:
"""Remove rows. Typically called when the database is modified."""
if self.draws is None:
return
self.draws = self.draws[indices]