cuEquivariance · JAX · 3k atoms · 160k edges · float32
2 layers · jax.grad for forces
includes forward + backward (grad for forces)
The channelwise tensor product (message passing + its backward) is the single largest cost.
| Kernel | Time | % total |
|---|---|---|
| layer_1 TP backward (jvp_T) | 11.5 ms | 41.8% |
| layer_1 TP forward (3 splits) | 2.6 ms | 9.6% |
| layer_0 TP forward (3 splits) | 1.3 ms | 4.9% |
| layer_0 TP backward (jvp_T) | 1.1 ms | 3.9% |
0e+1o+2e) making the TP significantly more expensive.
Both are negligible at this model size.
| Forward | 70 µs |
| Backward | 153 µs |
| Total | 223 µs |
0.8% of total
Polynomial of unit vectors (L=0..3). XLA fuses into simple
element-wise kernels.
| Forward | 80 µs |
| Backward | 185 µs |
| Total | 265 µs |
1.0% of total
Contracts higher-order features with learned weights per species.
0e+1o+2e).
HIGHEST precision.
MACE MP-L · 3k atoms · 160k edges · float32 · 3.5M params · nsys profiled on single GPU