Skip to content

Commit f0f16c2

Browse files
committed
go/ast: implement Apply for general tree traversal/rewriting
See also golang/go#17108.
1 parent f298a47 commit f0f16c2

File tree

1 file changed

+365
-0
lines changed

1 file changed

+365
-0
lines changed

src/go/ast/apply.go

+365
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
// Copyright 2016 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package ast
6+
7+
import (
8+
"fmt"
9+
"reflect"
10+
)
11+
12+
// An ApplyFunc is invoked by Apply for each node n, even if n is nil,
13+
// before and/or after the node's children.
14+
//
15+
// The parent, name, and index arguments identify the parent node's field
16+
// containing n. If that field is a slice, index identifies the node's position
17+
// in that slice; index is < 0 otherwise. Roughly speaking, the following
18+
// invariants hold:
19+
//
20+
// parent.name == n if index < 0
21+
// parent.name[index] == n if index >= 0
22+
//
23+
// SetField(parent, name, index, n1) can be used to change that field
24+
// to a different node n1.
25+
//
26+
// Exception: If the parent is a *Package, and Apply is iterating
27+
// through the Files map, name is the filename, and index is -1.
28+
//
29+
// The return value of ApplyFunc controls the syntax tree traversal.
30+
// See Apply for details.
31+
type ApplyFunc func(parent Node, name string, index int, n Node) bool
32+
33+
// Apply traverses a syntax tree recursively, starting with root,
34+
// and calling pre and post for each node as described below. The
35+
// result is the (possibly modified) syntax tree.
36+
//
37+
// If pre is not nil, it is called for each node before its children
38+
// are traversed (pre-order). If the result of calling pre is false,
39+
// no children are traversed, and post is not called for that node.
40+
//
41+
// If post is not nil, it is called for each node after its children
42+
// were traversed (post-order). If the result of calling post is false,
43+
// traversal is terminated and Apply returns immediately.
44+
//
45+
// Only fields that refer to AST nodes are considered children.
46+
// Children are traversed in the order in which they appear in the
47+
// respective node's struct definition.
48+
func Apply(root Node, pre, post ApplyFunc) Node {
49+
defer func() {
50+
if r := recover(); r != nil && r != abort {
51+
panic(r)
52+
}
53+
}()
54+
a := &application{root, pre, post}
55+
a.apply(a, "Node", -1, a.Node)
56+
return a.Node
57+
}
58+
59+
// SetField sets the named field in the parent node to n. If the field
60+
// is a slice, index is the slice index. The named field must exist in
61+
// the parent, n must be assignable to that field, and the field must be
62+
// indexable if index >= 0. In other words, SetField performs the following
63+
// assignment:
64+
//
65+
// parent.name = n if index < 0
66+
// parent.name[index] = n if index >= 0
67+
//
68+
// The parent node may be a pointer to the struct containing the named
69+
// field, or it may be the struct itself.
70+
//
71+
// Exception: If the parent is a Package, n must be a *File and name is
72+
// interpreted as the filename in the Package.Files map.
73+
func SetField(parent Node, name string, index int, n Node) {
74+
// TODO(gri) This doesn't handle the Package.Files map yet.
75+
v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
76+
if index >= 0 {
77+
v = v.Index(index)
78+
}
79+
v.Set(reflect.ValueOf(n))
80+
}
81+
82+
type application struct {
83+
Node
84+
pre, post ApplyFunc
85+
}
86+
87+
func (a *application) apply(parent Node, name string, index int, n Node) {
88+
if a.pre != nil && !a.pre(parent, name, index, n) {
89+
return
90+
}
91+
92+
// walk children
93+
// (the order of the cases matches the order
94+
// of the corresponding node types in ast.go)
95+
switch n := n.(type) {
96+
case nil:
97+
// nothing to do
98+
99+
// Comments and fields
100+
case *Comment:
101+
// nothing to do
102+
103+
case *CommentGroup:
104+
if n != nil {
105+
for i, x := range n.List {
106+
a.apply(n, "List", i, x)
107+
}
108+
}
109+
110+
case *Field:
111+
a.apply(n, "Doc", -1, n.Doc)
112+
a.applyIdentList(n, "Names", n.Names)
113+
a.apply(n, "Type", -1, n.Type)
114+
a.apply(n, "Tag", -1, n.Tag)
115+
a.apply(n, "Comment", -1, n.Comment)
116+
117+
case *FieldList:
118+
if n != nil {
119+
for i, x := range n.List {
120+
a.apply(n, "List", i, x)
121+
}
122+
}
123+
124+
// Expressions
125+
case *BadExpr, *Ident, *BasicLit:
126+
// nothing to do
127+
128+
case *Ellipsis:
129+
a.apply(n, "Elt", -1, n.Elt)
130+
131+
case *FuncLit:
132+
a.apply(n, "Type", -1, n.Type)
133+
a.apply(n, "Body", -1, n.Body)
134+
135+
case *CompositeLit:
136+
a.apply(n, "Type", -1, n.Type)
137+
a.applyExprList(n, "Elts", n.Elts)
138+
139+
case *ParenExpr:
140+
a.apply(n, "X", -1, n.X)
141+
142+
case *SelectorExpr:
143+
a.apply(n, "X", -1, n.X)
144+
a.apply(n, "Sel", -1, n.Sel)
145+
146+
case *IndexExpr:
147+
a.apply(n, "X", -1, n.X)
148+
a.apply(n, "Index", -1, n.Index)
149+
150+
case *SliceExpr:
151+
a.apply(n, "X", -1, n.X)
152+
a.apply(n, "Low", -1, n.Low)
153+
a.apply(n, "High", -1, n.High)
154+
a.apply(n, "Max", -1, n.Max)
155+
156+
case *TypeAssertExpr:
157+
a.apply(n, "X", -1, n.X)
158+
a.apply(n, "Type", -1, n.Type)
159+
160+
case *CallExpr:
161+
a.apply(n, "Fun", -1, n.Fun)
162+
a.applyExprList(n, "Args", n.Args)
163+
164+
case *StarExpr:
165+
a.apply(n, "X", -1, n.X)
166+
167+
case *UnaryExpr:
168+
a.apply(n, "X", -1, n.X)
169+
170+
case *BinaryExpr:
171+
a.apply(n, "X", -1, n.X)
172+
a.apply(n, "Y", -1, n.Y)
173+
174+
case *KeyValueExpr:
175+
a.apply(n, "Key", -1, n.Key)
176+
a.apply(n, "Value", -1, n.Value)
177+
178+
// Types
179+
case *ArrayType:
180+
a.apply(n, "Len", -1, n.Len)
181+
a.apply(n, "Elt", -1, n.Elt)
182+
183+
case *StructType:
184+
a.apply(n, "Fields", -1, n.Fields)
185+
186+
case *FuncType:
187+
a.apply(n, "Params", -1, n.Params)
188+
a.apply(n, "Results", -1, n.Results)
189+
190+
case *InterfaceType:
191+
a.apply(n, "Methods", -1, n.Methods)
192+
193+
case *MapType:
194+
a.apply(n, "Key", -1, n.Key)
195+
a.apply(n, "Value", -1, n.Value)
196+
197+
case *ChanType:
198+
a.apply(n, "Value", -1, n.Value)
199+
200+
// Statements
201+
case *BadStmt:
202+
// nothing to do
203+
204+
case *DeclStmt:
205+
a.apply(n, "Decl", -1, n.Decl)
206+
207+
case *EmptyStmt:
208+
// nothing to do
209+
210+
case *LabeledStmt:
211+
a.apply(n, "Label", -1, n.Label)
212+
a.apply(n, "Stmt", -1, n.Stmt)
213+
214+
case *ExprStmt:
215+
a.apply(n, "X", -1, n.X)
216+
217+
case *SendStmt:
218+
a.apply(n, "Chan", -1, n.Chan)
219+
a.apply(n, "Value", -1, n.Value)
220+
221+
case *IncDecStmt:
222+
a.apply(n, "X", -1, n.X)
223+
224+
case *AssignStmt:
225+
a.applyExprList(n, "Lhs", n.Lhs)
226+
a.applyExprList(n, "Rhs", n.Rhs)
227+
228+
case *GoStmt:
229+
a.apply(n, "Call", -1, n.Call)
230+
231+
case *DeferStmt:
232+
a.apply(n, "Call", -1, n.Call)
233+
234+
case *ReturnStmt:
235+
a.applyExprList(n, "Results", n.Results)
236+
237+
case *BranchStmt:
238+
a.apply(n, "Label", -1, n.Label)
239+
240+
case *BlockStmt:
241+
a.applyStmtList(n, "List", n.List)
242+
243+
case *IfStmt:
244+
a.apply(n, "Init", -1, n.Init)
245+
a.apply(n, "Cond", -1, n.Cond)
246+
a.apply(n, "Body", -1, n.Body)
247+
a.apply(n, "Else", -1, n.Else)
248+
249+
case *CaseClause:
250+
a.applyExprList(n, "List", n.List)
251+
a.applyStmtList(n, "Body", n.Body)
252+
253+
case *SwitchStmt:
254+
a.apply(n, "Init", -1, n.Init)
255+
a.apply(n, "Tag", -1, n.Tag)
256+
a.apply(n, "Body", -1, n.Body)
257+
258+
case *TypeSwitchStmt:
259+
a.apply(n, "Init", -1, n.Init)
260+
a.apply(n, "Assign", -1, n.Assign)
261+
a.apply(n, "Body", -1, n.Body)
262+
263+
case *CommClause:
264+
a.apply(n, "Comm", -1, n.Comm)
265+
a.applyStmtList(n, "Body", n.Body)
266+
267+
case *SelectStmt:
268+
a.apply(n, "Body", -1, n.Body)
269+
270+
case *ForStmt:
271+
a.apply(n, "Init", -1, n.Init)
272+
a.apply(n, "Cond", -1, n.Cond)
273+
a.apply(n, "Post", -1, n.Post)
274+
a.apply(n, "Body", -1, n.Body)
275+
276+
case *RangeStmt:
277+
a.apply(n, "Key", -1, n.Key)
278+
a.apply(n, "Value", -1, n.Value)
279+
a.apply(n, "X", -1, n.X)
280+
a.apply(n, "Body", -1, n.Body)
281+
282+
// Declarations
283+
case *ImportSpec:
284+
a.apply(n, "Doc", -1, n.Doc)
285+
a.apply(n, "Name", -1, n.Name)
286+
a.apply(n, "Path", -1, n.Path)
287+
a.apply(n, "Comment", -1, n.Comment)
288+
289+
case *ValueSpec:
290+
a.apply(n, "Doc", -1, n.Doc)
291+
a.applyIdentList(n, "Names", n.Names)
292+
a.apply(n, "Type", -1, n.Type)
293+
a.applyExprList(n, "Values", n.Values)
294+
a.apply(n, "Comment", -1, n.Comment)
295+
296+
case *TypeSpec:
297+
a.apply(n, "Doc", -1, n.Doc)
298+
a.apply(n, "Name", -1, n.Name)
299+
a.apply(n, "Type", -1, n.Type)
300+
a.apply(n, "Comment", -1, n.Comment)
301+
302+
case *BadDecl:
303+
// nothing to do
304+
305+
case *GenDecl:
306+
a.apply(n, "Doc", -1, n.Doc)
307+
for i, x := range n.Specs {
308+
a.apply(n, "Specs", i, x)
309+
}
310+
311+
case *FuncDecl:
312+
a.apply(n, "Doc", -1, n.Doc)
313+
a.apply(n, "Recv", -1, n.Recv)
314+
a.apply(n, "Name", -1, n.Name)
315+
a.apply(n, "Type", -1, n.Type)
316+
a.apply(n, "Body", -1, n.Body)
317+
318+
// Files and packages
319+
case *File:
320+
a.apply(n, "Doc", -1, n.Doc)
321+
a.apply(n, "Name", -1, n.Name)
322+
a.applyDeclList(n, "Decls", n.Decls)
323+
// don't walk n.Comments - they have been
324+
// visited already through the individual
325+
// nodes
326+
327+
case *Package:
328+
for name, f := range n.Files {
329+
a.apply(n, name, -1, f)
330+
}
331+
332+
default:
333+
panic(fmt.Sprintf("ast.Apply: unexpected node type %T", n))
334+
}
335+
336+
if a.post != nil && !a.post(parent, name, index, n) {
337+
panic(abort)
338+
}
339+
}
340+
341+
var abort = new(int) // singleton, to signal abortion of Apply
342+
343+
func (a *application) applyIdentList(parent Node, name string, list []*Ident) {
344+
for i, x := range list {
345+
a.apply(parent, name, i, x)
346+
}
347+
}
348+
349+
func (a *application) applyExprList(parent Node, name string, list []Expr) {
350+
for i, x := range list {
351+
a.apply(parent, name, i, x)
352+
}
353+
}
354+
355+
func (a *application) applyStmtList(parent Node, name string, list []Stmt) {
356+
for i, x := range list {
357+
a.apply(parent, name, i, x)
358+
}
359+
}
360+
361+
func (a *application) applyDeclList(parent Node, name string, list []Decl) {
362+
for i, x := range list {
363+
a.apply(parent, name, i, x)
364+
}
365+
}

0 commit comments

Comments
 (0)