Skip to content
80 changes: 76 additions & 4 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import ray
import torch

from vllm.distributed import (broadcast_tensor_dict,
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
is_pipeline_model_parallel_first_rank,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: update the test

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, forgot to rerun tests

is_pipeline_model_parallel_last_rank,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)

from ..utils import (init_test_distributed_environment,
multi_process_tensor_parallel)
from ..utils import init_test_distributed_environment, multi_process_parallel


@ray.remote(num_gpus=1, max_calls=1)
Expand Down Expand Up @@ -105,6 +106,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
assert torch.allclose(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

test_dict = {
# device tensor
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
# CPU tensor
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
"c": "test",
"d": [1, 2, 3],
"e": {
"a": 1,
"b": 2
},
# empty tensor
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
}

if not is_pipeline_model_parallel_first_rank():
recv_dict = get_pp_group().recv_tensor_dict()

if not is_pipeline_model_parallel_last_rank():
get_pp_group().send_tensor_dict(test_dict)

if not is_pipeline_model_parallel_first_rank():
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])


@ray.remote(num_gpus=1, max_calls=1)
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
distributed_init_port: str):
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)

size = 64
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")

if not is_pipeline_model_parallel_first_rank():
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)

if not is_pipeline_model_parallel_last_rank():
get_pp_group().send(test_tensor)

if not is_pipeline_model_parallel_first_rank():
assert torch.allclose(test_tensor, recv_tensor)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tp_size", [2])
Expand All @@ -113,4 +176,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
broadcast_tensor_dict_test_worker
])
def test_multi_process_tensor_parallel(tp_size, test_target):
multi_process_tensor_parallel(tp_size, 1, test_target)
multi_process_parallel(tp_size, 1, test_target)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("pp_size", [2])
@pytest.mark.parametrize(
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
def test_multi_process_pipeline_parallel(pp_size, test_target):
multi_process_parallel(1, pp_size, test_target)
5 changes: 2 additions & 3 deletions tests/distributed/test_custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
get_tp_group, graph_capture)

from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment,
multi_process_tensor_parallel)
init_test_distributed_environment, multi_process_parallel)

