Source code for biogeme.draws.pymc_draws
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import TypeAlias
import pymc as pm
from pymc.distributions import Distribution
PyMcDistributionFactory: TypeAlias = Callable[[str], Distribution]
pymc_distributions: dict[str, PyMcDistributionFactory] = {
"Cauchy": pm.Cauchy, # Heavy-tailed distribution defined by location and scale
"ChiSquared": pm.ChiSquared, # Distribution of a sum of squared standard normal variables
"Exponential": pm.Exponential, # Memoryless distribution for positive values, rate parameter
"Flat": pm.Flat, # Improper uniform prior over all real numbers (no bounds)
"Gumbel": pm.Gumbel, # Distribution of the maximum of samples from exponential families
"HalfCauchy": pm.HalfCauchy, # Positive-only Cauchy, often used as prior for scale parameters
"HalfFlat": pm.HalfFlat, # Improper uniform prior on positive reals
"HalfNormal": pm.HalfNormal, # Positive part of a normal distribution
"Logistic": pm.Logistic, # S-shaped distribution, similar to normal but with heavier tails
"LogNormal": pm.LogNormal, # Distribution of a variable whose log is normal
"Normal": pm.Normal, # Gaussian distribution defined by mean and standard deviation
"TruncatedNormal": pm.TruncatedNormal, # Normal distribution limited to a given interval
"Uniform": pm.Uniform, # Proper uniform distribution between specified bounds
"UniformSym": partial(pm.Uniform, lower=-1, upper=1),
"Weibull": pm.Weibull, # Flexible distribution for modeling lifetimes or failure times
}
[docs]
def get_distribution(
name: str, the_dict: dict[str, PyMcDistributionFactory]
) -> PyMcDistributionFactory:
"""Return a PyMC continuous distribution factory by name, ignoring case.
The returned callable can be used like any PyMC distribution constructor,
for example::
dist = get_distribution("Normal")
rv = dist("beta_time", mu=0.0, sigma=1.0)
"""
keymap: dict[str, PyMcDistributionFactory] = {
k.lower(): v for k, v in the_dict.items()
}
the_distribution = keymap.get(name.lower())
if the_distribution is None:
error_msg = (
f"{name} is not a valid distribution. Available distributions are "
f"{get_list_of_available_distributions()}"
)
raise ValueError(error_msg)
return the_distribution
[docs]
def get_list_of_available_distributions() -> list[str]:
return list(pymc_distributions.keys())