Skip to content

Commit b5d94e0

Browse files
authored
nxadb_to_nx cleanup (#32)
* attempt: nxadb_to_nx cleanup * checkpoint * bring back other algorithms * passing, but certain assertions are commented out need to revisit failing assertions ASAP * attempt cleanup: nx overrides * cleanup: symmetrize_edges_if_directed * cleanup: `test_algorithm` assertions * fix: symmetrize edges * fix: symmetrize edges
1 parent 32555a3 commit b5d94e0

File tree

7 files changed

+140
-82
lines changed

7 files changed

+140
-82
lines changed

nx_arangodb/algorithms/shortest_paths/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
extra_params=_dtype_param, version_added="24.04", _plc={"bfs", "sssp"}
1515
)
1616
def shortest_path(
17-
G: nxadb.Graph | nxadb.DiGraph,
17+
G: nxadb.Graph,
1818
source=None,
1919
target=None,
2020
weight=None,

nx_arangodb/classes/dict/adj.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,7 @@ def propagate_edge_directed_symmetric(
16071607

16081608
propagate_edge_func = (
16091609
propagate_edge_directed_symmetric
1610-
if self.is_directed and self.symmetrize_edges_if_directed
1610+
if self.symmetrize_edges_if_directed
16111611
else (
16121612
propagate_edge_directed
16131613
if self.is_directed
@@ -1663,8 +1663,7 @@ def _fetch_all(self) -> None:
16631663
load_all_edge_attributes=True,
16641664
is_directed=self.is_directed,
16651665
is_multigraph=self.is_multigraph,
1666-
symmetrize_edges_if_directed=self.is_directed
1667-
and self.symmetrize_edges_if_directed,
1666+
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
16681667
)
16691668

16701669
if self.is_directed:

nx_arangodb/classes/graph.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -441,27 +441,28 @@ def add_node(self, node_for_adding, **attr):
441441
nx._clear_cache(self)
442442

443443
def number_of_edges(self, u=None, v=None):
444-
if u is None:
445-
######################
446-
# NOTE: monkey patch #
447-
######################
444+
if not self.graph_exists_in_db:
445+
return super().number_of_edges(u, v)
448446

449-
# Old:
450-
# return int(self.size())
447+
if u is not None:
448+
return super().number_of_edges(u, v)
451449

452-
# New:
453-
edge_collections = {
454-
e_d["edge_collection"] for e_d in self.adb_graph.edge_definitions()
455-
}
456-
num = sum(
457-
self.adb_graph.edge_collection(e).count() for e in edge_collections
458-
)
459-
num *= 2 if self.is_directed() and self.symmetrize_edges else 1
450+
######################
451+
# NOTE: monkey patch #
452+
######################
460453

461-
return num
454+
# Old:
455+
# return int(self.size())
462456

463-
# Reason:
464-
# It is more efficient to count the number of edges in the edge collections
465-
# compared to relying on the DegreeView.
457+
# New:
458+
edge_collections = {
459+
e_d["edge_collection"] for e_d in self.adb_graph.edge_definitions()
460+
}
461+
num = sum(self.adb_graph.edge_collection(e).count() for e in edge_collections)
462+
num *= 2 if self.is_directed() and self.symmetrize_edges else 1
463+
464+
return num
466465

467-
super().number_of_edges(u, v)
466+
# Reason:
467+
# It is more efficient to count the number of edges in the edge collections
468+
# compared to relying on the DegreeView.

nx_arangodb/classes/multigraph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ def __init__(
5050
**kwargs,
5151
)
5252

53-
if self._graph_exists_in_db:
54-
self.add_edge = self.add_edge_keyless
55-
5653
#######################
5754
# Init helper methods #
5855
#######################
@@ -71,7 +68,10 @@ def _set_factory_methods(self) -> None:
7168
# nx.MultiGraph Overides #
7269
##########################
7370

74-
def add_edge_keyless(self, u_for_edge, v_for_edge, key=None, **attr):
71+
def add_edge(self, u_for_edge, v_for_edge, key=None, **attr):
72+
if not self.graph_exists_in_db:
73+
return super().add_edge(u_for_edge, v_for_edge, key=key, **attr)
74+
7575
if key is not None:
7676
m = "ArangoDB MultiGraph does not support custom edge keys yet."
7777
logger.warning(m)

nx_arangodb/convert.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import networkx as nx
77

88
import nx_arangodb as nxadb
9+
from nx_arangodb.classes.dict.adj import AdjListOuterDict
10+
from nx_arangodb.classes.dict.node import NodeDict
911
from nx_arangodb.classes.function import do_load_all_edge_attributes
1012
from nx_arangodb.logger import logger
1113

@@ -30,7 +32,7 @@
3032
def _to_nx_graph(G: Any, *args: Any, **kwargs: Any) -> nx.Graph:
3133
logger.debug(f"_to_nx_graph for {G.__class__.__name__}")
3234

33-
if isinstance(G, nxadb.Graph | nxadb.DiGraph):
35+
if isinstance(G, nxadb.Graph):
3436
return nxadb_to_nx(G)
3537

3638
if isinstance(G, nx.Graph):
@@ -109,16 +111,14 @@ def nx_to_nxadb(
109111

110112
def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
111113
if not G.graph_exists_in_db:
112-
logger.debug("graph does not exist, nothing to pull")
113-
# TODO: Consider just returning G here?
114-
# Avoids the need to re-create the graph from scratch
115-
return G.to_networkx_class()(incoming_graph_data=G)
114+
# Since nxadb.Graph is a subclass of nx.Graph, we can return it as is.
115+
# This only applies if the graph does not exist in the database.
116+
return G
116117

117-
# TODO: Re-enable this
118-
# if G.use_nx_cache and G._node and G._adj:
119-
# m = "**use_nx_cache** is enabled. using cached data. no pull required."
120-
# logger.debug(m)
121-
# return G
118+
assert isinstance(G._node, NodeDict)
119+
assert isinstance(G._adj, AdjListOuterDict)
120+
if G._node.FETCHED_ALL_DATA and G._adj.FETCHED_ALL_DATA:
121+
return G
122122

123123
start_time = time.time()
124124

@@ -137,11 +137,22 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
137137

138138
print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
139139

140-
G_NX: nx.Graph | nx.DiGraph = G.to_networkx_class()()
140+
# NOTE: At this point, we _could_ choose to implement something similar to
141+
# NodeDict._fetch_all() and AdjListOuterDict._fetch_all() to iterate through
142+
# **node_dict** and **adj_dict**, and establish the "custom" Dictionary classes
143+
# that we've implemented in nx_arangodb.classes.dict.
144+
# However, this would involve adding additional for-loops and would likely be
145+
# slower than the current implementation.
146+
# Perhaps we should consider adding a feature flag to allow users to choose
147+
# between the two methods? e.g `build_remote_dicts=True/False`
148+
# If True, then we would return the (updated) nxadb.Graph that was passed in.
149+
# If False, then we would return the nx.Graph that is built below:
150+
151+
G_NX: nx.Graph = G.to_networkx_class()()
141152
G_NX._node = node_dict
142153

143154
if isinstance(G_NX, nx.DiGraph):
144-
G_NX._succ = G._adj = adj_dict["succ"]
155+
G_NX._succ = G_NX._adj = adj_dict["succ"]
145156
G_NX._pred = adj_dict["pred"]
146157

147158
else:

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from arango import ArangoClient
99
from arango.database import StandardDatabase
1010

11+
import nx_arangodb as nxadb
1112
from nx_arangodb.logger import logger
1213

1314
logger.setLevel(logging.INFO)
@@ -85,3 +86,17 @@ def load_two_relation_graph() -> None:
8586
g.create_edge_definition(
8687
e2, from_vertex_collections=[v2], to_vertex_collections=[v1]
8788
)
89+
90+
91+
def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
92+
G = nx.Graph()
93+
G.add_edge(1, 2, my_custom_weight=1)
94+
G.add_edge(2, 3, my_custom_weight=1)
95+
G.add_edge(3, 4, my_custom_weight=1000)
96+
G.add_edge(4, 5, my_custom_weight=1000)
97+
98+
return nxadb.Graph(
99+
incoming_graph_data=G,
100+
graph_name="LineGraph",
101+
edge_collections_attributes=load_attributes,
102+
)

tests/test.py

Lines changed: 75 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Callable, Dict, Union
22

33
import networkx as nx
4+
import phenolrs
45
import pytest
56
from arango import DocumentDeleteError
67
from phenolrs.networkx.typings import (
@@ -11,33 +12,21 @@
1112
)
1213

1314
import nx_arangodb as nxadb
14-
from nx_arangodb.classes.dict.adj import EdgeAttrDict
15-
from nx_arangodb.classes.dict.node import NodeAttrDict
15+
from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict
16+
from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict
1617

17-
from .conftest import db
18+
from .conftest import create_line_graph, db
1819

1920
G_NX = nx.karate_club_graph()
2021

2122

22-
def extract_arangodb_key(adb_id: str) -> str:
23-
return adb_id.split("/")[1]
23+
def assert_remote_dict(G: nxadb.Graph) -> None:
24+
assert isinstance(G._node, NodeDict)
25+
assert isinstance(G._adj, AdjListOuterDict)
2426

2527

26-
def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
27-
G = nx.Graph()
28-
G.add_edge(1, 2, my_custom_weight=1)
29-
G.add_edge(2, 3, my_custom_weight=1)
30-
G.add_edge(3, 4, my_custom_weight=1000)
31-
G.add_edge(4, 5, my_custom_weight=1000)
32-
33-
if load_attributes:
34-
return nxadb.Graph(
35-
incoming_graph_data=G,
36-
graph_name="LineGraph",
37-
edge_collections_attributes=load_attributes,
38-
)
39-
40-
return nxadb.Graph(incoming_graph_data=G, graph_name="LineGraph")
28+
def extract_arangodb_key(adb_id: str) -> str:
29+
return adb_id.split("/")[1]
4130

4231

4332
def assert_same_dict_values(
@@ -172,18 +161,35 @@ def test_load_graph_with_non_default_weight_attribute():
172161

173162

174163
@pytest.mark.parametrize(
175-
"algorithm_func, assert_func",
164+
"algorithm_func, assert_func, affected_by_symmetry",
176165
[
177-
(nx.betweenness_centrality, assert_bc),
178-
(nx.pagerank, assert_pagerank),
179-
(nx.community.louvain_communities, assert_louvain),
166+
(nx.betweenness_centrality, assert_bc, True),
167+
(nx.pagerank, assert_pagerank, False),
180168
],
181169
)
182170
def test_algorithm(
183171
algorithm_func: Callable[..., Any],
184172
assert_func: Callable[..., Any],
173+
affected_by_symmetry: bool,
185174
load_karate_graph: Any,
186175
) -> None:
176+
def assert_func_should_fail(
177+
r1: dict[str | int, float], r2: dict[str | int, float]
178+
) -> None:
179+
assert r1 != r2
180+
assert len(r1) == len(r2)
181+
with pytest.raises(AssertionError):
182+
assert_func(r1, r2)
183+
184+
def assert_symmetry_differences(
185+
r1: dict[str | int, float], r2: dict[str | int, float], should_be_equal: bool
186+
) -> None:
187+
if should_be_equal:
188+
assert_func(r1, r2)
189+
return
190+
191+
assert_func_should_fail(r1, r2)
192+
187193
G_1 = G_NX
188194
G_2 = nxadb.Graph(incoming_graph_data=G_1)
189195
G_3 = nxadb.Graph(graph_name="KarateGraph")
@@ -193,6 +199,9 @@ def test_algorithm(
193199
G_7 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=True)
194200
G_8 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=False)
195201

202+
for G in [G_3, G_4, G_5, G_6, G_7, G_8]:
203+
assert_remote_dict(G)
204+
196205
r_1 = algorithm_func(G_1)
197206
r_2 = algorithm_func(G_2)
198207
r_3 = algorithm_func(G_1, backend="arangodb")
@@ -203,51 +212,74 @@ def test_algorithm(
203212
assert_func(r_2, r_3)
204213
assert_func(r_3, r_4)
205214

206-
try:
207-
import phenolrs # noqa
208-
except ModuleNotFoundError:
209-
pytest.skip("phenolrs not installed")
210-
211215
r_7 = algorithm_func(G_3)
216+
assert_remote_dict(G_3)
212217
r_7_orig = algorithm_func.orig_func(G_3) # type: ignore
218+
assert_remote_dict(G_3)
213219

214220
r_8 = algorithm_func(G_4)
221+
assert_remote_dict(G_4)
215222
r_8_orig = algorithm_func.orig_func(G_4) # type: ignore
223+
assert_remote_dict(G_4)
216224

217225
r_9 = algorithm_func(G_5)
226+
assert_remote_dict(G_5)
218227
r_9_orig = algorithm_func.orig_func(G_5) # type: ignore
228+
assert_remote_dict(G_5)
219229

220230
r_10 = algorithm_func(nx.DiGraph(incoming_graph_data=G_NX))
221231

222232
r_11 = algorithm_func(G_6)
233+
assert_remote_dict(G_6)
223234
r_11_orig = algorithm_func.orig_func(G_6) # type: ignore
235+
assert_remote_dict(G_6)
224236

225237
r_12 = algorithm_func(G_7)
238+
assert_remote_dict(G_7)
226239
r_12_orig = algorithm_func.orig_func(G_7) # type: ignore
240+
assert_remote_dict(G_7)
227241

228242
r_13 = algorithm_func(G_8)
243+
assert_remote_dict(G_8)
229244
r_13_orig = algorithm_func.orig_func(G_8) # type: ignore
245+
assert_remote_dict(G_8)
230246

231247
assert_func(r_7, r_7_orig)
232-
assert_func(r_8, r_8_orig)
233-
assert_func(r_9, r_9_orig)
234248
assert_func(r_7, r_1)
235249
assert_func(r_7, r_8)
236-
assert r_8 != r_9
237-
assert r_8_orig != r_9_orig
250+
251+
assert_symmetry_differences(r_8, r_8_orig, should_be_equal=not affected_by_symmetry)
252+
assert_func_should_fail(r_8, r_9)
253+
254+
assert_func(r_9, r_9_orig)
255+
assert_symmetry_differences(
256+
r_8_orig, r_9_orig, should_be_equal=affected_by_symmetry
257+
)
258+
238259
assert_func(r_8, r_10)
239-
assert_func(r_8_orig, r_10)
240-
assert_func(r_7, r_11)
241-
assert_func(r_8, r_11)
260+
assert_symmetry_differences(
261+
r_8_orig, r_10, should_be_equal=not affected_by_symmetry
262+
)
263+
264+
assert_func(r_11, r_7)
265+
assert_func(r_11, r_7)
242266
assert_func(r_11, r_11_orig)
243-
assert_func(r_12, r_12_orig)
267+
268+
assert_symmetry_differences(
269+
r_12, r_12_orig, should_be_equal=not affected_by_symmetry
270+
)
271+
assert_func_should_fail(r_12, r_13)
272+
244273
assert_func(r_13, r_13_orig)
245-
assert r_12 != r_13
246-
assert r_12_orig != r_13_orig
247-
assert_func(r_8, r_12)
248-
assert_func(r_8_orig, r_12_orig)
249-
assert_func(r_9, r_13)
250-
assert_func(r_9_orig, r_13_orig)
274+
assert_symmetry_differences(
275+
r_12_orig, r_13_orig, should_be_equal=affected_by_symmetry
276+
)
277+
278+
assert_func(r_12, r_8)
279+
assert_func(r_12_orig, r_8_orig)
280+
281+
assert_func(r_13, r_9)
282+
assert_func(r_13_orig, r_9_orig)
251283

252284

253285
def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None:

0 commit comments

Comments
 (0)