@@ -739,7 +739,7 @@ def ensure_nccl_weight_update_group(self):
739739
740740 def send (self , batch : WeightUpdateBatch ) -> None :
741741 state_dict = batch .state_dict
742- if not state_dict :
742+ if not state_dict and not batch . finished :
743743 return
744744
745745 train_sync_group = self .get_train_update_sync_group ()
@@ -860,3 +860,156 @@ def teardown(self) -> None:
860860 self .group_name = None
861861 self .engine_urls = []
862862 self .external_group_world_size = None
863+
864+
865+ class DiskBackendAdapter :
866+ def update (self , weight_iterator : Any ) -> None :
867+ raise NotImplementedError
868+
869+ def teardown (self ) -> None :
870+ return
871+
872+
873+ class SGLangDiskBackendAdapter (DiskBackendAdapter ):
874+ def __init__ (self , * , rank : int , rollout_info : RolloutWeightUpdateInfo ):
875+ self .rank = rank
876+ self .rollout_info = rollout_info
877+ self .executor : ThreadPoolExecutor | None = None
878+
879+ def build_request (self , hf_weight_path : str ) -> WeightUpdateRequest :
880+ # SGLang already owns the disk reload path. XTuner only needs to pass
881+ # the HF checkpoint directory to the rollout server.
882+ return WeightUpdateRequest (
883+ endpoint = "update_weights_from_disk" ,
884+ body = {
885+ "model_path" : hf_weight_path ,
886+ "load_format" : "safetensors" ,
887+ "abort_all_requests" : True ,
888+ "flush_cache" : True ,
889+ },
890+ )
891+
892+ def update (self , weight_iterator : Any ) -> None :
893+ # SGLang consumes the checkpoint path on the rollout server side.
894+ del weight_iterator
895+
896+ hf_weight_path = self .rollout_info .hf_weight_path
897+ if not hf_weight_path :
898+ raise RuntimeError ("Disk weight update requires rollout_info.hf_weight_path from rollout_config." )
899+
900+ try :
901+ if dist .get_rank () != 0 :
902+ dist .barrier ()
903+ return
904+
905+ target_urls = list (dict .fromkeys (url for url in self .rollout_info .rollout_server_url_dict .values () if url ))
906+ if not target_urls :
907+ raise RuntimeError ("Disk weight update requires at least one rollout server url." )
908+ request = self .build_request (hf_weight_path )
909+ self .executor = ThreadPoolExecutor (max_workers = max (1 , len (target_urls )))
910+ futures = [
911+ self .executor .submit (
912+ WeightTransport .post_json ,
913+ url ,
914+ request .endpoint ,
915+ request .body ,
916+ api_key = self .rollout_info .api_key ,
917+ )
918+ for url in target_urls
919+ ]
920+ for future in futures :
921+ result = future .result ()
922+ assert result .get ("success" , True ), f"disk weight update failed: { result .get ('message' , result )} "
923+ dist .barrier ()
924+ finally :
925+ self .teardown ()
926+ DEVICE_MODULE .empty_cache ()
927+
928+ def teardown (self ) -> None :
929+ if self .executor is not None :
930+ self .executor .shutdown (wait = False , cancel_futures = True )
931+ self .executor = None
932+
933+
934+ class LMDeployDiskBackendAdapter (DiskBackendAdapter ):
935+ def __init__ (
936+ self ,
937+ * ,
938+ rank : int ,
939+ logger : Any ,
940+ rollout_info : RolloutWeightUpdateInfo ,
941+ config : Any | None ,
942+ upstream_transport : str ,
943+ ):
944+ self .upstream_transport = upstream_transport
945+ self ._batch_transport = self ._build_batch_transport (
946+ rank = rank ,
947+ logger = logger ,
948+ rollout_info = rollout_info ,
949+ config = config ,
950+ )
951+
952+ def _build_batch_transport (
953+ self ,
954+ * ,
955+ rank : int ,
956+ logger : Any ,
957+ rollout_info : RolloutWeightUpdateInfo ,
958+ config : Any | None ,
959+ ) -> WeightTransport :
960+ if self .upstream_transport == "ipc" :
961+ return IPCWeightTransport (
962+ rank = rank ,
963+ logger = logger ,
964+ config = config ,
965+ rollout_info = rollout_info ,
966+ )
967+ elif self .upstream_transport == "nccl" :
968+ return NCCLWeightTransport (rank = rank , logger = logger , rollout_info = rollout_info )
969+ else :
970+ raise ValueError (f"Unsupported disk weight update upstream transport: { self .upstream_transport !r} " )
971+
972+ def update (self , weight_iterator : Any ) -> None :
973+ # WeightIterator.iter_batch_groups() switches disk mode to iter_disk_hf_batches().
974+ # The underlying LMDeploy transport then uses the existing tensor update endpoints.
975+ self ._batch_transport .update (weight_iterator )
976+
977+ def teardown (self ) -> None :
978+ self ._batch_transport .teardown ()
979+
980+
981+ class DiskWeightTransport (WeightTransport ):
982+ _disk_adapter : DiskBackendAdapter
983+
984+ def __init__ (self , * , rank : int , logger : Any , rollout_info : RolloutWeightUpdateInfo , config : Any | None = None ):
985+ super ().__init__ (rank = rank , logger = logger , rollout_info = rollout_info )
986+ self .config = config
987+ self ._disk_adapter = self ._build_adapter ()
988+
989+ def _build_adapter (self ) -> DiskBackendAdapter :
990+ if self .backend == "sglang" :
991+ return SGLangDiskBackendAdapter (rank = self .rank , rollout_info = self .rollout_info )
992+ elif self .backend == "pytorch" :
993+ upstream_transport = self .rollout_info .disk_update_upstream_transport or "ipc"
994+ return LMDeployDiskBackendAdapter (
995+ rank = self .rank ,
996+ logger = self .logger ,
997+ config = self .config ,
998+ rollout_info = self .rollout_info ,
999+ upstream_transport = upstream_transport ,
1000+ )
1001+ raise ValueError (f"Unsupported disk weight update backend: { self .backend !r} " )
1002+
1003+ def update (self , weight_iterator : Any ) -> None :
1004+ self ._disk_adapter .update (weight_iterator )
1005+
1006+ def send (self , batch : WeightUpdateBatch ) -> None :
1007+ raise NotImplementedError ("DiskWeightTransport bypasses WeightIterator batches." )
1008+
1009+ def after_update_all_groups (self ) -> None :
1010+ self ._disk_adapter .teardown ()
1011+ DEVICE_MODULE .empty_cache ()
1012+
1013+ def teardown (self ) -> None :
1014+ self ._disk_adapter .teardown ()
1015+ super ().teardown ()
0 commit comments