77from typing import Optional , Union
88
99from cuda import cuda
10+ from cuda .core .experimental ._device import Device
1011from cuda .core .experimental ._kernel_arg_handler import ParamHolder
1112from cuda .core .experimental ._module import Kernel
1213from cuda .core .experimental ._stream import Stream
@@ -38,10 +39,14 @@ class LaunchConfig:
3839 ----------
3940 grid : Union[tuple, int]
4041 Collection of threads that will execute a kernel function.
42+ cluster : Union[tuple, int]
43+ Group of blocks (Thread Block Cluster) that will execute on the same
44+ GPU Processing Cluster (GPC). Blocks within a cluster have access to
45+ distributed shared memory and can be explicitly synchronized.
4146 block : Union[tuple, int]
4247 Group of threads (Thread Block) that will execute on the same
43- multiprocessor. Threads within a thread blocks have access to
44- shared memory and can be explicitly synchronized.
48+ streaming multiprocessor (SM) . Threads within a thread blocks have
49+ access to shared memory and can be explicitly synchronized.
4550 stream : :obj:`Stream`
4651 The stream establishing the stream ordering semantic of a
4752 launch.
@@ -53,13 +58,22 @@ class LaunchConfig:
5358
5459 # TODO: expand LaunchConfig to include other attributes
5560 grid : Union [tuple , int ] = None
61+ cluster : Union [tuple , int ] = None
5662 block : Union [tuple , int ] = None
5763 stream : Stream = None
5864 shmem_size : Optional [int ] = None
5965
6066 def __post_init__ (self ):
67+ _lazy_init ()
6168 self .grid = self ._cast_to_3_tuple (self .grid )
6269 self .block = self ._cast_to_3_tuple (self .block )
70+ # thread block clusters are supported starting H100
71+ if self .cluster is not None :
72+ if not _use_ex :
73+ raise CUDAError ("thread block clusters require cuda.bindings & driver 11.8+" )
74+ if Device ().compute_capability < (9 , 0 ):
75+ raise CUDAError ("thread block clusters are not supported on devices with compute capability < 9.0" )
76+ self .cluster = self ._cast_to_3_tuple (self .cluster )
6377 # we handle "stream=None" in the launch API
6478 if self .stream is not None and not isinstance (self .stream , Stream ):
6579 try :
@@ -69,8 +83,6 @@ def __post_init__(self):
6983 if self .shmem_size is None :
7084 self .shmem_size = 0
7185
72- _lazy_init ()
73-
7486 def _cast_to_3_tuple (self , cfg ):
7587 if isinstance (cfg , int ):
7688 if cfg < 1 :
@@ -133,7 +145,15 @@ def launch(kernel, config, *kernel_args):
133145 drv_cfg .blockDimX , drv_cfg .blockDimY , drv_cfg .blockDimZ = config .block
134146 drv_cfg .hStream = config .stream .handle
135147 drv_cfg .sharedMemBytes = config .shmem_size
136- drv_cfg .numAttrs = 0 # TODO
148+ attrs = [] # TODO: support more attributes
149+ if config .cluster :
150+ attr = cuda .CUlaunchAttribute ()
151+ attr .id = cuda .CUlaunchAttributeID .CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
152+ dim = attr .value .clusterDim
153+ dim .x , dim .y , dim .z = config .cluster
154+ attrs .append (attr )
155+ drv_cfg .numAttrs = len (attrs )
156+ drv_cfg .attrs = attrs
137157 handle_return (cuda .cuLaunchKernelEx (drv_cfg , int (kernel ._handle ), args_ptr , 0 ))
138158 else :
139159 # TODO: check if config has any unsupported attrs
0 commit comments