model.runtime.runners.site_learning_runner¶
Site-learning workflow runner for ADELM.
Module Contents¶
Classes¶
Parameter-learning workflow runner. |
Data¶
API¶
- class model.runtime.runners.site_learning_runner.SiteLearningRunner(config, model, site_ids=None, time_values=None, data=None)¶
Bases:
model.runtime.runners.base_runner.BaseRunnerParameter-learning workflow runner.
Initialization
- class EvaluationResult¶
Outcome of :meth:
evaluate: total and per-target losses plus outputs.- loss: torch.Tensor¶
None
- per_target_loss: dict¶
None
- per_target_sample_loss: dict¶
None
- per_target_site_loss: dict¶
None
- outputs: dict¶
None
- final_states: dict¶
None
- property workflow_config¶
- runtime_output_dir()¶
- nn_training_config()¶
- site_mapping()¶
- site_targets_path()¶
- load_data()¶
- evaluate(drivers=None, targets=None, initial_states=None, target_map=None, sample_loss_fn=sample_r2_loss, site_loss_fn=site_r2_loss, stage='Validation')¶
Run a no-gradient forward pass and score it against observed targets.
Used for held-out validation and test evaluation. Computes the same weighted, per-target combined loss as :meth:
trainbut without backpropagation, and appends a summary to the runtime log.Parameters
drivers, targets : dict, optional Dynamic drivers and observation targets; default to the loaded data. sample_loss_fn, site_loss_fn : callable Sample-level and site-level loss functions. stage : str Label for the runtime-log summary (e.g.
"Validation").Returns
EvaluationResult Total loss, per-target losses, model outputs, and final states.
- train(drivers=None, targets=None, optimizer=None, epochs=1, train_chunk_size=None, initial_states=None, target_map=None, sample_loss_fn=sample_r2_loss, site_loss_fn=site_r2_loss, loss_weights=None, max_grad_norm=None, skip_nan_grads=True, stop_on_nan=True, show_progress=False, progress_desc=None, show_spinup_message=True, debug=False, debug_nonfinite=False, epoch_offset=0)¶
Optimize the learnable parameters against observed targets.
Runs
epochspasses over the time series, each split into chunks oftrain_chunk_sizetimesteps to bound backpropagation depth. Per chunk, the model is run forward, per-target losses are combined and weighted, gradients are backpropagated, optionally clipped, and applied. Optional spinup cycles warm-start prognostic states (no_grad) at the start of each epoch. Non-finite losses or gradients are optionally diagnosed to a debug file and either abort the run or skip the chunk.Parameters
drivers, targets : dict, optional Dynamic drivers and observation targets; default to the loaded data. optimizer : torch.optim.Optimizer Required. Updates the learnable NN parameters. epochs : int Number of passes over the training record. train_chunk_size : int, optional Timesteps per backprop chunk;
Noneuses the full record. sample_loss_fn, site_loss_fn : callable Sample-level and site-level loss functions (see :meth:_target_losses). loss_weights : dict, optional Per-target weights; default to the normalized config weights. max_grad_norm : float, optional Gradient-norm clipping threshold;Nonedisables clipping. skip_nan_grads, stop_on_nan : bool Behaviour on non-finite gradients / loss. debug_nonfinite : bool Write a detailed diagnostic file when a non-finite value is hit.Returns
list[dict] Per-epoch history records (loss, per-target losses, grad norm, final states), including any skipped-chunk markers.
- model.runtime.runners.site_learning_runner.__all__¶
[‘SiteLearningRunner’]