Skip to content

Commit 49d6314

Browse files
MatzeBfacebook-github-bot
authored andcommitted
Add generic IEEE754 truncation code (#3820)
Summary: Pull Request resolved: #3820 X-link: facebookresearch/FBGEMM#905 This adds a generic implementation of IEEE754 floatingpoint truncation. This is in preparation for conversion to FP8 E5M2 and FP8 E4M3FN formats but for consistency also replaces the existing float2half conversion functions. Reviewed By: r-barnes Differential Revision: D69941314 fbshipit-source-id: 1fb3d34ef3d6bc4613fbc1d522bf7c3eca53d568
1 parent c8ee354 commit 49d6314

File tree

1 file changed

+195
-127
lines changed

1 file changed

+195
-127
lines changed

include/fbgemm/FloatConversion.h

Lines changed: 195 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#pragma once
1010

11+
#include <cassert>
12+
#include <climits>
1113
#include <cstdint>
1214
#include <cstdlib>
1315
#include <cstring>
@@ -35,151 +37,217 @@ using native_fp16_t = void;
3537

3638
namespace fbgemm {
3739

38-
// The IEEE754 standard species a binary16 as having the following format:
39-
// SEEEEEMMMMMMMMMM
40-
// 0432109876543210
41-
// That is:
42-
// * 1 sign bit
43-
// * 5 exponent bits
44-
// * 10 mantissa/significand bits (an 11th bit is implicit)
45-
constexpr uint32_t f16_num_bits = 16;
46-
constexpr uint32_t f16_num_exponent_bits = 5;
47-
constexpr uint32_t f16_num_mantissa_bits = 10;
48-
constexpr uint32_t f16_num_non_sign_bits =
49-
f16_num_exponent_bits + f16_num_mantissa_bits;
50-
constexpr uint32_t f16_exponent_mask = 0b1'1111; // 5 bits
51-
constexpr uint32_t f16_sign_bit = 1u
52-
<< (f16_num_exponent_bits + f16_num_mantissa_bits);
53-
constexpr uint32_t f16_exponent_bits = f16_exponent_mask
54-
<< f16_num_mantissa_bits;
55-
constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111; // 10 bits
56-
constexpr uint32_t f16_exponent_bias = 15;
57-
constexpr uint32_t f16_nan = 0x7F'FF;
58-
59-
// The IEEE754 standard specifies a binary32 as having:
60-
// SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM
61-
// That is:
62-
// * 1 sign bit
63-
// * 8 exponent bits
64-
// * 23 mantissa/significand bits (a 24th bit is implicit)
65-
constexpr uint32_t f32_num_exponent_bits = 8;
66-
constexpr uint32_t f32_num_mantissa_bits = 23;
67-
constexpr uint32_t f32_exponent_mask = 0b1111'1111; // 8 bits
68-
constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF; // 23 bits
69-
constexpr uint32_t f32_exponent_bias = 127;
70-
constexpr uint32_t f32_all_non_sign_mask = 0x7F'FF'FF'FF; // 31 bits
71-
constexpr uint32_t f32_most_significant_bit = 1u << 22; // Turn on 23rd bit
72-
constexpr uint32_t f32_num_non_sign_bits =
73-
f32_num_exponent_bits + f32_num_mantissa_bits;
74-
75-
// Round to nearest even
76-
inline float16 cpu_float2half_rn(float f) {
77-
static_assert(
78-
sizeof(uint32_t) == sizeof(float),
79-
"Programming error sizeof(uint32_t) != sizeof(float)");
80-
81-
uint32_t* xp = reinterpret_cast<uint32_t*>(&f);
82-
uint32_t x = *xp;
83-
uint32_t u = (x & f32_all_non_sign_mask);
84-
85-
// Get rid of +NaN/-NaN case first.
86-
if (u > 0x7f800000) {
87-
return static_cast<float16>(f16_nan);
88-
}
89-
90-
uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit);
91-
92-
// Get rid of +Inf/-Inf, +0/-0.
93-
if (u > 0x477fefff) {
94-
return static_cast<float16>(sign | f16_exponent_bits);
95-
}
96-
if (u < 0x33000001) {
97-
return static_cast<float16>(sign | 0x0000);
40+
namespace detail {
41+
42+
template <typename T, int ExponentBits, bool HasInfinity = true>
43+
struct FloatFormat {
44+
using value_type = T;
45+
static constexpr int bits = sizeof(T) * CHAR_BIT;
46+
static constexpr int exponent_bits = ExponentBits;
47+
static constexpr int mantissa_bits = bits - exponent_bits - 1;
48+
static constexpr int sign_bit_pos = bits - 1;
49+
static constexpr int exponent_bias = (1 << (exponent_bits - 1)) - 1;
50+
static constexpr int unbiased_exponent_min = -exponent_bias + 1;
51+
static constexpr int unbiased_exponent_max =
52+
HasInfinity ? exponent_bias : (exponent_bias + 1);
53+
static constexpr T sign_bit = T{1} << sign_bit_pos;
54+
static constexpr T exponent_mask = ((T{1} << exponent_bits) - 1)
55+
<< mantissa_bits;
56+
static constexpr T mantissa_mask = (T{1} << mantissa_bits) - 1;
57+
// signaling/quiet encoding is unspecified by IEEE754. This mirrors x86/ARM.
58+
static constexpr T quiet_nan_bit = T{1} << (mantissa_bits - 1);
59+
60+
static constexpr T nan = exponent_mask | mantissa_mask;
61+
static constexpr T overflow_value = HasInfinity ? exponent_mask : nan;
62+
static constexpr bool has_infinity = HasInfinity;
63+
static constexpr bool has_nan_payload = HasInfinity;
64+
};
65+
66+
using IEEE754Single = FloatFormat</*T=*/uint32_t, /*ExponentBits=*/8>;
67+
using IEEE754Half = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/5>;
68+
// See https://arxiv.org/abs/1905.12322v3
69+
using BFloat16 = FloatFormat</*T=*/uint16_t, /*ExponentBits=*/8>;
70+
// See https://doi.org/10.48550/arXiv.2209.05433
71+
using FP8_E5M2 = FloatFormat</*T=*/uint8_t, /*ExponentBits=*/5>;
72+
// See https://doi.org/10.48550/arXiv.2209.05433
73+
using FP8_E4M3FN = FloatFormat<
74+
/*T=*/uint8_t,
75+
/*ExponentBits=*/4,
76+
/*HasInfinity=*/false>;
77+
78+
enum class RoundingMode {
79+
ToNearestTiesToEven,
80+
ToZero,
81+
};
82+
83+
// Generic IEEE754 truncation algorithm.
84+
template <typename Src, typename Tgt, RoundingMode RoundingMode>
85+
[[gnu::always_inline]] inline typename Tgt::value_type ieee754_trunc(
86+
typename Src::value_type value) {
87+
static_assert(Src::exponent_bits >= Tgt::exponent_bits);
88+
static_assert(Src::mantissa_bits > Tgt::mantissa_bits);
89+
using ST = typename Src::value_type;
90+
using TT = typename Tgt::value_type;
91+
92+
ST src_exponent = value & Src::exponent_mask;
93+
ST src_mantissa = value & Src::mantissa_mask;
94+
// Fast-path: If there is no difference in exponent sizes (e.g. fp32 -> bf16)
95+
// and we round toward zero, then we can just drop the least significant bits.
96+
if constexpr (
97+
Src::exponent_bits == Tgt::exponent_bits && Src::has_infinity &&
98+
Tgt::has_infinity && RoundingMode == RoundingMode::ToZero) {
99+
TT result = value >> (Src::bits - Tgt::bits);
100+
// Turn signaling NaN into quiet NaN. This also avoids that the mantissa
101+
// is completely zero after truncation (which would be misinterpreted as
102+
// INF).
103+
if (src_exponent == Src::exponent_mask && src_mantissa != 0) {
104+
result |= Tgt::quiet_nan_bit;
105+
}
106+
return result;
98107
}
99108

100-
uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask);
101-
uint32_t mantissa = (u & f32_mantissa_mask);
102-
103-
uint32_t shift;
104-
if (exponent > f32_exponent_bias - f16_exponent_bias) {
105-
shift = f32_num_mantissa_bits - f16_num_mantissa_bits;
106-
exponent -= f32_exponent_bias - f16_exponent_bias;
107-
} else {
108-
shift = (f32_exponent_bias - 1) - exponent;
109-
exponent = 0;
110-
mantissa |=
111-
(1u
112-
<< f32_num_mantissa_bits); // Bump the least significant exponent bit
113-
}
114-
const uint32_t lsb = (1u << shift);
115-
const uint32_t lsb_s1 = (lsb >> 1);
116-
const uint32_t lsb_m1 = (lsb - 1);
117-
118-
// Round to nearest even.
119-
const uint32_t remainder = (mantissa & lsb_m1);
120-
mantissa >>= shift;
121-
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
122-
++mantissa;
123-
if (!(mantissa & f16_mantissa_mask)) {
124-
++exponent;
125-
mantissa = 0;
109+
ST tgt_sign =
110+
(value & Src::sign_bit) >> (Src::sign_bit_pos - Tgt::sign_bit_pos);
111+
constexpr bool denormal_becomes_zero =
112+
Tgt::unbiased_exponent_min - Src::unbiased_exponent_min >
113+
Src::mantissa_bits - Tgt::mantissa_bits;
114+
if constexpr (denormal_becomes_zero) {
115+
// Fast-path for zero exponentbits: This means the number was zero or a
116+
// denormal number that will turn into zero in the Tgt format.
117+
if (src_exponent == 0) {
118+
return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0
126119
}
127120
}
128121

129-
return static_cast<float16>(
130-
sign | (exponent << f16_num_mantissa_bits) | mantissa);
131-
}
132-
133-
// Round to zero
134-
inline float16 cpu_float2half_rz(float f) {
135-
static_assert(
136-
sizeof(uint32_t) == sizeof(float),
137-
"Programming error sizeof(uint32_t) != sizeof(float)");
138-
139-
const uint32_t* xp = reinterpret_cast<uint32_t*>(&f);
140-
const uint32_t x = *xp;
141-
const uint32_t u = (x & f32_all_non_sign_mask);
142-
143-
// Get rid of +NaN/-NaN case first.
144-
if (u > 0x7f800000) {
145-
return static_cast<float16>(f16_nan);
122+
int unbiased_exponent =
123+
(src_exponent >> Src::mantissa_bits) - Src::exponent_bias;
124+
if (unbiased_exponent < Tgt::unbiased_exponent_min) {
125+
int shift = Tgt::unbiased_exponent_min - unbiased_exponent;
126+
if (shift <= Tgt::mantissa_bits + 1) {
127+
// Result is denormal.
128+
ST src_mantissa_one = src_mantissa;
129+
// Add explicit one if the source was not denormal.
130+
if (denormal_becomes_zero || src_exponent != 0) {
131+
src_mantissa_one |= TT{1} << Src::mantissa_bits;
132+
} else {
133+
shift--;
134+
}
135+
TT tgt_mantissa =
136+
src_mantissa_one >> (Src::mantissa_bits - Tgt::mantissa_bits + shift);
137+
138+
if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) {
139+
int half_pos = Src::mantissa_bits - Tgt::mantissa_bits + shift - 1;
140+
ST half = 1 << half_pos;
141+
ST remainder = src_mantissa_one & ((half << 1) - 1);
142+
if (remainder > half ||
143+
(remainder == half && (tgt_mantissa & 1) != 0)) {
144+
tgt_mantissa += 1;
145+
}
146+
} else {
147+
assert(RoundingMode == RoundingMode::ToZero);
148+
}
149+
return tgt_sign | tgt_mantissa; // tgt_exponent == 0
150+
} else {
151+
// Result is +/- zero
152+
return tgt_sign; // tgt_exponent == 0, tgt_mantissa == 0
153+
}
146154
}
147155

148-
uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit);
149-
150-
// Get rid of +Inf/-Inf, +0/-0.
151-
if (u > 0x477fefff) {
152-
return static_cast<float16>(sign | f16_exponent_bits);
153-
}
154-
if (u < 0x33000001) {
155-
return static_cast<float16>(sign | 0x0000);
156+
if (unbiased_exponent > Tgt::unbiased_exponent_max) {
157+
if (unbiased_exponent == Src::exponent_bias + 1 && src_mantissa != 0) {
158+
TT tgt_mantissa;
159+
if constexpr (Tgt::has_nan_payload) {
160+
// NaN; not a number
161+
tgt_mantissa =
162+
src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits);
163+
tgt_mantissa |= Tgt::quiet_nan_bit;
164+
} else {
165+
tgt_mantissa = Tgt::mantissa_mask;
166+
}
167+
return tgt_sign | Tgt::exponent_mask | tgt_mantissa;
168+
} else {
169+
if (RoundingMode == RoundingMode::ToZero &&
170+
(!Src::has_infinity || src_exponent != Src::exponent_mask)) {
171+
// Return largest finite number.
172+
return tgt_sign | (Tgt::exponent_mask - Tgt::has_infinity) |
173+
Tgt::mantissa_mask;
174+
}
175+
// Infinity or NaN for formats without infinity.
176+
return tgt_sign | Tgt::overflow_value;
177+
}
156178
}
157179

158-
uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask);
159-
uint32_t mantissa = (u & f32_mantissa_mask);
160-
161-
uint32_t shift;
162-
if (exponent > f32_exponent_bias - f16_exponent_bias) {
163-
shift = f32_num_mantissa_bits - f16_num_mantissa_bits;
164-
exponent -= f32_exponent_bias - f16_exponent_bias;
180+
// Normal number.
181+
TT tgt_mantissa = src_mantissa >> (Src::mantissa_bits - Tgt::mantissa_bits);
182+
TT tgt_exponent = (unbiased_exponent + Tgt::exponent_bias)
183+
<< Tgt::mantissa_bits;
184+
if constexpr (RoundingMode == RoundingMode::ToNearestTiesToEven) {
185+
ST half = 1 << (Src::mantissa_bits - Tgt::mantissa_bits - 1);
186+
ST remainder = src_mantissa & ((half << 1) - 1);
187+
if (remainder > half || (remainder == half && (tgt_mantissa & 1) != 0)) {
188+
if (tgt_mantissa < Tgt::mantissa_mask) {
189+
tgt_mantissa += 1;
190+
} else {
191+
// Mantissa overflowed, increment exponent.
192+
193+
// Normally we can just add to the exponent and will naturally end up
194+
// on infinity on overflow. But we need special treatments for formats
195+
// without infinity.
196+
if (Tgt::has_infinity || tgt_exponent != Tgt::exponent_mask) {
197+
tgt_mantissa = 0;
198+
tgt_exponent += TT{1} << Tgt::mantissa_bits;
199+
} else {
200+
// Return NaN.
201+
tgt_mantissa = Tgt::mantissa_mask;
202+
}
203+
}
204+
}
165205
} else {
166-
shift = (f32_exponent_bias - 1) - exponent;
167-
exponent = 0;
168-
mantissa |=
169-
(1u
170-
<< f32_num_mantissa_bits); // Bump the least significant exponent bit
206+
assert(RoundingMode == RoundingMode::ToZero);
171207
}
208+
return tgt_sign | tgt_exponent | tgt_mantissa;
209+
}
172210

173-
// Round to zero.
174-
mantissa >>= shift;
211+
} // namespace detail
175212

176-
return static_cast<float16>(
177-
sign | (exponent << f16_num_mantissa_bits) | mantissa);
213+
inline float16 cpu_float2half_rn(float f) {
214+
uint32_t f_u32;
215+
std::memcpy(&f_u32, &f, sizeof(f_u32));
216+
return detail::ieee754_trunc<
217+
/*Src=*/detail::IEEE754Single,
218+
/*Tgt=*/detail::IEEE754Half,
219+
detail::RoundingMode::ToNearestTiesToEven>(f_u32);
178220
}
179221

222+
inline float16 cpu_float2half_rz(float f) {
223+
uint32_t f_u32;
224+
std::memcpy(&f_u32, &f, sizeof(f_u32));
225+
return detail::ieee754_trunc<
226+
/*Src=*/detail::IEEE754Single,
227+
/*Tgt=*/detail::IEEE754Half,
228+
detail::RoundingMode::ToZero>(f_u32);
229+
};
230+
180231
// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
181232
// float into an IEEE754 32-bit single-precision float
182233
inline float cpu_half2float_ref(const float16 h) {
234+
constexpr uint32_t f16_num_exponent_bits = 5;
235+
constexpr uint32_t f16_num_mantissa_bits = 10;
236+
constexpr uint32_t f16_num_non_sign_bits =
237+
f16_num_exponent_bits + f16_num_mantissa_bits;
238+
constexpr uint32_t f16_exponent_bias = 15;
239+
constexpr uint32_t f16_exponent_mask = 0b1'1111;
240+
constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111;
241+
242+
constexpr uint32_t f32_num_exponent_bits = 8;
243+
constexpr uint32_t f32_num_mantissa_bits = 23;
244+
constexpr uint32_t f32_num_non_sign_bits =
245+
f32_num_exponent_bits + f32_num_mantissa_bits;
246+
constexpr uint32_t f32_exponent_bias = 127;
247+
constexpr uint32_t f32_exponent_mask = 0b1111'1111;
248+
constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF;
249+
constexpr uint32_t f32_most_significant_bit = 1u << 22;
250+
183251
// Get sign and exponent alone by themselves
184252
uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
185253
uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;

0 commit comments

Comments
 (0)