PR #1450

open

Record: TMA Megakernel + Triple Loop + Parallel Residuals — val_bpb 1.08480

by andrewbaggio1View on GitHub
val_bpb
1.0848
Architecture
Transformer
Optimizer
Muon
Artifact Size
15,750,593

Training Techniques

Quantization
GPTQ
bits: 6
scope: weight matrices
GPTQ
bits: 8
scope: embeddings
Architecture
depth recurrence
Layers 4-5 are looped multiple times to create virtual layers from fewer physical layers.
parameters: {"layers":[4,5],"loops":3,"virtual_layers":17}
parallel residuals
Attention and MLP branches read from the same pre-residual input and are summed in parallel.
parameters: {"layers":[7,8,9,10]}
LeakyReLU
Uses leaky ReLU squared activation in the MLP.
parameters: {"slope":0.5}
weight tying
Tied input and output embeddings.
parameters: null
Partial RoPE
Applies rotary position embeddings to only part of the head dimensions.
parameters: {"dimensions":16,"total_dimensions":64}
XSA
Uses XSA attention across all layers.
parameters: null
U-Net skip connections
Uses gated skip connections in a U-Net-like pattern.
parameters: null
Weight Averaging
EMA
parameters: {"decay":0.997}
Evaluation
sliding window eval
parameters: {"stride":64}
Compression
brotli
level: 11
Optimizer
Muon
weight_decay: null
momentum: null
other_params: null

Novel Contributions

  • Triton TMA fused MLP forward megakernel for Hopper GPUs
  • Avoids materializing the 2048-wide MLP intermediate activation
  • Reported +10.5% training throughput improvement
  • Triple depth recurrence with NUM_LOOPS=3
  • GPT-J style parallel residuals in layers 7-10
  • Combined frontier techniques to reach val_bpb 1.08480