PR #366

open

Non-record: 10L Int5-MLP + TTT + Backout Connection (val_bpb=1.1574 on 8xH100 SXM)

by shivnarainms22View on GitHub
val_bpb
1.1574
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.5MB

Training Techniques

Architecture
SmearGate
Adds SmearGate to the model architecture.
parameters: null
BigramHash
Uses a BigramHash module for token/context representation.
parameters: {"size":10240,"dim":128}
MLP3x
Uses a 3x expansion MLP with relu^2 activation.
parameters: {"hidden_dim":1536}
weight tying
Uses tied embeddings.
parameters: null
KV head count
Uses grouped-query attention with 8 attention heads and 4 KV heads.
parameters: {"heads":8,"kv_heads":4}
Quantization
mixed int5/int6
bits: 5
scope: MLP
mixed int5/int6
bits: 6
scope: attention
Compression
zstd
level: 22
Weight Averaging
EMA
parameters: {"decay":0.997,"start_step":50}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"matrix_lr":0.02,"warmup_momentum":0.92}
Evaluation
sliding window eval
parameters: {"stride":64}
Test-Time Training
full TTT
parameters: {"epochs":3,"learning_rate":0.002,"momentum":0.9,"grad_clip":1,"frozen_blocks":2}
Regularization
magnitude pruning
parameters: {"sparsity":0.03}
Other
other
Backout connection that subtracts a learned scalar multiple of the midpoint hidden state from the final representation before RMSNorm.
parameters: {"layer":5,"lambda_init":0.2,"extra_parameters":1}

Novel Contributions

  • Backout connection at the U-Net midpoint with a learned scalar subtraction
  • Test-time training on validation tokens after quantization roundtrip
  • EMA replacing SWA for weight averaging
  • Mixed int5 MLP / int6 attention quantization with zstd-22 compression
  • SmearGate and BigramHash architectural additions