|
15 | 15 |
|
16 | 16 | #define VEC4_T ${texel_type(DTYPE)}
|
17 | 17 |
|
| 18 | +#define T ${texel_component_type(DTYPE)} |
| 19 | + |
18 | 20 | layout(std430) buffer;
|
19 | 21 |
|
20 | 22 | ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
|
@@ -48,37 +50,97 @@ void main() {
|
48 | 50 |
|
49 | 51 | const int width = int(sizes.x);
|
50 | 52 |
|
51 |
| - VEC4_T mean = VEC4_T(0); |
52 |
| - VEC4_T delta = VEC4_T(0); |
53 |
| - VEC4_T delta2 = VEC4_T(0); |
54 |
| - VEC4_T M2 = VEC4_T(0); |
55 |
| - |
56 |
| - // Use Welford's online algorithm to compute mean and variance in one pass |
57 |
| - // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm |
58 |
| - ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); |
59 |
| - for (int w = 0; w < width; ++w) { |
60 |
| - in_pos[in_axis_map.x] = w; |
61 |
| - VEC4_T v = load_texel(t_in, in_pos); |
62 |
| - delta = v - mean; |
63 |
| - mean += delta / (w + 1); |
64 |
| - delta2 = v - mean; |
65 |
| - M2 += delta * delta2; |
66 |
| - } |
67 |
| - |
68 |
| - VEC4_T var = M2 / width; |
69 |
| - VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); |
70 |
| - VEC4_T offset = -rstd * mean; |
71 |
| - |
72 |
| - for (int w = 0; w < width; ++w) { |
73 |
| - in_pos[in_axis_map.x] = w; |
74 |
| - VEC4_T v = load_texel(t_in, in_pos); |
75 |
| - // broadcasting |
76 |
| - VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; |
77 |
| - VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; |
78 |
| - VEC4_T outtex = (v * rstd + offset) * weight + bias; |
79 |
| - write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); |
| 53 | + if (in_packed_dim != W_DIM) { |
| 54 | + VEC4_T mean = VEC4_T(0); |
| 55 | + VEC4_T delta = VEC4_T(0); |
| 56 | + VEC4_T delta2 = VEC4_T(0); |
| 57 | + VEC4_T M2 = VEC4_T(0); |
| 58 | + |
| 59 | + // Use Welford's online algorithm to compute mean and variance in one pass |
| 60 | + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm |
| 61 | + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); |
| 62 | + for (int w = 0; w < width; ++w) { |
| 63 | + in_pos[in_axis_map.x] = w; |
| 64 | + VEC4_T v = load_texel(t_in, in_pos); |
| 65 | + delta = v - mean; |
| 66 | + mean += delta / (w + 1); |
| 67 | + delta2 = v - mean; |
| 68 | + M2 += delta * delta2; |
| 69 | + } |
| 70 | + |
| 71 | + VEC4_T var = M2 / width; |
| 72 | + VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); |
| 73 | + VEC4_T offset = -rstd * mean; |
| 74 | + |
| 75 | + for (int w = 0; w < width; ++w) { |
| 76 | + in_pos[in_axis_map.x] = w; |
| 77 | + VEC4_T v = load_texel(t_in, in_pos); |
| 78 | + // broadcasting |
| 79 | + VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; |
| 80 | + VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; |
| 81 | + VEC4_T outtex = (v * rstd + offset) * weight + bias; |
| 82 | + write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); |
| 83 | + } |
| 84 | + |
| 85 | + write_texel(t_mean, lpos, mean); |
| 86 | + write_texel(t_rstd, lpos, rstd); |
| 87 | + } else { |
| 88 | + const int packed_width = divup4(width); |
| 89 | + |
| 90 | + T mean = T(0); |
| 91 | + T delta = T(0); |
| 92 | + T delta2 = T(0); |
| 93 | + T M2 = T(0); |
| 94 | + // Use Welford's online algorithm to compute mean and variance in one pass |
| 95 | + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm |
| 96 | + ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); |
| 97 | + T width_counter = T(1); |
| 98 | + |
| 99 | + const bool has_unaligned_width = (width & 0x3) != 0; |
| 100 | + const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width); |
| 101 | + |
| 102 | + // iterate through texels that are fully packed ie. has 4 components |
| 103 | + for (int w = 0; w < fully_packed_4_comp_count; ++w) { |
| 104 | + in_pos[in_axis_map.x] = w; |
| 105 | + VEC4_T v = load_texel(t_in, in_pos); |
| 106 | + for (int i=0; i<4; i++) { |
| 107 | + delta = v[i] - mean; |
| 108 | + mean += delta / width_counter; |
| 109 | + delta2 = v[i] - mean; |
| 110 | + M2 += delta * delta2; |
| 111 | + width_counter++; |
| 112 | + } |
| 113 | + } |
| 114 | + |
| 115 | + // handle last texel if its not 4 aligned |
| 116 | + if (has_unaligned_width) { |
| 117 | + in_pos[in_axis_map.x] = fully_packed_4_comp_count; |
| 118 | + const int remaining_width = width & 0x3; |
| 119 | + |
| 120 | + VEC4_T v = load_texel(t_in, in_pos); |
| 121 | + for (int i=0; i<remaining_width; i++) { |
| 122 | + delta = v[i] - mean; |
| 123 | + mean += delta / width_counter; |
| 124 | + delta2 = v[i] - mean; |
| 125 | + M2 += delta * delta2; |
| 126 | + width_counter++; |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + T var = M2 / (width_counter - 1); |
| 131 | + T rstd = inversesqrt(var + epsilon); |
| 132 | + T offset = -rstd * mean; |
| 133 | + |
| 134 | + for (int w = 0; w < packed_width; ++w) { |
| 135 | + in_pos[in_axis_map.x] = w; |
| 136 | + VEC4_T v = load_texel(t_in, in_pos); |
| 137 | + VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)); |
| 138 | + VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)); |
| 139 | + VEC4_T outtex = (v * rstd + offset) * weight + bias; |
| 140 | + write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); |
| 141 | + } |
| 142 | + |
| 143 | + write_texel(t_mean, lpos, VEC4_T(mean)); |
| 144 | + write_texel(t_rstd, lpos, VEC4_T(rstd)); |
80 | 145 | }
|
81 |
| - |
82 |
| - write_texel(t_mean, lpos, mean); |
83 |
| - write_texel(t_rstd, lpos, rstd); |
84 | 146 | }
|
0 commit comments