Source code for training.models.mlp
"""Built-in FullyConnected MLP model definition.
Supports optional inter-layer dropout for improved regularisation in
deeper networks. Dropout is applied between hidden layers during
training and disabled at eval time, so no extra state needs to be
persisted in the checkpoint.
"""
import torch
from training import import_physicsnemo_attr
from training.models import register_model
class _InterLayerDropout(torch.nn.Module):
"""Wraps a FullyConnected model, adding dropout between hidden layers.
Unlike the previous input-only ``_DropoutWrapper``, this applies
dropout *after each hidden layer's activation*, which is the
standard placement for MLP regularisation.
The saved checkpoint contains only the inner model so that
evaluation works with the standard ``Module.from_checkpoint`` path.
"""
def __init__(self, model: torch.nn.Module, dropout: float):
super().__init__()
self.model = model
self.p = dropout
n_layers = len(model.layers)
self.drops = torch.nn.ModuleList([torch.nn.Dropout(dropout) for _ in range(n_layers)])
def forward(self, x):
x_skip = None
for i, layer in enumerate(self.model.layers):
x = layer(x)
if self.training:
x = self.drops[i](x)
if self.model.skip_connections and i % 2 == 0:
if x_skip is not None:
x, x_skip = x + x_skip, x
else:
x_skip = x
return self.model.final_layer(x)
# Delegate PhysicsNeMo serialization to the inner model.
def save(self, path): # noqa: D102
return self.model.save(path)
def state_dict(self, *args, **kwargs): # noqa: D102
return self.model.state_dict(*args, **kwargs)
def load_state_dict(self, *args, **kwargs): # noqa: D102
return self.model.load_state_dict(*args, **kwargs)
[docs]
def build(model_cfg: dict, dataset_info: dict):
fc_cls = import_physicsnemo_attr("physicsnemo.models.mlp.fully_connected", "FullyConnected")
resolved = {
"in_features": dataset_info["in_features"],
"out_features": dataset_info["out_features"],
"layer_size": int(model_cfg.get("layer_size", 128)),
"num_layers": int(model_cfg.get("num_layers", 6)),
"activation_fn": model_cfg.get("activation_fn", "silu"),
"skip_connections": bool(model_cfg.get("skip_connections", True)),
"adaptive_activations": bool(model_cfg.get("adaptive_activations", False)),
"weight_norm": bool(model_cfg.get("weight_norm", False)),
}
model = fc_cls(**resolved)
dropout = float(model_cfg.get("dropout", 0.0))
if dropout > 0:
model = _InterLayerDropout(model, dropout)
resolved["dropout"] = dropout
model._resolved_model_params = dict(resolved)
return model
register_model("mlp", build_fn=build, adapter="pointwise")