PR #842

open

Non-record: 5L MLP4x + SlidingWindow + SWA + QAT — val_bpb 1.33 (1xH100)

by JUSTSUJAYView on GitHub
val_bpb
1.3380
Architecture
Transformer
Optimizer
Muon
Artifact Size
14.5MB

Training Techniques

Architecture
MLP4x
5-layer Transformer with widened MLP expansion factor to 4x (hidden size 2048) instead of deeper narrower stacks.
parameters: {"layers":5,"model_dim":512,"mlp_mult":4,"hidden":2048,"num_heads":8,"num_kv_heads":4}
BigramHash
Learned hashed embeddings for adjacent token pairs to inject lightweight bigram context.
parameters: {"buckets":4096,"dim":128}
SmearGate
Learned per-dimension gate blending each token embedding with the previous token embedding.
parameters: null
Initialization
OrthoInit
Orthogonal initialization for all weight matrices, with zero-init for output projections.
Quantization
STE QAT
bits: 8
scope: all
Weight Averaging
SWA
parameters: {"checkpoints":18,"interval_steps":50,"phase":"warmdown"}
Evaluation
sliding window eval
parameters: {"stride":64}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"matrix_lr":0.03,"embed_lr":0.06,"momentum_warmup_start":0.92,"momentum_warmup_steps":500}
LR Schedule
warmdown
parameters: {"warmdown_frac":0.5}
Regularization
weight decay
parameters: {"value":0.04}
Sequence Length
sequence_length
train_length: 1024
eval_length: null
Compression
zlib
level: null

Novel Contributions

  • 5-layer MLP4x architecture that outperformed deeper narrower models under single-GPU compute constraints
  • BigramHash embedding with 4096 buckets for lightweight bigram context
  • SmearGate token blending mechanism
  • Orthogonal initialization aligned with Muon optimizer geometry
  • Quantization-aware training with int8 STE to reduce quantization gap
  • Stochastic weight averaging over 18 checkpoints
  • Sliding window evaluation with stride 64 for improved validation score