PR #481

closed

Record: Cosine TTT scheduling with per-layer lr — mean val_bpb=1.0970 (3 seeds)

by mrdavtanView on GitHub
val_bpb
1.0970
Architecture
Transformer
Optimizer
AdamW
Artifact Size
15.4-15.8 MB

Training Techniques

Quantization
int6
bits: 6
scope: per-row all weights
Compression
zstd
level: 22
Architecture
Partial RoPE
Uses rotary positional embeddings on only part of the dimensions.
parameters: {"dimensions":"16/64"}
LN Scale
LayerNorm scaling modification.
parameters: null
SmearGate
Custom gating mechanism in the community stack.
parameters: null
BigramHash
Bigram hashing component with 2048 buckets.
parameters: {"buckets":2048}
MLP3x
Three-times wider/deeper MLP stack using relu-squared activations.
parameters: {"multiplier":3}
tied embeddings
Input and output embeddings are tied.
parameters: null
U-Net skips
Skip connections inspired by U-Net are used in the architecture.
parameters: null
Initialization
OrthoInit
Orthogonal initialization.
Weight Averaging
EMA
parameters: {"decay":0.997}
Optimizer
AdamW
weight_decay: null
momentum: null
other_params: {"TTT":true,"learning_rate":0.0005}
Test-Time Training
full TTT
parameters: {"optimizer":"AdamW","learning_rate":0.0005,"epochs":30,"cosine_decay":true,"per_layer_lr":true,"freeze_blocks":0,"batch_seqs_per_gpu":64}
LR Schedule
cosine decay
parameters: {"epochs":30}
Regularization
layerwise LN scale
parameters: null

Novel Contributions

  • Cosine learning-rate decay for TTT over 30 epochs instead of a flat schedule
  • Per-layer TTT learning-rate groups based on measured quantization damage
  • Analysis showing MLP output projections have much higher quantization error than input projections
  • Demonstration that TTT improves beyond merely repairing quantization damage
  • Extensive negative-result exploration of alternative compression and architectural ideas