PR #307

open

Record: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357)

by dennisimooView on GitHub
val_bpb
1.1357
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.67 MB

Training Techniques

Quantization
int6
bits: 6
scope: all
Architecture
XSA
Uses XSA on the last layers of the model.
parameters: {"layers":4}
EMA
Exponential moving average is enabled during training.
parameters: {"enabled":1,"decay":0.997}
SmearGate
Included as part of the model variant described in the README.
parameters: null
BigramHash
Included as part of the model variant described in the README.
parameters: null
MLP3x
Uses a 3x MLP multiplier.
parameters: {"multiplier":3}
Weight Averaging
EMA
parameters: {"decay":0.997}
Compression
zstd
level: null
Evaluation
sliding window eval
parameters: {"stride":64}
Sequence Length
sequence_length
train_length: 2048
eval_length: 2048
LR Schedule
warmdown
parameters: {"warmdown_steps":3000,"warmup_steps":20}
Regularization
weight decay
parameters: {"muon_wd":0.04,"adam_wd":0.04}
Other
other
Uses a larger fixed-budget batch setting to improve step count under the wall-clock cap.
parameters: {"train_batch_tokens":524288}
other
Provides SDPA fallback when flash_attn_interface Python bindings are unavailable.
parameters: null
other
Enables torch.compile behind an environment flag for reliable eager smoke tests and faster compiled runs.
parameters: null

Novel Contributions

  • 11-layer XSA4 model with EMA averaging
  • Fixed-budget batch size of 524,288 tokens to improve step count under the time cap
  • SDPA fallback for flash_attn_interface when FA3 Python bindings are unavailable
  • torch.compile gated behind an environment flag for safer testing and faster full runs
  • zstd Python-or-CLI fallback to keep int6 export under the 16MB limit