Skip to content

Commit 44bfbc7

Browse files
committed
replace_all
1 parent 386503d commit 44bfbc7

File tree

8 files changed

+574
-33
lines changed

8 files changed

+574
-33
lines changed

gopls/internal/golang/codeaction.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ var codeActionProducers = [...]codeActionProducer{
237237
{kind: settings.RefactorExtractMethod, fn: refactorExtractMethod},
238238
{kind: settings.RefactorExtractToNewFile, fn: refactorExtractToNewFile},
239239
{kind: settings.RefactorExtractVariable, fn: refactorExtractVariable},
240+
{kind: settings.RefactorReplaceAllOccursOfExpr, fn: refactorReplaceAllOccursOfExpr},
240241
{kind: settings.RefactorInlineCall, fn: refactorInlineCall, needPkg: true},
241242
{kind: settings.RefactorRewriteChangeQuote, fn: refactorRewriteChangeQuote},
242243
{kind: settings.RefactorRewriteFillStruct, fn: refactorRewriteFillStruct, needPkg: true},
@@ -458,6 +459,15 @@ func refactorExtractVariable(ctx context.Context, req *codeActionsRequest) error
458459
return nil
459460
}
460461

462+
// refactorReplaceAllOccursOfExpr produces "Replace all occcurrance of expr" code actions.
463+
// See [replaceAllOccursOfExpr] for command implementation.
464+
func refactorReplaceAllOccursOfExpr(ctx context.Context, req *codeActionsRequest) error {
465+
if _, ok, _ := allOccurs(req.start, req.end, req.pgf.File); ok {
466+
req.addApplyFixAction(fmt.Sprintf("Replace all occcurrances of expression"), fixReplaceAllOccursOfExpr, req.loc)
467+
}
468+
return nil
469+
}
470+
461471
// refactorExtractToNewFile produces "Extract declarations to new file" code actions.
462472
// See [server.commandHandler.ExtractToNewFile] for command implementation.
463473
func refactorExtractToNewFile(ctx context.Context, req *codeActionsRequest) error {

gopls/internal/golang/extract.go

Lines changed: 343 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,22 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
3636
// TODO: stricter rules for selectorExpr.
3737
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
3838
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
39-
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0)
39+
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0)
4040
lhsNames = append(lhsNames, lhsName)
4141
case *ast.CallExpr:
4242
tup, ok := info.TypeOf(expr).(*types.Tuple)
4343
if !ok {
4444
// If the call expression only has one return value, we can treat it the
4545
// same as our standard extract variable case.
46-
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", 0)
46+
lhsName, _ := generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", 0)
4747
lhsNames = append(lhsNames, lhsName)
4848
break
4949
}
5050
idx := 0
5151
for i := 0; i < tup.Len(); i++ {
5252
// Generate a unique variable for each return value.
5353
var lhsName string
54-
lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "x", idx)
54+
lhsName, idx = generateAvailableIdentifier(expr.Pos(), path, pkg, info, "newVar", idx)
5555
lhsNames = append(lhsNames, lhsName)
5656
}
5757
default:
@@ -105,6 +105,346 @@ func extractVariable(fset *token.FileSet, start, end token.Pos, src []byte, file
105105
}, nil
106106
}
107107

