Skip to content

Commit 0343b68

Browse files
kinto0meta-codesync[bot]
authored andcommitted
go-to-definition on calls jumps to __init__/__new__/__call__
Summary: When go-to-definition is invoked on a call expression, jump to the underlying method definition instead of the class or variable: - Constructor calls (`Foo(1)`, `mod.Foo(1)`): jump to `__init__` and/or `__new__` on the class - Callable instances (`adder(5)` where `adder` has `__call__`): jump to `__call__` on the instance's class - Regular function calls: unchanged (still jumps to the function def) The implementation detects call context via `covering_nodes`, resolves the callee's type, and dispatches to `find_call_target_definitions` which uses the existing `find_attribute_definition_for_base_type` machinery. Only `ClassDef` types get constructor lookup; only `ClassType` (instances) get `__call__` lookup; functions and other callables fall through to normal go-to-definition. Note: this stack is a different implementation from D87102818. D87102818 reused the find-references queue which ended up being more complicated and required storing a reverse-mapping in state. it was more correct (it would only go to the constructor pyrefly thought was used). instead, this implementation goes to all `__init__`, `__new__`, `__call__`. it might be more useful to see all of them as separate definitions and is much cleaner to implement. fixes #1877 fixes #1728 Reviewed By: yangdanny97 Differential Revision: D96921976 fbshipit-source-id: c9bfad0699a421720a75ad146b0db92b33b97d36
1 parent 6701f6a commit 0343b68

File tree

4 files changed

+162
-24
lines changed

4 files changed

+162
-24
lines changed

pyrefly/lib/state/lsp.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,40 @@ impl<'a> Transaction<'a> {
14021402
})
14031403
}
14041404

