Skip to content

Commit 0738caa

Browse files
author
Wei
authored
Changes done internally at Facebook (#1172)
b18ad449f9cbcad8d9e13c74c7605dc4dcca53bc Jason Park <[email protected]> AccOpGraph and pattern matcher from GraphModule 0888aecb9f5a6add2d855300f574592587ee8484 Jason Park <[email protected]> Dependency checker e935e26f2ace98e9ef251ad8328d6a07ddca828f Jason Park <[email protected]> Sort topologically. d6b5a5187cd22a7a7340be7bf021edbaf178b981 Jason Park <[email protected]> Grouped op fusion for layer norm c41054bec9cf617248d08363e15992bc24eb0ce0 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.eisum 9f9505f320c87e9069658349fcc6040d8fdc77a2 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.elu fec2cab3bb76c22f0eb660d6dd9a4f3ae5523cfd Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.embedding 1a07ef201654adb752160e8e47e2f8369abe142a Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.eq 7f110836cd4e36c04954591698aac4addac8fcb9 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.tanh c1b61c387833adcc602036605886a05df1d18152 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.size 3b3ae7efd15a599a2ef48e98ff2c0b7c1303d187 Kefei Lu <[email protected]> fx2trt: remove some comments bc4724e645d10dd75b47cef084a680f5eaa9d0bf Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.gelu 252a70960be30e037bee2d712ff286743015df66 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.getitem 80039709fde7cdcdf81abdbfd2199adfe9df215b Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.gt d6c9202ad3c1d9f35884cd189edad289e36a6787 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.hardsigmoid 6b1b7ba4690c8d5724ff6784f2f56a7cafad2c78 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.hardtanh 37747aec5f7ec8d1e1552f9280dae41910c966ba Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.isinf 6eb5891faae2e95ac7b6918f5a88b3a0a84a5140 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.leaky_relu df212a34c238cd60a6e13f508400468cec539961 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.linear 8dc800208a7482d4b89a37526d37cdd0934a2402 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.logical_and edfd5e9f6336700ebc91e907a9ee207d8fe25336 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.logical_or 588f7801308ee2129cb67276fe40e00f9078af1d Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.xor cba566389cc4b310dbccceb9dd265ada4aded201 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.lt 67497e1c99ef796e5ad88b3ecf5751d596877932 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.max c91cc6dda5221e55b2e960a6cd546f551d01cc7b Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.maximum fd95c60b850bff8e00ba844da38c5d04c8bc25ba Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.maxpool e7a00066f73231c466f79f15d7f48b917c855104 Shreyansh Prajapati <[email protected]> Test dynamic shape support for acc_ops.min 5f9fcfd27a4d1f7d6a8f034ea2b7c6c5c8a51809 Jason Park <[email protected]> grouped swish LN 249ce82c30f0d70dd6de952c4e85ce273b09cd45 Jason Park <[email protected]> Handling different eps for layer norms. 48907b6557d57569545d926b94f96c0d58ae2fcb Kefei Lu <[email protected]> fx2trt: log input specs
1 parent 5be9af2 commit 0738caa

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1275
-32
lines changed

examples/fx/quantized_resnet_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def build_int8_trt(rn18):
4949
# uncomment to check per channel quant works
5050
weight=torch.quantization.default_per_channel_weight_observer,
5151
)
52-
prepared = prepare_fx(rn18, {"": qconfig})
52+
prepared = prepare_fx(rn18, {"": qconfig}, data)
5353
for _ in range(10):
5454
prepared(data)
5555
quantized_rn18 = convert_to_reference(prepared)

py/torch_tensorrt/fx/lower.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,13 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
104104
),
105105
self.lower_setting.opt_profile_replica,
106106
)
107-
if self.lower_setting.explicit_batch_dimension and self.lower_setting.dynamic_batch
107+
if self.lower_setting.explicit_batch_dimension
108+
and self.lower_setting.dynamic_batch
108109
else InputTensorSpec.from_tensors(input)
109110
)
110111
)
112+
logger.info(f"{split_name=} {input_specs_val=}")
113+
111114
# Prepare algorithm selector and timing_cache for TRTInterpreter
112115
algo_selector = None
113116
if self.lower_setting.algo_selector:

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,7 @@ class LowerSetting(LowerSettingBasic):
6464
cache file is provided.
6565
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
6666
preset_lowerer (str): when specified, use a preset logic to build the
67-
instance of Lowerer. Refer to
68-
`caffe2.torch.fb.model_transform.fx2trt.presets.LowererPresetsManager` on
69-
how presets are applied. Refer to
70-
`caffe2.torch.fb.model_transform.fx2trt.presets.ESUHMLowererPreset` on how
71-
to add a preset.
67+
instance of Lowerer.
7268
opt_profile_replica (int): the number of opt profile set for TensorRT engine, this field is
7369
only used by explicit batch dim with dynamic shape mode.
7470
dynamic_batch: enable the dynamic shape in TRT with dim=-1 for the 1st dimension.

