|
15 | 15 | from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict, EdgeKeyDict
|
16 | 16 | from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict
|
17 | 17 |
|
18 |
| -from .conftest import Capturing, create_grid_graph, create_line_graph, db, run_gpu_tests |
| 18 | +from .conftest import create_grid_graph, create_line_graph, db, run_gpu_tests |
19 | 19 |
|
20 | 20 | G_NX = nx.karate_club_graph()
|
21 | 21 | G_NX_digraph = nx.DiGraph(G_NX)
|
@@ -344,38 +344,53 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None:
|
344 | 344 |
|
345 | 345 | assert nxadb.convert.GPU_AVAILABLE is True
|
346 | 346 | assert nx.config.backends.arangodb.use_gpu is True
|
| 347 | + assert graph.nxcg_graph is None |
347 | 348 |
|
348 |
| - res_gpu = None |
349 |
| - res_cpu = None |
350 |
| - |
351 |
| - # Measure GPU execution time |
| 349 | + # 1. GPU |
352 | 350 | start_gpu = time.time()
|
| 351 | + res_gpu = nx.pagerank(graph) |
| 352 | + gpu_time = time.time() - start_gpu |
353 | 353 |
|
354 |
| - # Note: While this works, we should use the logger or some alternative |
355 |
| - # approach testing this. Via stdout is not the best way to test this. |
356 |
| - with Capturing() as output_gpu: |
357 |
| - res_gpu = nx.pagerank(graph) |
| 354 | + assert graph.nxcg_graph is not None |
| 355 | + assert graph.nxcg_graph.number_of_nodes() == 250000 |
| 356 | + assert graph.nxcg_graph.number_of_edges() == 499000 |
358 | 357 |
|
359 |
| - assert any( |
360 |
| - "NXCG Graph construction took" in line for line in output_gpu |
361 |
| - ), "Expected output not found in GPU execution" |
| 358 | + # 2. GPU (cached) |
| 359 | + assert graph.use_nxcg_cache is True |
362 | 360 |
|
363 |
| - gpu_time = time.time() - start_gpu |
| 361 | + start_gpu_cached = time.time() |
| 362 | + res_gpu_cached = nx.pagerank(graph) |
| 363 | + gpu_cached_time = time.time() - start_gpu_cached |
364 | 364 |
|
365 |
| - # Disable GPU and measure CPU execution time |
366 |
| - nx.config.backends.arangodb.use_gpu = False |
367 |
| - start_cpu = time.time() |
368 |
| - with Capturing() as output_cpu: |
369 |
| - res_cpu = nx.pagerank(graph) |
| 365 | + assert gpu_cached_time < gpu_time |
| 366 | + assert_pagerank(res_gpu, res_gpu_cached, 10) |
| 367 | + |
| 368 | + # 3. GPU (disable cache) |
| 369 | + graph.use_nxcg_cache = False |
370 | 370 |
|
371 |
| - output_cpu_list = list(output_cpu) |
372 |
| - assert len(output_cpu_list) == 1 |
373 |
| - assert "Graph 'GridGraph' load took" in output_cpu_list[0] |
| 371 | + start_gpu_no_cache = time.time() |
| 372 | + res_gpu_no_cache = nx.pagerank(graph) |
| 373 | + gpu_no_cache_time = time.time() - start_gpu_no_cache |
374 | 374 |
|
| 375 | + assert gpu_cached_time < gpu_no_cache_time |
| 376 | + assert_pagerank(res_gpu_cached, res_gpu_no_cache, 10) |
| 377 | + |
| 378 | + # 4. CPU |
| 379 | + assert graph.nxcg_graph is not None |
| 380 | + graph.clear_nxcg_cache() |
| 381 | + assert graph.nxcg_graph is None |
| 382 | + nx.config.backends.arangodb.use_gpu = False |
| 383 | + |
| 384 | + start_cpu = time.time() |
| 385 | + res_cpu = nx.pagerank(graph) |
375 | 386 | cpu_time = time.time() - start_cpu
|
376 | 387 |
|
377 |
| - assert gpu_time < cpu_time, "GPU execution should be faster than CPU execution" |
378 |
| - assert_pagerank(res_gpu, res_cpu, 10) |
| 388 | + assert graph.nxcg_graph is None |
| 389 | + |
| 390 | + m = "GPU execution should be faster than CPU execution" |
| 391 | + assert gpu_time < cpu_time, m |
| 392 | + assert gpu_no_cache_time < cpu_time, m |
| 393 | + assert_pagerank(res_gpu_no_cache, res_cpu, 10) |
379 | 394 |
|
380 | 395 |
|
381 | 396 | @pytest.mark.parametrize(
|
|
0 commit comments