PR #369
closedSubmission: 11L NTK-RoPE + FA3 + Batch524K + XSA4 + EMA (val_bpb=1.1328)
by signalrushView on GitHub
val_bpb
1.1328
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.87 MB
Training Techniques
Architecture
RoPE
NTK-aware RoPE that auto-scales the RoPE base frequency when sequence length exceeds the training length.
parameters: {"train_seq_len":1024}
XSA
Exclusive Self Attention applied to the last 4 layers to remove self-value bias.
parameters: {"layers":4}
MLP3x
3x MLP expansion with relu-squared activation.
parameters: {"expansion":3}
tied embeddings
Input and output embeddings are tied.
parameters: null
KV head count
Uses grouped-query attention with 8 attention heads and 4 KV heads.
parameters: {"heads":8,"kv_heads":4}
SmearGate
Adds SmearGate as part of the model architecture.
parameters: null
BigramHash
Uses BigramHash features with 4096 buckets and 128-dimensional embeddings.
parameters: {"buckets":4096,"dimensions":128}
U-Net skips
Uses U-Net style skip connections across encoder and decoder layers.
parameters: {"encoder":5,"decoder":6}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.025}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"tied_embed_lr":0.035,"lr":0.025}
Weight Averaging
EMA
parameters: {"decay":0.997}
Evaluation
sliding window eval
parameters: {"stride":64}
Initialization
OrthoInit
Orthogonal initialization used with muP-scaled output projections.
Sequence Length
sequence_length
train_length: 1024
eval_length: 2048
LR Schedule
warmdown
parameters: {"warmdown_steps":3000}
Regularization
weight decay
parameters: {"value":0.04}
Compression
zstd
level: 22
Novel Contributions
- NTK-aware RoPE with automatic base scaling for longer sequences
- FlashAttention 3 on Hopper to increase training throughput within the time budget
- Reduced batch size to 524K tokens/step to obtain more gradient updates in 600 seconds
- Adaptive pruning to automatically fit each seed under the 16MB artifact limit
- Exclusive Self Attention on the last 4 layers
- EMA weight averaging during training
- Mixed-precision quantization with int5 MLP, int6 attention/bigram, and int8 embeddings