Skip to content

Commit bd4455c

Browse files
committed
Update on "[ET-VK] Improve packing format for int4 linear operator + misc improvements"
## Context Improve performance of the quantized int4 linear shader by packing the scales and zeros tensor, as well as the weight tensor in a more optimal way. See the comments in the `pack_int4_linear_weight_transposed_interleave` shader for more details about how the new packing works. ## Changes * Split int8 quantized linear and int4 quantized linear into separate C++ files for better code organization * Introduce packing shader for int4 weights * Update int4 linear shader to account for packed weights ## Impact This change massively improves the performance of the weight int4 quantized linear operator. With this change, running LLaMa 3.2 1B can now achieve 10 tok/s, from 0.9 tok/s on an Adreno 740. This is a 10x improvement! With this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 332.3 MB/s (74692800 bytes in 0.214s) I 00:00:00.003353 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003533 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003563 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003685 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003747 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003799 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003852 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003902 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003976 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004289 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:04.841690 executorch:runner.cpp:101] Reading metadata from model I 00:00:04.841808 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:04.841830 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:04.841851 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:04.841874 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:04.841893 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:04.841909 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:04.841927 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:04.841945 executorch:runner.cpp:133] eos_id = 128009 I 00:00:04.841951 executorch:runner.cpp:133] eos_id = 128001 I 00:00:04.841963 executorch:runner.cpp:188] RSS after loading model: 2229.828125 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:06.239633 executorch:runner.cpp:258] RSS after prompt prefill: 2229.828125 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:00:17.699086 executorch:runner.cpp:272] RSS after finishing text generation: 2229.828125 MiB (0 if unsupported) I 00:00:17.699155 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:00:17.699161 executorch:stats.h:114] Model Load Time: 4.837000 (seconds) I 00:00:17.699165 executorch:stats.h:124] Total inference time: 12.857000 (seconds) Rate: 8.788987 (tokens/second) I 00:00:17.699168 executorch:stats.h:132] Prompt evaluation: 1.398000 (seconds) Rate: 10.014306 (tokens/second) I 00:00:17.699171 executorch:stats.h:143] Generated 113 tokens: 11.459000 (seconds) Rate: 9.861244 (tokens/second) I 00:00:17.699174 executorch:stats.h:151] Time to first generated token: 1.398000 (seconds) I 00:00:17.699177 executorch:stats.h:158] Sampling time over 127 tokens: 549246500.843000 (seconds) ``` Before this change: ``` /home/ssjia/scratch/bin/app_bin: 1 file pushed, 0 skipped. 302.0 MB/s (74637464 bytes in 0.236s) I 00:00:00.003050 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003200 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003226 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1 I 00:00:00.003337 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu1/regs/identification/midr_el1 I 00:00:00.003396 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu2/regs/identification/midr_el1 I 00:00:00.003449 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu3/regs/identification/midr_el1 I 00:00:00.003502 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu4/regs/identification/midr_el1 I 00:00:00.003553 executorch:cpuinfo_utils.cpp:91] Reading file /sys/devices/system/cpu/cpu5/regs/identification/midr_el1 I 00:00:00.003629 executorch:main.cpp:69] Resetting threadpool with num threads = 6 I 00:00:00.004075 executorch:runner.cpp:68] Creating LLaMa runner: model_path=/data/local/tmp/llama3-1b/vk/llama3.pte, tokenizer_path=/data/local/tmp/tokenizer.model I 00:00:05.417531 executorch:runner.cpp:101] Reading metadata from model I 00:00:05.417647 executorch:runner.cpp:126] Metadata: get_vocab_size = 128256 I 00:00:05.417669 executorch:runner.cpp:126] Metadata: get_bos_id = 128000 I 00:00:05.417698 executorch:runner.cpp:126] Metadata: use_sdpa_with_kv_cache = 1 I 00:00:05.417716 executorch:runner.cpp:126] Metadata: use_kv_cache = 1 I 00:00:05.417735 executorch:runner.cpp:126] Metadata: get_max_context_len = 128 I 00:00:05.417751 executorch:runner.cpp:126] Metadata: get_max_seq_len = 128 I 00:00:05.417768 executorch:runner.cpp:126] Metadata: enable_dynamic_shape = 0 I 00:00:05.417787 executorch:runner.cpp:133] eos_id = 128009 I 00:00:05.417793 executorch:runner.cpp:133] eos_id = 128001 I 00:00:05.417808 executorch:runner.cpp:188] RSS after loading model: 2230.812500 MiB (0 if unsupported) <|begin_of_text|><|start_header_id|>system<|end_header_id|>Tell me a short story.<|eot_id|><|start_header_id|>assistant<|end_header_id|> I 00:00:19.689616 executorch:runner.cpp:258] RSS after prompt prefill: 2230.812500 MiB (0 if unsupported) Here's a short story for you: **The Library of Lost Memories** In a small, dusty town nestled between two great rivers, there was a library that held the secrets of the past. It was a place where memories were stored, not retrieved, and the librarians were the guardians of the past. The library was called the Library of Lost Memories, and it was said that anyone who entered its doors would be given a glimpse into the memories of those who had come before. The librarians were wise and kind, and they would only allow those who wereI 00:02:15.269693 executorch:runner.cpp:272] RSS after finishing text generation: 2230.812500 MiB (0 if unsupported) I 00:02:15.269810 executorch:stats.h:108] Prompt Tokens: 14 Generated Tokens: 113 I 00:02:15.269825 executorch:stats.h:114] Model Load Time: 5.414000 (seconds) I 00:02:15.269832 executorch:stats.h:124] Total inference time: 129.852000 (seconds) Rate: 0.870221 (tokens/second) I 00:02:15.269837 executorch:stats.h:132] Prompt evaluation: 14.271000 (seconds) Rate: 0.981010 (tokens/second) I 00:02:15.269841 executorch:stats.h:143] Generated 113 tokens: 115.581000 (seconds) Rate: 0.977669 (tokens/second) I 00:02:15.269844 executorch:stats.h:151] Time to first generated token: 14.271000 (seconds) I 00:02:15.269847 executorch:stats.h:158] Sampling time over 127 tokens: 549711269.115000 (seconds) PyTorchObserver {"prompt_tokens":14,"generated_tokens":113,"model_load_start_ms":1743712527974,"model_load_end_ms":1743712533388,"inference_start_ms":1743712533388,"inference_end_ms":1743712663240,"prompt_eval_end_ms":1743712547659,"first_token_ms":1743712547659,"aggregate_sampling_time_ms":549711269115,"SCALING_FACTOR_UNITS_PER_SECOND":1000} ``` Differential Revision: [D72412950](https://our.internmc.facebook.com/intern/diff/D72412950/) [ghstack-poisoned]
2 parents 7671891 + 8c55765 commit bd4455c

