diff --git a/nx_arangodb/classes/dict/adj.py b/nx_arangodb/classes/dict/adj.py index 1c9ea7cc..c2543eaa 100644 --- a/nx_arangodb/classes/dict/adj.py +++ b/nx_arangodb/classes/dict/adj.py @@ -34,6 +34,8 @@ check_list_for_errors, doc_insert, doc_update, + edge_get, + edge_link, get_arangodb_graph, get_node_id, get_node_type_and_id, @@ -144,7 +146,6 @@ def process_edge_attr_dict_value(parent: EdgeAttrDict, key: str, value: Any) -> return value edge_attr_dict = parent.edge_attr_dict_factory() - edge_attr_dict.root = parent.root or parent edge_attr_dict.edge_id = parent.edge_id edge_attr_dict.parent_keys = parent.parent_keys + [key] edge_attr_dict.data = build_edge_attr_dict_data(edge_attr_dict, value) @@ -183,8 +184,6 @@ def __init__( # EdgeAttrDict may be a child of another EdgeAttrDict # e.g G._adj['node/1']['node/2']['object']['foo'] = 'bar' # In this case, **parent_keys** would be ['object'] - # and **root** would be G._adj['node/1']['node/2'] - self.root: EdgeAttrDict | None = None self.parent_keys: list[str] = [] self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) @@ -236,8 +235,7 @@ def __setitem__(self, key: str, value: Any) -> None: edge_attr_dict_value = process_edge_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) self.data[key] = edge_attr_dict_value - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict) + doc_update(self.db, self.edge_id, update_dict) @key_is_string @key_is_not_reserved @@ -247,8 +245,7 @@ def __delitem__(self, key: str) -> None: assert self.edge_id self.data.pop(key, None) update_dict = get_update_dict(self.parent_keys, {key: None}) - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict) + doc_update(self.db, self.edge_id, update_dict) @keys_are_strings @keys_are_not_reserved @@ -265,8 +262,7 @@ def update(self, attrs: Any) -> None: return update_dict = get_update_dict(self.parent_keys, attrs) - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.edge_id, update_dict) + doc_update(self.db, self.edge_id, update_dict) class EdgeKeyDict(UserDict[str, EdgeAttrDict]): @@ -457,7 +453,7 @@ def __contains__(self, key: str | int) -> bool: if self.FETCHED_ALL_IDS: return False - edge = self.graph.edge(key) + edge = edge_get(self.graph, key) if edge is None: logger.warning(f"Edge '{key}' does not exist in Graph.") @@ -500,7 +496,7 @@ def __getitem__(self, key: str | int) -> EdgeAttrDict: if key not in self.data and self.FETCHED_ALL_IDS: raise KeyError(key) - edge = self.graph.edge(key) + edge = edge_get(self.graph, key) if edge is None: raise KeyError(key) @@ -546,8 +542,12 @@ def __setitem__(self, key: int, edge_attr_dict: EdgeAttrDict) -> None: if not edge_type: edge_type = self.default_edge_type - edge = self.graph.link( - edge_type, self.src_node_id, self.dst_node_id, edge_attr_dict.data + edge = edge_link( + self.graph, + edge_type, + self.src_node_id, + self.dst_node_id, + edge_attr_dict.data, ) edge_data: dict[str, Any] = { @@ -1032,12 +1032,17 @@ def __setitem__graph( can_return_multiple=False, ) - if edge_id: - edge = doc_insert(self.db, edge_type, edge_id, edge_attr_dict.data) - else: - edge = self.graph.link( - edge_type, self.src_node_id, dst_node_id, edge_attr_dict.data + edge = ( + doc_insert(self.db, edge_type, edge_id, edge_attr_dict.data) + if edge_id + else edge_link( + self.graph, + edge_type, + self.src_node_id, + dst_node_id, + edge_attr_dict.data, ) + ) edge_data: dict[str, Any] = { **edge_attr_dict.data, @@ -1064,6 +1069,7 @@ def __setitem__multigraph( assert list(edge_key_dict.data.keys())[0] == "-1" assert edge_key_dict.src_node_id is None assert edge_key_dict.dst_node_id is None + assert self.src_node_id is not None edge_attr_dict = edge_key_dict.data["-1"] @@ -1071,8 +1077,8 @@ def __setitem__multigraph( if edge_type is None: edge_type = self.edge_type_func(self.src_node_type, dst_node_type) - edge = self.graph.link( - edge_type, self.src_node_id, dst_node_id, edge_attr_dict.data + edge = edge_link( + self.graph, edge_type, self.src_node_id, dst_node_id, edge_attr_dict.data ) edge_data: dict[str, Any] = { @@ -1217,7 +1223,7 @@ def _fetch_all(self) -> None: query = f""" FOR v, e IN 1..1 {self.traversal_direction.name} @src_node_id GRAPH @graph_name - RETURN e + RETURN UNSET(e, '_rev') """ bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} @@ -1550,14 +1556,20 @@ def __set_adj_elements( def set_edge_graph( src_node_id: str, dst_node_id: str, edge: dict[str, Any] ) -> EdgeAttrDict: + edge.pop("_rev", None) + adjlist_inner_dict = self.data[src_node_id] edge_attr_dict: EdgeAttrDict edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) - adjlist_inner_dict.data[dst_node_id] = edge_attr_dict + if dst_node_id not in adjlist_inner_dict.data: + adjlist_inner_dict.data[dst_node_id] = edge_attr_dict + else: + existing_edge_attr_dict = adjlist_inner_dict.data[dst_node_id] + existing_edge_attr_dict.data.update(edge_attr_dict.data) - return edge_attr_dict + return adjlist_inner_dict.data[dst_node_id] # type: ignore # false positive def set_edge_multigraph( src_node_id: str, dst_node_id: str, edges: dict[int, dict[str, Any]] @@ -1571,8 +1583,16 @@ def set_edge_multigraph( edge_key_dict.FETCHED_ALL_IDS = True for edge in edges.values(): + edge.pop("_rev", None) + + edge_attr_dict: EdgeAttrDict edge_attr_dict = adjlist_inner_dict._create_edge_attr_dict(edge) - edge_key_dict.data[edge["_id"]] = edge_attr_dict + + if edge["_id"] not in edge_key_dict.data: + edge_key_dict.data[edge["_id"]] = edge_attr_dict + else: + existing_edge_attr_dict = edge_key_dict.data[edge["_id"]] + existing_edge_attr_dict.data.update(edge_attr_dict.data) adjlist_inner_dict.data[dst_node_id] = edge_key_dict @@ -1618,11 +1638,6 @@ def propagate_edge_directed_symmetric( for src_node_id, inner_dict in edges_dict.items(): for dst_node_id, edge_or_edges in inner_dict.items(): - if not self.is_directed: - if src_node_id in self.data: - if dst_node_id in self.data[src_node_id].data: - continue # can skip due not directed - self.__set_adj_inner_dict(self, src_node_id) self.__set_adj_inner_dict(self, dst_node_id) edge_attr_or_key_dict = set_edge_func( # type: ignore[operator] diff --git a/nx_arangodb/classes/dict/graph.py b/nx_arangodb/classes/dict/graph.py index bc075126..840663ff 100644 --- a/nx_arangodb/classes/dict/graph.py +++ b/nx_arangodb/classes/dict/graph.py @@ -62,7 +62,6 @@ def process_graph_attr_dict_value(parent: GraphAttrDict, key: str, value: Any) - return value graph_attr_dict = parent.graph_attr_dict_factory() - graph_attr_dict.root = parent.root or parent graph_attr_dict.parent_keys = parent.parent_keys + [key] graph_attr_dict.data = build_graph_attr_dict_data(graph_attr_dict, value) @@ -149,7 +148,7 @@ def __setitem__(self, key: str, value: Any) -> None: graph_dict_value = self.__process_graph_dict_value(key, value) self.data[key] = graph_dict_value - self.data["_rev"] = doc_update(self.db, self.graph_id, {key: value}) + doc_update(self.db, self.graph_id, {key: value}) @key_is_string @key_is_not_reserved @@ -157,7 +156,7 @@ def __setitem__(self, key: str, value: Any) -> None: def __delitem__(self, key: str) -> None: """del G.graph['foo']""" self.data.pop(key, None) - self.data["_rev"] = doc_update(self.db, self.graph_id, {key: None}) + doc_update(self.db, self.graph_id, {key: None}) # @values_are_json_serializable # TODO? @logger_debug @@ -172,7 +171,7 @@ def update(self, attrs: Any) -> None: graph_attr_dict.data = graph_attr_dict_data self.data.update(graph_attr_dict_data) - self.data["_rev"] = doc_update(self.db, self.graph_id, attrs) + doc_update(self.db, self.graph_id, attrs) @logger_debug def clear(self) -> None: @@ -211,7 +210,6 @@ def __init__( self.graph = graph self.graph_id: str = graph_id - self.root: GraphAttrDict | None = None self.parent_keys: list[str] = [] self.graph_attr_dict_factory = graph_attr_dict_factory( self.db, self.graph, self.graph_id @@ -262,8 +260,7 @@ def __setitem__(self, key, value): graph_attr_dict_value = process_graph_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) self.data[key] = graph_attr_dict_value - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict) + doc_update(self.db, self.graph_id, update_dict) @key_is_string @logger_debug @@ -271,8 +268,7 @@ def __delitem__(self, key): """del G.graph['foo']['bar']""" self.data.pop(key, None) update_dict = get_update_dict(self.parent_keys, {key: None}) - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict) + doc_update(self.db, self.graph_id, update_dict) @logger_debug def update(self, attrs: Any) -> None: @@ -282,5 +278,4 @@ def update(self, attrs: Any) -> None: self.data.update(build_graph_attr_dict_data(self, attrs)) updated_dict = get_update_dict(self.parent_keys, attrs) - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.graph_id, updated_dict) + doc_update(self.db, self.graph_id, updated_dict) diff --git a/nx_arangodb/classes/dict/node.py b/nx_arangodb/classes/dict/node.py index 62314aa7..2348de8a 100644 --- a/nx_arangodb/classes/dict/node.py +++ b/nx_arangodb/classes/dict/node.py @@ -31,6 +31,7 @@ logger_debug, separate_nodes_by_collections, upsert_collection_documents, + vertex_get, ) ############# @@ -78,7 +79,6 @@ def process_node_attr_dict_value(parent: NodeAttrDict, key: str, value: Any) -> return value node_attr_dict = parent.node_attr_dict_factory() - node_attr_dict.root = parent.root or parent node_attr_dict.node_id = parent.node_id node_attr_dict.parent_keys = parent.parent_keys + [key] node_attr_dict.data = build_node_attr_dict_data(node_attr_dict, value) @@ -109,8 +109,6 @@ def __init__(self, db: StandardDatabase, graph: Graph, *args: Any, **kwargs: Any # NodeAttrDict may be a child of another NodeAttrDict # e.g G._node['node/1']['object']['foo'] = 'bar' # In this case, **parent_keys** would be ['object'] - # and **root** would be G._node['node/1'] - self.root: NodeAttrDict | None = None self.parent_keys: list[str] = [] self.node_attr_dict_factory = node_attr_dict_factory(self.db, self.graph) @@ -168,8 +166,7 @@ def __setitem__(self, key: str, value: Any) -> None: node_attr_dict_value = process_node_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) self.data[key] = node_attr_dict_value - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.node_id, update_dict) + doc_update(self.db, self.node_id, update_dict) @key_is_string @key_is_not_reserved @@ -179,8 +176,7 @@ def __delitem__(self, key: str) -> None: assert self.node_id self.data.pop(key, None) update_dict = get_update_dict(self.parent_keys, {key: None}) - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.node_id, update_dict) + doc_update(self.db, self.node_id, update_dict) @keys_are_strings @keys_are_not_reserved @@ -198,8 +194,7 @@ def update(self, attrs: Any) -> None: return update_dict = get_update_dict(self.parent_keys, attrs) - root_data = self.root.data if self.root else self.data - root_data["_rev"] = doc_update(self.db, self.node_id, update_dict) + doc_update(self.db, self.node_id, update_dict) class NodeDict(UserDict[str, NodeAttrDict]): @@ -280,14 +275,14 @@ def __getitem__(self, key: str) -> NodeAttrDict: """G._node['node/1']""" node_id = get_node_id(key, self.default_node_type) - if vertex := self.data.get(node_id): - return vertex + if vertex_cache := self.data.get(node_id): + return vertex_cache if node_id not in self.data and self.FETCHED_ALL_IDS: raise KeyError(key) - if vertex := self.graph.vertex(node_id): - node_attr_dict = self._create_node_attr_dict(vertex) + if vertex_db := vertex_get(self.graph, node_id): + node_attr_dict = self._create_node_attr_dict(vertex_db) self.data[node_id] = node_attr_dict return node_attr_dict @@ -458,6 +453,7 @@ def _fetch_all(self): ) for node_id, node_data in node_dict.items(): + del node_data["_rev"] # TODO: Optimize away via phenolrs node_attr_dict = self._create_node_attr_dict(node_data) self.data[node_id] = node_attr_dict diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index d62c92c8..52808c50 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -400,7 +400,10 @@ def aql_edge_get( direction: str, can_return_multiple: bool = False, ) -> Any | None: - return_clause = "DISTINCT e" if direction == "ANY" else "e" + return_clause = "UNSET(e, '_rev')" + if direction == "ANY": + return_clause = f"DISTINCT {return_clause}" + return aql_edge( db, src_node_id, @@ -583,10 +586,9 @@ def aql_fetch_data_edge( def doc_update( db: StandardDatabase, id: str, data: dict[str, Any], **kwargs: Any -) -> str: +) -> None: """Updates a document in the collection.""" - res = db.update_document({**data, "_id": id}, keep_none=False, **kwargs) - return str(res["_rev"]) + db.update_document({**data, "_id": id}, keep_none=False, silent=True, **kwargs) def doc_delete(db: StandardDatabase, id: str, **kwargs: Any) -> None: @@ -606,6 +608,8 @@ def doc_insert( collection, {**data, "_id": id}, overwrite=True, **kwargs ) + del result["_rev"] + return result @@ -620,6 +624,36 @@ def doc_get_or_insert( return doc_insert(db, collection, id, **kwargs) +def vertex_get(graph: Graph, id: str) -> dict[str, Any] | None: + """Gets a vertex from the graph.""" + vertex: dict[str, Any] | None = graph.vertex(id) + if vertex is None: + return None + + del vertex["_rev"] + return vertex + + +def edge_get(graph: Graph, id: str) -> dict[str, Any] | None: + """Gets an edge from the graph.""" + edge: dict[str, Any] | None = graph.edge(id) + if edge is None: + return None + + del edge["_rev"] + + return edge + + +def edge_link( + graph: Graph, collection: str, src_id: str, dst_id: str, data: dict[str, Any] +) -> dict[str, Any]: + """Links two vertices via an edge.""" + edge: dict[str, Any] = graph.link(collection, src_id, dst_id, data) + del edge["_rev"] + return edge + + def get_node_id(key: str, default_node_type: str) -> str: """Gets the node ID.""" return key if "/" in key else f"{default_node_type}/{key}" diff --git a/tests/test.py b/tests/test.py index bec8c021..2da38d58 100644 --- a/tests/test.py +++ b/tests/test.py @@ -439,6 +439,7 @@ def test_edge_adj_dict_update_existing_single_collection_graph_and_digraph( local_edges_dict[from_doc_id][to_doc_id] = { "_id": edge_doc_id, "extraValue": edge_doc["_key"], + "newDict": {"foo": "bar"}, } G_1.adj.update(local_edges_dict) @@ -446,18 +447,42 @@ def test_edge_adj_dict_update_existing_single_collection_graph_and_digraph( edge_col = db.collection("knows") edge_col_docs = edge_col.all() - # Check if the extraValue attribute was added to each document in the database + # Check if the attributes were added to each document in the database for doc in edge_col_docs: assert "extraValue" in doc assert doc["extraValue"] == doc["_key"] + assert "newDict" in doc + assert doc["newDict"] == {"foo": "bar"} + assert "weight" in doc - # Check if the extraValue attribute was added to each document in the local cache + # Check if the attributes were added to each document in the local cache for from_doc_id, target_dict in local_edges_dict.items(): for to_doc_id, edge_doc in target_dict.items(): - assert "extraValue" in G_1._adj[from_doc_id][to_doc_id] - assert G_1.adj[from_doc_id][to_doc_id][ - "extraValue" - ] == extract_arangodb_key(edge_doc["_id"]) + key = extract_arangodb_key(edge_doc["_id"]) + + adj_edge = G_1._adj.data[from_doc_id].data[to_doc_id].data + assert adj_edge["extraValue"] == key + assert db.document(edge_doc["_id"])["extraValue"] == key + assert "_rev" not in adj_edge + + assert isinstance(adj_edge["newDict"], EdgeAttrDict) + G_1.adj[from_doc_id][to_doc_id]["newDict"]["foo"] = "foo" + assert db.document(edge_doc["_id"])["newDict"] == {"foo": "foo"} + + if G_1.is_directed(): + pred_edge = G_1._pred.data[to_doc_id].data[from_doc_id].data + assert pred_edge["extraValue"] == key + assert "_rev" not in pred_edge + + assert isinstance(pred_edge["newDict"], EdgeAttrDict) + assert pred_edge["newDict"]["foo"] == "foo" + else: + reverse_adj_edge = G_1._adj.data[to_doc_id].data[from_doc_id].data + assert reverse_adj_edge["extraValue"] == key + assert "_rev" not in reverse_adj_edge + + assert isinstance(reverse_adj_edge["newDict"], EdgeAttrDict) + assert reverse_adj_edge["newDict"]["foo"] == "foo" @pytest.mark.parametrize( @@ -491,6 +516,7 @@ def test_edge_adj_dict_update_existing_single_collection_MultiGraph_and_MultiDiG local_edges_dict[from_doc_id][to_doc_id][edge_id] = { "_id": edge_doc["_id"], "extraValue": edge_doc["_key"], + "newDict": {"foo": "bar"}, } G_1.adj.update(local_edges_dict) @@ -498,19 +524,46 @@ def test_edge_adj_dict_update_existing_single_collection_MultiGraph_and_MultiDiG edge_col = db.collection("knows") edge_col_docs = edge_col.all() - # Check if the extraValue attribute was added to each document in the database + # Check if the attributes were added to each document in the database for doc in edge_col_docs: assert "extraValue" in doc assert doc["extraValue"] == doc["_key"] + assert "newDict" in doc + assert doc["newDict"] == {"foo": "bar"} + assert "weight" in doc - # Check if the extraValue attribute was added to each document in the local cache + # Check if the attributes were added to each document in the local cache for from_doc_id, target_dict in local_edges_dict.items(): for to_doc_id, edge_dict in target_dict.items(): for edge_id, edge_doc in edge_dict.items(): - assert "extraValue" in G_1._adj[from_doc_id][to_doc_id][edge_id] - assert G_1.adj[from_doc_id][to_doc_id][edge_id][ - "extraValue" - ] == extract_arangodb_key(edge_doc["_id"]) + key = extract_arangodb_key(edge_doc["_id"]) + + adj_edge = G_1._adj.data[from_doc_id].data[to_doc_id].data[edge_id].data + assert adj_edge["extraValue"] == key + assert db.document(edge_doc["_id"])["extraValue"] == key + + assert isinstance(adj_edge["newDict"], EdgeAttrDict) + G_1.adj[from_doc_id][to_doc_id][edge_id]["newDict"]["foo"] = "foo" + assert db.document(edge_doc["_id"])["newDict"] == {"foo": "foo"} + + if G_1.is_directed(): + pred_edge = ( + G_1._pred.data[to_doc_id].data[from_doc_id].data[edge_id].data + ) + assert pred_edge["extraValue"] == key + assert "_rev" not in pred_edge + + assert isinstance(pred_edge["newDict"], EdgeAttrDict) + assert pred_edge["newDict"]["foo"] == "foo" + else: + reverse_adj_edge = ( + G_1._adj.data[to_doc_id].data[from_doc_id].data[edge_id].data + ) + assert reverse_adj_edge["extraValue"] == key + assert "_rev" not in reverse_adj_edge + + assert isinstance(reverse_adj_edge["newDict"], EdgeAttrDict) + assert reverse_adj_edge["newDict"]["foo"] == "foo" def test_edge_dict_update_multiple_collections(load_two_relation_graph: Any) -> None: @@ -592,7 +645,9 @@ def test_nodes_crud(load_karate_graph: Any, graph_cls: type[nxadb.Graph]) -> Non assert len(G_1.nodes) == len(G_2.nodes) for k, v in G_1.nodes(data=True): - assert db.document(k) == v + doc = db.document(k) + del doc["_rev"] + assert doc == v for k, v in G_1.nodes(data="club"): assert db.document(k)["club"] == v @@ -618,7 +673,9 @@ def test_nodes_crud(load_karate_graph: Any, graph_cls: type[nxadb.Graph]) -> Non assert db.document("person/3")["club"] == "bar" for k in G_1: - assert G_1.nodes[k] == db.document(k) + doc = db.document(k) + del doc["_rev"] + assert G_1.nodes[k] == doc for v in G_1.nodes.values(): assert v @@ -737,6 +794,7 @@ def test_graph_edges_crud(load_karate_graph: Any) -> None: G_1.adj["person/1"]["person/1"].update({"bar": "foo"}) doc = db.document(edge_id) + del doc["_rev"] assert doc["bar"] == "foo" assert len(G_1.adj["person/1"]["person/1"]) == len(doc) @@ -884,6 +942,7 @@ def test_digraph_edges_crud(load_karate_graph: Any) -> None: G_1.adj["person/1"]["person/1"].update({"bar": "foo"}) doc = db.document(edge_id) + del doc["_rev"] assert doc["bar"] == "foo" assert len(G_1.adj["person/1"]["person/1"]) == len(doc) @@ -1033,6 +1092,7 @@ def test_multigraph_edges_crud(load_karate_graph: Any) -> None: G_1.adj["person/1"]["person/1"][0].update({"bar": "foo"}) doc = db.document(edge_id) + del doc["_rev"] assert doc["bar"] == "foo" assert len(G_1.adj["person/1"]["person/1"][0]) == len(doc) @@ -1195,6 +1255,7 @@ def test_multidigraph_edges_crud(load_karate_graph: Any) -> None: G_1.adj["person/1"]["person/1"][0].update({"bar": "foo"}) doc = db.document(edge_id) + del doc["_rev"] assert doc["bar"] == "foo" assert len(G_1.adj["person/1"]["person/1"][0]) == len(doc)