training

The generic training framework. One pair of entry points (train.py, evaluate.py) drives every model through a model registry, three or four adapter classes, a shared runner, and pluggable experiment hooks.

Package

Shared helpers for generic training and evaluation workflows.

training.import_physicsnemo_module(module_path)[source]

Import a physicsnemo module from installed package or vendored source.

training.import_physicsnemo_attr(module_path, attr_name)[source]

Import a single symbol from a physicsnemo module.

Adapters

Adapter layer to unify grid and graph model families.

class training.adapters.ModelAdapter[source]

Bases: object

Interface for model/data-family-specific behavior.

family: str
build_dataset(data_cfg)[source]
dataset_info(dataset)[source]
Return type:

dict

build_batch(raw_batch, device)[source]
forward_train(model, batch)[source]
Return type:

tuple[Tensor, Tensor]

forward_eval(model, batch)[source]
Return type:

tuple[Tensor, Tensor]

collate_fn()[source]
Return type:

Callable | None

accumulate_metrics(batch, pred, target)[source]
Return type:

tuple[Tensor, int]

class training.adapters.GridAdapter[source]

Bases: ModelAdapter

family: str = 'grid'
build_dataset(data_cfg)[source]
Return type:

GridPairDataset

dataset_info(dataset)[source]
Return type:

dict

build_batch(raw_batch, device)[source]
forward_train(model, batch)[source]
Return type:

tuple[Tensor, Tensor]

forward_eval(model, batch)[source]
Return type:

tuple[Tensor, Tensor]

accumulate_metrics(batch, pred, target)[source]
Return type:

tuple[Tensor, int]

class training.adapters.GraphAdapter[source]

Bases: ModelAdapter

family: str = 'graph'
build_dataset(data_cfg)[source]
Return type:

GraphPairDataset

dataset_info(dataset)[source]
Return type:

dict

collate_fn()[source]
Return type:

Callable | None

build_batch(raw_batch, device)[source]
forward_train(model, batch)[source]
Return type:

tuple[Tensor, Tensor]

forward_eval(model, batch)[source]
Return type:

tuple[Tensor, Tensor]

accumulate_metrics(batch, pred, target)[source]
Return type:

tuple[Tensor, int]

class training.adapters.PointwiseAdapter[source]

Bases: ModelAdapter

Adapter for tabular / pointwise MLP models.

Expects datasets producing (x, y) tuples of shape (D_in,) and (D_out,) which the default collate batches into (B, D_in) and (B, D_out).

family: str = 'pointwise'
build_dataset(data_cfg)[source]
dataset_info(dataset)[source]
Return type:

dict

build_batch(raw_batch, device)[source]
forward_train(model, batch)[source]
forward_eval(model, batch)[source]
accumulate_metrics(batch, pred, target)[source]
Return type:

tuple[Tensor, int]

class training.adapters.ProfileAdapter[source]

Bases: ModelAdapter

Adapter for per-case profile (1D-conv) models.

Datasets emit per-case (x, y, w, case_idx) items shaped [C, S], [O, S], [1, S], scalar. After default collation the batch is [B, C, S], [B, O, S], [B, 1, S], [B].

Dataset construction is delegated to a case-side callable resolved from data.dataset_entrypoint (mirrors model.entrypoint); the adapter itself knows nothing about specific cases.

family: str = 'profile'
build_dataset(data_cfg)[source]
dataset_info(dataset)[source]
Return type:

dict

build_batch(raw_batch, device)[source]
forward_train(model, batch)[source]
forward_eval(model, batch)[source]
accumulate_metrics(batch, pred, target)[source]
Return type:

tuple[Tensor, int]

training.adapters.get_adapter(name)[source]
Return type:

ModelAdapter

Runner

Generic training/evaluation runners for supervised one-step models.

training.runner.to_plain_dict(cfg)[source]

Convert OmegaConf DictConfig or dict to a plain dict.

Return type:

dict[str, Any]

