Source code for training.models.afno
"""Built-in AFNO model definition."""
from training import import_physicsnemo_attr
from training.models import register_model
[docs]
def build(model_cfg: dict, dataset_info: dict):
afno_cls = import_physicsnemo_attr("physicsnemo.models.afno.afno", "AFNO")
resolved = {
"inp_shape": list(dataset_info["spatial_shape"]),
"in_channels": dataset_info["in_channels"],
"out_channels": dataset_info["out_channels"],
"patch_size": list(model_cfg.get("patch_size", [16, 16])),
"embed_dim": int(model_cfg.get("embed_dim", 256)),
"depth": int(model_cfg.get("depth", 4)),
"num_blocks": int(model_cfg.get("num_blocks", 16)),
"sparsity_threshold": float(model_cfg.get("sparsity_threshold", 0.01)),
"hard_thresholding_fraction": float(model_cfg.get("hard_thresholding_fraction", 1.0)),
}
model = afno_cls(**resolved)
model._resolved_model_params = dict(resolved)
return model
register_model("afno", build_fn=build, adapter="grid")