Source code for cases.alpha_d.experiment

"""Alpha-D experiment.

Specialises :class:`training.experiment.Experiment` with the alpha-D
case's evaluation metrics (per-region pointwise + Δp integral) and the
plotting hooks the runner uses for per-case profile and parity plots.
The training step itself is inherited unchanged.
"""

from __future__ import annotations

from typing import Any

import torch

from training.experiment import Experiment


[docs] class AlphaDExperiment(Experiment): """Alpha-D-specific evaluation + plotting hooks.""" def __init__( self, model, optimizer, loss_fn, adapter, device, **kwargs, ): super().__init__(model, optimizer, loss_fn, adapter, device, **kwargs) self.local_velocity_normalization: bool = False self.alpha_d_target_name: str = "log_alpha_D" # ------------------------------------------------------------------ # Evaluation hooks # ------------------------------------------------------------------
[docs] def compute_extended_metrics( self, eval_dataset, all_preds: list[torch.Tensor], all_targets: list[torch.Tensor], ) -> dict[str, Any]: """Pointwise + Δp metrics for the alpha-D adapter. Requires a TabularPairDataset / AlphaDProfileDataset (gated by ``_row_case_idx``). Other adapters fall through to ``{}``. """ from cases.alpha_d.metrics import ( compute_delta_p_metrics, compute_pointwise_extended_metrics, ) if not hasattr(eval_dataset, "_row_case_idx"): return {} output_fields = list(getattr(eval_dataset, "output_columns", [])) if not output_fields: return {} cat_preds = torch.cat(all_preds, dim=0) cat_targets = torch.cat(all_targets, dim=0) # Profile adapter emits per-case [N_cases, O, S]; the pointwise # metrics helper expects flat [N_rows, O] aligned with the inner # TabularPairDataset row layout. Scatter back via _case_slices so # the same helper covers both adapters. if cat_preds.dim() == 3: case_slices = getattr(eval_dataset, "_case_slices", None) if case_slices is None: return {} n_rows = sum(len(s) for s in case_slices) n_fields = cat_preds.shape[1] flat_preds = torch.empty(n_rows, n_fields, dtype=cat_preds.dtype) flat_targets = torch.empty(n_rows, n_fields, dtype=cat_targets.dtype) for ci, rows in enumerate(case_slices): rows_t = torch.as_tensor(rows, dtype=torch.long) flat_preds[rows_t] = cat_preds[ci].T # [S, O] flat_targets[rows_t] = cat_targets[ci].T cat_preds, cat_targets = flat_preds, flat_targets metrics = compute_pointwise_extended_metrics( cat_preds, cat_targets, eval_dataset, output_fields, ) local_vel_norm = bool(getattr(eval_dataset, "local_velocity_normalization", False)) dp_metrics = compute_delta_p_metrics( self.model, eval_dataset, self.device, alpha_d_target_name=str(output_fields[0]), local_velocity_normalization=local_vel_norm, ) if dp_metrics: metrics["delta_p"] = dp_metrics return metrics
[docs] def print_extended_metrics(self, metrics: dict[str, Any]) -> None: from cases.alpha_d.metrics import print_extended_metrics as _print _print(metrics)
# ------------------------------------------------------------------ # Training-lifecycle hooks # ------------------------------------------------------------------
[docs] def prepare_for_training( self, train_dataset, val_dataset, device: torch.device, ) -> None: """Bind alpha-D-specific state from the datasets onto the experiment.""" if hasattr(train_dataset, "output_columns") and train_dataset.output_columns: self.alpha_d_target_name = str(train_dataset.output_columns[0]) # Track the local-velocity-normalisation flag so decode_for_plotting # and the eval-time Δp integral pick the right basis. self.local_velocity_normalization = bool( getattr(train_dataset, "local_velocity_normalization", False) ) if not self.local_velocity_normalization and val_dataset is not None: self.local_velocity_normalization = bool( getattr(val_dataset, "local_velocity_normalization", False) )
[docs] def decode_for_plotting( self, values: torch.Tensor, dataset, field_name: str, mask, ): """Re-add encoded baseline, decode to bulk α_D for profile plotting.""" from cases.alpha_d.physics.targets import ( field_values_to_physical, is_alpha_d_target, ) values = values.detach().cpu().clone() values = dataset.add_baseline_to_encoded(values, row_mask=mask, field_idx=0) d_over_D_attr = getattr(dataset, "_raw_d_local_over_D", None) if d_over_D_attr is not None: d_over_D = d_over_D_attr[mask].detach().cpu() else: d_over_D = None decoded = field_values_to_physical( values, field_name=field_name, d_over_D=d_over_D, local_velocity_normalization=bool( getattr(dataset, "local_velocity_normalization", False) ), ) label = "alpha_D" if is_alpha_d_target(field_name) else field_name return decoded.detach().cpu().numpy(), label
[docs] def baseline_for_plotting(self, dataset, field_name: str, mask): """Decode the analytical alpha-D baseline for the masked stations. Reuses ``decode_for_plotting`` with a zero residual so the physical decode pipeline (baseline re-add + unit conversion) stays in one place. """ if not getattr(dataset, "has_target_baseline", False): return None if getattr(dataset, "_baseline_encoded", None) is None: return None n = int(mask.sum().item()) zero = torch.zeros(n, dtype=torch.float32) return self.decode_for_plotting(zero, dataset, field_name, mask)