Skip to content

Non-record: Mini-Hymba hybrid attention+SSM heads (arXiv:2411.13676)#1961

Open
aparna-1407 wants to merge 5 commits intoopenai:mainfrom
aparna-1407:main
Open

Non-record: Mini-Hymba hybrid attention+SSM heads (arXiv:2411.13676)#1961
aparna-1407 wants to merge 5 commits intoopenai:mainfrom
aparna-1407:main

Conversation

@aparna-1407
Copy link
Copy Markdown

@aparna-1407 aparna-1407 commented Apr 30, 2026

Mini-Hymba: Hybrid Attention + SSM Architecture

Track: Non-record
Base: PR #1493 (SP8192 + 3-layer recurrence + parallel residuals + QK-Gain + legal TTT)
Status: Complete — val_bpb will be updated when full run completes


What this does

Replaces one CausalSelfAttention module, layer 4, with a Mini-Hymba hybrid block that runs transformer attention heads and Mamba-lite SSM heads in parallel on the same input, then concatenates their outputs.

The motivation comes from Hymba: attention heads provide high-resolution token recall, while SSM heads provide efficient context summarization through a recurrent state. In this miniature version, only one layer is hybridized because the 1-layer ablation trained faster and reached much better BPB than the earlier 3-layer variant.

Architecture changes

  • 1 Mini-Hymba layer at layer 4
  • 4 attention heads + 4 SSM heads per hybrid layer (50/50 split guarantees GQA compatibility)
  • Learnable meta tokens (4 per layer) prepended before attention — stores global context,
    reduces forced-to-attend burden (Hymba paper §3.2)
  • Cross-layer KV sharing — layers 4 and 5 reuse layer 3's K/V projections, eliminating
    redundant parameters from the artifact entirely (no dead weights in state_dict)
  • Signed chunk-parallel fp32 SSM scan with HYMBA_SCAN_CHUNK=64
  • Everything else unchanged — TTT, GPTQ int8+zlib, BPB evaluation, artifact compression
    all run exactly as in the host script
  • Host RoPE, CastedLinear, QK-Gain, scorer, quantization, and artifact export paths unchanged
  • No eval-time n-gram or PPM cache

Implementation

hymba_layer.py is a self-contained drop-in patch. It registers CastedLinear, Rotary, and apply_rotary_emb from the host train_gpt.py script, so it avoids reimplementing RoPE or changing the compression/scoring path.
The original sequential SSM scan was stable but slow. The current implementation uses a chunk-parallel scan that reduces Python loop overhead from 1024 token steps to 16 chunk steps for sequence length 1024.

Results

  • Training steps: 800

  • Training time: 877.4 seconds

  • Hardware: 1x NVIDIA RTX PRO 6000 Blackwell

  • Final unquantized val_loss: 2.4501

  • Final unquantized val_bpb: 1.4511

  • Final int8+zlib roundtrip val_loss: 2.49096816

  • Final int8+zlib roundtrip val_bpb: 1.47529166

  • Total int8+zlib submission size: 9,234,838 bytes

  • Peak allocated memory: 14,145 MiB

  • Artifact size:

    • Compressed int8+zlib model size: 9,186,139 bytes
    • Total compressed submission size: 9,234,838 bytes

Validation trajectory
step 0: val_bpb 4.1077
step 200: val_bpb 1.9420
step 400: val_bpb 1.6016
step 600: val_bpb 1.5044
step 800: val_bpb 1.4511
roundtrip: val_bpb 1.4753

Notes
The earlier 3-layer version trained stably but was slower and worse under wallclock. The 1-layer chunk-parallel version is the best current configuration. 3 layer results:

  • model_params: 16,984,514 — fits within 16MB artifact budget
  • Total submission size int8+zlib: 6,838,905 bytes — 6.8MB, well under limit (fewer iterations (only 200) and the 3-layer config used KV sharing, so layers 4 and 5 omitted/reused K/V projections. That reduced the raw quantized payload a lot. The 1-layer config has no sharing opportunity, so it keeps the normal layer’s K/V projections and ends up with a larger compressed payload)
  • 1-step sanity: val_bpb 3.81 post-quantization roundtrip (untrained baseline)
  • Architecture trains stably — loss decreasing smoothly, no NaNs
Config Scan Steps Step time Roundtrip val_bpb Result log file name
3 layers: 3,4,5 sequential 134 ~6.7s 3.0721 train_log_3layer_serial_200.txt
3 layers: 3,4,5 chunked 20 ~1.30s 3.2394 train_log_3layer_chunked_20.txt
1 layer: 4 chunked 800 ~1.10s 1.4753 train_log.txt

Layer-position ablation at 800 steps:

Hymba layer Unquantized val_bpb Roundtrip val_bpb
3 1.4964 1.5638
4 1.4515 1.4753
5 1.4863 1.5600

Layer 4 produced both the best validation BPB and the smallest quantization penalty, so all longer runs use HYMBA_LAYERS=4.

References

  1. Hymba (arXiv:2411.13676, ICLR 2025) — Dong, Fu, ..., Yingyan Lin, Kautz, Molchanov
  2. CPT (arXiv:2101.09868, ICLR 2021) — Yingyan Lin et al., Georgia Tech EIC Lab

@aparna-1407 aparna-1407 marked this pull request as draft April 30, 2026 05:36
@aparna-1407 aparna-1407 marked this pull request as ready for review April 30, 2026 06:40
- Integrate configurable Hymba layer with chunked SSM scan
- Document 1-layer 800-step result and artifact metrics
- Update submission metadata and training log for non-record 16MB run
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant