diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index d0bccfe9..32949822 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -2258,7 +2258,8 @@ def set_edge_multigraph( load_all_edge_attributes=True, is_directed=self.is_directed, is_multigraph=self.is_multigraph, - symmetrize_edges_if_directed=self.symmetrize_edges_if_directed, + symmetrize_edges_if_directed=self.is_directed + and self.symmetrize_edges_if_directed, ) if self.is_directed: diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 23d18fea..8989ac1d 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -45,11 +45,11 @@ def __init__( read_parallelism, read_batch_size, write_batch_size, + symmetrize_edges, *args, **kwargs, ) - self.symmetrize_edges = symmetrize_edges if self.graph_exists_in_db: assert isinstance(self._succ, AdjListOuterDict) assert isinstance(self._pred, AdjListOuterDict) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 6bc88865..843bb2b2 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -52,6 +52,7 @@ def __init__( read_parallelism: int = 10, read_batch_size: int = 100000, write_batch_size: int = 50000, + symmetrize_edges: bool = False, *args: Any, **kwargs: Any, ): @@ -80,7 +81,8 @@ def __init__( self.edge_indices: npt.NDArray[np.int64] | None = None self.vertex_ids_to_index: dict[str, int] | None = None - self.symmetrize_edges = False # Does not apply to undirected graphs + # Does not apply to undirected graphs + self.symmetrize_edges = symmetrize_edges self.edge_type_key = edge_type_key diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index 00194dbf..2ec8d3cc 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -1,6 +1,7 @@ -from typing import ClassVar +from typing import Any, Callable, ClassVar import networkx as nx +from arango.database import StandardDatabase import nx_arangodb as nxadb from nx_arangodb.classes.digraph import DiGraph @@ -20,10 +21,33 @@ class MultiDiGraph(MultiGraph, DiGraph, nx.MultiDiGraph): def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph # type: ignore[no-any-return] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - m = "nxadb.MultiDiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiDiGraph for now." # noqa - logger.warning(m) + def __init__( + self, + graph_name: str | None = None, + default_node_type: str | None = None, + edge_type_key: str = "_edge_type", + edge_type_func: Callable[[str, str], str] | None = None, + db: StandardDatabase | None = None, + read_parallelism: int = 10, + read_batch_size: int = 100000, + write_batch_size: int = 50000, + symmetrize_edges: bool = False, + *args: Any, + **kwargs: Any, + ): + super().__init__( + graph_name, + default_node_type, + edge_type_key, + edge_type_func, + db, + read_parallelism, + read_batch_size, + write_batch_size, + symmetrize_edges, + *args, + **kwargs, + ) ####################### # Init helper methods # diff --git a/tests/test.py b/tests/test.py index 6757289d..38d96b05 100644 --- a/tests/test.py +++ b/tests/test.py @@ -28,22 +28,30 @@ def assert_same_dict_values( def assert_bc(d1: dict[str | int, float], d2: dict[str | int, float]) -> None: + assert d1 + assert d2 assert_same_dict_values(d1, d2, 14) def assert_pagerank(d1: dict[str | int, float], d2: dict[str | int, float]) -> None: + assert d1 + assert d2 assert_same_dict_values(d1, d2, 15) def assert_louvain(l1: list[set[Any]], l2: list[set[Any]]) -> None: # TODO: Implement some kind of comparison # Reason: Louvain returns different results on different runs + assert l1 + assert l2 pass def assert_k_components( d1: dict[int, list[set[Any]]], d2: dict[int, list[set[Any]]] ) -> None: + assert d1 + assert d2 assert d1.keys() == d2.keys(), "Dictionaries have different keys" assert d1 == d2 @@ -91,6 +99,8 @@ def test_algorithm( G_4 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=True) G_5 = nxadb.DiGraph(graph_name="KarateGraph", symmetrize_edges=False) G_6 = nxadb.MultiGraph(graph_name="KarateGraph") + G_7 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=True) + G_8 = nxadb.MultiDiGraph(graph_name="KarateGraph", symmetrize_edges=False) r_1 = algorithm_func(G_1) r_2 = algorithm_func(G_2) @@ -121,7 +131,12 @@ def test_algorithm( r_11 = algorithm_func(G_6) r_11_orig = algorithm_func.orig_func(G_6) # type: ignore - assert all([r_7, r_7_orig, r_8, r_8_orig, r_9, r_9_orig, r_10, r_11, r_11_orig]) + r_12 = algorithm_func(G_7) + r_12_orig = algorithm_func.orig_func(G_7) # type: ignore + + r_13 = algorithm_func(G_8) + r_13_orig = algorithm_func.orig_func(G_8) # type: ignore + assert_func(r_7, r_7_orig) assert_func(r_8, r_8_orig) assert_func(r_9, r_9_orig) @@ -134,6 +149,14 @@ def test_algorithm( assert_func(r_7, r_11) assert_func(r_8, r_11) assert_func(r_11, r_11_orig) + assert_func(r_12, r_12_orig) + assert_func(r_13, r_13_orig) + assert r_12 != r_13 + assert r_12_orig != r_13_orig + assert_func(r_8, r_12) + assert_func(r_8_orig, r_12_orig) + assert_func(r_9, r_13) + assert_func(r_9_orig, r_13_orig) def test_shortest_path_remote_algorithm(load_graph: Any) -> None: @@ -157,6 +180,7 @@ def test_shortest_path_remote_algorithm(load_graph: Any) -> None: (nxadb.Graph), (nxadb.DiGraph), (nxadb.MultiGraph), + (nxadb.MultiDiGraph), ], ) def test_nodes_crud(load_graph: Any, graph_cls: type[nxadb.Graph]) -> None: @@ -741,6 +765,176 @@ def test_multigraph_edges_crud(load_graph: Any) -> None: assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" +def test_multidigraph_edges_crud(load_graph: Any) -> None: + G_1 = nxadb.MultiDiGraph(graph_name="KarateGraph") + G_2 = G_NX + + assert len(G_1.adj) == len(G_2.adj) + assert len(G_1.edges) == len(G_2.edges) + assert G_1.number_of_edges() == G_2.number_of_edges() + + for src, dst, w in G_1.edges.data("weight"): + assert G_1.adj[src][dst][0]["weight"] == w + + for src, dst, w in G_1.edges.data("bad_key", default="boom!"): + assert "bad_key" not in G_1.adj[src][dst][0] + assert w == "boom!" + + for k, edge_key_dict in G_1.adj["person/1"].items(): + assert db.has_document(k) + assert db.has_document(edge_key_dict[0]["_id"]) + + G_1.add_edge("person/1", "person/1", foo="bar", _edge_type="knows") + edge_id = G_1.adj["person/1"]["person/1"][0]["_id"] + doc = db.document(edge_id) + assert doc["foo"] == "bar" + assert G_1.adj["person/1"]["person/1"][0]["foo"] == "bar" + + del G_1.adj["person/1"]["person/1"][0]["foo"] + doc = db.document(edge_id) + assert "foo" not in doc + + G_1.adj["person/1"]["person/1"][0].update({"bar": "foo"}) + doc = db.document(edge_id) + assert doc["bar"] == "foo" + + assert len(G_1.adj["person/1"]["person/1"][0]) == len(doc) + adj_count = len(G_1.adj["person/1"]) + G_1.remove_edge("person/1", "person/1") + assert len(G_1.adj["person/1"]) == adj_count - 1 + assert not db.has_document(edge_id) + assert "person/1" in G_1 + + assert not db.has_document("person/new_node_1") + col_count = db.collection("knows").count() + + G_1.add_edge("new_node_1", "new_node_2", foo="bar") + assert db.document(G_1["new_node_1"]["new_node_2"][0]["_id"])["foo"] == "bar" + G_1.add_edge("new_node_1", "new_node_2", foo="bar", bar="foo") + doc = db.document(G_1["new_node_1"]["new_node_2"][1]["_id"]) + assert doc["foo"] == "bar" + assert doc["bar"] == "foo" + + bind_vars = { + "src": f"{G_1.default_node_type}/new_node_1", + "dst": f"{G_1.default_node_type}/new_node_2", + } + + result = list( + db.aql.execute( + f"FOR e IN knows FILTER e._from == @src AND e._to == @dst RETURN e", # noqa + bind_vars=bind_vars, + ) + ) + + assert len(result) == 2 + + result = list( + db.aql.execute( + f"FOR e IN knows FILTER e._from == @dst AND e._to == @src RETURN e", # noqa + bind_vars=bind_vars, + ) + ) + + assert len(result) == 0 + + assert db.collection("knows").count() == col_count + 2 + assert G_1.adj["new_node_1"]["new_node_2"][0] + assert G_1.adj["new_node_1"]["new_node_2"][0]["foo"] == "bar" + assert G_1.pred["new_node_2"]["new_node_1"][0] + assert "new_node_1" not in G_1.adj["new_node_2"] + assert ( + G_1.adj["new_node_1"]["new_node_2"][0]["_id"] + == G_1.pred["new_node_2"]["new_node_1"][0]["_id"] + ) + edge_id = G_1.adj["new_node_1"]["new_node_2"][0]["_id"] + doc = db.document(edge_id) + assert db.has_document(doc["_from"]) + assert db.has_document(doc["_to"]) + assert G_1.nodes["new_node_1"] + assert G_1.nodes["new_node_2"] + + assert len(G_1.adj["new_node_1"]["new_node_2"]) == 2 + G_1.remove_edge("new_node_1", "new_node_2") + G_1.clear() + assert "new_node_1" in G_1 + assert "new_node_2" in G_1 + assert "new_node_2" in G_1.adj["new_node_1"] + assert len(G_1.adj["new_node_1"]["new_node_2"]) == 1 + + G_1.add_edges_from( + [("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")], foo="bar" + ) + G_1.clear() + assert "new_node_1" in G_1 + assert "new_node_2" in G_1 + assert "new_node_3" in G_1 + + for k in G_1.adj["new_node_1"]["new_node_2"]: + assert G_1.adj["new_node_1"]["new_node_2"][k]["foo"] == "bar" + assert G_1.pred["new_node_2"]["new_node_1"][k]["foo"] == "bar" + + for k in G_1.adj["new_node_1"]["new_node_3"]: + assert G_1.adj["new_node_1"]["new_node_3"][k]["foo"] == "bar" + assert G_1.pred["new_node_3"]["new_node_1"][k]["foo"] == "bar" + + assert len(G_1.adj["new_node_1"]["new_node_2"]) == 2 + assert len(G_1.adj["new_node_1"]["new_node_3"]) == 1 + G_1.remove_edges_from([("new_node_1", "new_node_2"), ("new_node_1", "new_node_3")]) + assert len(G_1.adj["new_node_1"]["new_node_2"]) == 1 + + assert "new_node_1" in G_1 + assert "new_node_2" in G_1 + assert "new_node_3" in G_1 + assert "new_node_2" in G_1.adj["new_node_1"] + assert "new_node_3" not in G_1.adj["new_node_1"] + + edge_id = "knows/1" + assert "person/1" not in G_1["person/2"] + assert ( + G_1.succ["person/1"]["person/2"][edge_id] + == G_1.pred["person/2"]["person/1"][edge_id] + ) + new_weight = 1000 + G_1["person/1"]["person/2"][edge_id]["weight"] = new_weight + assert G_1.succ["person/1"]["person/2"][edge_id]["weight"] == new_weight + assert G_1.pred["person/2"]["person/1"][edge_id]["weight"] == new_weight + G_1.clear() + assert G_1.succ["person/1"]["person/2"][edge_id]["weight"] == new_weight + G_1.clear() + assert G_1.pred["person/2"]["person/1"][edge_id]["weight"] == new_weight + + edge_id = G_1["person/1"]["person/2"][edge_id]["_id"] + G_1["person/1"]["person/2"][edge_id]["object"] = {"foo": "bar", "bar": "foo"} + assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"] + assert isinstance(G_1["person/1"]["person/2"][edge_id]["object"], EdgeAttrDict) + assert db.document(edge_id)["object"] == {"foo": "bar", "bar": "foo"} + + G_1["person/1"]["person/2"][edge_id]["object"]["foo"] = "baz" + assert db.document(edge_id)["object"]["foo"] == "baz" + + del G_1["person/1"]["person/2"][edge_id]["object"]["foo"] + assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"] + assert isinstance(G_1["person/1"]["person/2"][edge_id]["object"], EdgeAttrDict) + assert "foo" not in db.document(edge_id)["object"] + + G_1["person/1"]["person/2"][edge_id]["object"].update( + {"sub_object": {"foo": "bar"}} + ) + assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"] + assert isinstance( + G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"], EdgeAttrDict + ) + assert db.document(edge_id)["object"]["sub_object"]["foo"] == "bar" + + G_1.clear() + + assert G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]["foo"] == "bar" + G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"]["foo"] = "baz" + assert "_rev" not in G_1["person/1"]["person/2"][edge_id]["object"]["sub_object"] + assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" + + def test_graph_dict_init(load_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") assert db.collection("_graphs").has("KarateGraph")