Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
b315606
GA-163 | initial commit
aMahanna Aug 16, 2024
aa99026
Merge branch 'main' into GA-163
aMahanna Aug 16, 2024
26e1a85
unlock adbnx
aMahanna Aug 16, 2024
d6da2a3
fix: `incoming_graph_data`
aMahanna Aug 16, 2024
6561949
fix: incoming_graph_data
aMahanna Aug 17, 2024
469de71
fix: off-by-one IDs
aMahanna Aug 17, 2024
b742027
checkpoint
aMahanna Aug 17, 2024
8b47e4d
checkpoint: `BaseGraphTester` is passing
aMahanna Aug 17, 2024
0483486
checkpoint: BaseGraphAttrTester
aMahanna Aug 18, 2024
1ed111e
cleanup: `aql_fetch_data`, `aql_fetch_data_edge`
aMahanna Aug 18, 2024
f5963a6
use pytest skip for failing tests
aMahanna Aug 18, 2024
eb6717e
checkpoint: optimize `__iter__`
aMahanna Aug 18, 2024
04dc9c1
checkpoint: run `test_graph`
aMahanna Aug 19, 2024
2199ae3
add comment
aMahanna Aug 19, 2024
173f0a7
Merge branch 'main' into GA-163
aMahanna Aug 19, 2024
bc64fe9
checkpoint
aMahanna Aug 19, 2024
0df6c2b
attempt: slleep
aMahanna Aug 19, 2024
aa4b336
fix: lint
aMahanna Aug 19, 2024
5aa3eb2
cleanup: getitem
aMahanna Aug 20, 2024
b03f4cf
cleanup: copy
aMahanna Aug 20, 2024
27adfa3
attempt: shorten sleep
aMahanna Aug 20, 2024
c34898a
Merge branch 'main' into GA-163
aMahanna Aug 20, 2024
0d18563
fix: `__set_adj_elements`
aMahanna Aug 20, 2024
b0434a9
fix: mypy
aMahanna Aug 20, 2024
3f07ae1
attempt: decrease sleep
aMahanna Aug 20, 2024
8b87046
Merge branch 'main' into GA-163
aMahanna Aug 20, 2024
28dd130
Merge branch 'main' into GA-163
aMahanna Aug 21, 2024
443d436
Merge branch 'main' into GA-163
aMahanna Aug 21, 2024
dcb94ff
fix: graph name
aMahanna Aug 21, 2024
e7339de
fix: `_rev`, `use_experimental_views`
aMahanna Aug 21, 2024
6e5b504
new: `nbunch_iter` override
aMahanna Aug 21, 2024
e64781e
set experimental views to false
aMahanna Aug 22, 2024
bc88320
Merge branch 'main' into GA-163
aMahanna Aug 23, 2024
f4f7579
fix: lint
aMahanna Aug 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:

