PR #417

closed

Record: 11L XSA4 + Tight SWA + FA3 + Two-Phase TTT (3-seed mean val_bpb=1.1227)

by EthanYangTWView on GitHub
val_bpb
1.1227
Architecture
Transformer
Optimizer
Adam
Artifact Size
15,758,953 bytes

Training Techniques

Architecture
XSA
Uses XSA attention in the last 4 layers.
parameters: {"layers":4}
SmearGate
3x MLP with SmearGate nonlinearity.
parameters: {"mlp_multiplier":3}
BigramHash
Bigram hashing feature for token pair modeling.
parameters: {"buckets":2048}
Partial RoPE
Applies rotary position embeddings partially.
parameters: {"train_fraction":16,"total_fraction":64}
KV head count
Grouped-query attention with fewer KV heads than attention heads.
parameters: {"heads":8,"kv_heads":4}
Initialization
OrthoInit
Orthogonal initialization.
Regularization
layerwise LN scale
parameters: null
Quantization
mixed int6/int5 QAT
bits: 6
scope: int5 MLP layers, int6 attention
Weight Averaging
SWA
parameters: {"every_steps":50}
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":32}
Test-Time Training
two-phase TTT
parameters: {"phase_1":{"method":"norm-only recalibration","epochs":50,"optimizer":"Adam","learning_rate":0.01,"trainable_params":"~22K"},"phase_2":{"method":"selective-freeze block adaptation","epochs":10,"optimizer":"SGD","learning_rate":0.005,"trainable_params":"~7.6M"}}
Optimizer
Adam
weight_decay: null
momentum: null
other_params: {"phase":"TTT phase 1","learning_rate":0.01}
SGD
weight_decay: null
momentum: null
other_params: {"phase":"TTT phase 2","learning_rate":0.005}
Other
other
FA3 Hopper attention for faster training throughput.
parameters: {"step_time_ms":84.65}
other
Late QAT with 4% warmdown/quantization phase.
parameters: {"warmdown_fraction":0.04}
other
Tight SWA preserving averaged weights in the first 8 blocks during phase 2 TTT.
parameters: {"preserved_blocks":8}
other
Magnitude pruning.
parameters: {"pruning_rate":0.02}

Novel Contributions

  • Two-phase test-time training with norm-only recalibration followed by selective-freeze block adaptation
  • FA3 Hopper attention to increase training throughput
  • Tight SWA with preserved first 8 blocks during adaptation
  • Late QAT with mixed int5 MLP and int6 attention
  • XSA attention in the last 4 layers
  • BigramHash and partial RoPE architecture modifications