PR #515

open

Record: Atris Labs — 3-seed mean val_bpb=1.1807, 10L MLP3x Int5/Int6 BigramHash SmearGate SWA

by keshav55View on GitHub
val_bpb
1.1807
Architecture
Transformer
Optimizer
Muon
Artifact Size
14.6MB

Training Techniques

Architecture
MLP3x
MLP with 3x hidden dimension (1536 hidden), 10 transformer layers, 512 dim, GQA (8/4 heads)
parameters: {"layers":10,"mlp_hidden":1536,"embedding_dim":512,"GQA_heads":"8/4"}
BigramHash
XOR hash of token pairs to 128-dim embedding with projection
parameters: {"hash_size":10240,"embedding_dim":128}
SmearGate
Per-dimension learned gate blending current and previous token
parameters: null
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.02}
AdamW
weight_decay: 0.01
momentum: null
other_params: {"tied_embed_lr":0.03}
Weight Averaging
SWA
parameters: {"checkpoints_averaged":24,"phase":"warmdown"}
Quantization
int5/int6
bits: null
scope: MLP weights (int5), attention weights (int6, per-row scale)
fp16
bits: 16
scope: tied embeddings passthrough
Regularization
magnitude pruning
parameters: {"pruning_amount":"3%"}
Compression
zlib
level: null
Sequence Length
sequence_length
train_length: 2048
eval_length: 2048

Novel Contributions

  • Use of BigramHash embedding with XOR hash of token pairs
  • SmearGate: per-dimension learned gating blending current and previous token
  • Mixed precision quantization with Int5 for MLP weights and Int6 for attention weights
  • Use of Muon optimizer with tuned learning rates and momentum
  • SWA weight averaging over 24 checkpoints during warmdown
  • Extended evaluation context to 2048 tokens with RoPE extrapolation
  • QAT-aware training reducing quantization degradation
  • 3% magnitude pruning combined with quantization and compression to fit artifact size