"""Experiment abstraction for optional custom training/eval steps."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
import torch
[docs]
class Experiment:
"""Default experiment for supervised one-step training."""
def __init__(
self,
model,
optimizer: torch.optim.Optimizer | None,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None,
adapter,
device: torch.device,
**kwargs,
):
self.model = model
self.optimizer = optimizer
self.loss_fn = loss_fn
self.adapter = adapter
self.device = device
[docs]
def training_step(self, batch) -> float:
if self.optimizer is None or self.loss_fn is None:
raise RuntimeError("Experiment.training_step requires optimizer and loss_fn.")
self.model.train()
prepared = self.adapter.build_batch(batch, self.device)
result = self.adapter.forward_train(self.model, prepared)
if len(result) >= 3:
pred, target, weight = result[0], result[1], result[2]
loss = self.loss_fn(pred, target, weight)
else:
pred, target = result
loss = self.loss_fn(pred, target)
self.optimizer.zero_grad(set_to_none=True)
loss.backward()
self.optimizer.step()
return float(loss.detach().cpu())
[docs]
def eval_step(self, batch) -> tuple[torch.Tensor, torch.Tensor]:
self.model.eval()
prepared = self.adapter.build_batch(batch, self.device)
return self.adapter.forward_eval(self.model, prepared)
[docs]
def validation_step(self, batch) -> float:
"""Compute validation loss for one batch.
Override in subclasses for custom validation logic (e.g., weighted
metrics, physics-informed losses). The default uses ``eval_step``
followed by ``loss_fn``.
"""
if self.loss_fn is None:
raise RuntimeError("Experiment.validation_step requires loss_fn.")
self.model.eval()
prepared = self.adapter.build_batch(batch, self.device)
result = self.adapter.forward_train(self.model, prepared)
if len(result) >= 3:
pred, target, weight = result[0], result[1], result[2]
return float(self.loss_fn(pred, target, weight).detach().cpu())
pred, target = result
return float(self.loss_fn(pred, target).detach().cpu())
[docs]
def validation_epoch_loss(self, val_loader) -> float:
"""Optional validation loss term computed once per epoch.
Subclasses can override this to add case-level or physics-based
penalties that require access to the full validation set rather than
individual batches.
"""
_ = val_loader
return 0.0
[docs]
def on_epoch_end(self, epoch: int, avg_loss: float) -> None:
_ = (epoch, avg_loss)
[docs]
def compute_extended_metrics(
self,
eval_dataset,
all_preds: list[torch.Tensor],
all_targets: list[torch.Tensor],
) -> dict[str, Any]:
"""Per-experiment extended evaluation metrics.
Receives per-batch CPU tensor lists (uncatted because graph adapters
have variable shapes) and the eval dataset. Subclasses concatenate
as needed and produce a metrics dict the runner merges into the
``extended`` block of the metrics JSON. Default: empty.
"""
_ = (eval_dataset, all_preds, all_targets)
return {}
[docs]
def print_extended_metrics(self, metrics: dict[str, Any]) -> None:
"""Pretty-print the dict returned by ``compute_extended_metrics``.
Default: no-op (the generic runner already prints overall + per-field
MSE/RMSE). Subclasses override to format case-specific blocks.
"""
_ = metrics
[docs]
def decode_for_plotting(
self,
values: torch.Tensor,
dataset,
field_name: str,
mask,
):
"""Decode encoded model output to physical/labelled space for plotting.
Returns ``(decoded_ndarray, label)`` or ``None``. When ``None``, the
plotter shows raw encoded values with ``field_name`` as the y-axis
label and skips the physical-RMSE subtitle. Case experiments override
to apply baseline re-add, unit conversion, etc.
"""
_ = (values, dataset, field_name, mask)
return None
[docs]
def baseline_for_plotting(
self,
dataset,
field_name: str,
mask,
):
"""Return ``(decoded_baseline_ndarray, label)`` or ``None``.
Case experiments that train a residual on top of an analytical
baseline override this so the plotter can overlay the baseline as a
third curve next to ground truth and prediction. Default: no baseline.
"""
_ = (dataset, field_name, mask)
return None
[docs]
def prepare_for_training(
self,
train_dataset,
val_dataset,
device: torch.device,
) -> None:
"""Bind dataset-derived state onto the experiment before training.
Called once after datasets are subsetted but before the training loop.
Case experiments override to capture per-case geometry, target names,
normalisation statistics, etc. that the generic core would otherwise
have to read by name. Default: no-op.
"""
_ = (train_dataset, val_dataset, device)