PR #1237

open

Non-record: Fused Triton relu^2 kernel — negative result (val_bpb=1.1198)

by ibarrajoView on GitHub
val_bpb
1.1198
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.1 MB

Training Techniques

Architecture
ReLU²
MLP activation uses relu(x).square() with a fused Triton kernel fallback to standard PyTorch ops.
parameters: null
XSA
Attention uses XSA across all layers.
parameters: {"layers":11}
Partial RoPE
Applies RoPE to a subset of dimensions.
parameters: {"dimensions":"16/64"}
BigramHash
Bigram hash embeddings are used for token representation.
parameters: {"vocab":6144,"dimensions":128}
SmearGate
SmearGate is included as part of the architecture.
parameters: null
U-Net skip connections
U-Net style skip connections with learned per-dimension scaling.
parameters: null
Optimizer
Muon
weight_decay: null
momentum: null
other_params: {"also_used":"AdamW"}
Quantization
late QAT
bits: 5
scope: model
Weight Averaging
EMA
parameters: {"decay":0.997}
SWA
parameters: {"interval":50,"during":"warmdown"}
LR Schedule
warmdown
parameters: {"warmdown_steps":3500}
Regularization
magnitude pruning
parameters: {"sparsity":0.1}
logit softcap
parameters: {"value":30}
Test-Time Training
score-first TTT
parameters: null
Compression
zstd
level: null

Novel Contributions

  • Fused Triton ReLU² activation kernel with torch.compile fallback
  • Benchmark finding that torch.compile already fuses relu^2 effectively, so the custom Triton kernel gives no speedup
  • QK-Gain 4.0 included
  • Score-first TTT s_0 submission