PR #1524

closed

Non-Record : Hybrid XSA-SSM

by Jash-VoraView on GitHub
val_bpb
1.2552
Architecture
Hybrid
Optimizer
AdamW
Artifact Size
14,915,938 bytes

Training Techniques

Architecture
Mamba
Replaced the middle transformer layer with a Mamba SSM block as a bottleneck in a U-Net-style encoder-decoder stack.
parameters: {"layer":5,"d_model":512,"d_state":64,"d_conv":4,"expand":2}
Optimizer
AdamW
weight_decay: null
momentum: null
other_params: {"Mamba_internal_params_to_AdamW":["A_log","dt_bias","D"]}
Other
other
Applied torch._dynamo.disable to the Mamba forward call to avoid TorchDynamo tracing issues with custom CUDA kernels.
parameters: null

Novel Contributions

  • Hybrid XSA-SSM architecture
  • Inserted a Mamba SSM block at the midpoint bottleneck of the model
  • Routed Mamba internal parameters to AdamW instead of Muon
  • Disabled TorchDynamo tracing for the Mamba forward pass