From 1e3488bfb4043b4e7cb8b4395afd62f3936d3881 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Tue, 2 Apr 2024 11:59:07 -0700 Subject: [PATCH] =?UTF-8?q?Don=E2=80=99t=20repeat=20a=20function=20in=20`i?= =?UTF-8?q?ncomingCalls`=20if=20it=20contains=20multiple=20calls=20to=20th?= =?UTF-8?q?e=20same=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Eg. if we have the following, and we get the call hierarchy of `foo`, we only want to show `bar` once, with multiple `fromRanges` instead of having two entries for `bar` in the call hierarchy. ```swift func foo() {} func bar() { foo() foo() } ``` --- Sources/SourceKitLSP/SourceKitLSPServer.swift | 66 ++++++++++--------- .../CallHierarchyTests.swift | 57 ++++++++++------ 2 files changed, 72 insertions(+), 51 deletions(-) diff --git a/Sources/SourceKitLSP/SourceKitLSPServer.swift b/Sources/SourceKitLSP/SourceKitLSPServer.swift index c4dc1fa66..fa1979c9f 100644 --- a/Sources/SourceKitLSP/SourceKitLSPServer.swift +++ b/Sources/SourceKitLSP/SourceKitLSPServer.swift @@ -2126,27 +2126,44 @@ extension SourceKitLSPServer { callableUsrs += index.occurrences(ofUSR: data.usr, roles: .overrideOf).flatMap { occurrence in occurrence.relations.filter { $0.roles.contains(.overrideOf) }.map(\.symbol.usr) } + // callOccurrences are all the places that any of the USRs in callableUsrs is called. + // We also load the `calledBy` roles to get the method that contains the reference to this call. let callOccurrences = callableUsrs.flatMap { index.occurrences(ofUSR: $0, roles: .calledBy) } - let calls = callOccurrences.flatMap { occurrence -> [CallHierarchyIncomingCall] in - guard let location = indexToLSPLocation(occurrence.location) else { - return [] + + // Maps functions that call a USR in `callableUSRs` to all the called occurrences of `callableUSRs` within the + // function. If a function `foo` calls `bar` multiple times, `callersToCalls[foo]` will contain two call + // `SymbolOccurrence`s. + // This way, we can group multiple calls to `bar` within `foo` to a single item with multiple `fromRanges`. + var callersToCalls: [Symbol: [SymbolOccurrence]] = [:] + + for call in callOccurrences { + // Callers are all `calledBy` relations of a call to a USR in `callableUsrs`, ie. all the functions that contain a + // call to a USR in callableUSRs. In practice, this should always be a single item. + let callers = call.relations.filter { $0.roles.contains(.calledBy) }.map(\.symbol) + for caller in callers { + callersToCalls[caller, default: []].append(call) } - return occurrence.relations.filter { $0.symbol.kind.isCallable } - .map { related in - // Resolve the caller's definition to find its location - let definition = index.primaryDefinitionOrDeclarationOccurrence(ofUSR: related.symbol.usr) - let definitionSymbolLocation = definition?.location - let definitionLocation = definitionSymbolLocation.flatMap(indexToLSPLocation) - - return CallHierarchyIncomingCall( - from: indexToLSPCallHierarchyItem( - symbol: related.symbol, - containerName: definition?.containerName, - location: definitionLocation ?? location // Use occurrence location as fallback - ), - fromRanges: [location.range] - ) - } + } + + let calls = callersToCalls.compactMap { (caller: Symbol, calls: [SymbolOccurrence]) -> CallHierarchyIncomingCall? in + // Resolve the caller's definition to find its location + let definition = index.primaryDefinitionOrDeclarationOccurrence(ofUSR: caller.usr) + let definitionSymbolLocation = definition?.location + let definitionLocation = definitionSymbolLocation.flatMap(indexToLSPLocation) + + let locations = calls.compactMap { indexToLSPLocation($0.location) }.sorted() + guard !locations.isEmpty else { + return nil + } + + return CallHierarchyIncomingCall( + from: indexToLSPCallHierarchyItem( + symbol: caller, + containerName: definition?.containerName, + location: definitionLocation ?? locations.first! + ), + fromRanges: locations.map(\.range) + ) } return calls.sorted(by: { $0.from.name < $1.from.name }) } @@ -2455,17 +2472,6 @@ extension IndexSymbolKind { return .null } } - - var isCallable: Bool { - switch self { - case .function, .instanceMethod, .classMethod, .staticMethod, .constructor, .destructor, .conversionFunction: - return true - case .unknown, .module, .namespace, .namespaceAlias, .macro, .enum, .struct, .protocol, .extension, .union, - .typealias, .field, .enumConstant, .parameter, .using, .concept, .commentTag, .variable, .instanceProperty, - .class, .staticProperty, .classProperty: - return false - } - } } extension SymbolOccurrence { diff --git a/Tests/SourceKitLSPTests/CallHierarchyTests.swift b/Tests/SourceKitLSPTests/CallHierarchyTests.swift index 203000640..af69f0ffd 100644 --- a/Tests/SourceKitLSPTests/CallHierarchyTests.swift +++ b/Tests/SourceKitLSPTests/CallHierarchyTests.swift @@ -275,9 +275,15 @@ final class CallHierarchyTests: XCTestCase { """ func 1️⃣foo() {} - var testVar: Int 2️⃣{ - let myVar = 3️⃣foo() - return 2 + var testVar: Int { + 2️⃣get { + let myVar = 3️⃣foo() + return 2 + } + } + + func 4️⃣testFunc() { + _ = 5️⃣testVar } """ ) @@ -310,6 +316,31 @@ final class CallHierarchyTests: XCTestCase { ) ] ) + + let testVarItem = try XCTUnwrap(calls?.first?.from) + + let callsToTestVar = try await project.testClient.send(CallHierarchyIncomingCallsRequest(item: testVarItem)) + XCTAssertEqual( + callsToTestVar, + [ + CallHierarchyIncomingCall( + from: CallHierarchyItem( + name: "testFunc()", + kind: .function, + tags: nil, + detail: nil, + uri: project.fileURI, + range: Range(project.positions["4️⃣"]), + selectionRange: Range(project.positions["4️⃣"]), + data: .dictionary([ + "usr": .string("s:4test0A4FuncyyF"), + "uri": .string(project.fileURI.stringValue), + ]) + ), + fromRanges: [Range(project.positions["5️⃣"])] + ) + ] + ) } func testIncomingCallHierarchyShowsAccessToVariables() async throws { @@ -348,24 +379,8 @@ final class CallHierarchyTests: XCTestCase { "uri": .string(project.fileURI.stringValue), ]) ), - fromRanges: [Range(project.positions["3️⃣"])] - ), - CallHierarchyIncomingCall( - from: CallHierarchyItem( - name: "testFunc()", - kind: .function, - tags: nil, - detail: nil, - uri: project.fileURI, - range: Range(project.positions["2️⃣"]), - selectionRange: Range(project.positions["2️⃣"]), - data: .dictionary([ - "usr": .string("s:4test0A4FuncyyF"), - "uri": .string(project.fileURI.stringValue), - ]) - ), - fromRanges: [Range(project.positions["4️⃣"])] - ), + fromRanges: [Range(project.positions["3️⃣"]), Range(project.positions["4️⃣"])] + ) ] ) }