1414
1515import dataclasses
1616import logging
17+ import os
18+ import random
1719import shutil
1820import traceback
1921import typing as tp
2022import warnings
23+ from concurrent import futures
2124from pathlib import Path
2225
2326import pydantic
2427import submitit
2528
2629import exca
30+ from exca import utils
2731from exca .cachedict import inflight
2832
2933from . import errors , identity , items
@@ -280,7 +284,7 @@ def _exclude_from_cls_uid(cls) -> list[str]:
280284 return ["." ] # force ignored in uid
281285
282286 # Used by Backend._run for the inflight registry (concurrent worker safety).
283- _is_off_process : tp .ClassVar [bool ] = False
287+ _concurrent : tp .ClassVar [bool ] = False
284288
285289 folder : Path | None = None
286290 # deprecated: declare `CACHE_TYPE` on the Step subclass.
@@ -364,7 +368,7 @@ def _clear_cache(
364368 uid : str ,
365369 ) -> None :
366370 """Drop everything cached for this uid (cd row, error row, job folder)."""
367- if self ._is_off_process and paths .cache_folder .exists ():
371+ if self ._concurrent and paths .cache_folder .exists ():
368372 try :
369373 reg = inflight .InflightRegistry (paths .cache_folder )
370374 info = reg .get ([uid ])
@@ -421,23 +425,19 @@ def _run(self, step: Step, batch: items.StepItems) -> items.StepItems:
421425 for uid in to_compute :
422426 paths ._ensure_folders (uid )
423427 reg : inflight .InflightRegistry | None = None
424- if self ._is_off_process :
428+ if self ._concurrent :
425429 reg = inflight .InflightRegistry (paths .cache_folder )
426430 with inflight .inflight_session (reg , to_compute ):
427- # Re-check after lock (another worker may have finished).
428431 recheck = _CachedEntry .lookup_statuses (cd , to_compute )
429- still_uids = [u for u in to_compute if recheck [u ] != "success" ]
430- if still_uids :
431- filtered = batch .select (still_uids )
432+ pending_uids = [u for u in to_compute if recheck [u ] != "success" ]
433+ if pending_uids :
434+ filtered = batch .select (pending_uids )
432435 wrapper = _CachingCall (step , cd , paths .step_uid )
433436 try :
434- job = self ._submit (wrapper , filtered , paths = paths )
435- if reg is not None :
436- inflight .record_worker_info (reg , still_uids , job )
437- job .result ()
437+ self ._execute (wrapper , filtered , paths = paths , reg = reg )
438438 finally :
439439 if mode in ("force" , "retry" ):
440- self ._recomputed .update (still_uids )
440+ self ._recomputed .update (pending_uids )
441441 verify = _CachedEntry .lookup_statuses (cd , to_compute )
442442 for uid in to_compute :
443443 if verify [uid ] != "success" :
@@ -452,17 +452,16 @@ def _run(self, step: Step, batch: items.StepItems) -> items.StepItems:
452452 mode = mode ,
453453 )
454454
455- def _submit (self , wrapper : _CachingCall , * args : tp .Any , paths : StepPaths ) -> tp .Any :
456- """Submit wrapper for execution. Default: inline execution."""
457- wrapper (* args )
458- return _InlineJob ()
459-
460-
461- class _InlineJob :
462- """Completion signal for inline execution; the value lives in cache."""
463-
464- def result (self ) -> None :
465- return None
455+ def _execute (
456+ self ,
457+ wrapper : _CachingCall ,
458+ pending : items .StepItems ,
459+ * ,
460+ paths : StepPaths ,
461+ reg : inflight .InflightRegistry | None ,
462+ ) -> None :
463+ """Run *wrapper* on *pending* items. Override for chunking/pools/arrays."""
464+ wrapper (pending )
466465
467466
468467class Cached (Backend ):
@@ -473,7 +472,7 @@ class _SubmititBackend(Backend):
473472 """Base for submitit backends."""
474473
475474 # Submitit cloud-pickles inputs to a worker (`SubmititDebug` overrides: inline).
476- _is_off_process : tp .ClassVar [bool ] = True
475+ _concurrent : tp .ClassVar [bool ] = True
477476
478477 job_name : str | None = None
479478 timeout_min : int | None = None
@@ -482,24 +481,51 @@ class _SubmititBackend(Backend):
482481 cpus_per_task : int | None = None
483482 gpus_per_node : int | None = None
484483 mem_gb : float | None = None
484+ max_jobs : int = 128
485+ min_items_per_job : int = 1
485486
486487 # passed as `cluster=` to submitit.AutoExecutor; subclasses pin it.
487488 _CLUSTER : tp .ClassVar [str | None ] = None
488489
489490 def _submitit_params (self ) -> dict [str , tp .Any ]:
490491 """Build the kwargs dict forwarded to ``AutoExecutor.update_parameters``."""
491492 fields = set (type (self ).model_fields ) - set (Backend .model_fields )
492- params = {k : getattr (self , k ) for k in fields if getattr (self , k ) is not None }
493+ skip = {"max_jobs" , "min_items_per_job" }
494+ params = {
495+ k : getattr (self , k ) for k in fields - skip if getattr (self , k ) is not None
496+ }
493497 if "job_name" in params :
494498 params ["name" ] = params .pop ("job_name" )
495499 return params
496500
497- def _submit (self , wrapper : _CachingCall , * args : tp .Any , paths : StepPaths ) -> tp .Any :
501+ def _execute (
502+ self ,
503+ wrapper : _CachingCall ,
504+ pending : items .StepItems ,
505+ * ,
506+ paths : StepPaths ,
507+ reg : inflight .InflightRegistry | None ,
508+ ) -> None :
509+ uids = list (pending .uids )
510+ random .shuffle (uids )
511+ chunks = [
512+ pending .select (c )
513+ for c in utils .to_chunks (
514+ uids , max_chunks = self .max_jobs , min_items_per_chunk = self .min_items_per_job
515+ )
516+ ]
498517 executor = submitit .AutoExecutor (folder = paths ._logs_folder , cluster = self ._CLUSTER )
499- executor .update_parameters (** self ._submitit_params ())
500- with submitit .helpers .clean_env ():
501- job = executor .submit (wrapper , * args )
502- return job
518+ params = self ._submitit_params ()
519+ if self ._CLUSTER in ("slurm" , None ):
520+ params ["slurm_array_parallelism" ] = len (chunks )
521+ executor .update_parameters (** params )
522+ with submitit .helpers .clean_env (), executor .batch ():
523+ jobs = [executor .submit (wrapper , c ) for c in chunks ]
524+ if reg is not None :
525+ for c , j in zip (chunks , jobs ):
526+ inflight .record_worker_info (reg , c .uids , j )
527+ for j in jobs :
528+ j .result ()
503529
504530
505531class LocalProcess (_SubmititBackend ):
@@ -512,7 +538,7 @@ class SubmititDebug(_SubmititBackend):
512538 """Debug executor (inline but simulates submitit)."""
513539
514540 _CLUSTER : tp .ClassVar [str | None ] = "debug"
515- _is_off_process : tp .ClassVar [bool ] = False
541+ _concurrent : tp .ClassVar [bool ] = False
516542
517543
518544class Slurm (_SubmititBackend ):
@@ -542,3 +568,54 @@ class Auto(Slurm):
542568 """Auto-detect executor (local or Slurm). Slurm fields only apply on slurm."""
543569
544570 _CLUSTER : tp .ClassVar [str | None ] = None
571+
572+
573+ class _PoolBackend (Backend ):
574+ """Base for concurrent.futures pool backends."""
575+
576+ max_jobs : int | None = 128
577+ _POOL_TYPE : tp .ClassVar [str ]
578+
579+ def _execute (
580+ self ,
581+ wrapper : _CachingCall ,
582+ pending : items .StepItems ,
583+ * ,
584+ paths : StepPaths ,
585+ reg : inflight .InflightRegistry | None ,
586+ ) -> None :
587+ uids = list (pending .uids )
588+ random .shuffle (uids )
589+ cpus = max (1 , (os .cpu_count () or 1 ) - 1 )
590+ max_workers = min (len (uids ), cpus )
591+ if self .max_jobs is not None :
592+ max_workers = min (max_workers , self .max_jobs )
593+ chunks = [
594+ pending .select (c ) for c in utils .to_chunks (uids , max_chunks = 3 * max_workers )
595+ ]
596+ if reg is not None :
597+ for c in chunks :
598+ inflight .record_worker_info (reg , c .uids )
599+ with utils .make_pool_executor (self ._POOL_TYPE , max_workers ) as pool :
600+ futs = [pool .submit (wrapper , c ) for c in chunks ]
601+ try :
602+ for f in futures .as_completed (futs ):
603+ f .result ()
604+ except BaseException :
605+ for f in futs :
606+ f .cancel ()
607+ raise
608+
609+
610+ class ProcessPool (_PoolBackend ):
611+ """Process pool execution + caching."""
612+
613+ _POOL_TYPE : tp .ClassVar [str ] = "processpool"
614+ _concurrent : tp .ClassVar [bool ] = True
615+
616+
617+ class ThreadPool (_PoolBackend ):
618+ """Thread pool execution + caching."""
619+
620+ _POOL_TYPE : tp .ClassVar [str ] = "threadpool"
621+ _concurrent : tp .ClassVar [bool ] = True
0 commit comments