22
33template <int block_size>
44static __global__ void ssm_conv_f32 (
5- const float * src0, const float * src1, const float * src2, const float * src3,
6- const int src0_ne0, const int src0_nb1, const int src0_nb2,
7- const int src1_nb0, const int src1_nb1,
8- const int src2_nb1, const int src2_nb2,
9- const int src3_nb1,
5+ const float * src0, const float * src1, const float * src2,
6+ const int src0_nb1, const int src0_nb2,
7+ const int src1_nb0, const int src1_nb1, const int src1_nb2,
8+ const int src2_nb1,
109 float * dst,
11- const int nc, const int nr, const int n_t , const int n_kv) {
10+ const int dst_nb0, const int dst_nb1, const int dst_nb2,
11+ const int nc, const int nr, const int n_t , const int n_s) {
1212
1313// const int row = blockIdx.x*blockDim.y + threadIdx.y;
1414 const int tid = threadIdx .x ;
@@ -24,136 +24,118 @@ static __global__ void ssm_conv_f32(
2424 const int ir1 = min (ir0 + dr, nr);
2525 const int ir = ir1 - ir0;
2626
27- if (n_kv > 1 ) {
28- // multiple sequences means it's hard to know when it's the first time a state is read,
29- // so copy them all over to the destination, just to be sure.
30- for (int i3 = 0 ; i3 < n_kv; ++i3) {
31- float * s0 = (float *) ((char *) src0 + ir0*src0_nb1 + i3*src0_nb2);
32- float * s = (float *) ((char *) dst + ir0*src2_nb1 + i3*src2_nb2 + nr*n_t *sizeof (float ));
33- // can't use memcpy because of d_conv vs d_conv - 1
34- for (int i1 = 0 ; i1 < ir; ++i1) {
35- for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
36- // copy s0 to last (d_conv - 1) columns of s
37- s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1 )];
38- }
39- }
40- }
41- }
27+ // TODO: maybe require src0 to have d_conv columns instead of (d_conv - 1)?
28+ // This would avoid having to copy into an intermediate buffer, but the state would be bigger.
4229
43- for (int i2 = 0 ; i2 < n_t ; ++i2) {
44- int32_t * sq = (int32_t *) ((char *) src3 + i2*src3_nb1); // {n_kv, n_tokens}
45- float * x = (float *) ((char *) dst + ir0*sizeof (float ) + i2*(nr*sizeof (float ))); // {d_inner, n_tokens}
46- float * s = (float *) ((char *) dst + ir0*src2_nb1 + sq[0 ]*src2_nb2 + nr*n_t *sizeof (float )); // {d_conv, d_inner, n_kv}
47- float * s0; // {d_conv - 1, d_inner, n_kv}
48- float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1); // {d_inner, n_tokens}
49- float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
50- int ne0s0;
51-
52- // avoid needing to copy the state for the first token
53- if (i2 == 0 ) {
54- s0 = (float *) ((char *) src0 + ir0*src0_nb1 + sq[0 ]*src0_nb2); // {d_conv - 1, d_inner, n_kv}
55- ne0s0 = src0_ne0;
56- } else {
57- // the source is the last (d_conv - 1) columns of the destination
58- s0 = s + 1 ;
59- ne0s0 = nc;
60- }
30+ // float * s = (float *) params->wdata + (nc*dr + CACHE_LINE_SIZE_F32) * ith;
31+ extern __shared__ float wdata_f32[]; // work buffer for all threads
32+ float * s = (float *) wdata_f32 + nc*dr*ith;
6133
62- // d_inner
34+ for (int i3 = 0 ; i3 < n_s; ++i3) {
35+ float * s0 = (float *) ((char *) src0 + ir0*src0_nb1) + i3*src0_nb2; // {d_conv, d_inner, n_s}
36+
37+ // copy the state into working memory
38+ // can't use memcpy because (d_conv) != (d_conv - 1)
6339 for (int i1 = 0 ; i1 < ir; ++i1) {
64- // shift state left
6540 for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
66- s[i0 + i1*nc] = s0[i0 + i1*ne0s0 ];
41+ s[1 + i0 + i1*nc] = s0[i0 + i1*(nc - 1 ) ];
6742 }
68- // insert x on the last column
69- s[(nc - 1 ) + i1*nc] = x0[i1];
7043 }
7144
72- // handle copies when there are multiple output states
73- for (int i3 = 1 ; i3 < n_kv; ++i3) {
74- int32_t seq = sq[i3];
75- if (0 <= seq && seq < n_kv) {
76- float * s1 = s + (seq - sq[0 ])*nc*nr;
45+ for (int i2 = 0 ; i2 < n_t ; ++i2) {
46+ float * x = (float *) ((char *) dst + ir0* dst_nb0 + i2* dst_nb1 + i3* dst_nb2); // {d_inner, n_t, n_s}
47+ float * x0 = (float *) ((char *) src1 + ir0*src1_nb0 + i2*src1_nb1 + i3*src1_nb2); // {d_inner, n_t, n_s}
48+ float * c = (float *) ((char *) src2 + ir0*src2_nb1); // {d_conv, d_inner}
7749
78- // memcpy(s1, s, nc*ir*sizeof(float));
79- for (int i4 = 0 ; i4 < nc*ir; i4++) {
80- s1[i4] = s[i4];
50+ // shift state left
51+ // memmove(s, s + 1, (nc*ir - 1) * sizeof(float));
52+ for (int i4 = 0 ; i4 < nc*ir - 1 ; ++i4) {
53+ s[i4] = s[i4+1 ];
54+ }
55+
56+ // d_inner
57+ for (int i1 = 0 ; i1 < ir; ++i1) {
58+ // insert x on the last column
59+ s[(nc - 1 ) + i1*nc] = x0[i1];
60+ }
61+
62+ // it seems a little faster when this is separate from the state shift
63+ for (int i1 = 0 ; i1 < ir; ++i1) {
64+ // rowwise dot product
65+ // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
66+ float sumf = 0 .0f ;
67+ for (int i0 = 0 ; i0 < nc; ++i0) {
68+ int i = i0 + i1*nc;
69+ sumf += s[i] * c[i];
8170 }
82- } else {
83- // stop at negative or too big seq_ids
84- break ;
71+ x[i1] = sumf;
8572 }
8673 }
8774
88- // it seems a little faster when this is separate from the state shift
75+ // copy the state out of it
8976 for (int i1 = 0 ; i1 < ir; ++i1) {
90- // rowwise dot product
91- float sumf = 0 .0f ;
92- for (int i0 = 0 ; i0 < nc; ++i0) {
93- int i = i0 + i1*nc;
94- sumf += s[i] * c[i];
77+ for (int i0 = 0 ; i0 < nc - 1 ; ++i0) {
78+ s0[i0 + i1*(nc - 1 )] = s[1 + i0 + i1*nc];
9579 }
96- x[i1] = sumf;
9780 }
9881 }
9982}
10083
10184static void ssm_conv_f32_cuda (
102- const float * src0, const float * src1, const float * src2, const float * src3,
103- const int src0_ne0, const int src0_nb1, const int src0_nb2,
104- const int src1_nb0, const int src1_nb1,
105- const int src2_nb1, const int src2_nb2,
106- const int src3_nb1,
85+ const float * src0, const float * src1, const float * src2,
86+ const int src0_nb1, const int src0_nb2,
87+ const int src1_nb0, const int src1_nb1, const int src1_nb2,
88+ const int src2_nb1,
10789 float * dst,
108- const int nc, const int nr, const int n_t , const int n_kv, cudaStream_t stream) {
90+ const int dst_nb0, const int dst_nb1, const int dst_nb2,
91+ const int nc, const int nr, const int n_t , const int n_s,
92+ cudaStream_t stream) {
10993
11094 const dim3 block_dims (WARP_SIZE, 1 , 1 );
11195 const int nblocks = 1 ; // TODO
96+ const int shmem_size = nc * (nr + WARP_SIZE - 1 ) * sizeof (float ); // TODO
11297
113- ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, 0 , stream>>> (
114- src0, src1, src2, src3,
115- src0_ne0, src0_nb1, src0_nb2,
116- src1_nb0, src1_nb1,
117- src2_nb1, src2_nb2,
118- src3_nb1,
98+ ssm_conv_f32<WARP_SIZE><<<nblocks, block_dims, shmem_size, stream>>> (
99+ src0, src1, src2,
100+ src0_nb1, src0_nb2,
101+ src1_nb0, src1_nb1, src1_nb2,
102+ src2_nb1,
119103 dst,
120- nc, nr, n_t , n_kv);
104+ dst_nb0, dst_nb1, dst_nb2,
105+ nc, nr, n_t , n_s);
121106}
122107
123108void ggml_cuda_op_ssm_conv (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
124109 const struct ggml_tensor * src0 = dst->src [0 ]; // conv_state
125110 const struct ggml_tensor * src1 = dst->src [1 ]; // x
126111 const struct ggml_tensor * src2 = dst->src [2 ]; // conv1d.weight
127- const struct ggml_tensor * src3 = dst->src [3 ]; // state_seq
128112
129- const int nc = src2->ne [0 ]; // d_conv
130- const int nr = src0->ne [1 ]; // d_inner
131- const int n_t = src1->ne [1 ]; // n_tokens
132- const int n_kv = src0->ne [2 ]; // max number of sequences in the batch
113+ const int nc = src2->ne [0 ]; // d_conv
114+ const int nr = src0->ne [1 ]; // d_inner
115+ const int n_t = src1->ne [1 ]; // tokens per sequence
116+ const int n_s = src0->ne [2 ]; // number of sequences in the batch
133117
134- GGML_ASSERT ((nr* n_t ) + (nc*nr*n_kv) == ggml_nelements ( dst));
118+ GGML_ASSERT (ggml_are_same_shape (src1, dst));
135119 GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
136120 GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
137121 GGML_ASSERT (src2->nb [0 ] == sizeof (float ));
138- GGML_ASSERT (src3->nb [0 ] == sizeof (int32_t ));
139122 GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ]*sizeof (float ));
140- // for use with the destination state offset between sequences
141- GGML_ASSERT (src2->nb [2 ] == src2->ne [1 ]*src2->ne [0 ]*sizeof (float ));
142123
143124 const float * src0_d = (const float *)src0->data ;
144125 const float * src1_d = (const float *)src1->data ;
145126 const float * src2_d = (const float *)src2->data ;
146- const float * src3_d = (const float *)src3->data ;
147127 float * dst_d = (float *)dst->data ;
148128 cudaStream_t stream = ctx.stream ();
149129
150130 GGML_ASSERT (src0->type == GGML_TYPE_F32);
151131 GGML_ASSERT ( dst->type == GGML_TYPE_F32);
152132
153- ssm_conv_f32_cuda (src0_d, src1_d, src2_d, src3_d,
154- src0->ne [0 ], src0->nb [1 ], src0->nb [2 ],
155- src1->nb [0 ], src1->nb [1 ],
156- src2->nb [1 ], src2->nb [2 ],
157- src3->nb [1 ],
158- dst_d, nc, nr, n_t , n_kv, stream);
133+ ssm_conv_f32_cuda (src0_d, src1_d, src2_d,
134+ src0->nb [1 ], src0->nb [2 ],
135+ src1->nb [0 ], src1->nb [1 ], src1->nb [2 ],
136+ src2->nb [1 ],
137+ dst_d,
138+ dst->nb [0 ], dst->nb [1 ], dst->nb [2 ],
139+ nc, nr, n_t , n_s,
140+ stream);
159141}
0 commit comments