6
6
from queue import SimpleQueue
7
7
from selectors import EVENT_READ , DefaultSelector
8
8
from statistics import mean
9
- from time import time
10
- from typing import Dict , NamedTuple , Optional
9
+ from time import perf_counter
10
+ from typing import Any , Dict , NamedTuple , Optional , Tuple
11
11
12
12
import torch
13
13
from prefetch_generator import BackgroundGenerator
14
14
15
15
from hivemind .moe .server .module_backend import ModuleBackend
16
+ from hivemind .moe .server .task_pool import TaskPoolBase
16
17
from hivemind .utils import get_logger
17
18
18
19
logger = get_logger (__name__ )
@@ -85,15 +86,11 @@ def run(self):
85
86
86
87
for pool , batch_index , batch in batch_iterator :
87
88
logger .debug (f"Processing batch { batch_index } from pool { pool .name } " )
88
-
89
- start = time ()
89
+ start = perf_counter ()
90
90
try :
91
- outputs = pool .process_func (* batch )
92
- batch_processing_time = time () - start
93
-
94
- batch_size = outputs [0 ].size (0 )
91
+ outputs , batch_size = self .process_batch (pool , batch_index , * batch )
92
+ batch_processing_time = perf_counter () - start
95
93
logger .debug (f"Pool { pool .name } : batch { batch_index } processed, size { batch_size } " )
96
-
97
94
if self .stats_report_interval is not None :
98
95
self .stats_reporter .report_stats (pool .name , batch_size , batch_processing_time )
99
96
@@ -108,6 +105,11 @@ def run(self):
108
105
if not self .shutdown_trigger .is_set ():
109
106
self .shutdown ()
110
107
108
+ def process_batch (self , pool : TaskPoolBase , batch_index : int , * batch : torch .Tensor ) -> Tuple [Any , int ]:
109
+ """process one batch of tasks from a given pool, return a batch of results and total batch size"""
110
+ outputs = pool .process_func (* batch )
111
+ return outputs , outputs [0 ].size (0 )
112
+
111
113
def shutdown (self ):
112
114
"""Gracefully terminate a running runtime."""
113
115
logger .info ("Shutting down" )
0 commit comments