Skip to content

Commit 5ea9d99

Browse files
authored
new: invoke adbnx_adapter from nxadb.Graph constructor (#4)
* new: invoke `adbnx_adapter` from `nxadb.Graph` constructor * fix: conditional * fix: delete graph after creation * update graph_loader defaults * cleanup: test * cleanup
1 parent 9693659 commit 5ea9d99

File tree

2 files changed

+122
-33
lines changed

2 files changed

+122
-33
lines changed

nx_arangodb/classes/graph.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Callable, ClassVar
44

55
import networkx as nx
6+
from adbnx_adapter import ADBNX_Adapter
67
from arango import ArangoClient
78
from arango.cursor import Cursor
89
from arango.database import StandardDatabase
@@ -38,15 +39,11 @@ def to_networkx_class(cls) -> type[nx.Graph]:
3839
def __init__(
3940
self,
4041
graph_name: str | None = None,
41-
default_node_type: str = "nxadb_nodes",
42+
default_node_type: str = "nxadb_node",
4243
edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}",
4344
*args,
4445
**kwargs,
4546
):
46-
if kwargs.get("incoming_graph_data") is not None and graph_name is not None:
47-
m = "Cannot pass both **incoming_graph_data** and **graph_name** yet"
48-
raise NotImplementedError(m)
49-
5047
self.__db = None
5148
self.__graph_name = None
5249
self.__graph_exists = False
@@ -57,8 +54,8 @@ def __init__(
5754

5855
self.auto_sync = True
5956

60-
self.graph_loader_parallelism = 20
61-
self.graph_loader_batch_size = 5000000
57+
self.graph_loader_parallelism = 10
58+
self.graph_loader_batch_size = 100000
6259

6360
# NOTE: Need to revisit these...
6461
# self.maintain_node_dict_cache = False
@@ -74,11 +71,42 @@ def __init__(
7471
self.edge_type_func = edge_type_func
7572
self.default_edge_type = edge_type_func(default_node_type, default_node_type)
7673

74+
incoming_graph_data = kwargs.get("incoming_graph_data")
7775
if self.__graph_exists:
7876
self.adb_graph = self.db.graph(graph_name)
7977
self.__create_default_collections()
8078
self.__set_factory_methods()
8179

80+
if incoming_graph_data:
81+
m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501
82+
raise NotImplementedError(m)
83+
84+
elif self.__graph_name and incoming_graph_data:
85+
if not isinstance(incoming_graph_data, nx.Graph):
86+
m = f"Type of **incoming_graph_data** not supported yet ({type(incoming_graph_data)})"
87+
raise NotImplementedError(m)
88+
89+
adapter = ADBNX_Adapter(self.db)
90+
self.adb_graph = adapter.networkx_to_arangodb(
91+
graph_name,
92+
incoming_graph_data,
93+
# TODO: Parameterize the edge definitions
94+
# How can we work with a heterogenous **incoming_graph_data**?
95+
edge_definitions=[
96+
{
97+
"edge_collection": self.default_edge_type,
98+
"from_vertex_collections": [self.default_node_type],
99+
"to_vertex_collections": [self.default_node_type],
100+
}
101+
],
102+
)
103+
104+
self.__set_factory_methods()
105+
self.__graph_exists = True
106+
del kwargs["incoming_graph_data"]
107+
108+
# self.__qa_chain = None
109+
82110
super().__init__(*args, **kwargs)
83111

84112
#######################
@@ -207,6 +235,31 @@ def __set_graph_name(self, graph_name: str | None = None):
207235
def aql(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor:
208236
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)
209237

238+
# NOTE: Ignore this for now
239+
# def chat(self, prompt: str) -> str:
240+
# if self.__qa_chain is None:
241+
# if not self.__graph_exists:
242+
# return "Could not initialize QA chain: Graph does not exist"
243+
244+
# # try:
245+
# from langchain.chains import ArangoGraphQAChain
246+
# from langchain_community.graphs import ArangoGraph
247+
# from langchain_openai import ChatOpenAI
248+
249+
# model = ChatOpenAI(temperature=0, model_name="gpt-4")
250+
251+
# self.__qa_chain = ArangoGraphQAChain.from_llm(
252+
# llm=model, graph=ArangoGraph(self.db), verbose=True
253+
# )
254+
255+
# # except Exception as e:
256+
# # return f"Could not initialize QA chain: {e}"
257+
258+
# self.__qa_chain.graph.set_schema()
259+
# result = self.__qa_chain.invoke(prompt)
260+
261+
# print(result["result"])
262+
210263
def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True):
211264
"""Load the graph from the ArangoDB database, and update existing graph object.
212265

tests/test.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,97 +10,133 @@ def test_db(load_graph):
1010
assert db.version()
1111

1212

13+
def test_load_graph_from_nxadb():
14+
graph_name = "KarateGraph"
15+
16+
db.delete_graph(graph_name, drop_collections=True, ignore_missing=True)
17+
18+
G_nx = nx.karate_club_graph()
19+
20+
_ = nxadb.Graph(
21+
graph_name=graph_name,
22+
incoming_graph_data=G_nx,
23+
default_node_type="person",
24+
)
25+
26+
assert db.has_graph(graph_name)
27+
assert db.has_collection("person")
28+
assert db.has_collection("person_to_person")
29+
assert db.collection("person").count() == len(G_nx.nodes)
30+
assert db.collection("person_to_person").count() == len(G_nx.edges)
31+
32+
db.delete_graph(graph_name, drop_collections=True)
33+
34+
1335
def test_bc(load_graph):
1436
G_1 = nx.karate_club_graph()
1537
G_2 = nxadb.Graph(incoming_graph_data=G_1)
38+
G_3 = nxadb.Graph(graph_name="KarateGraph")
1639

1740
r_1 = nx.betweenness_centrality(G_1)
1841
r_2 = nx.betweenness_centrality(G_2)
1942
r_3 = nx.betweenness_centrality(G_1, backend="arangodb")
2043
r_4 = nx.betweenness_centrality(G_2, backend="arangodb")
44+
r_5 = nx.betweenness_centrality.orig_func(G_3)
2145

22-
assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0
46+
assert len(r_1) == len(G_1)
47+
assert r_1 == r_2
48+
assert r_2 == r_3
49+
assert r_3 == r_4
50+
assert len(r_1) == len(r_5)
2351

2452
try:
2553
import phenolrs
2654
except ModuleNotFoundError:
2755
return
2856

29-
G_3 = nxadb.Graph(graph_name="KarateGraph")
30-
r_5 = nx.betweenness_centrality(G_3)
31-
3257
G_4 = nxadb.Graph(graph_name="KarateGraph")
33-
r_6 = nxadb.betweenness_centrality(G_4, pull_graph_on_cpu=False)
58+
r_6 = nx.betweenness_centrality(G_4)
3459

35-
G_5 = nxadb.DiGraph(graph_name="KarateGraph")
36-
r_7 = nx.betweenness_centrality(G_5)
60+
G_5 = nxadb.Graph(graph_name="KarateGraph")
61+
r_7 = nxadb.betweenness_centrality(G_5, pull_graph_on_cpu=False)
3762

38-
# assert r_5 == r_6 # this is acting strange. I need to revisit
39-
assert r_6 == r_7
40-
assert len(r_5) == len(r_6) == len(r_7) > 0
63+
G_6 = nxadb.DiGraph(graph_name="KarateGraph")
64+
r_8 = nx.betweenness_centrality(G_6)
65+
66+
# assert r_6 == r_7 # this is acting strange. I need to revisit
67+
assert r_7 == r_8
68+
assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0
4169

4270

4371
def test_pagerank(load_graph):
4472
G_1 = nx.karate_club_graph()
45-
4673
G_2 = nxadb.Graph(incoming_graph_data=G_1)
74+
G_3 = nxadb.Graph(graph_name="KarateGraph")
4775

4876
r_1 = nx.pagerank(G_1)
4977
r_2 = nx.pagerank(G_2)
5078
r_3 = nx.pagerank(G_1, backend="arangodb")
5179
r_4 = nx.pagerank(G_2, backend="arangodb")
80+
r_5 = nx.pagerank.orig_func(G_3)
5281

53-
assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0
82+
assert len(r_1) == len(G_1)
83+
assert r_1 == r_2
84+
assert r_2 == r_3
85+
assert r_3 == r_4
86+
assert len(r_1) == len(r_5)
5487

5588
try:
5689
import phenolrs
5790
except ModuleNotFoundError:
5891
return
5992

60-
G_3 = nxadb.Graph(graph_name="KarateGraph")
61-
r_5 = nx.pagerank(G_3)
62-
6393
G_4 = nxadb.Graph(graph_name="KarateGraph")
64-
r_6 = nxadb.pagerank(G_4, pull_graph_on_cpu=False)
94+
r_6 = nx.pagerank(G_4)
95+
96+
G_5 = nxadb.Graph(graph_name="KarateGraph")
97+
r_7 = nxadb.pagerank(G_5, pull_graph_on_cpu=False)
6598

66-
G_5 = nxadb.DiGraph(graph_name="KarateGraph")
67-
r_7 = nx.pagerank(G_5)
99+
G_6 = nxadb.DiGraph(graph_name="KarateGraph")
100+
r_8 = nx.pagerank(G_6)
68101

69-
assert len(r_5) == len(r_6) == len(r_7) == len(G_4)
102+
assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0
70103

71104

72105
def test_louvain(load_graph):
73106
G_1 = nx.karate_club_graph()
74-
75107
G_2 = nxadb.Graph(incoming_graph_data=G_1)
108+
G_3 = nxadb.Graph(graph_name="KarateGraph")
76109

77110
r_1 = nx.community.louvain_communities(G_1)
78111
r_2 = nx.community.louvain_communities(G_2)
79112
r_3 = nx.community.louvain_communities(G_1, backend="arangodb")
80113
r_4 = nx.community.louvain_communities(G_2, backend="arangodb")
114+
r_5 = nx.community.louvain_communities.orig_func(G_3)
81115

82116
assert len(r_1) > 0
83117
assert len(r_2) > 0
84118
assert len(r_3) > 0
85119
assert len(r_4) > 0
120+
assert len(r_5) > 0
86121

87122
try:
88123
import phenolrs
89124
except ModuleNotFoundError:
90125
return
91126

92-
G_3 = nxadb.Graph(graph_name="KarateGraph")
93-
r_5 = nx.community.louvain_communities(G_3)
94-
95127
G_4 = nxadb.Graph(graph_name="KarateGraph")
96-
r_6 = nxadb.community.louvain_communities(G_4, pull_graph_on_cpu=False)
128+
r_6 = nx.community.louvain_communities(G_4)
129+
130+
G_5 = nxadb.Graph(graph_name="KarateGraph")
131+
r_7 = nxadb.community.louvain_communities(G_5, pull_graph_on_cpu=False)
97132

98-
G_5 = nxadb.DiGraph(graph_name="KarateGraph")
99-
r_7 = nx.community.louvain_communities(G_5)
133+
G_6 = nxadb.DiGraph(graph_name="KarateGraph")
134+
r_8 = nx.community.louvain_communities(G_6)
100135

101136
assert len(r_5) > 0
102137
assert len(r_6) > 0
103138
assert len(r_7) > 0
139+
assert len(r_8) > 0
104140

105141

106142
def test_shortest_path(load_graph):

0 commit comments

Comments
 (0)