Skip to content

Commit 7fb3dd9

Browse files
committed
simplify get rel types
1 parent 08ec26d commit 7fb3dd9

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

mcp_server/src/mcp_server_neo4j_gds/gds.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,8 @@ def get_relationship_properties_keys(gds: GraphDataScience, relationshipTypes=No
204204
return df["properties_keys"].iloc[0]
205205

206206

207-
def get_relationship_types(gds: GraphDataScience, node_labels=None):
208-
if node_labels is None:
209-
node_labels = []
207+
def get_relationship_types(gds: GraphDataScience):
208+
node_labels = []
210209
type_extractor = """
211210
WITH type(r) AS type
212211
WITH DISTINCT type

mcp_server/src/mcp_server_neo4j_gds/server.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,6 @@ async def handle_list_tools() -> list[types.Tool]:
108108
description="""Get relationship types in the database.""",
109109
inputSchema={
110110
"type": "object",
111-
"properties": {
112-
"nodeLabels": {
113-
"type": "array",
114-
"items": {"type": "string"},
115-
"description": "Ignore relationships whose source and target node is not in the specified node labels. To use this parameter, use get_node_labels first.",
116-
},
117-
},
118-
"required": [],
119111
},
120112
),
121113
]
@@ -151,10 +143,7 @@ async def handle_call_tool(
151143
result = get_node_labels(gds)
152144
return [types.TextContent(type="text", text=serialize_result(result))]
153145
elif name == "get_relationship_types":
154-
if "nodeLabels" in arguments:
155-
result = get_relationship_types(gds, arguments["nodeLabels"])
156-
else:
157-
result = get_relationship_types(gds)
146+
result = get_relationship_types(gds)
158147
return [types.TextContent(type="text", text=serialize_result(result))]
159148

160149
else:

mcp_server/tests/test_basic_tools.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,14 @@ async def test_get_node_labels(mcp_client):
125125
properties_keys = json.loads(result_text)
126126

127127
assert properties_keys == ["UndergroundStation"]
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_get_relationship_types(mcp_client):
132+
result = await mcp_client.call_tool("get_relationship_types")
133+
134+
assert len(result) == 1
135+
result_text = result[0]["text"]
136+
properties_keys = json.loads(result_text)
137+
138+
assert properties_keys == ["LINK"]

0 commit comments

Comments
 (0)