py/torch_tensorrt/fx/passes/pass_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,17 @@ def pass_with_validation(
6363
y = y.cpu()
6464
accuracy_check = torch.allclose(x, y, **kwargs)
6565
if not accuracy_check:
66+
_LOGGER.error(
67+
f"Pass {pass_} failed correctness check, get original model output as {x} and processed model output as {y} for output {kk}."
68+
)
6669
if suppress_accuracy_check_failure:
6770
_LOGGER.error(
68-
f"pass {pass_} failed correctness check due to output {kk}, escape current pass."
71+
f"Pass {pass_} failed correctness check due to output {kk}."
6972
)
7073
return processed_module
7174
else:
7275
raise RuntimeError(
73-
f"pass {pass_} failed correctness check due to output {kk}"
76+
f"Pass {pass_} failed correctness check due to output {kk}"
7477
)
7578
return processed_module
7679

py/torch_tensorrt/fx/test/converters/acc_op/test_dequantize.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,24 @@ def forward(self, x):
4545
TestModule(), input_specs, expected_ops={acc_ops.dequantize}
4646
)
4747

48+
def test_dequantize_with_dynamic_shape_four_dimensions(self):
49+
class TestModule(nn.Module):
50+
def forward(self, x):
51+
x = torch.quantize_per_tensor(x, 1, 0, torch.quint8)
52+
return x.dequantize()
53+
54+
input_specs = [
55+
InputTensorSpec(
56+
shape=(-1, -1, -1, -1),
57+
dtype=torch.float32,
58+
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
59+
),
60+
]
61+
62+
self.run_test_with_dynamic_shape(
63+
TestModule(), input_specs, expected_ops={acc_ops.dequantize}
64+
)
65+
4866

4967
if __name__ == "__main__":
5068
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_einsum.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
44
from parameterized import parameterized
55
from torch.testing._internal.common_utils import run_tests
6-
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
6+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
77

88

99
class TestConverter(AccTestCase):
@@ -30,6 +30,37 @@ def forward(self, x, y):
3030
test_implicit_batch_dim=False,
3131
)
3232

33+
@parameterized.expand(
34+
[
35+
("4d_dim", "bcwd,bcdh->bcwh", (2, 3, 4, 5), (2, 3, 5, 6)),
36+
("4d_dim_ext", "bcxd,bcyd->bcxy", (2, 3, 4, 5), (2, 3, 6, 5)),
37+
# TRT does not support ellipsis or diagonal operations
38+
]
39+
)
40+
def test_einsum_with_dynamic_shape_four_dimensions(
41+
self, _, equation, x_size, y_size
42+
):
43+
class Einsum(nn.Module):
44+
def forward(self, x, y):
45+
return torch.einsum(equation, x, y)
46+
47+
input_specs = [
48+
InputTensorSpec(
49+
shape=(-1, -1, -1, -1),
50+
dtype=torch.float32,
51+
shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))],
52+
),
53+
InputTensorSpec(
54+
shape=(-1, -1, -1, -1),
55+
dtype=torch.float32,
56+
shape_ranges=[((1, 1, 3, 3), (1, 2, 3, 3), (3, 3, 3, 3))],
57+
),
58+
]
59+
60+
self.run_test_with_dynamic_shape(
61+
Einsum(), input_specs, expected_ops={acc_ops.einsum}
62+
)
63+
3364

3465
if __name__ == "__main__":
3566
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_elu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@ def forward(self, x):
3030
TestModule(), input_specs, expected_ops={acc_ops.elu}
3131
)
3232

33+
def test_elu_with_dynamic_shape_four_dimensions(self):
34+
class TestModule(nn.Module):
35+
def forward(self, x):
36+
return nn.functional.elu(x)
37+
38+
input_specs = [
39+
InputTensorSpec(
40+
shape=(-1, -1, -1, -1),
41+
dtype=torch.float32,
42+
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 5), (3, 3, 3, 5))],
43+
),
44+
]
45+
46+
self.run_test_with_dynamic_shape(
47+
TestModule(), input_specs, expected_ops={acc_ops.elu}
48+
)
49+
3350

3451
if __name__ == "__main__":
3552
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_embedding.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
66
from parameterized import param, parameterized
77
from torch.testing._internal.common_utils import run_tests
8-
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
8+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
99

1010

1111
@unittest.skip(
@@ -62,6 +62,46 @@ def forward(self, indices, weights):
6262
test_explicit_batch_dim=True,
6363
)
6464