random.seed(42)
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
Expand Down Expand Up @@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def init_test_distributed_environment(
ensure_model_parallel_initialized(tp_size, pp_size)


def multi_process_tensor_parallel(
def multi_process_parallel(
tp_size: int,
pp_size: int,
test_target,
Expand Down
14 changes: 2 additions & 12 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,36 +121,26 @@ def all_reduce(self,
ncclRedOpTypeEnum.from_torch(op), self.comm,
cudaStream_t(stream.cuda_stream))

def send(self,
tensor: torch.Tensor,
dst: Optional[int] = None,
stream=None):
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if dst is None:
dst = (self.rank + 1) % self.world_size
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
self.comm, cudaStream_t(stream.cuda_stream))

def recv(self,
tensor: torch.Tensor,
src: Optional[int] = None,
stream=None):
def recv(self, tensor: torch.Tensor, src: int, stream=None):
if self.disabled:
return
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
if stream is None:
stream = self.stream
if src is None:
src = (self.rank - 1) % self.world_size
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
ncclDataTypeEnum.from_torch(tensor.dtype), src,
self.comm, cudaStream_t(stream.cuda_stream))
Expand Down
187 changes: 187 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
steps.
"""
import contextlib
import pickle
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
Expand All @@ -28,6 +29,7 @@
from unittest.mock import patch

import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup

import vllm.envs as envs
Expand Down Expand Up @@ -172,6 +174,16 @@ def last_rank(self):
"""Return the global rank of the last process in the group"""
return self.ranks[-1]

@property
def is_first_rank(self):
"""Return whether the caller is the first process in the group"""
return self.rank == self.first_rank

@property
def is_last_rank(self):
"""Return whether the caller is the last process in the group"""
return self.rank == self.last_rank

@property
def next_rank(self):
"""Return the global rank of the process that follows the caller"""
Expand Down Expand Up @@ -342,6 +354,70 @@ def broadcast_object_list(self,
group=self.device_group)
return obj_list

def send_object(self, obj: Any, dst: int) -> None:
"""Send the input object list to the destination rank."""
"""NOTE: `dst` is the local rank of the destination rank."""

assert dst < self.world_size, f"Invalid dst rank ({dst})"

assert dst != self.rank, (
"Invalid destination rank. Destination rank is the same "
"as the current rank.")

# Serialize object to tensor and get the size as well
object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

size_tensor = torch.tensor([object_tensor.numel()],
dtype=torch.long,
device="cpu")

# Send object size

torch.distributed.send(size_tensor,
dst=self.ranks[dst],
group=self.cpu_group)

# Send object
torch.distributed.send(object_tensor,
dst=self.ranks[dst],
group=self.cpu_group)

return None

def recv_object(self, src: int) -> Any:
"""Receive the input object list from the source rank."""
"""NOTE: `src` is the local rank of the source rank."""

assert src < self.world_size, f"Invalid src rank ({src})"

assert src != self.rank, (
"Invalid source rank. Source rank is the same as the current rank."
)

size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

# Receive object size
rank_size = torch.distributed.recv(size_tensor,
src=src,
group=self.cpu_group)

# Tensor to receive serialized objects into.
object_tensor = torch.empty( # type: ignore[call-overload]
size_tensor.item(), # type: ignore[arg-type]
dtype=torch.uint8,
device="cpu")

rank_object = torch.distributed.recv(object_tensor,
src=src,
group=self.cpu_group)

assert rank_object == rank_size, (
"Received object sender rank does not match the size sender rank.")

obj = pickle.loads(object_tensor.numpy().tobytes())

return obj

def broadcast_tensor_dict(
self,
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
Expand Down Expand Up @@ -433,6 +509,88 @@ def broadcast_tensor_dict(
async_handle.wait()
return tensor_dict

def send_tensor_dict(
self,
tensor_dict: Dict[Any, Union[torch.Tensor, Any]],
dst: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Send the input tensor dictionary.
NOTE: `dst` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return tensor_dict

group = self.device_group
metadata_group = self.cpu_group

if dst is None:
dst = self.next_rank
assert dst < self.world_size, f"Invalid dst rank ({dst})"

metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), f"Expecting a dictionary, got {type(tensor_dict)}"
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `send_object_list` has serialization & deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
self.send_object(metadata_list, dst=dst)
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip sending empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.send(tensor, dst=dst, group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.send(tensor, dst=dst, group=group)
return None

def recv_tensor_dict(
self,
src: Optional[int] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Recv the input tensor dictionary.
NOTE: `src` is the local rank of the source rank.
"""
# Bypass the function if we are using only 1 GPU.
if not torch.distributed.is_initialized() or self.world_size == 1:
return None

group = self.device_group
metadata_group = self.cpu_group

if src is None:
src = self.prev_rank
assert src < self.world_size, f"Invalid src rank ({src})"

recv_metadata_list = self.recv_object(src=src)
tensor_dict = {}
for key, value in recv_metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
torch.distributed.recv(tensor,
src=src,
group=metadata_group)
else:
# use group for GPU tensors
torch.distributed.recv(tensor, src=src, group=group)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
return tensor_dict

def barrier(self):
"""Barrier synchronization among the group.
NOTE: don't use `device_group` here! `barrier` in NCCL is
Expand All @@ -442,6 +600,35 @@ def barrier(self):
"""
torch.distributed.barrier(group=self.cpu_group)

def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
"""Sends a tensor to the destination rank in a non-blocking way"""
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = self.next_rank

pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.send(tensor, dst)
else:
torch.distributed.send(tensor, self.ranks[dst], self.device_group)

def recv(self,
size: torch.Size,
dtype: torch.dtype,
src: Optional[int] = None) -> torch.Tensor:
"""Receives a tensor from the src rank."""
"""NOTE: `src` is the local rank of the destination rank."""
if src is None:
src = self.prev_rank

tensor = torch.empty(size, dtype=dtype, device=self.device)
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
pynccl_comm.recv(tensor, src)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor

def destroy(self):
if self.device_group is not None:
torch.distributed.destroy_process_group(self.device_group)
Expand Down