@@ -1223,3 +1223,55 @@ def test_multiclass_nms_rotated_with_keep_top_k(backend, pre_top_k):
1223
1223
'multiclass_nms_rotated returned more values than "keep_top_k"\n ' \
1224
1224
f'dets.shape: { dets .shape } \n ' \
1225
1225
f'keep_top_k: { keep_top_k } '
1226
+
1227
+
1228
+ @pytest .mark .parametrize ('backend' , [TEST_TENSORRT ])
1229
+ def test_multi_scale_deformable_attn (backend , save_dir = None ):
1230
+ backend .check_env ()
1231
+ from mmcv .ops .multi_scale_deform_attn import \
1232
+ MultiScaleDeformableAttnFunction
1233
+
1234
+ Bs = 2
1235
+ Nh = 8
1236
+ Nc = 32
1237
+ Nq = 100
1238
+ Np = 200
1239
+ spatial_shapes = [[68 , 120 ], [34 , 60 ]]
1240
+ Nl = len (spatial_shapes )
1241
+ Nk = sum ([spatial_shapes [i ][0 ] * spatial_shapes [i ][1 ] for i in range (Nl )])
1242
+ value = torch .rand (Bs , Nk , Nh , Nc ).cuda ()
1243
+ value_spatial_shapes = torch .LongTensor (spatial_shapes ).cuda ()
1244
+ level_start_index = torch .LongTensor (
1245
+ [0 , spatial_shapes [0 ][0 ] * spatial_shapes [0 ][1 ]]).cuda ()
1246
+ sampling_locations = torch .rand (Bs , Nq , Nh , Nl , Np , 2 ).cuda ()
1247
+ attention_weights = torch .rand (Bs , Nq , Nh , Nl , Np ).cuda ()
1248
+
1249
+ class TestModel (torch .nn .Module ):
1250
+
1251
+ def __init__ (self ) -> None :
1252
+ super ().__init__ ()
1253
+ self .im2col_step = 64
1254
+
1255
+ def forward (self , value , value_spatial_shapes , level_start_index ,
1256
+ sampling_locations , attention_weights ):
1257
+
1258
+ new_x = MultiScaleDeformableAttnFunction .apply (
1259
+ value , value_spatial_shapes , level_start_index ,
1260
+ sampling_locations , attention_weights , self .im2col_step )
1261
+ return new_x
1262
+
1263
+ model = TestModel ().cuda ()
1264
+
1265
+ with RewriterContext (cfg = {}, backend = backend .backend_name , opset = 11 ):
1266
+ backend .run_and_validate (
1267
+ model , [
1268
+ value , value_spatial_shapes , level_start_index ,
1269
+ sampling_locations , attention_weights
1270
+ ],
1271
+ 'multi_scale_deformable_attn' ,
1272
+ input_names = [
1273
+ 'value' , 'value_spatial_shapes' , 'level_start_index' ,
1274
+ 'sampling_locations' , 'attention_weights'
1275
+ ],
1276
+ output_names = ['output' ],
1277
+ save_dir = save_dir )
0 commit comments