diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index f5f6309657..a91efac621 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -393,6 +393,11 @@ def neg( name: str, input_val: TRTTensor, ) -> TRTTensor: + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) + return convert_unary( network, target, source_ir, name, trt.UnaryOperation.NEG, input_val ) diff --git a/tests/py/dynamo/converters/test_neg_aten.py b/tests/py/dynamo/conversion/test_neg_aten.py similarity index 93% rename from tests/py/dynamo/converters/test_neg_aten.py rename to tests/py/dynamo/conversion/test_neg_aten.py index d5d805f9c2..bcb95b4172 100644 --- a/tests/py/dynamo/converters/test_neg_aten.py +++ b/tests/py/dynamo/conversion/test_neg_aten.py @@ -3,7 +3,8 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from torch_tensorrt.dynamo.test_utils import DispatchTestCase + +from .harness import DispatchTestCase class TestNegConverter(DispatchTestCase): @@ -43,8 +44,8 @@ def forward(self, input): self.run_test( neg(), inputs, - output_dtypes=[torch.int32], expected_ops={torch.ops.aten.neg.default}, + check_dtype=False, )