65+
def test_embedding_with_dynamic_shape_four_dimensions(
66+
self,
67+
test_name,
68+
indices_tensor,
69+
weights_tensor,
70+
padding_idx=None,
71+
max_norm=None,
72+
norm_type=2.0,
73+
scale_grad_by_freq=False,
74+
sparse=False,
75+
):
76+
class TestEmbedding(torch.nn.Module):
77+
def forward(self, indices, weights):
78+
return torch.nn.functional.embedding(
79+
input=indices,
80+
weight=weights,
81+
padding_idx=padding_idx,
82+
max_norm=max_norm,
83+
norm_type=norm_type,
84+
scale_grad_by_freq=scale_grad_by_freq,
85+
sparse=sparse,
86+
)
87+
88+
input_specs = [
89+
InputTensorSpec(
90+
shape=(-1, -1, -1, -1),
91+
dtype=torch.float32,
92+
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
93+
),
94+
InputTensorSpec(
95+
shape=(-1, -1, -1, -1),
96+
dtype=torch.float32,
97+
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
98+
),
99+
]
100+
101+
self.run_test_with_dynamic_shape(
102+
TestEmbedding(), input_specs, expected_ops={acc_ops.embedding}
103+
)
104+
65105

66106
if __name__ == "__main__":
67107
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_eq.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
5-
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase
5+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
66

77

88
class TestEqConverter(AccTestCase):
@@ -184,6 +184,28 @@ def forward(self, x, y):
184184
)
185185

186186

187+
class TestEqOperatorSimpleConverterWithDynamicShape(AccTestCase):
188+
def test_eq(self):
189+
class Eq(torch.nn.Module):
190+
def forward(self, x, y):
191+
return x == y
192+
193+
input_specs = [
194+
InputTensorSpec(
195+
shape=(-1, -1, -1, -1),
196+
dtype=torch.float32,
197+
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
198+
),
199+
InputTensorSpec(
200+
shape=(-1, -1, -1, -1),
201+
dtype=torch.float32,
202+
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
203+
),
204+
]
205+
206+
self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq})
207+
208+
187209
class TestEqOperatorConstantConverter(AccTestCase):
188210
@parameterized.expand(
189211
[
@@ -243,5 +265,25 @@ def forward(self, x):
243265
)
244266

245267

268+
class TestConstInputConverterWithDynamicShape(AccTestCase):
269+
def test_eq(self):
270+
class Eq(torch.nn.Module):
271+
def __init__(self):
272+
super().__init__()
273+
274+
def forward(self, x):
275+
return x.shape[0] == 4
276+
277+
input_specs = [
278+
InputTensorSpec(
279+
shape=(-1, -1, -1, -1),
280+
dtype=torch.float32,
281+
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
282+
),
283+
]
284+
285+
self.run_test_with_dynamic_shape(Eq(), input_specs, expected_ops={acc_ops.eq})
286+
287+
246288
if __name__ == "__main__":
247289
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_gelu.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,23 @@ def forward(self, x):
3535
TestModule(), input_specs, expected_ops={acc_ops.gelu}
3636
)
3737

38+
def test_gelu_with_dynamic_shape_four_dimensions(self):
39+
class TestModule(nn.Module):
40+
def forward(self, x):
41+
return nn.functional.gelu(x)
42+
43+
input_specs = [
44+
InputTensorSpec(
45+
shape=(-1, -1, -1, -1),
46+
dtype=torch.float32,
47+
shape_ranges=[((1, 1, 1, 1), (1, 2, 3, 3), (3, 3, 3, 3))],
48+
),
49+
]
50+
51+
self.run_test_with_dynamic_shape(
52+
TestModule(), input_specs, expected_ops={acc_ops.gelu}
53+
)
54+
3855

3956
if __name__ == "__main__":
4057
run_tests()

py/torch_tensorrt/fx/test/converters/acc_op/test_getitem.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,52 @@ def forward(self, x):
148148
Getitem(idx), input_specs, expected_ops={acc_ops.getitem}
149149
)
150150

151+
# Testing with following parameters results into Error:
152+
# AssertionError: We don't support slicing tensor on dynamic shape.
153+
154+
"""
155+
("ellipsis", (slice(None, None, None), ..., slice(0, -3, 2))),
156+
(
157+
"slice_end_none",
158+
(slice(None, None, None), slice(None, None, None), slice(1, None, 1)),
159+
),
160+
(
161+
"slice_step_none",
162+
(slice(None, None, None), slice(None, None, None), slice(0, 3, None)),
163+
),
164+
"""
165+
166+
@parameterized.expand(
167+
[
168+
("slice_batch_dim", slice(None, None, None)),
169+
(
170+
"slice_all_none",
171+
(slice(None, None, None), slice(None, None, None)),
172+
),
173+
]
174+
)
175+
def test_getitem_with_dynamic_shape_four_dimensions(self, _, idx):
176+
class Getitem(nn.Module):
177+
def __init__(self, idx):
178+
super().__init__()
179+
self.idx = idx
180+
181+
def forward(self, x):
182+
x = x + x
183+
return x[self.idx]
184+
185+
input_specs = [
186+
InputTensorSpec(
187+
shape=(-1, -1, -1, -1),
188+
dtype=torch.float32,
189+
shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))],
190+
),
191+
]
192+
193+
self.run_test_with_dynamic_shape(
194+
Getitem(idx), input_specs, expected_ops={acc_ops.getitem}
195+
)
196+
151197

152198
if __name__ == "__main__":
153199
run_tests()

0 commit comments

Comments
 (0)