PR #538

open

FP8 + Arithmetic Coding + SWA (1.1511 BPB)

by cruz-andrView on GitHub
val_bpb
1.1511
Architecture
Transformer
Optimizer
Muon
Artifact Size

Training Techniques

Quantization
fp8
bits: 8
scope: all
Architecture
SmearGate + BigramHash
SmearGate gating mechanism and BigramHash embedding with vocab size 10240 and dim 128
parameters: {"layers":10,"dimensions":512,"mlp_multiplier":3,"bigram_vocab_size":10240,"bigram_dim":128,"heads":8,"kv_heads":4}
Optimizer
Muon
weight_decay: 0.04
momentum: null
other_params: null
Weight Averaging
SWA
parameters: {"start_step":4500,"checkpoint_interval":50}
Compression
custom
level: null
Test-Time Training
LoRA TTT
parameters: null
Initialization
OrthoInit
Orthogonal initialization with muP-scaled output projections
Other
other
TF32 matmul precision for non-FP8 operations
parameters: null

Novel Contributions

  • FP8 training using TransformerEngine with hybrid E4M3 forward and E5M2 backward formats for increased throughput
  • Custom pure Python 32-bit integer arithmetic coder exploiting per-tensor empirical histograms to approach Shannon entropy, replacing zstd-22
  • Early start of Stochastic Weight Averaging at step 4500 for more averaging during warmdown
  • Use of TF32 precision for matmul operations outside FP8
  • Architecture modifications including SmearGate and BigramHash embeddings with large vocab size
  • FP8Linear wrapper isolating Muon optimizer from TransformerEngine's internal weight caches
  • Custom binary format eliminating torch.save pickle overhead