Skip to content

Commit 31047ac

Browse files
authored
Add pyhccl (vllm-project#503)
This is the first step to support trl vllm serve on Ascend NPU vllm-project#459. This PR can work properly only when vllm-project/vllm#16464 is merged into vLLM. --------- Signed-off-by: hzji210@gmail.com <hzji210@gmail.com> Signed-off-by: nanxing <1014662416@qq.com>
1 parent 34448fa commit 31047ac

8 files changed

Lines changed: 589 additions & 1 deletion

File tree

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import multiprocessing
20+
import os
21+
22+
import torch
23+
import torch_npu # noqa: F401
24+
from vllm.distributed.parallel_state import (get_world_group,
25+
init_distributed_environment)
26+
from vllm.utils import update_environment_variables
27+
28+
from vllm_ascend.distributed.device_communicators.pyhccl import \
29+
PyHcclCommunicator
30+
31+
32+
def distributed_run(fn, world_size):
33+
number_of_processes = world_size
34+
processes: list[multiprocessing.Process] = []
35+
for i in range(number_of_processes):
36+
env: dict[str, str] = {}
37+
env['RANK'] = str(i)
38+
env['LOCAL_RANK'] = str(i)
39+
env['WORLD_SIZE'] = str(number_of_processes)
40+
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
41+
env['MASTER_ADDR'] = 'localhost'
42+
env['MASTER_PORT'] = '12345'
43+
p = multiprocessing.Process(target=fn, args=(env, ))
44+
processes.append(p)
45+
p.start()
46+
47+
for p in processes:
48+
p.join()
49+
50+
for p in processes:
51+
assert p.exitcode == 0
52+
53+
54+
def worker_fn_wrapper(fn):
55+
# `multiprocessing.Process` cannot accept environment variables directly
56+
# so we need to pass the environment variables as arguments
57+
# and update the environment variables in the function
58+
def wrapped_fn(env):
59+
update_environment_variables(env)
60+
local_rank = os.environ['LOCAL_RANK']
61+
device = torch.device(f"npu:{local_rank}")
62+
torch.npu.set_device(device)
63+
init_distributed_environment(backend="hccl")
64+
fn()
65+
66+
return wrapped_fn
67+
68+
69+
@worker_fn_wrapper
70+
def worker_fn():
71+
pynccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
72+
device=get_world_group().device)
73+
tensor = torch.ones(16, 1024, 1024,
74+
dtype=torch.float32).npu(pynccl_comm.rank)
75+
tensor = pynccl_comm.all_reduce(tensor)
76+
torch.npu.synchronize()
77+
assert torch.all(tensor == pynccl_comm.world_size).cpu().item()
78+
79+
80+
# def test_pyhccl():
81+
# distributed_run(worker_fn, 2)
82+
83+
84+
@worker_fn_wrapper
85+
def broadcast_worker_fn():
86+
# Test broadcast for every root rank.
87+
# Essentially this is an all-gather operation.
88+
pyhccl_comm = PyHcclCommunicator(get_world_group().cpu_group,
89+
device=get_world_group().device)
90+
recv_tensors = [
91+
torch.empty(16,
92+
1024,
93+
1024,
94+
dtype=torch.float32,
95+
device=pyhccl_comm.device)
96+
for i in range(pyhccl_comm.world_size)
97+
]
98+
recv_tensors[pyhccl_comm.rank] = torch.ones(
99+
16, 1024, 1024, dtype=torch.float32,
100+
device=pyhccl_comm.device) * pyhccl_comm.rank
101+
102+
for i in range(pyhccl_comm.world_size):
103+
pyhccl_comm.broadcast(recv_tensors[i], src=i)
104+
# the broadcast op might be launched in a different stream
105+
# need to synchronize to make sure the tensor is ready
106+
torch.npu.synchronize()
107+
assert torch.all(recv_tensors[i] == i).cpu().item()
108+
109+
110+
# def test_pyhccl_broadcast():
111+
# distributed_run(broadcast_worker_fn, 4)

tests/singlecard/test_pyhccl.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import torch
20+
import torch_npu # noqa: F401
21+
22+
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import \
23+
HCCLLibrary
24+
25+
26+
def test_hcclGetUniqueId():
27+
torch.npu.set_device(0)
28+
lib = HCCLLibrary()
29+
unique_id = lib.hcclGetUniqueId()
30+
assert unique_id is not None

vllm_ascend/distributed/communicator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ def __init__(self,
3030
device_group: Optional[ProcessGroup] = None,
3131
unique_name: str = ""):
3232
super().__init__(cpu_group, device, device_group, unique_name)
33+
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
3334
# init device according to rank
3435
self.device = torch.npu.current_device()

vllm_ascend/distributed/device_communicators/__init__.py

