PR #369

closed

Submission: 11L NTK-RoPE + FA3 + Batch524K + XSA4 + EMA (val_bpb=1.1328)

by signalrushView on GitHub
val_bpb
1.1328
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.87 MB

Training Techniques

Architecture
RoPE
NTK-aware RoPE that auto-scales the RoPE base frequency when sequence length exceeds the training length.
parameters: {"train_seq_len":1024}
XSA
Exclusive Self Attention applied to the last 4 layers to remove self-value bias.
parameters: {"layers":4}
MLP3x
3x MLP expansion with relu-squared activation.
parameters: {"expansion":3}
tied embeddings
Input and output embeddings are tied.
parameters: null
KV head count
Uses grouped-query attention with 8 attention heads and 4 KV heads.
parameters: {"heads":8,"kv_heads":4}
SmearGate
Adds SmearGate as part of the model architecture.
parameters: null
BigramHash
Uses BigramHash features with 4096 buckets and 128-dimensional embeddings.
parameters: {"buckets":4096,"dimensions":128}
U-Net skips
Uses U-Net style skip connections across encoder and decoder layers.
parameters: {"encoder":5,"decoder":6}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.025}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"tied_embed_lr":0.035,"lr":0.025}
Weight Averaging
EMA
parameters: {"decay":0.997}
Evaluation
sliding window eval
parameters: {"stride":64}
Initialization
OrthoInit
Orthogonal initialization used with muP-scaled output projections.
Sequence Length
sequence_length
train_length: 1024
eval_length: 2048
LR Schedule
warmdown
parameters: {"warmdown_steps":3000}
Regularization
weight decay
parameters: {"value":0.04}
Compression
zstd
level: 22

Novel Contributions

  • NTK-aware RoPE with automatic base scaling for longer sequences
  • FlashAttention 3 on Hopper to increase training throughput within the time budget
  • Reduced batch size to 524K tokens/step to obtain more gradient updates in 600 seconds
  • Adaptive pruning to automatically fit each seed under the 16MB artifact limit
  • Exclusive Self Attention on the last 4 layers
  • EMA weight averaging during training
  • Mixed-precision quantization with int5 MLP, int6 attention/bigram, and int8 embeddings