Skip to content

Commit 15d109b

Browse files
committed
refactor-attention
1 parent 2d56897 commit 15d109b

File tree

3 files changed

+36
-10
lines changed

3 files changed

+36
-10
lines changed

examples/models/llama/attention.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,16 @@ def forward(
162162

163163
@register_attention("mha")
164164
class AttentionMHA(Attention):
165-
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
165+
def __init__(
166+
self,
167+
args: ModelArgs,
168+
layer_id: int,
169+
rope: Rope,
170+
# wq: nn.Module,
171+
# wk: nn.Module,
172+
# wv: nn.Module,
173+
# wo: nn.Module,
174+
):
166175
super().__init__()
167176
self.use_kv_cache = args.use_kv_cache
168177
self.n_heads = args.n_heads
@@ -196,6 +205,11 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
196205
)
197206
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
198207

208+
# self.wq = wq
209+
# self.wk = wk
210+
# self.wv = wv
211+
# self.wo = wo
212+
199213
self.layer_id = layer_id
200214

201215
self.rope = rope

examples/models/llama/llama_transformer.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
2121
from executorch.examples.models.llama.norm import RMSNorm
22+
2223
from executorch.examples.models.llama.rope import Rope
2324
from torch import nn
2425

@@ -83,7 +84,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
8384

8485

8586
class TransformerBlock(nn.Module):
86-
def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
87+
def __init__(self, args: ModelArgs, attention: Attention):
8788
super().__init__()
8889
self.use_kv_cache = args.use_kv_cache
8990
self.n_heads = args.n_heads
@@ -94,8 +95,8 @@ def __init__(self, layer_id: int, args: ModelArgs, rope: Rope):
9495
f"Unknown attention type: {args.attention_type}. "
9596
f"Available: {list(ATTENTION_REGISTRY.keys())}"
9697
)
97-
cls = ATTENTION_REGISTRY[args.attention_type]
98-
self.attention = cls(args, layer_id, rope)
98+
99+
self.attention = attention
99100
if args.moe:
100101
self.block_sparse_moe = MOEFeedForward(args)
101102
else:
@@ -117,7 +118,7 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
117118

118119

119120
class Transformer(nn.Module):
120-
def __init__(self, params: ModelArgs):
121+
def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope):
121122
super().__init__()
122123
self.params = params
123124
self.vocab_size = params.vocab_size
@@ -130,10 +131,8 @@ def __init__(self, params: ModelArgs):
130131
if self.apply_embedding
131132
else None
132133
)
133-
self.rope = Rope(params)
134-
self.layers = torch.nn.ModuleList()
135-
for layer_id in range(params.n_layers):
136-
self.layers.append(TransformerBlock(layer_id, params, self.rope))
134+
self.layers = layers
135+
self.rope = rope
137136
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
138137
self.output = (
139138
nn.Linear(params.dim, params.vocab_size, bias=False)

examples/models/llama/model.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.examples.models.llama.llama_transformer import Transformer
1919

2020
from executorch.examples.models.llama.model_args import ModelArgs
21+
from executorch.examples.models.llama.rope import Rope
2122

2223
try:
2324
from .fairseq2 import convert_to_llama_checkpoint
@@ -173,7 +174,19 @@ def __init__(self, **kwargs):
173174
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
174175
with torch.device("meta"):
175176
# Model itself is loaded in default dtype, fp32.
176-
self.model_ = Transformer(model_args)
177+
178+
# Construct attention layers.
179+
rope = Rope(model_args)
180+
layers = nn.ModuleList()
181+
cls = ATTENTION_REGISTRY[model_args.attention_type]
182+
for layer_id in range(model_args.n_layers):
183+
attention = cls(model_args, layer_id, rope)
184+
transformer_block = TransformerBlock(model_args, attention)
185+
layers.append(transformer_block)
186+
187+
# Construct transformer model.
188+
self.model_ = Transformer(model_args, layers, rope)
189+
177190
# Get checkpoint dtype.
178191
if checkpoint:
179192
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)

0 commit comments

Comments
 (0)