Skip to content

Commit 1e77bc0

Browse files
Add occupancy tests, except for cluster-related queries
1 parent 42f63b4 commit 1e77bc0

File tree

1 file changed

+52
-1
lines changed

1 file changed

+52
-1
lines changed

cuda_core/tests/test_module.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313

1414
import cuda.core.experimental
15-
from cuda.core.experimental import ObjectCode, Program, ProgramOptions, system
15+
from cuda.core.experimental import Device, ObjectCode, Program, ProgramOptions, system
1616

1717
SAXPY_KERNEL = """
1818
template<typename T>
@@ -34,6 +34,11 @@ def test_kernel_attributes_init_disabled():
3434
cuda.core.experimental._module.KernelAttributes() # Ensure back door is locked.
3535

3636

37+
def test_kernel_occupancy_init_disabled():
38+
with pytest.raises(RuntimeError, match=r"^KernelOccupancy cannot be instantiated directly\."):
39+
cuda.core.experimental._module.KernelOccupancy() # Ensure back door is locked.
40+
41+
3742
def test_kernel_init_disabled():
3843
with pytest.raises(RuntimeError, match=r"^Kernel objects cannot be instantiated directly\."):
3944
cuda.core.experimental._module.Kernel() # Ensure back door is locked.
@@ -162,3 +167,49 @@ def test_object_code_load_cubin_from_file(get_saxpy_kernel, tmp_path):
162167
def test_object_code_handle(get_saxpy_object_code):
163168
mod = get_saxpy_object_code
164169
assert mod.handle is not None
170+
171+
172+
@pytest.mark.parametrize("block_size", [32, 64, 96, 120, 128, 256])
173+
@pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096])
174+
def test_saxpy_occupancy_max_active_block_per_multiprocessor(get_saxpy_kernel, block_size, smem_size_per_block):
175+
kernel, _ = get_saxpy_kernel
176+
dev_props = Device().properties
177+
assert block_size <= dev_props.max_threads_per_block
178+
assert smem_size_per_block <= dev_props.max_shared_memory_per_block
179+
num_blocks_per_sm = kernel.occupancy.max_active_blocks_per_multiprocessor(block_size, smem_size_per_block)
180+
assert isinstance(num_blocks_per_sm, int)
181+
assert num_blocks_per_sm > 0
182+
kernel_threads_per_sm = num_blocks_per_sm * block_size
183+
kernel_smem_size_per_sm = num_blocks_per_sm * smem_size_per_block
184+
assert kernel_threads_per_sm <= dev_props.max_threads_per_multiprocessor
185+
assert kernel_smem_size_per_sm <= dev_props.max_shared_memory_per_multiprocessor
186+
assert kernel.attributes.num_regs() * num_blocks_per_sm <= dev_props.max_registers_per_multiprocessor
187+
188+
189+
@pytest.mark.parametrize("block_size_limit", [32, 64, 96, 120, 128, 256])
190+
@pytest.mark.parametrize("smem_size_per_block", [0, 32, 4096])
191+
def test_saxpy_occupancy_max_potential_block_size(get_saxpy_kernel, block_size_limit, smem_size_per_block):
192+
kernel, _ = get_saxpy_kernel
193+
dev_props = Device().properties
194+
assert block_size_limit <= dev_props.max_threads_per_block
195+
assert smem_size_per_block <= dev_props.max_shared_memory_per_block
196+
config_data = kernel.occupancy.max_potential_block_size(smem_size_per_block, block_size_limit)
197+
assert isinstance(config_data, tuple)
198+
assert len(config_data) == 2
199+
min_grid_size, max_block_size = config_data
200+
assert isinstance(min_grid_size, int)
201+
assert isinstance(max_block_size, int)
202+
assert min_grid_size > 0
203+
assert max_block_size > 0
204+
assert max_block_size <= block_size_limit
205+
206+
207+
@pytest.mark.parametrize("num_blocks_per_sm, block_size", [(4, 32), (2, 64), (2, 96), (3, 120), (2, 128), (1, 256)])
208+
def test_saxpy_occupancy_available_dynamic_shared_memory_per_block(get_saxpy_kernel, num_blocks_per_sm, block_size):
209+
kernel, _ = get_saxpy_kernel
210+
dev_props = Device().properties
211+
assert block_size <= dev_props.max_threads_per_block
212+
assert num_blocks_per_sm * block_size <= dev_props.max_threads_per_multiprocessor
213+
smem_size = kernel.occupancy.available_dynamic_shared_memory_per_block(num_blocks_per_sm, block_size)
214+
assert smem_size <= dev_props.max_shared_memory_per_block
215+
assert num_blocks_per_sm * smem_size <= dev_props.max_shared_memory_per_multiprocessor

0 commit comments

Comments
 (0)