Source code for biogeme.bayesian_estimation.raw_bayesian_results

"""
Raw Bayesian estimation results
built from ArviZ InferenceData (PyMC).

Michel Bierlaire
Mon Oct 20 2025, 17:18:07
"""

from __future__ import annotations

import contextlib
import logging
from datetime import timedelta

import arviz as az
import xarray as xr

from biogeme.tools import print_file_size, timeit

logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# NetCDF-only RawBayesianResults class
# -----------------------------------------------------------------------------
[docs] class RawBayesianResults: """ Minimal, NetCDF-only container of Bayesian estimation results. This class *holds* an ArviZ InferenceData (PyMC posterior, etc.) and a handful of metadata that cannot be robustly deduced from the InferenceData. - No YAML sidecar is produced. - All information is stored in a single NetCDF file via :meth:`save`. - To reload, use :meth:`load` which reads both posterior and metadata from the same NetCDF file. Stored metadata (beyond what can be inferred from idata): * model_name (str) * user_notes (str) * data_name (str) * beta_names (list[str]) # model free/fixed parameter names for reporting * sampler (str | None) * target_accept (float | None) * random_seed (int | None) * run_time (timedelta | None) """ _META_GROUP = 'biogeme_meta' def __init__( self, *, idata: az.InferenceData, model_name: str, log_like_name: str, number_of_observations: int, user_notes: str = '', data_name: str = '', beta_names: list[str] | None = None, sampler: str | None = None, target_accept: float | None = None, run_time: timedelta | None = None, ) -> None: self._idata = idata self._model_name = model_name self._log_like_name = log_like_name self._user_notes = user_notes self._data_name = data_name self._beta_names = beta_names or [] self._sampler = sampler self._target_accept = target_accept self._run_time = run_time self._number_of_observations = number_of_observations # ---------------- Properties inferred from idata ---------------- @property def idata(self) -> az.InferenceData: return self._idata @property def model_name(self) -> str: return self._model_name @property def log_like_name(self) -> str: return self._log_like_name @property def user_notes(self) -> str: return self._user_notes @property def data_name(self) -> str: return self._data_name @property def beta_names(self) -> list[str]: return list(self._beta_names) @property def sampler(self) -> str | None: return self._sampler @property def target_accept(self) -> float | None: return self._target_accept @property def run_time(self) -> timedelta | None: return self._run_time @property def chains(self) -> int: try: return int(self._idata.posterior.sizes.get('chain', 1)) except (AttributeError, KeyError, TypeError, ValueError): return 1 @property def draws(self) -> int: try: return int(self._idata.posterior.sizes.get('draw', 0)) except (AttributeError, KeyError, TypeError, ValueError): return 0 @property def number_of_observations(self) -> int: return self._number_of_observations # ---------------- Persistence (single NetCDF) ---------------- def _metadata_dataset(self) -> xr.Dataset: """Build a tiny xarray Dataset to store metadata as attributes.""" ds = xr.Dataset() ds.attrs['model_name'] = self._model_name ds.attrs['user_notes'] = self._user_notes ds.attrs['data_name'] = self._data_name ds.attrs['log_like_name'] = self._log_like_name ds.attrs['number_of_observations'] = self._number_of_observations ds.attrs['beta_names'] = list(self._beta_names) ds.attrs['sampler'] = self._sampler if self._sampler is not None else '' ds.attrs['target_accept'] = ( float(self._target_accept) if self._target_accept is not None else float('nan') ) ds.attrs['run_time_seconds'] = ( float(self._run_time.total_seconds()) if self._run_time is not None else float('nan') ) return ds
[docs] def save(self, path: str) -> None: """Write a single NetCDF file with posterior + metadata.""" logger.info(f'Save simulation results on {path}') idata = self._idata.copy() # attach (or replace) the metadata group try: import json meta_ds = xr.Dataset() meta_ds.attrs.update( { 'model_name': self._model_name or '', 'user_notes': self._user_notes or '', 'data_name': self._data_name or '', 'log_like_name': self._log_like_name or '', 'number_of_observations': self._number_of_observations or '', 'beta_names': json.dumps(self._beta_names or []), 'sampler': self._sampler or '', 'target_accept': ( self._target_accept if self._target_accept is not None else '' ), 'run_time_seconds': ( self._run_time.total_seconds() if self._run_time is not None else '' ), } ) except (AttributeError, KeyError, TypeError, ValueError) as e: logger.warning('Could not JSON-encode metadata cleanly: %s', e) # Mirror metadata on a standard group so it survives az.from_netcdf try: import json as _json posterior_attrs = { 'model_name': self._model_name or '', 'user_notes': self._user_notes or '', 'data_name': self._data_name or '', 'log_like_name': self._log_like_name or '', 'number_of_observations': self._number_of_observations or '', 'beta_names': _json.dumps(self._beta_names or []), 'sampler': self._sampler or '', 'target_accept': ( self._target_accept if self._target_accept is not None else '' ), 'run_time_seconds': ( self._run_time.total_seconds() if self._run_time is not None else '' ), } idata.posterior.attrs.update(posterior_attrs) except (AttributeError, KeyError, TypeError, ValueError) as e: logger.warning('Could not set posterior attrs metadata cleanly: %s', e) idata.to_netcdf(path, engine='h5netcdf') logger.info(f'Saved Bayesian results (posterior + metadata) to {path}')
[docs] @classmethod @timeit(label='load') def load(cls, path: str) -> RawBayesianResults: """ Load from a single NetCDF file written by :meth:`save`. Metadata are read from ``idata.posterior.attrs``, where they were stored by :meth:`save`. No custom ``biogeme_meta`` group is used anymore. """ logger.debug(f'Read file {path}') # On Windows, NetCDF backends may keep the file handle open via xarray's # file-manager cache. We therefore (i) reduce/disable caching when possible, # (ii) eagerly load all groups into memory, and (iii) close datasets. # Some xarray versions reject file_cache_maxsize=0, so we fall back safely. try: cache_ctx = xr.set_options(file_cache_maxsize=1) except Exception: cache_ctx = contextlib.nullcontext() with cache_ctx: idata = az.from_netcdf(path, engine='h5netcdf') # Detach from disk: load all datasets and close any open file handles. # This makes it safe to delete the NetCDF file immediately after loading. try: for group_name in idata.groups(): ds = getattr(idata, group_name, None) if ds is None: continue # Ensure arrays are in memory (not lazy on-disk) try: ds.load() except Exception: pass # Close backend resources if supported try: ds.close() except Exception: pass except Exception: # If anything goes wrong, keep behavior backward compatible. pass # Best-effort: clear xarray's global file cache to ensure no lingering handles. try: from xarray.backends.file_manager import FILE_CACHE FILE_CACHE.clear() except Exception: pass logger.info(f'Loaded NetCDF file size: {print_file_size(path)}') # Defaults model_name = '' user_notes = '' log_like_name = '' number_of_observations: int = 0 data_name = '' beta_names: list[str] = [] sampler: str | None = None # Read from posterior attrs if available try: p_attrs = idata.posterior.attrs except AttributeError as e: logger.info( 'Posterior group missing or invalid in InferenceData loaded from %s: %s', path, e, ) p_attrs = {} import json as _json model_name = p_attrs.get('model_name', model_name) user_notes = p_attrs.get('user_notes', user_notes) data_name = p_attrs.get('data_name', data_name) log_like_name = p_attrs.get('log_like_name', log_like_name) # number_of_observations may come as str, int, or be missing no_raw = p_attrs.get('number_of_observations', number_of_observations) try: number_of_observations = int(no_raw) except (TypeError, ValueError): number_of_observations = 0 # beta_names stored as JSON string beta_names_raw = p_attrs.get('beta_names', '[]') try: beta_names = _json.loads(beta_names_raw) or beta_names except (TypeError, ValueError): beta_names = [] sampler = p_attrs.get('sampler') or sampler ta = p_attrs.get('target_accept') try: target_accept = float(ta) if ta not in (None, '') else None except (TypeError, ValueError): target_accept = None rts = p_attrs.get('run_time_seconds') try: run_time = ( timedelta(seconds=float(rts)) if rts not in (None, '', float('nan')) else None ) except (TypeError, ValueError): run_time = None return cls( idata=idata, model_name=model_name, user_notes=user_notes, log_like_name=log_like_name, number_of_observations=number_of_observations, data_name=data_name, beta_names=beta_names, sampler=sampler, target_accept=target_accept, run_time=run_time, )
# Convenience dict for quick reporting
[docs] def to_dict(self) -> dict: return { 'model_name': self.model_name, 'user_notes': self.user_notes, 'data_name': self.data_name, 'log_like_name': self.log_like_name, 'number_of_observations': self.number_of_observations, 'beta_names': self.beta_names, 'sampler': self.sampler, 'chains': self.chains, 'draws': self.draws, 'target_accept': self.target_accept, 'run_time_seconds': ( self.run_time.total_seconds() if self.run_time is not None else None ), }