PR #359

closed

11L MLP3x + Int6 QAT + XSA + EMA + BigramHash + FA3 (val_bpb 1.1345)

by tmustierView on GitHub
val_bpb
1.1345
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.37MB

Training Techniques

Quantization
STE QAT
bits: 6
scope: all
Architecture
MLP3x
3x MLP expansion with SwiGLU in an 11-layer Transformer
parameters: {"layers":11,"width":512}
XSA
Cross-sequence attention applied to the last layers
parameters: {"last_n_layers":4}
BigramHash
BigramHash embedding augmentation
parameters: {"vocab_size":2048,"dim":128}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"adam_for_non_matrix_params":true}
Adam
weight_decay: 0.04
momentum: null
other_params: {"used_for":"non-matrix params"}
Weight Averaging
EMA
parameters: {"decay":0.997}
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64}
Sequence Length
sequence_length
train_length: 2048
eval_length: 2048
LR Schedule
warmdown
parameters: {"warmdown_iters":3000,"warmup_steps":20}
Regularization
weight decay
parameters: {"muon_wd":0.04,"adam_wd":0.04}
Other
other
FlashAttention 3 (Hopper build) used for competitive throughput
parameters: {"required":true}

Novel Contributions

  • 11-layer Transformer with 3x MLP expansion
  • Int6 STE QAT with zstd-22 compression
  • XSA on the last 4 layers
  • EMA weight averaging
  • BigramHash(2048) embedding augmentation
  • Muon optimizer combined with Adam for non-matrix parameters
  • Sliding-window evaluation with stride 64
  • FlashAttention 3 Hopper build for throughput