"""
Raw Bayesian estimation results
built from ArviZ InferenceData (PyMC).
Michel Bierlaire
Mon Oct 20 2025, 17:18:07
"""
from __future__ import annotations
import contextlib
import logging
from datetime import timedelta
import arviz as az
import xarray as xr
from biogeme.tools import print_file_size, timeit
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# NetCDF-only RawBayesianResults class
# -----------------------------------------------------------------------------
[docs]
class RawBayesianResults:
"""
Minimal, NetCDF-only container of Bayesian estimation results.
This class *holds* an ArviZ InferenceData (PyMC posterior, etc.) and a
handful of metadata that cannot be robustly deduced from the InferenceData.
- No YAML sidecar is produced.
- All information is stored in a single NetCDF file via :meth:`save`.
- To reload, use :meth:`load` which reads both posterior and metadata
from the same NetCDF file.
Stored metadata (beyond what can be inferred from idata):
* model_name (str)
* user_notes (str)
* data_name (str)
* beta_names (list[str]) # model free/fixed parameter names for reporting
* sampler (str | None)
* target_accept (float | None)
* random_seed (int | None)
* run_time (timedelta | None)
"""
_META_GROUP = 'biogeme_meta'
def __init__(
self,
*,
idata: az.InferenceData,
model_name: str,
log_like_name: str,
number_of_observations: int,
user_notes: str = '',
data_name: str = '',
beta_names: list[str] | None = None,
sampler: str | None = None,
target_accept: float | None = None,
run_time: timedelta | None = None,
) -> None:
self._idata = idata
self._model_name = model_name
self._log_like_name = log_like_name
self._user_notes = user_notes
self._data_name = data_name
self._beta_names = beta_names or []
self._sampler = sampler
self._target_accept = target_accept
self._run_time = run_time
self._number_of_observations = number_of_observations
# ---------------- Properties inferred from idata ----------------
@property
def idata(self) -> az.InferenceData:
return self._idata
@property
def model_name(self) -> str:
return self._model_name
@property
def log_like_name(self) -> str:
return self._log_like_name
@property
def user_notes(self) -> str:
return self._user_notes
@property
def data_name(self) -> str:
return self._data_name
@property
def beta_names(self) -> list[str]:
return list(self._beta_names)
@property
def sampler(self) -> str | None:
return self._sampler
@property
def target_accept(self) -> float | None:
return self._target_accept
@property
def run_time(self) -> timedelta | None:
return self._run_time
@property
def chains(self) -> int:
try:
return int(self._idata.posterior.sizes.get('chain', 1))
except (AttributeError, KeyError, TypeError, ValueError):
return 1
@property
def draws(self) -> int:
try:
return int(self._idata.posterior.sizes.get('draw', 0))
except (AttributeError, KeyError, TypeError, ValueError):
return 0
@property
def number_of_observations(self) -> int:
return self._number_of_observations
# ---------------- Persistence (single NetCDF) ----------------
def _metadata_dataset(self) -> xr.Dataset:
"""Build a tiny xarray Dataset to store metadata as attributes."""
ds = xr.Dataset()
ds.attrs['model_name'] = self._model_name
ds.attrs['user_notes'] = self._user_notes
ds.attrs['data_name'] = self._data_name
ds.attrs['log_like_name'] = self._log_like_name
ds.attrs['number_of_observations'] = self._number_of_observations
ds.attrs['beta_names'] = list(self._beta_names)
ds.attrs['sampler'] = self._sampler if self._sampler is not None else ''
ds.attrs['target_accept'] = (
float(self._target_accept)
if self._target_accept is not None
else float('nan')
)
ds.attrs['run_time_seconds'] = (
float(self._run_time.total_seconds())
if self._run_time is not None
else float('nan')
)
return ds
[docs]
def save(self, path: str) -> None:
"""Write a single NetCDF file with posterior + metadata."""
logger.info(f'Save simulation results on {path}')
idata = self._idata.copy()
# attach (or replace) the metadata group
try:
import json
meta_ds = xr.Dataset()
meta_ds.attrs.update(
{
'model_name': self._model_name or '',
'user_notes': self._user_notes or '',
'data_name': self._data_name or '',
'log_like_name': self._log_like_name or '',
'number_of_observations': self._number_of_observations or '',
'beta_names': json.dumps(self._beta_names or []),
'sampler': self._sampler or '',
'target_accept': (
self._target_accept if self._target_accept is not None else ''
),
'run_time_seconds': (
self._run_time.total_seconds()
if self._run_time is not None
else ''
),
}
)
except (AttributeError, KeyError, TypeError, ValueError) as e:
logger.warning('Could not JSON-encode metadata cleanly: %s', e)
# Mirror metadata on a standard group so it survives az.from_netcdf
try:
import json as _json
posterior_attrs = {
'model_name': self._model_name or '',
'user_notes': self._user_notes or '',
'data_name': self._data_name or '',
'log_like_name': self._log_like_name or '',
'number_of_observations': self._number_of_observations or '',
'beta_names': _json.dumps(self._beta_names or []),
'sampler': self._sampler or '',
'target_accept': (
self._target_accept if self._target_accept is not None else ''
),
'run_time_seconds': (
self._run_time.total_seconds() if self._run_time is not None else ''
),
}
idata.posterior.attrs.update(posterior_attrs)
except (AttributeError, KeyError, TypeError, ValueError) as e:
logger.warning('Could not set posterior attrs metadata cleanly: %s', e)
idata.to_netcdf(path, engine='h5netcdf')
logger.info(f'Saved Bayesian results (posterior + metadata) to {path}')
[docs]
@classmethod
@timeit(label='load')
def load(cls, path: str) -> RawBayesianResults:
"""
Load from a single NetCDF file written by :meth:`save`.
Metadata are read from ``idata.posterior.attrs``, where they were
stored by :meth:`save`. No custom ``biogeme_meta`` group is used
anymore.
"""
logger.debug(f'Read file {path}')
# On Windows, NetCDF backends may keep the file handle open via xarray's
# file-manager cache. We therefore (i) reduce/disable caching when possible,
# (ii) eagerly load all groups into memory, and (iii) close datasets.
# Some xarray versions reject file_cache_maxsize=0, so we fall back safely.
try:
cache_ctx = xr.set_options(file_cache_maxsize=1)
except Exception:
cache_ctx = contextlib.nullcontext()
with cache_ctx:
idata = az.from_netcdf(path, engine='h5netcdf')
# Detach from disk: load all datasets and close any open file handles.
# This makes it safe to delete the NetCDF file immediately after loading.
try:
for group_name in idata.groups():
ds = getattr(idata, group_name, None)
if ds is None:
continue
# Ensure arrays are in memory (not lazy on-disk)
try:
ds.load()
except Exception:
pass
# Close backend resources if supported
try:
ds.close()
except Exception:
pass
except Exception:
# If anything goes wrong, keep behavior backward compatible.
pass
# Best-effort: clear xarray's global file cache to ensure no lingering handles.
try:
from xarray.backends.file_manager import FILE_CACHE
FILE_CACHE.clear()
except Exception:
pass
logger.info(f'Loaded NetCDF file size: {print_file_size(path)}')
# Defaults
model_name = ''
user_notes = ''
log_like_name = ''
number_of_observations: int = 0
data_name = ''
beta_names: list[str] = []
sampler: str | None = None
# Read from posterior attrs if available
try:
p_attrs = idata.posterior.attrs
except AttributeError as e:
logger.info(
'Posterior group missing or invalid in InferenceData loaded from %s: %s',
path,
e,
)
p_attrs = {}
import json as _json
model_name = p_attrs.get('model_name', model_name)
user_notes = p_attrs.get('user_notes', user_notes)
data_name = p_attrs.get('data_name', data_name)
log_like_name = p_attrs.get('log_like_name', log_like_name)
# number_of_observations may come as str, int, or be missing
no_raw = p_attrs.get('number_of_observations', number_of_observations)
try:
number_of_observations = int(no_raw)
except (TypeError, ValueError):
number_of_observations = 0
# beta_names stored as JSON string
beta_names_raw = p_attrs.get('beta_names', '[]')
try:
beta_names = _json.loads(beta_names_raw) or beta_names
except (TypeError, ValueError):
beta_names = []
sampler = p_attrs.get('sampler') or sampler
ta = p_attrs.get('target_accept')
try:
target_accept = float(ta) if ta not in (None, '') else None
except (TypeError, ValueError):
target_accept = None
rts = p_attrs.get('run_time_seconds')
try:
run_time = (
timedelta(seconds=float(rts))
if rts not in (None, '', float('nan'))
else None
)
except (TypeError, ValueError):
run_time = None
return cls(
idata=idata,
model_name=model_name,
user_notes=user_notes,
log_like_name=log_like_name,
number_of_observations=number_of_observations,
data_name=data_name,
beta_names=beta_names,
sampler=sampler,
target_accept=target_accept,
run_time=run_time,
)
# Convenience dict for quick reporting
[docs]
def to_dict(self) -> dict:
return {
'model_name': self.model_name,
'user_notes': self.user_notes,
'data_name': self.data_name,
'log_like_name': self.log_like_name,
'number_of_observations': self.number_of_observations,
'beta_names': self.beta_names,
'sampler': self.sampler,
'chains': self.chains,
'draws': self.draws,
'target_accept': self.target_accept,
'run_time_seconds': (
self.run_time.total_seconds() if self.run_time is not None else None
),
}