Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 135 additions & 6 deletions gopls/internal/golang/extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,8 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
allReturnsFinalErr = true // all ReturnStmts have final 'err' expression
hasReturn = false // selection contains a ReturnStmt
filter = []ast.Node{(*ast.ReturnStmt)(nil), (*ast.FuncLit)(nil)}

origRetStmts []*ast.ReturnStmt // return stmts in source order, for type lookups
)
curEnclosing.Inspect(filter, func(cur inspector.Cursor) (descend bool) {
if funcLit, ok := cur.Node().(*ast.FuncLit); ok {
Expand All @@ -643,6 +645,8 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
}
hasReturn = true

origRetStmts = append(origRetStmts, ret)

if cur.Parent() == curStart.Parent() {
hasNonNestedReturn = true
}
Expand Down Expand Up @@ -1091,7 +1095,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
}

var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer
var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer
if err := format.Node(&declBuf, fset, declarations); err != nil {
return nil, nil, err
}
Expand All @@ -1112,6 +1116,15 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
}
}

newFuncResults := &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}

// Expand multi-value function calls in return statements.
// If a return contains a single CallExpr that is being augmented with new
// return values, the call return values must be expanded to maintain valid syntax.
if err := expandMultiValueCallReturns(extractedBlock, info, newFuncResults, file, start, origRetStmts); err != nil {
return nil, nil, err
}

// Build the extracted function. We format the function declaration and body
// separately, so that comments are printed relative to the extracted
// BlockStmt.
Expand All @@ -1125,7 +1138,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
Name: ast.NewIdent(funName),
Type: &ast.FuncType{
Params: &ast.FieldList{List: paramTypes},
Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
Results: newFuncResults,
},
// Body handled separately -- see above.
}
Expand Down Expand Up @@ -1172,10 +1185,6 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to

var fullReplacement strings.Builder
fullReplacement.Write(before)
if commentBuf.Len() > 0 {
comments := strings.ReplaceAll(commentBuf.String(), "\n", newLineIndent)
fullReplacement.WriteString(comments)
}
if declBuf.Len() > 0 { // add any initializations, if needed
initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
newLineIndent
Expand Down Expand Up @@ -1216,6 +1225,126 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
}, nil
}

// expandMultiValueCallReturns expands multi-value function calls in return
// statements within the extracted block.
func expandMultiValueCallReturns(extractedBlock *ast.BlockStmt, info *types.Info, newFuncResults *ast.FieldList, file *ast.File, start token.Pos, origRetStmts []*ast.ReturnStmt) error {
// The re-parsed AST has no type information, so we pair its return stmts
// with the original (type-checked) ones to look up types for naming.
//
// The pairing is done as a separate pass because the second pass doesn't
// exactly visit the ReturnStmt in the same way as how the origRetStmts
// is collected (via ast.Inspect).
origRetMap := map[*ast.ReturnStmt]*ast.ReturnStmt{}
origIdx := 0
ast.Inspect(extractedBlock, func(n ast.Node) bool {
switch n := n.(type) {
case *ast.ReturnStmt:
if origIdx < len(origRetStmts) {
origRetMap[n] = origRetStmts[origIdx]
origIdx++
} else {
// The re-parsed AST may have injected returns appended
// at the end with no original counterpart but it is ok since
// we can guarantee it will not have CallExpr in it.
return false
}
case *ast.FuncLit:
return false // don't descend into closures.
}

return true
})

// Traverse the extracted block again and do the actual expansion.
var expandErr error
ast.Inspect(extractedBlock, func(n ast.Node) bool {
if expandErr != nil {
return false
}
switch n := n.(type) {
case *ast.BlockStmt:
n.List, expandErr = expandFunctionCallReturnValues(n.List, info, newFuncResults, file, start, origRetMap)
case *ast.CaseClause:
n.Body, expandErr = expandFunctionCallReturnValues(n.Body, info, newFuncResults, file, start, origRetMap)
case *ast.FuncLit:
return false // don't descend into closures.
}

return true
})
return expandErr
}

