PR #559

open

Non-record: TernaryRecurrentGPT - ternary 1.58-bit MLP + depth recurrence (1xL4 val_bpb=1.5348)

by ParswanadhView on GitHub
val_bpb
1.5348
Architecture
TernaryRecurrentGPT
Optimizer
Muon + AdamW
Artifact Size
12,372,468 bytes

Training Techniques

Quantization
STE QAT
bits: 1
scope: MLP
Architecture
depth recurrence
7 unique layers repeated in 2 loops for 14 effective depth
parameters: {"unique_layers":7,"loops":2,"effective_depth":14}
BigramHash
Bigram hashing with 2048 buckets
parameters: {"buckets":2048}
SmearGate
SmearGate gating mechanism
parameters: null
tied embeddings
FP16 tied embeddings
parameters: {"precision":"FP16"}
Optimizer
Muon + AdamW
weight_decay: 0.04
momentum: null
other_params: null
Weight Averaging
SWA
parameters: {"start_percent":40}
Initialization
loop gates initialized at 1.0
Fixes failure mode in PR #319
Sequence Length
sequence_length
train_length: null
eval_length: 512
Other
other
Neural Cache disabled due to +0.028 bpb penalty at this scale
parameters: null

Novel Contributions

  • Use of ternary 1.58-bit STE QAT quantization on MLP weights
  • Depth recurrence with 7 unique layers repeated twice for effective depth of 14
  • Loop gates initialized at 1.0 to fix failure mode
  • Integration of BigramHash with 2048 buckets and SmearGate
  • FP16 tied embeddings
  • Use of Muon optimizer combined with AdamW and weight decay 0.04
  • SWA weight averaging starting from 40% of training
  • Disabling Neural Cache due to negative impact on val_bpb at this scale