Skip to content

Commit 2b82a7b

Browse files
author
ssjia
committed
[ET-VK][conv2d] Re-implement pointwise conv2d with tiled compute and blocked weight packing
Profiling EdgeTAM on Adreno shows pointwise 1×1 convolutions are a dominant bottleneck. This diff re-implements the stride=1, padding=0 pointwise path using the same tiled matmul approach as the recently landed linear shader rewrite. The new `conv2d_pw_tiled` shader reuses the shared linear tiled infrastructure (FPInputTile, FPWeightTile, FPOutTile, fp_accumulate_with_fp_weight, packed weight tile loading) with custom input/output tile load/store functions that map flat spatial indices to channels-packed texture3d coordinates. Weight packing uses the same 4OC×4IC blocked format as linear via the `pack_fp_linear_weight` shader. Dispatch uses DynamicDispatchNode for correct workgroup size updates during graph resizing. Only the stride=1, padding=0 pointwise path is changed; the general conv2d_pw shader for arbitrary stride/padding is left unchanged. EdgeTAM first frame on Samsung S25 (Adreno 830): 208 ms → 196 ms (~6%). Authored with Claude. Differential Revision: [D96756792](https://our.internmc.facebook.com/intern/diff/D96756792/) [ghstack-poisoned]
1 parent 552a31f commit 2b82a7b

File tree

8 files changed

+822
-0
lines changed

8 files changed

+822
-0
lines changed
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
13+
14+
#define TILE_M4 ${TILE_M4}
15+
#define TILE_K4 ${TILE_K4}
16+
#define TILE_N4 ${TILE_N4}
17+
18+
#define TILE_M ${TILE_M}
19+
#define TILE_K ${TILE_K4 * 4}
20+
#define TILE_N ${TILE_N4 * 4}
21+
22+
${define_required_extensions(STORAGE, DTYPE)}
23+
24+
layout(std430) buffer;
25+
26+
#include "common.glslh"
27+
28+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
29+
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
30+
${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, "texture2d")}
31+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "texture2d")}
32+
33+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
34+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
35+
36+
layout(push_constant) uniform restrict Block {
37+
int stride_h;
38+
int stride_w;
39+
int padding_h;
40+
int padding_w;
41+
float out_min;
42+
float out_max;
43+
};
44+
45+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
46+
47+
${layout_declare_spec_const(C, "int", "stride_1_padding_0", "0")}
48+
${layout_declare_spec_const(C, "int", "activation_type", "0")}
49+
50+
#include "linear_fp_input_tile.glslh"
51+
#include "linear_fp_packed_weight_tile_load.glslh"
52+
#include "linear_fp_output_tile_fp_compute.glslh"
53+
54+
void load_input_tile_with_checks(
55+
out FPInputTile tile,
56+
const int k4_start,
57+
const int m_start,
58+
const int K4,
59+
const int M,
60+
const int W_out,
61+
const int W_in,
62+
const int H_in) {
63+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
64+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
65+
if (k4_start + k4 < K4 && m_start + m < M) {
66+
if (stride_1_padding_0 != 0) {
67+
const int spatial = m_start + m;
68+
tile.data[m][k4] =
69+
texelFetch(t_in, ivec3(spatial % W_out, spatial / W_out, k4_start + k4), 0);
70+
} else {
71+
const int out_spatial = m_start + m;
72+
const int out_x = out_spatial % W_out;
73+
const int out_y = out_spatial / W_out;
74+
const int in_x = out_x * stride_w - padding_w;
75+
const int in_y = out_y * stride_h - padding_h;
76+
if (in_x >= 0 && in_x < W_in && in_y >= 0 && in_y < H_in) {
77+
tile.data[m][k4] =
78+
texelFetch(t_in, ivec3(in_x, in_y, k4_start + k4), 0);
79+
} else {
80+
tile.data[m][k4] = VEC4_T(0.0);
81+
}
82+
}
83+
} else {
84+
tile.data[m][k4] = VEC4_T(0.0);
85+
}
86+
}
87+
}
88+
}
89+
90+
void store_output_tile_with_checks(
91+
const FPOutTile out_tile,
92+
const int n4_start,
93+
const int m_start,
94+
const int N4,
95+
const int M,
96+
const int W_out) {
97+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
98+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
99+
if (m_start + m < M && n4_start + n4 < N4) {
100+
const int spatial = m_start + m;
101+
VEC4_T texel = out_tile.data[m][n4];
102+
if (activation_type == 1) {
103+
texel = max(texel, VEC4_T(0.0));
104+
} else if (activation_type == 2) {
105+
texel = clamp(texel, VEC4_T(out_min), VEC4_T(out_max));
106+
}
107+
imageStore(t_out, ivec3(spatial % W_out, spatial / W_out, n4_start + n4), texel);
108+
}
109+
}
110+
}
111+
}
112+
113+
void main() {
114+
const int tile_idx_n = int(gl_GlobalInvocationID.x);
115+
const int tile_idx_m = int(gl_GlobalInvocationID.y);
116+
117+
const int n4_start = tile_idx_n * TILE_N4;
118+
const int m_start = tile_idx_m * TILE_M;
119+
120+
const int W_in = in_sizes.x;
121+
const int H_in = in_sizes.y;
122+
const int K = in_sizes.z;
123+
const int K4 = div_up_4(K);
124+
125+
const int W_out = out_sizes.x;
126+
const int H_out = out_sizes.y;
127+
const int M = W_out * H_out;
128+
const int N = out_sizes.z;
129+
const int N4 = div_up_4(N);
130+
131+
if (n4_start >= N4 || m_start >= M) {
132+
return;
133+
}
134+
135+
FPOutTile out_tile;
136+
initialize(out_tile);
137+
138+
FPInputTile in_tile;
139+
FPWeightTile w_tile;
140+
141+
for (int k4 = 0; k4 < K4; k4++) {
142+
load_input_tile_with_checks(in_tile, k4, m_start, K4, M, W_out, W_in, H_in);
143+
load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4);
144+
fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile);
145+
}
146+
147+
// Apply bias
148+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
149+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
150+
if (n4_start + n4 < N4) {
151+
out_tile.data[m][n4] +=
152+
texelFetch(t_bias, ivec2(n4_start + n4, 0), 0);
153+
}
154+
}
155+
}
156+
157+
store_output_tile_with_checks(out_tile, n4_start, m_start, N4, M, W_out);
158+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_pw_tiled:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: texture3d
11+
TILE_M4: 1
12+
TILE_K4: 1
13+
TILE_N4: 1
14+
TILE_M: 4
15+
generate_variant_forall:
16+
DTYPE:
17+
- VALUE: float
18+
- VALUE: half
19+
shader_variants:
20+
- NAME: conv2d_pw_tiled

0 commit comments

Comments
 (0)