@@ -1091,7 +1091,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
10911091 declarations = initializeVars (uninitialized , retVars , seenUninitialized , seenVars )
10921092 }
10931093
1094- var declBuf , replaceBuf , newFuncBuf , ifBuf , commentBuf bytes.Buffer
1094+ var declBuf , replaceBuf , newFuncBuf , ifBuf bytes.Buffer
10951095 if err := format .Node (& declBuf , fset , declarations ); err != nil {
10961096 return nil , nil , err
10971097 }
@@ -1112,6 +1112,25 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
11121112 }
11131113 }
11141114
1115+ newFuncResults := & ast.FieldList {List : append (returnTypes , getDecls (retVars )... )}
1116+
1117+ // Expand multi-value function calls in return statements.
1118+ // If a return contains a single CallExpr that is being augmented with new
1119+ // return values, the call return values must be expanded to maintain valid syntax.
1120+ ast .Inspect (extractedBlock , func (n ast.Node ) bool {
1121+ switch n := n .(type ) {
1122+ case * ast.BlockStmt :
1123+ n .List = expandFunctionCallReturnValues (n .List , info , newFuncResults , file , start )
1124+ case * ast.CaseClause :
1125+ n .Body = expandFunctionCallReturnValues (n .Body , info , newFuncResults , file , start )
1126+ case * ast.FuncLit :
1127+ // Don't descend into nested functions.
1128+ return false
1129+ }
1130+
1131+ return true
1132+ })
1133+
11151134 // Build the extracted function. We format the function declaration and body
11161135 // separately, so that comments are printed relative to the extracted
11171136 // BlockStmt.
@@ -1125,7 +1144,7 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
11251144 Name : ast .NewIdent (funName ),
11261145 Type : & ast.FuncType {
11271146 Params : & ast.FieldList {List : paramTypes },
1128- Results : & ast. FieldList { List : append ( returnTypes , getDecls ( retVars ) ... )} ,
1147+ Results : newFuncResults ,
11291148 },
11301149 // Body handled separately -- see above.
11311150 }
@@ -1172,10 +1191,6 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
11721191
11731192 var fullReplacement strings.Builder
11741193 fullReplacement .Write (before )
1175- if commentBuf .Len () > 0 {
1176- comments := strings .ReplaceAll (commentBuf .String (), "\n " , newLineIndent )
1177- fullReplacement .WriteString (comments )
1178- }
11791194 if declBuf .Len () > 0 { // add any initializations, if needed
11801195 initializations := strings .ReplaceAll (declBuf .String (), "\n " , newLineIndent ) +
11811196 newLineIndent
@@ -1216,6 +1231,60 @@ func extractFunctionMethod(cpkg *cache.Package, pgf *parsego.File, start, end to
12161231 }, nil
12171232}
12181233
1234+ // expandFunctionCallReturnValues expands the return value of function calls when necessary.
1235+ func expandFunctionCallReturnValues (stmts []ast.Stmt , info * types.Info , newFuncResults * ast.FieldList , file * ast.File , start token.Pos ) []ast.Stmt {
1236+ result := make ([]ast.Stmt , 0 , len (stmts ))
1237+ for _ , stmt := range stmts {
1238+ result = append (result , stmt )
1239+
1240+ retStmt , ok := stmt .(* ast.ReturnStmt )
1241+ if ! ok {
1242+ continue
1243+ }
1244+
1245+ // when we have multiple return statement results, we can't have a CallExpr in it.
1246+ // in that case, we need to splat the values of that CallExpr into variable(s)
1247+ // and return them.
1248+ if len (retStmt .Results ) <= 1 {
1249+ continue
1250+ }
1251+
1252+ // We can only have CallExpr in the first return statement result with the assumption that
1253+ // the original code is valid.
1254+ callExpr , ok := retStmt .Results [0 ].(* ast.CallExpr )
1255+ if ! ok {
1256+ continue
1257+ }
1258+
1259+ // Infer the number of function's return value using the enclosing function
1260+ // signature and the original return statement results because we don't have
1261+ // type information here. This should be correct assuming the original code
1262+ // is valid to begin with.
1263+ expandedVars := make ([]ast.Expr , len (newFuncResults .List )- len (retStmt .Results )+ 1 ) // plus one to replace the CallExpr
1264+ prevIdx := 1
1265+ for i := range expandedVars {
1266+ // ideally we want to generate a better name (e.g. `errX` for error values)
1267+ // but we don't have type info at this stage.
1268+ name , idx := freshName (info , file , start , "v" , prevIdx )
1269+ expandedVars [i ] = ast .NewIdent (name )
1270+ prevIdx = idx
1271+ }
1272+
1273+ result [len (result )- 1 ] = ast .Stmt (& ast.AssignStmt {
1274+ Lhs : expandedVars ,
1275+ Tok : token .DEFINE ,
1276+ Rhs : []ast.Expr {callExpr },
1277+ TokPos : 0 ,
1278+ })
1279+ result = append (result , & ast.ReturnStmt {
1280+ Return : retStmt .Return ,
1281+ Results : slices .Concat (expandedVars , retStmt .Results [1 :]),
1282+ })
1283+ }
1284+
1285+ return result
1286+ }
1287+
12191288// isSelector reports if e is the selector expr <x>, <sel>. It works for pointer and non-pointer selector expressions.
12201289func isSelector (e ast.Expr , x , sel string ) bool {
12211290 unary , ok := e .(* ast.UnaryExpr )
0 commit comments