PR #281
closedNon-record: val_bpb=1.1374, FA2+SWA adaptation of Farnsworth
by charmquark1984View on GitHub
val_bpb
1.1381
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.59 MB
Training Techniques
Architecture
MLP3x
3x expansion MLP with ReLU^2 activation in an 11-layer transformer
parameters: {"layers":11,"model_dim":512,"heads":8,"kv_heads":4,"mlp_hidden":1536}
SmearGate
Learned sigmoid token blending gate
parameters: {"params":"~512"}
BigramHash
2048-bucket hash embedding for token-pair features
parameters: {"buckets":2048,"dim":128}
KV head count
Grouped-query attention with fewer KV heads than attention heads
parameters: {"heads":8,"kv_heads":4}
RoPE
NTK-scaled rotary position embeddings
parameters: {"base":10000}
Quantization
mixed int6/int8
bits: 6
scope: MLP+attention; embeddings int8
Compression
zstd
level: 22
Optimizer
Muon
weight_decay: 0.042
momentum: 0.99
other_params: {"warmup_steps":1500,"warmdown_iters":3000}
Weight Averaging
SWA
parameters: {"checkpoints":7,"every":200,"during":"warmdown"}
Evaluation
sliding window eval
parameters: {"stride":64}
Test-Time Training
full TTT
parameters: {"learning_rate":0.002,"momentum":0.9,"epochs":3}
Initialization
Orthogonal
Orthogonal initialization with muP output scaling
Sequence Length
sequence_length
train_length: 2048
eval_length: 2048
LR Schedule
warmdown
parameters: {"warmup_steps":1500,"warmdown_iters":3000}
Regularization
weight decay
parameters: {"weight_decay":0.042}
Other
other
FlashAttention 2 / torch SDPA flash backend used instead of FA3 Hopper kernels
parameters: {"version":"2.8.3"}
Novel Contributions
- Adaptation of FarnsworthEngine to FlashAttention 2 instead of FA3 Hopper kernels
- Weight decay tuned as a control knob for compressed artifact size targeting
- Benchmark showing cuDNN SDP is faster than Flash SDP on H100 but yields worse model quality
- Systematic sweep identifying WD=0.042 as optimal for ~15.5MB artifact size and best BPB
- Use of SWA during warmdown combined with TTT and sliding-window evaluation