Source code for biogeme.latent_variables.measurement_equations

"""JAX-compatible ordered-probit measurement equations for Biogeme latent-variable models.

This module constructs the joint ordered-probit measurement model typically used
in hybrid choice models. It combines:

- latent-variable definitions (structural equations, normalization rules),
- ordinal indicators (intercepts, latent-variable loadings, measurement scales),
- Likert-scale definitions (categories and ordered thresholds),
- consistency checks across model components,
- JAX-friendly Biogeme :class:`~biogeme.expressions.Expression` objects.

Thresholds (cut-points) are obtained from the :class:`~biogeme.latent_variables.likert_indicators.LikertType`
associated with each indicator via :meth:`~biogeme.latent_variables.likert_indicators.LikertType.get_thresholds`.
Threshold sharing relies on parameter naming: if multiple indicators (or types)
produce cut-points with identical parameter names, Biogeme will estimate a
single shared set of parameters.

The main entry points are:

- :func:`measurement_equations_jax`: product of measurement likelihoods.
- :func:`log_measurement_equations_jax`: sum of log measurement likelihoods.

Michel Bierlaire
Thu Dec 11 2025, 16:26:07
"""

from __future__ import annotations

from biogeme.expressions import (
    Expression,
    MultipleProduct,
    MultipleSum,
    OrderedProbit,
    Variable,
    log,
)
from biogeme.tools import assert_sets_equal

from .latent_variables import LatentVariable
from .likert_indicators import LikertIndicator, LikertType


def _ordered_model(
    latent_variables: list[LatentVariable],
    likert_indicators: list[LikertIndicator],
    likert_types: list[LikertType],
    draw_type: str,
) -> dict[str, Expression]:
    """Build ordered-probit likelihood terms for all indicators.

    For each ordinal indicator, this function constructs an
    :class:`~biogeme.expressions.OrderedProbit` term using:

    - the indicator intercept,
    - the indicator-specific measurement scale (except for normalization anchors),
    - the latent-variable structural equations (with the provided ``draw_type``),
    - the ordered cut-points returned by the corresponding
      :class:`~biogeme.latent_variables.likert_indicators.LikertType`.

    The function enforces a measurement-scale normalization: for each latent
    variable, the measurement scale (``sigma_star``) of its normalization
    indicator is fixed to 1.0.

    :param latent_variables:
        Latent-variable objects defining structural equations, the list of
        associated indicators, and the normalization anchor.
    :param likert_indicators:
        Indicator objects providing intercepts, measurement scale parameters, and
        latent-variable loading parameters.
    :param likert_types:
        Likert-type objects defining categories, neutral labels, and the
        threshold parameterization used to build cut-points.
    :param draw_type:
        Draw type assigned to each latent variable for JAX-based computations.
    :raises ValueError:
        If an indicator references an unknown ``type`` (no matching entry in
        ``likert_types``).
    :return:
        Dictionary mapping each indicator name to its ordered-probit likelihood
        expression.
    """

    # Set of indicator names appearing in any latent variable.
    all_indicators = {ind for lv in latent_variables for ind in lv.indicators}

    # Map of indicator objects by name.
    likert_mapping = {likert.name: likert for likert in likert_indicators}

    # Map of Likert types by type label.
    likert_types_mapping = {t.type: t for t in likert_types}

    # Check consistency between indicators declared by LVs and registered indicators.
    assert_sets_equal(
        name_a="Indicators",
        set_a=all_indicators,
        name_b="Likert indicators",
        set_b=set(likert_mapping.keys()),
    )

    # Intercepts: one per indicator.
    intercepts: dict[str, float | Expression] = {
        k: likert_mapping[k].intercept_parameter for k in all_indicators
    }

    # Coefficients linking latent variables to indicators.
    coefficients: dict[tuple[str, str], float | Expression] = {
        (lv.name, indicator): (
            lv.normalization.coefficient
            if indicator == lv.normalization.indicator
            else likert_mapping[indicator].get_lv_coefficient_parameter(
                latent_variable_name=lv.name
            )
        )
        for lv in latent_variables
        for indicator in lv.indicators
    }

    # Scale parameters of measurement error terms.
    sigma_star: dict[str, float | Expression] = {
        indicator: likert_mapping[indicator].scale_parameter
        for indicator in all_indicators
    }

    # Normalization: for each latent variable, force the intercept to 0.
    for lv in latent_variables:
        intercepts[lv.normalization.indicator] = 0.0

    # Normalization: for each group of thresholds, normalize one sigma_star to 1.
    # This is due to the ordered probit.
    for lt in likert_types:
        sigma_star[lt.scale_normalization] = 1.0

    # Linear predictors per indicator.
    models: dict[str, float | Expression] = {
        indicator: intercepts[indicator] for indicator in all_indicators
    }
    for lv in latent_variables:
        lv.draw_type_jax = draw_type
        for indicator in lv.indicators:
            models[indicator] += (
                coefficients[(lv.name, indicator)] * lv.structural_equation_jax
            )

    # Ordered-probit likelihood terms.
    ordered_ll: dict[str, Expression] = {}
    for indicator, model in models.items():
        the_likert = likert_mapping[indicator]
        the_likert_type = likert_types_mapping.get(the_likert.type, None)
        if the_likert_type is None:
            error_msg = (
                f"Unknown type for indicator {the_likert.name}: {the_likert.type}. "
                f"Known types: {list(likert_types_mapping.keys())}"
            )
            raise ValueError(error_msg)

        cutpoints = the_likert_type.get_thresholds()

        ordered_ll[indicator] = OrderedProbit(
            eta=model / sigma_star[indicator],
            cutpoints=[t / sigma_star[indicator] for t in cutpoints],
            y=Variable(indicator),
            categories=the_likert_type.categories,
            neutral_labels=the_likert_type.neutral_labels,
            enforce_order=True,
            eps=1e-12,
        )

    return ordered_ll


[docs] def measurement_equations_jax( latent_variables: list[LatentVariable], likert_indicators: list[LikertIndicator], likert_types: list[LikertType], draw_type: str, ) -> Expression: """Return the product of ordered-probit measurement likelihood terms. :param latent_variables: Latent-variable objects defining structural equations and associated indicators. :param likert_indicators: Likert indicator objects providing measurement parameters. :param likert_types: Likert-type objects defining categories and thresholds. :param draw_type: Draw type assigned to latent variables for JAX. :return: Expression representing the product of ordered-probit likelihood terms. """ ordered_ll = _ordered_model( latent_variables=latent_variables, likert_indicators=likert_indicators, likert_types=likert_types, draw_type=draw_type, ) return MultipleProduct(ordered_ll)
[docs] def log_measurement_equations_jax( latent_variables: list[LatentVariable], likert_indicators: list[LikertIndicator], likert_types: list[LikertType], draw_type: str, ) -> Expression: """Return the sum of log ordered-probit measurement likelihood terms. :param latent_variables: Latent-variable objects defining structural equations and associated indicators. :param likert_indicators: Likert indicator objects providing measurement parameters. :param likert_types: Likert-type objects defining categories and thresholds. :param draw_type: Draw type assigned to latent variables for JAX. :return: Expression representing the sum of log ordered-probit likelihood terms. """ ordered_ll = _ordered_model( latent_variables=latent_variables, likert_indicators=likert_indicators, likert_types=likert_types, draw_type=draw_type, ) log_ordered_ll = { indicator: log(likelihood) for indicator, likelihood in ordered_ll.items() } return MultipleSum(log_ordered_ll)