Source code for biogeme.validation.prepare_validation
"""Split data into validation and estimation samples"""
from typing import NamedTuple
import logging
import numpy as np
import pandas as pd
logger = logging.getLogger(__name__)
[docs]
class EstimationValidationIndices(NamedTuple):
estimation: pd.Index
validation: pd.Index
[docs]
def split(
dataframe: pd.DataFrame, slices: int, groups: str | None = None
) -> list[EstimationValidationIndices]:
"""
Splits a DataFrame into multiple training and validation index sets for cross-validation.
This function returns a list of `EstimationValidationIndices` named tuples, each containing
the indices for an estimation (training) set and a validation set. If a grouping column
is specified, the split ensures that all entries with the same group ID remain in the
same fold.
:param dataframe: The full dataset to split.
:param slices: The number of folds/slices. Must be >= 2.
:param groups: Optional name of the column containing group identifiers.
If provided, all rows with the same group ID are kept in the same fold.
:return: A list of EstimationValidationIndices tuples containing index sets, one per fold.
:raises ValueError: If `slices` is less than 2.
"""
if slices < 2:
raise ValueError(f'The number of slices is {slices}. It must be at least 2.')
if groups is None:
shuffled_data = dataframe.sample(frac=1)
fold_data = np.array_split(shuffled_data.index, slices)
else:
group_ids = dataframe[groups].unique()
np.random.shuffle(group_ids)
grouped_ids = np.array_split(group_ids, slices)
fold_data = [
dataframe[dataframe[groups].isin(group)].index for group in grouped_ids
]
return [
EstimationValidationIndices(
estimation=dataframe.index.difference(fold_data[i]),
validation=fold_data[i],
)
for i in range(slices)
]