Skip to content

Commit 0d09d52

Browse files
committed
refactor: Centralizing sigmoid implementation
Signed-off-by: Naren Dasan <[email protected]>
1 parent dfd98a5 commit 0d09d52

File tree

6 files changed

+112
-47
lines changed

6 files changed

+112
-47
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3156,17 +3156,8 @@ def acc_ops_sigmoid(
31563156
kwargs: Dict[str, Argument],
31573157
name: str,
31583158
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3159-
input_val = kwargs["input"]
3160-
3161-
if not isinstance(input_val, TRTTensor):
3162-
raise RuntimeError(
3163-
f"Sigmoid received input {input_val} that is not part "
3164-
"of the TensorRT region!"
3165-
)
31663159

3167-
return add_activation_layer(
3168-
network, input_val, trt.ActivationType.SIGMOID, target, name
3169-
)
3160+
return activation.convert_sigmoid(network, target, kwargs, name, SourceIR.ACC)
31703161

31713162

31723163
@tensorrt_converter(acc_ops.permute)

py/torch_tensorrt/fx/converters/activation.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,3 +482,19 @@ def aten_ops_sym_size(
482482
)
483483
set_layer_name(slice_layer, target, "_slice_layer")
484484
return slice_layer.get_output(0)
485+
486+
487+
@tensorrt_converter(torch.ops.aten.sigmoid.default)
488+
def aten_ops_sigmoid(
489+
network: TRTNetwork,
490+
target: Target,
491+
args: Tuple[Argument, ...],
492+
kwargs: Dict[str, Argument],
493+
name: str,
494+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
495+
496+
kwargs_new = {
497+
"input": args[0],
498+
}
499+
500+
return activation.convert_sigmoid(network, target, kwargs_new, name, SourceIR.ATEN)

py/torch_tensorrt/fx/converters/impl/activation.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,31 @@ def relu_dyn_range_fn(dyn_range):
8585
return convert_activation(
8686
network, input_val, operation_type, target, name, relu_dyn_range_fn, source_ir
8787
)
88+
89+
90+
def convert_sigmoid(
91+
network: TRTNetwork,
92+
target: Target,
93+
kwargs: Dict[str, Any],
94+
name: str,
95+
source_ir: SourceIR = SourceIR.UNKNOWN,
96+
):
97+
input_val = kwargs["input"]
98+
operation_type = trt.ActivationType.SIGMOID
99+
100+
def sigmoid_dyn_range_fn(dyn_range):
101+
def sigmoid_fn(x):
102+
# TODO: Can this just call torch.nn.functional.sigmoid?
103+
return 1 / (1 + np.exp(-x))
104+
105+
return sigmoid_fn(dyn_range[0]), sigmoid_fn(dyn_range[1])
106+
107+
return convert_activation(
108+
network,
109+
input_val,
110+
operation_type,
111+
target,
112+
name,
113+
sigmoid_dyn_range_fn,
114+
source_ir,
115+
)

py/torch_tensorrt/fx/converters/nn_ops_converters.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,17 @@ def relu(network, submod, args, kwargs, layer_name):
2222
name=layer_name,
2323
source_ir=SourceIR.NN,
2424
)
25+
26+
27+
@tensorrt_converter(torch.nn.modules.activation.Sigmoid)
28+
def sigmoid(network, submod, args, kwargs, layer_name):
29+
# args/kwargs should have already been normalized to kwargs
30+
assert len(args) == 0
31+
32+
activation.convert_sigmoid(
33+
network=network,
34+
target="torch.nn.modules.activation.Sigmoid",
35+
kwargs=kwargs,
36+
name=layer_name,
37+
source_ir=SourceIR.NN,
38+
)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
5+
6+
7+
class TestSigmoidConverter(DispatchTestCase):
8+
def test_sigmoid(self):
9+
class TestModule(nn.Module):
10+
def forward(self, x):
11+
return nn.functional.sigmoid(x)
12+
13+
inputs = [torch.randn(1, 10)]
14+
self.run_test(
15+
TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default}
16+
)
17+
18+
def test_sigmoid_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x):
21+
return nn.functional.sigmoid(x)
22+
23+
input_specs = [
24+
InputTensorSpec(
25+
shape=(-1, -1, -1),
26+
dtype=torch.float32,
27+
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
28+
),
29+
]
30+
self.run_test_with_dynamic_shape(
31+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
32+
)
33+
34+
def test_sigmoid_with_dynamic_shape_four_dimensions(self):
35+
class TestModule(nn.Module):
36+
def forward(self, x):
37+
return nn.functional.sigmoid(x)
38+
39+
input_specs = [
40+
InputTensorSpec(
41+
shape=(-1, -1, -1, -1),
42+
dtype=torch.float32,
43+
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
44+
),
45+
]
46+
47+
self.run_test_with_dynamic_shape(
48+
TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default}
49+
)
50+
51+
52+
if __name__ == "__main__":
53+
run_tests()

0 commit comments

Comments
 (0)