Skip to content

Commit 64cbc4c

Browse files
Lazy load code modules (#269)
Lazy load module in ObjectCode
1 parent c1fea41 commit 64cbc4c

File tree

4 files changed

+53
-62
lines changed

4 files changed

+53
-62
lines changed

cuda_core/cuda/core/experimental/_module.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import importlib.metadata
66

77
from cuda import cuda
8-
from cuda.core.experimental._utils import handle_return
8+
from cuda.core.experimental._utils import handle_return, precondition
99

1010
_backend = {
1111
"old": {
@@ -106,30 +106,43 @@ class ObjectCode:
106106
107107
"""
108108

109-
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map")
109+
__slots__ = ("_handle", "_backend_version", "_jit_options", "_code_type", "_module", "_loader", "_sym_map")
110110
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")
111111

112112
def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
113113
if code_type not in self._supported_code_type:
114114
raise ValueError
115115
_lazy_init()
116+
117+
# handle is assigned during _lazy_load
116118
self._handle = None
119+
self._jit_options = jit_options
120+
121+
self._backend_version = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
122+
self._loader = _backend[self._backend_version]
123+
124+
self._code_type = code_type
125+
self._module = module
126+
self._sym_map = {} if symbol_mapping is None else symbol_mapping
117127

118-
backend = "new" if (_py_major_ver >= 12 and _driver_ver >= 12000) else "old"
119-
self._loader = _backend[backend]
128+
# TODO: do we want to unload in a finalizer? Probably not..
120129

130+
def _lazy_load_module(self, *args, **kwargs):
131+
if self._handle is not None:
132+
return
133+
jit_options = self._jit_options
134+
module = self._module
121135
if isinstance(module, str):
122136
# TODO: this option is only taken by the new library APIs, but we have
123137
# a bug that we can't easily support it just yet (NVIDIA/cuda-python#73).
124138
if jit_options is not None:
125139
raise ValueError
126-
module = module.encode()
127140
self._handle = handle_return(self._loader["file"](module))
128141
else:
129142
assert isinstance(module, bytes)
130143
if jit_options is None:
131144
jit_options = {}
132-
if backend == "new":
145+
if self._backend_version == "new":
133146
args = (
134147
module,
135148
list(jit_options.keys()),
@@ -141,15 +154,15 @@ def __init__(self, module, code_type, jit_options=None, *, symbol_mapping=None):
141154
0,
142155
)
143156
else: # "old" backend
144-
args = (module, len(jit_options), list(jit_options.keys()), list(jit_options.values()))
157+
args = (
158+
module,
159+
len(jit_options),
160+
list(jit_options.keys()),
161+
list(jit_options.values()),
162+
)
145163
self._handle = handle_return(self._loader["data"](*args))
146164

147-
self._code_type = code_type
148-
self._module = module
149-
self._sym_map = {} if symbol_mapping is None else symbol_mapping
150-
151-
# TODO: do we want to unload in a finalizer? Probably not..
152-
165+
@precondition(_lazy_load_module)
153166
def get_kernel(self, name):
154167
"""Return the :obj:`Kernel` of a specified name from this object code.
155168
@@ -168,6 +181,7 @@ def get_kernel(self, name):
168181
name = self._sym_map[name]
169182
except KeyError:
170183
name = name.encode()
184+
171185
data = handle_return(self._loader["kernel"](self._handle, name))
172186
return Kernel._from_obj(data, self)
173187

cuda_core/tests/conftest.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
import sys
1212

1313
try:
14-
from cuda.bindings import driver
14+
from cuda.bindings import driver, nvrtc
1515
except ImportError:
1616
from cuda import cuda as driver
17-
17+
from cuda import nvrtc
1818
import pytest
1919

2020
from cuda.core.experimental import Device, _device
@@ -65,3 +65,9 @@ def clean_up_cffi_files():
6565
os.remove(f)
6666
except FileNotFoundError:
6767
pass # noqa: SIM105
68+
69+
70+
def can_load_generated_ptx():
71+
_, driver_ver = driver.cuDriverGetVersion()
72+
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
73+
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver

cuda_core/tests/test_module.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,22 @@
66
# this software and related documentation outside the terms of the EULA
77
# is strictly prohibited.
88

9-
import importlib
109

1110
import pytest
12-
13-
from cuda.core.experimental._module import ObjectCode
14-
15-
16-
@pytest.mark.skipif(
17-
int(importlib.metadata.version("cuda-python").split(".")[0]) < 12,
18-
reason="Module loading for older drivers validate require valid module code.",
19-
)
20-
def test_object_code_initialization():
21-
# Test with supported code types
22-
for code_type in ["cubin", "ptx", "fatbin"]:
23-
module_data = b"dummy_data"
24-
obj_code = ObjectCode(module_data, code_type)
25-
assert obj_code._code_type == code_type
26-
assert obj_code._module == module_data
27-
assert obj_code._handle is not None
28-
29-
# Test with unsupported code type
30-
with pytest.raises(ValueError):
31-
ObjectCode(b"dummy_data", "unsupported_code_type")
32-
33-
34-
# TODO add ObjectCode tests which provide the appropriate data for cuLibraryLoadFromFile
35-
def test_object_code_initialization_with_str():
36-
assert True
37-
38-
39-
def test_object_code_initialization_with_jit_options():
40-
assert True
41-
42-
43-
def test_object_code_get_kernel():
44-
assert True
45-
46-
47-
def test_kernel_from_obj():
48-
assert True
11+
from conftest import can_load_generated_ptx
12+
13+
from cuda.core.experimental import Program
14+
15+
16+
@pytest.mark.xfail(not can_load_generated_ptx(), reason="PTX version too new")
17+
def test_get_kernel():
18+
kernel = """
19+
extern __device__ int B();
20+
extern __device__ int C(int a, int b);
21+
__global__ void A() { int result = C(B(), 1);}
22+
"""
23+
object_code = Program(kernel, "c++").compile("ptx", options=("-rdc=true",))
24+
assert object_code._handle is None
25+
kernel = object_code.get_kernel("A")
26+
assert object_code._handle is not None
27+
assert kernel._handle is not None

cuda_core/tests/test_program.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,12 @@
77
# is strictly prohibited.
88

99
import pytest
10+
from conftest import can_load_generated_ptx
1011

11-
from cuda import cuda, nvrtc
1212
from cuda.core.experimental import Device, Program
1313
from cuda.core.experimental._module import Kernel, ObjectCode
1414

1515

16-
def can_load_generated_ptx():
17-
_, driver_ver = cuda.cuDriverGetVersion()
18-
_, nvrtc_major, nvrtc_minor = nvrtc.nvrtcVersion()
19-
if nvrtc_major * 1000 + nvrtc_minor * 10 > driver_ver:
20-
return False
21-
return True
22-
23-
2416
def test_program_init_valid_code_type():
2517
code = 'extern "C" __global__ void my_kernel() {}'
2618
program = Program(code, "c++")

0 commit comments

Comments
 (0)