File tree

66 files changed

+2807
-1006
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2807
-1006
lines changed

.github/release.yml

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# .github/release.yml
2+
3+
changelog:
4+
exclude:
5+
labels:
6+
- ignore-for-release
7+
categories:
8+
- title: Breaking Changes 🛠
9+
labels:
10+
- Semver-Major
11+
- breaking-change
12+
- title: API
13+
labels:
14+
- "release notes: api"
15+
- title: ARM
16+
labels:
17+
- "release notes: arm"
18+
- title: NXP
19+
labels:
20+
- "release notes: nxp"
21+
- title: Exir
22+
labels:
23+
- "release notes: exir"
24+
- title: Misc
25+
labels:
26+
- "release notes: misc"
27+
- title: Apple
28+
labels:
29+
- "release notes: apple"
30+
- title: Build
31+
labels:
32+
- "release notes: build"
33+
- title: Vulkan
34+
labels:
35+
- "release notes: vulkan"
36+
- title: Cadence
37+
labels:
38+
- "release notes: cadence"
39+
- title: Runtime
40+
labels:
41+
- "release notes: runtime"
42+
- title: XNNPACK
43+
labels:
44+
- "release notes: xnnpack"
45+
- title: Devtools
46+
labels:
47+
- "release notes: devtools"
48+
- title: Examples
49+
labels:
50+
- "release notes: examples"
51+
- title: Mediatek
52+
labels:
53+
- "release notes: mediatek"
54+
- title: Openvino
55+
labels:
56+
- "release notes: openvino"
57+
- title: Qualcomm
58+
labels:
59+
- "release notes: qualcomm"
60+
- title: Training
61+
labels:
62+
- "release notes: training"
63+
- title: Quantization
64+
labels:
65+
- "release notes: quantization"
66+
- title: Ops & kernels
67+
labels:
68+
- "release notes: ops & kernels"
69+
- title: Other Changes
70+
labels:
71+
- "*"

