diff --git a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py index a0160efa90f..b4337829d7f 100644 --- a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py +++ b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py @@ -27,6 +27,19 @@ class SqueezeUnsqueezeInputs(ExportPass): exir_ops.edge.aten.gelu.default, } + def should_squeeze(self, op, shape: List[int]) -> bool: # pyre-ignore + if len(shape) == 3: + return shape[1] == 1 and shape[0] > 1 + if len(shape) == 4: + # No need to squeeze if all dims are 1 except the width dim + if all(dim == 1 for dim in shape[:-1]): + return False + # Otherwise, check for squeezable dim + return 1 in shape[:-1] + + # Prefer not to introduce additional orchestration ops by default + return False + def call_operator( self, op, # pyre-ignore @@ -34,18 +47,18 @@ def call_operator( kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - def _squeezable(shape: List[int]) -> bool: - return len(shape) > 2 and 1 in shape - if op not in self._squeezable_ops: return super().call_operator(op, args, kwargs, meta) - # pyre-ignore[16]: `None` has no attribute `node` input_shape = args[0].node.meta["val"].shape output_shape = meta["val"].shape - if not _squeezable(input_shape): + + if not self.should_squeeze(op, input_shape): return super().call_operator(op, args, kwargs, meta) + def _squeezable(shape: List[int]) -> bool: + return len(shape) > 2 and 1 in shape + # squeeze input tensor squeeze_shape = list(input_shape) while _squeezable(squeeze_shape):