Skip to content

Commit 17060c5

Browse files
committed
Update on "[ET-VK] Adding all tensor packing support for native layer norm."
This diff updates Executorch Vulkan backend's `native layer norm` operation to support width, height and channel packed tensors. . and adds new test cases to the cases.py file to test the operation. Differential Revision: [D71663678](https://our.internmc.facebook.com/intern/diff/D71663678/) [ghstack-poisoned]
2 parents 31fa1aa + 1c9e7dd commit 17060c5

File tree

105 files changed

+2013
-819
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+2013
-819
lines changed

.github/scripts/extract_benchmark_results.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ def transform(
360360
"app_type": app_type,
361361
# Just keep a copy of the benchmark config here
362362
"benchmark_config": json.dumps(benchmark_config),
363+
"job_conclusion": "SUCCESS",
363364
},
364365
},
365366
"model": {
@@ -455,7 +456,7 @@ def transform_failure_record(
455456
},
456457
"metric": {
457458
"name": "FAILURE_REPORT",
458-
"benchmark_values": 0,
459+
"benchmark_values": [0],
459460
"target_value": 0,
460461
"extra_info": {
461462
"method": "",

CMakeLists.txt

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -645,13 +645,18 @@ target_link_options_shared_lib(executorch)
645645
# Real integrations should supply their own YAML file that only lists the
646646
# operators necessary for the models that will run.
647647
#
648+
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
649+
# find pytorch lib here to make it available to all
650+
# sub-directories. Find it before including portable so that
651+
# optimized_portable_kernels can use it.
652+
find_package_torch_headers()
653+
endif()
654+
648655
if(BUILD_EXECUTORCH_PORTABLE_OPS)
649656
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/portable)
650657
endif()
651658

652659
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
653-
# find pytorch lib here to make it available to all sub-directories
654-
find_package_torch_headers()
655660
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels/optimized)
656661
endif()
657662

@@ -764,10 +769,6 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
764769
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
765770
endif()
766771

767-
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
768-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
769-
endif()
770-
771772
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
772773
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
773774
endif()
@@ -872,34 +873,13 @@ if(EXECUTORCH_BUILD_PYBIND)
872873

873874
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
874875

875-
set(_pybind_training_dep_libs
876-
${TORCH_PYTHON_LIBRARY}
877-
etdump
878-
executorch
879-
util
880-
torch
881-
extension_training
882-
)
883-
884-
if(EXECUTORCH_BUILD_XNNPACK)
885-
# need to explicitly specify XNNPACK and microkernels-prod
886-
# here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu
887-
list(APPEND _pybind_training_dep_libs xnnpack_backend XNNPACK microkernels-prod)
888-
endif()
889-
890-
# pybind training
891-
pybind11_add_module(_training_lib SHARED extension/training/pybindings/_training_lib.cpp)
892-
893-
target_include_directories(_training_lib PRIVATE ${TORCH_INCLUDE_DIRS})
894-
target_compile_options(_training_lib PUBLIC ${_pybind_compile_options})
895-
target_link_libraries(_training_lib PRIVATE ${_pybind_training_dep_libs})
896-
897-
install(TARGETS _training_lib
898-
LIBRARY DESTINATION executorch/extension/training/pybindings
899-
)
900876
endif()
901877
endif()
902878

879+
if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
880+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training)
881+
endif()
882+
903883
if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
904884
# TODO: move all custom kernels to ${CMAKE_CURRENT_SOURCE_DIR}/kernels/custom
905885
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/custom_ops)

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ SCRIPT_DIR_PATH="$(
1212

1313
# TODO(jathu): remove the need to fetch coremltools to build deps for coreml_executor_runner.
1414
# Keep this version in sync with: pyproject.toml
15-
COREMLTOOLS_VERSION="8.1"
15+
COREMLTOOLS_VERSION="8.2"
1616

1717
red=`tput setaf 1`
1818
green=`tput setaf 2`

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
35+
pre_autograd_aten_dialect = export_for_training(
36+
model, example_inputs, strict=True
37+
).module()
3638

