Skip to content

Commit 8c1fe98

Browse files
tanmaykmKristofferC
authored andcommitted
make default worker pool an AbstractWorkerPool (#49101)
Changes [Distributed._default_worker_pool](https://github.com/JuliaLang/julia/blob/5f5d2040511b42ba74bd7529a0eac9cf817ad496/stdlib/Distributed/src/workerpool.jl#L242) to hold an `AbstractWorkerPool` instead of `WorkerPool`. With this, alternate implementations can be plugged in as the default pool. Helps in cases where a cluster is always meant to use a certain custom pool. Lower level calls can then work without having to pass a custom pool reference with every call. (cherry picked from commit def2dda)
1 parent ab206cb commit 8c1fe98

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

stdlib/Distributed/src/pmap.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ struct BatchProcessingError <: Exception
66
end
77

88
"""
9-
pgenerate([::WorkerPool], f, c...) -> iterator
9+
pgenerate([::AbstractWorkerPool], f, c...) -> iterator
1010
1111
Apply `f` to each element of `c` in parallel using available workers and tasks.
1212
@@ -18,14 +18,14 @@ Note that `f` must be made available to all worker processes; see
1818
[Code Availability and Loading Packages](@ref code-availability)
1919
for details.
2020
"""
21-
function pgenerate(p::WorkerPool, f, c)
21+
function pgenerate(p::AbstractWorkerPool, f, c)
2222
if length(p) == 0
2323
return AsyncGenerator(f, c; ntasks=()->nworkers(p))
2424
end
2525
batches = batchsplit(c, min_batch_count = length(p) * 3)
2626
return Iterators.flatten(AsyncGenerator(remote(p, b -> asyncmap(f, b)), batches))
2727
end
28-
pgenerate(p::WorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...))
28+
pgenerate(p::AbstractWorkerPool, f, c1, c...) = pgenerate(p, a->f(a...), zip(c1, c...))
2929
pgenerate(f, c) = pgenerate(default_worker_pool(), f, c)
3030
pgenerate(f, c1, c...) = pgenerate(a->f(a...), zip(c1, c...))
3131

stdlib/Distributed/src/workerpool.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,12 +239,14 @@ perform a `remote_do` on it.
239239
"""
240240
remote_do(f, pool::AbstractWorkerPool, args...; kwargs...) = remotecall_pool(remote_do, f, pool, args...; kwargs...)
241241

242-
const _default_worker_pool = Ref{Union{WorkerPool, Nothing}}(nothing)
242+
const _default_worker_pool = Ref{Union{AbstractWorkerPool, Nothing}}(nothing)
243243

244244
"""
245245
default_worker_pool()
246246
247-
[`WorkerPool`](@ref) containing idle [`workers`](@ref) - used by `remote(f)` and [`pmap`](@ref) (by default).
247+
[`AbstractWorkerPool`](@ref) containing idle [`workers`](@ref) - used by `remote(f)` and [`pmap`](@ref)
248+
(by default). Unless one is explicitly set via `default_worker_pool!(pool)`, the default worker pool is
249+
initialized to a [`WorkerPool`](@ref).
248250
249251
# Examples
250252
```julia-repl
@@ -267,6 +269,15 @@ function default_worker_pool()
267269
return _default_worker_pool[]
268270
end
269271

272+
"""
273+
default_worker_pool!(pool::AbstractWorkerPool)
274+
275+
Set a [`AbstractWorkerPool`](@ref) to be used by `remote(f)` and [`pmap`](@ref) (by default).
276+
"""
277+
function default_worker_pool!(pool::AbstractWorkerPool)
278+
_default_worker_pool[] = pool
279+
end
280+
270281
"""
271282
remote([p::AbstractWorkerPool], f) -> Function
272283

stdlib/Distributed/test/distributed_exec.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,19 @@ wp = CachingPool(workers())
675675
clear!(wp)
676676
@test length(wp.map_obj2ref) == 0
677677

678+
# default_worker_pool! tests
679+
wp_default = Distributed.default_worker_pool()
680+
try
681+
wp = CachingPool(workers())
682+
Distributed.default_worker_pool!(wp)
683+
@test [1:100...] == pmap(x->x, wp, 1:100)
684+
@test !isempty(wp.map_obj2ref)
685+
clear!(wp)
686+
@test isempty(wp.map_obj2ref)
687+
finally
688+
Distributed.default_worker_pool!(wp_default)
689+
end
690+
678691
# The below block of tests are usually run only on local development systems, since:
679692
# - tests which print errors
680693
# - addprocs tests are memory intensive

0 commit comments

Comments
 (0)