Skip to content

Commit f625262

Browse files
author
Nathanael See
committed
Update base for Update on "[ET-VK][int4] Wrap int4 linear calls with view_copy nodes to squeeze/unsqueeze inputs"
This is done automatically for full-precision linear/mm nodes in the graph at torch.export graph tracing time, but is not done for the int4 op. The new pass adds view_copy nodes, as there are subsequent passes which can fuse view_copy nodes if redundant, and convert view_copy nodes to squeeze/unsqueeze nodes. Differential Revision: [D69065866](https://our.internmc.facebook.com/intern/diff/D69065866/) [ghstack-poisoned]
2 parents d15bdb5 + 7805229 commit f625262

File tree

29 files changed

+983
-427
lines changed

29 files changed

+983
-427
lines changed

.buckconfig

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@
3333
**/.git, \
3434
cmake-out, \
3535
pip-out
36+
37+
[buck2]
38+
restarter=true

backends/arm/test/misc/test_tosa_spec.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@
2020
"TOSA-0.80+MI+8k",
2121
"TOSA-0.80+BI+u55",
2222
]
23-
test_valid_1_00_strings = [
24-
"TOSA-1.00.0+INT+FP+fft",
25-
"TOSA-1.00.0+FP+bf16+fft",
26-
"TOSA-1.00.0+INT+int4+cf",
27-
"TOSA-1.00.0+FP+cf+bf16+8k",
28-
"TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
29-
"TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
23+
test_valid_1_0_strings = [
24+
"TOSA-1.0.0+INT+FP+fft",
25+
"TOSA-1.0.0+FP+bf16+fft",
26+
"TOSA-1.0.0+INT+int4+cf",
27+
"TOSA-1.0.0+FP+cf+bf16+8k",
28+
"TOSA-1.0.0+FP+INT+bf16+fft+int4+cf",
29+
"TOSA-1.0.0+FP+INT+fft+int4+cf+8k",
30+
"TOSA-1.0+INT+FP+fft",
31+
"TOSA-1.0+FP+bf16+fft",
32+
"TOSA-1.0+INT+int4+cf",
33+
"TOSA-1.0+FP+cf+bf16+8k",
34+
"TOSA-1.0+FP+INT+bf16+fft+int4+cf",
35+
"TOSA-1.0+FP+INT+fft+int4+cf+8k",
3036
]
3137

32-
test_valid_1_00_extensions = {
38+
test_valid_1_0_extensions = {
3339
"INT": ["int16", "int4", "var", "cf"],
3440
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
3541
}
@@ -40,19 +46,19 @@
4046
"TOSA-0.80+8k",
4147
"TOSA-0.80+BI+MI",
4248
"TOSA-0.80+BI+U55",
43-
"TOSA-1.00.0+fft",
44-
"TOSA-1.00.0+fp+bf16+fft",
45-
"TOSA-1.00.0+INT+INT4+cf",
46-
"TOSA-1.00.0+BI",
47-
"TOSA-1.00.0+FP+FP+INT",
48-
"TOSA-1.00.0+FP+CF+bf16",
49-
"TOSA-1.00.0+BF16+fft+int4+cf+INT",
49+
"TOSA-1.0.0+fft",
50+
"TOSA-1.0.0+fp+bf16+fft",
51+
"TOSA-1.0.0+INT+INT4+cf",
52+
"TOSA-1.0.0+BI",
53+
"TOSA-1.0.0+FP+FP+INT",
54+
"TOSA-1.0.0+FP+CF+bf16",
55+
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
5056
]
5157

5258
test_compile_specs = [
5359
([CompileSpec("tosa_version", "TOSA-0.80+BI".encode())],),
5460
([CompileSpec("tosa_version", "TOSA-0.80+BI+u55".encode())],),
55-
([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
61+
([CompileSpec("tosa_version", "TOSA-1.0.0+INT".encode())],),
5662
]
5763

5864
test_compile_specs_no_version = [
@@ -70,8 +76,8 @@ def test_version_string_0_80(self, version_string: str):
7076
assert isinstance(tosa_spec, Tosa_0_80)
7177
assert tosa_spec.profile in ["BI", "MI"]
7278

73-
@parameterized.expand(test_valid_1_00_strings) # type: ignore[misc]
74-
def test_version_string_1_00(self, version_string: str):
79+
@parameterized.expand(test_valid_1_0_strings) # type: ignore[misc]
80+
def test_version_string_1_0(self, version_string: str):
7581
tosa_spec = TosaSpecification.create_from_string(version_string)
7682
assert isinstance(tosa_spec, Tosa_1_00)
7783
assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
@@ -80,7 +86,7 @@ def test_version_string_1_00(self, version_string: str):
8086

8187
for profile in tosa_spec.profiles:
8288
assert [
83-
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
89+
e in test_valid_1_0_extensions[profile] for e in tosa_spec.extensions
8490
]
8591

8692
@parameterized.expand(test_invalid_strings) # type: ignore[misc]
@@ -103,3 +109,15 @@ def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec])
103109
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
104110

105111
assert tosa_spec is None
112+
113+
@parameterized.expand(test_valid_0_80_strings)
114+
def test_correct_string_representation_0_80(self, version_string: str):
115+
tosa_spec = TosaSpecification.create_from_string(version_string)
116+
assert isinstance(tosa_spec, Tosa_0_80)
117+
assert f"{tosa_spec}" == version_string
118+
119+
@parameterized.expand(test_valid_1_0_strings)
120+
def test_correct_string_representation_1_0(self, version_string: str):
121+
tosa_spec = TosaSpecification.create_from_string(version_string)
122+
assert isinstance(tosa_spec, Tosa_1_00)
123+
assert f"{tosa_spec}" == version_string

backends/arm/tosa_specification.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -14,7 +14,9 @@
1414
import re
1515
from typing import List
1616

17-
from executorch.exir.backend.compile_spec_schema import CompileSpec
17+
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-untyped]
18+
CompileSpec,
19+
)
1820
from packaging.version import Version
1921

