Source code for biogeme.model_elements.model_elements

from __future__ import annotations

import pandas as pd

from biogeme.audit_tuple import AuditTuple, display_messages, merge_audit_tuples
from biogeme.constants import LOG_LIKE, WEIGHT
from biogeme.database import Database, PanelDatabase, audit_dataframe
from biogeme.default_parameters import MISSING_VALUE
from biogeme.draws import DrawsManagement, RandomNumberGeneratorTuple
from biogeme.exceptions import BiogemeError
from biogeme.expressions import Expression, audit_expression
from biogeme.expressions.base_expressions import LogitTuple
from biogeme.expressions_registry import ExpressionRegistry
from biogeme.function_output import FunctionOutput, NamedFunctionOutput
from biogeme.model_elements.audit import audit_chosen_alternative, audit_variables


[docs] class ModelElements: """ Container for all key components required to define and estimate a model. :param expressions: Dict of expressions to be evaluated. """ def __init__( self, expressions: dict[str, Expression], use_jit: bool, database: Database | None, number_of_draws: int | None = None, draws_management: DrawsManagement | None = None, user_defined_draws: dict[str:RandomNumberGeneratorTuple] | None = None, expressions_registry: ExpressionRegistry = None, ): self.use_jit = use_jit if database is None: database = Database.dummy_database() self.panel_prepared: bool = False if database.panel_column is None: self._is_panel = False self._panel_database: PanelDatabase | None = None self._database: Database = database else: self._is_panel: bool = True self._panel_database = PanelDatabase( database=database, panel_column=database.panel_column ) self._database: Database = database self.expressions: dict[str, Expression] = expressions self._flat_database: Database | None = None self.number_of_draws: int = number_of_draws self.draws_management = draws_management self.user_defined_draws = user_defined_draws self.expressions_registry = expressions_registry # Validate inputs: only one must be provided if self.number_of_draws and self.draws_management: raise ValueError( "One of 'number_of_draws' or 'draws_management' must be provided (or none of them)." ) if self.expressions_registry is None: self.expressions_registry = ExpressionRegistry( self.expressions.values(), self.database ) if self._database is not None: self.database.register_listener(self.on_database_update) if self.draws_management is None: self.draws_management = DrawsManagement( sample_size=self.sample_size, number_of_draws=self.number_of_draws, user_generators=user_defined_draws, ) if self.expressions_registry.requires_draws: self.draws_management.generate_draws( draw_types=self.expressions_registry.draw_types(), variable_names=self.expressions_registry.draws_names, ) else: # Check consistency of sized if self.draws_management.sample_size != self.sample_size: error_msg = f'Inconsistent sizes: database[{self.sample_size}] and draws [{self.draws_management.sample_size}]' raise BiogemeError(error_msg) if self._is_panel: self._prepare_for_panel() self.panel_prepared: bool = True display_messages(self.audit())
[docs] def is_panel(self) -> bool: return self._is_panel
@property def database(self) -> Database: if not self.is_panel(): return self._database if not self.panel_prepared: self._prepare_for_panel() return self._flat_database
[docs] @classmethod def from_expression_and_weight( cls, log_like: Expression, database: Database, use_jit: bool, weight: Expression | None = None, number_of_draws: int = 0, draws_management: DrawsManagement | None = None, user_defined_draws: dict[str:RandomNumberGeneratorTuple] | None = None, ) -> ModelElements: """ Alternative constructor for two expressions. :param log_like: Expression for the log-likelihood. :param weight: Expression for the weight. :param database: The database containing data. :param number_of_draws: Number of Monte Carlo draws. :param draws_management: Optional object managing the draws. """ expressions = ( {LOG_LIKE: log_like} if weight is None else {LOG_LIKE: log_like, WEIGHT: weight} ) return cls( expressions=expressions, database=database, number_of_draws=number_of_draws, draws_management=draws_management, user_defined_draws=user_defined_draws, use_jit=use_jit, )
[docs] def on_database_update(self, updated_index: pd.Index): """Update the draws object to remain consistent with the new database""" self.draws_management.remove_rows(updated_index) self.panel_prepared = False
def _prepare_for_panel(self) -> None: flat_dataframe, maximum_number_of_observations_per_individual = ( self._panel_database.flatten_database(missing_data=MISSING_VALUE) ) self._flat_database = Database( name=f'flat {self._database.name}', dataframe=flat_dataframe ) for expression in self.expressions.values(): expression.set_maximum_number_of_observations_per_individual( max_number=maximum_number_of_observations_per_individual ) self.expressions_registry = ExpressionRegistry( self.expressions.values(), self._flat_database ) @property def sample_size(self) -> int: if self._database is None: return 0 return self.database.num_rows() @property def number_of_observations(self) -> int: if self._database is None: return 0 return self._database.num_rows() @property def loglikelihood(self) -> Expression | None: return self.expressions.get(LOG_LIKE) @property def weight(self) -> Expression | None: return self.expressions.get(WEIGHT) @property def formula_names(self) -> list[str]: return list(self.expressions.keys())
[docs] def generate_named_output( self, function_output: FunctionOutput ) -> NamedFunctionOutput: """Assigns parameter name to the entries of the gradient and the hessian""" return NamedFunctionOutput( function_output=function_output, mapping=self.expressions_registry.free_betas_indices, )
[docs] def audit(self) -> AuditTuple: """Audit the model elements""" # First, we audit the expressions. expression_audits = [ audit_expression(expr) for expr in self.expressions.values() ] # Second, we audit the database database_audits = ( [audit_dataframe(data=self.database.dataframe)] if self._database is not None else [] ) # Then, we check the variables variables_audits = [ audit_variables(expression=expr, database=self.database) for expr in self.expressions.values() ] # Finally, we verify the logit formula, if any. logit_audits = [] if self.loglikelihood is not None: logits_to_check: list[LogitTuple] = self.loglikelihood.logit_choice_avail() logit_audits = ( [ audit_chosen_alternative( choice=logit_to_check.choice, availability=logit_to_check.availabilities, database=self.database, use_jit=self.use_jit, ) for logit_to_check in logits_to_check ] if self.loglikelihood is not None else [] ) return merge_audit_tuples( expression_audits + database_audits + variables_audits + logit_audits )