Skip to content

Commit 57ee262

Browse files
authored
Merge branch 'master' into macos-binary
2 parents 752cdda + 33a9a41 commit 57ee262

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

hivemind/moe/server/runtime.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
from queue import SimpleQueue
77
from selectors import EVENT_READ, DefaultSelector
88
from statistics import mean
9-
from time import time
10-
from typing import Dict, NamedTuple, Optional
9+
from time import perf_counter
10+
from typing import Any, Dict, NamedTuple, Optional, Tuple
1111

1212
import torch
1313
from prefetch_generator import BackgroundGenerator
1414

1515
from hivemind.moe.server.module_backend import ModuleBackend
16+
from hivemind.moe.server.task_pool import TaskPoolBase
1617
from hivemind.utils import get_logger
1718

1819
logger = get_logger(__name__)
@@ -85,15 +86,11 @@ def run(self):
8586

8687
for pool, batch_index, batch in batch_iterator:
8788
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
88-
89-
start = time()
89+
start = perf_counter()
9090
try:
91-
outputs = pool.process_func(*batch)
92-
batch_processing_time = time() - start
93-
94-
batch_size = outputs[0].size(0)
91+
outputs, batch_size = self.process_batch(pool, batch_index, *batch)
92+
batch_processing_time = perf_counter() - start
9593
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
96-
9794
if self.stats_report_interval is not None:
9895
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
9996

@@ -108,6 +105,11 @@ def run(self):
108105
if not self.shutdown_trigger.is_set():
109106
self.shutdown()
110107

108+
def process_batch(self, pool: TaskPoolBase, batch_index: int, *batch: torch.Tensor) -> Tuple[Any, int]:
109+
"""process one batch of tasks from a given pool, return a batch of results and total batch size"""
110+
outputs = pool.process_func(*batch)
111+
return outputs, outputs[0].size(0)
112+
111113
def shutdown(self):
112114
"""Gracefully terminate a running runtime."""
113115
logger.info("Shutting down")

0 commit comments

Comments
 (0)