@@ -101,23 +101,25 @@ void main() {
101
101
// "k" tracks the kernel's index for our input-kernel computation.
102
102
// It reads out-of-bound zeros, but trying to avoid them complicates
103
103
// for-loop conditions, which results in worse performance.
104
- for (int k = 0 ; k < kernel_size; k += 4 ) {
105
- // Since the weight tensor is width-packed, which is along the length
106
- // dimension, we can batch-read four elements at a time.
107
- const ivec3 w_lpos = ivec3 (k / 4 , in_c % in_group_size, out_c);
108
- const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
109
104
110
- ivec3 in_pos = lpos_to_pos(ivec3 (in_l + k * dilation, in_c, n / 4 ), in_axis_map);
111
- sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum);
112
-
113
- in_pos[in_axis_map.x] += dilation;
114
- sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum);
105
+ // The weight tensor is channel-packed. It may not be trival choice for
106
+ // performance reason since need to have more data fetch. The reason is
107
+ // for some sequence model, we found that the weight tensor
108
+ // (out_channel, in_channel / group, kernel) often has a large
109
+ // out_channel >> kernel, leading to non-optimal use of memory as the
110
+ // weight tensor gets very deep. As a mitigation, we use channel-packing
111
+ // for the weight tensor, yielding a 75% reduction in weight-tensor
112
+ // memory.
113
+
114
+ // It is possible to further reduce the memory footprint by swapping the
115
+ // dimensions, using x extent for out_channel, and y for kernel.
116
+ for (int k = 0 ; k < kernel_size; k += 1 ) {
117
+ const ivec3 w_lpos = ivec3 (k, in_c % in_group_size, out_c / 4 );
118
+ const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
119
+ VEC4_T weight = VEC4_T(weight_texel[out_c % 4 ]);
115
120
116
- in_pos[in_axis_map.x] += dilation;
117
- sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum);
118
-
119
- in_pos[in_axis_map.x] += dilation;
120
- sum = fma(weight.wwww, load_texel(t_in, in_pos), sum);
121
+ ivec3 in_pos = lpos_to_pos(ivec3 (in_l + k * dilation, in_c, n / 4 ), in_axis_map);
122
+ sum = fma(weight, load_texel(t_in, in_pos), sum);
121
123
}
122
124
}
123
125
0 commit comments