|
| 1 | +from collections import namedtuple |
| 2 | + |
| 3 | +from _pytest.runner import CollectReport |
| 4 | + |
| 5 | +from xdist.remote import Producer |
| 6 | +from xdist.workermanage import parse_spec_config |
| 7 | +from xdist.report import report_collection_diff |
| 8 | + |
| 9 | + |
| 10 | +NodePending = namedtuple("NodePending", ["node", "pending"]) |
| 11 | + |
| 12 | +# Every worker needs at least 2 tests in queue - the current and the next one. |
| 13 | +MIN_PENDING = 2 |
| 14 | + |
| 15 | + |
| 16 | +class WorkStealingScheduling: |
| 17 | + """Implement work-stealing scheduling. |
| 18 | +
|
| 19 | + Initially, tests are distributed evenly among all nodes. |
| 20 | +
|
| 21 | + When some node completes most of its assigned tests (when only one pending |
| 22 | + test remains), an attempt to reassign some tests to that node is made. |
| 23 | +
|
| 24 | + Attributes: |
| 25 | +
|
| 26 | + :numnodes: The expected number of nodes taking part. The actual |
| 27 | + number of nodes will vary during the scheduler's lifetime as |
| 28 | + nodes are added by the DSession as they are brought up and |
| 29 | + removed either because of a dead node or normal shutdown. This |
| 30 | + number is primarily used to know when the initial collection is |
| 31 | + completed. |
| 32 | +
|
| 33 | + :node2collection: Map of nodes and their test collection. All |
| 34 | + collections should always be identical. |
| 35 | +
|
| 36 | + :node2pending: Map of nodes and the indices of their pending |
| 37 | + tests. The indices are an index into ``.pending`` (which is |
| 38 | + identical to their own collection stored in |
| 39 | + ``.node2collection``). |
| 40 | +
|
| 41 | + :collection: The one collection once it is validated to be |
| 42 | + identical between all the nodes. It is initialised to None |
| 43 | + until ``.schedule()`` is called. |
| 44 | +
|
| 45 | + :pending: List of indices of globally pending tests. These are |
| 46 | + tests which have not yet been allocated to a chunk for a node |
| 47 | + to process. |
| 48 | +
|
| 49 | + :log: A py.log.Producer instance. |
| 50 | +
|
| 51 | + :config: Config object, used for handling hooks. |
| 52 | + """ |
| 53 | + |
| 54 | + def __init__(self, config, log=None): |
| 55 | + self.numnodes = len(parse_spec_config(config)) |
| 56 | + self.node2collection = {} |
| 57 | + self.node2pending = {} |
| 58 | + self.pending = [] |
| 59 | + self.collection = None |
| 60 | + if log is None: |
| 61 | + self.log = Producer("workstealsched") |
| 62 | + else: |
| 63 | + self.log = log.workstealsched |
| 64 | + self.config = config |
| 65 | + self.steal_requested = None |
| 66 | + |
| 67 | + @property |
| 68 | + def nodes(self): |
| 69 | + """A list of all nodes in the scheduler.""" |
| 70 | + return list(self.node2pending.keys()) |
| 71 | + |
| 72 | + @property |
| 73 | + def collection_is_completed(self): |
| 74 | + """Boolean indication initial test collection is complete. |
| 75 | +
|
| 76 | + This is a boolean indicating all initial participating nodes |
| 77 | + have finished collection. The required number of initial |
| 78 | + nodes is defined by ``.numnodes``. |
| 79 | + """ |
| 80 | + return len(self.node2collection) >= self.numnodes |
| 81 | + |
| 82 | + @property |
| 83 | + def tests_finished(self): |
| 84 | + """Return True if all tests have been executed by the nodes.""" |
| 85 | + if not self.collection_is_completed: |
| 86 | + return False |
| 87 | + if self.pending: |
| 88 | + return False |
| 89 | + if self.steal_requested is not None: |
| 90 | + return False |
| 91 | + for pending in self.node2pending.values(): |
| 92 | + if len(pending) >= MIN_PENDING: |
| 93 | + return False |
| 94 | + return True |
| 95 | + |
| 96 | + @property |
| 97 | + def has_pending(self): |
| 98 | + """Return True if there are pending test items |
| 99 | +
|
| 100 | + This indicates that collection has finished and nodes are |
| 101 | + still processing test items, so this can be thought of as |
| 102 | + "the scheduler is active". |
| 103 | + """ |
| 104 | + if self.pending: |
| 105 | + return True |
| 106 | + for pending in self.node2pending.values(): |
| 107 | + if pending: |
| 108 | + return True |
| 109 | + return False |
| 110 | + |
| 111 | + def add_node(self, node): |
| 112 | + """Add a new node to the scheduler. |
| 113 | +
|
| 114 | + From now on the node will be allocated chunks of tests to |
| 115 | + execute. |
| 116 | +
|
| 117 | + Called by the ``DSession.worker_workerready`` hook when it |
| 118 | + successfully bootstraps a new node. |
| 119 | + """ |
| 120 | + assert node not in self.node2pending |
| 121 | + self.node2pending[node] = [] |
| 122 | + |
| 123 | + def add_node_collection(self, node, collection): |
| 124 | + """Add the collected test items from a node |
| 125 | +
|
| 126 | + The collection is stored in the ``.node2collection`` map. |
| 127 | + Called by the ``DSession.worker_collectionfinish`` hook. |
| 128 | + """ |
| 129 | + assert node in self.node2pending |
| 130 | + if self.collection_is_completed: |
| 131 | + # A new node has been added later, perhaps an original one died. |
| 132 | + # .schedule() should have |
| 133 | + # been called by now |
| 134 | + assert self.collection |
| 135 | + if collection != self.collection: |
| 136 | + other_node = next(iter(self.node2collection.keys())) |
| 137 | + msg = report_collection_diff( |
| 138 | + self.collection, collection, other_node.gateway.id, node.gateway.id |
| 139 | + ) |
| 140 | + self.log(msg) |
| 141 | + return |
| 142 | + self.node2collection[node] = list(collection) |
| 143 | + |
| 144 | + def mark_test_complete(self, node, item_index, duration=None): |
| 145 | + """Mark test item as completed by node |
| 146 | +
|
| 147 | + This is called by the ``DSession.worker_testreport`` hook. |
| 148 | + """ |
| 149 | + self.node2pending[node].remove(item_index) |
| 150 | + self.check_schedule() |
| 151 | + |
| 152 | + def mark_test_pending(self, item): |
| 153 | + self.pending.insert( |
| 154 | + 0, |
| 155 | + self.collection.index(item), |
| 156 | + ) |
| 157 | + self.check_schedule() |
| 158 | + |
| 159 | + def remove_pending_tests_from_node(self, node, indices): |
| 160 | + """Node returned some test indices back in response to 'steal' command. |
| 161 | +
|
| 162 | + This is called by ``DSession.worker_unscheduled``. |
| 163 | + """ |
| 164 | + assert node is self.steal_requested |
| 165 | + self.steal_requested = None |
| 166 | + |
| 167 | + indices_set = set(indices) |
| 168 | + self.node2pending[node] = [ |
| 169 | + i for i in self.node2pending[node] if i not in indices_set |
| 170 | + ] |
| 171 | + self.pending.extend(indices) |
| 172 | + self.check_schedule() |
| 173 | + |
| 174 | + def check_schedule(self): |
| 175 | + """Reschedule tests/perform load balancing.""" |
| 176 | + nodes_up = [ |
| 177 | + NodePending(node, pending) |
| 178 | + for node, pending in self.node2pending.items() |
| 179 | + if not node.shutting_down |
| 180 | + ] |
| 181 | + |
| 182 | + def get_idle_nodes(): |
| 183 | + return [node for node, pending in nodes_up if len(pending) < MIN_PENDING] |
| 184 | + |
| 185 | + idle_nodes = get_idle_nodes() |
| 186 | + if not idle_nodes: |
| 187 | + return |
| 188 | + |
| 189 | + if self.pending: |
| 190 | + # Distribute pending tests evenly among idle nodes |
| 191 | + for i, node in enumerate(idle_nodes): |
| 192 | + nodes_remaining = len(idle_nodes) - i |
| 193 | + num_send = len(self.pending) // nodes_remaining |
| 194 | + self._send_tests(node, num_send) |
| 195 | + |
| 196 | + idle_nodes = get_idle_nodes() |
| 197 | + # No need to steal anything if all nodes have enough work to continue |
| 198 | + if not idle_nodes: |
| 199 | + return |
| 200 | + |
| 201 | + # Only one active stealing request is allowed |
| 202 | + if self.steal_requested is not None: |
| 203 | + return |
| 204 | + |
| 205 | + # Find the node that has the longest test queue |
| 206 | + steal_from = max( |
| 207 | + nodes_up, key=lambda node_pending: len(node_pending.pending), default=None |
| 208 | + ) |
| 209 | + |
| 210 | + if steal_from is None: |
| 211 | + num_steal = 0 |
| 212 | + else: |
| 213 | + # Steal half of the test queue - but keep that node running too. |
| 214 | + # If the node has 2 or less tests queued, stealing will fail |
| 215 | + # anyway. |
| 216 | + max_steal = max(0, len(steal_from.pending) - MIN_PENDING) |
| 217 | + num_steal = min(len(steal_from.pending) // 2, max_steal) |
| 218 | + |
| 219 | + if num_steal == 0: |
| 220 | + # Can't get more work - shutdown idle nodes. This will force them |
| 221 | + # to run the last test now instead of waiting for more tests. |
| 222 | + for node in idle_nodes: |
| 223 | + node.shutdown() |
| 224 | + return |
| 225 | + |
| 226 | + steal_from.node.send_steal(steal_from.pending[-num_steal:]) |
| 227 | + self.steal_requested = steal_from.node |
| 228 | + |
| 229 | + def remove_node(self, node): |
| 230 | + """Remove a node from the scheduler |
| 231 | +
|
| 232 | + This should be called either when the node crashed or at |
| 233 | + shutdown time. In the former case any pending items assigned |
| 234 | + to the node will be re-scheduled. Called by the |
| 235 | + ``DSession.worker_workerfinished`` and |
| 236 | + ``DSession.worker_errordown`` hooks. |
| 237 | +
|
| 238 | + Return the item which was being executing while the node |
| 239 | + crashed or None if the node has no more pending items. |
| 240 | +
|
| 241 | + """ |
| 242 | + pending = self.node2pending.pop(node) |
| 243 | + |
| 244 | + # If node was removed without completing its assigned tests - it crashed |
| 245 | + if pending: |
| 246 | + crashitem = self.collection[pending.pop(0)] |
| 247 | + else: |
| 248 | + crashitem = None |
| 249 | + |
| 250 | + self.pending.extend(pending) |
| 251 | + |
| 252 | + # Dead node won't respond to "steal" request |
| 253 | + if self.steal_requested is node: |
| 254 | + self.steal_requested = None |
| 255 | + |
| 256 | + self.check_schedule() |
| 257 | + return crashitem |
| 258 | + |
| 259 | + def schedule(self): |
| 260 | + """Initiate distribution of the test collection |
| 261 | +
|
| 262 | + Initiate scheduling of the items across the nodes. If this |
| 263 | + gets called again later it behaves the same as calling |
| 264 | + ``.check_schedule()`` on all nodes so that newly added nodes |
| 265 | + will start to be used. |
| 266 | +
|
| 267 | + This is called by the ``DSession.worker_collectionfinish`` hook |
| 268 | + if ``.collection_is_completed`` is True. |
| 269 | + """ |
| 270 | + assert self.collection_is_completed |
| 271 | + |
| 272 | + # Initial distribution already happened, reschedule on all nodes |
| 273 | + if self.collection is not None: |
| 274 | + self.check_schedule() |
| 275 | + return |
| 276 | + |
| 277 | + if not self._check_nodes_have_same_collection(): |
| 278 | + self.log("**Different tests collected, aborting run**") |
| 279 | + return |
| 280 | + |
| 281 | + # Collections are identical, create the index of pending items. |
| 282 | + self.collection = list(self.node2collection.values())[0] |
| 283 | + self.pending[:] = range(len(self.collection)) |
| 284 | + if not self.collection: |
| 285 | + return |
| 286 | + |
| 287 | + self.check_schedule() |
| 288 | + |
| 289 | + def _send_tests(self, node, num): |
| 290 | + tests_per_node = self.pending[:num] |
| 291 | + if tests_per_node: |
| 292 | + del self.pending[:num] |
| 293 | + self.node2pending[node].extend(tests_per_node) |
| 294 | + node.send_runtest_some(tests_per_node) |
| 295 | + |
| 296 | + def _check_nodes_have_same_collection(self): |
| 297 | + """Return True if all nodes have collected the same items. |
| 298 | +
|
| 299 | + If collections differ, this method returns False while logging |
| 300 | + the collection differences and posting collection errors to |
| 301 | + pytest_collectreport hook. |
| 302 | + """ |
| 303 | + node_collection_items = list(self.node2collection.items()) |
| 304 | + first_node, col = node_collection_items[0] |
| 305 | + same_collection = True |
| 306 | + for node, collection in node_collection_items[1:]: |
| 307 | + msg = report_collection_diff( |
| 308 | + col, collection, first_node.gateway.id, node.gateway.id |
| 309 | + ) |
| 310 | + if msg: |
| 311 | + same_collection = False |
| 312 | + self.log(msg) |
| 313 | + if self.config is not None: |
| 314 | + rep = CollectReport( |
| 315 | + node.gateway.id, "failed", longrepr=msg, result=[] |
| 316 | + ) |
| 317 | + self.config.hook.pytest_collectreport(report=rep) |
| 318 | + |
| 319 | + return same_collection |
0 commit comments