Skip to content

Commit 7158ca5

Browse files
apbosegs-olive
authored andcommitted
Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) <Target: converter_reorg_elementwise> (#1905)
1 parent 45e43ca commit 7158ca5

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2323
from torch_tensorrt.fx.converters.impl import activation
2424
from torch_tensorrt.fx.converters.impl.elementwise import trunc_div
25+
from torch_tensorrt.fx.converters.impl.elementwise import rsqrt
2526

2627
_LOGGER: logging.Logger = logging.getLogger(__name__)
2728

@@ -366,6 +367,42 @@ def aten_ops_relu(
366367
)
367368

368369

370+
@tensorrt_converter(torch.ops.aten.relu.default)
371+
def aten_ops_relu(
372+
network: TRTNetwork,
373+
target: Target,
374+
args: Tuple[Argument, ...],
375+
kwargs: Dict[str, Argument],
376+
name: str,
377+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
378+
379+
return activation.relu(
380+
network,
381+
target,
382+
SourceIR.ATEN,
383+
name,
384+
args[0],
385+
)
386+
387+
388+
@tensorrt_converter(torch.ops.aten.rsqrt.default)
389+
def aten_ops_rsqrt(
390+
network: TRTNetwork,
391+
target: Target,
392+
args: Tuple[Argument, ...],
393+
kwargs: Dict[str, Argument],
394+
name: str,
395+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
396+
397+
return rsqrt(
398+
network,
399+
target,
400+
SourceIR.ATEN,
401+
name,
402+
args[0],
403+
)
404+
405+
369406
@tensorrt_converter(torch.ops.aten.sub.Tensor)
370407
def aten_ops_sub(
371408
network: TRTNetwork,

py/torch_tensorrt/fx/converters/impl/elementwise/ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,33 @@ def trunc_div(
109109
)
110110

111111
return output
112+
113+
114+
def rsqrt(
115+
network: TRTNetwork,
116+
target: Target,
117+
source_ir: Optional[SourceIR],
118+
name: str,
119+
input: TRTTensor,
120+
) -> TRTTensor:
121+
122+
sqrt_trt_output = convert_unary(
123+
network,
124+
target,
125+
source_ir,
126+
f"{name}_sqrt",
127+
trt.UnaryOperation.SQRT,
128+
input,
129+
)
130+
131+
output = convert_binary_elementwise(
132+
network,
133+
target,
134+
source_ir,
135+
f"{name}_output",
136+
trt.ElementWiseOperation.DIV,
137+
1,
138+
sqrt_trt_output,
139+
)
140+
141+
return output
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
6+
7+
8+
class TestRSqrtConverter(DispatchTestCase):
9+
@parameterized.expand(
10+
[
11+
("2d_dim_alpha", (2, 1), 2),
12+
("3d_dim_alpha", (2, 1, 2), 2),
13+
]
14+
)
15+
def test_rsqrt(self, _, x, alpha):
16+
class rsqrt(nn.Module):
17+
def forward(self, input):
18+
return torch.rsqrt(input)
19+
20+
inputs = [torch.randn(x) + 1]
21+
self.run_test(
22+
rsqrt(),
23+
inputs,
24+
expected_ops={torch.ops.aten.rsqrt.default},
25+
)
26+
27+
28+
if __name__ == "__main__":
29+
run_tests()

0 commit comments

Comments
 (0)