Skip to content

Commit cb0b0c8

Browse files
authored
[OPUS] Add finfo class for float-valued type properties (#2330)
* [OPUS] Add finfo class for float-valued type properties (eps/max/min/tiny/bits) Supports fp32, fp16, bf16, fp8, bf8, fp4, e8m0 with gfx950/gfx942 specializations. Verified bitwise against torch.finfo on both MI355 (gfx950) and MI308 (gfx942). * [OPUS] Use explicit opus:: namespace in test_finfo.cu * [OPUS] Apply black formatting and update README compile times * [OPUS] Use __gfx942__ guard instead of __gfx950__ in numeric_limits and finfo
1 parent cca2de5 commit cb0b0c8

File tree

5 files changed

+255
-22
lines changed

5 files changed

+255
-22
lines changed

csrc/include/opus/opus.hpp

Lines changed: 81 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,8 @@ REGISTER_DTYPE(i8 , signed char)
828828
REGISTER_DTYPE(u8 , unsigned char)
829829

830830
///////////////////////////////////////////////////////////////////////////////////////////////////////////
831-
// numeric_limits -- min/max/lowest/quiet_nan/infinity for all registered dtypes
831+
// numeric_limits -- returns min/max/lowest/quiet_nan/infinity in the *original* dtype
832+
// (see finfo below for float-valued properties like eps/max/min/tiny)
832833
template<typename T> struct numeric_limits;
833834

834835
template<> struct numeric_limits<fp32_t> {
@@ -858,10 +859,10 @@ template<> struct numeric_limits<bf16_t> {
858859
// fp8 E4M3: gfx950=OCP(ieee-like, NaN=0x7F), gfx942=fnuz(NaN=0x80). No infinity in either format.
859860
// NOTE: __builtin_bit_cast with _BitInt(8) is not yet constexpr in clang, so use static_cast via signed char.
860861
template<> struct numeric_limits<fp8_t> {
861-
#if defined(__gfx950__)
862-
static constexpr unsigned char bin_min = 0x08, bin_max = 0x7E, bin_lowest = 0xFE, bin_qnan = 0x7F, bin_inf = 0x00;
863-
#else
862+
#if defined(__gfx942__)
864863
static constexpr unsigned char bin_min = 0x08, bin_max = 0x7F, bin_lowest = 0xFF, bin_qnan = 0x80, bin_inf = 0x00;
864+
#else
865+
static constexpr unsigned char bin_min = 0x08, bin_max = 0x7E, bin_lowest = 0xFE, bin_qnan = 0x7F, bin_inf = 0x00;
865866
#endif
866867
OPUS_H_D static constexpr fp8_t min() { return static_cast<fp8_t>(static_cast<signed char>(bin_min)); }
867868
OPUS_H_D static constexpr fp8_t max() { return static_cast<fp8_t>(static_cast<signed char>(bin_max)); }
@@ -871,10 +872,10 @@ template<> struct numeric_limits<fp8_t> {
871872
};
872873
// bf8 E5M2: gfx950=OCP(ieee, has inf=0x7C, NaN=0x7E), gfx942=fnuz(no inf, NaN=0x80)
873874
template<> struct numeric_limits<bf8_t> {
874-
#if defined(__gfx950__)
875-
static constexpr unsigned char bin_min = 0x04, bin_max = 0x7B, bin_lowest = 0xFB, bin_qnan = 0x7F, bin_inf = 0x7C;
876-
#else
875+
#if defined(__gfx942__)
877876
static constexpr unsigned char bin_min = 0x04, bin_max = 0x7F, bin_lowest = 0xFF, bin_qnan = 0x80, bin_inf = 0x00;
877+
#else
878+
static constexpr unsigned char bin_min = 0x04, bin_max = 0x7B, bin_lowest = 0xFB, bin_qnan = 0x7F, bin_inf = 0x7C;
878879
#endif
879880
OPUS_H_D static constexpr bf8_t min() { return static_cast<bf8_t>(bin_min); }
880881
OPUS_H_D static constexpr bf8_t max() { return static_cast<bf8_t>(bin_max); }
@@ -927,6 +928,61 @@ template<> struct numeric_limits<u8_t> {
927928
OPUS_H_D static constexpr u8_t infinity() { return 0; }
928929
};
929930

931+
///////////////////////////////////////////////////////////////////////////////////////////////////////////
932+
// finfo -- like torch.finfo: eps/max/min/tiny as float, bits as int
933+
template<typename T> struct finfo;
934+
935+
template<> struct finfo<fp32_t> {
936+
static constexpr int bits = 32;
937+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x34000000u); } // 2^-23
938+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x7F7FFFFFu); } // 3.4028235e+38
939+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xFF7FFFFFu); } // -3.4028235e+38
940+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x00800000u); } // 2^-126
941+
};
942+
template<> struct finfo<fp16_t> {
943+
static constexpr int bits = 16;
944+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3A800000u); } // 2^-10 = 9.765625e-4
945+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x477FE000u); } // 65504.0
946+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC77FE000u); } // -65504.0
947+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x38800000u); } // 2^-14
948+
};
949+
template<> struct finfo<bf16_t> {
950+
static constexpr int bits = 16;
951+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3C000000u); } // 2^-7 = 0.0078125
952+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x7F7F0000u); } // 3.389531e+38
953+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xFF7F0000u); } // -3.389531e+38
954+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x00800000u); } // 2^-126
955+
};
956+
// fp8 E4M3: gfx950=OCP(float8_e4m3fn, bias=7), gfx942=fnuz(float8_e4m3fnuz, bias=8)
957+
template<> struct finfo<fp8_t> {
958+
static constexpr int bits = 8;
959+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3E000000u); } // 2^-3 = 0.125
960+
#if defined(__gfx942__)
961+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x43700000u); } // 240.0
962+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC3700000u); } // -240.0
963+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x3C000000u); } // 2^-7 = 0.0078125
964+
#else
965+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x43E00000u); } // 448.0
966+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC3E00000u); } // -448.0
967+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x3C800000u); } // 2^-6 = 0.015625
968+
#endif
969+
};
970+
// bf8 E5M2: gfx950=OCP(float8_e5m2, bias=15), gfx942=fnuz(float8_e5m2fnuz, bias=16)
971+
template<> struct finfo<bf8_t> {
972+
static constexpr int bits = 8;
973+
#if defined(__gfx942__)
974+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3E000000u); } // 2^-3 = 0.125
975+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x47600000u); } // 57344.0
976+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC7600000u); } // -57344.0
977+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x38000000u); } // 2^-15
978+
#else
979+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3E800000u); } // 2^-2 = 0.25
980+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x47600000u); } // 57344.0
981+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC7600000u); } // -57344.0
982+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x38800000u); } // 2^-14
983+
#endif
984+
};
985+
930986
template<typename C, typename... S, std::enable_if_t<is_dtype_v<C> && (is_constant_v<S> && ...), bool> = true>
931987
OPUS_H_D constexpr auto slice(C&& container, S&&.../*ss*/) { return container; } // TODO: fallback slice a normal value does nonthing
932988
/////////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1039,6 +1095,24 @@ OPUS_DEFINE_DPACKS(uint4_t, unsigned char, 4, false) // uint4x2
10391095
OPUS_DEFINE_FPACKS(fp4_t, unsigned char, 4, 2, 1, true) // fp4x2
10401096
OPUS_DEFINE_FPACKS(e8m0_t, unsigned char, 8, 8, 0, false) // fp4x2
10411097

