Skip to content

Commit afa9da9

Browse files
committed
fix pre-commit
Signed-off-by: jianjunzhong <jianjunzhong@foxmail.com>
1 parent fcf779d commit afa9da9

File tree

6 files changed

+41
-70
lines changed

6 files changed

+41
-70
lines changed

verl/single_controller/ray/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ def _init_with_detached_workers(self, worker_names, worker_handles):
433433

434434
def _get_master_addr_port(self, pg):
435435
"""Get master addr and port for this worker group"""
436+
436437
def _do_get_master_addr_port(pg):
437438
master_addr, master_port = ray.get(
438439
get_master_addr_port.options(
@@ -442,6 +443,7 @@ def _do_get_master_addr_port(pg):
442443
).remote()
443444
)
444445
return master_addr, master_port
446+
445447
if self._master_addr is None and self._master_port is None:
446448
self._master_addr, self._master_port = _do_get_master_addr_port(pg)
447449
elif self._master_addr is not None and self._master_port is not None:

verl/utils/vllm/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import torch
1615

17-
from typing import Callable
1816
from msgspec import field
1917
from packaging import version as vs
2018
from vllm.lora.models import LoRAModel

verl/workers/rollout/vllm_rollout/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
from importlib.metadata import PackageNotFoundError, version
1616

17-
from .vllm_rollout import vLLMAsyncRollout # noqa: F401
17+
from .vllm_rollout import ServerAdapter # noqa: F401
1818

1919

2020
def get_version(pkg):

