MACE MP-L

GPU Inference Profile

cuEquivariance · JAX · 3k atoms · 160k edges · float32

MACE Layer Architecture

Spherical Harmonics
Linear ↑
MLP
↓ ↓ ↓
Tensor Product
Linear ↓
Symmetric Contraction
Linear + Skip

2 layers · jax.grad for forces

Where Does the Time Go?

28 ms / inference step

includes forward + backward (grad for forces)

TP 59.6%
Linear 22.1%
Other 16.5%
Tensor Products
Linear Layers
Data movement & other
Sym. Contraction
Sph. Harmonics

Component Breakdown

Tensor Products
Custom CUDA
16.5 ms
59.6%
Linear Layers (GEMM)
XLA / cuBLAS
6.1 ms
22.1%
Data movement & other
XLA fusions
4.6 ms
16.5%
Symmetric Contraction
Custom CUDA
1.0%
Spherical Harmonics
Pure JAX
0.8%

Tensor Products Dominate

The channelwise tensor product (message passing + its backward) is the single largest cost.

Kernel Time % total
layer_1 TP backward (jvp_T) 11.5 ms 41.8%
layer_1 TP forward (3 splits) 2.6 ms 9.6%
layer_0 TP forward (3 splits) 1.3 ms 4.9%
layer_0 TP backward (jvp_T) 1.1 ms 3.9%
Layer 1's backward TP alone is 41.8% of total inference time. This layer has higher-order irreps (0e+1o+2e) making the TP significantly more expensive.

Spherical Harmonics & Symmetric Contraction

Both are negligible at this model size.

Spherical Harmonics

Forward 70 µs
Backward 153 µs
Total 223 µs

0.8% of total
Polynomial of unit vectors (L=0..3). XLA fuses into simple element-wise kernels.

Symmetric Contraction

Forward 80 µs
Backward 185 µs
Total 265 µs

1.0% of total
Contracts higher-order features with learned weights per species.

Key Takeaways

60% Tensor Products are the bottleneck. Layer 1's backward pass alone is 42% due to higher-order irreps (0e+1o+2e).
22% Linear layers (IrrepsLinear) are the second cost — standard GEMM operations at HIGHEST precision.
<2% Spherical Harmonics (0.8%) and Symmetric Contraction (1.0%) are negligible. Not worth optimizing at this scale.

MACE MP-L · 3k atoms · 160k edges · float32 · 3.5M params · nsys profiled on single GPU