108+
func replaceAllOccursOfExpr(fset *token.FileSet, start, end token.Pos, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*token.FileSet, *analysis.SuggestedFix, error) {
109+
tokFile := fset.File(file.Pos())
110+
// exprs contains at least one expr.
111+
exprs, _, err := allOccurs(start, end, file)
112+
if err != nil {
113+
return nil, nil, fmt.Errorf("extractVariable: cannot extract %s: %v", safetoken.StartPosition(fset, start), err)
114+
}
115+
116+
scopes := make([][]*types.Scope, len(exprs))
117+
for i, e := range exprs {
118+
path, _ := astutil.PathEnclosingInterval(file, e.Pos(), e.End())
119+
scopes[i] = CollectScopes(info, path, e.Pos())
120+
}
121+
122+
// Find the deepest common scope among all expressions.
123+
commonScope, err := findDeepestCommonScope(scopes)
124+
if err != nil {
125+
return nil, nil, fmt.Errorf("extractVariable: %v", err)
126+
}
127+
128+
var innerScopes []*types.Scope
129+
for _, scope := range scopes {
130+
for _, s := range scope {
131+
if s != nil {
132+
innerScopes = append(innerScopes, s)
133+
break
134+
}
135+
}
136+
}
137+
if len(innerScopes) != len(exprs) {
138+
return nil, nil, fmt.Errorf("extractVariable: nil scope")
139+
}
140+
// So the largest scope's name won't conflict.
141+
innerScopes = append(innerScopes, commonScope)
142+
143+
// Create new AST node for extracted code.
144+
var lhsNames []string
145+
switch expr := exprs[0].(type) {
146+
// TODO: stricter rules for selectorExpr.
147+
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.SliceExpr,
148+
*ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
149+
lhsName, _ := generateAvailableIdentifierForAllScopes(innerScopes, "newVar", 0)
150+
lhsNames = append(lhsNames, lhsName)
151+
case *ast.CallExpr:
152+
tup, ok := info.TypeOf(expr).(*types.Tuple)
153+
if !ok {
154+
// If the call expression only has one return value, we can treat it the
155+
// same as our standard extract variable case.
156+
lhsName, _ := generateAvailableIdentifierForAllScopes(innerScopes, "newVar", 0)
157+
lhsNames = append(lhsNames, lhsName)
158+
break
159+
}
160+
idx := 0
161+
for i := 0; i < tup.Len(); i++ {
162+
// Generate a unique variable for each return value.
163+
var lhsName string
164+
lhsName, idx = generateAvailableIdentifierForAllScopes(innerScopes, "newVar", idx)
165+
lhsNames = append(lhsNames, lhsName)
166+
}
167+
default:
168+
return nil, nil, fmt.Errorf("cannot extract %T", expr)
169+
}
170+
171+
var validPath []ast.Node
172+
if commonScope != innerScopes[0] {
173+
// This means the first expr within function body is not the largest scope,
174+
// we need to find the scope immediately follow the common
175+
// scope where we will insert the statement before.
176+
child := innerScopes[0]
177+
for p := child; p != nil; p = p.Parent() {
178+
if p == commonScope {
179+
break
180+
}
181+
child = p
182+
}
183+
validPath, _ = astutil.PathEnclosingInterval(file, child.Pos(), child.End())
184+
} else {
185+
// Just insert before the first expr.
186+
validPath, _ = astutil.PathEnclosingInterval(file, exprs[0].Pos(), exprs[0].End())
187+
}
188+
//
189+
// TODO: There is a bug here: for a variable declared in a labeled
190+
// switch/for statement it returns the for/switch statement itself
191+
// which produces the below code which is a compiler error e.g.
192+
// label:
193+
// switch r1 := r() { ... break label ... }
194+
// On extracting "r()" to a variable
195+
// label:
196+
// x := r()
197+
// switch r1 := x { ... break label ... } // compiler error
198+
//
199+
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(validPath)
200+
if insertBeforeStmt == nil {
201+
return nil, nil, fmt.Errorf("cannot find location to insert extraction")
202+
}
203+
indent, err := calculateIndentation(src, tokFile, insertBeforeStmt)
204+
if err != nil {
205+
return nil, nil, err
206+
}
207+
newLineIndent := "\n" + indent
208+
209+
lhs := strings.Join(lhsNames, ", ")
210+
assignStmt := &ast.AssignStmt{
211+
Lhs: []ast.Expr{ast.NewIdent(lhs)},
212+
Tok: token.DEFINE,
213+
Rhs: []ast.Expr{exprs[0]},
214+
}
215+
var buf bytes.Buffer
216+
if err := format.Node(&buf, fset, assignStmt); err != nil {
217+
return nil, nil, err
218+
}
219+
assignment := strings.ReplaceAll(buf.String(), "\n", newLineIndent) + newLineIndent
220+
var textEdits []analysis.TextEdit
221+
textEdits = append(textEdits, analysis.TextEdit{
222+
Pos: insertBeforeStmt.Pos(),
223+
End: insertBeforeStmt.Pos(),
224+
NewText: []byte(assignment),
225+
})
226+
for _, e := range exprs {
227+
textEdits = append(textEdits, analysis.TextEdit{
228+
Pos: e.Pos(),
229+
End: e.End(),
230+
NewText: []byte(lhs),
231+
})
232+
}
233+
return fset, &analysis.SuggestedFix{
234+
TextEdits: textEdits,
235+
}, nil
236+
}
237+
238+
// findDeepestCommonScope finds the deepest (innermost) scope that is common to all provided scope chains.
239+
// Each scope chain represents the scopes of an expression from innermost to outermost.
240+
// If no common scope is found, it returns an error.
241+
func findDeepestCommonScope(scopeChains [][]*types.Scope) (*types.Scope, error) {
242+
if len(scopeChains) == 0 {
243+
return nil, fmt.Errorf("no scopes provided")
244+
}
245+
// Get the first scope chain as the reference.
246+
referenceChain := scopeChains[0]
247+
248+
// Iterate from innermost to outermost scope.
249+
for i := 0; i < len(referenceChain); i++ {
250+
candidateScope := referenceChain[i]
251+
if candidateScope == nil {
252+
continue
253+
}
254+
isCommon := true
255+
// See if other exprs' chains all have candidateScope as a common ancestor.
256+
for _, chain := range scopeChains[1:] {
257+
found := false
258+
for j := 0; j < len(chain); j++ {
259+
if chain[j] == candidateScope {
260+
found = true
261+
break
262+
}
263+
}
264+
if !found {
265+
isCommon = false
266+
break
267+
}
268+
}
269+
if isCommon {
270+
return candidateScope, nil
271+
}
272+
}
273+
return nil, fmt.Errorf("no common scope found")
274+
}
275+
276+
// allOccurs finds all occurrences of an expression identical to the one
277+
// specified by the start and end positions within the same function.
278+
// It returns at least one ast.Expr.
279+
func allOccurs(start, end token.Pos, file *ast.File) ([]ast.Expr, bool, error) {
280+
if start == end {
281+
return nil, false, fmt.Errorf("start and end are equal")
282+
}
283+
path, _ := astutil.PathEnclosingInterval(file, start, end)
284+
if len(path) == 0 {
285+
return nil, false, fmt.Errorf("no path enclosing interval")
286+
}
287+
for _, n := range path {
288+
if _, ok := n.(*ast.ImportSpec); ok {
289+
return nil, false, fmt.Errorf("cannot extract variable in an import block")
290+
}
291+
}
292+
node := path[0]
293+
if start != node.Pos() || end != node.End() {
294+
return nil, false, fmt.Errorf("range does not map to an AST node")
295+
}
296+
expr, ok := node.(ast.Expr)
297+
if !ok {
298+
return nil, false, fmt.Errorf("node is not an expression")
299+
}
300+
301+
var exprs []ast.Expr
302+
exprs = append(exprs, expr)
303+
if funcDecl, ok := path[len(path)-2].(*ast.FuncDecl); ok {
304+
ast.Inspect(funcDecl, func(n ast.Node) bool {
305+
if e, ok := n.(ast.Expr); ok && e != expr {
306+
if exprIdentical(e, expr) {
307+
exprs = append(exprs, e)
308+
}
309+
}
310+
return true
311+
})
312+
}
313+
sort.Slice(exprs, func(i, j int) bool {
314+
return exprs[i].Pos() < exprs[j].Pos()
315+
})
316+
317+
switch expr.(type) {
318+
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr, *ast.CallExpr,
319+
*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr:
320+
return exprs, true, nil
321+
}
322+
return nil, false, fmt.Errorf("cannot extract an %T to a variable", expr)
323+
}
324+
325+
// generateAvailableIdentifierForAllScopes adjusts the new identifier name
326+
// until there are no collisions in any of the provided scopes.
327+
func generateAvailableIdentifierForAllScopes(scopes []*types.Scope, prefix string, idx int) (string, int) {
328+
name := prefix
329+
for {
330+
collision := false
331+
for _, scope := range scopes {
332+
if scope.Lookup(name) != nil {
333+
collision = true
334+
break
335+
}
336+
}
337+
if !collision {
338+
// No collision found; return the name and the current index.
339+
return name, idx
340+
}
341+
// Adjust the name by appending the index and increment the index.
342+
idx++
343+
name = fmt.Sprintf("%s%d", prefix, idx)
344+
}
345+
}
346+
347+
// exprIdentical recursively compares two ast.Expr nodes for structural equality,
348+
// ignoring position fields.
349+
func exprIdentical(x, y ast.Expr) bool {
350+
if x == nil || y == nil {
351+
return x == y
352+
}
353+
switch x := x.(type) {
354+
case *ast.BasicLit:
355+
y, ok := y.(*ast.BasicLit)
356+
return ok && x.Kind == y.Kind && x.Value == y.Value
357+
case *ast.CompositeLit:
358+
y, ok := y.(*ast.CompositeLit)
359+
if !ok || len(x.Elts) != len(y.Elts) || !exprIdentical(x.Type, y.Type) {
360+
return false
361+
}
362+
for i := range x.Elts {
363+
if !exprIdentical(x.Elts[i], y.Elts[i]) {
364+
return false
365+
}
366+
}
367+
return true
368+
case *ast.ArrayType:
369+
y, ok := y.(*ast.ArrayType)
370+
return ok && exprIdentical(x.Len, y.Len) && exprIdentical(x.Elt, y.Elt)
371+
case *ast.Ellipsis:
372+
y, ok := y.(*ast.Ellipsis)
373+
return ok && exprIdentical(x.Elt, y.Elt)
374+
case *ast.FuncLit:
375+
y, ok := y.(*ast.FuncLit)
376+
return ok && exprIdentical(x.Type, y.Type)
377+
case *ast.IndexExpr:
378+
y, ok := y.(*ast.IndexExpr)
379+
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Index, y.Index)
380+
case *ast.IndexListExpr:
381+
y, ok := y.(*ast.IndexListExpr)
382+
if !ok || len(x.Indices) != len(y.Indices) || !exprIdentical(x.X, y.X) {
383+
return false
384+
}
385+
for i := range x.Indices {
386+
if !exprIdentical(x.Indices[i], y.Indices[i]) {
387+
return false
388+
}
389+
}
390+
return true
391+
case *ast.SliceExpr:
392+
y, ok := y.(*ast.SliceExpr)
393+
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Low, y.Low) && exprIdentical(x.High, y.High) && exprIdentical(x.Max, y.Max) && x.Slice3 == y.Slice3
394+
case *ast.TypeAssertExpr:
395+
y, ok := y.(*ast.TypeAssertExpr)
396+
return ok && exprIdentical(x.X, y.X) && exprIdentical(x.Type, y.Type)
397+
case *ast.StarExpr:
398+
y, ok := y.(*ast.StarExpr)
399+
return ok && exprIdentical(x.X, y.X)
400+
case *ast.KeyValueExpr:
401+
y, ok := y.(*ast.KeyValueExpr)
402+
return ok && exprIdentical(x.Key, y.Key) && exprIdentical(x.Value, y.Value)
403+
case *ast.UnaryExpr:
404+
y, ok := y.(*ast.UnaryExpr)
405+
return ok && x.Op == y.Op && exprIdentical(x.X, y.X)
406+
case *ast.MapType:
407+
y, ok := y.(*ast.MapType)
408+
return ok && exprIdentical(x.Value, y.Value) && exprIdentical(x.Key, y.Key)
409+
case *ast.ChanType:
410+
y, ok := y.(*ast.ChanType)
411+
return ok && exprIdentical(x.Value, y.Value) && x.Dir == y.Dir
412+
case *ast.BinaryExpr:
413+
y, ok := y.(*ast.BinaryExpr)
414+
return ok && x.Op == y.Op &&
415+
exprIdentical(x.X, y.X) &&
416+
exprIdentical(x.Y, y.Y)
417+
case *ast.Ident:
418+
y, ok := y.(*ast.Ident)
419+
return ok && x.Name == y.Name
420+
case *ast.ParenExpr:
421+
y, ok := y.(*ast.ParenExpr)
422+
return ok && exprIdentical(x.X, y.X)
423+
case *ast.SelectorExpr:
424+
y, ok := y.(*ast.SelectorExpr)
425+
return ok &&
426+
exprIdentical(x.X, y.X) &&
427+
exprIdentical(x.Sel, y.Sel)
428+
case *ast.CallExpr:
429+
y, ok := y.(*ast.CallExpr)
430+
if !ok || len(x.Args) != len(y.Args) {
431+
return false
432+
}
433+
if !exprIdentical(x.Fun, y.Fun) {
434+
return false
435+
}
436+
for i := range x.Args {
437+
if !exprIdentical(x.Args[i], y.Args[i]) {
438+
return false
439+
}
440+
}
441+
return true
442+
default:
443+
// For unhandled expression types, consider them unequal.
444+
return false
445+
}
446+
}
447+
108448
// canExtractVariable reports whether the code in the given range can be
109449
// extracted to a variable.
110450
func canExtractVariable(start, end token.Pos, file *ast.File) (ast.Expr, []ast.Node, bool, error) {

0 commit comments

Comments
 (0)