.github/scripts/run_nm.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import re
8+
import subprocess
9+
import sys
10+
from dataclasses import dataclass
11+
from typing import Dict, List, Optional, Union
12+
13+
14+
@dataclass
15+
class Symbol:
16+
name: str
17+
addr: int
18+
size: int
19+
symbol_type: str
20+
21+
22+
class Parser:
23+
def __init__(self, elf: str, toolchain_prefix: str = "", filter=None):
24+
self.elf = elf
25+
self.toolchain_prefix = toolchain_prefix
26+
self.symbols: Dict[str, Symbol] = self._get_nm_output()
27+
self.filter = filter
28+
29+
@staticmethod
30+
def run_nm(
31+
elf_file_path: str, args: Optional[List[str]] = None, nm: str = "nm"
32+
) -> str:
33+
"""
34+
Run the nm command on the specified ELF file.
35+
"""
36+
args = [] if args is None else args
37+
cmd = [nm] + args + [elf_file_path]
38+
try:
39+
result = subprocess.run(cmd, check=True, capture_output=True, text=True)
40+
return result.stdout
41+
except FileNotFoundError:
42+
print(f"Error: 'nm' command not found. Please ensure it's installed.")
43+
sys.exit(1)
44+
except subprocess.CalledProcessError as e:
45+
print(f"Error running nm on {elf_file_path}: {e}")
46+
print(f"stderr: {e.stderr}")
47+
sys.exit(1)
48+
49+
def _get_nm_output(self) -> Dict[str, Symbol]:
50+
args = [
51+
"--print-size",
52+
"--size-sort",
53+
"--reverse-sort",
54+
"--demangle",
55+
"--format=bsd",
56+
]
57+
output = Parser.run_nm(
58+
self.elf,
59+
args,
60+
nm=self.toolchain_prefix + "nm" if self.toolchain_prefix else "nm",
61+
)
62+
lines = output.splitlines()
63+
symbols = []
64+
symbol_pattern = re.compile(
65+
r"(?P<addr>[0-9a-fA-F]+)\s+(?P<size>[0-9a-fA-F]+)\s+(?P<type>\w)\s+(?P<name>.+)"
66+
)
67+
68+
def parse_line(line: str) -> Optional[Symbol]:
69+
70+
match = symbol_pattern.match(line)
71+
if match:
72+
addr = int(match.group("addr"), 16)
73+
size = int(match.group("size"), 16)
74+
type_ = match.group("type").strip().strip("\n")
75+
name = match.group("name").strip().strip("\n")
76+
return Symbol(name=name, addr=addr, size=size, symbol_type=type_)
77+
return None
78+
79+
for line in lines:
80+
symbol = parse_line(line)
81+
if symbol:
82+
symbols.append(symbol)
83+
84+
assert len(symbols) > 0, "No symbols found in nm output"
85+
if len(symbols) != len(lines):
86+
print(
87+
"** Warning: Not all lines were parsed, check the output of nm. Parsed {len(symbols)} lines, given {len(lines)}"
88+
)
89+
if any(symbol.size == 0 for symbol in symbols):
90+
print("** Warning: Some symbols have zero size, check the output of nm.")
91+
92+
# TODO: Populate the section and module fields from the linker map if available (-Wl,-Map=linker.map)
93+
return {symbol.name: symbol for symbol in symbols}
94+
95+
def print(self):
96+
print(f"Elf: {self.elf}")
97+
98+
def print_table(filter=None, filter_name=None):
99+
print("\nAddress\t\tSize\tType\tName")
100+
# Apply filter and sort symbols
101+
symbols_to_print = {
102+
name: sym
103+
for name, sym in self.symbols.items()
104+
if not filter or filter(sym)
105+
}
106+
sorted_symbols = sorted(
107+
symbols_to_print.items(), key=lambda x: x[1].size, reverse=True
108+
)
109+
110+
# Print symbols and calculate total size
111+
size_total = 0
112+
for name, sym in sorted_symbols:
113+
print(f"{hex(sym.addr)}\t\t{sym.size}\t{sym.symbol_type}\t{sym.name}")
114+
size_total += sym.size
115+
116+
# Print summary
117+
symbol_percent = len(symbols_to_print) / len(self.symbols) * 100
118+
print("-----")
119+
print(f"> Total bytes: {size_total}")
120+
print(
121+
f"Counted: {len(symbols_to_print)}/{len(self.symbols)}, {symbol_percent:0.2f}% (filter: '{filter_name}')"
122+
)
123+
print("=====\n")
124+
125+
# Print tables with different filters
126+
def is_executorch_symbol(s):
127+
return "executorch" in s.name or s.name.startswith("et")
128+
129+
FILTER_NAME_TO_FILTER_AND_LABEL = {
130+
"all": (None, "All"),
131+
"executorch": (is_executorch_symbol, "ExecuTorch"),
132+
"executorch_text": (
133+
lambda s: is_executorch_symbol(s) and s.symbol_type.lower() == "t",
134+
"ExecuTorch .text",
135+
),
136+
}
137+
138+
filter_func, label = FILTER_NAME_TO_FILTER_AND_LABEL.get(
139+
self.filter, FILTER_NAME_TO_FILTER_AND_LABEL["all"]
140+
)
141+
print_table(filter_func, label)
142+
143+
144+
if __name__ == "__main__":
145+
import argparse
146+
147+
parser = argparse.ArgumentParser(
148+
description="Process ELF file and linker map file."
149+
)
150+
parser.add_argument(
151+
"-e", "--elf-file-path", required=True, help="Path to the ELF file"
152+
)
153+
parser.add_argument(
154+
"-f",
155+
"--filter",
156+
required=False,
157+
default="all",
158+
help="Filter symbols by pre-defined filters",
159+
choices=["all", "executorch", "executorch_text"],
160+
)
161+
parser.add_argument(
162+
"-p",
163+
"--toolchain-prefix",
164+
required=False,
165+
default="",
166+
help="Optional toolchain prefix for nm",
167+
)
168+
169+
args = parser.parse_args()
170+
p = Parser(args.elf_file_path, args.toolchain_prefix, filter=args.filter)
171+
p.print()

