Source code for adeptml.ensemble

"""Ensemble module: HybridModel combining neural networks and physics solvers."""

import torch
import torch.nn
from adeptml import configs, models


[docs] class HybridModel(torch.nn.Module): """Torch module for serial hybrid physics-informed models. Combines neural network modules (MLPs, custom ``torch.nn.Module`` subclasses) with non-differentiable physics models in a single differentiable computation graph. Models are executed in the insertion order of ``config.models``. Parameters ---------- config : HybridConfig Configuration object specifying all sub-models and optional input routing. See :class:`~adeptml.configs.HybridConfig` for details. Examples -------- Serial NN → physics pipeline:: cfg = HybridConfig(models={"nn": mlp_cfg, "physics": phy_cfg}) model = HybridModel(cfg) out = model(x, phy_args=[arg1, arg2]) Custom input routing — pass original input to the physics model regardless of what the NN outputs:: cfg = HybridConfig( models={"nn": mlp_cfg, "physics": phy_cfg}, model_inputs={"physics": {"Input": None, "nn": None}}, ) """ def __init__(self, config: configs.HybridConfig): super().__init__() self.models_nn = torch.nn.ModuleDict() self.models_physics = {} self.config = config # Resolve the apply function for each physics model once at init time. for model_name, model_cfg in config.models.items(): if isinstance(model_cfg, torch.nn.Module): self.models_nn[model_name] = model_cfg.to(configs.DEVICE) elif isinstance(model_cfg, configs.PhysicsConfig): if model_cfg.use_split_vjp: self.models_physics[model_name] = models.Physics_SplitVJP.apply elif model_cfg.use_vjp: self.models_physics[model_name] = models.Physics_VJP.apply else: self.models_physics[model_name] = models.Physics.apply elif isinstance(model_cfg, configs.MLPConfig): self.models_nn[model_name] = models.MLP(model_cfg).to(configs.DEVICE) # Compute the set of intermediate outputs that need to be cached for # custom input routing. Done once here — not inside forward(). self.to_save = [] if config.model_inputs: to_save = [] for _, vals in config.model_inputs.items(): to_save += list(vals.keys()) self.to_save = list(set(to_save))
[docs] def forward(self, x, phy_args=None): """Run inference on the hybrid model. :param torch.Tensor x: Input tensor ``(batch, in_dim)``. :param phy_args: Extra positional arguments forwarded to physics sub-models (e.g. ``[arg1, arg2]``). May be a list of tensors or ``None``. :return: Output of the final model in the pipeline. :rtype: torch.Tensor """ interim_data = {"Input": x} current_input = x out = x # default if config.models is empty for model_name, model_cfg in self.config.models.items(): # --- resolve input for this model --- if self.config.model_inputs and model_name in self.config.model_inputs: input_tensors = [] for src_name, dims in self.config.model_inputs[model_name].items(): src = interim_data[src_name] input_tensors.append(src[:, dims] if dims else src) current_input = torch.hstack(input_tensors) # --- run the model --- if model_name in self.models_nn: out = self.models_nn[model_name](current_input) elif model_name in self.models_physics: apply_fn = self.models_physics[model_name] if isinstance(model_cfg, configs.PhysicsConfig) and model_cfg.use_split_vjp: # Split-VJP: forward_func already bakes in the pullback; no jacobian_func. out = apply_fn(current_input, model_cfg.forward_func, phy_args) else: out = apply_fn( current_input, model_cfg.forward_func, model_cfg.jacobian_func, phy_args, ) else: out = current_input # --- cache output for downstream routing --- if model_name in self.to_save: interim_data[model_name] = out current_input = out return out