PR #401

closed

Record: 11L + EMA + Tight SWA + QAT0.15 + VE128 + Partial RoPE + LN Scale (val_bpb: 1.1243)

by newjordanView on GitHub
val_bpb
1.1243
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.88 MB

Training Techniques

Weight Averaging
EMA
parameters: {"decay":0.997}
SWA
parameters: {"tight":true,"every_steps":50,"start_scale_threshold":0.2,"from_ema_weights":true}
EMA
parameters: {"decay":0.997,"accumulation":"float32"}
Quantization
STE QAT
bits: 6
scope: MLP + attention weights
Architecture
Partial RoPE
Applies rotary position embeddings to only part of the head dimensions with NTK-aware scaling.
parameters: {"dimensions":16,"total_dimensions":64}
LN Scale
Layerwise layer-norm scaling by 1/sqrt(layer_idx+1).
parameters: null
MLP3x
Uses 3x MLP expansion with relu-squared activation.
parameters: {"multiplier":3}
SmearGate
Custom gating mechanism used in the MLP/architecture.
parameters: null
BigramHash
Adds hashed bigram features with shared embeddings.
parameters: {"buckets":2048,"dimension":128}
tied embeddings
Input and output embeddings are tied.
parameters: null
KV head count
Uses grouped-query attention with fewer KV heads than attention heads.
parameters: {"heads":8,"kv_heads":4}
Shared Value Embedding
Shared value embeddings used in selected layers.
parameters: {"dimension":128,"layers":[9,10]}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.025,"warmup_momentum_start":0.92,"warmup_steps":1500}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"lr_embeddings":0.035,"lr_scalars":0.025}
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64}
Test-Time Training
full TTT
parameters: {"epochs":8,"learning_rate":0.002,"momentum":0.9}
Sequence Length
sequence_length
train_length: 2048
eval_length: null
LR Schedule
warmdown
parameters: {"warmdown_steps":3500}
Regularization
layerwise LN scale
parameters: {"scale_rule":"1/sqrt(layer_idx+1)"}

Novel Contributions

  • Stacking EMA with tight SWA so SWA collects from EMA-averaged weights
  • Earlier late QAT threshold at 0.15 to increase time under int6 fake quantization
  • Longer warmdown schedule of 3500 iterations
  • Partial RoPE with NTK-aware scaling
  • Layerwise LN scaling
  • Shared Value Embedding
  • SmearGate and BigramHash architectural additions