Skip to content

Commit 708fd70

Browse files
committed
fix type hint; compare against enums
1 parent eb82b80 commit 708fd70

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,10 @@ def multicast_supported(self) -> bool:
931931
return bool(self._get_cached_attribute(driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED))
932932

933933

934+
_SUCCESS = driver.CUresult.CUDA_SUCCESS
935+
_INVALID_CTX = driver.CUresult.CUDA_ERROR_INVALID_CONTEXT
936+
937+
934938
class Device:
935939
"""Represent a GPU and act as an entry point for cuda.core features.
936940
@@ -960,7 +964,7 @@ class Device:
960964

961965
__slots__ = ("_id", "_mr", "_has_inited", "_properties")
962966

963-
def __new__(cls, device_id: int = None):
967+
def __new__(cls, device_id: Optional[int] = None):
964968
global _is_cuInit
965969
if _is_cuInit is False:
966970
with _lock:
@@ -970,9 +974,9 @@ def __new__(cls, device_id: int = None):
970974
# important: creating a Device instance does not initialize the GPU!
971975
if device_id is None:
972976
err, dev = driver.cuCtxGetDevice()
973-
if err == 0:
977+
if err == _SUCCESS:
974978
device_id = int(dev)
975-
elif err == 201: # CUDA_ERROR_INVALID_CONTEXT
979+
elif err == _INVALID_CTX:
976980
ctx = handle_return(driver.cuCtxGetCurrent())
977981
assert int(ctx) == 0
978982
device_id = 0 # cudart behavior

0 commit comments

Comments
 (0)