Skip to content

new: invoke adbnx_adapter from nxadb.Graph constructor #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, ClassVar

import networkx as nx
from adbnx_adapter import ADBNX_Adapter
from arango import ArangoClient
from arango.cursor import Cursor
from arango.database import StandardDatabase
Expand Down Expand Up @@ -38,15 +39,11 @@ def to_networkx_class(cls) -> type[nx.Graph]:
def __init__(
self,
graph_name: str | None = None,
default_node_type: str = "nxadb_nodes",
default_node_type: str = "nxadb_node",
edge_type_func: Callable[[str, str], str] = lambda u, v: f"{u}_to_{v}",
*args,
**kwargs,
):
if kwargs.get("incoming_graph_data") is not None and graph_name is not None:
m = "Cannot pass both **incoming_graph_data** and **graph_name** yet"
raise NotImplementedError(m)

self.__db = None
self.__graph_name = None
self.__graph_exists = False
Expand All @@ -57,8 +54,8 @@ def __init__(

self.auto_sync = True

self.graph_loader_parallelism = 20
self.graph_loader_batch_size = 5000000
self.graph_loader_parallelism = 10
self.graph_loader_batch_size = 100000

# NOTE: Need to revisit these...
# self.maintain_node_dict_cache = False
Expand All @@ -74,11 +71,42 @@ def __init__(
self.edge_type_func = edge_type_func
self.default_edge_type = edge_type_func(default_node_type, default_node_type)

incoming_graph_data = kwargs.get("incoming_graph_data")
if self.__graph_exists:
self.adb_graph = self.db.graph(graph_name)
self.__create_default_collections()
self.__set_factory_methods()

if incoming_graph_data:
m = "Cannot pass both **incoming_graph_data** and **graph_name** yet if the already graph exists" # noqa: E501
raise NotImplementedError(m)

elif self.__graph_name and incoming_graph_data:
if not isinstance(incoming_graph_data, nx.Graph):
m = f"Type of **incoming_graph_data** not supported yet ({type(incoming_graph_data)})"
raise NotImplementedError(m)

adapter = ADBNX_Adapter(self.db)
self.adb_graph = adapter.networkx_to_arangodb(
graph_name,
incoming_graph_data,
# TODO: Parameterize the edge definitions
# How can we work with a heterogenous **incoming_graph_data**?
edge_definitions=[
{
"edge_collection": self.default_edge_type,
"from_vertex_collections": [self.default_node_type],
"to_vertex_collections": [self.default_node_type],
}
],
)

self.__set_factory_methods()
self.__graph_exists = True
del kwargs["incoming_graph_data"]

# self.__qa_chain = None

super().__init__(*args, **kwargs)

#######################
Expand Down Expand Up @@ -207,6 +235,31 @@ def __set_graph_name(self, graph_name: str | None = None):
def aql(self, query: str, bind_vars: dict | None = None, **kwargs) -> Cursor:
return nxadb.classes.function.aql(self.db, query, bind_vars, **kwargs)

# NOTE: Ignore this for now
# def chat(self, prompt: str) -> str:
# if self.__qa_chain is None:
# if not self.__graph_exists:
# return "Could not initialize QA chain: Graph does not exist"

# # try:
# from langchain.chains import ArangoGraphQAChain
# from langchain_community.graphs import ArangoGraph
# from langchain_openai import ChatOpenAI

# model = ChatOpenAI(temperature=0, model_name="gpt-4")

# self.__qa_chain = ArangoGraphQAChain.from_llm(
# llm=model, graph=ArangoGraph(self.db), verbose=True
# )

# # except Exception as e:
# # return f"Could not initialize QA chain: {e}"

# self.__qa_chain.graph.set_schema()
# result = self.__qa_chain.invoke(prompt)

# print(result["result"])

def pull(self, load_node_dict=True, load_adj_dict=True, load_coo=True):
"""Load the graph from the ArangoDB database, and update existing graph object.

Expand Down
88 changes: 62 additions & 26 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,97 +10,133 @@ def test_db(load_graph):
assert db.version()


def test_load_graph_from_nxadb():
graph_name = "KarateGraph"

db.delete_graph(graph_name, drop_collections=True, ignore_missing=True)

G_nx = nx.karate_club_graph()

_ = nxadb.Graph(
graph_name=graph_name,
incoming_graph_data=G_nx,
default_node_type="person",
)

assert db.has_graph(graph_name)
assert db.has_collection("person")
assert db.has_collection("person_to_person")
assert db.collection("person").count() == len(G_nx.nodes)
assert db.collection("person_to_person").count() == len(G_nx.edges)

db.delete_graph(graph_name, drop_collections=True)


def test_bc(load_graph):
G_1 = nx.karate_club_graph()
G_2 = nxadb.Graph(incoming_graph_data=G_1)
G_3 = nxadb.Graph(graph_name="KarateGraph")

r_1 = nx.betweenness_centrality(G_1)
r_2 = nx.betweenness_centrality(G_2)
r_3 = nx.betweenness_centrality(G_1, backend="arangodb")
r_4 = nx.betweenness_centrality(G_2, backend="arangodb")
r_5 = nx.betweenness_centrality.orig_func(G_3)

assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0
assert len(r_1) == len(G_1)
assert r_1 == r_2
assert r_2 == r_3
assert r_3 == r_4
assert len(r_1) == len(r_5)

try:
import phenolrs
except ModuleNotFoundError:
return

G_3 = nxadb.Graph(graph_name="KarateGraph")
r_5 = nx.betweenness_centrality(G_3)

G_4 = nxadb.Graph(graph_name="KarateGraph")
r_6 = nxadb.betweenness_centrality(G_4, pull_graph_on_cpu=False)
r_6 = nx.betweenness_centrality(G_4)

G_5 = nxadb.DiGraph(graph_name="KarateGraph")
r_7 = nx.betweenness_centrality(G_5)
G_5 = nxadb.Graph(graph_name="KarateGraph")
r_7 = nxadb.betweenness_centrality(G_5, pull_graph_on_cpu=False)

# assert r_5 == r_6 # this is acting strange. I need to revisit
assert r_6 == r_7
assert len(r_5) == len(r_6) == len(r_7) > 0
G_6 = nxadb.DiGraph(graph_name="KarateGraph")
r_8 = nx.betweenness_centrality(G_6)

# assert r_6 == r_7 # this is acting strange. I need to revisit
assert r_7 == r_8
assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0


def test_pagerank(load_graph):
G_1 = nx.karate_club_graph()

G_2 = nxadb.Graph(incoming_graph_data=G_1)
G_3 = nxadb.Graph(graph_name="KarateGraph")

r_1 = nx.pagerank(G_1)
r_2 = nx.pagerank(G_2)
r_3 = nx.pagerank(G_1, backend="arangodb")
r_4 = nx.pagerank(G_2, backend="arangodb")
r_5 = nx.pagerank.orig_func(G_3)

assert len(r_1) == len(r_2) == len(r_3) == len(r_4) > 0
assert len(r_1) == len(G_1)
assert r_1 == r_2
assert r_2 == r_3
assert r_3 == r_4
assert len(r_1) == len(r_5)

try:
import phenolrs
except ModuleNotFoundError:
return

G_3 = nxadb.Graph(graph_name="KarateGraph")
r_5 = nx.pagerank(G_3)

G_4 = nxadb.Graph(graph_name="KarateGraph")
r_6 = nxadb.pagerank(G_4, pull_graph_on_cpu=False)
r_6 = nx.pagerank(G_4)

G_5 = nxadb.Graph(graph_name="KarateGraph")
r_7 = nxadb.pagerank(G_5, pull_graph_on_cpu=False)

G_5 = nxadb.DiGraph(graph_name="KarateGraph")
r_7 = nx.pagerank(G_5)
G_6 = nxadb.DiGraph(graph_name="KarateGraph")
r_8 = nx.pagerank(G_6)

assert len(r_5) == len(r_6) == len(r_7) == len(G_4)
assert len(r_6) == len(r_7) == len(r_8) == len(G_4) > 0


def test_louvain(load_graph):
G_1 = nx.karate_club_graph()

G_2 = nxadb.Graph(incoming_graph_data=G_1)
G_3 = nxadb.Graph(graph_name="KarateGraph")

r_1 = nx.community.louvain_communities(G_1)
r_2 = nx.community.louvain_communities(G_2)
r_3 = nx.community.louvain_communities(G_1, backend="arangodb")
r_4 = nx.community.louvain_communities(G_2, backend="arangodb")
r_5 = nx.community.louvain_communities.orig_func(G_3)

assert len(r_1) > 0
assert len(r_2) > 0
assert len(r_3) > 0
assert len(r_4) > 0
assert len(r_5) > 0

try:
import phenolrs
except ModuleNotFoundError:
return

G_3 = nxadb.Graph(graph_name="KarateGraph")
r_5 = nx.community.louvain_communities(G_3)

G_4 = nxadb.Graph(graph_name="KarateGraph")
r_6 = nxadb.community.louvain_communities(G_4, pull_graph_on_cpu=False)
r_6 = nx.community.louvain_communities(G_4)

G_5 = nxadb.Graph(graph_name="KarateGraph")
r_7 = nxadb.community.louvain_communities(G_5, pull_graph_on_cpu=False)

G_5 = nxadb.DiGraph(graph_name="KarateGraph")
r_7 = nx.community.louvain_communities(G_5)
G_6 = nxadb.DiGraph(graph_name="KarateGraph")
r_8 = nx.community.louvain_communities(G_6)

assert len(r_5) > 0
assert len(r_6) > 0
assert len(r_7) > 0
assert len(r_8) > 0


def test_shortest_path(load_graph):
Expand Down
Loading