Skip to content

Commit 93d1d24

Browse files
authored
GA-163 | test_graph (#33)
* GA-163 | initial commit will fail * unlock adbnx * fix: `incoming_graph_data` * fix: incoming_graph_data * fix: off-by-one IDs * checkpoint * checkpoint: `BaseGraphTester` is passing * checkpoint: BaseGraphAttrTester * cleanup: `aql_fetch_data`, `aql_fetch_data_edge` * use pytest skip for failing tests * checkpoint: optimize `__iter__` * checkpoint: run `test_graph` * add comment * checkpoint * attempt: slleep * fix: lint * cleanup: getitem * cleanup: copy * attempt: shorten sleep * fix: `__set_adj_elements` * fix: mypy * attempt: decrease sleep * fix: graph name * fix: `_rev`, `use_experimental_views` * new: `nbunch_iter` override * set experimental views to false * fix: lint
1 parent 61d75df commit 93d1d24

File tree

13 files changed

+1553
-297
lines changed

13 files changed

+1553
-297
lines changed

.circleci/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ jobs:
7373

7474
- run:
7575
name: Run local tests
76-
command: pytest tests/test.py
76+
command: pytest tests/*.py
7777

7878
- run:
7979
name: Run NetworkX tests

nx_arangodb/classes/dict/adj.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
GraphAdjDict,
1515
MultiDiGraphAdjDict,
1616
MultiGraphAdjDict,
17+
NodeDict,
1718
)
1819

1920
from nx_arangodb.exceptions import EdgeTypeAmbiguity, MultipleEdgesFound
@@ -210,13 +211,13 @@ def __contains__(self, key: str) -> bool:
210211
@logger_debug
211212
def __getitem__(self, key: str) -> Any:
212213
"""G._adj['node/1']['node/2']['foo']"""
213-
if value := self.data.get(key):
214-
return value
214+
if key in self.data:
215+
return self.data[key]
215216

216217
assert self.edge_id
217218
result = aql_doc_get_key(self.db, self.edge_id, key, self.parent_keys)
218219

219-
if not result:
220+
if result is None:
220221
raise KeyError(key)
221222

222223
edge_attr_dict_value = process_edge_attr_dict_value(self, key, result)
@@ -637,6 +638,14 @@ def __len__(self) -> int:
637638
@logger_debug
638639
def __iter__(self) -> Iterator[str]:
639640
"""for k in g._adj['node/1']['node/2']"""
641+
if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS):
642+
self._fetch_all()
643+
644+
yield from self.data.keys()
645+
646+
@logger_debug
647+
def keys(self) -> Any:
648+
"""g._adj['node/1']['node/2'].keys()"""
640649
if self.FETCHED_ALL_IDS:
641650
yield from self.data.keys()
642651

@@ -661,11 +670,6 @@ def __iter__(self) -> Iterator[str]:
661670
self.data[edge_id] = self.edge_attr_dict_factory()
662671
yield edge_id
663672

664-
@logger_debug
665-
def keys(self) -> Any:
666-
"""g._adj['node/1']['node/2'].keys()"""
667-
return self.__iter__()
668-
669673
@logger_debug
670674
def values(self) -> Any:
671675
"""g._adj['node/1']['node/2'].values()"""
@@ -1165,6 +1169,14 @@ def __len__(self) -> int:
11651169
@logger_debug
11661170
def __iter__(self) -> Iterator[str]:
11671171
"""for k in g._adj['node/1']"""
1172+
if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS):
1173+
self._fetch_all()
1174+
1175+
yield from self.data.keys()
1176+
1177+
@logger_debug
1178+
def keys(self) -> Any:
1179+
"""g._adj['node/1'].keys()"""
11681180
if self.FETCHED_ALL_IDS:
11691181
yield from self.data.keys()
11701182

@@ -1182,11 +1194,6 @@ def __iter__(self) -> Iterator[str]:
11821194
self.__contains_helper(edge_id)
11831195
yield edge_id
11841196

1185-
@logger_debug
1186-
def keys(self) -> Any:
1187-
"""g._adj['node/1'].keys()"""
1188-
return self.__iter__()
1189-
11901197
@logger_debug
11911198
def clear(self) -> None:
11921199
"""G._adj['node/1'].clear()"""
@@ -1528,6 +1535,14 @@ def __len__(self) -> int:
15281535
@logger_debug
15291536
def __iter__(self) -> Iterator[str]:
15301537
"""for k in g._adj"""
1538+
if not (self.FETCHED_ALL_DATA or self.FETCHED_ALL_IDS):
1539+
self._fetch_all()
1540+
1541+
yield from self.data.keys()
1542+
1543+
@logger_debug
1544+
def keys(self) -> Any:
1545+
"""g._adj.keys()"""
15311546
if self.FETCHED_ALL_IDS:
15321547
yield from self.data.keys()
15331548

@@ -1540,11 +1555,6 @@ def __iter__(self) -> Iterator[str]:
15401555
self.data[node_id] = lazy_adjlist_inner_dict
15411556
yield node_id
15421557

1543-
@logger_debug
1544-
def keys(self) -> Any:
1545-
"""g._adj.keys()"""
1546-
return self.__iter__()
1547-
15481558
@logger_debug
15491559
def clear(self) -> None:
15501560
"""g._node.clear()"""
@@ -1599,15 +1609,15 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
15991609

16001610
else:
16011611
e_cols = [ed["edge_collection"] for ed in self.graph.edge_definitions()]
1602-
result = aql_fetch_data_edge(self.db, e_cols, data, default)
1603-
yield from result
1612+
yield from aql_fetch_data_edge(self.db, e_cols, data, default)
16041613

16051614
@logger_debug
16061615
def __set_adj_elements(
16071616
self,
1608-
edges_dict: (
1617+
adj_dict: (
16091618
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict
16101619
),
1620+
node_dict: NodeDict | None = None,
16111621
) -> None:
16121622
def set_edge_graph(
16131623
src_node_id: str, dst_node_id: str, edge: dict[str, Any]
@@ -1691,7 +1701,11 @@ def propagate_edge_directed_symmetric(
16911701
)
16921702
)
16931703

1694-
for src_node_id, inner_dict in edges_dict.items():
1704+
if node_dict is not None:
1705+
for node_id in node_dict.keys():
1706+
self.__set_adj_inner_dict(self, node_id)
1707+
1708+
for src_node_id, inner_dict in adj_dict.items():
16951709
for dst_node_id, edge_or_edges in inner_dict.items():
16961710

16971711
self.__set_adj_inner_dict(self, src_node_id)
@@ -1721,16 +1735,16 @@ def _fetch_all(self) -> None:
17211735
self.clear()
17221736

17231737
(
1724-
_,
1738+
node_dict,
17251739
adj_dict,
17261740
*_,
17271741
) = get_arangodb_graph(
17281742
self.graph,
1729-
load_node_dict=False,
1743+
load_node_dict=True,
17301744
load_adj_dict=True,
17311745
load_coo=False,
17321746
edge_collections_attributes=set(), # not used
1733-
load_all_vertex_attributes=False, # not used
1747+
load_all_vertex_attributes=False,
17341748
load_all_edge_attributes=True,
17351749
is_directed=self.is_directed,
17361750
is_multigraph=self.is_multigraph,
@@ -1740,7 +1754,7 @@ def _fetch_all(self) -> None:
17401754
if self.is_directed:
17411755
adj_dict = adj_dict["succ"]
17421756

1743-
self.__set_adj_elements(adj_dict)
1757+
self.__set_adj_elements(adj_dict, node_dict)
17441758

17451759
self.FETCHED_ALL_DATA = True
17461760
self.FETCHED_ALL_IDS = True

nx_arangodb/classes/dict/node.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,13 @@ def __contains__(self, key: str) -> bool:
134134
@logger_debug
135135
def __getitem__(self, key: str) -> Any:
136136
"""G._node['node/1']['foo']"""
137-
if value := self.data.get(key):
138-
return value
137+
if key in self.data:
138+
return self.data[key]
139139

140140
assert self.node_id
141141
result = aql_doc_get_key(self.db, self.node_id, key, self.parent_keys)
142142

143-
if not result:
143+
if result is None:
144144
raise KeyError(key)
145145

146146
node_attr_dict_value = process_node_attr_dict_value(self, key, result)
@@ -348,7 +348,15 @@ def __len__(self) -> int:
348348

349349
@logger_debug
350350
def __iter__(self) -> Iterator[str]:
351-
"""iter(g._node)"""
351+
"""for k in g._node"""
352+
if not (self.FETCHED_ALL_IDS or self.FETCHED_ALL_DATA):
353+
self._fetch_all()
354+
355+
yield from self.data.keys()
356+
357+
@logger_debug
358+
def keys(self) -> Any:
359+
"""g._node.keys()"""
352360
if self.FETCHED_ALL_IDS:
353361
yield from self.data.keys()
354362
else:
@@ -360,11 +368,6 @@ def __iter__(self) -> Iterator[str]:
360368
self.data[node_id] = empty_node_attr_dict
361369
yield node_id
362370

363-
@logger_debug
364-
def keys(self) -> Any:
365-
"""g._node.keys()"""
366-
return self.__iter__()
367-
368371
@logger_debug
369372
def clear(self) -> None:
370373
"""g._node.clear()"""
@@ -429,8 +432,7 @@ def items(self, data: str | None = None, default: Any | None = None) -> Any:
429432
yield from self.data.items()
430433
else:
431434
v_cols = list(self.graph.vertex_collections())
432-
result = aql_fetch_data(self.db, v_cols, data, default)
433-
yield from result.items()
435+
yield from aql_fetch_data(self.db, v_cols, data, default)
434436

435437
@logger_debug
436438
def _fetch_all(self):

nx_arangodb/classes/digraph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def to_networkx_class(cls) -> type[nx.DiGraph]:
2424

2525
def __init__(
2626
self,
27+
incoming_graph_data: Any = None,
2728
name: str | None = None,
2829
default_node_type: str | None = None,
2930
edge_type_key: str = "_edge_type",
@@ -39,6 +40,7 @@ def __init__(
3940
**kwargs: Any,
4041
):
4142
super().__init__(
43+
incoming_graph_data,
4244
name,
4345
default_node_type,
4446
edge_type_key,

nx_arangodb/classes/function.py

Lines changed: 25 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from __future__ import annotations
77

8-
from typing import Any, Callable, Optional, Tuple
8+
from typing import Any, Callable, Generator, Optional, Tuple
99

1010
import networkx as nx
1111
from arango import ArangoError, DocumentInsertError
@@ -160,6 +160,9 @@ def key_is_string(func: Callable[..., Any]) -> Any:
160160

161161
def wrapper(self: Any, key: Any, *args: Any, **kwargs: Any) -> Any:
162162
""""""
163+
if key is None:
164+
raise ValueError("Key cannot be None.")
165+
163166
if not isinstance(key, str):
164167
if not isinstance(key, (int, float)):
165168
raise TypeError(f"{key} cannot be casted to string.")
@@ -332,7 +335,7 @@ def aql_doc_has_key(
332335

333336
def aql_doc_get_key(
334337
db: StandardDatabase, id: str, key: str, nested_keys: list[str] = []
335-
) -> Any:
338+
) -> Any | None:
336339
"""Gets a key from a document."""
337340
nested_keys_str = "." + ".".join(nested_keys) if nested_keys else ""
338341
query = f"RETURN DOCUMENT(@id){nested_keys_str}.@key"
@@ -541,47 +544,34 @@ def aql_fetch_data(
541544
collections: list[str],
542545
data: str,
543546
default: Any,
544-
) -> dict[str, Any]:
545-
items = {}
546-
for collection in collections:
547-
query = """
548-
LET result = (
549-
FOR doc IN @@collection
550-
RETURN {[doc._id]: doc.@data or @default}
551-
)
552-
553-
RETURN MERGE(result)
554-
"""
555-
556-
bind_vars = {"data": data, "default": default, "@collection": collection}
557-
result = aql_single(db, query, bind_vars)
558-
items.update(result if result is not None else {})
547+
) -> Generator[dict[str, Any], None, None]:
548+
bind_vars = {"data": data, "default": default}
549+
query = """
550+
FOR doc IN @@collection
551+
RETURN [doc._id, doc.@data or @default]
552+
"""
559553

560-
return items
554+
for collection in collections:
555+
bind_vars["@collection"] = collection
556+
yield from aql(db, query, bind_vars)
561557

562558

563559
def aql_fetch_data_edge(
564560
db: StandardDatabase,
565561
collections: list[str],
566562
data: str,
567563
default: Any,
568-
) -> list[tuple[str, str, Any]]:
569-
items = []
570-
for collection in collections:
571-
query = """
572-
LET result = (
573-
FOR doc IN @@collection
574-
RETURN [doc._from, doc._to, doc.@data or @default]
575-
)
576-
577-
RETURN result
578-
"""
579-
580-
bind_vars = {"data": data, "default": default, "@collection": collection}
581-
result = aql_single(db, query, bind_vars)
582-
items.extend(result if result is not None else [])
564+
) -> Generator[tuple[str, str, Any], None, None]:
565+
bind_vars = {"data": data, "default": default}
566+
query = """
567+
FOR doc IN @@collection
568+
RETURN [doc._from, doc._to, doc.@data or @default]
569+
"""
583570

584-
return items
571+
for collection in collections:
572+
bind_vars["@collection"] = collection
573+
for item in aql(db, query, bind_vars):
574+
yield tuple(item)
585575

586576

587577
def doc_update(
@@ -619,6 +609,7 @@ def doc_get_or_insert(
619609
"""Loads a document if existing, otherwise inserts it & returns it."""
620610
if db.has_document(id):
621611
result: dict[str, Any] = db.document(id)
612+
del result["_rev"]
622613
return result
623614

624615
return doc_insert(db, collection, id, **kwargs)

0 commit comments

Comments
 (0)