Skip to content

GA-150 | MultiDiGraph Support #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
9e31ddd
GA-149 | initial commit
aMahanna Aug 5, 2024
d7cd8c0
checkpoint
aMahanna Aug 5, 2024
f5fde45
checkpoint 2
aMahanna Aug 6, 2024
e0b39dd
checkpoint 3
aMahanna Aug 6, 2024
48d7ed1
Merge branch 'main' into GA-149
aMahanna Aug 6, 2024
d608f9b
checkpoint 4
aMahanna Aug 7, 2024
4e8370b
cleanup & comments
aMahanna Aug 7, 2024
4340e26
comments
aMahanna Aug 7, 2024
813eb12
cleanup: `__contains__`
aMahanna Aug 7, 2024
25de7fa
cleanup: `__getitem__`
aMahanna Aug 7, 2024
381af10
restructuring
aMahanna Aug 7, 2024
dcd1264
docstring updates
aMahanna Aug 7, 2024
3aa3266
checkpoint 5
aMahanna Aug 7, 2024
52bd984
cleanup
aMahanna Aug 7, 2024
2cdbb11
new helper functions
aMahanna Aug 8, 2024
a42b42f
checkpoint 6
aMahanna Aug 8, 2024
3eecb26
checkpoint 7
aMahanna Aug 8, 2024
8b46b19
cleanup
aMahanna Aug 8, 2024
69b31b7
add warning
aMahanna Aug 8, 2024
9c5c7ab
fix: conditional override
aMahanna Aug 8, 2024
43311d0
fix: func name
aMahanna Aug 8, 2024
5b31775
new: `FETCHED_ALL_IDS`
aMahanna Aug 8, 2024
3aa3404
fix: parameterize `EDGE_TYPE_KEY`
aMahanna Aug 9, 2024
d503b88
cleanup: redundant code
aMahanna Aug 9, 2024
837adb4
fix: `nodes` & `edges` properties
aMahanna Aug 9, 2024
91658b3
new: `__process_int_edge_key`
aMahanna Aug 9, 2024
d9861cc
new: `test_multigraph_*_crud`
aMahanna Aug 9, 2024
5dbc74b
update: `test_algorithm` for `nxadb.MultiGraph`
aMahanna Aug 9, 2024
0d2a9f2
fix: `__get_mirrored_adjlist_inner_dict`
aMahanna Aug 9, 2024
a6bd0b0
extra docstring
aMahanna Aug 12, 2024
54d82f0
new: graph overrides
aMahanna Aug 12, 2024
0dd8c4a
fix: EdgeKeyDict docstring
aMahanna Aug 12, 2024
b7ef351
update `phenolrs` wheel
aMahanna Aug 12, 2024
64c74d9
fix: phenolrs
aMahanna Aug 12, 2024
173094c
remove unused import
aMahanna Aug 12, 2024
5c72746
fix: except clause
aMahanna Aug 12, 2024
094b9a6
fix: logger info
aMahanna Aug 12, 2024
f155905
remove multigraph lock
aMahanna Aug 12, 2024
a2a8a5c
fix: typo
aMahanna Aug 12, 2024
5374443
cleanup: kwargs
aMahanna Aug 12, 2024
57e680c
remove print
aMahanna Aug 12, 2024
15a19f5
fix: add `write_batch_size` to config
aMahanna Aug 12, 2024
ec1cbc8
temp: `NodeDict.update` hack
aMahanna Aug 12, 2024
317666e
revert ec1cbc8
aMahanna Aug 12, 2024
b05d8f4
add custom exception
aMahanna Aug 13, 2024
1baedc0
update node & edge type logic for new vs existing graphs
aMahanna Aug 13, 2024
6431c94
fix: `symmetrize_edges` logic
aMahanna Aug 13, 2024
06a581f
GA-150 | initial commit
aMahanna Aug 13, 2024
10a2f3a
Merge branch 'main' into GA-150
aMahanna Aug 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2258,7 +2258,8 @@ def set_edge_multigraph(
load_all_edge_attributes=True,
is_directed=self.is_directed,
is_multigraph=self.is_multigraph,
symmetrize_edges_if_directed=self.symmetrize_edges_if_directed,
symmetrize_edges_if_directed=self.is_directed
and self.symmetrize_edges_if_directed,
)

