Source code for training.models.fno
"""Built-in FNO model definition."""
from training import import_physicsnemo_attr
from training.models import register_model
[docs]
def build(model_cfg: dict, dataset_info: dict):
fno_cls = import_physicsnemo_attr("physicsnemo.models.fno.fno", "FNO")
resolved = {
"in_channels": dataset_info["in_channels"],
"out_channels": dataset_info["out_channels"],
"dimension": int(model_cfg.get("dimension", 2)),
"latent_channels": int(model_cfg.get("latent_channels", 32)),
"num_fno_layers": int(model_cfg.get("num_fno_layers", 4)),
"num_fno_modes": model_cfg.get("num_fno_modes", 12),
"padding": int(model_cfg.get("padding", 5)),
"decoder_layers": int(model_cfg.get("decoder_layers", 1)),
"decoder_layer_size": int(model_cfg.get("decoder_layer_size", 32)),
}
model = fno_cls(**resolved)
model._resolved_model_params = dict(resolved)
return model
register_model("fno", build_fn=build, adapter="grid")