Source code for adeptml.models

import torch
import torch.nn
from typing import Callable, Optional, List
from adeptml import configs

ACTIVATIONS = {
    "identity": torch.nn.Identity(),
    "leakyrelu": torch.nn.LeakyReLU(),
    "relu": torch.nn.ReLU(),
    "elu": torch.nn.ELU(),
    "sigmoid": torch.nn.Sigmoid(),
    "tanh": torch.nn.Tanh(),
    "sin": torch.sin,
    "softplus": torch.nn.Softplus(),
    "swish": torch.nn.SiLU(),
}


def _to_numpy(tensors):
    """Convert a list of tensors to numpy arrays. Returns an empty list for None/empty input."""
    if not tensors:
        return []
    return [t.detach().cpu().numpy() for t in tensors]


[docs] class MLP(torch.nn.Module): """ Multilayer Perceptron (MLP) neural network model. Attributes ---------- config : MLPConfig Instance of :class:`~adeptml.configs.MLPConfig`. Note ---- This class implements a Multilayer Perceptron (MLP) neural network model. It takes a configuration dataclass with parameters such as hidden layer size, input and output dimensions, number of hidden layers, and activation functions. """ def __init__(self, config: configs.MLPConfig): super(MLP, self).__init__() self.layers = torch.nn.ModuleList() self.linear_in = torch.nn.Linear(config.num_input_dim, config.num_hidden_dim) for _ in range(config.num_hidden_layers): self.layers.append( torch.nn.Linear(config.num_hidden_dim, config.num_hidden_dim) ) self.linear_out = torch.nn.Linear(config.num_hidden_dim, config.num_output_dim) self.nl1 = ACTIVATIONS[config.hidden_activation] self.nl2 = ACTIVATIONS[config.output_activation]
[docs] def forward(self, x): """ Forward pass of the MLP model. :param torch.Tensor x: Input tensor. :return: Output tensor. :rtype: torch.Tensor """ out = self.linear_in(x) for i in range(len(self.layers)): net = self.layers[i] out = self.nl1(net(out)) out = self.linear_out(out) return self.nl2(out)
[docs] class Physics(torch.autograd.Function): """Custom autograd function wrapping a physics model with a full Jacobian. The backward pass materialises the full ``(batch, out, in)`` Jacobian and contracts it with the upstream gradient via a batched matrix-vector product. Use :class:`Physics_VJP` when a manual VJP is cheaper than the full Jacobian, or :class:`Physics_SplitVJP` when the physics solver already produces a pullback closure (e.g. via ``jax.vjp``). See Also -------- Physics_VJP, Physics_SplitVJP """
[docs] @staticmethod def forward( ctx, x: torch.Tensor, forward_fun: Callable, jacobian_fun: Callable, args: Optional[List[torch.Tensor]] = None, ): """ Run the physics forward pass. :param ctx: PyTorch autograd context. :param x: Input tensor ``(batch, in_dim)``. :param forward_fun: Physics function ``(x_np, *args_np) -> ndarray``. :param jacobian_fun: Jacobian function ``(x_np, *args_np) -> ndarray`` with shape ``(batch, out_dim, in_dim)``. :param args: Extra positional arguments passed to ``forward_fun`` and ``jacobian_fun``. Gradients are *not* computed w.r.t. these. :return: Output tensor ``(batch, out_dim)``. :rtype: torch.Tensor """ if args: ctx.save_for_backward(x, *args) else: ctx.save_for_backward(x) ctx.jacobian_fun = jacobian_fun x_np = x.detach().cpu().numpy() args_np = _to_numpy(args) out = forward_fun(x_np, *args_np) return torch.tensor(out, dtype=x.dtype).to(configs.DEVICE)
[docs] @staticmethod def backward(ctx, grad_output): """Compute VJP via full Jacobian: ``grad = grad_output @ J``.""" x = ctx.saved_tensors[0] args = ctx.saved_tensors[1:] jacobian_fun = ctx.jacobian_fun if ctx.needs_input_grad[0]: x_np = x.detach().cpu().numpy() args_np = _to_numpy(args) jac = jacobian_fun(x_np, *args_np) jac = torch.tensor(jac, dtype=grad_output.dtype).to(configs.DEVICE) jac = jac.reshape(x_np.shape[0], -1, x_np.shape[1]) grad_final = torch.matmul(grad_output.unsqueeze(1), jac).squeeze(1) return grad_final, None, None, None return None, None, None, None
[docs] class Physics_VJP(torch.autograd.Function): """Custom autograd function wrapping a physics model with a manual VJP. Unlike :class:`Physics`, the backward pass delegates to a user-supplied ``jacobian_func(x, grad_output, *args) -> ndarray`` that returns the VJP directly, avoiding explicit Jacobian materialisation. See Also -------- Physics, Physics_SplitVJP """
[docs] @staticmethod def forward( ctx, x: torch.Tensor, forward_fun: Callable, jacobian_fun: Callable, args: Optional[List[torch.Tensor]] = None, ): """ Run the physics forward pass. :param ctx: PyTorch autograd context. :param x: Input tensor ``(batch, in_dim)``. :param forward_fun: Physics function ``(x_np, *args_np) -> ndarray``. :param jacobian_fun: VJP function ``(x_np, grad_output_np, *args_np) -> ndarray`` with shape ``(batch, in_dim)``. :param args: Extra positional arguments. Gradients are *not* computed w.r.t. these. :return: Output tensor ``(batch, out_dim)``. :rtype: torch.Tensor """ if args: ctx.save_for_backward(x, *args) else: ctx.save_for_backward(x) ctx.jacobian_fun = jacobian_fun x_np = x.detach().cpu().numpy() args_np = _to_numpy(args) out = forward_fun(x_np, *args_np) return torch.tensor(out, dtype=x.dtype).to(configs.DEVICE)
[docs] @staticmethod def backward(ctx, grad_output): """Compute VJP by calling the user-supplied VJP function directly.""" x = ctx.saved_tensors[0] args = ctx.saved_tensors[1:] jacobian_fun = ctx.jacobian_fun if ctx.needs_input_grad[0]: x_np = x.detach().cpu().numpy() args_np = _to_numpy(args) grad_np = grad_output.detach().cpu().numpy() grad_final = jacobian_fun(x_np, grad_np, *args_np) grad_final = torch.tensor(grad_final, dtype=grad_output.dtype).to( configs.DEVICE ) return grad_final, None, None, None return None, None, None, None
[docs] class Physics_SplitVJP(torch.autograd.Function): """Custom autograd function for physics models that expose a pullback closure. This mode supports ``forward_func`` functions that return both the output *and* a pullback closure — the pattern produced by ``jax.vjp``: .. code-block:: python y, pullback_fn = forward_func(x, *args) # later, during backward: (dx, *_) = pullback_fn(grad_output) The pullback closure is stored in the autograd context during the forward pass and invoked during the backward pass. No part of the physics forward is re-executed during backpropagation, making this the most memory-efficient mode when the solver is expensive and already caches its intermediates (e.g. via ``jax.checkpoint`` inside ``jax.vjp``). ``jacobian_func`` is not used in this mode and should be omitted from :class:`~adeptml.configs.PhysicsConfig`. Example ------- Given the JAX physics module:: from Funcs import physics # physics(x)-> (y_normalized, pullback_fn) Configure and use as:: from adeptml.configs import PhysicsConfig, HybridConfig from adeptml.ensemble import HybridModel phy_cfg = PhysicsConfig( forward_func=physics, use_split_vjp=True, ) hybrid_cfg = HybridConfig(models={"nn": mlp_cfg, "physics": phy_cfg}) model = HybridModel(hybrid_cfg) # During training the pullback is cached automatically; no extra code # is needed in the training loop. See Also -------- Physics, Physics_VJP """
[docs] @staticmethod def forward( ctx, x: torch.Tensor, forward_fun: Callable, args: Optional[List[torch.Tensor]] = None, ): """ Run the split-VJP physics forward pass. Calls ``forward_fun(x_np, *args_np)`` which must return ``(output_np, pullback_fn)``. The pullback is stored in ``ctx`` for use during :meth:`backward`. :param ctx: PyTorch autograd context. :param x: Input tensor ``(batch, in_dim)``. :param forward_fun: Split-VJP physics function ``(x_np, *args_np) -> (ndarray, callable)``. :param args: Extra positional arguments (e.g. boundary conditions, initial conditions, simulation time). Gradients are *not* computed w.r.t. these. :return: Output tensor ``(batch, out_dim)``. :rtype: torch.Tensor """ if args: ctx.save_for_backward(x, *args) else: ctx.save_for_backward(x) x_np = x.detach().cpu().numpy() args_np = _to_numpy(args) out_np, pullback_fn = forward_fun(x_np, *args_np) ctx.pullback_fn = pullback_fn return torch.tensor(out_np, dtype=x.dtype).to(configs.DEVICE)
[docs] @staticmethod def backward(ctx, grad_output): """Call the stored pullback closure with the upstream gradient. The cotangent for ``x`` is ``pullback_fn(grad_output)[0]``; cotangents for ``forward_fun`` and ``args`` are ``None``. """ if ctx.needs_input_grad[0]: grad_np = grad_output.detach().cpu().numpy() cotangents = ctx.pullback_fn(grad_np) # cotangents is a tuple ordered by the arguments of forward_fun; # index 0 corresponds to x. dx = torch.tensor(cotangents[0], dtype=grad_output.dtype).to( configs.DEVICE ) return dx, None, None return None, None, None