Skip to content

Commit 9a46e90

Browse files
committed
Update on "[ET-VK] Tuning native layer norm local workgroup size to improve thread occupancy during reduce."
This diff is tuning the local workgroup size of the native layer norm operation in Vulkan backend of Executorch to improve thread occupancy during the reduce phase. Differential Revision: [D72581293](https://our.internmc.facebook.com/intern/diff/D72581293/) [ghstack-poisoned]
2 parents e0860b4 + 951aa92 commit 9a46e90

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1193
-144
lines changed

.ci/scripts/gather_benchmark_configs.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"samsung_galaxy_s22": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/e59f866a-30aa-4aa1-87b7-4510e5820dfa",
2424
"samsung_galaxy_s24": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98f8788c-2e25-4a3c-8bb2-0d1e8897c0db",
2525
"google_pixel_8_pro": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/d65096ab-900b-4521-be8b-a3619b69236a",
26+
"google_pixel_3_private_rooted": "arn:aws:devicefarm:us-west-2:308535385114:devicepool:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/98d23ca8-ea9e-4fb7-b725-d402017b198d",
2627
}
2728

2829
# Predefined benchmark configurations
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
name: android-perf (private devices)
2+
3+
on:
4+
schedule:
5+
- cron: 0 0,4,8,12,16,20 * * *
6+
pull_request:
7+
paths:
8+
- .github/workflows/android-perf-private-device-experiment.yml
9+
push:
10+
branches:
11+
- main
12+
paths:
13+
- .github/workflows/android-perf-private-device-experiment.yml
14+
# Note: GitHub has an upper limit of 10 inputs
15+
workflow_dispatch:
16+
inputs:
17+
models:
18+
description: Models to be benchmarked
19+
required: false
20+
type: string
21+
default: mv3,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8
22+
devices:
23+
description: Target devices to run benchmark
24+
required: false
25+
type: string
26+
default: google_pixel_3_private_rooted
27+
benchmark_configs:
28+
description: The list of configs used the benchmark
29+
required: false
30+
type: string
31+
workflow_call:
32+
inputs:
33+
models:
34+
description: Models to be benchmarked
35+
required: false
36+
type: string
37+
default: mv3,meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8,meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8
38+
devices:
39+
description: Target devices to run benchmark
40+
required: false
41+
type: string
42+
default: google_pixel_3_private_rooted
43+
benchmark_configs:
44+
description: The list of configs used the benchmark
45+
required: false
46+
type: string
47+
48+
concurrency:
49+
group: android-perf-private-devices-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
50+
cancel-in-progress: true
51+
52+
jobs:
53+
android:
54+
uses: ./.github/workflows/android-perf.yml
55+
secrets: inherit
56+
permissions:
57+
id-token: write
58+
contents: read
59+
with:
60+
models: ${{ inputs.models }}
61+
devices: google_pixel_3_private_rooted
62+
benchmark_configs: ${{ inputs.benchmark_configs }}

.github/workflows/android-release-artifacts.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ jobs:
4949
contents: read
5050
with:
5151
secrets-env: EXECUTORCH_MAVEN_SIGNING_KEYID EXECUTORCH_MAVEN_SIGNING_PASSWORD EXECUTORCH_MAVEN_CENTRAL_PASSWORD EXECUTORCH_MAVEN_CENTRAL_USERNAME EXECUTORCH_MAVEN_SIGNING_GPG_KEY_CONTENTS
52-
runner: linux.2xlarge
52+
# As this job has access to Maven credential, run this on a fresh ephemeral runner
53+
runner: ephemeral.linux.2xlarge
5354
docker-image: executorch-ubuntu-22.04-clang12-android
5455
submodules: 'recursive'
5556
ref: ${{ github.sha }}

backends/arm/_passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .decompose_linear_pass import DecomposeLinearPass # noqa
2727
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2828
from .decompose_select import DecomposeSelectPass # noqa
29+
from .decompose_silu_pass import DecomposeSiluPass # noqa
2930
from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa
3031
from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa
3132
from .decompose_sqrt_pass import DecomposeSqrtPass # noqa

backends/arm/_passes/arm_pass_manager.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
DecomposeLinearPass,
3232
DecomposeMeanDimPass,
3333
DecomposeSelectPass,
34+
DecomposeSiluPass,
3435
DecomposeSoftmaxPass,
3536
DecomposeSoftmaxUnstablePass,
3637
DecomposeSqrtPass,
@@ -196,6 +197,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
196197
self.add_pass(DecomposeDivPass())
197198
self.add_pass(DecomposeLeakyReLUPass())
198199
self.add_pass(DecomposeSqrtPass())
200+
self.add_pass(DecomposeSiluPass())
199201

