Skip to content

Commit dcf0caf

Browse files
committed
unittest for ms_deformable_cross_attn
1 parent 7ee9309 commit dcf0caf

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed

tests/test_ops/test_ops.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,3 +1223,55 @@ def test_multiclass_nms_rotated_with_keep_top_k(backend, pre_top_k):
12231223
'multiclass_nms_rotated returned more values than "keep_top_k"\n' \
12241224
f'dets.shape: {dets.shape}\n' \
12251225
f'keep_top_k: {keep_top_k}'
1226+
1227+
1228+
@pytest.mark.parametrize('backend', [TEST_TENSORRT])
1229+
def test_multi_scale_deformable_attn(backend, save_dir=None):
1230+
backend.check_env()
1231+
from mmcv.ops.multi_scale_deform_attn import \
1232+
MultiScaleDeformableAttnFunction
1233+
1234+
Bs = 2
1235+
Nh = 8
1236+
Nc = 32
1237+
Nq = 100
1238+
Np = 200
1239+
spatial_shapes = [[68, 120], [34, 60]]
1240+
Nl = len(spatial_shapes)
1241+
Nk = sum([spatial_shapes[i][0] * spatial_shapes[i][1] for i in range(Nl)])
1242+
value = torch.rand(Bs, Nk, Nh, Nc).cuda()
1243+
value_spatial_shapes = torch.LongTensor(spatial_shapes).cuda()
1244+
level_start_index = torch.LongTensor(
1245+
[0, spatial_shapes[0][0] * spatial_shapes[0][1]]).cuda()
1246+
sampling_locations = torch.rand(Bs, Nq, Nh, Nl, Np, 2).cuda()
1247+
attention_weights = torch.rand(Bs, Nq, Nh, Nl, Np).cuda()
1248+
1249+
class TestModel(torch.nn.Module):
1250+
1251+
def __init__(self) -> None:
1252+
super().__init__()
1253+
self.im2col_step = 64
1254+
1255+
def forward(self, value, value_spatial_shapes, level_start_index,
1256+
sampling_locations, attention_weights):
1257+
1258+
new_x = MultiScaleDeformableAttnFunction.apply(
1259+
value, value_spatial_shapes, level_start_index,
1260+
sampling_locations, attention_weights, self.im2col_step)
1261+
return new_x
1262+
1263+
model = TestModel().cuda()
1264+
1265+
with RewriterContext(cfg={}, backend=backend.backend_name, opset=11):
1266+
backend.run_and_validate(
1267+
model, [
1268+
value, value_spatial_shapes, level_start_index,
1269+
sampling_locations, attention_weights
1270+
],
1271+
'multi_scale_deformable_attn',
1272+
input_names=[
1273+
'value', 'value_spatial_shapes', 'level_start_index',
1274+
'sampling_locations', 'attention_weights'
1275+
],
1276+
output_names=['output'],
1277+
save_dir=save_dir)

0 commit comments

Comments
 (0)