training¶
The generic training framework. One pair of entry points (train.py,
evaluate.py) drives every model through a model registry, three or
four adapter classes, a shared runner, and pluggable experiment hooks.
Package¶
Shared helpers for generic training and evaluation workflows.
Adapters¶
Adapter layer to unify grid and graph model families.
- class training.adapters.ModelAdapter[source]¶
Bases:
objectInterface for model/data-family-specific behavior.
- class training.adapters.GridAdapter[source]¶
Bases:
ModelAdapter
- class training.adapters.GraphAdapter[source]¶
Bases:
ModelAdapter
- class training.adapters.PointwiseAdapter[source]¶
Bases:
ModelAdapterAdapter for tabular / pointwise MLP models.
Expects datasets producing
(x, y)tuples of shape(D_in,)and(D_out,)which the default collate batches into(B, D_in)and(B, D_out).
- class training.adapters.ProfileAdapter[source]¶
Bases:
ModelAdapterAdapter for per-case profile (1D-conv) models.
Datasets emit per-case
(x, y, w, case_idx)items shaped[C, S],[O, S],[1, S], scalar. After default collation the batch is[B, C, S],[B, O, S],[B, 1, S],[B].Dataset construction is delegated to a case-side callable resolved from
data.dataset_entrypoint(mirrorsmodel.entrypoint); the adapter itself knows nothing about specific cases.
Runner¶
Generic training/evaluation runners for supervised one-step models.
- training.runner.set_seed(seed)[source]¶
Seed all random number generators for reproducibility.
- Return type:
- training.runner.resolve_device(device_arg)[source]¶
Parse a device string (‘auto’, ‘cpu’, ‘cuda’) into a torch.device.
- Return type:
- training.runner.build_experiment(experiment_entrypoint, model, optimizer, loss_fn, adapter, device, **kwargs)[source]¶
Instantiate an Experiment from an optional entrypoint string.
Extra kwargs are forwarded to the experiment constructor, allowing custom experiments to accept domain-specific arguments.
- Return type:
- training.runner.normalize_split_cfg(split_cfg, default_seed)[source]¶
Fill in default values for a split config dict.
- training.runner.prepare_training(cfg_dict)[source]¶
Build adapter, dataset, dataset_info, and build_fn from a config dict.
Returns a dict with keys: model_cfg, data_cfg, training_cfg, output_cfg, adapter_name, adapter, dataset, dataset_info, build_fn, device, seed. This is the shared setup that both
train()and HPO use.
- training.runner.train_one_epoch(experiment, dataloader)[source]¶
Run one training epoch, return average loss.
Raises
RuntimeErrorif the dataloader produces zero batches.- Return type:
- training.runner.compute_val_loss(experiment, val_loader)[source]¶
Evaluate on a validation loader, return average loss.
Uses
experiment.validation_step()which defaults toeval_step() + loss_fn. Custom experiments can override it.Raises
RuntimeErrorif the loader produces zero batches.- Return type:
- training.runner.train_or_hpo(cfg)[source]¶
Dispatch to Optuna HPO when
cfg.hpo.search_spaceis populated; train otherwise.Single source of truth for the train-vs-HPO branching shared by every
@hydra.mainentry point (top-levelsrc/train.pyand the per-casesrc/cases/<case>/train.pywrappers). Entry-point choice only affects which config is loaded — never whether HPO runs.- Return type:
Experiment¶
Experiment abstraction for optional custom training/eval steps.
- class training.experiment.Experiment(model, optimizer, loss_fn, adapter, device, **kwargs)[source]¶
Bases:
objectDefault experiment for supervised one-step training.
- validation_step(batch)[source]¶
Compute validation loss for one batch.
Override in subclasses for custom validation logic (e.g., weighted metrics, physics-informed losses). The default uses
eval_stepfollowed byloss_fn.- Return type:
- validation_epoch_loss(val_loader)[source]¶
Optional validation loss term computed once per epoch.
Subclasses can override this to add case-level or physics-based penalties that require access to the full validation set rather than individual batches.
- Return type:
- compute_extended_metrics(eval_dataset, all_preds, all_targets)[source]¶
Per-experiment extended evaluation metrics.
Receives per-batch CPU tensor lists (uncatted because graph adapters have variable shapes) and the eval dataset. Subclasses concatenate as needed and produce a metrics dict the runner merges into the
extendedblock of the metrics JSON. Default: empty.
- print_extended_metrics(metrics)[source]¶
Pretty-print the dict returned by
compute_extended_metrics.Default: no-op (the generic runner already prints overall + per-field MSE/RMSE). Subclasses override to format case-specific blocks.
- Return type:
- decode_for_plotting(values, dataset, field_name, mask)[source]¶
Decode encoded model output to physical/labelled space for plotting.
Returns
(decoded_ndarray, label)orNone. WhenNone, the plotter shows raw encoded values withfield_nameas the y-axis label and skips the physical-RMSE subtitle. Case experiments override to apply baseline re-add, unit conversion, etc.
- baseline_for_plotting(dataset, field_name, mask)[source]¶
Return
(decoded_baseline_ndarray, label)orNone.Case experiments that train a residual on top of an analytical baseline override this so the plotter can overlay the baseline as a third curve next to ground truth and prediction. Default: no baseline.
- prepare_for_training(train_dataset, val_dataset, device)[source]¶
Bind dataset-derived state onto the experiment before training.
Called once after datasets are subsetted but before the training loop. Case experiments override to capture per-case geometry, target names, normalisation statistics, etc. that the generic core would otherwise have to read by name. Default: no-op.
- Return type:
Losses¶
Loss registry for supervised one-step training.
Internal datasets¶
MOOSEDataset (under dataset) is the public Dataset API.
The classes below are the input/output-slicing variants used internally
by train.py.
Dataset wrappers and split helpers for training workflows.
- training.datasets.parse_field_list(raw)[source]¶
Parse field selections from CLI-style strings or list values.
- training.datasets.resolve_time_idx(time_idx, num_steps, label)[source]¶
Resolve negative time indexing and validate the index.
- Return type:
- class training.datasets.GridPairDataset(zarr_dir, input_fields, output_fields, input_time_idx, target_time_idx)[source]¶
Bases:
DatasetBuild supervised grid pairs (x, y) from MOOSEDataset grid samples.
- class training.datasets.GraphPairDataset(zarr_dir, input_fields, output_fields, input_time_idx, target_time_idx)[source]¶
Bases:
DatasetBuild supervised PyG Data objects from MOOSEDataset graph samples.
- training.datasets.split_indices(num_cases, split_cfg, sim_names)[source]¶
Return split indices and simulation-name lists.
Tabular dataset for pointwise/axial-profile MLP training.
Reads per-case Zarr stores produced by the alpha_D ETL pipeline. Each store contains:
- {case_name}.zarr/
features/ float32 [N_stations, D_in] targets/ float32 [N_stations, D_out] metadata/ attrs: case_id, feature_names, target_names, …
All cases are loaded and concatenated row-wise. Splitting is done at
the case level via subset_by_case_indices.
- class training.datasets_tabular.TabularPairDataset(zarr_dir, input_columns=None, output_columns=None, normalize=False, norm_stats=None, norm_from_case_indices=None, throat_weight=None, downstream_weight=None, include_case_idx=False, exclude_cases=None, local_velocity_normalization=False, min_Dr=None, target_transform=None, engineered_feature_names=None, engineered_feature_builder=None)[source]¶
Bases:
DatasetReads a directory of
.zarrstores and produces(x, y)pairs.- Parameters:
zarr_dir (str or Path) – Directory containing
*.zarrstores.input_columns (list[str] or None) – Feature column names to select. If None, use all features.
output_columns (list[str] or None) – Target column names to select. If None, use all targets.
normalize (bool) – If True, z-score normalize input features after loading. Statistics are computed from the loaded data (or from externally supplied
norm_stats).norm_stats (dict or None) – Externally supplied
{"x_mean": Tensor, "x_std": Tensor}. If None and normalize is True, computed from the loaded data.throat_weight (float or None) – Stations where
is_throat == 1receive this weight; others receive weight 1.downstream_weight (float or None) – Stations where
is_downstream == 1receive this weight. Applied multiplicatively withthroat_weightwhen both are set.include_case_idx (bool) – If True,
__getitem__returns a case-index tensor as the last element so that per-case losses can be computed.min_Dr (float or None) – If set, exclude cases whose diameter ratio Dr is below this value. Dr is parsed from the case name (
Re_*__Dr_XpXXX__Lr_*).engineered_feature_names (list[str] or None) – Names of synthesized columns appended to each row in the order given. If None, no engineered columns are synthesized. Caller-supplied because engineered-feature schemas are case-specific.
engineered_feature_builder (callable or None) –
(features, raw_feature_names) -> dict[name, ndarray[N]]mapping each name inengineered_feature_namesto a 1-D column. Required whenengineered_feature_namesis set.target_transform (callable or None) – Case-specific transform applied to
full_yafter local-velocity normalisation. Receives(full_y, full_x)plus keyword context (target_names,feature_names,case_meta_list,rows_per_case,local_velocity_normalization) and returns(transformed_y, extras). When extras containsbaseline_encoded, it is stashed asself._baseline_encodedandself.has_target_baselineis set to True so consumers can re-add the baseline at decode time.
- subset_by_case_indices(case_indices)[source]¶
Return a new dataset containing only rows for the given case indices.
case_indicesindexes intoself.sim_names.- Return type:
- add_baseline_to_encoded(encoded, row_mask=None, field_idx=None)[source]¶
Re-add the per-row encoded baseline to an encoded tensor.
No-op when no target baseline was attached (
has_target_baselineis False) so callers can use it unconditionally at decode boundaries.- Parameters:
encoded (torch.Tensor) – Encoded tensor in residual space (e.g. a model prediction).
row_mask (np.ndarray | torch.Tensor | None) – Optional row selector — numpy boolean / integer array or torch tensor — to slice the baseline before adding.
field_idx (int | None) – Output-field index when the dataset has multiple target columns. When
None, falls back to auto-squeezing a single-field baseline so 1-Dencodedtensors broadcast correctly.
- Return type:
torch.Tensor
Split I/O and plotting¶
Helpers for exporting reusable train/test split files.
- training.split_io.write_sim_name_list(path, sim_names)[source]¶
Write one simulation name per line and return the resolved output path.
- Return type:
- training.split_io.export_split_files(train_sims, test_sims, output_dir, *, train_filename='train.txt', test_filename='test.txt')[source]¶
Export train/test simulation-name lists into a directory.
- training.split_io.export_split_files_from_run_meta(run_meta_path, output_dir, *, train_filename='train.txt', test_filename='test.txt')[source]¶
Export reusable split files from a run_meta.json file.
Plotting helpers for evaluation outputs.
- training.plotting.select_best_worst_pointwise_cases(extended_metrics, output_fields)[source]¶
Choose one best and one worst pointwise case for profile plotting.
- training.plotting.save_pointwise_profile_plots(model, dataset, output_fields, device, plot_dir, case_entries, *, plot_dpi=150, decode_fn=None, baseline_fn=None)[source]¶
Save best/worst profile plots for pointwise/tabular models.
When
decode_fnis supplied (typicallyexperiment.decode_for_plotting), the plotter applies it to both predicted and target tensors before plotting; the function returns the decoded values and the y-axis label. Whendecode_fnisNoneor returnsNone, the plotter shows raw encoded values withfield_nameas the label.When
baseline_fnis supplied (typicallyexperiment.baseline_for_plotting) and returns a non-None decoded array, it is overlaid as a third “Baseline” curve so model-versus-baseline improvement is visible. Has no effect whenbaseline_fnisNoneor the callback returnsNone.
- training.plotting.save_profile_prediction_plots(model, dataset, output_fields, device, plot_dir, case_entries, *, plot_dpi=150, decode_fn=None, baseline_fn=None)[source]¶
Save best/worst per-case prediction plots for profile (1D-conv) models.
The profile model consumes per-case
[1, C, S]tensors and emits[1, O, S]predictions. This plotter mirrorssave_pointwise_profile_plots()(same z_hat axis, samedecode_fn/baseline_fncontract, same RMSE subtitle) but runs the model once per case rather than once per row.
- training.plotting.save_parity_plot(*, cat_preds, cat_targets, dataset, output_fields, plot_dir, decode_fn=None, plot_dpi=150)[source]¶
Save one parity plot per output field across the whole evaluation set.
Each scatter point is one (case, station) pair: ground-truth value on X, predicted value on Y, with a y=x reference line. When
decode_fnis supplied, predictions and targets are mapped to physical space first (case-by-case via the same hook the per-case plots use). When the dataset’sinput_columnsincludeis_upstream/is_throat/is_downstreamflags, points are colored by region.Accepts both pointwise (
[N_rows, O]) and profile ([N_cases, O, S]) prediction tensors; profile tensors are scattered to flat row layout viadataset._case_slicesso the same plotting path works for both.
- training.plotting.save_delta_p_parity_plot(*, per_case, plot_dir, plot_dpi=150)[source]¶
Save a Δp parity plot: one marker per test case, log-log.
per_caseis the full sorted list produced bycompute_delta_p_metrics(extended.delta_p.per_case). Each entry must carrydelta_p_gt/delta_p_pred; whenDris present, points are colored on a discrete diameter-ratio colormap so the low-Dr cluster (the typical alpha-D failure mode) stands out.
Model registry¶
Model registry and entrypoint resolution for training workflows.
- training.models.get_build_fn_and_adapter(model_cfg)[source]¶
Resolve model build function and adapter.
Built-in models derive adapter from registry. Custom entrypoints must specify model.adapter explicitly.
Built-in models¶
Built-in FullyConnected MLP model definition.
Supports optional inter-layer dropout for improved regularisation in deeper networks. Dropout is applied between hidden layers during training and disabled at eval time, so no extra state needs to be persisted in the checkpoint.
Built-in FNO model definition.
Built-in AFNO model definition.
Built-in Pix2Pix model definition.
Built-in MeshGraphNet model definition.
1D-convolutional profile model.
Each case is one sample of shape [C, S] (channels = features, S =
stations) and the model produces [O, S] per case. Replicate padding
avoids zero-pad artefacts at the boundary; dilations widen the
receptive field without increasing depth.
The class subclasses physicsnemo.Module so checkpoints round-trip
through Module.from_checkpoint. PhysicsNeMo records
cls.__module__/cls.__name__ at save time and re-imports the
class by getattr(module, name) at load time, so the class must live
at module scope (not inside build).
When physicsnemo is unavailable in the environment the module
imports cleanly but the model is not registered — mirroring how
GraphAdapter only activates if PyG is importable. This keeps
_load_builtins from hard-failing in envs without physicsnemo.
Backward-compat: AlphaDConv1D was the original class name. The
alias subclass below keeps old .mdlus checkpoints loadable, since
PhysicsNeMo embeds __name__ in the saved archive.
- class training.models.conv1d_profile.Conv1DProfile(in_channels, out_channels, hidden, num_blocks, kernel_size, dilations, dropout, activation='silu')[source]¶
Bases:
ModuleResidual dilated 1D-conv stack over the station axis.
- class training.models.conv1d_profile.AlphaDConv1D(in_channels, out_channels, hidden, num_blocks, kernel_size, dilations, dropout, activation='silu')[source]¶
Bases:
Conv1DProfileBackward-compat alias for old
.mdluscheckpoints.
Hyperparameter optimization¶
Optuna-based hyperparameter optimization for the training pipeline.
Study creation, orchestration, and artifact saving.
- training.hpo.study.create_study(hpo_cfg)[source]¶
Create or resume an Optuna study from config.
- Return type:
Study
Optuna objective function for hyperparameter optimization.
- training.hpo.objective.make_objective(base_cfg, search_space, hpo_cfg, prepared, train_inner_idx, val_idx)[source]¶
Create an Optuna objective function.
The returned closure rebuilds only the model, optimizer, and loss per trial. The dataset, adapter, and dataset_info are cached from prepared (the output of
runner.prepare_training).- Parameters:
base_cfg (dict) – Full training config without the
hposection.search_space (dict) – YAML search-space definition (dot-path -> spec).
hpo_cfg (dict) – The
hposection of the config.prepared (dict) – Output of
runner.prepare_training(base_cfg).train_inner_idx (list[int]) – Case indices (into
dataset.sim_names) for inner training.
- Return type:
Parse YAML search space definitions into Optuna trial suggestions.
- training.hpo.search_space.validate_search_space(search_space, base_cfg)[source]¶
Check that all search-space keys are safe and exist in base_cfg.
Raises
ValueErrorfor unsafe prefixes andKeyErrorfor dot-paths that do not exist in base_cfg (catches typos).- Return type:
- training.hpo.search_space.sample_from_search_space(trial, search_space)[source]¶
Sample hyperparameters from a YAML search-space definition.
Each entry in search_space maps a dot-path config key to a spec dict:
training.lr: type: float low: 1e-5 high: 1e-2 log: true
Supported types:
float,int,categorical.Returns a dict mapping dot-path keys to sampled values.
- training.hpo.search_space.apply_overrides(base_cfg, overrides)[source]¶
Deep-copy base_cfg and set values at dot-paths.
Raises
KeyErrorif any intermediate or leaf key does not already exist in the base config. This prevents silent creation of new keys from typos.- Return type:
Non-fatal Optuna visualization helpers.