Skip to content

Commit 6086781

Browse files
authored
Merge branch 'main' into fix_docs
2 parents 1370ea8 + 1e97232 commit 6086781

27 files changed

+849
-285
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT
6+
7+
# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here.
8+
CUSTOM_EDGE_OPS = ["linspace.default", "eye.default"]
9+
ALL_EDGE_OPS = SAMPLE_INPUT.keys() | CUSTOM_EDGE_OPS
10+
11+
# Add all targets and TOSA profiles we support here.
12+
TARGETS = {"tosa_BI", "tosa_MI", "u55_BI", "u85_BI"}
13+
14+
15+
def get_edge_ops():
16+
"""
17+
Returns a set with edge_ops with names on the form to be used in unittests:
18+
1. Names are in lowercase.
19+
2. Overload is ignored if it is 'default', otherwise its appended with an underscore.
20+
3. Overly verbose name are shortened by removing certain prefixes/suffixes.
21+
22+
Examples:
23+
abs.default -> abs
24+
split_copy.Tensor -> split_tensor
25+
"""
26+
edge_ops = set()
27+
for edge_name in ALL_EDGE_OPS:
28+
op, overload = edge_name.split(".")
29+
30+
# Normalize names
31+
op = op.lower()
32+
op = op.removeprefix("_")
33+
op = op.removesuffix("_copy")
34+
op = op.removesuffix("_with_indices")
35+
op = op.removesuffix("_no_training")
36+
overload = overload.lower()
37+
38+
if overload == "default":
39+
edge_ops.add(op)
40+
else:
41+
edge_ops.add(f"{op}_{overload}")
42+
43+
return edge_ops
44+
45+
46+
def parse_test_name(test_name: str, edge_ops: set[str]) -> tuple[str, str, bool]:
47+
"""
48+
Parses a test name on the form
49+
test_OP_TARGET_<not_delegated>_<any_other_info>
50+
where OP must match a string in edge_ops and TARGET must match one string in TARGETS.
51+
The "not_delegated" suffix indicates that the test tests that the op is not delegated.
52+
53+
Examples of valid names: "test_mm_u55_BI_not_delegated" or "test_add_scalar_tosa_MI_two_inputs".
54+
55+
Returns a tuple (OP, TARGET, IS_DELEGATED) if valid.
56+
"""
57+
test_name = test_name.removeprefix("test_")
58+
is_delegated = "not_delegated" not in test_name
59+
assert (
60+
"reject" not in test_name
61+
), f"Use 'not_delegated' instead of 'reject' in {test_name}"
62+
63+
op = "None"
64+
target = "None"
65+
for potential_target in TARGETS:
66+
index = test_name.find(potential_target)
67+
if index != -1:
68+
op = test_name[: index - 1]
69+
target = potential_target
70+
break
71+
# Special case for convolution
72+
op = op.removesuffix("_1d")
73+
op = op.removesuffix("_2d")
74+
75+
assert target != "None", f"{test_name} does not contain one of {TARGETS}"
76+
assert (
77+
op in edge_ops
78+
), f"Parsed unvalid OP from {test_name}, {op} does not exist in edge.yaml or CUSTOM_EDGE_OPS"
79+
80+
return op, target, is_delegated
81+
82+
83+
if __name__ == "__main__":
84+
"""Parses a list of test names given on the commandline."""
85+
import sys
86+
87+
sys.tracebacklimit = 0 # Do not print stack trace
88+
89+
edge_ops = get_edge_ops()
90+
exit_code = 0
91+
92+
for test_name in sys.argv[1:]:
93+
try:
94+
assert test_name[:5] == "test_", f"Unexpected input: {test_name}"
95+
parse_test_name(test_name, edge_ops)
96+
except AssertionError as e:
97+
print(e)
98+
exit_code = 1
99+
else:
100+
print(f"{test_name} OK")
101+
102+
sys.exit(exit_code)

backends/arm/scripts/pre-push

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,44 @@ for COMMIT in ${COMMITS}; do
166166
fi
167167
fi
168168