if self.is_directed:
Expand Down
2 changes: 1 addition & 1 deletion nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def __init__(
read_parallelism,
read_batch_size,
write_batch_size,
symmetrize_edges,
*args,
**kwargs,
)

self.symmetrize_edges = symmetrize_edges
if self.graph_exists_in_db:
assert isinstance(self._succ, AdjListOuterDict)
assert isinstance(self._pred, AdjListOuterDict)
Expand Down
4 changes: 3 additions & 1 deletion nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
read_parallelism: int = 10,
read_batch_size: int = 100000,
write_batch_size: int = 50000,
symmetrize_edges: bool = False,
*args: Any,
**kwargs: Any,
):
Expand Down Expand Up @@ -80,7 +81,8 @@ def __init__(
self.edge_indices: npt.NDArray[np.int64] | None = None
self.vertex_ids_to_index: dict[str, int] | None = None

self.symmetrize_edges = False # Does not apply to undirected graphs
# Does not apply to undirected graphs
self.symmetrize_edges = symmetrize_edges

self.edge_type_key = edge_type_key

Expand Down
34 changes: 29 additions & 5 deletions nx_arangodb/classes/multidigraph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import ClassVar
from typing import Any, Callable, ClassVar

import networkx as nx
from arango.database import StandardDatabase

import nx_arangodb as nxadb
from nx_arangodb.classes.digraph import DiGraph
Expand All @@ -20,10 +21,33 @@ class MultiDiGraph(MultiGraph, DiGraph, nx.MultiDiGraph):
def to_networkx_class(cls) -> type[nx.MultiDiGraph]:
return nx.MultiDiGraph # type: ignore[no-any-return]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
m = "nxadb.MultiDiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiDiGraph for now." # noqa
logger.warning(m)
def __init__(
self,
graph_name: str | None = None,
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
write_batch_size: int = 50000,
symmetrize_edges: bool = False,
*args: Any,
**kwargs: Any,
):
super().__init__(
graph_name,
default_node_type,
edge_type_key,
edge_type_func,
db,
read_parallelism,
read_batch_size,
write_batch_size,
symmetrize_edges,
*args,
**kwargs,
)

#######################
# Init helper methods #
Expand Down
196 changes: 195 additions & 1 deletion tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,30 @@ def assert_same_dict_values(


def assert_bc(d1: dict[str | int, float], d2: dict[str | int, float]) -> None:
assert d1
assert d2
assert_same_dict_values(d1, d2, 14)


def assert_pagerank(d1: dict[str | int, float], d2: dict[str | int, float]) -> None:
assert d1
assert d2
assert_same_dict_values(d1, d2, 15)


def assert_louvain(l1: list[set[Any]], l2: list[set[Any]]) -> None:
# TODO: Implement some kind of comparison
# Reason: Louvain returns different results on different runs
assert l1
assert l2
pass


def assert_k_components(
d1: dict[int, list[set[Any]]], d2: dict[int, list[set[Any]]]
) -> None:
assert d1
assert d2
assert d1.keys() == d2.keys(), "Dictionaries have different keys"
assert d1 == d2

Expand Down Expand Up @@ -91,6 +99,8 @@ def test_algorithm(
G_4 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=True)
G_5 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=False)
G_6 = nxadb.MultiGraph(graph_name="KarateGraph")
G_7 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=True)
G_8 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=False)

