@@ -2087,32 +2087,33 @@ def recv_data_transfer_result(self):
20872087 logger .error (f"recv_data_transfer_result: { str (traceback .format_exc ())} " )
20882088 raise e
20892089
2090- def reset (self ):
2090+ def reset (self , wait_for_tasks_done = False ):
20912091 """
20922092 Reset the RadixTree.
20932093 """
2094- logger .info (f"wait for cache_task_inflight_signal to reset { self .cache_task_inflight_signal .value } " )
2095- while np .sum (self .cache_task_inflight_signal .value ) != 0 :
2096- time .sleep (0.1 )
20972094
2098- logger .info ("wait for recv_data_transfer_result done" )
2099- while not self .cache_task_queue .result_queue_empty ():
2100- time .sleep (0.1 )
2095+ if wait_for_tasks_done :
2096+ logger .info (f"wait for cache_task_inflight_signal to reset: { self .cache_task_inflight_signal .value } " )
2097+ while np .sum (self .cache_task_inflight_signal .value ) != 0 :
2098+ time .sleep (0.1 )
2099+
2100+ logger .info ("wait for recv_data_transfer_result done" )
2101+ while not self .cache_task_queue .result_queue_empty ():
2102+ time .sleep (0.1 )
2103+
2104+ logger .info ("wait for cpu_free_future to finish" )
2105+ if self .cpu_free_future is not None :
2106+ self .cpu_free_future .result ()
2107+
2108+ logger .info ("wait for gpu_free_task_future to finish" )
2109+ if self .gpu_free_task_future is not None :
2110+ self .gpu_free_task_future .result ()
21012111
21022112 logger .info (f"Resetting the RadixTree! node_map len { len (self .node_map )} " )
21032113
2104- logger .info ("waiting for cpu_free_future to finish" )
2105- if self .cpu_free_future is not None :
2106- self .cpu_free_future .result ()
2114+ # clear future & events
21072115 self .cpu_free_future = None
2108- logger .info ("reset cpu_free_future" )
2109-
2110- logger .info ("waiting for gpu_free_task_future to finish" )
2111- if self .gpu_free_task_future is not None :
2112- self .gpu_free_task_future .result ()
21132116 self .gpu_free_task_future = None
2114- logger .info ("reset gpu_free_task_future" )
2115-
21162117 self .task_swapping_event .clear ()
21172118
21182119 # clear node map
@@ -2157,11 +2158,11 @@ def clear_prefix_cache(self):
21572158 prefix_tree_status_signal = self .prefix_tree_status_signal
21582159 while True :
21592160 if prefix_tree_status_signal .value [0 ] == PrefixTreeStatus .CLEARING :
2160- self .reset ()
2161+ self .reset (wait_for_tasks_done = True )
21612162 prefix_tree_status_signal .value [0 ] = PrefixTreeStatus .CLEARED
21622163 logger .info ("Prefix cache tree is cleared." )
21632164 if prefix_tree_status_signal .value [0 ] == PrefixTreeStatus .UPDATING :
2164- self .reset ()
2165+ self .reset (wait_for_tasks_done = False )
21652166 prefix_tree_status_signal .value [0 ] = PrefixTreeStatus .NORMAL
21662167 logger .info ("Prefix cache tree is updated." )
21672168 time .sleep (0.01 )
0 commit comments