Skip to content

Improve perf of accessing dev.compute_capability #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions cuda_core/cuda/core/experimental/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from cuda.core.experimental._utils import ComputeCapability, CUDAError, driver, handle_return, precondition, runtime

_tls = threading.local()
_tls_lock = threading.Lock()
_lock = threading.Lock()
_is_cuInit = False


class DeviceProperties:
Expand Down Expand Up @@ -938,6 +939,12 @@ class Device:
__slots__ = ("_id", "_mr", "_has_inited", "_properties")

def __new__(cls, device_id=None):
global _is_cuInit
if _is_cuInit is False:
with _lock:
handle_return(driver.cuInit(0))
_is_cuInit = True

# important: creating a Device instance does not initialize the GPU!
if device_id is None:
device_id = handle_return(runtime.cudaGetDevice())
Expand All @@ -948,27 +955,26 @@ def __new__(cls, device_id=None):
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")

# ensure Device is singleton
with _tls_lock:
if not hasattr(_tls, "devices"):
total = handle_return(runtime.cudaGetDeviceCount())
_tls.devices = []
for dev_id in range(total):
dev = super().__new__(cls)
dev._id = dev_id
# If the device is in TCC mode, or does not support memory pools for some other reason,
# use the SynchronousMemoryResource which does not use memory pools.
if (
handle_return(
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
)
) == 1:
dev._mr = _DefaultAsyncMempool(dev_id)
else:
dev._mr = _SynchronousMemoryResource(dev_id)

dev._has_inited = False
dev._properties = None
_tls.devices.append(dev)
if not hasattr(_tls, "devices"):
total = handle_return(runtime.cudaGetDeviceCount())
_tls.devices = []
for dev_id in range(total):
dev = super().__new__(cls)
dev._id = dev_id
# If the device is in TCC mode, or does not support memory pools for some other reason,
# use the SynchronousMemoryResource which does not use memory pools.
if (
handle_return(
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrMemoryPoolsSupported, 0)
)
) == 1:
dev._mr = _DefaultAsyncMempool(dev_id)
else:
dev._mr = _SynchronousMemoryResource(dev_id)

dev._has_inited = False
dev._properties = None
_tls.devices.append(dev)

return _tls.devices[device_id]

Expand Down Expand Up @@ -1029,13 +1035,11 @@ def properties(self) -> DeviceProperties:
@property
def compute_capability(self) -> ComputeCapability:
"""Return a named tuple with 2 fields: major and minor."""
major = handle_return(
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, self._id)
)
minor = handle_return(
runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, self._id)
)
return ComputeCapability(major, minor)
if "compute_capability" in self.properties._cache:
return self.properties._cache["compute_capability"]
cc = ComputeCapability(self.properties.compute_capability_major, self.properties.compute_capability_minor)
self.properties._cache["compute_capability"] = cc
return cc

@property
@precondition(_check_context_initialized)
Expand Down
3 changes: 1 addition & 2 deletions cuda_core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def _device_unset_current():
return
handle_return(driver.cuCtxPopCurrent())
if hasattr(_device._tls, "devices"):
with _device._tls_lock:
del _device._tls.devices
del _device._tls.devices


@pytest.fixture(scope="function")
Expand Down
Loading