Skip to content

Commit 511b4fb

Browse files
Arm backend: Cast int64 buffers to int32 (#18234)
Cast int64 buffers using get_attr in the same way as regular placeholders. --------- Signed-off-by: Adrian Lundell <adrian.lundell@arm.com>
1 parent b6b701e commit 511b4fb

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

backends/arm/_passes/insert_int32_casts_after_int64_placeholders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,10 @@ def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModul
8989
modified = False
9090
graph = graph_module.graph
9191
for node in graph.nodes:
92-
if node.op != "placeholder":
92+
if node.op not in ("placeholder", "get_attr"):
9393
continue
94+
if "val" not in node.meta:
95+
continue # Ignore submodule get_attrs
9496
node_val = node.meta["val"]
9597
if not self._is_tensor_of_dtype(node_val, torch.int64):
9698
continue

backends/arm/test/ops/test_clamp.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ def test_clamp_vgf_quant(test_data):
257257
pipeline.run()
258258

259259

260-
aten_op_tensor = "torch.ops.aten.clamp.Tensor"
260+
aten_op_tensor = [
261+
"torch.ops.aten.clamp.Tensor",
262+
]
261263
exir_op_tensor = "executorch_exir_dialects_edge__ops_aten_clamp_Tensor"
262264

263265
test_data_suite_tensor_FP = {
@@ -413,10 +415,22 @@ def test_clamp_tosa_INT_tensor(test_data):
413415
input_tensor, min_val, max_val = test_data()
414416
model = Clamp(min_val, max_val)
415417

418+
# Check that int64 inputs are cast to int32 in the tfa pipeline
419+
if any(
420+
t.dtype == torch.int64
421+
for t in (input_tensor, min_val, max_val)
422+
if isinstance(t, torch.Tensor)
423+
):
424+
aten_op = aten_op_tensor + [
425+
"torch.ops.dim_order_ops._to_dim_order_copy.default"
426+
]
427+
else:
428+
aten_op = aten_op_tensor
429+
416430
pipeline = TosaPipelineINT[input_t](
417431
model,
418432
(input_tensor,),
419-
aten_op_tensor,
433+
aten_op,
420434
exir_op_tensor,
421435
)
422436
pipeline.run()

0 commit comments

Comments
 (0)