diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 7f472261db..4a7c780a6b 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2649,3 +2649,16 @@ def aten_ops_flip( args[0], args[1], ) + + +@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default) +def aten_ops_scalar_tensor( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.scalar_tensor( + ctx, target, SourceIR.ATEN, name, args[0], dtype=kwargs.get("dtype") + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 9ed5d0636d..fc6c737e79 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -1,6 +1,8 @@ -from typing import Optional +from typing import Optional, Union +import numpy as np import tensorrt as trt +import torch import torch_tensorrt.dynamo.conversion.impl as impl from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -10,7 +12,8 @@ get_trt_tensor, ) from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary -from torch_tensorrt.fx.types import TRTTensor +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTDataType, TRTTensor def exp( @@ -459,3 +462,17 @@ def trunc( return impl.elementwise.trunc_div( ctx, target, source_ir, f"{name}_trunc", input_val, dividend ) + + +def scalar_tensor( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + scalar: Union[int, float, bool], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, +) -> TRTTensor: + tensor = get_trt_tensor(ctx, scalar, f"{name}_scalar_tensor", dtype) + identity_layer = ctx.net.add_identity(tensor) + set_layer_name(identity_layer, target, name, source_ir) + return identity_layer.get_output(0) diff --git a/tests/py/dynamo/conversion/test_scalar_tensor_aten.py b/tests/py/dynamo/conversion/test_scalar_tensor_aten.py new file mode 100644 index 0000000000..28c3d7f481 --- /dev/null +++ b/tests/py/dynamo/conversion/test_scalar_tensor_aten.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestScalarTensorConverter(DispatchTestCase): + @parameterized.expand( + [ + (-2.00001,), + (-1.3,), + (-0.0,), + (1.0,), + (2.99,), + ] + ) + def test_scalar_tensor_float(self, scalar): + class ScalarTensor(nn.Module): + def forward(self): + return torch.ops.aten.scalar_tensor.default(scalar) + + inputs = [] + self.run_test( + ScalarTensor(), + inputs, + ) + + @parameterized.expand( + [ + (-9999,), + (-1,), + (0,), + (2,), + (99999,), + ] + ) + def test_scalar_tensor_int(self, scalar): + class ScalarTensor(nn.Module): + def forward(self): + return torch.ops.aten.scalar_tensor.default(scalar) + + inputs = [] + self.run_test( + ScalarTensor(), + inputs, + ) + + @parameterized.expand( + [ + (True,), + (False,), + ] + ) + def test_scalar_tensor_bool(self, scalar): + class ScalarTensor(nn.Module): + def forward(self): + return torch.ops.aten.scalar_tensor.default(scalar) + + inputs = [] + self.run_test( + ScalarTensor(), + inputs, + ) + + @parameterized.expand( + [ + (-9999, torch.int), + (-2.00001, torch.float), + (-1, torch.float), + (0, torch.int), + (-0.0, torch.float), + (1.0, torch.int), + (2.99, torch.float), + (9999999, None), + (9999999.99999, None), + (True, torch.bool), + ] + ) + def test_scalar_tensor_dtype(self, scalar, dtype): + class ScalarTensor(nn.Module): + def forward(self): + return torch.ops.aten.scalar_tensor.default(scalar, dtype=dtype) + + inputs = [] + self.run_test( + ScalarTensor(), + inputs, + output_dtypes=None if dtype is None else [dtype], + ) + + +if __name__ == "__main__": + run_tests()