diff --git a/cuda_core/cuda/core/experimental/_linker.py b/cuda_core/cuda/core/experimental/_linker.py index 976d739f3..7736d7b2d 100644 --- a/cuda_core/cuda/core/experimental/_linker.py +++ b/cuda_core/cuda/core/experimental/_linker.py @@ -6,12 +6,12 @@ import weakref from contextlib import contextmanager from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple, Union from warnings import warn from cuda.core.experimental._device import Device from cuda.core.experimental._module import ObjectCode -from cuda.core.experimental._utils import check_or_create_options, driver, handle_return +from cuda.core.experimental._utils import check_or_create_options, driver, handle_return, is_sequence # TODO: revisit this treatment for py313t builds _driver = None # populated if nvJitLink cannot be used @@ -130,15 +130,14 @@ class LinkerOptions: fma : bool, optional Use fast multiply-add. Default: True. - kernels_used : List[str], optional - Pass list of kernels that are used; any not in the list can be removed. This option can be specified multiple - times. - variables_used : List[str], optional - Pass a list of variables that are used; any not in the list can be removed. + kernels_used : [Union[str, Tuple[str], List[str]]], optional + Pass a kernel or sequence of kernels that are used; any not in the list can be removed. + variables_used : [Union[str, Tuple[str], List[str]]], optional + Pass a variable or sequence of variables that are used; any not in the list can be removed. optimize_unused_variables : bool, optional Assume that if a variable is not referenced in device code, it can be removed. Default: False. - ptxas_options : List[str], optional + ptxas_options : [Union[str, Tuple[str], List[str]]], optional Pass options to PTXAS. split_compile : int, optional Split compilation maximum thread count. Use 0 to use all available processors. Value of 1 disables split @@ -167,10 +166,10 @@ class LinkerOptions: prec_div: Optional[bool] = None prec_sqrt: Optional[bool] = None fma: Optional[bool] = None - kernels_used: Optional[List[str]] = None - variables_used: Optional[List[str]] = None + kernels_used: Optional[Union[str, Tuple[str], List[str]]] = None + variables_used: Optional[Union[str, Tuple[str], List[str]]] = None optimize_unused_variables: Optional[bool] = None - ptxas_options: Optional[List[str]] = None + ptxas_options: Optional[Union[str, Tuple[str], List[str]]] = None split_compile: Optional[int] = None split_compile_extended: Optional[int] = None no_cache: Optional[bool] = None @@ -213,16 +212,25 @@ def _init_nvjitlink(self): if self.fma is not None: self.formatted_options.append(f"-fma={'true' if self.fma else 'false'}") if self.kernels_used is not None: - for kernel in self.kernels_used: - self.formatted_options.append(f"-kernels-used={kernel}") + if isinstance(self.kernels_used, str): + self.formatted_options.append(f"-kernels-used={self.kernels_used}") + elif isinstance(self.kernels_used, list): + for kernel in self.kernels_used: + self.formatted_options.append(f"-kernels-used={kernel}") if self.variables_used is not None: - for variable in self.variables_used: - self.formatted_options.append(f"-variables-used={variable}") + if isinstance(self.variables_used, str): + self.formatted_options.append(f"-variables-used={self.variables_used}") + elif isinstance(self.variables_used, list): + for variable in self.variables_used: + self.formatted_options.append(f"-variables-used={variable}") if self.optimize_unused_variables is not None: self.formatted_options.append("-optimize-unused-variables") if self.ptxas_options is not None: - for opt in self.ptxas_options: - self.formatted_options.append(f"-Xptxas={opt}") + if isinstance(self.ptxas_options, str): + self.formatted_options.append(f"-Xptxas={self.ptxas_options}") + elif is_sequence(self.ptxas_options): + for opt in self.ptxas_options: + self.formatted_options.append(f"-Xptxas={opt}") if self.split_compile is not None: self.formatted_options.append(f"-split-compile={self.split_compile}") if self.split_compile_extended is not None: diff --git a/cuda_core/tests/test_linker.py b/cuda_core/tests/test_linker.py index 556fe9f7c..732f70c32 100644 --- a/cuda_core/tests/test_linker.py +++ b/cuda_core/tests/test_linker.py @@ -58,7 +58,9 @@ def compile_ltoir_functions(init_cuda): options += [ LinkerOptions(arch=ARCH, time=True), LinkerOptions(arch=ARCH, optimize_unused_variables=True), - LinkerOptions(arch=ARCH, ptxas_options=["-v"]), + LinkerOptions(arch=ARCH, ptxas_options="-v"), + LinkerOptions(arch=ARCH, ptxas_options=["-v", "--verbose"]), + LinkerOptions(arch=ARCH, ptxas_options=("-v", "--verbose")), LinkerOptions(arch=ARCH, split_compile=0), LinkerOptions(arch=ARCH, split_compile_extended=1), # The following options are supported by nvjitlink and deprecated by culink @@ -66,10 +68,12 @@ def compile_ltoir_functions(init_cuda): LinkerOptions(arch=ARCH, prec_div=True), LinkerOptions(arch=ARCH, prec_sqrt=True), LinkerOptions(arch=ARCH, fma=True), - LinkerOptions(arch=ARCH, kernels_used=["A"]), + LinkerOptions(arch=ARCH, kernels_used="A"), LinkerOptions(arch=ARCH, kernels_used=["C", "B"]), - LinkerOptions(arch=ARCH, variables_used=["var1"]), + LinkerOptions(arch=ARCH, kernels_used=("C", "B")), + LinkerOptions(arch=ARCH, variables_used="var1"), LinkerOptions(arch=ARCH, variables_used=["var1", "var2"]), + LinkerOptions(arch=ARCH, variables_used=("var1", "var2")), ] version = nvjitlink.version() if version >= (12, 5):