PR #399
openRecord: Parallel Muon + Parameter Banking — 81.87ms/step, val_bpb 1.1247 (3-seed mean)
by abaybektursunView on GitHub
val_bpb
1.1247
Architecture
Transformer
Optimizer
Parallel Muon
Artifact Size
~15.8 MB
Training Techniques
Quantization
int6
bits: 6
scope: evaluation artifact / model weights
Architecture
Parameter Banking
Restructures 66 separate linear weight matrices into 4 contiguous 3D parameter banks to enable batched optimizer operations.
parameters: {"qo_bank":[22,512,512],"kv_bank":[22,256,512],"mlp_up_bank":[11,1536,512],"mlp_down_bank":[11,512,1536]}
Partial RoPE
Uses partial rotary positional embeddings as part of the base architecture.
parameters: {"dimensions":16}
XSA
Includes XSA attention component in the base model.
parameters: {"last_n":4}
LayerNorm scale
Applies learnable LayerNorm scaling.
parameters: null
weight tying
Uses tied embeddings / tied output weights.
parameters: null
Optimizer
Parallel Muon
weight_decay: 0.04
momentum: 0.99
other_params: {"muon_momentum_warmup_start":0.92,"muon_momentum_warmup_steps":1500,"warmdown_iters":3000,"matrix_lr":0.025,"scalar_lr":0.025,"tied_embed_lr":0.035}
Weight Averaging
EMA
parameters: {"decay":0.997}
Evaluation
sliding window eval
parameters: {"stride":64}
Regularization
weight decay
parameters: {"muon_wd":0.04,"adam_wd":0.04}
LR Schedule
linear warmup + warmdown
parameters: {"muon_momentum_warmup_steps":1500,"warmdown_iters":3000}
Other
other
Parameter banking enables batched Newton-Schulz orthogonalization and explicit asynchronous communication scheduling (reduce_scatter, all_reduce, all_gather) to restore compute-communication overlap.
parameters: {"optimizer_time_reduction_ms":{"before":19.7,"after":1.3}}
Novel Contributions
- Parameter Banking: restructuring 66 linear weights into 4 contiguous parameter banks for batched optimizer operations
- Parallel Muon communication strategy adapted to work without DDP on banked parameters
- Batched Newton-Schulz orthogonalization over parameter banks using torch.bmm
- Explicit asynchronous communication schedule (reduce_scatter, all_reduce, all_gather) to restore overlap
- Architecture-agnostic systems optimization that improves training throughput without changing model architecture or hyperparameters