2022

@@ -131,7 +133,7 @@ def __init__(self, version: Version, extras: List[str]):
131133
def __repr__(self):
132134
extensions = ""
133135
if self.level_8k:
134-
extensions += "+8K"
136+
extensions += "+8k"
135137
if self.is_U55_subset:
136138
extensions += "+u55"
137139
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
@@ -207,7 +209,10 @@ def _get_extensions_string(self) -> str:
207209
return "".join(["+" + e for e in self.extensions])
208210

209211
def __repr__(self):
210-
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
212+
extensions = self._get_extensions_string()
213+
if self.level_8k:
214+
extensions += "+8k"
215+
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"
211216

212217
def __hash__(self) -> int:
213218
return hash(str(self.version) + self._get_profiles_string())

backends/cadence/aot/functions_hifi.yaml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::HiFi::full_out
74-
74+
7575
- op: gt.Scalar_out
7676
kernels:
7777
- arg_meta: null
78-
kernel_name: torch::executor::gt_scalar_out
78+
kernel_name: torch::executor::gt_scalar_out
7979

8080
- op: gelu.out
8181
kernels:
@@ -100,7 +100,7 @@
100100
- op: mean.out
101101
kernels:
102102
- arg_meta: null
103-
kernel_name: cadence::impl::HiFi::mean_dim_out
103+
kernel_name: cadence::impl::HiFi::mean_dim_out
104104

105105
- op: minimum.out
106106
kernels:
@@ -213,3 +213,13 @@
213213
kernels:
214214
- arg_meta: null
215215
kernel_name: cadence::impl::HiFi::quantized_linear_per_tensor_out
216+
217+
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
218+
kernels:
219+
- arg_meta: null
220+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
221+
222+
- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
223+
kernels:
224+
- arg_meta: null
225+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out

backends/cadence/hifi/operators/op_clamp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace impl {
4848
namespace HiFi {
4949
namespace native {
5050

51-
Tensor& clamp_tensor_out(
51+
Tensor& clamp_Tensor_out(
5252
RuntimeContext& ctx,
5353
const Tensor& in,
5454
const executorch::aten::optional<Tensor>& min_opt,

0 commit comments

Comments
 (0)