PR #1525

open

Non-Record 16 MB Track : Hybrid XSA-SSM

by Jash-VoraView on GitHub
val_bpb
1.2552
Architecture
Hybrid
Optimizer
AdamW
Artifact Size
14,915,938 bytes

Training Techniques

Architecture
Mamba
Replaced the middle transformer layer with a Mamba SSM block inserted as a residual bottleneck in a U-Net-style encoder-decoder stack.
parameters: {"layer":5,"d_model":512,"d_state":64,"d_conv":4,"expand":2}
Optimizer
AdamW
weight_decay: null
momentum: null
other_params: {"scalar_optimizer_for_mamba_params":true,"mamba_params_routed_to_adamw":["A_log","dt_bias","D"]}
Other
other
Applied torch._dynamo.disable to the Mamba forward call to avoid TorchDynamo tracing issues with custom CUDA kernels.
parameters: null

Novel Contributions

  • Hybrid XSA-SSM architecture
  • Middle-layer Mamba SSM bottleneck replacing a transformer attention block
  • Residual insertion of the Mamba block
  • Routing Mamba internal parameters to AdamW scalar optimizer instead of Muon
  • Disabling TorchDynamo tracing for the Mamba forward pass