Skip to content

feat: support aten.scalar_tensor dynamo converter #2595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2649,3 +2649,16 @@ def aten_ops_flip(
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
def aten_ops_scalar_tensor(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.scalar_tensor(
ctx, target, SourceIR.ATEN, name, args[0], dtype=kwargs.get("dtype")
)
21 changes: 19 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional
from typing import Optional, Union

import numpy as np
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -10,7 +12,8 @@
get_trt_tensor,
)
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTDataType, TRTTensor


def exp(
Expand Down Expand Up @@ -459,3 +462,17 @@ def trunc(
return impl.elementwise.trunc_div(
ctx, target, source_ir, f"{name}_trunc", input_val, dividend
)


def scalar_tensor(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
scalar: Union[int, float, bool],
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
) -> TRTTensor:
tensor = get_trt_tensor(ctx, scalar, f"{name}_scalar_tensor", dtype)
identity_layer = ctx.net.add_identity(tensor)
set_layer_name(identity_layer, target, name, source_ir)
return identity_layer.get_output(0)
95 changes: 95 additions & 0 deletions tests/py/dynamo/conversion/test_scalar_tensor_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestScalarTensorConverter(DispatchTestCase):
@parameterized.expand(
[
(-2.00001,),
(-1.3,),
(-0.0,),
(1.0,),
(2.99,),
]
)
def test_scalar_tensor_float(self, scalar):
class ScalarTensor(nn.Module):
def forward(self):
return torch.ops.aten.scalar_tensor.default(scalar)

inputs = []
self.run_test(
ScalarTensor(),
inputs,
)

@parameterized.expand(
[
(-9999,),
(-1,),
(0,),
(2,),
(99999,),
]
)
def test_scalar_tensor_int(self, scalar):
class ScalarTensor(nn.Module):
def forward(self):
return torch.ops.aten.scalar_tensor.default(scalar)

inputs = []
self.run_test(
ScalarTensor(),
inputs,
)

@parameterized.expand(
[
(True,),
(False,),
]
)
def test_scalar_tensor_bool(self, scalar):
class ScalarTensor(nn.Module):
def forward(self):
return torch.ops.aten.scalar_tensor.default(scalar)

inputs = []
self.run_test(
ScalarTensor(),
inputs,
)

@parameterized.expand(
[
(-9999, torch.int),
(-2.00001, torch.float),
(-1, torch.float),
(0, torch.int),
(-0.0, torch.float),
(1.0, torch.int),
(2.99, torch.float),
(9999999, None),
(9999999.99999, None),
(True, torch.bool),
]
)
def test_scalar_tensor_dtype(self, scalar, dtype):
class ScalarTensor(nn.Module):
def forward(self):
return torch.ops.aten.scalar_tensor.default(scalar, dtype=dtype)

inputs = []
self.run_test(
ScalarTensor(),
inputs,
output_dtypes=None if dtype is None else [dtype],
)


if __name__ == "__main__":
run_tests()