Skip to content

nxadb_to_nx cleanup #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 20, 2024
2 changes: 1 addition & 1 deletion nx_arangodb/algorithms/shortest_paths/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
extra_params=_dtype_param, version_added="24.04", _plc={"bfs", "sssp"}
)
def shortest_path(
G: nxadb.Graph | nxadb.DiGraph,
G: nxadb.Graph,
source=None,
target=None,
weight=None,
Expand Down
5 changes: 2 additions & 3 deletions nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,7 @@ def propagate_edge_directed_symmetric(

propagate_edge_func = (
propagate_edge_directed_symmetric
if self.is_directed and self.symmetrize_edges_if_directed
if self.symmetrize_edges_if_directed
else (
propagate_edge_directed
if self.is_directed
Expand Down Expand Up @@ -1606,8 +1606,7 @@ def set_edge_multigraph(
load_all_edge_attributes=True,
is_directed=self.is_directed,
is_multigraph=self.is_multigraph,
symmetrize_edges_if_directed=self.is_directed
and self.symmetrize_edges_if_directed,
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
)

if self.is_directed:
Expand Down
39 changes: 20 additions & 19 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,27 +432,28 @@ def add_node(self, node_for_adding, **attr):
nx._clear_cache(self)

def number_of_edges(self, u=None, v=None):
if u is None:
######################
# NOTE: monkey patch #
######################
if not self.graph_exists_in_db:
return super().number_of_edges(u, v)

# Old:
# return int(self.size())
if u is not None:
return super().number_of_edges(u, v)

# New:
edge_collections = {
e_d["edge_collection"] for e_d in self.adb_graph.edge_definitions()
}
num = sum(
self.adb_graph.edge_collection(e).count() for e in edge_collections
)
num *= 2 if self.is_directed() and self.symmetrize_edges else 1
######################
# NOTE: monkey patch #
######################

return num
# Old:
# return int(self.size())

# Reason:
# It is more efficient to count the number of edges in the edge collections
# compared to relying on the DegreeView.
# New:
edge_collections = {
e_d["edge_collection"] for e_d in self.adb_graph.edge_definitions()
}
num = sum(self.adb_graph.edge_collection(e).count() for e in edge_collections)
num *= 2 if self.is_directed() and self.symmetrize_edges else 1

return num

super().number_of_edges(u, v)
# Reason:
# It is more efficient to count the number of edges in the edge collections
# compared to relying on the DegreeView.
8 changes: 4 additions & 4 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def __init__(
**kwargs,
)

if self._graph_exists_in_db:
self.add_edge = self.add_edge_keyless

#######################
# Init helper methods #
#######################
Expand All @@ -71,7 +68,10 @@ def _set_factory_methods(self) -> None:
# nx.MultiGraph Overides #
##########################

def add_edge_keyless(self, u_for_edge, v_for_edge, key=None, **attr):
def add_edge(self, u_for_edge, v_for_edge, key=None, **attr):
if not self.graph_exists_in_db:
return super().add_edge(u_for_edge, v_for_edge, key=key, **attr)

if key is not None:
m = "ArangoDB MultiGraph does not support custom edge keys yet."
logger.warning(m)
Expand Down
35 changes: 23 additions & 12 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import networkx as nx

import nx_arangodb as nxadb
from nx_arangodb.classes.dict.adj import AdjListOuterDict
from nx_arangodb.classes.dict.node import NodeDict
from nx_arangodb.classes.function import do_load_all_edge_attributes
from nx_arangodb.logger import logger

Expand All @@ -30,7 +32,7 @@
def _to_nx_graph(G: Any, *args: Any, **kwargs: Any) -> nx.Graph:
logger.debug(f"_to_nx_graph for {G.__class__.__name__}")

if isinstance(G, nxadb.Graph | nxadb.DiGraph):
if isinstance(G, nxadb.Graph):
return nxadb_to_nx(G)

if isinstance(G, nx.Graph):
Expand Down Expand Up @@ -109,16 +111,14 @@ def nx_to_nxadb(

def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
if not G.graph_exists_in_db:
logger.debug("graph does not exist, nothing to pull")
# TODO: Consider just returning G here?
# Avoids the need to re-create the graph from scratch
return G.to_networkx_class()(incoming_graph_data=G)
# Since nxadb.Graph is a subclass of nx.Graph, we can return it as is.
# This only applies if the graph does not exist in the database.
return G

# TODO: Re-enable this
# if G.use_nx_cache and G._node and G._adj:
# m = "**use_nx_cache** is enabled. using cached data. no pull required."
# logger.debug(m)
# return G
assert isinstance(G._node, NodeDict)
assert isinstance(G._adj, AdjListOuterDict)
if G._node.FETCHED_ALL_DATA and G._adj.FETCHED_ALL_DATA:
return G

start_time = time.time()

Expand All @@ -137,11 +137,22 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:

print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")

G_NX: nx.Graph | nx.DiGraph = G.to_networkx_class()()
# NOTE: At this point, we _could_ choose to implement something similar to
# NodeDict._fetch_all() and AdjListOuterDict._fetch_all() to iterate through
# **node_dict** and **adj_dict**, and establish the "custom" Dictionary classes
# that we've implemented in nx_arangodb.classes.dict.
# However, this would involve adding additional for-loops and would likely be
# slower than the current implementation.
# Perhaps we should consider adding a feature flag to allow users to choose
# between the two methods? e.g `build_remote_dicts=True/False`
# If True, then we would return the (updated) nxadb.Graph that was passed in.
# If False, then we would return the nx.Graph that is built below:
Comment on lines +140 to +149
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be an interesting point of discussion to address in a separate PR as an improvement


G_NX: nx.Graph = G.to_networkx_class()()
G_NX._node = node_dict

if isinstance(G_NX, nx.DiGraph):
G_NX._succ = G._adj = adj_dict["succ"]
G_NX._succ = G_NX._adj = adj_dict["succ"]
Copy link
Member Author

@aMahanna aMahanna Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was the bug culprit. I mistakenly overwrote the _adj of the nxadb.Graph object 😨

G_NX._pred = adj_dict["pred"]

else:
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from arango import ArangoClient
from arango.database import StandardDatabase

import nx_arangodb as nxadb
from nx_arangodb.logger import logger

logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -85,3 +86,17 @@ def load_two_relation_graph() -> None:
g.create_edge_definition(
e2, from_vertex_collections=[v2], to_vertex_collections=[v1]
)


def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
G = nx.Graph()
G.add_edge(1, 2, my_custom_weight=1)
G.add_edge(2, 3, my_custom_weight=1)
G.add_edge(3, 4, my_custom_weight=1000)
G.add_edge(4, 5, my_custom_weight=1000)

return nxadb.Graph(
incoming_graph_data=G,
graph_name="LineGraph",
edge_collections_attributes=load_attributes,
)
114 changes: 73 additions & 41 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,22 @@
from typing import Any, Callable, Dict

import networkx as nx
import phenolrs
import pytest
from arango import DocumentDeleteError

import nx_arangodb as nxadb
from nx_arangodb.classes.dict.adj import EdgeAttrDict
from nx_arangodb.classes.dict.node import NodeAttrDict
from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict
from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict

from .conftest import db
from .conftest import create_line_graph, db

G_NX = nx.karate_club_graph()


def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
G = nx.Graph()
G.add_edge(1, 2, my_custom_weight=1)
G.add_edge(2, 3, my_custom_weight=1)
G.add_edge(3, 4, my_custom_weight=1000)
G.add_edge(4, 5, my_custom_weight=1000)

if load_attributes:
return nxadb.Graph(
incoming_graph_data=G,
graph_name="LineGraph",
edge_collections_attributes=load_attributes,
)

return nxadb.Graph(incoming_graph_data=G, graph_name="LineGraph")
def assert_remote_dict(G: nxadb.Graph) -> None:
assert isinstance(G._node, NodeDict)
assert isinstance(G._adj, AdjListOuterDict)


def assert_same_dict_values(
Expand Down Expand Up @@ -162,18 +151,35 @@ def test_load_graph_with_non_default_weight_attribute():


@pytest.mark.parametrize(
"algorithm_func, assert_func",
"algorithm_func, assert_func, affected_by_symmetry",
[
(nx.betweenness_centrality, assert_bc),
(nx.pagerank, assert_pagerank),
(nx.community.louvain_communities, assert_louvain),
(nx.betweenness_centrality, assert_bc, True),
(nx.pagerank, assert_pagerank, False),
],
)
def test_algorithm(
algorithm_func: Callable[..., Any],
assert_func: Callable[..., Any],
affected_by_symmetry: bool,
load_karate_graph: Any,
) -> None:
def assert_func_should_fail(
r1: dict[str | int, float], r2: dict[str | int, float]
) -> None:
assert r1 != r2
assert len(r1) == len(r2)
with pytest.raises(AssertionError):
assert_func(r1, r2)

def assert_symmetry_differences(
r1: dict[str | int, float], r2: dict[str | int, float], should_be_equal: bool
) -> None:
if should_be_equal:
assert_func(r1, r2)
return

assert_func_should_fail(r1, r2)

G_1 = G_NX
G_2 = nxadb.Graph(incoming_graph_data=G_1)
G_3 = nxadb.Graph(graph_name="KarateGraph")
Expand All @@ -183,6 +189,9 @@ def test_algorithm(
G_7 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=True)
G_8 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=False)

for G in [G_3, G_4, G_5, G_6, G_7, G_8]:
assert_remote_dict(G)

r_1 = algorithm_func(G_1)
r_2 = algorithm_func(G_2)
r_3 = algorithm_func(G_1, backend="arangodb")
Expand All @@ -193,51 +202,74 @@ def test_algorithm(
assert_func(r_2, r_3)
assert_func(r_3, r_4)

try:
import phenolrs # noqa
except ModuleNotFoundError:
pytest.skip("phenolrs not installed")

r_7 = algorithm_func(G_3)
assert_remote_dict(G_3)
r_7_orig = algorithm_func.orig_func(G_3) # type: ignore
assert_remote_dict(G_3)

r_8 = algorithm_func(G_4)
assert_remote_dict(G_4)
r_8_orig = algorithm_func.orig_func(G_4) # type: ignore
assert_remote_dict(G_4)

r_9 = algorithm_func(G_5)
assert_remote_dict(G_5)
r_9_orig = algorithm_func.orig_func(G_5) # type: ignore
assert_remote_dict(G_5)

r_10 = algorithm_func(nx.DiGraph(incoming_graph_data=G_NX))

r_11 = algorithm_func(G_6)
assert_remote_dict(G_6)
r_11_orig = algorithm_func.orig_func(G_6) # type: ignore
assert_remote_dict(G_6)

r_12 = algorithm_func(G_7)
assert_remote_dict(G_7)
r_12_orig = algorithm_func.orig_func(G_7) # type: ignore
assert_remote_dict(G_7)

r_13 = algorithm_func(G_8)
assert_remote_dict(G_8)
r_13_orig = algorithm_func.orig_func(G_8) # type: ignore
assert_remote_dict(G_8)

assert_func(r_7, r_7_orig)
assert_func(r_8, r_8_orig)
assert_func(r_9, r_9_orig)
assert_func(r_7, r_1)
assert_func(r_7, r_8)
assert r_8 != r_9
assert r_8_orig != r_9_orig

assert_symmetry_differences(r_8, r_8_orig, should_be_equal=not affected_by_symmetry)
assert_func_should_fail(r_8, r_9)

assert_func(r_9, r_9_orig)
assert_symmetry_differences(
r_8_orig, r_9_orig, should_be_equal=affected_by_symmetry
)

assert_func(r_8, r_10)
assert_func(r_8_orig, r_10)
assert_func(r_7, r_11)
assert_func(r_8, r_11)
assert_symmetry_differences(
r_8_orig, r_10, should_be_equal=not affected_by_symmetry
)

assert_func(r_11, r_7)
assert_func(r_11, r_7)
assert_func(r_11, r_11_orig)
assert_func(r_12, r_12_orig)

assert_symmetry_differences(
r_12, r_12_orig, should_be_equal=not affected_by_symmetry
)
assert_func_should_fail(r_12, r_13)

assert_func(r_13, r_13_orig)
assert r_12 != r_13
assert r_12_orig != r_13_orig
assert_func(r_8, r_12)
assert_func(r_8_orig, r_12_orig)
assert_func(r_9, r_13)
assert_func(r_9_orig, r_13_orig)
assert_symmetry_differences(
r_12_orig, r_13_orig, should_be_equal=affected_by_symmetry
)

assert_func(r_12, r_8)
assert_func(r_12_orig, r_8_orig)

assert_func(r_13, r_9)
assert_func(r_13_orig, r_9_orig)


def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None:
Expand Down