"""MOOSEDataset: PyTorch Dataset over processed Zarr simulation stores.
Supports three representation modes for different PhysicsNeMo model families:
"graph" GNN / MeshGraphNet
─────────────────────────────────────────────────────────────────────────
coords float32 [N, D] node spatial coordinates
edge_index int64 [2, M] COO edge list (src / dst)
node_fields float32 [T, N, F] per-node fields (interpolated from elements)
elem_fields float32 [T, E, F] per-element fields (raw)
probe_data dict probe_name → float32 [Np, C]
"point_cloud" PointNet / Transformer
─────────────────────────────────────────────────────────────────────────
coords float32 [N, D] node spatial coordinates
node_fields float32 [T, N, F] per-node fields
"grid" CNN (U-Net, FNO)
─────────────────────────────────────────────────────────────────────────
grid_x float32 [Nx] column x-coordinates
grid_y float32 [Ny] row y-coordinates
grid_fields float32 [T, Nx, Ny, F] fields on regular grid
All modes also include:
field_names list[str] field name for each F index
norm_stats dict field_name → {"mean": float, "std": float}
sim_name str unique simulation identifier
time_steps float32 [T]
If time_idx is given (≥ 0), only that time step is returned (T-dim removed).
Denormalization
───────────────
Call ``dataset.denormalize("pressure", tensor)`` to recover a tensor in
original physical units.
"""
import json
import logging
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
logger = logging.getLogger(__name__)
[docs]
class MOOSEDataset(Dataset):
"""Dataset over a directory of processed MOOSE Zarr stores.
Args:
zarr_dir : Path to the directory containing ``*.zarr`` stores.
mode : One of "graph", "point_cloud", "grid".
time_idx : If ≥ 0, return only this time step (removes T dimension).
If -1 (default), return all time steps.
"""
MODES = ("graph", "point_cloud", "grid")
def __init__(
self,
zarr_dir: str | Path,
mode: str = "graph",
time_idx: int = -1,
):
if mode not in self.MODES:
raise ValueError(f"mode must be one of {self.MODES}, got '{mode}'")
self.zarr_dir = Path(zarr_dir)
self.mode = mode
self.time_idx = time_idx
self.sim_paths: list[Path] = sorted(self.zarr_dir.glob("*.zarr"))
if not self.sim_paths:
raise FileNotFoundError(f"No .zarr stores found in {self.zarr_dir}")
logger.info(
"MOOSEDataset: found %d simulation(s), mode='%s'",
len(self.sim_paths),
mode,
)
# ------------------------------------------------------------------
# Dataset interface
# ------------------------------------------------------------------
def __len__(self) -> int:
return len(self.sim_paths)
def __getitem__(self, idx: int) -> dict:
import zarr
store_path = self.sim_paths[idx]
root = zarr.open(str(store_path), mode="r")
# --- Common metadata ---
meta = root["metadata"]
field_names: list[str] = json.loads(meta.attrs["field_names"])
probe_columns: list[str] = json.loads(meta.attrs["probe_columns"])
sim_name: str = str(meta.attrs["sim_name"])
time_steps = torch.from_numpy(np.array(meta["time_steps"], dtype=np.float32))
norm_stats = load_norm_stats(meta)
sample: dict = {
"field_names": field_names,
"norm_stats": norm_stats,
"sim_name": sim_name,
"time_steps": time_steps,
}
if self.mode == "graph":
sample.update(self._load_graph(root, field_names, probe_columns))
elif self.mode == "point_cloud":
sample.update(self._load_point_cloud(root, field_names))
elif self.mode == "grid":
sample.update(self._load_grid(root, field_names))
return sample
# ------------------------------------------------------------------
# Mode-specific loaders
# ------------------------------------------------------------------
def _load_graph(self, root, field_names: list[str], probe_columns: list[str]) -> dict:
mesh = root["mesh"]
coords = to_tensor(mesh["coords"]) # [N, D]
connectivity = torch.from_numpy(np.array(mesh["connectivity"], dtype=np.int64))
edge_src = torch.from_numpy(np.array(mesh["edge_src"], dtype=np.int64))
edge_dst = torch.from_numpy(np.array(mesh["edge_dst"], dtype=np.int64))
edge_index = torch.stack([edge_src, edge_dst], dim=0) # [2, M]
# Element fields [T, E, F]
elem_fields = load_fields(root["fields"], field_names)
elem_fields = slice_time(elem_fields, self.time_idx)
# Interpolate element fields to nodes via simple centroid averaging
node_fields = elem_to_node(elem_fields, connectivity, coords.shape[0])
# Probes
probes_grp = root["probes"]
probe_data = {name: to_tensor(probes_grp[name]) for name in probes_grp.array_keys()}
return {
"coords": coords,
"edge_index": edge_index,
"elem_fields": elem_fields,
"node_fields": node_fields,
"probe_data": probe_data,
}
def _load_point_cloud(self, root, field_names: list[str]) -> dict:
mesh = root["mesh"]
coords = to_tensor(mesh["coords"]) # [N, D]
connectivity = torch.from_numpy(np.array(mesh["connectivity"], dtype=np.int64))
elem_fields = load_fields(root["fields"], field_names) # [T, E, F]
elem_fields = slice_time(elem_fields, self.time_idx)
node_fields = elem_to_node(elem_fields, connectivity, coords.shape[0])
return {
"coords": coords,
"node_fields": node_fields,
}
def _load_grid(self, root, field_names: list[str]) -> dict:
grid_grp = root["grid"]
grid_x = to_tensor(grid_grp["x"]) # [Nx]
grid_y = to_tensor(grid_grp["y"]) # [Ny]
# Load each field and stack: [T, Nx, Ny, F]
field_arrays = [
to_tensor(grid_grp[name]).unsqueeze(-1) # [T, Nx, Ny, 1]
for name in field_names
if name in grid_grp
]
grid_fields = torch.cat(field_arrays, dim=-1) # [T, Nx, Ny, F]
grid_fields = slice_time(grid_fields, self.time_idx)
return {
"grid_x": grid_x,
"grid_y": grid_y,
"grid_fields": grid_fields,
}
# ------------------------------------------------------------------
# Normalization helpers
# ------------------------------------------------------------------
[docs]
def denormalize(self, field_name: str, tensor: torch.Tensor) -> torch.Tensor:
"""Reverse the z-score normalization for a single field.
Args:
field_name: Name of the field (must be in norm_stats).
tensor : Normalized tensor of any shape.
Returns:
Tensor in original physical units.
"""
# Load stats from the first store (stats are per-simulation; use idx=0
# here as a convention — override per-sample if needed).
import zarr
root = zarr.open(str(self.sim_paths[0]), mode="r")
norm_stats = load_norm_stats(root["metadata"])
if field_name not in norm_stats:
raise KeyError(f"Field '{field_name}' not found in norm_stats.")
stats = norm_stats[field_name]
mean = torch.tensor(stats["mean"], dtype=tensor.dtype, device=tensor.device)
std = torch.tensor(stats["std"], dtype=tensor.dtype, device=tensor.device)
return tensor * std + mean
# ---------------------------------------------------------------------------
# Module-level helpers
# ---------------------------------------------------------------------------
[docs]
def to_tensor(arr) -> torch.Tensor:
"""Convert a zarr array to a float32 torch tensor."""
return torch.from_numpy(np.array(arr, dtype=np.float32))
[docs]
def load_fields(fields_grp, field_names: list[str]) -> torch.Tensor:
"""Load and stack element fields from a zarr group → [T, E, F]."""
arrays = [
to_tensor(fields_grp[name]).unsqueeze(-1) # [T, E, 1]
for name in field_names
if name in fields_grp
]
if not arrays:
raise ValueError("No matching field arrays found in 'fields' group.")
return torch.cat(arrays, dim=-1) # [T, E, F]
[docs]
def slice_time(tensor: torch.Tensor, time_idx: int) -> torch.Tensor:
"""If time_idx >= 0, select that time step and remove the T dimension."""
if time_idx >= 0:
return tensor[time_idx] # removes T dim
return tensor
[docs]
def elem_to_node(
elem_fields: torch.Tensor,
connectivity: torch.Tensor,
n_nodes: int,
) -> torch.Tensor:
"""Average element fields onto nodes (scatter mean over element→node map).
elem_fields : [..., E, F] (leading dims may include T)
connectivity : [E, K] 0-indexed node indices
n_nodes : N
Returns node_fields : [..., N, F]
"""
leading = elem_fields.shape[:-2]
E, F = elem_fields.shape[-2], elem_fields.shape[-1]
K = connectivity.shape[1]
# Flatten leading dims for scatter
flat_fields = elem_fields.reshape(-1, E, F) # [B, E, F]
B = flat_fields.shape[0]
# node_fields accumulator
node_acc = torch.zeros(B, n_nodes, F, dtype=elem_fields.dtype)
node_cnt = torch.zeros(n_nodes, dtype=torch.float32)
for k in range(K):
node_idx = connectivity[:, k] # [E]
node_acc.scatter_add_(
1,
node_idx.unsqueeze(0).unsqueeze(-1).expand(B, E, F),
flat_fields,
)
ones = torch.ones(E, dtype=torch.float32)
node_cnt.scatter_add_(0, node_idx, ones)
# Avoid division by zero (nodes not referenced by any element)
node_cnt = node_cnt.clamp(min=1.0)
node_fields = node_acc / node_cnt.unsqueeze(0).unsqueeze(-1) # [B, N, F]
return node_fields.reshape(*leading, n_nodes, F)
[docs]
def load_norm_stats(meta_grp) -> dict[str, dict[str, float]]:
"""Read per-field normalization stats from metadata/norm_stats/."""
stats: dict[str, dict[str, float]] = {}
if "norm_stats" not in meta_grp:
return stats
norm_grp = meta_grp["norm_stats"]
for field_name in norm_grp.group_keys():
field_grp = norm_grp[field_name]
stats[field_name] = {
"mean": float(field_grp.attrs["mean"]),
"std": float(field_grp.attrs["std"]),
}
return stats