Skip to content

Commit 45fd796

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
conv1d general case (#3223)
Summary: Pull Request resolved: #3223 We port jorgep31415's work of conv1d for lite interpreter into ET. The current implementation supports general batch_size, weight_size, stride, padding, dilation and groups. Reviewed By: jorgep31415 Differential Revision: D56380147 fbshipit-source-id: 62fdc2958d683590317aaec5be3d0366f6df42e4
1 parent f89c312 commit 45fd796

File tree

7 files changed

+175
-111
lines changed

7 files changed

+175
-111
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl

Lines changed: 87 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -21,78 +21,104 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
2121
layout(set = 0, binding = 2) uniform PRECISION sampler3D kernel_in;
2222
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;
2323

24-
layout(set = 0, binding = 4) uniform PRECISION restrict Out_channels {
25-
int data;
26-
}
27-
out_channels;
28-
29-
layout(set = 0, binding = 5) uniform PRECISION restrict In_length {
30-
int data;
31-
}
32-
in_length;
33-
34-
layout(set = 0, binding = 6) uniform PRECISION restrict Kernel_size {
35-
int data;
36-
}
37-
kernel_size;
24+
layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits {
25+
ivec3 out_limits;
26+
};
27+
28+
layout(set = 0, binding = 5) uniform PRECISION restrict InSizes {
29+
ivec4 in_sizes;
30+
};
31+
32+
layout(set = 0, binding = 6) uniform PRECISION restrict Params {
33+
int kernel_size;
34+
int stride;
35+
int padding;
36+
int dilation;
37+
int in_group_size;
38+
int out_group_size;
39+
};
3840

3941
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4042

41-
/*
42-
* This implementation optimize for simplicity (and partially performance) for a
43-
* (1, C, L) where C == groups. Hence we only focus on calculating the rolling
44-
* kernel of the L dimension.
45-
*/
43+
// Let us define
44+
//
45+
// input = (N, in_C, in_L),
46+
// output = (N, out_C, out_L),
47+
// groups = G,
48+
// kernel = K,
49+
//
50+
// which results in shapes
51+
//
52+
// weight = (out_C, in_C / G, K),
53+
// bias = (out_C,).
54+
//
55+
// This implementation performs out_C shader invocations, where each invocation
56+
// calculates the rolling kernel of the length dimension for each batch, i.e.,
57+
// computes out_L * N results.
58+
//
59+
// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
60+
// shader invocations, where each invocation computes 1 result. But that
61+
// performs worse.
4662
void main() {
4763
const ivec3 pos = ivec3(gl_GlobalInvocationID);
4864

49-
// The global workgroup should have taken care of it. We only perform one
50-
// work item for each 1d tensor on lengths
51-
if (pos.x >= 1) {
65+
if (any(greaterThanEqual(pos, out_limits))) {
5266
return;
5367
}
5468

55-
int c = pos.y;
56-
if (c >= out_channels.data) {
57-
return;
58-
}
59-
60-
// Assume n = 1, do not handle n > 1 case for now.
61-
int n = pos.z;
62-
if (n >= 1) {
63-
return;
64-
}
65-
66-
vec4 bias = texelFetch(bias_in, ivec3(c, 0, 0), 0);
67-
68-
for (int i = 0; i < in_length.data - kernel_size.data + 1; ++i) {
69-
vec4 v = vec4(0);
70-
for (int k = 0; k < kernel_size.data; ++k) {
71-
const ivec3 in_pos = ivec3(i+k, c, 0);
72-
const vec4 input_value = texelFetch(image_in, in_pos, 0);
73-
74-
// Note that we are reading weight in the inner loop, this could be
75-
// improved by moving it before the outer loop. Since the weight vector is
76-
// contant for the entire call.
77-
78-
// weight in input-space: (c, 0, k);
79-
// notice that c is 4-packed. We need to mod 4 to get the actual weight.
80-
const ivec3 w_pos = ivec3(k, 0, c / 4);
81-
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
82-
83-
float w = weight.x;
84-
if (c % 4 == 1) {
85-
w = weight.y;
86-
} else if (c % 4 == 2) {
87-
w = weight.z;
88-
} else if (c % 4 == 3) {
89-
w = weight.w;
69+
int in_length = in_sizes.x;
70+
int batch_size = in_sizes.z;
71+
72+
// "out_c" is the output's channel index where we write our result.
73+
// Across shader invocations, this is the only value that varies.
74+
int out_c = pos.y;
75+
vec4 bias = texelFetch(bias_in, ivec3(out_c, 0, 0), 0);
76+
77+
// "in_c" tracks the input's channel start index.
78+
// We iterate over the input group that corresponds to the output group.
79+
int c_start = (out_c / out_group_size) * in_group_size;
80+
int c_end = c_start + in_group_size;
81+
82+
// "in_l" tracks the input's length start index for our input-kernel overlay
83+
// region.
84+
int l_start = -padding;
85+
int l_end = in_length + padding - dilation * (kernel_size - 1);
86+
87+
// Since the input/output tensors are channel-packed, which is along the
88+
// batch dimension, we can batch-read/write four elements at a time.
89+
for (int n = 0; n < batch_size; n += 4) {
90+
// "out_l" tracks the output's length index where we write our result.
91+
int out_l = 0;
92+
93+
for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
94+
vec4 sum = vec4(0);
95+
96+
for (int in_c = c_start; in_c < c_end; ++in_c) {
97+
// "k" tracks the kernel's index for our input-kernel computation.
98+
// It reads out-of-bound zeros, but trying to avoid them complicates
99+
// for-loop conditions, which results in worse performance.
100+
for (int k = 0; k < kernel_size; k += 4) {
101+
// Since the weight tensor is width-packed, which is along the length
102+
// dimension, we can batch-read four elements at a time.
103+
const ivec3 w_pos = ivec3(k / 4, in_c % in_group_size, out_c);
104+
const vec4 weight = texelFetch(kernel_in, w_pos, 0);
105+
106+
const ivec3 in_pos_0 = ivec3(in_l + k * dilation, in_c, n / 4);
107+
sum = fma(weight.xxxx, texelFetch(image_in, in_pos_0, 0), sum);
108+
109+
const ivec3 in_pos_1 = ivec3(in_l + (k+1) * dilation, in_c, n / 4);
110+
sum = fma(weight.yyyy, texelFetch(image_in, in_pos_1, 0), sum);
111+
112+
const ivec3 in_pos_2 = ivec3(in_l + (k+2) * dilation, in_c, n / 4);
113+
sum = fma(weight.zzzz, texelFetch(image_in, in_pos_2, 0), sum);
114+
115+
const ivec3 in_pos_3 = ivec3(in_l + (k+3) * dilation, in_c, n / 4);
116+
sum = fma(weight.wwww, texelFetch(image_in, in_pos_3, 0), sum);
117+
}
90118
}
91119

92-
v += w * input_value.x;
120+
ivec3 out_pos = ivec3(out_l, out_c, n / 4);
121+
imageStore(image_out, out_pos, sum + bias.x);
93122
}
94-
95-
ivec3 out_pos = ivec3(i, c, 0);
96-
imageStore(image_out, out_pos, vec4(v.x + bias.x, 0, 0, 0));
97123
}
98124
}

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ void resize_conv1d_node(
6161
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
6262
vTensorPtr self = graph->get_tensor(args[1].refs[0]);
6363
TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
64+
65+
int64_t stride_size = graph->get_int_list(extra_args[1])->at(0);
66+
int64_t padding_size = graph->get_int_list(extra_args[2])->at(0);
67+
int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0);
68+
6469
const std::vector<int64_t>& weight_sizes = weight_ref->sizes;
6570

6671
const std::vector<int64_t>& in_sizes = self->sizes();
@@ -71,8 +76,9 @@ void resize_conv1d_node(
7176
int64_t in_length = in_sizes.at(2);
7277

7378
new_out_sizes.at(0) = in_sizes.at(0);
74-
new_out_sizes.at(1) = in_sizes.at(1);
75-
new_out_sizes.at(2) = in_length - kernel_size + 1;
79+
new_out_sizes.at(1) = weight_sizes.at(0);
80+
new_out_sizes.at(2) = calc_out_size(
81+
in_length, kernel_size, stride_size, padding_size, dilation_size, false);
7682

7783
out->virtual_resize(new_out_sizes);
7884
}
@@ -244,10 +250,6 @@ ValueRef prepack_weights(
244250
}
245251

246252
void check_conv_args(const vTensor& in, const vTensor& out) {
247-
if (in.sizes().at(0) > 1) {
248-
VK_THROW(
249-
"aten.convolution.default: input batch size > 1 is not supported yet!");
250-
}
251253
VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked));
252254
VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked));
253255
}
@@ -260,7 +262,7 @@ struct Conv2dParams final {
260262
Conv2dParams create_conv2d_params(
261263
ComputeGraph& graph,
262264
const ValueRef weight,
263-
const KernelParams& p,
265+
const Kernel2dParams& p,
264266
const bool transposed) {
265267
const auto& overlay_region = api::utils::make_ivec2({
266268
p.kernel_size.data[0] +
@@ -275,7 +277,7 @@ Conv2dParams create_conv2d_params(
275277
return {overlay_region, in_group_size};
276278
}
277279

278-
void check_conv2d_params(const KernelParams& p, const bool transposed) {
280+
void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
279281
if (transposed) {
280282
if (p.dilation.data[0] > 1 || p.dilation.data[1] > 1) {
281283
VK_THROW(
@@ -342,12 +344,15 @@ void add_conv2d_node(
342344

343345
vTensorPtr t_in = graph.get_tensor(arg_in);
344346
vTensorPtr t_out = graph.get_tensor(out);
347+
if (t_in->sizes().at(0) > 1) {
348+
VK_THROW("conv2d: input batch size > 1 is not supported yet!");
349+
}
345350
check_conv_args(*t_in, *t_out);
346351

347352
api::utils::uvec3 global_size = t_out->extents();
348353
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
349354

350-
KernelParams kernel_params = create_kernel_params(
355+
Kernel2dParams kernel_params = create_kernel2d_params(
351356
graph,
352357
weight,
353358
/*kernel_size_only = */ false,
@@ -395,8 +400,7 @@ void add_conv1d_node(
395400
const ValueRef groups,
396401
const ValueRef out) {
397402
ValueRef arg_in = prepack_if_tensor_ref(graph, in);
398-
ValueRef arg_weight =
399-
prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in));
403+
ValueRef arg_weight = prepack_if_tensor_ref(graph, weight, api::kWidthPacked);
400404
ValueRef arg_bias = prepack_biases(
401405
graph,
402406
bias,
@@ -414,37 +418,29 @@ void add_conv1d_node(
414418
std::vector<int64_t> in_sizes = t_in->sizes();
415419
std::vector<int64_t> weight_sizes = t_weight->sizes();
416420
std::vector<int64_t> out_sizes = t_out->sizes();
417-
IntListPtr stride_sizes = graph.get_int_list(stride);
418-
IntListPtr padding_sizes = graph.get_int_list(padding);
419-
IntListPtr dilation_sizes = graph.get_int_list(dilation);
420-
int64_t weight_out_channels = weight_sizes.at(0);
421-
int64_t kernel_size = weight_sizes.at(2);
422-
int64_t in_length = in_sizes.at(2);
423-
424-
VK_CHECK_COND(in_sizes.size() == 3, "input must be a 3-dim tensor");
425-
VK_CHECK_COND(weight_sizes.size() == 3, "weight must be a 3-dim tensor");
426-
VK_CHECK_COND(
427-
stride_sizes->size() == 1 && stride_sizes->at(0) == 1,
428-
"stride must be 1");
429-
VK_CHECK_COND(
430-
padding_sizes->size() == 1 && padding_sizes->at(0) == 0,
431-
"padding must be 0");
432-
VK_CHECK_COND(
433-
dilation_sizes->size() == 1 && dilation_sizes->at(0) == 1,
434-
"dilation must be 1");
435-
VK_CHECK_COND(
436-
groups_val == in_sizes.at(1), "groups must be equal to in_channels");
437-
VK_CHECK_COND(
438-
groups_val == weight_sizes.at(0),
439-
"groups must be equal to weight_sizes.at(0)");
440-
VK_CHECK_COND(weight_sizes.at(1) == 1, "weight_sizes.at(1) must be 1");
441421

442422
check_conv_args(*t_in, *t_out);
443423

444-
api::utils::uvec3 global_size = {
445-
1, static_cast<uint32_t>(weight_out_channels), 1};
424+
int32_t in_channels = in_sizes.at(1);
425+
int32_t out_channels = weight_sizes.at(0);
426+
int32_t kernel_size = weight_sizes.at(2);
427+
int32_t stride_size = graph.get_int_list(stride)->at(0);
428+
int32_t padding_size = graph.get_int_list(padding)->at(0);
429+
int32_t dilation_size = graph.get_int_list(dilation)->at(0);
430+
int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
431+
int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
432+
433+
api::utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
446434
api::utils::uvec3 local_size = {1, 1, 1};
447435

436+
Kernel1dParams kernel_params = {
437+
kernel_size,
438+
stride_size,
439+
padding_size,
440+
dilation_size,
441+
in_group_size,
442+
out_group_size};
443+
448444
std::string kernel_name("conv1d");
449445
kernel_name.reserve(kShaderNameReserve);
450446

@@ -460,15 +456,15 @@ void add_conv1d_node(
460456
{{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}},
461457
// Shader params buffers
462458
{
463-
graph.create_params_buffer(weight_out_channels),
464-
graph.create_params_buffer(in_length),
465-
graph.create_params_buffer(kernel_size),
459+
t_out->texture_limits_ubo(),
460+
t_in->sizes_ubo(),
461+
graph.create_params_buffer(kernel_params),
466462
},
467463
// Specialization Constants
468464
{},
469465
// Resizing Logic
470466
resize_conv1d_node,
471-
{weight}));
467+
{weight, stride, padding, dilation}));
472468
}
473469

474470
void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void add_max_pool2d_node(
7676
std::string kernel_name("max_pool2d");
7777
add_dtype_suffix(kernel_name, *t_out);
7878

79-
KernelParams kernel_params = create_kernel_params(
79+
Kernel2dParams kernel_params = create_kernel2d_params(
8080
graph,
8181
kernel_size,
8282
/*kernel_size_only = */ true,

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ api::utils::ivec2 make_ivec2_kernel_size(
2626
}
2727
}
2828

29-
KernelParams create_kernel_params(
29+
Kernel2dParams create_kernel2d_params(
3030
ComputeGraph& graph,
3131
const ValueRef weight,
3232
const bool kernel_size_only,

backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,38 @@
1616

1717
namespace vkcompute {
1818

19-
struct KernelParams final {
19+
struct Kernel2dParams final {
2020
api::utils::ivec2 kernel_size;
2121
api::utils::ivec2 stride;
2222
api::utils::ivec2 padding;
2323
api::utils::ivec2 dilation;
2424
};
2525

26-
KernelParams create_kernel_params(
26+
struct Kernel1dParams final {
27+
int kernel_size;
28+
int stride;
29+
int padding;
30+
int dilation;
31+
int in_group_size;
32+
int out_group_size;
33+
};
34+
35+
Kernel2dParams create_kernel2d_params(
2736
ComputeGraph& graph,
2837
const ValueRef weight,
2938
const bool kernel_size_only,
3039
const ValueRef stride,
3140
const ValueRef padding,
3241
const ValueRef dilation);
3342

43+
int64_t calc_out_size(
44+
const int64_t in_size,
45+
const int64_t kernel_size,
46+
const int64_t stride,
47+
const int64_t padding,
48+
const int64_t dilation,
49+
const bool ceil_mode);
50+
3451
std::vector<int64_t> calc_out_sizes_hw(
3552
ComputeGraph& graph,
3653
const std::vector<int64_t>& in_sizes,

0 commit comments

Comments
 (0)