model.utils.cross_validationΒΆ
Cross-validation helpers for site-first ADELM datasets.
All ADELM tensors are assumed to use site as the first dimension.
Time-varying tensors additionally use time as the second dimension.
Module ContentsΒΆ
ClassesΒΆ
One cross-validation fold over a site-first dataset. |
|
Cross-validation interface for ADELM datasets. |
APIΒΆ
- class model.utils.cross_validation.CVFoldΒΆ
One cross-validation fold over a site-first dataset.
- index: intΒΆ
None
- name: strΒΆ
None
- train_site_mask: torch.TensorΒΆ
None
- valid_site_mask: torch.TensorΒΆ
None
- train_time_mask: torch.TensorΒΆ
None
- valid_time_mask: torch.TensorΒΆ
None
- valid_sites: tupleΒΆ
()
- valid_time_range: tupleΒΆ
()
- target_mask(split='train')ΒΆ
Return a 2D
[site, time]mask for the requested split.
- mask_tensor(tensor, split='train', fill_value=torch.nan)ΒΆ
Keep only the requested split and fill the remainder.
- class model.utils.cross_validation.CrossValidation(mode, folds, site_ids, time_values=None, metadata=None)ΒΆ
Bases:
objectCross-validation interface for ADELM datasets.
Use
foldsto iterate over prepared splits, then apply the split to a single target tensor or an entiredict[str, Tensor]collection.Initialization
- __len__()ΒΆ
- __iter__()ΒΆ
- split_targets(targets, fold, fill_value=torch.nan)ΒΆ
Return
(train_targets, valid_targets)for one fold.
- mask_targets(targets, fold, split='train', fill_value=torch.nan)ΒΆ
Mask one tensor or a dict of tensors by the requested fold split.
- classmethod spatial_random(site_ids, n_folds, shuffle=True, seed=0)ΒΆ
Random spatial CV by site.
- classmethod spatial_predefined(site_ids, fold_definition)ΒΆ
Spatial CV from a user-defined fold specification.
- classmethod temporal_block(site_ids, time_values, n_folds, train_start, train_end)ΒΆ
Temporal CV by contiguous time blocks within the training window.
- classmethod from_config(config, site_ids, scheme=None, time_values=None, n_folds=None, fold_definition=None, shuffle=True, seed=0)ΒΆ
Build a cross-validator using the existing runtime config time window.