From afcedbea6fcf592ec86f6eacf1087183e36a1152 Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Tue, 26 Mar 2024 17:00:08 +0900 Subject: [PATCH 1/2] feat: support aten.isnan converter --- .../dynamo/conversion/aten_ops_converters.py | 17 ++++ .../dynamo/conversion/impl/unary/ops.py | 20 +++++ tests/py/dynamo/conversion/test_isnan_aten.py | 82 +++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 tests/py/dynamo/conversion/test_isnan_aten.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0dd153d0aa..2341c7f5ac 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1493,6 +1493,23 @@ def aten_ops_isinf( ) +@dynamo_tensorrt_converter(torch.ops.aten.isnan.default) +def aten_ops_isnan( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.isnan( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.add.Tensor) @dynamo_tensorrt_converter(torch.ops.aten.add.Scalar) def aten_ops_add( diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index 554640ea5a..f02e6082e3 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -508,3 +508,23 @@ def scalar_tensor( identity_layer = ctx.net.add_identity(tensor) set_layer_name(identity_layer, target, name, source_ir) return identity_layer.get_output(0) + + +def isnan( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + # False for NaN elements since NaN is not equal to anything, including itself. + equality_result = impl.elementwise.ops.eq( + ctx, target, source_ir, f"{name}_eq_nan", input, input + ) + + # Invert equality_result to get a mask where NaN values are marked as True. + nan_values_mask = logical_not( + ctx, target, source_ir, f"{name}_logical_not", equality_result + ) + + return nan_values_mask diff --git a/tests/py/dynamo/conversion/test_isnan_aten.py b/tests/py/dynamo/conversion/test_isnan_aten.py new file mode 100644 index 0000000000..5651b0ca25 --- /dev/null +++ b/tests/py/dynamo/conversion/test_isnan_aten.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests + +from .harness import DispatchTestCase + + +class TestIsNanConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + torch.tensor( + [ + 1.23, + float("nan"), + -4.56, + float("inf"), + float("-inf"), + -100.0, + float("nan"), + 0.13, + -0.13, + 3.14159265, + ] + ), + ), + ] + ) + def test_isnan_float(self, data): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [data] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + (torch.full((2, 2), float("nan"), dtype=torch.float32),), + (torch.full((3, 10, 5), float("nan"), dtype=torch.float32),), + (torch.randn((5, 10, 5), dtype=torch.float32),), + ] + ) + def test_isnan_dim(self, data): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [data] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + @parameterized.expand( + [ + ((10,), torch.int, 0, 5), + ((1, 20), torch.int32, -10, 10), + ((2, 3, 4), torch.int, -5, 5), + ] + ) + def test_isnan_int(self, input_shape, dtype, low, high): + class isnan(nn.Module): + def forward(self, input): + return torch.ops.aten.isnan.default(input) + + inputs = [torch.randint(low, high, input_shape, dtype=dtype)] + self.run_test( + isnan(), + inputs, + output_dtypes=[torch.bool], + ) + + +if __name__ == "__main__": + run_tests() From 862017fb416386e961a3a742de2f6fd9d69483ee Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 1 Apr 2024 15:54:23 +0900 Subject: [PATCH 2/2] chore: minor fix --- py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index f02e6082e3..4bc24051ee 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -518,7 +518,7 @@ def isnan( input: TRTTensor, ) -> TRTTensor: # False for NaN elements since NaN is not equal to anything, including itself. - equality_result = impl.elementwise.ops.eq( + equality_result = impl.elementwise.eq( ctx, target, source_ir, f"{name}_eq_nan", input, input )