1
1
# -*- coding: utf-8 -*-
2
2
# Copyright (c) 2023, Yu Zhang, Songlin Yang
3
3
4
- from typing import Tuple
4
+ from typing import Optional , Tuple
5
5
6
6
import torch
7
7
import triton
11
11
from fla .utils import contiguous
12
12
13
13
14
- @torch .jit .script
15
- def normalize_output (q , k , o ):
16
- k = k .transpose (- 2 , - 1 )
17
- k = k .cumsum (- 1 )
18
- k = k .transpose (- 2 , - 1 )
19
- z = (q * k ).sum (- 1 , keepdim = True )
20
- return o / (z + 1e-5 )
21
-
22
-
23
14
@triton .jit
24
15
def chunk_simple_gla_fwd_kernel_h (
25
16
k ,
26
17
v ,
27
18
h ,
28
19
g ,
29
- initial_state , # initial state of the chunk [B, H, D_head_K, D_head_V]
30
- final_state , # final state of the chunk [B, H, D_head_K, D_head_V]
20
+ h0 ,
21
+ ht ,
31
22
s_qk_h ,
32
23
s_qk_t ,
33
24
s_qk_d ,
@@ -36,7 +27,6 @@ def chunk_simple_gla_fwd_kernel_h(
36
27
s_vo_d ,
37
28
s_h_h ,
38
29
s_h_t ,
39
- H : tl .constexpr ,
40
30
T : tl .constexpr ,
41
31
K : tl .constexpr ,
42
32
V : tl .constexpr ,
@@ -53,17 +43,13 @@ def chunk_simple_gla_fwd_kernel_h(
53
43
b_h = tl .zeros ([BK , BV ], dtype = tl .float32 )
54
44
55
45
if USE_INITIAL_STATE :
56
- p_h0 = tl .make_block_ptr (initial_state + i_bh * K * V ,
57
- (K , V ), (V , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
46
+ p_h0 = tl .make_block_ptr (h0 + i_bh * K * V , (K , V ), (V , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
58
47
b_h = tl .load (p_h0 , boundary_check = (0 , 1 )).to (tl .float32 )
59
48
60
49
for i_t in range (NT ):
61
- p_k = tl .make_block_ptr (
62
- k + i_bh * s_qk_h , (K , T ), (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
63
- p_v = tl .make_block_ptr (
64
- v + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
65
- p_h = tl .make_block_ptr (h + i_bh * s_h_h + i_t * K * V ,
66
- (K , V ), (s_h_t , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
50
+ p_k = tl .make_block_ptr (k + i_bh * s_qk_h , (K , T ), (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
51
+ p_v = tl .make_block_ptr (v + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
52
+ p_h = tl .make_block_ptr (h + i_bh * s_h_h + i_t * K * V , (K , V ), (s_h_t , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
67
53
68
54
tl .store (p_h , b_h .to (p_h .dtype .element_ty ), boundary_check = (0 , 1 ))
69
55
# [BK, BT]
@@ -72,13 +58,12 @@ def chunk_simple_gla_fwd_kernel_h(
72
58
b_v = tl .load (p_v , boundary_check = (0 , 1 ))
73
59
# [BK, BV]
74
60
b_g_last = tl .load (g + i_bh * T + i_t * BT + BT - 1 )
75
- b_h *= tl .math . exp2 (b_g_last )
61
+ b_h *= tl .exp (b_g_last )
76
62
b_g = tl .load (g + i_bh * T + i_t * BT + tl .arange (0 , BT ))
77
- b_h += tl .dot (b_k , (b_v * tl .math . exp2 (b_g_last - b_g )[:, None ]).to (b_k .dtype ), allow_tf32 = False )
63
+ b_h += tl .dot (b_k , (b_v * tl .exp (b_g_last - b_g )[:, None ]).to (b_k .dtype ), allow_tf32 = False )
78
64
79
65
if STORE_FINAL_STATE :
80
- p_ht = tl .make_block_ptr (
81
- final_state + i_bh * K * V , (K , V ), (V , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
66
+ p_ht = tl .make_block_ptr (ht + i_bh * K * V , (K , V ), (V , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
82
67
tl .store (p_ht , b_h .to (p_ht .dtype .element_ty ), boundary_check = (0 , 1 ))
83
68
84
69
@@ -99,7 +84,6 @@ def chunk_simple_gla_fwd_kernel_o(
99
84
s_h_h ,
100
85
s_h_t ,
101
86
scale ,
102
- H : tl .constexpr ,
103
87
T : tl .constexpr ,
104
88
K : tl .constexpr ,
105
89
V : tl .constexpr ,
@@ -115,12 +99,9 @@ def chunk_simple_gla_fwd_kernel_o(
115
99
b_o = tl .zeros ([BT , BV ], dtype = tl .float32 )
116
100
b_s = tl .zeros ([BT , BT ], dtype = tl .float32 )
117
101
for i_k in range (tl .cdiv (K , BK )):
118
- p_q = tl .make_block_ptr (
119
- q + i_bh * s_qk_h , (T , K ), (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
120
- p_k = tl .make_block_ptr (
121
- k + i_bh * s_qk_h , (K , T ), (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
122
- p_h = tl .make_block_ptr (h + i_bh * s_h_h + i_t * K * V ,
123
- (K , V ), (s_h_t , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
102
+ p_q = tl .make_block_ptr (q + i_bh * s_qk_h , (T , K ), (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
103
+ p_k = tl .make_block_ptr (k + i_bh * s_qk_h , (K , T ), (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
104
+ p_h = tl .make_block_ptr (h + i_bh * s_h_h + i_t * K * V , (K , V ), (s_h_t , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
124
105
125
106
# [BT, BK]
126
107
b_q = tl .load (p_q , boundary_check = (0 , 1 ))
@@ -135,16 +116,14 @@ def chunk_simple_gla_fwd_kernel_o(
135
116
136
117
p_g = g + i_bh * T + i_t * BT + tl .arange (0 , BT )
137
118
b_g = tl .load (p_g )
138
- b_o = b_o * tl .math . exp2 (b_g )[:, None ]
139
- b_s = b_s * tl .math . exp2 (b_g [:, None ] - b_g [None , :])
119
+ b_o = b_o * tl .exp (b_g )[:, None ]
120
+ b_s = b_s * tl .exp (b_g [:, None ] - b_g [None , :])
140
121
b_s = tl .where (m_s , b_s , 0 )
141
122
142
- p_v = tl .make_block_ptr (v + i_bh * s_vo_h , (T , V ),
143
- (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
123
+ p_v = tl .make_block_ptr (v + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
144
124
b_v = tl .load (p_v , boundary_check = (0 , 1 ))
145
125
b_o = (b_o + tl .dot (b_s .to (b_v .dtype ), b_v , allow_tf32 = False )) * scale
146
- p_o = tl .make_block_ptr (o + i_bh * s_vo_h , (T , V ),
147
- (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
126
+ p_o = tl .make_block_ptr (o + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
148
127
tl .store (p_o , b_o .to (p_o .dtype .element_ty ), boundary_check = (0 , 1 ))
149
128
150
129
@@ -163,7 +142,6 @@ def chunk_simple_gla_bwd_kernel_dh(
163
142
s_h_h ,
164
143
s_h_t ,
165
144
scale ,
166
- H : tl .constexpr ,
167
145
T : tl .constexpr ,
168
146
K : tl .constexpr ,
169
147
V : tl .constexpr ,
@@ -177,22 +155,18 @@ def chunk_simple_gla_bwd_kernel_dh(
177
155
# [BK, BV]
178
156
b_dh = tl .zeros ([BK , BV ], dtype = tl .float32 )
179
157
for i_t in range (NT - 1 , - 1 , - 1 ):
180
- p_q = tl .make_block_ptr (
181
- q + i_bh * s_qk_h , (K , T ), (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
182
- p_do = tl .make_block_ptr (
183
- do + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
184
- p_dh = tl .make_block_ptr (dh + i_bh * s_h_h + i_t * K * V ,
185
- (K , V ), (s_h_t , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
158
+ p_q = tl .make_block_ptr (q + i_bh * s_qk_h , (K , T ), (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
159
+ p_do = tl .make_block_ptr (do + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
160
+ p_dh = tl .make_block_ptr (dh + i_bh * s_h_h + i_t * K * V , (K , V ), (s_h_t , 1 ), (i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
186
161
187
162
tl .store (p_dh , b_dh .to (p_dh .dtype .element_ty ), boundary_check = (0 , 1 ))
188
163
# [BK, BT]
189
164
b_q = tl .load (p_q , boundary_check = (0 , 1 ))
190
- b_q = (b_q * scale * tl .math .exp2 (tl .load (g + i_bh * T +
191
- i_t * BT + tl .arange (0 , BT )))[None , :]).to (b_q .dtype )
165
+ b_q = (b_q * scale * tl .exp (tl .load (g + i_bh * T + i_t * BT + tl .arange (0 , BT )))[None , :]).to (b_q .dtype )
192
166
# [BT, V]
193
167
b_do = tl .load (p_do , boundary_check = (0 , 1 ))
194
168
# [BK, BV]
195
- b_dh *= tl .math . exp2 (tl .load (g + i_bh * T + i_t * BT + BT - 1 ))
169
+ b_dh *= tl .exp (tl .load (g + i_bh * T + i_t * BT + BT - 1 ))
196
170
b_dh += tl .dot (b_q , b_do .to (b_q .dtype ), allow_tf32 = False )
197
171
198
172
@@ -217,8 +191,6 @@ def chunk_simple_gla_bwd_kernel_dqkv(
217
191
s_h_h ,
218
192
s_h_t ,
219
193
scale ,
220
- B : tl .constexpr ,
221
- H : tl .constexpr ,
222
194
T : tl .constexpr ,
223
195
K : tl .constexpr ,
224
196
V : tl .constexpr ,
@@ -231,35 +203,28 @@ def chunk_simple_gla_bwd_kernel_dqkv(
231
203
n_bh = tl .num_programs (2 )
232
204
o_i = tl .arange (0 , BT )
233
205
234
- p_q = tl .make_block_ptr (q + i_bh * s_qk_h , (K , T ),
235
- (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
236
- p_k = tl .make_block_ptr (k + i_bh * s_qk_h , (T , K ),
237
- (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
206
+ p_q = tl .make_block_ptr (q + i_bh * s_qk_h , (K , T ), (s_qk_d , s_qk_t ), (i_k * BK , i_t * BT ), (BK , BT ), (0 , 1 ))
207
+ p_k = tl .make_block_ptr (k + i_bh * s_qk_h , (T , K ), (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
238
208
239
209
b_q = tl .load (p_q , boundary_check = (0 , 1 ))
240
210
b_k = tl .load (p_k , boundary_check = (0 , 1 ))
241
211
b_s = tl .dot (b_k , b_q , allow_tf32 = False )
242
212
p_g = g + i_bh * T + i_t * BT + tl .arange (0 , BT )
243
213
b_g = tl .load (p_g )
244
214
b_g_last = tl .load (g + i_bh * T + i_t * BT + BT - 1 )
245
- mask = tl .math . exp2 (b_g [None , :] - b_g [:, None ])
215
+ mask = tl .exp (b_g [None , :] - b_g [:, None ])
246
216
mask = tl .where (o_i [:, None ] <= o_i [None , :], mask * scale , 0 )
247
217
b_s = b_s * mask
248
218
249
219
b_dq = tl .zeros ([BT , BK ], dtype = tl .float32 )
250
220
b_dk = tl .zeros ([BT , BK ], dtype = tl .float32 )
251
221
b_ds = tl .zeros ([BT , BT ], dtype = tl .float32 )
252
222
for i_v in range (tl .cdiv (V , BV )):
253
- p_v = tl .make_block_ptr (
254
- v + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
255
- p_h = tl .make_block_ptr (h + i_bh * s_h_h , (V , NT * K ), (1 , s_h_t ),
256
- (i_v * BV , i_t * K + i_k * BK ), (BV , BK ), (0 , 1 ))
257
- p_do = tl .make_block_ptr (
258
- do + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
259
- p_dh = tl .make_block_ptr (dh + i_bh * s_h_h , (NT * K , V ),
260
- (s_h_t , 1 ), (i_t * K + i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
261
- p_dv = tl .make_block_ptr (dv + (i_k * n_bh + i_bh )* s_vo_h , (T , V ),
262
- (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
223
+ p_v = tl .make_block_ptr (v + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
224
+ p_h = tl .make_block_ptr (h + i_bh * s_h_h , (V , NT * K ), (1 , s_h_t ), (i_v * BV , i_t * K + i_k * BK ), (BV , BK ), (0 , 1 ))
225
+ p_do = tl .make_block_ptr (do + i_bh * s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
226
+ p_dh = tl .make_block_ptr (dh + i_bh * s_h_h , (NT * K , V ), (s_h_t , 1 ), (i_t * K + i_k * BK , i_v * BV ), (BK , BV ), (1 , 0 ))
227
+ p_dv = tl .make_block_ptr (dv + (i_k * n_bh + i_bh )* s_vo_h , (T , V ), (s_vo_t , s_vo_d ), (i_t * BT , i_v * BV ), (BT , BV ), (1 , 0 ))
263
228
# [BT, BV]
264
229
b_v = tl .load (p_v , boundary_check = (0 , 1 ))
265
230
b_do = tl .load (p_do , boundary_check = (0 , 1 ))
@@ -273,21 +238,19 @@ def chunk_simple_gla_bwd_kernel_dqkv(
273
238
b_dq += tl .dot (b_do , b_h , allow_tf32 = False ) * scale
274
239
b_dk += tl .dot (b_v , tl .trans (b_dh ), allow_tf32 = False )
275
240
# [BT, BV]
276
- b_dv = tl .dot (b_k , b_dh , allow_tf32 = False ) * tl .math . exp2 (- b_g + b_g_last )[:, None ] + \
277
- tl .dot (b_s .to (b_q .dtype ), b_do , allow_tf32 = False )
241
+ b_dv = tl .dot (b_k , b_dh , allow_tf32 = False ) * tl .exp (- b_g + b_g_last )[:, None ]
242
+ b_dv += tl .dot (b_s .to (b_q .dtype ), b_do , allow_tf32 = False )
278
243
tl .store (p_dv , b_dv .to (p_dv .dtype .element_ty ), boundary_check = (0 , 1 ))
279
244
280
- b_dq = b_dq * tl .math . exp2 (b_g )[:, None ]
281
- b_dk = b_dk * tl .math . exp2 (- b_g + b_g_last )[:, None ]
245
+ b_dq = b_dq * tl .exp (b_g )[:, None ]
246
+ b_dk = b_dk * tl .exp (- b_g + b_g_last )[:, None ]
282
247
b_ds = b_ds * tl .trans (mask )
283
248
b_ds = b_ds .to (b_k .dtype )
284
249
# [BT, BK]
285
250
b_dq += tl .dot (b_ds , b_k , allow_tf32 = False )
286
251
b_dk += tl .trans (tl .dot (b_q , b_ds , allow_tf32 = False ))
287
- p_dq = tl .make_block_ptr (dq + i_bh * s_qk_h , (T , K ),
288
- (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
289
- p_dk = tl .make_block_ptr (dk + i_bh * s_qk_h , (T , K ),
290
- (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
252
+ p_dq = tl .make_block_ptr (dq + i_bh * s_qk_h , (T , K ), (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
253
+ p_dk = tl .make_block_ptr (dk + i_bh * s_qk_h , (T , K ), (s_qk_t , s_qk_d ), (i_t * BT , i_k * BK ), (BT , BK ), (1 , 0 ))
291
254
tl .store (p_dq , b_dq .to (p_dq .dtype .element_ty ), boundary_check = (0 , 1 ))
292
255
tl .store (p_dk , b_dk .to (p_dk .dtype .element_ty ), boundary_check = (0 , 1 ))
293
256
@@ -300,19 +263,14 @@ class SimpleGLAFunction(torch.autograd.Function):
300
263
def forward (ctx , q , k , v , g , scale , initial_state , output_final_state ):
301
264
B , H , T , K , V = * q .shape , v .shape [- 1 ]
302
265
BT = 64
303
- BK , BV = min (64 , triton .next_power_of_2 (K )), min (
304
- 64 , triton .next_power_of_2 (V ))
266
+ BK , BV = min (64 , triton .next_power_of_2 (K )), min (64 , triton .next_power_of_2 (V ))
305
267
NT , NK , NV = triton .cdiv (T , BT ), triton .cdiv (K , BK ), triton .cdiv (V , BV )
306
- num_stages = 1
307
268
num_warps = 4 if BK == 64 else 2
269
+ num_stages = 1
308
270
309
- if scale is None :
310
- scale = K ** - 0.5
311
-
312
- BT = 64
313
271
assert T % BT == 0 , 'sequence length must be divisible by BT'
314
272
g = g .reshape (B , H , - 1 , BT )
315
- g = g .cumsum (- 1 ) * 1.44269504
273
+ g = g .cumsum (- 1 )
316
274
g = g .reshape (B , H , - 1 )
317
275
318
276
final_state = None
@@ -326,7 +284,7 @@ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
326
284
q .stride (1 ), q .stride (2 ), q .stride (3 ),
327
285
v .stride (1 ), v .stride (2 ), v .stride (3 ),
328
286
h .stride (1 ), h .stride (2 ),
329
- H = H , T = T , K = K , V = V , BT = BT , BK = BK , BV = BV , NT = NT ,
287
+ T = T , K = K , V = V , BT = BT , BK = BK , BV = BV , NT = NT ,
330
288
USE_INITIAL_STATE = initial_state is not None ,
331
289
STORE_FINAL_STATE = output_final_state ,
332
290
num_warps = num_warps ,
@@ -340,30 +298,29 @@ def forward(ctx, q, k, v, g, scale, initial_state, output_final_state):
340
298
v .stride (1 ), v .stride (2 ), v .stride (3 ),
341
299
h .stride (1 ), h .stride (2 ),
342
300
scale ,
343
- H = H , T = T , K = K , V = V , BT = BT , BK = BK , BV = BV ,
301
+ T = T , K = K , V = V , BT = BT , BK = BK , BV = BV ,
344
302
num_warps = num_warps ,
345
303
num_stages = num_stages
346
304
)
347
305
348
306
ctx .save_for_backward (q , k , v , h , g )
307
+ ctx .scale = scale
349
308
return o .to (q .dtype ), final_state
350
309
351
310
@staticmethod
352
311
@custom_bwd
353
312
@contiguous
354
- def backward (ctx , do , scale , d_ht = None ):
313
+ def backward (ctx , do , dht = None ):
355
314
q , k , v , h , g = ctx .saved_tensors
356
315
357
316
B , H , T , K , V = * q .shape , v .shape [- 1 ]
358
317
BT = 64
359
- BK , BV = min (32 if q .dtype == torch .float32 else 64 , triton .next_power_of_2 (K )), min (
360
- 32 if q .dtype == torch .float32 else 64 , triton .next_power_of_2 (V ))
318
+ BK = min (32 if q .dtype == torch .float32 else 64 , triton .next_power_of_2 (K ))
319
+ BV = min ( 32 if q .dtype == torch .float32 else 64 , triton .next_power_of_2 (V ))
361
320
NT , NK , NV = triton .cdiv (T , BT ), triton .cdiv (K , BK ), triton .cdiv (V , BV )
362
- num_stages = 1
363
321
num_warps = 4 if BK == 64 else 2
364
-
365
- if scale is None :
366
- scale = K ** - 0.5
322
+ num_stages = 1
323
+ scale = ctx .scale
367
324
368
325
dh = q .new_empty (B , H , NT * K , V )
369
326
grid = (NK , NV , B * H )
@@ -373,7 +330,7 @@ def backward(ctx, do, scale, d_ht=None):
373
330
v .stride (1 ), v .stride (2 ), v .stride (3 ),
374
331
dh .stride (1 ), dh .stride (2 ),
375
332
scale ,
376
- H = H , T = T , K = K , V = V , BT = BT , BK = BK , BV = BV , NT = NT ,
333
+ T = T , K = K , V = V , BT = BT , BK = BK , BV = BV , NT = NT ,
377
334
num_warps = num_warps ,
378
335
num_stages = num_stages
379
336
)
@@ -389,7 +346,7 @@ def backward(ctx, do, scale, d_ht=None):
389
346
v .stride (1 ), v .stride (2 ), v .stride (3 ),
390
347
dh .stride (1 ), dh .stride (2 ),
391
348
scale ,
392
- B = B , H = H , T = T , K = K , V = V , BT = BT , BK = BK , BV = BV , NT = NT ,
349
+ T = T , K = K , V = V , BT = BT , BK = BK , BV = BV , NT = NT ,
393
350
num_warps = num_warps ,
394
351
num_stages = num_stages
395
352
)
@@ -409,12 +366,31 @@ def chunk_simple_gla(
409
366
k : torch .Tensor ,
410
367
v : torch .Tensor ,
411
368
g : torch .Tensor , # log decay
412
- scale : float = None ,
369
+ scale : Optional [ float ] = None ,
413
370
initial_state : torch .Tensor = None ,
414
371
output_final_state : bool = False
415
372
) -> Tuple [torch .Tensor , torch .Tensor ]:
416
- if initial_state is not None :
417
- initial_state = initial_state .detach ()
373
+ r"""
374
+ Args:
375
+ q (torch.Tensor):
376
+ queries of shape `(B, H, T, K)`
377
+ k (torch.Tensor):
378
+ keys of shape `(B, H, T, K)`
379
+ v (torch.Tensor):
380
+ values of shape `(B, H, T, V)`
381
+ g (torch.Tensor):
382
+ Forget gates of shape `(B, H, T)` applied to keys.
383
+ Compared to GLA, the gating is head-wise instead of elementwise.
384
+ scale (Optional[int]):
385
+ Scale factor for the attention scores.
386
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
387
+ initial_state (Optional[torch.Tensor]):
388
+ Initial state of shape `(B, H, K, V)`. Default: `None`.
389
+ output_final_state (Optional[bool]):
390
+ Whether to output the final state of shape `(B, H, K, V)`. Default: `False`.
391
+ """
392
+ if scale is None :
393
+ scale = k .shape [- 1 ] ** - 0.5
418
394
g = g .float ()
419
395
o , final_state = SimpleGLAFunction .apply (q , k , v , g , scale , initial_state , output_final_state )
420
396
return o , final_state
0 commit comments