Skip to content

Commit 2967302

Browse files
authored
Change weight to channel-packing in Conv1d
Differential Revision: D66417572 Pull Request resolved: #7057
1 parent a35cb73 commit 2967302

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,23 +101,25 @@ void main() {
101101
// "k" tracks the kernel's index for our input-kernel computation.
102102
// It reads out-of-bound zeros, but trying to avoid them complicates
103103
// for-loop conditions, which results in worse performance.
104-
for (int k = 0; k < kernel_size; k += 4) {
105-
// Since the weight tensor is width-packed, which is along the length
106-
// dimension, we can batch-read four elements at a time.
107-
const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c);
108-
const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
109104

110-
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
111-
sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum);
112-
113-
in_pos[in_axis_map.x] += dilation;
114-
sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum);
105+
// The weight tensor is channel-packed. It may not be trival choice for
106+
// performance reason since need to have more data fetch. The reason is
107+
// for some sequence model, we found that the weight tensor
108+
// (out_channel, in_channel / group, kernel) often has a large
109+
// out_channel >> kernel, leading to non-optimal use of memory as the
110+
// weight tensor gets very deep. As a mitigation, we use channel-packing
111+
// for the weight tensor, yielding a 75% reduction in weight-tensor
112+
// memory.
113+
114+
// It is possible to further reduce the memory footprint by swapping the
115+
// dimensions, using x extent for out_channel, and y for kernel.
116+
for (int k = 0; k < kernel_size; k += 1) {
117+
const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c / 4);
118+
const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
119+
VEC4_T weight = VEC4_T(weight_texel[out_c % 4]);
115120

116-
in_pos[in_axis_map.x] += dilation;
117-
sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum);
118-
119-
in_pos[in_axis_map.x] += dilation;
120-
sum = fma(weight.wwww, load_texel(t_in, in_pos), sum);
121+
ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
122+
sum = fma(weight, load_texel(t_in, in_pos), sum);
121123
}
122124
}
123125

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ void add_conv1d_node(
407407
const ValueRef out,
408408
const bool clamp_out) {
409409
ValueRef arg_weight = prepack_standard(
410-
graph, weight, graph.storage_type_of(out), utils::kWidthPacked);
410+
graph, weight, graph.storage_type_of(out), utils::kChannelsPacked);
411411
ValueRef arg_bias = prepack_biases(
412412
graph,
413413
bias,

0 commit comments

Comments
 (0)