Skip to content

Commit 096fa1e

Browse files
[ET-VK] Adding all tensor packing support for native layer norm. (#9870)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #9532 by @trivedivivek ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/trivedivivek/71/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/71/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/trivedivivek/71/orig @diff-train-skip-merge Co-authored-by: Vivek Trivedi <[email protected]>
1 parent 9e7b469 commit 096fa1e

File tree

4 files changed

+114
-39
lines changed

4 files changed

+114
-39
lines changed

backends/vulkan/op_registry.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,6 @@ def register_ported_op_all_packed_dims(features: OpFeatures):
576576
[
577577
exir_ops.edge.aten.embedding.default,
578578
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
579-
exir_ops.edge.aten.native_layer_norm.default,
580579
]
581580
)
582581
def register_ported_ops_with_prepacking(features: OpFeatures):
@@ -587,6 +586,20 @@ def register_ported_ops_with_prepacking(features: OpFeatures):
587586
return features
588587

589588

589+
# Ported ops that support their own prepacking.
590+
@update_features(
591+
[
592+
exir_ops.edge.aten.native_layer_norm.default,
593+
]
594+
)
595+
def register_ported_ops_with_prepacking_all_dims(features: OpFeatures):
596+
features.texture_impl = TextureImplFeatures(
597+
valid_packed_dims=all_packed_dims,
598+
)
599+
features.handles_own_prepacking = True
600+
return features
601+
602+
590603
#######################
591604
## Utility functions ##
592605
#######################

backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl

+94-32
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
#define VEC4_T ${texel_type(DTYPE)}
1717

18+
#define T ${texel_component_type(DTYPE)}
19+
1820
layout(std430) buffer;
1921

2022
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
@@ -48,37 +50,97 @@ void main() {
4850

4951
const int width = int(sizes.x);
5052

51-
VEC4_T mean = VEC4_T(0);
52-
VEC4_T delta = VEC4_T(0);
53-
VEC4_T delta2 = VEC4_T(0);
54-
VEC4_T M2 = VEC4_T(0);
55-
56-
// Use Welford's online algorithm to compute mean and variance in one pass
57-
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
58-
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
59-
for (int w = 0; w < width; ++w) {
60-
in_pos[in_axis_map.x] = w;
61-
VEC4_T v = load_texel(t_in, in_pos);
62-
delta = v - mean;
63-
mean += delta / (w + 1);
64-
delta2 = v - mean;
65-
M2 += delta * delta2;
66-
}
67-
68-
VEC4_T var = M2 / width;
69-
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
70-
VEC4_T offset = -rstd * mean;
71-
72-
for (int w = 0; w < width; ++w) {
73-
in_pos[in_axis_map.x] = w;
74-
VEC4_T v = load_texel(t_in, in_pos);
75-
// broadcasting
76-
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
77-
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
78-
VEC4_T outtex = (v * rstd + offset) * weight + bias;
79-
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
53+
if (in_packed_dim != W_DIM) {
54+
VEC4_T mean = VEC4_T(0);
55+
VEC4_T delta = VEC4_T(0);
56+
VEC4_T delta2 = VEC4_T(0);
57+
VEC4_T M2 = VEC4_T(0);
58+
59+
// Use Welford's online algorithm to compute mean and variance in one pass
60+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
61+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
62+
for (int w = 0; w < width; ++w) {
63+
in_pos[in_axis_map.x] = w;
64+
VEC4_T v = load_texel(t_in, in_pos);
65+
delta = v - mean;
66+
mean += delta / (w + 1);
67+
delta2 = v - mean;
68+
M2 += delta * delta2;
69+
}
70+
71+
VEC4_T var = M2 / width;
72+
VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
73+
VEC4_T offset = -rstd * mean;
74+
75+
for (int w = 0; w < width; ++w) {
76+
in_pos[in_axis_map.x] = w;
77+
VEC4_T v = load_texel(t_in, in_pos);
78+
// broadcasting
79+
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
80+
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
81+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
82+
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
83+
}
84+
85+
write_texel(t_mean, lpos, mean);
86+
write_texel(t_rstd, lpos, rstd);
87+
} else {
88+
const int packed_width = divup4(width);
89+
90+
T mean = T(0);
91+
T delta = T(0);
92+
T delta2 = T(0);
93+
T M2 = T(0);
94+
// Use Welford's online algorithm to compute mean and variance in one pass
95+
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
96+
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
97+
T width_counter = T(1);
98+
99+
const bool has_unaligned_width = (width & 0x3) != 0;
100+
const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width);
101+
102+
// iterate through texels that are fully packed ie. has 4 components
103+
for (int w = 0; w < fully_packed_4_comp_count; ++w) {
104+
in_pos[in_axis_map.x] = w;
105+
VEC4_T v = load_texel(t_in, in_pos);
106+
for (int i=0; i<4; i++) {
107+
delta = v[i] - mean;
108+
mean += delta / width_counter;
109+
delta2 = v[i] - mean;
110+
M2 += delta * delta2;
111+
width_counter++;
112+
}
113+
}
114+
115+
// handle last texel if its not 4 aligned
116+
if (has_unaligned_width) {
117+
in_pos[in_axis_map.x] = fully_packed_4_comp_count;
118+
const int remaining_width = width & 0x3;
119+
120+
VEC4_T v = load_texel(t_in, in_pos);
121+
for (int i=0; i<remaining_width; i++) {
122+
delta = v[i] - mean;
123+
mean += delta / width_counter;
124+
delta2 = v[i] - mean;
125+
M2 += delta * delta2;
126+
width_counter++;
127+
}
128+
}
129+
130+
T var = M2 / (width_counter - 1);
131+
T rstd = inversesqrt(var + epsilon);
132+
T offset = -rstd * mean;
133+
134+
for (int w = 0; w < packed_width; ++w) {
135+
in_pos[in_axis_map.x] = w;
136+
VEC4_T v = load_texel(t_in, in_pos);
137+
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0));
138+
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0));
139+
VEC4_T outtex = (v * rstd + offset) * weight + bias;
140+
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
141+
}
142+
143+
write_texel(t_mean, lpos, VEC4_T(mean));
144+
write_texel(t_rstd, lpos, VEC4_T(rstd));
80145
}
81-
82-
write_texel(t_mean, lpos, mean);
83-
write_texel(t_rstd, lpos, rstd);
84146
}

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

+1-6
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,6 @@ void resize_native_layer_norm_node(
4848
rstd->virtual_resize(mean_size);
4949
}
5050

51-
void check_args(const api::vTensor& in, const api::vTensor& out) {
52-
VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
53-
VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
54-
}
55-
5651
void add_native_layer_norm_node(
5752
ComputeGraph& graph,
5853
const ValueRef in,
@@ -84,7 +79,7 @@ void add_native_layer_norm_node(
8479
vTensorPtr t_input = graph.get_tensor(in);
8580
float epsilon = graph.extract_scalar<float>(eps);
8681

87-
check_args(*t_input, *t_out);
82+
VK_CHECK_COND(check_same_packed_dim(*t_input, *t_out));
8883

8984
std::vector<int64_t> in_sizes = t_input->sizes();
9085

backends/vulkan/test/op_tests/cases.py

+5
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,11 @@ def get_native_layer_norm_inputs():
385385
((S, XL, M1, M2), [M2], (M2), (M2), 0.001),
386386
]
387387
)
388+
test_suite.layouts = [
389+
"utils::kWidthPacked",
390+
"utils::kHeightPacked",
391+
"utils::kChannelsPacked",
392+
]
388393
return test_suite
389394

390395

0 commit comments

Comments
 (0)