Skip to content

Commit 462d7c1

Browse files
njhillwseaton
andcommitted
only do forward context collective when needed
Signed-off-by: Nick Hill <[email protected]> Co-authored-by: Will Eaton <[email protected]> Signed-off-by: Nick Hill <[email protected]>
1 parent 6064eaf commit 462d7c1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

vllm/forward_context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_forward_context() -> ForwardContext:
6363
def set_forward_context(attn_metadata: Any,
6464
vllm_config: VllmConfig,
6565
virtual_engine: int = 0,
66-
num_tokens: int = 0):
66+
num_tokens: Optional[int] = None):
6767
"""A context manager that stores the current forward context,
6868
can be attention metadata, etc.
6969
Here we can inject common logic for every model forward pass.
@@ -73,8 +73,8 @@ def set_forward_context(attn_metadata: Any,
7373
if need_to_track_batchsize:
7474
forward_start_time = time.perf_counter()
7575
dp_metadata: Optional[DPMetadata] = None
76-
if vllm_config.parallel_config.data_parallel_size > 1:
77-
dp_size = vllm_config.parallel_config.data_parallel_size
76+
dp_size = vllm_config.parallel_config.data_parallel_size
77+
if dp_size > 1 and (attn_metadata is not None or num_tokens is not None):
7878
dp_rank = vllm_config.parallel_config.data_parallel_rank
7979
if attn_metadata is not None and hasattr(attn_metadata,
8080
"num_prefill_tokens"):

0 commit comments

Comments
 (0)