"""Tabular dataset for pointwise/axial-profile MLP training.
Reads per-case Zarr stores produced by the alpha_D ETL pipeline.
Each store contains:
{case_name}.zarr/
features/ float32 [N_stations, D_in]
targets/ float32 [N_stations, D_out]
metadata/ attrs: case_id, feature_names, target_names, ...
All cases are loaded and concatenated row-wise. Splitting is done at
the case level via ``subset_by_case_indices``.
"""
from __future__ import annotations
from collections.abc import Callable
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
[docs]
class TabularPairDataset(Dataset):
"""Reads a directory of ``.zarr`` stores and produces ``(x, y)`` pairs.
Parameters
----------
zarr_dir : str or Path
Directory containing ``*.zarr`` stores.
input_columns : list[str] or None
Feature column names to select. If *None*, use all features.
output_columns : list[str] or None
Target column names to select. If *None*, use all targets.
normalize : bool
If *True*, z-score normalize input features after loading.
Statistics are computed from the loaded data (or from externally
supplied ``norm_stats``).
norm_stats : dict or None
Externally supplied ``{"x_mean": Tensor, "x_std": Tensor}``.
If *None* and *normalize* is True, computed from the loaded data.
throat_weight : float or None
Stations where ``is_throat == 1`` receive this weight; others
receive weight 1.
downstream_weight : float or None
Stations where ``is_downstream == 1`` receive this weight.
Applied multiplicatively with ``throat_weight`` when both are set.
include_case_idx : bool
If *True*, ``__getitem__`` returns a case-index tensor as the last
element so that per-case losses can be computed.
min_Dr : float or None
If set, exclude cases whose diameter ratio Dr is below this value.
Dr is parsed from the case name (``Re_*__Dr_XpXXX__Lr_*``).
engineered_feature_names : list[str] or None
Names of synthesized columns appended to each row in the order given.
If *None*, no engineered columns are synthesized. Caller-supplied
because engineered-feature schemas are case-specific.
engineered_feature_builder : callable or None
``(features, raw_feature_names) -> dict[name, ndarray[N]]`` mapping
each name in ``engineered_feature_names`` to a 1-D column. Required
when ``engineered_feature_names`` is set.
target_transform : callable or None
Case-specific transform applied to ``full_y`` after local-velocity
normalisation. Receives ``(full_y, full_x)`` plus keyword context
(``target_names``, ``feature_names``, ``case_meta_list``,
``rows_per_case``, ``local_velocity_normalization``) and returns
``(transformed_y, extras)``. When extras contains ``baseline_encoded``,
it is stashed as ``self._baseline_encoded`` and ``self.has_target_baseline``
is set to True so consumers can re-add the baseline at decode time.
"""
def __init__(
self,
zarr_dir: str | Path,
input_columns: list[str] | None = None,
output_columns: list[str] | None = None,
normalize: bool = False,
norm_stats: dict | None = None,
norm_from_case_indices: list[int] | None = None,
throat_weight: float | None = None,
downstream_weight: float | None = None,
include_case_idx: bool = False,
exclude_cases: list[str] | None = None,
local_velocity_normalization: bool = False,
min_Dr: float | None = None,
target_transform: Callable[..., tuple] | None = None,
engineered_feature_names: list[str] | None = None,
engineered_feature_builder: Callable[..., dict] | None = None,
):
import json
import zarr
self.zarr_dir = Path(zarr_dir)
sim_paths = sorted(self.zarr_dir.glob("*.zarr"))
if not sim_paths:
raise FileNotFoundError(f"No .zarr stores found in {self.zarr_dir}")
# Filter out excluded cases
if exclude_cases:
exclude_set = set(exclude_cases)
sim_paths = [sp for sp in sim_paths if sp.stem not in exclude_set]
# Filter by minimum diameter ratio
if min_Dr is not None:
sim_paths = [sp for sp in sim_paths if self._parse_Dr(sp.stem) >= min_Dr]
all_x: list[np.ndarray] = []
all_y: list[np.ndarray] = []
all_w: list[np.ndarray] = []
case_ids: list[str] = []
rows_per_case: list[int] = []
case_meta_list: list[dict] = []
has_weights = False
eng_names: list[str] = (
list(engineered_feature_names) if engineered_feature_names is not None else []
)
if eng_names and engineered_feature_builder is None:
raise ValueError(
"engineered_feature_builder is required when engineered_feature_names is set."
)
for sp in sim_paths:
root = zarr.open(store=str(sp), mode="r")
features = np.array(root["features"][:], dtype=np.float32)
targets = np.array(root["targets"][:], dtype=np.float32)
if "sample_weight" in root:
weights = np.array(root["sample_weight"][:], dtype=np.float32)
has_weights = True
else:
weights = np.ones(features.shape[0], dtype=np.float32)
meta = root["metadata"]
case_id = str(meta.attrs.get("case_id", sp.stem))
# On first store, resolve column names
if not case_ids:
raw_feature_names = json.loads(meta.attrs.get("feature_names", "[]"))
raw_target_names = json.loads(meta.attrs.get("target_names", "[]"))
self._base_feature_names = list(raw_feature_names)
self._all_feature_names = list(raw_feature_names) + eng_names
self._all_target_names = list(raw_target_names)
if output_columns is not None:
tgt_map = {n: i for i, n in enumerate(raw_target_names)}
missing = [c for c in output_columns if c not in tgt_map]
if missing:
raise ValueError(f"Unknown output columns: {missing}")
self._tgt_idx = [tgt_map[c] for c in output_columns]
self.output_columns = list(output_columns)
else:
self._tgt_idx = list(range(targets.shape[1]))
self.output_columns = list(raw_target_names)
# Pass the case's metadata through verbatim. Case code reads
# whatever keys it needs (with its own ``.get`` defaults); the
# generic dataset doesn't know which keys are physics-relevant.
case_meta_list.append(dict(meta.attrs))
# Load ALL base features (derived columns need access to
# source columns that may not be in input_columns).
if eng_names:
engineered = engineered_feature_builder(features, raw_feature_names)
engineered_cols = [engineered[name].reshape(-1, 1) for name in eng_names]
all_x.append(
np.concatenate([features] + engineered_cols, axis=1).astype(np.float32)
)
else:
all_x.append(features.astype(np.float32))
all_y.append(targets[:, self._tgt_idx])
all_w.append(weights)
case_ids.append(case_id)
rows_per_case.append(features.shape[0])
# ----------------------------------------------------------
# Concatenate
# ----------------------------------------------------------
full_x = np.concatenate(all_x, axis=0) # [N, D_base]
full_y = np.concatenate(all_y, axis=0) # [N, D_out]
# Store per-case metadata
self._case_meta = case_meta_list
self.exclude_cases = list(exclude_cases) if exclude_cases else []
# Store raw geometry columns (before normalization) for delta_p loss
z_hat_col = (
self._all_feature_names.index("z_hat") if "z_hat" in self._all_feature_names else None
)
d_over_D_col = (
self._all_feature_names.index("d_local_over_D")
if "d_local_over_D" in self._all_feature_names
else None
)
self._raw_z_hat = (
torch.from_numpy(full_x[:, z_hat_col].copy()) if z_hat_col is not None else None
)
self._raw_d_local_over_D = (
torch.from_numpy(full_x[:, d_over_D_col].copy()) if d_over_D_col is not None else None
)
# ----------------------------------------------------------
# Optional case-supplied target transform. Extras keys consumed by
# the dataset: ``baseline_encoded`` (stashed on
# ``self._baseline_encoded`` for downstream consumers — metrics,
# plotting, the Δp integral) and ``local_velocity_normalization``
# (propagated onto ``self.local_velocity_normalization``).
# ----------------------------------------------------------
self._baseline_encoded: torch.Tensor | None = None
self.has_target_baseline = False
self.local_velocity_normalization = False
if target_transform is not None:
full_y, extras = target_transform(
full_y,
full_x,
target_names=self.output_columns,
feature_names=self._all_feature_names,
case_meta_list=case_meta_list,
rows_per_case=rows_per_case,
local_velocity_normalization=local_velocity_normalization,
)
extras = extras or {}
baseline = extras.get("baseline_encoded")
if baseline is not None:
self._baseline_encoded = torch.from_numpy(np.asarray(baseline, dtype=np.float32))
self.has_target_baseline = True
self.local_velocity_normalization = bool(
extras.get("local_velocity_normalization", False)
)
# Resolve input columns
if input_columns is not None:
feat_map = {n: i for i, n in enumerate(self._all_feature_names)}
missing = [c for c in input_columns if c not in feat_map]
if missing:
raise ValueError(f"Unknown input columns: {missing}")
self._feat_idx = [feat_map[c] for c in input_columns]
self.input_columns = list(input_columns)
else:
self._feat_idx = list(range(len(self._base_feature_names)))
self.input_columns = list(self._base_feature_names)
self._x = torch.from_numpy(full_x[:, self._feat_idx].copy())
self._y = torch.from_numpy(full_y)
raw_w = np.concatenate(all_w, axis=0)
self._w = torch.from_numpy(raw_w).unsqueeze(-1) if has_weights else None
self._case_ids_unique = case_ids
self._rows_per_case = rows_per_case
# Build per-row case index for subsetting
self._row_case_idx = np.concatenate(
[np.full(n, i, dtype=np.int32) for i, n in enumerate(rows_per_case)]
)
# ----------------------------------------------------------
# Region weights (throat)
# ----------------------------------------------------------
throat_col_full = (
self._all_feature_names.index("is_throat")
if "is_throat" in self._all_feature_names
else None
)
self.throat_weight = throat_weight
self.downstream_weight = downstream_weight
if throat_weight is not None and throat_weight > 0 and throat_col_full is not None:
new_w = torch.ones(len(self._x), dtype=torch.float32)
new_w[full_x[:, throat_col_full] > 0.5] = float(throat_weight)
self._w = new_w.unsqueeze(-1)
# Region weights (downstream)
downstream_col_full = (
self._all_feature_names.index("is_downstream")
if "is_downstream" in self._all_feature_names
else None
)
if (
downstream_weight is not None
and downstream_weight > 0
and downstream_col_full is not None
):
if self._w is not None:
# Multiply with existing weights (e.g. throat weights)
ds_mask = full_x[:, downstream_col_full] > 0.5
self._w = self._w.clone()
self._w[ds_mask] *= float(downstream_weight)
else:
new_w = torch.ones(len(self._x), dtype=torch.float32)
new_w[full_x[:, downstream_col_full] > 0.5] = float(downstream_weight)
self._w = new_w.unsqueeze(-1)
# Case index tensor for per-case losses
self.include_case_idx = include_case_idx
self._case_idx_tensor = torch.from_numpy(self._row_case_idx).long()
# Feature normalization
self.normalize = normalize
if norm_stats is not None:
self.norm_stats = self._coerce_norm_stats(norm_stats, dtype=self._x.dtype)
elif normalize:
stats_source_x = self._x
if norm_from_case_indices is not None:
keep = sorted({int(i) for i in norm_from_case_indices})
if not keep:
raise ValueError("norm_from_case_indices must not be empty.")
invalid = [i for i in keep if i < 0 or i >= len(self._case_ids_unique)]
if invalid:
raise ValueError(
f"norm_from_case_indices contains out-of-range case index(es): {invalid}"
)
mask = np.isin(self._row_case_idx, keep)
if not np.any(mask):
raise ValueError(
"norm_from_case_indices selected zero rows; cannot compute normalization stats."
)
stats_source_x = self._x[mask]
self.norm_stats = self._compute_norm_stats(stats_source_x)
else:
self.norm_stats = None
if self.normalize and self.norm_stats is not None:
self._x = (self._x - self.norm_stats["x_mean"]) / self.norm_stats["x_std"]
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _parse_Dr(case_name: str) -> float:
"""Extract the diameter ratio from a case name like ``Re_*__Dr_0p333__Lr_*``."""
for part in case_name.split("__"):
if part.startswith("Dr_"):
return float(part[3:].replace("p", "."))
return 0.0
# ------------------------------------------------------------------
# Normalization
# ------------------------------------------------------------------
@staticmethod
def _compute_norm_stats(x: torch.Tensor) -> dict[str, torch.Tensor]:
"""Compute per-feature mean and std from input data tensor."""
return {
"x_mean": x.mean(dim=0),
"x_std": x.std(dim=0).clamp(min=1e-8),
}
@staticmethod
def _coerce_norm_stats(
norm_stats: dict,
*,
dtype: torch.dtype,
) -> dict[str, torch.Tensor]:
if "x_mean" not in norm_stats or "x_std" not in norm_stats:
raise ValueError("norm_stats must contain both 'x_mean' and 'x_std'.")
x_mean = torch.as_tensor(norm_stats["x_mean"], dtype=dtype)
x_std = torch.as_tensor(norm_stats["x_std"], dtype=dtype).clamp(min=1e-8)
if x_mean.ndim != 1 or x_std.ndim != 1:
raise ValueError("norm_stats['x_mean'] and norm_stats['x_std'] must be 1D.")
if x_mean.shape != x_std.shape:
raise ValueError(
"norm_stats['x_mean'] and norm_stats['x_std'] must have matching shapes."
)
return {"x_mean": x_mean, "x_std": x_std}
# ------------------------------------------------------------------
# Public properties
# ------------------------------------------------------------------
@property
def in_features(self) -> int:
return len(self.input_columns)
@property
def out_features(self) -> int:
return len(self.output_columns)
@property
def sim_names(self) -> list[str]:
"""Unique case IDs in discovery order (compatible with split_indices)."""
return self._case_ids_unique
# ------------------------------------------------------------------
# Dataset interface
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self._x)
def __getitem__(self, idx: int):
if self.include_case_idx:
if self._w is not None:
return self._x[idx], self._y[idx], self._w[idx], self._case_idx_tensor[idx]
return self._x[idx], self._y[idx], self._case_idx_tensor[idx]
if self._w is not None:
return self._x[idx], self._y[idx], self._w[idx]
return self._x[idx], self._y[idx]
# ------------------------------------------------------------------
# Case-level subsetting
# ------------------------------------------------------------------
[docs]
def subset_by_case_indices(self, case_indices: list[int]) -> "TabularPairDataset":
"""Return a new dataset containing only rows for the given case indices.
``case_indices`` indexes into ``self.sim_names``.
"""
keep = set(case_indices)
mask = np.isin(self._row_case_idx, list(keep))
new = object.__new__(TabularPairDataset)
new.zarr_dir = self.zarr_dir
new.input_columns = list(self.input_columns)
new.output_columns = list(self.output_columns)
new._base_feature_names = list(self._base_feature_names)
new._all_feature_names = list(self._all_feature_names)
new._all_target_names = list(self._all_target_names)
new._feat_idx = list(self._feat_idx)
new._tgt_idx = list(self._tgt_idx)
new._x = self._x[mask]
new._y = self._y[mask]
new._w = self._w[mask] if self._w is not None else None
new._case_ids_unique = [self._case_ids_unique[i] for i in case_indices]
new.normalize = self.normalize
new.norm_stats = self.norm_stats # share parent's stats (don't recompute)
new.throat_weight = self.throat_weight
new.downstream_weight = self.downstream_weight
new.include_case_idx = self.include_case_idx
new.local_velocity_normalization = self.local_velocity_normalization
new.exclude_cases = list(self.exclude_cases)
# Propagate per-case metadata and raw geometry arrays
new._case_meta = [self._case_meta[i] for i in case_indices]
new._raw_z_hat = self._raw_z_hat[mask] if self._raw_z_hat is not None else None
new._raw_d_local_over_D = (
self._raw_d_local_over_D[mask] if self._raw_d_local_over_D is not None else None
)
# Rebuild rows_per_case and row_case_idx for the subset
new._rows_per_case = [self._rows_per_case[i] for i in case_indices]
new._row_case_idx = np.concatenate(
[np.full(n, new_i, dtype=np.int32) for new_i, n in enumerate(new._rows_per_case)]
)
new._case_idx_tensor = torch.from_numpy(new._row_case_idx).long()
new.has_target_baseline = self.has_target_baseline
new._baseline_encoded = (
self._baseline_encoded[mask] if self._baseline_encoded is not None else None
)
return new
# ------------------------------------------------------------------
# Residual-target helpers
# ------------------------------------------------------------------
[docs]
def add_baseline_to_encoded(
self,
encoded: torch.Tensor,
row_mask: np.ndarray | torch.Tensor | None = None,
field_idx: int | None = None,
) -> torch.Tensor:
"""Re-add the per-row encoded baseline to an encoded tensor.
No-op when no target baseline was attached (``has_target_baseline``
is False) so callers can use it unconditionally at decode boundaries.
Parameters
----------
encoded
Encoded tensor in residual space (e.g. a model prediction).
row_mask
Optional row selector — numpy boolean / integer array or torch
tensor — to slice the baseline before adding.
field_idx
Output-field index when the dataset has multiple target
columns. When ``None``, falls back to auto-squeezing a
single-field baseline so 1-D ``encoded`` tensors broadcast
correctly.
"""
if not self.has_target_baseline or self._baseline_encoded is None:
return encoded
bl = self._baseline_encoded
if row_mask is not None:
if isinstance(row_mask, np.ndarray):
row_mask = torch.as_tensor(row_mask)
bl = bl[row_mask]
if field_idx is not None:
bl = bl[..., field_idx]
elif encoded.dim() == 1 and bl.dim() == 2 and bl.shape[1] == 1:
# Allow callers to pass a 1-D field-slice; broadcast accordingly.
bl = bl.squeeze(-1)
return encoded + bl.to(encoded.dtype).to(encoded.device)