-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[misc] remove is_cuda_available #5319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3c0cc54
2276949
1655902
a37fd3c
05374c9
7b79be9
249c23a
a04d958
51f60d3
7b18732
6a66544
243a3ab
f9def37
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,10 +23,10 @@ | |
from sglang.srt.layers.attention.triton_ops.prefill_attention import ( | ||
context_attention_fwd, | ||
) | ||
from sglang.srt.utils import is_hip | ||
from sglang.srt.utils import is_cuda, is_hip | ||
|
||
is_cuda_available = torch.cuda.is_available() | ||
if is_cuda_available: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda because it is used after _is_hip |
||
_is_cuda = is_cuda() | ||
if _is_cuda: | ||
CUDA_CAPABILITY = torch.cuda.get_device_capability() | ||
|
||
_is_hip = is_hip() | ||
|
@@ -345,12 +345,12 @@ def extend_attention_fwd( | |
num_warps = 4 | ||
|
||
else: | ||
if is_cuda_available and CUDA_CAPABILITY[0] >= 9: | ||
if _is_cuda and CUDA_CAPABILITY[0] >= 9: | ||
if Lq <= 256: | ||
BLOCK_M, BLOCK_N = (128, 64) | ||
else: | ||
BLOCK_M, BLOCK_N = (32, 64) | ||
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: | ||
elif _is_cuda and CUDA_CAPABILITY[0] >= 8: | ||
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K) | ||
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6: | ||
if Lq <= 128: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,8 +22,12 @@ | |
import triton | ||
import triton.language as tl | ||
|
||
is_cuda_available = torch.cuda.is_available() | ||
if is_cuda_available: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe NOT equivalent |
||
from sglang.srt.utils import is_cuda, is_hip | ||
|
||
_is_cuda = is_cuda() | ||
_is_hip = is_hip() | ||
|
||
if _is_cuda or _is_hip: | ||
CUDA_CAPABILITY = torch.cuda.get_device_capability() | ||
|
||
|
||
|
@@ -172,7 +176,7 @@ def context_attention_fwd( | |
b_seq_len: [b] | ||
out: [b * s, head, head_dim] | ||
""" | ||
if is_cuda_available and CUDA_CAPABILITY[0] > 8: | ||
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8: | ||
BLOCK = 128 | ||
else: | ||
BLOCK = 64 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,9 +20,9 @@ | |
import torch.nn as nn | ||
|
||
from sglang.srt.custom_op import CustomOp | ||
from sglang.srt.utils import is_cuda_available | ||
from sglang.srt.utils import is_cuda | ||
|
||
_is_cuda = is_cuda_available() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
_is_cuda = is_cuda() | ||
|
||
if _is_cuda: | ||
from sgl_kernel import ( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,9 +22,9 @@ | |
requantize_with_max_scale, | ||
) | ||
from sglang.srt.layers.radix_attention import RadixAttention | ||
from sglang.srt.utils import is_cuda_available | ||
from sglang.srt.utils import is_cuda | ||
|
||
if is_cuda_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
if is_cuda(): | ||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant | ||
|
||
# Initialize logger for the module | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,11 @@ | |
import torch.nn as nn | ||
|
||
from sglang.srt.custom_op import CustomOp | ||
from sglang.srt.utils import is_cuda_available | ||
from sglang.srt.utils import is_cuda | ||
|
||
_is_cuda_available = is_cuda_available() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
_is_cuda = is_cuda() | ||
|
||
if _is_cuda_available: | ||
if _is_cuda: | ||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace | ||
else: | ||
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding | ||
|
@@ -82,7 +82,7 @@ def __init__( | |
|
||
cache = self._compute_cos_sin_cache() | ||
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability | ||
if not _is_cuda_available: | ||
if not _is_cuda: | ||
cache = cache.to(dtype) | ||
self.cos_sin_cache: torch.Tensor | ||
self.register_buffer("cos_sin_cache", cache, persistent=False) | ||
|
@@ -149,7 +149,7 @@ def forward_cuda( | |
key: torch.Tensor, | ||
offsets: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]): | ||
if _is_cuda and (self.head_size in [64, 128, 256, 512]): | ||
apply_rope_with_cos_sin_cache_inplace( | ||
positions=positions, | ||
query=query, | ||
|
@@ -652,7 +652,7 @@ def forward_hip(self, *args, **kwargs): | |
def forward(self, *args, **kwargs): | ||
if torch.compiler.is_compiling(): | ||
return self.forward_native(*args, **kwargs) | ||
if _is_cuda_available: | ||
if _is_cuda: | ||
return self.forward_cuda(*args, **kwargs) | ||
else: | ||
return self.forward_native(*args, **kwargs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,9 +10,9 @@ | |
from sglang.srt.layers.logits_processor import LogitsProcessorOutput | ||
from sglang.srt.managers.schedule_batch import global_server_args_dict | ||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo | ||
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available | ||
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda | ||
|
||
if is_cuda_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
if is_cuda(): | ||
from sgl_kernel import ( | ||
min_p_sampling_from_probs, | ||
top_k_renorm_prob, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,9 +40,9 @@ | |
from sglang.srt.managers.schedule_batch import global_server_args_dict | ||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
from sglang.srt.model_loader.weight_utils import default_weight_loader | ||
from sglang.srt.utils import add_prefix, is_cuda_available | ||
from sglang.srt.utils import add_prefix, is_cuda | ||
|
||
if is_cuda_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
if is_cuda(): | ||
from sgl_kernel import bmm_fp8 | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,9 +4,9 @@ | |
|
||
import torch | ||
|
||
from sglang.srt.utils import is_cuda_available, is_hip | ||
from sglang.srt.utils import is_cuda, is_hip | ||
|
||
if is_cuda_available() or is_hip(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
if is_cuda() or is_hip(): | ||
from sgl_kernel import ( | ||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,9 +19,9 @@ | |
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator | ||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode | ||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient | ||
from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2 | ||
from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 | ||
|
||
if is_cuda_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
if is_cuda(): | ||
from sgl_kernel import ( | ||
top_k_renorm_prob, | ||
top_p_renorm_prob, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,14 +34,9 @@ | |
select_top_k_tokens, | ||
) | ||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm | ||
from sglang.srt.utils import ( | ||
empty_context, | ||
fast_topk, | ||
get_available_gpu_memory, | ||
is_cuda_available, | ||
) | ||
from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda | ||
|
||
if is_cuda_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mark: equivalent to is_cuda |
||
if is_cuda(): | ||
from sgl_kernel import segment_packbits | ||
|
||
logger = logging.getLogger(__name__) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mark: equivalent to is_cuda because it is used after _is_hip