Skip to content

Commit a5f4d5b

Browse files
Add dynamic shape support for cumsum/grid (#3051)
1 parent ee16bad commit a5f4d5b

File tree

4 files changed

+209
-13
lines changed

4 files changed

+209
-13
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,10 +331,14 @@ def aten_ops_fmod(
331331
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])
332332

333333

334-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
335-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
336-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.default)
337-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.default)
334+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler, supports_dynamic_shapes=True)
335+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d, supports_dynamic_shapes=True)
336+
@dynamo_tensorrt_converter(
337+
torch.ops.aten.grid_sampler.default, supports_dynamic_shapes=True
338+
)
339+
@dynamo_tensorrt_converter(
340+
torch.ops.aten.grid_sampler_2d.default, supports_dynamic_shapes=True
341+
)
338342
@enforce_tensor_types(
339343
{
340344
0: (TRTTensor,),
@@ -922,7 +926,7 @@ def aten_ops_chunk(
922926
)
923927

924928

925-
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default)
929+
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
926930
@enforce_tensor_types(
927931
{
928932
0: (TRTTensor,),

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -387,17 +387,46 @@ def cumsum(
387387
input: TRTTensor,
388388
dim: int,
389389
) -> TRTTensor:
390+
390391
input_shape = input.shape
391392
dim = get_positive_dim(dim, len(input_shape))
393+
if input_shape[dim] < 0:
394+
trip_limit = impl.shape.shape(
395+
ctx, target, source_ir, name + "_shape", input, dim
396+
)
397+
# the trip_limit has to be a 0D shape tensor, however this impl.shape.shape gives a 1D shape
398+
# for example if the trip limit is 3, it wants a tensor(3), not a tensor([3])
399+
# in order to reduce it from 1D to 0D, i have to use this impl.reduce.sum
400+
trip_limit = impl.reduce.sum(
401+
ctx, target, source_ir, name, trip_limit, 0, keepdim=False
402+
)
403+
else:
404+
axis = np.array(input_shape[dim])
405+
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
406+
392407
loop = ctx.net.add_loop()
393-
axis = np.array(input_shape[dim])
394-
trip_limit = get_trt_tensor(ctx, axis, f"{name}_trip_limit")
395408
loop.add_trip_limit(trip_limit, trt.TripLimit.COUNT)
396409
iterator = loop.add_iterator(input, dim, reverse=False)
397410
data = iterator.get_output(0)
398-
new_dims = tuple(data.shape)
399-
zeros = np.zeros(new_dims)
400-
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
411+
if has_dynamic_shape(data.shape):
412+
data_shape = []
413+
for i in range(len(input_shape)):
414+
if i != dim:
415+
if input_shape[i] < 0:
416+
data_shape.append(
417+
impl.shape.shape(
418+
ctx, target, source_ir, name + f"_{i}_shape", input, i
419+
)
420+
)
421+
else:
422+
data_shape.append(input_shape[i])
423+
zero_trttensor = impl.full.full(
424+
ctx, target, source_ir, name + "_full", data_shape, 0.0
425+
)
426+
else:
427+
new_dims = tuple(data.shape)
428+
zeros = np.zeros(new_dims)
429+
zero_trttensor = get_trt_tensor(ctx, zeros, f"{name}_initial_value")
401430

402431
running_sum = loop.add_recurrence(zero_trttensor)
403432
set_layer_name(running_sum, target, f"{name}_running_sum", source_ir)

tests/py/dynamo/conversion/test_cumsum_aten.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
import torch_tensorrt
34
from parameterized import parameterized
45
from torch.testing._internal.common_utils import run_tests
56

@@ -46,7 +47,7 @@ def forward(self, x):
4647

4748
@parameterized.expand(
4849
[
49-
((4, 2, 3), 0),
50+
((2, 3, 3), 0),
5051
((4, 2, 3), 1),
5152
((1, 2, 3), 2),
5253
((1, 2, 3), -1),
@@ -64,6 +65,35 @@ def forward(self, x):
6465
inputs,
6566
)
6667

68+
@parameterized.expand(
69+
[
70+
((1,), (2,), (3,), 0),
71+
((1,), (2,), (3,), -1),
72+
((2, 3), (2, 4), (2, 5), 0),
73+
((2, 3), (3, 4), (4, 5), -1),
74+
((1, 2, 2), (2, 2, 3), (3, 3, 3), 0),
75+
((1, 2, 2), (2, 2, 3), (3, 2, 3), -2),
76+
((1, 2, 2, 3), (2, 2, 3, 4), (3, 3, 4, 5), -3),
77+
((1, 2, 2, 3), (2, 2, 3, 4), (3, 3, 4, 5), -2),
78+
]
79+
)
80+
def test_cumsum_dynamic_shape(self, min_shape, opt_shape, max_shape, dims):
81+
class Cumsum(nn.Module):
82+
def forward(self, x):
83+
return torch.ops.aten.cumsum.default(x, dims)
84+
85+
inputs = [
86+
torch_tensorrt.Input(
87+
min_shape=min_shape,
88+
opt_shape=opt_shape,
89+
max_shape=max_shape,
90+
),
91+
]
92+
self.run_test_with_dynamic_shape(
93+
Cumsum(),
94+
inputs,
95+
)
96+
6797

6898
if __name__ == "__main__":
6999
run_tests()

tests/py/dynamo/conversion/test_grid_aten.py

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import pytest
22
import torch
33
import torch.nn as nn
4-
from .harness import DispatchTestCase
4+
import torch_tensorrt
55
from parameterized import parameterized
66
from torch.testing._internal.common_utils import run_tests
7-
from torch_tensorrt import Input
7+
8+
from .harness import DispatchTestCase
89

910
grid_sampler_aten_ops = {
1011
"torch.ops.aten.grid_sampler": torch.ops.aten.grid_sampler,
@@ -185,6 +186,138 @@ def forward(self, x):
185186
grid_model = TestModule(op)
186187
self.run_test(grid_model, inputs)
187188

189+
@parameterized.expand(
190+
[
191+
(
192+
(1, 1, 2, 2),
193+
(2, 2, 3, 3),
194+
(3, 3, 5, 5),
195+
(1, 2, 2, 2),
196+
(2, 3, 3, 2),
197+
(3, 5, 5, 2),
198+
0,
199+
0,
200+
True,
201+
),
202+
(
203+
(1, 1, 2, 2),
204+
(2, 2, 3, 3),
205+
(3, 3, 5, 5),
206+
(1, 2, 2, 2),
207+
(2, 3, 3, 2),
208+
(3, 5, 5, 2),
209+
0,
210+
2,
211+
True,
212+
),
213+
(
214+
(1, 1, 2, 2),
215+
(1, 1, 3, 3),
216+
(1, 1, 5, 5),
217+
(1, 3, 3, 2),
218+
(1, 4, 4, 2),
219+
(1, 5, 5, 2),
220+
0,
221+
1,
222+
True,
223+
),
224+
(
225+
(1, 1, 2, 2),
226+
(2, 2, 3, 3),
227+
(3, 3, 5, 5),
228+
(1, 4, 2, 2),
229+
(2, 4, 3, 2),
230+
(3, 4, 5, 2),
231+
1,
232+
0,
233+
True,
234+
),
235+
(
236+
(1, 1, 2, 2),
237+
(2, 2, 3, 3),
238+
(3, 3, 5, 5),
239+
(1, 4, 2, 2),
240+
(2, 5, 3, 2),
241+
(3, 5, 5, 2),
242+
1,
243+
1,
244+
False,
245+
),
246+
]
247+
)
248+
def test_grid_2d_default_dynamic_shape(
249+
self,
250+
input_min_shape,
251+
input_opt_shape,
252+
input_max_shape,
253+
grid_min_shape,
254+
grid_opt_shape,
255+
grid_max_shape,
256+
interpolation_mode,
257+
padding_mode,
258+
align_corners,
259+
):
260+
class Grid_SAMPLER_2D(nn.Module):
261+
def forward(self, input, grid):
262+
return torch.ops.aten.grid_sampler_2d(
263+
input, grid, interpolation_mode, padding_mode, align_corners
264+
)
265+
266+
class Grid_SAMPLER_2D_default(nn.Module):
267+
def forward(self, input, grid):
268+
return torch.ops.aten.grid_sampler_2d.default(
269+
input, grid, interpolation_mode, padding_mode, align_corners
270+
)
271+
272+
class Grid_SAMPLER(nn.Module):
273+
def forward(self, input, grid):
274+
return torch.ops.aten.grid_sampler(
275+
input, grid, interpolation_mode, padding_mode, align_corners
276+
)
277+
278+
class Grid_SAMPLER_default(nn.Module):
279+
def forward(self, input, grid):
280+
return torch.ops.aten.grid_sampler.default(
281+
input, grid, interpolation_mode, padding_mode, align_corners
282+
)
283+
284+
inputs = [
285+
torch_tensorrt.Input(
286+
min_shape=input_min_shape,
287+
opt_shape=input_opt_shape,
288+
max_shape=input_max_shape,
289+
dtype=torch.float32,
290+
torch_tensorrt=torch.randn(input_opt_shape, dtype=torch.float32),
291+
),
292+
torch_tensorrt.Input(
293+
min_shape=grid_min_shape,
294+
opt_shape=grid_opt_shape,
295+
max_shape=grid_max_shape,
296+
dtype=torch.float32,
297+
torch_tensor=torch.randint(-1, 1, grid_opt_shape, dtype=torch.float32),
298+
),
299+
]
300+
self.run_test_with_dynamic_shape(
301+
Grid_SAMPLER_2D(),
302+
inputs,
303+
use_example_tensors=False,
304+
)
305+
self.run_test_with_dynamic_shape(
306+
Grid_SAMPLER_2D_default(),
307+
inputs,
308+
use_example_tensors=False,
309+
)
310+
self.run_test_with_dynamic_shape(
311+
Grid_SAMPLER(),
312+
inputs,
313+
use_example_tensors=False,
314+
)
315+
self.run_test_with_dynamic_shape(
316+
Grid_SAMPLER_default(),
317+
inputs,
318+
use_example_tensors=False,
319+
)
320+
188321

189322
if __name__ == "__main__":
190323
run_tests()

0 commit comments

Comments
 (0)