training.runner.set_seed(seed)[source]

Seed all random number generators for reproducibility.

Return type:

None

training.runner.resolve_device(device_arg)[source]

Parse a device string (‘auto’, ‘cpu’, ‘cuda’) into a torch.device.

Return type:

device

training.runner.build_experiment(experiment_entrypoint, model, optimizer, loss_fn, adapter, device, **kwargs)[source]

Instantiate an Experiment from an optional entrypoint string.

Extra kwargs are forwarded to the experiment constructor, allowing custom experiments to accept domain-specific arguments.

Return type:

Experiment

training.runner.normalize_split_cfg(split_cfg, default_seed)[source]

Fill in default values for a split config dict.

Return type:

dict[str, Any]

training.runner.prepare_training(cfg_dict)[source]

Build adapter, dataset, dataset_info, and build_fn from a config dict.

Returns a dict with keys: model_cfg, data_cfg, training_cfg, output_cfg, adapter_name, adapter, dataset, dataset_info, build_fn, device, seed. This is the shared setup that both train() and HPO use.

Return type:

dict[str, Any]

training.runner.train_one_epoch(experiment, dataloader)[source]

Run one training epoch, return average loss.

Raises RuntimeError if the dataloader produces zero batches.

Return type:

float

training.runner.compute_val_loss(experiment, val_loader)[source]

Evaluate on a validation loader, return average loss.

Uses experiment.validation_step() which defaults to eval_step() + loss_fn. Custom experiments can override it.

Raises RuntimeError if the loader produces zero batches.

Return type:

float

training.runner.train(cfg)[source]

Train a supervised model and save checkpoint + run_meta.json.

Return type:

dict[str, Any]

training.runner.train_or_hpo(cfg)[source]

Dispatch to Optuna HPO when cfg.hpo.search_space is populated; train otherwise.

Single source of truth for the train-vs-HPO branching shared by every @hydra.main entry point (top-level src/train.py and the per-case src/cases/<case>/train.py wrappers). Entry-point choice only affects which config is loaded — never whether HPO runs.

Return type:

None

training.runner.evaluate(cfg)[source]

Evaluate checkpoint using run_meta.json to reconstruct dataset and split.

Return type:

dict[str, Any]

Experiment

Experiment abstraction for optional custom training/eval steps.

class training.experiment.Experiment(model, optimizer, loss_fn, adapter, device, **kwargs)[source]

Bases: object

Default experiment for supervised one-step training.

training_step(batch)[source]
Return type:

float

eval_step(batch)[source]
Return type:

tuple[Tensor, Tensor]

validation_step(batch)[source]

Compute validation loss for one batch.

Override in subclasses for custom validation logic (e.g., weighted metrics, physics-informed losses). The default uses eval_step followed by loss_fn.

Return type:

float

validation_epoch_loss(val_loader)[source]

Optional validation loss term computed once per epoch.

Subclasses can override this to add case-level or physics-based penalties that require access to the full validation set rather than individual batches.

Return type:

float

on_epoch_end(epoch, avg_loss)[source]
Return type:

None

compute_extended_metrics(eval_dataset, all_preds, all_targets)[source]

Per-experiment extended evaluation metrics.

Receives per-batch CPU tensor lists (uncatted because graph adapters have variable shapes) and the eval dataset. Subclasses concatenate as needed and produce a metrics dict the runner merges into the extended block of the metrics JSON. Default: empty.

Return type:

dict[str, Any]

print_extended_metrics(metrics)[source]

Pretty-print the dict returned by compute_extended_metrics.

Default: no-op (the generic runner already prints overall + per-field MSE/RMSE). Subclasses override to format case-specific blocks.

Return type:

None

decode_for_plotting(values, dataset, field_name, mask)[source]

Decode encoded model output to physical/labelled space for plotting.

Returns (decoded_ndarray, label) or None. When None, the plotter shows raw encoded values with field_name as the y-axis label and skips the physical-RMSE subtitle. Case experiments override to apply baseline re-add, unit conversion, etc.

