From b2a91e968e7ab8f8f1998552cf44ec278b8da86e Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 4 Apr 2025 13:49:31 -0700 Subject: [PATCH] [ET-VK][ez] Make squeeze insertion requirements more strict ## Context Refactor the `SqueezeUnsqueezeInputs` pass to be more clear about its intention. For Llama models, input shapes to 4 bit linear will oftentimes have the shape `[1, seq_len, dim]`; under the current implementation of the pass, the input would be squeezed to `[seq_len, dim]` even though the squeeze is not necessary. The original intention of thispass was to squeeze inputs with shape `[batch_size, 1, dim]` to `[batch_size, dim]` before calling the 4-bit linear operator. ## Changes To avoid inserting unnecessary squeeze/unsqueezes, be more specific about when squeeze/unsqueeze should be added. I would like to consider refactoring this pass in the future, since the logic is currently a bit uninttuitive. Squeeze/unsqueeze is also inserted for gelu and relu, but this is to create a chain of unsqueeze/squeeze that will be eliminated by a later pass (see https://github.com/pytorch/executorch/pull/8601 / D69673068). I think eventually it will be good to rewrite the pass to make shape management more explicit and self contained within the pass rather than inserting ops which are expected to be removed later on. Differential Revision: [D72480178](https://our.internmc.facebook.com/intern/diff/D72480178/) [ghstack-poisoned] --- .../vulkan/_passes/squeeze_unsqueeze_inputs.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py index a0160efa90f..60cede7610d 100644 --- a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py +++ b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py @@ -27,6 +27,13 @@ class SqueezeUnsqueezeInputs(ExportPass): exir_ops.edge.aten.gelu.default, } + def should_squeeze(self, op, shape: List[int]) -> bool: + if len(shape) == 3: + return shape[1] == 1 and shape[0] > 1 + + # Prefer not to introduce additional orchestration ops by default + return False + def call_operator( self, op, # pyre-ignore @@ -34,18 +41,18 @@ def call_operator( kwargs: Dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - def _squeezable(shape: List[int]) -> bool: - return len(shape) > 2 and 1 in shape - if op not in self._squeezable_ops: return super().call_operator(op, args, kwargs, meta) - # pyre-ignore[16]: `None` has no attribute `node` input_shape = args[0].node.meta["val"].shape output_shape = meta["val"].shape - if not _squeezable(input_shape): + + if not self.should_squeeze(op, input_shape): return super().call_operator(op, args, kwargs, meta) + def _squeezable(shape: List[int]) -> bool: + return len(shape) > 2 and 1 in shape + # squeeze input tensor squeeze_shape = list(input_shape) while _squeezable(squeeze_shape):