PR #406

open

Non-record: 11L XSA4 + EMA + SDTTT (3-seed mean val_bpb=1.1287)

by dentity007View on GitHub
val_bpb
1.1287
Architecture
Transformer
Optimizer
SGD
Artifact Size
15.7MB

Training Techniques

Architecture
XSA
Exclusive Self-Attention applied to the last 4 layers of the model.
parameters: {"layers":4}
MLP3x
Uses a 3x MLP expansion in the architecture.
parameters: null
KV head count
Uses grouped-query attention with 4 KV heads and 8 attention heads.
parameters: {"heads":8,"kv_heads":4}
tied embeddings
FP16 tied embedding passthrough.
parameters: null
SmearGate
Includes SmearGate as part of the architecture.
parameters: null
U-Net skip connections
Adds U-Net style skip connections.
parameters: null
Quantization
int6 QAT
bits: 6
scope: model weights
Compression
zstd
level: null
Weight Averaging
EMA
parameters: {"decay":0.997}
Evaluation
sliding window eval
parameters: {"stride":64}
Test-Time Training
Self-Distillation TTT
parameters: {"learning_rate":0.001,"temperature":2,"epochs":2,"freeze_blocks":4,"momentum":0.9}
Regularization
weight decay
parameters: {"weight_decay":0.04}
Other
other
Stock PyTorch SDPA used instead of Flash Attention 3 or custom kernels.
parameters: null

Novel Contributions

  • 11-layer architecture with XSA applied to the last 4 layers
  • EMA replacing SWA
  • Self-Distillation TTT at evaluation time
  • Int6 QAT with STE and zstd compression
  • Sliding-window evaluation with stride 64
  • Grouped-query attention with 4 KV heads
  • FP16 tied embedding passthrough
  • SmearGate and U-Net skip connections