Skip to content

Commit 361d8dd

Browse files
committed
Implement work-stealing scheduling
Closes pytest-dev#858
1 parent c81ac4d commit 361d8dd

File tree

7 files changed

+518
-2
lines changed

7 files changed

+518
-2
lines changed

changelog/858.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
New ``worksteal`` scheduler, based on the idea of `work stealing <https://en.wikipedia.org/wiki/Work_stealing>`_. It's similar to ``load`` scheduler, but it should handle tests with significantly differing duration better, and, at the same time, it should provide similar or better reuse of fixtures.

src/xdist/dsession.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
LoadScopeScheduling,
99
LoadFileScheduling,
1010
LoadGroupScheduling,
11+
WorkStealingScheduling,
1112
)
1213

1314

@@ -100,6 +101,7 @@ def pytest_xdist_make_scheduler(self, config, log):
100101
"loadscope": LoadScopeScheduling,
101102
"loadfile": LoadFileScheduling,
102103
"loadgroup": LoadGroupScheduling,
104+
"worksteal": WorkStealingScheduling,
103105
}
104106
return schedulers[dist](config, log)
105107

@@ -282,6 +284,17 @@ def worker_runtest_protocol_complete(self, node, item_index, duration):
282284
"""
283285
self.sched.mark_test_complete(node, item_index, duration)
284286

287+
def worker_unscheduled(self, node, indices):
288+
"""
289+
Emitted when a node fires the 'unscheduled' event, signalling that
290+
some tests have been removed from the worker's queue and should be
291+
sent to some worker again.
292+
293+
This should happen only in response to 'steal' command, so schedulers
294+
not using 'steal' command don't have to implement it.
295+
"""
296+
self.sched.remove_pending_tests_from_node(node, indices)
297+
285298
def worker_collectreport(self, node, rep):
286299
"""Emitted when a node calls the pytest_collectreport hook.
287300

src/xdist/plugin.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,15 @@ def pytest_addoption(parser):
9494
"--dist",
9595
metavar="distmode",
9696
action="store",
97-
choices=["each", "load", "loadscope", "loadfile", "loadgroup", "no"],
97+
choices=[
98+
"each",
99+
"load",
100+
"loadscope",
101+
"loadfile",
102+
"loadgroup",
103+
"worksteal",
104+
"no",
105+
],
98106
dest="dist",
99107
default="no",
100108
help=(
@@ -107,6 +115,8 @@ def pytest_addoption(parser):
107115
"loadfile: load balance by sending test grouped by file"
108116
" to any available environment.\n\n"
109117
"loadgroup: like load, but sends tests marked with 'xdist_group' to the same worker.\n\n"
118+
"worksteal: split the test suite between available environments,"
119+
" then rebalance when any worker runs out of tests.\n\n"
110120
"(default) no: run tests inprocess, don't distribute."
111121
),
112122
)

src/xdist/scheduler/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from xdist.scheduler.loadfile import LoadFileScheduling # noqa
44
from xdist.scheduler.loadscope import LoadScopeScheduling # noqa
55
from xdist.scheduler.loadgroup import LoadGroupScheduling # noqa
6+
from xdist.scheduler.worksteal import WorkStealingScheduling # noqa

