Skip to content

Commit cf8d0cf

Browse files
[ET-VK] Making stride equals dilation the default mode for conv2d dw. (#7629)
* [ET-VK] Fixing conv2d dw incorrect output when stride != dilation issue. Pull Request resolved: #7595 This diff moves current implementation of conv2d dw as a special case when stride equals dilation in the Vulkan backend of Executorch, since that's the only time this kind of caching is possible. If stride does not equal dilation the old implementation is used. Additional test cases are added to ensure computation is correct when stride != dilation. ghstack-source-id: 261183385 @exported-using-ghexport Differential Revision: [D67908916](https://our.internmc.facebook.com/intern/diff/D67908916/) * [ET-VK] Making stride equals dilation the default mode for conv2d dw. Pull Request resolved: #7596 This diff makes changes make stride equals dilation the default mode for conv2d dw output op. Adds a different source file to handle stride not equal dilation case. ghstack-source-id: 261183386 Differential Revision: [D67979760](https://our.internmc.facebook.com/intern/diff/D67979760/) * Update conv2d_dw_output_tile.glsl --------- Co-authored-by: Vivek Trivedi <[email protected]>
1 parent f31e466 commit cf8d0cf

File tree

4 files changed

+101
-42
lines changed

4 files changed

+101
-42
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4343
* output at a single output location.
4444
*/
4545

46-
#if STRIDE_EQ_DILATION
4746
void main() {
4847
// x and y are divided by batch size to determine 3d position
4948
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
@@ -125,42 +124,3 @@ void main() {
125124
}
126125
}
127126
}
128-
129-
#else
130-
void main() {
131-
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
132-
const ivec3 pos = ivec3(
133-
gl_GlobalInvocationID.x % out_limits.x,
134-
div_by_x % out_limits.y,
135-
div_by_x / out_limits.y);
136-
137-
if (any(greaterThanEqual(pos, out_limits))) {
138-
return;
139-
}
140-
141-
// Compute the index of the top-left element of the overlay region. Negative
142-
// indices indicate that the top-left element is in a region added by padding.
143-
const ivec2 ipos = pos.xy * stride - padding;
144-
145-
// Compute the start and end of the input indices to load. Padding is assumed
146-
// to be constant 0 padding, so any reads from the padding region is skipped.
147-
const ivec2 start = ipos;
148-
const ivec2 end = ipos + overlay_region.xy;
149-
150-
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
151-
int kx = 0;
152-
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
153-
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
154-
// The weight kernel was rearranged such that every NxN filter is
155-
// flattened to fit in one row. Each filter was then stacked on top of
156-
// each other vertically.
157-
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
158-
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
159-
kx++;
160-
}
161-
}
162-
163-
imageStore(t_out, pos, op(sum, out_min, out_max));
164-
}
165-
166-
#endif
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
#define TILE_SIZE ${TILE_SIZE}
16+
17+
#define op(X, A, B) ${OPERATOR}
18+
19+
#include "indexing_utils.h"
20+
21+
layout(std430) buffer;
22+
23+
${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
24+
${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
25+
${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
26+
${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
27+
${layout_declare_ubo(4, "ivec3", "out_limits")}
28+
${layout_declare_ubo(5, "ivec4", "in_sizes")}
29+
${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
30+
${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
31+
${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
32+
33+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
34+
35+
/*
36+
* Computes a depthwise convolution. Each shader invocation calculates the
37+
* output at a single output location.
38+
*/
39+
40+
void main() {
41+
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
42+
const ivec3 pos = ivec3(
43+
gl_GlobalInvocationID.x % out_limits.x,
44+
div_by_x % out_limits.y,
45+
div_by_x / out_limits.y);
46+
47+
if (any(greaterThanEqual(pos, out_limits))) {
48+
return;
49+
}
50+
51+
// Compute the index of the top-left element of the overlay region. Negative
52+
// indices indicate that the top-left element is in a region added by padding.
53+
const ivec2 ipos = pos.xy * stride - padding;
54+
55+
// Compute the start and end of the input indices to load. Padding is assumed
56+
// to be constant 0 padding, so any reads from the padding region is skipped.
57+
const ivec2 start = ipos;
58+
const ivec2 end = ipos + overlay_region.xy;
59+
60+
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
61+
int kx = 0;
62+
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
63+
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
64+
// The weight kernel was rearranged such that every NxN filter is
65+
// flattened to fit in one row. Each filter was then stacked on top of
66+
// each other vertically.
67+
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
68+
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
69+
kx++;
70+
}
71+
}
72+
73+
imageStore(t_out, pos, op(sum, out_min, out_max));
74+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_dw_sned_output_tile:
8+
parameter_names_with_default_values:
9+
OPERATOR: X
10+
NDIM: 3
11+
DTYPE: float
12+
TILE_SIZE: 3
13+
generate_variant_forall:
14+
DTYPE:
15+
- VALUE: half
16+
- VALUE: float
17+
shader_variants:
18+
- NAME: conv2d_dw_sned_output_tile_3x3
19+
- NAME: conv2d_dw_sned_output_tile_3x3_clamp
20+
OPERATOR: clamp(X, A, B)
21+
- NAME: conv2d_dw_sned_output_tile_5x5
22+
TILE_SIZE: 5
23+
- NAME: conv2d_dw_sned_output_tile_5x5_clamp
24+
OPERATOR: clamp(X, A, B)
25+
TILE_SIZE: 5

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ vkapi::ShaderInfo get_conv2d_shader(
134134
case Conv2dMethod::Depthwise:
135135
kernel_name = "conv2d_dw";
136136
if (!prepack_weights) {
137-
if (stride_equals_dilation) {
138-
kernel_name += "_sed";
137+
if (!stride_equals_dilation) {
138+
kernel_name += "_sned";
139139
}
140140
const auto& weight_sizes = graph.get_tref(weight)->sizes;
141141
if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {

0 commit comments

Comments
 (0)