Skip to content

Commit 2d71f12

Browse files
committed
fix: Fix for Dynamic Shape Tests + Input class
1 parent 645ea45 commit 2d71f12

File tree

11 files changed

+73
-33
lines changed

11 files changed

+73
-33
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
_LOGGER: logging.Logger = logging.getLogger(__name__)
1515

1616

17-
def or_none(args, i):
18-
return args[i] if len(args) > i else None
17+
def args_bounds_check(args, i, replacement=None):
18+
return args[i] if len(args) > i else replacement
1919

2020

2121
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm)
@@ -59,17 +59,24 @@ def aten_ops_div(
5959
# If both are TRTTensor, both are cast to float32
6060
if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor):
6161
kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor(
62-
network, kwargs_new["input"], kwargs_new["other"]
62+
network,
63+
kwargs_new["input"],
64+
kwargs_new["other"],
65+
name,
6366
)
6467
# If one is TRTTensor, it is cast to float32
6568
elif isinstance(args[0], TRTTensor) and (
6669
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
6770
):
68-
kwargs_new["input"] = cast_trt_tensor(network, kwargs_new["input"], trt.float32)
71+
kwargs_new["input"] = cast_trt_tensor(
72+
network, kwargs_new["input"], trt.float32, name
73+
)
6974
elif isinstance(args[1], TRTTensor) and (
7075
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
7176
):
72-
kwargs_new["other"] = cast_trt_tensor(network, kwargs_new["other"], trt.float32)
77+
kwargs_new["other"] = cast_trt_tensor(
78+
network, kwargs_new["other"], trt.float32, name
79+
)
7380
rounding_mode = kwargs.get("rounding_mode")
7481
if rounding_mode is None:
7582
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
@@ -136,10 +143,10 @@ def aten_ops_embedding(
136143
name,
137144
input=args[1],
138145
weight=args[0],
139-
max_norm=or_none(args, 2),
140-
norm_type=or_none(args, 3),
141-
scale_grad_by_freq=or_none(args, 4),
142-
sparse=or_none(args, 5),
146+
max_norm=args_bounds_check(args, 2),
147+
norm_type=args_bounds_check(args, 3),
148+
scale_grad_by_freq=args_bounds_check(args, 4),
149+
sparse=args_bounds_check(args, 5),
143150
)
144151

145152

@@ -311,11 +318,11 @@ def aten_ops_clamp(
311318
return impl.elementwise.clamp(
312319
network,
313320
target,
314-
SourceIR.ACC,
321+
SourceIR.ATEN,
315322
name,
316323
input_val=args[0],
317-
min_val=or_none(args, 1),
318-
max_val=or_none(args, 2),
324+
min_val=args_bounds_check(args, 1),
325+
max_val=args_bounds_check(args, 2),
319326
)
320327

321328

@@ -349,5 +356,5 @@ def aten_ops_slice(
349356
args[1],
350357
args[2],
351358
args[3],
352-
args[4],
359+
args_bounds_check(args, 4, replacement=1),
353360
)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
TRTNetwork,
66
TRTTensor,
77
)
8-
import torch_tensorrt as trt
8+
import tensorrt as trt
99
from typing import List
1010

1111

@@ -71,21 +71,23 @@ def cast_int_int_div_trt_tensor(
7171
network: TRTNetwork,
7272
lhs_val: TRTTensor,
7373
rhs_val: TRTTensor,
74+
name: str,
7475
) -> List[TRTTensor]:
7576
"""
7677
Given two `int` data type TRT Tensor to div operation, cast the TRT Tensor to float type
7778
Args:
7879
network (TRTNetwork): A TensorRT network
7980
lhs_val (TRTTensor): A TRT Tensor numerator
8081
rhs_val (TRTTensor): A TRT Tensor numerator
82+
name (str): Name of calling layer
8183
Returns:
8284
A list of lhs_val and rhs_val casted to the approriate datatype
8385
"""
8486
if (lhs_val.dtype == trt.int8 or lhs_val.dtype == trt.int32) and (
8587
rhs_val.dtype == trt.int8 or rhs_val.dtype == trt.int32
8688
):
87-
lhs_val = cast_trt_tensor(network, lhs_val, trt.float32)
88-
rhs_val = cast_trt_tensor(network, rhs_val, trt.float32)
89+
lhs_val = cast_trt_tensor(network, lhs_val, trt.float32, name)
90+
rhs_val = cast_trt_tensor(network, rhs_val, trt.float32, name)
8991
return list((lhs_val, rhs_val))
9092

