PR #838

open

Non-Record: 11L Parallel Muon + LN Scale + LeakyReLU² MLP3x + Legal TTT — val_bpb 1.1215 (3-seed mean)

by aryanbhosaleView on GitHub
val_bpb
1.1215
Architecture
Transformer
Optimizer
Parallel Muon
Artifact Size
~15.85 MB

Training Techniques

Optimizer
Parallel Muon
weight_decay: 0.04
momentum: 0.92
other_params: {"momentum_schedule_end":0.99,"momentum_schedule_steps":1500,"newton_schulz_steps":5,"parameter_banking":true,"async_reduce_scatter_all_gather":true}
Architecture
MLP3x
3x expansion MLP with LeakyReLU(0.5)^2 activation
parameters: {"hidden_dim":1536}
LN Scale
Depth-dependent normalization scaling by 1/sqrt(layer_idx+1)
parameters: null
SmearGate
Additional gating mechanism in the architecture
parameters: null
BigramHash
Bigram hashing feature module
parameters: {"dimensions":1536,"dim":128}
Partial RoPE
Rotary positional embeddings applied to only part of the head dimensions
parameters: {"dimensions":"16/64"}
tied embeddings
Input and output embeddings are tied
parameters: null
KV head count
Grouped-query attention with 8 attention heads and 4 KV heads
parameters: {"heads":8,"kv_heads":4}
XSA
Exclusive self-attention used in the last 4 layers
parameters: {"layers":4}
Value Residual
Caches V from layer 0 and blends via learned lambda
parameters: null
Gated Attention
Per-head sigmoid gating on attention outputs
parameters: null
U-Net skips
Skip connections inspired by U-Net
parameters: null
Initialization
OrthoInit
Orthogonal initialization
Weight Averaging
EMA
parameters: {"decay":0.997}
SWA
parameters: {"every_steps":50,"scale_threshold":0.2}
Quantization
GPTQ-lite
bits: 6
scope: all weights with FP16 embedding passthrough
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64,"chunk_size":32000}
Test-Time Training
score-first TTT
parameters: {"learning_rate":0.002,"momentum":0.9,"epochs":3,"all_blocks_unfrozen":true}
LR Schedule
cosine decay
parameters: {"warmdown_steps":3500}
Regularization
layerwise LN scale
parameters: {"scale":"1/sqrt(layer_idx+1)"}

Novel Contributions

  • Parallel Muon with parameter banking and batched Newton-Schulz updates
  • Depth-dependent LN Scale normalization
  • LeakyReLU(0.5)^2 MLP with 3x expansion
  • Legal score-first test-time training under inference_mode
  • EMA plus SWA model averaging
  • GPTQ-lite int6 quantization with per-row 5-percentile clip search
  • Flash Attention 3 and torch.compile(fullgraph=True) training stack