biogeme.profiling.jax_profile module

class biogeme.profiling.jax_profile.FunctionProfile(build_count=0, call_count=0, total_time=0.0, first_call_time=None, last_call_time=None, signatures=<factory>)[source]

Bases: object

Parameters:
  • build_count (int)

  • call_count (int)

  • total_time (float)

  • first_call_time (float | None)

  • last_call_time (float | None)

  • signatures (set[tuple[Any, ...]])

build_count: int = 0
call_count: int = 0
first_call_time: float | None = None
last_call_time: float | None = None
property mean_time: float
signatures: set[tuple[Any, ...]]
total_time: float = 0.0
class biogeme.profiling.jax_profile.JaxExecutionProfile(enabled=False, functions=<factory>, notes=<factory>)[source]

Bases: object

Lightweight profiler for JAX-related build and execution events.

Parameters:
add_note(message)[source]
Return type:

None

Parameters:

message (str)

enabled: bool = False
functions: dict[str, FunctionProfile]
notes: list[str]
record_build(name)[source]
Return type:

None

Parameters:

name (str)

summary()[source]
Return type:

str

timed_call(name, function, *args, **kwargs)[source]
Return type:

Any

Parameters:
  • name (str)

  • function (Callable[[...], Any])

  • args (Any)

  • kwargs (Any)

to_dict()[source]
Return type:

dict[str, Any]