Source code for training.runner

"""Generic training/evaluation runners for supervised one-step models."""

import copy
import importlib
import json
import math
import random
import subprocess
from pathlib import Path
from typing import Any

import numpy as np

try:
    import torch
    from torch.utils.data import DataLoader, Subset
except ModuleNotFoundError as exc:
    raise ModuleNotFoundError(
        "PyTorch is required for training/evaluation but is not installed in this "
        "environment. Use the `etl` or `etl-ngc` Docker service, or install `torch` "
        "in your active environment."
    ) from exc

from training import import_physicsnemo_attr
from training.adapters import get_adapter
from training.datasets import split_indices
from training.experiment import Experiment
from training.losses import get_loss_fn
from training.models import get_build_fn_and_adapter, model_entrypoint_string
from training.plotting import (
    resolve_plot_indices,
    save_delta_p_parity_plot,
    save_grid_prediction_plots,
    save_parity_plot,
    save_pointwise_profile_plots,
    save_profile_prediction_plots,
    select_best_worst_pointwise_cases,
)

try:
    from tqdm.auto import tqdm
except ModuleNotFoundError:
    tqdm = None


[docs] def to_plain_dict(cfg: Any) -> dict[str, Any]: """Convert OmegaConf DictConfig or dict to a plain dict.""" if isinstance(cfg, dict): return cfg try: from omegaconf import DictConfig, OmegaConf if isinstance(cfg, DictConfig): plain = OmegaConf.to_container(cfg, resolve=True) if isinstance(plain, dict): return plain except ModuleNotFoundError: pass raise TypeError(f"Expected dict-like config, got {type(cfg)}")
_to_plain_dict = to_plain_dict
[docs] def set_seed(seed: int) -> None: """Seed all random number generators for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
_set_seed = set_seed
[docs] def resolve_device(device_arg: str) -> torch.device: """Parse a device string ('auto', 'cpu', 'cuda') into a torch.device.""" if device_arg == "auto": return torch.device("cuda" if torch.cuda.is_available() else "cpu") if device_arg == "cuda" and not torch.cuda.is_available(): raise RuntimeError("Requested CUDA but no CUDA device is available.") return torch.device(device_arg)
_resolve_device = resolve_device def _resolve_path(raw_path: str | Path) -> Path: return Path(raw_path).expanduser().resolve() def _resolve_metrics_out_path(output_cfg: dict[str, Any], checkpoint_path: Path) -> Path | None: """Resolve where evaluation metrics JSON should be written.""" metrics_out_value = output_cfg.get("metrics_out") if metrics_out_value is None: return None if str(metrics_out_value).strip().lower() == "auto": return checkpoint_path.with_name("eval_metrics.json") return _resolve_path(str(metrics_out_value)) def _load_object(entrypoint: str): if ":" not in entrypoint: raise ValueError( f"Invalid entrypoint '{entrypoint}'. Expected format 'module.path:object'." ) module_path, object_name = entrypoint.rsplit(":", 1) module = importlib.import_module(module_path) if not hasattr(module, object_name): raise AttributeError( f"Entrypoint object '{object_name}' not found in module '{module_path}'." ) return getattr(module, object_name)
[docs] def build_experiment( experiment_entrypoint: str | None, model, optimizer, loss_fn, adapter, device: torch.device, **kwargs, ) -> Experiment: """Instantiate an Experiment from an optional entrypoint string. Extra *kwargs* are forwarded to the experiment constructor, allowing custom experiments to accept domain-specific arguments. """ if experiment_entrypoint: experiment_cls = _load_object(experiment_entrypoint) else: experiment_cls = Experiment experiment = experiment_cls( model=model, optimizer=optimizer, loss_fn=loss_fn, adapter=adapter, device=device, **kwargs, ) if not hasattr(experiment, "training_step") or not hasattr(experiment, "eval_step"): raise TypeError("Experiment class must define training_step() and eval_step() methods.") return experiment
_build_experiment = build_experiment def _git_code_version() -> str: try: result = subprocess.run( ["git", "rev-parse", "--short", "HEAD"], check=True, capture_output=True, text=True, cwd=Path(__file__).resolve().parents[2], ) return result.stdout.strip() except Exception: return "unknown" def _collect_resolved_model_params( model, model_params: dict, dataset_info: dict, ) -> dict[str, Any]: resolved_from_model = getattr(model, "_resolved_model_params", None) if isinstance(resolved_from_model, dict): return dict(resolved_from_model) resolved = dict(model_params) for key in ("in_channels", "out_channels", "edge_dim", "spatial_shape"): if key in dataset_info and key not in resolved: resolved[key] = dataset_info[key] return resolved def _serialize_norm_stats( norm_stats: dict[str, torch.Tensor] | None, ) -> dict[str, list[float]] | None: if not norm_stats: return None return { "x_mean": norm_stats["x_mean"].detach().cpu().tolist(), "x_std": norm_stats["x_std"].detach().cpu().tolist(), }
[docs] def normalize_split_cfg(split_cfg: dict, default_seed: int) -> dict[str, Any]: """Fill in default values for a split config dict.""" normalized = dict(split_cfg) normalized.setdefault("strategy", "sequential") if normalized["strategy"] in {"sequential", "random", "stratified"}: normalized.setdefault("train_ratio", 0.8) if normalized["strategy"] in {"random", "stratified"}: normalized.setdefault("seed", default_seed) return normalized
_normalize_split_cfg = normalize_split_cfg
[docs] def prepare_training(cfg_dict: dict) -> dict[str, Any]: """Build adapter, dataset, dataset_info, and build_fn from a config dict. Returns a dict with keys: model_cfg, data_cfg, training_cfg, output_cfg, adapter_name, adapter, dataset, dataset_info, build_fn, device, seed. This is the shared setup that both ``train()`` and HPO use. """ model_cfg = dict(cfg_dict.get("model") or {}) data_cfg = dict(cfg_dict.get("data") or {}) training_cfg = dict(cfg_dict.get("training") or {}) output_cfg = dict(cfg_dict.get("output") or {}) if not data_cfg.get("zarr_dir"): raise ValueError("data.zarr_dir is required.") seed = int(training_cfg.get("seed", 42)) set_seed(seed) device = resolve_device(str(training_cfg.get("device", "auto"))) build_fn, adapter_name = get_build_fn_and_adapter(model_cfg) adapter = get_adapter(adapter_name) dataset = adapter.build_dataset(data_cfg) dataset_info = adapter.dataset_info(dataset) return { "model_cfg": model_cfg, "data_cfg": data_cfg, "training_cfg": training_cfg, "output_cfg": output_cfg, "adapter_name": adapter_name, "adapter": adapter, "dataset": dataset, "dataset_info": dataset_info, "build_fn": build_fn, "device": device, "seed": seed, }
[docs] def train_one_epoch(experiment: Experiment, dataloader: DataLoader) -> float: """Run one training epoch, return average loss. Raises ``RuntimeError`` if the dataloader produces zero batches. """ running_loss = 0.0 num_batches = 0 for batch in dataloader: running_loss += experiment.training_step(batch) num_batches += 1 if num_batches == 0: raise RuntimeError("No training batches were produced.") return running_loss / num_batches
[docs] def compute_val_loss(experiment: Experiment, val_loader: DataLoader) -> float: """Evaluate on a validation loader, return average loss. Uses ``experiment.validation_step()`` which defaults to ``eval_step() + loss_fn``. Custom experiments can override it. Raises ``RuntimeError`` if the loader produces zero batches. """ total = 0.0 n = 0 with torch.no_grad(): for batch in val_loader: total += experiment.validation_step(batch) n += 1 if n == 0: raise RuntimeError("Validation loader produced zero batches.") return total / n + float(experiment.validation_epoch_loss(val_loader))
[docs] def train(cfg: dict | Any) -> dict[str, Any]: """Train a supervised model and save checkpoint + run_meta.json.""" cfg_dict = to_plain_dict(cfg) prep = prepare_training(cfg_dict) model_cfg = prep["model_cfg"] data_cfg = prep["data_cfg"] training_cfg = prep["training_cfg"] output_cfg = prep["output_cfg"] adapter_name = prep["adapter_name"] adapter = prep["adapter"] dataset = prep["dataset"] dataset_info = prep["dataset_info"] build_fn = prep["build_fn"] device = prep["device"] seed = prep["seed"] model_params = dict(model_cfg.get("params") or {}) model = build_fn(model_params, dataset_info).to(device) loss_name = str(training_cfg.get("loss", "mse")) loss_fn = get_loss_fn(loss_name) lr = float(training_cfg.get("lr", 1.0e-3)) weight_decay = float(training_cfg.get("weight_decay", 0.0)) if weight_decay > 0: optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) else: optimizer = torch.optim.Adam(model.parameters(), lr=lr) experiment_entrypoint = training_cfg.get("experiment") experiment_kwargs: dict[str, Any] = {} experiment = _build_experiment( experiment_entrypoint=experiment_entrypoint, model=model, optimizer=optimizer, loss_fn=loss_fn, adapter=adapter, device=device, **experiment_kwargs, ) split_cfg = _normalize_split_cfg(dict(data_cfg.get("split") or {}), default_seed=seed) num_cases = len(dataset.sim_names) if hasattr(dataset, "sim_names") else len(dataset) train_idx, test_idx, train_sims, test_sims = split_indices( num_cases=num_cases, split_cfg=split_cfg, sim_names=dataset.sim_names, ) # Early stopping: carve a validation split from the training cases. early_stop_cfg = dict(training_cfg.get("early_stopping") or {}) patience = int(early_stop_cfg.get("patience", 0)) use_early_stopping = patience > 0 train_case_idx = list(train_idx) val_case_idx: list = [] if use_early_stopping: val_ratio = float(early_stop_cfg.get("val_ratio", 0.15)) rng_es = random.Random(seed + 1) shuffled_train = list(train_idx) rng_es.shuffle(shuffled_train) n_val_cases = max(1, round(len(shuffled_train) * val_ratio)) val_case_idx = shuffled_train[:n_val_cases] train_case_idx = shuffled_train[n_val_cases:] if not train_case_idx: raise ValueError("Training split is empty after validation split. Reduce val_ratio.") # Pointwise normalization must be fit on training cases only. if adapter_name == "pointwise" and bool(data_cfg.get("normalize", False)): normalized_data_cfg = dict(data_cfg) normalized_data_cfg["norm_from_case_indices"] = train_case_idx dataset = adapter.build_dataset(normalized_data_cfg) if hasattr(dataset, "subset_by_case_indices"): train_dataset = dataset.subset_by_case_indices(train_case_idx) val_dataset = dataset.subset_by_case_indices(val_case_idx) if use_early_stopping else None else: train_dataset = Subset(dataset, train_case_idx) val_dataset = Subset(dataset, val_case_idx) if use_early_stopping else None # Let the experiment bind any case-specific state derived from the # finalised train/val datasets (target names, per-case geometry, # normalisation stats, etc.). Generic experiments are no-ops. experiment.prepare_for_training(train_dataset, val_dataset, device) epochs = int(training_cfg.get("epochs", 20)) batch_size = int(training_cfg.get("batch_size", 4)) num_workers = int(training_cfg.get("num_workers", 0)) if epochs < 1: raise ValueError("training.epochs must be >= 1.") if batch_size < 1: raise ValueError("training.batch_size must be >= 1.") if num_workers < 0: raise ValueError("training.num_workers must be >= 0.") dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=device.type == "cuda", persistent_workers=num_workers > 0, collate_fn=adapter.collate_fn(), ) val_loader = None if use_early_stopping and val_dataset is not None: val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=device.type == "cuda", persistent_workers=num_workers > 0, collate_fn=adapter.collate_fn(), ) scheduler_name = str(training_cfg.get("lr_scheduler") or "") scheduler = None if scheduler_name == "cosine": warmup_epochs = int(training_cfg.get("lr_warmup_epochs", 0)) if warmup_epochs > 0 and warmup_epochs < epochs: warmup_sched = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_epochs ) cosine_sched = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=epochs - warmup_epochs, eta_min=1e-7 ) scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup_sched, cosine_sched], milestones=[warmup_epochs] ) else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=epochs, eta_min=1e-7 ) model_name = str(model_cfg.get("name", "custom")) if adapter_name == "pointwise": print( f"Training model='{model_name}' adapter='{adapter_name}' on " f"{len(train_case_idx)} train case(s) ({len(train_dataset)} samples), " f"{len(test_idx)} test case(s), device={device}." ) else: print( f"Training model='{model_name}' adapter='{adapter_name}' on {len(train_dataset)} " f"train case(s), {len(test_idx)} test case(s), device={device}." ) if use_early_stopping: print(f"Early stopping enabled (patience={patience}, val_cases={len(val_case_idx)}).") best_val_loss = float("inf") best_state_dict = None patience_counter = 0 last_avg_loss = float("nan") epoch_iter = range(1, epochs + 1) epoch_progress = None if tqdm is not None: epoch_progress = tqdm(epoch_iter, total=epochs, desc="training", unit="epoch") epoch_iter = epoch_progress for epoch in epoch_iter: running_loss = 0.0 num_batches = 0 for batch in dataloader: loss_value = experiment.training_step(batch) running_loss += loss_value num_batches += 1 if num_batches == 0: raise RuntimeError("No training batches were produced.") avg_loss = running_loss / num_batches last_avg_loss = avg_loss experiment.on_epoch_end(int(epoch), avg_loss) experiment.on_epoch_end_extra_step() if scheduler is not None: scheduler.step() if use_early_stopping and val_loader is not None: val_loss = compute_val_loss(experiment, val_loader) if val_loss < best_val_loss: best_val_loss = val_loss best_state_dict = copy.deepcopy(model.state_dict()) patience_counter = 0 else: patience_counter += 1 if epoch_progress is not None: epoch_progress.set_postfix( loss=f"{avg_loss:.3e}", val=f"{val_loss:.3e}", patience=patience_counter ) else: print( f"epoch {epoch}/{epochs}: loss={avg_loss:.6e} " f"val_loss={val_loss:.6e} patience={patience_counter}/{patience}" ) if patience_counter >= patience: print(f"Early stopping at epoch {epoch} (best val_loss={best_val_loss:.6e}).") break else: if epoch_progress is not None: epoch_progress.set_postfix(loss=f"{avg_loss:.3e}") else: print(f"epoch {epoch}/{epochs}: loss={avg_loss:.6e}") if use_early_stopping and best_state_dict is not None: model.load_state_dict(best_state_dict) print(f"Restored best weights (val_loss={best_val_loss:.6e}).") checkpoint_value = output_cfg.get("checkpoint") if not checkpoint_value: checkpoint_value = f"../data/models/{model_name}_model.mdlus" checkpoint_path = _resolve_path(str(checkpoint_value)) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) model.save(str(checkpoint_path)) model_params_resolved = _collect_resolved_model_params(model, model_params, dataset_info) split_meta: dict[str, Any] = { "strategy": split_cfg["strategy"], "train_sims": train_sims, "test_sims": test_sims, } if "train_ratio" in split_cfg: split_meta["train_ratio"] = float(split_cfg["train_ratio"]) if "seed" in split_cfg: split_meta["seed"] = int(split_cfg["seed"]) if split_cfg.get("strategy") == "file": split_meta["train_file"] = str(_resolve_path(str(split_cfg["train_file"]))) split_meta["test_file"] = str(_resolve_path(str(split_cfg["test_file"]))) if hasattr(dataset, "input_columns"): data_meta = { "zarr_dir": str(_resolve_path(str(data_cfg["zarr_dir"]))), "input_columns": list(dataset.input_columns), "output_columns": list(dataset.output_columns), "normalize": bool(getattr(dataset, "normalize", False)), "norm_stats": _serialize_norm_stats(getattr(dataset, "norm_stats", None)), "norm_fit_train_sims": [dataset.sim_names[i] for i in train_case_idx], "adapter": adapter_name, "local_velocity_normalization": bool( getattr(dataset, "local_velocity_normalization", False) ), "target_transform": data_cfg.get("target_transform"), "dataset_entrypoint": data_cfg.get("dataset_entrypoint"), "exclude_cases": getattr(dataset, "exclude_cases", []) or [], "min_Dr": float(data_cfg.get("min_Dr")) if data_cfg.get("min_Dr") is not None else None, } else: data_meta = { "zarr_dir": str(_resolve_path(str(data_cfg["zarr_dir"]))), "input_fields": list(dataset.input_fields), "output_fields": list(dataset.output_fields), "input_time_idx": int(dataset.input_time_idx), "target_time_idx": int(dataset.target_time_idx), } run_meta = { "code_version": _git_code_version(), "model_name": model_name, "entrypoint": model_entrypoint_string(model_cfg, build_fn), "adapter": adapter_name, "model_params": model_params, "model_params_resolved": model_params_resolved, "data": data_meta, "split": split_meta, "training": { "epochs": epochs, "loss": loss_name, "lr": lr, "lr_scheduler": scheduler_name or None, "lr_warmup_epochs": int(training_cfg.get("lr_warmup_epochs", 0)), "seed": seed, "final_train_loss": float(last_avg_loss), "best_val_loss": float(best_val_loss) if use_early_stopping else None, "experiment": experiment_entrypoint, }, "checkpoint": str(checkpoint_path), } run_meta_path = checkpoint_path.with_name("run_meta.json") run_meta_path.write_text(json.dumps(run_meta, indent=2), encoding="utf-8") print(f"Saved model checkpoint to {checkpoint_path}") print(f"Saved run metadata to {run_meta_path}") return { "checkpoint": str(checkpoint_path), "run_meta": str(run_meta_path), "final_train_loss": float(last_avg_loss), "train_cases": len(train_case_idx), "train_samples": len(train_dataset), "test_cases": len(test_idx), }
def _has_hpo(cfg: Any) -> bool: """True iff the config requests Optuna HPO (non-empty search_space).""" cfg_dict = to_plain_dict(cfg) if not isinstance(cfg, dict) else cfg hpo = cfg_dict.get("hpo") if not isinstance(hpo, dict): return False return bool(hpo.get("search_space")) def _log_hpo_summary(results: dict[str, Any]) -> None: """Print the post-run HPO summary that previously lived in src/train.py.""" n_complete = results.get("n_complete", 0) n_pruned = results.get("n_pruned", 0) n_trials = results.get("n_trials", 0) print(f"\nHPO complete: {n_complete} finished, {n_pruned} pruned, {n_trials} total") if n_complete > 0: print(f"Best trial #{results['best_trial_number']}: val_loss={results['best_value']:.6e}") print(f"Best params: {json.dumps(results['best_params'], indent=2)}") print(f"Artifacts saved to: {results['output_dir']}") if "retrain" in results: retrain = results["retrain"] print(f"\nRetrained model saved to: {retrain['checkpoint']}") print(f"Final train loss: {retrain['final_train_loss']:.6e}")
[docs] def train_or_hpo(cfg: Any) -> None: """Dispatch to Optuna HPO when ``cfg.hpo.search_space`` is populated; train otherwise. Single source of truth for the train-vs-HPO branching shared by every ``@hydra.main`` entry point (top-level ``src/train.py`` and the per-case ``src/cases/<case>/train.py`` wrappers). Entry-point choice only affects which config is loaded — never whether HPO runs. """ if _has_hpo(cfg): from training.hpo.study import run_hpo results = run_hpo(to_plain_dict(cfg)) _log_hpo_summary(results) else: train(cfg)
def _indices_for_test_split( sim_names: list[str], split_meta: dict[str, Any], ) -> tuple[list[int], list[str], list[str]]: train_sims = [str(name) for name in split_meta.get("train_sims", [])] test_sims = [str(name) for name in split_meta.get("test_sims", [])] if test_sims: sim_to_idx = {name: idx for idx, name in enumerate(sim_names)} unknown_test = [name for name in test_sims if name not in sim_to_idx] if unknown_test: raise ValueError(f"run_meta split contains unknown test sim name(s): {unknown_test}") test_idx = [sim_to_idx[name] for name in test_sims] return test_idx, train_sims, test_sims reconstructed_split = _normalize_split_cfg(dict(split_meta), default_seed=42) train_idx, test_idx, train_sims, test_sims = split_indices( num_cases=len(sim_names), split_cfg=reconstructed_split, sim_names=sim_names, ) _ = train_idx return test_idx, train_sims, test_sims
[docs] def evaluate(cfg: dict | Any) -> dict[str, Any]: """Evaluate checkpoint using run_meta.json to reconstruct dataset and split.""" cfg_dict = _to_plain_dict(cfg) eval_cfg = dict(cfg_dict.get("eval") or {}) output_cfg = dict(cfg_dict.get("output") or {}) checkpoint_value = eval_cfg.get("checkpoint") if not checkpoint_value: raise ValueError("eval.checkpoint is required.") checkpoint_path = _resolve_path(str(checkpoint_value)) if not checkpoint_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") run_meta_value = eval_cfg.get("run_meta") run_meta_path = ( _resolve_path(str(run_meta_value)) if run_meta_value else checkpoint_path.with_name("run_meta.json") ) if not run_meta_path.exists(): raise FileNotFoundError( f"run_meta.json not found: {run_meta_path}. " "Set eval.run_meta explicitly or train with the new runner first." ) run_meta = json.loads(run_meta_path.read_text(encoding="utf-8")) adapter_name = str(run_meta["adapter"]) adapter = get_adapter(adapter_name) data_meta = dict(run_meta.get("data") or {}) if not data_meta: raise ValueError(f"run_meta at {run_meta_path} is missing 'data' section.") if "input_columns" in data_meta: data_cfg = { "zarr_dir": data_meta["zarr_dir"], "input_columns": data_meta.get("input_columns"), "output_columns": data_meta.get("output_columns"), "normalize": bool(data_meta.get("normalize", False)), "norm_stats": data_meta.get("norm_stats"), "local_velocity_normalization": bool( data_meta.get("local_velocity_normalization", False) ), "target_transform": data_meta.get("target_transform"), "dataset_entrypoint": data_meta.get("dataset_entrypoint"), "exclude_cases": data_meta.get("exclude_cases", []), "min_Dr": data_meta.get("min_Dr"), } else: data_cfg = { "zarr_dir": data_meta["zarr_dir"], "input_fields": data_meta.get("input_fields"), "output_fields": data_meta.get("output_fields"), "input_time_idx": int(data_meta.get("input_time_idx", 0)), "target_time_idx": int(data_meta.get("target_time_idx", -1)), } dataset = adapter.build_dataset(data_cfg) split_meta = dict(run_meta.get("split") or {}) test_idx, train_sims, test_sims = _indices_for_test_split(dataset.sim_names, split_meta) if hasattr(dataset, "subset_by_case_indices"): eval_dataset = dataset.subset_by_case_indices(test_idx) else: eval_dataset = Subset(dataset, test_idx) device = _resolve_device(str(eval_cfg.get("device", "auto"))) batch_size = int(eval_cfg.get("batch_size", 4)) num_workers = int(eval_cfg.get("num_workers", 0)) if batch_size < 1: raise ValueError("eval.batch_size must be >= 1.") if num_workers < 0: raise ValueError("eval.num_workers must be >= 0.") dataloader = DataLoader( eval_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=device.type == "cuda", persistent_workers=num_workers > 0, collate_fn=adapter.collate_fn(), ) # PhysicsNeMo re-exports Module at the top level. NGC physicsnemo:25.11 # dropped the legacy `physicsnemo.core.module` path; the top-level export # works in both NGC and the vendored submodule. module_cls = import_physicsnemo_attr("physicsnemo", "Module") model = module_cls.from_checkpoint(str(checkpoint_path)).to(device) loss_name = str(run_meta.get("training", {}).get("loss", "mse")) loss_fn = get_loss_fn(loss_name if loss_name in {"mse", "l1", "relative_l2"} else "mse") experiment_entrypoint = eval_cfg.get("experiment") or run_meta.get("training", {}).get( "experiment" ) experiment = _build_experiment( experiment_entrypoint=experiment_entrypoint, model=model, optimizer=None, loss_fn=loss_fn, adapter=adapter, device=device, ) if hasattr(dataset, "output_columns"): output_fields = list(dataset.output_columns) else: output_fields = list(dataset.output_fields) total_se_per_field = torch.zeros(len(output_fields), dtype=torch.float64) total_samples = 0 all_preds: list[torch.Tensor] = [] all_targets: list[torch.Tensor] = [] with torch.no_grad(): for batch in dataloader: pred, target = experiment.eval_step(batch) field_se, sample_count = adapter.accumulate_metrics(batch, pred, target) total_se_per_field += field_se.detach().to(torch.float64).cpu() total_samples += int(sample_count) all_preds.append(pred.detach().cpu()) all_targets.append(target.detach().cpu()) if total_samples == 0: raise RuntimeError("No evaluation samples were processed.") per_field_mse = total_se_per_field / float(total_samples) per_field_rmse = torch.sqrt(per_field_mse) overall_mse = float(per_field_mse.mean().item()) overall_rmse = math.sqrt(overall_mse) # Experiment-specific extended metrics (no-op for the base Experiment; # alpha-D and other case experiments override to add per-field R², per-region # breakdown, Δp integration error, etc.). extended_metrics: dict[str, Any] = experiment.compute_extended_metrics( eval_dataset, all_preds, all_targets, ) plot_files: list[str] = [] plot_dir_value = output_cfg.get("plot_dir") if plot_dir_value is not None: try: if adapter.family == "grid": plot_indices = resolve_plot_indices( num_cases=len(eval_dataset), raw_indices=output_cfg.get("plot_case_indices"), max_cases=int(output_cfg.get("plot_max_cases", 3)), ) plot_files = save_grid_prediction_plots( model=model, dataset=eval_dataset, output_fields=output_fields, device=device, plot_dir=plot_dir_value, plot_indices=plot_indices, plot_cmap=str(output_cfg.get("plot_cmap", "viridis")), plot_dpi=int(output_cfg.get("plot_dpi", 150)), quiver_step=int(output_cfg.get("plot_quiver_step", 4)), vel_x_field=str(output_cfg.get("plot_velocity_x_field", "vel_x")), vel_y_field=str(output_cfg.get("plot_velocity_y_field", "vel_y")), ) elif adapter.family == "pointwise": plot_cases = select_best_worst_pointwise_cases(extended_metrics, output_fields) plot_files = save_pointwise_profile_plots( model=model, dataset=eval_dataset, output_fields=output_fields, device=device, plot_dir=plot_dir_value, case_entries=plot_cases, plot_dpi=int(output_cfg.get("plot_dpi", 150)), decode_fn=experiment.decode_for_plotting, baseline_fn=experiment.baseline_for_plotting, ) elif adapter.family == "profile": plot_cases = select_best_worst_pointwise_cases(extended_metrics, output_fields) plot_files = save_profile_prediction_plots( model=model, dataset=eval_dataset, output_fields=output_fields, device=device, plot_dir=plot_dir_value, case_entries=plot_cases, plot_dpi=int(output_cfg.get("plot_dpi", 150)), decode_fn=experiment.decode_for_plotting, baseline_fn=experiment.baseline_for_plotting, ) else: raise ValueError( "Plotting currently supports grid, pointwise, and profile " f"adapters. Received adapter='{adapter.family}'." ) if adapter.family in {"pointwise", "profile"}: parity_files = save_parity_plot( cat_preds=torch.cat(all_preds, dim=0), cat_targets=torch.cat(all_targets, dim=0), dataset=eval_dataset, output_fields=output_fields, plot_dir=plot_dir_value, decode_fn=experiment.decode_for_plotting, plot_dpi=int(output_cfg.get("plot_dpi", 150)), ) plot_files.extend(parity_files) dp_per_case = extended_metrics.get("delta_p", {}).get("per_case") or [] dp_parity = save_delta_p_parity_plot( per_case=dp_per_case, plot_dir=plot_dir_value, plot_dpi=int(output_cfg.get("plot_dpi", 150)), ) if dp_parity is not None: plot_files.append(dp_parity) except ModuleNotFoundError as exc: print(f"Skipping plot generation: {exc}") payload = { "zarr_dir": str(_resolve_path(str(data_cfg["zarr_dir"]))), "checkpoint": str(checkpoint_path), "run_meta": str(run_meta_path), "adapter": adapter.family, "num_cases": len(test_idx), "num_samples": total_samples, "train_cases": len(train_sims), "test_cases": len(test_idx), "overall": { "mse": overall_mse, "rmse": overall_rmse, }, "per_field": [ { "name": field_name, "mse": float(mse_value), "rmse": float(rmse_value), } for field_name, mse_value, rmse_value in zip( output_fields, per_field_mse.tolist(), per_field_rmse.tolist(), ) ], "plots": { "plot_dir": str(_resolve_path(str(plot_dir_value))) if plot_dir_value is not None else None, "num_saved": len(plot_files), "files": plot_files, }, "split": { "train_sims": train_sims, "test_sims": test_sims, }, } if extended_metrics: payload["extended"] = extended_metrics print( f"Evaluated adapter='{adapter.family}' on {len(test_idx)} test case(s) " f"({total_samples} sample(s)), " f"overall mse={overall_mse:.6e}, rmse={overall_rmse:.6e}." ) for row in payload["per_field"]: print(f"{row['name']}: mse={row['mse']:.6e}, rmse={row['rmse']:.6e}") if extended_metrics: experiment.print_extended_metrics(extended_metrics) metrics_out_path = _resolve_metrics_out_path(output_cfg, checkpoint_path) if metrics_out_path is not None: metrics_out_path.parent.mkdir(parents=True, exist_ok=True) metrics_out_path.write_text(json.dumps(payload, indent=2), encoding="utf-8") print(f"Saved metrics JSON to {metrics_out_path}") if plot_files: print(f"Saved {len(plot_files)} plot(s) to {plot_dir_value}") return payload