Source code for biogeme.model_elements.model_elements

from __future__ import annotations

import pandas as pd

from biogeme.audit_tuple import AuditTuple, merge_audit_tuples
from biogeme.constants import LOG_LIKE, WEIGHT
from biogeme.database import Database, audit_dataframe
from biogeme.draws import DrawsManagement, RandomNumberGeneratorTuple
from biogeme.exceptions import BiogemeError
from biogeme.expressions import Expression, audit_expression
from biogeme.expressions.base_expressions import LogitTuple
from biogeme.expressions_registry import ExpressionRegistry
from biogeme.function_output import FunctionOutput, NamedFunctionOutput
from biogeme.model_elements.audit import audit_chosen_alternative, audit_variables
from .database_adapter import ModelElementsAdapter


[docs] class ModelElements: """ Container for all key components required to define and estimate a model, using an adapter-based design. :param expressions: Dict of expressions to be evaluated. :param use_jit: Whether to use just-in-time compilation from Jax. :param adapter: Adapter implementing the model elements interface. :param number_of_draws: Number of Monte Carlo draws. :param draws_management: Optional object managing the draws. :param user_defined_draws: dict with user defined draw generators. :param expressions_registry: Optional expressions registry. """ loglikelihood_name: str = LOG_LIKE weight_name: str = WEIGHT def __init__( self, expressions: dict[str, Expression], use_jit: bool, adapter: ModelElementsAdapter, number_of_draws: int | None = None, draws_management: DrawsManagement | None = None, user_defined_draws: dict[str, RandomNumberGeneratorTuple] | None = None, expressions_registry: ExpressionRegistry | None = None, ): self.use_jit = use_jit self._adapter = adapter self.expressions = expressions self._adapter.prepare(self.expressions) self._database = self._adapter.database self._database.register_listener(self.on_database_update) self.number_of_draws: int | None = number_of_draws self._draws_management = draws_management self.user_defined_draws = user_defined_draws self._expressions_registry = expressions_registry # Validate inputs: only one must be provided if self.number_of_draws and self._draws_management: raise ValueError( "One of 'number_of_draws' or 'draws_management' must be provided (or none of them)." ) @property def expressions_registry(self) -> ExpressionRegistry: if self._expressions_registry is None: self._expressions_registry = self._adapter.build_registry(self.expressions) return self._expressions_registry @property def draws_management(self) -> DrawsManagement: if self._draws_management is None: self._draws_management = DrawsManagement( sample_size=self.sample_size, number_of_draws=self.number_of_draws, user_generators=self.user_defined_draws, ) if self.expressions_registry.requires_draws: self._draws_management.generate_draws( draw_types=self.expressions_registry.draw_types(), variable_names=self.expressions_registry.draws_names, ) if self._draws_management.sample_size != self.sample_size: error_msg = f'Inconsistent sizes: database[{self.sample_size}] and draws [{self.draws_management.sample_size}]' raise BiogemeError(error_msg) return self._draws_management @property def free_betas_names(self) -> list[str]: """Returns the names of the parameters that must be estimated :return: list of names of the parameters :rtype: list(str) """ return self.expressions_registry.free_betas_names @property def database(self) -> Database: return self._adapter.database
[docs] @classmethod def from_expression_and_weight( cls, log_like: Expression, adapter: ModelElementsAdapter, use_jit: bool, weight: Expression | None = None, number_of_draws: int = 0, draws_management: DrawsManagement | None = None, user_defined_draws: dict[str, RandomNumberGeneratorTuple] | None = None, ) -> ModelElements: """ Alternative constructor for two expressions. :param log_like: Expression for the log-likelihood. :param weight: Expression for the weight. :param use_jit: use just-in-time compilation from Jax :param adapter: Adapter implementing the model elements interface. :param number_of_draws: Number of Monte Carlo draws. :param draws_management: Optional object managing the draws. :param user_defined_draws: dict with user defined draw generators. """ expressions = ( {cls.loglikelihood_name: log_like} if weight is None else {cls.loglikelihood_name: log_like, cls.weight_name: weight} ) return cls( expressions=expressions, adapter=adapter, number_of_draws=number_of_draws, draws_management=draws_management, user_defined_draws=user_defined_draws, use_jit=use_jit, )
[docs] def on_database_update(self, updated_index: pd.Index): """Update the draws object to remain consistent with the new database""" self.draws_management.remove_rows(updated_index)
@property def sample_size(self) -> int: return self._adapter.sample_size @property def number_of_observations(self) -> int: return self._adapter.number_of_observations @property def loglikelihood(self) -> Expression | None: return self.expressions.get(self.loglikelihood_name) @property def weight(self) -> Expression | None: return self.expressions.get(self.weight_name) @property def formula_names(self) -> list[str]: return list(self.expressions.keys())
[docs] def generate_named_output( self, function_output: FunctionOutput ) -> NamedFunctionOutput: """Assigns parameter name to the entries of the gradient and the hessian""" return NamedFunctionOutput( function_output=function_output, mapping=self.expressions_registry.free_betas_indices, )
[docs] def audit(self) -> AuditTuple: """Audit the model elements""" # First, we audit the expressions. expression_audits = [ audit_expression(expr) for expr in self.expressions.values() ] # Second, we audit the database database_audits = [audit_dataframe(data=self.database.dataframe)] # Then, we check the variables variables_audits = [ audit_variables(expression=expr, database=self.database) for expr in self.expressions.values() ] # Finally, we verify the logit formula, if any. logit_audits = [] if self.loglikelihood is not None: logits_to_check: list[LogitTuple] = self.loglikelihood.logit_choice_avail() logit_audits = ( [ audit_chosen_alternative( choice=logit_to_check.choice, availability=logit_to_check.availabilities, database=self.database, use_jit=self.use_jit, ) for logit_to_check in logits_to_check ] if self.loglikelihood is not None else [] ) return merge_audit_tuples( expression_audits + database_audits + variables_audits + logit_audits )