Skip to content

[ET-VK] Manual sync to fbsource #10238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 159 additions & 75 deletions backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -43,106 +43,190 @@ ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
const lowp int out_packed_dim = unhash_packed_dim(out_layout);

void main() {
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
#define SHARED_MEMORY_FACTOR 2
#define MAX_WORKGROUP_SIZE 64

#define offset_pos_index(index) ((index) + ((index) >> 2))

shared VEC4_T shared_input[offset_pos_index(MAX_WORKGROUP_SIZE * SHARED_MEMORY_FACTOR)];

// function to reduce input data in workgroup's x dimension
void reduce_input(const int width_stride, const int shared_idx_offset) {
// wait for all shared memory writes to finish
memoryBarrierShared();
barrier();

// loop log(width_stride) times
for (int current_stride = 1, index = int(gl_LocalInvocationID.x << 1); current_stride < width_stride; current_stride *= 2, index <<= 1) {
// if the index at this thread is within the width stride
if (index < width_stride) {
const int local_shared_idx = shared_idx_offset + index;
// add the value at current stride to this thread's value
shared_input[offset_pos_index(local_shared_idx)] += shared_input[offset_pos_index(local_shared_idx + current_stride)];
}

if (any(greaterThanEqual(lpos, out_limits))) {
return;
memoryBarrierShared();
barrier();
}
}

void main() {
const ivec3 lpos = ivec3(gl_GlobalInvocationID);
const int width = int(sizes.x);

ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);

// width batch read stride
const int width_stride = int(gl_WorkGroupSize.x) * SHARED_MEMORY_FACTOR;

// local memory starting offset for this thread
const int shared_idx_offset = width_stride * int(gl_WorkGroupSize.y * gl_LocalInvocationID.z + gl_LocalInvocationID.y);

// local memory index for this thread
const int shared_idx = shared_idx_offset + int(gl_LocalInvocationID.x);

// if packed dimension width
if (in_packed_dim != W_DIM) {
VEC4_T mean = VEC4_T(0);
VEC4_T delta = VEC4_T(0);
VEC4_T delta2 = VEC4_T(0);
VEC4_T M2 = VEC4_T(0);

// Use Welford's online algorithm to compute mean and variance in one pass
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
for (int w = 0; w < width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
delta = v - mean;
mean += delta / (w + 1);
delta2 = v - mean;
M2 += delta * delta2;
VEC4_T var = VEC4_T(0);

// Loop over the width in stride increments
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
// Read input in shared memory
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);

VEC4_T in_val = VEC4_T(0);
if (all(lessThan(in_pos, out_limits))) {
in_val = load_texel(t_in, in_pos);
}
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
}

reduce_input(width_stride, shared_idx_offset);
mean += shared_input[offset_pos_index(shared_idx_offset)];
}

mean /= width;

// Loop over the width in stride increments
for (int width_offset = 0; width_offset < width; width_offset += width_stride) {
// Read input in shared memory
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
in_pos[in_axis_map.x] = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);

VEC4_T in_val = mean;
if (all(lessThan(in_pos, out_limits))) {
in_val = load_texel(t_in, in_pos);
}

const VEC4_T delta = in_val - mean;
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta * delta;
}

reduce_input(width_stride, shared_idx_offset);
var += shared_input[offset_pos_index(shared_idx_offset)];
}

VEC4_T var = M2 / width;
var /= width;

VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));
VEC4_T offset = -rstd * mean;

for (int w = 0; w < width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
// broadcasting
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;
VEC4_T outtex = (v * rstd + offset) * weight + bias;
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
VEC4_T v = load_texel(t_in, lpos);
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0)).xxxx;
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0)).xxxx;
VEC4_T outtex = (v * rstd + offset) * weight + bias;
if (all(lessThan(lpos, out_limits))) {
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
}

write_texel(t_mean, lpos, mean);
write_texel(t_rstd, lpos, rstd);
if (gl_GlobalInvocationID.x == 0) {
write_texel(t_mean, lpos, mean);
write_texel(t_rstd, lpos, rstd);
}
} else {
const int packed_width = divup4(width);

const int last_packed_width_index = divup4(width) - 1;
T mean = T(0);
T delta = T(0);
T delta2 = T(0);
T M2 = T(0);
// Use Welford's online algorithm to compute mean and variance in one pass
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);
T width_counter = T(1);

const bool has_unaligned_width = (width & 0x3) != 0;
const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width);

// iterate through texels that are fully packed ie. has 4 components
for (int w = 0; w < fully_packed_4_comp_count; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
for (int i=0; i<4; i++) {
delta = v[i] - mean;
mean += delta / width_counter;
delta2 = v[i] - mean;
M2 += delta * delta2;
width_counter++;
T var = T(0);
const int remain = width & 3;

const int in_pos_x_limit = out_limits[in_axis_map.x];

// Loop over the width in stride increments
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
// Read input in shared memory
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
in_pos[in_axis_map.x] = in_pos_x;

VEC4_T in_val = VEC4_T(0);
if (in_pos_x < in_pos_x_limit) {
in_val = load_texel(t_in, in_pos);
}

if (in_pos_x == last_packed_width_index && remain != 0) {
const int remain_inv = 4 - remain;
in_val.y = mix(in_val.y, T(0), remain_inv > 2);
in_val.z = mix(in_val.z, T(0), remain_inv > 1);
in_val.w = mix(in_val.w, T(0), remain_inv > 0);
}

shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = in_val;
}

reduce_input(width_stride, shared_idx_offset);
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
mean += val.x + val.y + val.z + val.w;
}

