Skip to content

Commit cffda39

Browse files
authored
update: test_gpu (#48)
* fix: `logger` instead of `print` * update `test_gpu_pagerank` * temp: remove gpu ci filter * remove: `Capturing` * add asserts * bring back filter * fix: import
1 parent 514f69b commit cffda39

File tree

3 files changed

+41
-40
lines changed

3 files changed

+41
-40
lines changed

nx_arangodb/convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
135135
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
136136
)
137137

138-
print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
138+
logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
139139

140140
# NOTE: At this point, we _could_ choose to implement something similar to
141141
# NodeDict._fetch_all() and AdjListOuterDict._fetch_all() to iterate through
@@ -195,7 +195,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
195195
),
196196
)
197197

198-
print(f"ADB Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
198+
logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
199199

200200
start_time = time.time()
201201

@@ -240,6 +240,6 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
240240
key_to_id=vertex_ids_to_index,
241241
)
242242

243-
print(f"NXCG Graph construction took {time.time() - start_time}s")
243+
logger.info(f"NXCG Graph construction took {time.time() - start_time}s")
244244

245245
return G.nxcg_graph

tests/conftest.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,17 +120,3 @@ def create_grid_graph(graph_cls: type[nxadb.Graph]) -> nxadb.Graph:
120120
return graph_cls(
121121
incoming_graph_data=grid_graph, name="GridGraph", write_async=False
122122
)
123-
124-
125-
# Taken from:
126-
# https://stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call
127-
class Capturing(list[str]):
128-
def __enter__(self):
129-
self._stdout = sys.stdout
130-
sys.stdout = self._stringio = StringIO()
131-
return self
132-
133-
def __exit__(self, *args):
134-
self.extend(self._stringio.getvalue().splitlines())
135-
del self._stringio # free up some memory
136-
sys.stdout = self._stdout

tests/test.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict, EdgeKeyDict
1616
from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict
1717

18-
from .conftest import Capturing, create_grid_graph, create_line_graph, db, run_gpu_tests
18+
from .conftest import create_grid_graph, create_line_graph, db, run_gpu_tests
1919

2020
G_NX = nx.karate_club_graph()
2121
G_NX_digraph = nx.DiGraph(G_NX)
@@ -344,38 +344,53 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None:
344344

345345
assert nxadb.convert.GPU_AVAILABLE is True
346346
assert nx.config.backends.arangodb.use_gpu is True
347+
assert graph.nxcg_graph is None
347348

348-
res_gpu = None
349-
res_cpu = None
350-
351-
# Measure GPU execution time
349+
# 1. GPU
352350
start_gpu = time.time()
351+
res_gpu = nx.pagerank(graph)
352+
gpu_time = time.time() - start_gpu
353353

354-
# Note: While this works, we should use the logger or some alternative
355-
# approach testing this. Via stdout is not the best way to test this.
356-
with Capturing() as output_gpu:
357-
res_gpu = nx.pagerank(graph)
354+
assert graph.nxcg_graph is not None
355+
assert graph.nxcg_graph.number_of_nodes() == 250000
356+
assert graph.nxcg_graph.number_of_edges() == 499000
358357

359-
assert any(
360-
"NXCG Graph construction took" in line for line in output_gpu
361-
), "Expected output not found in GPU execution"
358+
# 2. GPU (cached)
359+
assert graph.use_nxcg_cache is True
362360

363-
gpu_time = time.time() - start_gpu
361+
start_gpu_cached = time.time()
362+
res_gpu_cached = nx.pagerank(graph)
363+
gpu_cached_time = time.time() - start_gpu_cached
364364

365-
# Disable GPU and measure CPU execution time
366-
nx.config.backends.arangodb.use_gpu = False
367-
start_cpu = time.time()
368-
with Capturing() as output_cpu:
369-
res_cpu = nx.pagerank(graph)
365+
assert gpu_cached_time < gpu_time
366+
assert_pagerank(res_gpu, res_gpu_cached, 10)
367+
368+
# 3. GPU (disable cache)
369+
graph.use_nxcg_cache = False
370370

371-
output_cpu_list = list(output_cpu)
372-
assert len(output_cpu_list) == 1
373-
assert "Graph 'GridGraph' load took" in output_cpu_list[0]
371+
start_gpu_no_cache = time.time()
372+
res_gpu_no_cache = nx.pagerank(graph)
373+
gpu_no_cache_time = time.time() - start_gpu_no_cache
374374

375+
assert gpu_cached_time < gpu_no_cache_time
376+
assert_pagerank(res_gpu_cached, res_gpu_no_cache, 10)
377+
378+
# 4. CPU
379+
assert graph.nxcg_graph is not None
380+
graph.clear_nxcg_cache()
381+
assert graph.nxcg_graph is None
382+
nx.config.backends.arangodb.use_gpu = False
383+
384+
start_cpu = time.time()
385+
res_cpu = nx.pagerank(graph)
375386
cpu_time = time.time() - start_cpu
376387

377-
assert gpu_time < cpu_time, "GPU execution should be faster than CPU execution"
378-
assert_pagerank(res_gpu, res_cpu, 10)
388+
assert graph.nxcg_graph is None
389+
390+
m = "GPU execution should be faster than CPU execution"
391+
assert gpu_time < cpu_time, m
392+
assert gpu_no_cache_time < cpu_time, m
393+
assert_pagerank(res_gpu_no_cache, res_cpu, 10)
379394

380395

381396
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)