diff --git a/nx_arangodb/classes/graph.py b/nx_arangodb/classes/graph.py index dbab8b2b..81b519ca 100644 --- a/nx_arangodb/classes/graph.py +++ b/nx_arangodb/classes/graph.py @@ -3,6 +3,7 @@ from typing import Callable, ClassVar import networkx as nx +from adbnx_adapter import ADBNX_Adapter from arango import ArangoClient from arango.cursor import Cursor from arango.database import StandardDatabase @@ -38,15 +39,11 @@ def to_networkx_class(cls) -> type[nx.Graph]: def __init__( self, graph_name: str | None = None, - default_node_type: str = "nxadb_nodes", + default_node_type: str = "nxadb_node", edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}", *args, **kwargs, ): - if kwargs.get("incoming_graph_data") is not None and graph_name is not None: - m = "Cannot pass both **incoming_graph_data** and **graph_name** yet" - raise NotImplementedError(m) - self.__db = None self.__graph_name = None self.__graph_exists = False @@ -57,8 +54,8 @@ def __init__( self.auto_sync = True - self.graph_loader_parallelism = 20 - self.graph_loader_batch_size = 5000000 + self.graph_loader_parallelism = 10 + self.graph_loader_batch_size = 100000 # NOTE: Need to revisit these... # self.maintain_node_dict_cache = False @@ -74,11 +71,42 @@ def __init__( self.edge_type_func = edge_type_func self.default_edge_type = edge_type_func(default_node_type, default_node_type) + incoming_graph_data = kwargs.get("incoming_graph_data") if self.__graph_exists: self.adb_graph = self.db.graph(graph_name) self.__create_default_collections() self.__set_factory_methods() + if incoming_graph_data: + m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501 + raise NotImplementedError(m) + + elif self.__graph_name and incoming_graph_data: + if not isinstance(incoming_graph_data, nx.Graph): + m = f"Type of **incoming_graph_data** not supported yet ({type(incoming_graph_data)})" + raise NotImplementedError(m) + + adapter = ADBNX_Adapter(self.db) + self.adb_graph = adapter.networkx_to_arangodb( + graph_name, + incoming_graph_data, + # TODO: Parameterize the edge definitions + # How can we work with a heterogenous **incoming_graph_data**? + edge_definitions=[ + { + "edge_collection": self.default_edge_type, + "from_vertex_collections": [self.default_node_type], + "to_vertex_collections": [self.default_node_type], + } + ], + ) + + self.__set_factory_methods() + self.__graph_exists = True + del kwargs["incoming_graph_data"] + + # self.__qa_chain = None + super().__init__(*args, **kwargs) ####################### @@ -207,6 +235,31 @@ def __set_graph_name(self, graph_name: str | None = None): def aql(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor: return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs) + # NOTE: Ignore this for now + # def chat(self, prompt: str) -> str: + # if self.__qa_chain is None: + # if not self.__graph_exists: + # return "Could not initialize QA chain: Graph does not exist" + + # # try: + # from langchain.chains import ArangoGraphQAChain + # from langchain_community.graphs import ArangoGraph + # from langchain_openai import ChatOpenAI + + # model = ChatOpenAI(temperature=0, model_name="gpt-4") + + # self.__qa_chain = ArangoGraphQAChain.from_llm( + # llm=model, graph=ArangoGraph(self.db), verbose=True + # ) + + # # except Exception as e: + # # return f"Could not initialize QA chain: {e}" + + # self.__qa_chain.graph.set_schema() + # result = self.__qa_chain.invoke(prompt) + + # print(result["result"]) + def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True): """Load the graph from the ArangoDB database, and update existing graph object. diff --git a/tests/test.py b/tests/test.py index 8f8ab06f..9173bc94 100644 --- a/tests/test.py +++ b/tests/test.py @@ -10,97 +10,133 @@ def test_db(load_graph): assert db.version() +def test_load_graph_from_nxadb(): + graph_name = "KarateGraph" + + db.delete_graph(graph_name, drop_collections=True, ignore_missing=True) + + G_nx = nx.karate_club_graph() + + _ = nxadb.Graph( + graph_name=graph_name, + incoming_graph_data=G_nx, + default_node_type="person", + ) + + assert db.has_graph(graph_name) + assert db.has_collection("person") + assert db.has_collection("person_to_person") + assert db.collection("person").count() == len(G_nx.nodes) + assert db.collection("person_to_person").count() == len(G_nx.edges) + + db.delete_graph(graph_name, drop_collections=True) + + def test_bc(load_graph): G_1 = nx.karate_club_graph() G_2 = nxadb.Graph(incoming_graph_data=G_1) + G_3 = nxadb.Graph(graph_name="KarateGraph") r_1 = nx.betweenness_centrality(G_1) r_2 = nx.betweenness_centrality(G_2) r_3 = nx.betweenness_centrality(G_1, backend="arangodb") r_4 = nx.betweenness_centrality(G_2, backend="arangodb") + r_5 = nx.betweenness_centrality.orig_func(G_3) - assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0 + assert len(r_1) == len(G_1) + assert r_1 == r_2 + assert r_2 == r_3 + assert r_3 == r_4 + assert len(r_1) == len(r_5) try: import phenolrs except ModuleNotFoundError: return - G_3 = nxadb.Graph(graph_name="KarateGraph") - r_5 = nx.betweenness_centrality(G_3) - G_4 = nxadb.Graph(graph_name="KarateGraph") - r_6 = nxadb.betweenness_centrality(G_4, pull_graph_on_cpu=False) + r_6 = nx.betweenness_centrality(G_4) - G_5 = nxadb.DiGraph(graph_name="KarateGraph") - r_7 = nx.betweenness_centrality(G_5) + G_5 = nxadb.Graph(graph_name="KarateGraph") + r_7 = nxadb.betweenness_centrality(G_5, pull_graph_on_cpu=False) - # assert r_5 == r_6 # this is acting strange. I need to revisit - assert r_6 == r_7 - assert len(r_5) == len(r_6) == len(r_7) > 0 + G_6 = nxadb.DiGraph(graph_name="KarateGraph") + r_8 = nx.betweenness_centrality(G_6) + + # assert r_6 == r_7 # this is acting strange. I need to revisit + assert r_7 == r_8 + assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 def test_pagerank(load_graph): G_1 = nx.karate_club_graph() - G_2 = nxadb.Graph(incoming_graph_data=G_1) + G_3 = nxadb.Graph(graph_name="KarateGraph") r_1 = nx.pagerank(G_1) r_2 = nx.pagerank(G_2) r_3 = nx.pagerank(G_1, backend="arangodb") r_4 = nx.pagerank(G_2, backend="arangodb") + r_5 = nx.pagerank.orig_func(G_3) - assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0 + assert len(r_1) == len(G_1) + assert r_1 == r_2 + assert r_2 == r_3 + assert r_3 == r_4 + assert len(r_1) == len(r_5) try: import phenolrs except ModuleNotFoundError: return - G_3 = nxadb.Graph(graph_name="KarateGraph") - r_5 = nx.pagerank(G_3) - G_4 = nxadb.Graph(graph_name="KarateGraph") - r_6 = nxadb.pagerank(G_4, pull_graph_on_cpu=False) + r_6 = nx.pagerank(G_4) + + G_5 = nxadb.Graph(graph_name="KarateGraph") + r_7 = nxadb.pagerank(G_5, pull_graph_on_cpu=False) - G_5 = nxadb.DiGraph(graph_name="KarateGraph") - r_7 = nx.pagerank(G_5) + G_6 = nxadb.DiGraph(graph_name="KarateGraph") + r_8 = nx.pagerank(G_6) - assert len(r_5) == len(r_6) == len(r_7) == len(G_4) + assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0 def test_louvain(load_graph): G_1 = nx.karate_club_graph() - G_2 = nxadb.Graph(incoming_graph_data=G_1) + G_3 = nxadb.Graph(graph_name="KarateGraph") r_1 = nx.community.louvain_communities(G_1) r_2 = nx.community.louvain_communities(G_2) r_3 = nx.community.louvain_communities(G_1, backend="arangodb") r_4 = nx.community.louvain_communities(G_2, backend="arangodb") + r_5 = nx.community.louvain_communities.orig_func(G_3) assert len(r_1) > 0 assert len(r_2) > 0 assert len(r_3) > 0 assert len(r_4) > 0 + assert len(r_5) > 0 try: import phenolrs except ModuleNotFoundError: return - G_3 = nxadb.Graph(graph_name="KarateGraph") - r_5 = nx.community.louvain_communities(G_3) - G_4 = nxadb.Graph(graph_name="KarateGraph") - r_6 = nxadb.community.louvain_communities(G_4, pull_graph_on_cpu=False) + r_6 = nx.community.louvain_communities(G_4) + + G_5 = nxadb.Graph(graph_name="KarateGraph") + r_7 = nxadb.community.louvain_communities(G_5, pull_graph_on_cpu=False) - G_5 = nxadb.DiGraph(graph_name="KarateGraph") - r_7 = nx.community.louvain_communities(G_5) + G_6 = nxadb.DiGraph(graph_name="KarateGraph") + r_8 = nx.community.louvain_communities(G_6) assert len(r_5) > 0 assert len(r_6) > 0 assert len(r_7) > 0 + assert len(r_8) > 0 def test_shortest_path(load_graph):