Source code for cases.alpha_d.datasets.profile

"""Per-case profile dataset for 1D-conv α_D training.

Wraps a :class:`TabularPairDataset` and exposes per-case views shaped
``(features, stations)`` so that a 1D conv along the station axis can
treat each case as a single sample.

The wrapper delegates flat row-level state to the inner tabular dataset,
which is what every existing access site in ``runner.py``,
``experiments/alpha_d.py``, and ``plotting.py`` reads. Subsets are built
by delegating to ``TabularPairDataset.subset_by_case_indices`` so that
flat state stays aligned with the train/val/test case split.
"""

from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset

from cases.alpha_d.feature_data import engineered_features_spec
from cases.alpha_d.transforms import alpha_d_residual_transform
from training.datasets import parse_field_list
from training.datasets_tabular import TabularPairDataset


[docs] class AlphaDProfileDataset(Dataset): """Per-case dataset producing ``(x, y, w, case_idx)`` profile tensors. Shapes per item: x : ``[in_features, n_stations]`` y : ``[out_features, n_stations]`` w : ``[1, n_stations]`` (broadcast-compatible with y) case_idx : scalar long tensor Stations are sorted by ``z_hat`` per case so the conv sees a monotone spatial sequence. """ def __init__(self, **tabular_kwargs): if "engineered_feature_names" not in tabular_kwargs: names, builder = engineered_features_spec() tabular_kwargs["engineered_feature_names"] = names tabular_kwargs["engineered_feature_builder"] = builder tabular_kwargs.setdefault("target_transform", alpha_d_residual_transform) self._inner = TabularPairDataset(**tabular_kwargs) self._case_slices = _build_case_slices(self._inner) @classmethod def _from_inner(cls, inner: TabularPairDataset) -> "AlphaDProfileDataset": new = object.__new__(cls) new._inner = inner new._case_slices = _build_case_slices(inner) return new # ------------------------------------------------------------------ # Delegated flat properties (consumed by runner / plotting / alpha_d) # ------------------------------------------------------------------ @property def _x(self): return self._inner._x @property def _y(self): return self._inner._y @property def _w(self): return self._inner._w @property def _baseline_encoded(self): return self._inner._baseline_encoded @property def _row_case_idx(self): return self._inner._row_case_idx @property def _case_ids_unique(self): return self._inner._case_ids_unique @property def _case_meta(self): return self._inner._case_meta @property def _raw_z_hat(self): return self._inner._raw_z_hat @property def _raw_d_local_over_D(self): return self._inner._raw_d_local_over_D @property def norm_stats(self): return self._inner.norm_stats @property def normalize(self): return self._inner.normalize @property def has_target_baseline(self): return self._inner.has_target_baseline @property def local_velocity_normalization(self): return self._inner.local_velocity_normalization @property def exclude_cases(self): return self._inner.exclude_cases @property def input_columns(self): return self._inner.input_columns @property def output_columns(self): return self._inner.output_columns @property def in_features(self): return self._inner.in_features @property def out_features(self): return self._inner.out_features @property def sim_names(self): return self._inner.sim_names # ------------------------------------------------------------------ # Dataset interface # ------------------------------------------------------------------ def __len__(self) -> int: return len(self._case_slices) def __getitem__(self, ci: int): idx = self._case_slices[ci] x = self._inner._x[idx].T.contiguous() # [C, S] y = self._inner._y[idx].T.contiguous() # [O, S] if self._inner._w is not None: w = self._inner._w[idx].squeeze(-1).unsqueeze(0).contiguous() # [1, S] else: w = torch.ones(1, len(idx), dtype=x.dtype) return x, y, w, torch.tensor(ci, dtype=torch.long) # ------------------------------------------------------------------ # Case-level subsetting # ------------------------------------------------------------------
[docs] def subset_by_case_indices(self, case_indices) -> "AlphaDProfileDataset": case_indices = [int(i) for i in case_indices] return AlphaDProfileDataset._from_inner(self._inner.subset_by_case_indices(case_indices))
[docs] def add_baseline_to_encoded(self, encoded, row_mask=None, field_idx=None): """Delegate to the inner TabularPairDataset.""" return self._inner.add_baseline_to_encoded( encoded, row_mask=row_mask, field_idx=field_idx, )
def _build_case_slices(inner: TabularPairDataset) -> list[np.ndarray]: """Index per-case row slices into ``inner``, sorted by z_hat.""" slices: list[np.ndarray] = [] z_hat = inner._raw_z_hat for ci in range(len(inner._case_ids_unique)): idx = np.where(inner._row_case_idx == ci)[0] if z_hat is not None: order = np.argsort(z_hat[idx].numpy()) idx = idx[order] slices.append(idx) return slices
[docs] def build_dataset(data_cfg: dict) -> AlphaDProfileDataset: """Construct an :class:`AlphaDProfileDataset` from a Hydra data config. Called by :class:`training.adapters.ProfileAdapter.build_dataset` after it resolves ``data.dataset_entrypoint``. Reads all the alpha-D-flavoured kwargs the generic adapter no longer knows about; engineered features and target transform default-injection still happen inside ``AlphaDProfileDataset.__init__``. """ norm_from_case_indices = data_cfg.get("norm_from_case_indices") if norm_from_case_indices is not None: norm_from_case_indices = [int(i) for i in norm_from_case_indices] if data_cfg.get("input_columns_file") is not None: cols_path = Path(str(data_cfg["input_columns_file"])) if not cols_path.exists(): raise FileNotFoundError(f"input_columns_file not found: {cols_path}") input_columns = [ line.strip() for line in cols_path.read_text().splitlines() if line.strip() ] if not input_columns: raise ValueError(f"input_columns_file is empty: {cols_path}") else: input_columns = parse_field_list(data_cfg.get("input_columns")) def _opt_float(key: str) -> float | None: v = data_cfg.get(key) return float(v) if v is not None else None exclude_cases = data_cfg.get("exclude_cases") if exclude_cases is not None: exclude_cases = [str(c) for c in exclude_cases] return AlphaDProfileDataset( zarr_dir=data_cfg["zarr_dir"], input_columns=input_columns, output_columns=parse_field_list(data_cfg.get("output_columns")), normalize=bool(data_cfg.get("normalize", False)), norm_stats=data_cfg.get("norm_stats"), norm_from_case_indices=norm_from_case_indices, throat_weight=_opt_float("throat_weight"), downstream_weight=_opt_float("downstream_weight"), include_case_idx=bool(data_cfg.get("include_case_idx", False)), exclude_cases=exclude_cases, local_velocity_normalization=bool(data_cfg.get("local_velocity_normalization", False)), min_Dr=_opt_float("min_Dr"), )