Skip to content

Commit 334872a

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] dequantize_per_token.default test setup"
Creating dequantize_per_token testing framework along with a reference implementation for testing Differential Revision: [D76267037](https://our.internmc.facebook.com/intern/diff/D76267037/) [ghstack-poisoned]
2 parents 208ddac + 49c1f77 commit 334872a

File tree

79 files changed

+1958
-701
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1958
-701
lines changed

.github/workflows/trunk.yml

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -693,32 +693,3 @@ jobs:
693693
build-mode: Release
694694
build-tool: cmake
695695
docker-image: executorch-ubuntu-22.04-clang12
696-
697-
unittest-nxp-neutron:
698-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
699-
permissions:
700-
id-token: write
701-
contents: read
702-
with:
703-
runner: linux.2xlarge
704-
docker-image: executorch-ubuntu-22.04-clang12
705-
submodules: 'recursive'
706-
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
707-
timeout: 90
708-
script: |
709-
set -eux
710-
711-
# The generic Linux job chooses to use base env, not the one setup by the image
712-
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
713-
conda activate "${CONDA_ENV}"
714-
715-
# Build and install Executorch
716-
PYTHON_EXECUTABLE=python \
717-
CMAKE_ARGS="-DEXECUTORCH_BUILD_NXP_NEUTRON=ON" \
718-
.ci/scripts/setup-linux.sh --build-tool "cmake"
719-
720-
# Install test requirements
721-
pip install -r backends/nxp/requirements-tests.txt
722-
723-
# Run pytest
724-
PYTHON_EXECUTABLE=python bash backends/nxp/run_unittests.sh

backends/arm/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,9 @@ It is possible to control the compilation flow to aid in development and debug o
187187
Configuration of the EthosUBackend export flow is controlled by CompileSpec information (essentially used as compilation flags) to determine which of these outputs is produced. In particular this allows for use of the tosa_reference_model to run intermediate output to check for correctness and quantization accuracy without a full loop via hardware implemntation.
188188

189189
As this is in active development see the EthosUBackend for accurate information on [compilation flags](https://github.com/pytorch/executorch/blob/29f6dc9353e90951ed3fae3c57ae416de0520067/backends/arm/arm_backend.py#L319-L324)
190+
191+
## Model specific and optional passes
192+
The current TOSA version does not support int64. For LLMs for example LLama, often aten.emedding is the first operator and it requires int64 indicies.
193+
In order to lower this to TOSA and int64->int32 cast need to be injected. This pass need to run very early in the lowering process and can be passed in to the to_edge_transform_and_lower() function call as an optional parameter. See example in: backends/arm/test/models/test_llama.py.
194+
By doing this aten.embedding will be decomposed into to aten.index_select which can handle int32 indices.
195+
Note that this additional step is only needed for pure float models. With quantization this is automatically handled during annotation before the export stage.

backends/arm/_passes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .convert_to_clamp import ConvertToClampPass # noqa
2323
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2424
from .decompose_div_pass import DecomposeDivPass # noqa
25+
from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa
2526
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2627
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
2728
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
@@ -46,6 +47,9 @@
4647
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
4748
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
4849
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
50+
from .insert_int64_input_cast_pass import ( # noqa # noqa
51+
InsertCastForOpsWithInt64InputPass,
52+
)
4953
from .insert_rescales_pass import InsertRescalePass # noqa
5054
from .insert_table_ops import InsertTableOpsPass # noqa
5155
from .match_arg_ranks_pass import MatchArgRanksPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
# LICENSE file in the root directory of this source tree.
77

88
# pyre-unsafe
9-
109
from executorch.backends.arm._passes import (
1110
AnnotateChannelsLastDimOrder,
1211
AnnotateDecomposedMatmulPass,
@@ -26,6 +25,7 @@
2625
ConvertToClampPass,
2726
DecomposeCosineSimilarityPass,
2827
DecomposeDivPass,
28+
DecomposeEmbeddingPass,
2929
DecomposeGeluPass,
3030
DecomposeGroupNormPass,
3131
DecomposeLayerNormPass,
@@ -46,6 +46,7 @@
4646
FuseConstantArgsPass,
4747
FuseEqualPlaceholdersPass,
4848
FuseQuantizedActivationPass,
49+
InsertCastForOpsWithInt64InputPass,
4950
InsertRescalePass,
5051
InsertTableOpsPass,
5152
MatchArgRanksPass,
@@ -139,6 +140,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
139140
self.add_pass(DecomposeSqrtPass())
140141
self.add_pass(ConvertIntPowToMuls())
141142
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
143+
self.add_pass(DecomposeEmbeddingPass())
142144
self.add_pass(FuseQuantizedActivationPass())
143145
self.add_pass(RemoveGetItemPass())
144146
self.add_pass(ConvertSplitToSlicePass())
@@ -211,6 +213,8 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
211213
)
212214

213215
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
216+
self.add_pass(InsertCastForOpsWithInt64InputPass())
217+
self.add_pass(DecomposeEmbeddingPass())
214218
self.add_pass(DecomposeScaledDotProductAttention())
215219
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
216220
self.add_pass(ScalarsToAttributePass())
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
from math import prod
11+
12+
import torch
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
from .arm_pass_utils import create_node, get_first_fake_tensor
17+
18+
logger = logging.getLogger(__name__)
19+
logger.setLevel(logging.WARNING)
20+
21+
22+
class DecomposeEmbeddingPass(ExportPass):
23+
"""
24+
This pass decomposes embedding into index_select.
25+
26+
Example:
27+
o = embedding(w, i)
28+
Becomes:
29+
i = view_copy(i) # flatten indices
30+
o = index_select(w, i)
31+
o = view_copy(o) # reshape back output
32+
Note:
33+
i = indices is expected to be int32 before this pass
34+
"""
35+
36+
aten_ops = (torch.ops.aten.embedding.default,)
37+
edge_ops = (exir_ops.edge.aten.embedding.default,)
38+
39+
def get_decomposition(self, op):
40+
if op in self.aten_ops:
41+
return (
42+
torch.ops.aten.view_copy.default,
43+
torch.ops.aten.index_select.default,
44+
)
45+
46+
if op in self.edge_ops:
47+
return (
48+
exir_ops.edge.aten.view_copy.default,
49+
exir_ops.edge.aten.index_select.default,
50+
)
51+
raise RuntimeError(
52+
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
53+
)
54+
55+
def call(self, graph_module):
56+
graph = graph_module.graph
57+
modified_graph = False
58+
59+
for node in graph.nodes:
60+
if node.op != "call_function":
61+
continue
62+
if node.target not in self.aten_ops + self.edge_ops:
63+
continue
64+
65+
args = node.args
66+
67+
weights = args[0]
68+
indices = args[1]
69+
70+
weights_shape = get_first_fake_tensor(weights).shape
71+
indices_shape = get_first_fake_tensor(indices).shape
72+
73+
output_shape = torch.Size(list(indices_shape) + [weights_shape[1]])
74+
if output_shape != get_first_fake_tensor(node).shape:
75+
raise RuntimeError(
76+
f"[{self.__class__.__name__}] Unexpected output shape mismatch {output_shape} "
77+
"!= {get_first_fake_tensor(node).shape}"
78+
)
79+
80+
view_copy_op, index_select_op = self.get_decomposition(node.target)
81+
82+
with graph.inserting_before(node):
83+
reshaped_indices = [prod(list(indices_shape))]
84+
flattened_indices = create_node(
85+
graph=graph,
86+
op_target=view_copy_op,
87+
args=(indices, reshaped_indices),
88+
)
89+
node.replace_input_with(indices, flattened_indices)
90+
91+
index_select = create_node(
92+
graph=graph,
93+
op_target=index_select_op,
94+
args=(weights, 0, flattened_indices),
95+
)
96+
node.replace_all_uses_with(index_select)
97+
graph.erase_node(node)
98+
99+
with graph.inserting_after(index_select):
100+
restored_output = create_node(
101+
graph,
102+
view_copy_op,
103+
)
104+
restored_output.args = (
105+
index_select,
106+
output_shape,
107+
)
108+
original_users = [
109+
user for user in index_select.users if user != restored_output
110+
]
111+
for user in original_users:
112+
user.replace_input_with(index_select, restored_output)
113+
114+
modified_graph = True
115+
116+
if modified_graph:
117+
graph.eliminate_dead_code()
118+
graph_module.recompile()
119+
graph_module = super().call(graph_module).graph_module
120+
return PassResult(graph_module, modified_graph)
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
from .arm_pass_utils import create_node, get_first_fake_tensor
16+
17+
logger = logging.getLogger(__name__)
18+
logger.setLevel(logging.WARNING)
19+
20+
21+
class InsertCastForOpsWithInt64InputPass(ExportPass):
22+
23+
aten_ops = (torch.ops.aten.embedding.default,)
24+
edge_ops = (exir_ops.edge.aten.embedding.default,)
25+
26+
def get_decomposition(self, op):
27+
if op in self.edge_ops:
28+
return exir_ops.edge.aten._to_copy.default
29+
30+
if op in self.aten_ops:
31+
return torch.ops.aten._to_copy.default
32+
33+
raise RuntimeError(
34+
f"[{self.__class__.__name__}] Can't get decomposition for op {op}"
35+
)
36+
37+
def _check_aten_embedding_within_int32(self, weights, indices, node: torch.fx.Node):
38+
weights_shape = get_first_fake_tensor(weights).shape
39+
vocab_size = weights_shape[0]
40+
41+
# Essentially output = weight[indices] which means 0 <= indices[i] < vocab_size
42+
# So should be good if vocab size or number embeddings is below max int32
43+
if vocab_size >= torch.iinfo(torch.int32).max:
44+
logger.warning(
45+
f"[{node.name}] has size ({vocab_size}) that exceeds int32 limit,"
46+
"so aten.embedding will not be lowered to TOSA."
47+
)
48+
return False
49+
50+
return True
51+
52+
def call(self, graph_module):
53+
graph = graph_module.graph
54+
modified_graph = False
55+
56+
for node in list(graph.nodes):
57+
if node.op != "call_function":
58+
continue
59+
if node.target not in self.aten_ops + self.edge_ops:
60+
continue
61+
62+
args = node.args
63+
weights = args[0]
64+
indices = args[1]
65+
66+
valid_for_insert = False
67+
if node.target in (
68+
exir_ops.edge.aten.embedding.default,
69+
torch.ops.aten.embedding.default,
70+
):
71+
valid_for_insert = self._check_aten_embedding_within_int32(
72+
weights, indices, node
73+
)
74+
75+
if valid_for_insert:
76+
to_copy_op = self.get_decomposition(node.target)
77+
with graph.inserting_before(node):
78+
cast_before = create_node(
79+
graph,
80+
to_copy_op,
81+
args=(indices,),
82+
kwargs={
83+
"dtype": torch.int32,
84+
"memory_format": torch.preserve_format,
85+
},
86+
)
87+
node.replace_input_with(indices, cast_before)
88+
89+
modified_graph = True
90+
91+
if modified_graph:
92+
graph_module.recompile()
93+
graph_module = super().call(graph_module).graph_module
94+
return PassResult(graph_module, True)

backends/arm/operator_support/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from . import ( # noqa
99
convolution_support,
10+
embedding_support,
1011
ethos_u55_support,
12+
index_select_support,
1113
minmax_support,
1214
pool_2d_support,
1315
reduce_sum_support,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import torch
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.exir.dialects._ops import ops as exir_ops
16+
17+
18+
@register_tosa_support_check
19+
class EmbeddingSupported(SupportedTOSAOperatorCheck):
20+
targets = [exir_ops.edge.aten.embedding.default]
21+
22+
tosa_specs = [
23+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
24+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
26+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
27+
]
28+
29+
def is_node_tosa_supported(
30+
self, node: fx.Node, tosa_spec: TosaSpecification
31+
) -> bool: # type: ignore[override, misc]
32+
# Note aten.embedding.default requires int64 indices and TOSA does not support it.
33+
# Int32 indices here for aten.embedding.default is ok since it will be decomposed into ops that can handle it.
34+
assert (
35+
len(node.all_input_nodes) == 2
36+
), "Number of inputs to aten.embedding is not 2"
37+
indices_val = node.all_input_nodes[1].meta["val"]
38+
indices_dtype = indices_val.dtype
39+
40+
if indices_dtype != torch.int32:
41+
self.reporter.report_reject(
42+
node,
43+
f"Indices dtype {indices_val.dtype} is not supported in {node.target}.",
44+
)
45+
return False
46+
47+
return True

0 commit comments

Comments
 (0)