diff --git a/cuda_core/cuda/core/experimental/__init__.py b/cuda_core/cuda/core/experimental/__init__.py index 9093163b3..7739ae64e 100644 --- a/cuda_core/cuda/core/experimental/__init__.py +++ b/cuda_core/cuda/core/experimental/__init__.py @@ -5,7 +5,8 @@ from cuda.core.experimental import utils from cuda.core.experimental._device import Device from cuda.core.experimental._event import Event, EventOptions -from cuda.core.experimental._launcher import LaunchConfig, launch +from cuda.core.experimental._launch_config import LaunchConfig +from cuda.core.experimental._launcher import launch from cuda.core.experimental._linker import Linker, LinkerOptions from cuda.core.experimental._module import ObjectCode from cuda.core.experimental._program import Program, ProgramOptions diff --git a/cuda_core/cuda/core/experimental/_launch_config.py b/cuda_core/cuda/core/experimental/_launch_config.py new file mode 100644 index 000000000..bb4e92fb3 --- /dev/null +++ b/cuda_core/cuda/core/experimental/_launch_config.py @@ -0,0 +1,97 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Optional, Union + +from cuda.core.experimental._device import Device +from cuda.core.experimental._utils.cuda_utils import ( + CUDAError, + cast_to_3_tuple, + driver, + get_binding_version, + handle_return, +) + +# TODO: revisit this treatment for py313t builds +_inited = False + + +def _lazy_init(): + global _inited + if _inited: + return + + global _use_ex + # binding availability depends on cuda-python version + _py_major_minor = get_binding_version() + _driver_ver = handle_return(driver.cuDriverGetVersion()) + _use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8)) + _inited = True + + +@dataclass +class LaunchConfig: + """Customizable launch options. + + Attributes + ---------- + grid : Union[tuple, int] + Collection of threads that will execute a kernel function. + cluster : Union[tuple, int] + Group of blocks (Thread Block Cluster) that will execute on the same + GPU Processing Cluster (GPC). Blocks within a cluster have access to + distributed shared memory and can be explicitly synchronized. + block : Union[tuple, int] + Group of threads (Thread Block) that will execute on the same + streaming multiprocessor (SM). Threads within a thread blocks have + access to shared memory and can be explicitly synchronized. + shmem_size : int, optional + Dynamic shared-memory size per thread block in bytes. + (Default to size 0) + + """ + + # TODO: expand LaunchConfig to include other attributes + grid: Union[tuple, int] = None + cluster: Union[tuple, int] = None + block: Union[tuple, int] = None + shmem_size: Optional[int] = None + + def __post_init__(self): + _lazy_init() + self.grid = cast_to_3_tuple("LaunchConfig.grid", self.grid) + self.block = cast_to_3_tuple("LaunchConfig.block", self.block) + # thread block clusters are supported starting H100 + if self.cluster is not None: + if not _use_ex: + err, drvers = driver.cuDriverGetVersion() + drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else "" + raise CUDAError(f"thread block clusters require cuda.bindings & driver 11.8+{drvers_fmt}") + cc = Device().compute_capability + if cc < (9, 0): + raise CUDAError( + f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})" + ) + self.cluster = cast_to_3_tuple("LaunchConfig.cluster", self.cluster) + if self.shmem_size is None: + self.shmem_size = 0 + + +def _to_native_launch_config(config: LaunchConfig) -> driver.CUlaunchConfig: + _lazy_init() + drv_cfg = driver.CUlaunchConfig() + drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid + drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block + drv_cfg.sharedMemBytes = config.shmem_size + attrs = [] # TODO: support more attributes + if config.cluster: + attr = driver.CUlaunchAttribute() + attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION + dim = attr.value.clusterDim + dim.x, dim.y, dim.z = config.cluster + attrs.append(attr) + drv_cfg.numAttrs = len(attrs) + drv_cfg.attrs = attrs + return drv_cfg diff --git a/cuda_core/cuda/core/experimental/_launcher.py b/cuda_core/cuda/core/experimental/_launcher.py index 7eef9de13..72afb5ffd 100644 --- a/cuda_core/cuda/core/experimental/_launcher.py +++ b/cuda_core/cuda/core/experimental/_launcher.py @@ -2,17 +2,13 @@ # # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass -from typing import Optional, Union -from cuda.core.experimental._device import Device from cuda.core.experimental._kernel_arg_handler import ParamHolder +from cuda.core.experimental._launch_config import LaunchConfig, _to_native_launch_config from cuda.core.experimental._module import Kernel from cuda.core.experimental._stream import Stream from cuda.core.experimental._utils.clear_error_support import assert_type from cuda.core.experimental._utils.cuda_utils import ( - CUDAError, - cast_to_3_tuple, check_or_create_options, driver, get_binding_version, @@ -37,54 +33,6 @@ def _lazy_init(): _inited = True -@dataclass -class LaunchConfig: - """Customizable launch options. - - Attributes - ---------- - grid : Union[tuple, int] - Collection of threads that will execute a kernel function. - cluster : Union[tuple, int] - Group of blocks (Thread Block Cluster) that will execute on the same - GPU Processing Cluster (GPC). Blocks within a cluster have access to - distributed shared memory and can be explicitly synchronized. - block : Union[tuple, int] - Group of threads (Thread Block) that will execute on the same - streaming multiprocessor (SM). Threads within a thread blocks have - access to shared memory and can be explicitly synchronized. - shmem_size : int, optional - Dynamic shared-memory size per thread block in bytes. - (Default to size 0) - - """ - - # TODO: expand LaunchConfig to include other attributes - grid: Union[tuple, int] = None - cluster: Union[tuple, int] = None - block: Union[tuple, int] = None - shmem_size: Optional[int] = None - - def __post_init__(self): - _lazy_init() - self.grid = cast_to_3_tuple("LaunchConfig.grid", self.grid) - self.block = cast_to_3_tuple("LaunchConfig.block", self.block) - # thread block clusters are supported starting H100 - if self.cluster is not None: - if not _use_ex: - err, drvers = driver.cuDriverGetVersion() - drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else "" - raise CUDAError(f"thread block clusters require cuda.bindings & driver 11.8+{drvers_fmt}") - cc = Device().compute_capability - if cc < (9, 0): - raise CUDAError( - f"thread block clusters are not supported on devices with compute capability < 9.0 (got {cc})" - ) - self.cluster = cast_to_3_tuple("LaunchConfig.cluster", self.cluster) - if self.shmem_size is None: - self.shmem_size = 0 - - def launch(stream, config, kernel, *kernel_args): """Launches a :obj:`~_module.Kernel` object with launch-time configuration. @@ -114,6 +62,7 @@ def launch(stream, config, kernel, *kernel_args): f"stream must either be a Stream object or support __cuda_stream__ (got {type(stream)})" ) from e assert_type(kernel, Kernel) + _lazy_init() config = check_or_create_options(LaunchConfig, config, "launch config") # TODO: can we ensure kernel_args is valid/safe to use here? @@ -127,25 +76,13 @@ def launch(stream, config, kernel, *kernel_args): # mainly to see if the "Ex" API is available and if so we use it, as it's more feature # rich. if _use_ex: - drv_cfg = driver.CUlaunchConfig() - drv_cfg.gridDimX, drv_cfg.gridDimY, drv_cfg.gridDimZ = config.grid - drv_cfg.blockDimX, drv_cfg.blockDimY, drv_cfg.blockDimZ = config.block + drv_cfg = _to_native_launch_config(config) drv_cfg.hStream = stream.handle - drv_cfg.sharedMemBytes = config.shmem_size - attrs = [] # TODO: support more attributes - if config.cluster: - attr = driver.CUlaunchAttribute() - attr.id = driver.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION - dim = attr.value.clusterDim - dim.x, dim.y, dim.z = config.cluster - attrs.append(attr) - drv_cfg.numAttrs = len(attrs) - drv_cfg.attrs = attrs handle_return(driver.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0)) else: # TODO: check if config has any unsupported attrs handle_return( driver.cuLaunchKernel( - int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream._handle, args_ptr, 0 + int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream.handle, args_ptr, 0 ) ) diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index 29b810382..9c80c687b 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -6,6 +6,8 @@ from typing import Optional, Union from warnings import warn +from cuda.core.experimental._launch_config import LaunchConfig, _to_native_launch_config +from cuda.core.experimental._stream import Stream from cuda.core.experimental._utils.clear_error_support import ( assert_type, assert_type_str_or_bytes, @@ -184,6 +186,170 @@ def cluster_scheduling_policy_preference(self, device_id: int = None) -> int: ) +MaxPotentialBlockSizeOccupancyResult = namedtuple("MaxPotential", ("min_grid_size", "max_block_size")) + + +class KernelOccupancy: + """ """ + + def __new__(self, *args, **kwargs): + raise RuntimeError("KernelOccupancy cannot be instantiated directly. Please use Kernel APIs.") + + slots = ("_handle",) + + @classmethod + def _init(cls, handle): + self = super().__new__(cls) + self._handle = handle + + return self + + def max_active_blocks_per_multiprocessor(self, block_size: int, dynamic_shared_memory_size: int) -> int: + """Occupancy of the kernel. + + Returns the maximum number of active blocks per multiprocessor for this kernel. + + Parameters + ---------- + block_size: int + Block size parameter used to launch this kernel. + dynamic_shared_memory_size: int + The amount of dynamic shared memory in bytes needed by block. + Use `0` if block does not need shared memory. + + Returns + ------- + int + The maximum number of active blocks per multiprocessor. + + Note + ---- + The fraction of the product of maximum number of active blocks per multiprocessor + and the block size to the maximum number of threads per multiprocessor is known as + theoretical multiprocessor utilization (occupancy). + + """ + return handle_return( + driver.cuOccupancyMaxActiveBlocksPerMultiprocessor(self._handle, block_size, dynamic_shared_memory_size) + ) + + def max_potential_block_size( + self, dynamic_shared_memory_needed: Union[int, driver.CUoccupancyB2DSize], block_size_limit: int + ) -> MaxPotentialBlockSizeOccupancyResult: + """MaxPotentialBlockSizeOccupancyResult: Suggested launch configuration for reasonable occupancy. + + Returns the minimum grid size needed to achieve the maximum occupancy and + the maximum block size that can achieve the maximum occupancy. + + Parameters + ---------- + dynamic_shared_memory_needed: Union[int, driver.CUoccupancyB2DSize] + The amount of dynamic shared memory in bytes needed by block. + Use `0` if block does not need shared memory. Use C-callable + represented by :obj:`~driver.CUoccupancyB2DSize` to encode + amount of needed dynamic shared memory which varies depending + on tne block size. + block_size_limit: int + Known upper limit on the kernel block size. Use `0` to indicate + the maximum block size permitted by the device / kernel instead + + Returns + ------- + :obj:`~MaxPotentialBlockSizeOccupancyResult` + An object with `min_grid_size` amd `max_block_size` attributes encoding + the suggested launch configuration. + + Note + ---- + Please be advised that use of C-callable that requires Python Global + Interpreter Lock may lead to deadlocks. + + """ + if isinstance(dynamic_shared_memory_needed, int): + min_grid_size, max_block_size = handle_return( + driver.cuOccupancyMaxPotentialBlockSize( + self._handle, None, dynamic_shared_memory_needed, block_size_limit + ) + ) + elif isinstance(dynamic_shared_memory_needed, driver.CUoccupancyB2DSize): + min_grid_size, max_block_size = handle_return( + driver.cuOccupancyMaxPotentialBlockSize( + self._handle, dynamic_shared_memory_needed.getPtr(), 0, block_size_limit + ) + ) + else: + raise TypeError( + "dynamic_shared_memory_needed expected to have type int, or CUoccupancyB2DSize, " + f"got {type(dynamic_shared_memory_needed)}" + ) + return MaxPotentialBlockSizeOccupancyResult(min_grid_size=min_grid_size, max_block_size=max_block_size) + + def available_dynamic_shared_memory_per_block(self, num_blocks_per_multiprocessor: int, block_size: int) -> int: + """Dynamic shared memory available per block for given launch configuration. + + The amount of dynamic shared memory per block, in bytes, for given kernel launch configuration. + + Parameters + ---------- + num_blocks_per_multiprocessor: int + Number of blocks to be concurrently executing on a multiprocessor. + block_size: int + Block size parameter used to launch this kernel. + + Returns + ------- + int + Dynamic shared memory available per block for given launch configuration. + """ + return handle_return( + driver.cuOccupancyAvailableDynamicSMemPerBlock(self._handle, num_blocks_per_multiprocessor, block_size) + ) + + def max_potential_cluster_size(self, config: LaunchConfig, stream: Optional[Stream] = None) -> int: + """Maximum potential cluster size. + + The maximum potential cluster size for this kernel and given launch configuration. + + Parameters + ---------- + config: :obj:`~_launch_config.LaunchConfig` + Kernel launch configuration. Cluster dimensions in the configuration are ignored. + stream: :obj:`~Stream`, optional + The stream on which this kernel is to be launched. + + Returns + ------- + int + The maximum cluster size that can be launched for this kernel and launch configuration. + """ + drv_cfg = _to_native_launch_config(config) + if stream is not None: + drv_cfg.hStream = stream.handle + return handle_return(driver.cuOccupancyMaxPotentialClusterSize(self._handle, drv_cfg)) + + def max_active_clusters(self, config: LaunchConfig, stream: Optional[Stream] = None) -> int: + """Maximum number of active clusters on the target device. + + The maximum number of clusters that could concurrently execute on the target device. + + Parameters + ---------- + config: :obj:`~_launch_config.LaunchConfig` + Kernel launch configuration. + stream: :obj:`~Stream`, optional + The stream on which this kernel is to be launched. + + Returns + ------- + int + The maximum number of clusters that could co-exist on the target device. + """ + drv_cfg = _to_native_launch_config(config) + if stream is not None: + drv_cfg.hStream = stream.handle + return handle_return(driver.cuOccupancyMaxActiveClusters(self._handle, drv_cfg)) + + ParamInfo = namedtuple("ParamInfo", ["offset", "size"]) @@ -198,7 +364,7 @@ class Kernel: """ - __slots__ = ("_handle", "_module", "_attributes") + __slots__ = ("_handle", "_module", "_attributes", "_occupancy") def __new__(self, *args, **kwargs): raise RuntimeError("Kernel objects cannot be instantiated directly. Please use ObjectCode APIs.") @@ -211,6 +377,7 @@ def _from_obj(cls, obj, mod): ker._handle = obj ker._module = mod ker._attributes = None + ker._occupancy = None return ker @property @@ -250,6 +417,13 @@ def arguments_info(self) -> list[ParamInfo]: _, param_info = self._get_arguments_info(param_info=True) return param_info + @property + def occupancy(self) -> KernelOccupancy: + """Get the occupancy information for launching this kernel.""" + if self._occupancy is None: + self._occupancy = KernelOccupancy._init(self._handle) + return self._occupancy + # TODO: implement from_handle() diff --git a/cuda_core/docs/source/api_private.rst b/cuda_core/docs/source/api_private.rst index a38cc2602..94f5a5df1 100644 --- a/cuda_core/docs/source/api_private.rst +++ b/cuda_core/docs/source/api_private.rst @@ -19,7 +19,11 @@ CUDA runtime _stream.Stream _event.Event _device.DeviceProperties + _launch_config.LaunchConfig _module.KernelAttributes + _module.KernelOccupancy + _module.ParamInfo + _module.MaxPotentialBlockSizeOccupancyResult CUDA compilation toolchain @@ -29,3 +33,4 @@ CUDA compilation toolchain :toctree: generated/ _module.Kernel + _module.ObjectCode diff --git a/cuda_core/docs/source/release/0.3.0-notes.rst b/cuda_core/docs/source/release/0.3.0-notes.rst index eb365b4cb..df1f21ffd 100644 --- a/cuda_core/docs/source/release/0.3.0-notes.rst +++ b/cuda_core/docs/source/release/0.3.0-notes.rst @@ -21,6 +21,7 @@ New features ------------ - :class:`Kernel` adds :property:`Kernel.num_arguments` and :property:`Kernel.arguments_info` for introspection of kernel arguments. (#612) +- Add pythonic access to kernel occupancy calculation functions via :property:`Kernel.occupancy`. (#648) New examples ------------ diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index d85a4745e..36b60b3bc 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -9,9 +9,14 @@ from conftest import skipif_testing_with_compute_sanitizer import cuda.core.experimental -from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system +from cuda.core.experimental import Device, ObjectCode, Program, ProgramOptions, system from cuda.core.experimental._utils.cuda_utils import CUDAError, driver, get_binding_version, handle_return +try: + import numba +except ImportError: + numba = None + SAXPY_KERNEL = r""" template __global__ void saxpy(const T a, @@ -41,6 +46,11 @@ def test_kernel_attributes_init_disabled(): cuda.core.experimental._module.KernelAttributes() # Ensure back door is locked. +def test_kernel_occupancy_init_disabled(): + with pytest.raises(RuntimeError, match=r"^KernelOccupancy cannot be instantiated directly\."): + cuda.core.experimental._module.KernelOccupancy() # Ensure back door is locked. + + def test_kernel_init_disabled(): with pytest.raises(RuntimeError, match=r"^Kernel objects cannot be instantiated directly\."): cuda.core.experimental._module.Kernel() # Ensure back door is locked. @@ -248,6 +258,118 @@ def test_num_args_error_handling(deinit_all_contexts_function, cuda12_prerequisi _ = krn.num_arguments +@pytest.mark.parametrize("block_size", [32, 64, 96, 120, 128, 256]) +@pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096]) +def test_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, block_size, smem_size_per_block): + kernel, _ = get_saxpy_kernel + dev_props = Device().properties + assert block_size <= dev_props.max_threads_per_block + assert smem_size_per_block <= dev_props.max_shared_memory_per_block + num_blocks_per_sm = kernel.occupancy.max_active_blocks_per_multiprocessor(block_size, smem_size_per_block) + assert isinstance(num_blocks_per_sm, int) + assert num_blocks_per_sm > 0 + kernel_threads_per_sm = num_blocks_per_sm * block_size + kernel_smem_size_per_sm = num_blocks_per_sm * smem_size_per_block + assert kernel_threads_per_sm <= dev_props.max_threads_per_multiprocessor + assert kernel_smem_size_per_sm <= dev_props.max_shared_memory_per_multiprocessor + assert kernel.attributes.num_regs() * num_blocks_per_sm <= dev_props.max_registers_per_multiprocessor + + +@pytest.mark.parametrize("block_size_limit", [32, 64, 96, 120, 128, 256, 0]) +@pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096]) +def test_occupancy_max_potential_block_size_constant(get_saxpy_kernel, block_size_limit, smem_size_per_block): + """Tests use case when shared memory needed is independent on the block size""" + kernel, _ = get_saxpy_kernel + dev_props = Device().properties + assert block_size_limit <= dev_props.max_threads_per_block + assert smem_size_per_block <= dev_props.max_shared_memory_per_block + config_data = kernel.occupancy.max_potential_block_size(smem_size_per_block, block_size_limit) + assert isinstance(config_data, tuple) + assert len(config_data) == 2 + min_grid_size, max_block_size = config_data + assert isinstance(min_grid_size, int) + assert isinstance(max_block_size, int) + assert min_grid_size > 0 + assert max_block_size > 0 + if block_size_limit > 0: + assert max_block_size <= block_size_limit + else: + assert max_block_size <= dev_props.max_threads_per_block + assert min_grid_size == config_data.min_grid_size + assert max_block_size == config_data.max_block_size + invalid_dsmem = Ellipsis + with pytest.raises(TypeError): + kernel.occupancy.max_potential_block_size(invalid_dsmem, block_size_limit) + + +@pytest.mark.skipif(numba is None, reason="Test requires numba to be installed") +@pytest.mark.parametrize("block_size_limit", [32, 64, 96, 120, 128, 277, 0]) +def test_occupancy_max_potential_block_size_b2dsize(get_saxpy_kernel, block_size_limit): + """Tests use case when shared memory needed depends on the block size""" + kernel, _ = get_saxpy_kernel + + def shared_memory_needed(block_size: numba.intc) -> numba.size_t: + "Size of dynamic shared memory needed by kernel of this block size" + return 1024 * (block_size // 32) + + b2dsize_sig = numba.size_t(numba.intc) + dsmem_needed_cfunc = numba.cfunc(b2dsize_sig)(shared_memory_needed) + fn_ptr = ctypes.cast(dsmem_needed_cfunc.ctypes, ctypes.c_void_p).value + b2dsize_fn = driver.CUoccupancyB2DSize(_ptr=fn_ptr) + config_data = kernel.occupancy.max_potential_block_size(b2dsize_fn, block_size_limit) + dev_props = Device().properties + assert block_size_limit <= dev_props.max_threads_per_block + min_grid_size, max_block_size = config_data + assert isinstance(min_grid_size, int) + assert isinstance(max_block_size, int) + assert min_grid_size > 0 + assert max_block_size > 0 + if block_size_limit > 0: + assert max_block_size <= block_size_limit + else: + assert max_block_size <= dev_props.max_threads_per_block + + +@pytest.mark.parametrize("num_blocks_per_sm, block_size", [(4, 32), (2, 64), (2, 96), (3, 120), (2, 128), (1, 256)]) +def test_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel, num_blocks_per_sm, block_size): + kernel, _ = get_saxpy_kernel + dev_props = Device().properties + assert block_size <= dev_props.max_threads_per_block + assert num_blocks_per_sm * block_size <= dev_props.max_threads_per_multiprocessor + smem_size = kernel.occupancy.available_dynamic_shared_memory_per_block(num_blocks_per_sm, block_size) + assert smem_size <= dev_props.max_shared_memory_per_block + assert num_blocks_per_sm * smem_size <= dev_props.max_shared_memory_per_multiprocessor + + +@pytest.mark.parametrize("cluster", [None, 2]) +def test_occupancy_max_active_clusters(get_saxpy_kernel, cluster): + kernel, _ = get_saxpy_kernel + dev = Device() + if (cluster) and (dev.compute_capability < (9, 0)): + pytest.skip("Device with compute capability 90 or higher is required for cluster support") + launch_config = cuda.core.experimental.LaunchConfig(grid=128, block=64, cluster=cluster) + query_fn = kernel.occupancy.max_active_clusters + max_active_clusters = query_fn(launch_config) + assert isinstance(max_active_clusters, int) + assert max_active_clusters >= 0 + max_active_clusters = query_fn(launch_config, stream=dev.default_stream) + assert isinstance(max_active_clusters, int) + assert max_active_clusters >= 0 + + +def test_occupancy_max_potential_cluster_size(get_saxpy_kernel): + kernel, _ = get_saxpy_kernel + dev = Device() + launch_config = cuda.core.experimental.LaunchConfig(grid=128, block=64) + query_fn = kernel.occupancy.max_potential_cluster_size + max_potential_cluster_size = query_fn(launch_config) + assert isinstance(max_potential_cluster_size, int) + assert max_potential_cluster_size >= 0 + max_potential_cluster_size = query_fn(launch_config, stream=dev.default_stream) + assert isinstance(max_potential_cluster_size, int) + assert max_potential_cluster_size >= 0 + + def test_module_serialization_roundtrip(get_saxpy_kernel): _, objcode = get_saxpy_kernel result = pickle.loads(pickle.dumps(objcode)) # nosec B403, B301