PR #383
openRecord: 11L Full Stack + XSA4 + Tight SWA + Late QAT (val_bpb=1.1320)
by joelnishanthView on GitHub
val_bpb
1.1320
Architecture
Transformer
Optimizer
Muon
Artifact Size
15,753,020 bytes
Training Techniques
Architecture
XSA
Applied XSA to the last 4 layers in a GQA-aware, zero-allocation form.
parameters: {"layers":4}
Partial RoPE
Used partial rotary positional embeddings with NTK-aware scaling.
parameters: {"dimensions":16,"base_dimensions":64}
MLP3x
Expanded MLP width by 3x with relu-squared activation.
parameters: {"expansion":3}
tied embeddings
Input and output embeddings are tied.
parameters: null
KV head count
Used grouped-query attention with 4 KV heads out of 8 attention heads.
parameters: {"heads":8,"kv_heads":4}
SmearGate
Included SmearGate as part of the model architecture.
parameters: null
BigramHash
Added BigramHash with 2048 buckets and 128-dimensional embeddings.
parameters: {"buckets":2048,"dim":128}
Shared Value Embedding
Shared value embeddings were used in layers 9 and 10.
parameters: {"layers":[9,10],"dim":128}
Initialization
OrthoInit
Orthogonal initialization with projection scaling.
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.025}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"lr_embeddings":0.035,"lr_scalars":0.025}
Weight Averaging
SWA
parameters: {"interval_steps":50,"threshold_scale":0.2,"checkpoints":12}
Quantization
STE QAT
bits: 6
scope: MLP + attention weights
int8
bits: 8
scope: embeddings
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64}
LR Schedule
warmdown
parameters: {"iters":3000,"wallclock_based":true}
Regularization
layerwise LN scale
parameters: {"scale":"1/sqrt(layer_idx+1)"}
gradient clipping
parameters: {"clip_norm":0.3}
Other
other
Late QAT applied when learning-rate scale dropped below 0.1.
parameters: {"threshold_lr_scale":0.1}
other
Used FlashAttention 3 with FA2 fallback.
parameters: null
Novel Contributions
- 11-layer transformer with 512 hidden size and GQA
- XSA on the last 4 layers
- Partial RoPE with NTK-aware scaling
- U-Net skip connections
- SmearGate and BigramHash features
- Shared Value Embedding in later layers
- Tight SWA during late training
- Late QAT with STE int6
- FlashAttention 3 with FA2 fallback
- Orthogonal initialization with projection scaling