model.runtime.configs.site_learning_config¶
Site-learning runtime configuration.
Module Contents¶
Classes¶
One observed target variable under |
|
All observed target variables under the |
|
Optimizer and training-schedule settings under |
|
Cross-validation protocol under |
|
NN-weight initialization under |
|
Top-level |
API¶
- class model.runtime.configs.site_learning_config.TargetVariableConfig¶
Bases:
model.runtime.configs.base_config.BaseConfigOne observed target variable under
site_learning.targets.<name>.Attributes
mapping : str or list Name of the variable in the target NetCDF, optionally
[name, scale]. layer : int, optional 1-based soil layer to select when the model output is layer-wise. sample_loss_weight : float Relative weight of this targetās sample-level loss in the total loss. site_loss_weight : float Weight of this targetās site-level (per-site mean) loss;0disables it.- mapping: object¶
None
- layer: object¶
None
- sample_loss_weight: float¶
1.0
- site_loss_weight: float¶
1.0
- classmethod from_raw(name, raw)¶
- validate()¶
- apply_layer_selection(tensor, label)¶
- class model.runtime.configs.site_learning_config.TargetConfig¶
Bases:
model.runtime.configs.base_config.BaseConfigAll observed target variables under the
site_learning.targetsblock.Each key is a target name mapped to a :class:
TargetVariableConfig. Sample-loss weights are normalised across targets so their relative magnitude, not absolute scale, determines the multi-objective balance.- specs: dict[str, model.runtime.configs.site_learning_config.TargetVariableConfig]¶
āfield(ā¦)ā
- classmethod from_dict(raw)¶
- property normalized_sample_loss_weights¶
- property variable_names¶
- property variables¶
- property loader_mapping¶
- validate()¶
- class model.runtime.configs.site_learning_config.TrainingConfig¶
Bases:
model.runtime.configs.base_config.BaseConfigOptimizer and training-schedule settings under
site_learning.training.Attributes
num_epochs, lr, weight_decay, seed : int / float Core optimizer settings for the Adam optimizer. train_chunk_size : int, optional Number of timesteps per backpropagation chunk;
Noneuses the full record. Smaller chunks reduce memory and backprop depth. max_grad_norm : float Gradient-norm clipping threshold. val_within_train_enabled, val_fraction : bool / float Hold out a fraction of the training data for in-training validation. early_stopping_* : Patience-based early stopping on the validation metric. reduce_lr_* : ReduceLROnPlateau schedule parameters. min_sites_for_site_loss, min_samples_per_site_for_site_loss : int Thresholds below which the site-level loss is skipped (too few sites or too few samples per site to form a stable per-site mean).- num_epochs: int¶
10
- lr: float¶
0.001
- seed: int¶
42
- train_chunk_size: object¶
None
- max_grad_norm: float¶
1.0
- weight_decay: float¶
0.0001
- debug: bool¶
False
- val_within_train_enabled: bool¶
True
- val_fraction: float¶
0.3
- early_stopping_enabled: bool¶
True
- early_stopping_patience: int¶
10
- early_stopping_min_delta: float¶
0.0
- reduce_lr_enabled: bool¶
True
- reduce_lr_patience: int¶
5
- reduce_lr_factor: float¶
0.5
- reduce_lr_min_delta: float¶
0.0
- reduce_lr_min_lr: float¶
1e-06
- min_sites_for_site_loss: int¶
10
- min_samples_per_site_for_site_loss: int¶
365
- classmethod from_dict(raw)¶
- validate()¶
- class model.runtime.configs.site_learning_config.CrossValidationConfig¶
Bases:
model.runtime.configs.base_config.BaseConfigCross-validation protocol under
site_learning.cross_validation.Attributes
enabled : bool Whether to run cross-validation rather than a single train/val split. scheme : str or None
spatial(split by site) ortemporal(split by time period). n_folds : int Number of CV folds. cv_seed, shuffle : int / bool Control random fold assignment for the spatial-random mode. spatial_mode : strrandomfor randomly assigned folds, orpredefinedto read fold membership fromspatial_fold_path. spatial_fold_path : str, optional Path to a fold-definition file; required whenspatial_modeispredefined.- enabled: bool¶
False
- scheme: object¶
None
- n_folds: int¶
5
- cv_seed: int¶
42
- spatial_mode: str¶
ārandomā
- spatial_fold_path: object¶
None
- shuffle: bool¶
True
- classmethod from_dict(raw)¶
- validate()¶
- class model.runtime.configs.site_learning_config.SiteLearningInitializationConfig¶
Bases:
model.runtime.configs.base_config.BaseConfigNN-weight initialization under
site_learning.initialization.Attributes
init_nn_weights_path : str, optional Checkpoint of pretrained NN parameter weights to warm-start from. frozen_parameters : list[str] Names of learnable parameters to hold fixed during training.
- init_nn_weights_path: object¶
None
- frozen_parameters: list[str]¶
āfield(ā¦)ā
- classmethod from_dict(raw)¶
- validate()¶
- class model.runtime.configs.site_learning_config.SiteLearningConfig¶
Bases:
model.runtime.configs.base_config.BaseConfigTop-level
site_learningblock driving parameter optimization.Groups the site selection (
domain), time window (time), spinup, observed targets, training schedule, cross-validation protocol, and NN initialization into the configuration consumed bySiteLearningRunner.Attributes
output_dir : str, optional Directory for learned weights, summaries, and final inference output. targets_path : str, optional Path to the NetCDF holding observed target variables. save_final_inference : bool Whether to write a full forward simulation after training completes.
- domain: model.runtime.configs.site_simulation_config.DomainConfig¶
āfield(ā¦)ā
- time: model.runtime.configs.site_simulation_config.SiteTimeConfig¶
āfield(ā¦)ā
- spinup: model.runtime.configs.site_simulation_config.SiteSpinupConfig¶
āfield(ā¦)ā
- output_dir: object¶
None
- targets_path: object¶
None
- targets: model.runtime.configs.site_learning_config.TargetConfig¶
āfield(ā¦)ā
- training: model.runtime.configs.site_learning_config.TrainingConfig¶
āfield(ā¦)ā
- cross_validation: model.runtime.configs.site_learning_config.CrossValidationConfig¶
āfield(ā¦)ā
- save_final_inference: bool¶
True
- initialization: model.runtime.configs.site_learning_config.SiteLearningInitializationConfig¶
āfield(ā¦)ā
- classmethod from_dict(raw)¶
- validate()¶