Skip to content

Commit 775d745

Browse files
update weight from disk design
1 parent e5cdbfb commit 775d745

7 files changed

Lines changed: 297 additions & 37 deletions

File tree

xtuner/v1/rl/trainer/controller.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,12 @@ def onload(self, target: Literal["model", "optimizer", "all"] = "all"):
290290
ray.get([worker.onload_optimizer.remote() for worker in self.workers], timeout=TRAIN_RAY_GET_TIMEOUT) # type: ignore
291291
return
292292

293-
def update_rollout_info(self, info_dict, train_rollout_mode, weight_update_host=None, weight_update_port=None):
293+
def update_rollout_info(self, info_dict, weight_update_mode, weight_update_host=None, weight_update_port=None):
294294
ray.get(
295295
[
296296
worker.update_rollout_info.remote(
297297
**info_dict,
298-
train_rollout_mode=train_rollout_mode,
298+
weight_update_mode=weight_update_mode,
299299
weight_update_host=weight_update_host,
300300
weight_update_port=weight_update_port,
301301
)

xtuner/v1/rl/weight_update/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
from .data import (
22
DeviceMeshRaw,
3+
DiskUpdateUpstreamTransport,
34
RolloutBackend,
45
RolloutEngineInfo,
56
RolloutWeightUpdateInfo,
67
ServiceUrlMap,
7-
TrainRolloutMode,
88
WeightTransportType,
99
WeightUpdateBatch,
1010
)
1111
from .transport import (
12+
DiskBackendAdapter,
13+
DiskWeightTransport,
1214
IPCBackendAdapter,
1315
IPCWeightTransport,
16+
LMDeployDiskBackendAdapter,
1417
LMDeployIPCBackendAdapter,
1518
NCCLBackendAdapter,
1619
NCCLWeightTransport,
20+
SGLangDiskBackendAdapter,
1721
SGLangIPCBackendAdapter,
1822
SGLangNCCLBackendAdapter,
1923
WeightTransport,
@@ -24,19 +28,23 @@
2428

2529

2630
__all__ = [
31+
"DiskBackendAdapter",
32+
"DiskUpdateUpstreamTransport",
33+
"DiskWeightTransport",
2734
"DeviceMeshRaw",
2835
"IPCBackendAdapter",
2936
"IPCWeightTransport",
37+
"LMDeployDiskBackendAdapter",
3038
"LMDeployIPCBackendAdapter",
3139
"NCCLBackendAdapter",
3240
"NCCLWeightTransport",
3341
"RolloutBackend",
3442
"RolloutEngineInfo",
3543
"RolloutWeightUpdateInfo",
44+
"SGLangDiskBackendAdapter",
3645
"SGLangIPCBackendAdapter",
3746
"SGLangNCCLBackendAdapter",
3847
"ServiceUrlMap",
39-
"TrainRolloutMode",
4048
"UpdateWeighter",
4149
"WeightIterator",
4250
"WeightTransportType",

xtuner/v1/rl/weight_update/data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
DeviceMeshRaw: TypeAlias = List[List[int]] # A list of lists representing device mesh indices.
1111
ServiceUrlMap: TypeAlias = Dict[int, str] # A dictionary mapping rollout ranks to their server URLs.
1212
RolloutEngineInfo: TypeAlias = list[tuple[int, str, int]] # (rollout rank, server url, engine gpu count)
13-
TrainRolloutMode: TypeAlias = Literal["colocate", "disaggregated"] # Train and rollout deployment mode.
1413
RolloutBackend: TypeAlias = Literal["sglang", "vllm", "pytorch", "turbomind"] # Rollout inference backend.
15-
WeightTransportType: TypeAlias = Literal["ipc", "nccl"] # Supported weight transport types.
14+
WeightTransportType: TypeAlias = Literal["ipc", "nccl", "disk"] # Supported weight transport types.
15+
DiskUpdateUpstreamTransport: TypeAlias = Literal["ipc", "nccl"] # How disk-loaded weights are delivered to rollout.
1616

1717

1818
@dataclass
@@ -23,7 +23,6 @@ class RolloutWeightUpdateInfo:
2323
backend: RolloutBackend | None = None
2424
tp: int = 1
2525
ep: int = 1
26-
train_rollout_mode: TrainRolloutMode | None = None
2726
transport_type: WeightTransportType | None = None
2827
rollout_cfg_info: dict = field(default_factory=dict)
2928
endpoints: dict[str, str] = field(default_factory=lambda: {"update_weights": "update_weights"})
@@ -38,6 +37,10 @@ class RolloutWeightUpdateInfo:
3837
weight_update_host: str | None = None
3938
weight_update_port: int | None = None
4039

40+
# Disk update metadata.
41+
hf_weight_path: str | None = None
42+
disk_update_upstream_transport: DiskUpdateUpstreamTransport | None = None
43+
4144

4245
@dataclass
4346
class WeightUpdateBatch:

xtuner/v1/rl/weight_update/transport.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def ensure_nccl_weight_update_group(self):
739739

740740
def send(self, batch: WeightUpdateBatch) -> None:
741741
state_dict = batch.state_dict
742-
if not state_dict:
742+
if not state_dict and not batch.finished:
743743
return
744744

745745
train_sync_group = self.get_train_update_sync_group()
@@ -860,3 +860,156 @@ def teardown(self) -> None:
860860
self.group_name = None
861861
self.engine_urls = []
862862
self.external_group_world_size = None
863+
864+
865+
class DiskBackendAdapter:
866+
def update(self, weight_iterator: Any) -> None:
867+
raise NotImplementedError
868+
869+
def teardown(self) -> None:
870+
return
871+
872+
873+
class SGLangDiskBackendAdapter(DiskBackendAdapter):
874+
def __init__(self, *, rank: int, rollout_info: RolloutWeightUpdateInfo):
875+
self.rank = rank
876+
self.rollout_info = rollout_info
877+
self.executor: ThreadPoolExecutor | None = None
878+
879+
def build_request(self, hf_weight_path: str) -> WeightUpdateRequest:
880+
# SGLang already owns the disk reload path. XTuner only needs to pass
881+
# the HF checkpoint directory to the rollout server.
882+
return WeightUpdateRequest(
883+
endpoint="update_weights_from_disk",
884+
body={
885+
"model_path": hf_weight_path,
886+
"load_format": "safetensors",
887+
"abort_all_requests": True,
888+
"flush_cache": True,
889+
},
890+
)
891+
892+
def update(self, weight_iterator: Any) -> None:
893+
# SGLang consumes the checkpoint path on the rollout server side.
894+
del weight_iterator
895+
896+
hf_weight_path = self.rollout_info.hf_weight_path
897+
if not hf_weight_path:
898+
raise RuntimeError("Disk weight update requires rollout_info.hf_weight_path from rollout_config.")
899+
900+
try:
901+
if dist.get_rank() != 0:
902+
dist.barrier()
903+
return
904+
905+
target_urls = list(dict.fromkeys(url for url in self.rollout_info.rollout_server_url_dict.values() if url))
906+
if not target_urls:
907+
raise RuntimeError("Disk weight update requires at least one rollout server url.")
908+
request = self.build_request(hf_weight_path)
909+
self.executor = ThreadPoolExecutor(max_workers=max(1, len(target_urls)))
910+
futures = [
911+
self.executor.submit(
912+
WeightTransport.post_json,
913+
url,
914+
request.endpoint,
915+
request.body,
916+
api_key=self.rollout_info.api_key,
917+
)
918+
for url in target_urls
919+
]
920+
for future in futures:
921+
result = future.result()
922+
assert result.get("success", True), f"disk weight update failed: {result.get('message', result)}"
923+
dist.barrier()
924+
finally:
925+
self.teardown()
926+
DEVICE_MODULE.empty_cache()
927+
928+
def teardown(self) -> None:
929+
if self.executor is not None:
930+
self.executor.shutdown(wait=False, cancel_futures=True)
931+
self.executor = None
932+
933+
934+
class LMDeployDiskBackendAdapter(DiskBackendAdapter):
935+
def __init__(
936+
self,
937+
*,
938+
rank: int,
939+
logger: Any,
940+
rollout_info: RolloutWeightUpdateInfo,
941+
config: Any | None,
942+
upstream_transport: str,
943+
):
944+
self.upstream_transport = upstream_transport
945+
self._batch_transport = self._build_batch_transport(
946+
rank=rank,
947+
logger=logger,
948+
rollout_info=rollout_info,
949+
config=config,
950+
)
951+
952+
def _build_batch_transport(
953+
self,
954+
*,
955+
rank: int,
956+
logger: Any,
957+
rollout_info: RolloutWeightUpdateInfo,
958+
config: Any | None,
959+
) -> WeightTransport:
960+
if self.upstream_transport == "ipc":
961+
return IPCWeightTransport(
962+
rank=rank,
963+
logger=logger,
964+
config=config,
965+
rollout_info=rollout_info,
966+
)
967+
elif self.upstream_transport == "nccl":
968+
return NCCLWeightTransport(rank=rank, logger=logger, rollout_info=rollout_info)
969+
else:
970+
raise ValueError(f"Unsupported disk weight update upstream transport: {self.upstream_transport!r}")
971+
972+
def update(self, weight_iterator: Any) -> None:
973+
# WeightIterator.iter_batch_groups() switches disk mode to iter_disk_hf_batches().
974+
# The underlying LMDeploy transport then uses the existing tensor update endpoints.
975+
self._batch_transport.update(weight_iterator)
976+
977+
def teardown(self) -> None:
978+
self._batch_transport.teardown()
979+
980+
981+
class DiskWeightTransport(WeightTransport):
982+
_disk_adapter: DiskBackendAdapter
983+
984+
def __init__(self, *, rank: int, logger: Any, rollout_info: RolloutWeightUpdateInfo, config: Any | None = None):
985+
super().__init__(rank=rank, logger=logger, rollout_info=rollout_info)
986+
self.config = config
987+
self._disk_adapter = self._build_adapter()
988+
989+
def _build_adapter(self) -> DiskBackendAdapter:
990+
if self.backend == "sglang":
991+
return SGLangDiskBackendAdapter(rank=self.rank, rollout_info=self.rollout_info)
992+
elif self.backend == "pytorch":
993+
upstream_transport = self.rollout_info.disk_update_upstream_transport or "ipc"
994+
return LMDeployDiskBackendAdapter(
995+
rank=self.rank,
996+
logger=self.logger,
997+
config=self.config,
998+
rollout_info=self.rollout_info,
999+
upstream_transport=upstream_transport,
1000+
)
1001+
raise ValueError(f"Unsupported disk weight update backend: {self.backend!r}")
1002+
1003+
def update(self, weight_iterator: Any) -> None:
1004+
self._disk_adapter.update(weight_iterator)
1005+
1006+
def send(self, batch: WeightUpdateBatch) -> None:
1007+
raise NotImplementedError("DiskWeightTransport bypasses WeightIterator batches.")
1008+
1009+
def after_update_all_groups(self) -> None:
1010+
self._disk_adapter.teardown()
1011+
DEVICE_MODULE.empty_cache()
1012+
1013+
def teardown(self) -> None:
1014+
self._disk_adapter.teardown()
1015+
super().teardown()

0 commit comments

Comments
 (0)