PR #374

RECORDopen

Record: 11L + Tight SWA + Shared VE128 + Partial RoPE + LN Scale + XSA4 (val_bpb: 1.1246)

val_bpb
1.1246
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.71 MB

Training Techniques

Architecture
MLP3x
Uses 3x MLP expansion with relu-squared activation.
parameters: {"multiplier":3}
XSA
Efficient partial XSA applied to the last 4 layers in a GQA-aware, zero-allocation manner.
parameters: {"layers":4}
Partial RoPE
Applies RoPE to only part of the head dimensions with NTK-aware scaling.
parameters: {"dimensions":16,"total_dimensions":64}
KV head count
Uses grouped-query attention with 4 KV heads across 8 attention heads.
parameters: {"heads":8,"kv_heads":4}
tied embeddings
Input and output embeddings are tied.
parameters: null
SmearGate
Adds SmearGate as part of the architecture.
parameters: null
BigramHash
Adds a bigram hashing feature with learned embeddings.
parameters: {"buckets":2048,"dimension":128}
Shared Value Embedding
Shares a single value embedding table across layers 9 and 10 with per-layer learned scales.
parameters: {"dimension":128,"layers":[9,10]}
U-Net skip connections
Uses encoder-decoder style skip connections.
parameters: {"encoder_layers":5,"decoder_layers":6}
Initialization
OrthoInit
Orthogonal initialization with projection scaling by 1/sqrt(2*num_layers).
Regularization
LN scale
parameters: {"scale_factor":"1/sqrt(layer_idx+1)"}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.025,"warmup_momentum_start":0.92,"warmup_momentum_end":0.99,"warmup_steps":1500}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"embedding_lr":0.035,"scalar_lr":0.025}
Weight Averaging
SWA
parameters: {"every_steps":50,"start_scale_threshold":0.2,"checkpoint_window_steps":600,"num_checkpoints":12}
Quantization
STE QAT
bits: 6
scope: MLP + attention weights
int8
bits: 8
scope: embeddings
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: null
LR Schedule
warmdown
parameters: {"warmdown_iters":3000,"wallclock_based":true}
Sequence Length
sequence_length
train_length: 2048
eval_length: null
Other
other
Late QAT with STE int6 fake-quantization when LR scale drops below 0.1.
parameters: {"threshold":0.1}

Novel Contributions

  • Tight SWA restricted to late training checkpoints with scale < 0.2
  • Shared Value Embedding across layers 9 and 10
  • Partial RoPE with NTK-aware scaling
  • Efficient partial XSA on the last 4 layers
  • Layer-wise LN scale factor of 1/sqrt(layer_idx+1)
  • Late QAT using STE int6 fake quantization
  • Sliding window evaluation to obtain the reported best val_bpb