@@ -95,6 +95,9 @@ class QueueManager(BaseManager):
9595 self .finish_request_barrier = [
9696 threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
9797 ]
98+ self .worker_process_tp_barrier = [
99+ threading .Barrier (self .num_client ) for _ in range (self .local_data_parallel_size )
100+ ]
98101
99102 # Register shared objects with proxy types
100103 QueueManager .register (
@@ -161,6 +164,10 @@ class QueueManager(BaseManager):
161164 "get_finish_request_barrier" ,
162165 callable = lambda idx : self .finish_request_barrier [idx ],
163166 )
167+ QueueManager .register (
168+ "get_worker_process_tp_barrier" ,
169+ callable = lambda idx : self .worker_process_tp_barrier [idx ],
170+ )
164171 self .manager : BaseManager = QueueManager (address = self .address , authkey = self .authkey )
165172 self .manager .start ()
166173 else :
@@ -180,6 +187,7 @@ class QueueManager(BaseManager):
180187 QueueManager .register ("get_disaggregate_requests" )
181188 QueueManager .register ("get_available_prefill_instances" )
182189 QueueManager .register ("get_finish_request_barrier" )
190+ QueueManager .register ("get_worker_process_tp_barrier" )
183191 self .manager = QueueManager (address = self .address , authkey = self .authkey )
184192 self ._connect_with_retry ()
185193
@@ -199,6 +207,7 @@ class QueueManager(BaseManager):
199207 self .disaggregate_requests = self .manager .get_disaggregate_requests (self .local_data_parallel_id )
200208 self .available_prefill_instances = self .manager .get_available_prefill_instances ()
201209 self .finish_request_barrier = self .manager .get_finish_request_barrier (self .local_data_parallel_id )
210+ self .worker_process_tp_barrier = self .manager .get_worker_process_tp_barrier (self .local_data_parallel_id )
202211 self .finished_req_queue = self .manager .get_finish_request_queue (self .local_data_parallel_id )
203212 assert self .num_client == len (self .client_read_flag )
204213
0 commit comments