Skip to content

Commit fab1463

Browse files
[ET-VK] [ET-VK] Reduced int precision for all int storage in conv pw op to improve performance. (#7499)
[ET-VK] Reduced int precision for all int storage in conv pw op to improve performance. Pull Request resolved: #7447 This diff reduces the precision of all int storage in the conv pw op to improve performance. The code changes include adding the extension GL_EXT_shader_explicit_arithmetic_types_int16 and changing the data type of ints to uint16. ghstack-source-id: 260166244 @exported-using-ghexport Differential Revision: [D67674212](https://our.internmc.facebook.com/intern/diff/D67674212/) Co-authored-by: Vivek Trivedi <[email protected]>
1 parent c001634 commit fab1463

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

+12-10
Original file line numberDiff line numberDiff line change
@@ -32,35 +32,37 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3232

3333
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3434

35+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36+
3537
/*
3638
* Computes a depthwise convolution. Each shader invocation calculates the
3739
* output at a single output location.
3840
*/
3941
void main() {
40-
const ivec3 pos = ivec3(gl_GlobalInvocationID);
42+
const u16vec3 pos = u16vec3(gl_GlobalInvocationID);
4143

4244
if (any(greaterThanEqual(pos, out_limits))) {
4345
return;
4446
}
4547

4648
// Compute the index of the top-left element of the overlay region. Negative
4749
// indices indicate that the top-left element is in a region added by padding.
48-
const ivec2 ipos = pos.xy * stride - padding;
50+
const u16vec2 ipos = pos.xy * u16vec2(stride) - u16vec2(padding);
4951

5052
// Compute the start and end of the input indices to load. Padding is assumed
5153
// to be constant 0 padding, so any reads from the padding region is skipped.
52-
const ivec2 start = ipos;
53-
const ivec2 end = ipos + overlay_region.xy;
54+
const u16vec2 start = ipos;
55+
const u16vec2 end = ipos + u16vec2(overlay_region.xy);
5456

55-
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
56-
int kx = 0;
57-
for (int y = start.y, i = 0; i < TILE_SIZE; y += dilation.y, i++) {
58-
for (int x = start.x, j = 0; j < TILE_SIZE; x += dilation.x, j++) {
57+
VEC4_T sum = texelFetch(t_bias, u16vec2(pos.z, 0), 0);
58+
uint16_t kx = uint16_t(0);
59+
for (uint16_t y = start.y, i = uint16_t(0); i < uint16_t(TILE_SIZE); y += uint16_t(dilation.y), i++) {
60+
for (uint16_t x = start.x, j = uint16_t(0); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++) {
5961
// The weight kernel was rearranged such that every NxN filter is
6062
// flattened to fit in one row. Each filter was then stacked on top of
6163
// each other vertically.
62-
const vec4 in_texel = texelFetch(t_in, ivec3(x, y, pos.z), 0);
63-
sum = fma(in_texel, texelFetch(t_kernel, ivec2(kx, pos.z), 0), sum);
64+
const vec4 in_texel = texelFetch(t_in, u16vec3(x, y, pos.z), 0);
65+
sum = fma(in_texel, texelFetch(t_kernel, u16vec2(kx, pos.z), 0), sum);
6466
kx++;
6567
}
6668
}

0 commit comments

Comments
 (0)