diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 66d5afb8..a471e7b5 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -4,18 +4,49 @@ on: pull_request: push: branches: [ main ] + env: PACKAGE_DIR: nx_arangodb TESTS_DIR: tests + jobs: - build: + lint: runs-on: ubuntu-latest - continue-on-error: true steps: - uses: actions/checkout@v4 - name: Setup Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + cache: 'pip' + cache-dependency-path: setup.py + + - name: Setup pip + run: python -m pip install --upgrade pip setuptools wheel + + - name: Install packages + run: pip install .[dev] + + - name: Run black + run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + + - name: Run flake8 + run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + + - name: Run isort + run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + + - name: Run mypy + run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Python 3.10 + uses: actions/setup-python@v5 with: python-version: "3.10" cache: 'pip' @@ -30,22 +61,10 @@ jobs: run: python -m pip install --upgrade pip setuptools wheel - name: Install packages - run: pip install .[test] - - # - name: Run black - # run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - # - name: Run flake8 - # run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - # - name: Run isort - # run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - # - name: Run mypy - # run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + run: pip install .[dev] - name: Run local tests run: pytest tests/test.py - name: Run NetworkX tests - run: ./run_nx_tests.sh \ No newline at end of file + run: ./run_nx_tests.sh diff --git a/nx_arangodb/algorithms/centrality/betweenness.py b/nx_arangodb/algorithms/centrality/betweenness.py index b3c4c1e0..a949fbd1 100644 --- a/nx_arangodb/algorithms/centrality/betweenness.py +++ b/nx_arangodb/algorithms/centrality/betweenness.py @@ -1,6 +1,9 @@ +# type: ignore +# NOTE: NetworkX algorithms are not typed + import networkx as nx -from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.convert import _to_nx_graph, _to_nxcg_graph from nx_arangodb.logger import logger from nx_arangodb.utils import networkx_algorithm @@ -40,7 +43,7 @@ def betweenness_centrality( print("Running nxcg.betweenness_centrality()") return nxcg.betweenness_centrality(G, k=k, normalized=normalized, weight=weight) - G = _to_nxadb_graph(G, pull_graph=pull_graph_on_cpu) + G = _to_nx_graph(G, pull_graph=pull_graph_on_cpu) logger.debug("using nx.betweenness_centrality") return nx.betweenness_centrality.orig_func( diff --git a/nx_arangodb/algorithms/community/louvain.py b/nx_arangodb/algorithms/community/louvain.py index 65426337..2c708863 100644 --- a/nx_arangodb/algorithms/community/louvain.py +++ b/nx_arangodb/algorithms/community/louvain.py @@ -1,8 +1,11 @@ +# type: ignore +# NOTE: NetworkX algorithms are not typed + from collections import deque import networkx as nx -from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.convert import _to_nx_graph, _to_nxcg_graph from nx_arangodb.logger import logger from nx_arangodb.utils import _dtype_param, networkx_algorithm @@ -50,7 +53,7 @@ def louvain_communities( seed=seed, ) - G = _to_nxadb_graph(G, pull_graph=pull_graph_on_cpu) + G = _to_nx_graph(G, pull_graph=pull_graph_on_cpu) logger.debug("using nx.louvain_communities") return nx.community.louvain_communities.orig_func( diff --git a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py index d10aa443..a4da41c9 100644 --- a/nx_arangodb/algorithms/link_analysis/pagerank_alg.py +++ b/nx_arangodb/algorithms/link_analysis/pagerank_alg.py @@ -1,6 +1,9 @@ +# type: ignore +# NOTE: NetworkX algorithms are not typed + import networkx as nx -from nx_arangodb.convert import _to_nxadb_graph, _to_nxcg_graph +from nx_arangodb.convert import _to_nx_graph, _to_nxcg_graph from nx_arangodb.logger import logger from nx_arangodb.utils import _dtype_param, networkx_algorithm @@ -51,7 +54,7 @@ def pagerank( dtype=dtype, ) - G = _to_nxadb_graph(G, pull_graph=pull_graph_on_cpu) + G = _to_nx_graph(G, pull_graph=pull_graph_on_cpu) logger.debug("using nx.pagerank") return nx.algorithms.pagerank.orig_func( diff --git a/nx_arangodb/algorithms/shortest_paths/generic.py b/nx_arangodb/algorithms/shortest_paths/generic.py index b0ec7c09..658aabfb 100644 --- a/nx_arangodb/algorithms/shortest_paths/generic.py +++ b/nx_arangodb/algorithms/shortest_paths/generic.py @@ -1,3 +1,6 @@ +# type: ignore +# NOTE: NetworkX algorithms are not typed + import networkx as nx import nx_arangodb as nxadb diff --git a/nx_arangodb/classes/dict.py b/nx_arangodb/classes/dict.py index c7e4ee91..f70c6e70 100644 --- a/nx_arangodb/classes/dict.py +++ b/nx_arangodb/classes/dict.py @@ -1,8 +1,13 @@ +""" +A collection of dictionary-like objects for interacting with ArangoDB. +Used as the underlying data structure for NetworkX-ArangoDB graphs. +""" + from __future__ import annotations from collections import UserDict, defaultdict from collections.abc import Iterator -from typing import Any, Callable +from typing import Any, Callable, Generator from arango.database import StandardDatabase from arango.exceptions import DocumentInsertError @@ -21,6 +26,7 @@ aql_edge_get, aql_edge_id, aql_fetch_data, + aql_fetch_data_edge, aql_single, create_collection, doc_delete, @@ -81,7 +87,7 @@ def edge_attr_dict_factory( return lambda: EdgeAttrDict(db, graph) -class GraphDict(UserDict): +class GraphDict(UserDict[str, Any]): """A dictionary-like object for storing graph attributes. Given that ArangoDB does not have a concept of graph attributes, this class @@ -95,9 +101,12 @@ class GraphDict(UserDict): COLLECTION_NAME = "nxadb_graphs" - def __init__(self, db: StandardDatabase, graph_name: str, *args, **kwargs): + def __init__( + self, db: StandardDatabase, graph_name: str, *args: Any, **kwargs: Any + ): logger.debug("GraphDict.__init__") super().__init__(*args, **kwargs) + self.data: dict[str, Any] = {} self.db = db self.graph_name = graph_name @@ -120,7 +129,7 @@ def __contains__(self, key: str) -> bool: return aql_doc_has_key(self.db, self.graph_id, key) @key_is_string - def __getitem__(self, key: Any) -> Any: + def __getitem__(self, key: str) -> Any: """G.graph['foo']""" if value := self.data.get(key): return value @@ -138,7 +147,7 @@ def __getitem__(self, key: Any) -> Any: @key_is_string @key_is_not_reserved # @value_is_json_serializable # TODO? - def __setitem__(self, key: str, value: Any): + def __setitem__(self, key: str, value: Any) -> None: """G.graph['foo'] = 'bar'""" self.data[key] = value logger.debug(f"doc_update in GraphDict.__setitem__({key})") @@ -146,7 +155,7 @@ def __setitem__(self, key: str, value: Any): @key_is_string @key_is_not_reserved - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: """del G.graph['foo']""" self.data.pop(key, None) logger.debug(f"doc_update in GraphDict.__delitem__({key})") @@ -155,14 +164,14 @@ def __delitem__(self, key): @keys_are_strings @keys_are_not_reserved # @values_are_json_serializable # TODO? - def update(self, attrs): + def update(self, attrs: Any) -> None: """G.graph.update({'foo': 'bar'})""" if attrs: self.data.update(attrs) logger.debug(f"doc_update in GraphDict.update({attrs})") doc_update(self.db, self.graph_id, attrs) - def clear(self): + def clear(self) -> None: """G.graph.clear()""" self.data.clear() logger.debug("cleared GraphDict") @@ -171,16 +180,139 @@ def clear(self): # doc_insert(self.db, self.COLLECTION_NAME, self.graph_id, silent=True) -class NodeDict(UserDict): - """The outer-level of the dict of dict structure representing the nodes (vertices) of a graph. +class NodeAttrDict(UserDict[str, Any]): + """The inner-level of the dict of dict structure + representing the nodes (vertices) of a graph. + + :param db: The ArangoDB database. + :type db: StandardDatabase + :param graph: The ArangoDB graph. + :type graph: Graph + """ + + def __init__(self, db: StandardDatabase, graph: Graph, *args: Any, **kwargs: Any): + logger.debug("NodeAttrDict.__init__") + + self.db = db + self.graph = graph + self.node_id: str + + super().__init__(*args, **kwargs) + self.data: dict[str, Any] = {} + + @key_is_string + def __contains__(self, key: str) -> bool: + """'foo' in G._node['node/1']""" + if key in self.data: + logger.debug(f"cached in NodeAttrDict.__contains__({key})") + return True + + logger.debug("aql_doc_has_key in NodeAttrDict.__contains__") + return aql_doc_has_key(self.db, self.node_id, key) + + @key_is_string + def __getitem__(self, key: str) -> Any: + """G._node['node/1']['foo']""" + if value := self.data.get(key): + logger.debug(f"cached in NodeAttrDict.__getitem__({key})") + return value + + logger.debug(f"aql_doc_get_key in NodeAttrDict.__getitem__({key})") + result = aql_doc_get_key(self.db, self.node_id, key) + + if not result: + raise KeyError(key) + + self.data[key] = result + + return result + + @key_is_string + @key_is_not_reserved + # @value_is_json_serializable # TODO? + def __setitem__(self, key: str, value: Any) -> None: + """G._node['node/1']['foo'] = 'bar'""" + self.data[key] = value + logger.debug(f"doc_update in NodeAttrDict.__setitem__({key})") + doc_update(self.db, self.node_id, {key: value}) + + @key_is_string + @key_is_not_reserved + def __delitem__(self, key: str) -> None: + """del G._node['node/1']['foo']""" + self.data.pop(key, None) + logger.debug(f"doc_update in NodeAttrDict({self.node_id}).__delitem__({key})") + doc_update(self.db, self.node_id, {key: None}) + + def __iter__(self) -> Iterator[str]: + """for key in G._node['node/1']""" + logger.debug(f"NodeAttrDict({self.node_id}).__iter__") + yield from aql_doc_get_keys(self.db, self.node_id) + + def __len__(self) -> int: + """len(G._node['node/1'])""" + logger.debug(f"NodeAttrDict({self.node_id}).__len__") + return aql_doc_get_length(self.db, self.node_id) + + # TODO: Revisit typing of return value + from collections.abc import KeysView + + def keys(self) -> Any: + """G._node['node/1'].keys()""" + logger.debug(f"NodeAttrDict({self.node_id}).keys()") + yield from self.__iter__() + + # TODO: Revisit typing of return value + def values(self) -> Any: + """G._node['node/1'].values()""" + logger.debug(f"NodeAttrDict({self.node_id}).values()") + self.data = self.db.document(self.node_id) + yield from self.data.values() + + # TODO: Revisit typing of return value + def items(self) -> Any: + """G._node['node/1'].items()""" + logger.debug(f"NodeAttrDict({self.node_id}).items()") + self.data = self.db.document(self.node_id) + yield from self.data.items() + + def clear(self) -> None: + """G._node['node/1'].clear()""" + self.data.clear() + logger.debug(f"cleared NodeAttrDict({self.node_id})") + + # if clear_remote: + # doc_insert(self.db, self.node_id, silent=True, overwrite=True) + + @keys_are_strings + @keys_are_not_reserved + # @values_are_json_serializable # TODO? + def update(self, attrs: Any) -> None: + """G._node['node/1'].update({'foo': 'bar'})""" + if attrs: + self.data.update(attrs) + + if not self.node_id: + logger.debug("Node ID not set, skipping NodeAttrDict(?).update()") + return + + logger.debug(f"NodeAttrDict({self.node_id}).update({attrs})") + doc_update(self.db, self.node_id, attrs) + + +class NodeDict(UserDict[str, NodeAttrDict]): + """The outer-level of the dict of dict structure representing the + nodes (vertices) of a graph. - The outer dict is keyed by ArangoDB Vertex IDs and the inner dict is keyed by Vertex attributes. + The outer dict is keyed by ArangoDB Vertex IDs and the inner dict + is keyed by Vertex attributes. :param db: The ArangoDB database. :type db: StandardDatabase :param graph: The ArangoDB graph. :type graph: Graph - :param default_node_type: The default node type. Used if the node ID is not formatted as 'type/id'. + :param default_node_type: The default node type. Used if the node ID + is not formatted as 'type/id'. :type default_node_type: str """ @@ -189,11 +321,12 @@ def __init__( db: StandardDatabase, graph: Graph, default_node_type: str, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): logger.debug("NodeDict.__init__") super().__init__(*args, **kwargs) + self.data: dict[str, NodeAttrDict] = {} self.db = db self.graph = graph @@ -210,7 +343,7 @@ def __contains__(self, key: str) -> bool: return True logger.debug(f"graph.has_vertex in NodeDict.__contains__({node_id})") - return self.graph.has_vertex(node_id) + return bool(self.graph.has_vertex(node_id)) @key_is_string def __getitem__(self, key: str) -> NodeAttrDict: @@ -234,7 +367,7 @@ def __getitem__(self, key: str) -> NodeAttrDict: raise KeyError(key) @key_is_string - def __setitem__(self, key: str, value: NodeAttrDict): + def __setitem__(self, key: str, value: NodeAttrDict) -> None: """G._node['node/1'] = {'foo': 'bar'} Not to be confused with: @@ -254,7 +387,7 @@ def __setitem__(self, key: str, value: NodeAttrDict): self.data[node_id] = node_attr_dict @key_is_string - def __delitem__(self, key: Any) -> None: + def __delitem__(self, key: str) -> None: """del g._node['node/1']""" node_id = get_node_id(key, self.default_node_type) @@ -262,7 +395,7 @@ def __delitem__(self, key: Any) -> None: raise KeyError(key) remove_statements = "\n".join( - f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" + f"REMOVE e IN `{edge_def['edge_collection']}` OPTIONS {{ignoreErrors: true}}" # noqa for edge_def in self.graph.edge_definitions() ) @@ -295,10 +428,9 @@ def __iter__(self) -> Iterator[str]: """iter(g._node)""" logger.debug("NodeDict.__iter__") for collection in self.graph.vertex_collections(): - for node_id in self.graph.vertex_collection(collection).ids(): - yield node_id + yield from self.graph.vertex_collection(collection).ids() - def clear(self): + def clear(self) -> None: """g._node.clear()""" self.data.clear() logger.debug("cleared NodeDict") @@ -308,7 +440,7 @@ def clear(self): # self.graph.vertex_collection(collection).truncate() @keys_are_strings - def update(self, nodes: dict[str, dict[str, Any]]): + def update(self, nodes: Any) -> None: """g._node.update({'node/1': {'foo': 'bar'}, 'node/2': {'baz': 'qux'}})""" raise NotImplementedError("NodeDict.update()") # for node_id, attrs in nodes.items(): @@ -322,27 +454,30 @@ def update(self, nodes: dict[str, dict[str, Any]]): # self.data[node_id] = node_attr_dict - def keys(self): + def keys(self) -> Any: """g._node.keys()""" logger.debug("NodeDict.keys()") return self.__iter__() - def values(self): + # TODO: Revisit typing of return value + def values(self) -> Any: """g._node.values()""" logger.debug("NodeDict.values()") self.__fetch_all() - return self.data.values() + yield from self.data.values() - def items(self, data: str | None = None, default: Any | None = None): + # TODO: Revisit typing of return value + def items(self, data: str | None = None, default: Any | None = None) -> Any: """g._node.items() or G._node.items(data='foo')""" if data is None: logger.debug("NodeDict.items(data=None)") self.__fetch_all() - return self.data.items() - - logger.debug(f"NodeDict.items(data={data})") - v_cols = list(self.graph.vertex_collections()) - return aql_fetch_data(self.db, v_cols, data, default, is_edge=False) + yield from self.data.items() + else: + logger.debug(f"NodeDict.items(data={data})") + v_cols = list(self.graph.vertex_collections()) + result = aql_fetch_data(self.db, v_cols, data, default) + yield from result.items() def __fetch_all(self): logger.debug("NodeDict.__fetch_all()") @@ -359,8 +494,11 @@ def __fetch_all(self): self.data[node_id] = node_attr_dict -class NodeAttrDict(UserDict): - """The inner-level of the dict of dict structure representing the nodes (vertices) of a graph. +class EdgeAttrDict(UserDict[str, Any]): + """The innermost-level of the dict of dict of dict structure + representing the Adjacency List of a graph. + + The innermost-dict is keyed by the edge attribute key. :param db: The ArangoDB database. :type db: StandardDatabase @@ -368,34 +506,43 @@ class NodeAttrDict(UserDict): :type graph: Graph """ - def __init__(self, db: StandardDatabase, graph: Graph, *args, **kwargs): - logger.debug("NodeAttrDict.__init__") + def __init__( + self, + db: StandardDatabase, + graph: Graph, + *args: Any, + **kwargs: Any, + ) -> None: + logger.debug("EdgeAttrDict.__init__") + + super().__init__(*args, **kwargs) + self.data: dict[str, Any] = {} self.db = db self.graph = graph - self.node_id: str | None = None - - super().__init__(*args, **kwargs) + self.edge_id: str @key_is_string def __contains__(self, key: str) -> bool: - """'foo' in G._node['node/1']""" + """'foo' in G._adj['node/1']['node/2']""" if key in self.data: - logger.debug(f"cached in NodeAttrDict.__contains__({key})") + logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__contains__({key})") return True - logger.debug("aql_doc_has_key in NodeAttrDict.__contains__") - return aql_doc_has_key(self.db, self.node_id, key) + logger.debug(f"aql_doc_has_key in EdgeAttrDict({self.edge_id}).__contains__") + return aql_doc_has_key(self.db, self.edge_id, key) @key_is_string def __getitem__(self, key: str) -> Any: - """G._node['node/1']['foo']""" + """G._adj['node/1']['node/2']['foo']""" if value := self.data.get(key): - logger.debug(f"cached in NodeAttrDict.__getitem__({key})") + logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__getitem__({key})") return value - logger.debug(f"aql_doc_get_key in NodeAttrDict.__getitem__({key})") - result = aql_doc_get_key(self.db, self.node_id, key) + logger.debug( + f"aql_doc_get_key in EdgeAttrDict({self.edge_id}).__getitem__({key})" + ) + result = aql_doc_get_key(self.db, self.edge_id, key) if not result: raise KeyError(key) @@ -407,73 +554,75 @@ def __getitem__(self, key: str) -> Any: @key_is_string @key_is_not_reserved # @value_is_json_serializable # TODO? - def __setitem__(self, key: str, value: Any): - """G._node['node/1']['foo'] = 'bar'""" + def __setitem__(self, key: str, value: Any) -> None: + """G._adj['node/1']['node/2']['foo'] = 'bar'""" self.data[key] = value - logger.debug(f"doc_update in NodeAttrDict.__setitem__({key})") - doc_update(self.db, self.node_id, {key: value}) + logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__setitem__({key})") + doc_update(self.db, self.edge_id, {key: value}) @key_is_string @key_is_not_reserved - def __delitem__(self, key: str): - """del G._node['node/1']['foo']""" + def __delitem__(self, key: str) -> None: + """del G._adj['node/1']['node/2']['foo']""" self.data.pop(key, None) - logger.debug(f"doc_update in NodeAttrDict({self.node_id}).__delitem__({key})") - doc_update(self.db, self.node_id, {key: None}) + logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__delitem__({key})") + doc_update(self.db, self.edge_id, {key: None}) def __iter__(self) -> Iterator[str]: - """for key in G._node['node/1']""" - logger.debug(f"NodeAttrDict({self.node_id}).__iter__") - for key in aql_doc_get_keys(self.db, self.node_id): - yield key + """for key in G._adj['node/1']['node/2']""" + logger.debug(f"EEdgeAttrDict({self.edge_id}).__iter__") + yield from aql_doc_get_keys(self.db, self.edge_id) def __len__(self) -> int: - """len(G._node['node/1'])""" - logger.debug(f"NodeAttrDict({self.node_id}).__len__") - return aql_doc_get_length(self.db, self.node_id) + """len(G._adj['node/1']['node/'2])""" + logger.debug(f"EdgeAttrDict({self.edge_id}).__len__") + return aql_doc_get_length(self.db, self.edge_id) - def keys(self): - """G._node['node/1'].keys()""" - logger.debug(f"NodeAttrDict({self.node_id}).keys()") + # TODO: Revisit typing of return value + def keys(self) -> Any: + """G._adj['node/1']['node/'2].keys()""" + logger.debug(f"EdgeAttrDict({self.edge_id}).keys()") return self.__iter__() - def values(self): - """G._node['node/1'].values()""" - logger.debug(f"NodeAttrDict({self.node_id}).values()") - self.data = self.db.document(self.node_id) - return self.data.values() + # TODO: Revisit typing of return value + def values(self) -> Any: + """G._adj['node/1']['node/'2].values()""" + logger.debug(f"EdgeAttrDict({self.edge_id}).values()") + self.data = self.db.document(self.edge_id) + yield from self.data.values() - def items(self): - """G._node['node/1'].items()""" - logger.debug(f"NodeAttrDict({self.node_id}).items()") - self.data = self.db.document(self.node_id) - return self.data.items() + # TODO: Revisit typing of return value + def items(self) -> Any: + """G._adj['node/1']['node/'2].items()""" + logger.debug(f"EdgeAttrDict({self.edge_id}).items()") + self.data = self.db.document(self.edge_id) + yield from self.data.items() - def clear(self): - """G._node['node/1'].clear()""" + def clear(self) -> None: + """G._adj['node/1']['node/'2].clear()""" self.data.clear() - logger.debug(f"cleared NodeAttrDict({self.node_id})") - - # if clear_remote: - # doc_insert(self.db, self.node_id, silent=True, overwrite=True) + logger.debug(f"cleared EdgeAttrDict({self.edge_id})") - def update(self, attrs: dict[str, Any]): - """G._node['node/1'].update({'foo': 'bar'})""" + @keys_are_strings + @keys_are_not_reserved + def update(self, attrs: Any) -> None: + """G._adj['node/1']['node/'2].update({'foo': 'bar'})""" if attrs: self.data.update(attrs) - if not self.node_id: - logger.debug(f"Node ID not set, skipping NodeAttrDict(?).update()") + if not hasattr(self, "edge_id"): + logger.debug("Edge ID not set, skipping EdgeAttrDict(?).update()") return - logger.debug(f"NodeAttrDict({self.node_id}).update({attrs})") - doc_update(self.db, self.node_id, attrs) + logger.debug(f"EdgeAttrDict({self.edge_id}).update({attrs})") + doc_update(self.db, self.edge_id, attrs) -class AdjListOuterDict(UserDict): - """The outer-level of the dict of dict of dict structure representing the Adjacency List of a graph. +class AdjListInnerDict(UserDict[str, EdgeAttrDict]): + """The inner-level of the dict of dict of dict structure + representing the Adjacency List of a graph. - The outer-dict is keyed by the node ID of the source node. + The inner-dict is keyed by the node ID of the destination node. :param db: The ArangoDB database. :type db: StandardDatabase @@ -491,302 +640,81 @@ def __init__( graph: Graph, default_node_type: str, edge_type_func: Callable[[str, str], str], - *args, - **kwargs, + adjlist_outer_dict: AdjListOuterDict | None, + *args: Any, + **kwargs: Any, ): - logger.debug("AdjListOuterDict.__init__") + logger.debug("AdjListInnerDict.__init__") super().__init__(*args, **kwargs) + self.data: dict[str, EdgeAttrDict] = {} self.db = db self.graph = graph self.default_node_type = default_node_type self.edge_type_func = edge_type_func - self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( - db, graph, default_node_type, edge_type_func, self - ) + self.adjlist_outer_dict = adjlist_outer_dict + + self.src_node_id: str + + self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) self.FETCHED_ALL_DATA = False - # def __repr__(self) -> str: - # return f"'{self.graph.name}'" + def __get_mirrored_edge_attr_dict(self, dst_node_id: str) -> EdgeAttrDict | None: + if self.adjlist_outer_dict is None: + return None - # def __str__(self) -> str: - # return f"'{self.graph.name}'" + logger.debug(f"checking for mirrored edge ({self.src_node_id}, {dst_node_id})") + if dst_node_id in self.adjlist_outer_dict.data: + if self.src_node_id in self.adjlist_outer_dict.data[dst_node_id].data: + return self.adjlist_outer_dict.data[dst_node_id].data[self.src_node_id] - @key_is_string - def __contains__(self, key) -> bool: - """'node/1' in G.adj""" - node_id = get_node_id(key, self.default_node_type) + return None - if node_id in self.data: - logger.debug(f"cached in AdjListOuterDict.__contains__({node_id})") - return True + def __repr__(self) -> str: + return f"'{self.src_node_id}'" - logger.debug("graph.has_vertex in AdjListOuterDict.__contains__") - return self.graph.has_vertex(node_id) + def __str__(self) -> str: + return f"'{self.src_node_id}'" @key_is_string - def __getitem__(self, key: str) -> AdjListInnerDict: - """G.adj["node/1"]""" - node_type, node_id = get_node_type_and_id(key, self.default_node_type) - - if value := self.data.get(node_id): - logger.debug(f"cached in AdjListOuterDict.__getitem__({node_id})") - return value + def __contains__(self, key: str) -> bool: + """'node/2' in G.adj['node/1']""" + dst_node_id = get_node_id(key, self.default_node_type) - if self.graph.has_vertex(node_id): - logger.debug(f"graph.vertex in AdjListOuterDict.__getitem__({node_id})") - adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory() - adjlist_inner_dict.src_node_id = node_id + if dst_node_id in self.data: + logger.debug(f"cached in AdjListInnerDict.__contains__({dst_node_id})") + return True - self.data[node_id] = adjlist_inner_dict + logger.debug(f"aql_edge_exists in AdjListInnerDict.__contains__({dst_node_id})") - return adjlist_inner_dict + result = aql_edge_exists( + self.db, + self.src_node_id, + dst_node_id, + self.graph.name, + direction="ANY", + ) - raise KeyError(key) + return result if result else False @key_is_string - def __setitem__(self, src_key: str, adjlist_inner_dict: AdjListInnerDict): - """ - g._adj['node/1'] = AdjListInnerDict() - """ - assert isinstance(adjlist_inner_dict, AdjListInnerDict) - assert not adjlist_inner_dict.src_node_id + def __getitem__(self, key: str) -> EdgeAttrDict: + """g._adj['node/1']['node/2']""" + dst_node_id = get_node_id(key, self.default_node_type) - logger.debug(f"AdjListOuterDict.__setitem__({src_key})") + if dst_node_id in self.data: + m = f"cached in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" # noqa + logger.debug(m) + return self.data[dst_node_id] - src_node_type, src_node_id = get_node_type_and_id( - src_key, self.default_node_type - ) + if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): + logger.debug("No need to fetch the edge, as it is already cached") + self.data[dst_node_id] = mirrored_edge_attr_dict + return mirrored_edge_attr_dict - # NOTE: this might not actually be needed... - results = {} - edge_dict: dict[str, Any] - for dst_key, edge_dict in adjlist_inner_dict.data.items(): - dst_node_type, dst_node_id = get_node_type_and_id( - dst_key, self.default_node_type - ) - - edge_type = edge_dict.get("_edge_type") - if edge_type is None: - edge_type = self.edge_type_func(src_node_type, dst_node_type) - - logger.debug(f"graph.link({src_key}, {dst_key})") - results[dst_key] = self.graph.link( - edge_type, src_node_id, dst_node_id, edge_dict - ) - - adjlist_inner_dict.src_node_id = src_node_id - adjlist_inner_dict.data = results - - self.data[src_node_id] = adjlist_inner_dict - - @key_is_string - def __delitem__(self, key: Any) -> None: - """ - del G._adj['node/1'] - """ - # Nothing else to do here, as this delete is always invoked by - # G.remove_node(), which already removes all edges via - # del G._node['node/1'] - logger.debug(f"AdjListOuterDict.__delitem__({key}) (just cache)") - node_id = get_node_id(key, self.default_node_type) - self.data.pop(node_id, None) - - def __len__(self) -> int: - """len(g._adj)""" - logger.debug("AdjListOuterDict.__len__") - return sum( - [ - self.graph.vertex_collection(c).count() - for c in self.graph.vertex_collections() - ] - ) - - def __iter__(self) -> Iterator[str]: - """for k in g._adj""" - logger.debug("AdjListOuterDict.__iter__") - - if self.FETCHED_ALL_DATA: - yield from self.data.keys() - - else: - for collection in self.graph.vertex_collections(): - for id in self.graph.vertex_collection(collection).ids(): - yield id - - def keys(self): - """g._adj.keys()""" - logger.debug("AdjListOuterDict.keys()") - return self.__iter__() - - def clear(self): - """g._node.clear()""" - self.data.clear() - self.FETCHED_ALL_DATA = False - logger.debug("cleared AdjListOuterDict") - - # if clear_remote: - # for ed in self.graph.edge_definitions(): - # self.graph.edge_collection(ed["edge_collection"]).truncate() - - @keys_are_strings - def update(self, edges: dict[str, dict[str, dict[str, Any]]]): - """g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})""" - raise NotImplementedError("AdjListOuterDict.update()") - - def values(self): - """g._adj.values()""" - logger.debug("AdjListOuterDict.values()") - self.__fetch_all() - return self.data.values() - - def items(self, data: str | None = None, default: Any | None = None): - """g._adj.items() or G._adj.items(data='foo')""" - if data is None: - logger.debug("AdjListOuterDict.items(data=None)") - self.__fetch_all() - return self.data.items() - - logger.debug(f"AdjListOuterDict.items(data={data})") - e_cols = [ed["edge_collection"] for ed in self.graph.edge_definitions()] - result = aql_fetch_data(self.db, e_cols, data, default, is_edge=True) - yield from result - - # TODO: Revisit - def __fetch_all(self) -> None: - logger.debug("AdjListOuterDict.__fetch_all()") - - if self.FETCHED_ALL_DATA: - logger.debug("Already fetched data, skipping fetch") - return - - self.clear() - # items = defaultdict(dict) - for ed in self.graph.edge_definitions(): - collection = ed["edge_collection"] - - for edge in self.graph.edge_collection(collection): - src_node_id = edge["_from"] - dst_node_id = edge["_to"] - - # items[src_node_id][dst_node_id] = edge - # items[dst_node_id][src_node_id] = edge - - if src_node_id in self.data: - src_inner_dict = self.data[src_node_id] - else: - src_inner_dict = self.adjlist_inner_dict_factory() - src_inner_dict.src_node_id = src_node_id - self.data[src_node_id] = src_inner_dict - - if dst_node_id in self.data: - dst_inner_dict = self.data[dst_node_id] - else: - dst_inner_dict = self.adjlist_inner_dict_factory() - dst_inner_dict.src_node_id = dst_node_id - self.data[dst_node_id] = dst_inner_dict - - edge_attr_dict = src_inner_dict.edge_attr_dict_factory() - edge_attr_dict.edge_id = edge["_id"] - edge_attr_dict.data = edge - - self.data[src_node_id].data[dst_node_id] = edge_attr_dict - self.data[dst_node_id].data[src_node_id] = edge_attr_dict - - self.FETCHED_ALL_DATA = True - - -class AdjListInnerDict(UserDict): - """The inner-level of the dict of dict of dict structure representing the Adjacency List of a graph. - - The inner-dict is keyed by the node ID of the destination node. - - :param db: The ArangoDB database. - :type db: StandardDatabase - :param graph: The ArangoDB graph. - :type graph: Graph - :param default_node_type: The default node type. - :type default_node_type: str - :param edge_type_func: The function to generate the edge type. - :type edge_type_func: Callable[[str, str], str] - """ - - def __init__( - self, - db: StandardDatabase, - graph: Graph, - default_node_type: str, - edge_type_func: Callable[[str, str], str], - adjlist_outer_dict: AdjListOuterDict, - *args, - **kwargs, - ): - logger.debug("AdjListInnerDict.__init__") - - super().__init__(*args, **kwargs) - - self.db = db - self.graph = graph - self.default_node_type = default_node_type - self.edge_type_func = edge_type_func - self.adjlist_outer_dict = adjlist_outer_dict - - self.src_node_id = None - - self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph) - - self.FETCHED_ALL_DATA = False - - def __get_mirrored_edge_attr_dict(self, dst_node_id: str) -> bool: - logger.debug(f"checking for mirrored edge ({self.src_node_id}, {dst_node_id})") - if dst_node_id in self.adjlist_outer_dict.data: - if self.src_node_id in self.adjlist_outer_dict.data[dst_node_id].data: - return self.adjlist_outer_dict.data[dst_node_id].data[self.src_node_id] - - return None - - # def __repr__(self) -> str: - # return f"'{self.src_node_id}'" - - # def __str__(self) -> str: - # return f"'{self.src_node_id}'" - - @key_is_string - def __contains__(self, key) -> bool: - """'node/2' in G.adj['node/1']""" - dst_node_id = get_node_id(key, self.default_node_type) - - if dst_node_id in self.data: - logger.debug(f"cached in AdjListInnerDict.__contains__({dst_node_id})") - return True - - logger.debug(f"aql_edge_exists in AdjListInnerDict.__contains__({dst_node_id})") - return aql_edge_exists( - self.db, - self.src_node_id, - dst_node_id, - self.graph.name, - direction="ANY", - ) - - @key_is_string - def __getitem__(self, key) -> EdgeAttrDict: - """g._adj['node/1']['node/2']""" - dst_node_id = get_node_id(key, self.default_node_type) - - if dst_node_id in self.data: - m = f"cached in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" - logger.debug(m) - return self.data[dst_node_id] - - if mirrored_edge_attr_dict := self.__get_mirrored_edge_attr_dict(dst_node_id): - logger.debug("No need to fetch the edge, as it is already cached") - self.data[dst_node_id] = mirrored_edge_attr_dict - return mirrored_edge_attr_dict - - m = f"aql_edge_get in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" + m = f"aql_edge_get in AdjListInnerDict({self.src_node_id}).__getitem__({dst_node_id})" # noqa edge = aql_edge_get( self.db, self.src_node_id, @@ -807,7 +735,7 @@ def __getitem__(self, key) -> EdgeAttrDict: return edge_attr_dict @key_is_string - def __setitem__(self, key: str, value: dict | EdgeAttrDict): + def __setitem__(self, key: str, value: dict[str, Any] | EdgeAttrDict) -> None: """g._adj['node/1']['node/2'] = {'foo': 'bar'}""" assert isinstance(value, EdgeAttrDict) logger.debug(f"AdjListInnerDict({self.src_node_id}).__setitem__({key})") @@ -825,10 +753,11 @@ def __setitem__(self, key: str, value: dict | EdgeAttrDict): edge_type = self.edge_type_func(src_node_type, dst_node_type) logger.debug(f"No edge type specified, so generated: {edge_type})") - if edge_id := value.edge_id: + edge_id: str | None + if hasattr(value, "edge_id"): m = f"edge id found, deleting ({self.src_node_id, dst_node_id})" logger.debug(m) - self.graph.delete_edge(edge_id) + self.graph.delete_edge(value.edge_id) elif edge_id := aql_edge_id( self.db, @@ -857,7 +786,7 @@ def __setitem__(self, key: str, value: dict | EdgeAttrDict): self.data[dst_node_id] = edge_attr_dict @key_is_string - def __delitem__(self, key: Any) -> None: + def __delitem__(self, key: str) -> None: """del g._adj['node/1']['node/2']""" dst_node_id = get_node_id(key, self.default_node_type) self.data.pop(dst_node_id, None) @@ -877,7 +806,7 @@ def __delitem__(self, key: Any) -> None: ) if not edge_id: - m = f"edge not found, AdjListInnerDict({self.src_node_id}).__delitem__({dst_node_id})" + m = f"edge not found, AdjListInnerDict({self.src_node_id}).__delitem__({dst_node_id})" # noqa logger.debug(m) return @@ -889,7 +818,7 @@ def __len__(self) -> int: assert self.src_node_id if self.FETCHED_ALL_DATA: - m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__len__" + m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__len__" # noqa logger.debug(m) return len(self.data) @@ -903,14 +832,17 @@ def __len__(self) -> int: bind_vars = {"src_node_id": self.src_node_id, "graph_name": self.graph.name} logger.debug(f"aql_single in AdjListInnerDict({self.src_node_id}).__len__") - count = aql_single(self.db, query, bind_vars) + result = aql_single(self.db, query, bind_vars) + + if result is None: + return 0 - return count if count is not None else 0 + return int(result) def __iter__(self) -> Iterator[str]: """for k in g._adj['node/1']""" if self.FETCHED_ALL_DATA: - m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__iter__" + m = f"Already fetched data, skipping AdjListInnerDict({self.src_node_id}).__iter__" # noqa logger.debug(m) yield from self.data.keys() @@ -925,34 +857,38 @@ def __iter__(self) -> Iterator[str]: logger.debug(f"aql in AdjListInnerDict({self.src_node_id}).__iter__") yield from aql(self.db, query, bind_vars) - def keys(self): + # TODO: Revisit typing of return value + def keys(self) -> Any: """g._adj['node/1'].keys()""" logger.debug(f"AdjListInnerDict({self.src_node_id}).keys()") return self.__iter__() - def clear(self): + def clear(self) -> None: """G._adj['node/1'].clear()""" self.data.clear() self.FETCHED_ALL_DATA = False logger.debug(f"cleared AdjListInnerDict({self.src_node_id})") - def update(self, edges: dict[str, dict[str, Any]]): + @keys_are_strings + def update(self, edges: Any) -> None: """g._adj['node/1'].update({'node/2': {'foo': 'bar'}})""" raise NotImplementedError("AdjListInnerDict.update()") - def values(self): + # TODO: Revisit typing of return value + def values(self) -> Any: """g._adj['node/1'].values()""" logger.debug(f"AdjListInnerDict({self.src_node_id}).values()") self.__fetch_all() - return self.data.values() + yield from self.data.values() - def items(self): + # TODO: Revisit typing of return value + def items(self) -> Any: """g._adj['node/1'].items()""" logger.debug(f"AdjListInnerDict({self.src_node_id}).items()") self.__fetch_all() - return self.data.items() + yield from self.data.items() - def __fetch_all(self): + def __fetch_all(self) -> None: logger.debug(f"AdjListInnerDict({self.src_node_id}).__fetch_all()") if self.FETCHED_ALL_DATA: @@ -978,119 +914,238 @@ def __fetch_all(self): self.FETCHED_ALL_DATA = True -class EdgeAttrDict(UserDict): - """The innermost-level of the dict of dict of dict structure representing the Adjacency List of a graph. +class AdjListOuterDict(UserDict[str, AdjListInnerDict]): + """The outer-level of the dict of dict of dict structure + representing the Adjacency List of a graph. - The innermost-dict is keyed by the edge attribute key. + The outer-dict is keyed by the node ID of the source node. :param db: The ArangoDB database. :type db: StandardDatabase :param graph: The ArangoDB graph. :type graph: Graph + :param default_node_type: The default node type. + :type default_node_type: str + :param edge_type_func: The function to generate the edge type. + :type edge_type_func: Callable[[str, str], str] """ def __init__( self, db: StandardDatabase, graph: Graph, - *args, - **kwargs, + default_node_type: str, + edge_type_func: Callable[[str, str], str], + *args: Any, + **kwargs: Any, ): - logger.debug("EdgeAttrDict.__init__") + logger.debug("AdjListOuterDict.__init__") super().__init__(*args, **kwargs) + self.data: dict[str, AdjListInnerDict] = {} self.db = db self.graph = graph - self.edge_id: str | None = None + self.default_node_type = default_node_type + self.edge_type_func = edge_type_func + self.adjlist_inner_dict_factory = adjlist_inner_dict_factory( + db, graph, default_node_type, edge_type_func, self + ) + + self.FETCHED_ALL_DATA = False + + def __repr__(self) -> str: + return f"'{self.graph.name}'" + + def __str__(self) -> str: + return f"'{self.graph.name}'" @key_is_string def __contains__(self, key: str) -> bool: - """'foo' in G._adj['node/1']['node/2']""" - if key in self.data: - logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__contains__({key})") + """'node/1' in G.adj""" + node_id = get_node_id(key, self.default_node_type) + + if node_id in self.data: + logger.debug(f"cached in AdjListOuterDict.__contains__({node_id})") return True - logger.debug(f"aql_doc_has_key in EdgeAttrDict({self.edge_id}).__contains__") - return aql_doc_has_key(self.db, self.edge_id, key) + logger.debug("graph.has_vertex in AdjListOuterDict.__contains__") + return bool(self.graph.has_vertex(node_id)) @key_is_string - def __getitem__(self, key: str) -> Any: - """G._adj['node/1']['node/2']['foo']""" - if value := self.data.get(key): - logger.debug(f"cached in EdgeAttrDict({self.edge_id}).__getitem__({key})") + def __getitem__(self, key: str) -> AdjListInnerDict: + """G.adj["node/1"]""" + node_type, node_id = get_node_type_and_id(key, self.default_node_type) + + if value := self.data.get(node_id): + logger.debug(f"cached in AdjListOuterDict.__getitem__({node_id})") return value - logger.debug( - f"aql_doc_get_key in EdgeAttrDict({self.edge_id}).__getitem__({key})" - ) - result = aql_doc_get_key(self.db, self.edge_id, key) + if self.graph.has_vertex(node_id): + logger.debug(f"graph.vertex in AdjListOuterDict.__getitem__({node_id})") + adjlist_inner_dict: AdjListInnerDict = self.adjlist_inner_dict_factory() + adjlist_inner_dict.src_node_id = node_id - if not result: - raise KeyError(key) + self.data[node_id] = adjlist_inner_dict - self.data[key] = result + return adjlist_inner_dict - return result + raise KeyError(key) @key_is_string - @key_is_not_reserved - # @value_is_json_serializable # TODO? - def __setitem__(self, key: str, value: Any): - """G._adj['node/1']['node/2']['foo'] = 'bar'""" - self.data[key] = value - logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__setitem__({key})") - doc_update(self.db, self.edge_id, {key: value}) + def __setitem__(self, src_key: str, adjlist_inner_dict: AdjListInnerDict) -> None: + """ + g._adj['node/1'] = AdjListInnerDict() + """ + assert isinstance(adjlist_inner_dict, AdjListInnerDict) + assert not hasattr(adjlist_inner_dict, "src_node_id") + + logger.debug(f"AdjListOuterDict.__setitem__({src_key})") + + src_node_type, src_node_id = get_node_type_and_id( + src_key, self.default_node_type + ) + + # NOTE: this might not actually be needed... + results = {} + for dst_key, edge_dict in adjlist_inner_dict.data.items(): + dst_node_type, dst_node_id = get_node_type_and_id( + dst_key, self.default_node_type + ) + + edge_type = edge_dict.get("_edge_type") + if edge_type is None: + edge_type = self.edge_type_func(src_node_type, dst_node_type) + + logger.debug(f"graph.link({src_key}, {dst_key})") + results[dst_key] = self.graph.link( + edge_type, src_node_id, dst_node_id, edge_dict + ) + + adjlist_inner_dict.src_node_id = src_node_id + adjlist_inner_dict.data = results + + self.data[src_node_id] = adjlist_inner_dict @key_is_string - @key_is_not_reserved - def __delitem__(self, key: str): - """del G._adj['node/1']['node/2']['foo']""" - self.data.pop(key, None) - logger.debug(f"doc_update in EdgeAttrDict({self.edge_id}).__delitem__({key})") - doc_update(self.db, self.edge_id, {key: None}) + def __delitem__(self, key: str) -> None: + """ + del G._adj['node/1'] + """ + # Nothing else to do here, as this delete is always invoked by + # G.remove_node(), which already removes all edges via + # del G._node['node/1'] + logger.debug(f"AdjListOuterDict.__delitem__({key}) (just cache)") + node_id = get_node_id(key, self.default_node_type) + self.data.pop(node_id, None) + + def __len__(self) -> int: + """len(g._adj)""" + logger.debug("AdjListOuterDict.__len__") + return sum( + [ + self.graph.vertex_collection(c).count() + for c in self.graph.vertex_collections() + ] + ) def __iter__(self) -> Iterator[str]: - """for key in G._adj['node/1']['node/2']""" - logger.debug(f"EEdgeAttrDict({self.edge_id}).__iter__") - for key in aql_doc_get_keys(self.db, self.edge_id): - yield key + """for k in g._adj""" + logger.debug("AdjListOuterDict.__iter__") - def __len__(self) -> int: - """len(G._adj['node/1']['node/'2])""" - logger.debug(f"EdgeAttrDict({self.edge_id}).__len__") - return aql_doc_get_length(self.db, self.edge_id) + if self.FETCHED_ALL_DATA: + yield from self.data.keys() - def keys(self): - """G._adj['node/1']['node/'2].keys()""" - logger.debug(f"EdgeAttrDict({self.edge_id}).keys()") + else: + for collection in self.graph.vertex_collections(): + yield from self.graph.vertex_collection(collection).ids() + + # TODO: Revisit typing of return value + def keys(self) -> Any: + """g._adj.keys()""" + logger.debug("AdjListOuterDict.keys()") return self.__iter__() - def values(self): - """G._adj['node/1']['node/'2].values()""" - logger.debug(f"EdgeAttrDict({self.edge_id}).values()") - self.data = self.db.document(self.edge_id) - return self.data.values() + def clear(self) -> None: + """g._node.clear()""" + self.data.clear() + self.FETCHED_ALL_DATA = False + logger.debug("cleared AdjListOuterDict") - def items(self): - """G._adj['node/1']['node/'2].items()""" - logger.debug(f"EdgeAttrDict({self.edge_id}).items()") - self.data = self.db.document(self.edge_id) - return self.data.items() + # if clear_remote: + # for ed in self.graph.edge_definitions(): + # self.graph.edge_collection(ed["edge_collection"]).truncate() - def clear(self): - """G._adj['node/1']['node/'2].clear()""" - self.data.clear() - logger.debug(f"cleared EdgeAttrDict({self.edge_id})") + @keys_are_strings + def update(self, edges: Any) -> None: + """g._adj.update({'node/1': {'node/2': {'foo': 'bar'}})""" + raise NotImplementedError("AdjListOuterDict.update()") - def update(self, attrs: dict[str, Any]): - """G._adj['node/1']['node/'2].update({'foo': 'bar'})""" - if attrs: - self.data.update(attrs) + # TODO: Revisit typing of return value + def values(self) -> Any: + """g._adj.values()""" + logger.debug("AdjListOuterDict.values()") + self.__fetch_all() + yield from self.data.values() + + # TODO: Revisit typing of return value + def items(self, data: str | None = None, default: Any | None = None) -> Any: + # TODO: Revisit typing + # -> ( + # Generator[tuple[str, AdjListInnerDict], None, None] + # | Generator[tuple[str, str, Any], None, None] + # ): + """g._adj.items() or G._adj.items(data='foo')""" + if data is None: + logger.debug("AdjListOuterDict.items(data=None)") + self.__fetch_all() + yield from self.data.items() - if not self.edge_id: - logger.debug("Edge ID not set, skipping EdgeAttrDict(?).update()") - return + else: + logger.debug(f"AdjListOuterDict.items(data={data})") + e_cols = [ed["edge_collection"] for ed in self.graph.edge_definitions()] + result = aql_fetch_data_edge(self.db, e_cols, data, default) + yield from result - logger.debug(f"EdgeAttrDict({self.edge_id}).update({attrs})") - doc_update(self.db, self.edge_id, attrs) + # TODO: Revisit this logic + def __fetch_all(self) -> None: + logger.debug("AdjListOuterDict.__fetch_all()") + + if self.FETCHED_ALL_DATA: + logger.debug("Already fetched data, skipping fetch") + return + + self.clear() + # items = defaultdict(dict) + for ed in self.graph.edge_definitions(): + collection = ed["edge_collection"] + + for edge in self.graph.edge_collection(collection): + src_node_id = edge["_from"] + dst_node_id = edge["_to"] + + # items[src_node_id][dst_node_id] = edge + # items[dst_node_id][src_node_id] = edge + + if src_node_id in self.data: + src_inner_dict = self.data[src_node_id] + else: + src_inner_dict = self.adjlist_inner_dict_factory() + src_inner_dict.src_node_id = src_node_id + self.data[src_node_id] = src_inner_dict + + if dst_node_id in self.data: + dst_inner_dict = self.data[dst_node_id] + else: + dst_inner_dict = self.adjlist_inner_dict_factory() + dst_inner_dict.src_node_id = dst_node_id + self.data[dst_node_id] = dst_inner_dict + + edge_attr_dict = src_inner_dict.edge_attr_dict_factory() + edge_attr_dict.edge_id = edge["_id"] + edge_attr_dict.data = edge + + self.data[src_node_id].data[dst_node_id] = edge_attr_dict + self.data[dst_node_id].data[src_node_id] = edge_attr_dict + + self.FETCHED_ALL_DATA = True diff --git a/nx_arangodb/classes/digraph.py b/nx_arangodb/classes/digraph.py index 39c16839..52f0539a 100644 --- a/nx_arangodb/classes/digraph.py +++ b/nx_arangodb/classes/digraph.py @@ -1,17 +1,19 @@ import os -from typing import ClassVar +from typing import Any, ClassVar import networkx as nx +import numpy as np +import numpy.typing as npt from arango import ArangoClient from arango.cursor import Cursor from arango.database import StandardDatabase from arango.exceptions import ServerConnectionError import nx_arangodb as nxadb -from nx_arangodb.exceptions import * +from nx_arangodb.exceptions import DatabaseNotSet, GraphNameNotSet from nx_arangodb.logger import logger -networkx_api = nxadb.utils.decorators.networkx_class(nx.DiGraph) +networkx_api = nxadb.utils.decorators.networkx_class(nx.DiGraph) # type: ignore __all__ = ["DiGraph"] @@ -22,15 +24,15 @@ class DiGraph(nx.DiGraph): @classmethod def to_networkx_class(cls) -> type[nx.DiGraph]: - return nx.DiGraph + return nx.DiGraph # type: ignore[no-any-return] def __init__( self, graph_name: str | None = None, # default_node_type: str = "nxadb_nodes", # edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): m = "Please note that nxadb.DiGraph has no ArangoDB CRUD support yet." logger.warning(m) @@ -39,8 +41,8 @@ def __init__( m = "Cannot pass both **incoming_graph_data** and **graph_name** yet" raise NotImplementedError(m) - self.__db = None - self.__graph_name = None + self.__db: StandardDatabase | None = None + self.__graph_name: str | None = None self.__graph_exists = False self.__set_db() @@ -58,9 +60,9 @@ def __init__( self.use_nx_cache = True self.use_coo_cache = True - self.src_indices = None - self.dst_indices = None - self.vertex_ids_to_index = None + self.src_indices: npt.NDArray[np.int64] | None = None + self.dst_indices: npt.NDArray[np.int64] | None = None + self.vertex_ids_to_index: dict[str, int] | None = None # self.default_node_type = default_node_type # self.edge_type_func = edge_type_func @@ -99,7 +101,7 @@ def graph_exists(self) -> bool: # Setters # ########### - def __set_db(self, db: StandardDatabase | None = None): + def __set_db(self, db: StandardDatabase | None = None) -> None: if db is not None: if not isinstance(db, StandardDatabase): m = "arango.database.StandardDatabase" @@ -128,11 +130,10 @@ def __set_db(self, db: StandardDatabase | None = None): self.__db = None logger.warning(f"Could not connect to the database: {e}") - def __set_graph_name(self, graph_name: str | None = None): + def __set_graph_name(self, graph_name: str | None = None) -> None: if self.__db is None: - raise DatabaseNotSet( - "Cannot set graph name without setting the database first" - ) + m = "Cannot set graph name without setting the database first" + raise DatabaseNotSet(m) if graph_name is None: self.__graph_exists = False @@ -151,7 +152,7 @@ def __set_graph_name(self, graph_name: str | None = None): # ArangoDB Methods # #################### - def aql(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor: + def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): diff --git a/nx_arangodb/classes/function.py b/nx_arangodb/classes/function.py index b50107ca..afef0351 100644 --- a/nx_arangodb/classes/function.py +++ b/nx_arangodb/classes/function.py @@ -1,11 +1,18 @@ +""" +A collection of CRUD functions for the ArangoDB graph database. +Used by the nx_arangodb Graph, DiGraph, MultiGraph, and MultiDiGraph classes. +""" + from __future__ import annotations -from typing import Any, Tuple +from collections import UserDict +from typing import Any, Callable, Tuple -import arango -import networkx as nx import numpy as np -from arango import exceptions, graph +import numpy.typing as npt +from arango.collection import StandardCollection +from arango.cursor import Cursor +from arango.database import StandardDatabase import nx_arangodb as nxadb @@ -25,8 +32,8 @@ def get_arangodb_graph( ) -> Tuple[ dict[str, dict[str, Any]], dict[str, dict[str, dict[str, Any]]], - np.ndarray, - np.ndarray, + npt.NDArray[np.int64], + npt.NDArray[np.int64], dict[str, int], ]: """Pulls the graph from the database, assuming the graph exists. @@ -48,12 +55,12 @@ def get_arangodb_graph( edge_definitions = adb_graph.edge_definitions() e_cols = {c["edge_collection"] for c in edge_definitions} - metagraph = { + metagraph: dict[str, dict[str, Any]] = { "vertexCollections": {col: {} for col in v_cols}, "edgeCollections": {col: {} for col in e_cols}, } - from phenolrs.graph_loader import GraphLoader + from phenolrs.networkx_loader import NetworkXLoader kwargs = {} if G.graph_loader_parallelism is not None: @@ -61,7 +68,8 @@ def get_arangodb_graph( if G.graph_loader_batch_size is not None: kwargs["batch_size"] = G.graph_loader_batch_size - return GraphLoader.load( + # TODO: Remove ignore when phenolrs is published + return NetworkXLoader.load_into_networkx( # type: ignore G.db.name, metagraph, [G._host], @@ -75,10 +83,10 @@ def get_arangodb_graph( ) -def key_is_string(func) -> Any: +def key_is_string(func: Callable[..., Any]) -> Any: """Decorator to check if the key is a string.""" - def wrapper(self, key, *args, **kwargs) -> Any: + def wrapper(self: Any, key: str, *args: Any, **kwargs: Any) -> Any: if not isinstance(key, str): raise TypeError(f"'{key}' is not a string.") @@ -87,12 +95,13 @@ def wrapper(self, key, *args, **kwargs) -> Any: return wrapper -def keys_are_strings(func) -> Any: +def keys_are_strings(func: Callable[..., Any]) -> Any: """Decorator to check if the keys are strings.""" - def wrapper(self, dict, *args, **kwargs) -> Any: - if not all(isinstance(key, str) for key in dict): - raise TypeError(f"All keys must be strings.") + def wrapper(self: Any, dict: dict[Any, Any], *args: Any, **kwargs: Any) -> Any: + for key in dict: + if not isinstance(key, str): + raise TypeError(f"'{key}' is not a string.") return func(self, dict, *args, **kwargs) @@ -102,10 +111,10 @@ def wrapper(self, dict, *args, **kwargs) -> Any: RESERVED_KEYS = {"_id", "_key", "_rev"} -def key_is_not_reserved(func) -> Any: +def key_is_not_reserved(func: Callable[..., Any]) -> Any: """Decorator to check if the key is not reserved.""" - def wrapper(self, key, *args, **kwargs) -> Any: + def wrapper(self: Any, key: str, *args: Any, **kwargs: Any) -> Any: if key in RESERVED_KEYS: raise KeyError(f"'{key}' is a reserved key.") @@ -114,12 +123,13 @@ def wrapper(self, key, *args, **kwargs) -> Any: return wrapper -def keys_are_not_reserved(func) -> Any: +def keys_are_not_reserved(func: Any) -> Any: """Decorator to check if the keys are not reserved.""" - def wrapper(self, dict, *args, **kwargs) -> Any: - if any(key in RESERVED_KEYS for key in dict): - raise KeyError(f"All keys must not be reserved.") + def wrapper(self: Any, dict: dict[Any, Any], *args: Any, **kwargs: Any) -> Any: + for key in dict: + if key in RESERVED_KEYS: + raise KeyError(f"'{key}' is a reserved key.") return func(self, dict, *args, **kwargs) @@ -127,8 +137,8 @@ def wrapper(self, dict, *args, **kwargs) -> Any: def create_collection( - db: arango.StandardDatabase, collection_name: str, edge: bool = False -) -> arango.StandardCollection: + db: StandardDatabase, collection_name: str, edge: bool = False +) -> StandardCollection: """Creates a collection if it does not exist and returns it.""" if not db.has_collection(collection_name): db.create_collection(collection_name, edge=edge) @@ -137,22 +147,22 @@ def create_collection( def aql( - db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs -) -> arango.Cursor: + db: StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs: Any +) -> Cursor: """Executes an AQL query and returns the cursor.""" return db.aql.execute(query, bind_vars=bind_vars, stream=True, **kwargs) def aql_as_list( - db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs + db: StandardDatabase, query: str, bind_vars: dict[str, Any], **kwargs: Any ) -> list[Any]: """Executes an AQL query and returns the results as a list.""" return list(aql(db, query, bind_vars, **kwargs)) def aql_single( - db: arango.StandardDatabase, query: str, bind_vars: dict[str, Any] -) -> Any: + db: StandardDatabase, query: str, bind_vars: dict[str, Any] +) -> Any | None: """Executes an AQL query and returns the first result.""" result = aql_as_list(db, query, bind_vars) if len(result) == 0: @@ -164,41 +174,44 @@ def aql_single( return result[0] -def aql_doc_has_key(db: arango.StandardDatabase, id: str, key: str) -> bool: +def aql_doc_has_key(db: StandardDatabase, id: str, key: str) -> bool: """Checks if a document has a key.""" - query = f"RETURN HAS(DOCUMENT(@id), @key)" + query = "RETURN HAS(DOCUMENT(@id), @key)" bind_vars = {"id": id, "key": key} - return aql_single(db, query, bind_vars) + result = aql_single(db, query, bind_vars) + return bool(result) if result is not None else False -def aql_doc_get_key(db: arango.StandardDatabase, id: str, key: str) -> Any: +def aql_doc_get_key(db: StandardDatabase, id: str, key: str) -> Any: """Gets a key from a document.""" - query = f"RETURN DOCUMENT(@id).@key" + query = "RETURN DOCUMENT(@id).@key" bind_vars = {"id": id, "key": key} return aql_single(db, query, bind_vars) -def aql_doc_get_keys(db: arango.StandardDatabase, id: str) -> list[str]: +def aql_doc_get_keys(db: StandardDatabase, id: str) -> list[str]: """Gets the keys of a document.""" - query = f"RETURN ATTRIBUTES(DOCUMENT(@id))" + query = "RETURN ATTRIBUTES(DOCUMENT(@id))" bind_vars = {"id": id} - return aql_single(db, query, bind_vars) + result = aql_single(db, query, bind_vars) + return list(result) if result is not None else [] -def aql_doc_get_length(db: arango.StandardDatabase, id: str) -> int: +def aql_doc_get_length(db: StandardDatabase, id: str) -> int: """Gets the length of a document.""" - query = f"RETURN LENGTH(DOCUMENT(@id))" + query = "RETURN LENGTH(DOCUMENT(@id))" bind_vars = {"id": id} - return aql_single(db, query, bind_vars) + result = aql_single(db, query, bind_vars) + return int(result) if result is not None else 0 def aql_edge_exists( - db: arango.StandardDatabase, + db: StandardDatabase, src_node_id: str, dst_node_id: str, graph_name: str, direction: str, -): +) -> bool | None: return aql_edge( db, src_node_id, @@ -210,12 +223,12 @@ def aql_edge_exists( def aql_edge_get( - db: arango.StandardDatabase, + db: StandardDatabase, src_node_id: str, dst_node_id: str, graph_name: str, direction: str, -): +) -> Any | None: # TODO: need the use of DISTINCT return_clause = "DISTINCT e" if direction == "ANY" else "e" return aql_edge( @@ -229,15 +242,15 @@ def aql_edge_get( def aql_edge_id( - db: arango.StandardDatabase, + db: StandardDatabase, src_node_id: str, dst_node_id: str, graph_name: str, direction: str, -): +) -> str | None: # TODO: need the use of DISTINCT return_clause = "DISTINCT e._id" if direction == "ANY" else "e._id" - return aql_edge( + result = aql_edge( db, src_node_id, dst_node_id, @@ -246,21 +259,26 @@ def aql_edge_id( return_clause=return_clause, ) + return str(result) if result is not None else None + def aql_edge( - db: arango.StandardDatabase, + db: StandardDatabase, src_node_id: str, dst_node_id: str, graph_name: str, direction: str, return_clause: str, -): +) -> Any | None: if direction == "INBOUND": - filter_clause = f"e._from == @dst_node_id" + filter_clause = "e._from == @dst_node_id" elif direction == "OUTBOUND": - filter_clause = f"e._to == @dst_node_id" + filter_clause = "e._to == @dst_node_id" elif direction == "ANY": - filter_clause = f"(e._from == @dst_node_id AND e._to == @src_node_id) OR (e._to == @dst_node_id AND e._from == @src_node_id)" + filter_clause = """ + (e._from == @dst_node_id AND e._to == @src_node_id) + OR (e._to == @dst_node_id AND e._from == @src_node_id) + """ else: raise InvalidTraversalDirection(f"Invalid direction: {direction}") @@ -280,80 +298,87 @@ def aql_edge( def aql_fetch_data( - db: arango.StandardDatabase, + db: StandardDatabase, collections: list[str], data: str, default: Any, - is_edge: bool = True, -) -> dict[str, Any] | list[tuple[str, str, Any]]: - if is_edge: - items = [] - for collection in collections: - query = f""" - LET result = ( - FOR doc IN `{collection}` - RETURN [doc._from, doc._to, doc.@data or @default] - ) - - RETURN result - """ - - bind_vars = {"data": data, "default": default} +) -> dict[str, Any]: + items = {} + for collection in collections: + query = """ + LET result = ( + FOR doc IN @@collection + RETURN {[doc._id]: doc.@data or @default} + ) - items.extend(aql_single(db, query, bind_vars)) + RETURN MERGE(result) + """ - return items + bind_vars = {"data": data, "default": default, "@collection": collection} + result = aql_single(db, query, bind_vars) + items.update(result if result is not None else {}) - else: - return_clause = f"{{[doc._id]: doc.@data or @default}}" + return items - items = {} - for collection in collections: - query = f""" - LET result = ( - FOR doc IN `{collection}` - RETURN {return_clause} - ) - RETURN MERGE(result) - """ +def aql_fetch_data_edge( + db: StandardDatabase, + collections: list[str], + data: str, + default: Any, +) -> list[tuple[str, str, Any]]: + items = [] + for collection in collections: + query = """ + LET result = ( + FOR doc IN @@collection + RETURN [doc._from, doc._to, doc.@data or @default] + ) - bind_vars = {"data": data, "default": default} + RETURN result + """ - items.update(aql_single(db, query, bind_vars)) + bind_vars = {"data": data, "default": default, "@collection": collection} + result = aql_single(db, query, bind_vars) + items.extend(result if result is not None else []) - return items.items() + return items def doc_update( - db: arango.StandardDatabase, id: str, data: dict[str, Any], **kwargs + db: StandardDatabase, id: str, data: dict[str, Any], **kwargs: Any ) -> None: """Updates a document in the collection.""" db.update_document({**data, "_id": id}, keep_none=False, silent=True, **kwargs) -def doc_delete(db: arango.StandardDatabase, id: str, **kwargs) -> None: +def doc_delete(db: StandardDatabase, id: str, **kwargs: Any) -> None: """Deletes a document from the collection.""" db.delete_document(id, silent=True, **kwargs) def doc_insert( - db: arango.StandardDatabase, + db: StandardDatabase, collection: str, id: str, data: dict[str, Any] = {}, - **kwargs, -) -> dict[str, Any] | bool: + **kwargs: Any, +) -> dict[str, Any]: """Inserts a document into a collection.""" - return db.insert_document(collection, {**data, "_id": id}, overwrite=True, **kwargs) + result: dict[str, Any] = db.insert_document( + collection, {**data, "_id": id}, overwrite=True, **kwargs + ) + + return result def doc_get_or_insert( - db: arango.StandardDatabase, collection: str, id: str, **kwargs + db: StandardDatabase, collection: str, id: str, **kwargs: Any ) -> dict[str, Any]: """Loads a document if existing, otherwise inserts it & returns it.""" if db.has_document(id): - return db.document(id) + result: dict[str, Any] = db.document(id) + return result return doc_insert(db, collection, id, **kwargs) diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index 81b519ca..2372d734 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -1,8 +1,10 @@ import os from functools import cached_property -from typing import Callable, ClassVar +from typing import Any, Callable, ClassVar import networkx as nx +import numpy as np +import numpy.typing as npt from adbnx_adapter import ADBNX_Adapter from arango import ArangoClient from arango.cursor import Cursor @@ -10,7 +12,7 @@ from arango.exceptions import ServerConnectionError import nx_arangodb as nxadb -from nx_arangodb.exceptions import * +from nx_arangodb.exceptions import DatabaseNotSet, GraphNameNotSet from nx_arangodb.logger import logger from .dict import ( @@ -23,7 +25,7 @@ ) from .reportviews import CustomEdgeView, CustomNodeView -networkx_api = nxadb.utils.decorators.networkx_class(nx.Graph) +networkx_api = nxadb.utils.decorators.networkx_class(nx.Graph) # type: ignore __all__ = ["Graph"] @@ -34,15 +36,15 @@ class Graph(nx.Graph): @classmethod def to_networkx_class(cls) -> type[nx.Graph]: - return nx.Graph + return nx.Graph # type: ignore[no-any-return] def __init__( self, graph_name: str | None = None, default_node_type: str = "nxadb_node", edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): self.__db = None self.__graph_name = None @@ -63,9 +65,9 @@ def __init__( self.use_nx_cache = True self.use_coo_cache = True - self.src_indices = None - self.dst_indices = None - self.vertex_ids_to_index = None + self.src_indices: npt.NDArray[np.int64] | None = None + self.dst_indices: npt.NDArray[np.int64] | None = None + self.vertex_ids_to_index: dict[str, int] | None = None self.default_node_type = default_node_type self.edge_type_func = edge_type_func @@ -83,7 +85,7 @@ def __init__( elif self.__graph_name and incoming_graph_data: if not isinstance(incoming_graph_data, nx.Graph): - m = f"Type of **incoming_graph_data** not supported yet ({type(incoming_graph_data)})" + m = f"Type of **incoming_graph_data** not supported yet ({type(incoming_graph_data)})" # noqa: E501 raise NotImplementedError(m) adapter = ADBNX_Adapter(self.db) @@ -179,7 +181,7 @@ def graph_exists(self) -> bool: # Setters # ########### - def __set_db(self, db: StandardDatabase | None = None): + def __set_db(self, db: StandardDatabase | None = None) -> None: if db is not None: if not isinstance(db, StandardDatabase): m = "arango.database.StandardDatabase" @@ -208,11 +210,10 @@ def __set_db(self, db: StandardDatabase | None = None): self.__db = None logger.warning(f"Could not connect to the database: {e}") - def __set_graph_name(self, graph_name: str | None = None): + def __set_graph_name(self, graph_name: str | None = None) -> None: if self.__db is None: - raise DatabaseNotSet( - "Cannot set graph name without setting the database first" - ) + m = "Cannot set graph name without setting the database first" + raise DatabaseNotSet(m) if graph_name is None: self.__graph_exists = False @@ -232,7 +233,7 @@ def __set_graph_name(self, graph_name: str | None = None): #################### # TODO: proper subgraphing! - def aql(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor: + def aql(self, query: str, bind_vars: dict[str, Any] = {}, **kwargs: Any) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) # NOTE: Ignore this for now @@ -273,8 +274,9 @@ def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): and replace it with the edge data from the database. Comes with a remote reference to the database. <--- TODO: Should we paramaterize this? :type load_adj_dict: bool - :param load_coo: Load the COO representation. If False, the src & dst indices will be empty, - along with the node-ID-to-index mapping. Used for nx-cuGraph compatibility. + :param load_coo: Load the COO representation. If False, the src & dst + indices will be empty, along with the node-ID-to-index mapping. + Used for nx-cuGraph compatibility. :type load_coo: bool """ node_dict, adj_dict, src_indices, dst_indices, vertex_ids_to_indices = ( diff --git a/nx_arangodb/classes/multidigraph.py b/nx_arangodb/classes/multidigraph.py index cabc1b93..07c7d4c1 100644 --- a/nx_arangodb/classes/multidigraph.py +++ b/nx_arangodb/classes/multidigraph.py @@ -5,7 +5,7 @@ import nx_arangodb as nxadb from nx_arangodb.logger import logger -networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiDiGraph) +networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiDiGraph) # type: ignore __all__ = ["MultiDiGraph"] @@ -16,10 +16,10 @@ class MultiDiGraph(nx.MultiDiGraph): @classmethod def to_networkx_class(cls) -> type[nx.MultiDiGraph]: - return nx.MultiDiGraph + return nx.MultiDiGraph # type: ignore[no-any-return] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.graph_exists = False - m = "nxadb.MultiDiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiDiGraph for now." + m = "nxadb.MultiDiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiDiGraph for now." # noqa logger.warning(m) diff --git a/nx_arangodb/classes/multigraph.py b/nx_arangodb/classes/multigraph.py index 80dc47fa..0706001c 100644 --- a/nx_arangodb/classes/multigraph.py +++ b/nx_arangodb/classes/multigraph.py @@ -5,7 +5,7 @@ import nx_arangodb as nxadb from nx_arangodb.logger import logger -networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiGraph) +networkx_api = nxadb.utils.decorators.networkx_class(nx.MultiGraph) # type: ignore __all__ = ["MultiGraph"] @@ -16,10 +16,10 @@ class MultiGraph(nx.MultiGraph): @classmethod def to_networkx_class(cls) -> type[nx.MultiGraph]: - return nx.MultiGraph + return nx.MultiGraph # type: ignore[no-any-return] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.graph_exists = False - m = "nxadb.MultiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiGraph for now." + m = "nxadb.MultiGraph has not been implemented yet. This is a pass-through subclass of nx.MultiGraph for now." # noqa logger.warning(m) diff --git a/nx_arangodb/classes/reportviews.py b/nx_arangodb/classes/reportviews.py index 0dd4655f..03763255 100644 --- a/nx_arangodb/classes/reportviews.py +++ b/nx_arangodb/classes/reportviews.py @@ -1,3 +1,8 @@ +""" +An override of the NodeView, NodeDataView, EdgeView, and EdgeDataView classes +to allow for custom data filtering in the database instead of in Python. +""" + from __future__ import annotations import networkx as nx diff --git a/nx_arangodb/convert.py b/nx_arangodb/convert.py index 86978899..ef3aa86d 100644 --- a/nx_arangodb/convert.py +++ b/nx_arangodb/convert.py @@ -2,7 +2,7 @@ import itertools import time -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import networkx as nx @@ -36,79 +36,23 @@ def from_networkx( graph: nx.Graph, - edge_attrs: AttrKey | dict[AttrKey, EdgeValue | None] | None = None, - edge_dtypes: Dtype | dict[AttrKey, Dtype | None] | None = None, - *, - node_attrs: AttrKey | dict[AttrKey, NodeValue | None] | None = None, - node_dtypes: Dtype | dict[AttrKey, Dtype | None] | None = None, - preserve_all_attrs: bool = False, - preserve_edge_attrs: bool = False, - preserve_node_attrs: bool = False, - preserve_graph_attrs: bool = False, + *args: Any, as_directed: bool = False, - name: str | None = None, - graph_name: str | None = None, -) -> nxadb.Graph: - """Convert a networkx graph to nx_arangodb graph; can convert all attributes. + **kwargs: Any, + # name: str | None = None, + # graph_name: str | None = None, +) -> nxadb.Graph | nxadb.DiGraph: + """Convert a networkx graph to nx_arangodb graph. TEMPORARY ASSUMPTION: The nx_arangodb Graph is a subclass of networkx Graph. Therefore, I'm going to assume that we _should_ be able instantiate an - nx_arangodb Graph using the **incoming_graph_data** parameter. Let's try it! + nx_arangodb Graph using the **incoming_graph_data** parameter. + + TODO: The actual implementation should store the graph in ArangoDB. Parameters ---------- G : networkx.Graph - edge_attrs : str or dict, optional - Dict that maps edge attributes to default values if missing in ``G``. - If None, then no edge attributes will be converted. - If default value is None, then missing values are handled with a mask. - A default value of ``nxcg.convert.REQUIRED`` or ``...`` indicates that - all edges have data for this attribute, and raise `KeyError` if not. - For convenience, `edge_attrs` may be a single attribute with default 1; - for example ``edge_attrs="weight"``. - edge_dtypes : dtype or dict, optional - node_attrs : str or dict, optional - Dict that maps node attributes to default values if missing in ``G``. - If None, then no node attributes will be converted. - If default value is None, then missing values are handled with a mask. - A default value of ``nxcg.convert.REQUIRED`` or ``...`` indicates that - all edges have data for this attribute, and raise `KeyError` if not. - For convenience, `node_attrs` may be a single attribute with no default; - for example ``node_attrs="weight"``. - node_dtypes : dtype or dict, optional - preserve_all_attrs : bool, default False - If True, then equivalent to setting preserve_edge_attrs, preserve_node_attrs, - and preserve_graph_attrs to True. - preserve_edge_attrs : bool, default False - Whether to preserve all edge attributes. - preserve_node_attrs : bool, default False - Whether to preserve all node attributes. - preserve_graph_attrs : bool, default False - Whether to preserve all graph attributes. - as_directed : bool, default False - If True, then the returned graph will be directed regardless of input. - If False, then the returned graph type is determined by input graph. - name : str, optional - The name of the algorithm when dispatched from networkx. - graph_name : str, optional - The name of the graph argument geing converted when dispatched from networkx. - - Returns - ------- - nx_arangodb.Graph - - Notes - ----- - For optimal performance, be as specific as possible about what is being converted: - - 1. Do you need edge values? Creating a graph with just the structure is the fastest. - 2. Do you know the edge attribute(s) you need? Specify with `edge_attrs`. - 3. Do you know the default values? Specify with ``edge_attrs={weight: default}``. - 4. Do you know if all edges have values? Specify with ``edge_attrs={weight: ...}``. - 5. Do you know the dtype of attributes? Specify with `edge_dtypes`. - - Conversely, using ``preserve_edge_attrs=True`` or ``preserve_all_attrs=True`` are - the slowest, but are also the most flexible and generic. See Also -------- @@ -140,23 +84,18 @@ def from_networkx( return klass(incoming_graph_data=graph) -def to_networkx(G: nxadb.Graph, *, sort_edges: bool = False) -> nx.Graph: +def to_networkx(G: nxadb.Graph, *args: Any, **kwargs: Any) -> nx.Graph: """Convert a nx_arangodb graph to networkx graph. All edge and node attributes and ``G.graph`` properties are converted. TEMPORARY ASSUMPTION: The nx_arangodb Graph is a subclass of networkx Graph. Therefore, I'm going to assume that we _should_ be able instantiate an - nx Graph using the **incoming_graph_data** parameter. Let's try it! + nx Graph using the **incoming_graph_data** parameter. Parameters ---------- G : nx_arangodb.Graph - sort_edges : bool, default False - Whether to sort the edge data of the input graph by (src, dst) indices - before converting. This can be useful to convert to networkx graphs - that iterate over edges consistently since edges are stored in dicts - in the order they were added. Returns ------- @@ -176,7 +115,7 @@ def to_networkx(G: nxadb.Graph, *, sort_edges: bool = False) -> nx.Graph: def from_networkx_arangodb( G: nxadb.Graph | nxadb.DiGraph, pull_graph: bool -) -> nxadb.Graph | nxadb.DiGraph: +) -> nx.Graph | nx.DiGraph: logger.debug(f"from_networkx_arangodb for {G.__class__.__name__}") if not isinstance(G, (nxadb.Graph, nxadb.DiGraph)): @@ -188,7 +127,7 @@ def from_networkx_arangodb( if not pull_graph: if isinstance(G, nxadb.DiGraph): - m = "nx_arangodb.DiGraph has no CRUD Support yet. Cannot rely on remote connection." + m = "nx_arangodb.DiGraph has no CRUD Support yet. Cannot rely on remote connection." # noqa: E501 raise NotImplementedError(m) logger.debug("graph exists, but not pulling. relying on remote connection...") @@ -218,9 +157,9 @@ def from_networkx_arangodb( logger.debug("creating nx graph from loaded ArangoDB data...") print("Creating nx graph from loaded ArangoDB data...") start_time = time.time() - result = nx.convert.from_dict_of_dicts( + result: nx.Graph = nx.convert.from_dict_of_dicts( adj_dict, - create_using=G.__class__, + create_using=G.to_networkx_class(), multigraph_input=G.is_multigraph(), ) @@ -229,47 +168,32 @@ def from_networkx_arangodb( end_time = time.time() print(f"NX Graph creation took {end_time - start_time}") - # TODO: Could we just get away with: - # G._node = node_dict - # G._adj = adj_dict - # ? - return result except Exception as err: raise nx.NetworkXError("Input is not a correct NetworkX graph.") from err -def _to_nxadb_graph( - G, - edge_attr: AttrKey | None = None, - edge_default: EdgeValue | None = 1, - edge_dtype: Dtype | None = None, +def _to_nx_graph( + G: Any, pull_graph: bool = True, -) -> nxadb.Graph | nxadb.DiGraph: - """Ensure that input type is a nx_arangodb graph, and convert if necessary.""" - logger.debug(f"_to_nxadb_graph for {G.__class__.__name__}") +) -> nx.Graph | nx.DiGraph: + """Ensure that input type is an nx graph, and convert if necessary.""" + logger.debug(f"_to_nx_graph for {G.__class__.__name__}") if isinstance(G, (nxadb.Graph, nxadb.DiGraph)): return from_networkx_arangodb(G, pull_graph) if isinstance(G, nx.Graph): - return from_networkx( - G, {edge_attr: edge_default} if edge_attr is not None else None, edge_dtype - ) + return G + # TODO: handle cugraph.Graph raise TypeError if GPU_ENABLED: - def _to_nxcg_graph( - G, - edge_attr: AttrKey | None = None, - edge_default: EdgeValue | None = 1, - edge_dtype: Dtype | None = None, - as_directed: bool = False, - ) -> nxcg.Graph | nxcg.DiGraph: + def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph | nxcg.DiGraph: """Ensure that input type is a nx_cugraph graph, and convert if necessary.""" logger.debug(f"_to_nxcg_graph for {G.__class__.__name__}") @@ -294,21 +218,8 @@ def _to_nxcg_graph( "nxadb.MultiGraph not yet supported for _to_nxcg_graph()" ) - # If G is a networkx graph, or is a nxadb graph that doesn't point to an "existing" - # ArangoDB graph, then we just treat it as a normal networkx graph & - # convert it to nx_cugraph. - # TODO: Need to revisit the "existing" ArangoDB graph condition... - if isinstance(G, nx.Graph): - logger.debug("converting networkx graph to nx_cugraph graph") - return nxcg.convert.from_networkx( - G, - {edge_attr: edge_default} if edge_attr is not None else None, - edge_dtype, - as_directed=as_directed, - ) - # TODO: handle cugraph.Graph - raise TypeError + raise TypeError(f"Expected nx_arangodb.Graph or nx.Graph; got {type(G)}") def nxcg_from_networkx_arangodb( G: nxadb.Graph | nxadb.DiGraph, as_directed: bool = False @@ -364,7 +275,7 @@ def nxcg_from_networkx_arangodb( print(f"COO (NumPy) -> COO (CuPy) took {end_time - start_time}") logger.debug("creating nx_cugraph graph from COO data...") - print(f"creating nx_cugraph graph from COO data...") + print("creating nx_cugraph graph from COO data...") start_time = time.time() rv = klass.from_coo( N=N, @@ -380,12 +291,6 @@ def nxcg_from_networkx_arangodb( else: - def _to_nxcg_graph( - G, - edge_attr: AttrKey | None = None, - edge_default: EdgeValue | None = 1, - edge_dtype: Dtype | None = None, - as_directed: bool = False, - ) -> nxcg.Graph | nxcg.DiGraph: + def _to_nxcg_graph(G: Any, as_directed: bool = False) -> nxcg.Graph | nxcg.DiGraph: m = "nx-cugraph is not installed; cannot convert to nx-cugraph graph" raise NotImplementedError(m) diff --git a/nx_arangodb/interface.py b/nx_arangodb/interface.py index 3acf4ecc..ec4a7f99 100644 --- a/nx_arangodb/interface.py +++ b/nx_arangodb/interface.py @@ -2,6 +2,7 @@ import os import sys +from typing import Any import networkx as nx @@ -11,19 +12,17 @@ class BackendInterface: # Required conversions @staticmethod - def convert_from_nx(graph, *args, edge_attrs=None, weight=None, **kwargs): - if weight is not None: - # MAINT: networkx 3.0, 3.1 - # For networkx 3.0 and 3.1 compatibility - if edge_attrs is not None: - raise TypeError( - "edge_attrs and weight arguments should not both be given" - ) - edge_attrs = {weight: 1} - return nxadb.from_networkx(graph, *args, edge_attrs=edge_attrs, **kwargs) + def convert_from_nx( + graph: Any, *args: Any, **kwargs: Any + ) -> nxadb.Graph | nxadb.DiGraph: + return nxadb.from_networkx(graph, *args, **kwargs) @staticmethod - def convert_to_nx(obj, *, name: str | None = None): + def convert_to_nx( + obj: nx.Graph | nx.DiGraph | nxadb.Graph | nxadb.DiGraph, + *, + name: str | None = None, + ) -> nx.Graph | nx.DiGraph: if isinstance(obj, nxadb.Graph): return nxadb.to_networkx(obj) return obj diff --git a/nx_arangodb/logger.py b/nx_arangodb/logger.py index a69ea0b2..50e2e7ce 100644 --- a/nx_arangodb/logger.py +++ b/nx_arangodb/logger.py @@ -8,7 +8,7 @@ handler = logging.StreamHandler() formatter = logging.Formatter( - f"[%(asctime)s] [%(levelname)s]: %(message)s", + "[%(asctime)s] [%(levelname)s]: %(message)s", "%H:%M:%S %z", ) diff --git a/nx_arangodb/typing.py b/nx_arangodb/typing.py index 9fccd9b5..cdd3cefa 100644 --- a/nx_arangodb/typing.py +++ b/nx_arangodb/typing.py @@ -7,6 +7,7 @@ import cupy as cp import numpy as np +import numpy.typing as npt AttrKey = TypeVar("AttrKey", bound=Hashable) EdgeKey = TypeVar("EdgeKey", bound=Hashable) @@ -20,4 +21,4 @@ class any_ndarray: def __class_getitem__(cls, item): - return cp.ndarray[item] | np.ndarray[item] + return cp.ndarray[item] | npt.NDArray[item] diff --git a/nx_arangodb/utils/decorators.py b/nx_arangodb/utils/decorators.py index 783b41f5..4f773d43 100644 --- a/nx_arangodb/utils/decorators.py +++ b/nx_arangodb/utils/decorators.py @@ -1,3 +1,4 @@ +# type: ignore # Copied from nx-cugraph from __future__ import annotations diff --git a/nx_arangodb/utils/misc.py b/nx_arangodb/utils/misc.py index 55e339d3..c71661f4 100644 --- a/nx_arangodb/utils/misc.py +++ b/nx_arangodb/utils/misc.py @@ -13,24 +13,10 @@ if TYPE_CHECKING: # import nx_cugraph as nxcg - from ..typing import Dtype, EdgeKey - -try: - from itertools import pairwise # Python >=3.10 -except ImportError: - - def pairwise(it): - it = iter(it) - for prev in it: - for cur in it: - yield (prev, cur) - prev = cur - + from ..typing import Dtype, EdgeKey # noqa __all__ = [ "index_dtype", - "_seed_to_int", - "_get_int_dtype", "_dtype_param", ] @@ -44,76 +30,3 @@ def pairwise(it): "in the algorithm. If None, then dtype is determined by the edge values." ), } - - -def _seed_to_int(seed: int | Random | None) -> int: - """Handle any valid seed argument and convert it to an int if necessary.""" - if seed is None: - return - if isinstance(seed, Random): - return seed.randint(0, sys.maxsize) - return op.index(seed) # Ensure seed is integral - - -def _get_int_dtype( - val: SupportsIndex, *, signed: bool | None = None, unsigned: bool | None = None -): - """Determine the smallest integer dtype that can store the integer ``val``. - - If signed or unsigned are unspecified, then signed integers are preferred - unless the value can be represented by a smaller unsigned integer. - - Raises - ------ - ValueError : If the value cannot be represented with an int dtype. - """ - # This is similar in spirit to `np.min_scalar_type` - if signed is not None: - if unsigned is not None and (not signed) is (not unsigned): - raise TypeError( - f"signed (={signed}) and unsigned (={unsigned}) keyword arguments " - "are incompatible." - ) - signed = bool(signed) - unsigned = not signed - elif unsigned is not None: - unsigned = bool(unsigned) - signed = not unsigned - - val = op.index(val) # Ensure val is integral - if val < 0: - if unsigned: - raise ValueError(f"Value is incompatible with unsigned int: {val}.") - signed = True - unsigned = False - - if signed is not False: - # Number of bytes (and a power of two) - signed_nbytes = (val + (val < 0)).bit_length() // 8 + 1 - signed_nbytes = next( - filter( - signed_nbytes.__le__, - itertools.accumulate(itertools.repeat(2), op.mul, initial=1), - ) - ) - if unsigned is not False: - # Number of bytes (and a power of two) - unsigned_nbytes = (val.bit_length() + 7) // 8 - unsigned_nbytes = next( - filter( - unsigned_nbytes.__le__, - itertools.accumulate(itertools.repeat(2), op.mul, initial=1), - ) - ) - if signed is None and unsigned is None: - # Prefer signed int if same size - signed = signed_nbytes <= unsigned_nbytes - - if signed: - dtype_string = f"i{signed_nbytes}" - else: - dtype_string = f"u{unsigned_nbytes}" - try: - return np.dtype(dtype_string) - except TypeError as exc: - raise ValueError("Value is too large to store as integer: {val}") from exc diff --git a/pyproject.toml b/pyproject.toml index 6ee4c0bc..559566d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ ] [project.optional-dependencies] -test = [ +dev = [ "packaging>=21", "pandas", "pytest", @@ -47,6 +47,11 @@ test = [ "pytest-mpl", "pytest-xdist", "scipy", + "black", + "flake8", + "Flake8-pyproject", + "isort", + "mypy", ] gpu = [ "nx-cugraph-cu12 @ https://pypi.nvidia.com" @@ -100,6 +105,24 @@ extend_skip_glob = [ "nx_arangodb/classes/__init__.py", ] +[tool.flake8] +max-line-length = 88 +extend-ignore = ["E203", "W503", "E251", "F401", "F403"] +exclude = [".git", ".idea", ".*_cache", "dist", "venv"] + +[tool.mypy] +strict = true +ignore_missing_imports = true +disallow_untyped_defs = false +disallow_untyped_calls = false +implicit_reexport = true +scripts_are_modules = true +follow_imports = "skip" +disallow_subclassing_any = false +disallow_untyped_decorators = false +exclude = ["venv", "build", "vendor/integration_api", "vendor/protodeps"] + + [tool.pytest.ini_options] minversion = "6.0" testpaths = "nx_arangodb/tests" @@ -151,105 +174,3 @@ exclude_lines = [ "raise AssertionError", "raise NotImplementedError", ] - -[tool.ruff] -# https://github.com/charliermarsh/ruff/ -line-length = 88 -target-version = "py39" -[tool.ruff.lint] -unfixable = [ - "F841", # unused-variable (Note: can leave useless expression) - "B905", # zip-without-explicit-strict (Note: prefer `zip(x, y, strict=True)`) -] -select = [ - "ALL", -] -external = [ - # noqa codes that ruff doesn't know about: https://github.com/charliermarsh/ruff#external -] -ignore = [ - # Would be nice to fix these - "D100", # Missing docstring in public module - "D101", # Missing docstring in public class - "D102", # Missing docstring in public method - "D103", # Missing docstring in public function - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic method - - # Maybe consider - # "SIM300", # Yoda conditions are discouraged, use ... instead (Note: we're not this picky) - # "SIM401", # Use dict.get ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer) - # "TRY004", # Prefer `TypeError` exception for invalid type (Note: good advice, but not worth the nuisance) - "B904", # Bare `raise` inside exception clause (like TRY200; sometimes okay) - "S310", # Audit URL open for permitted schemes (Note: we don't download URLs in normal usage) - - # Intentionally ignored - "A003", # Class attribute ... is shadowing a python builtin - "ANN101", # Missing type annotation for `self` in method - "ARG004", # Unused static method argument: `...` - "COM812", # Trailing comma missing - "D203", # 1 blank line required before class docstring (Note: conflicts with D211, which is preferred) - "D400", # First line should end with a period (Note: prefer D415, which also allows "?" and "!") - "F403", # `from .classes import *` used; unable to detect undefined names (Note: used to match networkx) - "N801", # Class name ... should use CapWords convention (Note:we have a few exceptions to this) - "N802", # Function name ... should be lowercase - "N803", # Argument name ... should be lowercase (Maybe okay--except in tests) - "N806", # Variable ... in function should be lowercase - "N807", # Function name should not start and end with `__` - "N818", # Exception name ... should be named with an Error suffix (Note: good advice) - "PLR0911", # Too many return statements - "PLR0912", # Too many branches - "PLR0913", # Too many arguments to function call - "PLR0915", # Too many statements - "PLR2004", # Magic number used in comparison, consider replacing magic with a constant variable - "PLW2901", # Outer for loop variable ... overwritten by inner assignment target (Note: good advice, but too strict) - "RET502", # Do not implicitly `return None` in function able to return non-`None` value - "RET503", # Missing explicit `return` at the end of function able to return non-`None` value - "RET504", # Unnecessary variable assignment before `return` statement - "S110", # `try`-`except`-`pass` detected, consider logging the exception (Note: good advice, but we don't log) - "S112", # `try`-`except`-`continue` detected, consider logging the exception (Note: good advice, but we don't log) - "SIM102", # Use a single `if` statement instead of nested `if` statements (Note: often necessary) - "SIM105", # Use contextlib.suppress(...) instead of try-except-pass (Note: try-except-pass is much faster) - "SIM108", # Use ternary operator ... instead of if-else-block (Note: if-else better for coverage and sometimes clearer) - "TRY003", # Avoid specifying long messages outside the exception class (Note: why?) - - # Ignored categories - "C90", # mccabe (Too strict, but maybe we should make things less complex) - "I", # isort (Should we replace `isort` with this?) - "ANN", # flake8-annotations - "BLE", # flake8-blind-except (Maybe consider) - "FBT", # flake8-boolean-trap (Why?) - "DJ", # flake8-django (We don't use django) - "EM", # flake8-errmsg (Perhaps nicer, but too much work) - # "ICN", # flake8-import-conventions (Doesn't allow "_" prefix such as `_np`) - "PYI", # flake8-pyi (We don't have stub files yet) - "SLF", # flake8-self (We can use our own private variables--sheesh!) - "TID", # flake8-tidy-imports (Rely on isort and our own judgement) - # "TCH", # flake8-type-checking - "ARG", # flake8-unused-arguments (Sometimes helpful, but too strict) - "TD", # flake8-todos (Maybe okay to add some of these) - "FIX", # flake8-fixme (like flake8-todos) - "ERA", # eradicate (We like code in comments!) - "PD", # pandas-vet (Intended for scripts that use pandas, not libraries) -] - -[tool.ruff.lint.per-file-ignores] -"__init__.py" = ["F401"] # Allow unused imports (w/o defining `__all__`) -# Allow assert, print, RNG, and no docstring -"nx_arangodb/**/tests/*py" = ["S101", "S311", "T201", "D103", "D100"] -"_nx_arangodb/__init__.py" = ["E501"] -"nx_arangodb/algorithms/**/*py" = ["D205", "D401"] # Allow flexible docstrings for algorithms -"scripts/update_readme.py" = ["INP001"] # Not part of a package - -[tool.ruff.lint.flake8-annotations] -mypy-init-return = true - -[tool.ruff.lint.flake8-builtins] -builtins-ignorelist = ["copyright"] - -[tool.ruff.lint.flake8-pytest-style] -fixture-parentheses = false -mark-parentheses = false - -[tool.ruff.lint.pydocstyle] -convention = "numpy" diff --git a/tests/conftest.py b/tests/conftest.py index 1ed72b36..ae2955ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,11 +6,14 @@ import pytest from adbnx_adapter import ADBNX_Adapter from arango import ArangoClient +from arango.database import StandardDatabase from nx_arangodb.logger import logger logger.setLevel(logging.INFO) +db: StandardDatabase + def pytest_addoption(parser: Any) -> None: parser.addoption("--url", action="store", default="http://localhost:8529") @@ -46,7 +49,8 @@ def pytest_configure(config: Any) -> None: @pytest.fixture(scope="function") -def load_graph(): +def load_graph() -> None: + global db db.delete_graph("KarateGraph", drop_collections=True, ignore_missing=True) adapter = ADBNX_Adapter(db) adapter.networkx_to_arangodb( diff --git a/tests/test.py b/tests/test.py index 9173bc94..32fe8a20 100644 --- a/tests/test.py +++ b/tests/test.py @@ -1,3 +1,5 @@ +from typing import Any + import networkx as nx import pytest @@ -6,7 +8,7 @@ from .conftest import db -def test_db(load_graph): +def test_db(load_graph: Any) -> None: assert db.version() @@ -50,15 +52,15 @@ def test_bc(load_graph): assert len(r_1) == len(r_5) try: - import phenolrs + import phenolrs # noqa except ModuleNotFoundError: - return + pytest.skip("phenolrs not installed") G_4 = nxadb.Graph(graph_name="KarateGraph") r_6 = nx.betweenness_centrality(G_4) G_5 = nxadb.Graph(graph_name="KarateGraph") - r_7 = nxadb.betweenness_centrality(G_5, pull_graph_on_cpu=False) + r_7 = nxadb.betweenness_centrality(G_5, pull_graph_on_cpu=False) # type: ignore G_6 = nxadb.DiGraph(graph_name="KarateGraph") r_8 = nx.betweenness_centrality(G_6) @@ -68,7 +70,7 @@ def test_bc(load_graph): assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 -def test_pagerank(load_graph): +def test_pagerank(load_graph: Any) -> None: G_1 = nx.karate_club_graph() G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -86,15 +88,15 @@ def test_pagerank(load_graph): assert len(r_1) == len(r_5) try: - import phenolrs + import phenolrs # noqa except ModuleNotFoundError: - return + pytest.skip("phenolrs not installed") G_4 = nxadb.Graph(graph_name="KarateGraph") r_6 = nx.pagerank(G_4) G_5 = nxadb.Graph(graph_name="KarateGraph") - r_7 = nxadb.pagerank(G_5, pull_graph_on_cpu=False) + r_7 = nxadb.pagerank(G_5, pull_graph_on_cpu=False) # type: ignore G_6 = nxadb.DiGraph(graph_name="KarateGraph") r_8 = nx.pagerank(G_6) @@ -102,7 +104,7 @@ def test_pagerank(load_graph): assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 -def test_louvain(load_graph): +def test_louvain(load_graph: Any) -> None: G_1 = nx.karate_club_graph() G_2 = nxadb.Graph(incoming_graph_data=G_1) G_3 = nxadb.Graph(graph_name="KarateGraph") @@ -120,15 +122,15 @@ def test_louvain(load_graph): assert len(r_5) > 0 try: - import phenolrs + import phenolrs # noqa except ModuleNotFoundError: - return + pytest.skip("phenolrs not installed") G_4 = nxadb.Graph(graph_name="KarateGraph") r_6 = nx.community.louvain_communities(G_4) G_5 = nxadb.Graph(graph_name="KarateGraph") - r_7 = nxadb.community.louvain_communities(G_5, pull_graph_on_cpu=False) + r_7 = nxadb.community.louvain_communities(G_5, pull_graph_on_cpu=False) # type: ignore # noqa G_6 = nxadb.DiGraph(graph_name="KarateGraph") r_8 = nx.community.louvain_communities(G_6) @@ -139,7 +141,7 @@ def test_louvain(load_graph): assert len(r_8) > 0 -def test_shortest_path(load_graph): +def test_shortest_path(load_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") G_2 = nxadb.DiGraph(graph_name="KarateGraph") @@ -154,7 +156,7 @@ def test_shortest_path(load_graph): assert r_3 != r_4 -def test_graph_nodes_crud(load_graph): +def test_graph_nodes_crud(load_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph", foo="bar") G_2 = nx.Graph(nx.karate_club_graph()) @@ -254,7 +256,7 @@ def test_graph_nodes_crud(load_graph): assert not db.has_document(edge_id) -def test_graph_edges_crud(load_graph): +def test_graph_edges_crud(load_graph: Any) -> None: G_1 = nxadb.Graph(graph_name="KarateGraph") G_2 = nx.karate_club_graph() @@ -306,7 +308,7 @@ def test_graph_edges_crud(load_graph): result = list( db.aql.execute( - f"FOR e IN {G_1.default_edge_type} FILTER e._from == @src AND e._to == @dst RETURN e", + f"FOR e IN {G_1.default_edge_type} FILTER e._from == @src AND e._to == @dst RETURN e", # noqa bind_vars=bind_vars, ) ) @@ -315,7 +317,7 @@ def test_graph_edges_crud(load_graph): result = list( db.aql.execute( - f"FOR e IN {G_1.default_edge_type} FILTER e._from == @dst AND e._to == @src RETURN e", + f"FOR e IN {G_1.default_edge_type} FILTER e._from == @dst AND e._to == @src RETURN e", # noqa bind_vars=bind_vars, ) ) @@ -370,7 +372,7 @@ def test_graph_edges_crud(load_graph): assert G_1["person/2"]["person/1"]["weight"] == new_weight -def test_readme(load_graph): +def test_readme(load_graph: Any) -> None: G = nxadb.Graph(graph_name="KarateGraph") G_nx = nx.karate_club_graph() @@ -421,9 +423,9 @@ def test_readme(load_graph): assert len(G.edges) == len(G_nx.edges) -def test_digraph_nodes_crud(): +def test_digraph_nodes_crud() -> None: pytest.skip("Not implemented yet") -def test_digraph_edges_crud(): +def test_digraph_edges_crud() -> None: pytest.skip("Not implemented yet")