PR #352

open

Memory Tokens + Mixed Quantization (val_bpb: 1.1659)

val_bpb
1.1659
Architecture
Transformer
Optimizer
Muon
Artifact Size
15,070,662 bytes

Training Techniques

Architecture
Memory Tokens
64 learnable embedding vectors overwrite/prepend the first K positions of each sequence to provide a global context scratchpad accessible through causal attention.
parameters: {"num_memory_tokens":64}
MLP3x
Uses 3x MLP expansion in the transformer blocks.
parameters: {"mlp_multiplier":3}
BigramHash
Hashes consecutive token pairs to inject local context via BigramHashEmbedding.
parameters: {"vocab_size":10240}
SmearGate
Learned blend with the previous token at the embedding level.
parameters: null
Partial RoPE
Applies rotary position encoding to only part of the head dimensions, leaving the rest content-only.
parameters: {"dimensions":16,"total_dimensions":64}
tied embeddings
Input and output embeddings are tied.
parameters: null
KV head count
Uses grouped-query attention with fewer KV heads than attention heads.
parameters: {"attention_heads":8,"kv_heads":4}
Regularization
LN Scale
parameters: {"scale":"1/sqrt(layer+1)"}
Optimizer
Muon
weight_decay: 0.04
momentum: 0.95
other_params: {"matrix_lr":0.04}
AdamW
weight_decay: 0.04
momentum: null
other_params: {"scope":"embeddings/scalars"}
Weight Averaging
EMA
parameters: {"decay":0.997,"update_every_steps":10}
Quantization
mixed int5/int6 QAT
bits: null
scope: MLP weights int5, attention weights int6, embeddings fp16
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":128,"seq_len":1024,"batched_windows":256,"compiled_forward_logits":true}
Sequence Length
sequence_length
train_length: 2048
eval_length: 1024
LR Schedule
warmdown
parameters: {"warmdown_steps":3000}
Other
other
MTP auxiliary heads used during training and stripped before export.
parameters: {"k":2,"alpha":0.2}
other
Late QAT with fake int6 quantization (STE) when lr_scale < 0.1.
parameters: {"quantization":"int6","method":"STE QAT"}

Novel Contributions

  • Memory tokens: 64 learnable embedding vectors used as a global context scratchpad
  • A/B tested improvement from memory tokens (-0.014 BPB)
  • Mixed quantization with int5 MLP weights and int6 attention weights
  • Late QAT with fake int6 quantization
  • Batched sliding window evaluation with compiled forward_logits