"""Dataset wrappers and split helpers for training workflows."""
import random
import re as _re
from collections import defaultdict
from pathlib import Path
import torch
from torch.utils.data import Dataset
from dataset.moose_dataset import MOOSEDataset
from training import _require_pyg
[docs]
def parse_field_list(raw: str | list[str] | tuple[str, ...] | None) -> list[str] | None:
"""Parse field selections from CLI-style strings or list values."""
if raw is None:
return None
if isinstance(raw, str):
names = [item.strip() for item in raw.split(",") if item.strip()]
else:
names = [str(item).strip() for item in raw if str(item).strip()]
if not names:
raise ValueError("Field list is empty. Provide at least one field name.")
return names
[docs]
def resolve_time_idx(time_idx: int, num_steps: int, label: str) -> int:
"""Resolve negative time indexing and validate the index."""
resolved = time_idx if time_idx >= 0 else num_steps + time_idx
if resolved < 0 or resolved >= num_steps:
raise ValueError(f"{label}={time_idx} is out of range for {num_steps} time step(s).")
return resolved
[docs]
class GridPairDataset(Dataset):
"""Build supervised grid pairs `(x, y)` from MOOSEDataset grid samples."""
def __init__(
self,
zarr_dir: str | Path,
input_fields: list[str] | None,
output_fields: list[str] | None,
input_time_idx: int,
target_time_idx: int,
):
self.zarr_dir = Path(zarr_dir)
self.base = MOOSEDataset(zarr_dir=self.zarr_dir, mode="grid", time_idx=-1)
reference = self.base[0]
self.field_names = list(reference["field_names"])
self.field_to_index = {name: idx for idx, name in enumerate(self.field_names)}
self.sim_names = [path.stem for path in self.base.sim_paths]
self.input_fields = input_fields or list(self.field_names)
self.output_fields = output_fields or list(self.input_fields)
missing_inputs = [name for name in self.input_fields if name not in self.field_to_index]
if missing_inputs:
raise ValueError(f"Unknown input field(s): {missing_inputs}")
missing_outputs = [name for name in self.output_fields if name not in self.field_to_index]
if missing_outputs:
raise ValueError(f"Unknown output field(s): {missing_outputs}")
grid = reference["grid_fields"]
if grid.ndim != 4:
raise ValueError(
f"Expected grid_fields with shape [T, Nx, Ny, F], got {tuple(grid.shape)}"
)
self.num_time_steps = int(grid.shape[0])
self.spatial_shape = (int(grid.shape[1]), int(grid.shape[2]))
self.input_time_idx = resolve_time_idx(
input_time_idx, self.num_time_steps, "input_time_idx"
)
self.target_time_idx = resolve_time_idx(
target_time_idx, self.num_time_steps, "target_time_idx"
)
self.input_indices = [self.field_to_index[name] for name in self.input_fields]
self.output_indices = [self.field_to_index[name] for name in self.output_fields]
def __len__(self) -> int:
return len(self.base)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
sample = self.base[idx]
if list(sample["field_names"]) != self.field_names:
raise ValueError(
f"Field names differ across simulations. Expected {self.field_names}, "
f"got {sample['field_names']} in {sample['sim_name']}."
)
grid = sample["grid_fields"]
if grid.ndim != 4:
raise ValueError(
f"Expected grid_fields with shape [T, Nx, Ny, F], got {tuple(grid.shape)}"
)
spatial_shape = (int(grid.shape[1]), int(grid.shape[2]))
if spatial_shape != self.spatial_shape:
raise ValueError(
f"Grid size mismatch for {sample['sim_name']}: expected "
f"{self.spatial_shape}, got {spatial_shape}."
)
x = self._select_channels(grid, self.input_time_idx, self.input_indices)
y = self._select_channels(grid, self.target_time_idx, self.output_indices)
return x, y
@staticmethod
def _select_channels(
grid: torch.Tensor, time_idx: int, field_indices: list[int]
) -> torch.Tensor:
tensor = grid[time_idx, :, :, field_indices]
if tensor.ndim == 2:
tensor = tensor.unsqueeze(-1)
return tensor.permute(2, 0, 1).contiguous()
[docs]
class GraphPairDataset(Dataset):
"""Build supervised PyG Data objects from MOOSEDataset graph samples."""
def __init__(
self,
zarr_dir: str | Path,
input_fields: list[str] | None,
output_fields: list[str] | None,
input_time_idx: int,
target_time_idx: int,
):
_require_pyg()
from torch_geometric.data import Data
self._pyg_data_cls = Data
self.zarr_dir = Path(zarr_dir)
self.base = MOOSEDataset(zarr_dir=self.zarr_dir, mode="graph", time_idx=-1)
reference = self.base[0]
self.field_names = list(reference["field_names"])
self.field_to_index = {name: idx for idx, name in enumerate(self.field_names)}
self.sim_names = [path.stem for path in self.base.sim_paths]
self.input_fields = input_fields or list(self.field_names)
self.output_fields = output_fields or list(self.input_fields)
missing_inputs = [name for name in self.input_fields if name not in self.field_to_index]
if missing_inputs:
raise ValueError(f"Unknown input field(s): {missing_inputs}")
missing_outputs = [name for name in self.output_fields if name not in self.field_to_index]
if missing_outputs:
raise ValueError(f"Unknown output field(s): {missing_outputs}")
node_fields = reference["node_fields"]
if node_fields.ndim != 3:
raise ValueError(
f"Expected node_fields with shape [T, N, F], got {tuple(node_fields.shape)}"
)
coords = reference["coords"]
if coords.ndim != 2:
raise ValueError(f"Expected coords with shape [N, D], got {tuple(coords.shape)}")
self.num_time_steps = int(node_fields.shape[0])
self.coord_dim = int(coords.shape[1])
self.edge_dim = self.coord_dim + 1
self.input_time_idx = resolve_time_idx(
input_time_idx, self.num_time_steps, "input_time_idx"
)
self.target_time_idx = resolve_time_idx(
target_time_idx, self.num_time_steps, "target_time_idx"
)
self.input_indices = [self.field_to_index[name] for name in self.input_fields]
self.output_indices = [self.field_to_index[name] for name in self.output_fields]
def __len__(self) -> int:
return len(self.base)
def __getitem__(self, idx: int):
sample = self.base[idx]
if list(sample["field_names"]) != self.field_names:
raise ValueError(
f"Field names differ across simulations. Expected {self.field_names}, "
f"got {sample['field_names']} in {sample['sim_name']}."
)
node_fields = sample["node_fields"]
if node_fields.ndim != 3:
raise ValueError(
f"Expected node_fields with shape [T, N, F], got {tuple(node_fields.shape)}"
)
coords = sample["coords"].float().contiguous()
edge_index = sample["edge_index"].long().contiguous()
x = self._select_channels(node_fields, self.input_time_idx, self.input_indices)
y = self._select_channels(node_fields, self.target_time_idx, self.output_indices)
edge_attr = self._build_edge_attr(coords, edge_index)
data = self._pyg_data_cls(
x=x.float().contiguous(),
y=y.float().contiguous(),
edge_index=edge_index,
edge_attr=edge_attr.float().contiguous(),
pos=coords,
)
data.num_nodes = int(coords.shape[0])
return data
@staticmethod
def _select_channels(
node_fields: torch.Tensor, time_idx: int, field_indices: list[int]
) -> torch.Tensor:
tensor = node_fields[time_idx, :, field_indices]
if tensor.ndim == 1:
tensor = tensor.unsqueeze(-1)
return tensor.contiguous()
@staticmethod
def _build_edge_attr(coords: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
src = edge_index[0]
dst = edge_index[1]
displacement = coords[dst] - coords[src]
distance = torch.linalg.norm(displacement, dim=1, keepdim=True)
return torch.cat([displacement, distance], dim=1)
def _read_sim_name_list(path: str | Path) -> list[str]:
entries: list[str] = []
for line in Path(path).read_text(encoding="utf-8").splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
entries.append(stripped.removesuffix(".zarr"))
if not entries:
raise ValueError(f"Split file '{path}' contains no simulation names.")
return entries
def _parse_case_params(sim_name: str) -> dict[str, float]:
"""Extract Re, Dr, Lr from a simulation name like ``Re_5000__Dr_0p144__Lr_0p01``.
Returns a dict with keys ``Re``, ``Dr``, ``Lr``. Values that cannot be
parsed are returned as 0.0 so the case still lands in *some* bin.
"""
params: dict[str, float] = {}
for key in ("Re", "Dr", "Lr"):
match = _re.search(rf"{key}_([0-9p]+)", sim_name)
if match:
val_str = match.group(1).replace("p", ".")
try:
params[key] = float(val_str)
except ValueError:
params[key] = 0.0
else:
params[key] = 0.0
return params
def _stratified_split(
sim_names: list[str],
train_ratio: float,
seed: int,
n_bins: int = 3,
) -> tuple[list[int], list[int]]:
"""Stratified train/test split over (Dr, Re) bins.
Each case is assigned to a composite bin from its quantile-binned Dr
and Re values, and *train_ratio* is sampled from each bin. Stratifying
on (Dr, Re) ensures parameter-space corners (e.g. extreme Dr at low
Re) appear in both train and test. Lr is intentionally not included
in the stratum key — adding it over-fragments small bins and weakens
coverage of the corners that actually drive prediction error.
"""
parsed = [_parse_case_params(name) for name in sim_names]
re_vals = [p["Re"] for p in parsed]
dr_vals = [p["Dr"] for p in parsed]
def _quantile_bin(values: list[float], n: int) -> list[int]:
sorted_unique = sorted(set(values))
if len(sorted_unique) <= n:
mapping = {v: i for i, v in enumerate(sorted_unique)}
return [mapping[v] for v in values]
edges = [sorted_unique[int(len(sorted_unique) * i / n)] for i in range(n)]
bins = []
for v in values:
b = n - 1
for j in range(1, n):
if v < edges[j]:
b = j - 1
break
bins.append(b)
return bins
re_bins = _quantile_bin(re_vals, n_bins)
dr_bins = _quantile_bin(dr_vals, n_bins)
# Composite bin key per case: (Dr, Re)
bin_groups: dict[tuple[int, int], list[int]] = defaultdict(list)
for idx in range(len(sim_names)):
key = (dr_bins[idx], re_bins[idx])
bin_groups[key].append(idx)
rng = random.Random(seed)
train_idx: list[int] = []
test_idx: list[int] = []
for _key, indices in sorted(bin_groups.items()):
rng.shuffle(indices)
n_train = max(1, min(len(indices) - 1, round(len(indices) * train_ratio)))
if len(indices) == 1:
# Single-case bins go to train
train_idx.extend(indices)
else:
train_idx.extend(indices[:n_train])
test_idx.extend(indices[n_train:])
# Fallback: if test is empty (all bins had 1 case), move some from train
if not test_idx:
rng.shuffle(train_idx)
n_train = max(1, round(len(train_idx) * train_ratio))
test_idx = train_idx[n_train:]
train_idx = train_idx[:n_train]
return sorted(train_idx), sorted(test_idx)
[docs]
def split_indices(
num_cases: int,
split_cfg: dict,
sim_names: list[str],
) -> tuple[list[int], list[int], list[str], list[str]]:
"""Return split indices and simulation-name lists."""
if num_cases != len(sim_names):
raise ValueError(f"sim_names length {len(sim_names)} does not match num_cases {num_cases}.")
if num_cases < 2:
raise ValueError(f"Need at least 2 cases to split train/test, but found {num_cases}.")
strategy = str(split_cfg.get("strategy", "sequential"))
if strategy in {"sequential", "random"}:
train_ratio = float(split_cfg.get("train_ratio", 0.8))
if not 0.0 < train_ratio < 1.0:
raise ValueError("train_ratio must be between 0 and 1 (exclusive).")
n_train = int(num_cases * train_ratio)
n_train = max(1, min(num_cases - 1, n_train))
indices = list(range(num_cases))
if strategy == "random":
seed = int(split_cfg.get("seed", 42))
rng = random.Random(seed)
rng.shuffle(indices)
train_idx = sorted(indices[:n_train])
test_idx = sorted(indices[n_train:])
else:
train_idx = indices[:n_train]
test_idx = indices[n_train:]
elif strategy == "stratified":
train_ratio = float(split_cfg.get("train_ratio", 0.8))
if not 0.0 < train_ratio < 1.0:
raise ValueError("train_ratio must be between 0 and 1 (exclusive).")
seed = int(split_cfg.get("seed", 42))
n_bins = int(split_cfg.get("n_bins", 3))
train_idx, test_idx = _stratified_split(
sim_names, train_ratio=train_ratio, seed=seed, n_bins=n_bins
)
elif strategy == "file":
train_file = split_cfg.get("train_file")
test_file = split_cfg.get("test_file")
if not train_file or not test_file:
raise ValueError(
"split.strategy='file' requires both split.train_file and split.test_file"
)
train_names = _read_sim_name_list(train_file)
test_names = _read_sim_name_list(test_file)
sim_to_idx = {name: idx for idx, name in enumerate(sim_names)}
unknown_train = [name for name in train_names if name not in sim_to_idx]
unknown_test = [name for name in test_names if name not in sim_to_idx]
if unknown_train or unknown_test:
raise ValueError(
"Split files contain unknown simulation names. "
f"unknown_train={unknown_train}, unknown_test={unknown_test}"
)
overlap = sorted(set(train_names).intersection(test_names))
if overlap:
raise ValueError(f"Split files overlap on simulation name(s): {overlap}.")
train_idx = sorted({sim_to_idx[name] for name in train_names})
test_idx = sorted({sim_to_idx[name] for name in test_names})
if not train_idx or not test_idx:
raise ValueError("Both train and test split files must contain at least one case.")
else:
raise ValueError(
"split.strategy must be one of {'sequential', 'random', 'stratified', 'file'}, "
f"got '{strategy}'."
)
train_sims = [sim_names[idx] for idx in train_idx]
test_sims = [sim_names[idx] for idx in test_idx]
return train_idx, test_idx, train_sims, test_sims