PR #331

open

10L MLP3x + BigramHash(2048) + SWA + Stride-32: 1.1487 BPB

by RhodriumView on GitHub
val_bpb
1.1487
Architecture
Transformer
Optimizer
Muon
Artifact Size
14.9 MB

Training Techniques

Architecture
MLP3x
10-layer transformer with relu² MLP expanded to 3x hidden size.
parameters: {"layers":10,"hidden":1536}
BigramHash
Adds BigramHash features for n-gram information.
parameters: {"vocab":2048,"dim":128}
SmearGate
Learnable previous-token blending mechanism.
parameters: null
weight tying
Input and output embeddings are tied.
parameters: null
KV head count
Uses grouped-query attention with fewer KV heads than attention heads.
parameters: {"heads":8,"kv_heads":4}
Initialization
OrthoInit
Orthogonal initialization with scaled projections.
Quantization
mixed int5/int6
bits: null
scope: MLP int5, attention int6, embeddings fp16
Compression
zstd
level: 22
Weight Averaging
SWA
parameters: {"checkpoints_averaged":24}
Evaluation
sliding window eval
parameters: {"stride":32,"context_length":2048}
Optimizer
Muon
weight_decay: 0.04
momentum: null
other_params: {"adamw_for_embeddings_scalars":true}
Sequence Length
sequence_length
train_length: 2048
eval_length: 2048
Regularization
weight decay
parameters: {"value":0.04}

Novel Contributions

  • 10-layer relu² MLP3x transformer
  • BigramHash(2048) with SmearGate
  • Orthogonal initialization
  • Mixed int5/int6 quantization with zstd-22 compression
  • SWA averaging over late checkpoints
  • Stride-32 dense sliding-window evaluation