Skip to content

Commit 69094af

Browse files
authored
Arm backend: Add index_copy test (#18283)
* Add index_copy tests, to check the op fully delegates (except for U55) * Add tests for inplace and out-of-place versions Change-Id: Ia1df838cee75eb4cc3061e6b4d58fba0a6bb910d cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell Signed-off-by: Tom Allsop <tom.allsop@arm.com>
1 parent d9b394a commit 69094af

File tree

2 files changed

+190
-1
lines changed

2 files changed

+190
-1
lines changed

backends/arm/scripts/collect_testname_resources.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"upsample_nearest2d.vec",
6262
"index_put.default",
6363
"conv_transpose2d.default",
64+
"index_copy.default",
6465
]
6566
_ALL_EDGE_OPS = _SAMPLE_INPUT.keys() | _CUSTOM_EDGE_OPS
6667

@@ -138,9 +139,16 @@ def _collect_arm_models(models_md: pathlib.Path) -> set[str]:
138139
def _normalize_op_name(edge_name: str) -> str:
139140
op, overload = edge_name.split(".")
140141

142+
# There are ops where we want to keep "copy" in the name
143+
# Add them in this list as we encounter them
144+
ignore_copy_list = {"index_copy"}
145+
141146
op = op.lower()
142147
op = op.removeprefix("_")
143-
op = op.removesuffix("_copy")
148+
149+
if op not in ignore_copy_list:
150+
op = op.removesuffix("_copy")
151+
144152
op = op.removesuffix("_with_indices")
145153