1405+
/// When a name or attribute in a call position resolves to a class, find
1406+
/// `__init__` and `__new__` definitions. When it resolves to a class
1407+
/// instance, find `__call__`. Returns all found definitions, or empty if
1408+
/// neither case applies. Does not match functions/callables — those should
1409+
/// use the normal go-to-definition path.
1410+
fn find_call_target_definitions(
1411+
&self,
1412+
handle: &Handle,
1413+
preference: FindPreference,
1414+
ty: Type,
1415+
) -> Vec<FindDefinitionItemWithDocstring> {
1416+
match &ty {
1417+
Type::ClassDef(_) => {
1418+
let mut defs = self.find_attribute_definition_for_base_type(
1419+
handle,
1420+
preference,
1421+
ty.clone(),
1422+
&dunder::INIT,
1423+
);
1424+
defs.extend(self.find_attribute_definition_for_base_type(
1425+
handle,
1426+
preference,
1427+
ty,
1428+
&dunder::NEW,
1429+
));
1430+
defs
1431+
}
1432+
Type::ClassType(_) => {
1433+
self.find_attribute_definition_for_base_type(handle, preference, ty, &dunder::CALL)
1434+
}
1435+
_ => vec![],
1436+
}
1437+
}
1438+
14051439
pub(crate) fn find_definition_for_base_type(
14061440
&self,
14071441
handle: &Handle,
@@ -1740,6 +1774,23 @@ impl<'a> Transaction<'a> {
17401774
.map_or(vec![], |item| vec![item])
17411775
}
17421776
ExprContext::Load | ExprContext::Del | ExprContext::Invalid => {
1777+
// If this name is the callee of a call expression, jump
1778+
// to constructor or __call__ definitions when applicable.
1779+
if let Some(AnyNodeRef::ExprCall(call)) = covering_nodes.get(1)
1780+
&& call.func.range() == id.range
1781+
&& let Some(bindings) = self.get_bindings(handle)
1782+
{
1783+
let key = Key::BoundName(ShortIdentifier::new(&id));
1784+
if bindings.is_valid_key(&key)
1785+
&& let Some(ty) = self.get_type(handle, &key)
1786+
{
1787+
let defs =
1788+
self.find_call_target_definitions(handle, preference, ty);
1789+
if !defs.is_empty() {
1790+
return defs;
1791+
}
1792+
}
1793+
}
17431794
// This is a usage of the variable
17441795
self.find_definition_for_name_use(handle, &id, preference)
17451796
.map_or(vec![], |item| vec![item])
@@ -1900,6 +1951,18 @@ impl<'a> Transaction<'a> {
19001951
identifier,
19011952
context: IdentifierContext::Attribute { base_range, .. },
19021953
}) => {
1954+
// If this attribute is the callee of a call expression, jump
1955+
// to constructor or __call__ definitions when applicable.
1956+
if let Some(AnyNodeRef::ExprAttribute(attr)) = covering_nodes.get(1)
1957+
&& let Some(AnyNodeRef::ExprCall(call)) = covering_nodes.get(2)
1958+
&& call.func.range() == attr.range()
1959+
&& let Some(ty) = self.get_type_trace(handle, attr.range())
1960+
{
1961+
let defs = self.find_call_target_definitions(handle, preference, ty);
1962+
if !defs.is_empty() {
1963+
return defs;
1964+
}
1965+
}
19031966
self.find_definition_for_attribute(handle, base_range, identifier.id(), preference)
19041967
}
19051968
Some(IdentifierWithContext {

pyrefly/lib/test/lsp/definition.rs

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,10 +2238,6 @@ Definition Result:
22382238
);
22392239
}
22402240

2241-
// TODO: go-to-definition on constructor calls should jump to __init__/__new__,
2242-
// not the class definition. These tests document the current (wrong) behavior
2243-
// and will be updated when the feature is implemented.
2244-
22452241
#[test]
22462242
fn goto_def_constructor_call_same_module() {
22472243
let code = r#"
@@ -2259,8 +2255,8 @@ Bar("hello")
22592255
6 | Bar("hello")
22602256
^
22612257
Definition Result:
2262-
2 | class Bar:
2263-
^^^
2258+
3 | def __init__(self, name: str) -> None:
2259+
^^^^^^^^
22642260
"#
22652261
.trim(),
22662262
report.trim(),
@@ -2289,8 +2285,8 @@ Foo(1)
22892285
3 | Foo(1)
22902286
^
22912287
Definition Result:
2292-
2 | class Foo:
2293-
^^^
2288+
3 | def __init__(self, x: int) -> None:
2289+
^^^^^^^^
22942290
22952291
22962292
# foo_mod.py
@@ -2326,8 +2322,8 @@ Child(1)
23262322
7 | Child(1)
23272323
^
23282324
Definition Result:
2329-
4 | class Child(Base):
2330-
^^^^^
2325+
3 | def __init__(self, x: int) -> None:
2326+
^^^^^^^^
23312327
23322328
23332329
# base_mod.py
@@ -2357,8 +2353,8 @@ Singleton()
23572353
9 | Singleton()
23582354
^
23592355
Definition Result:
2360-
2 | class Singleton:
2361-
^^^^^^^^^
2356+
4 | def __new__(cls) -> "Singleton":
2357+
^^^^^^^
23622358
"#
23632359
.trim(),
23642360
report.trim(),
@@ -2384,8 +2380,11 @@ MyClass()
23842380
8 | MyClass()
23852381
^
23862382
Definition Result:
2387-
2 | class MyClass:
2388-
^^^^^^^
2383+
5 | def __init__(self) -> None:
2384+
^^^^^^^^
2385+
Definition Result:
2386+
3 | def __new__(cls) -> "MyClass":
2387+
^^^^^^^
23892388
"#
23902389
.trim(),
23912390
report.trim(),
@@ -2414,8 +2413,8 @@ foo_mod.Foo(1)
24142413
3 | foo_mod.Foo(1)
24152414
^
24162415
Definition Result:
2417-
2 | class Foo:
2418-
^^^
2416+
3 | def __init__(self, x: int) -> None:
2417+
^^^^^^^^
24192418
24202419
24212420
# foo_mod.py
@@ -2424,3 +2423,78 @@ Definition Result:
24242423
report.trim(),
24252424
);
24262425
}
2426+
2427+
#[test]
2428+
fn goto_def_callable_instance() {
2429+
let code = r#"
2430+
class Adder:
2431+
def __call__(self, x: int) -> int:
2432+
return x + 1
2433+
2434+
adder = Adder()
2435+
adder(5)
2436+
# ^
2437+
"#;
2438+
let report = get_batched_lsp_operations_report(&[("main", code)], get_test_report);
2439+
assert_eq!(
2440+
r#"
2441+
# main.py
2442+
7 | adder(5)
2443+
^
2444+
Definition Result:
2445+
3 | def __call__(self, x: int) -> int:
2446+
^^^^^^^^
2447+
"#
2448+
.trim(),
2449+
report.trim(),
2450+
);
2451+
}
2452+
2453+
#[test]
2454+
fn goto_def_non_constructor_call_goes_to_function() {
2455+
let code = r#"
2456+
def foo(x: int) -> int:
2457+
return x
2458+
2459+
foo(1)
2460+
# ^
2461+
"#;
2462+
let report = get_batched_lsp_operations_report(&[("main", code)], get_test_report);
2463+
assert_eq!(
2464+
r#"
2465+
# main.py
2466+
5 | foo(1)
2467+
^
2468+
Definition Result:
2469+
2 | def foo(x: int) -> int:
2470+
^^^
2471+
"#
2472+
.trim(),
2473+
report.trim(),
2474+
);
2475+
}
2476+
2477+
#[test]
2478+
fn goto_def_class_name_without_call_goes_to_class() {
2479+
let code = r#"
2480+
class Baz:
2481+
def __init__(self) -> None:
2482+
pass
2483+
2484+
x = Baz
2485+
# ^
2486+
"#;
2487+
let report = get_batched_lsp_operations_report(&[("main", code)], get_test_report);
2488+
assert_eq!(
2489+
r#"
2490+
# main.py
2491+
6 | x = Baz
2492+
^
2493+
Definition Result:
2494+
2 | class Baz:
2495+
^^^
2496+
"#
2497+
.trim(),
2498+
report.trim(),
2499+
);
2500+
}

