"""
Defines the strategy for sampling the MCMC based on hardware configuration and user's preferences.
Michel Bierlaire
Mon Oct 27 2025, 17:01:13
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
logger = logging.getLogger(__name__)
SAMPLER_STRATEGIES_DESCRIPTION = {
"numpyro-parallel": "one chain per device",
"numpyro-vectorized": "all chains on one device",
"pymc": "default PyMC sampler on CPU",
}
[docs]
def describe_strategies() -> str:
return ', '.join(
[f"'{name}' ({desc})" for name, desc in SAMPLER_STRATEGIES_DESCRIPTION.items()]
)
[docs]
@dataclass(frozen=True)
class SamplingConfig:
backend: str
chain_method: str | None
cores: int | None
target_accept: float
init: str | None
max_treedepth: int | None
nuts_kwargs: dict[str, Any] | None
# ------------------------- Hardware helpers (single responsibility) -------------------------
def _jax_available() -> bool:
"""
Return whether JAX/NumPyro sampling is importable.
:returns: ``True`` if both JAX and NumPyro sampling can be imported, ``False`` otherwise.
:rtype: bool
"""
try:
import jax # noqa: F401
from pymc.sampling.jax import sample_numpyro_nuts # noqa: F401
import numpyro # noqa: F401
except (ImportError, ModuleNotFoundError):
return False
return True
def _jax_devices_summary() -> tuple[int, list[str]]:
"""
Get a summary of available JAX devices.
:returns: A pair ``(n_devices, platforms)`` where ``n_devices`` is the number of devices detected and ``platforms`` is the list of platform names.
:rtype: tuple[int, list[str]]
"""
try:
import jax
except (ImportError, ModuleNotFoundError):
return 0, []
devices = jax.devices()
n_dev = len(devices)
platforms = sorted({getattr(d, "platform", "unknown") for d in devices})
return n_dev, platforms
def _cpu_core_count() -> int:
"""
Best-effort CPU core count.
:returns: The detected CPU core count, falling back to ``1`` if it cannot be determined.
:rtype: int
"""
try:
import os
return os.cpu_count() or 1
except ImportError:
return 1
def _select_chain_method(n_dev: int) -> str:
"""
Select a NumPyro chain method given the number of devices.
:param n_dev: Number of available JAX devices.
:type n_dev: int
:returns: ``'parallel'`` if more than one device is available, otherwise ``'vectorized'``.
:rtype: str
"""
return "parallel" if n_dev > 1 else "vectorized"
# ------------------------- Public factory -------------------------
[docs]
def make_sampling_config(
strategy: str,
target_accept: float,
) -> SamplingConfig:
"""
Create a :class:`SamplingConfig` from a short strategy string.
:param strategy: One of ``'automatic'``, ``'numpyro-parallel'``, ``'numpyro-vectorized'``, or ``'pymc'``.
:type strategy: str
:param target_accept: Target acceptance rate.
:type target_accept: float
:returns: A ready-to-use configuration object.
:rtype: SamplingConfig
:raises ValueError: If ``strategy`` is not one of the allowed values.
"""
key = (strategy or "").strip().lower()
if key == "automatic":
# Prefer JAX/NumPyro when available. If multiple devices → parallel chains; else vectorized.
if _jax_available():
n_dev, platforms = _jax_devices_summary()
method = _select_chain_method(n_dev if n_dev else 1)
logger.info(
"Auto sampling: JAX available (devices=%s, platforms=%s) → numpyro/%s",
n_dev,
",".join(platforms) if platforms else "unknown",
method,
)
return SamplingConfig(
backend="numpyro",
chain_method=method,
cores=None,
target_accept=target_accept,
init=None,
max_treedepth=None,
nuts_kwargs=None,
)
# Fallback to PyMC with a sensible cores default
cores_count = _cpu_core_count()
logger.info("Auto sampling: JAX not available → PyMC (cores=%s)", cores_count)
return SamplingConfig(
backend="pymc",
chain_method=None,
cores=cores_count,
target_accept=target_accept,
init=None,
max_treedepth=None,
nuts_kwargs=None,
)
if key == "numpyro-parallel":
if not _jax_available():
logger.warning(
"Requested numpyro-parallel but JAX/NumPyro not available; falling back to PyMC."
)
return SamplingConfig(
backend="pymc",
chain_method=None,
cores=_cpu_core_count(),
target_accept=target_accept,
init=None,
max_treedepth=None,
nuts_kwargs=None,
)
return SamplingConfig(
backend="numpyro",
chain_method="parallel",
cores=None,
target_accept=target_accept,
init=None,
max_treedepth=None,
nuts_kwargs=None,
)
if key == "numpyro-vectorized":
if not _jax_available():
logger.warning(
"Requested numpyro-vectorized but JAX/NumPyro not available; falling back to PyMC."
)
return SamplingConfig(
backend="pymc",
chain_method=None,
cores=_cpu_core_count(),
target_accept=target_accept,
init=None,
max_treedepth=None,
nuts_kwargs=None,
)
return SamplingConfig(
backend="numpyro",
chain_method="vectorized",
cores=None,
target_accept=target_accept,
init=None,
max_treedepth=None,
nuts_kwargs=None,
)
if key == "pymc":
return SamplingConfig(
backend="pymc",
chain_method=None,
cores=_cpu_core_count(),
target_accept=target_accept,
init=None,
max_treedepth=None,
nuts_kwargs=None,
)
raise ValueError(
f"Unknown sampling strategy: {strategy!r}. Allowed values are: 'automatic', 'numpyro-parallel', 'numpyro-vectorized', 'pymc'."
)