baseline_for_plotting(dataset, field_name, mask)[source]

Return (decoded_baseline_ndarray, label) or None.

Case experiments that train a residual on top of an analytical baseline override this so the plotter can overlay the baseline as a third curve next to ground truth and prediction. Default: no baseline.

prepare_for_training(train_dataset, val_dataset, device)[source]

Bind dataset-derived state onto the experiment before training.

Called once after datasets are subsetted but before the training loop. Case experiments override to capture per-case geometry, target names, normalisation statistics, etc. that the generic core would otherwise have to read by name. Default: no-op.

Return type:

None

on_epoch_end_extra_step()[source]

Optional extra gradient step run once per epoch.

Case experiments override to apply secondary objectives that aren’t well-expressed as per-batch losses (e.g. integral-of-profile losses that need a full case at a time). Default: no-op.

Return type:

None

Losses

Loss registry for supervised one-step training.

training.losses.mse_loss(pred, target, weight=None)[source]
Return type:

torch.Tensor

training.losses.l1_loss(pred, target, weight=None)[source]
Return type:

torch.Tensor

training.losses.relative_l2_loss(pred, target, weight=None)[source]
Return type:

torch.Tensor

training.losses.get_loss_fn(name)[source]
Return type:

Callable[[Tensor, Tensor], Tensor]

Internal datasets

MOOSEDataset (under dataset) is the public Dataset API. The classes below are the input/output-slicing variants used internally by train.py.

Dataset wrappers and split helpers for training workflows.

training.datasets.parse_field_list(raw)[source]

Parse field selections from CLI-style strings or list values.

Return type:

list[str] | None

training.datasets.resolve_time_idx(time_idx, num_steps, label)[source]

Resolve negative time indexing and validate the index.

Return type:

int

class training.datasets.GridPairDataset(zarr_dir, input_fields, output_fields, input_time_idx, target_time_idx)[source]

Bases: Dataset

Build supervised grid pairs (x, y) from MOOSEDataset grid samples.

class training.datasets.GraphPairDataset(zarr_dir, input_fields, output_fields, input_time_idx, target_time_idx)[source]

Bases: Dataset

Build supervised PyG Data objects from MOOSEDataset graph samples.

training.datasets.split_indices(num_cases, split_cfg, sim_names)[source]

Return split indices and simulation-name lists.

Return type:

tuple[list[int], list[int], list[str], list[str]]

Tabular dataset for pointwise/axial-profile MLP training.

Reads per-case Zarr stores produced by the alpha_D ETL pipeline. Each store contains:

{case_name}.zarr/

features/ float32 [N_stations, D_in] targets/ float32 [N_stations, D_out] metadata/ attrs: case_id, feature_names, target_names, …

All cases are loaded and concatenated row-wise. Splitting is done at the case level via subset_by_case_indices.

class training.datasets_tabular.TabularPairDataset(zarr_dir, input_columns=None, output_columns=None, normalize=False, norm_stats=None, norm_from_case_indices=None, throat_weight=None, downstream_weight=None, include_case_idx=False, exclude_cases=None, local_velocity_normalization=False, min_Dr=None, target_transform=None, engineered_feature_names=None, engineered_feature_builder=None)[source]

Bases: Dataset

Reads a directory of .zarr stores and produces (x, y) pairs.

