PR #640

RECORDopen

Record Submission: 1.1570 BPB - 73.7M Ternary U-Net + NeoMuon + 4x relu²MLP + Factored Tied Emb + Poly5 Softcap + YaRN2048 + 8192BPE + FP8QAT + Bitmask-LZMA + Stride-16 Sliding

by CiprianFlorin-IfrimView on GitHub
val_bpb
1.1570
Architecture
Ternary U-Net Transformer
Optimizer
Muon
Artifact Size
15.99 MB

Training Techniques

Quantization
BitNet b1.58 ternary quantisation with FP8 QAT
bits: 1
scope: weights ternary {-1,0,+1} with FP8 QAT for fp params
Architecture
U-Net encoder/decoder
U-Net encoder/decoder with learned skip weights (ones-init) and per-block residual mix from input embedding
parameters: {"layers":10,"dim":768,"heads":8,"kv_heads":4,"head_dim":96,"MLP_expansion":4,"MLP_hidden":3072,"activation":"relu²","embedding_dim":254,"vocab_size":8192,"positional_encoding":"YaRN max_len=2048 ROPE_BASE=5000"}
Factored tied embedding
8192×254 bottleneck with learned 254-to-768 and 768-to-254 projections
parameters: null
Fused QKV projection
Single TernaryLinear fused QKV projection
parameters: null
FlashAttention-3
Hopper native kernels for attention
parameters: null
Optimizer
NeoMuon
weight_decay: 0
momentum: 0.95
other_params: {"backend_steps":3,"momentum_warmup_start":0.85,"momentum_warmup_steps":500,"adam_lr":0.05,"adam_wd":0.05,"matrix_lr":0.04,"scalar_lr":0.02,"tied_embed_lr":0.02}
Evaluation
sliding window eval
parameters: {"stride":16,"temperature_scaling":0.9,"temperature_grid_points":5}
Compression
Base-3 + LZMA
level: 9
Regularization
Z-loss regularisation
parameters: {"weight":0.0001}
Sequence Length
sequence_length
train_length: 1024
eval_length: null
LR Schedule
warmdown
parameters: {"warmdown_fraction":0.2}
Other
other
Shrinkage fix to correct ternary zero-fraction scale mismatch, eliminating roundtrip gaps
parameters: null

Novel Contributions

  • Use of BitNet b1.58 ternary quantisation with per-group absmean scaling
  • Integration of NeoMuon optimizer with 3 Newton-Schulz steps to compensate ternary STE gradient attenuation
  • 4x relu² MLP expansion with fused gate+up projection
  • U-Net encoder/decoder with learned skip weights and per-block residual mix from input embedding
  • Factored tied embedding with 8192×254 bottleneck and learned projections
  • Polynomial softcap (degree 5, cap=10) with Z-loss regularisation
  • YaRN positional encoding with max_len=2048 and ROPE_BASE=5000
  • Fused QKV projection using single TernaryLinear
  • FlashAttention-3 for faster attention computation
  • Temperature scaling during evaluation with sliding window stride=16
  • Artifact compression using base-3 encoding combined with LZMA achieving 39% reduction over int8+zlib
  • FP8 QAT to halve floating point parameters with minimal bpb penalty
  • Shrinkage fix to eliminate roundtrip gaps in ternary quantization
  • Width over depth design choice (768d/10L) for faster training steps and better performance