Skip to content
Merged
28 changes: 24 additions & 4 deletions paddle/phi/common/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,37 @@ struct PADDLE_ALIGN(2) float16 {
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
Bits v, s;
v.f = val;
// Extract sign bit and clear from value
uint32_t sign = v.si & sigN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
sign >>= shiftSign;

// Handle subnormals: normalize using multiplication
const uint32_t subnormal_mask = -(minN > v.si);
s.si = mulN;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
s.si = s.f * v.f; // Extract the fraction of the subnormal number through
// multiplication and conversion from float to int
v.si ^= (s.si ^ v.si) & subnormal_mask;

// Handle special values: infinity and NaN
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));

// Rounding: round to nearest, ties to even
// https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
const uint32_t lsb =
(v.ui >> shift) & 0x1; // Least significant retained bit
v.ui += (0xFFF + lsb) & -(v.ui < infN); // Round with overflow protection

v.ui >>= shift; // logical shift

// Exponent adjustment for overflow (max values)
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
// Exponent adjustment for normal numbers
const uint32_t normal_mask = ~subnormal_mask;
v.si ^= ((v.si - minD) ^ v.si) & normal_mask;

// Combine sign and value bits
x = v.ui | sign;

#endif
Expand Down
Loading