r_1 = algorithm_func(G_1)
r_2 = algorithm_func(G_2)
Expand Down Expand Up @@ -121,7 +131,12 @@ def test_algorithm(
r_11 = algorithm_func(G_6)
r_11_orig = algorithm_func.orig_func(G_6) # type: ignore

assert all([r_7, r_7_orig, r_8, r_8_orig, r_9, r_9_orig, r_10, r_11, r_11_orig])
r_12 = algorithm_func(G_7)
r_12_orig = algorithm_func.orig_func(G_7) # type: ignore

r_13 = algorithm_func(G_8)
r_13_orig = algorithm_func.orig_func(G_8) # type: ignore

assert_func(r_7, r_7_orig)
assert_func(r_8, r_8_orig)
assert_func(r_9, r_9_orig)
Expand All @@ -134,6 +149,14 @@ def test_algorithm(
assert_func(r_7, r_11)
assert_func(r_8, r_11)
assert_func(r_11, r_11_orig)
assert_func(r_12, r_12_orig)
assert_func(r_13, r_13_orig)
assert r_12 != r_13
assert r_12_orig != r_13_orig
assert_func(r_8, r_12)
assert_func(r_8_orig, r_12_orig)
assert_func(r_9, r_13)
assert_func(r_9_orig, r_13_orig)


def test_shortest_path_remote_algorithm(load_graph: Any) -> None:
Expand All @@ -157,6 +180,7 @@ def test_shortest_path_remote_algorithm(load_graph: Any) -> None:
(nxadb.Graph),
(nxadb.DiGraph),
(nxadb.MultiGraph),
(nxadb.MultiDiGraph),
],
)
def test_nodes_crud(load_graph: Any, graph_cls: type[nxadb.Graph]) -> None:
Expand Down Expand Up @@ -741,6 +765,176 @@ def test_multigraph_edges_crud(load_graph: Any) -> None:
assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz"


def test_multidigraph_edges_crud(load_graph: Any) -> None:
G_1 = nxadb.MultiDiGraph(graph_name="KarateGraph")
G_2 = G_NX

assert len(G_1.adj) == len(G_2.adj)
assert len(G_1.edges) == len(G_2.edges)
assert G_1.number_of_edges() == G_2.number_of_edges()

for src, dst, w in G_1.edges.data("weight"):
assert G_1.adj[src][dst][0]["weight"] == w

for src, dst, w in G_1.edges.data("bad_key", default="boom!"):
assert "bad_key" not in G_1.adj[src][dst][0]
assert w == "boom!"

for k, edge_key_dict in G_1.adj["person/1"].items():
assert db.has_document(k)
assert db.has_document(edge_key_dict[0]["_id"])

G_1.add_edge("person/1", "person/1", foo="bar", _edge_type="knows")
edge_id = G_1.adj["person/1"]["person/1"][0]["_id"]
doc = db.document(edge_id)
assert doc["foo"] == "bar"
assert G_1.adj["person/1"]["person/1"][0]["foo"] == "bar"

del G_1.adj["person/1"]["person/1"][0]["foo"]
doc = db.document(edge_id)
assert "foo" not in doc

G_1.adj["person/1"]["person/1"][0].update({"bar": "foo"})
doc = db.document(edge_id)
assert doc["bar"] == "foo"

assert len(G_1.adj["person/1"]["person/1"][0]) == len(doc)
adj_count = len(G_1.adj["person/1"])
G_1.remove_edge("person/1", "person/1")
assert len(G_1.adj["person/1"]) == adj_count - 1
assert not db.has_document(edge_id)
assert "person/1" in G_1

assert not db.has_document("person/new_node_1")
col_count = db.collection("knows").count()

G_1.add_edge("new_node_1", "new_node_2", foo="bar")
assert db.document(G_1["new_node_1"]["new_node_2"][0]["_id"])["foo"] == "bar"
G_1.add_edge("new_node_1", "new_node_2", foo="bar", bar="foo")
doc = db.document(G_1["new_node_1"]["new_node_2"][1]["_id"])
assert doc["foo"] == "bar"
assert doc["bar"] == "foo"

bind_vars = {
"src": f"{G_1.default_node_type}/new_node_1",
"dst": f"{G_1.default_node_type}/new_node_2",
}

result = list(
db.aql.execute(
f"FOR e IN knows FILTER e._from == @src AND e._to == @dst RETURN e", # noqa
bind_vars=bind_vars,
)
)

assert len(result) == 2

result = list(
db.aql.execute(
f"FOR e IN knows FILTER e._from == @dst AND e._to == @src RETURN e", # noqa
bind_vars=bind_vars,
)
)

assert len(result) == 0