169+
# Op test checks
170+
op_test_files=$(echo $commit_files | grep -oE 'backends/arm/test/ops/\S+')
171+
if [ "$op_test_files" ]; then
172+
173+
# TODO: These checks can be removed when all unittests are refactored.
174+
if grep -icq "SkipIfNoCorstone" $op_test_files; then
175+
echo -e "${ERROR} @SkipIfNoCorstone300/320 is deprecated;"\
176+
"please use XfailIfNoCorstone300/320 instead." >&2
177+
FAILED=1
178+
fi
179+
180+
if grep -icq "conftest.expectedFailureOnFVP" $op_test_files; then
181+
echo -e "${ERROR} @conftest.expectedFailureOnFVP is deprecated;"\
182+
"please use XfailIfCorstone300/320 instead." >&2
183+
FAILED=1
184+
fi
185+
186+
if grep -icq "unittest.TestCase" $op_test_files; then
187+
echo -e "${ERROR} Use of the Unittest test framework is deprecated;"\
188+
"please use Pytest instead." >&2
189+
FAILED=1
190+
fi
191+
192+
if grep -icq "on_fvp(" $op_test_files; then
193+
echo -e "${ERROR} All unittests should run on FVP if relevant,"\
194+
"on_fvp suffix can be excluded." >&2
195+
FAILED=1
196+
fi
197+
198+
# Check that the tested op and target is parsed correctly from the test name
199+
test_names=$(grep -h "def test_" $op_test_files | cut -d"(" -f1 | cut -d" " -f2)
200+
python ./backends/arm/scripts/parse_test_names.py $test_names
201+
if [ $? -ne 0 ]; then
202+
echo -e "${ERROR} Failed op test name check." >&2
203+
FAILED=1
204+
fi
205+
fi
206+
169207
echo "" # Newline to visually separate commit processing
170208
done
171209

backends/cadence/aot/ops_registrations.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,15 @@
293293
"attention_mask.out(Tensor input, Tensor start, Tensor stop, *, Tensor(a!) out) -> Tensor(a!)"
294294
)
295295

296+
# Custom ops in aten namespace. RMSNorm is usually decomposed, so having
297+
# an out-variant is non-standard
298+
299+
lib_aten = Library("aten", "FRAGMENT")
300+
301+
lib_aten.define(
302+
"rms_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)"
303+
)
304+
296305

