|
8 | 8 |
|
9 | 9 | #pragma once
|
10 | 10 |
|
| 11 | +#include <cassert> |
| 12 | +#include <climits> |
11 | 13 | #include <cstdint>
|
12 | 14 | #include <cstdlib>
|
13 | 15 | #include <cstring>
|
@@ -35,151 +37,217 @@ using native_fp16_t = void;
|
35 | 37 |
|
36 | 38 | namespace fbgemm {
|
37 | 39 |
|
38 |
| -// The IEEE754 standard species a binary16 as having the following format: |
39 |
| -// SEEEEEMMMMMMMMMM |
40 |
| -// 0432109876543210 |
41 |
| -// That is: |
42 |
| -// * 1 sign bit |
43 |
| -// * 5 exponent bits |
44 |
| -// * 10 mantissa/significand bits (an 11th bit is implicit) |
45 |
| -constexpr uint32_t f16_num_bits = 16; |
46 |
| -constexpr uint32_t f16_num_exponent_bits = 5; |
47 |
| -constexpr uint32_t f16_num_mantissa_bits = 10; |
48 |
| -constexpr uint32_t f16_num_non_sign_bits = |
49 |
| - f16_num_exponent_bits + f16_num_mantissa_bits; |
50 |
| -constexpr uint32_t f16_exponent_mask = 0b1'1111; // 5 bits |
51 |
| -constexpr uint32_t f16_sign_bit = 1u |
52 |
| - << (f16_num_exponent_bits + f16_num_mantissa_bits); |
53 |
| -constexpr uint32_t f16_exponent_bits = f16_exponent_mask |
54 |
| - << f16_num_mantissa_bits; |
55 |
| -constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111; // 10 bits |
56 |
| -constexpr uint32_t f16_exponent_bias = 15; |
57 |
| -constexpr uint32_t f16_nan = 0x7F'FF; |
58 |
| - |
59 |
| -// The IEEE754 standard specifies a binary32 as having: |
60 |
| -// SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM |
61 |
| -// That is: |
62 |
| -// * 1 sign bit |
63 |
| -// * 8 exponent bits |
64 |
| -// * 23 mantissa/significand bits (a 24th bit is implicit) |
65 |
| -constexpr uint32_t f32_num_exponent_bits = 8; |
66 |
| -constexpr uint32_t f32_num_mantissa_bits = 23; |
67 |
| -constexpr uint32_t f32_exponent_mask = 0b1111'1111; // 8 bits |
68 |
| -constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF; // 23 bits |
69 |
| -constexpr uint32_t f32_exponent_bias = 127; |
70 |
| -constexpr uint32_t f32_all_non_sign_mask = 0x7F'FF'FF'FF; // 31 bits |
71 |
| -constexpr uint32_t f32_most_significant_bit = 1u << 22; // Turn on 23rd bit |
72 |
| -constexpr uint32_t f32_num_non_sign_bits = |
73 |
| - f32_num_exponent_bits + f32_num_mantissa_bits; |
74 |
| - |
75 |
| -// Round to nearest even |
76 |
| -inline float16 cpu_float2half_rn(float f) { |
77 |
| - static_assert( |
78 |
| - sizeof(uint32_t) == sizeof(float), |
79 |
| - "Programming error sizeof(uint32_t) != sizeof(float)"); |
80 |
| - |
81 |
| - uint32_t* xp = reinterpret_cast<uint32_t*>(&f); |
82 |
| - uint32_t x = *xp; |
83 |
| - uint32_t u = (x & f32_all_non_sign_mask); |
84 |
| - |
85 |
| - // Get rid of +NaN/-NaN case first. |
86 |
| - if (u > 0x7f800000) { |
87 |
| - return static_cast<float16>(f16_nan); |
88 |
| - } |
89 |
| - |
90 |
| - uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit); |
91 |
| - |
92 |
| - // Get rid of +Inf/-Inf, +0/-0. |
93 |
| - if (u > 0x477fefff) { |
94 |
| - return static_cast<float16>(sign | f16_exponent_bits); |
95 |
| - } |
96 |
| - if (u < 0x33000001) { |
97 |
| - return static_cast<float16>(sign | 0x0000); |
| 40 | +namespace detail { |
| 41 | + |
| 42 | +template <typename T, int ExponentBits, bool HasInfinity = true> |
| 43 | +struct FloatFormat { |
| 44 | + using value_type = T; |
| 45 | + static constexpr int bits = sizeof(T) * CHAR_BIT; |
| 46 | + static constexpr int exponent_bits = ExponentBits; |
| 47 | + static constexpr int mantissa_bits = bits - exponent_bits - 1; |
| 48 | + static constexpr int sign_bit_pos = bits - 1; |
| 49 | + static constexpr int exponent_bias = (1 << (exponent_bits - 1)) - 1; |
| 50 | + static constexpr int unbiased_exponent_min = -exponent_bias + 1; |
| 51 | + static constexpr int unbiased_exponent_max = |
| 52 | + HasInfinity ? exponent_bias : (exponent_bias + 1); |
| 53 | + static constexpr T sign_bit = T{1} << sign_bit_pos; |
| 54 | + static constexpr T exponent_mask = ((T{1} << exponent_bits) - 1) |
| 55 | + << mantissa_bits; |
| 56 | + static constexpr T mantissa_mask = (T{1} << mantissa_bits) - 1; |
| 57 | + // signaling/quiet encoding is unspecified by IEEE754. This mirrors x86/ARM. |
| 58 | + static constexpr T quiet_nan_bit = T{1} << (mantissa_bits - 1); |
| 59 | + |
| 60 | + static constexpr T nan = exponent_mask | mantissa_mask; |
| 61 | + static constexpr T overflow_value = HasInfinity ? exponent_mask : nan; |
| 62 | + static constexpr bool has_infinity = HasInfinity; |
| 63 | + static constexpr bool has_nan_payload = HasInfinity; |
| 64 | +}; |
| 65 | + |
| 66 | +using IEEE754Single = FloatFormat</*T=*/uint32_t, /*ExponentBits=*/8>; |
| 67 | +using IEEE754Half = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/5>; |
| 68 | +// See https://arxiv.org/abs/1905.12322v3 |
| 69 | +using BFloat16 = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/8>; |
| 70 | +// See https://doi.org/10.48550/arXiv.2209.05433 |
| 71 | +using FP8_E5M2 = FloatFormat</*T=*/uint8_t, /*ExponentBits=*/5>; |
| 72 | +// See https://doi.org/10.48550/arXiv.2209.05433 |
| 73 | +using FP8_E4M3FN = FloatFormat< |
| 74 | + /*T=*/uint8_t, |
| 75 | + /*ExponentBits=*/4, |
| 76 | + /*HasInfinity=*/false>; |
| 77 | + |
| 78 | +enum class RoundingMode { |
| 79 | + ToNearestTiesToEven, |
| 80 | + ToZero, |
| 81 | +}; |
| 82 | + |
| 83 | +// Generic IEEE754 truncation algorithm. |
| 84 | +template <typename Src, typename Tgt, RoundingMode RoundingMode> |
| 85 | +[[gnu::always_inline]] inline typename Tgt::value_type ieee754_trunc( |
| 86 | + typename Src::value_type value) { |
| 87 | + static_assert(Src::exponent_bits >= Tgt::exponent_bits); |
| 88 | + static_assert(Src::mantissa_bits > Tgt::mantissa_bits); |
| 89 | + using ST = typename Src::value_type; |
| 90 | + using TT = typename Tgt::value_type; |
| 91 | + |
| 92 | + ST src_exponent = value & Src::exponent_mask; |
| 93 | + ST src_mantissa = value & Src::mantissa_mask; |
| 94 | + // Fast-path: If there is no difference in exponent sizes (e.g. fp32 -> bf16) |
| 95 | + // and we round toward zero, then we can just drop the least significant bits. |
| 96 | + if constexpr ( |
| 97 | + Src::exponent_bits == Tgt::exponent_bits && Src::has_infinity && |
| 98 | + Tgt::has_infinity && RoundingMode == RoundingMode::ToZero) { |
| 99 | + TT result = value >> (Src::bits - Tgt::bits); |
| 100 | + // Turn signaling NaN into quiet NaN. This also avoids that the mantissa |
| 101 | + // is completely zero after truncation (which would be misinterpreted as |
| 102 | + // INF). |
| 103 | + if (src_exponent == Src::exponent_mask && src_mantissa != 0) { |
| 104 | + result |= Tgt::quiet_nan_bit; |
| 105 | + } |
| 106 | + return result; |
98 | 107 | }
|
99 | 108 |
|
100 |
| - uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask); |
101 |
| - uint32_t mantissa = (u & f32_mantissa_mask); |
102 |
| - |
103 |
| - uint32_t shift; |
104 |
| - if (exponent > f32_exponent_bias - f16_exponent_bias) { |
105 |
| - shift = f32_num_mantissa_bits - f16_num_mantissa_bits; |
106 |
| - exponent -= f32_exponent_bias - f16_exponent_bias; |
107 |
| - } else { |
108 |
| - shift = (f32_exponent_bias - 1) - exponent; |
109 |
| - exponent = 0; |
110 |
| - mantissa |= |
111 |
| - (1u |
112 |
| - << f32_num_mantissa_bits); // Bump the least significant exponent bit |
113 |
| - } |
114 |
| - const uint32_t lsb = (1u << shift); |
115 |
| - const uint32_t lsb_s1 = (lsb >> 1); |
116 |
| - const uint32_t lsb_m1 = (lsb - 1); |
117 |
| - |
118 |
| - // Round to nearest even. |
119 |
| - const uint32_t remainder = (mantissa & lsb_m1); |
120 |
| - mantissa >>= shift; |
121 |
| - if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { |
122 |
| - ++mantissa; |
123 |
| - if (!(mantissa & f16_mantissa_mask)) { |
124 |
| - ++exponent; |
125 |
| - mantissa = 0; |
| 109 | + ST tgt_sign = |
| 110 | + (value & Src::sign_bit) >> (Src::sign_bit_pos - Tgt::sign_bit_pos); |
| 111 | + constexpr bool denormal_becomes_zero = |
| 112 | + Tgt::unbiased_exponent_min - Src::unbiased_exponent_min > |
| 113 | + Src::mantissa_bits - Tgt::mantissa_bits; |
| 114 | + if constexpr (denormal_becomes_zero) { |
| 115 | + // Fast-path for zero exponentbits: This means the number was zero or a |
| 116 | + // denormal number that will turn into zero in the Tgt format. |
| 117 | + if (src_exponent == 0) { |
| 118 | + return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0 |
126 | 119 | }
|
127 | 120 | }
|
128 | 121 |
|
129 |
| - return static_cast<float16>( |
130 |
| - sign | (exponent << f16_num_mantissa_bits) | mantissa); |
131 |
| -} |
132 |
| - |
133 |
| -// Round to zero |
134 |
| -inline float16 cpu_float2half_rz(float f) { |
135 |
| - static_assert( |
136 |
| - sizeof(uint32_t) == sizeof(float), |
137 |
| - "Programming error sizeof(uint32_t) != sizeof(float)"); |
138 |
| - |
139 |
| - const uint32_t* xp = reinterpret_cast<uint32_t*>(&f); |
140 |
| - const uint32_t x = *xp; |
141 |
| - const uint32_t u = (x & f32_all_non_sign_mask); |
142 |
| - |
143 |
| - // Get rid of +NaN/-NaN case first. |
144 |
| - if (u > 0x7f800000) { |
145 |
| - return static_cast<float16>(f16_nan); |
| 122 | + int unbiased_exponent = |
| 123 | + (src_exponent >> Src::mantissa_bits) - Src::exponent_bias; |
| 124 | + if (unbiased_exponent < Tgt::unbiased_exponent_min) { |
| 125 | + int shift = Tgt::unbiased_exponent_min - unbiased_exponent; |
| 126 | + if (shift <= Tgt::mantissa_bits + 1) { |
| 127 | + // Result is denormal. |
| 128 | + ST src_mantissa_one = src_mantissa; |
| 129 | + // Add explicit one if the source was not denormal. |
| 130 | + if (denormal_becomes_zero || src_exponent != 0) { |
| 131 | + src_mantissa_one |= TT{1} << Src::mantissa_bits; |
| 132 | + } else { |
| 133 | + shift--; |
| 134 | + } |
| 135 | + TT tgt_mantissa = |
| 136 | + src_mantissa_one >> (Src::mantissa_bits - Tgt::mantissa_bits + shift); |
| 137 | + |
| 138 | + if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { |
| 139 | + int half_pos = Src::mantissa_bits - Tgt::mantissa_bits + shift - 1; |
| 140 | + ST half = 1 << half_pos; |
| 141 | + ST remainder = src_mantissa_one & ((half << 1) - 1); |
| 142 | + if (remainder > half || |
| 143 | + (remainder == half && (tgt_mantissa & 1) != 0)) { |
| 144 | + tgt_mantissa += 1; |
| 145 | + } |
| 146 | + } else { |
| 147 | + assert(RoundingMode == RoundingMode::ToZero); |
| 148 | + } |
| 149 | + return tgt_sign | tgt_mantissa; // tgt_exponent == 0 |
| 150 | + } else { |
| 151 | + // Result is +/- zero |
| 152 | + return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0 |
| 153 | + } |
146 | 154 | }
|
147 | 155 |
|
148 |
| - uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit); |
149 |
| - |
150 |
| - // Get rid of +Inf/-Inf, +0/-0. |
151 |
| - if (u > 0x477fefff) { |
152 |
| - return static_cast<float16>(sign | f16_exponent_bits); |
153 |
| - } |
154 |
| - if (u < 0x33000001) { |
155 |
| - return static_cast<float16>(sign | 0x0000); |
| 156 | + if (unbiased_exponent > Tgt::unbiased_exponent_max) { |
| 157 | + if (unbiased_exponent == Src::exponent_bias + 1 && src_mantissa != 0) { |
| 158 | + TT tgt_mantissa; |
| 159 | + if constexpr (Tgt::has_nan_payload) { |
| 160 | + // NaN; not a number |
| 161 | + tgt_mantissa = |
| 162 | + src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); |
| 163 | + tgt_mantissa |= Tgt::quiet_nan_bit; |
| 164 | + } else { |
| 165 | + tgt_mantissa = Tgt::mantissa_mask; |
| 166 | + } |
| 167 | + return tgt_sign | Tgt::exponent_mask | tgt_mantissa; |
| 168 | + } else { |
| 169 | + if (RoundingMode == RoundingMode::ToZero && |
| 170 | + (!Src::has_infinity || src_exponent != Src::exponent_mask)) { |
| 171 | + // Return largest finite number. |
| 172 | + return tgt_sign | (Tgt::exponent_mask - Tgt::has_infinity) | |
| 173 | + Tgt::mantissa_mask; |
| 174 | + } |
| 175 | + // Infinity or NaN for formats without infinity. |
| 176 | + return tgt_sign | Tgt::overflow_value; |
| 177 | + } |
156 | 178 | }
|
157 | 179 |
|
158 |
| - uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask); |
159 |
| - uint32_t mantissa = (u & f32_mantissa_mask); |
160 |
| - |
161 |
| - uint32_t shift; |
162 |
| - if (exponent > f32_exponent_bias - f16_exponent_bias) { |
163 |
| - shift = f32_num_mantissa_bits - f16_num_mantissa_bits; |
164 |
| - exponent -= f32_exponent_bias - f16_exponent_bias; |
| 180 | + // Normal number. |
| 181 | + TT tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits); |
| 182 | + TT tgt_exponent = (unbiased_exponent + Tgt::exponent_bias) |
| 183 | + << Tgt::mantissa_bits; |
| 184 | + if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) { |
| 185 | + ST half = 1 << (Src::mantissa_bits - Tgt::mantissa_bits - 1); |
| 186 | + ST remainder = src_mantissa & ((half << 1) - 1); |
| 187 | + if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) { |
| 188 | + if (tgt_mantissa < Tgt::mantissa_mask) { |
| 189 | + tgt_mantissa += 1; |
| 190 | + } else { |
| 191 | + // Mantissa overflowed, increment exponent. |
| 192 | + |
| 193 | + // Normally we can just add to the exponent and will naturally end up |
| 194 | + // on infinity on overflow. But we need special treatments for formats |
| 195 | + // without infinity. |
| 196 | + if (Tgt::has_infinity || tgt_exponent != Tgt::exponent_mask) { |
| 197 | + tgt_mantissa = 0; |
| 198 | + tgt_exponent += TT{1} << Tgt::mantissa_bits; |
| 199 | + } else { |
| 200 | + // Return NaN. |
| 201 | + tgt_mantissa = Tgt::mantissa_mask; |
| 202 | + } |
| 203 | + } |
| 204 | + } |
165 | 205 | } else {
|
166 |
| - shift = (f32_exponent_bias - 1) - exponent; |
167 |
| - exponent = 0; |
168 |
| - mantissa |= |
169 |
| - (1u |
170 |
| - << f32_num_mantissa_bits); // Bump the least significant exponent bit |
| 206 | + assert(RoundingMode == RoundingMode::ToZero); |
171 | 207 | }
|
| 208 | + return tgt_sign | tgt_exponent | tgt_mantissa; |
| 209 | +} |
172 | 210 |
|
173 |
| - // Round to zero. |
174 |
| - mantissa >>= shift; |
| 211 | +} // namespace detail |
175 | 212 |
|
176 |
| - return static_cast<float16>( |
177 |
| - sign | (exponent << f16_num_mantissa_bits) | mantissa); |
| 213 | +inline float16 cpu_float2half_rn(float f) { |
| 214 | + uint32_t f_u32; |
| 215 | + std::memcpy(&f_u32, &f, sizeof(f_u32)); |
| 216 | + return detail::ieee754_trunc< |
| 217 | + /*Src=*/detail::IEEE754Single, |
| 218 | + /*Tgt=*/detail::IEEE754Half, |
| 219 | + detail::RoundingMode::ToNearestTiesToEven>(f_u32); |
178 | 220 | }
|
179 | 221 |
|
| 222 | +inline float16 cpu_float2half_rz(float f) { |
| 223 | + uint32_t f_u32; |
| 224 | + std::memcpy(&f_u32, &f, sizeof(f_u32)); |
| 225 | + return detail::ieee754_trunc< |
| 226 | + /*Src=*/detail::IEEE754Single, |
| 227 | + /*Tgt=*/detail::IEEE754Half, |
| 228 | + detail::RoundingMode::ToZero>(f_u32); |
| 229 | +}; |
| 230 | + |
180 | 231 | // Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
|
181 | 232 | // float into an IEEE754 32-bit single-precision float
|
182 | 233 | inline float cpu_half2float_ref(const float16 h) {
|
| 234 | + constexpr uint32_t f16_num_exponent_bits = 5; |
| 235 | + constexpr uint32_t f16_num_mantissa_bits = 10; |
| 236 | + constexpr uint32_t f16_num_non_sign_bits = |
| 237 | + f16_num_exponent_bits + f16_num_mantissa_bits; |
| 238 | + constexpr uint32_t f16_exponent_bias = 15; |
| 239 | + constexpr uint32_t f16_exponent_mask = 0b1'1111; |
| 240 | + constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111; |
| 241 | + |
| 242 | + constexpr uint32_t f32_num_exponent_bits = 8; |
| 243 | + constexpr uint32_t f32_num_mantissa_bits = 23; |
| 244 | + constexpr uint32_t f32_num_non_sign_bits = |
| 245 | + f32_num_exponent_bits + f32_num_mantissa_bits; |
| 246 | + constexpr uint32_t f32_exponent_bias = 127; |
| 247 | + constexpr uint32_t f32_exponent_mask = 0b1111'1111; |
| 248 | + constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF; |
| 249 | + constexpr uint32_t f32_most_significant_bit = 1u << 22; |
| 250 | + |
183 | 251 | // Get sign and exponent alone by themselves
|
184 | 252 | uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
|
185 | 253 | uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;
|
|
0 commit comments