From d0d3ec92282c245ba15886c2db637d212da27ded Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 11 Apr 2024 22:42:09 -0700 Subject: [PATCH 1/2] Decouple custom ops in llama_transformer.py Part 1/N (#3005) Summary: This is a no-op Pull Request resolved: https://github.com/pytorch/executorch/pull/3005 Test Plan: CI Run with `python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv --use_sdpa_with_kv_cache -X` and with `python -m examples.models.llama2.export_llama -c stories110M.pt -p params.json -kv -X` Make sure both work Differential Revision: D56048177 Pulled By: mergennachin --- examples/models/llama2/builder.py | 6 +- examples/models/llama2/export_llama_lib.py | 4 + examples/models/llama2/llama_transformer.py | 134 ++++++++++++++------ examples/models/llama2/model.py | 16 +-- 4 files changed, 104 insertions(+), 56 deletions(-) diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 35577ad3ec7..00d71a5b014 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -206,11 +206,7 @@ def source_transform( def _get_dynamic_shape(self) -> Any: dim = torch.export.Dim("token_dim", max=self.model.params.max_seq_len - 1) if self.use_kv_cache: - if self.use_sdpa_with_kv_cache: - return None - else: - # return {1: dim}, {0: dim}} TODO update xnnpack to be able to handle dynamic shape kv cache - return None + return None else: return ({1: dim},) diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 76cfd00f3b7..8728b3fdd21 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -492,6 +492,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) + if args.use_sdpa_with_kv_cache: + pass + # TODO: Next diff transforms.append() + return ( load_llama_model( checkpoint=checkpoint_path, diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 66fc47b17f0..d0794b8c376 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -209,6 +209,95 @@ def update( return k_out, v_out +class SDPA(nn.Module): + def __init__( + self, + kv_cache: KVCache, + mask, + use_sdpa_with_kv_cache_op: bool, + dim: int, + n_rep: int, + ): + super().__init__() + self.kv_cache = kv_cache + self.mask = mask + self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op + self.dim = dim + self.n_rep = n_rep + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ) -> torch.Tensor: + if not self.use_sdpa_with_kv_cache_op: + return self._forward_default( + input_pos, + q, + k, + v, + bsz, + seqlen, + ) + else: + return self._forward_custom( + input_pos, + q, + k, + v, + bsz, + seqlen, + ) + + def _forward_custom( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ): + from .custom_ops import sdpa_with_kv_cache # noqa + + output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.kv_cache.k_cache, + self.kv_cache.v_cache, + input_pos[-1].item(), + seqlen, + ) + return output.view(bsz, seqlen, self.dim) + + def _forward_default( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ) -> torch.Tensor: + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + k, v = self.kv_cache.update(input_pos, k, v) + mask = self.mask[None, None, input_pos] + + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + class Attention(nn.Module): def __init__(self, args: ModelArgs, layer_id: int): super().__init__() @@ -229,7 +318,6 @@ def __init__(self, args: ModelArgs, layer_id: int): self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) - self.use_sdpa_with_kv_cache_op = args.use_sdpa_with_kv_cache_op self.layer_id = layer_id causal_mask = torch.tril( @@ -250,6 +338,13 @@ def __init__(self, args: ModelArgs, layer_id: int): self.head_dim, not args.use_sdpa_with_kv_cache_op, # if we are using the custom op dont transpose the cache. Expect untransposed q k v ) + self.SDPA = SDPA( + self.kv_cache, + self.mask, + args.use_sdpa_with_kv_cache_op, + self.dim, + self.n_rep, + ) def forward( self, @@ -272,41 +367,8 @@ def forward( if self.use_kv_cache: assert input_pos is not None - - if not self.use_sdpa_with_kv_cache_op: - - q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - k = k.transpose(1, 2) - v = v.transpose(1, 2) - - k, v = self.kv_cache.update(input_pos, k, v) - mask = self.mask[None, None, input_pos] - - k = k.repeat_interleave(self.n_rep, dim=1) - v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) - - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) - - y = self.wo(y) - return y - else: - from .custom_ops import sdpa_with_kv_cache # noqa - - output = torch.ops.llama.sdpa_with_kv_cache( - q, - k, - v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, - input_pos[-1].item(), - seqlen, - ) - output = output.view(bsz, seqlen, -1) - output = self.wo(output) - return output + output = self.SDPA(input_pos, q, k, v, bsz, seqlen) + return self.wo(output) q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) diff --git a/examples/models/llama2/model.py b/examples/models/llama2/model.py index 5428e34b74f..31931adb389 100644 --- a/examples/models/llama2/model.py +++ b/examples/models/llama2/model.py @@ -209,11 +209,7 @@ def get_eager_model(self): def get_example_inputs(self): if self.use_kv_cache: - if self.use_sdpa_with_kv_cache_op: - return self.get_example_inputs_kvcache_sdpa() - else: - # return self.get_example_inputs_kvcache() TODO xnnpack does not handle forwarding symints, update partitioner to not partition symints - return self.get_example_inputs_kvcache_sdpa() + return self.get_example_inputs_kvcache_sdpa() else: return ( torch.tensor( @@ -231,13 +227,3 @@ def get_example_inputs_kvcache_sdpa(self): [0], dtype=torch.long ), # start_pos, what token of output are we on.) ) - - def get_example_inputs_kvcache(self): - return ( - torch.tensor( - [[1, 2, 3]], dtype=torch.long - ), # tokens, with kv cache our input token length is always just 1 token. - torch.tensor( - [0, 1, 2], dtype=torch.long - ), # start_pos, what token of output are we on. - ) From fbd4d36a112b1176af0e6c248a3bc7147f6446b9 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 11 Apr 2024 22:42:23 -0700 Subject: [PATCH 2/2] Decouple custom ops in llama_transformer.py Part 2/N (#3007) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/3007 Keep llama_transformer.py to look like stock implementation, so that it can be reused everywhere. Do module swap Differential Revision: D56048640 --- examples/models/llama2/TARGETS | 2 +- examples/models/llama2/export_llama_lib.py | 61 ++++++++++++++++++++- examples/models/llama2/llama_transformer.py | 53 ------------------ 3 files changed, 59 insertions(+), 57 deletions(-) diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 09ebd5aeada..9da7a26d6d7 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -18,7 +18,6 @@ runtime.python_library( ], deps = [ "//caffe2:torch", - "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py", ], ) @@ -86,6 +85,7 @@ runtime.python_library( "//executorch/backends/vulkan/partitioner:vulkan_partitioner", "//executorch/examples/models:model_base", "//executorch/examples/models:models", + "//executorch/examples/models/llama2/custom_ops:custom_ops_aot_py", "//executorch/examples/portable:utils", "//executorch/exir:lib", "//executorch/sdk/etrecord:etrecord", diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 8728b3fdd21..aa195209ad9 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -23,7 +23,11 @@ XnnpackDynamicallyQuantizedPartitioner, ) -from executorch.examples.models.llama2.llama_transformer import Transformer +from executorch.examples.models.llama2.llama_transformer import ( + KVCache, + SDPA, + Transformer, +) from executorch.exir.backend.backend_details import CompileSpec from executorch.sdk.etrecord import generate_etrecord @@ -88,6 +92,58 @@ def materialze_broadcast_of_rope_freq_cis( return module +class SDPACustom(torch.nn.Module): + def __init__( + self, + kv_cache: KVCache, + mask, + dim: int, + ): + super().__init__() + self.kv_cache = kv_cache + self.mask = mask + self.dim = dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + ): + output = torch.ops.llama.sdpa_with_kv_cache( + q, + k, + v, + self.kv_cache.k_cache, + self.kv_cache.v_cache, + input_pos[-1].item(), + seqlen, + ) + return output.view(bsz, seqlen, self.dim) + + +def _replace_sdpa_with_custom_op(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + setattr( + module, + name, + SDPACustom(child.kv_cache, child.mask, child.dim), + ) + else: + _replace_sdpa_with_custom_op(child) + + +def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: + from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa + + _replace_sdpa_with_custom_op(module) + return module + + def quantize( model: torch.nn.Module, qmode: str, @@ -493,8 +549,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager: transforms.append(materialze_broadcast_of_rope_freq_cis) if args.use_sdpa_with_kv_cache: - pass - # TODO: Next diff transforms.append() + transforms.append(replace_sdpa_with_custom_op) return ( load_llama_model( diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index d0794b8c376..c353a913bf0 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -214,14 +214,12 @@ def __init__( self, kv_cache: KVCache, mask, - use_sdpa_with_kv_cache_op: bool, dim: int, n_rep: int, ): super().__init__() self.kv_cache = kv_cache self.mask = mask - self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op self.dim = dim self.n_rep = n_rep @@ -233,56 +231,6 @@ def forward( v: torch.Tensor, bsz, seqlen, - ) -> torch.Tensor: - if not self.use_sdpa_with_kv_cache_op: - return self._forward_default( - input_pos, - q, - k, - v, - bsz, - seqlen, - ) - else: - return self._forward_custom( - input_pos, - q, - k, - v, - bsz, - seqlen, - ) - - def _forward_custom( - self, - input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bsz, - seqlen, - ): - from .custom_ops import sdpa_with_kv_cache # noqa - - output = torch.ops.llama.sdpa_with_kv_cache( - q, - k, - v, - self.kv_cache.k_cache, - self.kv_cache.v_cache, - input_pos[-1].item(), - seqlen, - ) - return output.view(bsz, seqlen, self.dim) - - def _forward_default( - self, - input_pos: torch.Tensor, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - bsz, - seqlen, ) -> torch.Tensor: q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) k = k.transpose(1, 2) @@ -341,7 +289,6 @@ def __init__(self, args: ModelArgs, layer_id: int): self.SDPA = SDPA( self.kv_cache, self.mask, - args.use_sdpa_with_kv_cache_op, self.dim, self.n_rep, )