verl/workers/rollout/vllm_rollout/utils.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,9 @@ def compute_logits(
6767

6868
model.compute_logits = MethodType(compute_logits, model)
6969

70+
7071
# copy from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py
71-
def rebuild_ipc(
72-
handle: tuple[Callable, tuple], device_id: int | None = None
73-
) -> torch.Tensor:
72+
def rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor:
7473
func, args = handle
7574
list_args = list(args)
7675
if device_id is not None:
@@ -80,13 +79,15 @@ def rebuild_ipc(
8079
buffer = func(*list_args)
8180
return buffer
8281

82+
8383
class FlattenedTensorMetadata(TypedDict):
8484
name: str
8585
shape: torch.Size
8686
dtype: torch.dtype
8787
# specify the start offset of this tensor in shared ipc_buffer tensor
8888
offset: int
8989

90+
9091
class vLLMColocateWorkerExtension:
9192
"""
9293
The class for vLLM's worker to inherit from, in the colocate setting.
@@ -96,6 +97,7 @@ class vLLMColocateWorkerExtension:
9697
NOTE: we define this class in a separate module, and the main module
9798
should pass the full qualified name as `worker_extension_cls` argument.
9899
"""
100+
99101
def __new__(cls, **kwargs):
100102
global_rank = kwargs.get("rank", 0) + int(os.environ.get("VERL_VLLM_MULTIPROC_GLOBAL_RANK_OFFSET", "0"))
101103
local_rank = kwargs.get("local_rank", 0)
@@ -115,7 +117,7 @@ def __new__(cls, **kwargs):
115117

116118
def monkey_patch_compute_logits(self, vocab_size: int):
117119
_monkey_patch_compute_logits(self.model_runner.model, vocab_size)
118-
120+
119121
def _fetch_weights(self, zmq_handle: str, load: bool = True):
120122
from vllm.model_executor.model_loader.utils import process_weights_after_loading
121123

@@ -126,14 +128,10 @@ def _fetch_weights(self, zmq_handle: str, load: bool = True):
126128
socket.connect(zmq_handle)
127129
weights_to_load = []
128130
while True:
129-
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = (
130-
socket.recv_pyobj()
131-
)
131+
payload: tuple[Callable, tuple] | list[FlattenedTensorMetadata] | None = socket.recv_pyobj()
132132
if payload is None:
133133
# means the update is done
134-
process_weights_after_loading(
135-
self.model_runner.model, self.model_config, self.device
136-
)
134+
process_weights_after_loading(self.model_runner.model, self.model_config, self.device)
137135
torch.cuda.synchronize()
138136
socket.send(b"")
139137
break
@@ -191,10 +189,10 @@ def update_lora_weights_from_ipc(self, peft_config: dict, zmq_handles: dict[str,
191189
lora_tensors=dict(lora_weights),
192190
)
193191
self.add_lora(lora_request)
192+
logger.info(f"vLLM load weights, loaded_params: {len(lora_weights)}")
194193
del lora_weights
195194
gc.collect()
196195
torch.cuda.empty_cache()
197-
logger.info(f"vLLM load weights, loaded_params: {len(lora_weights)}")
198196

199197
def report_device_id(self) -> str:
200198
"""Report device ID for ZMQ handle."""

verl/workers/rollout/vllm_rollout/vllm_async_server.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,17 @@
1313
# limitations under the License.
1414
import argparse
1515
import asyncio
16+
import inspect
1617
import json
1718
import logging
1819
import os
19-
from concurrent.futures import Future
2020
from pprint import pprint
21-
from typing import Any, Callable, Optional
21+
from typing import Any, Optional
2222

23-
import cloudpickle as pickle
2423
import numpy as np
2524
import ray
2625
import torch
2726
import vllm.entrypoints.cli.serve
28-
import zmq
29-
from filelock import FileLock
3027
from ray.actor import ActorHandle
3128
from vllm import SamplingParams
3229
from vllm.config import LoRAConfig
@@ -271,10 +268,9 @@ async def launch_server(self, master_address: str = None, master_port: int = Non
271268
server_args.append(json.dumps(v) if isinstance(v, dict) else str(v))
272269

273270
# pass worker_extension_cls parameter for cuda-ipc based weights updating
274-
server_args.extend([
275-
"--worker_extension_cls",
276-
"verl.workers.rollout.vllm_rollout.utils.vLLMColocateWorkerExtension"
277-
])
271+
server_args.extend(
272+
["--worker_extension_cls", "verl.workers.rollout.vllm_rollout.utils.vLLMColocateWorkerExtension"]
273+
)
278274

279275
if self.replica_rank == 0:
280276
pprint(server_args)
@@ -336,8 +332,7 @@ async def run_server(self, args: argparse.Namespace):
336332
# Don't keep the dummy data in memory
337333
await engine_client.reset_mm_cache()
338334
await engine_client.collective_rpc(
339-
method="monkey_patch_compute_logits",
340-
kwargs={"vocab_size": len(self.model_config.tokenizer)}
335+
method="monkey_patch_compute_logits", kwargs={"vocab_size": len(self.model_config.tokenizer)}
341336
)
342337

343338
app = build_app(args)
@@ -376,18 +371,12 @@ async def run_headless(self, args: argparse.Namespace):
376371
executor_class=Executor.get_class(vllm_config),
377372
log_stats=not engine_args.disable_log_stats,
378373
)
379-
374+
380375
async def collective_rpc(
381-
self,
382-
method: str,
383-
timeout: Optional[float] = None,
384-
args: tuple = (),
385-
kwargs: Optional[dict] = None
376+
self, method: str, timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict] = None
386377
):
387378
"""Perform a collective RPC call to the inference engine."""
388-
return await self.engine.collective_rpc(
389-
method=method, timeout=timeout, args=args, kwargs=kwargs
390-
)
379+
return await self.engine.collective_rpc(method=method, timeout=timeout, args=args, kwargs=kwargs)
391380

392381
async def generate(
393382
self,
@@ -582,7 +571,17 @@ def __init__(
582571
nnodes: int,
583572
cuda_visible_devices: str,
584573
):
585-
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes, cuda_visible_devices)
574+
super().__init__(
575+
config,
576+
model_config,
577+
rollout_mode,
578+
workers,
579+
replica_rank,
580+
node_rank,
581+
gpus_per_node,
582+
nnodes,
583+
cuda_visible_devices,
584+
)
586585

587586

588587
_rollout_worker_actor_cls = ray.remote(ServerAdapter)

verl/workers/rollout/vllm_rollout/vllm_rollout.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -26,49 +26,22 @@
2626
- After inference, all the parameters that doesn't belong to this pp rank is freed.
2727
"""
2828

29-
import asyncio
30-
import getpass
3129
import logging
3230
import os
33-
from dataclasses import asdict
34-
from types import MethodType
3531
from typing import Any, Generator, Optional
3632

37-
import cloudpickle as pickle
3833
import ray
3934
import torch
40-
import torch.distributed
4135
import zmq
42-
import zmq.asyncio
43-
from filelock import FileLock
36+
from packaging import version as vs
4437
from torch.distributed.device_mesh import DeviceMesh
4538
from torch.multiprocessing.reductions import reduce_tensor
46-
from vllm.config import LoRAConfig
47-
48-
try:
49-
from vllm.worker.worker_base import WorkerWrapperBase
50-
except ModuleNotFoundError:
51-
# https://github.com/vllm-project/vllm/commit/6a113d9aed8221a9c234535958e70e34ab6cac5b
52-
from vllm.v1.worker.worker_base import WorkerWrapperBase
53-
54-
from packaging import version as vs
5539

5640
from verl import DataProto
5741
from verl.third_party.vllm import VLLM_SLEEP_LEVEL, get_version
58-
from verl.utils.device import is_npu_available
59-
from verl.utils.distributed import initialize_global_process_group_ray
60-
from verl.utils.ray_utils import ray_noset_visible_devices
61-
from verl.utils.vllm import TensorLoRARequest, VLLMHijack, is_version_ge
62-
from verl.utils.vllm.vllm_fp8_utils import apply_vllm_fp8_patches, is_fp8_model, load_quanted_weights
42+
from verl.utils.vllm import VLLMHijack, is_version_ge
6343
from verl.workers.config import HFModelConfig, RolloutConfig
6444
from verl.workers.rollout.base import BaseRollout
65-
from verl.workers.rollout.utils import get_free_port, is_valid_ipv6_address
66-
from verl.workers.rollout.vllm_rollout.utils import (
67-
VLLM_LORA_INT_ID,
68-
VLLM_LORA_NAME,
69-
VLLM_LORA_PATH,
70-
get_vllm_max_lora_rank,
71-
)
7245

7346
logger = logging.getLogger(__file__)
7447
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@@ -98,6 +71,7 @@ class ServerAdapter(BaseRollout):
9871
vLLM server adapter used in native async mode, serve as a client to request vLLM server
9972
to resume/release/update weights and kv_cache.
10073
"""
74+
10175
def __init__(
10276
self,
10377
config: RolloutConfig,
@@ -110,7 +84,7 @@ def __init__(
11084
rank = int(os.environ["RANK"])
11185
local_world_size = int(os.environ["RAY_LOCAL_WORLD_SIZE"])
11286
rollout_world_size = (
113-
self.config.tensor_model_parallel_size
87+
self.config.tensor_model_parallel_size
11488
* self.config.data_parallel_size
11589
* self.config.pipeline_model_parallel_size
11690
)
@@ -122,7 +96,7 @@ def __init__(
12296
self.sleep_level = 1
12397
else:
12498
self.sleep_level = VLLM_SLEEP_LEVEL
125-
99+
126100
# Attributes related to weight updates
127101
from vllm.platforms import current_platform
128102

@@ -137,7 +111,7 @@ async def _execute_method(
137111
non_block: bool = False,
138112
timeout: Optional[float] = None,
139113
args: tuple = (),
140-
kwargs: Optional[dict] = None
114+
kwargs: Optional[dict] = None,
141115
) -> Any:
142116
"""Execute method on inference engine via ray.
143117
@@ -184,15 +158,15 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
184158
kwargs={
185159
"peft_config": peft_config,
186160
"zmq_handles": self.zmq_handles,
187-
}
161+
},
188162
)
189163
else:
190164
await self._execute_method(
191165
"update_weights_from_ipc",
192166
non_block=True,
193167
kwargs={
194168
"zmq_handles": self.zmq_handles,
195-
}
169+
},
196170
)
197171
await self._update_weights_per_tensor(weights)
198172

@@ -225,7 +199,7 @@ def set_server_handle(self, server_handle: ray.actor.ActorHandle):
225199
"""Set vLLMHttpServer handle"""
226200
if self.rollout_rank == 0:
227201
self.server_handle = server_handle
228-
202+
229203
def get_update_weights_zmq_handle(self) -> dict[str, str]:
230204
"""Get ZMQ handle for weight updates."""
231205
suffix = f"{self.device_uuid}-{self.zmq_address_counter}"

0 commit comments

Comments
 (0)