val_bpb
1.1307
Architecture
Transformer
Optimizer
Muon
Artifact Size
15,892,986 bytes
Training Techniques
Architecture
XSA
Exclusive Self Attention applied only to the deepest layers, with an efficient GQA-aware implementation that avoids value-vector duplication via reshape and broadcasting.
parameters: {"layers":3,"total_layers":11,"head_count":8,"kv_heads":4}
SmearGate
Uses SmearGate as part of the model architecture.
parameters: null
BigramHash
Adds a BigramHash component with bucketed representation.
parameters: {"buckets":2048,"dim":128}
tied embeddings
Input and output embeddings are tied.
parameters: null
RoPE
NTK-aware rotary positional embeddings that auto-scale at longer context lengths.
parameters: {"train_seq_len":1024,"auto_scales_at":2048}
MLP3x
Uses a 3x MLP expansion with relu-squared activation.
parameters: {"hidden_dim":1536}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"lr":0.025,"warmup_start_momentum":0.92,"warmup_steps":1500}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"embedding_lr":0.035,"scalar_lr":0.025}
Weight Averaging
SWA
parameters: {"every_steps":120,"num_checkpoints":13,"scale_threshold":0.5}
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64}
Initialization
OrthoInit
Orthogonal initialization with muP-scaled output projections.
Sequence Length
sequence_length
train_length: 1024
eval_length: null
LR Schedule
warmdown
parameters: {"warmdown_iters":3000,"warmup_steps":30}
Regularization
weight decay
parameters: {"muon_wd":0.04,"adam_wd":0.04}
Novel Contributions
- Efficient GQA-aware implementation of Exclusive Self Attention using reshape and broadcasting instead of repeat_interleave
- Applying XSA only to the deepest 3 of 11 layers to reduce compute while targeting layers with higher self-attention bias
- Combination of partial XSA with an 11-layer Transformer, GQA, SmearGate, BigramHash, and U-Net skip connections