Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions backends/arm/operators/op_hardtanh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,6 +11,8 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg

from executorch.backends.arm.tosa_quant_utils import get_quant_node_args
from serializer.tosa_serializer import TosaOp


Expand All @@ -30,12 +32,31 @@ def define_node(
is_quant_node: bool,
) -> None:
attr = ts.TosaSerializerAttribute()

if is_quant_node:
# Get quant parameters
scale, zp = get_quant_node_args(node.all_input_nodes[0])
# Convert to quantized representation
clamp_min_qs = round((inputs[1].number / scale) + zp)
clamp_min_qs = max(clamp_min_qs, -128)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should get qmin/qmax from the quant node

clamp_max_qs = round((inputs[2].number / scale) + zp)
clamp_max_qs = min(clamp_max_qs, 127)
# Set fp values to 0.0 since they are not used
clamp_min_fp = 0.0
clamp_max_fp = 0.0
else:
clamp_min_fp = inputs[1].number
clamp_max_fp = inputs[2].number
# Set qs values to 0 since they are not used
clamp_min_qs = 0
clamp_max_qs = 0

attr.ClampAttribute(
tosa_graph.builder,
int(inputs[1].number),
int(inputs[2].number),
inputs[1].number,
inputs[2].number,
clamp_min_qs,
clamp_max_qs,
clamp_min_fp,
clamp_max_fp,
)

tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr)
67 changes: 54 additions & 13 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from parameterized import parameterized

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -126,6 +127,32 @@ def forward(self, x):
return x


class ComboConvRelu6(torch.nn.Module):
edge_op_list = [
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten_hardtanh_default",
]

test_data = [
(20 * torch.randn(1, 3, 256, 256),),
(5 * torch.randn(1, 3, 256, 256),),
(torch.randn(1, 3, 256, 256),),
(-5 * torch.randn(1, 3, 256, 256),),
]

def __init__(self):
super().__init__()
self.conv2d = torch.nn.Conv2d(
in_channels=3, out_channels=3, kernel_size=3, stride=1, groups=1
)
self.relu6 = torch.nn.ReLU6()

def forward(self, x):
x = self.conv2d(x)
x = self.relu6(x)
return x


class TestConvCombos(unittest.TestCase):
def _test_conv_combo_tosa_MI_pipeline(
self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
Expand Down Expand Up @@ -222,15 +249,9 @@ def test_conv_batchnorm_relu_tosa_MI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

# TODO(MLETORCH-85): Investigate numerical issue. This diff is present in legacy
# testcase as well (and also not tested). For now, just increase the
# tolerance, such that we don't skip the test entirely (i.e. we maintain
# functionality).
def test_conv_batchnorm_relu_tosa_BI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_tosa_BI_pipeline(
model, model.get_inputs(), atol=1.0, rtol=1.0
)
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())

@unittest.skipIf(
not common.VELA_INSTALLED,
Expand All @@ -240,21 +261,41 @@ def test_conv_batchnorm_relu_u55_BI(self):
model = ComboConvBatchnormRelu()
self._test_conv_combo_u55_BI_pipeline(model, model.get_inputs())

##################
## Conv + ReLU6 ##
##################
@parameterized.expand(ComboConvRelu6.test_data)
def test_conv_relu6_tosa_MI(self, test_data: torch.Tensor):
model = ComboConvRelu6()
test_data = (test_data,)
self._test_conv_combo_tosa_MI_pipeline(model, test_data)

@parameterized.expand(ComboConvRelu6.test_data)
def test_conv_relu6_tosa_BI(self, test_data: torch.Tensor):
model = ComboConvRelu6()
test_data = (test_data,)
self._test_conv_combo_tosa_BI_pipeline(model, test_data)

@parameterized.expand(ComboConvRelu6.test_data)
@unittest.skipIf(
not common.VELA_INSTALLED,
"There is no point in running U55 tests if the Vela tool is not installed",
)
def test_conv_relu6_u55_BI(self, test_data: torch.Tensor):
model = ComboConvRelu6()
test_data = (test_data,)
self._test_conv_combo_u55_BI_pipeline(model, test_data)

###############################
## Block bottleneck residual ##
###############################
def test_block_bottleneck_residual_tosa_MI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_MI_pipeline(model, model.get_inputs())

# TODO(MLETORCH-85): Investigate numerical issue. This diff was present in legacy
# testcase as well. For now, just increase the tolerance, such that
# we don't skip the test entirely (i.e. we maintain functionality).
def test_block_bottleneck_residual_tosa_BI(self):
model = ComboBlockBottleneckResidual()
self._test_conv_combo_tosa_BI_pipeline(
model, model.get_inputs(), atol=1.0, rtol=1.0
)
self._test_conv_combo_tosa_BI_pipeline(model, model.get_inputs())

@unittest.skipIf(
not common.VELA_INSTALLED,
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
f"Output {i} does not match reference output.\n"
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}.\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n"
f"\t-- Model vs. Reference --\n"
f"\t Numel: {model.numel()}, {ref.numel()}\n"
f"\tMedian: {model.median()}, {ref.median()}\n"
Expand Down