Parameters:
  • zarr_dir (str or Path) – Directory containing *.zarr stores.

  • input_columns (list[str] or None) – Feature column names to select. If None, use all features.

  • output_columns (list[str] or None) – Target column names to select. If None, use all targets.

  • normalize (bool) – If True, z-score normalize input features after loading. Statistics are computed from the loaded data (or from externally supplied norm_stats).

  • norm_stats (dict or None) – Externally supplied {"x_mean": Tensor, "x_std": Tensor}. If None and normalize is True, computed from the loaded data.

  • throat_weight (float or None) – Stations where is_throat == 1 receive this weight; others receive weight 1.

  • downstream_weight (float or None) – Stations where is_downstream == 1 receive this weight. Applied multiplicatively with throat_weight when both are set.

  • include_case_idx (bool) – If True, __getitem__ returns a case-index tensor as the last element so that per-case losses can be computed.

  • min_Dr (float or None) – If set, exclude cases whose diameter ratio Dr is below this value. Dr is parsed from the case name (Re_*__Dr_XpXXX__Lr_*).

  • engineered_feature_names (list[str] or None) – Names of synthesized columns appended to each row in the order given. If None, no engineered columns are synthesized. Caller-supplied because engineered-feature schemas are case-specific.

  • engineered_feature_builder (callable or None) – (features, raw_feature_names) -> dict[name, ndarray[N]] mapping each name in engineered_feature_names to a 1-D column. Required when engineered_feature_names is set.

  • target_transform (callable or None) – Case-specific transform applied to full_y after local-velocity normalisation. Receives (full_y, full_x) plus keyword context (target_names, feature_names, case_meta_list, rows_per_case, local_velocity_normalization) and returns (transformed_y, extras). When extras contains baseline_encoded, it is stashed as self._baseline_encoded and self.has_target_baseline is set to True so consumers can re-add the baseline at decode time.

property in_features: int
property out_features: int
property sim_names: list[str]

Unique case IDs in discovery order (compatible with split_indices).

subset_by_case_indices(case_indices)[source]

Return a new dataset containing only rows for the given case indices.

case_indices indexes into self.sim_names.

Return type:

TabularPairDataset

add_baseline_to_encoded(encoded, row_mask=None, field_idx=None)[source]

Re-add the per-row encoded baseline to an encoded tensor.

No-op when no target baseline was attached (has_target_baseline is False) so callers can use it unconditionally at decode boundaries.

Parameters:
  • encoded (torch.Tensor) – Encoded tensor in residual space (e.g. a model prediction).

  • row_mask (np.ndarray | torch.Tensor | None) – Optional row selector — numpy boolean / integer array or torch tensor — to slice the baseline before adding.

  • field_idx (int | None) – Output-field index when the dataset has multiple target columns. When None, falls back to auto-squeezing a single-field baseline so 1-D encoded tensors broadcast correctly.

Return type:

torch.Tensor

Split I/O and plotting

Helpers for exporting reusable train/test split files.

training.split_io.write_sim_name_list(path, sim_names)[source]

Write one simulation name per line and return the resolved output path.

Return type:

Path

training.split_io.export_split_files(train_sims, test_sims, output_dir, *, train_filename='train.txt', test_filename='test.txt')[source]

Export train/test simulation-name lists into a directory.

Return type:

dict[str, str]

training.split_io.load_run_meta(run_meta_path)[source]

Load a run_meta JSON file.

Return type:

dict[str, Any]

training.split_io.export_split_files_from_run_meta(run_meta_path, output_dir, *, train_filename='train.txt', test_filename='test.txt')[source]

Export reusable split files from a run_meta.json file.

Return type:

dict[str, str]

Plotting helpers for evaluation outputs.

training.plotting.parse_index_list(raw)[source]
Return type:

list[int] | None

training.plotting.resolve_plot_indices(num_cases, raw_indices, max_cases)[source]
Return type:

list[int]

training.plotting.select_best_worst_pointwise_cases(extended_metrics, output_fields)[source]

Choose one best and one worst pointwise case for profile plotting.

Return type:

list[dict[str, Any]]

training.plotting.save_pointwise_profile_plots(model, dataset, output_fields, device, plot_dir, case_entries, *, plot_dpi=150, decode_fn=None, baseline_fn=None)[source]

Save best/worst profile plots for pointwise/tabular models.

When decode_fn is supplied (typically experiment.decode_for_plotting), the plotter applies it to both predicted and target tensors before plotting; the function returns the decoded values and the y-axis label. When decode_fn is None or returns None, the plotter shows raw encoded values with field_name as the label.

