diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index d15bfcc5..eb26d8ba 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -5,6 +5,7 @@ from __future__ import annotations +import json from collections import UserDict, defaultdict from collections.abc import Iterator from typing import Any, Callable, Generator @@ -36,6 +37,7 @@ get_arangodb_graph, get_node_id, get_node_type_and_id, + json_serializable, key_is_not_reserved, key_is_string, keys_are_not_reserved, @@ -54,6 +56,12 @@ def graph_dict_factory( return lambda: GraphDict(db, graph_name) +def graph_attr_dict_factory( + db: StandardDatabase, graph: Graph, graph_id: str +) -> Callable[..., GraphAttrDict]: + return lambda: GraphAttrDict(db, graph, graph_id) + + def node_dict_factory( db: StandardDatabase, graph: Graph, default_node_type: str ) -> Callable[..., NodeDict]: @@ -98,6 +106,36 @@ def edge_attr_dict_factory( ######### +def build_graph_attr_dict_data( + parent: GraphAttrDict, data: dict[str, Any] +) -> dict[str, Any | GraphAttrDict]: + """Recursively build a GraphAttrDict from a dict. + + It's possible that **value** is a nested dict, so we need to + recursively build a GraphAttrDict for each nested dict. + + Returns the parent GraphAttrDict. + """ + graph_attr_dict_data = {} + for key, value in data.items(): + graph_attr_dict_value = process_graph_attr_dict_value(parent, key, value) + graph_attr_dict_data[key] = graph_attr_dict_value + + return graph_attr_dict_data + + +def process_graph_attr_dict_value(parent: GraphAttrDict, key: str, value: Any) -> Any: + if not isinstance(value, dict): + return value + + graph_attr_dict = parent.graph_attr_dict_factory() + graph_attr_dict.root = parent.root or parent + graph_attr_dict.parent_keys = parent.parent_keys + [key] + graph_attr_dict.data = build_graph_attr_dict_data(graph_attr_dict, value) + + return graph_attr_dict + + class GraphDict(UserDict[str, Any]): """A dictionary-like object for storing graph attributes. @@ -110,8 +148,6 @@ class GraphDict(UserDict[str, Any]): :type graph_name: str """ - COLLECTION_NAME = "nxadb_graphs" - @logger_debug def __init__( self, db: StandardDatabase, graph_name: str, *args: Any, **kwargs: Any @@ -121,13 +157,28 @@ def __init__( self.db = db self.graph_name = graph_name + self.COLLECTION_NAME = "nxadb_graphs" self.graph_id = f"{self.COLLECTION_NAME}/{graph_name}" self.adb_graph = db.graph(graph_name) self.collection = create_collection(db, self.COLLECTION_NAME) + self.graph_attr_dict_factory = graph_attr_dict_factory( + self.db, self.adb_graph, self.graph_id + ) + + result = doc_get_or_insert(self.db, self.COLLECTION_NAME, self.graph_id) + for k, v in result.items(): + self.data[k] = self.__process_graph_dict_value(k, v) + + def __process_graph_dict_value(self, key: str, value: Any) -> Any: + if not isinstance(value, dict): + return value - data = doc_get_or_insert(self.db, self.COLLECTION_NAME, self.graph_id) - self.data.update(data) + graph_attr_dict = self.graph_attr_dict_factory() + graph_attr_dict.parent_keys = [key] + graph_attr_dict.data = build_graph_attr_dict_data(graph_attr_dict, value) + + return graph_attr_dict @key_is_string @logger_debug @@ -148,20 +199,25 @@ def __getitem__(self, key: str) -> Any: result = aql_doc_get_key(self.db, self.graph_id, key) - if not result: + if result is None: raise KeyError(key) - self.data[key] = result + graph_dict_value = self.__process_graph_dict_value(key, result) + self.data[key] = graph_dict_value - return result + return graph_dict_value @key_is_string @key_is_not_reserved @logger_debug - # @value_is_json_serializable # TODO? def __setitem__(self, key: str, value: Any) -> None: """G.graph['foo'] = 'bar'""" - self.data[key] = value + if value is None: + self.__delitem__(key) + return + + graph_dict_value = self.__process_graph_dict_value(key, value) + self.data[key] = graph_dict_value self.data["_rev"] = doc_update(self.db, self.graph_id, {key: value}) @key_is_string @@ -172,25 +228,128 @@ def __delitem__(self, key: str) -> None: self.data.pop(key, None) self.data["_rev"] = doc_update(self.db, self.graph_id, {key: None}) - @keys_are_strings - @keys_are_not_reserved # @values_are_json_serializable # TODO? @logger_debug def update(self, attrs: Any) -> None: """G.graph.update({'foo': 'bar'})""" + if not attrs: return - self.data.update(attrs) + graph_attr_dict = self.graph_attr_dict_factory() + graph_attr_dict_data = build_graph_attr_dict_data(graph_attr_dict, attrs) + graph_attr_dict.data = graph_attr_dict_data + + self.data.update(graph_attr_dict_data) self.data["_rev"] = doc_update(self.db, self.graph_id, attrs) - # @logger_debug - # def clear(self) -> None: - # """G.graph.clear()""" - # self.data.clear() + @logger_debug + def clear(self) -> None: + """G.graph.clear()""" + self.data.clear() + + +@json_serializable +class GraphAttrDict(UserDict[str, Any]): + """The inner-level of the dict of dict structure + representing the attributes of a graph stored in the database. + + Only used if the value associated with a GraphDict key is a dict. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph: The ArangoDB graph. + :type graph: Graph + :param graph_id: The ArangoDB graph ID. + :type graph_id: str + """ + + @logger_debug + def __init__( + self, + db: StandardDatabase, + graph: Graph, + graph_id: str, + *args: Any, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + self.data: dict[str, Any] = {} + + self.db = db + self.graph = graph + self.graph_id: str = graph_id + + self.root: GraphAttrDict | None = None + self.parent_keys: list[str] = [] + self.graph_attr_dict_factory = graph_attr_dict_factory( + self.db, self.graph, self.graph_id + ) - # # if clear_remote: - # # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True) + @key_is_string + @logger_debug + def __contains__(self, key: str) -> bool: + """'bar' in G.graph['foo']""" + if key in self.data: + return True + + return aql_doc_has_key(self.db, self.graph.name, key) + + @key_is_string + @logger_debug + def __getitem__(self, key: str) -> Any: + """G.graph['foo']['bar']""" + + if value := self.data.get(key): + return value + + result = aql_doc_get_key(self.db, self.graph_id, key, self.parent_keys) + + if result is None: + raise KeyError(key) + + graph_attr_dict_value = process_graph_attr_dict_value(self, key, result) + self.data[key] = graph_attr_dict_value + + return graph_attr_dict_value + + @key_is_string + @logger_debug + def __setitem__(self, key, value): + """ + G.graph['foo'] = 'bar' + G.graph['object'] = {'foo': 'bar'} + G._node['object']['foo'] = 'baz' + """ + if value is None: + self.__delitem__(key) + return + + graph_attr_dict_value = process_graph_attr_dict_value(self, key, value) + update_dict = get_update_dict(self.parent_keys, {key: value}) + self.data[key] = graph_attr_dict_value + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict) + + @key_is_string + @logger_debug + def __delitem__(self, key): + """del G.graph['foo']['bar']""" + self.data.pop(key, None) + update_dict = get_update_dict(self.parent_keys, {key: None}) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict) + + @logger_debug + def update(self, attrs: Any) -> None: + """G.graph['foo'].update({'bar': 'baz'})""" + if not attrs: + return + + self.data.update(build_graph_attr_dict_data(self, attrs)) + updated_dict = get_update_dict(self.parent_keys, attrs) + root_data = self.root.data if self.root else self.data + root_data["_rev"] = doc_update(self.db, self.graph_id, updated_dict) ######## @@ -304,6 +463,10 @@ def __setitem__(self, key: str, value: Any) -> None: G._node['node/1']['object'] = {'foo': 'bar'} G._node['node/1']['object']['foo'] = 'baz' """ + if value is None: + self.__delitem__(key) + return + assert self.node_id node_attr_dict_value = process_node_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) @@ -656,6 +819,10 @@ def __getitem__(self, key: str) -> Any: @logger_debug def __setitem__(self, key: str, value: Any) -> None: """G._adj['node/1']['node/2']['foo'] = 'bar'""" + if value is None: + self.__delitem__(key) + return + assert self.edge_id edge_attr_dict_value = process_edge_attr_dict_value(self, key, value) update_dict = get_update_dict(self.parent_keys, {key: value}) diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index 44a2ef98..873ff74a 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -97,6 +97,17 @@ def get_arangodb_graph( ) +def json_serializable(cls): + def to_dict(self): + return { + key: (value.to_dict() if isinstance(value, cls) else value) + for key, value in self.items() + } + + cls.to_dict = to_dict + return cls + + def key_is_string(func: Callable[..., Any]) -> Any: """Decorator to check if the key is a string.""" diff --git a/tests/test.py b/tests/test.py index 0c9da34c..e15b1a47 100644 --- a/tests/test.py +++ b/tests/test.py @@ -3,6 +3,7 @@ import networkx as nx import pandas as pd import pytest +from arango import DocumentDeleteError import nx_arangodb as nxadb from nx_arangodb.classes.dict import EdgeAttrDict, NodeAttrDict @@ -423,6 +424,194 @@ def test_graph_edges_crud(load_graph: Any) -> None: assert db.document(edge_id)["object"]["sub_object"]["foo"] == "baz" +def test_graph_dict_init(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + assert db.collection("_graphs").has("KarateGraph") + graph_document = db.collection("_graphs").get("KarateGraph") + assert graph_document["_key"] == "KarateGraph" + assert graph_document["edgeDefinitions"] == [ + {"collection": "knows", "from": ["person"], "to": ["person"]}, + {"collection": "person_to_person", "from": ["person"], "to": ["person"]}, + ] + assert graph_document["orphanCollections"] == [] + + graph_doc_id = G.graph.graph_id + assert db.has_document(graph_doc_id) + + +def test_graph_dict_init_extended(load_graph: Any) -> None: + # Tests that available data (especially dicts) will be properly + # stored as GraphDicts in the internal cache. + G = nxadb.Graph(graph_name="KarateGraph", foo="bar", bar={"baz": True}) + G.graph["foo"] = "!!!" + G.graph["bar"]["baz"] = False + assert db.document(G.graph.graph_id)["foo"] == "!!!" + assert db.document(G.graph.graph_id)["bar"]["baz"] is False + assert "baz" not in db.document(G.graph.graph_id) + + +def test_graph_dict_clear_will_not_remove_remote_data(load_graph: Any) -> None: + G_adb = nxadb.Graph( + graph_name="KarateGraph", + foo="bar", + bar={"a": 4}, + ) + + G_adb.graph["ant"] = {"b": 5} + G_adb.graph["ant"]["b"] = 6 + G_adb.clear() + try: + G_adb.graph["ant"] + except KeyError: + raise AssertionError("Not allowed to fail.") + + assert db.document(G_adb.graph.graph_id)["ant"] == {"b": 6} + + +def test_graph_dict_set_item(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + try: + db.collection(G.graph.COLLECTION_NAME).delete(G.name) + except DocumentDeleteError: + pass + except Exception as e: + print(f"An unexpected error occurred: {e}") + raise + + json_values = [ + "aString", + 1, + 1.0, + True, + False, + {"a": "b"}, + ["a", "b", "c"], + {"a": "b", "c": ["a", "b", "c"]}, + None, + ] + + for value in json_values: + G.graph["json"] = value + + if value is None: + assert "json" not in db.document(G.graph.graph_id) + else: + assert G.graph["json"] == value + assert db.document(G.graph.graph_id)["json"] == value + + +def test_graph_dict_update(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + + G.graph["a"] = "b" + to_update = {"c": "d"} + G.graph.update(to_update) + + # local + assert G.graph["a"] == "b" + assert G.graph["c"] == "d" + + # remote + adb_doc = db.collection("nxadb_graphs").get(G.graph_name) + assert adb_doc["a"] == "b" + assert adb_doc["c"] == "d" + + +def test_graph_attr_dict_nested_update(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + + G.graph["a"] = {"b": "c"} + G.graph["a"].update({"d": "e"}) + assert G.graph["a"]["b"] == "c" + assert G.graph["a"]["d"] == "e" + assert db.document(G.graph.graph_id)["a"]["b"] == "c" + assert db.document(G.graph.graph_id)["a"]["d"] == "e" + + +def test_graph_dict_nested_1(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + icon = {"football_icon": "MJ7"} + + G.graph["a"] = {"b": icon} + assert G.graph["a"]["b"] == icon + assert db.document(G.graph.graph_id)["a"]["b"] == icon + + +def test_graph_dict_nested_2(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + icon = {"football_icon": "MJ7"} + + G.graph["x"] = {"y": icon} + G.graph["x"]["y"]["amount_of_goals"] = 1337 + + assert G.graph["x"]["y"]["amount_of_goals"] == 1337 + assert db.document(G.graph.graph_id)["x"]["y"]["amount_of_goals"] == 1337 + + +def test_graph_dict_empty_values(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + + G.graph["empty"] = {} + assert G.graph["empty"] == {} + assert db.document(G.graph.graph_id)["empty"] == {} + + G.graph["none"] = None + assert "none" not in db.document(G.graph.graph_id) + assert "none" not in G.graph + + +def test_graph_dict_nested_overwrite(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + icon1 = {"football_icon": "MJ7"} + icon2 = {"basketball_icon": "MJ23"} + + G.graph["a"] = {"b": icon1} + G.graph["a"]["b"]["football_icon"] = "ChangedIcon" + assert G.graph["a"]["b"]["football_icon"] == "ChangedIcon" + assert db.document(G.graph.graph_id)["a"]["b"]["football_icon"] == "ChangedIcon" + + # Overwrite entire nested dictionary + G.graph["a"] = {"b": icon2} + assert G.graph["a"]["b"]["basketball_icon"] == "MJ23" + assert db.document(G.graph.graph_id)["a"]["b"]["basketball_icon"] == "MJ23" + + +def test_graph_dict_complex_nested(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + + complex_structure = {"level1": {"level2": {"level3": {"key": "value"}}}} + + G.graph["complex"] = complex_structure + assert G.graph["complex"]["level1"]["level2"]["level3"]["key"] == "value" + assert ( + db.document(G.graph.graph_id)["complex"]["level1"]["level2"]["level3"]["key"] + == "value" + ) + + +def test_graph_dict_nested_deletion(load_graph: Any) -> None: + G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person") + G.clear() + icon = {"football_icon": "MJ7", "amount_of_goals": 1337} + + G.graph["x"] = {"y": icon} + del G.graph["x"]["y"]["amount_of_goals"] + assert "amount_of_goals" not in G.graph["x"]["y"] + assert "amount_of_goals" not in db.document(G.graph.graph_id)["x"]["y"] + + # Delete top-level key + del G.graph["x"] + assert "x" not in G.graph + assert "x" not in db.document(G.graph.graph_id) + + def test_readme(load_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph", default_node_type="person")