Skip to content

Commit 110711a

Browse files
committed
feat(lsp): add go to definition for ctes
1 parent a58a59f commit 110711a

File tree

3 files changed

+111
-15
lines changed

3 files changed

+111
-15
lines changed

sqlmesh/lsp/main.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,21 +265,31 @@ def goto_definition(
265265
raise RuntimeError(f"No context found for document: {document.path}")
266266

267267
references = get_references(self.lsp_context, uri, params.position)
268-
return [
269-
types.LocationLink(
270-
target_uri=reference.uri,
271-
target_selection_range=types.Range(
268+
location_links = []
269+
for reference in references:
270+
# Use target_range if available (for CTEs), otherwise default to start of file
271+
if reference.target_range:
272+
target_range = reference.target_range
273+
target_selection_range = reference.target_range
274+
else:
275+
target_range = types.Range(
272276
start=types.Position(line=0, character=0),
273277
end=types.Position(line=0, character=0),
274-
),
275-
target_range=types.Range(
278+
)
279+
target_selection_range = types.Range(
276280
start=types.Position(line=0, character=0),
277281
end=types.Position(line=0, character=0),
278-
),
279-
origin_selection_range=reference.range,
282+
)
283+
284+
location_links.append(
285+
types.LocationLink(
286+
target_uri=reference.uri,
287+
target_selection_range=target_selection_range,
288+
target_range=target_range,
289+
origin_selection_range=reference.range,
290+
)
280291
)
281-
for reference in references
282-
]
292+
return location_links
283293
except Exception as e:
284294
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
285295
return []

sqlmesh/lsp/reference.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@
1111

1212
class Reference(PydanticModel):
1313
"""
14-
A reference to a model.
14+
A reference to a model or CTE.
1515
1616
Attributes:
1717
range: The range of the reference in the source file
18-
uri: The uri of the referenced model
19-
description: The description of the referenced model
18+
uri: The uri of the referenced model or file
19+
description: The description of the referenced model or CTE
20+
target_range: The range of the definition for go-to-definition (optional, used for CTEs)
2021
"""
2122

2223
range: Range
2324
uri: str
2425
description: t.Optional[str] = None
26+
target_range: t.Optional[Range] = None
2527

2628

2729
def by_position(position: Position) -> t.Callable[[Reference], bool]:
@@ -87,6 +89,7 @@ def get_model_definitions_for_a_path(
8789
- Need to normalize it before matching
8890
- Try get_model before normalization
8991
- Match to models that the model refers to
92+
- Also find CTE references within the query
9093
"""
9194
path = document_uri.to_path()
9295
if path.suffix != ".sql":
@@ -127,13 +130,49 @@ def get_model_definitions_for_a_path(
127130

128131
# Get SQL query and find all table references
129132
tables = list(query.find_all(exp.Table))
130-
if len(tables) == 0:
131-
return []
132133

133134
with open(file_path, "r", encoding="utf-8") as file:
134135
read_file = file.readlines()
135136

137+
# Build a map of CTE names to their definitions for CTE go-to-definition
138+
cte_definitions = {}
139+
with_clause = query.find(exp.With)
140+
if with_clause:
141+
for cte in with_clause.expressions:
142+
if isinstance(cte.alias, str):
143+
cte_definitions[cte.alias] = cte
144+
136145
for table in tables:
146+
table_name = table.name
147+
148+
# Check if this table reference is a CTE
149+
if table_name in cte_definitions:
150+
try:
151+
# This is a CTE reference - create a reference to the CTE definition
152+
cte_def = cte_definitions[table_name]
153+
args = cte_def.args["alias"]
154+
if args and isinstance(args, exp.TableAlias):
155+
identifier = args.this
156+
if isinstance(identifier, exp.Identifier):
157+
meta = identifier.meta
158+
159+
table_meta = TokenPositionDetails.from_meta(meta)
160+
target_range = _range_from_token_position_details(table_meta, read_file)
161+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
162+
table_range = _range_from_token_position_details(table_meta, read_file)
163+
164+
references.append(
165+
Reference(
166+
uri=document_uri.value, # Same file
167+
range=table_range,
168+
target_range=target_range,
169+
)
170+
)
171+
except Exception:
172+
pass
173+
continue
174+
175+
# For non-CTE tables, process as before (external model references)
137176
# Normalize the table reference
138177
unaliased = table.copy()
139178
if unaliased.args.get("alias") is not None:

tests/lsp/test_reference_cte.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import re
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_references
5+
from sqlmesh.lsp.uri import URI
6+
from lsprotocol.types import Range, Position
7+
import typing as t
8+
9+
10+
def test_cte_parsing():
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Find model URIs
15+
sushi_customers_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
19+
)
20+
21+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
22+
read_file = file.readlines()
23+
24+
# Find position of the cte reference
25+
ranges = find_ranges_from_regex(read_file, r"current_marketing")
26+
assert len(ranges) == 2
27+
# Middle of the second range
28+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
29+
30+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
31+
32+
assert len(references) == 1
33+
assert references[0].uri == URI.from_path(sushi_customers_path).value
34+
assert references[0].description is None
35+
assert references[0].range == ranges[1]
36+
37+
38+
def find_ranges_from_regex(read_file: t.List[str], regex: str) -> t.List[Range]:
39+
"""Find all ranges in the read file that match the regex."""
40+
return [
41+
Range(
42+
start=Position(line=line_number, character=match.start()),
43+
end=Position(line=line_number, character=match.end()),
44+
)
45+
for line_number, line in enumerate(read_file)
46+
for match in [m for m in [re.search(regex, line)] if m]
47+
]

0 commit comments

Comments
 (0)