assert db.collection("knows").count() == col_count + 2
assert G_1.adj["new_node_1"]["new_node_2"][0]
assert G_1.adj["new_node_1"]["new_node_2"][0]["foo"] == "bar"
assert G_1.pred["new_node_2"]["new_node_1"][0]
assert "new_node_1" not in G_1.adj["new_node_2"]
assert (
G_1.adj["new_node_1"]["new_node_2"][0]["_id"]
== G_1.pred["new_node_2"]["new_node_1"][0]["_id"]
)
edge_id = G_1.adj["new_node_1"]["new_node_2"][0]["_id"]
doc = db.document(edge_id)
assert db.has_document(doc["_from"])
assert db.has_document(doc["_to"])
assert G_1.nodes["new_node_1"]
assert G_1.nodes["new_node_2"]

assert len(G_1.adj["new_node_1"]["new_node_2"]) == 2
G_1.remove_edge("new_node_1", "new_node_2")
G_1.clear()
assert "new_node_1" in G_1
assert "new_node_2" in G_1
assert "new_node_2" in G_1.adj["new_node_1"]
assert len(G_1.adj["new_node_1"]["new_node_2"]) == 1

G_1.add_edges_from(
[("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")], foo="bar"
)
G_1.clear()
assert "new_node_1" in G_1
assert "new_node_2" in G_1
assert "new_node_3" in G_1

for k in G_1.adj["new_node_1"]["new_node_2"]:
assert G_1.adj["new_node_1"]["new_node_2"][k]["foo"] == "bar"
assert G_1.pred["new_node_2"]["new_node_1"][k]["foo"] == "bar"

for k in G_1.adj["new_node_1"]["new_node_3"]:
assert G_1.adj["new_node_1"]["new_node_3"][k]["foo"] == "bar"
assert G_1.pred["new_node_3"]["new_node_1"][k]["foo"] == "bar"

assert len(G_1.adj["new_node_1"]["new_node_2"]) == 2
assert len(G_1.adj["new_node_1"]["new_node_3"]) == 1
G_1.remove_edges_from([("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")])
assert len(G_1.adj["new_node_1"]["new_node_2"]) == 1

assert "new_node_1" in G_1
assert "new_node_2" in G_1
assert "new_node_3" in G_1
assert "new_node_2" in G_1.adj["new_node_1"]
assert "new_node_3" not in G_1.adj["new_node_1"]

edge_id = "knows/1"
assert "person/1" not in G_1["person/2"]
assert (
G_1.succ["person/1"]["person/2"][edge_id]
== G_1.pred["person/2"]["person/1"][edge_id]
)
new_weight = 1000
G_1["person/1"]["person/2"][edge_id]["weight"] = new_weight
assert G_1.succ["person/1"]["person/2"][edge_id]["weight"] == new_weight
assert G_1.pred["person/2"]["person/1"][edge_id]["weight"] == new_weight
G_1.clear()
assert G_1.succ["person/1"]["person/2"][edge_id]["weight"] == new_weight
G_1.clear()
assert G_1.pred["person/2"]["person/1"][edge_id]["weight"] == new_weight

edge_id = G_1["person/1"]["person/2"][edge_id]["_id"]
G_1["person/1"]["person/2"][edge_id]["object"] = {"foo": "bar", "bar": "foo"}
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]
assert isinstance(G_1["person/1"]["person/2"][edge_id]["object"], EdgeAttrDict)
assert db.document(edge_id)["object"] == {"foo": "bar", "bar": "foo"}

G_1["person/1"]["person/2"][edge_id]["object"]["foo"] = "baz"
assert db.document(edge_id)["object"]["foo"] == "baz"

del G_1["person/1"]["person/2"][edge_id]["object"]["foo"]
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]
assert isinstance(G_1["person/1"]["person/2"][edge_id]["object"], EdgeAttrDict)
assert "foo" not in db.document(edge_id)["object"]

G_1["person/1"]["person/2"][edge_id]["object"].update(
{"sub_object": {"foo": "bar"}}
)
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]
assert isinstance(
G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"], EdgeAttrDict
)
assert db.document(edge_id)["object"]["sub_object"]["foo"] == "bar"

G_1.clear()

assert G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]["foo"] == "bar"
G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]["foo"] = "baz"
assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]
assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz"


def test_graph_dict_init(load_graph: Any) -> None:
G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person")
assert db.collection("_graphs").has("KarateGraph")
Expand Down