Whitespace-only changes.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from typing import Optional, Union
19+
20+
import torch
21+
import torch.distributed as dist
22+
import torch_npu # noqa: F401
23+
from torch.distributed import ProcessGroup, ReduceOp
24+
from vllm.distributed.utils import StatelessProcessGroup
25+
from vllm.logger import logger
26+
27+
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
28+
HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t, hcclDataTypeEnum,
29+
hcclRedOpTypeEnum, hcclUniqueId)
30+
from vllm_ascend.utils import current_stream
31+
32+
33+
class PyHcclCommunicator:
34+
35+
def __init__(
36+
self,
37+
group: Union[ProcessGroup, StatelessProcessGroup],
38+
device: Union[int, str, torch.device],
39+
library_path: Optional[str] = None,
40+
):
41+
"""
42+
Args:
43+
group: the process group to work on. If None, it will use the
44+
default process group.
45+
device: the device to bind the PyHcclCommunicator to. If None,
46+
it will be bind to f"npu:{local_rank}".
47+
library_path: the path to the HCCL library. If None, it will
48+
use the default library path.
49+
It is the caller's responsibility to make sure each communicator
50+
is bind to a unique device.
51+
"""
52+
53+
if not isinstance(group, StatelessProcessGroup):
54+
assert dist.is_initialized()
55+
assert dist.get_backend(group) != dist.Backend.HCCL, (
56+
"PyHcclCommunicator should be attached to a non-HCCL group.")
57+
# note: this rank is the rank in the group
58+
self.rank = dist.get_rank(group)
59+
self.world_size = dist.get_world_size(group)
60+
else:
61+
self.rank = group.rank
62+
self.world_size = group.world_size
63+
64+
self.group = group
65+
66+
# if world_size == 1, no need to create communicator
67+
if self.world_size == 1:
68+
self.available = False
69+
self.disabled = True
70+
return
71+
72+
try:
73+
self.hccl = HCCLLibrary(library_path)
74+
except Exception:
75+
# disable because of missing HCCL library
76+
# e.g. in a non-NPU environment
77+
self.available = False
78+
self.disabled = True
79+
return
80+
81+
self.available = True
82+
self.disabled = False
83+
84+
logger.info("vLLM is using pyhccl")
85+
86+
if isinstance(device, int):
87+
device = torch.device(f"npu:{device}")
88+
elif isinstance(device, str):
89+
device = torch.device(device)
90+
# now `device` is a `torch.device` object
91+
assert isinstance(device, torch.device)
92+
self.device = device
93+
94+
if self.rank == 0:
95+
# get the unique id from HCCL
96+
with torch.npu.device(device):
97+
self.unique_id = self.hccl.hcclGetUniqueId()
98+
else:
99+
# construct an empty unique id
100+
self.unique_id = hcclUniqueId()
101+
102+
if not isinstance(group, StatelessProcessGroup):
103+
tensor = torch.ByteTensor(list(self.unique_id.internal))
104+
ranks = dist.get_process_group_ranks(group)
105+
# arg `src` in `broadcast` is the global rank
106+
dist.broadcast(tensor, src=ranks[0], group=group)
107+
byte_list = tensor.tolist()
108+
for i, byte in enumerate(byte_list):
109+
self.unique_id.internal[i] = byte
110+
else:
111+
self.unique_id = group.broadcast_obj(self.unique_id, src=0)
112+
113+
# hccl communicator and stream will use this device
114+
# `torch.npu.device` is a context manager that changes the
115+
# current npu device to the specified one
116+
with torch.npu.device(device):
117+
self.comm: hcclComm_t = self.hccl.hcclCommInitRank(
118+
self.world_size, self.unique_id, self.rank)
119+
120+
stream = current_stream()
121+
# A small all_reduce for warmup.
122+
data = torch.zeros(1, device=device)
123+
self.all_reduce(data)
124+
stream.synchronize()
125+
del data
126+
127+
def all_reduce(self,
128+
in_tensor: torch.Tensor,
129+
op: ReduceOp = ReduceOp.SUM,
130+
stream=None) -> torch.Tensor:
131+
if self.disabled:
132+
return None
133+
# hccl communicator created on a specific device
134+
# will only work on tensors on the same device
135+
# otherwise it will cause "illegal memory access"
136+
assert in_tensor.device == self.device, (
137+
f"this hccl communicator is created to work on {self.device}, "
138+
f"but the input tensor is on {in_tensor.device}")
139+
140+
out_tensor = torch.empty_like(in_tensor)
141+
142+
if stream is None:
143+
stream = current_stream()
144+
self.hccl.hcclAllReduce(buffer_type(in_tensor.data_ptr()),
145+
buffer_type(out_tensor.data_ptr()),
146+
in_tensor.numel(),
147+
hcclDataTypeEnum.from_torch(in_tensor.dtype),
148+
hcclRedOpTypeEnum.from_torch(op), self.comm,
149+
aclrtStream_t(stream.npu_stream))
150+
return out_tensor
151+
152+
def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
153+
if self.disabled:
154+
return
155+
assert tensor.device == self.device, (
156+
f"this hccl communicator is created to work on {self.device}, "
157+
f"but the input tensor is on {tensor.device}")
158+
if stream is None:
159+
stream = current_stream()
160+
if src == self.rank:
161+
buffer = buffer_type(tensor.data_ptr())
162+
else:
163+
buffer = buffer_type(tensor.data_ptr())
164+
self.hccl.hcclBroadcast(buffer, tensor.numel(),
165+
hcclDataTypeEnum.from_torch(tensor.dtype), src,
166+
self.comm, aclrtStream_t(stream.npu_stream))

0 commit comments

Comments
 (0)