Skip to content

Commit fdaba9a

Browse files
authored
feat: Support aten.dot dynamo converter (#3043)
1 parent 8ecc809 commit fdaba9a

File tree

4 files changed

+44
-7
lines changed

4 files changed

+44
-7
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def aten_ops_hard_sigmoid(
548548

549549

550550
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
551+
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
551552
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
552553
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)
553554
@dynamo_tensorrt_converter(torch.ops.aten.bmm.default, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt import _enums
67
from torch_tensorrt.dynamo._SourceIR import SourceIR
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
9-
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
10-
from torch_tensorrt.fx.types import TRTTensor
11-
12-
import tensorrt as trt
9+
from torch_tensorrt.dynamo.conversion.converter_utils import (
10+
broadcast,
11+
get_trt_tensor,
12+
set_layer_name,
13+
)
14+
from torch_tensorrt.dynamo.types import TRTTensor
1315

1416

1517
def matrix_multiply(
@@ -43,7 +45,7 @@ def matrix_multiply(
4345
other_matrix_op = trt.MatrixOperation.VECTOR
4446

4547
input, other = broadcast(
46-
ctx.net, input, other, f"{name}_input", f"{name}_other", preset_diff
48+
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
4749
)
4850
layer = ctx.net.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op)
4951
set_layer_name(layer, target, name, source_ir)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
aten.detach,
3434
aten.diag_embed,
3535
aten.diagonal_backward,
36-
aten.dot,
3736
aten.elu_backward,
3837
aten.embedding_dense_backward,
3938
aten.empty_like,

tests/py/dynamo/conversion/test_matmul_aten.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,41 @@
88

99

1010
class TestMatMulConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
(
14+
"1_1",
15+
(1,),
16+
(1,),
17+
),
18+
(
19+
"1_1",
20+
(2,),
21+
(2,),
22+
),
23+
(
24+
"1_1",
25+
(3,),
26+
(3,),
27+
),
28+
]
29+
)
30+
def test_matmul_dot(self, _, input_shape, other_shape):
31+
class MatMul(nn.Module):
32+
def __init__(self):
33+
super().__init__()
34+
self.other = nn.Parameter(torch.randn(*other_shape))
35+
36+
def forward(self, input):
37+
return torch.ops.aten.dot.default(input, self.other)
38+
39+
inputs = [torch.randn(*input_shape)]
40+
41+
self.run_test(
42+
MatMul(),
43+
inputs,
44+
)
45+
1146
@parameterized.expand(
1247
[
1348
(

0 commit comments

Comments
 (0)