Skip to content

Commit 25861fb

Browse files
authored
[sglang] Upgrade sglang to 0.4.6.post1 & misc fixes (volcengine#1385)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? - [x] upgrade required sglang version to 0.4.6.post1 which suports Qwen3 - [x] fix: flush_cache was never awaited - [x] remove unused env - [x] fix: add rank num to port to avoid SGLang picking the same port when random.seed being set - [x] feat: disable SGLang memory inbalance check by default sgl-project/sglang#5426 - [x] update setup.py to avoid old version pip can not resolving deps - [x] fix: tools_kwargs length mismatch with batch volcengine#1380 > Add one-line overview of what this PR aims to achieve or accomplish. ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add CI test(s) if neccessary.
1 parent 1d14d57 commit 25861fb

File tree

9 files changed

+26
-27
lines changed

9 files changed

+26
-27
lines changed

.github/workflows/e2e_ppo_trainer.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ jobs:
180180
HF_ENDPOINT: "https://hf-mirror.com"
181181
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
182182
container:
183-
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3
183+
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post1
184184
options: --gpus all --shm-size=10g
185185
steps:
186186
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -209,7 +209,7 @@ jobs:
209209
HF_ENDPOINT: "https://hf-mirror.com"
210210
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
211211
container:
212-
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3
212+
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post1
213213
options: --gpus all --shm-size=10g
214214
steps:
215215
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -238,7 +238,7 @@ jobs:
238238
HF_ENDPOINT: "https://hf-mirror.com"
239239
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
240240
container:
241-
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3
241+
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post1
242242
options: --gpus all --shm-size=10g
243243
steps:
244244
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -267,7 +267,7 @@ jobs:
267267
HF_ENDPOINT: "https://hf-mirror.com"
268268
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
269269
container:
270-
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3
270+
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post1
271271
options: --gpus all --shm-size=50g # Visual dataloader requires large memory
272272
steps:
273273
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

.github/workflows/sgl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
HF_HUB_ENABLE_HF_TRANSFER: 1
4242
SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK: "True"
4343
container:
44-
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3
44+
image: ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post1
4545
options: --gpus all --shm-size=10g
4646
steps:
4747
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

docker/Dockerfile.sglang

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
3636
pip config set global.extra-index-url "${PIP_INDEX}" && \
3737
python -m pip install --upgrade pip
3838

39-
# Install sglang-0.4.5.post3 and torch-memory-saver
40-
RUN pip install "sglang[all]==0.4.5.post3" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
39+
# Install sglang-0.4.6.post1 and torch-memory-saver
40+
RUN pip install "sglang[all]==0.4.6.post1" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
4141

4242
# Install torch-2.6.0
4343
RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 tensordict torchdata \
@@ -47,8 +47,8 @@ RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.
4747

4848
# Install flash_attn-2.7.4.post1
4949
RUN pip uninstall -y transformer-engine flash-attn && \
50-
wget -v https://ghfast.top/https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
51-
pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
50+
wget -v https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
51+
pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
5252

5353
# Fix cv2
5454
RUN pip uninstall -y pynvml nvidia-ml-py && \

docs/start/install.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ For vLLM with Megatron or FSDP, please use ``whatcanyousee/verl:ngc-cu124-vllm0.
4242

4343
For latest vLLM with FSDP, please refer to ``hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.4-flashinfer0.2.2-cxx11abi0``.
4444

45-
For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.5.post3`` which is provided by SGLang RL Group.
45+
For SGLang with FSDP, please use ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post1`` which is provided by SGLang RL Group.
4646

4747
See files under ``docker/`` for NGC-based image or if you want to build your own.
4848

docs/workers/sglang_worker.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ Introduction
1010
------------
1111
`SGLang <https://github.com/sgl-project/sglang>`_ is an open-source state-of-the-art inference service engine, fully adopted by xAI to support all inference needs of Grok during research and serving processes.
1212

13-
Currently, verl fully supports using SGLang as the inference engine during the rollout phase. As a rollout engine, SGLang provides the same feature coverage as vLLM., including memory saving and multi-node rollout features. After installing verl and SGLang, simply add ``actor_rollout_ref.rollout.name=sglang`` at startup to seamlessly switch between the two inference frameworks.
13+
Currently, verl fully supports using SGLang as the inference engine during the rollout phase. As a rollout engine, SGLang provides the same feature coverage as vLLM., including memory saving and multi-node rollout features. After installing verl and SGLang, simply add ``actor_rollout_ref.rollout.name=sglang`` at startup script to seamlessly switch between the two inference frameworks.
1414

1515
In addition, the SGLang team is actively working on supporting features such as Multi-Turn Agentic RL, VLM RLHF, Server-Based RLHF, and Partial Rollout. You can track the related development progress in the `Tracking Roadmap <https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/issues/74>`_.
1616

1717
Installation
1818
------------
19-
First, follow the requirements outlined in `Install SGLang as rollout backend <https://verl.readthedocs.io/en/latest/start/install.html#install-sglang-as-rollout-backend>`_ for installation, and ensure that the version requirements are met. Generally, using the latest `SGLang <https://github.com/sgl-project/sglang>`_ from the main branch will allow stable training startup without needing to target a specific version.
19+
Please always follow the following command to install SGLang with verl.
2020

2121
.. code-block:: bash
22-
23-
# Currently 0.4.5, subject to updates at any time, please refer to the latest version
24-
pip install "sglang[all]>=0.4.5" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
22+
pip install --upgrade pip
23+
# Currently 0.4.6.post1, subject to updates at any time, please refer to the latest version specified in `setup.py`
24+
pip install -e ".[sglang]"
2525
2626
Using SGLang as the Inference Backend for PPO Training on a Single Machine
2727
-------------------------------------------------------------------------

recipe/dapo/src/main_dapo.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
def get_custom_reward_fn(config):
2727
import importlib.util
28-
import os
2928

3029
reward_fn_config = config.get("custom_reward_function") or {}
3130
file_path = reward_fn_config.get("path")
@@ -58,9 +57,6 @@ def main(config):
5857

5958

6059
def run_ppo(config) -> None:
61-
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
62-
# isolation, will solve in the future
63-
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
6460
if not ray.is_initialized():
6561
# this is for local ray cluster
6662
ray.init(

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636
"pybind11",
3737
"pylatexenc",
3838
"ray[default]>=2.10",
39-
"tensordict<=0.6.2",
4039
"torchdata",
40+
"tensordict<=0.6.2",
4141
"transformers",
4242
"wandb",
4343
]
@@ -50,7 +50,7 @@
5050
VLLM_REQUIRES = ["tensordict<=0.6.2", "vllm<=0.8.3"]
5151
SGLANG_REQUIRES = [
5252
"tensordict<=0.6.2",
53-
"sglang[all]==0.4.5.post3",
53+
"sglang[srt,openai]==0.4.6.post1",
5454
"torch-memory-saver>=0.0.5",
5555
]
5656

verl/trainer/main_ppo.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ def main(config):
6565

6666

6767
def run_ppo(config) -> None:
68-
# TODO(linjunrong.ocss884): this ENV is left for resolving SGLang conflict with ray devices
69-
# isolation, will solve in the future
70-
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
7168
if not ray.is_initialized():
7269
# this is for local ray cluster
7370
ray.init(

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
"""
113113
super().__init__()
114114
self.config = config
115+
os.environ.setdefault("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", "true")
115116

116117
assert not (not config.enforce_eager and config.free_cache_engine), "disable CUDA graph (enforce_eager = False) if free cache engine"
117118

@@ -128,7 +129,6 @@ def __init__(
128129
tensor_model_parallel_size=tensor_parallel_size,
129130
num_tp_per_train_tp=num_tp_per_train_tp,
130131
)
131-
132132
assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, "model context length should be greater than total sequence length"
133133

134134
tp_size = tensor_parallel_size
@@ -144,6 +144,7 @@ def __init__(
144144
# device_mesh_device = init_device_mesh("cuda", **device_mesh_kwargs)
145145

146146
# get tp_rank of this process in this tp group
147+
rank = device_mesh_cpu.get_rank()
147148
tp_rank = device_mesh_cpu["tp"].get_local_rank()
148149
visible_devices = [None] * device_mesh_cpu.size(1)
149150
torch.distributed.all_gather_object(visible_devices, os.environ["CUDA_VISIBLE_DEVICES"], device_mesh_cpu.get_group("tp"))
@@ -178,7 +179,10 @@ def __init__(
178179
load_format=load_format,
179180
dist_init_addr=dist_init_addr,
180181
nnodes=nnodes,
181-
# NOTE(Chenyang): if you want to debug the sglang engine
182+
# NOTE(linjunrong): add rank to prevent SGLang generate same port inside PortArgs.init_new
183+
# when random.seed is being set during training
184+
port=30000 + rank,
185+
# NOTE(Chenyang): if you want to debug the SGLang engine output
182186
# please set the following parameters
183187
# Otherwise, it will make the engine run too slow
184188
# log_level="INFO",
@@ -320,6 +324,8 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
320324
batch_size = batch_size * self.sampling_params["n"]
321325
if "multi_modal_inputs" in non_tensor_batch.keys():
322326
non_tensor_batch["multi_modal_inputs"] = np.repeat(non_tensor_batch["multi_modal_inputs"], self.sampling_params["n"], axis=0)
327+
if "tools_kwargs" in non_tensor_batch.keys():
328+
non_tensor_batch["tools_kwargs"] = np.repeat(non_tensor_batch["tools_kwargs"], self.sampling_params["n"], axis=0)
323329
seq = torch.cat([idx, response], dim=-1)
324330

325331
response_length = response.size(1)
@@ -350,6 +356,6 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
350356

351357
# free cache engine
352358
if self.config.free_cache_engine and self.inference_engine._engine is not None and self.inference_engine._engine.tokenizer_manager is not None:
353-
self.inference_engine._engine.tokenizer_manager.flush_cache()
359+
self.inference_engine._engine.flush_cache()
354360

355361
return DataProto(batch=batch, non_tensor_batch=non_tensor_batch)

0 commit comments

Comments
 (0)