Skip to content

Commit e7aa0a6

Browse files
PengchengShi00root
andauthored
refactor rollout weight update flow (#1828)
* refactor rollout weight update flow * refactor weight update flow & support lmdeploy-disaggerate weight update * Add SGLang weight update correctness tests * Resolve conflicts with main * Update checkpoint tests for update_rollout_info changes * [CI] Skip lmdeploy disaggregated and sglang update weight correctness check * Fix pre-commit formatting issues * address weight update review comments * fix CI * fix bug:create DeviceMesh before using it * rerun CI --------- Co-authored-by: root <root@test-spc.shipengcheng.ailab-sys.svc.pjlab.local>
1 parent 9502f8b commit e7aa0a6

12 files changed

Lines changed: 1718 additions & 1437 deletions

File tree

tests/rl/test_rl_trainer_checkpoint.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,17 @@ def __init__(self):
127127
self.update_weights_count = 0
128128
self.rollout_info = None
129129

130-
def set_train_rollout_mode(self, mode: str):
131-
self.train_rollout_mode = mode
132-
133-
def update_rollout_info(self, info):
130+
def update_rollout_info(
131+
self,
132+
info,
133+
train_rollout_mode,
134+
weight_update_host,
135+
weight_update_port
136+
):
134137
self.rollout_info = info
138+
self.train_rollout_mode = train_rollout_mode
139+
self.weight_update_host = weight_update_host
140+
self.weight_update_port = weight_update_port
135141

136142
def onload(self, target="all"):
137143
return f"onload:{target}"

tests/rl/test_update_weight_disaggregated.py

Lines changed: 109 additions & 275 deletions
Large diffs are not rendered by default.

xtuner/v1/rl/rollout/worker.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ class RolloutConfig(BaseModel):
144144
gpu_memory_utilization (float): GPU memory utilization ratio. Defaults to 0.85.
145145
random_seed (int): Random seed for reproducible generation. Defaults to 1024.
146146
rollout_cross_node_comm (bool): Enable cross-node communication. Defaults to False.
147+
weight_update_host (Optional[str]): Host used by train rank 0 to initialize the external NCCL weight update
148+
group. Defaults to None.
149+
weight_update_port (Optional[int]): Port used by train rank 0 to initialize the external NCCL weight update
150+
group. Defaults to 30000.
147151
rollout_max_batch_size_per_instance (int): Maximum batch size for the rollout worker. If not set, it
148152
will be determined automatically based on `context_length`. Defaults to 512.
149153
allow_over_concurrency_ratio (float): Deprecated compatibility option. Rollout runtime concurrency is
@@ -223,6 +227,26 @@ class RolloutConfig(BaseModel):
223227
help="Base port number for distributed communication among rollout workers.",
224228
),
225229
] = 25000
230+
weight_update_host: Annotated[
231+
Optional[str],
232+
Parameter(
233+
group=infer_group,
234+
help=(
235+
"Host used by train rank 0 to initialize the external NCCL weight update group. "
236+
"Only used for NCCL weight update."
237+
),
238+
),
239+
] = None
240+
weight_update_port: Annotated[
241+
Optional[int],
242+
Parameter(
243+
group=infer_group,
244+
help=(
245+
"Port used by train rank 0 to initialize the external NCCL weight update group. "
246+
"Only used for NCCL weight update."
247+
),
248+
),
249+
] = 30000
226250
rollout_max_batch_size_per_instance: Annotated[
227251
Optional[int],
228252
Parameter(

xtuner/v1/rl/trainer/controller.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,18 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"):
290290
ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore
291291
return
292292

293-
def update_rollout_info(self, info_dict):
294-
ray.get([worker.update_rollout_info.remote(**info_dict) for worker in self.workers]) # type: ignore[attr-defined]
295-
296-
def set_train_rollout_mode(self, train_rollout_mode: str):
297-
ray.get([worker.set_train_rollout_mode.remote(train_rollout_mode) for worker in self.workers])
293+
def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None):
294+
ray.get(
295+
[
296+
worker.update_rollout_info.remote(
297+
**info_dict,
298+
train_rollout_mode=train_rollout_mode,
299+
weight_update_host=weight_update_host,
300+
weight_update_port=weight_update_port,
301+
)
302+
for worker in self.workers
303+
]
304+
)
298305

299306
def update_weights(self):
300307
"""Update the weights of the training workers."""

0 commit comments

Comments
 (0)