.github/workflows/trunk.yml

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,60 @@ jobs:
231231
# Run arm unit tests using the simulator
232232
backends/arm/test/test_arm_baremetal.sh test_pytest_ethosu_fvp
233233
234+
test-arm-cortex-m-size-test:
235+
name: test-arm-cortex-m-size-test
236+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
237+
permissions:
238+
id-token: write
239+
contents: read
240+
with:
241+
runner: linux.2xlarge
242+
docker-image: executorch-ubuntu-22.04-arm-sdk
243+
submodules: 'true'
244+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
245+
timeout: 90
246+
script: |
247+
# The generic Linux job chooses to use base env, not the one setup by the image
248+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
249+
conda activate "${CONDA_ENV}"
250+
251+
source .ci/scripts/utils.sh
252+
install_executorch "--use-pt-pinned-commit"
253+
.ci/scripts/setup-arm-baremetal-tools.sh
254+
source examples/arm/ethos-u-scratch/setup_path.sh
255+
256+
# User baremetal toolchain
257+
arm-none-eabi-c++ --version
258+
toolchain_cmake=examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake
259+
toolchain_cmake=$(realpath ${toolchain_cmake})
260+
261+
# Build and test size test
262+
bash test/build_size_test.sh "-DCMAKE_TOOLCHAIN_FILE=${toolchain_cmake} -DEXECUTORCH_BUILD_ARM_BAREMETAL=ON"
263+
elf="cmake-out/test/size_test"
264+
265+
# Dump basic info
266+
ls -al ${elf}
267+
arm-none-eabi-size ${elf}
268+
269+
# Dump symbols
270+
python .github/scripts/run_nm.py -e ${elf}
271+
python .github/scripts/run_nm.py -e ${elf} -f "executorch" -p "arm-none-eabi-"
272+
python .github/scripts/run_nm.py -e ${elf} -f "executorch_text" -p "arm-none-eabi-"
273+
274+
# Add basic guard - TODO: refine this!
275+
arm-none-eabi-strip ${elf}
276+
output=$(ls -la ${elf})
277+
arr=($output)
278+
size=${arr[4]}
279+
threshold="102400" # 100KiB
280+
echo "size: $size, threshold: $threshold"
281+
if [[ "$size" -le "$threshold" ]]; then
282+
echo "Success $size <= $threshold"
283+
else
284+
echo "Fail $size > $threshold"
285+
exit 1
286+
fi
287+
234288
test-coreml-delegate:
235289
name: test-coreml-delegate
236290
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
2222
from .decompose_div_pass import DecomposeDivPass # noqa
2323
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
24+
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
2425
from .decompose_linear_pass import DecomposeLinearPass # noqa
2526
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
2627
from .decompose_select import DecomposeSelectPass # noqa
@@ -39,6 +40,7 @@
3940
from .insert_table_ops import InsertTableOpsPass # noqa
4041
from .keep_dims_false_to_squeeze_pass import KeepDimsFalseToSqueezePass # noqa
4142
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
43+
from .match_where_self_arg_dtype_pass import MatchWhereSelfDtypePass # noqa
4244
from .meandim_to_averagepool_pass import ConvertMeanDimToAveragePoolPass # noqa
4345
from .mm_to_bmm_pass import ConvertMmToBmmPass # noqa
4446
from .remove_clone_pass import RemoveClonePass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DecomposeBatchNormPass,
2727
DecomposeDivPass,
2828
DecomposeLayerNormPass,
29+
DecomposeLeakyReLUPass,
2930
DecomposeLinearPass,
3031
DecomposeMeanDimPass,
3132
DecomposeSelectPass,
@@ -40,6 +41,7 @@
4041
InsertTableOpsPass,
4142
KeepDimsFalseToSqueezePass,
4243
MatchArgRanksPass,
44+
MatchWhereSelfDtypePass,
4345
QuantizeOperatorArguments,
4446
RemoveClonePass,
4547
ReplaceScalarWithTensorArgPassTOSABI,
@@ -80,6 +82,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8082
self.add_pass(ConvertToClampPass())
8183
self.add_pass(ConvertMinMaxPass())
8284
self.add_pass(ConvertAnyDefaultDimDimsPass())
85+
self.add_pass(MatchWhereSelfDtypePass())
8386
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
8487
self.add_pass(CastToInt32Pass())
8588

