Skip to content

Commit 3913247

Browse files
authored
fix: cache nxcg graph instead of coo representation (#31)
* fix: cache `nxcg` graph instead of coo representation * fix lint * fix print statements
1 parent ea19c1f commit 3913247

File tree

2 files changed

+54
-63
lines changed

2 files changed

+54
-63
lines changed

nx_arangodb/classes/graph.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,10 @@ def __init__(
7676
# self.maintain_node_dict_cache = False
7777
# self.maintain_adj_dict_cache = False
7878
self.use_nx_cache = True
79-
self.use_coo_cache = True
80-
# self.__qa_chain = None
79+
self.use_nxcg_cache = True
80+
self.nxcg_graph = None
8181

82-
self.src_indices: npt.NDArray[np.int64] | None = None
83-
self.dst_indices: npt.NDArray[np.int64] | None = None
84-
self.edge_indices: npt.NDArray[np.int64] | None = None
85-
self.vertex_ids_to_index: dict[str, int] | None = None
86-
self.edge_values: dict[str, list[int | float]] | None = None
82+
# self.__qa_chain = None
8783

8884
# Does not apply to undirected graphs
8985
self.symmetrize_edges = symmetrize_edges
@@ -379,6 +375,9 @@ def clear_edges(self):
379375
logger.info("Note that clearing edges ony erases the edges in the local cache")
380376
super().clear_edges()
381377

378+
def clear_nxcg_cache(self):
379+
self.nxcg_graph = None
380+
382381
@cached_property
383382
def nodes(self):
384383
if self.graph_exists_in_db:

nx_arangodb/convert.py

Lines changed: 48 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
135135
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
136136
)
137137

138-
print(f"ADB -> Dictionaries load took {time.time() - start_time}s")
138+
print(f"Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
139139

140140
G_NX: nx.Graph | nx.DiGraph = G.to_networkx_class()()
141141
G_NX._node = node_dict
@@ -153,74 +153,62 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
153153
if GPU_ENABLED:
154154

155155
def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
156-
if (
157-
G.use_coo_cache
158-
and G.src_indices is not None
159-
and G.dst_indices is not None
160-
and G.edge_indices is not None
161-
and G.vertex_ids_to_index is not None
162-
and G.edge_values is not None
163-
):
164-
m = "**use_coo_cache** is enabled. using cached COO data. no pull required."
156+
if G.use_nxcg_cache and G.nxcg_graph is not None:
157+
m = "**use_nxcg_cache** is enabled. using cached NXCG Graph. no pull required." # noqa
165158
logger.debug(m)
166159

167-
else:
168-
start_time = time.time()
169-
170-
(
171-
_,
172-
_,
173-
src_indices,
174-
dst_indices,
175-
edge_indices,
176-
vertex_ids_to_index,
177-
edge_values,
178-
) = nxadb.classes.function.get_arangodb_graph(
179-
adb_graph=G.adb_graph,
180-
load_node_dict=False,
181-
load_adj_dict=False,
182-
load_coo=True,
183-
edge_collections_attributes=G.get_edge_attributes,
184-
load_all_vertex_attributes=False, # not used
185-
load_all_edge_attributes=do_load_all_edge_attributes(
186-
G.get_edge_attributes
187-
),
188-
is_directed=G.is_directed(),
189-
is_multigraph=G.is_multigraph(),
190-
symmetrize_edges_if_directed=(
191-
G.symmetrize_edges if G.is_directed() else False
192-
),
193-
)
194-
195-
print(f"ADB -> COO load took {time.time() - start_time}s")
196-
197-
G.src_indices = src_indices
198-
G.dst_indices = dst_indices
199-
G.edge_indices = edge_indices
200-
G.vertex_ids_to_index = vertex_ids_to_index
201-
G.edge_values = edge_values
202-
203-
N = len(G.vertex_ids_to_index) # type: ignore
204-
src_indices_cp = cp.array(G.src_indices)
205-
dst_indices_cp = cp.array(G.dst_indices)
206-
edge_indices_cp = cp.array(G.edge_indices)
160+
return G.nxcg_graph
161+
162+
start_time = time.time()
163+
164+
(
165+
_,
166+
_,
167+
src_indices,
168+
dst_indices,
169+
edge_indices,
170+
vertex_ids_to_index,
171+
edge_values,
172+
) = nxadb.classes.function.get_arangodb_graph(
173+
adb_graph=G.adb_graph,
174+
load_node_dict=False,
175+
load_adj_dict=False,
176+
load_coo=True,
177+
edge_collections_attributes=G.get_edge_attributes,
178+
load_all_vertex_attributes=False, # not used
179+
load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes),
180+
is_directed=G.is_directed(),
181+
is_multigraph=G.is_multigraph(),
182+
symmetrize_edges_if_directed=(
183+
G.symmetrize_edges if G.is_directed() else False
184+
),
185+
)
186+
187+
print(f"ADB Graph '{G.adb_graph.name}' load took {time.time() - start_time}s")
188+
189+
start_time = time.time()
190+
191+
N = len(vertex_ids_to_index)
192+
src_indices_cp = cp.array(src_indices)
193+
dst_indices_cp = cp.array(dst_indices)
194+
edge_indices_cp = cp.array(edge_indices)
207195

208196
if G.is_multigraph():
209197
if G.is_directed() or as_directed:
210198
klass = nxcg.MultiDiGraph
211199
else:
212200
klass = nxcg.MultiGraph
213201

214-
return klass.from_coo(
202+
G.nxcg_graph = klass.from_coo(
215203
N=N,
216204
src_indices=src_indices_cp,
217205
dst_indices=dst_indices_cp,
218206
edge_indices=edge_indices_cp,
219-
edge_values=G.edge_values,
207+
edge_values=edge_values,
220208
# edge_masks,
221209
# node_values,
222210
# node_masks,
223-
key_to_id=G.vertex_ids_to_index,
211+
key_to_id=vertex_ids_to_index,
224212
# edge_keys=edge_keys,
225213
)
226214

@@ -230,13 +218,17 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
230218
else:
231219
klass = nxcg.Graph
232220

233-
return klass.from_coo(
221+
G.nxcg_graph = klass.from_coo(
234222
N=N,
235223
src_indices=src_indices_cp,
236224
dst_indices=dst_indices_cp,
237-
edge_values=G.edge_values,
225+
edge_values=edge_values,
238226
# edge_masks,
239227
# node_values,
240228
# node_masks,
241-
key_to_id=G.vertex_ids_to_index,
229+
key_to_id=vertex_ids_to_index,
242230
)
231+
232+
print(f"NXCG Graph construction took {time.time() - start_time}s")
233+
234+
return G.nxcg_graph

0 commit comments

Comments
 (0)