Source code for training.datasets_tabular

"""Tabular dataset for pointwise/axial-profile MLP training.

Reads per-case Zarr stores produced by the alpha_D ETL pipeline.
Each store contains:

    {case_name}.zarr/
        features/   float32 [N_stations, D_in]
        targets/    float32 [N_stations, D_out]
        metadata/   attrs: case_id, feature_names, target_names, ...

All cases are loaded and concatenated row-wise.  Splitting is done at
the case level via ``subset_by_case_indices``.
"""

from __future__ import annotations

from collections.abc import Callable
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset


[docs] class TabularPairDataset(Dataset): """Reads a directory of ``.zarr`` stores and produces ``(x, y)`` pairs. Parameters ---------- zarr_dir : str or Path Directory containing ``*.zarr`` stores. input_columns : list[str] or None Feature column names to select. If *None*, use all features. output_columns : list[str] or None Target column names to select. If *None*, use all targets. normalize : bool If *True*, z-score normalize input features after loading. Statistics are computed from the loaded data (or from externally supplied ``norm_stats``). norm_stats : dict or None Externally supplied ``{"x_mean": Tensor, "x_std": Tensor}``. If *None* and *normalize* is True, computed from the loaded data. throat_weight : float or None Stations where ``is_throat == 1`` receive this weight; others receive weight 1. downstream_weight : float or None Stations where ``is_downstream == 1`` receive this weight. Applied multiplicatively with ``throat_weight`` when both are set. include_case_idx : bool If *True*, ``__getitem__`` returns a case-index tensor as the last element so that per-case losses can be computed. min_Dr : float or None If set, exclude cases whose diameter ratio Dr is below this value. Dr is parsed from the case name (``Re_*__Dr_XpXXX__Lr_*``). engineered_feature_names : list[str] or None Names of synthesized columns appended to each row in the order given. If *None*, no engineered columns are synthesized. Caller-supplied because engineered-feature schemas are case-specific. engineered_feature_builder : callable or None ``(features, raw_feature_names) -> dict[name, ndarray[N]]`` mapping each name in ``engineered_feature_names`` to a 1-D column. Required when ``engineered_feature_names`` is set. target_transform : callable or None Case-specific transform applied to ``full_y`` after local-velocity normalisation. Receives ``(full_y, full_x)`` plus keyword context (``target_names``, ``feature_names``, ``case_meta_list``, ``rows_per_case``, ``local_velocity_normalization``) and returns ``(transformed_y, extras)``. When extras contains ``baseline_encoded``, it is stashed as ``self._baseline_encoded`` and ``self.has_target_baseline`` is set to True so consumers can re-add the baseline at decode time. """ def __init__( self, zarr_dir: str | Path, input_columns: list[str] | None = None, output_columns: list[str] | None = None, normalize: bool = False, norm_stats: dict | None = None, norm_from_case_indices: list[int] | None = None, throat_weight: float | None = None, downstream_weight: float | None = None, include_case_idx: bool = False, exclude_cases: list[str] | None = None, local_velocity_normalization: bool = False, min_Dr: float | None = None, target_transform: Callable[..., tuple] | None = None, engineered_feature_names: list[str] | None = None, engineered_feature_builder: Callable[..., dict] | None = None, ): import json import zarr self.zarr_dir = Path(zarr_dir) sim_paths = sorted(self.zarr_dir.glob("*.zarr")) if not sim_paths: raise FileNotFoundError(f"No .zarr stores found in {self.zarr_dir}") # Filter out excluded cases if exclude_cases: exclude_set = set(exclude_cases) sim_paths = [sp for sp in sim_paths if sp.stem not in exclude_set] # Filter by minimum diameter ratio if min_Dr is not None: sim_paths = [sp for sp in sim_paths if self._parse_Dr(sp.stem) >= min_Dr] all_x: list[np.ndarray] = [] all_y: list[np.ndarray] = [] all_w: list[np.ndarray] = [] case_ids: list[str] = [] rows_per_case: list[int] = [] case_meta_list: list[dict] = [] has_weights = False eng_names: list[str] = ( list(engineered_feature_names) if engineered_feature_names is not None else [] ) if eng_names and engineered_feature_builder is None: raise ValueError( "engineered_feature_builder is required when engineered_feature_names is set." ) for sp in sim_paths: root = zarr.open(store=str(sp), mode="r") features = np.array(root["features"][:], dtype=np.float32) targets = np.array(root["targets"][:], dtype=np.float32) if "sample_weight" in root: weights = np.array(root["sample_weight"][:], dtype=np.float32) has_weights = True else: weights = np.ones(features.shape[0], dtype=np.float32) meta = root["metadata"] case_id = str(meta.attrs.get("case_id", sp.stem)) # On first store, resolve column names if not case_ids: raw_feature_names = json.loads(meta.attrs.get("feature_names", "[]")) raw_target_names = json.loads(meta.attrs.get("target_names", "[]")) self._base_feature_names = list(raw_feature_names) self._all_feature_names = list(raw_feature_names) + eng_names self._all_target_names = list(raw_target_names) if output_columns is not None: tgt_map = {n: i for i, n in enumerate(raw_target_names)} missing = [c for c in output_columns if c not in tgt_map] if missing: raise ValueError(f"Unknown output columns: {missing}") self._tgt_idx = [tgt_map[c] for c in output_columns] self.output_columns = list(output_columns) else: self._tgt_idx = list(range(targets.shape[1])) self.output_columns = list(raw_target_names) # Pass the case's metadata through verbatim. Case code reads # whatever keys it needs (with its own ``.get`` defaults); the # generic dataset doesn't know which keys are physics-relevant. case_meta_list.append(dict(meta.attrs)) # Load ALL base features (derived columns need access to # source columns that may not be in input_columns). if eng_names: engineered = engineered_feature_builder(features, raw_feature_names) engineered_cols = [engineered[name].reshape(-1, 1) for name in eng_names] all_x.append( np.concatenate([features] + engineered_cols, axis=1).astype(np.float32) ) else: all_x.append(features.astype(np.float32)) all_y.append(targets[:, self._tgt_idx]) all_w.append(weights) case_ids.append(case_id) rows_per_case.append(features.shape[0]) # ---------------------------------------------------------- # Concatenate # ---------------------------------------------------------- full_x = np.concatenate(all_x, axis=0) # [N, D_base] full_y = np.concatenate(all_y, axis=0) # [N, D_out] # Store per-case metadata self._case_meta = case_meta_list self.exclude_cases = list(exclude_cases) if exclude_cases else [] # Store raw geometry columns (before normalization) for delta_p loss z_hat_col = ( self._all_feature_names.index("z_hat") if "z_hat" in self._all_feature_names else None ) d_over_D_col = ( self._all_feature_names.index("d_local_over_D") if "d_local_over_D" in self._all_feature_names else None ) self._raw_z_hat = ( torch.from_numpy(full_x[:, z_hat_col].copy()) if z_hat_col is not None else None ) self._raw_d_local_over_D = ( torch.from_numpy(full_x[:, d_over_D_col].copy()) if d_over_D_col is not None else None ) # ---------------------------------------------------------- # Optional case-supplied target transform. Extras keys consumed by # the dataset: ``baseline_encoded`` (stashed on # ``self._baseline_encoded`` for downstream consumers — metrics, # plotting, the Δp integral) and ``local_velocity_normalization`` # (propagated onto ``self.local_velocity_normalization``). # ---------------------------------------------------------- self._baseline_encoded: torch.Tensor | None = None self.has_target_baseline = False self.local_velocity_normalization = False if target_transform is not None: full_y, extras = target_transform( full_y, full_x, target_names=self.output_columns, feature_names=self._all_feature_names, case_meta_list=case_meta_list, rows_per_case=rows_per_case, local_velocity_normalization=local_velocity_normalization, ) extras = extras or {} baseline = extras.get("baseline_encoded") if baseline is not None: self._baseline_encoded = torch.from_numpy(np.asarray(baseline, dtype=np.float32)) self.has_target_baseline = True self.local_velocity_normalization = bool( extras.get("local_velocity_normalization", False) ) # Resolve input columns if input_columns is not None: feat_map = {n: i for i, n in enumerate(self._all_feature_names)} missing = [c for c in input_columns if c not in feat_map] if missing: raise ValueError(f"Unknown input columns: {missing}") self._feat_idx = [feat_map[c] for c in input_columns] self.input_columns = list(input_columns) else: self._feat_idx = list(range(len(self._base_feature_names))) self.input_columns = list(self._base_feature_names) self._x = torch.from_numpy(full_x[:, self._feat_idx].copy()) self._y = torch.from_numpy(full_y) raw_w = np.concatenate(all_w, axis=0) self._w = torch.from_numpy(raw_w).unsqueeze(-1) if has_weights else None self._case_ids_unique = case_ids self._rows_per_case = rows_per_case # Build per-row case index for subsetting self._row_case_idx = np.concatenate( [np.full(n, i, dtype=np.int32) for i, n in enumerate(rows_per_case)] ) # ---------------------------------------------------------- # Region weights (throat) # ---------------------------------------------------------- throat_col_full = ( self._all_feature_names.index("is_throat") if "is_throat" in self._all_feature_names else None ) self.throat_weight = throat_weight self.downstream_weight = downstream_weight if throat_weight is not None and throat_weight > 0 and throat_col_full is not None: new_w = torch.ones(len(self._x), dtype=torch.float32) new_w[full_x[:, throat_col_full] > 0.5] = float(throat_weight) self._w = new_w.unsqueeze(-1) # Region weights (downstream) downstream_col_full = ( self._all_feature_names.index("is_downstream") if "is_downstream" in self._all_feature_names else None ) if ( downstream_weight is not None and downstream_weight > 0 and downstream_col_full is not None ): if self._w is not None: # Multiply with existing weights (e.g. throat weights) ds_mask = full_x[:, downstream_col_full] > 0.5 self._w = self._w.clone() self._w[ds_mask] *= float(downstream_weight) else: new_w = torch.ones(len(self._x), dtype=torch.float32) new_w[full_x[:, downstream_col_full] > 0.5] = float(downstream_weight) self._w = new_w.unsqueeze(-1) # Case index tensor for per-case losses self.include_case_idx = include_case_idx self._case_idx_tensor = torch.from_numpy(self._row_case_idx).long() # Feature normalization self.normalize = normalize if norm_stats is not None: self.norm_stats = self._coerce_norm_stats(norm_stats, dtype=self._x.dtype) elif normalize: stats_source_x = self._x if norm_from_case_indices is not None: keep = sorted({int(i) for i in norm_from_case_indices}) if not keep: raise ValueError("norm_from_case_indices must not be empty.") invalid = [i for i in keep if i < 0 or i >= len(self._case_ids_unique)] if invalid: raise ValueError( f"norm_from_case_indices contains out-of-range case index(es): {invalid}" ) mask = np.isin(self._row_case_idx, keep) if not np.any(mask): raise ValueError( "norm_from_case_indices selected zero rows; cannot compute normalization stats." ) stats_source_x = self._x[mask] self.norm_stats = self._compute_norm_stats(stats_source_x) else: self.norm_stats = None if self.normalize and self.norm_stats is not None: self._x = (self._x - self.norm_stats["x_mean"]) / self.norm_stats["x_std"] # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @staticmethod def _parse_Dr(case_name: str) -> float: """Extract the diameter ratio from a case name like ``Re_*__Dr_0p333__Lr_*``.""" for part in case_name.split("__"): if part.startswith("Dr_"): return float(part[3:].replace("p", ".")) return 0.0 # ------------------------------------------------------------------ # Normalization # ------------------------------------------------------------------ @staticmethod def _compute_norm_stats(x: torch.Tensor) -> dict[str, torch.Tensor]: """Compute per-feature mean and std from input data tensor.""" return { "x_mean": x.mean(dim=0), "x_std": x.std(dim=0).clamp(min=1e-8), } @staticmethod def _coerce_norm_stats( norm_stats: dict, *, dtype: torch.dtype, ) -> dict[str, torch.Tensor]: if "x_mean" not in norm_stats or "x_std" not in norm_stats: raise ValueError("norm_stats must contain both 'x_mean' and 'x_std'.") x_mean = torch.as_tensor(norm_stats["x_mean"], dtype=dtype) x_std = torch.as_tensor(norm_stats["x_std"], dtype=dtype).clamp(min=1e-8) if x_mean.ndim != 1 or x_std.ndim != 1: raise ValueError("norm_stats['x_mean'] and norm_stats['x_std'] must be 1D.") if x_mean.shape != x_std.shape: raise ValueError( "norm_stats['x_mean'] and norm_stats['x_std'] must have matching shapes." ) return {"x_mean": x_mean, "x_std": x_std} # ------------------------------------------------------------------ # Public properties # ------------------------------------------------------------------ @property def in_features(self) -> int: return len(self.input_columns) @property def out_features(self) -> int: return len(self.output_columns) @property def sim_names(self) -> list[str]: """Unique case IDs in discovery order (compatible with split_indices).""" return self._case_ids_unique # ------------------------------------------------------------------ # Dataset interface # ------------------------------------------------------------------ def __len__(self) -> int: return len(self._x) def __getitem__(self, idx: int): if self.include_case_idx: if self._w is not None: return self._x[idx], self._y[idx], self._w[idx], self._case_idx_tensor[idx] return self._x[idx], self._y[idx], self._case_idx_tensor[idx] if self._w is not None: return self._x[idx], self._y[idx], self._w[idx] return self._x[idx], self._y[idx] # ------------------------------------------------------------------ # Case-level subsetting # ------------------------------------------------------------------
[docs] def subset_by_case_indices(self, case_indices: list[int]) -> "TabularPairDataset": """Return a new dataset containing only rows for the given case indices. ``case_indices`` indexes into ``self.sim_names``. """ keep = set(case_indices) mask = np.isin(self._row_case_idx, list(keep)) new = object.__new__(TabularPairDataset) new.zarr_dir = self.zarr_dir new.input_columns = list(self.input_columns) new.output_columns = list(self.output_columns) new._base_feature_names = list(self._base_feature_names) new._all_feature_names = list(self._all_feature_names) new._all_target_names = list(self._all_target_names) new._feat_idx = list(self._feat_idx) new._tgt_idx = list(self._tgt_idx) new._x = self._x[mask] new._y = self._y[mask] new._w = self._w[mask] if self._w is not None else None new._case_ids_unique = [self._case_ids_unique[i] for i in case_indices] new.normalize = self.normalize new.norm_stats = self.norm_stats # share parent's stats (don't recompute) new.throat_weight = self.throat_weight new.downstream_weight = self.downstream_weight new.include_case_idx = self.include_case_idx new.local_velocity_normalization = self.local_velocity_normalization new.exclude_cases = list(self.exclude_cases) # Propagate per-case metadata and raw geometry arrays new._case_meta = [self._case_meta[i] for i in case_indices] new._raw_z_hat = self._raw_z_hat[mask] if self._raw_z_hat is not None else None new._raw_d_local_over_D = ( self._raw_d_local_over_D[mask] if self._raw_d_local_over_D is not None else None ) # Rebuild rows_per_case and row_case_idx for the subset new._rows_per_case = [self._rows_per_case[i] for i in case_indices] new._row_case_idx = np.concatenate( [np.full(n, new_i, dtype=np.int32) for new_i, n in enumerate(new._rows_per_case)] ) new._case_idx_tensor = torch.from_numpy(new._row_case_idx).long() new.has_target_baseline = self.has_target_baseline new._baseline_encoded = ( self._baseline_encoded[mask] if self._baseline_encoded is not None else None ) return new
# ------------------------------------------------------------------ # Residual-target helpers # ------------------------------------------------------------------
[docs] def add_baseline_to_encoded( self, encoded: torch.Tensor, row_mask: np.ndarray | torch.Tensor | None = None, field_idx: int | None = None, ) -> torch.Tensor: """Re-add the per-row encoded baseline to an encoded tensor. No-op when no target baseline was attached (``has_target_baseline`` is False) so callers can use it unconditionally at decode boundaries. Parameters ---------- encoded Encoded tensor in residual space (e.g. a model prediction). row_mask Optional row selector — numpy boolean / integer array or torch tensor — to slice the baseline before adding. field_idx Output-field index when the dataset has multiple target columns. When ``None``, falls back to auto-squeezing a single-field baseline so 1-D ``encoded`` tensors broadcast correctly. """ if not self.has_target_baseline or self._baseline_encoded is None: return encoded bl = self._baseline_encoded if row_mask is not None: if isinstance(row_mask, np.ndarray): row_mask = torch.as_tensor(row_mask) bl = bl[row_mask] if field_idx is not None: bl = bl[..., field_idx] elif encoded.dim() == 1 and bl.dim() == 2 and bl.shape[1] == 1: # Allow callers to pass a 1-D field-slice; broadcast accordingly. bl = bl.squeeze(-1) return encoded + bl.to(encoded.dtype).to(encoded.device)