@@ -133,23 +133,27 @@ def __init__(
133
133
134
134
def forward (
135
135
self ,
136
- input_pos : torch .Tensor ,
136
+ input_pos : Optional [ torch .Tensor ] ,
137
137
q : torch .Tensor , # Already have rotary embeddings. (bs, n_local_heads, seqlen, head_dim)
138
138
k : torch .Tensor , # Already have rotary embeddings. (bs, n_local_kv_heads, seqlen, head_dim)
139
139
v : torch .Tensor , # (bs, n_local_kv_heads, seqlen, head_dim)
140
140
bsz ,
141
141
seqlen ,
142
142
mask : torch .Tensor ,
143
143
) -> torch .Tensor :
144
- if self .enable_dynamic_shape :
145
- start_pos = input_pos [- 1 ].item ()
146
- torch ._check_is_size (start_pos )
147
- torch ._check (start_pos < self .max_context_len )
148
- seq_length = q .size (2 )
149
- # pyre-ignore: Incompatible parameter type [6]
150
- attn_mask = mask .narrow (0 , start_pos , seq_length )
144
+ if input_pos is None :
145
+ # No kv cache
146
+ attn_mask = mask [:seqlen , :seqlen ]
151
147
else :
152
- attn_mask = mask [None , None , input_pos ]
148
+ if self .enable_dynamic_shape :
149
+ start_pos = input_pos [- 1 ].item ()
150
+ torch ._check_is_size (start_pos )
151
+ torch ._check (start_pos < self .max_context_len )
152
+ seq_length = q .size (2 )
153
+ # pyre-ignore: Incompatible parameter type [6]
154
+ attn_mask = mask .narrow (0 , start_pos , seq_length )
155
+ else :
156
+ attn_mask = mask [None , None , input_pos ]
153
157
154
158
# TODO(kimishpatel): This should not be necessary because scaled_dot_product_attention
155
159
# can natively support GQA now. But needs enable_gqa=True
@@ -218,13 +222,13 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
218
222
self .head_dim ,
219
223
args .enable_dynamic_shape ,
220
224
)
221
- self .SDPA = SDPA (
222
- dim = self .n_local_heads * self .head_dim ,
223
- head_dim = self .head_dim ,
224
- n_rep = self .n_rep ,
225
- max_context_len = self .max_context_len ,
226
- enable_dynamic_shape = args .enable_dynamic_shape ,
227
- )
225
+ self .SDPA = SDPA (
226
+ dim = self .n_local_heads * self .head_dim ,
227
+ head_dim = self .head_dim ,
228
+ n_rep = self .n_rep ,
229
+ max_context_len = self .max_context_len ,
230
+ enable_dynamic_shape = args .enable_dynamic_shape ,
231
+ )
228
232
229
233
def forward (
230
234
self ,
@@ -257,21 +261,5 @@ def forward(
257
261
if self .use_kv_cache :
258
262
assert input_pos is not None
259
263
k , v = self .kv_cache .update (input_pos , k , v )
260
- output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
261
- return self .wo (output ), None
262
-
263
- # grouped multiquery attention: expand out keys and values
264
- k = k .repeat_interleave (self .n_rep , dim = 1 )
265
- v = v .repeat_interleave (self .n_rep , dim = 1 )
266
-
267
- assert hasattr (self , "mask" )
268
-
269
- mask = self .mask [:seqlen , :seqlen ]
270
-
271
- output = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
272
-
273
- output = output .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , - 1 )
274
-
275
- output = self .wo (output )
276
-
277
- return output , None
264
+ output = self .SDPA (input_pos , q , k , v , bsz , seqlen , self .mask )
265
+ return self .wo (output ), None
0 commit comments