PR #388
closedRecord: 11L + Tight SWA + VE128 + Partial RoPE + LN Scale + TTT (val_bpb: 1.1231)
by ElliotSluskyView on GitHub
val_bpb
1.1231
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.43 MB
Training Techniques
Architecture
Partial RoPE
Uses RoPE on only part of the attention dimensions, leaving most attention dimensions position-free.
parameters: {"dimensions":16,"total_dimensions":64}
Shared Value Embeddings
A learned embedding table shared across layers 9 and 10 and added to the value path with per-layer learned scales.
parameters: {"dim":128,"layers":[9,10]}
LN Scale
LayerNorm scaling factor of 1/sqrt(layer_idx+1).
parameters: null
MLP3x
Uses 3x MLP expansion with relu-squared activation.
parameters: null
GQA
Uses grouped-query attention with 4 KV heads across 8 attention heads.
parameters: {"heads":8,"kv_heads":4}
SmearGate
Includes SmearGate in the architecture.
parameters: null
BigramHash
Adds BigramHash features with 2048 buckets and 128-dimensional embeddings.
parameters: {"buckets":2048,"dim":128}
tied embeddings
Input and output embeddings are tied.
parameters: null
U-Net skip connections
Uses encoder-decoder style skip connections with 5 encoder and 6 decoder layers.
parameters: {"encoder_layers":5,"decoder_layers":6}
Optimizer
Muon
weight_decay: 0.042
momentum: 0.99
other_params: {"lr":0.025,"warmup_momentum_start":0.92,"warmup_steps":1500}
AdamW
weight_decay: 0.042
momentum: null
other_params: {"lr_embeddings":0.035,"lr_scalars":0.025}
Weight Averaging
SWA
parameters: {"every_n_steps":50,"start_scale":0.2,"num_checkpoints":16}
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64}
Test-Time Training
full TTT
parameters: {"learning_rate":0.008,"epochs":25,"momentum":0.9,"batch_seqs":32,"freeze_blocks":0}
Initialization
OrthoInit
Orthogonal initialization with projection scaling by 1/sqrt(2*num_layers).
Sequence Length
sequence_length
train_length: 2048
eval_length: null
LR Schedule
warmdown
parameters: {"warmdown_iters":4000}
Regularization
layerwise LN scale
parameters: {"scale_factor":"1/sqrt(layer_idx+1)"}
Quantization
mixed int6/int8
bits: 6
scope: MLP and attention weights int6 per-row; embeddings int8 per-row
Other
other
Uses cuDNN SDPA attention implementation for speed.
parameters: {"speedup_vs_fa2":1.18}
Novel Contributions
- Tight SWA restricted to low-scale checkpoints and averaging only the most recent 16 checkpoints
- Shared Value Embeddings with a single table shared across layers 9 and 10
- Partial RoPE applied to only 16 of 64 attention dimensions
- LayerNorm scaling by 1/sqrt(layer_idx+1)
- Test-Time Training with full-weight SGD on validation data after quantization
- Int6+zstd quantization with int8 embeddings
- cuDNN SDPA attention for faster grouped-query attention