// handle last texel if its not 4 aligned
if (has_unaligned_width) {
in_pos[in_axis_map.x] = fully_packed_4_comp_count;
const int remaining_width = width & 0x3;

VEC4_T v = load_texel(t_in, in_pos);
for (int i=0; i<remaining_width; i++) {
delta = v[i] - mean;
mean += delta / width_counter;
delta2 = v[i] - mean;
M2 += delta * delta2;
width_counter++;
mean /= width;

// Loop over the width in stride increments
for (int width_offset = 0; width_offset <= last_packed_width_index; width_offset += width_stride) {
// Read input in shared memory
for (int si = 0; si < SHARED_MEMORY_FACTOR; si++) {
const int in_pos_x = width_offset + int(gl_LocalInvocationID.x + si * gl_WorkGroupSize.x);
in_pos[in_axis_map.x] = in_pos_x;

VEC4_T in_val = VEC4_T(mean);
if (in_pos_x < in_pos_x_limit) {
in_val = load_texel(t_in, in_pos);
}

if (in_pos_x == last_packed_width_index && remain != 0) {
const int remain_inv = 4 - remain;
in_val.y = mix(in_val.y, mean.x, remain_inv > 2);
in_val.z = mix(in_val.z, mean.x, remain_inv > 1);
in_val.w = mix(in_val.w, mean.x, remain_inv > 0);
}

const VEC4_T delta = in_val - mean;
const VEC4_T delta2 = delta * delta;
shared_input[offset_pos_index(shared_idx + si * gl_WorkGroupSize.x)] = delta2;
}

reduce_input(width_stride, shared_idx_offset);
const VEC4_T val = shared_input[offset_pos_index(shared_idx_offset)];
var += val.x + val.y + val.z + val.w;
}

T var = M2 / (width_counter - 1);
T rstd = inversesqrt(var + epsilon);
var /= width;

T rstd = pow(var + epsilon, T(-0.5));
T offset = -rstd * mean;

for (int w = 0; w < packed_width; ++w) {
in_pos[in_axis_map.x] = w;
VEC4_T v = load_texel(t_in, in_pos);
VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0));
VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0));
VEC4_T outtex = (v * rstd + offset) * weight + bias;
write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);
VEC4_T v = load_texel(t_in, lpos);
VEC4_T weight = load_texel(t_weight, ivec3(lpos.x, 0, 0));
VEC4_T bias = load_texel(t_bias, ivec3(lpos.x, 0, 0));
VEC4_T outtex = (v * rstd + offset) * weight + bias;
if (all(lessThan(lpos, out_limits))) {
write_texel_lpos(t_out, ivec3(lpos.x, lpos.y, lpos.z), outtex, out_axis_map);
}

write_texel(t_mean, lpos, VEC4_T(mean));
write_texel(t_rstd, lpos, VEC4_T(rstd));
if (gl_GlobalInvocationID.x == 0) {
write_texel(t_mean, lpos, VEC4_T(mean));
write_texel(t_rstd, lpos, VEC4_T(rstd));
}
}
}
14 changes: 9 additions & 5 deletions backends/vulkan/runtime/graph/ops/glsl/permute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ layout(push_constant) uniform PRECISION restrict Block {
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
layout(constant_id = 3) const int packed_dim = C_DIM;

#extension GL_EXT_control_flow_attributes : require

void main() {
ivec3 pos = ivec3(gl_GlobalInvocationID);

Expand All @@ -54,11 +56,16 @@ void main() {
in_bchw_pos[out_ndims[2]] = pos.y;
in_bchw_pos[out_ndims[3]] = pos.x;

for (int j = 0; j < 4; ++j) {
const int in_packed_dim_size = in_sizes[3 - out_ndims[in_packed_dim_bchw_index]];

[[unroll]] for (int j = 0, bchw_index = in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]; j < 4; ++j, ++bchw_index) {
// terminate the loop if trying to access input texture out of bounds
if (any(greaterThanEqual(in_bchw_pos.wzyx, in_sizes.xyzw))) {
if (bchw_index >= in_packed_dim_size) {
break;
}
// go to position in the input, that is mapped to the packed dim in the output
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]] = bchw_index;

ivec3 fetch_pos;

fetch_pos.xy = in_bchw_pos.wz;
Expand All @@ -74,9 +81,6 @@ void main() {
// fetch input texel
VEC4_T inval = VEC4_T(load_texel(t_in, fetch_pos));
outval[j] = inval[in_packed_dim_lane_index];

// go to next position in the input, that is mapped to the packed dim in the output
in_bchw_pos[out_ndims[in_packed_dim_bchw_index]]++;
}

pos[packed_dim] = int(gl_GlobalInvocationID[packed_dim]);
Expand Down
Loading
Loading