diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index 449824a33b9..da42f177f76 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -9,6 +9,7 @@ import argparse import copy import logging +import math import os import shlex @@ -143,6 +144,80 @@ def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: return module +class SDPASimple(torch.nn.Module): + + def __init__( + self, + kv_cache: KVCache, + dim: int, + head_dim: int, + n_rep: int, + ): + super().__init__() + self.kv_cache = kv_cache + self.dim = dim + self.head_dim = head_dim + self.n_rep = n_rep + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz, + seqlen, + mask, + ): + q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + k, v = self.kv_cache.update(input_pos, k, v) + attn_mask = mask[None, None, input_pos] + + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + scale_factor = 1 / math.sqrt(q.size(-1)) + attn_weight = q @ k.transpose(-2, -1) * scale_factor + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + y = attn_weight @ v + + return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + +def replace_sdpa_with_simple_sdpa(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, SDPA): + setattr( + module, + name, + SDPASimple(child.kv_cache, child.dim, child.head_dim, child.n_rep), + ) + else: + replace_sdpa_with_simple_sdpa(child) + return module + + +def replace_causal_mask(module: torch.nn.Module): + for buffer_fqn_name, buffer in module.named_buffers(): + buffer_name = buffer_fqn_name.split(".")[-1] + if buffer_name == "mask": + max_seq_len = buffer.shape[-1] + mask = torch.full( + (max_seq_len, max_seq_len), + float("-inf"), + device="cpu", + ) + + mask = torch.triu(mask, diagonal=1) + module.register_buffer(buffer_name, mask) + for _, child in module.named_children(): + replace_causal_mask(child) + return module + + def quantize( model: torch.nn.Module, qmode: str, diff --git a/examples/models/llama2/tests/TARGETS b/examples/models/llama2/tests/TARGETS new file mode 100644 index 00000000000..3d2aef6209f --- /dev/null +++ b/examples/models/llama2/tests/TARGETS @@ -0,0 +1,15 @@ +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +oncall("executorch") + +python_unittest( + name = "test_simple_sdpa", + srcs = [ + "test_simple_sdpa.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama2:export_library", + "//executorch/examples/models/llama2:llama_transformer", + ], +) diff --git a/examples/models/llama2/tests/test_simple_sdpa.py b/examples/models/llama2/tests/test_simple_sdpa.py new file mode 100644 index 00000000000..e5360f0e0fa --- /dev/null +++ b/examples/models/llama2/tests/test_simple_sdpa.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.examples.models.llama2.export_llama_lib import SDPASimple +from executorch.examples.models.llama2.llama_transformer import KVCache, SDPA + + +class SDPATest(unittest.TestCase): + def test_simple_sdpa(self): + # Verify the correctness between the simple SDPA and the original SDPA module defined in llama_transformer.py + max_batch_size = 1 + max_seq_length = 128 + n_heads = 8 + head_dim = 8 + dim = 64 + n_rep = 1 + bsz = 1 + seqlen = 1 + n_local_heads = n_heads + kv_cache = KVCache( + max_batch_size=max_batch_size, + max_seq_length=max_seq_length, + n_heads=n_heads, + head_dim=head_dim, + transpose_cache=True, + ) + sdpa = SDPA( + kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep + ) + input_pos = torch.tensor([0]) + query = torch.randn(1, 1, n_local_heads, head_dim) + key = torch.randn(1, 1, n_local_heads, head_dim) + value = torch.randn(1, 1, n_local_heads, head_dim) + mask = torch.randn(max_seq_length, max_seq_length) + sdpa_output = sdpa( + input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask + ) + + simple_sdpa = SDPASimple( + kv_cache=copy.deepcopy(kv_cache), dim=dim, head_dim=head_dim, n_rep=n_rep + ) + simple_sdpa_output = simple_sdpa( + input_pos, query, key, value, bsz=bsz, seqlen=seqlen, mask=mask + ) + + # Compare the output from output from two sdpa implementation + self.assertTrue(torch.allclose(sdpa_output, simple_sdpa_output))