PR #710

open

Submission: 11L EMA + GPTQ-lite + Int6 (val_bpb: 1.1240)

by Dhruba531View on GitHub
val_bpb
1.1240
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.58 MB

Training Techniques

Quantization
GPTQ-lite
bits: 6
scope: MLP and attention weights
int8
bits: 8
scope: embeddings
STE QAT
bits: 6
scope: model weights
Architecture
MLP3x
3x MLP expansion with relu-squared activation
parameters: {"expansion":3}
U-Net skip connections
Encoder-decoder style skip connections across layers
parameters: {"encoder_layers":5,"decoder_layers":6}
XSA
Efficient partial XSA applied to the last layers
parameters: {"layers":4}
Partial RoPE
Rotary positional embeddings applied to a subset of dimensions with NTK-aware scaling
parameters: {"dimensions":16,"total_dimensions":64}
tied embeddings
Input and output embeddings are tied
parameters: null
KV head count
Grouped-query attention with fewer KV heads than attention heads
parameters: {"heads":8,"kv_heads":4}
SmearGate
Custom gating mechanism used in the model
parameters: null
BigramHash
Bigram hashing feature with bucketed representation
parameters: {"buckets":2048,"dim":128}
Value Embeddings
Shared value embeddings used in later layers with learned per-layer scales
parameters: {"layers":[9,10],"dim":128}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.025,"warmup_momentum_start":0.92}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"lr":0.035,"scope":"embeddings"}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"lr":0.025,"scope":"scalars"}
Weight Averaging
EMA
parameters: {"decay":0.997}
SWA
parameters: {"every_steps":50,"scale_threshold":0.2}
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64}
Initialization
OrthoInit
Orthogonal initialization used with muP-scaled output projections
Sequence Length
sequence_length
train_length: 2048
eval_length: null
LR Schedule
warmdown
parameters: {"warmdown_steps":3500}
Regularization
layerwise LN scale
parameters: {"scale_rule":"1/sqrt(layer_idx+1)"}
Other
other
Late QAT with STE int6 fake-quantization when LR scale drops below 0.15
parameters: {"threshold":0.15}

Novel Contributions

  • 11-layer transformer with 3x MLP expansion and U-Net skip connections
  • Efficient partial XSA on the last 4 layers
  • Partial RoPE with NTK-aware scaling
  • SmearGate, BigramHash, and shared value embeddings
  • EMA plus tight SWA during training
  • GPTQ-lite per-row optimal clip percentile search for int6 quantization
  • Late QAT with STE int6 fake-quantization
  • Int6 roundtrip evaluation with zstd-compressed artifact