biogeme.expressions.jax_utils module

Define various items for Jax

Michel Bierlaire Tue Mar 18 18:28:07 2025

biogeme.expressions.jax_utils.build_vectorized_function(the_function, use_jit, profiler=None, profile_name=None)[source]

Build and cache the row-wise vectorized version of a JAX function.

The returned callable applies the_function observation by observation, vectorizing jointly over the rows of data and draws while broadcasting parameters and random_variables.

Return type:

Callable[[Array, Array, Array, Array], Array]

Parameters:
  • the_function (Callable[[Array, Array, Array, Array], Array])

  • use_jit (bool)

  • profiler (JaxExecutionProfile | None)

  • profile_name (str | None)