@@ -828,7 +828,8 @@ REGISTER_DTYPE(i8 , signed char)
828828REGISTER_DTYPE(u8 , unsigned char )
829829
830830// /////////////////////////////////////////////////////////////////////////////////////////////////////////
831- // numeric_limits -- min/max/lowest/quiet_nan/infinity for all registered dtypes
831+ // numeric_limits -- returns min/max/lowest/quiet_nan/infinity in the *original* dtype
832+ // (see finfo below for float-valued properties like eps/max/min/tiny)
832833template<typename T> struct numeric_limits;
833834
834835template <> struct numeric_limits <fp32_t > {
@@ -858,10 +859,10 @@ template<> struct numeric_limits<bf16_t> {
858859// fp8 E4M3: gfx950=OCP(ieee-like, NaN=0x7F), gfx942=fnuz(NaN=0x80). No infinity in either format.
859860// NOTE: __builtin_bit_cast with _BitInt(8) is not yet constexpr in clang, so use static_cast via signed char.
860861template <> struct numeric_limits <fp8_t > {
861- #if defined(__gfx950__)
862- static constexpr unsigned char bin_min = 0x08 , bin_max = 0x7E , bin_lowest = 0xFE , bin_qnan = 0x7F , bin_inf = 0x00 ;
863- #else
862+ #if defined(__gfx942__)
864863 static constexpr unsigned char bin_min = 0x08 , bin_max = 0x7F , bin_lowest = 0xFF , bin_qnan = 0x80 , bin_inf = 0x00 ;
864+ #else
865+ static constexpr unsigned char bin_min = 0x08 , bin_max = 0x7E , bin_lowest = 0xFE , bin_qnan = 0x7F , bin_inf = 0x00 ;
865866#endif
866867 OPUS_H_D static constexpr fp8_t min () { return static_cast <fp8_t >(static_cast <signed char >(bin_min)); }
867868 OPUS_H_D static constexpr fp8_t max () { return static_cast <fp8_t >(static_cast <signed char >(bin_max)); }
@@ -871,10 +872,10 @@ template<> struct numeric_limits<fp8_t> {
871872};
872873// bf8 E5M2: gfx950=OCP(ieee, has inf=0x7C, NaN=0x7E), gfx942=fnuz(no inf, NaN=0x80)
873874template <> struct numeric_limits <bf8_t > {
874- #if defined(__gfx950__)
875- static constexpr unsigned char bin_min = 0x04 , bin_max = 0x7B , bin_lowest = 0xFB , bin_qnan = 0x7F , bin_inf = 0x7C ;
876- #else
875+ #if defined(__gfx942__)
877876 static constexpr unsigned char bin_min = 0x04 , bin_max = 0x7F , bin_lowest = 0xFF , bin_qnan = 0x80 , bin_inf = 0x00 ;
877+ #else
878+ static constexpr unsigned char bin_min = 0x04 , bin_max = 0x7B , bin_lowest = 0xFB , bin_qnan = 0x7F , bin_inf = 0x7C ;
878879#endif
879880 OPUS_H_D static constexpr bf8_t min () { return static_cast <bf8_t >(bin_min); }
880881 OPUS_H_D static constexpr bf8_t max () { return static_cast <bf8_t >(bin_max); }
@@ -927,6 +928,61 @@ template<> struct numeric_limits<u8_t> {
927928 OPUS_H_D static constexpr u8_t infinity () { return 0 ; }
928929};
929930
931+ // /////////////////////////////////////////////////////////////////////////////////////////////////////////
932+ // finfo -- like torch.finfo: eps/max/min/tiny as float, bits as int
933+ template <typename T> struct finfo ;
934+
935+ template <> struct finfo <fp32_t > {
936+ static constexpr int bits = 32 ;
937+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x34000000u ); } // 2^-23
938+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x7F7FFFFFu ); } // 3.4028235e+38
939+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xFF7FFFFFu ); } // -3.4028235e+38
940+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x00800000u ); } // 2^-126
941+ };
942+ template <> struct finfo <fp16_t > {
943+ static constexpr int bits = 16 ;
944+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x3A800000u ); } // 2^-10 = 9.765625e-4
945+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x477FE000u ); } // 65504.0
946+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xC77FE000u ); } // -65504.0
947+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x38800000u ); } // 2^-14
948+ };
949+ template <> struct finfo <bf16_t > {
950+ static constexpr int bits = 16 ;
951+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x3C000000u ); } // 2^-7 = 0.0078125
952+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x7F7F0000u ); } // 3.389531e+38
953+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xFF7F0000u ); } // -3.389531e+38
954+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x00800000u ); } // 2^-126
955+ };
956+ // fp8 E4M3: gfx950=OCP(float8_e4m3fn, bias=7), gfx942=fnuz(float8_e4m3fnuz, bias=8)
957+ template <> struct finfo <fp8_t > {
958+ static constexpr int bits = 8 ;
959+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x3E000000u ); } // 2^-3 = 0.125
960+ #if defined(__gfx942__)
961+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x43700000u ); } // 240.0
962+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xC3700000u ); } // -240.0
963+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x3C000000u ); } // 2^-7 = 0.0078125
964+ #else
965+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x43E00000u ); } // 448.0
966+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xC3E00000u ); } // -448.0
967+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x3C800000u ); } // 2^-6 = 0.015625
968+ #endif
969+ };
970+ // bf8 E5M2: gfx950=OCP(float8_e5m2, bias=15), gfx942=fnuz(float8_e5m2fnuz, bias=16)
971+ template <> struct finfo <bf8_t > {
972+ static constexpr int bits = 8 ;
973+ #if defined(__gfx942__)
974+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x3E000000u ); } // 2^-3 = 0.125
975+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x47600000u ); } // 57344.0
976+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xC7600000u ); } // -57344.0
977+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x38000000u ); } // 2^-15
978+ #else
979+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x3E800000u ); } // 2^-2 = 0.25
980+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x47600000u ); } // 57344.0
981+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xC7600000u ); } // -57344.0
982+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x38800000u ); } // 2^-14
983+ #endif
984+ };
985+
930986template <typename C, typename ... S, std::enable_if_t <is_dtype_v<C> && (is_constant_v<S> && ...), bool> = true>
931987OPUS_H_D constexpr auto slice(C&& container, S&&.../* ss*/ ) { return container; } // TODO: fallback slice a normal value does nonthing
932988// ///////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1039,6 +1095,24 @@ OPUS_DEFINE_DPACKS(uint4_t, unsigned char, 4, false) // uint4x2
10391095OPUS_DEFINE_FPACKS(fp4_t , unsigned char , 4 , 2 , 1 , true ) // fp4x2
10401096OPUS_DEFINE_FPACKS(e8m0_t , unsigned char , 8 , 8 , 0 , false ) // fp4x2
10411097
1098+ // finfo specializations for subbyte/packed types (defined after OPUS_DEFINE_FPACKS)
1099+ // fp4 E2M1: 1 sign, 2 exp, 1 mantissa, bias=1
1100+ template<> struct finfo<fp4_t> {
1101+ static constexpr int bits = 4 ;
1102+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x3F000000u ); } // 2^-1 = 0.5
1103+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x40C00000u ); } // 6.0
1104+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0xC0C00000u ); } // -6.0
1105+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x3F800000u ); } // 1.0
1106+ };
1107+ // e8m0: 8-bit exponent only, unsigned, bias=127
1108+ template <> struct finfo <e8m0_t > {
1109+ static constexpr int bits = 8 ;
1110+ OPUS_H_D static constexpr float eps () { return __builtin_bit_cast (float , 0x3F800000u ); } // 1.0
1111+ OPUS_H_D static constexpr float max () { return __builtin_bit_cast (float , 0x7F000000u ); } // 2^127
1112+ OPUS_H_D static constexpr float min () { return __builtin_bit_cast (float , 0x00400000u ); } // 2^-127 (unsigned, no negative)
1113+ OPUS_H_D static constexpr float tiny () { return __builtin_bit_cast (float , 0x00400000u ); } // 2^-127
1114+ };
1115+
10421116#pragma clang diagnostic push
10431117#pragma clang diagnostic ignored "-Wuninitialized"
10441118#pragma clang diagnostic ignored "-Wc++20-extensions"
0 commit comments