From 0b7a17a11dd123ca34462bd0cbb6c8d2cb124eb5 Mon Sep 17 00:00:00 2001 From: Nathanael See Date: Thu, 6 Feb 2025 18:29:02 -0800 Subject: [PATCH] skip op in partitioning if there are bool input tensors (#8295) Summary: Vulkan backend does not support bool tensors Reviewed By: jorgep31415 Differential Revision: D69273733 --- backends/vulkan/partitioner/vulkan_partitioner.py | 8 ++++++-- backends/vulkan/utils.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 6ff3fa8d70f..07660c8878f 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -61,7 +61,7 @@ def __init__( self.buffer_limit = buffer_limit self.require_dynamic_shapes = require_dynamic_shape - def op_node_is_compatible( + def op_node_is_compatible( # noqa: C901: Function is too complex self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: """ @@ -98,8 +98,12 @@ def op_node_is_compatible( and utils.is_tensor_node(arg) and i not in features.skip_limits_check ): + # Check for bool inputs + if utils.tensor_node_is_bool(arg): + return False, "contains bool tensor" + # Check for high dimensional tensors - if utils.is_tensor_node(arg) and utils.tensor_node_is_high_dim(arg): + if utils.tensor_node_is_high_dim(arg): return False, "contains high dim tensor" arg_texture_layouts = utils.possible_node_memory_layouts( diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 5034747be9d..fa032cd7b4f 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -80,6 +80,20 @@ def is_tensor_node(node: torch.fx.Node) -> bool: return False +def tensor_node_is_bool(node: torch.fx.Node) -> bool: + """ + Returns true if a given node contains a tensor with bool dtype + """ + if isinstance(node.meta["val"], FakeTensor): + return node.meta["val"].dtype == torch.bool + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + for fake_tensor in node.meta["val"]: + if isinstance(fake_tensor, FakeTensor): + if fake_tensor.dtype == torch.bool: + return True + return False + + ## ## Memory Layout, Storage Type Determination ##