PR #1202

open

Record: Unified Attention + FA3 + Legal TTT (val_bpb=1.1412, 3-seed)

by VirajDeshwalView on GitHub
val_bpb
1.1412
Architecture
Transformer
Optimizer
Parallel Muon
Artifact Size
~15.97 MB

Training Techniques

Architecture
Unified Attention
Replaces separate Q/K/V projections with a single unified projection whose output splits into functional bands.
parameters: {"layers":11,"dimension":528,"heads":4}
MLP3x
Uses a 3x MLP expansion in the model.
parameters: {"multiplier":3}
LeakyReLU
Uses LeakyReLU squared activation in the MLP.
parameters: {"squared":true,"alpha":0.5}
SmearGate
Position-mixing gate with zero-init sigmoid.
parameters: null
VE128
Value embedding applied on later layers.
parameters: {"layers":[9,10]}
U-Net skip connections
Encoder-decoder style skip connections between layers.
parameters: null
Regularization
LN scale
parameters: {"scale":"1/sqrt(layer+1)"}
Weight Averaging
EMA + Tight SWA
parameters: {"ema_decay":0.997,"swa_every":50}
Quantization
GPTQ-lite
bits: 6
scope: weights
int8
bits: 8
scope: embeddings
STE QAT
bits: 6
scope: all
Compression
lzma
level: 6
Test-Time Training
score-first TTT
parameters: {"chunk_size":32768,"epochs_per_chunk":3,"learning_rate":0.002,"momentum":0.9,"stride":64}
Optimizer
SGD
weight_decay: null
momentum: 0.9
other_params: {"lr":0.002,"cosine_decay":true,"grad_clip":1}
LR Schedule
cosine decay
parameters: {"applied_to":"TTT chunks"}
Evaluation
sliding window eval
parameters: {"stride":64}
Sequence Length
sequence_length
train_length: 2048
eval_length: null

Novel Contributions

  • Unified Attention with a single W_unified projection replacing Q/K/V
  • FA3 head-dim padding from 44 to 48 for Hopper compatibility while remaining mathematically lossless
  • Legal score-first test-time training protocol
  • Parallel Muon with parameter banking and batched tensor operations
  • Combining unified attention, FA3 speedups, and legal TTT to improve val_bpb under the 16MB constraint