Skip to content

Commit 670d2be

Browse files
authored
feat: support Dynamo converter for torch.ops.aten.erf.default op
Dynamo converter support for torch.ops.aten.erf.default op
2 parents e6e8099 + 3c4c2fe commit 670d2be

File tree

3 files changed

+89
-3
lines changed

3 files changed

+89
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,23 @@ def aten_ops_squeeze(
329329
return impl.squeeze.squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
330330

331331

332+
@dynamo_tensorrt_converter(torch.ops.aten.erf.default) # type: ignore[misc]
333+
def aten_ops_erf(
334+
network: TRTNetwork,
335+
target: Target,
336+
args: Tuple[Argument, ...],
337+
kwargs: Dict[str, Argument],
338+
name: str,
339+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
340+
return impl.unary.erf(
341+
network,
342+
target,
343+
SourceIR.ATEN,
344+
name,
345+
args[0],
346+
)
347+
348+
332349
@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) # type: ignore[misc]
333350
def aten_ops_unsqueeze(
334351
network: TRTNetwork,
@@ -357,14 +374,14 @@ def aten_ops_softmax(
357374

358375
@dynamo_tensorrt_converter(
359376
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_with_args([1])
360-
)
377+
) # type: ignore[misc]
361378
@dynamo_tensorrt_converter(
362379
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_with_args([1])
363-
)
380+
) # type: ignore[misc]
364381
@dynamo_tensorrt_converter(
365382
torch.ops.aten.split_with_sizes.default,
366383
capability_validator=dynamic_unsupported_with_args([1]),
367-
)
384+
) # type: ignore[misc]
368385
def aten_ops_split(
369386
network: TRTNetwork,
370387
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,20 @@ def neg(
401401
return convert_unary(
402402
network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
403403
)
404+
405+
406+
def erf(
407+
network: TRTNetwork,
408+
target: Target,
409+
source_ir: Optional[SourceIR],
410+
name: str,
411+
input_val: TRTTensor,
412+
) -> TRTTensor:
413+
if (isinstance(input_val, TRTTensor)) and (
414+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
415+
):
416+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
417+
418+
return convert_unary(
419+
network, target, source_ir, name, trt.UnaryOperation.ERF, input_val
420+
)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestErfConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("2d_dim_dtype_float", (2, 2), torch.float),
14+
("3d_dim_dtype_float", (2, 2, 2), torch.float),
15+
("2d_dim_dtype_half", (2, 2), torch.half),
16+
("3d_dim_dtype_half", (2, 2, 2), torch.half),
17+
]
18+
)
19+
def test_erf_float(self, _, x, type):
20+
class erf(nn.Module):
21+
def forward(self, input):
22+
return torch.erf(input)
23+
24+
inputs = [torch.randn(x, dtype=type)]
25+
self.run_test(
26+
erf(),
27+
inputs,
28+
precision=type,
29+
expected_ops={torch.ops.aten.erf.default},
30+
)
31+
32+
@parameterized.expand(
33+
[
34+
("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5),
35+
("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5),
36+
]
37+
)
38+
def test_erf_int(self, _, x, type, min, max):
39+
class erf(nn.Module):
40+
def forward(self, input):
41+
return torch.erf(input)
42+
43+
inputs = [torch.randint(min, max, x, dtype=type)]
44+
self.run_test(
45+
erf(),
46+
inputs,
47+
expected_ops={torch.ops.aten.erf.default},
48+
)
49+
50+
51+
if __name__ == "__main__":
52+
run_tests()

0 commit comments

Comments
 (0)