3739
quantization_config = LinearQuantizerConfig.from_dict(
3840
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def lower_module_and_test_output(
207207
expected_output = model(*sample_inputs)
208208

209209
model = torch.export.export_for_training(
210-
model, sample_inputs, dynamic_shapes=dynamic_shapes
210+
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

213213
edge_program = export_to_edge(

backends/arm/_passes/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from . import arm_pass_utils # noqa
88
from .annotate_channels_last_dim_order_pass import AnnotateChannelsLastDimOrder # noqa
99
from .annotate_decomposed_matmul import AnnotateDecomposedMatmulPass # noqa
10+
from .arm_pass import ArmPass # noqa
1011
from .cast_int64_pass import CastInt64BuffersToInt32Pass # noqa
1112
from .cast_to_int32_pass import CastToInt32Pass # noqa
1213
from .conv1d_unsqueeze_pass import Conv1dUnsqueezePass # noqa
@@ -41,6 +42,10 @@
4142
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
4243
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
4344
from .remove_clone_pass import RemoveClonePass # noqa
45+
from .replace_scalar_with_tensor_pass import ( # noqa
46+
ReplaceScalarWithTensorArgPassTOSABI,
47+
ReplaceScalarWithTensorArgPassTOSAMI,
48+
)
4449
from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa
4550
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
4651
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa

backends/arm/_passes/arm_pass.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
import traceback
9+
from typing import Optional
10+
11+
import torch
12+
from executorch.exir.pass_base import ExportPass, NodeMetadata
13+
14+
15+
class ArmPass(ExportPass):
16+
"""Base class for Arm passes"""
17+
18+
def __init__(self, exported_program: Optional[torch.export.ExportedProgram] = None):
19+
super(ArmPass, self).__init__()
20+
self.exported_program = exported_program
21+
22+
def call_operator(self, op, args, kwargs, meta, updated: Optional[bool] = False):
23+
if not updated:
24+
return super().call_operator(op, args, kwargs, meta)
25+
26+
# if updated we should update metadata
27+
new_meta = {}
28+
keys = meta.data.keys()
29+
for key in keys:
30+
new_meta[key] = meta[key]
31+
old_stack_trace = new_meta.get("stack_trace", "")
32+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
33+
return super().call_operator(op, args, kwargs, NodeMetadata(new_meta))

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,17 @@
4242
MatchArgRanksPass,
4343
QuantizeOperatorArguments,
4444
RemoveClonePass,
45+
ReplaceScalarWithTensorArgPassTOSABI,
46+
ReplaceScalarWithTensorArgPassTOSAMI,
4547
RetraceFoldedDtypesPass,
4648
ScalarsToAttributePass,
4749
SizeAdjustConv2DPass,
4850
UnsqueezeBeforeRepeatPass,
4951
UnsqueezeScalarPlaceholdersPass,
5052
)
53+
5154
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
5255
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
53-
54-
from executorch.backends.transforms.replace_scalar_with_tensor import (
55-
ReplaceScalarWithTensorArgPass,
56-
)
5756
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
5857
from executorch.exir import ExportedProgram
5958
from executorch.exir.pass_manager import PassManager
@@ -84,7 +83,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8483
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
8584
self.add_pass(CastToInt32Pass())
8685

87-
self.add_pass(ReplaceScalarWithTensorArgPass())
86+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
8887
self.add_pass(AnnotateDecomposedMatmulPass())
8988
self.add_pass(QuantizeOperatorArguments())
9089
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -113,7 +112,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113112
return self._transform(exported_program.graph_module)
114113

115114
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
116-
self.add_pass(ReplaceScalarWithTensorArgPass())
115+
self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI())
117116
self.add_pass(FuseQuantizedActivationPass())
118117
self.add_pass(RemoveGetItemPass())
119118
self.add_pass(ConvertSplitToSlicePass())
@@ -170,7 +169,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
170169
)
171170

172171
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
173-
self.add_pass(ReplaceScalarWithTensorArgPass())
172+
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
174173
self.add_pass(ScalarsToAttributePass())
175174
self.add_pass(DecomposeLayerNormPass())
176175
self.add_pass(DecomposeVarPass())

backends/arm/_passes/arm_pass_utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
# pyre-unsafe
99

10+
import traceback
1011
from inspect import isclass
1112
from typing import Optional, Sequence
1213

1314
import torch
1415
import torch.fx
15-
1616
from executorch.exir import ExportedProgram
1717
from executorch.exir.dialects._ops import ops as exir_ops
1818

@@ -96,6 +96,7 @@ def create_node(
9696
kwargs: Optional[dict] = None,
9797
quantize: bool = False,
9898
q_params: Optional[tuple] = None,
99+
from_node: Optional[torch.fx.Node] = None,
99100
):
100101
"""
101102
Adds a node to 'graph'. graph.inserting_before/after() should be used before the call to decide where to insert the node.
@@ -108,15 +109,26 @@ def create_node(
108109
args=args,
109110
kwargs=kwargs or {},
110111
)
112+
113+
new_meta = {}
114+
if from_node:
115+
keys = from_node.meta.keys()
116+
for key in keys:
117+
new_meta[key] = from_node.meta[key]
118+
old_stack_trace = new_meta.get("stack_trace", "")
119+
new_meta["stack_trace"] = f"{old_stack_trace}\n{traceback.format_stack()[-2]}"
120+
node.meta = new_meta
121+
111122
if quantize and q_params:
112-
return insert_q_dq_pair(graph, node, q_params)
123+
return insert_q_dq_pair(graph, node, q_params, from_node)
113124
return node
114125

115126

116127
def insert_q_dq_pair(
117128
graph: torch.fx.Graph,
118129
anchor: torch.fx.Node,
119130
q_params: tuple,
131+
from_node: Optional[torch.fx.Node] = None,
120132
):
121133
"""
122134
Inserts a q dq node pair after the node 'anchor'.
@@ -127,13 +139,15 @@ def insert_q_dq_pair(
127139
graph=graph,
128140
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
129141
args=(), # We add the argument last
142+
from_node=from_node if from_node else anchor,
130143
)
131144
q.meta = anchor.meta
132145
with graph.inserting_after(q):
133146
dq = create_node(
134147
graph=graph,
135148
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
136149
args=(q,) + q_params,
150+
from_node=from_node if from_node else anchor,
137151
)
138152
dq.meta = q.meta
139153
anchor.replace_all_uses_with(dq)

backends/arm/_passes/decompose_layernorm_pass.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
import operator
1010

1111
import torch
12+
from executorch.backends.arm._passes import ArmPass
1213
from executorch.backends.arm._passes.arm_pass_utils import create_node
1314
from executorch.exir.dialects._ops import ops as exir_ops
14-
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.pass_base import PassResult
1516

1617

1718
def get_layer_norm_decomposition(op) -> tuple:
@@ -40,7 +41,7 @@ def get_layer_norm_decomposition(op) -> tuple:
4041
raise RuntimeError(f"Can't get layer_norm composition for op {op}")
4142

4243

43-
class DecomposeLayerNormPass(ExportPass):
44+
class DecomposeLayerNormPass(ArmPass):
4445
"""
4546
layernorm is defined as: ((x - E[x]) / sqrt(Var[x] + eps)) * weights + bias
4647
Decompose layernorm(x, normalized_shape, weights, bias, eps) to a sequence of:
@@ -111,35 +112,56 @@ def call(self, graph_module: torch.fx.GraphModule):
111112
var_op,
112113
args=(x, dims),
113114
kwargs={"correction": 0, "keepdim": keepdim},
115+
from_node=node,
114116
)
115117
full = create_node(
116118
graph_module.graph,
117119
full_op,
118120
args=(epsilon_reshaped_shape, epsilon),
119121
kwargs={"dtype": dtype},
122+
from_node=node,
123+
)
124+
add0 = create_node(
125+
graph_module.graph, add_op, args=(var, full), from_node=node
126+
)
127+
rsqrt = create_node(
128+
graph_module.graph, rsqrt_op, args=(add0,), from_node=node
129+
)
130+
mul0 = create_node(
131+
graph_module.graph, mul_op, args=(sub, rsqrt), from_node=node
120132
)
121-
add0 = create_node(graph_module.graph, add_op, args=(var, full))
122-
rsqrt = create_node(graph_module.graph, rsqrt_op, args=(add0,))
123-
mul0 = create_node(graph_module.graph, mul_op, args=(sub, rsqrt))
124133
if weights is not None:
125134
weights_reshaped = create_node(
126135
graph_module.graph,
127136
view_op,
128137
args=(weights, weights_reshaped_shape),
138+
from_node=node,
129139
)
130140
mul1 = create_node(
131-
graph_module.graph, mul_op, args=(mul0, weights_reshaped)
141+
graph_module.graph,
142+
mul_op,
143+
args=(
144+
mul0,
145+
weights_reshaped,
146+
),
147+
from_node=node,
132148
)
133149
else:
134150
mul1 = mul0
135151
output = mul1
136152
if bias is not None:
137153
bias_reshaped_shape = weights_reshaped_shape
138154
bias_reshaped = create_node(
139-
graph_module.graph, view_op, args=(bias, bias_reshaped_shape)
155+
graph_module.graph,
156+
view_op,
157+
args=(bias, bias_reshaped_shape),
158+
from_node=node,
140159
)
141160
output = create_node(
142-
graph_module.graph, add_op, args=(mul1, bias_reshaped)
161+
graph_module.graph,
162+
add_op,
163+
args=(mul1, bias_reshaped),
164+
from_node=node,
143165
)
144166

145167
users = [user for user in node.users if node != user]

0 commit comments

Comments
 (0)