Source code for biogeme.bayesian_estimation.raw_bayesian_results

"""
Raw Bayesian estimation results
built from ArviZ InferenceData (PyMC).

Michel Bierlaire
Mon Oct 20 2025, 17:18:07
"""

from __future__ import annotations

import contextlib
import logging
from datetime import timedelta

import arviz as az
import xarray as xr
from biogeme.tools import print_file_size, timeit

logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# NetCDF-only RawBayesianResults class
# -----------------------------------------------------------------------------
[docs] class RawBayesianResults: """ Minimal, NetCDF-only container of Bayesian estimation results. This class *holds* an ArviZ InferenceData (PyMC posterior, etc.) and a handful of metadata that cannot be robustly deduced from the InferenceData. - No YAML sidecar is produced. - All information is stored in a single NetCDF file via :meth:`save`. - To reload, use :meth:`load` which reads both posterior and metadata from the same NetCDF file. Stored metadata (beyond what can be inferred from idata): * model_name (str) * user_notes (str) * data_name (str) * beta_names (list[str]) # model free/fixed parameter names for reporting * sampler (str | None) * target_accept (float | None) * random_seed (int | None) * run_time (timedelta | None) """ _META_GROUP = "biogeme_meta" def __init__( self, *, idata: az.InferenceData, model_name: str, log_like_name: str, number_of_observations: int, user_notes: str = "", data_name: str = "", beta_names: list[str] | None = None, sampler: str | None = None, target_accept: float | None = None, run_time: timedelta | None = None, ) -> None: self._idata = idata self._model_name = model_name self._log_like_name = log_like_name self._user_notes = user_notes self._data_name = data_name self._beta_names = beta_names or [] self._sampler = sampler self._target_accept = target_accept self._run_time = run_time self._number_of_observations = number_of_observations # ---------------- Properties inferred from idata ---------------- @property def idata(self) -> az.InferenceData: return self._idata @property def model_name(self) -> str: return self._model_name @property def log_like_name(self) -> str: return self._log_like_name @property def user_notes(self) -> str: return self._user_notes @property def data_name(self) -> str: return self._data_name @property def beta_names(self) -> list[str]: return list(self._beta_names) @property def sampler(self) -> str | None: return self._sampler @property def target_accept(self) -> float | None: return self._target_accept @property def run_time(self) -> timedelta | None: return self._run_time @property def chains(self) -> int: try: return int(self._idata.posterior.sizes.get("chain", 1)) except (AttributeError, KeyError, TypeError, ValueError): return 1 @property def draws(self) -> int: try: return int(self._idata.posterior.sizes.get("draw", 0)) except (AttributeError, KeyError, TypeError, ValueError): return 0 @property def number_of_observations(self) -> int: return self._number_of_observations # ---------------- Persistence (single NetCDF) ---------------- def _metadata_dataset(self) -> xr.Dataset: """Build a tiny xarray Dataset to store metadata as attributes.""" ds = xr.Dataset() ds.attrs["model_name"] = self._model_name ds.attrs["user_notes"] = self._user_notes ds.attrs["data_name"] = self._data_name ds.attrs["log_like_name"] = self._log_like_name ds.attrs["number_of_observations"] = self._number_of_observations ds.attrs["beta_names"] = list(self._beta_names) ds.attrs["sampler"] = self._sampler if self._sampler is not None else "" ds.attrs["target_accept"] = ( float(self._target_accept) if self._target_accept is not None else float("nan") ) ds.attrs["run_time_seconds"] = ( float(self._run_time.total_seconds()) if self._run_time is not None else float("nan") ) return ds
[docs] def save(self, path: str) -> None: """Write a single NetCDF file with posterior + metadata.""" logger.info(f'Save simulation results on {path}') idata = self._idata.copy() # attach (or replace) the metadata group try: import json meta_ds = xr.Dataset() meta_ds.attrs.update( { "model_name": self._model_name or "", "user_notes": self._user_notes or "", "data_name": self._data_name or "", "log_like_name": self._log_like_name or "", "number_of_observations": self._number_of_observations or "", "beta_names": json.dumps(self._beta_names or []), "sampler": self._sampler or "", "target_accept": ( self._target_accept if self._target_accept is not None else "" ), "run_time_seconds": ( self._run_time.total_seconds() if self._run_time is not None else "" ), } ) except (AttributeError, KeyError, TypeError, ValueError) as e: logger.warning("Could not JSON-encode metadata cleanly: %s", e) # Mirror metadata on a standard group so it survives az.from_netcdf try: import json as _json posterior_attrs = { "model_name": self._model_name or "", "user_notes": self._user_notes or "", "data_name": self._data_name or "", "log_like_name": self._log_like_name or "", "number_of_observations": self._number_of_observations or "", "beta_names": _json.dumps(self._beta_names or []), "sampler": self._sampler or "", "target_accept": ( self._target_accept if self._target_accept is not None else "" ), "run_time_seconds": ( self._run_time.total_seconds() if self._run_time is not None else "" ), } idata.posterior.attrs.update(posterior_attrs) except (AttributeError, KeyError, TypeError, ValueError) as e: logger.warning("Could not set posterior attrs metadata cleanly: %s", e) az.to_netcdf(idata, path, engine="h5netcdf") logger.info(f'Saved Bayesian results (posterior + metadata) to {path}')
[docs] @classmethod @timeit(label='load') def load(cls, path: str) -> RawBayesianResults: """ Load from a single NetCDF file written by :meth:`save`. Metadata are read from ``idata.posterior.attrs``, where they were stored by :meth:`save`. No custom ``biogeme_meta`` group is used anymore. """ logger.debug(f"Read file {path}") # On Windows, NetCDF backends may keep the file handle open via xarray's # file-manager cache. We therefore (i) reduce/disable caching when possible, # (ii) eagerly load all groups into memory, and (iii) close datasets. # Some xarray versions reject file_cache_maxsize=0, so we fall back safely. try: cache_ctx = xr.set_options(file_cache_maxsize=1) except Exception: cache_ctx = contextlib.nullcontext() with cache_ctx: idata = az.from_netcdf(path, engine="h5netcdf") # Detach from disk: load all datasets and close any open file handles. # This makes it safe to delete the NetCDF file immediately after loading. try: for group_name in idata.groups(): ds = getattr(idata, group_name, None) if ds is None: continue # Ensure arrays are in memory (not lazy on-disk) try: ds.load() except Exception: pass # Close backend resources if supported try: ds.close() except Exception: pass except Exception: # If anything goes wrong, keep behavior backward compatible. pass # Best-effort: clear xarray's global file cache to ensure no lingering handles. try: from xarray.backends.file_manager import FILE_CACHE FILE_CACHE.clear() except Exception: pass logger.info(f"Loaded NetCDF file size: {print_file_size(path)}") # Defaults model_name = "" user_notes = "" log_like_name = "" number_of_observations: int = 0 data_name = "" beta_names: list[str] = [] sampler: str | None = None # Read from posterior attrs if available try: p_attrs = idata.posterior.attrs except AttributeError as e: logger.info( "Posterior group missing or invalid in InferenceData loaded from %s: %s", path, e, ) p_attrs = {} import json as _json model_name = p_attrs.get("model_name", model_name) user_notes = p_attrs.get("user_notes", user_notes) data_name = p_attrs.get("data_name", data_name) log_like_name = p_attrs.get("log_like_name", log_like_name) # number_of_observations may come as str, int, or be missing no_raw = p_attrs.get("number_of_observations", number_of_observations) try: number_of_observations = int(no_raw) except (TypeError, ValueError): number_of_observations = 0 # beta_names stored as JSON string beta_names_raw = p_attrs.get("beta_names", "[]") try: beta_names = _json.loads(beta_names_raw) or beta_names except (TypeError, ValueError): beta_names = [] sampler = p_attrs.get("sampler") or sampler ta = p_attrs.get("target_accept") try: target_accept = float(ta) if ta not in (None, "") else None except (TypeError, ValueError): target_accept = None rts = p_attrs.get("run_time_seconds") try: run_time = ( timedelta(seconds=float(rts)) if rts not in (None, "", float("nan")) else None ) except (TypeError, ValueError): run_time = None return cls( idata=idata, model_name=model_name, user_notes=user_notes, log_like_name=log_like_name, number_of_observations=number_of_observations, data_name=data_name, beta_names=beta_names, sampler=sampler, target_accept=target_accept, run_time=run_time, )
# Convenience dict for quick reporting
[docs] def to_dict(self) -> dict: return { "model_name": self.model_name, "user_notes": self.user_notes, "data_name": self.data_name, "log_like_name": self.log_like_name, "number_of_observations": self.number_of_observations, "beta_names": self.beta_names, "sampler": self.sampler, "chains": self.chains, "draws": self.draws, "target_accept": self.target_accept, "run_time_seconds": ( self.run_time.total_seconds() if self.run_time is not None else None ), }