PR #498

closed

The Frugendorff: Recursive Weight Sharing + MLP 4x (1.1478 BPB, 15.19MB)

by newjordanView on GitHub
val_bpb
1.1478
Architecture
Transformer
Optimizer
Muon
Artifact Size
15.19MB

Training Techniques

Architecture
depth recurrence / weight sharing
6 unique transformer blocks are looped twice each to create 12 effective layers while storing only 6 blocks of parameters.
parameters: {"unique_blocks":6,"loops":2,"effective_layers":12}
MLP4x
Expanded MLP hidden size to 4x the model dimension to improve quality.
parameters: {"multiplier":4,"hidden_size":2560}
GQA
Grouped-query attention with fewer KV heads than attention heads.
parameters: {"heads":10,"kv_heads":5,"head_dim":64}
Partial RoPE
Applies rotary position embeddings to only part of the head dimension with NTK-aware scaling.
parameters: {"rope_dims":16,"total_dims":64}
tied embeddings
Input and output embeddings are tied.
parameters: null
XSA
XSA is applied in the last 2 unique layers.
parameters: {"layers":2}
SmearGate
Additional gating mechanism used in the architecture.
parameters: null
BigramHash
Auxiliary hashed bigram feature module.
parameters: {"buckets":2048,"dim":128}
Shared Value Embedding
Uses a shared value embedding to reduce parameters.
parameters: {"dim":128}
U-Net skips
Skip connections are used within each loop iteration.
parameters: null
Initialization
Orthogonal loop positions
Loop position embeddings are QR-initialized to differentiate repeated passes through shared blocks.
Optimizer
Muon
weight_decay: null
momentum: 0.99
other_params: {"lr":0.025,"scope":"matrices"}
AdamW
weight_decay: null
momentum: null
other_params: {"lr":0.035,"scope":"embeddings"}
AdamW
weight_decay: null
momentum: null
other_params: {"lr":0.025,"scope":"scalars"}
Weight Averaging
SWA
parameters: {"every_steps":50,"condition":"scale < 0.2"}
EMA
parameters: {"decay":0.997,"applied_after":"distillation"}
Quantization
int6 QAT
bits: 6
scope: MLP and attention weights
int8
bits: 8
scope: embeddings
Compression
zstd
level: 22
Evaluation
sliding window eval
parameters: {"stride":64}
Regularization
layerwise LN scale
parameters: {"scale":"1/sqrt(layer_idx+1)"}
Other
other
Late training replay over the last 100 training batches for 2 epochs at 10% learning rate.
parameters: {"batches":100,"epochs":2,"lr_fraction":0.1}
other
Self-distillation using an EMA teacher during training.
parameters: {"teacher":"EMA","steps":50,"temperature":2,"alpha":0.7}

Novel Contributions

  • Recursive/fractal weight sharing to achieve 12 effective layers using only 6 stored transformer blocks
  • Reinvesting saved parameter budget into MLP 4x expansion
  • Orthogonal loop position embeddings to distinguish repeated passes through shared blocks
  • Combination of U-Net skips, SmearGate, BigramHash, shared value embedding, and XSA in a compact transformer
  • Full training pipeline including Muon, SWA, late QAT, training replay, self-distillation, and EMA