Skip to content

Commit a19b71d

Browse files
committed
[Step] Add Items distribution (array/pools)
1 parent 18cbe6d commit a19b71d

8 files changed

Lines changed: 221 additions & 146 deletions

File tree

exca/cachedict/inflight.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -356,28 +356,18 @@ def wait_for_inflight(
356356
def inflight_session(
357357
reg: InflightRegistry | None,
358358
item_uids: tp.Collection[str],
359-
*,
360-
local: bool = False,
361359
) -> tp.Iterator[list[str]]:
362360
"""Wait for in-flight items, claim available ones, release+close on exit.
363361
364362
When *reg* is ``None`` (no cache folder), yields all *item_uids*
365363
unchanged so that callers never need a ``None`` guard.
366364
367-
Parameters
368-
----------
369-
local:
370-
Set to ``True`` when items will be processed locally (no Slurm
371-
submission). Stamps claims with ``_LOCAL_JOB_ID`` so ``is_alive``
372-
can distinguish "local work in progress" from "Slurm submission
373-
that never completed ``update_worker_info``".
374-
375365
Self-deadlock is prevented internally: ``wait_for_inflight`` skips items
376366
owned by the current PID, and ``claim`` treats same-PID rows as already
377367
ours.
378368
379-
The registry connection is closed on exit; callers must perform any
380-
``record_worker_info`` calls inside the ``with`` block.
369+
Callers should call ``record_worker_info`` inside the ``with`` block
370+
to stamp claimed items with an appropriate liveness signal.
381371
"""
382372
if reg is None:
383373
yield list(item_uids)
@@ -406,8 +396,6 @@ def inflight_session(
406396
msg = "Claim race: got %d/%d items, re-waiting"
407397
logger.info(msg, len(claimed), len(item_uids))
408398
time.sleep(random.uniform(0.5, 2.0))
409-
if local:
410-
reg.update_worker_info(claimed, job_id=_LOCAL_JOB_ID)
411399
try:
412400
yield claimed
413401
finally:
@@ -417,10 +405,10 @@ def inflight_session(
417405

418406

419407
def record_worker_info(
420-
reg: InflightRegistry, item_uids: list[str], job: submitit.Job[tp.Any]
408+
reg: InflightRegistry, item_uids: list[str], job: tp.Any = None
421409
) -> None:
422-
"""Stamp a submitit *job*'s worker info on *item_uids*. Slurm jobs get
423-
job_id + folder; other backends get the local sentinel."""
410+
"""Stamp worker info on *item_uids*. Slurm jobs get job_id + folder;
411+
everything else (including no job) gets the local PID sentinel."""
424412
if isinstance(job, submitit.SlurmJob):
425413
reg.update_worker_info(
426414
item_uids, job_id=str(job.job_id), job_folder=str(job.paths.folder)

exca/cachedict/test_inflight.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,11 @@ def fresh() -> inflight.InflightRegistry:
7979
raise ValueError("boom")
8080
assert seen(["a"]) == {}
8181

82-
# Local: claims marked with job_id="local".
83-
with inflight.inflight_session(fresh(), ["loc"], local=True) as claimed:
82+
# Local: record_worker_info without job stamps _LOCAL_JOB_ID.
83+
reg = fresh()
84+
with inflight.inflight_session(reg, ["loc"]) as claimed:
8485
assert claimed == ["loc"]
86+
inflight.record_worker_info(reg, claimed)
8587
assert seen(["loc"])["loc"].job_id == inflight._LOCAL_JOB_ID
8688

8789
# Nested: inner session must NOT release outer's claim.

exca/map.py

Lines changed: 11 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
import numpy as np
1919
import pydantic
2020

21-
from . import base, slurm
21+
from . import base, slurm, utils
2222
from .cachedict import CacheDict, inflight
23-
from .utils import ShortItemUid
2423

2524

2625
@dataclasses.dataclass
@@ -70,46 +69,7 @@ def __call__(self, items: tp.Sequence[tp.Any]) -> tp.Iterator[tp.Any]:
7069
return self.infra._method_override(items)
7170

7271

73-
def _make_pool_executor(pool: str, max_workers: int) -> futures.Executor:
74-
"""Falls back to ThreadPoolExecutor (with a warning) if ``pool="processpool"``
75-
and ``ProcessPoolExecutor`` cannot be created (eg ``sem_open`` EPERM)."""
76-
if pool == "processpool":
77-
try:
78-
return futures.ProcessPoolExecutor(max_workers=max_workers)
79-
except PermissionError as e:
80-
logger.warning(
81-
"ProcessPoolExecutor unavailable (%s); falling back to "
82-
"ThreadPoolExecutor. Set cluster='threadpool' or cluster=None "
83-
"to silence this warning.",
84-
e,
85-
)
86-
return futures.ThreadPoolExecutor(max_workers=max_workers)
87-
88-
89-
def to_chunks(
90-
items: list[X], *, max_chunks: int | None, min_items_per_chunk: int = 1
91-
) -> tp.Iterator[list[X]]:
92-
"""Split a list of items into several smaller list of items
93-
94-
Parameters
95-
----------
96-
max_chunks: optional int
97-
maximum number of chunks to create
98-
min_items_per_chunk: int
99-
minimum number of items per chunk
100-
101-
Yields
102-
------
103-
list of items
104-
"""
105-
splits = min(
106-
len(items) if max_chunks is None else max_chunks,
107-
int(np.ceil(len(items) / min_items_per_chunk)),
108-
)
109-
items_per_chunk = int(np.ceil(len(items) / splits))
110-
for k in range(splits):
111-
# select a batch/chunk of samples_per_job items to send to a job
112-
yield items[k * items_per_chunk : (k + 1) * items_per_chunk]
72+
_make_pool_executor = utils.make_pool_executor # deprecated
11373

11474

11575
class MapInfra(base.BaseInfra, slurm.SubmititMixin):
@@ -281,7 +241,9 @@ def apply(
281241
params.pop("self")
282242
max_length = params.pop("item_uid_max_length")
283243
if max_length is not None:
284-
params["item_uid"] = ShortItemUid(params["item_uid"], max_length=max_length)
244+
params["item_uid"] = utils.ShortItemUid(
245+
params["item_uid"], max_length=max_length
246+
)
285247
if self._infra_method is not None:
286248
raise RuntimeError(f"Infra was already applied: {self._infra_method}")
287249

@@ -392,7 +354,7 @@ def _method_override(self, *args: tp.Any, **kwargs: tp.Any) -> tp.Iterator[tp.An
392354
if missing:
393355
jobs: list[tp.Any] = []
394356
uid_item_chunks = list(
395-
to_chunks(
357+
utils.to_chunks(
396358
missing,
397359
max_chunks=self.max_jobs,
398360
min_items_per_chunk=self.min_samples_per_job,
@@ -455,9 +417,10 @@ def _method_override_futures(self, items: tp.Sequence[tp.Any]) -> tp.Iterator[tp
455417
pool = None
456418
# avoid processing same files at same time if several jobs overlap
457419
np.random.shuffle(missing)
458-
with inflight.inflight_session(
459-
self._inflight_registry(), [k for k, _ in missing], local=True
460-
) as claimed_uids:
420+
reg = self._inflight_registry()
421+
with inflight.inflight_session(reg, [k for k, _ in missing]) as claimed_uids:
422+
if reg is not None:
423+
inflight.record_worker_info(reg, claimed_uids)
461424
claimed_set = set(claimed_uids)
462425
# Re-check cache after wait: other workers may have completed
463426
# items while we were blocked in inflight_session.
@@ -487,7 +450,7 @@ def _method_override_futures(self, items: tp.Sequence[tp.Any]) -> tp.Iterator[tp
487450
max_workers = min(max_workers, self.max_jobs)
488451
with _make_pool_executor(pool, max_workers) as ex:
489452
mitems = [ki[1] for ki in missing]
490-
chunks = to_chunks(mitems, max_chunks=3 * max_workers)
453+
chunks = utils.to_chunks(mitems, max_chunks=3 * max_workers)
491454
for chunk in chunks:
492455
j = ex.submit(
493456
self._call_and_store,

exca/steps/backends.py

Lines changed: 108 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,20 @@
1414

1515
import dataclasses
1616
import logging
17+
import os
18+
import random
1719
import shutil
1820
import traceback
1921
import typing as tp
2022
import warnings
23+
from concurrent import futures
2124
from pathlib import Path
2225

2326
import pydantic
2427
import submitit
2528

2629
import exca
30+
from exca import utils
2731
from exca.cachedict import inflight
2832

2933
from . import errors, identity, items
@@ -280,7 +284,7 @@ def _exclude_from_cls_uid(cls) -> list[str]:
280284
return ["."] # force ignored in uid
281285

282286
# Used by Backend._run for the inflight registry (concurrent worker safety).
283-
_is_off_process: tp.ClassVar[bool] = False
287+
_concurrent: tp.ClassVar[bool] = False
284288

285289
folder: Path | None = None
286290
# deprecated: declare `CACHE_TYPE` on the Step subclass.
@@ -364,7 +368,7 @@ def _clear_cache(
364368
uid: str,
365369
) -> None:
366370
"""Drop everything cached for this uid (cd row, error row, job folder)."""
367-
if self._is_off_process and paths.cache_folder.exists():
371+
if self._concurrent and paths.cache_folder.exists():
368372
try:
369373
reg = inflight.InflightRegistry(paths.cache_folder)
370374
info = reg.get([uid])
@@ -421,23 +425,19 @@ def _run(self, step: Step, batch: items.StepItems) -> items.StepItems:
421425
for uid in to_compute:
422426
paths._ensure_folders(uid)
423427
reg: inflight.InflightRegistry | None = None
424-
if self._is_off_process:
428+
if self._concurrent:
425429
reg = inflight.InflightRegistry(paths.cache_folder)
426430
with inflight.inflight_session(reg, to_compute):
427-
# Re-check after lock (another worker may have finished).
428431
recheck = _CachedEntry.lookup_statuses(cd, to_compute)
429-
still_uids = [u for u in to_compute if recheck[u] != "success"]
430-
if still_uids:
431-
filtered = batch.select(still_uids)
432+
pending_uids = [u for u in to_compute if recheck[u] != "success"]
433+
if pending_uids:
434+
filtered = batch.select(pending_uids)
432435
wrapper = _CachingCall(step, cd, paths.step_uid)
433436
try:
434-
job = self._submit(wrapper, filtered, paths=paths)
435-
if reg is not None:
436-
inflight.record_worker_info(reg, still_uids, job)
437-
job.result()
437+
self._execute(wrapper, filtered, paths=paths, reg=reg)
438438
finally:
439439
if mode in ("force", "retry"):
440-
self._recomputed.update(still_uids)
440+
self._recomputed.update(pending_uids)
441441
verify = _CachedEntry.lookup_statuses(cd, to_compute)
442442
for uid in to_compute:
443443
if verify[uid] != "success":
@@ -452,17 +452,16 @@ def _run(self, step: Step, batch: items.StepItems) -> items.StepItems:
452452
mode=mode,
453453
)
454454

455-
def _submit(self, wrapper: _CachingCall, *args: tp.Any, paths: StepPaths) -> tp.Any:
456-
"""Submit wrapper for execution. Default: inline execution."""
457-
wrapper(*args)
458-
return _InlineJob()
459-
460-
461-
class _InlineJob:
462-
"""Completion signal for inline execution; the value lives in cache."""
463-
464-
def result(self) -> None:
465-
return None
455+
def _execute(
456+
self,
457+
wrapper: _CachingCall,
458+
pending: items.StepItems,
459+
*,
460+
paths: StepPaths,
461+
reg: inflight.InflightRegistry | None,
462+
) -> None:
463+
"""Run *wrapper* on *pending* items. Override for chunking/pools/arrays."""
464+
wrapper(pending)
466465

467466

468467
class Cached(Backend):
@@ -473,7 +472,7 @@ class _SubmititBackend(Backend):
473472
"""Base for submitit backends."""
474473

475474
# Submitit cloud-pickles inputs to a worker (`SubmititDebug` overrides: inline).
476-
_is_off_process: tp.ClassVar[bool] = True
475+
_concurrent: tp.ClassVar[bool] = True
477476

478477
job_name: str | None = None
479478
timeout_min: int | None = None
@@ -482,24 +481,51 @@ class _SubmititBackend(Backend):
482481
cpus_per_task: int | None = None
483482
gpus_per_node: int | None = None
484483
mem_gb: float | None = None
484+
max_jobs: int = 128
485+
min_items_per_job: int = 1
485486

486487
# passed as `cluster=` to submitit.AutoExecutor; subclasses pin it.
487488
_CLUSTER: tp.ClassVar[str | None] = None
488489

489490
def _submitit_params(self) -> dict[str, tp.Any]:
490491
"""Build the kwargs dict forwarded to ``AutoExecutor.update_parameters``."""
491492
fields = set(type(self).model_fields) - set(Backend.model_fields)
492-
params = {k: getattr(self, k) for k in fields if getattr(self, k) is not None}
493+
skip = {"max_jobs", "min_items_per_job"}
494+
params = {
495+
k: getattr(self, k) for k in fields - skip if getattr(self, k) is not None
496+
}
493497
if "job_name" in params:
494498
params["name"] = params.pop("job_name")
495499
return params
496500

497-
def _submit(self, wrapper: _CachingCall, *args: tp.Any, paths: StepPaths) -> tp.Any:
501+
def _execute(
502+
self,
503+
wrapper: _CachingCall,
504+
pending: items.StepItems,
505+
*,
506+
paths: StepPaths,
507+
reg: inflight.InflightRegistry | None,
508+
) -> None:
509+
uids = list(pending.uids)
510+
random.shuffle(uids)
511+
chunks = [
512+
pending.select(c)
513+
for c in utils.to_chunks(
514+
uids, max_chunks=self.max_jobs, min_items_per_chunk=self.min_items_per_job
515+
)
516+
]
498517
executor = submitit.AutoExecutor(folder=paths._logs_folder, cluster=self._CLUSTER)
499-
executor.update_parameters(**self._submitit_params())
500-
with submitit.helpers.clean_env():
501-
job = executor.submit(wrapper, *args)
502-
return job
518+
params = self._submitit_params()
519+
if self._CLUSTER in ("slurm", None):
520+
params["slurm_array_parallelism"] = len(chunks)
521+
executor.update_parameters(**params)
522+
with submitit.helpers.clean_env(), executor.batch():
523+
jobs = [executor.submit(wrapper, c) for c in chunks]
524+
if reg is not None:
525+
for c, j in zip(chunks, jobs):
526+
inflight.record_worker_info(reg, c.uids, j)
527+
for j in jobs:
528+
j.result()
503529

504530

505531
class LocalProcess(_SubmititBackend):
@@ -512,7 +538,7 @@ class SubmititDebug(_SubmititBackend):
512538
"""Debug executor (inline but simulates submitit)."""
513539

514540
_CLUSTER: tp.ClassVar[str | None] = "debug"
515-
_is_off_process: tp.ClassVar[bool] = False
541+
_concurrent: tp.ClassVar[bool] = False
516542

517543

518544
class Slurm(_SubmititBackend):
@@ -542,3 +568,54 @@ class Auto(Slurm):
542568
"""Auto-detect executor (local or Slurm). Slurm fields only apply on slurm."""
543569

544570
_CLUSTER: tp.ClassVar[str | None] = None
571+
572+
573+
class _PoolBackend(Backend):
574+
"""Base for concurrent.futures pool backends."""
575+
576+
max_jobs: int | None = 128
577+
_POOL_TYPE: tp.ClassVar[str]
578+
579+
def _execute(
580+
self,
581+
wrapper: _CachingCall,
582+
pending: items.StepItems,
583+
*,
584+
paths: StepPaths,
585+
reg: inflight.InflightRegistry | None,
586+
) -> None:
587+
uids = list(pending.uids)
588+
random.shuffle(uids)
589+
cpus = max(1, (os.cpu_count() or 1) - 1)
590+
max_workers = min(len(uids), cpus)
591+
if self.max_jobs is not None:
592+
max_workers = min(max_workers, self.max_jobs)
593+
chunks = [
594+
pending.select(c) for c in utils.to_chunks(uids, max_chunks=3 * max_workers)
595+
]
596+
if reg is not None:
597+
for c in chunks:
598+
inflight.record_worker_info(reg, c.uids)
599+
with utils.make_pool_executor(self._POOL_TYPE, max_workers) as pool:
600+
futs = [pool.submit(wrapper, c) for c in chunks]
601+
try:
602+
for f in futures.as_completed(futs):
603+
f.result()
604+
except BaseException:
605+
for f in futs:
606+
f.cancel()
607+
raise
608+
609+
610+
class ProcessPool(_PoolBackend):
611+
"""Process pool execution + caching."""
612+
613+
_POOL_TYPE: tp.ClassVar[str] = "processpool"
614+
_concurrent: tp.ClassVar[bool] = True
615+
616+
617+
class ThreadPool(_PoolBackend):
618+
"""Thread pool execution + caching."""
619+
620+
_POOL_TYPE: tp.ClassVar[str] = "threadpool"
621+
_concurrent: tp.ClassVar[bool] = True

0 commit comments

Comments
 (0)