1098+
// finfo specializations for subbyte/packed types (defined after OPUS_DEFINE_FPACKS)
1099+
// fp4 E2M1: 1 sign, 2 exp, 1 mantissa, bias=1
1100+
template<> struct finfo<fp4_t> {
1101+
static constexpr int bits = 4;
1102+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3F000000u); } // 2^-1 = 0.5
1103+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x40C00000u); } // 6.0
1104+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0xC0C00000u); } // -6.0
1105+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x3F800000u); } // 1.0
1106+
};
1107+
// e8m0: 8-bit exponent only, unsigned, bias=127
1108+
template<> struct finfo<e8m0_t> {
1109+
static constexpr int bits = 8;
1110+
OPUS_H_D static constexpr float eps() { return __builtin_bit_cast(float, 0x3F800000u); } // 1.0
1111+
OPUS_H_D static constexpr float max() { return __builtin_bit_cast(float, 0x7F000000u); } // 2^127
1112+
OPUS_H_D static constexpr float min() { return __builtin_bit_cast(float, 0x00400000u); } // 2^-127 (unsigned, no negative)
1113+
OPUS_H_D static constexpr float tiny() { return __builtin_bit_cast(float, 0x00400000u); } // 2^-127
1114+
};
1115+
10421116
#pragma clang diagnostic push
10431117
#pragma clang diagnostic ignored "-Wuninitialized"
10441118
#pragma clang diagnostic ignored "-Wc++20-extensions"

