"""Adapter layer to unify grid and graph model families."""
import importlib
from collections.abc import Callable
from pathlib import Path
import torch
from training import _require_pyg
from training.datasets import GraphPairDataset, GridPairDataset, parse_field_list
[docs]
class ModelAdapter:
"""Interface for model/data-family-specific behavior."""
family: str
[docs]
def build_dataset(self, data_cfg: dict):
raise NotImplementedError
[docs]
def dataset_info(self, dataset) -> dict:
raise NotImplementedError
[docs]
def build_batch(self, raw_batch, device: torch.device):
raise NotImplementedError
[docs]
def forward_train(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
[docs]
def forward_eval(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
[docs]
def collate_fn(self) -> Callable | None:
return None
[docs]
def accumulate_metrics(
self,
batch,
pred: torch.Tensor,
target: torch.Tensor,
) -> tuple[torch.Tensor, int]:
raise NotImplementedError
[docs]
class GridAdapter(ModelAdapter):
family = "grid"
[docs]
def build_dataset(self, data_cfg: dict) -> GridPairDataset:
return GridPairDataset(
zarr_dir=data_cfg["zarr_dir"],
input_fields=parse_field_list(data_cfg.get("input_fields")),
output_fields=parse_field_list(data_cfg.get("output_fields")),
input_time_idx=int(data_cfg.get("input_time_idx", 0)),
target_time_idx=int(data_cfg.get("target_time_idx", -1)),
)
[docs]
def dataset_info(self, dataset: GridPairDataset) -> dict:
return {
"in_channels": len(dataset.input_indices),
"out_channels": len(dataset.output_indices),
"spatial_shape": dataset.spatial_shape,
}
[docs]
def build_batch(self, raw_batch, device: torch.device):
pin_memory = device.type == "cuda"
x, y = raw_batch
return (
x.to(device, non_blocking=pin_memory),
y.to(device, non_blocking=pin_memory),
)
[docs]
def forward_train(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]:
x, y = batch
pred = model(x)
return pred, y
[docs]
def forward_eval(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]:
x, y = batch
pred = model(x)
return pred, y
[docs]
def accumulate_metrics(
self,
batch,
pred: torch.Tensor,
target: torch.Tensor,
) -> tuple[torch.Tensor, int]:
if pred.shape != target.shape:
raise ValueError(
f"Prediction shape {tuple(pred.shape)} does not match target shape {tuple(target.shape)}."
)
if pred.ndim != 4:
raise ValueError(
f"Grid predictions must have shape [B, C, Nx, Ny], got {tuple(pred.shape)}."
)
squared = (pred - target) ** 2
field_se = squared.mean(dim=(2, 3)).sum(dim=0)
num_samples = int(pred.shape[0])
return field_se, num_samples
[docs]
class GraphAdapter(ModelAdapter):
family = "graph"
def __init__(self):
_require_pyg()
from torch_geometric.data import Batch
self._pyg_batch_cls = Batch
self._last_batch = None
[docs]
def build_dataset(self, data_cfg: dict) -> GraphPairDataset:
return GraphPairDataset(
zarr_dir=data_cfg["zarr_dir"],
input_fields=parse_field_list(data_cfg.get("input_fields")),
output_fields=parse_field_list(data_cfg.get("output_fields")),
input_time_idx=int(data_cfg.get("input_time_idx", 0)),
target_time_idx=int(data_cfg.get("target_time_idx", -1)),
)
[docs]
def dataset_info(self, dataset: GraphPairDataset) -> dict:
return {
"in_channels": len(dataset.input_indices),
"out_channels": len(dataset.output_indices),
"edge_dim": dataset.edge_dim,
}
[docs]
def collate_fn(self) -> Callable | None:
def _pyg_collate(items):
return self._pyg_batch_cls.from_data_list(items)
return _pyg_collate
[docs]
def build_batch(self, raw_batch, device: torch.device):
pin_memory = device.type == "cuda"
batch = raw_batch.to(device, non_blocking=pin_memory)
self._last_batch = batch
return batch
[docs]
def forward_train(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]:
pred = model(batch.x, batch.edge_attr, batch)
return pred, batch.y
[docs]
def forward_eval(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]:
pred = model(batch.x, batch.edge_attr, batch)
return pred, batch.y
[docs]
def accumulate_metrics(
self,
batch,
pred: torch.Tensor,
target: torch.Tensor,
) -> tuple[torch.Tensor, int]:
batch_obj = self._last_batch
if batch_obj is None:
raise RuntimeError("GraphAdapter has no batch metadata for metric accumulation.")
if pred.shape != target.shape:
raise ValueError(
f"Prediction shape {tuple(pred.shape)} does not match target shape {tuple(target.shape)}."
)
if pred.ndim != 2:
raise ValueError(
f"Graph predictions must have shape [total_nodes, C], got {tuple(pred.shape)}."
)
if not hasattr(batch_obj, "batch"):
raise ValueError("Graph batch is missing 'batch' graph index tensor.")
graph_ids = batch_obj.batch
num_graphs = int(batch_obj.num_graphs)
squared = (pred - target) ** 2
graph_sum = torch.zeros(
(num_graphs, squared.shape[1]), dtype=squared.dtype, device=squared.device
)
graph_count = torch.zeros(num_graphs, dtype=squared.dtype, device=squared.device)
graph_sum.index_add_(0, graph_ids, squared)
graph_count.index_add_(0, graph_ids, torch.ones_like(graph_ids, dtype=squared.dtype))
graph_mean = graph_sum / graph_count.clamp_min(1.0).unsqueeze(-1)
field_se = graph_mean.sum(dim=0)
return field_se, num_graphs
[docs]
class PointwiseAdapter(ModelAdapter):
"""Adapter for tabular / pointwise MLP models.
Expects datasets producing ``(x, y)`` tuples of shape
``(D_in,)`` and ``(D_out,)`` which the default collate batches into
``(B, D_in)`` and ``(B, D_out)``.
"""
family = "pointwise"
[docs]
def build_dataset(self, data_cfg: dict):
from training.datasets_tabular import TabularPairDataset
norm_from_case_indices = data_cfg.get("norm_from_case_indices")
if norm_from_case_indices is not None:
norm_from_case_indices = [int(i) for i in norm_from_case_indices]
input_columns = None
if data_cfg.get("input_columns_file") is not None:
cols_path = Path(str(data_cfg["input_columns_file"]))
if not cols_path.exists():
raise FileNotFoundError(f"input_columns_file not found: {cols_path}")
input_columns = [
line.strip() for line in cols_path.read_text().splitlines() if line.strip()
]
if not input_columns:
raise ValueError(f"input_columns_file is empty: {cols_path}")
else:
input_columns = parse_field_list(data_cfg.get("input_columns"))
def _opt_float(key: str) -> float | None:
v = data_cfg.get(key)
return float(v) if v is not None else None
exclude_cases = data_cfg.get("exclude_cases")
if exclude_cases is not None:
exclude_cases = [str(c) for c in exclude_cases]
eng_ep = data_cfg.get("engineered_features_entrypoint")
eng_names = None
eng_builder = None
if eng_ep is not None:
module_name, fn_name = str(eng_ep).split(":", 1)
eng_names, eng_builder = getattr(importlib.import_module(module_name), fn_name)()
tt_ep = data_cfg.get("target_transform")
target_transform = None
if tt_ep is not None:
module_name, fn_name = str(tt_ep).split(":", 1)
target_transform = getattr(importlib.import_module(module_name), fn_name)
return TabularPairDataset(
zarr_dir=data_cfg["zarr_dir"],
input_columns=input_columns,
output_columns=parse_field_list(data_cfg.get("output_columns")),
normalize=bool(data_cfg.get("normalize", False)),
norm_stats=data_cfg.get("norm_stats"),
norm_from_case_indices=norm_from_case_indices,
throat_weight=_opt_float("throat_weight"),
downstream_weight=_opt_float("downstream_weight"),
include_case_idx=bool(data_cfg.get("include_case_idx", False)),
exclude_cases=exclude_cases,
local_velocity_normalization=bool(data_cfg.get("local_velocity_normalization", False)),
min_Dr=_opt_float("min_Dr"),
target_transform=target_transform,
engineered_feature_names=eng_names,
engineered_feature_builder=eng_builder,
)
[docs]
def dataset_info(self, dataset) -> dict:
return {
"in_features": dataset.in_features,
"out_features": dataset.out_features,
}
[docs]
def build_batch(self, raw_batch, device: torch.device):
pin_memory = device.type == "cuda"
if len(raw_batch) == 4:
x, y, w, cidx = raw_batch
return (
x.to(device, non_blocking=pin_memory),
y.to(device, non_blocking=pin_memory),
w.to(device, non_blocking=pin_memory),
cidx.to(device, non_blocking=pin_memory),
)
if len(raw_batch) == 3:
x, y, w_or_cidx = raw_batch
return (
x.to(device, non_blocking=pin_memory),
y.to(device, non_blocking=pin_memory),
w_or_cidx.to(device, non_blocking=pin_memory),
)
x, y = raw_batch
return (
x.to(device, non_blocking=pin_memory),
y.to(device, non_blocking=pin_memory),
)
[docs]
def forward_train(self, model, batch):
if len(batch) == 4:
x, y, w, cidx = batch
pred = model(x)
return pred, y, w, cidx
if len(batch) == 3:
x, y, w = batch
pred = model(x)
return pred, y, w
x, y = batch
pred = model(x)
return pred, y
[docs]
def forward_eval(self, model, batch):
if len(batch) >= 3:
x, y = batch[0], batch[1]
pred = model(x)
return pred, y
x, y = batch
pred = model(x)
return pred, y
[docs]
def accumulate_metrics(
self,
batch,
pred: torch.Tensor,
target: torch.Tensor,
) -> tuple[torch.Tensor, int]:
if pred.shape != target.shape:
raise ValueError(
f"Prediction shape {tuple(pred.shape)} does not match "
f"target shape {tuple(target.shape)}."
)
squared = (pred - target) ** 2
field_se = squared.sum(dim=0)
num_samples = int(pred.shape[0])
return field_se, num_samples
[docs]
class ProfileAdapter(ModelAdapter):
"""Adapter for per-case profile (1D-conv) models.
Datasets emit per-case ``(x, y, w, case_idx)`` items shaped
``[C, S]``, ``[O, S]``, ``[1, S]``, scalar. After default collation
the batch is ``[B, C, S]``, ``[B, O, S]``, ``[B, 1, S]``, ``[B]``.
Dataset construction is delegated to a case-side callable resolved from
``data.dataset_entrypoint`` (mirrors ``model.entrypoint``); the adapter
itself knows nothing about specific cases.
"""
family = "profile"
[docs]
def build_dataset(self, data_cfg: dict):
ep = data_cfg.get("dataset_entrypoint")
if ep is None:
raise ValueError(
"Profile adapter requires data.dataset_entrypoint "
"in the format '<module>:<callable>' where the callable "
"accepts a data_cfg dict and returns a profile dataset."
)
module_name, fn_name = str(ep).split(":", 1)
build = getattr(importlib.import_module(module_name), fn_name)
return build(data_cfg)
[docs]
def dataset_info(self, dataset) -> dict:
x, _, _, _ = dataset[0]
return {
"in_channels": dataset.in_features,
"out_channels": dataset.out_features,
"n_stations": int(x.shape[-1]),
}
[docs]
def build_batch(self, raw_batch, device: torch.device):
pin_memory = device.type == "cuda"
x, y, w, cidx = raw_batch
return (
x.to(device, non_blocking=pin_memory),
y.to(device, non_blocking=pin_memory),
w.to(device, non_blocking=pin_memory),
cidx.to(device, non_blocking=pin_memory),
)
[docs]
def forward_train(self, model, batch):
x, y, w, cidx = batch
pred = model(x)
return pred, y, w, cidx
[docs]
def forward_eval(self, model, batch):
x, y = batch[0], batch[1]
pred = model(x)
return pred, y
[docs]
def accumulate_metrics(
self,
batch,
pred: torch.Tensor,
target: torch.Tensor,
) -> tuple[torch.Tensor, int]:
if pred.shape != target.shape:
raise ValueError(
f"Prediction shape {tuple(pred.shape)} does not match "
f"target shape {tuple(target.shape)}."
)
if pred.ndim != 3:
raise ValueError(
f"Profile predictions must have shape [B, O, S], got {tuple(pred.shape)}."
)
squared = (pred - target) ** 2
field_se = squared.sum(dim=(0, 2))
num_samples = int(pred.shape[0] * pred.shape[2])
return field_se, num_samples
ADAPTER_REGISTRY = {
"grid": GridAdapter,
"graph": GraphAdapter,
"pointwise": PointwiseAdapter,
"profile": ProfileAdapter,
}
[docs]
def get_adapter(name: str) -> ModelAdapter:
if name not in ADAPTER_REGISTRY:
available = ", ".join(sorted(ADAPTER_REGISTRY))
raise ValueError(f"Unknown adapter '{name}'. Available adapters: {available}")
return ADAPTER_REGISTRY[name]()