Skip to content

Commit ec1cd04

Browse files
authored
Add mixed dtype check for XNNPACK partitioner (#9612)
### Summary Fixes #9023 Prevents the partitioner from handling ops with mixed dtypes. ### Test plan Unable to directly test due to auto-casting of dtypes and existing dtype checks in verifier.py.
1 parent 5531a0e commit ec1cd04

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

backends/xnnpack/partition/config/xnnpack_config.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,10 @@ def check_common_constraints(
144144
return True
145145

146146
def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
147-
# Check inputs are valid dtypes
147+
# Check inputs are valid and have the same dtypes
148148
# Gather all args which are nodes
149149
args_to_check = []
150+
reference_dtype = None
150151
for arg in node.args:
151152
if isinstance(arg, list) or isinstance(arg, tuple):
152153
for item in arg:
@@ -174,11 +175,32 @@ def _check_inputs_are_valid_dtypes(self, node, valid_dtypes):
174175
if arg_val.dtype not in valid_dtypes:
175176
return False
176177

178+
# Use the first dtype as reference
179+
reference_dtype = reference_dtype or arg_val.dtype
180+
181+
# Check for mixed dtypes
182+
if arg_val.dtype != reference_dtype:
183+
# Get op name if the attribute exists, otherwise use the full node target for logging
184+
op_name = (
185+
node.target.__name__
186+
if hasattr(node.target, "__name__")
187+
else str(node.target)
188+
)
189+
why(
190+
node,
191+
reason=(
192+
f"{op_name} does not support mixed input dtypes, "
193+
f"got: [{reference_dtype}, {arg_val.dtype}]"
194+
),
195+
)
196+
return False
197+
177198
return True
178199

179200
def _check_outputs_are_valid_dtypes(self, node, valid_dtypes):
180-
# Check outputs are valid dtype
201+
# Check outputs are valid
181202
node_val = node.meta.get("val", None)
203+
182204
if node_val is None:
183205
return True
184206

0 commit comments

Comments
 (0)