op_tests/opus/README.md

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ op_tests/opus/
2020
│ ├── test_dtype_convert.cu # FP32<->BF16/FP16/FP8/FP4 round-trip kernels
2121
│ ├── test_load_store_if.cu # Predicated load/store + free function API tests
2222
│ ├── test_numeric_limits.cu # opus::numeric_limits kernel
23+
│ ├── test_finfo.cu # opus::finfo kernel
2324
│ ├── test_mdiv.cu # opus::magic_div kernel
2425
│ ├── test_workgroup_barrier.cu# Workgroup barrier kernel
25-
│ ├── setup.py # Parallel hipcc build: 11 .cu -> .o -> .so
26+
│ ├── setup.py # Parallel hipcc build: 12 .cu -> .o -> .so
2627
│ └── test_opus_device.py # Python test runner (builds .so, runs all tests)
2728
├── run_tests.sh # Runs host test + device tests
2829
└── README.md
@@ -110,18 +111,19 @@ Total wall clock ~6.9 s
110111
### Per-file device compile times
111112

112113
```
113-
test_vector_add.cu 187 ms
114-
test_async_load.cu 185 ms
115-
test_numeric_limits.cu 191 ms
116-
test_workgroup_barrier.cu 216 ms
117-
test_mdiv.cu 243 ms
118-
test_mxfp.cu 248 ms
119-
test_load_store_if.cu 354 ms
120-
test_dtype_convert.cu 506 ms
121-
test_mfma_f32.cu 1,445 ms
122-
test_mfma_f8.cu 1,654 ms
123-
test_mfma_f16.cu 1,712 ms <-- critical path
124-
link 31 ms
114+
test_finfo.cu 127 ms
115+
test_async_load.cu 130 ms
116+
test_numeric_limits.cu 143 ms
117+
test_vector_add.cu 147 ms
118+
test_workgroup_barrier.cu 147 ms
119+
test_mdiv.cu 167 ms
120+
test_load_store_if.cu 216 ms
121+
test_mxfp.cu 224 ms
122+
test_dtype_convert.cu 292 ms
123+
test_mfma_f32.cu 769 ms
124+
test_mfma_f16.cu 863 ms
125+
test_mfma_f8.cu 884 ms <-- critical path
126+
link 25 ms
125127
```
126128

127129
## How to add a new device test
@@ -232,10 +234,11 @@ In `device/test_opus_device.py`:
232234
| `test_load_store_if` | free_func_vector_add | Free functions `opus::load`/`opus::store`, `is_gmem_v`/`is_mem_v` type traits | all |
233235
| `test_load_store_if` | predicated_async_load | `gmem::async_load_if`, free function `opus::async_load_if`, `layout_linear::operator+` | all |
234236
| `test_numeric_limits` | all types | `opus::numeric_limits<T>` for fp32/fp16/bf16/fp8/bf8/i32/i16/i8/u8 | all |
237+
| `test_finfo` | all float types | `opus::finfo<T>` (eps/max/min/tiny/bits) for fp32/fp16/bf16/fp8/bf8/fp4/e8m0 | all |
235238
| `test_mdiv` | 11 divisors | `opus::magic_div` integer division by magic multiply | all |
236239
| `test_workgroup_barrier` | cumulative + streamk | `opus::workgroup_barrier` cross-workgroup synchronization | all |
237240

238-
Total: **50+ test calls** (14 MFMA + 4 MXFP + 1 vector_add + 1 async_load + 11 dtype_convert + 3 load_store_if + 9 numeric_limits + 11 mdiv + 4 workgroup_barrier).
241+
Total: **50+ test calls** (14 MFMA + 4 MXFP + 1 vector_add + 1 async_load + 11 dtype_convert + 3 load_store_if + 9 numeric_limits + 7 finfo + 11 mdiv + 4 workgroup_barrier).
239242

