Source code for training.models.meshgraphnet
"""Built-in MeshGraphNet model definition."""
from training import _require_pyg, import_physicsnemo_attr
from training.models import register_model
[docs]
def build(model_cfg: dict, dataset_info: dict):
_require_pyg()
mesh_graph_net_cls = import_physicsnemo_attr(
"physicsnemo.models.meshgraphnet.meshgraphnet", "MeshGraphNet"
)
resolved = {
"input_dim_nodes": dataset_info["in_channels"],
"input_dim_edges": dataset_info["edge_dim"],
"output_dim": dataset_info["out_channels"],
"processor_size": int(model_cfg.get("processor_size", 15)),
"hidden_dim_processor": int(model_cfg.get("hidden_dim_processor", 128)),
"hidden_dim_node_encoder": int(model_cfg.get("hidden_dim_node_encoder", 128)),
"num_layers_node_processor": int(model_cfg.get("num_layers_node_processor", 2)),
"num_layers_edge_processor": int(model_cfg.get("num_layers_edge_processor", 2)),
}
model = mesh_graph_net_cls(**resolved)
model._resolved_model_params = dict(resolved)
return model
register_model("meshgraphnet", build_fn=build, adapter="graph")