Skip to content

[GA-157] Recursive GraphDict #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 2, 2024
203 changes: 185 additions & 18 deletions nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import json
from collections import UserDict, defaultdict
from collections.abc import Iterator
from typing import Any, Callable, Generator
Expand Down Expand Up @@ -36,6 +37,7 @@
get_arangodb_graph,
get_node_id,
get_node_type_and_id,
json_serializable,
key_is_not_reserved,
key_is_string,
keys_are_not_reserved,
Expand All @@ -54,6 +56,12 @@ def graph_dict_factory(
return lambda: GraphDict(db, graph_name)


def graph_attr_dict_factory(
db: StandardDatabase, graph: Graph, graph_id: str
) -> Callable[..., GraphAttrDict]:
return lambda: GraphAttrDict(db, graph, graph_id)


def node_dict_factory(
db: StandardDatabase, graph: Graph, default_node_type: str
) -> Callable[..., NodeDict]:
Expand Down Expand Up @@ -98,6 +106,36 @@ def edge_attr_dict_factory(
#########


def build_graph_attr_dict_data(
parent: GraphAttrDict, data: dict[str, Any]
) -> dict[str, Any | GraphAttrDict]:
"""Recursively build a GraphAttrDict from a dict.

It's possible that **value** is a nested dict, so we need to
recursively build a GraphAttrDict for each nested dict.

Returns the parent GraphAttrDict.
"""
graph_attr_dict_data = {}
for key, value in data.items():
graph_attr_dict_value = process_graph_attr_dict_value(parent, key, value)
graph_attr_dict_data[key] = graph_attr_dict_value

return graph_attr_dict_data


def process_graph_attr_dict_value(parent: GraphAttrDict, key: str, value: Any) -> Any:
if not isinstance(value, dict):
return value

graph_attr_dict = parent.graph_attr_dict_factory()
graph_attr_dict.root = parent.root or parent
graph_attr_dict.parent_keys = parent.parent_keys + [key]
graph_attr_dict.data = build_graph_attr_dict_data(graph_attr_dict, value)

return graph_attr_dict


class GraphDict(UserDict[str, Any]):
"""A dictionary-like object for storing graph attributes.

Expand All @@ -110,8 +148,6 @@ class GraphDict(UserDict[str, Any]):
:type graph_name: str
"""

COLLECTION_NAME = "nxadb_graphs"

@logger_debug
def __init__(
self, db: StandardDatabase, graph_name: str, *args: Any, **kwargs: Any
Expand All @@ -121,13 +157,28 @@ def __init__(

self.db = db
self.graph_name = graph_name
self.COLLECTION_NAME = "nxadb_graphs"
self.graph_id = f"{self.COLLECTION_NAME}/{graph_name}"

self.adb_graph = db.graph(graph_name)
self.collection = create_collection(db, self.COLLECTION_NAME)
self.graph_attr_dict_factory = graph_attr_dict_factory(
self.db, self.adb_graph, self.graph_id
)

result = doc_get_or_insert(self.db, self.COLLECTION_NAME, self.graph_id)
for k, v in result.items():
self.data[k] = self.__process_graph_dict_value(k, v)

def __process_graph_dict_value(self, key: str, value: Any) -> Any:
if not isinstance(value, dict):
return value

data = doc_get_or_insert(self.db, self.COLLECTION_NAME, self.graph_id)
self.data.update(data)
graph_attr_dict = self.graph_attr_dict_factory()
graph_attr_dict.parent_keys = [key]
graph_attr_dict.data = build_graph_attr_dict_data(graph_attr_dict, value)

return graph_attr_dict

@key_is_string
@logger_debug
Expand All @@ -148,20 +199,25 @@ def __getitem__(self, key: str) -> Any:

result = aql_doc_get_key(self.db, self.graph_id, key)

if not result:
if result is None:
raise KeyError(key)

self.data[key] = result
graph_dict_value = self.__process_graph_dict_value(key, result)
self.data[key] = graph_dict_value

return result
return graph_dict_value

@key_is_string
@key_is_not_reserved
@logger_debug
# @value_is_json_serializable # TODO?
def __setitem__(self, key: str, value: Any) -> None:
"""G.graph['foo'] = 'bar'"""
self.data[key] = value
if value is None:
self.__delitem__(key)
return

graph_dict_value = self.__process_graph_dict_value(key, value)
self.data[key] = graph_dict_value
self.data["_rev"] = doc_update(self.db, self.graph_id, {key: value})

@key_is_string
Expand All @@ -172,25 +228,128 @@ def __delitem__(self, key: str) -> None:
self.data.pop(key, None)
self.data["_rev"] = doc_update(self.db, self.graph_id, {key: None})

@keys_are_strings
@keys_are_not_reserved
# @values_are_json_serializable # TODO?
@logger_debug
def update(self, attrs: Any) -> None:
"""G.graph.update({'foo': 'bar'})"""

if not attrs:
return

self.data.update(attrs)
graph_attr_dict = self.graph_attr_dict_factory()
graph_attr_dict_data = build_graph_attr_dict_data(graph_attr_dict, attrs)
graph_attr_dict.data = graph_attr_dict_data

self.data.update(graph_attr_dict_data)
self.data["_rev"] = doc_update(self.db, self.graph_id, attrs)

# @logger_debug
# def clear(self) -> None:
# """G.graph.clear()"""
# self.data.clear()
@logger_debug
def clear(self) -> None:
"""G.graph.clear()"""
self.data.clear()


@json_serializable
class GraphAttrDict(UserDict[str, Any]):
"""The inner-level of the dict of dict structure
representing the attributes of a graph stored in the database.

Only used if the value associated with a GraphDict key is a dict.

:param db: The ArangoDB database.
:type db: StandardDatabase
:param graph: The ArangoDB graph.
:type graph: Graph
:param graph_id: The ArangoDB graph ID.
:type graph_id: str
"""

@logger_debug
def __init__(
self,
db: StandardDatabase,
graph: Graph,
graph_id: str,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self.data: dict[str, Any] = {}

self.db = db
self.graph = graph
self.graph_id: str = graph_id

self.root: GraphAttrDict | None = None
self.parent_keys: list[str] = []
self.graph_attr_dict_factory = graph_attr_dict_factory(
self.db, self.graph, self.graph_id
)

# # if clear_remote:
# # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True)
@key_is_string
@logger_debug
def __contains__(self, key: str) -> bool:
"""'bar' in G.graph['foo']"""
if key in self.data:
return True

return aql_doc_has_key(self.db, self.graph.name, key)

@key_is_string
@logger_debug
def __getitem__(self, key: str) -> Any:
"""G.graph['foo']['bar']"""

if value := self.data.get(key):
return value

result = aql_doc_get_key(self.db, self.graph_id, key, self.parent_keys)

if result is None:
raise KeyError(key)

graph_attr_dict_value = process_graph_attr_dict_value(self, key, result)
self.data[key] = graph_attr_dict_value

return graph_attr_dict_value

@key_is_string
@logger_debug
def __setitem__(self, key, value):
"""
G.graph['foo'] = 'bar'
G.graph['object'] = {'foo': 'bar'}
G._node['object']['foo'] = 'baz'
"""
if value is None:
self.__delitem__(key)
return

graph_attr_dict_value = process_graph_attr_dict_value(self, key, value)
update_dict = get_update_dict(self.parent_keys, {key: value})
self.data[key] = graph_attr_dict_value
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict)

@key_is_string
@logger_debug
def __delitem__(self, key):
"""del G.graph['foo']['bar']"""
self.data.pop(key, None)
update_dict = get_update_dict(self.parent_keys, {key: None})
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.graph_id, update_dict)

@logger_debug
def update(self, attrs: Any) -> None:
"""G.graph['foo'].update({'bar': 'baz'})"""
if not attrs:
return

self.data.update(build_graph_attr_dict_data(self, attrs))
updated_dict = get_update_dict(self.parent_keys, attrs)
root_data = self.root.data if self.root else self.data
root_data["_rev"] = doc_update(self.db, self.graph_id, updated_dict)


########
Expand Down Expand Up @@ -304,6 +463,10 @@ def __setitem__(self, key: str, value: Any) -> None:
G._node['node/1']['object'] = {'foo': 'bar'}
G._node['node/1']['object']['foo'] = 'baz'
"""
if value is None:
self.__delitem__(key)
return

assert self.node_id
node_attr_dict_value = process_node_attr_dict_value(self, key, value)
update_dict = get_update_dict(self.parent_keys, {key: value})
Expand Down Expand Up @@ -656,6 +819,10 @@ def __getitem__(self, key: str) -> Any:
@logger_debug
def __setitem__(self, key: str, value: Any) -> None:
"""G._adj['node/1']['node/2']['foo'] = 'bar'"""
if value is None:
self.__delitem__(key)
return

assert self.edge_id
edge_attr_dict_value = process_edge_attr_dict_value(self, key, value)
update_dict = get_update_dict(self.parent_keys, {key: value})
Expand Down
11 changes: 11 additions & 0 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def get_arangodb_graph(
)


def json_serializable(cls):
def to_dict(self):
return {
key: (value.to_dict() if isinstance(value, cls) else value)
for key, value in self.items()
}

cls.to_dict = to_dict
return cls


def key_is_string(func: Callable[..., Any]) -> Any:
"""Decorator to check if the key is a string."""

Expand Down
Loading