Skip to content
2 changes: 1 addition & 1 deletion nx_arangodb/algorithms/shortest_paths/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
extra_params=_dtype_param, version_added="24.04", _plc={"bfs", "sssp"}
)
def shortest_path(
G: nxadb.Graph | nxadb.DiGraph,
G: nxadb.Graph,
source=None,
target=None,
weight=None,
Expand Down
5 changes: 2 additions & 3 deletions nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,7 @@ def propagate_edge_directed_symmetric(

propagate_edge_func = (
propagate_edge_directed_symmetric
if self.is_directed and self.symmetrize_edges_if_directed
if self.symmetrize_edges_if_directed
else (
propagate_edge_directed
if self.is_directed
Expand Down Expand Up @@ -1663,8 +1663,7 @@ def _fetch_all(self) -> None:
load_all_edge_attributes=True,
is_directed=self.is_directed,
is_multigraph=self.is_multigraph,
symmetrize_edges_if_directed=self.is_directed
and self.symmetrize_edges_if_directed,
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
)

if self.is_directed:
Expand Down
39 changes: 20 additions & 19 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,27 +441,28 @@ def add_node(self, node_for_adding, **attr):
nx._clear_cache(self)

def number_of_edges(self, u=None, v=None):
if u is None:
######################
# NOTE: monkey patch #
######################
if not self.graph_exists_in_db:
return super().number_of_edges(u, v)

# Old:
# return int(self.size())
if u is not None:
return super().number_of_edges(u, v)

# New:
edge_collections = {
e_d["edge_collection"] for e_d in self.adb_graph.edge_definitions()
}
num = sum(
self.adb_graph.edge_collection(e).count() for e in edge_collections
)
num *= 2 if self.is_directed() and self.symmetrize_edges else 1
######################
# NOTE: monkey patch #
######################

return num
# Old:
# return int(self.size())

# Reason:
# It is more efficient to count the number of edges in the edge collections
# compared to relying on the DegreeView.
# New:
edge_collections = {
e_d["edge_collection"] for e_d in self.adb_graph.edge_definitions()
}
num = sum(self.adb_graph.edge_collection(e).count() for e in edge_collections)
num *= 2 if self.is_directed() and self.symmetrize_edges else 1

return num

super().number_of_edges(u, v)
# Reason:
# It is more efficient to count the number of edges in the edge collections
# compared to relying on the DegreeView.
8 changes: 4 additions & 4 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ def __init__(
**kwargs,
)

if self._graph_exists_in_db:
self.add_edge = self.add_edge_keyless

#######################
# Init helper methods #
#######################
Expand All @@ -71,7 +68,10 @@ def _set_factory_methods(self) -> None:
# nx.MultiGraph Overides #
##########################

def add_edge_keyless(self, u_for_edge, v_for_edge, key=None, **attr):
def add_edge(self, u_for_edge, v_for_edge, key=None, **attr):
if not self.graph_exists_in_db:
return super().add_edge(u_for_edge, v_for_edge, key=key, **attr)

if key is not None:
m = "ArangoDB MultiGraph does not support custom edge keys yet."
logger.warning(m)
Expand Down
35 changes: 23 additions & 12 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import networkx as nx

import nx_arangodb as nxadb
from nx_arangodb.classes.dict.adj import AdjListOuterDict
from nx_arangodb.classes.dict.node import NodeDict
from nx_arangodb.classes.function import do_load_all_edge_attributes
from nx_arangodb.logger import logger

Expand All @@ -30,7 +32,7 @@
def _to_nx_graph(G: Any, *args: Any, **kwargs: Any) -> nx.Graph:
logger.debug(f"_to_nx_graph for {G.__class__.__name__}")

if isinstance(G, nxadb.Graph | nxadb.DiGraph):
if isinstance(G, nxadb.Graph):
return nxadb_to_nx(G)

if isinstance(G, nx.Graph):
Expand Down Expand Up @@ -109,16 +111,14 @@ def nx_to_nxadb(

def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
if not G.graph_exists_in_db:
logger.debug("graph does not exist, nothing to pull")
# TODO: Consider just returning G here?
# Avoids the need to re-create the graph from scratch
return G.to_networkx_class()(incoming_graph_data=G)
# Since nxadb.Graph is a subclass of nx.Graph, we can return it as is.
# This only applies if the graph does not exist in the database.
return G

# TODO: Re-enable this
# if G.use_nx_cache and G._node and G._adj:
# m = "**use_nx_cache** is enabled. using cached data. no pull required."
# logger.debug(m)
# return G
assert isinstance(G._node, NodeDict)
assert isinstance(G._adj, AdjListOuterDict)
if G._node.FETCHED_ALL_DATA and G._adj.FETCHED_ALL_DATA:
return G

start_time = time.time()

Expand All @@ -137,11 +137,22 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:

print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")

G_NX: nx.Graph | nx.DiGraph = G.to_networkx_class()()
# NOTE: At this point, we _could_ choose to implement something similar to
# NodeDict._fetch_all() and AdjListOuterDict._fetch_all() to iterate through
# **node_dict** and **adj_dict**, and establish the "custom" Dictionary classes
# that we've implemented in nx_arangodb.classes.dict.
# However, this would involve adding additional for-loops and would likely be
# slower than the current implementation.
# Perhaps we should consider adding a feature flag to allow users to choose
# between the two methods? e.g `build_remote_dicts=True/False`
# If True, then we would return the (updated) nxadb.Graph that was passed in.
# If False, then we would return the nx.Graph that is built below:
Comment on lines +140 to +149
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be an interesting point of discussion to address in a separate PR as an improvement


G_NX: nx.Graph = G.to_networkx_class()()
G_NX._node = node_dict

if isinstance(G_NX, nx.DiGraph):
G_NX._succ = G._adj = adj_dict["succ"]
G_NX._succ = G_NX._adj = adj_dict["succ"]
Copy link
Member Author

@aMahanna aMahanna Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was the bug culprit. I mistakenly overwrote the _adj of the nxadb.Graph object 😨

G_NX._pred = adj_dict["pred"]

else:
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from arango import ArangoClient
from arango.database import StandardDatabase

import nx_arangodb as nxadb
from nx_arangodb.logger import logger

logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -85,3 +86,17 @@ def load_two_relation_graph() -> None:
g.create_edge_definition(
e2, from_vertex_collections=[v2], to_vertex_collections=[v1]
)


def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
G = nx.Graph()
G.add_edge(1, 2, my_custom_weight=1)
G.add_edge(2, 3, my_custom_weight=1)
G.add_edge(3, 4, my_custom_weight=1000)
G.add_edge(4, 5, my_custom_weight=1000)

return nxadb.Graph(
incoming_graph_data=G,
graph_name="LineGraph",
edge_collections_attributes=load_attributes,
)
118 changes: 75 additions & 43 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, Union

import networkx as nx
import phenolrs
import pytest
from arango import DocumentDeleteError
from phenolrs.networkx.typings import (
Expand All @@ -11,33 +12,21 @@
)

import nx_arangodb as nxadb
from nx_arangodb.classes.dict.adj import EdgeAttrDict
from nx_arangodb.classes.dict.node import NodeAttrDict
from nx_arangodb.classes.dict.adj import AdjListOuterDict, EdgeAttrDict
from nx_arangodb.classes.dict.node import NodeAttrDict, NodeDict

from .conftest import db
from .conftest import create_line_graph, db

G_NX = nx.karate_club_graph()


def extract_arangodb_key(adb_id: str) -> str:
return adb_id.split("/")[1]
def assert_remote_dict(G: nxadb.Graph) -> None:
assert isinstance(G._node, NodeDict)
assert isinstance(G._adj, AdjListOuterDict)


def create_line_graph(load_attributes: set[str]) -> nxadb.Graph:
G = nx.Graph()
G.add_edge(1, 2, my_custom_weight=1)
G.add_edge(2, 3, my_custom_weight=1)
G.add_edge(3, 4, my_custom_weight=1000)
G.add_edge(4, 5, my_custom_weight=1000)

if load_attributes:
return nxadb.Graph(
incoming_graph_data=G,
graph_name="LineGraph",
edge_collections_attributes=load_attributes,
)

return nxadb.Graph(incoming_graph_data=G, graph_name="LineGraph")
def extract_arangodb_key(adb_id: str) -> str:
return adb_id.split("/")[1]


def assert_same_dict_values(
Expand Down Expand Up @@ -172,18 +161,35 @@ def test_load_graph_with_non_default_weight_attribute():


@pytest.mark.parametrize(
"algorithm_func, assert_func",
"algorithm_func, assert_func, affected_by_symmetry",
[
(nx.betweenness_centrality, assert_bc),
(nx.pagerank, assert_pagerank),
(nx.community.louvain_communities, assert_louvain),
(nx.betweenness_centrality, assert_bc, True),
(nx.pagerank, assert_pagerank, False),
],
)
def test_algorithm(
algorithm_func: Callable[..., Any],
assert_func: Callable[..., Any],
affected_by_symmetry: bool,
load_karate_graph: Any,
) -> None:
def assert_func_should_fail(
r1: dict[str | int, float], r2: dict[str | int, float]
) -> None:
assert r1 != r2
assert len(r1) == len(r2)
with pytest.raises(AssertionError):
assert_func(r1, r2)

def assert_symmetry_differences(
r1: dict[str | int, float], r2: dict[str | int, float], should_be_equal: bool
) -> None:
if should_be_equal:
assert_func(r1, r2)
return

assert_func_should_fail(r1, r2)

G_1 = G_NX
G_2 = nxadb.Graph(incoming_graph_data=G_1)
G_3 = nxadb.Graph(graph_name="KarateGraph")
Expand All @@ -193,6 +199,9 @@ def test_algorithm(
G_7 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=True)
G_8 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=False)

for G in [G_3, G_4, G_5, G_6, G_7, G_8]:
assert_remote_dict(G)

r_1 = algorithm_func(G_1)
r_2 = algorithm_func(G_2)
r_3 = algorithm_func(G_1, backend="arangodb")
Expand All @@ -203,51 +212,74 @@ def test_algorithm(
assert_func(r_2, r_3)
assert_func(r_3, r_4)

try:
import phenolrs # noqa
except ModuleNotFoundError:
pytest.skip("phenolrs not installed")

r_7 = algorithm_func(G_3)
assert_remote_dict(G_3)
r_7_orig = algorithm_func.orig_func(G_3) # type: ignore
assert_remote_dict(G_3)

r_8 = algorithm_func(G_4)
assert_remote_dict(G_4)
r_8_orig = algorithm_func.orig_func(G_4) # type: ignore
assert_remote_dict(G_4)

r_9 = algorithm_func(G_5)
assert_remote_dict(G_5)
r_9_orig = algorithm_func.orig_func(G_5) # type: ignore
assert_remote_dict(G_5)

r_10 = algorithm_func(nx.DiGraph(incoming_graph_data=G_NX))

r_11 = algorithm_func(G_6)
assert_remote_dict(G_6)
r_11_orig = algorithm_func.orig_func(G_6) # type: ignore
assert_remote_dict(G_6)

r_12 = algorithm_func(G_7)
assert_remote_dict(G_7)
r_12_orig = algorithm_func.orig_func(G_7) # type: ignore
assert_remote_dict(G_7)

r_13 = algorithm_func(G_8)
assert_remote_dict(G_8)
r_13_orig = algorithm_func.orig_func(G_8) # type: ignore
assert_remote_dict(G_8)

assert_func(r_7, r_7_orig)
assert_func(r_8, r_8_orig)
assert_func(r_9, r_9_orig)
assert_func(r_7, r_1)
assert_func(r_7, r_8)
assert r_8 != r_9
assert r_8_orig != r_9_orig

assert_symmetry_differences(r_8, r_8_orig, should_be_equal=not affected_by_symmetry)
assert_func_should_fail(r_8, r_9)

assert_func(r_9, r_9_orig)
assert_symmetry_differences(
r_8_orig, r_9_orig, should_be_equal=affected_by_symmetry
)

assert_func(r_8, r_10)
assert_func(r_8_orig, r_10)
assert_func(r_7, r_11)
assert_func(r_8, r_11)
assert_symmetry_differences(
r_8_orig, r_10, should_be_equal=not affected_by_symmetry
)

assert_func(r_11, r_7)
assert_func(r_11, r_7)
assert_func(r_11, r_11_orig)
assert_func(r_12, r_12_orig)

assert_symmetry_differences(
r_12, r_12_orig, should_be_equal=not affected_by_symmetry
)
assert_func_should_fail(r_12, r_13)

assert_func(r_13, r_13_orig)
assert r_12 != r_13
assert r_12_orig != r_13_orig
assert_func(r_8, r_12)
assert_func(r_8_orig, r_12_orig)
assert_func(r_9, r_13)
assert_func(r_9_orig, r_13_orig)
assert_symmetry_differences(
r_12_orig, r_13_orig, should_be_equal=affected_by_symmetry
)

assert_func(r_12, r_8)
assert_func(r_12_orig, r_8_orig)

assert_func(r_13, r_9)
assert_func(r_13_orig, r_9_orig)


def test_shortest_path_remote_algorithm(load_karate_graph: Any) -> None:
Expand Down