pyrefly/lib/test/lsp/hover.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ from mymod.submod.deep import Bar
10201020
}
10211021

10221022
#[test]
1023-
fn hover_on_constructor_shows_instance_type() {
1023+
fn hover_on_constructor() {
10241024
let code = r#"
10251025
class Person:
10261026
def __init__(self, name: str, age: int) -> None: ...
@@ -1030,8 +1030,9 @@ Person()
10301030
"#;
10311031
let report = get_batched_lsp_operations_report_allow_error(&[("main", code)], get_test_report);
10321032
assert!(
1033-
report
1034-
.contains("def Person(\n self: Person,\n name: str,\n age: int\n) -> Person"),
1033+
report.contains(
1034+
"def __init__(\n self: Person,\n name: str,\n age: int\n) -> Person"
1035+
),
10351036
"Expected constructor hover to show complete signature with -> Person, got: {report}"
10361037
);
10371038
}

pyrefly/lib/test/lsp/lsp_interaction/definition.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -489,18 +489,18 @@ fn goto_type_def_on_list_of_primitives_shows_selector() {
489489

490490
#[test]
491491
fn test_go_to_def_constructor_calls() {
492-
// Note: go-to-definition currently goes to the class definition, not __init__.
492+
// go-to-definition on constructor calls should go to __init__
493493
let root = get_test_files_root();
494494
let constructor_root = root.path().join("constructor_references");
495495
test_go_to_def(
496496
constructor_root,
497497
None,
498498
"usage.py",
499499
vec![
500-
// Person("Alice", 30) - goes to class Person definition
501-
(7, 7, "person.py", 6, 6, 6, 12),
502-
// Person("Bob", 25) - goes to class Person definition
503-
(8, 7, "person.py", 6, 6, 6, 12),
500+
// Person("Alice", 30) - goes to Person.__init__ definition
501+
(7, 7, "person.py", 7, 8, 7, 16),
502+
// Person("Bob", 25) - goes to Person.__init__ definition
503+
(8, 7, "person.py", 7, 8, 7, 16),
504504
],
505505
);
506506
}

0 commit comments

Comments
 (0)