Skip to content

Commit 97cb762

Browse files
authored
[misc] remove is_cuda_available (sgl-project#5319)
1 parent 1195182 commit 97cb762

File tree

14 files changed

+42
-47
lines changed

14 files changed

+42
-47
lines changed

python/sglang/srt/layers/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
get_tensor_model_parallel_world_size,
2929
)
3030
from sglang.srt.layers.quantization.base_config import QuantizationConfig
31-
from sglang.srt.utils import is_cuda_available, set_weight_attrs
31+
from sglang.srt.utils import is_cuda, set_weight_attrs
3232

33-
_is_cuda = is_cuda_available()
33+
_is_cuda = is_cuda()
3434

3535
if _is_cuda:
3636
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul

python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import triton.language as tl
44

55
from sglang.srt.managers.schedule_batch import global_server_args_dict
6-
from sglang.srt.utils import is_hip
6+
from sglang.srt.utils import is_cuda, is_hip
77

8-
is_cuda_available = torch.cuda.is_available()
9-
if is_cuda_available:
8+
_is_cuda = is_cuda()
9+
if _is_cuda:
1010
CUDA_CAPABILITY = torch.cuda.get_device_capability()
1111

1212
_is_hip = is_hip()
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
10371037
num_warps = 4
10381038

10391039
else:
1040-
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
1040+
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
10411041
if Lq <= 256:
10421042
BLOCK_M, BLOCK_N = (128, 64)
10431043
else:
10441044
BLOCK_M, BLOCK_N = (32, 64)
1045-
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
1045+
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
10461046
if Lq <= 128:
10471047
BLOCK_M, BLOCK_N = (128, 128)
10481048
elif Lq <= 256:

python/sglang/srt/layers/attention/triton_ops/extend_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
2424
context_attention_fwd,
2525
)
26-
from sglang.srt.utils import is_hip
26+
from sglang.srt.utils import is_cuda, is_hip
2727

28-
is_cuda_available = torch.cuda.is_available()
29-
if is_cuda_available:
28+
_is_cuda = is_cuda()
29+
if _is_cuda:
3030
CUDA_CAPABILITY = torch.cuda.get_device_capability()
3131

3232
_is_hip = is_hip()
@@ -345,12 +345,12 @@ def extend_attention_fwd(
345345
num_warps = 4
346346

347347
else:
348-
if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
348+
if _is_cuda and CUDA_CAPABILITY[0] >= 9:
349349
if Lq <= 256:
350350
BLOCK_M, BLOCK_N = (128, 64)
351351
else:
352352
BLOCK_M, BLOCK_N = (32, 64)
353-
elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
353+
elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
354354
# sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
355355
if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
356356
if Lq <= 128:

python/sglang/srt/layers/attention/triton_ops/prefill_attention.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@
2222
import triton
2323
import triton.language as tl
2424

25-
is_cuda_available = torch.cuda.is_available()
26-
if is_cuda_available:
25+
from sglang.srt.utils import is_cuda, is_hip
26+
27+
_is_cuda = is_cuda()
28+
_is_hip = is_hip()
29+
30+
if _is_cuda or _is_hip:
2731
CUDA_CAPABILITY = torch.cuda.get_device_capability()
2832

2933

@@ -172,7 +176,7 @@ def context_attention_fwd(
172176
b_seq_len: [b]
173177
out: [b * s, head, head_dim]
174178
"""
175-
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
179+
if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
176180
BLOCK = 128
177181
else:
178182
BLOCK = 64

python/sglang/srt/layers/layernorm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import torch.nn as nn
2121

2222
from sglang.srt.custom_op import CustomOp
23-
from sglang.srt.utils import is_cuda_available
23+
from sglang.srt.utils import is_cuda
2424

25-
_is_cuda = is_cuda_available()
25+
_is_cuda = is_cuda()
2626

2727
if _is_cuda:
2828
from sgl_kernel import (

python/sglang/srt/layers/quantization/modelopt_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
requantize_with_max_scale,
2323
)
2424
from sglang.srt.layers.radix_attention import RadixAttention
25-
from sglang.srt.utils import is_cuda_available
25+
from sglang.srt.utils import is_cuda
2626

27-
if is_cuda_available():
27+
if is_cuda():
2828
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
2929

3030
# Initialize logger for the module

python/sglang/srt/layers/quantization/w8a8_int8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
QuantizeMethodBase,
1212
)
1313
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14-
from sglang.srt.utils import is_cuda_available, set_weight_attrs
14+
from sglang.srt.utils import is_cuda, set_weight_attrs
1515

16-
is_cuda = is_cuda_available()
17-
if is_cuda:
16+
_is_cuda = is_cuda()
17+
if _is_cuda:
1818
from sgl_kernel import int8_scaled_mm
1919

2020

python/sglang/srt/layers/rotary_embedding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
import torch.nn as nn
99

1010
from sglang.srt.custom_op import CustomOp
11-
from sglang.srt.utils import is_cuda_available
11+
from sglang.srt.utils import is_cuda
1212

13-
_is_cuda_available = is_cuda_available()
13+
_is_cuda = is_cuda()
1414

15-
if _is_cuda_available:
15+
if _is_cuda:
1616
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
1717
else:
1818
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
@@ -82,7 +82,7 @@ def __init__(
8282

8383
cache = self._compute_cos_sin_cache()
8484
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
85-
if not _is_cuda_available:
85+
if not _is_cuda:
8686
cache = cache.to(dtype)
8787
self.cos_sin_cache: torch.Tensor
8888
self.register_buffer("cos_sin_cache", cache, persistent=False)
@@ -149,7 +149,7 @@ def forward_cuda(
149149
key: torch.Tensor,
150150
offsets: Optional[torch.Tensor] = None,
151151
) -> Tuple[torch.Tensor, torch.Tensor]:
152-
if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
152+
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
153153
apply_rope_with_cos_sin_cache_inplace(
154154
positions=positions,
155155
query=query,
@@ -652,7 +652,7 @@ def forward_hip(self, *args, **kwargs):
652652
def forward(self, *args, **kwargs):
653653
if torch.compiler.is_compiling():
654654
return self.forward_native(*args, **kwargs)
655-
if _is_cuda_available:
655+
if _is_cuda:
656656
return self.forward_cuda(*args, **kwargs)
657657
else:
658658
return self.forward_native(*args, **kwargs)

python/sglang/srt/layers/sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
1111
from sglang.srt.managers.schedule_batch import global_server_args_dict
1212
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
13-
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
13+
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
1414

15-
if is_cuda_available():
15+
if is_cuda():
1616
from sgl_kernel import (
1717
min_p_sampling_from_probs,
1818
top_k_renorm_prob,

python/sglang/srt/models/minicpm3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
from sglang.srt.managers.schedule_batch import global_server_args_dict
4141
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4242
from sglang.srt.model_loader.weight_utils import default_weight_loader
43-
from sglang.srt.utils import add_prefix, is_cuda_available
43+
from sglang.srt.utils import add_prefix, is_cuda
4444

45-
if is_cuda_available():
45+
if is_cuda():
4646
from sgl_kernel import bmm_fp8
4747

4848

0 commit comments

Comments
 (0)