Skip to content

Commit db6a793

Browse files
committed
[Macros] Add parentContext(of:) operation to MacroExpansionContext.
Extend `MacroExpansionContext` with a new operation `parentContext(of:)` that provides a "pruned" version of the innermost enclosing syntax node for a given node. This allows us to establish contextual information for the syntax nodes that are passed to a macro expansion, without exposing information about unrelated parts of the source file.
1 parent 2b205bc commit db6a793

File tree

5 files changed

+298
-1
lines changed

5 files changed

+298
-1
lines changed

Sources/SwiftSyntaxMacros/BasicMacroExpansionContext.swift

+19
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ public class BasicMacroExpansionContext {
6565
/// Used in conjunction with `expansionDiscriminator`.
6666
private var uniqueNames: [String: Int] = [:]
6767

68+
// The parent contexts of any nodes that have been detached.
69+
private var parentContexts: [Syntax: Syntax] = [:]
6870
}
6971

7072
extension BasicMacroExpansionContext {
@@ -86,6 +88,15 @@ extension BasicMacroExpansionContext {
8688
addDisconnected(detached, at: node.position, in: rootSourceFile)
8789
}
8890

91+
// Record macro parent contexts.
92+
do {
93+
var childNode = Syntax(detached)
94+
for parentNode in node.allMacroParentContexts() {
95+
parentContexts[childNode] = parentNode
96+
childNode = parentNode
97+
}
98+
}
99+
89100
return detached
90101
}
91102
}
@@ -190,4 +201,12 @@ extension BasicMacroExpansionContext: MacroExpansionContext {
190201
let converter = SourceLocationConverter(file: fileName, tree: rootSourceFile)
191202
return converter.location(for: rawPosition.advanced(by: offsetAdjustment))
192203
}
204+
205+
public func parentContext<Node: SyntaxProtocol>(of node: Node) -> Syntax? {
206+
if let context = node.allMacroParentContexts().first {
207+
return context
208+
}
209+
210+
return parentContexts[Syntax(node)]
211+
}
193212
}

Sources/SwiftSyntaxMacros/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ add_swift_host_library(SwiftSyntaxMacros
2626
MacroReplacement.swift
2727
MacroSystem.swift
2828
Syntax+MacroEvaluation.swift
29+
Syntax+ParentContext.swift
2930
)
3031

3132
target_link_libraries(SwiftSyntaxMacros PUBLIC

Sources/SwiftSyntaxMacros/MacroExpansionContext.swift

+30
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,19 @@ public protocol MacroExpansionContext: AnyObject {
7878
at position: PositionInSyntaxNode,
7979
filePathMode: SourceLocationFilePathMode
8080
) -> AbstractSourceLocation?
81+
82+
/// Determine the parent context of the given syntax node.
83+
///
84+
/// For a syntax node that is part of the syntax provided to a macro
85+
/// expansion, find the innermost enclosing context node. A context
86+
/// node is an entity such as a function declaration, type declaration,
87+
/// or extension that can have other entities nested inside it.
88+
/// The resulting context node will have any information about nested
89+
/// entities removed from it, to prevent macro expansion operations from
90+
/// seeing unrelated code within the program. For more information
91+
/// about the identification and pruning of parent contexts, see
92+
/// `SyntaxProtocol.asMacroParentContext`.
93+
func parentContext<Node: SyntaxProtocol>(of node: Node) -> Syntax?
8194
}
8295

8396
extension MacroExpansionContext {
@@ -220,6 +233,23 @@ extension MacroExpansionContext {
220233
}
221234
}
222235

