Source code for cases.moose_grid.etl.transformations.moose_transform
"""MOOSEDataTransformation: normalize, build graph, interpolate to grid.
Implements the DataTransformation ABC from physicsnemo-curator.
Pipeline:
1. Per-field mean/std normalization across all time steps and elements.
2. Graph edge construction from element→node connectivity (all node pairs
within each element, both directions).
3. Bilinear interpolation of element-centroid values onto a regular Nx×Ny grid
using scipy.interpolate.griddata.
Returns a dict containing all fields of MOOSEProcessedData, ready for
MOOSEZarrSink to write to disk.
"""
import itertools
import logging
from typing import Any, Optional
import numpy as np
from physicsnemo_curator.etl.data_transformations import DataTransformation
from physicsnemo_curator.etl.processing_config import ProcessingConfig
from scipy.interpolate import griddata
from cases.moose_grid.etl.schemas import MOOSEProcessedData, NormStats
logger = logging.getLogger(__name__)
[docs]
class MOOSEDataTransformation(DataTransformation):
"""Normalize, build graph, and interpolate MOOSE simulation data.
Args:
cfg : ProcessingConfig from the curator framework.
grid_nx : Number of grid columns for the regular-grid output.
grid_ny : Number of grid rows for the regular-grid output.
eps : Small value added to std to avoid division by zero.
"""
def __init__(
self,
cfg: ProcessingConfig,
grid_nx: int = 64,
grid_ny: int = 64,
eps: float = 1e-8,
):
super().__init__(cfg)
self.grid_nx = grid_nx
self.grid_ny = grid_ny
self.eps = eps
# ------------------------------------------------------------------
# DataTransformation interface
# ------------------------------------------------------------------
[docs]
def transform(self, data: dict[str, Any]) -> Optional[dict[str, Any]]:
"""Transform raw MOOSE data into ML-ready form.
Args:
data: dict produced by ExodusDataSource.read_file().
Returns:
dict of MOOSEProcessedData fields, or None to skip this sample.
"""
coords: np.ndarray = data["coords"] # [N, D]
connectivity: np.ndarray = data["connectivity"] # [E, K]
field_names: list[str] = data["field_names"]
fields: np.ndarray = data["fields"] # [T, E, F]
time_steps: np.ndarray = data["time_steps"]
probe_data: dict = data["probe_data"]
probe_columns: list[str] = data["probe_columns"]
sim_name: str = data["sim_name"]
if fields.size == 0:
logger.warning("Skipping %s: no element fields found.", sim_name)
return None
# 1. Normalize fields
norm_fields, norm_stats = self._normalize(fields, field_names)
# 2. Build undirected graph edges from element connectivity
edge_src, edge_dst = self._build_edges(connectivity)
# 3. Interpolate to regular grid
grid_fields, grid_x, grid_y = self._interpolate_to_grid(coords, connectivity, norm_fields)
processed = MOOSEProcessedData(
coords=coords,
connectivity=connectivity,
edge_src=edge_src,
edge_dst=edge_dst,
fields=norm_fields,
field_names=field_names,
norm_stats=norm_stats,
probe_data=probe_data,
probe_columns=probe_columns,
grid_fields=grid_fields,
grid_x=grid_x,
grid_y=grid_y,
time_steps=time_steps,
sim_name=sim_name,
)
# Return as plain dict for the curator pipeline / zarr sink
return {
"coords": processed.coords,
"connectivity": processed.connectivity,
"edge_src": processed.edge_src,
"edge_dst": processed.edge_dst,
"fields": processed.fields,
"field_names": processed.field_names,
"norm_stats": {
name: {"mean": float(s.mean), "std": float(s.std)}
for name, s in processed.norm_stats.items()
},
"probe_data": processed.probe_data,
"probe_columns": processed.probe_columns,
"grid_fields": processed.grid_fields,
"grid_x": processed.grid_x,
"grid_y": processed.grid_y,
"time_steps": processed.time_steps,
"sim_name": processed.sim_name,
}
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _normalize(
self, fields: np.ndarray, field_names: list[str]
) -> tuple[np.ndarray, dict[str, NormStats]]:
"""Z-score normalize each field independently.
fields: [T, E, F]
Returns normalized array (same shape) and per-field stats.
"""
norm_fields = fields.copy()
norm_stats: dict[str, NormStats] = {}
for fi, name in enumerate(field_names):
vals = fields[:, :, fi] # [T, E]
mean = float(vals.mean())
std = float(vals.std())
norm_fields[:, :, fi] = (vals - mean) / (std + self.eps)
norm_stats[name] = NormStats(mean=mean, std=std)
return norm_fields, norm_stats
def _build_edges(self, connectivity: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Build undirected graph edges from element-to-node connectivity.
For each element, connect all pairs of its nodes in both directions.
Duplicate edges (shared faces/edges between elements) are removed.
connectivity: [E, K] 0-indexed node indices per element.
Returns edge_src, edge_dst each of shape [M].
"""
edge_set: set[tuple[int, int]] = set()
num_elem, nodes_per_elem = connectivity.shape
for e in range(num_elem):
nodes = connectivity[e] # [K]
for i, j in itertools.combinations(range(nodes_per_elem), 2):
n_i, n_j = int(nodes[i]), int(nodes[j])
edge_set.add((n_i, n_j))
edge_set.add((n_j, n_i)) # undirected → both directions
if not edge_set:
return np.empty(0, dtype=np.int32), np.empty(0, dtype=np.int32)
edges = np.array(sorted(edge_set), dtype=np.int32) # [M, 2]
edge_src = edges[:, 0]
edge_dst = edges[:, 1]
return edge_src, edge_dst
def _interpolate_to_grid(
self,
coords: np.ndarray,
connectivity: np.ndarray,
fields: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Interpolate element-centroid values to a regular Nx×Ny grid.
Element centroids are computed as the mean position of their nodes.
scipy.interpolate.griddata (linear method) fills the grid.
coords : [N, D] node coordinates (only first two dims used)
connectivity: [E, K] element→node connectivity (0-indexed)
fields : [T, E, F] normalized element values
Returns:
grid_fields : [T, Nx, Ny, F]
grid_x : [Nx] column x-coordinates
grid_y : [Ny] row y-coordinates
"""
# Element centroids: mean of node positions [E, 2]
elem_xy = coords[:, :2][connectivity].mean(axis=1) # [E, 2]
x_min, y_min = elem_xy[:, 0].min(), elem_xy[:, 1].min()
x_max, y_max = elem_xy[:, 0].max(), elem_xy[:, 1].max()
grid_x = np.linspace(x_min, x_max, self.grid_nx, dtype=np.float32)
grid_y = np.linspace(y_min, y_max, self.grid_ny, dtype=np.float32)
gx, gy = np.meshgrid(grid_x, grid_y, indexing="ij") # [Nx, Ny]
num_time = fields.shape[0]
num_fields = fields.shape[2]
grid_fields = np.zeros((num_time, self.grid_nx, self.grid_ny, num_fields), dtype=np.float32)
for t in range(num_time):
for fi in range(num_fields):
values = fields[t, :, fi] # [E]
grid_fields[t, :, :, fi] = griddata(
points=elem_xy,
values=values,
xi=(gx, gy),
method="linear",
fill_value=0.0,
).astype(np.float32)
return grid_fields, grid_x, grid_y