// expandFunctionCallReturnValues expands the return value of function calls
// in the given statement list when necessary.
func expandFunctionCallReturnValues(stmts []ast.Stmt, info *types.Info, newFuncResults *ast.FieldList, file *ast.File, start token.Pos, origRetMap map[*ast.ReturnStmt]*ast.ReturnStmt) ([]ast.Stmt, error) {
result := make([]ast.Stmt, 0, len(stmts))
for _, stmt := range stmts {
result = append(result, stmt)

// When we have multiple return statement results, we can't have a CallExpr in it.
// In that case, we need to splat the values of that CallExpr into variable(s)
// and return them.
retStmt, ok := stmt.(*ast.ReturnStmt)
if !ok || len(retStmt.Results) <= 1 {
continue
}

// We can only have CallExpr in the first return statement result with the assumption that
// the original code is valid.
callExpr, ok := retStmt.Results[0].(*ast.CallExpr)
if !ok {
continue
}

// Infer the number of function's return value using the enclosing function
// signature and the original return statement results because we don't have
// type information here. This should be correct assuming the original code
// is valid to begin with.
expandedVars := make([]ast.Expr, len(newFuncResults.List)-len(retStmt.Results)+1) // plus one to replace the CallExpr

// Use type information from the original return statement to
// generate type-aware names and detect scope collisions.
origRet := origRetMap[retStmt]
if origRet == nil {
return nil, bug.Errorf("no original return statement for re-parsed return")
}

scopePos := origRet.Pos()
origCallExpr := origRet.Results[0].(*ast.CallExpr)
sig := info.TypeOf(origCallExpr.Fun).Underlying().(*types.Signature)
tup := sig.Results()

// Generate type-aware names for each expanded return values.
prevIdxByPrefix := map[string]int{}
for i := range expandedVars {
prefix := "v"
if name, ok := varNameForType(tup.At(i).Type()); ok {
prefix = name
}

prev := prevIdxByPrefix[prefix]
name, next := freshName(info, file, scopePos, prefix, prev)
prevIdxByPrefix[prefix] = next

expandedVars[i] = ast.NewIdent(name)
}

result[len(result)-1] = ast.Stmt(&ast.AssignStmt{
Lhs: expandedVars,
Tok: token.DEFINE,
Rhs: []ast.Expr{callExpr},
TokPos: 0,
})
result = append(result, &ast.ReturnStmt{
Return: retStmt.Return,
Results: slices.Concat(expandedVars, retStmt.Results[1:]),
})
}

return result, nil
}

// isSelector reports if e is the selector expr <x>, <sel>. It works for pointer and non-pointer selector expressions.
func isSelector(e ast.Expr, x, sel string) bool {
unary, ok := e.(*ast.UnaryExpr)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
This test verifies the fix for golang/go#77240: type-aware variable naming
when expanding multi-value function call returns during extraction.

-- go.mod --
module mod.test/extract

go 1.18

-- p1/p.go --
package extract

func Fun(v2 int) (int, int, error) {
switch v2 { //@codeaction("switch", "refactor.extract.function", end=end, result=ext)
case 1: // also a comment!
return doOne() // a comment!
case 2:
return doTwo()
} //@loc(end, "}")

return 1, 3, nil
}

func doOne() (int, int, error) {
return 0, 1, nil
}

func doTwo() (int, int, error) {
return 0, 2, nil
}

-- @ext/p1/p.go --
package extract

func Fun(v2 int) (int, int, error) {
i, i1, err, shouldReturn := newFunction(v2)
if shouldReturn {
return i, i1, err
} //@loc(end, "}")

return 1, 3, nil
}

func newFunction(v2 int) (int, int, error, bool) {
switch v2 { //@codeaction("switch", "refactor.extract.function", end=end, result=ext)
case 1:
i, // also a comment!
i1, err := doOne()
return i, i1, err, true // a comment!
case 2:
i, i1, err := doTwo()
return i, i1, err, true
}
return 0, 0, nil, false
}

func doOne() (int, int, error) {
return 0, 1, nil
}

func doTwo() (int, int, error) {
return 0, 2, nil
}

-- p2/p.go --
package extract

import "fmt"

func Fun(v2 int) (int, int, error) {
switch v2 { //@codeaction("switch", "refactor.extract.function", end=end2, result=ext2)
case 1:
i := v2 + 1
i1 := v2 + 2
err := fmt.Errorf("foo")
fmt.Println(i, i1, err)
return doOne()
case 2:
return doTwo()
} //@loc(end2, "}")

return 1, 3, nil
}

func doOne() (int, int, error) { return 0, 1, nil }
func doTwo() (int, int, error) { return 0, 2, nil }

-- @ext2/p2/p.go --
package extract

import "fmt"

func Fun(v2 int) (int, int, error) {
i, i1, err, shouldReturn := newFunction(v2)
if shouldReturn {
return i, i1, err
} //@loc(end2, "}")

return 1, 3, nil
}

func newFunction(v2 int) (int, int, error, bool) {
switch v2 { //@codeaction("switch", "refactor.extract.function", end=end2, result=ext2)
case 1:
i := v2 + 1
i1 := v2 + 2
err := fmt.Errorf("foo")
fmt.Println(i, i1, err)
i2, i3, err1 := doOne()
return i2, i3, err1, true
case 2:
i, i1, err := doTwo()
return i, i1, err, true
}
return 0, 0, nil, false
}

func doOne() (int, int, error) { return 0, 1, nil }
func doTwo() (int, int, error) { return 0, 2, nil }