Skip to content

Commit f71e096

Browse files
authored
Update and rename test_simple_gla_for_mamba2.py to test_simple_gla.py
1 parent 11c7f66 commit f71e096

File tree

2 files changed

+71
-105
lines changed

2 files changed

+71
-105
lines changed

tests/ops/test_simple_gla.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import pytest
4+
import torch
5+
6+
from fla.ops.simple_gla import chunk_simple_gla
7+
8+
9+
@pytest.mark.parametrize("vary_A", [True, False])
10+
@pytest.mark.parametrize("dtype", [torch.float, torch.bfloat16])
11+
def test_simple_gla_to_mamba2(vary_A, dtype):
12+
r"""
13+
Map Mamba-2's `mamba_chunk_scan_combined` kernel to FLA's `simple_gla` kernel
14+
15+
Dependencies:
16+
$ pip install mamba-ssm==2.2.2 triton==2.3.1
17+
18+
Reference: `ssd_minimal_discrete` and `test_correctness` in mamba repository:
19+
https://github.com/state-spaces/mamba/blob/v2.2.2/mamba_ssm/modules/ssd_minimal.py#L82
20+
"""
21+
from mamba_ssm.modules.ssd_minimal import ssd_minimal_discrete
22+
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
23+
torch.manual_seed(42)
24+
25+
# Dimensions, Denoted (B, T, Q, D, P) in Mamba2 paper
26+
batch, seq_len, chunk_size, dim, headdim = 2, 512, 8, 64, 16
27+
n_heads = dim // headdim # (H) in the paper
28+
ngroups = n_heads # (G) in the paper; NOTE: do not use group-query here
29+
dstate = 64 # (N) in the paper
30+
device = "cuda"
31+
atol = 5e-4 if dtype == torch.float else 1e-2
32+
33+
x = 0.1 * torch.randn(batch, seq_len, n_heads, headdim, dtype=dtype, device=device)
34+
dt = torch.ones(batch, seq_len, n_heads, dtype=dtype, device=device) # dt=1 can be ignored
35+
36+
if vary_A:
37+
A = -0.1 * torch.rand(1, seq_len, n_heads, dtype=dtype, device=device)
38+
else: # constant A for all position
39+
A = -0.1 * torch.rand(n_heads, dtype=dtype, device=device)
40+
41+
B = 0.1 * torch.randn(batch, seq_len, ngroups, dstate, dtype=dtype, device=device)
42+
C = 0.1 * torch.randn(batch, seq_len, ngroups, dstate, dtype=dtype, device=device)
43+
44+
y_ssd, final_ssd = ssd_minimal_discrete(x * dt.unsqueeze(-1), A * dt, B, C, chunk_size)
45+
46+
if not vary_A:
47+
# NOTE: fused kernel does not support varying A with time
48+
y_fuse, final_fuse = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, return_final_states=True)
49+
assert y_ssd.allclose(y_fuse, 0, atol), f"y diff: {torch.abs(y_ssd - y_fuse).max()}"
50+
# fused kernel upcasts state to float32
51+
# https://github.com/state-spaces/mamba/blob/v2.2.2/mamba_ssm/ops/triton/ssd_combined.py#L650
52+
final_fuse = final_fuse.to(dtype)
53+
assert final_ssd.allclose(final_fuse, 0, atol), f"final diff: {torch.abs(final_ssd - final_fuse).max()}"
54+
55+
# mapping inputs Mamba2 -> FLA
56+
# C, B, X: [batch, seq, head, hidden] -> [batch, head, seq, hidden]
57+
# g: [batch, seq, head] -> [batch, head, seq]
58+
q = C.transpose(1, 2)
59+
k = B.transpose(1, 2)
60+
v = x.transpose(1, 2)
61+
g = (A * dt).transpose(1, 2)
62+
63+
# mapping outputs Mamba2 -> FLA
64+
y_rearrange = y_ssd.transpose(1, 2)
65+
final_rearrange = final_ssd.transpose(2, 3)
66+
67+
# comparing output results between FLA kernel and Mamba2 kernel
68+
outputs_gla_fuse, final_gla_fuse = chunk_simple_gla(q, k, v, g, scale=1.0, output_final_state=True)
69+
assert y_rearrange.allclose(outputs_gla_fuse, 0, atol), f"y diff: {torch.abs(y_rearrange - outputs_gla_fuse).max()}"
70+
final_gla_fuse = final_gla_fuse.to(dtype) # states hard-coded to float32 in FLA kernel
71+
assert final_rearrange.allclose(final_gla_fuse, 0, atol), f"final diff: {torch.abs(final_ssd - final_gla_fuse).max()}"

tests/ops/test_simple_gla_for_mamba2.py

Lines changed: 0 additions & 105 deletions
This file was deleted.

0 commit comments

Comments
 (0)