Source code for biogeme.profiling.timing

from __future__ import annotations

from contextlib import AbstractContextManager
from dataclasses import dataclass
from time import perf_counter
from typing import Any, Callable, Generic, TypeVar

import jax

T = TypeVar("T")


[docs] def block_until_ready(value: Any) -> None: """Recursively block on JAX results so timings reflect actual execution.""" if value is None: return if isinstance(value, tuple | list): for item in value: block_until_ready(item) return if isinstance(value, dict): for item in value.values(): block_until_ready(item) return try: jax.block_until_ready(value) except (TypeError, AttributeError): # Non-JAX values do not need synchronization. return
[docs] def timed_call( function: Callable[..., T], *args: Any, **kwargs: Any ) -> tuple[T, float]: """Execute a callable, block until ready, and return result and elapsed time.""" start = perf_counter() result = function(*args, **kwargs) block_until_ready(result) elapsed = perf_counter() - start return result, elapsed
[docs] @dataclass class TimedBlock(AbstractContextManager["TimedBlock"], Generic[T]): """Simple context manager for wall-clock timing.""" label: str | None = None elapsed: float = 0.0 start_time: float = 0.0 def __enter__(self) -> "TimedBlock": self.start_time = perf_counter() return self def __exit__(self, exc_type, exc_value, traceback) -> None: self.elapsed = perf_counter() - self.start_time return None