Source code for training.adapters

"""Adapter layer to unify grid and graph model families."""

import importlib
from collections.abc import Callable
from pathlib import Path

import torch

from training import _require_pyg
from training.datasets import GraphPairDataset, GridPairDataset, parse_field_list


[docs] class ModelAdapter: """Interface for model/data-family-specific behavior.""" family: str
[docs] def build_dataset(self, data_cfg: dict): raise NotImplementedError
[docs] def dataset_info(self, dataset) -> dict: raise NotImplementedError
[docs] def build_batch(self, raw_batch, device: torch.device): raise NotImplementedError
[docs] def forward_train(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError
[docs] def forward_eval(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError
[docs] def collate_fn(self) -> Callable | None: return None
[docs] def accumulate_metrics( self, batch, pred: torch.Tensor, target: torch.Tensor, ) -> tuple[torch.Tensor, int]: raise NotImplementedError
[docs] class GridAdapter(ModelAdapter): family = "grid"
[docs] def build_dataset(self, data_cfg: dict) -> GridPairDataset: return GridPairDataset( zarr_dir=data_cfg["zarr_dir"], input_fields=parse_field_list(data_cfg.get("input_fields")), output_fields=parse_field_list(data_cfg.get("output_fields")), input_time_idx=int(data_cfg.get("input_time_idx", 0)), target_time_idx=int(data_cfg.get("target_time_idx", -1)), )
[docs] def dataset_info(self, dataset: GridPairDataset) -> dict: return { "in_channels": len(dataset.input_indices), "out_channels": len(dataset.output_indices), "spatial_shape": dataset.spatial_shape, }
[docs] def build_batch(self, raw_batch, device: torch.device): pin_memory = device.type == "cuda" x, y = raw_batch return ( x.to(device, non_blocking=pin_memory), y.to(device, non_blocking=pin_memory), )
[docs] def forward_train(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]: x, y = batch pred = model(x) return pred, y
[docs] def forward_eval(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]: x, y = batch pred = model(x) return pred, y
[docs] def accumulate_metrics( self, batch, pred: torch.Tensor, target: torch.Tensor, ) -> tuple[torch.Tensor, int]: if pred.shape != target.shape: raise ValueError( f"Prediction shape {tuple(pred.shape)} does not match target shape {tuple(target.shape)}." ) if pred.ndim != 4: raise ValueError( f"Grid predictions must have shape [B, C, Nx, Ny], got {tuple(pred.shape)}." ) squared = (pred - target) ** 2 field_se = squared.mean(dim=(2, 3)).sum(dim=0) num_samples = int(pred.shape[0]) return field_se, num_samples
[docs] class GraphAdapter(ModelAdapter): family = "graph" def __init__(self): _require_pyg() from torch_geometric.data import Batch self._pyg_batch_cls = Batch self._last_batch = None
[docs] def build_dataset(self, data_cfg: dict) -> GraphPairDataset: return GraphPairDataset( zarr_dir=data_cfg["zarr_dir"], input_fields=parse_field_list(data_cfg.get("input_fields")), output_fields=parse_field_list(data_cfg.get("output_fields")), input_time_idx=int(data_cfg.get("input_time_idx", 0)), target_time_idx=int(data_cfg.get("target_time_idx", -1)), )
[docs] def dataset_info(self, dataset: GraphPairDataset) -> dict: return { "in_channels": len(dataset.input_indices), "out_channels": len(dataset.output_indices), "edge_dim": dataset.edge_dim, }
[docs] def collate_fn(self) -> Callable | None: def _pyg_collate(items): return self._pyg_batch_cls.from_data_list(items) return _pyg_collate
[docs] def build_batch(self, raw_batch, device: torch.device): pin_memory = device.type == "cuda" batch = raw_batch.to(device, non_blocking=pin_memory) self._last_batch = batch return batch
[docs] def forward_train(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]: pred = model(batch.x, batch.edge_attr, batch) return pred, batch.y
[docs] def forward_eval(self, model, batch) -> tuple[torch.Tensor, torch.Tensor]: pred = model(batch.x, batch.edge_attr, batch) return pred, batch.y
[docs] def accumulate_metrics( self, batch, pred: torch.Tensor, target: torch.Tensor, ) -> tuple[torch.Tensor, int]: batch_obj = self._last_batch if batch_obj is None: raise RuntimeError("GraphAdapter has no batch metadata for metric accumulation.") if pred.shape != target.shape: raise ValueError( f"Prediction shape {tuple(pred.shape)} does not match target shape {tuple(target.shape)}." ) if pred.ndim != 2: raise ValueError( f"Graph predictions must have shape [total_nodes, C], got {tuple(pred.shape)}." ) if not hasattr(batch_obj, "batch"): raise ValueError("Graph batch is missing 'batch' graph index tensor.") graph_ids = batch_obj.batch num_graphs = int(batch_obj.num_graphs) squared = (pred - target) ** 2 graph_sum = torch.zeros( (num_graphs, squared.shape[1]), dtype=squared.dtype, device=squared.device ) graph_count = torch.zeros(num_graphs, dtype=squared.dtype, device=squared.device) graph_sum.index_add_(0, graph_ids, squared) graph_count.index_add_(0, graph_ids, torch.ones_like(graph_ids, dtype=squared.dtype)) graph_mean = graph_sum / graph_count.clamp_min(1.0).unsqueeze(-1) field_se = graph_mean.sum(dim=0) return field_se, num_graphs
[docs] class PointwiseAdapter(ModelAdapter): """Adapter for tabular / pointwise MLP models. Expects datasets producing ``(x, y)`` tuples of shape ``(D_in,)`` and ``(D_out,)`` which the default collate batches into ``(B, D_in)`` and ``(B, D_out)``. """ family = "pointwise"
[docs] def build_dataset(self, data_cfg: dict): from training.datasets_tabular import TabularPairDataset norm_from_case_indices = data_cfg.get("norm_from_case_indices") if norm_from_case_indices is not None: norm_from_case_indices = [int(i) for i in norm_from_case_indices] input_columns = None if data_cfg.get("input_columns_file") is not None: cols_path = Path(str(data_cfg["input_columns_file"])) if not cols_path.exists(): raise FileNotFoundError(f"input_columns_file not found: {cols_path}") input_columns = [ line.strip() for line in cols_path.read_text().splitlines() if line.strip() ] if not input_columns: raise ValueError(f"input_columns_file is empty: {cols_path}") else: input_columns = parse_field_list(data_cfg.get("input_columns")) def _opt_float(key: str) -> float | None: v = data_cfg.get(key) return float(v) if v is not None else None exclude_cases = data_cfg.get("exclude_cases") if exclude_cases is not None: exclude_cases = [str(c) for c in exclude_cases] eng_ep = data_cfg.get("engineered_features_entrypoint") eng_names = None eng_builder = None if eng_ep is not None: module_name, fn_name = str(eng_ep).split(":", 1) eng_names, eng_builder = getattr(importlib.import_module(module_name), fn_name)() tt_ep = data_cfg.get("target_transform") target_transform = None if tt_ep is not None: module_name, fn_name = str(tt_ep).split(":", 1) target_transform = getattr(importlib.import_module(module_name), fn_name) return TabularPairDataset( zarr_dir=data_cfg["zarr_dir"], input_columns=input_columns, output_columns=parse_field_list(data_cfg.get("output_columns")), normalize=bool(data_cfg.get("normalize", False)), norm_stats=data_cfg.get("norm_stats"), norm_from_case_indices=norm_from_case_indices, throat_weight=_opt_float("throat_weight"), downstream_weight=_opt_float("downstream_weight"), include_case_idx=bool(data_cfg.get("include_case_idx", False)), exclude_cases=exclude_cases, local_velocity_normalization=bool(data_cfg.get("local_velocity_normalization", False)), min_Dr=_opt_float("min_Dr"), target_transform=target_transform, engineered_feature_names=eng_names, engineered_feature_builder=eng_builder, )
[docs] def dataset_info(self, dataset) -> dict: return { "in_features": dataset.in_features, "out_features": dataset.out_features, }
[docs] def build_batch(self, raw_batch, device: torch.device): pin_memory = device.type == "cuda" if len(raw_batch) == 4: x, y, w, cidx = raw_batch return ( x.to(device, non_blocking=pin_memory), y.to(device, non_blocking=pin_memory), w.to(device, non_blocking=pin_memory), cidx.to(device, non_blocking=pin_memory), ) if len(raw_batch) == 3: x, y, w_or_cidx = raw_batch return ( x.to(device, non_blocking=pin_memory), y.to(device, non_blocking=pin_memory), w_or_cidx.to(device, non_blocking=pin_memory), ) x, y = raw_batch return ( x.to(device, non_blocking=pin_memory), y.to(device, non_blocking=pin_memory), )
[docs] def forward_train(self, model, batch): if len(batch) == 4: x, y, w, cidx = batch pred = model(x) return pred, y, w, cidx if len(batch) == 3: x, y, w = batch pred = model(x) return pred, y, w x, y = batch pred = model(x) return pred, y
[docs] def forward_eval(self, model, batch): if len(batch) >= 3: x, y = batch[0], batch[1] pred = model(x) return pred, y x, y = batch pred = model(x) return pred, y
[docs] def accumulate_metrics( self, batch, pred: torch.Tensor, target: torch.Tensor, ) -> tuple[torch.Tensor, int]: if pred.shape != target.shape: raise ValueError( f"Prediction shape {tuple(pred.shape)} does not match " f"target shape {tuple(target.shape)}." ) squared = (pred - target) ** 2 field_se = squared.sum(dim=0) num_samples = int(pred.shape[0]) return field_se, num_samples
[docs] class ProfileAdapter(ModelAdapter): """Adapter for per-case profile (1D-conv) models. Datasets emit per-case ``(x, y, w, case_idx)`` items shaped ``[C, S]``, ``[O, S]``, ``[1, S]``, scalar. After default collation the batch is ``[B, C, S]``, ``[B, O, S]``, ``[B, 1, S]``, ``[B]``. Dataset construction is delegated to a case-side callable resolved from ``data.dataset_entrypoint`` (mirrors ``model.entrypoint``); the adapter itself knows nothing about specific cases. """ family = "profile"
[docs] def build_dataset(self, data_cfg: dict): ep = data_cfg.get("dataset_entrypoint") if ep is None: raise ValueError( "Profile adapter requires data.dataset_entrypoint " "in the format '<module>:<callable>' where the callable " "accepts a data_cfg dict and returns a profile dataset." ) module_name, fn_name = str(ep).split(":", 1) build = getattr(importlib.import_module(module_name), fn_name) return build(data_cfg)
[docs] def dataset_info(self, dataset) -> dict: x, _, _, _ = dataset[0] return { "in_channels": dataset.in_features, "out_channels": dataset.out_features, "n_stations": int(x.shape[-1]), }
[docs] def build_batch(self, raw_batch, device: torch.device): pin_memory = device.type == "cuda" x, y, w, cidx = raw_batch return ( x.to(device, non_blocking=pin_memory), y.to(device, non_blocking=pin_memory), w.to(device, non_blocking=pin_memory), cidx.to(device, non_blocking=pin_memory), )
[docs] def forward_train(self, model, batch): x, y, w, cidx = batch pred = model(x) return pred, y, w, cidx
[docs] def forward_eval(self, model, batch): x, y = batch[0], batch[1] pred = model(x) return pred, y
[docs] def accumulate_metrics( self, batch, pred: torch.Tensor, target: torch.Tensor, ) -> tuple[torch.Tensor, int]: if pred.shape != target.shape: raise ValueError( f"Prediction shape {tuple(pred.shape)} does not match " f"target shape {tuple(target.shape)}." ) if pred.ndim != 3: raise ValueError( f"Profile predictions must have shape [B, O, S], got {tuple(pred.shape)}." ) squared = (pred - target) ** 2 field_se = squared.sum(dim=(0, 2)) num_samples = int(pred.shape[0] * pred.shape[2]) return field_se, num_samples
ADAPTER_REGISTRY = { "grid": GridAdapter, "graph": GraphAdapter, "pointwise": PointwiseAdapter, "profile": ProfileAdapter, }
[docs] def get_adapter(name: str) -> ModelAdapter: if name not in ADAPTER_REGISTRY: available = ", ".join(sorted(ADAPTER_REGISTRY)) raise ValueError(f"Unknown adapter '{name}'. Available adapters: {available}") return ADAPTER_REGISTRY[name]()