236+
extension MacroExpansionContext {
237+
/// Determine the parent context of the given syntax node.
238+
///
239+
/// For a syntax node that is part of the syntax provided to a macro
240+
/// expansion, find the innermost enclosing context node. A context
241+
/// node is an entity such as a function declaration, type declaration,
242+
/// or extension that can have other entities nested inside it.
243+
/// The resulting context node will have any information about nested
244+
/// entities removed from it, to prevent macro expansion operations from
245+
/// seeing unrelated code within the program. For more information
246+
/// about the identification and pruning of parent contexts, see
247+
/// `SyntaxProtocol.asMacroParentContext`.
248+
public func parentContext<Node: SyntaxProtocol>(of node: Node) -> Syntax? {
249+
return nil
250+
}
251+
}
252+
223253
/// Describe the position within a syntax node that can be used to compute
224254
/// source locations.
225255
public enum PositionInSyntaxNode {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
import SwiftSyntax
14+
import SwiftSyntaxBuilder
15+
16+
extension SyntaxProtocol {
17+
/// If this is a
18+
public func asMacroParentContext() -> Syntax? {
19+
switch Syntax(self).asProtocol(SyntaxProtocol.self) {
20+
// Functions
21+
case var function as HasTrailingOptionalCodeBlock:
22+
function.body = nil
23+
return Syntax(function).detach()
24+
25+
// Nominal types and extensions.
26+
case var typeOrExtension as HasTrailingMemberDeclBlock:
27+
typeOrExtension.members = MemberDeclBlockSyntax(members: MemberDeclListSyntax())
28+
return Syntax(typeOrExtension).detach()
29+
30+
case var subscriptDecl as SubscriptDeclSyntax:
31+
subscriptDecl.accessor = nil
32+
return Syntax(subscriptDecl).detach()
33+
34+
case is EnumCaseElementSyntax:
35+
return Syntax(self).detach()
36+
37+
case var patternBinding as PatternBindingSyntax:
38+
patternBinding.accessor = nil
39+
patternBinding.initializer = nil
40+
return Syntax(patternBinding).detach()
41+
42+
default:
43+
return nil
44+
}
45+
}
46+
47+
/// Return an array of enclosing parent contexts for the purpose of macros,
48+
/// from the innermost enclosing parent context (first in the array) to the
49+
/// outermost.
50+
public func allMacroParentContexts() -> [Syntax] {
51+
var parentContexts: [Syntax] = []
52+
var currentNode = Syntax(self)
53+
while let parentNode = currentNode.parent {
54+
if let parentContext = parentNode.asMacroParentContext() {
55+
parentContexts.append(parentContext)
56+
}
57+
58+
currentNode = parentNode
59+
}
60+
61+
return parentContexts
62+
}
63+
}

Tests/SwiftSyntaxMacrosTest/MacroSystemTests.swift

+185-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import SwiftDiagnostics
1414
import SwiftParser
1515
import SwiftSyntax
1616
import SwiftSyntaxBuilder
17-
import SwiftSyntaxMacros
17+
@_spi(Testing) import SwiftSyntaxMacros
1818
import _SwiftSyntaxTestSupport
1919
import XCTest
2020

@@ -145,6 +145,125 @@ public struct FileIDMacro: ExpressionMacro {
145145
}
146146
}
147147