src/xdist/scheduler/worksteal.py

Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
from collections import namedtuple
2+
3+
from _pytest.runner import CollectReport
4+
5+
from xdist.remote import Producer
6+
from xdist.workermanage import parse_spec_config
7+
from xdist.report import report_collection_diff
8+
9+
10+
NodePending = namedtuple("NodePending", ["node", "pending"])
11+
12+
# Every worker needs at least 2 tests in queue - the current and the next one.
13+
MIN_PENDING = 2
14+
15+
16+
class WorkStealingScheduling:
17+
"""Implement work-stealing scheduling.
18+
19+
Initially, tests are distributed evenly among all nodes.
20+
21+
When some node completes most of its assigned tests (when only one pending
22+
test remains), an attempt to reassign some tests to that node is made.
23+
24+
Attributes:
25+
26+
:numnodes: The expected number of nodes taking part. The actual
27+
number of nodes will vary during the scheduler's lifetime as
28+
nodes are added by the DSession as they are brought up and
29+
removed either because of a dead node or normal shutdown. This
30+
number is primarily used to know when the initial collection is
31+
completed.
32+
33+
:node2collection: Map of nodes and their test collection. All
34+
collections should always be identical.
35+
36+
:node2pending: Map of nodes and the indices of their pending
37+
tests. The indices are an index into ``.pending`` (which is
38+
identical to their own collection stored in
39+
``.node2collection``).
40+
41+
:collection: The one collection once it is validated to be
42+
identical between all the nodes. It is initialised to None
43+
until ``.schedule()`` is called.
44+
45+
:pending: List of indices of globally pending tests. These are
46+
tests which have not yet been allocated to a chunk for a node
47+
to process.
48+
49+
:log: A py.log.Producer instance.
50+
51+
:config: Config object, used for handling hooks.
52+
"""
53+
54+
def __init__(self, config, log=None):
55+
self.numnodes = len(parse_spec_config(config))
56+
self.node2collection = {}
57+
self.node2pending = {}
58+
self.pending = []
59+
self.collection = None
60+
if log is None:
61+
self.log = Producer("workstealsched")
62+
else:
63+
self.log = log.workstealsched
64+
self.config = config
65+
self.steal_requested = None
66+
67+
@property
68+
def nodes(self):
69+
"""A list of all nodes in the scheduler."""
70+
return list(self.node2pending.keys())
71+
72+
@property
73+
def collection_is_completed(self):
74+
"""Boolean indication initial test collection is complete.
75+
76+
This is a boolean indicating all initial participating nodes
77+
have finished collection. The required number of initial
78+
nodes is defined by ``.numnodes``.
79+
"""
80+
return len(self.node2collection) >= self.numnodes
81+
82+
@property
83+
def tests_finished(self):
84+
"""Return True if all tests have been executed by the nodes."""
85+
if not self.collection_is_completed:
86+
return False
87+
if self.pending:
88+
return False
89+
if self.steal_requested is not None:
90+
return False
91+
for pending in self.node2pending.values():
92+
if len(pending) >= MIN_PENDING:
93+
return False
94+
return True
95+
96+
@property
97+
def has_pending(self):
98+
"""Return True if there are pending test items
99+
100+
This indicates that collection has finished and nodes are
101+
still processing test items, so this can be thought of as
102+
"the scheduler is active".
103+
"""
104+
if self.pending:
105+
return True
106+
for pending in self.node2pending.values():
107+
if pending:
108+
return True
109+
return False
110+
111+
def add_node(self, node):
112+
"""Add a new node to the scheduler.
113+
114+
From now on the node will be allocated chunks of tests to
115+
execute.
116+
117+
Called by the ``DSession.worker_workerready`` hook when it
118+
successfully bootstraps a new node.
119+
"""
120+
assert node not in self.node2pending
121+
self.node2pending[node] = []
122+
123+
def add_node_collection(self, node, collection):
124+
"""Add the collected test items from a node
125+
126+
The collection is stored in the ``.node2collection`` map.
127+
Called by the ``DSession.worker_collectionfinish`` hook.
128+
"""
129+
assert node in self.node2pending
130+
if self.collection_is_completed:
131+
# A new node has been added later, perhaps an original one died.
132+
# .schedule() should have
133+
# been called by now
134+
assert self.collection
135+
if collection != self.collection:
136+
other_node = next(iter(self.node2collection.keys()))
137+
msg = report_collection_diff(
138+
self.collection, collection, other_node.gateway.id, node.gateway.id
139+
)
140+
self.log(msg)
141+
return
142+
self.node2collection[node] = list(collection)
143+
144+
def mark_test_complete(self, node, item_index, duration=None):
145+
"""Mark test item as completed by node
146+
147+
This is called by the ``DSession.worker_testreport`` hook.
148+
"""
149+
self.node2pending[node].remove(item_index)
150+
self.check_schedule()
151+
152+
def mark_test_pending(self, item):
153+
self.pending.insert(
154+
0,
155+
self.collection.index(item),
156+
)
157+
self.check_schedule()
158+
159+
def remove_pending_tests_from_node(self, node, indices):
160+
"""Node returned some test indices back in response to 'steal' command.
161+
162+
This is called by ``DSession.worker_unscheduled``.
163+
"""
164+
assert node is self.steal_requested
165+
self.steal_requested = None
166+
167+
indices_set = set(indices)
168+
self.node2pending[node] = [
169+
i for i in self.node2pending[node] if i not in indices_set
170+
]
171+
self.pending.extend(indices)
172+
self.check_schedule()
173+
174+
def check_schedule(self):
175+
"""Reschedule tests/perform load balancing."""
176+
nodes_up = [
177+
NodePending(node, pending)
178+
for node, pending in self.node2pending.items()
179+
if not node.shutting_down
180+
]
181+
182+
def get_idle_nodes():
183+
return [node for node, pending in nodes_up if len(pending) < MIN_PENDING]
184+
185+
idle_nodes = get_idle_nodes()
186+
if not idle_nodes:
187+
return
188+
189+
if self.pending:
190+
# Distribute pending tests evenly among idle nodes
191+
for i, node in enumerate(idle_nodes):
192+
nodes_remaining = len(idle_nodes) - i
193+
num_send = len(self.pending) // nodes_remaining
194+
self._send_tests(node, num_send)
195+
196+
idle_nodes = get_idle_nodes()
197+
# No need to steal anything if all nodes have enough work to continue
198+
if not idle_nodes:
199+
return
200+
201+
# Only one active stealing request is allowed
202+
if self.steal_requested is not None:
203+
return
204+
205+
# Find the node that has the longest test queue
206+
steal_from = max(
207+
nodes_up, key=lambda node_pending: len(node_pending.pending), default=None
208+
)
209+
210+
if steal_from is None:
211+
num_steal = 0
212+
else:
213+
# Steal half of the test queue - but keep that node running too.
214+
# If the node has 2 or less tests queued, stealing will fail
215+
# anyway.
216+
max_steal = max(0, len(steal_from.pending) - MIN_PENDING)
217+
num_steal = min(len(steal_from.pending) // 2, max_steal)
218+
219+
if num_steal == 0:
220+
# Can't get more work - shutdown idle nodes. This will force them
221+
# to run the last test now instead of waiting for more tests.
222+
for node in idle_nodes:
223+
node.shutdown()
224+
return
225+
226+
steal_from.node.send_steal(steal_from.pending[-num_steal:])
227+
self.steal_requested = steal_from.node
228+
229+
def remove_node(self, node):
230+
"""Remove a node from the scheduler
231+
232+
This should be called either when the node crashed or at
233+
shutdown time. In the former case any pending items assigned
234+
to the node will be re-scheduled. Called by the
235+
``DSession.worker_workerfinished`` and
236+
``DSession.worker_errordown`` hooks.
237+
238+
Return the item which was being executing while the node
239+
crashed or None if the node has no more pending items.
240+
241+
"""
242+
pending = self.node2pending.pop(node)
243+
244+
# If node was removed without completing its assigned tests - it crashed
245+
if pending:
246+
crashitem = self.collection[pending.pop(0)]
247+
else:
248+
crashitem = None
249+
250+
self.pending.extend(pending)
251+
252+
# Dead node won't respond to "steal" request
253+
if self.steal_requested is node:
254+
self.steal_requested = None
255+
256+
self.check_schedule()
257+
return crashitem
258+
259+
def schedule(self):
260+
"""Initiate distribution of the test collection
261+
262+
Initiate scheduling of the items across the nodes. If this
263+
gets called again later it behaves the same as calling
264+
``.check_schedule()`` on all nodes so that newly added nodes
265+
will start to be used.
266+
267+
This is called by the ``DSession.worker_collectionfinish`` hook
268+
if ``.collection_is_completed`` is True.
269+
"""
270+
assert self.collection_is_completed
271+
272+
# Initial distribution already happened, reschedule on all nodes
273+
if self.collection is not None:
274+
self.check_schedule()
275+
return
276+
277+
if not self._check_nodes_have_same_collection():
278+
self.log("**Different tests collected, aborting run**")
279+
return
280+
281+
# Collections are identical, create the index of pending items.
282+
self.collection = list(self.node2collection.values())[0]
283+
self.pending[:] = range(len(self.collection))
284+
if not self.collection:
285+
return
286+
287+
self.check_schedule()
288+
289+
def _send_tests(self, node, num):
290+
tests_per_node = self.pending[:num]
291+
if tests_per_node:
292+
del self.pending[:num]
293+
self.node2pending[node].extend(tests_per_node)
294+
node.send_runtest_some(tests_per_node)
295+
296+
def _check_nodes_have_same_collection(self):
297+
"""Return True if all nodes have collected the same items.
298+
299+
If collections differ, this method returns False while logging
300+
the collection differences and posting collection errors to
301+
pytest_collectreport hook.
302+
"""
303+
node_collection_items = list(self.node2collection.items())
304+
first_node, col = node_collection_items[0]
305+
same_collection = True
306+
for node, collection in node_collection_items[1:]:
307+
msg = report_collection_diff(
308+
col, collection, first_node.gateway.id, node.gateway.id
309+
)
310+
if msg:
311+
same_collection = False
312+
self.log(msg)
313+
if self.config is not None:
314+
rep = CollectReport(
315+
node.gateway.id, "failed", longrepr=msg, result=[]
316+
)
317+
self.config.hook.pytest_collectreport(report=rep)
318+
319+
return same_collection

0 commit comments

Comments
 (0)