Skip to content

Commit 1d609d1

Browse files
committed
format
1 parent ff8f714 commit 1d609d1

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,9 +429,9 @@ def event_loop_normal_disagg_decode(self):
429429
# polling and allocating kv cache
430430
self.process_decode_queue()
431431
batch = self.get_next_disagg_decode_batch_to_run()
432-
432+
433433
is_real_batch = True
434-
434+
435435
if batch and batch.forward_mode.is_extend():
436436
self.cur_batch = batch
437437
# Generate fake extend output.
@@ -442,9 +442,12 @@ def event_loop_normal_disagg_decode(self):
442442
is_real_batch = False
443443

444444
# Handle DP attention
445-
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
445+
if (
446+
self.server_args.enable_dp_attention
447+
or self.server_args.enable_sp_layernorm
448+
):
446449
batch, _ = self.prepare_dp_attn_batch(batch)
447-
450+
448451
if is_real_batch:
449452
self.cur_batch = batch
450453

python/sglang/srt/disaggregation/prefill.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,12 @@ def event_loop_normal_disagg_prefill(self):
183183
)
184184
self.process_prefill_chunk()
185185
batch = self.get_new_batch_prefill()
186-
186+
187187
# Handle DP attention
188-
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
188+
if (
189+
self.server_args.enable_dp_attention
190+
or self.server_args.enable_sp_layernorm
191+
):
189192
batch, _ = self.prepare_dp_attn_batch(batch)
190193

191194
self.cur_batch = batch

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
import zmq
2525

2626
from sglang.srt.disaggregation.utils import DisaggregationMode
27-
from sglang.srt.managers.schedule_batch import Req
2827
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
2928
from sglang.srt.managers.io_struct import (
3029
TokenizedEmbeddingReqInput,
3130
TokenizedGenerateReqInput,
3231
)
32+
from sglang.srt.managers.schedule_batch import Req
3333
from sglang.srt.managers.scheduler import run_scheduler_process
3434
from sglang.srt.server_args import PortArgs, ServerArgs
3535
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
@@ -225,7 +225,9 @@ def launch_tensor_parallel_group(
225225
def round_robin_scheduler(self, req: Req):
226226
if self.server_args.disaggregation_mode == DisaggregationMode.NULL:
227227
self.workers[self.round_robin_counter].send_pyobj(req)
228-
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
228+
self.round_robin_counter = (self.round_robin_counter + 1) % len(
229+
self.workers
230+
)
229231
else:
230232
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
231233

0 commit comments

Comments
 (0)