240243
## Notes
241244

op_tests/opus/device/setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"test_mdiv.cu",
3434
"test_numeric_limits.cu",
3535
"test_workgroup_barrier.cu",
36+
"test_finfo.cu",
3637
]
3738

3839

op_tests/opus/device/test_finfo.cu

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// SPDX-License-Identifier: MIT
2+
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
3+
//
4+
// Device test for opus::finfo.
5+
// Single-thread kernel writes eps/max/min/tiny as float and bits as int.
6+
7+
#ifdef __HIP_DEVICE_COMPILE__
8+
// ── Device pass ─────────────────────────────────────────────────────────────
9+
#include "opus/opus.hpp"
10+
namespace {
11+
12+
// Each type writes 5 floats: eps, max, min, tiny, __int_as_float(bits)
13+
template<typename T>
14+
__device__ void write_finfo(float* out) {
15+
out[0] = opus::finfo<T>::eps();
16+
out[1] = opus::finfo<T>::max();
17+
out[2] = opus::finfo<T>::min();
18+
out[3] = opus::finfo<T>::tiny();
19+
out[4] = __builtin_bit_cast(float, opus::finfo<T>::bits);
20+
}
21+
22+
__global__ void finfo_kernel(float* out) {
23+
if (__builtin_amdgcn_workitem_id_x() != 0) return;
24+
write_finfo<opus::fp32_t>(out + 0);
25+
write_finfo<opus::fp16_t>(out + 5);
26+
write_finfo<opus::bf16_t>(out + 10);
27+
write_finfo<opus::fp8_t >(out + 15);
28+
write_finfo<opus::bf8_t >(out + 20);
29+
write_finfo<opus::fp4_t >(out + 25);
30+
write_finfo<opus::e8m0_t>(out + 30);
31+
}
32+
} // anonymous namespace
33+
34+
#else
35+
// ── Host pass ───────────────────────────────────────────────────────────────
36+
#include "hip_host_minimal.h"
37+
#include <cstdio>
38+
39+
namespace {
40+
__global__ void finfo_kernel(float* out) {}
41+
}
42+
43+
extern "C" void run_finfo(void* d_out) {
44+
finfo_kernel<<<1, 1>>>(static_cast<float*>(d_out));
45+
hipError_t err = hipDeviceSynchronize();
46+
if (err != hipSuccess) {
47+
fprintf(stderr, "finfo_kernel failed: %s\n", hipGetErrorString(err));
48+
}
49+
}
50+
#endif

