Source code for biogeme.tools.jax_multicore

#!/usr/bin/env python3
import os
import platform
import warnings


[docs] def count_cpu_devices() -> int: import jax # Count CPU devices cpu_devices = [d for d in jax.devices() if d.platform == "cpu"] n_cpu_devices = len(cpu_devices) return n_cpu_devices
def _platform_quick_command() -> str: """Return a platform-specific one-liner to set XLA_FLAGS for the current OS. Includes a Jupyter `%env` line for convenience. """ sys = platform.system().lower() if sys == "windows": # Provide both PowerShell and cmd.exe return ( "Windows PowerShell:\n" " $env:XLA_FLAGS=\"--xla_force_host_platform_device_count=<number_of_cores>\"\n\n" "Windows cmd.exe:\n" " set XLA_FLAGS=--xla_force_host_platform_device_count=<number_of_cores>\n\n" "Jupyter (new cell, before `import jax`):\n" " %env XLA_FLAGS=\"--xla_force_host_platform_device_count=<number_of_cores>\"\n" ) # Default to POSIX shells for macOS/Linux return ( "macOS / Linux (bash/zsh):\n" " export XLA_FLAGS=\"--xla_force_host_platform_device_count=<number_of_cores>\"\n\n" "Jupyter (new cell, before `import jax`):\n" " %env XLA_FLAGS=\"--xla_force_host_platform_device_count=<number_of_cores>\"\n" )
[docs] def report_jax_cpu_devices() -> str: # Count CPU devices n_cpu_devices = count_cpu_devices() lines = [ f"Detected CPU devices: {n_cpu_devices} | System logical cores: {os.cpu_count() or 'unknown'}", f"Current XLA_FLAGS: {os.environ.get('XLA_FLAGS', '(none set)')}", f"Platform: {platform.system()} {platform.release()} | Python: {platform.python_version()}", "", ] return "\n".join(lines)
[docs] def warning_cpu_devices() -> None: n_cpu_devices = count_cpu_devices() if n_cpu_devices <= 1: lines = [ "Note: JAX currently sees 1 CPU device. To parallelize across CPU devices, set XLA_FLAGS as above and restart Python/Jupyter.", _platform_quick_command(), ] warnings.warn("\n".join(lines), stacklevel=2) return
if __name__ == "__main__": # print(report_jax_cpu_devices()) warning_cpu_devices()