Skip to content

Commit 20620a9

Browse files
committed
feat: support aten.scalar_tensor dynamo converter
1 parent b8403b8 commit 20620a9

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2626,3 +2626,20 @@ def aten_ops_pdist(
26262626
args[0],
26272627
args_bounds_check(args, 1, 2),
26282628
)
2629+
2630+
2631+
@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
2632+
def aten_ops_scalar_tensor(
2633+
ctx: ConversionContext,
2634+
target: Target,
2635+
args: Tuple[Argument, ...],
2636+
kwargs: Dict[str, Argument],
2637+
name: str,
2638+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2639+
return impl.unary.scalar_tensor(
2640+
ctx,
2641+
target,
2642+
SourceIR.ATEN,
2643+
name,
2644+
args[0],
2645+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
import tensorrt as trt
44
import torch_tensorrt.dynamo.conversion.impl as impl
@@ -10,6 +10,7 @@
1010
get_trt_tensor,
1111
)
1212
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
13+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1314
from torch_tensorrt.fx.types import TRTTensor
1415

1516

@@ -459,3 +460,16 @@ def trunc(
459460
return impl.elementwise.trunc_div(
460461
ctx, target, source_ir, f"{name}_trunc", input_val, dividend
461462
)
463+
464+
465+
def scalar_tensor(
466+
ctx: ConversionContext,
467+
target: Target,
468+
source_ir: Optional[SourceIR],
469+
name: str,
470+
scalar: Union[int, float, bool],
471+
) -> TRTTensor:
472+
tensor = get_trt_tensor(ctx, scalar, f"{name}_scalar_tensor")
473+
identity_layer = ctx.net.add_identity(tensor)
474+
set_layer_name(identity_layer, target, name, source_ir)
475+
return identity_layer.get_output(0)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestScalarTensorConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(-2.00001,),
13+
(-1.3,),
14+
(-0.0,),
15+
(1.0,),
16+
(2.99,),
17+
]
18+
)
19+
def test_scalar_tensor_float(self, scalar):
20+
class ScalarTensor(nn.Module):
21+
def forward(self):
22+
return torch.ops.aten.scalar_tensor.default(scalar)
23+
24+
inputs = []
25+
self.run_test(
26+
ScalarTensor(),
27+
inputs,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
(-9999,),
33+
(-1,),
34+
(0,),
35+
(2,),
36+
(99999,),
37+
]
38+
)
39+
def test_scalar_tensor_int(self, scalar):
40+
class ScalarTensor(nn.Module):
41+
def forward(self):
42+
return torch.ops.aten.scalar_tensor.default(scalar)
43+
44+
inputs = []
45+
self.run_test(
46+
ScalarTensor(),
47+
inputs,
48+
)
49+
50+
@parameterized.expand(
51+
[
52+
(True,),
53+
(False,),
54+
]
55+
)
56+
def test_scalar_tensor_bool(self, scalar):
57+
class ScalarTensor(nn.Module):
58+
def forward(self):
59+
return torch.ops.aten.scalar_tensor.default(scalar)
60+
61+
inputs = []
62+
self.run_test(
63+
ScalarTensor(),
64+
inputs,
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

0 commit comments

Comments
 (0)