op_tests/opus/device/test_opus_device.py

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
- MFMA variants (fp32, fp16, bf16, fp8, bf8)
1212
- MXFP variants (fp8, fp4) -- gfx950 only
1313
- vector_add, async_load, dtype_convert, predicated_copy, free_func_add,
14-
predicated_async_load, numeric_limits, mdiv, workgroup_barrier
14+
predicated_async_load, numeric_limits, finfo, mdiv, workgroup_barrier
1515
"""
1616

1717
import ctypes
@@ -141,6 +141,13 @@ def run_numeric_limits(self, Out):
141141
fn.argtypes = [_VP]
142142
fn(self._ptr(Out))
143143

144+
# -- finfo --
145+
def run_finfo(self, Out):
146+
fn = self._lib.run_finfo
147+
fn.restype = None
148+
fn.argtypes = [_VP]
149+
fn(self._ptr(Out))
150+
144151
# -- mdiv --
145152
def run_mdiv(self, Dividends, OutQ, OutR, divisor):
146153
fn = self._lib.run_mdiv
@@ -1298,6 +1305,103 @@ def ref_int(dtype, size):
12981305
return 0
12991306

13001307

1308+
def test_finfo(mod):
1309+
"""Test opus::finfo against torch.finfo reference values (bitwise comparison)."""
1310+
import struct
1311+
1312+
device = torch.device("cuda")
1313+
1314+
N_TYPES = 7 # fp32, fp16, bf16, fp8, bf8, fp4, e8m0
1315+
FIELDS_PER_TYPE = 5 # eps, max, min, tiny, bits
1316+
N = N_TYPES * FIELDS_PER_TYPE
1317+
out = torch.zeros(N, device=device, dtype=torch.float32)
1318+
mod.run_finfo(out)
1319+
raw = out.cpu()
1320+
1321+
fails = 0
1322+
fields = ["eps", "max", "min", "tiny", "bits"]
1323+
1324+
def float_to_u32(f):
1325+
return struct.unpack("I", struct.pack("f", float(f)))[0]
1326+
1327+
def u32_to_float(u):
1328+
return struct.unpack("f", struct.pack("I", u))[0]
1329+
1330+
def ref_from_torch_finfo(dtype):
1331+
fi = torch.finfo(dtype)
1332+
return {
1333+
"eps": fi.eps,
1334+
"max": fi.max,
1335+
"min": fi.min,
1336+
"tiny": fi.tiny,
1337+
}
1338+
1339+
fp8_dtype = _get_fp8_dtype()
1340+
bf8_dtype = _get_bf8_dtype()
1341+
1342+
# (name, offset, torch_dtype_or_None, manual_ref_or_None)
1343+
# For fp4 and e8m0 there is no torch.finfo, so we provide manual reference.
1344+
fp4_ref = {"eps": 0.5, "max": 6.0, "min": -6.0, "tiny": 1.0, "bits": 4}
1345+
e8m0_ref = {
1346+
"eps": 1.0,
1347+
"max": 2.0**127,
1348+
"min": 2.0**-127,
1349+
"tiny": 2.0**-127,
1350+
"bits": 8,
1351+
}
1352+
1353+
type_table = [
1354+
("fp32", 0, torch.float32, 32, None),
1355+
("fp16", 5, torch.float16, 16, None),
1356+
("bf16", 10, torch.bfloat16, 16, None),
1357+
("fp8", 15, fp8_dtype, 8, None),
1358+
("bf8", 20, bf8_dtype, 8, None),
1359+
("fp4", 25, None, 4, fp4_ref),
1360+
("e8m0", 30, None, 8, e8m0_ref),
1361+
]
1362+
1363+
for name, offset, dtype, expected_bits, manual_ref in type_table:
1364+
if dtype is not None:
1365+
ref = ref_from_torch_finfo(dtype)
1366+
ref["bits"] = expected_bits
1367+
else:
1368+
ref = manual_ref
1369+
1370+
type_fails = 0
1371+
for j, field in enumerate(fields):
1372+
actual_f32 = raw[offset + j].item()
1373+
if field == "bits":
1374+
# bits is stored as __int_as_float(bits), extract the int
1375+
actual_val = struct.unpack("I", struct.pack("f", actual_f32))[0]
1376+
expected_val = ref["bits"]
1377+
if actual_val != expected_val:
1378+
print(
1379+
f" {name}.{field}: {actual_val} != expected {expected_val}"
1380+
)
1381+
type_fails += 1
1382+
else:
1383+
expected_f32 = float(ref[field])
1384+
actual_bits = float_to_u32(actual_f32)
1385+
expected_bits_val = float_to_u32(expected_f32)
1386+
if actual_bits != expected_bits_val:
1387+
print(
1388+
f" {name}.{field}: 0x{actual_bits:08X} ({actual_f32}) "
1389+
f"!= expected 0x{expected_bits_val:08X} ({expected_f32})"
1390+
)
1391+
type_fails += 1
1392+
if type_fails == 0:
1393+
print(f" PASS: finfo<{name}> (all {len(fields)} fields)")
1394+
else:
1395+
print(f" FAIL: finfo<{name}> ({type_fails} field(s) wrong)")
1396+
fails += type_fails
1397+
1398+
if fails:
1399+
print(f" finfo: {fails} field(s) FAILED")
1400+
return 1
1401+
print(" PASS: finfo all types correct")
1402+
return 0
1403+
1404+
13011405
def test_wb_cumulative(mod):
13021406
"""Test workgroup_barrier wait_lt + inc: N workgroups contribute i+1 sequentially."""
13031407
device = torch.device("cuda")
@@ -1395,6 +1499,7 @@ def main():
13951499
failures += test_free_func_vector_add(mod)
13961500
failures += test_predicated_async_load(mod)
13971501
failures += test_numeric_limits(mod)
1502+
failures += test_finfo(mod)
13981503
failures += test_mdiv(mod)
13991504
failures += test_wb_cumulative(mod)
14001505
failures += test_wb_streamk_reduce(mod)

0 commit comments

Comments
 (0)