diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 16dcee707c..ca650b09f6 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -354,6 +354,29 @@ def aten_ops_rsqrt( ) +@dynamo_tensorrt_converter(torch.ops.aten.neg.default) +def aten_ops_neg( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + input_val = args[0] + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) + + return impl.unary.neg( + network, + target, + SourceIR.ATEN, + name, + input_val, + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) # type: ignore[misc] @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) # type: ignore[misc] def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 4a6380c964..f5f6309657 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -384,3 +384,15 @@ def isinf( return convert_unary( network, target, source_ir, name, trt.UnaryOperation.ISINF, input_val ) + + +def neg( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + return convert_unary( + network, target, source_ir, name, trt.UnaryOperation.NEG, input_val + ) diff --git a/tests/py/dynamo/converters/test_neg_aten.py b/tests/py/dynamo/converters/test_neg_aten.py new file mode 100644 index 0000000000..d5d805f9c2 --- /dev/null +++ b/tests/py/dynamo/converters/test_neg_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestNegConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_dtype_float", (2, 2), torch.float), + ("3d_dim_dtype_float", (2, 2, 2), torch.float), + ("2d_dim_dtype_half", (2, 2), torch.half), + ("3d_dim_dtype_half", (2, 2, 2), torch.half), + ] + ) + def test_neg_float(self, _, x, type): + class neg(nn.Module): + def forward(self, input): + return torch.neg(input) + + inputs = [torch.randn(x, dtype=type)] + self.run_test( + neg(), + inputs, + precision=type, + expected_ops={torch.ops.aten.neg.default}, + ) + + @parameterized.expand( + [ + ("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5), + ("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5), + ] + ) + def test_neg_int(self, _, x, type, min, max): + class neg(nn.Module): + def forward(self, input): + return torch.neg(input) + + inputs = [torch.randint(min, max, x, dtype=type)] + self.run_test( + neg(), + inputs, + output_dtypes=[torch.int32], + expected_ops={torch.ops.aten.neg.default}, + ) + + +if __name__ == "__main__": + run_tests()