diff --git a/doc/tests/test_client_only_endpoints.py b/doc/tests/test_client_only_endpoints.py index b8f56e15b..7c8f86a5a 100644 --- a/doc/tests/test_client_only_endpoints.py +++ b/doc/tests/test_client_only_endpoints.py @@ -59,7 +59,7 @@ def find_covered_server_endpoints() -> List[str]: driver.close() - return [ep["name"] for ep in all_server_endpoints if not ep["name"] in IGNORED_SERVER_ENDPOINTS] + return [ep["name"] for ep in all_server_endpoints if ep["name"] not in IGNORED_SERVER_ENDPOINTS] def check_rst_files(endpoints: List[str]) -> None: diff --git a/graphdatascience/graph/graph_cypher_runner.py b/graphdatascience/graph/graph_cypher_runner.py new file mode 100644 index 000000000..34dcf0a3f --- /dev/null +++ b/graphdatascience/graph/graph_cypher_runner.py @@ -0,0 +1,384 @@ +from collections import defaultdict +from typing import Any, Dict, NamedTuple, Optional, Tuple + +from pandas import Series + +from ..error.illegal_attr_checker import IllegalAttrChecker +from ..query_runner.arrow_query_runner import ArrowQueryRunner +from ..query_runner.query_runner import QueryRunner +from ..server_version.server_version import ServerVersion +from .graph_object import Graph + + +class NodeProperty(NamedTuple): + name: str + property_key: str + default_value: Optional[Any] = None + + +class NodeProjection(NamedTuple): + name: str + source_label: str + properties: Optional[list[NodeProperty]] = None + + +class RelationshipProperty(NamedTuple): + name: str + property_key: str + default_value: Optional[Any] = None + + +class RelationshipProjection(NamedTuple): + name: str + source_type: str + properties: Optional[list[RelationshipProperty]] = None + + +class MatchParts(NamedTuple): + match: str = "" + source_where: str = "" + optional_match: str = "" + optional_where: str = "" + + def __str__(self) -> str: + return "\n".join( + part + for part in [ + self.match, + self.source_where, + self.optional_match, + self.optional_where, + ] + if part + ) + + +class MatchPattern(NamedTuple): + label_filter: str = "" + left_arrow: str = "" + type_filter: str = "" + right_arrow: str = "" + + def __str__(self) -> str: + return f"{self.left_arrow}{self.type_filter}{self.right_arrow}(target{self.label_filter})" + + +class LabelPropertyMapping(NamedTuple): + label: str + property_key: str + default_value: Optional[Any] = None + + +class GraphCypherRunner(IllegalAttrChecker): + def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion) -> None: + if server_version < ServerVersion(2, 4, 0): + raise ValueError("The new Cypher projection is only supported since GDS 2.4.0.") + super().__init__(query_runner, namespace, server_version) + + def project( + self, + graph_name: str, + *, + nodes: Any = None, + relationships: Any = None, + where: Optional[str] = None, + allow_disconnected_nodes: bool = False, + inverse: bool = False, + combine_labels_with: str = "OR", + **config: Any, + ) -> Tuple[Graph, "Series[Any]"]: + """ + Project a graph using Cypher projection. + + Parameters + ---------- + graph_name : str + The name of the graph to project. + nodes : Any + The nodes to project. If not specified, all nodes are projected. + relationships : Any + The relationships to project. If not specified, all relationships + are projected. + where : Optional[str] + A Cypher WHERE clause to filter the nodes and relationships to + project. + allow_disconnected_nodes : bool + Whether to allow disconnected nodes in the projected graph. + inverse : bool + Whether to project inverse relationships. The projected graph will + be configured as NATURAL. + combine_labels_with : str + Whether to combine node labels with AND or OR. The default is AND. + Allowed values are 'AND' and 'OR'. + **config : Any + Additional configuration for the projection. + + Returns + ------- + A tuple of the projected graph and statistics about the projection + """ + + query_params: Dict[str, Any] = {"graph_name": graph_name} + + data_config: Dict[str, Any] = {} + data_config_is_static = True + + nodes = self._node_projections_spec(nodes) + rels = self._rel_projections_spec(relationships) + + match_parts = MatchParts() + match_pattern = MatchPattern( + left_arrow="<-" if inverse else "-", + right_arrow="-" if inverse else "->", + ) + + label_mappings = defaultdict(list) + + if nodes: + if len(nodes) == 1 or combine_labels_with == "AND": + match_pattern = match_pattern._replace(label_filter=f":{':'.join(spec.source_label for spec in nodes)}") + + projected_labels = [spec.name for spec in nodes] + data_config["sourceNodeLabels"] = projected_labels + data_config["targetNodeLabels"] = projected_labels + + elif combine_labels_with == "OR": + source_labels_filter = " OR ".join(f"source:{spec.source_label}" for spec in nodes) + target_labels_filter = " OR ".join(f"target:{spec.source_label}" for spec in nodes) + if allow_disconnected_nodes: + match_parts = match_parts._replace( + source_where=f"WHERE {source_labels_filter}", optional_where=f"WHERE {target_labels_filter}" + ) + else: + match_parts = match_parts._replace( + source_where=f"WHERE ({source_labels_filter}) AND ({target_labels_filter})" + ) + + data_config["sourceNodeLabels"] = "labels(source)" + data_config["targetNodeLabels"] = "labels(target)" + data_config_is_static = False + else: + raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}") + + for spec in nodes: + if spec.properties: + for prop in spec.properties: + label_mappings[spec.source_label].append( + LabelPropertyMapping(spec.source_label, prop.property_key, prop.default_value) + ) + + rel_var = "" + if rels: + if len(rels) == 1: + data_config["relationshipType"] = rels[0].source_type + else: + rel_var = "rel" + data_config["relationshipTypes"] = "type(rel)" + data_config_is_static = False + + match_pattern = match_pattern._replace( + type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]" + ) + + source = f"(source{match_pattern.label_filter})" + if allow_disconnected_nodes: + match_parts = match_parts._replace( + match=f"MATCH {source}", optional_match=f"OPTIONAL MATCH (source){match_pattern}" + ) + else: + match_parts = match_parts._replace(match=f"MATCH {source}{match_pattern}") + + match_part = str(match_parts) + + print("nodes", nodes) + print("labels", label_mappings) + + case_part = [] + if label_mappings: + with_rel = f", {rel_var}" if rel_var else "" + case_part = [f"WITH source, target{with_rel}"] + for kind in ["source", "target"]: + case_part.append("CASE") + + for label, mappings in label_mappings.items(): + mapping_projection = ", ".join(f".{key.property_key}" for key in mappings) + when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mapping_projection}}}]" + case_part.append(when_part) + + case_part.append(f"END AS {kind}NodeProperties") + + data_config["sourceNodeProperties"] = "sourceNodeProperties" + data_config["targetNodeProperties"] = "targetNodeProperties" + data_config_is_static = False + + args = ["$graph_name", "source", "target"] + + if data_config: + if data_config_is_static: + query_params["data_config"] = data_config + args += ["$data_config"] + else: + args += [self._render_map(data_config)] + + if config: + query_params["config"] = config + args += ["$config"] + + return_part = f"RETURN {self._namespace}({', '.join(args)})" + + query = "\n".join(part for part in [match_part, *case_part, return_part] if part) + + result = self._query_runner.run_query_with_logging(query, query_params) + result = result.squeeze() + + return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore + + def run_project( + self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None + ) -> Tuple[Graph, "Series[Any]"]: + """ + Run a Cypher projection. + The provided query must end with a `RETURN gds.graph.project(...)` call. + + Parameters + ---------- + query: str + the Cypher projection query + params: Dict[str, Any] + parameters to the query + database: str + the database on which to run the query + + Returns + ------- + A tuple of the projected graph and statistics about the projection + """ + + return_clause = f"RETURN {self._namespace}" + + return_index = query.rfind(return_clause) + if return_index == -1: + raise ValueError(f"Invalid query, the query must end with a `{return_clause}` clause: {query}") + + return_index += len(return_clause) + return_part = query[return_index:] + + # Remove surrounding parentheses and whitespace + right_paren = return_part.rfind(")") + 1 + return_part = return_part[:right_paren].strip("() \n\t") + + graph_name = return_part.split(",", maxsplit=1)[0] + graph_name = graph_name.strip() + + if graph_name.startswith("$"): + if params is None: + raise ValueError( + f"Invalid query, the query references parameter `{graph_name}` but no params were given" + ) + + graph_name = graph_name[1:] + graph_name = params[graph_name] + else: + # remove the quotes + graph_name = graph_name.strip("'\"") + + # remove possible `AS graph` from the end of the query + end_of_query = return_index + right_paren + query = query[:end_of_query] + + # run_cypher + qr = self._query_runner + + # The Arrow query runner should not be used to execute arbitrary Cypher + if isinstance(qr, ArrowQueryRunner): + qr = qr.fallback_query_runner() + + result = qr.run_query(query, params, database, False) + result = result.squeeze() + + return Graph(graph_name, self._query_runner, self._server_version), result # type: ignore + + def _node_projections_spec(self, spec: Any) -> list[NodeProjection]: + if spec is None or spec is False: + return [] + + if isinstance(spec, str): + spec = [spec] + + if isinstance(spec, list): + return [self._node_projection_spec(node) for node in spec] + + if isinstance(spec, dict): + return [self._node_projection_spec(node, name) for name, node in spec.items()] + + raise TypeError(f"Invalid node projections specification: {spec}") + + def _node_projection_spec(self, spec: Any, name: Optional[str] = None) -> NodeProjection: + if isinstance(spec, str): + return NodeProjection(name=name or spec, source_label=spec) + + if name is None: + raise ValueError(f"Node projections with properties must use the dict syntax: {spec}") + + if isinstance(spec, dict): + properties = [self._node_properties_spec(prop, name) for name, prop in spec.items()] + return NodeProjection(name=name, source_label=name, properties=properties) + + if isinstance(spec, list): + properties = [self._node_properties_spec(prop) for prop in spec] + return NodeProjection(name=name, source_label=name, properties=properties) + + raise TypeError(f"Invalid node projection specification: {spec}") + + def _node_properties_spec(self, spec: Any, name: Optional[str] = None) -> NodeProperty: + if isinstance(spec, str): + return NodeProperty(name=name or spec, property_key=spec) + + if isinstance(spec, dict): + name = spec.pop("name", name) + if name is None: + raise ValueError( + f"Node properties must specify either a name in the outer dict or by using the `name` key: {spec}" + ) + property_key = spec.pop("property_key", name) + + return NodeProperty(name=name, property_key=property_key, **spec) + + if spec is True: + if name is None: + raise ValueError(f"Node properties spec must be used with the dict syntax: {spec}") + + return NodeProperty(name=name, property_key=name) + + raise TypeError(f"Invalid node property specification: {spec}") + + def _rel_projections_spec(self, spec: Any) -> list[RelationshipProjection]: + if spec is None or spec is False: + return [] + + if isinstance(spec, str): + spec = [spec] + + if isinstance(spec, list): + return [self._rel_projection_spec(node) for node in spec] + + if isinstance(spec, dict): + return [self._rel_projection_spec(node, name) for name, node in spec.items()] + + raise TypeError(f"Invalid relationship projection specification: {spec}") + + def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> RelationshipProjection: + if isinstance(spec, str): + return RelationshipProjection(name=name or spec, source_type=spec) + + raise TypeError(f"Invalid relationship projection specification: {spec}") + + def _rel_properties_spec(self, properties: Dict[str, Any]) -> list[RelationshipProperty]: + raise TypeError(f"Invalid relationship projection specification: {properties}") + + def _render_map(self, mapping: Dict[str, Any]) -> str: + return "{" + ", ".join(f"{key}: {value}" for key, value in mapping.items()) + "}" + + # + # def estimate(self, *, nodes: Any, relationships: Any, **config: Any) -> "Series[Any]": + # pass diff --git a/graphdatascience/graph/graph_proc_runner.py b/graphdatascience/graph/graph_proc_runner.py index f7633dcd8..90c1d6763 100644 --- a/graphdatascience/graph/graph_proc_runner.py +++ b/graphdatascience/graph/graph_proc_runner.py @@ -27,6 +27,7 @@ from .graph_sample_runner import GraphSampleRunner from .graph_type_check import graph_type_check, graph_type_check_optional from .ogb_loader import OGBLLoader, OGBNLoader +from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner Strings = Union[str, List[str]] @@ -165,6 +166,11 @@ def project(self) -> GraphProjectRunner: self._namespace += ".project" return GraphProjectRunner(self._query_runner, self._namespace, self._server_version) + @property + def cypher(self) -> GraphCypherRunner: + self._namespace += ".project" + return GraphCypherRunner(self._query_runner, self._namespace, self._server_version) + @property def export(self) -> GraphExportRunner: self._namespace += ".export" diff --git a/graphdatascience/tests/unit/test_graph_cypher.py b/graphdatascience/tests/unit/test_graph_cypher.py new file mode 100644 index 000000000..532690ebb --- /dev/null +++ b/graphdatascience/tests/unit/test_graph_cypher.py @@ -0,0 +1,308 @@ +import pytest + +from .conftest import CollectingQueryRunner +from graphdatascience.graph_data_science import GraphDataScience +from graphdatascience.server_version.server_version import ServerVersion + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)") + + assert G.name() == "gg" + assert runner.last_params() == {} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)" + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project_with_return_as(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project("MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t) AS graph") + + assert G.name() == "gg" + assert runner.last_params() == {} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project('gg', s, t)" + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_run_project_with_graph_name_parameter(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.run_project( + "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)", params={"graph_name": "gg"} + ) + + assert G.name() == "gg" + assert runner.last_params() == {"graph_name": "gg"} + + assert runner.last_query() == "MATCH (s)-->(t) RETURN gds.graph.project($graph_name, s, t)" + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_all(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g") + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert ( + runner.last_query() + == """MATCH (source)-->(target) +RETURN gds.graph.project($graph_name, source, target)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert ( + runner.last_query() + == """MATCH (source) +OPTIONAL MATCH (source)-->(target) +RETURN gds.graph.project($graph_name, source, target)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_inverse_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", inverse=True) # TODO: or using orientation="INVERSE"? + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert ( + runner.last_query() + == """MATCH (source)<--(target) +RETURN gds.graph.project($graph_name, source, target)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_single_node_label(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A") + + assert G.name() == "g" + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]}, + } + + assert ( + runner.last_query() + == """MATCH (source:A)-->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_single_node_label(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"]}, + } + + assert ( + runner.last_query() + == """MATCH (source:A) +OPTIONAL MATCH (source)-->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_single_node_label_alias(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes={"Target": "Label"}) + + assert G.name() == "g" + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["Target"], "targetNodeLabels": ["Target"]}, + } + + assert ( + runner.last_query() + == """MATCH (source:Label)-->(target:Label) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_multiple_node_labels_and(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="AND") + + assert G.name() == "g" + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]}, + } + + assert ( + runner.last_query() + == """MATCH (source:A:B)-->(target:A:B) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_multiple_node_labels_and(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="AND", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A", "B"], "targetNodeLabels": ["A", "B"]}, + } + + assert ( + runner.last_query() + == """MATCH (source:A:B) +OPTIONAL MATCH (source)-->(target:A:B) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR") + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert runner.last_query() == ( + """MATCH (source)-->(target) +WHERE (source:A OR source:B) AND (target:A OR target:B) +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target)})" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert runner.last_query() == ( + """MATCH (source) +WHERE source:A OR source:B +OPTIONAL MATCH (source)-->(target) +WHERE target:A OR target:B +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target)})" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_single_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A", relationships="REL") + + assert G.name() == "g" + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, + } + + assert ( + runner.last_query() + == """MATCH (source:A)-[:REL]->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_disconnected_nodes_single_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes="A", relationships="REL", allow_disconnected_nodes=True) + + assert G.name() == "g" + assert runner.last_params() == { + "graph_name": "g", + "data_config": {"sourceNodeLabels": ["A"], "targetNodeLabels": ["A"], "relationshipType": "REL"}, + } + + assert ( + runner.last_query() + == """MATCH (source:A) +OPTIONAL MATCH (source)-[:REL]->(target:A) +RETURN gds.graph.project($graph_name, source, target, $data_config)""" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], relationships=["REL1", "REL2"]) + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert ( + runner.last_query() + == """MATCH (source)-[rel:REL1|REL2]->(target) +WHERE (source:A OR source:B) AND (target:A OR target:B) +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target), " + "relationshipTypes: type(rel)})" + ) + + +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project( + "g", nodes={"L1": ["prop1"], "L2": ["prop2", "prop3"], "L3": {"prop4": True, "prop5": {}}} + ) + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert runner.last_query() == ( + """MATCH (source)-->(target) +WHERE (source:L1 OR source:L2 OR source:L3) AND (target:L1 OR target:L2 OR target:L3) +WITH source, target +CASE +WHEN 'L1' in labels(source) THEN [source {.prop1}] +WHEN 'L2' in labels(source) THEN [source {.prop2, .prop3}] +WHEN 'L3' in labels(source) THEN [source {.prop4, .prop5}] +END AS sourceNodeProperties +CASE +WHEN 'L1' in labels(target) THEN [target {.prop1}] +WHEN 'L2' in labels(target) THEN [target {.prop2, .prop3}] +WHEN 'L3' in labels(target) THEN [target {.prop4, .prop5}] +END AS targetNodeProperties +RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target), " + "sourceNodeProperties: sourceNodeProperties, " + "targetNodeProperties: targetNodeProperties})" + ) + + +@pytest.mark.skip(reason="Not implemented yet") +@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)]) +def test_node_properties_alias(runner: CollectingQueryRunner, gds: GraphDataScience) -> None: + G, _ = gds.graph.cypher.project( + "g", nodes={"A": {"target_prop1": "source_prop1", "target_prop2": {"property_key": "source_prop2"}}} + ) + + assert G.name() == "g" + assert runner.last_params() == {"graph_name": "g"} + + assert runner.last_query() == ( + """MATCH (source:A)-->(target:A) +WITH source, target, """ + "[{target_prop1: source.source_prop1, target_prop1: source.source_prop2}] AS sourceNodeProperties" + """[{target_prop1: target.source_prop1, target_prop1: target.source_prop2}] AS targetNodeProperties + RETURN gds.graph.project($graph_name, source, target, {""" + "sourceNodeLabels: labels(source), " + "targetNodeLabels: labels(target), " + "sourceNodeProperties: sourceNodeProperties, " + "targetNodeProperties: targetNodeProperties})" + )