Skip to content

Commit e50d3ec

Browse files
committed
expand multi-value returns in function extraction
When extracting a function, handle return statements containing multi-value function calls by expanding them into variable assignments. This ensures valid syntax when the extracted function adds additional return values. Fixes golang/go#77240
1 parent 613c127 commit e50d3ec

File tree

2 files changed

+138
-6
lines changed

2 files changed

+138
-6
lines changed

gopls/internal/golang/extract.go

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
10911091
declarations = initializeVars(uninitialized, retVars, seenUninitialized, seenVars)
10921092
}
10931093

1094-
var declBuf, replaceBuf, newFuncBuf, ifBuf, commentBuf bytes.Buffer
1094+
var declBuf, replaceBuf, newFuncBuf, ifBuf bytes.Buffer
10951095
if err := format.Node(&declBuf, fset, declarations); err != nil {
10961096
return nil, nil, err
10971097
}
@@ -1112,6 +1112,25 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
11121112
}
11131113
}
11141114

1115+
newFuncResults := &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)}
1116+
1117+
// Expand multi-value function calls in return statements.
1118+
// If a return contains a single CallExpr that is being augmented with new
1119+
// return values, the call return values must be expanded to maintain valid syntax.
1120+
ast.Inspect(extractedBlock, func(n ast.Node) bool {
1121+
switch n := n.(type) {
1122+
case *ast.BlockStmt:
1123+
n.List = expandFunctionCallReturnValues(n.List, info, newFuncResults, file, start)
1124+
case *ast.CaseClause:
1125+
n.Body = expandFunctionCallReturnValues(n.Body, info, newFuncResults, file, start)
1126+
case *ast.FuncLit:
1127+
// Don't descend into nested functions.
1128+
return false
1129+
}
1130+
1131+
return true
1132+
})
1133+
11151134
// Build the extracted function. We format the function declaration and body
11161135
// separately, so that comments are printed relative to the extracted
11171136
// BlockStmt.
@@ -1125,7 +1144,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
11251144
Name: ast.NewIdent(funName),
11261145
Type: &ast.FuncType{
11271146
Params: &ast.FieldList{List: paramTypes},
1128-
Results: &ast.FieldList{List: append(returnTypes, getDecls(retVars)...)},
1147+
Results: newFuncResults,
11291148
},
11301149
// Body handled separately -- see above.
11311150
}
@@ -1172,10 +1191,6 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
11721191

11731192
var fullReplacement strings.Builder
11741193
fullReplacement.Write(before)
1175-
if commentBuf.Len() > 0 {
1176-
comments := strings.ReplaceAll(commentBuf.String(), "\n", newLineIndent)
1177-
fullReplacement.WriteString(comments)
1178-
}
11791194
if declBuf.Len() > 0 { // add any initializations, if needed
11801195
initializations := strings.ReplaceAll(declBuf.String(), "\n", newLineIndent) +
11811196
newLineIndent
@@ -1216,6 +1231,60 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
12161231
}, nil
12171232
}
12181233

1234+
// expandFunctionCallReturnValues expands the return value of function calls when necessary.
1235+
func expandFunctionCallReturnValues(stmts []ast.Stmt, info *types.Info, newFuncResults *ast.FieldList, file *ast.File, start token.Pos) []ast.Stmt {
1236+
result := make([]ast.Stmt, 0, len(stmts))
1237+
for _, stmt := range stmts {
1238+
result = append(result, stmt)
1239+
1240+
retStmt, ok := stmt.(*ast.ReturnStmt)
1241+
if !ok {
1242+
continue
1243+
}
1244+
1245+
// when we have multiple return statement results, we can't have a CallExpr in it.
1246+
// in that case, we need to splat the values of that CallExpr into variable(s)
1247+
// and return them.
1248+
if len(retStmt.Results) <= 1 {
1249+
continue
1250+
}
1251+
1252+
// We can only have CallExpr in the first return statement result with the assumption that
1253+
// the original code is valid.
1254+
callExpr, ok := retStmt.Results[0].(*ast.CallExpr)
1255+
if !ok {
1256+
continue
1257+
}
1258+
1259+
// Infer the number of function's return value using the enclosing function
1260+
// signature and the original return statement results because we don't have
1261+
// type information here. This should be correct assuming the original code
1262+
// is valid to begin with.
1263+
expandedVars := make([]ast.Expr, len(newFuncResults.List)-len(retStmt.Results)+1) // plus one to replace the CallExpr
1264+
prevIdx := 1
1265+
for i := range expandedVars {
1266+
// ideally we want to generate a better name (e.g. `errX` for error values)
1267+
// but we don't have type info at this stage.
1268+
name, idx := freshName(info, file, start, "v", prevIdx)
1269+
expandedVars[i] = ast.NewIdent(name)
1270+
prevIdx = idx
1271+
}
1272+
1273+
result[len(result)-1] = ast.Stmt(&ast.AssignStmt{
1274+
Lhs: expandedVars,
1275+
Tok: token.DEFINE,
1276+
Rhs: []ast.Expr{callExpr},
1277+
TokPos: 0,
1278+
})
1279+
result = append(result, &ast.ReturnStmt{
1280+
Return: retStmt.Return,
1281+
Results: slices.Concat(expandedVars, retStmt.Results[1:]),
1282+
})
1283+
}
1284+
1285+
return result
1286+
}
1287+
12191288
// isSelector reports if e is the selector expr <x>, <sel>. It works for pointer and non-pointer selector expressions.
12201289
func isSelector(e ast.Expr, x, sel string) bool {
12211290
unary, ok := e.(*ast.UnaryExpr)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
This test verifies the fix for golang/go#44813: extraction failure when there
2+
are blank identifiers.
3+
4+
-- go.mod --
5+
module mod.test/extract
6+
7+
go 1.18
8+
9+
-- p1/p.go --
10+
package extract
11+
12+
func Fun(v2 int) (int, int, error) {
13+
switch v2 { //@codeaction("switch", "refactor.extract.function", end=end, result=ext)
14+
case 1: // also a comment!
15+
return doOne() // a comment!
16+
case 2:
17+
return doTwo()
18+
} //@loc(end, "}")
19+
20+
return 1, 3, nil
21+
}
22+
23+
func doOne() (int, int, error) {
24+
return 0, 1, nil
25+
}
26+
27+
func doTwo() (int, int, error) {
28+
return 0, 2, nil
29+
}
30+
31+
-- @ext/p1/p.go --
32+
package extract
33+
34+
func Fun(v2 int) (int, int, error) {
35+
i, i1, err, shouldReturn := newFunction(v2)
36+
if shouldReturn {
37+
return i, i1, err
38+
} //@loc(end, "}")
39+
40+
return 1, 3, nil
41+
}
42+
43+
func newFunction(v2 int) (int, int, error, bool) {
44+
switch v2 { //@codeaction("switch", "refactor.extract.function", end=end, result=ext)
45+
case 1:
46+
v1, // also a comment!
47+
v3, v4 := doOne()
48+
return v1, v3, v4, true // a comment!
49+
case 2:
50+
v1, v3, v4 := doTwo()
51+
return v1, v3, v4, true
52+
}
53+
return 0, 0, nil, false
54+
}
55+
56+
func doOne() (int, int, error) {
57+
return 0, 1, nil
58+
}
59+
60+
func doTwo() (int, int, error) {
61+
return 0, 2, nil
62+
}
63+

0 commit comments

Comments
 (0)