200202
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
201203
# Numerically stable softmax uses amax which is not supported on Ethos-U55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import torch
9+
from executorch.exir.pass_base import ExportPass
10+
11+
aten_silu_ops = (torch.ops.aten.silu.default, torch.ops.aten.silu_.default)
12+
13+
14+
class DecomposeSiluPass(ExportPass):
15+
"""
16+
This pass decomposes silu into a mul and a sigmoid node.
17+
18+
Example:
19+
y = silu(a)
20+
Becomes:
21+
x = sigmoid(a)
22+
y = mul(a,x)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in (aten_silu_ops):
27+
return super().call_operator(op, args, kwargs, meta)
28+
sigmoid_op = torch.ops.aten.sigmoid.default
29+
mul_op = torch.ops.aten.mul.Tensor
30+
31+
original = args[0]
32+
sigmoid = super().call_operator(sigmoid_op, (original,), {}, meta)
33+
34+
return super().call_operator(mul_op, (original, sigmoid), {}, meta)

backends/arm/_passes/match_arg_ranks_pass.py

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, exported_program):
4949
exir_ops.edge.aten.bitwise_left_shift.Tensor,
5050
exir_ops.edge.aten.eq.Tensor,
5151
exir_ops.edge.aten.gt.Tensor,
52+
exir_ops.edge.aten.ge.Tensor,
5253
exir_ops.edge.aten.lt.Tensor,
5354
exir_ops.edge.aten.pow.Tensor_Tensor,
5455
exir_ops.edge.aten.where.self,

backends/arm/_passes/replace_scalar_with_tensor_pass.py

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor,
2828
exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor,
2929
exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor,
30+
exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor,
3031
exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor,
3132
torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor,
3233
torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor,
@@ -36,6 +37,7 @@
3637
torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor,
3738
torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor,
3839
torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor,
40+
torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor,
3941
torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor,
4042
}
4143

backends/arm/operator_support/ethos_u55_support.py

+1
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class EthosU55NotSupported(OperatorSupportBase):
134134
exir_ops.edge.aten.eq.Tensor,
135135
exir_ops.edge.aten.eq.Scalar,
136136
exir_ops.edge.aten.ge.Tensor,
137+
exir_ops.edge.aten.ge.Scalar,
137138
exir_ops.edge.aten.gt.Tensor,
138139
exir_ops.edge.aten.gt.Scalar,
139140
exir_ops.edge.aten.le.Tensor,

backends/arm/operator_support/tosa_supported_operators.py

+2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def is_node_supported(
178178
exir_ops.edge.aten.full.default,
179179
exir_ops.edge.aten.full_like.default,
180180
exir_ops.edge.aten.ge.Tensor,
181+
exir_ops.edge.aten.ge.Scalar,
181182
exir_ops.edge.aten.gt.Tensor,
182183
exir_ops.edge.aten.gt.Scalar,
183184
exir_ops.edge.aten.le.Tensor,
@@ -228,6 +229,7 @@ def is_node_supported(
228229
exir_ops.edge.aten.__lshift__.Scalar,
229230
torch.ops.aten.scalar_tensor.default,
230231
exir_ops.edge.aten.gelu.default,
232+
exir_ops.edge.aten.alias_copy.default,
231233
]
232234

233235
return supported

backends/arm/operators/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
op_erf,
2323
op_exp,
2424
op_ge,
25-
op_get_item,
2625
op_gt,
2726
op_le,
2827
op_log,
@@ -51,5 +50,6 @@
5150
op_view,
5251
op_where,
5352
ops_binary,
53+
ops_identity,
5454
ops_unary,
5555
)

backends/arm/operators/op_get_item.py

-35
This file was deleted.
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import List
9+
10+
import torch
11+
import torch.fx
12+
13+
import tosa_tools.v0_80.serializer.tosa_serializer as ts
14+
15+
from executorch.backends.arm.operators.node_visitor import (
16+
NodeVisitor,
17+
register_node_visitor,
18+
)
19+
from executorch.backends.arm.tosa_mapping import TosaArg
20+
21+
22+
def identity_operator_factory(identity_target: str):
23+
"""
24+
Creates and registers NodeVisitors for operators that map directly
25+
to a TOSA IDENTITY op.
26+
"""
27+
28+
class IdentityOperatorVisitor(NodeVisitor):
29+
target = identity_target
30+
31+
def define_node(
32+
self,
33+
node: torch.fx.Node,
34+
tosa_graph: ts.TosaSerializer,
35+
inputs: List[TosaArg],
36+
output: TosaArg,
37+
) -> None:
38+
# Simply add an identityOp
39+
tosa_graph.addOperator(
40+
ts.TosaOp.Op().IDENTITY, [inputs[0].name], [output.name]
41+
)
42+
43+
register_node_visitor(IdentityOperatorVisitor)
44+
45+
46+
identity_operator_factory("getitem")
47+
identity_operator_factory("aten.alias_copy.default")

backends/arm/quantizer/arm_quantizer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -286,10 +286,10 @@ def _annotate_all_static_patterns(
286286
quantization_config: Optional[QuantizationConfig],
287287
filter_fn: Optional[Callable[[Node], bool]] = None,
288288
) -> GraphModule:
289-
"""Loops over all STATIC_OPS and runs the corresponding registred annotator.
289+
"""Loops over all STATIC_OPS and runs the corresponding registered annotator.
290290
Args:
291291
model: The model to annotate statically.
292-
quantization_config: Specifices the QuantizationSpecs for the model's
292+
quantization_config: Specifies the QuantizationSpecs for the model's
293293
input activations, output activations, weights and biases.
294294
filter_fn: An optional filter function that takes a node and returns whether the node should be annotated.
295295
Returns:

backends/arm/quantizer/quantization_annotator.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ def _match_pattern(
244244
operator.getitem,
245245
]
246246

247+
_one_to_one_shared_input_or_input_act_qspec = [
248+
torch.ops.aten.adaptive_avg_pool2d.default,
249+
torch.ops.aten.alias_copy.default,
250+
]
251+
247252

248253
def get_quant_properties( # noqa: C901
249254
node: Node, gm: torch.fx.GraphModule, quantization_config
@@ -332,7 +337,7 @@ def any_or_hardtanh_min_zero(n: Node):
332337
_QuantProperty(2, shared_qspec), # type: ignore[arg-type]
333338
]
334339
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
335-
elif node.target == torch.ops.aten.adaptive_avg_pool2d.default:
340+
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
336341
input_qspec = (
337342
SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
338343
if arm_quantizer_utils.is_output_annotated(node.args[0]) # type: ignore

0 commit comments

Comments
 (0)