File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed
cuda_core/cuda/core/experimental Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -931,6 +931,10 @@ def multicast_supported(self) -> bool:
931
931
return bool (self ._get_cached_attribute (driver .CUdevice_attribute .CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED ))
932
932
933
933
934
+ _SUCCESS = driver .CUresult .CUDA_SUCCESS
935
+ _INVALID_CTX = driver .CUresult .CUDA_ERROR_INVALID_CONTEXT
936
+
937
+
934
938
class Device :
935
939
"""Represent a GPU and act as an entry point for cuda.core features.
936
940
@@ -960,7 +964,7 @@ class Device:
960
964
961
965
__slots__ = ("_id" , "_mr" , "_has_inited" , "_properties" )
962
966
963
- def __new__ (cls , device_id : int = None ):
967
+ def __new__ (cls , device_id : Optional [ int ] = None ):
964
968
global _is_cuInit
965
969
if _is_cuInit is False :
966
970
with _lock :
@@ -970,9 +974,9 @@ def __new__(cls, device_id: int = None):
970
974
# important: creating a Device instance does not initialize the GPU!
971
975
if device_id is None :
972
976
err , dev = driver .cuCtxGetDevice ()
973
- if err == 0 :
977
+ if err == _SUCCESS :
974
978
device_id = int (dev )
975
- elif err == 201 : # CUDA_ERROR_INVALID_CONTEXT
979
+ elif err == _INVALID_CTX :
976
980
ctx = handle_return (driver .cuCtxGetCurrent ())
977
981
assert int (ctx ) == 0
978
982
device_id = 0 # cudart behavior
You can’t perform that action at this time.
0 commit comments