PR #653
openfeat(arch): Mish² Activation & PyTorch Native SDPA GQA Core (1.155 BPB) 8xH100
by demireloView on GitHub
val_bpb
1.1552
Architecture
Transformer
Optimizer
Parameter Banking + Parallel Muon
Artifact Size
15.65 MB
Training Techniques
Architecture
Mish² Activation
Smooth non-monotonic activation function using F.mish(x).square() replacing LeakyReLU²
parameters: null
BigramHash
Hashing mechanism for bigrams with vocabulary size 1536
parameters: {"bigram_vocab_size":1536}
Partial RoPE
Rotary Positional Embeddings applied partially on 16/64 dimensions
parameters: {"dims":"16/64"}
LayerNorm Scale
LayerNorm scaled by 1/sqrt(layer+1)
parameters: null
MLP
3× MLP layers using Mish² activation
parameters: {"count":3}
Optimizer
Parameter Banking + Parallel Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"adam_wd":0.04,"muon_momentum_warmup_start":0.92,"muon_momentum_warmup_steps":1500,"matrix_lr":0.025,"scalar_lr":0.025,"tied_embed_lr":0.035}
Weight Averaging
EMA + Tight SWA
parameters: {"ema_decay":0.997,"swa_every":50}
Test-Time Training
score-first TTT
parameters: {"chunk_size":32768,"optimizer":"SGD","learning_rate":0.002,"momentum":0.9,"epochs_per_chunk":3,"frozen_blocks":0,"gradient_clip":1,"batch_seqs":32}
Compression
lzma
level: null
Evaluation
sliding window eval
parameters: {"stride":64}
Regularization
weight decay
parameters: {"value":0.04}
LR Schedule
cosine decay
parameters: null
Other
other
Parameter Banking replaces many separate nn.Linear weights with 4 contiguous 3D parameter banks for efficiency
parameters: null
other
PyTorch-native scaled dot product attention loop replacing flash-attn C++ dependency for robust multi-GPU synchronization
parameters: null
Novel Contributions
- Integration of Mish² activation (F.mish(x).square()) for improved implicit regularization and BPB reduction
- PyTorch-native GQA-aware scaled dot product attention loop replacing flash-attn C++ dependency for stable multi-GPU training
- Use of Parameter Banking combined with Parallel Muon optimizer for faster training throughput
- Legal score-first Test-Time Training (TTT) protocol with sliding window evaluation and SGD adaptation on validation chunks
- Partial RoPE positional embeddings and layer-scaled LayerNorm for architectural improvements
- Artifact size tightly optimized under 16MB limit using GPTQ-lite int6 quantization and lzma compression