Skip to content

Commit 8a5d102

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Add support for quantized LeakyReLU (#1)
Summary: X-link: pytorch/pytorch#104309 Pull Request resolved: #1 Also adds support for backend_config Reviewed By: mcr229 Differential Revision: D47043207 fbshipit-source-id: 51abd266bba7441c28578f6c58686a3d021d9d2a
1 parent c551f20 commit 8a5d102

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,9 @@ def __init__(self):
552552
torch.nn.ReLU,
553553
torch.nn.functional.relu,
554554
torch.nn.functional.relu_,
555+
torch.nn.functional.leaky_relu,
556+
torch.nn.functional.leaky_relu_,
557+
torch.nn.LeakyReLU,
555558
]
556559

557560
# Modules which support dynamic quantization

backends/xnnpack/test/test_xnnpack_quantized.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,47 @@ def test_xnnpack_qhardtanh(self):
178178
example_inputs = (torch.randn(1, 1, 1),)
179179
self.quantize_and_test_model(torch.nn.Hardtanh(), example_inputs)
180180

181+
def test_xnnpack_leaky_relu(self):
182+
example_inputs = (torch.randn(1, 3, 3),)
183+
184+
class LeakyReLUModule(torch.nn.Module):
185+
def __init__(self):
186+
super().__init__()
187+
self.leaky_relu_out_of_place = torch.nn.LeakyReLU(negative_slope=0.2)
188+
189+
def forward(self, x):
190+
return self.leaky_relu_out_of_place(x)
191+
192+
self.quantize_and_test_model(LeakyReLUModule(), example_inputs)
193+
194+
def test_xnnpack_leaky_relu2(self):
195+
example_inputs = (torch.randn(1, 3, 3),)
196+
197+
class LeakyReLUModule(torch.nn.Module):
198+
def __init__(self):
199+
super().__init__()
200+
self.leaky_relu_in_place = torch.nn.LeakyReLU(
201+
negative_slope=0.08, inplace=True
202+
)
203+
204+
def forward(self, x):
205+
return self.leaky_relu_in_place(x)
206+
207+
self.quantize_and_test_model(LeakyReLUModule(), example_inputs)
208+
209+
def test_xnnpack_leaky_relu3(self):
210+
example_inputs = (torch.randn(1, 3, 3),)
211+
212+
class LeakyReLUModule(torch.nn.Module):
213+
def __init__(self):
214+
super().__init__()
215+
self.leaky_relu_functional_default = torch.nn.functional.leaky_relu
216+
217+
def forward(self, x):
218+
return self.leaky_relu_functional_default(x)
219+
220+
self.quantize_and_test_model(LeakyReLUModule(), example_inputs)
221+
181222
def test_xnnpack_qlinear(self):
182223
in_size = 1
183224
input_size = 3

0 commit comments

Comments
 (0)