"""Define various items for JaxMichel BierlaireTue Mar 18 18:28:07 2025"""fromcollections.abcimportCallableimportjax.numpyasjnpfromjaximportjit,vmapJaxFunctionType=Callable[[jnp.ndarray,jnp.ndarray,jnp.ndarray,jnp.ndarray],jnp.array]
[docs]defbuild_vectorized_function(the_function,use_jit:bool):"""Build the function that is applied to each row of the databaser"""defvectorized_function(parameters,data,draws,random_variables):returnvmap(lambdarow,draw:the_function(parameters,row,draw,random_variables),in_axes=(0,0),)(data,draws)returnjit(vectorized_function)ifuse_jitelsevectorized_function