@@ -36,22 +36,22 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
36
36
// TODO: stricter rules for selectorExpr.
37
37
case * ast.BasicLit , * ast.CompositeLit , * ast.IndexExpr , * ast.SliceExpr ,
38
38
* ast.UnaryExpr , * ast.BinaryExpr , * ast.SelectorExpr :
39
- lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x " , 0 )
39
+ lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "newVar " , 0 )
40
40
lhsNames = append (lhsNames , lhsName )
41
41
case * ast.CallExpr :
42
42
tup , ok := info .TypeOf (expr ).(* types.Tuple )
43
43
if ! ok {
44
44
// If the call expression only has one return value, we can treat it the
45
45
// same as our standard extract variable case.
46
- lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x " , 0 )
46
+ lhsName , _ := generateAvailableIdentifier (expr .Pos (), path , pkg , info , "newVar " , 0 )
47
47
lhsNames = append (lhsNames , lhsName )
48
48
break
49
49
}
50
50
idx := 0
51
51
for i := 0 ; i < tup .Len (); i ++ {
52
52
// Generate a unique variable for each return value.
53
53
var lhsName string
54
- lhsName , idx = generateAvailableIdentifier (expr .Pos (), path , pkg , info , "x " , idx )
54
+ lhsName , idx = generateAvailableIdentifier (expr .Pos (), path , pkg , info , "newVar " , idx )
55
55
lhsNames = append (lhsNames , lhsName )
56
56
}
57
57
default :
@@ -105,6 +105,346 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
105
105
}, nil
106
106
}
107
107
108
+ func replaceAllOccursOfExpr (fset * token.FileSet , start , end token.Pos , src []byte , file * ast.File , pkg * types.Package , info * types.Info ) (* token.FileSet , * analysis.SuggestedFix , error ) {
109
+ tokFile := fset .File (file .Pos ())
110
+ // exprs contains at least one expr.
111
+ exprs , _ , err := allOccurs (start , end , file )
112
+ if err != nil {
113
+ return nil , nil , fmt .Errorf ("extractVariable: cannot extract %s: %v" , safetoken .StartPosition (fset , start ), err )
114
+ }
115
+
116
+ scopes := make ([][]* types.Scope , len (exprs ))
117
+ for i , e := range exprs {
118
+ path , _ := astutil .PathEnclosingInterval (file , e .Pos (), e .End ())
119
+ scopes [i ] = CollectScopes (info , path , e .Pos ())
120
+ }
121
+
122
+ // Find the deepest common scope among all expressions.
123
+ commonScope , err := findDeepestCommonScope (scopes )
124
+ if err != nil {
125
+ return nil , nil , fmt .Errorf ("extractVariable: %v" , err )
126
+ }
127
+
128
+ var innerScopes []* types.Scope
129
+ for _ , scope := range scopes {
130
+ for _ , s := range scope {
131
+ if s != nil {
132
+ innerScopes = append (innerScopes , s )
133
+ break
134
+ }
135
+ }
136
+ }
137
+ if len (innerScopes ) != len (exprs ) {
138
+ return nil , nil , fmt .Errorf ("extractVariable: nil scope" )
139
+ }
140
+ // So the largest scope's name won't conflict.
141
+ innerScopes = append (innerScopes , commonScope )
142
+
143
+ // Create new AST node for extracted code.
144
+ var lhsNames []string
145
+ switch expr := exprs [0 ].(type ) {
146
+ // TODO: stricter rules for selectorExpr.
147
+ case * ast.BasicLit , * ast.CompositeLit , * ast.IndexExpr , * ast.SliceExpr ,
148
+ * ast.UnaryExpr , * ast.BinaryExpr , * ast.SelectorExpr :
149
+ lhsName , _ := generateAvailableIdentifierForAllScopes (innerScopes , "newVar" , 0 )
150
+ lhsNames = append (lhsNames , lhsName )
151
+ case * ast.CallExpr :
152
+ tup , ok := info .TypeOf (expr ).(* types.Tuple )
153
+ if ! ok {
154
+ // If the call expression only has one return value, we can treat it the
155
+ // same as our standard extract variable case.
156
+ lhsName , _ := generateAvailableIdentifierForAllScopes (innerScopes , "newVar" , 0 )
157
+ lhsNames = append (lhsNames , lhsName )
158
+ break
159
+ }
160
+ idx := 0
161
+ for i := 0 ; i < tup .Len (); i ++ {
162
+ // Generate a unique variable for each return value.
163
+ var lhsName string
164
+ lhsName , idx = generateAvailableIdentifierForAllScopes (innerScopes , "newVar" , idx )
165
+ lhsNames = append (lhsNames , lhsName )
166
+ }
167
+ default :
168
+ return nil , nil , fmt .Errorf ("cannot extract %T" , expr )
169
+ }
170
+
171
+ var validPath []ast.Node
172
+ if commonScope != innerScopes [0 ] {
173
+ // This means the first expr within function body is not the largest scope,
174
+ // we need to find the scope immediately follow the common
175
+ // scope where we will insert the statement before.
176
+ child := innerScopes [0 ]
177
+ for p := child ; p != nil ; p = p .Parent () {
178
+ if p == commonScope {
179
+ break
180
+ }
181
+ child = p
182
+ }
183
+ validPath , _ = astutil .PathEnclosingInterval (file , child .Pos (), child .End ())
184
+ } else {
185
+ // Just insert before the first expr.
186
+ validPath , _ = astutil .PathEnclosingInterval (file , exprs [0 ].Pos (), exprs [0 ].End ())
187
+ }
188
+ //
189
+ // TODO: There is a bug here: for a variable declared in a labeled
190
+ // switch/for statement it returns the for/switch statement itself
191
+ // which produces the below code which is a compiler error e.g.
192
+ // label:
193
+ // switch r1 := r() { ... break label ... }
194
+ // On extracting "r()" to a variable
195
+ // label:
196
+ // x := r()
197
+ // switch r1 := x { ... break label ... } // compiler error
198
+ //
199
+ insertBeforeStmt := analysisinternal .StmtToInsertVarBefore (validPath )
200
+ if insertBeforeStmt == nil {
201
+ return nil , nil , fmt .Errorf ("cannot find location to insert extraction" )
202
+ }
203
+ indent , err := calculateIndentation (src , tokFile , insertBeforeStmt )
204
+ if err != nil {
205
+ return nil , nil , err
206
+ }
207
+ newLineIndent := "\n " + indent
208
+
209
+ lhs := strings .Join (lhsNames , ", " )
210
+ assignStmt := & ast.AssignStmt {
211
+ Lhs : []ast.Expr {ast .NewIdent (lhs )},
212
+ Tok : token .DEFINE ,
213
+ Rhs : []ast.Expr {exprs [0 ]},
214
+ }
215
+ var buf bytes.Buffer
216
+ if err := format .Node (& buf , fset , assignStmt ); err != nil {
217
+ return nil , nil , err
218
+ }
219
+ assignment := strings .ReplaceAll (buf .String (), "\n " , newLineIndent ) + newLineIndent
220
+ var textEdits []analysis.TextEdit
221
+ textEdits = append (textEdits , analysis.TextEdit {
222
+ Pos : insertBeforeStmt .Pos (),
223
+ End : insertBeforeStmt .Pos (),
224
+ NewText : []byte (assignment ),
225
+ })
226
+ for _ , e := range exprs {
227
+ textEdits = append (textEdits , analysis.TextEdit {
228
+ Pos : e .Pos (),
229
+ End : e .End (),
230
+ NewText : []byte (lhs ),
231
+ })
232
+ }
233
+ return fset , & analysis.SuggestedFix {
234
+ TextEdits : textEdits ,
235
+ }, nil
236
+ }
237
+
238
+ // findDeepestCommonScope finds the deepest (innermost) scope that is common to all provided scope chains.
239
+ // Each scope chain represents the scopes of an expression from innermost to outermost.
240
+ // If no common scope is found, it returns an error.
241
+ func findDeepestCommonScope (scopeChains [][]* types.Scope ) (* types.Scope , error ) {
242
+ if len (scopeChains ) == 0 {
243
+ return nil , fmt .Errorf ("no scopes provided" )
244
+ }
245
+ // Get the first scope chain as the reference.
246
+ referenceChain := scopeChains [0 ]
247
+
248
+ // Iterate from innermost to outermost scope.
249
+ for i := 0 ; i < len (referenceChain ); i ++ {
250
+ candidateScope := referenceChain [i ]
251
+ if candidateScope == nil {
252
+ continue
253
+ }
254
+ isCommon := true
255
+ // See if other exprs' chains all have candidateScope as a common ancestor.
256
+ for _ , chain := range scopeChains [1 :] {
257
+ found := false
258
+ for j := 0 ; j < len (chain ); j ++ {
259
+ if chain [j ] == candidateScope {
260
+ found = true
261
+ break
262
+ }
263
+ }
264
+ if ! found {
265
+ isCommon = false
266
+ break
267
+ }
268
+ }
269
+ if isCommon {
270
+ return candidateScope , nil
271
+ }
272
+ }
273
+ return nil , fmt .Errorf ("no common scope found" )
274
+ }
275
+
276
+ // allOccurs finds all occurrences of an expression identical to the one
277
+ // specified by the start and end positions within the same function.
278
+ // It returns at least one ast.Expr.
279
+ func allOccurs (start , end token.Pos , file * ast.File ) ([]ast.Expr , bool , error ) {
280
+ if start == end {
281
+ return nil , false , fmt .Errorf ("start and end are equal" )
282
+ }
283
+ path , _ := astutil .PathEnclosingInterval (file , start , end )
284
+ if len (path ) == 0 {
285
+ return nil , false , fmt .Errorf ("no path enclosing interval" )
286
+ }
287
+ for _ , n := range path {
288
+ if _ , ok := n .(* ast.ImportSpec ); ok {
289
+ return nil , false , fmt .Errorf ("cannot extract variable in an import block" )
290
+ }
291
+ }
292
+ node := path [0 ]
293
+ if start != node .Pos () || end != node .End () {
294
+ return nil , false , fmt .Errorf ("range does not map to an AST node" )
295
+ }
296
+ expr , ok := node .(ast.Expr )
297
+ if ! ok {
298
+ return nil , false , fmt .Errorf ("node is not an expression" )
299
+ }
300
+
301
+ var exprs []ast.Expr
302
+ exprs = append (exprs , expr )
303
+ if funcDecl , ok := path [len (path )- 2 ].(* ast.FuncDecl ); ok {
304
+ ast .Inspect (funcDecl , func (n ast.Node ) bool {
305
+ if e , ok := n .(ast.Expr ); ok && e != expr {
306
+ if exprIdentical (e , expr ) {
307
+ exprs = append (exprs , e )
308
+ }
309
+ }
310
+ return true
311
+ })
312
+ }
313
+ sort .Slice (exprs , func (i , j int ) bool {
314
+ return exprs [i ].Pos () < exprs [j ].Pos ()
315
+ })
316
+
317
+ switch expr .(type ) {
318
+ case * ast.BasicLit , * ast.CompositeLit , * ast.IndexExpr , * ast.CallExpr ,
319
+ * ast.SliceExpr , * ast.UnaryExpr , * ast.BinaryExpr , * ast.SelectorExpr :
320
+ return exprs , true , nil
321
+ }
322
+ return nil , false , fmt .Errorf ("cannot extract an %T to a variable" , expr )
323
+ }
324
+
325
+ // generateAvailableIdentifierForAllScopes adjusts the new identifier name
326
+ // until there are no collisions in any of the provided scopes.
327
+ func generateAvailableIdentifierForAllScopes (scopes []* types.Scope , prefix string , idx int ) (string , int ) {
328
+ name := prefix
329
+ for {
330
+ collision := false
331
+ for _ , scope := range scopes {
332
+ if scope .Lookup (name ) != nil {
333
+ collision = true
334
+ break
335
+ }
336
+ }
337
+ if ! collision {
338
+ // No collision found; return the name and the current index.
339
+ return name , idx
340
+ }
341
+ // Adjust the name by appending the index and increment the index.
342
+ idx ++
343
+ name = fmt .Sprintf ("%s%d" , prefix , idx )
344
+ }
345
+ }
346
+
347
+ // exprIdentical recursively compares two ast.Expr nodes for structural equality,
348
+ // ignoring position fields.
349
+ func exprIdentical (x , y ast.Expr ) bool {
350
+ if x == nil || y == nil {
351
+ return x == y
352
+ }
353
+ switch x := x .(type ) {
354
+ case * ast.BasicLit :
355
+ y , ok := y .(* ast.BasicLit )
356
+ return ok && x .Kind == y .Kind && x .Value == y .Value
357
+ case * ast.CompositeLit :
358
+ y , ok := y .(* ast.CompositeLit )
359
+ if ! ok || len (x .Elts ) != len (y .Elts ) || ! exprIdentical (x .Type , y .Type ) {
360
+ return false
361
+ }
362
+ for i := range x .Elts {
363
+ if ! exprIdentical (x .Elts [i ], y .Elts [i ]) {
364
+ return false
365
+ }
366
+ }
367
+ return true
368
+ case * ast.ArrayType :
369
+ y , ok := y .(* ast.ArrayType )
370
+ return ok && exprIdentical (x .Len , y .Len ) && exprIdentical (x .Elt , y .Elt )
371
+ case * ast.Ellipsis :
372
+ y , ok := y .(* ast.Ellipsis )
373
+ return ok && exprIdentical (x .Elt , y .Elt )
374
+ case * ast.FuncLit :
375
+ y , ok := y .(* ast.FuncLit )
376
+ return ok && exprIdentical (x .Type , y .Type )
377
+ case * ast.IndexExpr :
378
+ y , ok := y .(* ast.IndexExpr )
379
+ return ok && exprIdentical (x .X , y .X ) && exprIdentical (x .Index , y .Index )
380
+ case * ast.IndexListExpr :
381
+ y , ok := y .(* ast.IndexListExpr )
382
+ if ! ok || len (x .Indices ) != len (y .Indices ) || ! exprIdentical (x .X , y .X ) {
383
+ return false
384
+ }
385
+ for i := range x .Indices {
386
+ if ! exprIdentical (x .Indices [i ], y .Indices [i ]) {
387
+ return false
388
+ }
389
+ }
390
+ return true
391
+ case * ast.SliceExpr :
392
+ y , ok := y .(* ast.SliceExpr )
393
+ return ok && exprIdentical (x .X , y .X ) && exprIdentical (x .Low , y .Low ) && exprIdentical (x .High , y .High ) && exprIdentical (x .Max , y .Max ) && x .Slice3 == y .Slice3
394
+ case * ast.TypeAssertExpr :
395
+ y , ok := y .(* ast.TypeAssertExpr )
396
+ return ok && exprIdentical (x .X , y .X ) && exprIdentical (x .Type , y .Type )
397
+ case * ast.StarExpr :
398
+ y , ok := y .(* ast.StarExpr )
399
+ return ok && exprIdentical (x .X , y .X )
400
+ case * ast.KeyValueExpr :
401
+ y , ok := y .(* ast.KeyValueExpr )
402
+ return ok && exprIdentical (x .Key , y .Key ) && exprIdentical (x .Value , y .Value )
403
+ case * ast.UnaryExpr :
404
+ y , ok := y .(* ast.UnaryExpr )
405
+ return ok && x .Op == y .Op && exprIdentical (x .X , y .X )
406
+ case * ast.MapType :
407
+ y , ok := y .(* ast.MapType )
408
+ return ok && exprIdentical (x .Value , y .Value ) && exprIdentical (x .Key , y .Key )
409
+ case * ast.ChanType :
410
+ y , ok := y .(* ast.ChanType )
411
+ return ok && exprIdentical (x .Value , y .Value ) && x .Dir == y .Dir
412
+ case * ast.BinaryExpr :
413
+ y , ok := y .(* ast.BinaryExpr )
414
+ return ok && x .Op == y .Op &&
415
+ exprIdentical (x .X , y .X ) &&
416
+ exprIdentical (x .Y , y .Y )
417
+ case * ast.Ident :
418
+ y , ok := y .(* ast.Ident )
419
+ return ok && x .Name == y .Name
420
+ case * ast.ParenExpr :
421
+ y , ok := y .(* ast.ParenExpr )
422
+ return ok && exprIdentical (x .X , y .X )
423
+ case * ast.SelectorExpr :
424
+ y , ok := y .(* ast.SelectorExpr )
425
+ return ok &&
426
+ exprIdentical (x .X , y .X ) &&
427
+ exprIdentical (x .Sel , y .Sel )
428
+ case * ast.CallExpr :
429
+ y , ok := y .(* ast.CallExpr )
430
+ if ! ok || len (x .Args ) != len (y .Args ) {
431
+ return false
432
+ }
433
+ if ! exprIdentical (x .Fun , y .Fun ) {
434
+ return false
435
+ }
436
+ for i := range x .Args {
437
+ if ! exprIdentical (x .Args [i ], y .Args [i ]) {
438
+ return false
439
+ }
440
+ }
441
+ return true
442
+ default :
443
+ // For unhandled expression types, consider them unequal.
444
+ return false
445
+ }
446
+ }
447
+
108
448
// canExtractVariable reports whether the code in the given range can be
109
449
// extracted to a variable.
110
450
func canExtractVariable (start , end token.Pos , file * ast.File ) (ast.Expr , []ast.Node , bool , error ) {
0 commit comments