Skip to content

Add ObjectCode ptx constructor #470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions cuda_core/cuda/core/experimental/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,10 +224,9 @@ class ObjectCode:
Note
----
This class has no default constructor. If you already have a cubin that you would
like to load, use the :meth:`from_cubin` alternative constructor. For all other
possible code types (ex: "ptx"), only :class:`~cuda.core.experimental.Program`
accepts them and returns an :class:`ObjectCode` instance with its
:meth:`~cuda.core.experimental.Program.compile` method.
like to load, use the :meth:`from_cubin` alternative constructor. Constructing directly
from all other possible code types should be avoided in favor of compilation through
:class:`~cuda.core.experimental.Program`

Note
----
Expand Down Expand Up @@ -278,6 +277,22 @@ def from_cubin(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = No
"""
return ObjectCode._init(module, "cubin", symbol_mapping=symbol_mapping)

@staticmethod
def from_ptx(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
"""Create an :class:`ObjectCode` instance from an existing PTX.

Parameters
----------
module : Union[bytes, str]
Either a bytes object containing the in-memory ptx code to load, or
a file path string pointing to the on-disk ptx file to load.
symbol_mapping : Optional[dict]
A dictionary specifying how the unmangled symbol names (as keys)
should be mapped to the mangled names before trying to retrieve
them (default to no mappings).
"""
return ObjectCode._init(module, "ptx", symbol_mapping=symbol_mapping)

# TODO: do we want to unload in a finalizer? Probably not..

def _lazy_load_module(self, *args, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,8 @@ def close(self):
self._linker.close()
self._mnff.close()

def _can_load_generated_ptx(self):
@staticmethod
def _can_load_generated_ptx():
driver_ver = handle_return(driver.cuDriverGetVersion())
nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion())
return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver
Expand Down
8 changes: 8 additions & 0 deletions cuda_core/tests/test_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def test_linker_link_cubin(compile_ptx_functions):
assert isinstance(linked_code, ObjectCode)


def test_linker_link_ptx_multiple(compile_ptx_functions):
ptxes = tuple(ObjectCode.from_ptx(obj.code) for obj in compile_ptx_functions)
options = LinkerOptions(arch=ARCH)
linker = Linker(*ptxes, options=options)
linked_code = linker.link("cubin")
assert isinstance(linked_code, ObjectCode)


def test_linker_link_invalid_target_type(compile_ptx_functions):
options = LinkerOptions(arch=ARCH)
linker = Linker(*compile_ptx_functions, options=options)
Expand Down
51 changes: 36 additions & 15 deletions cuda_core/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,25 @@

from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system

SAXPY_KERNEL = """
template<typename T>
__global__ void saxpy(const T a,
const T* x,
const T* y,
T* out,
size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (size_t i=tid; i<N; i+=gridDim.x*blockDim.x) {
out[tid] = a * x[tid] + y[tid];
}
}
"""


@pytest.fixture(scope="function")
def get_saxpy_kernel(init_cuda):
code = """
template<typename T>
__global__ void saxpy(const T a,
const T* x,
const T* y,
T* out,
size_t N) {
const unsigned int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (size_t i=tid; i<N; i+=gridDim.x*blockDim.x) {
out[tid] = a * x[tid] + y[tid];
}
}
"""

# prepare program
prog = Program(code, code_type="c++")
prog = Program(SAXPY_KERNEL, code_type="c++")
mod = prog.compile(
"cubin",
name_expressions=("saxpy<float>", "saxpy<double>"),
Expand All @@ -41,6 +41,17 @@ def get_saxpy_kernel(init_cuda):
return mod.get_kernel("saxpy<float>"), mod


@pytest.fixture(scope="function")
def get_saxpy_kernel_ptx(init_cuda):
prog = Program(SAXPY_KERNEL, code_type="c++")
mod = prog.compile(
"ptx",
name_expressions=("saxpy<float>", "saxpy<double>"),
)
ptx = mod._module
return ptx, mod


def test_get_kernel(init_cuda):
kernel = """extern "C" __global__ void ABC() { }"""

Expand Down Expand Up @@ -100,6 +111,16 @@ def test_object_code_load_cubin(get_saxpy_kernel):
mod.get_kernel("saxpy<double>") # force loading


def test_object_code_load_ptx(get_saxpy_kernel_ptx):
ptx, mod = get_saxpy_kernel_ptx
sym_map = mod._sym_map
mod_obj = ObjectCode.from_ptx(ptx, symbol_mapping=sym_map)
assert mod.code == ptx
if not Program._can_load_generated_ptx():
pytest.skip("PTX version too new for current driver")
mod_obj.get_kernel("saxpy<double>") # force loading


def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
_, mod = get_saxpy_kernel
cubin = mod._module
Expand Down
Loading