- run:
name: Run local tests
command: pytest tests/test.py
command: pytest tests/*.py

- run:
name: Run NetworkX tests
Expand Down
66 changes: 40 additions & 26 deletions nx_arangodb/classes/dict/adj.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
GraphAdjDict,
MultiDiGraphAdjDict,
MultiGraphAdjDict,
NodeDict,
)

from nx_arangodb.exceptions import EdgeTypeAmbiguity, MultipleEdgesFound
Expand Down Expand Up @@ -210,13 +211,13 @@ def __contains__(self, key: str) -> bool:
@logger_debug
def __getitem__(self, key: str) -> Any:
"""G._adj['node/1']['node/2']['foo']"""
if value := self.data.get(key):
return value
if key in self.data:
return self.data[key]

assert self.edge_id
result = aql_doc_get_key(self.db, self.edge_id, key, self.parent_keys)

if not result:
if result is None:
raise KeyError(key)

edge_attr_dict_value = process_edge_attr_dict_value(self, key, result)
Expand Down Expand Up @@ -637,6 +638,14 @@ def __len__(self) -> int:
@logger_debug
def __iter__(self) -> Iterator[str]:
"""for k in g._adj['node/1']['node/2']"""
if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS):
self._fetch_all()

yield from self.data.keys()
Comment on lines +641 to +644
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good. Just as a global comment: At some point we should think about in general all tasks that could require some or long time to finish (e.g. loading a lot of edges, or loading huge edges, etc.) Some indicator for that would be great somewhen or visibility.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya good point. Something like a temporary spinner or a progress bar would be useful to have

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've worked with these progress libraries before, could be something we can introduce:


@logger_debug
def keys(self) -> Any:
"""g._adj['node/1']['node/2'].keys()"""
if self.FETCHED_ALL_IDS:
yield from self.data.keys()

Expand All @@ -661,11 +670,6 @@ def __iter__(self) -> Iterator[str]:
self.data[edge_id] = self.edge_attr_dict_factory()
yield edge_id

@logger_debug
def keys(self) -> Any:
"""g._adj['node/1']['node/2'].keys()"""
return self.__iter__()

@logger_debug
def values(self) -> Any:
"""g._adj['node/1']['node/2'].values()"""
Expand Down Expand Up @@ -1165,6 +1169,14 @@ def __len__(self) -> int:
@logger_debug
def __iter__(self) -> Iterator[str]:
"""for k in g._adj['node/1']"""
if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS):
self._fetch_all()

yield from self.data.keys()

@logger_debug
def keys(self) -> Any:
"""g._adj['node/1'].keys()"""
if self.FETCHED_ALL_IDS:
yield from self.data.keys()

Expand All @@ -1182,11 +1194,6 @@ def __iter__(self) -> Iterator[str]:
self.__contains_helper(edge_id)
yield edge_id

@logger_debug
def keys(self) -> Any:
"""g._adj['node/1'].keys()"""
return self.__iter__()

@logger_debug
def clear(self) -> None:
"""G._adj['node/1'].clear()"""
Expand Down Expand Up @@ -1528,6 +1535,14 @@ def __len__(self) -> int:
@logger_debug
def __iter__(self) -> Iterator[str]:
"""for k in g._adj"""
if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS):
self._fetch_all()

yield from self.data.keys()

@logger_debug
def keys(self) -> Any:
"""g._adj.keys()"""
if self.FETCHED_ALL_IDS:
yield from self.data.keys()

Expand All @@ -1540,11 +1555,6 @@ def __iter__(self) -> Iterator[str]:
self.data[node_id] = lazy_adjlist_inner_dict
yield node_id

@logger_debug
def keys(self) -> Any:
"""g._adj.keys()"""
return self.__iter__()

@logger_debug
def clear(self) -> None:
"""g._node.clear()"""
Expand Down Expand Up @@ -1599,15 +1609,15 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:

else:
e_cols = [ed["edge_collection"] for ed in self.graph.edge_definitions()]
result = aql_fetch_data_edge(self.db, e_cols, data, default)
yield from result
yield from aql_fetch_data_edge(self.db, e_cols, data, default)

@logger_debug
def __set_adj_elements(
self,
edges_dict: (
adj_dict: (
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict
),
node_dict: NodeDict | None = None,
) -> None:
def set_edge_graph(
src_node_id: str, dst_node_id: str, edge: dict[str, Any]
Expand Down Expand Up @@ -1691,7 +1701,11 @@ def propagate_edge_directed_symmetric(
)
)

for src_node_id, inner_dict in edges_dict.items():
if node_dict is not None:
for node_id in node_dict.keys():
self.__set_adj_inner_dict(self, node_id)

for src_node_id, inner_dict in adj_dict.items():
for dst_node_id, edge_or_edges in inner_dict.items():

self.__set_adj_inner_dict(self, src_node_id)
Expand Down Expand Up @@ -1721,16 +1735,16 @@ def _fetch_all(self) -> None:
self.clear()

(
_,
node_dict,
adj_dict,
*_,
) = get_arangodb_graph(
self.graph,
load_node_dict=False,
load_node_dict=True,
load_adj_dict=True,
load_coo=False,
edge_collections_attributes=set(), # not used
load_all_vertex_attributes=False, # not used
load_all_vertex_attributes=False,
load_all_edge_attributes=True,
is_directed=self.is_directed,
is_multigraph=self.is_multigraph,
Expand All @@ -1740,7 +1754,7 @@ def _fetch_all(self) -> None:
if self.is_directed:
adj_dict = adj_dict["succ"]

self.__set_adj_elements(adj_dict)
self.__set_adj_elements(adj_dict, node_dict)

self.FETCHED_ALL_DATA = True
self.FETCHED_ALL_IDS = True
Expand Down
24 changes: 13 additions & 11 deletions nx_arangodb/classes/dict/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,13 @@ def __contains__(self, key: str) -> bool:
@logger_debug
def __getitem__(self, key: str) -> Any:
"""G._node['node/1']['foo']"""
if value := self.data.get(key):
return value
if key in self.data:
return self.data[key]

assert self.node_id
result = aql_doc_get_key(self.db, self.node_id, key, self.parent_keys)

if not result:
if result is None:
raise KeyError(key)

node_attr_dict_value = process_node_attr_dict_value(self, key, result)
Expand Down Expand Up @@ -348,7 +348,15 @@ def __len__(self) -> int:

@logger_debug
def __iter__(self) -> Iterator[str]:
"""iter(g._node)"""
"""for k in g._node"""
if not (self.FETCHED_ALL_IDS or self.FETCHED_ALL_DATA):
self._fetch_all()

yield from self.data.keys()

@logger_debug
def keys(self) -> Any:
"""g._node.keys()"""
if self.FETCHED_ALL_IDS:
yield from self.data.keys()
else:
Expand All @@ -360,11 +368,6 @@ def __iter__(self) -> Iterator[str]:
self.data[node_id] = empty_node_attr_dict
yield node_id

@logger_debug
def keys(self) -> Any:
"""g._node.keys()"""
return self.__iter__()

@logger_debug
def clear(self) -> None:
"""g._node.clear()"""
Expand Down Expand Up @@ -429,8 +432,7 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
yield from self.data.items()
else:
v_cols = list(self.graph.vertex_collections())
result = aql_fetch_data(self.db, v_cols, data, default)
yield from result.items()
yield from aql_fetch_data(self.db, v_cols, data, default)

@logger_debug
def _fetch_all(self):
Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def to_networkx_class(cls) -> type[nx.DiGraph]:

def __init__(
self,
incoming_graph_data: Any = None,
name: str | None = None,
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
Expand All @@ -39,6 +40,7 @@ def __init__(
**kwargs: Any,
):
super().__init__(
incoming_graph_data,
name,
default_node_type,
edge_type_key,
Expand Down
59 changes: 25 additions & 34 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Any, Callable, Optional, Tuple
from typing import Any, Callable, Generator, Optional, Tuple

import networkx as nx
from arango import ArangoError, DocumentInsertError
Expand Down Expand Up @@ -160,6 +160,9 @@ def key_is_string(func: Callable[..., Any]) -> Any:

def wrapper(self: Any, key: Any, *args: Any, **kwargs: Any) -> Any:
""""""
if key is None:
raise ValueError("Key cannot be None.")

if not isinstance(key, str):
if not isinstance(key, (int, float)):
raise TypeError(f"{key} cannot be casted to string.")
Expand Down Expand Up @@ -332,7 +335,7 @@ def aql_doc_has_key(

def aql_doc_get_key(
db: StandardDatabase, id: str, key: str, nested_keys: list[str] = []
) -> Any:
) -> Any | None:
"""Gets a key from a document."""
nested_keys_str = "." + ".".join(nested_keys) if nested_keys else ""
query = f"RETURN DOCUMENT(@id){nested_keys_str}.@key"
Expand Down Expand Up @@ -541,47 +544,34 @@ def aql_fetch_data(
collections: list[str],
data: str,
default: Any,
) -> dict[str, Any]:
items = {}
for collection in collections:
query = """
LET result = (
FOR doc IN @@collection
RETURN {[doc._id]: doc.@data or @default}
)

RETURN MERGE(result)
"""

bind_vars = {"data": data, "default": default, "@collection": collection}
result = aql_single(db, query, bind_vars)
items.update(result if result is not None else {})
) -> Generator[dict[str, Any], None, None]:
bind_vars = {"data": data, "default": default}
query = """
FOR doc IN @@collection
RETURN [doc._id, doc.@data or @default]
"""

return items
for collection in collections:
bind_vars["@collection"] = collection
yield from aql(db, query, bind_vars)


def aql_fetch_data_edge(
db: StandardDatabase,
collections: list[str],
data: str,
default: Any,
) -> list[tuple[str, str, Any]]:
items = []
for collection in collections:
query = """
LET result = (
FOR doc IN @@collection
RETURN [doc._from, doc._to, doc.@data or @default]
)

RETURN result
"""

bind_vars = {"data": data, "default": default, "@collection": collection}
result = aql_single(db, query, bind_vars)
items.extend(result if result is not None else [])
) -> Generator[tuple[str, str, Any], None, None]:
bind_vars = {"data": data, "default": default}
query = """
FOR doc IN @@collection
RETURN [doc._from, doc._to, doc.@data or @default]
"""

return items
for collection in collections:
bind_vars["@collection"] = collection
for item in aql(db, query, bind_vars):
yield tuple(item)


def doc_update(
Expand Down Expand Up @@ -619,6 +609,7 @@ def doc_get_or_insert(
"""Loads a document if existing, otherwise inserts it & returns it."""
if db.has_document(id):
result: dict[str, Any] = db.document(id)
del result["_rev"]
return result

return doc_insert(db, collection, id, **kwargs)
Expand Down
Loading