@@ -21,78 +21,104 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
2121layout (set = 0 , binding = 2 ) uniform PRECISION sampler3D kernel_in;
2222layout (set = 0 , binding = 3 ) uniform PRECISION sampler3D bias_in;
2323
24- layout (set = 0 , binding = 4 ) uniform PRECISION restrict Out_channels {
25- int data;
26- }
27- out_channels;
28-
29- layout (set = 0 , binding = 5 ) uniform PRECISION restrict In_length {
30- int data;
31- }
32- in_length;
33-
34- layout (set = 0 , binding = 6 ) uniform PRECISION restrict Kernel_size {
35- int data;
36- }
37- kernel_size;
24+ layout (set = 0 , binding = 4 ) uniform PRECISION restrict OutLimits {
25+ ivec3 out_limits;
26+ };
27+
28+ layout (set = 0 , binding = 5 ) uniform PRECISION restrict InSizes {
29+ ivec4 in_sizes;
30+ };
31+
32+ layout (set = 0 , binding = 6 ) uniform PRECISION restrict Params {
33+ int kernel_size;
34+ int stride;
35+ int padding;
36+ int dilation;
37+ int in_group_size;
38+ int out_group_size;
39+ };
3840
3941layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4042
41- /*
42- * This implementation optimize for simplicity (and partially performance) for a
43- * (1, C, L) where C == groups. Hence we only focus on calculating the rolling
44- * kernel of the L dimension.
45- */
43+ // Let us define
44+ //
45+ // input = (N, in_C, in_L),
46+ // output = (N, out_C, out_L),
47+ // groups = G,
48+ // kernel = K,
49+ //
50+ // which results in shapes
51+ //
52+ // weight = (out_C, in_C / G, K),
53+ // bias = (out_C,).
54+ //
55+ // This implementation performs out_C shader invocations, where each invocation
56+ // calculates the rolling kernel of the length dimension for each batch, i.e.,
57+ // computes out_L * N results.
58+ //
59+ // Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
60+ // shader invocations, where each invocation computes 1 result. But that
61+ // performs worse.
4662void main() {
4763 const ivec3 pos = ivec3 (gl_GlobalInvocationID);
4864
49- // The global workgroup should have taken care of it. We only perform one
50- // work item for each 1d tensor on lengths
51- if (pos.x >= 1 ) {
65+ if (any (greaterThanEqual (pos, out_limits))) {
5266 return ;
5367 }
5468
55- int c = pos.y;
56- if (c >= out_channels.data) {
57- return ;
58- }
59-
60- // Assume n = 1, do not handle n > 1 case for now.
61- int n = pos.z;
62- if (n >= 1 ) {
63- return ;
64- }
65-
66- vec4 bias = texelFetch(bias_in, ivec3 (c, 0 , 0 ), 0 );
67-
68- for (int i = 0 ; i < in_length.data - kernel_size.data + 1 ; ++ i) {
69- vec4 v = vec4 (0 );
70- for (int k = 0 ; k < kernel_size.data; ++ k) {
71- const ivec3 in_pos = ivec3 (i+ k, c, 0 );
72- const vec4 input_value = texelFetch(image_in, in_pos, 0 );
73-
74- // Note that we are reading weight in the inner loop, this could be
75- // improved by moving it before the outer loop. Since the weight vector is
76- // contant for the entire call.
77-
78- // weight in input-space: (c, 0, k);
79- // notice that c is 4-packed. We need to mod 4 to get the actual weight.
80- const ivec3 w_pos = ivec3 (k, 0 , c / 4 );
81- const vec4 weight = texelFetch(kernel_in, w_pos, 0 );
82-
83- float w = weight.x;
84- if (c % 4 == 1 ) {
85- w = weight.y;
86- } else if (c % 4 == 2 ) {
87- w = weight.z;
88- } else if (c % 4 == 3 ) {
89- w = weight.w;
69+ int in_length = in_sizes.x;
70+ int batch_size = in_sizes.z;
71+
72+ // "out_c" is the output's channel index where we write our result.
73+ // Across shader invocations, this is the only value that varies.
74+ int out_c = pos.y;
75+ vec4 bias = texelFetch(bias_in, ivec3 (out_c, 0 , 0 ), 0 );
76+
77+ // "in_c" tracks the input's channel start index.
78+ // We iterate over the input group that corresponds to the output group.
79+ int c_start = (out_c / out_group_size) * in_group_size;
80+ int c_end = c_start + in_group_size;
81+
82+ // "in_l" tracks the input's length start index for our input-kernel overlay
83+ // region.
84+ int l_start = - padding;
85+ int l_end = in_length + padding - dilation * (kernel_size - 1 );
86+
87+ // Since the input/output tensors are channel-packed, which is along the
88+ // batch dimension, we can batch-read/write four elements at a time.
89+ for (int n = 0 ; n < batch_size; n += 4 ) {
90+ // "out_l" tracks the output's length index where we write our result.
91+ int out_l = 0 ;
92+
93+ for (int in_l = l_start; in_l < l_end; in_l += stride, ++ out_l) {
94+ vec4 sum = vec4 (0 );
95+
96+ for (int in_c = c_start; in_c < c_end; ++ in_c) {
97+ // "k" tracks the kernel's index for our input-kernel computation.
98+ // It reads out-of-bound zeros, but trying to avoid them complicates
99+ // for-loop conditions, which results in worse performance.
100+ for (int k = 0 ; k < kernel_size; k += 4 ) {
101+ // Since the weight tensor is width-packed, which is along the length
102+ // dimension, we can batch-read four elements at a time.
103+ const ivec3 w_pos = ivec3 (k / 4 , in_c % in_group_size, out_c);
104+ const vec4 weight = texelFetch(kernel_in, w_pos, 0 );
105+
106+ const ivec3 in_pos_0 = ivec3 (in_l + k * dilation, in_c, n / 4 );
107+ sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0 ), sum);
108+
109+ const ivec3 in_pos_1 = ivec3 (in_l + (k+ 1 ) * dilation, in_c, n / 4 );
110+ sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0 ), sum);
111+
112+ const ivec3 in_pos_2 = ivec3 (in_l + (k+ 2 ) * dilation, in_c, n / 4 );
113+ sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0 ), sum);
114+
115+ const ivec3 in_pos_3 = ivec3 (in_l + (k+ 3 ) * dilation, in_c, n / 4 );
116+ sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0 ), sum);
117+ }
90118 }
91119
92- v += w * input_value.x;
120+ ivec3 out_pos = ivec3 (out_l, out_c, n / 4 );
121+ imageStore(image_out, out_pos, sum + bias.x);
93122 }
94-
95- ivec3 out_pos = ivec3 (i, c, 0 );
96- imageStore(image_out, out_pos, vec4 (v.x + bias.x, 0 , 0 , 0 ));
97123 }
98124}
0 commit comments