diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 19f273ba3f..b0f718256f 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -329,6 +329,23 @@ def aten_ops_squeeze( return impl.squeeze.squeeze(network, target, SourceIR.ATEN, name, args[0], args[1]) +@dynamo_tensorrt_converter(torch.ops.aten.erf.default) # type: ignore[misc] +def aten_ops_erf( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unary.erf( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) # type: ignore[misc] def aten_ops_unsqueeze( network: TRTNetwork, @@ -357,14 +374,14 @@ def aten_ops_softmax( @dynamo_tensorrt_converter( torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1]) -) +) # type: ignore[misc] @dynamo_tensorrt_converter( torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1]) -) +) # type: ignore[misc] @dynamo_tensorrt_converter( torch.ops.aten.split_with_sizes.default, capability_validator=dynamic_unsupported_with_args([1]), -) +) # type: ignore[misc] def aten_ops_split( network: TRTNetwork, target: Target, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index a91efac621..1a52ae7dc6 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -401,3 +401,20 @@ def neg( return convert_unary( network, target, source_ir, name, trt.UnaryOperation.NEG, input_val ) + + +def erf( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) + + return convert_unary( + network, target, source_ir, name, trt.UnaryOperation.ERF, input_val + ) diff --git a/tests/py/dynamo/conversion/test_erf_aten.py b/tests/py/dynamo/conversion/test_erf_aten.py new file mode 100644 index 0000000000..e50deeb5bb --- /dev/null +++ b/tests/py/dynamo/conversion/test_erf_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input + +from .harness import DispatchTestCase + + +class TestErfConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_dtype_float", (2, 2), torch.float), + ("3d_dim_dtype_float", (2, 2, 2), torch.float), + ("2d_dim_dtype_half", (2, 2), torch.half), + ("3d_dim_dtype_half", (2, 2, 2), torch.half), + ] + ) + def test_erf_float(self, _, x, type): + class erf(nn.Module): + def forward(self, input): + return torch.erf(input) + + inputs = [torch.randn(x, dtype=type)] + self.run_test( + erf(), + inputs, + precision=type, + expected_ops={torch.ops.aten.erf.default}, + ) + + @parameterized.expand( + [ + ("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5), + ("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5), + ] + ) + def test_erf_int(self, _, x, type, min, max): + class erf(nn.Module): + def forward(self, input): + return torch.erf(input) + + inputs = [torch.randint(min, max, x, dtype=type)] + self.run_test( + erf(), + inputs, + expected_ops={torch.ops.aten.erf.default}, + ) + + +if __name__ == "__main__": + run_tests()