Skip to content

Commit 077e46d

Browse files
authored
feat: dynamic shape support for pow/mod/eq operator (#2982)
1 parent 853556e commit 077e46d

File tree

4 files changed

+175
-2
lines changed

4 files changed

+175
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,6 +2002,7 @@ def aten_ops_div(
20022002
@dynamo_tensorrt_converter(
20032003
torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True
20042004
)
2005+
@dynamo_tensorrt_converter(operator.pow, supports_dynamic_shapes=True)
20052006
def aten_ops_pow(
20062007
ctx: ConversionContext,
20072008
target: Target,
@@ -2278,6 +2279,7 @@ def aten_ops_bitwise_not(
22782279

22792280
@dynamo_tensorrt_converter(torch.ops.aten.eq.Tensor, supports_dynamic_shapes=True)
22802281
@dynamo_tensorrt_converter(torch.ops.aten.eq.Scalar, supports_dynamic_shapes=True)
2282+
@dynamo_tensorrt_converter(operator.eq, supports_dynamic_shapes=True)
22812283
@enforce_tensor_types(
22822284
{
22832285
0: (TRTTensor,),
@@ -3149,8 +3151,13 @@ def aten_ops_copy(
31493151
)
31503152

31513153

3152-
@dynamo_tensorrt_converter(torch.ops.aten.remainder.Scalar)
3153-
@dynamo_tensorrt_converter(torch.ops.aten.remainder.Tensor)
3154+
@dynamo_tensorrt_converter(
3155+
torch.ops.aten.remainder.Scalar, supports_dynamic_shapes=True
3156+
)
3157+
@dynamo_tensorrt_converter(
3158+
torch.ops.aten.remainder.Tensor, supports_dynamic_shapes=True
3159+
)
3160+
@dynamo_tensorrt_converter(operator.mod, supports_dynamic_shapes=True)
31543161
@enforce_tensor_types(
31553162
{
31563163
0: (TRTTensor,),

tests/py/dynamo/conversion/test_eq_aten.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,54 @@ def forward(self, lhs_val):
146146
input_specs,
147147
)
148148

149+
@parameterized.expand(
150+
[
151+
((1,), (3,), (5,)),
152+
((1, 20), (2, 20), (3, 20)),
153+
((2, 3, 4), (3, 4, 5), (4, 5, 6)),
154+
((2, 3, 4, 5), (3, 5, 5, 6), (4, 5, 6, 7)),
155+
]
156+
)
157+
def test_eq_operator_dynamic_shape(self, min_shape, opt_shape, max_shape):
158+
class eq_tensor_operator(nn.Module):
159+
def forward(self, lhs_val, rhs_val):
160+
return lhs_val == rhs_val
161+
162+
class eq_tensor_scalar_operator(nn.Module):
163+
def forward(self, lhs_val, rhs_val):
164+
return lhs_val == torch.tensor(1)
165+
166+
class eq_scalar_operator(nn.Module):
167+
def forward(self, lhs_val, rhs_val):
168+
return lhs_val == 1.0
169+
170+
input_specs = [
171+
Input(
172+
dtype=torch.float32,
173+
min_shape=min_shape,
174+
opt_shape=opt_shape,
175+
max_shape=max_shape,
176+
),
177+
Input(
178+
dtype=torch.float32,
179+
min_shape=min_shape,
180+
opt_shape=opt_shape,
181+
max_shape=max_shape,
182+
),
183+
]
184+
self.run_test_with_dynamic_shape(
185+
eq_tensor_operator(),
186+
input_specs,
187+
)
188+
self.run_test_with_dynamic_shape(
189+
eq_tensor_scalar_operator(),
190+
input_specs,
191+
)
192+
self.run_test_with_dynamic_shape(
193+
eq_scalar_operator(),
194+
input_specs,
195+
)
196+
149197

150198
if __name__ == "__main__":
151199
run_tests()

tests/py/dynamo/conversion/test_pow_aten.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,65 @@ def forward(self, lhs_val):
5959
inputs,
6060
)
6161

62+
@parameterized.expand(
63+
[
64+
(
65+
"2d_dim_dtype_half",
66+
(1, 1),
67+
(2, 2),
68+
(4, 4),
69+
torch.half,
70+
torch.half,
71+
),
72+
(
73+
"3d_dim_dtype_float",
74+
(1, 1, 1),
75+
(1, 2, 3),
76+
(3, 3, 3),
77+
torch.float,
78+
torch.float,
79+
),
80+
]
81+
)
82+
def test_pow_dynamic_shape(
83+
self, _, min_shape, opt_shape, max_shape, type, output_type
84+
):
85+
class pow(nn.Module):
86+
def forward(self, lhs_val, rhs_val):
87+
return torch.ops.aten.pow.Tensor_Tensor(lhs_val, rhs_val)
88+
89+
class pow_scalar(nn.Module):
90+
def forward(self, lhs_val, rhs_val):
91+
return torch.ops.aten.pow.Tensor_Scalar(lhs_val, 2.0)
92+
93+
class pow_operator(nn.Module):
94+
def forward(self, lhs_val, rhs_val):
95+
return lhs_val**rhs_val
96+
97+
input_specs = [
98+
Input(
99+
min_shape=min_shape,
100+
opt_shape=opt_shape,
101+
max_shape=max_shape,
102+
dtype=type,
103+
),
104+
Input(
105+
min_shape=min_shape,
106+
opt_shape=opt_shape,
107+
max_shape=max_shape,
108+
dtype=type,
109+
),
110+
]
111+
self.run_test_with_dynamic_shape(
112+
pow(), input_specs, output_dtypes=[output_type]
113+
)
114+
self.run_test_with_dynamic_shape(
115+
pow_scalar(), input_specs, output_dtypes=[output_type]
116+
)
117+
self.run_test_with_dynamic_shape(
118+
pow_operator(), input_specs, output_dtypes=[output_type]
119+
)
120+
62121

63122
if __name__ == "__main__":
64123
run_tests()

tests/py/dynamo/conversion/test_remainder_aten.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,65 @@ def forward(self, lhs_val, rhs_val):
5555
inputs,
5656
)
5757

58+
@parameterized.expand(
59+
[
60+
(
61+
"2d_dim_dtype_half",
62+
(1, 1),
63+
(2, 2),
64+
(4, 4),
65+
torch.half,
66+
torch.half,
67+
),
68+
(
69+
"3d_dim_dtype_float",
70+
(1, 1, 1),
71+
(1, 2, 3),
72+
(3, 3, 3),
73+
torch.float,
74+
torch.float,
75+
),
76+
]
77+
)
78+
def test_remainder_dynamic_shape(
79+
self, _, min_shape, opt_shape, max_shape, type, output_type
80+
):
81+
class remainder(nn.Module):
82+
def forward(self, lhs_val, rhs_val):
83+
return torch.ops.aten.remainder.Tensor(lhs_val, rhs_val)
84+
85+
class remainder_scalar(nn.Module):
86+
def forward(self, lhs_val, rhs_val):
87+
return torch.ops.aten.remainder.Scalar(lhs_val, 2)
88+
89+
class mod_operator(nn.Module):
90+
def forward(self, lhs_val, rhs_val):
91+
return lhs_val % rhs_val
92+
93+
input_specs = [
94+
Input(
95+
min_shape=min_shape,
96+
opt_shape=opt_shape,
97+
max_shape=max_shape,
98+
dtype=type,
99+
),
100+
Input(
101+
min_shape=min_shape,
102+
opt_shape=opt_shape,
103+
max_shape=max_shape,
104+
dtype=type,
105+
),
106+
]
107+
self.run_test_with_dynamic_shape(
108+
remainder(), input_specs, output_dtypes=[output_type]
109+
)
110+
self.run_test_with_dynamic_shape(
111+
remainder_scalar(), input_specs, output_dtypes=[output_type]
112+
)
113+
self.run_test_with_dynamic_shape(
114+
mod_operator(), input_specs, output_dtypes=[output_type]
115+
)
116+
58117

59118
if __name__ == "__main__":
60119
run_tests()

0 commit comments

Comments
 (0)