Skip to content

add initial support for torch.ops.aten.neg.default converter #2147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,29 @@ def aten_ops_rsqrt(
)


@dynamo_tensorrt_converter(torch.ops.aten.neg.default)
def aten_ops_neg(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_val = args[0]
if (isinstance(input_val, TRTTensor)) and (
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
):
input_val = cast_trt_tensor(network, input_val, trt.float32, name)

return impl.unary.neg(
network,
target,
SourceIR.ATEN,
name,
input_val,
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) # type: ignore[misc]
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) # type: ignore[misc]
def aten_ops_squeeze(
Expand Down
12 changes: 12 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,15 @@ def isinf(
return convert_unary(
network, target, source_ir, name, trt.UnaryOperation.ISINF, input_val
)


def neg(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
return convert_unary(
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
)
52 changes: 52 additions & 0 deletions tests/py/dynamo/converters/test_neg_aten.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider renaming this file test_unary_aten.py, so any subsequent unary op tests can go in the same file

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input
from torch_tensorrt.dynamo.test_utils import DispatchTestCase


class TestNegConverter(DispatchTestCase):
@parameterized.expand(
[
("2d_dim_dtype_float", (2, 2), torch.float),
("3d_dim_dtype_float", (2, 2, 2), torch.float),
("2d_dim_dtype_half", (2, 2), torch.half),
("3d_dim_dtype_half", (2, 2, 2), torch.half),
]
)
def test_neg_float(self, _, x, type):
class neg(nn.Module):
def forward(self, input):
return torch.neg(input)

inputs = [torch.randn(x, dtype=type)]
self.run_test(
neg(),
inputs,
precision=type,
expected_ops={torch.ops.aten.neg.default},
)

@parameterized.expand(
[
("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5),
("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5),
]
)
def test_neg_int(self, _, x, type, min, max):
class neg(nn.Module):
def forward(self, input):
return torch.neg(input)

inputs = [torch.randint(min, max, x, dtype=type)]
self.run_test(
neg(),
inputs,
output_dtypes=[torch.int32],
expected_ops={torch.ops.aten.neg.default},
)


if __name__ == "__main__":
run_tests()