-
Notifications
You must be signed in to change notification settings - Fork 881
Expand file tree
/
Copy pathembedding_legacy.glsl
More file actions
53 lines (40 loc) · 1.77 KB
/
embedding_legacy.glsl
File metadata and controls
53 lines (40 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#version 450 core
#define PRECISION ${PRECISION}
#define VEC4_T ${texel_type(DTYPE)}
layout(std430) buffer;
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", "int", STORAGE)}
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "texture2d")}
${layout_declare_ubo(B, "ivec4", "sizes")}
#include "indexing_utils.h"
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
const lowp int packed_dim = unhash_packed_dim(out_layout);
${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
void main() {
const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim);
if (any(greaterThanEqual(out_tidx, sizes))) {
return;
}
VEC4_T out_texel;
// Consider optimizing via W-packing format for t_in and t_weight.
for (int i = 0; i < 4; ++i) {
// Read input tensor for embedding index.
const ivec3 in_lpos = ivec3(out_tidx.y, out_tidx.z * 4 + i, out_tidx.w / 4);
const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4];
// Read weight tensor for embedding, it is height-packed.
const ivec2 weight_pos = ivec2(out_tidx.x, in_texel_elem / 4);
out_texel[i] = texelFetch(t_weight, weight_pos, 0)[in_texel_elem % 4];
}
write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map);
}