Skip to content

Commit d4377b3

Browse files
Add should_run functionality (#123)
* add should_run functionality * added a few basic tests * improve naming, should_run message * remove _default_should_run import in tests * centralise should_run * adding skip arguments feature * ensured test accomodates tournament graphs * added a policy for using certain number of cores * modify default_should_run policy * modify should_run tests * update docstring * change func name * handling exports * minor edit * fix naming; imports * concise imports * revert prev due to circular import * remove custom Co-authored-by: Dan Schult <dschult@colgate.edu> --------- Co-authored-by: Dan Schult <dschult@colgate.edu>
1 parent 8914899 commit d4377b3

File tree

7 files changed

+131
-3
lines changed

7 files changed

+131
-3
lines changed

nx_parallel/algorithms/isolate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
__all__ = ["number_of_isolates"]
66

77

8-
@nxp._configure_if_nx_active()
8+
@nxp._configure_if_nx_active(should_run=nxp.should_skip_parallel)
99
def number_of_isolates(G, get_chunks="chunks"):
1010
"""The parallel computation is implemented by dividing the list
1111
of isolated nodes into chunks and then finding the length of each chunk in parallel

nx_parallel/algorithms/shortest_paths/unweighted.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
]
1616

1717

18-
@nxp._configure_if_nx_active()
18+
@nxp._configure_if_nx_active(should_run=nxp.should_skip_parallel)
1919
def all_pairs_shortest_path_length(G, cutoff=None, get_chunks="chunks"):
2020
"""The parallel implementation first divides the nodes into chunks and then
2121
creates a generator to lazily compute shortest paths lengths for each node in

nx_parallel/interface.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,30 @@ def convert_to_nx(result, *, name=None):
8686
return result.graph_object
8787
return result
8888

89+
@classmethod
90+
def should_run(cls, name, args, kwargs):
91+
"""Determine whether this backend should run the specified algorithm
92+
with the given arguments.
93+
94+
Parameters
95+
----------
96+
cls : type
97+
`BackendInterface` class
98+
name : str
99+
Name of the target algorithm
100+
args : tuple
101+
Positional arguments passed to the algorithm's `should_run`.
102+
kwargs : dict
103+
Keyword arguments passed to the algorithm's `should_run`.
104+
105+
Returns
106+
-------
107+
bool or str
108+
If the algorithm should run, returns True.
109+
Otherwise, returns a string explaining why parallel execution is skipped.
110+
"""
111+
return getattr(cls, name).should_run(*args, **kwargs)
112+
89113

90114
for attr in ALGORITHMS:
91115
setattr(BackendInterface, attr, getattr(algorithms, attr))
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import nx_parallel as nxp
2+
from nx_parallel.interface import ALGORITHMS
3+
import networkx as nx
4+
import inspect
5+
import pytest
6+
import os
7+
import joblib
8+
9+
10+
def get_functions_with_should_run():
11+
for name, obj in inspect.getmembers(nxp.algorithms, inspect.isfunction):
12+
if callable(obj.should_run):
13+
yield name
14+
15+
16+
def test_get_functions_with_should_run():
17+
assert set(get_functions_with_should_run()) == set(ALGORITHMS)
18+
19+
20+
def test_default_should_run():
21+
@nxp._configure_if_nx_active()
22+
def dummy_default():
23+
pass
24+
25+
with pytest.MonkeyPatch().context() as mp:
26+
mp.delitem(os.environ, "PYTEST_CURRENT_TEST", raising=False)
27+
assert (
28+
dummy_default.should_run()
29+
== "Parallel backend requires `n_jobs` > 1 to run"
30+
)
31+
32+
with joblib.parallel_config(n_jobs=4):
33+
assert dummy_default.should_run()
34+
35+
36+
def test_skip_parallel_backend():
37+
@nxp._configure_if_nx_active(should_run=nxp.should_skip_parallel)
38+
def dummy_skip_parallel():
39+
pass
40+
41+
assert dummy_skip_parallel.should_run() == "Fast algorithm; skip parallel execution"
42+
43+
44+
def test_should_run_if_large():
45+
@nxp._configure_if_nx_active(should_run=nxp.should_run_if_large)
46+
def dummy_if_large(G):
47+
pass
48+
49+
smallG = nx.fast_gnp_random_graph(20, 0.6, seed=42)
50+
largeG = nx.fast_gnp_random_graph(250, 0.6, seed=42)
51+
52+
assert dummy_if_large.should_run(smallG) == "Graph too small for parallel execution"
53+
assert dummy_if_large.should_run(largeG)
54+
55+
56+
@pytest.mark.parametrize("func_name", get_functions_with_should_run())
57+
def test_should_run(func_name):
58+
tournament_funcs = [
59+
"tournament_is_strongly_connected",
60+
]
61+
62+
if func_name in tournament_funcs:
63+
G = nx.tournament.random_tournament(15, seed=42)
64+
else:
65+
G = nx.fast_gnp_random_graph(40, 0.6, seed=42)
66+
H = nxp.ParallelGraph(G)
67+
func = getattr(nxp, func_name)
68+
69+
result = func.should_run(H)
70+
if not isinstance(result, (bool, str)):
71+
raise AssertionError(
72+
f"{func.__name__}.should_run has an invalid return type; {type(result)}"
73+
)

nx_parallel/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .chunk import *
22
from .decorators import *
3+
from .should_run_policies import *

nx_parallel/utils/decorators.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from functools import wraps
44
import networkx as nx
55
from joblib import parallel_config
6+
from nx_parallel.utils.should_run_policies import default_should_run
67

78

89
__all__ = ["_configure_if_nx_active"]
910

1011

11-
def _configure_if_nx_active():
12+
def _configure_if_nx_active(should_run=None):
1213
"""Decorator to set the configuration for the parallel computation
1314
of the nx-parallel algorithms.
1415
"""
@@ -29,6 +30,10 @@ def wrapper(*args, **kwargs):
2930
return func(*args, **kwargs)
3031
return func(*args, **kwargs)
3132

33+
wrapper.should_run = default_should_run
34+
if should_run:
35+
wrapper.should_run = should_run
36+
3237
return wrapper
3338

3439
return decorator
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import nx_parallel as nxp
2+
3+
4+
__all__ = [
5+
"default_should_run",
6+
"should_skip_parallel",
7+
"should_run_if_large",
8+
]
9+
10+
11+
def should_skip_parallel(*_):
12+
return "Fast algorithm; skip parallel execution"
13+
14+
15+
def should_run_if_large(G, *_):
16+
if len(G) <= 200:
17+
return "Graph too small for parallel execution"
18+
return True
19+
20+
21+
def default_should_run(*_):
22+
n_jobs = nxp.get_n_jobs()
23+
if n_jobs in (None, 0, 1):
24+
return "Parallel backend requires `n_jobs` > 1 to run"
25+
return True

0 commit comments

Comments
 (0)