9193

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def slice_op(
4646
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
4747
start_int = cast(int, start)
4848
stop_int = cast(int, stop)
49+
if stop_int == 2**63 - 1:
50+
stop_int = input.shape[dim]
4951
step_int = cast(int, step)
5052
start = [0] * len(input.shape)
5153
start[dim] = start_int

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,15 @@ def __init__(
6464
+ "\n".join(f"{i}" for i in missing_ops)
6565
)
6666

67-
self.optimization_profiles: Optional[List] = None
67+
self.optimization_profiles = (
68+
[self.builder.create_optimization_profile()]
69+
if any(
70+
input_spec.shape_mode == Input._ShapeMode.DYNAMIC
71+
for input_spec in input_specs
72+
)
73+
else None
74+
)
75+
6876
self.input_specs = input_specs
6977
self.input_specs_iter = 0
7078
self._cur_node_name: Optional[str] = None
@@ -257,7 +265,7 @@ def placeholder(self, target, args, kwargs):
257265
opt_shape = current_input.shape["opt_shape"]
258266
max_shape = current_input.shape["max_shape"]
259267
self.optimization_profiles[0].set_shape(
260-
target, [min_shape, opt_shape, max_shape]
268+
target, min_shape, opt_shape, max_shape
261269
)
262270
assert len(min_shape) == len(opt_shape) == len(max_shape)
263271
for i in range(len(min_shape)):

py/torch_tensorrt/dynamo/test_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def run_test(
6767
rtol,
6868
atol,
6969
precision=torch.float,
70+
check_dtype=True,
7071
):
7172
with torch.no_grad():
7273
cuda_inputs = []
@@ -117,7 +118,12 @@ def run_test(
117118
if ref.dtype == torch.int64:
118119
ref = ref.int() # convert torch.max's index output tensor to int32
119120
torch.testing.assert_close(
120-
out.cpu(), ref, rtol=rtol, atol=atol, equal_nan=True
121+
out.cpu(),
122+
ref,
123+
rtol=rtol,
124+
atol=atol,
125+
equal_nan=True,
126+
check_dtype=check_dtype,
121127
)
122128

123129
def run_test_custom_compare_results(
@@ -254,6 +260,7 @@ def run_test(
254260
rtol=1e-03,
255261
atol=1e-03,
256262
precision=torch.float,
263+
check_dtype=True,
257264
):
258265
mod.eval()
259266
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
@@ -267,7 +274,15 @@ def run_test(
267274
Input.from_tensors(inputs),
268275
)
269276
super().run_test(
270-
mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision
277+
mod,
278+
inputs,
279+
expected_ops,
280+
unexpected_ops,
281+
interp,
282+
rtol,
283+
atol,
284+
precision,
285+
check_dtype,
271286
)
272287

273288
def run_test_with_dynamic_shape(

tests/py/dynamo/backend/test_backend_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def test_int64_input_partial_support(self):
222222
class PartiallySupportedMultiOp(torch.nn.Module):
223223
def forward(self, x, y):
224224
return torch.ops.aten.div.Tensor_mode(
225-
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
225+
x, torch.ops.aten.add.Tensor(y, y), rounding_mode=None
226226
)
227227

228228
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())

tests/py/dynamo/backend/test_decompositions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def forward(self, x):
226226
# Validate that the results between Torch and Torch-TRT are similar
227227
optimized_model = torch_tensorrt.compile(
228228
fx_graph,
229-
"torch_tensorrt",
229+
"torch_compile",
230230
inputs,
231231
min_block_size=1,
232232
pass_through_build_failures=True,

tests/py/dynamo/converters/test_binary_ops_aten.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Callable
2+
import unittest
23

34
import torch
45
import torch.nn as nn
@@ -76,6 +77,7 @@ def forward(self, x):
7677
self.run_test(m, inputs, expected_ops={expected_op})
7778

7879
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])
80+
@unittest.skip("Pending reimplementation of all binary converters in Dynamo")
7981
def test_elementwise_ops_mismatched_dtypes(
8082
self, name, orig_op: Callable, expected_op
8183
):
@@ -84,12 +86,15 @@ def __init__(self, orig_op):
8486
super().__init__()
8587
self.orig_op = orig_op
8688

87-
def forward(self, x):
88-
return self.orig_op(x.int(), x)
89+
def forward(self, x, y):
90+
return self.orig_op(x, y)
8991

9092
m = TestModule(orig_op)
9193
# Avoid dividing by 0.
92-
inputs = [2 * torch.rand(1, 1, dtype=torch.float) + 1]
94+
inputs = [
95+
2 * torch.rand(1, 1, dtype=torch.float) + 1,
96+
torch.randint(1, 3, (1, 1), dtype=torch.int),
97+
]
9398
self.run_test(m, inputs, expected_ops={expected_op})
9499

95100
@parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops])

