model.runtime.runners.site_learning_runner

Site-learning workflow runner for ADELM.

Module Contents

Classes

SiteLearningRunner

Parameter-learning workflow runner.

Data

API

class model.runtime.runners.site_learning_runner.SiteLearningRunner(config, model, site_ids=None, time_values=None, data=None)

Bases: model.runtime.runners.base_runner.BaseRunner

Parameter-learning workflow runner.

Initialization

class EvaluationResult

Outcome of :meth:evaluate: total and per-target losses plus outputs.

loss: torch.Tensor

None

per_target_loss: dict

None

per_target_sample_loss: dict

None

per_target_site_loss: dict

None

outputs: dict

None

final_states: dict

None

property workflow_config
runtime_output_dir()
nn_training_config()
site_mapping()
site_targets_path()
load_data()
evaluate(drivers=None, targets=None, initial_states=None, target_map=None, sample_loss_fn=sample_r2_loss, site_loss_fn=site_r2_loss, stage='Validation')

Run a no-gradient forward pass and score it against observed targets.

Used for held-out validation and test evaluation. Computes the same weighted, per-target combined loss as :meth:train but without backpropagation, and appends a summary to the runtime log.

Parameters

drivers, targets : dict, optional Dynamic drivers and observation targets; default to the loaded data. sample_loss_fn, site_loss_fn : callable Sample-level and site-level loss functions. stage : str Label for the runtime-log summary (e.g. "Validation").

Returns

EvaluationResult Total loss, per-target losses, model outputs, and final states.

train(drivers=None, targets=None, optimizer=None, epochs=1, train_chunk_size=None, initial_states=None, target_map=None, sample_loss_fn=sample_r2_loss, site_loss_fn=site_r2_loss, loss_weights=None, max_grad_norm=None, skip_nan_grads=True, stop_on_nan=True, show_progress=False, progress_desc=None, show_spinup_message=True, debug=False, debug_nonfinite=False, epoch_offset=0)

Optimize the learnable parameters against observed targets.

Runs epochs passes over the time series, each split into chunks of train_chunk_size timesteps to bound backpropagation depth. Per chunk, the model is run forward, per-target losses are combined and weighted, gradients are backpropagated, optionally clipped, and applied. Optional spinup cycles warm-start prognostic states (no_grad) at the start of each epoch. Non-finite losses or gradients are optionally diagnosed to a debug file and either abort the run or skip the chunk.

Parameters

drivers, targets : dict, optional Dynamic drivers and observation targets; default to the loaded data. optimizer : torch.optim.Optimizer Required. Updates the learnable NN parameters. epochs : int Number of passes over the training record. train_chunk_size : int, optional Timesteps per backprop chunk; None uses the full record. sample_loss_fn, site_loss_fn : callable Sample-level and site-level loss functions (see :meth:_target_losses). loss_weights : dict, optional Per-target weights; default to the normalized config weights. max_grad_norm : float, optional Gradient-norm clipping threshold; None disables clipping. skip_nan_grads, stop_on_nan : bool Behaviour on non-finite gradients / loss. debug_nonfinite : bool Write a detailed diagnostic file when a non-finite value is hit.

Returns

list[dict] Per-epoch history records (loss, per-target losses, grad norm, final states), including any skipped-chunk markers.

model.runtime.runners.site_learning_runner.__all__

[‘SiteLearningRunner’]