Source code for biogeme.latent_variables.normalization_plan
from __future__ import annotations
"""Normalization plan objects."""
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from .normalization_refs import ParameterRef
[docs]
class ConflictPolicy(str, Enum):
ERROR = 'error'
OVERWRITE = 'overwrite'
IGNORE_SAME = 'ignore_same'
[docs]
@dataclass(frozen=True, slots=True)
class Fixing:
"""One explicit parameter fixing.
:param target: Semantic parameter reference.
:param value: Numeric value to impose.
:param note: Optional human-readable note.
"""
target: ParameterRef
value: float
note: str | None = None
[docs]
class NormalizationPlan:
"""Collection of explicit fixings.
The plan itself does not decide whether a fixing is meaningful for a given
model. That responsibility belongs to validation.
"""
def __init__(self, fixings: Iterable[Fixing] | None = None) -> None:
self._fixings: dict[ParameterRef, Fixing] = {}
if fixings is not None:
for fixing in fixings:
self.add(fixing)
def __len__(self) -> int:
return len(self._fixings)
def __iter__(self):
for ref in sorted(self._fixings, key=lambda r: r.key()):
yield self._fixings[ref]
[docs]
def add(self, fixing: Fixing, *, on_conflict: ConflictPolicy = ConflictPolicy.ERROR) -> None:
existing = self._fixings.get(fixing.target)
if existing is None:
self._fixings[fixing.target] = fixing
return
if existing.value == fixing.value:
if on_conflict == ConflictPolicy.OVERWRITE:
self._fixings[fixing.target] = fixing
return
if on_conflict == ConflictPolicy.OVERWRITE:
self._fixings[fixing.target] = fixing
return
raise ValueError(
f"Conflicting fixings for '{fixing.target}': {existing.value} vs {fixing.value}."
)
[docs]
def get(self, target: ParameterRef) -> float | None:
fixing = self._fixings.get(target)
return None if fixing is None else fixing.value
[docs]
def get_fixing(self, target: ParameterRef) -> Fixing | None:
return self._fixings.get(target)
[docs]
def is_fixed(self, target: ParameterRef) -> bool:
return target in self._fixings
[docs]
def as_list(self) -> list[Fixing]:
return list(iter(self))