PR #578

open

GPTQ + Early QAT + Legal TTT — 3-seed mean val_bpb 1.1215

by newjordanView on GitHub
val_bpb
1.1215
Architecture
11L/512d/8H/4KV/3xMLP (relu²) with U-Net skip connections, Partial RoPE (16/64), XSA last 4 layers, BigramHash(2048), VE128 on layers 9-10, SmearGate, logit softcap 30, tied embeddings
Optimizer
Muon (lr=0.025, WD=0.04, momentum=0.99) for base training; SGD + momentum 0.9 for TTT
Artifact Size
15.56 MB

Training Techniques

Quantization
GPTQ with early QAT
bits: 6
scope: all weights (per-row int6 quantization with Hessian-aware error compensation)
Weight Averaging
EMA
parameters: {"decay":0.995,"usage":"smoothed weights for evaluation, raw weights for training"}
Test-Time Training
Legal Score-First TTT
parameters: {"EMA_decay":0.995,"cosine_lr_decay_fixed":true,"embedding_freeze":["tok_emb","bigram","ve_shared"],"optimizer":"SGD + momentum 0.9","epochs_per_chunk":3,"grad_clip":1}
Architecture
Partial RoPE, XSA, BigramHash, VE128, SmearGate, logit softcap, tied embeddings
Transformer with U-Net skip connections and multiple architectural enhancements
parameters: {"layers":11,"dimension":512,"heads":8,"kv_heads":4,"mlp_expansion":3,"bigram_hash_buckets":2048,"ve_layers":[9,10],"logit_softcap":30}
Compression
zstd
level: 22
Evaluation
sliding window eval with stride 32
parameters: {"stride":32}
LR Schedule
cosine decay
parameters: {"fixed_window":200}

Novel Contributions

  • GPTQ quantization replacing naive per-row int6 quantization with Hessian-aware error compensation, reducing quantization error by 32%
  • Early QAT with matched clipping extending QAT steps (~3x more) and using 99.95th percentile clipping matching GPTQ export quantizer
  • Legal Score-First Test-Time Training (TTT) with EMA smoothing and fixed cosine LR decay over actual training window
  • Embedding freeze during TTT to stabilize adaptation
  • Use of U-Net skip connections and multiple architectural enhancements including Partial RoPE, XSA in last layers, BigramHash, VE128, SmearGate, and logit softcap
  • Evaluation improvements via finer sliding window stride (32 vs 64) and extended TTT epochs (8 vs 3) for free val_bpb gains