When baseline_fn is supplied (typically experiment.baseline_for_plotting) and returns a non-None decoded array, it is overlaid as a third “Baseline” curve so model-versus-baseline improvement is visible. Has no effect when baseline_fn is None or the callback returns None.

Return type:

list[str]

training.plotting.save_profile_prediction_plots(model, dataset, output_fields, device, plot_dir, case_entries, *, plot_dpi=150, decode_fn=None, baseline_fn=None)[source]

Save best/worst per-case prediction plots for profile (1D-conv) models.

The profile model consumes per-case [1, C, S] tensors and emits [1, O, S] predictions. This plotter mirrors save_pointwise_profile_plots() (same z_hat axis, same decode_fn / baseline_fn contract, same RMSE subtitle) but runs the model once per case rather than once per row.

Return type:

list[str]

training.plotting.save_parity_plot(*, cat_preds, cat_targets, dataset, output_fields, plot_dir, decode_fn=None, plot_dpi=150)[source]

Save one parity plot per output field across the whole evaluation set.

Each scatter point is one (case, station) pair: ground-truth value on X, predicted value on Y, with a y=x reference line. When decode_fn is supplied, predictions and targets are mapped to physical space first (case-by-case via the same hook the per-case plots use). When the dataset’s input_columns include is_upstream / is_throat / is_downstream flags, points are colored by region.

Accepts both pointwise ([N_rows, O]) and profile ([N_cases, O, S]) prediction tensors; profile tensors are scattered to flat row layout via dataset._case_slices so the same plotting path works for both.

Return type:

list[str]

training.plotting.save_delta_p_parity_plot(*, per_case, plot_dir, plot_dpi=150)[source]

Save a Δp parity plot: one marker per test case, log-log.

per_case is the full sorted list produced by compute_delta_p_metrics (extended.delta_p.per_case). Each entry must carry delta_p_gt / delta_p_pred; when Dr is present, points are colored on a discrete diameter-ratio colormap so the low-Dr cluster (the typical alpha-D failure mode) stands out.

Return type:

str | None

training.plotting.save_grid_prediction_plots(model, dataset, output_fields, device, plot_dir, plot_indices, plot_cmap='viridis', plot_dpi=150, quiver_step=4, vel_x_field='vel_x', vel_y_field='vel_y')[source]
Return type:

list[str]

Model registry

Model registry and entrypoint resolution for training workflows.

class training.models.ModelEntry(build_fn, adapter)[source]

Bases: object

build_fn: Callable
adapter: str
training.models.register_model(name, build_fn, adapter)[source]
Return type:

None

training.models.resolve_entrypoint(entrypoint)[source]
Return type:

Callable

training.models.get_build_fn_and_adapter(model_cfg)[source]

Resolve model build function and adapter.

Built-in models derive adapter from registry. Custom entrypoints must specify model.adapter explicitly.

Return type:

tuple[Callable, str]

training.models.model_entrypoint_string(model_cfg, build_fn)[source]
Return type:

str

Built-in models

Built-in FullyConnected MLP model definition.

Supports optional inter-layer dropout for improved regularisation in deeper networks. Dropout is applied between hidden layers during training and disabled at eval time, so no extra state needs to be persisted in the checkpoint.

training.models.mlp.build(model_cfg, dataset_info)[source]

Built-in FNO model definition.

training.models.fno.build(model_cfg, dataset_info)[source]

Built-in AFNO model definition.

training.models.afno.build(model_cfg, dataset_info)[source]

Built-in Pix2Pix model definition.

training.models.pix2pix.build(model_cfg, dataset_info)[source]

Built-in MeshGraphNet model definition.

training.models.meshgraphnet.build(model_cfg, dataset_info)[source]

1D-convolutional profile model.

Each case is one sample of shape [C, S] (channels = features, S = stations) and the model produces [O, S] per case. Replicate padding avoids zero-pad artefacts at the boundary; dilations widen the receptive field without increasing depth.

