Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion nx_arangodb/classes/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
aql_fetch_data_edge,
aql_single,
create_collection,
do_load_all_edge_attributes,
doc_delete,
doc_get_or_insert,
doc_insert,
Expand Down Expand Up @@ -105,6 +106,7 @@ def adjlist_inner_dict_factory(
default_node_type: str,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
edge_collection_attributes: set[str],
graph_type: str,
adjlist_outer_dict: AdjListOuterDict | None = None,
) -> Callable[..., AdjListInnerDict]:
Expand All @@ -114,6 +116,7 @@ def adjlist_inner_dict_factory(
default_node_type,
edge_type_key,
edge_type_func,
edge_collection_attributes,
graph_type,
adjlist_outer_dict,
)
Expand All @@ -125,6 +128,7 @@ def adjlist_outer_dict_factory(
default_node_type: str,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
edge_collection_attributes: set[str],
graph_type: str,
symmetrize_edges_if_directed: bool,
) -> Callable[..., AdjListOuterDict]:
Expand All @@ -134,6 +138,7 @@ def adjlist_outer_dict_factory(
default_node_type,
edge_type_key,
edge_type_func,
edge_collection_attributes,
graph_type,
symmetrize_edges_if_directed,
)
Expand Down Expand Up @@ -752,6 +757,7 @@ def _fetch_all(self):
load_node_dict=True,
load_adj_dict=False,
load_coo=False,
edge_collections_attributes=set(),
load_all_vertex_attributes=True,
load_all_edge_attributes=False, # not used
is_directed=False, # not used
Expand Down Expand Up @@ -1376,6 +1382,7 @@ def __init__(
default_node_type: str,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
edge_collection_attributes: set[str],
graph_type: str,
adjlist_outer_dict: AdjListOuterDict | None,
*args: Any,
Expand All @@ -1395,6 +1402,7 @@ def __init__(
self.graph = graph
self.edge_type_key = edge_type_key
self.edge_type_func = edge_type_func
self.edge_collection_attributes = edge_collection_attributes
self.default_node_type = default_node_type
self.edge_attr_dict_factory = edge_attr_dict_factory(self.db, self.graph)
self.edge_key_dict_factory = edge_key_dict_factory(
Expand Down Expand Up @@ -1934,6 +1942,7 @@ def __init__(
default_node_type: str,
edge_type_key: str,
edge_type_func: Callable[[str, str], str],
edge_collection_attributes: set[str],
graph_type: str,
symmetrize_edges_if_directed: bool,
*args: Any,
Expand All @@ -1953,13 +1962,15 @@ def __init__(
self.graph = graph
self.edge_type_key = edge_type_key
self.edge_type_func = edge_type_func
self.edge_collection_attributes = edge_collection_attributes
self.default_node_type = default_node_type
self.adjlist_inner_dict_factory = adjlist_inner_dict_factory(
db,
graph,
default_node_type,
edge_type_key,
edge_type_func,
self.edge_collection_attributes,
graph_type,
self,
)
Expand Down Expand Up @@ -2254,8 +2265,11 @@ def set_edge_multigraph(
load_node_dict=False,
load_adj_dict=True,
load_coo=False,
edge_collections_attributes=self.edge_collection_attributes,
load_all_vertex_attributes=False, # not used
load_all_edge_attributes=True,
load_all_edge_attributes=do_load_all_edge_attributes(
self.edge_collection_attributes
),
is_directed=self.is_directed,
is_multigraph=self.is_multigraph,
symmetrize_edges_if_directed=self.is_directed
Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -41,6 +42,7 @@ def __init__(
default_node_type,
edge_type_key,
edge_type_func,
edge_collections_attributes,
db,
read_parallelism,
read_batch_size,
Expand Down
55 changes: 50 additions & 5 deletions nx_arangodb/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,39 @@
)


def _build_meta_graph(
v_cols: list[str], e_cols: set[str], edge_collections_attributes: set[str]
) -> dict[str, dict[str, Any]]:
if len(edge_collections_attributes) == 0:
return {
"vertexCollections": {col: set() for col in v_cols},
"edgeCollections": {col: set() for col in e_cols},
}
else:
return {
"vertexCollections": {col: set() for col in v_cols},
"edgeCollections": {
col: {attr: set() for attr in edge_collections_attributes}
for col in e_cols
},
}


def do_load_all_edge_attributes(attributes: set[str] | None) -> bool:
if attributes is None:
return True
if len(attributes) == 0:
return True

return False


def get_arangodb_graph(
adb_graph: Graph,
load_node_dict: bool,
load_adj_dict: bool,
load_coo: bool,
edge_collections_attributes: set[str],
load_all_vertex_attributes: bool,
load_all_edge_attributes: bool,
is_directed: bool,
Expand All @@ -69,10 +97,9 @@ def get_arangodb_graph(
edge_definitions = adb_graph.edge_definitions()
e_cols = {c["edge_collection"] for c in edge_definitions}

metagraph: dict[str, dict[str, Any]] = {
"vertexCollections": {col: set() for col in v_cols},
"edgeCollections": {col: set() for col in e_cols},
}
metagraph: dict[str, dict[str, Any]] = _build_meta_graph(
v_cols, e_cols, edge_collections_attributes
)

if not any((load_node_dict, load_adj_dict, load_coo)):
raise ValueError("At least one of the load flags must be True.")
Expand All @@ -89,6 +116,24 @@ def get_arangodb_graph(
assert config.username
assert config.password

res_do_load_all_edge_attributes = do_load_all_edge_attributes(
edge_collections_attributes
)

if res_do_load_all_edge_attributes is not load_all_edge_attributes:
if (
edge_collections_attributes is not None
and len(edge_collections_attributes) > 0
):
raise ValueError(
"You have specified to load at least one specific edge attribute"
" and at the same time set the parameter `load_all_vertex_attributes`"
" to true. This combination is not allowed."
)
else:
# We need this case as the user wants by purpose to not load any edge data
res_do_load_all_edge_attributes = load_all_edge_attributes

(
node_dict,
adj_dict,
Expand All @@ -106,7 +151,7 @@ def get_arangodb_graph(
load_adj_dict=load_adj_dict,
load_coo=load_coo,
load_all_vertex_attributes=load_all_vertex_attributes,
load_all_edge_attributes=load_all_edge_attributes,
load_all_edge_attributes=res_do_load_all_edge_attributes,
is_directed=is_directed,
is_multigraph=is_multigraph,
symmetrize_edges_if_directed=symmetrize_edges_if_directed,
Expand Down
19 changes: 19 additions & 0 deletions nx_arangodb/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -69,6 +70,8 @@ def __init__(
self.read_batch_size = read_batch_size
self.write_batch_size = write_batch_size

self._set_edge_collections_attributes_to_fetch(edge_collections_attributes)

# NOTE: Need to revisit these...
# self.maintain_node_dict_cache = False
# self.maintain_adj_dict_cache = False
Expand Down Expand Up @@ -222,6 +225,7 @@ def _set_factory_methods(self) -> None:
*node_args,
self.edge_type_key,
self.edge_type_func,
self.get_edge_attributes,
self.__class__.__name__,
)

Expand All @@ -236,6 +240,17 @@ def _set_factory_methods(self) -> None:
*adj_args, self.symmetrize_edges
)

def _set_edge_collections_attributes_to_fetch(
self, attributes: set[str] | None
) -> None:
if attributes is None:
self._edge_collections_attributes = set()
return
if len(attributes) > 0:
self._edge_collections_attributes = attributes
if "_id" not in attributes:
self._edge_collections_attributes.add("_id")

###########
# Getters #
###########
Expand All @@ -258,6 +273,10 @@ def graph_name(self) -> str:
def graph_exists_in_db(self) -> bool:
return self._graph_exists_in_db

@property
def get_edge_attributes(self) -> set[str]:
return self._edge_collections_attributes

###########
# Setters #
###########
Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/multidigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -40,6 +41,7 @@ def __init__(
default_node_type,
edge_type_key,
edge_type_func,
edge_collections_attributes,
db,
read_parallelism,
read_batch_size,
Expand Down
2 changes: 2 additions & 0 deletions nx_arangodb/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
default_node_type: str | None = None,
edge_type_key: str = "_edge_type",
edge_type_func: Callable[[str, str], str] | None = None,
edge_collections_attributes: set[str] | None = None,
db: StandardDatabase | None = None,
read_parallelism: int = 10,
read_batch_size: int = 100000,
Expand All @@ -40,6 +41,7 @@ def __init__(
default_node_type,
edge_type_key,
edge_type_func,
edge_collections_attributes,
db,
read_parallelism,
read_batch_size,
Expand Down
10 changes: 7 additions & 3 deletions nx_arangodb/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import networkx as nx

import nx_arangodb as nxadb
from nx_arangodb.classes.function import do_load_all_edge_attributes
from nx_arangodb.logger import logger

try:
Expand Down Expand Up @@ -126,9 +127,9 @@ def nxadb_to_nx(G: nxadb.Graph) -> nx.Graph:
load_node_dict=True,
load_adj_dict=True,
load_coo=False,
edge_collections_attributes=G.get_edge_attributes,
load_all_vertex_attributes=False,
# TODO: Only return the edge attributes that are needed
load_all_edge_attributes=True,
load_all_edge_attributes=do_load_all_edge_attributes(G.get_edge_attributes),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=G.symmetrize_edges if G.is_directed() else False,
Expand Down Expand Up @@ -171,8 +172,11 @@ def nxadb_to_nxcg(G: nxadb.Graph, as_directed: bool = False) -> nxcg.Graph:
load_node_dict=False,
load_adj_dict=False,
load_coo=True,
edge_collections_attributes=G.get_edge_attributes,
load_all_vertex_attributes=False, # not used
load_all_edge_attributes=False, # not used
load_all_edge_attributes=do_load_all_edge_attributes(
G.get_edge_attributes
),
is_directed=G.is_directed(),
is_multigraph=G.is_multigraph(),
symmetrize_edges_if_directed=(
Expand Down
Loading