Skip to content

Commit fb41cf7

Browse files
committed
fix: Add test suite for torch.compile backend (#1849)
1 parent 2addf5e commit fb41cf7

File tree

7 files changed

+292
-39
lines changed

7 files changed

+292
-39
lines changed

.circleci/config.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,22 @@ commands:
727727
- store_artifacts:
728728
path: /tmp/testlogs
729729

730+
test-dynamo-torch_compile-core:
731+
description: "Test the Dynamo torch_compile path"
732+
steps:
733+
- run:
734+
name: Run Dynamo torch_compile core tests
735+
command: |
736+
cd py/torch_tensorrt/dynamo/torch_compile
737+
pushd test/
738+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
739+
popd
740+
741+
- store_test_results:
742+
path: /tmp/artifacts
743+
- store_artifacts:
744+
path: /tmp/testlogs
745+
730746
test-dynamo-torch_compile:
731747
description: "Test the Dynamo torch_compile path"
732748
steps:
@@ -953,6 +969,7 @@ jobs:
953969
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
954970
- dump-test-env
955971
- test-dynamo-torch_compile
972+
- test-dynamo-torch_compile-core
956973
- test-dynamo-fx_ts
957974

958975
package-x86_64-linux:

py/torch_tensorrt/dynamo/test/utils.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,3 @@ def cosine_similarity(gt_tensor, pred_tensor):
1313
res = res.cpu().detach().item()
1414

1515
return res
16-
17-
18-
def same_output_format(trt_output, torch_output):
19-
# For each encountered collection type, ensure the torch and trt outputs agree
20-
# on type and size, checking recursively through all member elements.
21-
if isinstance(trt_output, tuple):
22-
return (
23-
isinstance(torch_output, tuple)
24-
and (len(trt_output) == len(torch_output))
25-
and all(
26-
same_output_format(trt_entry, torch_entry)
27-
for trt_entry, torch_entry in zip(trt_output, torch_output)
28-
)
29-
)
30-
elif isinstance(trt_output, list):
31-
return (
32-
isinstance(torch_output, list)
33-
and (len(trt_output) == len(torch_output))
34-
and all(
35-
same_output_format(trt_entry, torch_entry)
36-
for trt_entry, torch_entry in zip(trt_output, torch_output)
37-
)
38-
)
39-
elif isinstance(trt_output, dict):
40-
return (
41-
isinstance(torch_output, dict)
42-
and (len(trt_output) == len(torch_output))
43-
and (trt_output.keys() == torch_output.keys())
44-
and all(
45-
same_output_format(trt_output[key], torch_output[key])
46-
for key in trt_output.keys()
47-
)
48-
)
49-
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
50-
raise AssertionError(
51-
"Unsupported output type 'set' encountered in output format check."
52-
)
53-
else:
54-
return type(trt_output) is type(torch_output)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from torch_tensorrt.dynamo.torch_compile.utils import prepare_device, prepare_inputs
2+
from utils import same_output_format
3+
import torch_tensorrt
4+
import unittest
5+
import torch
6+
7+
8+
class TestPrepareDevice(unittest.TestCase):
9+
def test_prepare_cuda_device(self):
10+
gpu_id = 0
11+
device = torch.device(f"cuda:{gpu_id}")
12+
prepared_device = prepare_device(device)
13+
self.assertTrue(isinstance(prepared_device, torch.device))
14+
self.assertTrue(prepared_device.index == gpu_id)
15+
16+
def test_prepare_trt_device(self):
17+
gpu_id = 4
18+
device = torch_tensorrt.Device(gpu_id=gpu_id)
19+
prepared_device = prepare_device(device)
20+
self.assertTrue(isinstance(prepared_device, torch.device))
21+
self.assertTrue(prepared_device.index == gpu_id)
22+
23+
24+
class TestPrepareInputs(unittest.TestCase):
25+
def test_prepare_single_tensor_input(self):
26+
inputs = [torch.ones((4, 4))]
27+
prepared_inputs = prepare_inputs(inputs)
28+
self.assertTrue(
29+
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
30+
)
31+
32+
def test_prepare_trt_input(self):
33+
inputs = [torch_tensorrt.Input(shape=(4, 3), dtype=torch.float)]
34+
prepared_inputs = prepare_inputs(inputs)
35+
self.assertTrue(
36+
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
37+
)
38+
39+
def test_prepare_mixed_type_compound_tensor_input(self):
40+
inputs = {
41+
"first": [
42+
torch.ones((4, 4)),
43+
torch_tensorrt.Input(shape=(4, 3), dtype=torch.float),
44+
],
45+
"second": (
46+
torch.rand((5, 1)),
47+
(torch.rand((5, 1)), torch_tensorrt.Input(shape=(2, 3))),
48+
),
49+
}
50+
prepared_inputs = prepare_inputs(inputs)
51+
self.assertTrue(
52+
same_output_format(inputs, prepared_inputs, enforce_tensor_type=False)
53+
)
54+
55+
56+
if __name__ == "__main__":
57+
unittest.main()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from functools import partial
2+
from utils import fx_dynamo_testing_backend
3+
from torch.testing._internal.common_utils import run_tests, TestCase
4+
import torch
5+
6+
7+
class TestLowering(TestCase):
8+
def test_lowering_inplace_op(self):
9+
class FullySupported(torch.nn.Module):
10+
def __init__(self, *args, **kwargs) -> None:
11+
super().__init__(*args, **kwargs)
12+
13+
def forward(self, x, y):
14+
x = torch.ops.aten.add_.Tensor(x, y)
15+
x = torch.ops.aten.relu_.default(x)
16+
return x
17+
18+
# Operations expected to be included in the traced graph after decompositions
19+
expected_ops = {torch.ops.aten.add.Tensor, torch.ops.aten.relu.default}
20+
21+
# Trace module and set up custom backend to track intermediate graphs
22+
fx_graph = torch.fx.symbolic_trace(FullySupported())
23+
partitioned_graphs = []
24+
custom_backend = partial(
25+
fx_dynamo_testing_backend,
26+
store_intermediate_graphs=partitioned_graphs,
27+
)
28+
29+
# Invoke compilation
30+
compiled_graph = torch.compile(fx_graph, backend=custom_backend)
31+
compiled_graph(
32+
torch.rand(
33+
5,
34+
).cuda(),
35+
torch.rand(
36+
5,
37+
).cuda(),
38+
)
39+
40+
# Iterate over intermediate graphs, attempt to match nodes
41+
for fx_module in partitioned_graphs:
42+
for _, submodule in fx_module.named_children():
43+
for node in submodule.graph.nodes:
44+
45+
if node.op == "call_function" and node.target in expected_ops:
46+
expected_ops.remove(node.target)
47+
48+
self.assertEqual(
49+
len(expected_ops), 0, "All operators should have been decomposed"
50+
)
51+
52+
53+
if __name__ == "__main__":
54+
run_tests()
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from torch_tensorrt.dynamo.torch_compile.lowering import partition
2+
from torch.testing._internal.common_utils import run_tests, TestCase
3+
import torch
4+
from copy import deepcopy
5+
import numpy as np
6+
7+
8+
class TestPartitioning(TestCase):
9+
def test_partition_fully_supported_one_op(self):
10+
class FullySupportedOneOp(torch.nn.Module):
11+
def __init__(self, *args, **kwargs) -> None:
12+
super().__init__(*args, **kwargs)
13+
14+
def forward(self, x, y):
15+
return torch.ops.aten.add.Tensor(x, y)
16+
17+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
18+
partitioned_graph = partition(deepcopy(fx_graph))
19+
self.assertEqual(
20+
len(list(partitioned_graph.named_children())),
21+
0,
22+
"Single operators should not be segmented",
23+
)
24+
25+
def test_partition_fully_supported_multi_op(self):
26+
class FullySupportedMultiOp(torch.nn.Module):
27+
def __init__(self, *args, **kwargs) -> None:
28+
super().__init__(*args, **kwargs)
29+
30+
def forward(self, x, y):
31+
sum_ = torch.ops.aten.sub.Tensor(x, y)
32+
concat_ = torch.ops.aten.cat.default(x, sum_)
33+
relu_ = torch.ops.aten.relu.default(concat_)
34+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
35+
return pow_
36+
37+
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
38+
partitioned_graph = partition(deepcopy(fx_graph))
39+
self.assertEqual(
40+
len(list(partitioned_graph.named_children())),
41+
1,
42+
"All operators are supported, there should be one segment",
43+
)
44+
45+
def test_partition_partially_supported_multi_op(self):
46+
class PartiallySupportedMultiOp(torch.nn.Module):
47+
def __init__(self, *args, **kwargs) -> None:
48+
super().__init__(*args, **kwargs)
49+
50+
def forward(self, x, y):
51+
sum_1 = torch.ops.aten.add.Tensor(x, y)
52+
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
53+
sum_ = np.sum(sum_1) + np.sum(sum_2)
54+
relu_ = torch.ops.aten.relu.default(sum_)
55+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
56+
return pow_
57+
58+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
59+
partitioned_graph = partition(deepcopy(fx_graph))
60+
self.assertEqual(
61+
len(list(partitioned_graph.named_children())),
62+
2,
63+
"Unsupported operators interleave supported ones, expected 2 segments",
64+
)
65+
66+
67+
if __name__ == "__main__":
68+
run_tests()
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from copy import deepcopy
2+
from functools import partial
3+
from typing import List, Sequence
4+
import torch
5+
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
6+
get_decompositions,
7+
)
8+
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
9+
partition,
10+
)
11+
12+
from torch._dynamo.backends.common import fake_tensor_unsupported
13+
14+
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
15+
16+
17+
@fake_tensor_unsupported
18+
def fx_dynamo_testing_backend(
19+
gm: torch.fx.GraphModule,
20+
sample_inputs: Sequence[torch.Tensor],
21+
*,
22+
store_intermediate_graphs: List,
23+
):
24+
"""Helper Dynamo backend exclusively for testing"""
25+
custom_backend = partial(
26+
compile_module_testing,
27+
store_intermediate_graphs=store_intermediate_graphs,
28+
)
29+
30+
# Invoke AOTAutograd to translate operators to aten
31+
return aot_module_simplified(
32+
gm,
33+
sample_inputs,
34+
fw_compiler=make_boxed_compiler(custom_backend),
35+
decompositions=get_decompositions(),
36+
)
37+
38+
39+
def compile_module_testing(
40+
gm: torch.fx.GraphModule,
41+
example_inputs: Sequence[torch.Tensor],
42+
*,
43+
store_intermediate_graphs: List,
44+
) -> torch.fx.GraphModule:
45+
"""Helper compiler exclusively for testing"""
46+
partitioned_module = partition(gm)
47+
48+
# Store intermediate graph from partitioned module
49+
store_intermediate_graphs.append(deepcopy(partitioned_module))
50+
51+
return partitioned_module
52+
53+
54+
def same_output_format(trt_output, torch_output, enforce_tensor_type=True):
55+
# For each encountered collection type, ensure the torch and trt outputs agree
56+
# on type and size, checking recursively through all member elements.
57+
if isinstance(trt_output, tuple):
58+
return (
59+
isinstance(torch_output, tuple)
60+
and (len(trt_output) == len(torch_output))
61+
and all(
62+
same_output_format(trt_entry, torch_entry, enforce_tensor_type)
63+
for trt_entry, torch_entry in zip(trt_output, torch_output)
64+
)
65+
)
66+
elif isinstance(trt_output, list):
67+
return (
68+
isinstance(torch_output, list)
69+
and (len(trt_output) == len(torch_output))
70+
and all(
71+
same_output_format(trt_entry, torch_entry, enforce_tensor_type)
72+
for trt_entry, torch_entry in zip(trt_output, torch_output)
73+
)
74+
)
75+
elif isinstance(trt_output, dict):
76+
return (
77+
isinstance(torch_output, dict)
78+
and (len(trt_output) == len(torch_output))
79+
and (trt_output.keys() == torch_output.keys())
80+
and all(
81+
same_output_format(
82+
trt_output[key], torch_output[key], enforce_tensor_type
83+
)
84+
for key in trt_output.keys()
85+
)
86+
)
87+
elif isinstance(trt_output, set) or isinstance(trt_output, frozenset):
88+
raise AssertionError(
89+
"Unsupported output type 'set' encountered in output format check."
90+
)
91+
elif enforce_tensor_type:
92+
return type(trt_output) is type(torch_output)
93+
else:
94+
return True

py/torch_tensorrt/dynamo/torch_compile/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,5 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
6464
raise ValueError(
6565
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
6666
)
67+
68+
return device

0 commit comments

Comments
 (0)