tests/py/dynamo/converters/test_embedding_aten.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ def forward(self, input, weights):
7878
input_specs = [
7979
Input(
8080
shape=(-1, -1, -1, -1),
81-
dtype=torch.float32,
81+
dtype=torch.int,
8282
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
8383
),
8484
Input(
85-
shape=(-1, -1, -1, -1),
85+
shape=(-1, -1),
8686
dtype=torch.float32,
87-
shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))],
87+
shape_ranges=[((1, 1), (2, 3), (2, 3))],
8888
),
8989
]
9090

tests/py/dynamo/converters/test_sigmoid_aten.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def forward(self, x):
6060
inputs,
6161
expected_ops={torch.ops.aten.sigmoid.default},
6262
precision=torch.half,
63+
check_dtype=False,
6364
)
6465

6566

tests/py/dynamo/models/test_models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_resnet18(ir):
3232
"ir": ir,
3333
"pass_through_build_failures": True,
3434
"optimization_level": 1,
35-
"min_block_size": 8,
35+
"min_block_size": 10,
3636
"ir": "torch_compile",
3737
}
3838

@@ -66,7 +66,7 @@ def test_mobilenet_v2(ir):
6666
"ir": ir,
6767
"pass_through_build_failures": True,
6868
"optimization_level": 1,
69-
"min_block_size": 8,
69+
"min_block_size": 10,
7070
"ir": "torch_compile",
7171
}
7272

@@ -100,7 +100,7 @@ def test_efficientnet_b0(ir):
100100
"ir": ir,
101101
"pass_through_build_failures": True,
102102
"optimization_level": 1,
103-
"min_block_size": 8,
103+
"min_block_size": 10,
104104
"ir": "torch_compile",
105105
}
106106

@@ -143,7 +143,7 @@ def test_bert_base_uncased(ir):
143143
"ir": ir,
144144
"pass_through_build_failures": True,
145145
"optimization_level": 1,
146-
"min_block_size": 8,
146+
"min_block_size": 10,
147147
"ir": "torch_compile",
148148
}
149149
trt_mod = torchtrt.compile(model, **compile_spec)
@@ -181,7 +181,7 @@ def test_resnet18_half(ir):
181181
"ir": ir,
182182
"pass_through_build_failures": True,
183183
"optimization_level": 1,
184-
"min_block_size": 8,
184+
"min_block_size": 10,
185185
"ir": "torch_compile",
186186
}
187187

0 commit comments

Comments
 (0)