Skip to content

Commit 14ea170

Browse files
author
ssjia
committed
[ETVK] Add benchmark binary + im2col/GEMM conv2d prototype
This change does two related things on top of the existing direct conv2d path: it adds a new benchmark binary for general conv2d, and it adds an im2col-backed conv2d implementation that the benchmark exercises alongside the existing direct shader. **Why a benchmark binary** Profiling a sample CNN showed that the standard `conv2d_float` (general sliding window) shader accounts for ~93% of all conv time, with six 3x3 stride=1 same-channels shapes dominating. The existing custom-ops directory had benchmark binaries for pointwise and depthwise conv but no standalone way to iterate on the general kernel. The new `test_conv2d` binary fills that gap. `test_conv2d.cpp` includes 7 small accuracy configs (validated against a CPU float reference) and 13 performance configs covering the sample CNN's hotspots: the six dominant `C_in == C_out` 3x3 stride=1 shapes, several stride=2 downsample variants, two channel-reduction cases, and the 3-channel RGB stem. Perf configs are run in FP32 and FP16; accuracy configs are FP32-only because the reference is float. The binary uses 5 warmup + 20 timed iterations per case so the GPU governor reaches a stable clock before measurement. On a Pixel device, the reported per-call latencies for the direct path match the in-model profile within 0.84x-0.99x for all six dominant shapes, confirming the binary is a faithful proxy for in-model conv latency. **Why an im2col-backed conv2d** The im2col approach materializes the conv input into a `[1, K_total, H_out, W_out]` (or `[M, K_total]`) intermediate and runs the conv as a single tiled GEMM. The im2col K-axis layout `K = (ki * Kw + kj) * Cin_padded + ci` is chosen so that every 4-tile of K holds 4 consecutive `ci` values for the same `(ki, kj)` — that way each im2col output texel reads exactly one input texel and the GEMM can use a clean 1x1-style load pattern. On the sample CNN's hotspots this gives 1.20x-1.43x FP32 and 1.50x-1.80x FP16 speedups vs. the direct shader (estimated ~21% reduction in total FP32 conv time, ~36% in FP16) on Pixel 9 Pro XL. The implementation is split into three pieces so we can iterate on the GEMM step in isolation: - `conv2d_im2col.glsl` + `impl/Conv2dIm2Col.{h,cpp}`: the im2col dispatch only. - `conv2d_gemm.glsl` + the orchestration in `impl/Conv2dGemm.{h,cpp}`: a private GEMM shader for the im2col-backed case, separate from the production pointwise path so we can experiment with more aggressive optimizations (larger tiles, cooperative matrix, register blocking) without affecting `conv2d_pw_tiled`. - `Conv2dGemm.cpp` also does the CPU-side weight repack from `[C_out, C_in, Kh, Kw]` into the matching `[C_out, K_total]` layout, wrapped in a `FreeableBuffer` so the graph owns the lifetime. **Device-specific storage selection** Both shader templates codegen three variants of the im2col intermediate — `buffer`, `texture2d` width-packed `[K4_total, M]`, and `texture3d` channels-packed `[W_out, H_out, K4_total]` — and `conv2d_gemm_impl` picks at graph build time based on `graph.device_is_mali()` and the relevant max texture extents. Mali → buffer always (its texture sampling is comparatively slow vs SSBO reads). Adreno and others prefer `texture2d`, but for shapes where M would exceed `max_texture2d_dim` (e.g. `[1, 32, 144, 192]` with M = 27,648) the dispatch falls back to `texture3d`, then to `buffer` as a last resort. On Adreno (Samsung S921), the device-specific routing pushes wins to 0.47x-0.79x FP32 and 0.65x-0.96x FP16 on the dominant shapes. On Mali (Pixel 9 Pro XL), buffer routing pushes wins to 0.51x-0.78x FP32 and 0.34x-0.46x FP16. **Test integration** `test_etvk.test_conv2d.default` switches between `aten.convolution.default` and `et_vk.conv2d_gemm.default` based on the `impl_selector` string ("im2col" picks the new path), so the same benchmark binary exercises both implementations back-to-back per shape. Differential Revision: [D105120966](https://our.internmc.facebook.com/intern/diff/D105120966/) [ghstack-poisoned]
1 parent 97cece2 commit 14ea170

11 files changed

Lines changed: 1541 additions & 0 deletions

File tree

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* conv2d_gemm: GEMM step of im2col-backed conv2d.
11+
*
12+
* Reads the im2col'd input produced by conv2d_im2col.glsl as a 2D matrix
13+
* of shape [M, K_total] (M = H_out * W_out, K_total = Kh*Kw*Cin_padded)
14+
* and writes the conv2d output as texture3D channels-packed
15+
* logical shape [1, C_out, H_out, W_out].
16+
*
17+
* The im2col input can be any of:
18+
* - texture2d, width-packed: texel at (k4, m) holds 4 K values for row m.
19+
* IN_STORAGE=texture2d codegen.
20+
* - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values
21+
* for output spatial position (oh, ow). Used when M would exceed
22+
* max_texture2d_dim. IN_STORAGE=texture3d codegen.
23+
* - buffer: vec4 at offset m*K4 + k4, same K packing.
24+
* IN_STORAGE=buffer codegen.
25+
*
26+
* The matmul interpretation is:
27+
* out[m, n] = sum_k im2col[m, k] * weight[n, k] + bias[n]
28+
* with M = H_out * W_out, K = K_total, N = C_out.
29+
*/
30+
31+
#version 450 core
32+
33+
#define PRECISION ${PRECISION}
34+
#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
35+
36+
#define TILE_M4 ${TILE_M4}
37+
#define TILE_K4 ${TILE_K4}
38+
#define TILE_N4 ${TILE_N4}
39+
40+
#define TILE_M ${TILE_M}
41+
#define TILE_K ${TILE_K4 * 4}
42+
#define TILE_N ${TILE_N4 * 4}
43+
44+
$if IN_STORAGE == "buffer":
45+
#define INPUT_BUFFER
46+
$elif IN_STORAGE == "texture3d":
47+
#define INPUT_TEXTURE3D
48+
49+
${define_required_extensions("texture3d", DTYPE)}
50+
$if IN_STORAGE == "buffer":
51+
${define_required_extensions("buffer", DTYPE)}
52+
53+
layout(std430) buffer;
54+
55+
#include "common.glslh"
56+
57+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
58+
$if IN_STORAGE == "buffer":
59+
${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer", is_scalar_array=False)}
60+
$else:
61+
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE)}
62+
${layout_declare_tensor(B, "r", "t_weight_packed", DTYPE, "texture2d")}
63+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "texture2d")}
64+
65+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
66+
67+
// Push constants are uploaded in 16-byte chunks (one ivec4 each).
68+
layout(push_constant) uniform restrict Block {
69+
ivec4 gemm_dims; // (K_total, K4_total, M, _unused)
70+
vec4 clamp_vals; // (out_min, out_max, _unused, _unused)
71+
};
72+
73+
#define K_TOTAL gemm_dims.x
74+
#define K4_TOTAL gemm_dims.y
75+
#define M_TOTAL gemm_dims.z
76+
#define OUT_MIN clamp_vals.x
77+
#define OUT_MAX clamp_vals.y
78+
79+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
80+
81+
${layout_declare_spec_const(C, "int", "activation_type", "0")}
82+
83+
#include "linear_fp_input_tile.glslh"
84+
#include "linear_fp_packed_weight_tile_load.glslh"
85+
#include "linear_fp_output_tile_fp_compute.glslh"
86+
87+
/*
88+
* Load TILE_M rows × TILE_K4 K-tiles of the im2col'd input.
89+
* The im2col output is a contiguous (M, K_total/4) matrix of vec4s, so the
90+
* load is a plain 2D fetch — no spatial decomposition.
91+
*/
92+
void load_input_tile_with_checks(
93+
out FPInputTile tile,
94+
const int k4_start,
95+
const int m_start,
96+
const int K4,
97+
const int M,
98+
const int W_out) {
99+
// W_out is only consumed by the texture3d variant below.
100+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
101+
[[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) {
102+
if (k4_start + k4 < K4 && m_start + m < M) {
103+
const int row = m_start + m;
104+
const int col = k4_start + k4;
105+
#if defined(INPUT_BUFFER)
106+
// Cast handles buffer/half (f16vec4) -> tile's VEC4_T (vec4).
107+
tile.data[m][k4] = VEC4_T(t_in[row * K4 + col]);
108+
#elif defined(INPUT_TEXTURE3D)
109+
// texture3d layout: row (the flat M index) decomposes into (ow, oh)
110+
// and K4 is along the Z axis.
111+
tile.data[m][k4] =
112+
texelFetch(t_in, ivec3(row % W_out, row / W_out, col), 0);
113+
#else
114+
tile.data[m][k4] = texelFetch(t_in, ivec2(col, row), 0);
115+
#endif
116+
} else {
117+
tile.data[m][k4] = VEC4_T(0.0);
118+
}
119+
}
120+
}
121+
}
122+
123+
void store_output_tile_with_checks(
124+
const FPOutTile out_tile,
125+
const int n4_start,
126+
const int m_start,
127+
const int N4,
128+
const int M,
129+
const int W_out) {
130+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
131+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
132+
if (m_start + m < M && n4_start + n4 < N4) {
133+
const int spatial = m_start + m;
134+
VEC4_T texel = out_tile.data[m][n4];
135+
if (activation_type == 1) {
136+
texel = max(texel, VEC4_T(0.0));
137+
} else if (activation_type == 2) {
138+
texel = clamp(texel, VEC4_T(OUT_MIN), VEC4_T(OUT_MAX));
139+
}
140+
imageStore(
141+
t_out, ivec3(spatial % W_out, spatial / W_out, n4_start + n4), texel);
142+
}
143+
}
144+
}
145+
}
146+
147+
void main() {
148+
const int tile_idx_n = int(gl_GlobalInvocationID.x);
149+
const int tile_idx_m = int(gl_GlobalInvocationID.y);
150+
151+
const int n4_start = tile_idx_n * TILE_N4;
152+
const int m_start = tile_idx_m * TILE_M;
153+
154+
const int W_out = out_sizes.x;
155+
const int H_out = out_sizes.y;
156+
const int M = M_TOTAL;
157+
const int K4 = K4_TOTAL;
158+
const int N = out_sizes.z;
159+
const int N4 = div_up_4(N);
160+
161+
if (n4_start >= N4 || m_start >= M) {
162+
return;
163+
}
164+
165+
FPOutTile out_tile;
166+
initialize(out_tile);
167+
168+
FPInputTile in_tile;
169+
FPWeightTile w_tile;
170+
171+
for (int k4 = 0; k4 < K4; k4++) {
172+
load_input_tile_with_checks(in_tile, k4, m_start, K4, M, W_out);
173+
load_packed_weight_tile_with_checks(w_tile, n4_start, k4, 0, N4, K4);
174+
fp_accumulate_with_fp_weight(out_tile, in_tile, w_tile);
175+
}
176+
177+
// Apply bias
178+
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
179+
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
180+
if (n4_start + n4 < N4) {
181+
out_tile.data[m][n4] +=
182+
texelFetch(t_bias, ivec2(n4_start + n4, 0), 0);
183+
}
184+
}
185+
}
186+
187+
store_output_tile_with_checks(out_tile, n4_start, m_start, N4, M, W_out);
188+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_gemm:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
IN_STORAGE: texture2d
11+
TILE_M4: 1
12+
TILE_K4: 1
13+
TILE_N4: 1
14+
TILE_M: 4
15+
generate_variant_forall:
16+
combination:
17+
parameter_names: [IN_STORAGE, DTYPE]
18+
combos:
19+
- parameter_values: [texture2d, float]
20+
- parameter_values: [texture2d, half]
21+
- parameter_values: [texture3d, float]
22+
- parameter_values: [texture3d, half]
23+
- parameter_values: [buffer, float]
24+
- parameter_values: [buffer, half]
25+
shader_variants:
26+
- NAME: conv2d_gemm
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
/*
10+
* Im2col transformation for FP32 / FP16 conv2d.
11+
*
12+
* The output is a 2D matrix of shape [M, K_total] where
13+
* M = H_out * W_out (number of output spatial positions)
14+
* K_total = Kh * Kw * align_up_4(C_in) (flattened receptive field)
15+
*
16+
* K layout (so a 4-tile in K — one vec4 — holds the same kernel position):
17+
* K = (ki * Kw + kj) * Cin_padded + ci
18+
*
19+
* Three codegen'd storage variants of the output tensor:
20+
* - texture2d, width-packed: texel at (k4, m) holds 4 K values for spatial
21+
* position m. Extents = (K_total/4, M).
22+
* - texture3d, channels-packed: texel at (ow, oh, k4) holds 4 K values
23+
* for output spatial position (oh, ow). Extents = (W_out, H_out, K4).
24+
* Used as a fallback when M would exceed max_texture2d_dim.
25+
* - buffer: vec4 at offset (m * K4 + k4), same K packing.
26+
*
27+
* The caller picks storage per device (Mali → buffer; others → texture2d
28+
* when its 2D extents fit, texture3d when its 3D extents fit, else buffer).
29+
*/
30+
31+
#version 450 core
32+
33+
#define PRECISION ${PRECISION}
34+
35+
#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
36+
37+
$if OUT_STORAGE == "buffer":
38+
#define OUTPUT_BUFFER
39+
#define VEC4_BUF_T ${texel_load_type(DTYPE, "buffer")}
40+
$elif OUT_STORAGE == "texture3d":
41+
#define OUTPUT_TEXTURE3D
42+
43+
${define_required_extensions("texture3d", DTYPE)}
44+
$if OUT_STORAGE == "buffer":
45+
${define_required_extensions("buffer", DTYPE)}
46+
47+
layout(std430) buffer;
48+
49+
$if OUT_STORAGE == "buffer":
50+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer", is_scalar_array=False)}
51+
$else:
52+
${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE)}
53+
${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")}
54+
55+
${layout_declare_ubo(B, "ivec4", "in_sizes")}
56+
57+
// Push constants are uploaded in 16-byte chunks (one ivec4 each) to comply
58+
// with the per-entry size limit.
59+
layout(push_constant) uniform restrict Block {
60+
ivec4 kernel_stride; // (Kh, Kw, Sh, Sw)
61+
ivec4 padding_dil; // (Ph, Pw, Dh, Dw)
62+
ivec4 dims; // (Cin_padded, W_out, H_out, K4_total)
63+
};
64+
65+
#define KERNEL_H kernel_stride.x
66+
#define KERNEL_W kernel_stride.y
67+
#define STRIDE_H kernel_stride.z
68+
#define STRIDE_W kernel_stride.w
69+
#define PADDING_H padding_dil.x
70+
#define PADDING_W padding_dil.y
71+
#define DILATION_H padding_dil.z
72+
#define DILATION_W padding_dil.w
73+
#define CIN_PADDED dims.x
74+
#define W_OUT dims.y
75+
#define H_OUT dims.z
76+
#define K4_TOTAL dims.w
77+
78+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
79+
80+
void main() {
81+
const int k4 = int(gl_GlobalInvocationID.x);
82+
const int m = int(gl_GlobalInvocationID.y);
83+
const int M = H_OUT * W_OUT;
84+
85+
if (k4 >= K4_TOTAL || m >= M) {
86+
return;
87+
}
88+
89+
const int k_start = k4 * 4;
90+
91+
// K = (ki * Kw + kj) * Cin_padded + ci ; since Cin_padded % 4 == 0, all 4
92+
// K values in this texel share the same (ki, kj) and span 4 consecutive
93+
// ci values starting at ci_start.
94+
const int krow_idx = k_start / CIN_PADDED; // ki * Kw + kj
95+
const int ci_start = k_start % CIN_PADDED;
96+
const int kj = krow_idx % KERNEL_W;
97+
const int ki = krow_idx / KERNEL_W;
98+
const int ci_blk = ci_start >> 2; // ci_start / 4
99+
100+
// Decompose flat output position m back into (oh, ow).
101+
const int ow = m % W_OUT;
102+
const int oh = m / W_OUT;
103+
104+
// Compute the input spatial position for this (oh, ow, ki, kj).
105+
const int ih = oh * STRIDE_H - PADDING_H + ki * DILATION_H;
106+
const int iw = ow * STRIDE_W - PADDING_W + kj * DILATION_W;
107+
108+
VEC4_T out_texel = VEC4_T(0);
109+
if (ih >= 0 && ih < in_sizes.y && iw >= 0 && iw < in_sizes.x) {
110+
out_texel = texelFetch(t_in, ivec3(iw, ih, ci_blk), 0);
111+
}
112+
113+
#if defined(OUTPUT_BUFFER)
114+
t_out[m * K4_TOTAL + k4] = VEC4_BUF_T(out_texel);
115+
#elif defined(OUTPUT_TEXTURE3D)
116+
imageStore(t_out, ivec3(ow, oh, k4), out_texel);
117+
#else
118+
imageStore(t_out, ivec2(k4, m), out_texel);
119+
#endif
120+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
conv2d_im2col:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
OUT_STORAGE: texture2d
11+
generate_variant_forall:
12+
combination:
13+
parameter_names: [OUT_STORAGE, DTYPE]
14+
combos:
15+
- parameter_values: [texture2d, float]
16+
- parameter_values: [texture2d, half]
17+
- parameter_values: [texture3d, float]
18+
- parameter_values: [texture3d, half]
19+
- parameter_values: [buffer, float]
20+
- parameter_values: [buffer, half]
21+
shader_variants:
22+
- NAME: conv2d_im2col

0 commit comments

Comments
 (0)