|
19 | 19 | codebase_config=dict(type='mmdet', task='ObjectDetection')))
|
20 | 20 |
|
21 | 21 |
|
22 |
| -def get_trt_config(output_names, shape): |
| 22 | +def get_trt_config(output_names, shape, dynamic_axes=None): |
23 | 23 | deploy_cfg_tensorrt = Config(
|
24 | 24 | dict(
|
25 |
| - onnx_config=dict(input_shape=None, output_names=output_names), |
| 25 | + onnx_config=dict( |
| 26 | + input_shape=None, |
| 27 | + output_names=output_names, |
| 28 | + dynamic_axes=dynamic_axes), |
26 | 29 | backend_config=dict(
|
27 | 30 | type='tensorrt',
|
28 | 31 | common_config=dict(
|
@@ -615,3 +618,30 @@ def linspace_caller(*arg, **kwargs):
|
615 | 618 |
|
616 | 619 | assert np.allclose(
|
617 | 620 | model_output, rewrite_outputs, rtol=1e-03, atol=1e-05)
|
| 621 | + |
| 622 | + |
| 623 | +@backend_checker(Backend.TENSORRT) |
| 624 | +@pytest.mark.parametrize('dtype', [torch.bool, torch.float32]) |
| 625 | +@pytest.mark.parametrize('dynamic_axes', |
| 626 | + [None, dict(input=dict({ |
| 627 | + 0: 'dim0', |
| 628 | + 1: 'dim1' |
| 629 | + }))]) |
| 630 | +def test_cat__tensorrt(dtype, dynamic_axes): |
| 631 | + input = torch.rand(2, 4) |
| 632 | + model = WrapFunction(lambda input: torch.cat( |
| 633 | + [input.to(dtype), input.to(dtype)], -1)) |
| 634 | + pytorch_output = model(input) |
| 635 | + rewrite_output, _ = get_rewrite_outputs( |
| 636 | + model, |
| 637 | + model_inputs={'input': input}, |
| 638 | + deploy_cfg=get_trt_config(['output'], |
| 639 | + shape=[2, 4], |
| 640 | + dynamic_axes=dynamic_axes), |
| 641 | + run_with_backend=True) |
| 642 | + assert pytorch_output.dtype == rewrite_output[0].dtype |
| 643 | + assert torch.allclose( |
| 644 | + pytorch_output.cpu().float(), |
| 645 | + rewrite_output[0].cpu().float(), |
| 646 | + rtol=1e-3, |
| 647 | + atol=1e-5) |
0 commit comments