Source code for biogeme.profiling.environment
from __future__ import annotations
import os
from typing import Any
import jax
[docs]
def report_jax_environment() -> dict[str, Any]:
"""Return a snapshot of the JAX runtime and thread-related environment."""
return {
"backend": jax.default_backend(),
"device_count": jax.device_count(),
"local_device_count": jax.local_device_count(),
"devices": [str(device) for device in jax.devices()],
"jax_enable_x64": jax.config.read("jax_enable_x64"),
"XLA_FLAGS": os.environ.get("XLA_FLAGS"),
"OMP_NUM_THREADS": os.environ.get("OMP_NUM_THREADS"),
"MKL_NUM_THREADS": os.environ.get("MKL_NUM_THREADS"),
"OPENBLAS_NUM_THREADS": os.environ.get("OPENBLAS_NUM_THREADS"),
"NUMEXPR_NUM_THREADS": os.environ.get("NUMEXPR_NUM_THREADS"),
}