PR #1691

open

Submission/fastattn mtp dr

by AVINASH0052View on GitHub
val_bpb
1.2244
Architecture
Transformer
Optimizer
Muon
Artifact Size
16 MB

Training Techniques

Architecture
GQA
Grouped query attention with 8 attention heads and 4 KV heads in the baseline; the alternate fork uses 8 heads and 2 KV heads.
parameters: {"heads":8,"kv_heads":4}
weight tying
Tied embedding and unembedding tables, and reused tied embeddings for the MTP vocabulary projection.
parameters: null
U-Net skip connections
U-Net style skip connections with per-feature skip weights.
parameters: null
ReLU²
MLP uses ReLU squared activation.
parameters: null
depth recurrence
Runs the block stack multiple times with shared weights in the alternate fork.
parameters: {"reps":2}
KV head count
Uses fewer KV heads in the alternate fork to improve the GQA layout.
parameters: {"kv_heads":2}
MTP head
Adds a multi-token prediction auxiliary head with a small projection and RMSNorm, trained only during training.
parameters: {"aux_token_offset":2}
Optimizer
Muon
weight_decay: null
momentum: null
other_params: {"matrix_lr":0.04}
Adam
weight_decay: null
momentum: null
other_params: {"embeddings_lr":0.05,"scalars_lr":0.04}
Quantization
GPTQ
bits: 8
scope: all
Compression
zlib
level: null
LR Schedule
linear warmup
parameters: {"warmup_steps":20}
linear cooldown
parameters: {"cooldown_steps":1200}
Regularization
logit softcap
parameters: null
Sequence Length
sequence_length
train_length: 1024
eval_length: 1024

Novel Contributions

  • Added multi-token prediction auxiliary loss to the baseline model
  • Used depth recurrence with shared weights in the alternate fork
  • Adjusted model width/depth and KV head count while keeping the parameter budget similar
  • Kept the proven fast baseline components such as Muon, GQA, U-Net skips, tied embeddings, and GPTQ compression