Skip to content

update: test_gpu #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
)

print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")

# NOTE: At this point, we _could_ choose to implement something similar to
# NodeDict._fetch_all() and AdjListOuterDict._fetch_all() to iterate through
Expand Down Expand Up @@ -195,7 +195,7 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
),
)

print(f"ADB Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
logger.info(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")

start_time = time.time()

Expand Down Expand Up @@ -240,6 +240,6 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
key_to_id=vertex_ids_to_index,
)

print(f"NXCG Graph construction took {time.time() - start_time}s")
logger.info(f"NXCG Graph construction took {time.time() - start_time}s")

return G.nxcg_graph
14 changes: 0 additions & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,3 @@ def create_grid_graph(graph_cls: type[nxadb.Graph]) -> nxadb.Graph:
return graph_cls(
incoming_graph_data=grid_graph, name="GridGraph", write_async=False
)


# Taken from:
# https://stackoverflow.com/questions/16571150/how-to-capture-stdout-output-from-a-python-function-call
class Capturing(list[str]):
def __enter__(self):
self._stdout = sys.stdout
sys.stdout = self._stringio = StringIO()
return self

def __exit__(self, *args):
self.extend(self._stringio.getvalue().splitlines())
del self._stringio # free up some memory
sys.stdout = self._stdout
61 changes: 38 additions & 23 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict, EdgeKeyDict
from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict

from .conftest import Capturing, create_grid_graph, create_line_graph, db, run_gpu_tests
from .conftest import create_grid_graph, create_line_graph, db, run_gpu_tests

G_NX = nx.karate_club_graph()
G_NX_digraph = nx.DiGraph(G_NX)
Expand Down Expand Up @@ -344,38 +344,53 @@ def test_gpu_pagerank(graph_cls: type[nxadb.Graph]) -> None:

assert nxadb.convert.GPU_AVAILABLE is True
assert nx.config.backends.arangodb.use_gpu is True
assert graph.nxcg_graph is None

res_gpu = None
res_cpu = None

# Measure GPU execution time
# 1. GPU
start_gpu = time.time()
res_gpu = nx.pagerank(graph)
gpu_time = time.time() - start_gpu

# Note: While this works, we should use the logger or some alternative
# approach testing this. Via stdout is not the best way to test this.
with Capturing() as output_gpu:
res_gpu = nx.pagerank(graph)
assert graph.nxcg_graph is not None
assert graph.nxcg_graph.number_of_nodes() == 250000
assert graph.nxcg_graph.number_of_edges() == 499000

assert any(
"NXCG Graph construction took" in line for line in output_gpu
), "Expected output not found in GPU execution"
# 2. GPU (cached)
assert graph.use_nxcg_cache is True

gpu_time = time.time() - start_gpu
start_gpu_cached = time.time()
res_gpu_cached = nx.pagerank(graph)
gpu_cached_time = time.time() - start_gpu_cached

# Disable GPU and measure CPU execution time
nx.config.backends.arangodb.use_gpu = False
start_cpu = time.time()
with Capturing() as output_cpu:
res_cpu = nx.pagerank(graph)
assert gpu_cached_time < gpu_time
assert_pagerank(res_gpu, res_gpu_cached, 10)

# 3. GPU (disable cache)
graph.use_nxcg_cache = False

output_cpu_list = list(output_cpu)
assert len(output_cpu_list) == 1
assert "Graph 'GridGraph' load took" in output_cpu_list[0]
start_gpu_no_cache = time.time()
res_gpu_no_cache = nx.pagerank(graph)
gpu_no_cache_time = time.time() - start_gpu_no_cache

assert gpu_cached_time < gpu_no_cache_time
assert_pagerank(res_gpu_cached, res_gpu_no_cache, 10)

# 4. CPU
assert graph.nxcg_graph is not None
graph.clear_nxcg_cache()
assert graph.nxcg_graph is None
nx.config.backends.arangodb.use_gpu = False

start_cpu = time.time()
res_cpu = nx.pagerank(graph)
cpu_time = time.time() - start_cpu

assert gpu_time < cpu_time, "GPU execution should be faster than CPU execution"
assert_pagerank(res_gpu, res_cpu, 10)
assert graph.nxcg_graph is None

m = "GPU execution should be faster than CPU execution"
assert gpu_time < cpu_time, m
assert gpu_no_cache_time < cpu_time, m
assert_pagerank(res_gpu_no_cache, res_cpu, 10)


@pytest.mark.parametrize(
Expand Down