Skip to content

Commit e501c89

Browse files
authored
Support Functions Typed with __nv_bfloat16_raw (#262)
Currently Numbast's type system does not recognize `__nv_bfloat16_raw` data type. Certain CUDA C++ functions takes the raw type as input, so we also provide those support here. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added support for bfloat16 raw type conversions, enabling bidirectional conversion between bfloat16 values and their raw representations in CUDA kernels. * Extended type system to recognize and properly handle the new bfloat16 raw variant. * **Tests** * Added comprehensive test coverage for bfloat16 raw conversion functionality. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Michael Wang <[email protected]>
1 parent d52dcff commit e501c89

File tree

5 files changed

+46
-2
lines changed

5 files changed

+46
-2
lines changed

numbast/src/numbast/static/renderer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ def _try_import_numba_type(cls, typ: str):
110110
cls.Imports.add("from numba.cuda.types import bfloat16")
111111
cls._imported_numba_types.add(typ)
112112

113+
elif typ == "__nv_bfloat16_raw":
114+
cls.Imports.add(
115+
"from numba.cuda._internal.cuda_bf16 import _type_unnamed1405307 as bfloat16_raw_type"
116+
)
117+
cls._imported_numba_types.add(typ)
118+
113119
elif typ in vector_types:
114120
# CUDA target specific types
115121
cls.Imports.add("from numba.cuda.vector_types import vector_types")

numbast/src/numbast/static/tests/data/bf16.cuh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,12 @@ nv_bfloat16 inline __device__ add(nv_bfloat16 a, nv_bfloat16 b) {
1010
return a + b;
1111
}
1212

13+
__nv_bfloat16_raw inline __device__ bf16_to_raw(nv_bfloat16 a) {
14+
return __nv_bfloat16_raw(a);
15+
}
16+
17+
nv_bfloat16 inline __device__ bf16_from_raw(__nv_bfloat16_raw a) {
18+
return __nv_bfloat16(a);
19+
}
20+
1321
#endif

numbast/src/numbast/static/tests/test_bf16_support.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,27 @@ def kernel(arr):
2323

2424
# Check that bfloat16 is imported
2525
assert "from numba.cuda.types import bfloat16" in res1["src"]
26+
27+
28+
def test_bindings_from_bf16raw(make_binding):
29+
res = make_binding("bf16.cuh", {}, {})
30+
31+
binding = res["bindings"]
32+
33+
bf16_from_raw = binding["bf16_from_raw"]
34+
bf16_to_raw = binding["bf16_to_raw"]
35+
36+
@cuda.jit
37+
def kernel(arr):
38+
x = bfloat16(3.14)
39+
40+
x_raw = bf16_to_raw(x)
41+
x2 = bf16_from_raw(x_raw)
42+
43+
arr[0] = float32(x2)
44+
45+
arr = cuda.device_array((1,), dtype="float32")
46+
47+
kernel[1, 1](arr)
48+
49+
assert pytest.approx(arr[0], 1e-2) == 3.14

numbast/src/numbast/static/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ def to_numba_type_str(ty: str):
5959
BaseRenderer._try_import_numba_type("__nv_bfloat16")
6060
return "bfloat16"
6161

62+
if ty == "__nv_bfloat16_raw":
63+
BaseRenderer._try_import_numba_type("__nv_bfloat16_raw")
64+
return "bfloat16_raw_type"
65+
6266
if ty.endswith("*"):
6367
base_ty = ty.rstrip("*").rstrip(" ")
6468
ptr_ty_str = f"CPointer({to_numba_type_str(base_ty)})"

numbast/src/numbast/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from numba.cuda.vector_types import vector_types
1010
from numba.misc.special import typeof
1111

12+
from numba.cuda._internal.cuda_bf16 import _type_unnamed1405307
13+
1214

1315
class FunctorType(nbtypes.Type):
1416
def __init__(self, name):
@@ -83,8 +85,8 @@ def register_enum_type(cxx_name: str, e: IntEnum):
8385

8486

8587
def to_numba_type(ty: str):
86-
if ty == "__nv_bfloat16":
87-
return bfloat16
88+
if ty == "__nv_bfloat16_raw":
89+
return _type_unnamed1405307
8890

8991
if "FunctorType" in ty:
9092
return FunctorType(ty[:-11])

0 commit comments

Comments
 (0)