Skip to content

Commit 2abe017

Browse files
committed
feat(lsp): add go to definition for ctes
1 parent 2808e66 commit 2abe017

File tree

3 files changed

+192
-68
lines changed

3 files changed

+192
-68
lines changed

sqlmesh/lsp/main.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,21 +279,31 @@ def goto_definition(
279279
raise RuntimeError(f"No context found for document: {document.path}")
280280

281281
references = get_references(self.lsp_context, uri, params.position)
282-
return [
283-
types.LocationLink(
284-
target_uri=reference.uri,
285-
target_selection_range=types.Range(
282+
location_links = []
283+
for reference in references:
284+
# Use target_range if available (for CTEs), otherwise default to start of file
285+
if reference.target_range:
286+
target_range = reference.target_range
287+
target_selection_range = reference.target_range
288+
else:
289+
target_range = types.Range(
286290
start=types.Position(line=0, character=0),
287291
end=types.Position(line=0, character=0),
288-
),
289-
target_range=types.Range(
292+
)
293+
target_selection_range = types.Range(
290294
start=types.Position(line=0, character=0),
291295
end=types.Position(line=0, character=0),
292-
),
293-
origin_selection_range=reference.range,
296+
)
297+
298+
location_links.append(
299+
types.LocationLink(
300+
target_uri=reference.uri,
301+
target_selection_range=target_selection_range,
302+
target_range=target_range,
303+
origin_selection_range=reference.range,
304+
)
294305
)
295-
for reference in references
296-
]
306+
return location_links
297307
except Exception as e:
298308
ls.show_message(f"Error getting references: {e}", types.MessageType.Error)
299309
return []

sqlmesh/lsp/reference.py

Lines changed: 108 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,26 @@
66
from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
77
from sqlglot import exp
88
from sqlmesh.lsp.description import generate_markdown_description
9+
from sqlglot.optimizer.scope import build_scope
910
from sqlmesh.lsp.uri import URI
1011
from sqlmesh.utils.pydantic import PydanticModel
1112

1213

1314
class Reference(PydanticModel):
1415
"""
15-
A reference to a model.
16+
A reference to a model or CTE.
1617
1718
Attributes:
1819
range: The range of the reference in the source file
1920
uri: The uri of the referenced model
2021
markdown_description: The markdown description of the referenced model
22+
target_range: The range of the definition for go-to-definition (optional, used for CTEs)
2123
"""
2224

2325
range: Range
2426
uri: str
2527
markdown_description: t.Optional[str] = None
28+
target_range: t.Optional[Range] = None
2629

2730

2831
def by_position(position: Position) -> t.Callable[[Reference], bool]:
@@ -88,6 +91,7 @@ def get_model_definitions_for_a_path(
8891
- Need to normalize it before matching
8992
- Try get_model before normalization
9093
- Match to models that the model refers to
94+
- Also find CTE references within the query
9195
"""
9296
path = document_uri.to_path()
9397
if path.suffix != ".sql":
@@ -126,66 +130,112 @@ def get_model_definitions_for_a_path(
126130
# Find all possible references
127131
references = []
128132

129-
# Get SQL query and find all table references
130-
tables = list(query.find_all(exp.Table))
131-
if len(tables) == 0:
132-
return []
133-
134133
with open(file_path, "r", encoding="utf-8") as file:
135134
read_file = file.readlines()
136135

137-
for table in tables:
138-
# Normalize the table reference
139-
unaliased = table.copy()
140-
if unaliased.args.get("alias") is not None:
141-
unaliased.set("alias", None)
142-
reference_name = unaliased.sql(dialect=dialect)
143-
try:
144-
normalized_reference_name = normalize_model_name(
145-
reference_name,
146-
default_catalog=lint_context.context.default_catalog,
147-
dialect=dialect,
148-
)
149-
if normalized_reference_name not in depends_on:
150-
continue
151-
except Exception:
152-
# Skip references that cannot be normalized
153-
continue
154-
155-
# Get the referenced model uri
156-
referenced_model = lint_context.context.get_model(
157-
model_or_snapshot=normalized_reference_name, raise_if_missing=False
158-
)
159-
if referenced_model is None:
160-
continue
161-
referenced_model_path = referenced_model._path
162-
# Check whether the path exists
163-
if not referenced_model_path.is_file():
164-
continue
165-
referenced_model_uri = URI.from_path(referenced_model_path)
166-
167-
# Extract metadata for positioning
168-
table_meta = TokenPositionDetails.from_meta(table.this.meta)
169-
table_range = _range_from_token_position_details(table_meta, read_file)
170-
start_pos = table_range.start
171-
end_pos = table_range.end
172-
173-
# If there's a catalog or database qualifier, adjust the start position
174-
catalog_or_db = table.args.get("catalog") or table.args.get("db")
175-
if catalog_or_db is not None:
176-
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
177-
catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file)
178-
start_pos = catalog_or_db_range.start
179-
180-
description = generate_markdown_description(referenced_model)
181-
182-
references.append(
183-
Reference(
184-
uri=referenced_model_uri.value,
185-
range=Range(start=start_pos, end=end_pos),
186-
markdown_description=description,
187-
)
188-
)
136+
# Build scope tree to properly handle nested CTEs
137+
root_scope = build_scope(query)
138+
139+
if root_scope:
140+
# Traverse all scopes to find CTE definitions and table references
141+
for scope in root_scope.traverse():
142+
# Build a map of CTE names to their definitions within this scope
143+
cte_definitions = {}
144+
145+
# For CTEs defined in this scope
146+
for cte in scope.ctes:
147+
if cte.alias:
148+
cte_definitions[cte.alias] = cte
149+
150+
# Also include CTEs from parent scopes (for references inside nested CTEs)
151+
parent = scope.parent
152+
while parent:
153+
for cte in parent.ctes:
154+
if cte.alias and cte.alias not in cte_definitions:
155+
cte_definitions[cte.alias] = cte
156+
parent = parent.parent
157+
158+
# Get all table references in this scope
159+
tables = list(scope.find_all(exp.Table))
160+
161+
for table in tables:
162+
table_name = table.name
163+
164+
# Check if this table reference is a CTE in the current scope
165+
if cte_def := cte_definitions.get(table_name):
166+
# This is a CTE reference - create a reference to the CTE definition
167+
alias = cte_def.args["alias"]
168+
if isinstance(alias, exp.TableAlias):
169+
identifier = alias.this
170+
if isinstance(identifier, exp.Identifier):
171+
target_range = _range_from_token_position_details(
172+
TokenPositionDetails.from_meta(identifier.meta), read_file
173+
)
174+
table_range = _range_from_token_position_details(
175+
TokenPositionDetails.from_meta(table.this.meta), read_file
176+
)
177+
references.append(
178+
Reference(
179+
uri=document_uri.value, # Same file
180+
range=table_range,
181+
target_range=target_range,
182+
)
183+
)
184+
185+
# For non-CTE tables, process as before (external model references)
186+
# Normalize the table reference
187+
unaliased = table.copy()
188+
if unaliased.args.get("alias") is not None:
189+
unaliased.set("alias", None)
190+
reference_name = unaliased.sql(dialect=dialect)
191+
try:
192+
normalized_reference_name = normalize_model_name(
193+
reference_name,
194+
default_catalog=lint_context.context.default_catalog,
195+
dialect=dialect,
196+
)
197+
if normalized_reference_name not in depends_on:
198+
continue
199+
except Exception:
200+
# Skip references that cannot be normalized
201+
continue
202+
203+
# Get the referenced model uri
204+
referenced_model = lint_context.context.get_model(
205+
model_or_snapshot=normalized_reference_name, raise_if_missing=False
206+
)
207+
if referenced_model is None:
208+
continue
209+
referenced_model_path = referenced_model._path
210+
# Check whether the path exists
211+
if not referenced_model_path.is_file():
212+
continue
213+
referenced_model_uri = URI.from_path(referenced_model_path)
214+
215+
# Extract metadata for positioning
216+
table_meta = TokenPositionDetails.from_meta(table.this.meta)
217+
table_range = _range_from_token_position_details(table_meta, read_file)
218+
start_pos = table_range.start
219+
end_pos = table_range.end
220+
221+
# If there's a catalog or database qualifier, adjust the start position
222+
catalog_or_db = table.args.get("catalog") or table.args.get("db")
223+
if catalog_or_db is not None:
224+
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
225+
catalog_or_db_range = _range_from_token_position_details(
226+
catalog_or_db_meta, read_file
227+
)
228+
start_pos = catalog_or_db_range.start
229+
230+
description = generate_markdown_description(referenced_model)
231+
232+
references.append(
233+
Reference(
234+
uri=referenced_model_uri.value,
235+
range=Range(start=start_pos, end=end_pos),
236+
markdown_description=description,
237+
)
238+
)
189239

190240
return references
191241

tests/lsp/test_reference_cte.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import re
2+
from sqlmesh.core.context import Context
3+
from sqlmesh.lsp.context import LSPContext, ModelTarget
4+
from sqlmesh.lsp.reference import get_references
5+
from sqlmesh.lsp.uri import URI
6+
from lsprotocol.types import Range, Position
7+
import typing as t
8+
9+
10+
def test_cte_parsing():
11+
context = Context(paths=["examples/sushi"])
12+
lsp_context = LSPContext(context)
13+
14+
# Find model URIs
15+
sushi_customers_path = next(
16+
path
17+
for path, info in lsp_context.map.items()
18+
if isinstance(info, ModelTarget) and "sushi.customers" in info.names
19+
)
20+
21+
with open(sushi_customers_path, "r", encoding="utf-8") as file:
22+
read_file = file.readlines()
23+
24+
# Find position of the cte reference
25+
ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)")
26+
assert len(ranges) == 2
27+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
28+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
29+
assert len(references) == 1
30+
assert references[0].uri == URI.from_path(sushi_customers_path).value
31+
assert references[0].markdown_description is None
32+
assert (
33+
references[0].range.start.line == ranges[1].start.line
34+
) # The reference location (where we clicked)
35+
assert (
36+
references[0].target_range.start.line == ranges[0].start.line
37+
) # The CTE definition location
38+
39+
# Find the position of the current_marketing_outer reference
40+
ranges = find_ranges_from_regex(read_file, r"current_marketing_outer")
41+
assert len(ranges) == 2
42+
position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4)
43+
references = get_references(lsp_context, URI.from_path(sushi_customers_path), position)
44+
assert len(references) == 1
45+
assert references[0].uri == URI.from_path(sushi_customers_path).value
46+
assert references[0].description is None
47+
assert (
48+
references[0].range.start.line == ranges[1].start.line
49+
) # The reference location (where we clicked)
50+
assert (
51+
references[0].target_range.start.line == ranges[0].start.line
52+
) # The CTE definition location
53+
54+
55+
def find_ranges_from_regex(read_file: t.List[str], regex: str) -> t.List[Range]:
56+
"""Find all ranges in the read file that match the regex."""
57+
return [
58+
Range(
59+
start=Position(line=line_number, character=match.start()),
60+
end=Position(line=line_number, character=match.end()),
61+
)
62+
for line_number, line in enumerate(read_file)
63+
for match in [m for m in [re.search(regex, line)] if m]
64+
]

0 commit comments

Comments
 (0)