model.utils.cross_validationΒΆ

Cross-validation helpers for site-first ADELM datasets.

All ADELM tensors are assumed to use site as the first dimension. Time-varying tensors additionally use time as the second dimension.

Module ContentsΒΆ

ClassesΒΆ

CVFold

One cross-validation fold over a site-first dataset.

CrossValidation

Cross-validation interface for ADELM datasets.

APIΒΆ

class model.utils.cross_validation.CVFoldΒΆ

One cross-validation fold over a site-first dataset.

index: intΒΆ

None

name: strΒΆ

None

train_site_mask: torch.TensorΒΆ

None

valid_site_mask: torch.TensorΒΆ

None

train_time_mask: torch.TensorΒΆ

None

valid_time_mask: torch.TensorΒΆ

None

valid_sites: tupleΒΆ

()

valid_time_range: tupleΒΆ

()

target_mask(split='train')ΒΆ

Return a 2D [site, time] mask for the requested split.

mask_tensor(tensor, split='train', fill_value=torch.nan)ΒΆ

Keep only the requested split and fill the remainder.

class model.utils.cross_validation.CrossValidation(mode, folds, site_ids, time_values=None, metadata=None)ΒΆ

Bases: object

Cross-validation interface for ADELM datasets.

Use folds to iterate over prepared splits, then apply the split to a single target tensor or an entire dict[str, Tensor] collection.

Initialization

__len__()ΒΆ
__iter__()ΒΆ
split_targets(targets, fold, fill_value=torch.nan)ΒΆ

Return (train_targets, valid_targets) for one fold.

mask_targets(targets, fold, split='train', fill_value=torch.nan)ΒΆ

Mask one tensor or a dict of tensors by the requested fold split.

classmethod spatial_random(site_ids, n_folds, shuffle=True, seed=0)ΒΆ

Random spatial CV by site.

classmethod spatial_predefined(site_ids, fold_definition)ΒΆ

Spatial CV from a user-defined fold specification.

classmethod temporal_block(site_ids, time_values, n_folds, train_start, train_end)ΒΆ

Temporal CV by contiguous time blocks within the training window.

classmethod from_config(config, site_ids, scheme=None, time_values=None, n_folds=None, fold_definition=None, shuffle=True, seed=0)ΒΆ

Build a cross-validator using the existing runtime config time window.