Skip to content

Commit 7e7cbd1

Browse files
committed
Arm backend: Fix meandim decomp bug
Wrong indexing caused the 4th dim to be skipped when reshaping 5D tensors. Additionally add unittest which covers this case. Signed-off-by: Adrian Lundell <[email protected]> Change-Id: Ia7addeec7e4c01b2afde41e18a0b6ce932832238
1 parent c52d0bf commit 7e7cbd1

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def call_operator(self, op, args, kwargs, meta):
105105

106106
# Reshape back to 5D if necessary
107107
if len(input_shape) > 4:
108-
original_dims = input_shape[0:-4]
108+
original_dims = input_shape[0:-3]
109109
temp_shape = list(x.data.shape)[1:]
110110
temp_shape = original_dims + temp_shape
111111
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]

backends/arm/test/ops/test_mean_dim.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ class MeanDim(torch.nn.Module):
210210
(1, 2),
211211
False,
212212
),
213+
"rank5_2": lambda: (
214+
torch.rand(1, 4, 7, 3, 2),
215+
(2),
216+
False,
217+
),
213218
"u55_avg_pool_not_supported": lambda: (
214219
torch.rand(1, 1, 1, 257),
215220
(0, 1, 2, 3),
@@ -255,6 +260,7 @@ def test_mean_dim_tosa_BI(test_data):
255260
"rank5_01234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
256261
"rank5_234": "Rank 5 graph input currently not supported in EthosUBackend (passes since CHW are all averaged over so data order does not matter in this case)",
257262
"rank5_12": "Rank 5 graph input currently not supported in EthosUBackend",
263+
"rank5_2": "Rank 5 graph input currently not supported in EthosUBackend",
258264
}
259265

260266

0 commit comments

Comments
 (0)