|
6 | 6 | from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget
|
7 | 7 | from sqlglot import exp
|
8 | 8 | from sqlmesh.lsp.description import generate_markdown_description
|
| 9 | +from sqlglot.optimizer.scope import build_scope |
9 | 10 | from sqlmesh.lsp.uri import URI
|
10 | 11 | from sqlmesh.utils.pydantic import PydanticModel
|
11 | 12 |
|
12 | 13 |
|
13 | 14 | class Reference(PydanticModel):
|
14 | 15 | """
|
15 |
| - A reference to a model. |
| 16 | + A reference to a model or CTE. |
16 | 17 |
|
17 | 18 | Attributes:
|
18 | 19 | range: The range of the reference in the source file
|
19 | 20 | uri: The uri of the referenced model
|
20 | 21 | markdown_description: The markdown description of the referenced model
|
| 22 | + target_range: The range of the definition for go-to-definition (optional, used for CTEs) |
21 | 23 | """
|
22 | 24 |
|
23 | 25 | range: Range
|
24 | 26 | uri: str
|
25 | 27 | markdown_description: t.Optional[str] = None
|
| 28 | + target_range: t.Optional[Range] = None |
26 | 29 |
|
27 | 30 |
|
28 | 31 | def by_position(position: Position) -> t.Callable[[Reference], bool]:
|
@@ -88,6 +91,7 @@ def get_model_definitions_for_a_path(
|
88 | 91 | - Need to normalize it before matching
|
89 | 92 | - Try get_model before normalization
|
90 | 93 | - Match to models that the model refers to
|
| 94 | + - Also find CTE references within the query |
91 | 95 | """
|
92 | 96 | path = document_uri.to_path()
|
93 | 97 | if path.suffix != ".sql":
|
@@ -126,66 +130,112 @@ def get_model_definitions_for_a_path(
|
126 | 130 | # Find all possible references
|
127 | 131 | references = []
|
128 | 132 |
|
129 |
| - # Get SQL query and find all table references |
130 |
| - tables = list(query.find_all(exp.Table)) |
131 |
| - if len(tables) == 0: |
132 |
| - return [] |
133 |
| - |
134 | 133 | with open(file_path, "r", encoding="utf-8") as file:
|
135 | 134 | read_file = file.readlines()
|
136 | 135 |
|
137 |
| - for table in tables: |
138 |
| - # Normalize the table reference |
139 |
| - unaliased = table.copy() |
140 |
| - if unaliased.args.get("alias") is not None: |
141 |
| - unaliased.set("alias", None) |
142 |
| - reference_name = unaliased.sql(dialect=dialect) |
143 |
| - try: |
144 |
| - normalized_reference_name = normalize_model_name( |
145 |
| - reference_name, |
146 |
| - default_catalog=lint_context.context.default_catalog, |
147 |
| - dialect=dialect, |
148 |
| - ) |
149 |
| - if normalized_reference_name not in depends_on: |
150 |
| - continue |
151 |
| - except Exception: |
152 |
| - # Skip references that cannot be normalized |
153 |
| - continue |
154 |
| - |
155 |
| - # Get the referenced model uri |
156 |
| - referenced_model = lint_context.context.get_model( |
157 |
| - model_or_snapshot=normalized_reference_name, raise_if_missing=False |
158 |
| - ) |
159 |
| - if referenced_model is None: |
160 |
| - continue |
161 |
| - referenced_model_path = referenced_model._path |
162 |
| - # Check whether the path exists |
163 |
| - if not referenced_model_path.is_file(): |
164 |
| - continue |
165 |
| - referenced_model_uri = URI.from_path(referenced_model_path) |
166 |
| - |
167 |
| - # Extract metadata for positioning |
168 |
| - table_meta = TokenPositionDetails.from_meta(table.this.meta) |
169 |
| - table_range = _range_from_token_position_details(table_meta, read_file) |
170 |
| - start_pos = table_range.start |
171 |
| - end_pos = table_range.end |
172 |
| - |
173 |
| - # If there's a catalog or database qualifier, adjust the start position |
174 |
| - catalog_or_db = table.args.get("catalog") or table.args.get("db") |
175 |
| - if catalog_or_db is not None: |
176 |
| - catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
177 |
| - catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file) |
178 |
| - start_pos = catalog_or_db_range.start |
179 |
| - |
180 |
| - description = generate_markdown_description(referenced_model) |
181 |
| - |
182 |
| - references.append( |
183 |
| - Reference( |
184 |
| - uri=referenced_model_uri.value, |
185 |
| - range=Range(start=start_pos, end=end_pos), |
186 |
| - markdown_description=description, |
187 |
| - ) |
188 |
| - ) |
| 136 | + # Build scope tree to properly handle nested CTEs |
| 137 | + root_scope = build_scope(query) |
| 138 | + |
| 139 | + if root_scope: |
| 140 | + # Traverse all scopes to find CTE definitions and table references |
| 141 | + for scope in root_scope.traverse(): |
| 142 | + # Build a map of CTE names to their definitions within this scope |
| 143 | + cte_definitions = {} |
| 144 | + |
| 145 | + # For CTEs defined in this scope |
| 146 | + for cte in scope.ctes: |
| 147 | + if cte.alias: |
| 148 | + cte_definitions[cte.alias] = cte |
| 149 | + |
| 150 | + # Also include CTEs from parent scopes (for references inside nested CTEs) |
| 151 | + parent = scope.parent |
| 152 | + while parent: |
| 153 | + for cte in parent.ctes: |
| 154 | + if cte.alias and cte.alias not in cte_definitions: |
| 155 | + cte_definitions[cte.alias] = cte |
| 156 | + parent = parent.parent |
| 157 | + |
| 158 | + # Get all table references in this scope |
| 159 | + tables = list(scope.find_all(exp.Table)) |
| 160 | + |
| 161 | + for table in tables: |
| 162 | + table_name = table.name |
| 163 | + |
| 164 | + # Check if this table reference is a CTE in the current scope |
| 165 | + if cte_def := cte_definitions.get(table_name): |
| 166 | + # This is a CTE reference - create a reference to the CTE definition |
| 167 | + alias = cte_def.args["alias"] |
| 168 | + if isinstance(alias, exp.TableAlias): |
| 169 | + identifier = alias.this |
| 170 | + if isinstance(identifier, exp.Identifier): |
| 171 | + target_range = _range_from_token_position_details( |
| 172 | + TokenPositionDetails.from_meta(identifier.meta), read_file |
| 173 | + ) |
| 174 | + table_range = _range_from_token_position_details( |
| 175 | + TokenPositionDetails.from_meta(table.this.meta), read_file |
| 176 | + ) |
| 177 | + references.append( |
| 178 | + Reference( |
| 179 | + uri=document_uri.value, # Same file |
| 180 | + range=table_range, |
| 181 | + target_range=target_range, |
| 182 | + ) |
| 183 | + ) |
| 184 | + |
| 185 | + # For non-CTE tables, process as before (external model references) |
| 186 | + # Normalize the table reference |
| 187 | + unaliased = table.copy() |
| 188 | + if unaliased.args.get("alias") is not None: |
| 189 | + unaliased.set("alias", None) |
| 190 | + reference_name = unaliased.sql(dialect=dialect) |
| 191 | + try: |
| 192 | + normalized_reference_name = normalize_model_name( |
| 193 | + reference_name, |
| 194 | + default_catalog=lint_context.context.default_catalog, |
| 195 | + dialect=dialect, |
| 196 | + ) |
| 197 | + if normalized_reference_name not in depends_on: |
| 198 | + continue |
| 199 | + except Exception: |
| 200 | + # Skip references that cannot be normalized |
| 201 | + continue |
| 202 | + |
| 203 | + # Get the referenced model uri |
| 204 | + referenced_model = lint_context.context.get_model( |
| 205 | + model_or_snapshot=normalized_reference_name, raise_if_missing=False |
| 206 | + ) |
| 207 | + if referenced_model is None: |
| 208 | + continue |
| 209 | + referenced_model_path = referenced_model._path |
| 210 | + # Check whether the path exists |
| 211 | + if not referenced_model_path.is_file(): |
| 212 | + continue |
| 213 | + referenced_model_uri = URI.from_path(referenced_model_path) |
| 214 | + |
| 215 | + # Extract metadata for positioning |
| 216 | + table_meta = TokenPositionDetails.from_meta(table.this.meta) |
| 217 | + table_range = _range_from_token_position_details(table_meta, read_file) |
| 218 | + start_pos = table_range.start |
| 219 | + end_pos = table_range.end |
| 220 | + |
| 221 | + # If there's a catalog or database qualifier, adjust the start position |
| 222 | + catalog_or_db = table.args.get("catalog") or table.args.get("db") |
| 223 | + if catalog_or_db is not None: |
| 224 | + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) |
| 225 | + catalog_or_db_range = _range_from_token_position_details( |
| 226 | + catalog_or_db_meta, read_file |
| 227 | + ) |
| 228 | + start_pos = catalog_or_db_range.start |
| 229 | + |
| 230 | + description = generate_markdown_description(referenced_model) |
| 231 | + |
| 232 | + references.append( |
| 233 | + Reference( |
| 234 | + uri=referenced_model_uri.value, |
| 235 | + range=Range(start=start_pos, end=end_pos), |
| 236 | + markdown_description=description, |
| 237 | + ) |
| 238 | + ) |
189 | 239 |
|
190 | 240 | return references
|
191 | 241 |
|
|
0 commit comments