PR #366
openNon-record: 10L Int5-MLP + TTT + Backout Connection (val_bpb=1.1574 on 8xH100 SXM)
by shivnarainms22View on GitHub
val_bpb
1.1574
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.5MB
Training Techniques
Architecture
SmearGate
Adds SmearGate to the model architecture.
parameters: null
BigramHash
Uses a BigramHash module for token/context representation.
parameters: {"size":10240,"dim":128}
MLP3x
Uses a 3x expansion MLP with relu^2 activation.
parameters: {"hidden_dim":1536}
weight tying
Uses tied embeddings.
parameters: null
KV head count
Uses grouped-query attention with 8 attention heads and 4 KV heads.
parameters: {"heads":8,"kv_heads":4}
Quantization
mixed int5/int6
bits: 5
scope: MLP
mixed int5/int6
bits: 6
scope: attention
Compression
zstd
level: 22
Weight Averaging
EMA
parameters: {"decay":0.997,"start_step":50}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"matrix_lr":0.02,"warmup_momentum":0.92}
Evaluation
sliding window eval
parameters: {"stride":64}
Test-Time Training
full TTT
parameters: {"epochs":3,"learning_rate":0.002,"momentum":0.9,"grad_clip":1,"frozen_blocks":2}
Regularization
magnitude pruning
parameters: {"sparsity":0.03}
Other
other
Backout connection that subtracts a learned scalar multiple of the midpoint hidden state from the final representation before RMSNorm.
parameters: {"layer":5,"lambda_init":0.2,"extra_parameters":1}
Novel Contributions
- Backout connection at the U-Net midpoint with a learned scalar subtraction
- Test-time training on validation tokens after quantization roundtrip
- EMA replacing SWA for weight averaging
- Mixed int5 MLP / int6 attention quantization with zstd-22 compression
- SmearGate and BigramHash architectural additions