from __future__ import annotations
import numpy as np
import pandas as pd
import pytensor.tensor as pt
from pytensor.tensor.math import softplus as pt_softplus
from .base_expressions import Expression
from .bayesian import PymcModelBuilderType
from .convert import validate_and_convert
from .jax_utils import JaxFunctionType
from ..exceptions import BiogemeError
[docs]
class OrderedBase(Expression):
"""
Base class for ordered-response models (logit and probit).
This class implements the common logic for ordered discrete-choice models.
Given a latent variable :math:`\\eta_n` and ordered cutpoints
:math:`\\tau_1 < \\tau_2 < \\dots < \\tau_{K-1}`, the probability of observing
category :math:`y_k` is
.. math::
P(y = y_k \\mid \\eta, \\tau)
= \\mathrm{CDF}(\\tau_k - \\eta) - \\mathrm{CDF}(\\tau_{k-1} - \\eta),
where :math:`\\tau_0 = -\\infty` and :math:`\\tau_K = +\\infty`.
Subclasses must implement the appropriate cumulative distribution
function (CDF), either logistic (logit) or Gaussian (probit).
This class returns per-observation **probabilities**. To obtain per-observation
log-likelihoods, use the corresponding log-variants (e.g., :class:`OrderedLogLogit`).
:param eta: Expression defining the latent variable :math:`\\eta_n`.
:param cutpoints: List of expressions defining the cutpoints (length K-1).
:param y: Expression defining the observed categorical response.
:param categories: Ordered list of category labels (e.g. ``[1, 2, 3, 4, 5]``).
If ``None``, defaults to ``[0, 1, ..., K-1]``.
:param enforce_order: If ``True``, ensures that cutpoints are monotonically
increasing, using a softplus transform (JAX) or sorting (PyTensor).
:param eps: Lower bound for probabilities to avoid numerical issues.
:param neutral_labels: Labels that may appear in the data and must be
treated as “always valid”; their contribution is probability 1.
Useful to avoid crashes on placeholder/missing/special codes.
**Examples**
Ordered Logit with 5 Likert levels (valid labels 1..5) and two neutral codes 98/99:
.. code-block:: python
income = Variable('Income')
age = Variable('Age')
beta_income = Beta('beta_income', 0, None, None, 0)
beta_age = Beta('beta_age', 0, None, None, 0)
eta = beta_income * income + beta_age * age
# Four thresholds for five ordered responses
tau1 = Beta('tau1', -1, None, None, 0)
tau2 = Beta('tau2', 0, None, None, 0)
tau3 = Beta('tau3', 1, None, None, 0)
tau4 = Beta('tau4', 2, None, None, 0)
y = Variable('Satisfaction') # coded as 1, 2, 3, 4, 5, 98, 99
model = OrderedLogit(
eta=eta,
cutpoints=[tau1, tau2, tau3, tau4],
y=y,
categories=[1, 2, 3, 4, 5],
neutral_labels=[98, 99],
)
prob_vec = model.recursive_construct_pymc_model_builder()
Ordered Probit with the same structure:
.. code-block:: python
model = OrderedProbit(
eta=eta,
cutpoints=[tau1, tau2, tau3, tau4],
y=y,
categories=[1, 2, 3, 4, 5],
neutral_labels=[98, 99],
)
"""
def __init__(
self,
eta: Expression,
cutpoints: list[Expression],
y: Expression,
categories: list[float] | tuple[float, ...] | None = None,
neutral_labels: list[float] | tuple[float, ...] | None = None,
enforce_order: bool = True,
eps: float = 1e-12,
):
super().__init__()
self.eta = validate_and_convert(eta)
self.cutpoints = [validate_and_convert(c) for c in cutpoints]
self.y = validate_and_convert(y)
self.enforce_order = bool(enforce_order)
self.eps = float(eps)
# Determine K and validate 'categories'
K = len(self.cutpoints) + 1
if categories is None:
self.categories = np.arange(K, dtype=np.int32)
else:
cats = np.asarray(categories)
if cats.ndim != 1 or cats.size != K:
raise BiogemeError(
f"'categories' must be a 1-D sequence of length K={K}; got shape {cats.shape}."
)
self.categories = cats.astype(np.float64)
# Neutral/skip labels
if neutral_labels is None or len(neutral_labels) == 0:
self.neutral_labels = np.asarray([], dtype=np.float64)
self._has_neutrals = False
else:
self.neutral_labels = np.asarray(neutral_labels, dtype=np.float64)
self._has_neutrals = True
# Register children
self.children += [self.eta, self.y] + self.cutpoints
# --------------------------------------------------------------------------
# JAX builder
# --------------------------------------------------------------------------
[docs]
def recursive_construct_jax_function(
self, numerically_safe: bool
) -> JaxFunctionType:
"""
Build a JAX-compatible function returning per-observation probabilities.
:param numerically_safe: Whether to use numerically stable operators.
:return: A callable taking model parameters and one observation, and returning
the probability of the observed category.
"""
import jax
import jax.numpy as jnp
Km1 = len(self.cutpoints)
K = Km1 + 1
eps = self.eps
cats = jnp.asarray(self.categories)
neutrals = jnp.asarray(self.neutral_labels)
def get_val(expr, θ, row, draws, rvars):
fn = expr.recursive_construct_jax_function(
numerically_safe=numerically_safe
)
return fn(θ, row, draws, rvars)
def order_cuts(raw_tau: jnp.ndarray) -> jnp.ndarray:
"""Ensure monotonic cutpoints using softplus increments (differentiable)."""
if (not self.enforce_order) or (raw_tau.size == 0):
return raw_tau
tau0 = raw_tau[0]
deltas = jax.nn.softplus(jnp.diff(raw_tau))
return jnp.concatenate([jnp.array([tau0]), tau0 + jnp.cumsum(deltas)])
CDF = self._cdf_jax
def the_jax(θ, one_row, draws, rvars):
# Scalars for a single observation 'one_row'
eta = get_val(self.eta, θ, one_row, draws, rvars)
raw = jnp.array(
[get_val(c, θ, one_row, draws, rvars) for c in self.cutpoints]
)
tau = order_cuts(raw)
# Observed label and matches
y_val = get_val(self.y, θ, one_row, draws, rvars)
match_cat = cats == y_val
has_cat = jnp.any(match_cat)
kpos = jnp.argmax(
match_cat
) # 0..K-1 (undefined if has_cat=False, masked later)
has_neutral = jnp.logical_and(neutrals.size > 0, jnp.any(neutrals == y_val))
# Category probabilities via CDF differences
if K == 1:
probs = jnp.array([1.0])
else:
p0 = CDF(tau[0] - eta)
mids = [
CDF(tau[k] - eta) - CDF(tau[k - 1] - eta) for k in range(1, Km1)
]
pK = 1.0 - CDF(tau[-1] - eta)
probs = jnp.concatenate(
[jnp.array([p0]), *(jnp.array([m]) for m in mids), jnp.array([pK])]
)
probs = jnp.clip(probs, eps, 1.0 - eps)
p_valid = probs[kpos]
p = jnp.where(has_neutral, 1.0, jnp.where(has_cat, p_valid, eps))
return p
return the_jax
# --------------------------------------------------------------------------
# PyMC / PyTensor builder
# --------------------------------------------------------------------------
[docs]
def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType:
"""
Build a PyTensor-compatible function returning per-observation probabilities.
:return: Callable taking a pandas DataFrame and returning a PyTensor
variable of probabilities of the observed categories.
"""
K = len(self.cutpoints) + 1
eps = pt.as_tensor_variable(self.eps)
cats_const = pt.constant(np.asarray(self.categories))
neutrals_const = pt.constant(np.asarray(self.neutral_labels))
eta_b = self.eta.recursive_construct_pymc_model_builder()
cut_b = [c.recursive_construct_pymc_model_builder() for c in self.cutpoints]
y_b = self.y.recursive_construct_pymc_model_builder()
def builder(dataframe: pd.DataFrame) -> pt.TensorVariable:
eta = eta_b(dataframe) # (N,)
# Raw cutpoints stacked: (N, K-1) or broadcasted to that
if len(cut_b):
raw = pt.stack(
[cb(dataframe) for cb in cut_b], axis=1
) # (N, K-1) or (K-1,)
# Ensure 2D shape (N, K-1) for downstream ops
if raw.ndim == 1:
raw = pt.shape_padaxis(raw, 0) + pt.zeros_like(eta)[:, None]
if self.enforce_order:
# Monotone cutpoints via softplus increments (same logic as JAX):
# tau0 = raw[:, 0]; deltas = softplus(diff(raw)); tau = concat([tau0, tau0 + cumsum(deltas)])
tau0 = raw[:, 0:1]
deltas = pt_softplus(raw[:, 1:] - raw[:, :-1])
tau = pt.concatenate(
[tau0, tau0 + pt.cumsum(deltas, axis=1)], axis=1
)
else:
tau = raw
else:
tau = pt.zeros((eta.shape[0], 0), dtype=eta.dtype) # K == 1
CDF = self._cdf_pt
# Probabilities matrix (N, K)
if K == 1:
probs = pt.ones((eta.shape[0], 1), dtype=eta.dtype)
else:
c = tau
p0 = CDF(c[:, 0] - eta) # (N,)
mids = [
CDF(c[:, k] - eta) - CDF(c[:, k - 1] - eta) for k in range(1, K - 1)
]
pK = 1.0 - CDF(c[:, -1] - eta)
probs = pt.stack([p0, *mids, pK], axis=1) # (N, K)
probs = pt.clip(probs, eps, 1.0 - eps)
# Map observed labels to positions 0..K-1 using provided categories
y_val = y_b(dataframe) # (N,)
matches = pt.eq(y_val[:, None], cats_const[None, :]) # (N, K)
idx = pt.argmax(matches, axis=1) # (N,)
any_match = pt.any(matches, axis=1) # (N,)
rows = pt.arange(y_val.shape[0])
chosen = probs[rows, idx] # (N,)
p_valid = chosen
# Neutral labels detection
if self._has_neutrals:
neutral_matches = pt.eq(
y_val[:, None], neutrals_const[None, :]
) # (N, |neutrals|)
has_neutral = pt.any(neutral_matches, axis=1) # (N,)
else:
has_neutral = pt.zeros_like(any_match)
p = pt.where(has_neutral, 1.0, pt.where(any_match, p_valid, eps))
return p
return builder
# --------------------------------------------------------------------------
# Abstract CDF hooks for subclasses
# --------------------------------------------------------------------------
def _cdf_jax(self, z):
"""Subclass hook for the cumulative distribution function in JAX."""
raise NotImplementedError
def _cdf_pt(self, z):
"""Subclass hook for the cumulative distribution function in PyTensor."""
raise NotImplementedError
[docs]
class OrderedLogit(OrderedBase):
"""Ordered response model using the logistic cumulative distribution function.
Returns per-observation probabilities.
"""
def _cdf_jax(self, z):
import jax
return jax.nn.sigmoid(z)
def _cdf_pt(self, z):
return 1.0 / (1.0 + pt.exp(-z))
def __repr__(self):
return f'OrderedLogit({repr(self.eta)})'
[docs]
class OrderedProbit(OrderedBase):
"""Ordered response model using the standard normal cumulative distribution function.
Returns per-observation probabilities.
"""
def _cdf_jax(self, z):
import jax.numpy as jnp
import jax.lax as lax
return 0.5 * (1.0 + lax.erf(z / jnp.sqrt(2.0)))
def _cdf_pt(self, z):
return 0.5 * (1.0 + pt.erf(z / pt.sqrt(pt.as_tensor_variable(2.0))))
def __repr__(self):
return f'OrderedProbit({repr(self.eta)})'
[docs]
class OrderedLogLogit(OrderedLogit):
"""Ordered response model using logistic CDF, returning per-observation log-likelihoods."""
[docs]
def recursive_construct_jax_function(
self, numerically_safe: bool
) -> JaxFunctionType:
import jax.numpy as jnp
base_fn = super().recursive_construct_jax_function(numerically_safe)
def wrapper(*args, **kwargs):
p = base_fn(*args, **kwargs)
return jnp.log(jnp.clip(p, self.eps, 1.0))
return wrapper
[docs]
def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType:
base_builder = super().recursive_construct_pymc_model_builder()
eps = pt.as_tensor_variable(self.eps)
def builder(dataframe: pd.DataFrame) -> pt.TensorVariable:
p = base_builder(dataframe)
return pt.log(pt.clip(p, eps, 1.0))
return builder
[docs]
class OrderedLogProbit(OrderedProbit):
"""Ordered response model using probit CDF, returning per-observation log-likelihoods."""
[docs]
def recursive_construct_jax_function(
self, numerically_safe: bool
) -> JaxFunctionType:
import jax.numpy as jnp
base_fn = super().recursive_construct_jax_function(numerically_safe)
def wrapper(*args, **kwargs):
p = base_fn(*args, **kwargs)
return jnp.log(jnp.clip(p, self.eps, 1.0))
return wrapper
[docs]
def recursive_construct_pymc_model_builder(self) -> PymcModelBuilderType:
base_builder = super().recursive_construct_pymc_model_builder()
eps = pt.as_tensor_variable(self.eps)
def builder(dataframe: pd.DataFrame) -> pt.TensorVariable:
p = base_builder(dataframe)
return pt.log(pt.clip(p, eps, 1.0))
return builder