@@ -119,6 +122,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
119122
self.add_pass(FuseBatchnorm2DPass(exported_program))
120123
self.add_pass(ConvertMmToBmmPass())
121124
self.add_pass(DecomposeLinearPass())
125+
self.add_pass(DecomposeLeakyReLUPass())
122126
self.add_pass(DecomposeBatchNormPass())
123127
self.add_pass(DecomposeLayerNormPass())
124128
self.add_pass(DecomposeVarPass())
@@ -130,6 +134,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
130134
self.add_pass(ConvertToClampPass())
131135
self.add_pass(ConvertMinMaxPass())
132136
self.add_pass(ConvertAnyDefaultDimDimsPass())
137+
self.add_pass(MatchWhereSelfDtypePass())
133138

134139
self.add_pass(AnnotateDecomposedMatmulPass())
135140
self.add_pass(QuantizeOperatorArguments())
@@ -175,6 +180,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
175180
self.add_pass(DecomposeVarPass())
176181
self.add_pass(DecomposeMeanDimPass())
177182
self.add_pass(DecomposeDivPass())
183+
self.add_pass(DecomposeLeakyReLUPass())
178184

179185
if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset:
180186
# Numerically stable softmax uses amax which is not supported on Ethos-U55

backends/arm/_passes/arm_pass_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
import torch.fx
16+
from executorch.backends.arm.tosa_utils import get_node_debug_info
1617
from executorch.exir import ExportedProgram
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819

@@ -169,9 +170,13 @@ def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
169170
else:
170171
fake_tensor = node.meta["val"]
171172

172-
assert isinstance(
173-
fake_tensor, FakeTensor
174-
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
173+
if not isinstance(fake_tensor, FakeTensor):
174+
raise TypeError(
175+
f'Expected a FakeTensor in meta["val"] of node {node}, but got '
176+
f"{type(fake_tensor).__name__}\n"
177+
f"{get_node_debug_info(node)}"
178+
)
179+
175180
return fake_tensor
176181

177182

0 commit comments

Comments
 (0)