146154
overload = overload.lower()
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Tuple
6+
7+
import torch
8+
from executorch.backends.arm._passes import InsertInt32CastsAfterInt64PlaceholdersPass
9+
from executorch.backends.arm.test import common
10+
from executorch.backends.arm.test.tester.test_pipeline import (
11+
EthosU85PipelineINT,
12+
OpNotSupportedPipeline,
13+
TosaPipelineFP,
14+
TosaPipelineINT,
15+
VgfPipeline,
16+
)
17+
18+
input_t = Tuple[int, torch.Tensor, torch.LongTensor, torch.Tensor] # dim, x, index, y
19+
20+
21+
class IndexCopyModule(torch.nn.Module):
22+
base_test_data = {
23+
"rand_1d": lambda: (
24+
0,
25+
torch.rand((6,), dtype=torch.float32),
26+
torch.LongTensor([0, 2, 5]),
27+
torch.tensor([10.0, 20.0, 30.0], dtype=torch.float32),
28+
),
29+
"rand_3d": lambda: (
30+
0,
31+
torch.rand((4, 2, 3), dtype=torch.float32),
32+
torch.LongTensor([0, 3]),
33+
torch.ones((2, 2, 3), dtype=torch.float32),
34+
),
35+
"rand_3d_dim_1": lambda: (
36+
1,
37+
torch.rand((4, 2, 3), dtype=torch.float32),
38+
torch.LongTensor([0, 1]),
39+
torch.ones((4, 2, 3), dtype=torch.float32),
40+
),
41+
"rand_3d_dim_2": lambda: (
42+
2,
43+
torch.rand((4, 2, 3), dtype=torch.float32),
44+
torch.LongTensor([0]),
45+
torch.ones((4, 2, 1), dtype=torch.float32),
46+
),
47+
"rand_single_index": lambda: (
48+
0,
49+
torch.rand((4, 5), dtype=torch.float32),
50+
torch.LongTensor([0]),
51+
torch.zeros((1, 5), dtype=torch.float32),
52+
),
53+
"rand_single_index_not_zero": lambda: (
54+
0,
55+
torch.rand((4, 5), dtype=torch.float32),
56+
torch.LongTensor([2]),
57+
torch.zeros((1, 5), dtype=torch.float32),
58+
),
59+
"rand_all_rows": lambda: (
60+
0,
61+
torch.rand((3, 4), dtype=torch.float32),
62+
torch.LongTensor([0, 1, 2]),
63+
torch.ones((3, 4), dtype=torch.float32),
64+
),
65+
}
66+
67+
test_data = {
68+
f"{name}_{variant}": (
69+
lambda test_case=test_case, inplace=inplace: (test_case(), inplace)
70+
)
71+
for name, test_case in base_test_data.items()
72+
for variant, inplace in (
73+
("out_of_place", False),
74+
("in_place", True),
75+
)
76+
}
77+
78+
aten_ops = {
79+
False: ["torch.ops.aten.index_put.default"],
80+
True: ["torch.ops.aten.index_put_.default"],
81+
}
82+
exir_op = "executorch_exir_dialects_edge__ops_aten_index_put_default"
83+
84+
def __init__(self, inplace: bool = False):
85+
super().__init__()
86+
self.inplace = inplace
87+
88+
def forward(
89+
self, dim: int, x: torch.Tensor, index: torch.LongTensor, y: torch.Tensor
90+
):
91+
if self.inplace:
92+
return x.index_copy_(dim, index, y)
93+
return x.index_copy(dim, index, y)
94+
95+
96+
xfails_u85 = {
97+
"rand_single_index_not_zero_out_of_place": "MLETORCH-1949: index_copy (SCATTER/INDEX_PUT) produces incorrect results for non-zero indices on U85",
98+
"rand_single_index_not_zero_in_place": "MLETORCH-1949: index_copy (SCATTER/INDEX_PUT) produces incorrect results for non-zero indices on U85",
99+
}
100+
101+
102+
@common.parametrize("test_data", IndexCopyModule.test_data)
103+
def test_index_copy_tosa_FP(test_data):
104+
inputs, inplace = test_data()
105+
module = IndexCopyModule(inplace=inplace)
106+
pipeline = TosaPipelineFP(
107+
module=module,
108+
test_data=inputs,
109+
aten_op=[],
110+
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
111+
)
112+
pipeline.run()
113+
114+
115+
@common.parametrize("test_data", IndexCopyModule.test_data)
116+
def test_index_copy_tosa_INT(test_data):
117+
inputs, inplace = test_data()
118+
module = IndexCopyModule(inplace=inplace)
119+
pipeline = TosaPipelineINT(
120+
module=module,
121+
test_data=inputs,
122+
aten_op=IndexCopyModule.aten_ops[inplace],
123+
)
124+
pipeline.run()
125+
126+
127+
@common.parametrize("test_data", IndexCopyModule.test_data)
128+
def test_index_copy_u55_INT(test_data):
129+
inputs, inplace = test_data()
130+
# SCATTER (index_put) is not supported on U55
131+
pipeline = OpNotSupportedPipeline[input_t](
132+
IndexCopyModule(inplace=inplace),
133+
inputs,
134+
{IndexCopyModule.exir_op: 1},
135+
quantize=True,
136+
u55_subset=True,
137+
n_expected_delegates=0,
138+
)
139+
pipeline.run()
140+
141+
142+
@common.parametrize("test_data", IndexCopyModule.test_data, xfails=xfails_u85)
143+
@common.XfailIfNoCorstone320
144+
def test_index_copy_u85_INT(test_data):
145+
inputs, inplace = test_data()
146+
pipeline = EthosU85PipelineINT[input_t](
147+
IndexCopyModule(inplace=inplace),
148+
inputs,
149+
aten_ops=IndexCopyModule.aten_ops[inplace],
150+
)
151+
# int64 index inputs need to be cast to int32; _to_dim_order_copy is not delegated
152+
pipeline.tester.use_portable_ops = True
153+
pipeline.run()
154+
155+
156+
@common.parametrize("test_data", IndexCopyModule.test_data)
157+
@common.SkipIfNoModelConverter
158+
def test_index_copy_vgf_no_quant(test_data):
159+
inputs, inplace = test_data()
160+
pipeline = VgfPipeline[input_t](
161+
IndexCopyModule(inplace=inplace),
162+
inputs,
163+
aten_op=[],
164+
transform_passes=[InsertInt32CastsAfterInt64PlaceholdersPass()],
165+
quantize=False,
166+
)
167+
pipeline.run()
168+
169+
170+
@common.parametrize("test_data", IndexCopyModule.test_data)
171+
@common.SkipIfNoModelConverter
172+
def test_index_copy_vgf_quant(test_data):
173+
inputs, inplace = test_data()
174+
pipeline = VgfPipeline[input_t](
175+
IndexCopyModule(inplace=inplace),
176+
inputs,
177+
aten_op=IndexCopyModule.aten_ops[inplace],
178+
quantize=True,
179+
tosa_spec="TOSA-1.0+INT",
180+
)
181+
pipeline.run()

0 commit comments

Comments
 (0)