PR #452

closed

10L XSA + EMA + Partial RoPE + LN Scale (val_bpb: 1.1366)

by ofirkrisView on GitHub
val_bpb
1.1366
Architecture
GPT
Optimizer
Muon
Artifact Size
15,820,386 bytes

Training Techniques

Architecture
XSA
Exclusive Self-Attention applied to the last 4 transformer layers.
parameters: {"layers":4}
Partial RoPE
Rotary positional embeddings applied to only part of the head dimension.
parameters: {"dimensions":"16/64"}
SmearGate
Gating mechanism used in the model.
parameters: null
BigramHash
Bigram hashing feature with learned embedding dimension.
parameters: {"size":10240,"dim":128}
U-Net skip connections
Skip connections inspired by U-Net added to the transformer.
parameters: null
Weight Averaging
EMA
parameters: {"decay":0.997}
Regularization
LN Scale
parameters: {"scale":"1/sqrt(layer_idx+1)"}
Quantization
int5
bits: 5
scope: MLP
int6
bits: 6
scope: attention
fp16
bits: 16
scope: embeddings
Compression
zstd
level: 22
Initialization
Orthogonal init
Orthogonal weight initialization.
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: null
AdamW
weight_decay: null
momentum: null
other_params: {"scope":"embeddings/scalars"}
Evaluation
sliding window eval
parameters: {"stride":64,"seq_len":2048}
Test-Time Training
SGD post-quantization
parameters: {"epochs":3}
LR Schedule
warmdown
parameters: {"warmdown_iters":3000}

Novel Contributions

  • 10-layer GPT with XSA on the last 4 layers
  • EMA with decay 0.997
  • Partial RoPE using 16/64 dimensions
  • LN Scale based on layer index
  • Mixed precision quantization with int5 MLP and int6 attention
  • 3.2% magnitude pruning
  • SmearGate and BigramHash(10240)
  • Orthogonal initialization
  • Muon optimizer with AdamW for embeddings/scalars
  • Sliding window evaluation with stride 64