297306
@register_fake("cadence::quantize_per_tensor")
298307
def quantize_per_tensor_meta(
@@ -619,15 +628,6 @@ def linalg_vector_norm_meta(
619628
return X.new_empty([], dtype=X.dtype)
620629

621630

622-
@register_fake("cadence::rms_norm")
623-
def rms_norm_meta(
624-
X: torch.Tensor,
625-
eps: float,
626-
weight: torch.Tensor,
627-
) -> torch.Tensor:
628-
return X.new_empty(X.shape, dtype=X.dtype)
629-
630-
631631
@register_fake("cadence::requantize")
632632
def requantize_meta(
633633
input: torch.Tensor,

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,6 @@ def _vk_replace_linear_int4(
118118
# Use custom vulkan linear layer as default
119119
linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
120120
copy_weights: bool = False,
121-
# Serves the same purpose as `tensor_dim_limit` in
122-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
123-
feature_limit: int = 16384,
124121
):
125122
for name, child in module.named_children():
126123
if isinstance(child, torch.nn.Linear) and (
@@ -131,8 +128,6 @@ def _vk_replace_linear_int4(
131128
if (
132129
_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
133130
or padding_allowed
134-
) and (
135-
child.out_features < feature_limit and child.in_features < feature_limit
136131
):
137132
new_linear = linear_class(
138133
child.in_features,
@@ -175,7 +170,6 @@ def __init__(
175170
inner_k_tiles: Optional[int] = 8,
176171
device: torch.device = torch.device("cpu"), # noqa
177172
precision: torch.dtype = torch.float32,
178-
feature_limit: int = 16384,
179173
) -> None:
180174
super().__init__()
181175
assert inner_k_tiles in [2, 4, 8]
@@ -186,9 +180,6 @@ def __init__(
186180
self.padding_allowed: bool = padding_allowed
187181
self.device: torch.device = device
188182
self.precision: torch.dtype = precision
189-
# Serves the same purpose as `tensor_dim_limit` in
190-
# executorch.backends.vulkan.partitioner.VulkanSupportedOperators
191-
self.feature_limit = feature_limit
192183

193184
@torch.no_grad()
194185
def _create_quantized_state_dict(
@@ -197,10 +188,7 @@ def _create_quantized_state_dict(
197188
cur_state_dict = model.state_dict()
198189
for fqn, mod in model.named_modules():
199190
# Add additional check to make sure features do not exceed feature limit
200-
if isinstance(mod, torch.nn.Linear) and (
201-
mod.out_features < self.feature_limit
202-
and mod.in_features < self.feature_limit
203-
):
191+
if isinstance(mod, torch.nn.Linear):
204192
out_features = mod.out_features
205193
in_features = mod.in_features
206194
logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")

backends/vulkan/_passes/squeeze_unsqueeze_inputs.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,25 +27,38 @@ class SqueezeUnsqueezeInputs(ExportPass):
2727
exir_ops.edge.aten.gelu.default,
2828
}
2929

30+
def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore
31+
if len(shape) == 3:
32+
return shape[1] == 1 and shape[0] > 1
33+
if len(shape) == 4:
34+
# No need to squeeze if all dims are 1 except the width dim
35+
if all(dim == 1 for dim in shape[:-1]):
36+
return False
37+
# Otherwise, check for squeezable dim
38+
return 1 in shape[:-1]
39+
40+
# Prefer not to introduce additional orchestration ops by default
41+
return False
42+
3043
def call_operator(
3144
self,
3245
op, # pyre-ignore
3346
args: Tuple[Argument, ...],
3447
kwargs: Dict[str, Argument],
3548
meta: NodeMetadata,
3649
) -> ProxyValue:
37-
def _squeezable(shape: List[int]) -> bool:
38-
return len(shape) > 2 and 1 in shape
39-
4050
if op not in self._squeezable_ops:
4151
return super().call_operator(op, args, kwargs, meta)
42-
4352
# pyre-ignore[16]: `None` has no attribute `node`
4453
input_shape = args[0].node.meta["val"].shape
4554
output_shape = meta["val"].shape
46-
if not _squeezable(input_shape):
55+
56+
if not self.should_squeeze(op, input_shape):
4757
return super().call_operator(op, args, kwargs, meta)
4858

59+
def _squeezable(shape: List[int]) -> bool:
60+
return len(shape) > 2 and 1 in shape
61+
4962
# squeeze input tensor
5063
squeeze_shape = list(input_shape)
5164
while _squeezable(squeeze_shape):

backends/vulkan/op_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def register_int8_mm_op(features: OpFeatures):
393393

394394
@update_features(exir_ops.edge.et_vk.linear_weight_int4.default)
395395
def register_int4_mm_op(features: OpFeatures):
396+
features.buffer_impl = True
396397
features.texture_impl = TextureImplFeatures(
397398
uses_axis_map=False,
398399
valid_packed_dims={PackedDim.WIDTH},
@@ -401,6 +402,7 @@ def register_int4_mm_op(features: OpFeatures):
401402
features.optimal_storage = VkStorageType.TEXTURE_3D
402403
features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED
403404
features.handles_own_prepacking = True
405+
features.skip_limits_check = {1}
404406
return features
405407

406408

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -497,9 +497,7 @@ vTensor::vTensor(
497497
VK_CHECK_COND(
498498
dim_order_is_valid(dim_order_), "computed dim order is invalid");
499499

500-
if (storage_type != utils::kBuffer) {
501-
set_logical_limits(storage_.image_extents_);
502-
}
500+
set_logical_limits(storage_.image_extents_);
503501
}
504502

505503
// NOLINTNEXTLINE

0 commit comments

Comments
 (0)