The class subclasses physicsnemo.Module so checkpoints round-trip through Module.from_checkpoint. PhysicsNeMo records cls.__module__/cls.__name__ at save time and re-imports the class by getattr(module, name) at load time, so the class must live at module scope (not inside build).

When physicsnemo is unavailable in the environment the module imports cleanly but the model is not registered — mirroring how GraphAdapter only activates if PyG is importable. This keeps _load_builtins from hard-failing in envs without physicsnemo.

Backward-compat: AlphaDConv1D was the original class name. The alias subclass below keeps old .mdlus checkpoints loadable, since PhysicsNeMo embeds __name__ in the saved archive.

class training.models.conv1d_profile.Conv1DProfile(in_channels, out_channels, hidden, num_blocks, kernel_size, dilations, dropout, activation='silu')[source]

Bases: Module

Residual dilated 1D-conv stack over the station axis.

forward(x)[source]
Return type:

Tensor

class training.models.conv1d_profile.AlphaDConv1D(in_channels, out_channels, hidden, num_blocks, kernel_size, dilations, dropout, activation='silu')[source]

Bases: Conv1DProfile

Backward-compat alias for old .mdlus checkpoints.

training.models.conv1d_profile.build(model_cfg, dataset_info)[source]

Hyperparameter optimization

Optuna-based hyperparameter optimization for the training pipeline.

Study creation, orchestration, and artifact saving.

training.hpo.study.create_study(hpo_cfg)[source]

Create or resume an Optuna study from config.

Return type:

Study

training.hpo.study.run_hpo(cfg_dict)[source]

Run hyperparameter optimization.

Parameters:

cfg_dict (dict) – Full Hydra config (resolved) with hpo section + base training config.

Returns:

Summary with best trial info, study stats, saved artifacts.

Return type:

dict

Optuna objective function for hyperparameter optimization.

training.hpo.objective.make_objective(base_cfg, search_space, hpo_cfg, prepared, train_inner_idx, val_idx)[source]

Create an Optuna objective function.

The returned closure rebuilds only the model, optimizer, and loss per trial. The dataset, adapter, and dataset_info are cached from prepared (the output of runner.prepare_training).

Parameters:
  • base_cfg (dict) – Full training config without the hpo section.

  • search_space (dict) – YAML search-space definition (dot-path -> spec).

  • hpo_cfg (dict) – The hpo section of the config.

  • prepared (dict) – Output of runner.prepare_training(base_cfg).

  • train_inner_idx (list[int]) – Case indices (into dataset.sim_names) for inner training.

  • val_idx (list[int]) – Case indices for validation.

Return type:

Callable[[Trial], float]

Parse YAML search space definitions into Optuna trial suggestions.

training.hpo.search_space.validate_search_space(search_space, base_cfg)[source]

Check that all search-space keys are safe and exist in base_cfg.

Raises ValueError for unsafe prefixes and KeyError for dot-paths that do not exist in base_cfg (catches typos).

Return type:

None

training.hpo.search_space.sample_from_search_space(trial, search_space)[source]

Sample hyperparameters from a YAML search-space definition.

Each entry in search_space maps a dot-path config key to a spec dict:

training.lr:
  type: float
  low: 1e-5
  high: 1e-2
  log: true

Supported types: float, int, categorical.

Returns a dict mapping dot-path keys to sampled values.

Return type:

dict[str, Any]

training.hpo.search_space.apply_overrides(base_cfg, overrides)[source]

Deep-copy base_cfg and set values at dot-paths.

Raises KeyError if any intermediate or leaf key does not already exist in the base config. This prevents silent creation of new keys from typos.

Return type:

dict

Non-fatal Optuna visualization helpers.

training.hpo.visualize.save_study_plots(study, output_dir)[source]

Generate and save standard Optuna plots. Non-fatal on errors.

Returns list of saved file paths (may be empty if matplotlib or Optuna visualization is not available).

Return type:

list[str]