148+
extension PatternBindingSyntax {
149+
/// When the variable is declaring a single binding, produce the name of
150+
/// that binding.
151+
fileprivate var singleBindingName: String? {
152+
if let identifierPattern = pattern.as(IdentifierPatternSyntax.self) {
153+
return identifierPattern.identifier.text
154+
}
155+
156+
return nil
157+
}
158+
}
159+
160+
extension SyntaxProtocol {
161+
/// Form a function name.
162+
private func formFunctionName(
163+
_ baseName: String, _ parameters: ParameterClauseSyntax?,
164+
isSubscript: Bool = false
165+
) -> String {
166+
let argumentNames: [String] = parameters?.parameterList.map { param in
167+
let argumentLabelText = param.argumentName?.text ?? "_"
168+
return argumentLabelText + ":"
169+
} ?? []
170+
171+
return "\(baseName)(\(argumentNames.joined(separator: "")))"
172+
}
173+
174+
/// Form the #function name for the given node.
175+
fileprivate func functionName<Context: MacroExpansionContext>(
176+
in context: Context
177+
) -> String? {
178+
// Declarations with parameters.
179+
// FIXME: Can we abstract over these?
180+
if let function = self.as(FunctionDeclSyntax.self) {
181+
return formFunctionName(
182+
function.identifier.text, function.signature.input
183+
)
184+
}
185+
186+
if let initializer = self.as(InitializerDeclSyntax.self) {
187+
return formFunctionName("init", initializer.signature.input)
188+
}
189+
190+
if let subscriptDecl = self.as(SubscriptDeclSyntax.self) {
191+
return formFunctionName(
192+
"subscript", subscriptDecl.indices, isSubscript: true
193+
)
194+
}
195+
196+
if let enumCase = self.as(EnumCaseElementSyntax.self) {
197+
guard let associatedValue = enumCase.associatedValue else {
198+
return enumCase.identifier.text
199+
}
200+
201+
let argumentNames = associatedValue.parameterList.map { param in
202+
guard let firstName = param.firstName else {
203+
return "_:"
204+
}
205+
206+
return firstName.text + ":"
207+
}.joined()
208+
209+
return "\(enumCase.identifier.text)(\(argumentNames))"
210+
}
211+
212+
// Accessors use their enclosing context, i.e., a subscript or pattern
213+
// binding.
214+
if self.is(AccessorDeclSyntax.self) {
215+
guard let parentContext = context.parentContext(of: self) else {
216+
return nil
217+
}
218+
219+
return parentContext.functionName(in: context)
220+
}
221+
222+
// All declarations with identifiers.
223+
if let identified = self.asProtocol(IdentifiedDeclSyntax.self) {
224+
return identified.identifier.text
225+
}
226+
227+
// Extensions
228+
if let extensionDecl = self.as(ExtensionDeclSyntax.self) {
229+
// FIXME: It would be nice to be able to switch on type syntax...
230+
let extendedType = extensionDecl.extendedType
231+
if let simple = extendedType.as(SimpleTypeIdentifierSyntax.self) {
232+
return simple.name.text
233+
}
234+
235+
if let member = extendedType.as(MemberTypeIdentifierSyntax.self) {
236+
return member.name.text
237+
}
238+
}
239+
240+
// Pattern bindings.
241+
if let patternBinding = self.as(PatternBindingSyntax.self),
242+
let singleVarName = patternBinding.singleBindingName {
243+
return singleVarName
244+
}
245+
246+
return nil
247+
}
248+
}
249+
250+
public struct FunctionMacro: ExpressionMacro {
251+
public static func expansion<
252+
Node: FreestandingMacroExpansionSyntax,
253+
Context: MacroExpansionContext
254+
>(
255+
of node: Node,
256+
in context: Context
257+
) -> ExprSyntax {
258+
guard let parentContext = context.parentContext(of: node),
259+
let name = parentContext.functionName(in: context) else {
260+
return #""<unknown>""#
261+
}
262+
263+
return ExprSyntax("\(literal: name)").with(\.leadingTrivia, node.leadingTrivia)
264+
}
265+
}
266+
148267
/// Macro whose only purpose is to ensure that we cannot see "out" of the
149268
/// macro expansion syntax node we were given.
150269
struct CheckContextIndependenceMacro: ExpressionMacro {
@@ -707,6 +826,7 @@ public let testMacros: [String: Macro.Type] = [
707826
"colorLiteral": ColorLiteralMacro.self,
708827
"column": ColumnMacro.self,
709828
"fileID": FileIDMacro.self,
829+
"function": FunctionMacro.self,
710830
"imageLiteral": ImageLiteralMacro.self,
711831
"stringify": StringifyMacro.self,
712832
"myError": ErrorMacro.self,
@@ -769,6 +889,70 @@ final class MacroSystemTests: XCTestCase {
769889
)
770890
}
771891

892+
func testPoundFunction() {
893+
assertMacroExpansion(
894+
macros: testMacros,
895+
"""
896+
func f(a: Int, _: Double, c: Int) {
897+
print(#function)
898+
}
899+
900+
struct X {
901+
var computed: String {
902+
get {
903+
#function
904+
}
905+
}
906+
907+
init(from: String) {
908+
#function
909+
}
910+
911+
subscript(a: Int) -> String {
912+
#function
913+
}
914+
915+
subscript(a a: Int) -> String {
916+
#function
917+
}
918+
}
919+
920+
extension A {
921+
static var staticProp: String = #function
922+
}
923+
""",
924+
"""
925+
func f(a: Int, _: Double, c: Int) {
926+
print("f(a:_:c:)")
927+
}
928+
929+
struct X {
930+
var computed: String {
931+
get {
932+
"computed"
933+
}
934+
}
935+
936+
init(from: String) {
937+
"init(from:)"
938+
}
939+
940+
subscript(a: Int) -> String {
941+
"subscript(_:)"
942+
}
943+
944+
subscript(a a: Int) -> String {
945+
"subscript(a:)"
946+
}
947+
}
948+
949+
extension A {
950+
static var staticProp: String = "staticProp"
951+
}
952+
"""
953+
)
954+
}
955+
772956
func testContextUniqueLocalNames() {
773957
let context = BasicMacroExpansionContext()
774958

0 commit comments

Comments
 (0)