PR #531
openRecord: 11L + XSA4 + EMA + Late QAT + GPTQ-lite (1.1325 BPB)
by pragnyanramthaView on GitHub
val_bpb
1.1324
Architecture
GPT with GQA (8 heads, 4 KV heads)
Optimizer
Muon + AdamW
Artifact Size
16.9-17.4 MB
Training Techniques
Architecture
XSA4
Cross-layer Shared Attention with post-attention geometric subtraction on last 4 layers to save parameters
parameters: {"layers":4,"active_layers":[7,8,9,10]}
Layer count
Increased number of layers from 9 to 11 (5 encoder + 6 decoder with skip connections)
parameters: {"num_layers":11,"encoder_layers":5,"decoder_layers":6}
GQA
Grouped Query Attention with 8 heads and 4 KV heads
parameters: {"heads":8,"kv_heads":4}
Embeddings
Tied embeddings with Bigram hash
parameters: {"vocab_size":10240,"embedding_dim":128}
Attention
Flash Attention 3 with RoPE and SmearGate
parameters: null
Weight Averaging
EMA
parameters: {"decay":0.997,"start_step":0,"duration":"full training"}
Quantization
STE QAT
bits: 6
scope: weights during backward pass when LR < 15% peak
GPTQ-lite
bits: null
scope: attention layers int6, MLP int5, rest int8 or pass-through
LR Schedule
warmdown
parameters: {"warmdown_iters":3000,"warmup_steps":20}
Optimizer
Muon + AdamW
weight_decay: null
momentum: null
other_params: {"lr_matrix":0.02,"lr_embedding":0.03,"lr_scalar":0.02,"grad_accum_steps":8}
Compression
zstd
level: null
Other
other
Compile with fullgraph=True to enable full-graph compilation without graph breaks, saving compilation overhead and enabling more training steps
parameters: null
Evaluation
sliding window eval
parameters: {"stride":64}
Novel Contributions
- XSA4: Cross-layer Shared Attention with zero new parameters via geometric subtraction on last 4 layers to save ~800K parameters
- Increasing model depth to 11 layers enabled by XSA4 parameter savings to fit 16MB budget
- Full-duration EMA with decay 0.997 maintaining float32 running average on CPU applied at end of training
- Late Quantization-Aware Training (QAT) activating only when learning rate drops below 15% of peak, quantizing weights to int6 during backward pass
- GPTQ-lite with 5-percentile MSE search for optimal clipping levels applied selectively to attention and MLP layers
- Fullgraph=True compilation enabled by subclass design to avoid graph breaks, improving compilation speed and allowing more training steps within time budget