forked from networkx/nx-parallel
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_get_chunks.py
More file actions
86 lines (74 loc) · 2.93 KB
/
test_get_chunks.py
File metadata and controls
86 lines (74 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# smoke tests for all functions supporting `get_chunks` kwarg
import importlib
import inspect
import math
import random
import types
import networkx as nx
import nx_parallel as nxp
def get_all_functions(package_name="nx_parallel"):
"""Returns a dictionary where the keys are the function names in a given Python package, and the values are dictionaries containing the function's keyword arguments and positional arguments."""
package = importlib.import_module(package_name)
functions = {}
for name, obj in inspect.getmembers(package, inspect.isfunction):
if not name.startswith("_"):
args, kwargs = inspect.getfullargspec(obj)[:2]
functions[name] = {"args": args, "kwargs": kwargs}
return functions
def get_functions_with_get_chunks():
"""Returns a list of functions with the `get_chunks` kwarg."""
all_funcs = get_all_functions()
get_chunks_funcs = []
for func in all_funcs:
if "get_chunks" in all_funcs[func]["args"]:
get_chunks_funcs.append(func)
return get_chunks_funcs
def test_get_chunks():
def random_chunking(nodes):
_nodes = list(nodes).copy()
random.seed(42)
random.shuffle(_nodes)
num_chunks = nxp.cpu_count()
num_in_chunk = max(len(_nodes) // num_chunks, 1)
return nxp.chunks(_nodes, num_in_chunk)
get_chunks_funcs = get_functions_with_get_chunks()
ignore_funcs = [
"number_of_isolates",
"is_reachable",
]
tournament_funcs = [
"tournament_is_strongly_connected",
]
chk_dict_vals = [
"betweenness_centrality",
]
G = nx.fast_gnp_random_graph(50, 0.6, seed=42)
H = nxp.ParallelGraph(G)
for func in get_chunks_funcs:
print(func)
if func not in ignore_funcs:
if func in tournament_funcs:
G = nx.tournament.random_tournament(50, seed=42)
H = nxp.ParallelGraph(G)
c1 = getattr(nxp, func)(H)
c2 = getattr(nxp, func)(H, get_chunks=random_chunking)
assert c1 == c2
else:
c1 = getattr(nxp, func)(H)
c2 = getattr(nxp, func)(H, get_chunks=random_chunking)
if isinstance(c1, types.GeneratorType):
c1, c2 = dict(c1), dict(c2)
if func in chk_dict_vals:
for i in range(len(G.nodes)):
assert math.isclose(c1[i], c2[i], abs_tol=1e-16)
else:
assert c1 == c2
else:
if func in chk_dict_vals:
for i in range(len(G.nodes)):
assert math.isclose(c1[i], c2[i], abs_tol=1e-16)
else:
if isinstance(c1, float):
assert math.isclose(c1, c2, abs_tol=1e-16)
else:
assert c1 == c2