Source code for biogeme.profiling.jax_profile

from __future__ import annotations

from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable

from .timing import timed_call


def _shape_and_dtype(value: Any) -> tuple[tuple[int, ...] | None, str | None]:
    if value is None:
        return None, None
    shape = getattr(value, "shape", None)
    dtype = getattr(value, "dtype", None)
    normalized_shape = tuple(shape) if shape is not None else None
    normalized_dtype = str(dtype) if dtype is not None else None
    return normalized_shape, normalized_dtype


def _make_signature(args: tuple[Any, ...], kwargs: dict[str, Any]) -> tuple[Any, ...]:
    positional = tuple(_shape_and_dtype(arg) for arg in args)
    keyword = tuple(
        sorted((key, _shape_and_dtype(value)) for key, value in kwargs.items())
    )
    return positional + keyword


[docs] @dataclass class FunctionProfile: build_count: int = 0 call_count: int = 0 total_time: float = 0.0 first_call_time: float | None = None last_call_time: float | None = None signatures: set[tuple[Any, ...]] = field(default_factory=set) @property def mean_time(self) -> float: return self.total_time / self.call_count if self.call_count else 0.0
[docs] @dataclass class JaxExecutionProfile: """Lightweight profiler for JAX-related build and execution events.""" enabled: bool = False functions: dict[str, FunctionProfile] = field( default_factory=lambda: defaultdict(FunctionProfile) ) notes: list[str] = field(default_factory=list)
[docs] def record_build(self, name: str) -> None: if not self.enabled: return self.functions[name].build_count += 1
[docs] def add_note(self, message: str) -> None: if not self.enabled: return self.notes.append(message)
[docs] def timed_call( self, name: str, function: Callable[..., Any], *args: Any, **kwargs: Any ) -> Any: if not self.enabled: return function(*args, **kwargs) profile = self.functions[name] profile.signatures.add(_make_signature(args, kwargs)) result, elapsed = timed_call(function, *args, **kwargs) profile.call_count += 1 profile.total_time += elapsed profile.last_call_time = elapsed if profile.first_call_time is None: profile.first_call_time = elapsed return result
[docs] def to_dict(self) -> dict[str, Any]: return { "enabled": self.enabled, "notes": list(self.notes), "functions": { name: { "build_count": stats.build_count, "call_count": stats.call_count, "total_time": stats.total_time, "first_call_time": stats.first_call_time, "last_call_time": stats.last_call_time, "mean_time": stats.mean_time, "distinct_signatures": len(stats.signatures), "signatures": [ repr(signature) for signature in sorted(stats.signatures, key=repr) ], } for name, stats in sorted(self.functions.items()) }, }
[docs] def summary(self) -> str: if not self.enabled: return "JaxExecutionProfile(enabled=False)" lines = ["JAX execution profile"] for name, stats in sorted(self.functions.items()): lines.append( f"- {name}: builds={stats.build_count}, " f"calls={stats.call_count}, " f"first={stats.first_call_time}, " f"last={stats.last_call_time}, " f"mean={stats.mean_time}, " f"distinct_signatures={len(stats.signatures)}" ) if self.notes: lines.append("Notes:") lines.extend(f" - {note}" for note in self.notes) return "\n".join(lines)