Skip to content

[misc] remove is_cuda_available #5319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda_available, set_weight_attrs
from sglang.srt.utils import is_cuda, set_weight_attrs

_is_cuda = is_cuda_available()
_is_cuda = is_cuda()

if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import triton.language as tl

from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_cuda, is_hip

is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda because it is used after _is_hip

_is_cuda = is_cuda()
if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability()

_is_hip = is_hip()
Expand Down Expand Up @@ -1037,12 +1037,12 @@ def extend_attention_fwd(
num_warps = 4

else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import is_cuda, is_hip

is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda because it is used after _is_hip

_is_cuda = is_cuda()
if _is_cuda:
CUDA_CAPABILITY = torch.cuda.get_device_capability()

_is_hip = is_hip()
Expand Down Expand Up @@ -345,12 +345,12 @@ def extend_attention_fwd(
num_warps = 4

else:
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
if Lq <= 128:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
import triton
import triton.language as tl

is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe NOT equivalent

from sglang.srt.utils import is_cuda, is_hip

_is_cuda = is_cuda()
_is_hip = is_hip()

if _is_cuda or _is_hip:
CUDA_CAPABILITY = torch.cuda.get_device_capability()


Expand Down Expand Up @@ -172,7 +176,7 @@ def context_attention_fwd(
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
BLOCK = 64
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import torch.nn as nn

from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda

_is_cuda = is_cuda_available()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

_is_cuda = is_cuda()

if _is_cuda:
from sgl_kernel import (
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/quantization/modelopt_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda

if is_cuda_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

if is_cuda():
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant

# Initialize logger for the module
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda_available, set_weight_attrs
from sglang.srt.utils import is_cuda, set_weight_attrs

is_cuda = is_cuda_available()
if is_cuda:
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import int8_scaled_mm


Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch.nn as nn

from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda

_is_cuda_available = is_cuda_available()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

_is_cuda = is_cuda()

if _is_cuda_available:
if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else:
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(

cache = self._compute_cos_sin_cache()
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
if not _is_cuda_available:
if not _is_cuda:
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
Expand Down Expand Up @@ -149,7 +149,7 @@ def forward_cuda(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
positions=positions,
query=query,
Expand Down Expand Up @@ -652,7 +652,7 @@ def forward_hip(self, *args, **kwargs):
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda_available:
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda

if is_cuda_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

if is_cuda():
from sgl_kernel import (
min_p_sampling_from_probs,
top_k_renorm_prob,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available
from sglang.srt.utils import add_prefix, is_cuda

if is_cuda_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

if is_cuda():
from sgl_kernel import bmm_fp8


Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/speculative/build_eagle_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch

from sglang.srt.utils import is_cuda_available, is_hip
from sglang.srt.utils import is_cuda, is_hip

if is_cuda_available() or is_hip():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

if is_cuda() or is_hip():
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2

if is_cuda_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

if is_cuda():
from sgl_kernel import (
top_k_renorm_prob,
top_p_renorm_prob,
Expand Down
9 changes: 2 additions & 7 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
fast_topk,
get_available_gpu_memory,
is_cuda_available,
)
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda

if is_cuda_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark: equivalent to is_cuda

if is_cuda():
from sgl_kernel import segment_packbits

logger = logging.getLogger(__name__)
Expand Down
4 changes: 0 additions & 4 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def is_flashinfer_available():
return importlib.util.find_spec("flashinfer") is not None and is_cuda()


def is_cuda_available():
return is_cuda()


_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
)
Expand Down
Loading