Skip to content

Commit fbac65e

Browse files
committed
add condinst head unit testing
1 parent d580a59 commit fbac65e

File tree

1 file changed

+252
-0
lines changed

1 file changed

+252
-0
lines changed

tests/test_codebase/test_mmdet/test_mmdet_models.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2364,3 +2364,255 @@ def test_solov2_head_predict_by_feat(backend_type):
23642364
atol=1e-05)
23652365
else:
23662366
assert rewrite_outputs is not None
2367+
2368+
2369+
def get_condinst_bbox_head():
2370+
"""condinst Bbox Head Config."""
2371+
test_cfg = Config(
2372+
dict(
2373+
mask_thr=0.5,
2374+
max_per_img=100,
2375+
min_bbox_size=0,
2376+
nms=dict(iou_threshold=0.6, type='nms'),
2377+
nms_pre=1000,
2378+
score_thr=0.05))
2379+
from mmdet.models.dense_heads import CondInstBboxHead
2380+
model = CondInstBboxHead(
2381+
center_sampling=True,
2382+
centerness_on_reg=True,
2383+
conv_bias=True,
2384+
dcn_on_last_conv=False,
2385+
feat_channels=256,
2386+
in_channels=256,
2387+
loss_bbox=dict(loss_weight=1.0, type='GIoULoss'),
2388+
loss_centerness=dict(
2389+
loss_weight=1.0, type='CrossEntropyLoss', use_sigmoid=True),
2390+
loss_cls=dict(
2391+
alpha=0.25,
2392+
gamma=2.0,
2393+
loss_weight=1.0,
2394+
type='FocalLoss',
2395+
use_sigmoid=True),
2396+
norm_on_bbox=True,
2397+
num_classes=80,
2398+
num_params=169,
2399+
stacked_convs=4,
2400+
strides=[
2401+
8,
2402+
16,
2403+
32,
2404+
64,
2405+
128,
2406+
],
2407+
test_cfg=test_cfg,
2408+
)
2409+
2410+
model.requires_grad_(False)
2411+
return model
2412+
2413+
2414+
@pytest.mark.skipif(
2415+
reason='Only support GPU test', condition=not torch.cuda.is_available())
2416+
@pytest.mark.parametrize('backend_type',
2417+
[Backend.ONNXRUNTIME, Backend.TENSORRT])
2418+
def test_condinst_bbox_head_predict_by_feat(backend_type):
2419+
"""Test predict_by_feat rewrite of condinst bbox head."""
2420+
check_backend(backend_type)
2421+
condinst_bbox_head = get_condinst_bbox_head()
2422+
condinst_bbox_head.cpu().eval()
2423+
s = 128
2424+
batch_img_metas = [{
2425+
'scale_factor': np.ones(4),
2426+
'pad_shape': (s, s, 3),
2427+
'img_shape': (s, s, 3)
2428+
}]
2429+
2430+
output_names = ['dets', 'labels', 'param_preds', 'points', 'strides']
2431+
deploy_cfg = Config(
2432+
dict(
2433+
backend_config=dict(
2434+
type=backend_type.value,
2435+
common_config=dict(max_workspace_size=1 << 32),
2436+
model_inputs=[
2437+
dict(
2438+
input_shapes=dict(
2439+
input=dict(
2440+
min_shape=[1, 3, 320, 320],
2441+
opt_shape=[1, 3, 800, 1344],
2442+
max_shape=[1, 3, 1344, 1344])))
2443+
]),
2444+
onnx_config=dict(output_names=output_names, input_shape=None),
2445+
codebase_config=dict(
2446+
type='mmdet',
2447+
task='ObjectDetection',
2448+
post_processing=dict(
2449+
score_threshold=0.05,
2450+
confidence_threshold=0.005,
2451+
iou_threshold=0.5,
2452+
max_output_boxes_per_class=200,
2453+
pre_top_k=5000,
2454+
keep_top_k=100,
2455+
background_label_id=-1,
2456+
export_postprocess_mask=False))))
2457+
2458+
seed_everything(1234)
2459+
cls_scores = [
2460+
torch.rand(1, condinst_bbox_head.num_classes, pow(2, i), pow(2, i))
2461+
for i in range(5, 0, -1)
2462+
]
2463+
seed_everything(5678)
2464+
bbox_preds = [
2465+
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
2466+
]
2467+
seed_everything(9101)
2468+
score_factors = [
2469+
torch.rand(1, 1, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
2470+
]
2471+
seed_everything(1121)
2472+
param_preds = [
2473+
torch.rand(1, condinst_bbox_head.num_params, pow(2, i), pow(2, i))
2474+
for i in range(5, 0, -1)
2475+
]
2476+
2477+
# to get outputs of onnx/tensorrt model after rewrite
2478+
wrapped_model = WrapModel(
2479+
condinst_bbox_head, 'predict_by_feat', batch_img_metas=batch_img_metas)
2480+
rewrite_inputs = {
2481+
'cls_scores': cls_scores,
2482+
'bbox_preds': bbox_preds,
2483+
'score_factors': score_factors,
2484+
'param_preds': param_preds,
2485+
}
2486+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
2487+
wrapped_model=wrapped_model,
2488+
model_inputs=rewrite_inputs,
2489+
deploy_cfg=deploy_cfg)
2490+
2491+
if is_backend_output:
2492+
dets = rewrite_outputs[0]
2493+
labels = rewrite_outputs[1]
2494+
param_preds = rewrite_outputs[2]
2495+
points = rewrite_outputs[3]
2496+
strides = rewrite_outputs[4]
2497+
assert dets.shape[-1] == 5
2498+
assert labels is not None
2499+
assert param_preds.shape[-1] == condinst_bbox_head.num_params
2500+
assert points.shape[-1] == 2
2501+
assert strides is not None
2502+
else:
2503+
assert rewrite_outputs is not None
2504+
2505+
2506+
def get_condinst_mask_head():
2507+
"""condinst Mask Head Config."""
2508+
test_cfg = Config(
2509+
dict(
2510+
mask_thr=0.5,
2511+
max_per_img=100,
2512+
min_bbox_size=0,
2513+
nms=dict(iou_threshold=0.6, type='nms'),
2514+
nms_pre=1000,
2515+
score_thr=0.05))
2516+
from mmdet.models.dense_heads import CondInstMaskHead
2517+
model = CondInstMaskHead(
2518+
mask_feature_head=dict(
2519+
end_level=2,
2520+
feat_channels=128,
2521+
in_channels=256,
2522+
mask_stride=8,
2523+
norm_cfg=dict(requires_grad=True, type='BN'),
2524+
num_stacked_convs=4,
2525+
out_channels=8,
2526+
start_level=0),
2527+
num_layers=3,
2528+
feat_channels=8,
2529+
mask_out_stride=4,
2530+
size_of_interest=8,
2531+
max_masks_to_train=300,
2532+
loss_mask=dict(
2533+
activate=True,
2534+
eps=5e-06,
2535+
loss_weight=1.0,
2536+
type='DiceLoss',
2537+
use_sigmoid=True),
2538+
test_cfg=test_cfg,
2539+
)
2540+
2541+
model.requires_grad_(False)
2542+
return model
2543+
2544+
2545+
@pytest.mark.skipif(
2546+
reason='Only support GPU test', condition=not torch.cuda.is_available())
2547+
@pytest.mark.parametrize('backend_type',
2548+
[Backend.ONNXRUNTIME, Backend.TENSORRT])
2549+
def test_condinst_mask_head_predict_by_feat(backend_type):
2550+
"""Test predict_by_feat rewrite of condinst mask head."""
2551+
check_backend(backend_type)
2552+
s = 128
2553+
batch_img_metas = [{
2554+
'scale_factor': np.ones(4),
2555+
'pad_shape': (s, s, 3),
2556+
'img_shape': (s, s, 3)
2557+
}]
2558+
2559+
output_names = ['dets', 'labels', 'masks']
2560+
deploy_cfg = Config(
2561+
dict(
2562+
backend_config=dict(
2563+
type=backend_type.value,
2564+
common_config=dict(max_workspace_size=1 << 32),
2565+
model_inputs=[
2566+
dict(
2567+
input_shapes=dict(
2568+
input=dict(
2569+
min_shape=[1, 3, 320, 320],
2570+
opt_shape=[1, 3, 800, 1344],
2571+
max_shape=[1, 3, 1344, 1344])))
2572+
]),
2573+
onnx_config=dict(output_names=output_names, input_shape=None),
2574+
codebase_config=dict(
2575+
type='mmdet',
2576+
task='ObjectDetection')))
2577+
2578+
class TestCondInstMaskHeadModel(torch.nn.Module):
2579+
def __init__(self, condinst_mask_head):
2580+
super(TestCondInstMaskHeadModel, self).__init__()
2581+
self.mask_head = condinst_mask_head
2582+
2583+
def predict_by_feat(self, mask_preds, det, label, batch_img_metas):
2584+
results = dict(dets=det, labels=label)
2585+
return self.mask_head.predict_by_feat(mask_preds, results, batch_img_metas)
2586+
2587+
head = get_condinst_mask_head()
2588+
condinst_mask_head = TestCondInstMaskHeadModel(head)
2589+
condinst_mask_head.cpu().eval()
2590+
2591+
seed_everything(1234)
2592+
mask_preds = torch.rand(1, 100, 200, 200)
2593+
seed_everything(5678)
2594+
dets = torch.rand(1, 100, 5)
2595+
labels = torch.rand(1, 100)
2596+
2597+
# to get outputs of onnx/tensorrt model after rewrite
2598+
wrapped_model = WrapModel(
2599+
condinst_mask_head, 'predict_by_feat', batch_img_metas=batch_img_metas)
2600+
rewrite_inputs = {
2601+
'mask_preds': mask_preds,
2602+
'det': dets,
2603+
'label': labels
2604+
}
2605+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
2606+
wrapped_model=wrapped_model,
2607+
model_inputs=rewrite_inputs,
2608+
deploy_cfg=deploy_cfg)
2609+
2610+
if is_backend_output:
2611+
dets = rewrite_outputs[0]
2612+
labels = rewrite_outputs[1]
2613+
masks = rewrite_outputs[2]
2614+
assert dets.shape[-1] == 5
2615+
assert labels is not None
2616+
assert masks is not None
2617+
else:
2618+
assert rewrite_outputs is not None

0 commit comments

Comments
 (0)