Skip to content

Commit eb38a3e

Browse files
committed
gopls/internal: add code action "extract declarations to new file"
This code action moves selected code sections to a newly created file within the same package. The created filename is chosen as the first {function, type, const, var} name encountered. In addition, import declarations are added or removed as needed. Fixes golang/go#65707
1 parent 850c7c3 commit eb38a3e

File tree

6 files changed

+383
-4
lines changed

6 files changed

+383
-4
lines changed

gopls/internal/golang/codeaction.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,6 @@ func fixedByImportFix(fix *imports.ImportFix, diagnostics []protocol.Diagnostic)
240240

241241
// getExtractCodeActions returns any refactor.extract code actions for the selection.
242242
func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *settings.Options) ([]protocol.CodeAction, error) {
243-
if rng.Start == rng.End {
244-
return nil, nil
245-
}
246-
247243
start, end, err := pgf.RangePos(rng)
248244
if err != nil {
249245
return nil, err
@@ -286,6 +282,16 @@ func getExtractCodeActions(pgf *parsego.File, rng protocol.Range, options *setti
286282
}
287283
commands = append(commands, cmd)
288284
}
285+
if canExtractToNewFile(pgf, start, end) {
286+
cmd, err := command.NewExtractToNewFileCommand(
287+
"Extract declarations to new file",
288+
command.ExtractToNewFileArgs{URI: pgf.URI, Range: rng},
289+
)
290+
if err != nil {
291+
return nil, err
292+
}
293+
commands = append(commands, cmd)
294+
}
289295
var actions []protocol.CodeAction
290296
for i := range commands {
291297
actions = append(actions, newCodeAction(commands[i].Title, protocol.RefactorExtract, &commands[i], nil, options))
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
// Copyright 2024 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package golang
6+
7+
// This file defines the code action "Extract declarations to new file".
8+
9+
import (
10+
"context"
11+
"errors"
12+
"fmt"
13+
"go/ast"
14+
"go/format"
15+
"go/token"
16+
"go/types"
17+
"os"
18+
"path/filepath"
19+
"strings"
20+
21+
"golang.org/x/tools/gopls/internal/cache"
22+
"golang.org/x/tools/gopls/internal/cache/parsego"
23+
"golang.org/x/tools/gopls/internal/file"
24+
"golang.org/x/tools/gopls/internal/protocol"
25+
"golang.org/x/tools/gopls/internal/util/bug"
26+
"golang.org/x/tools/gopls/internal/util/typesutil"
27+
)
28+
29+
// canExtractToNewFile reports whether the code in the given range can be extracted to a new file.
30+
func canExtractToNewFile(pgf *parsego.File, start, end token.Pos) bool {
31+
_, _, _, ok := selectedToplevelDecls(pgf, start, end)
32+
return ok
33+
}
34+
35+
// findImportEdits finds imports specs that needs to be added to the new file
36+
// or deleted from the old file if the range is extracted to a new file.
37+
//
38+
// TODO: handle dot imports
39+
func findImportEdits(file *ast.File, info *types.Info, start, end token.Pos) (adds []*ast.ImportSpec, deletes []*ast.ImportSpec) {
40+
// make a map from a pkgName to its references
41+
pkgNameReferences := make(map[*types.PkgName][]*ast.Ident)
42+
for ident, use := range info.Uses {
43+
if pkgName, ok := use.(*types.PkgName); ok {
44+
pkgNameReferences[pkgName] = append(pkgNameReferences[pkgName], ident)
45+
}
46+
}
47+
48+
// PkgName referenced in the extracted selection must be
49+
// imported in the new file.
50+
// PkgName only refereced in the extracted selection must be
51+
// deleted from the original file.
52+
for _, spec := range file.Imports {
53+
pkgName, ok := typesutil.ImportedPkgName(info, spec)
54+
if !ok {
55+
continue
56+
}
57+
usedInSelection := false
58+
usedInNonSelection := false
59+
for _, ident := range pkgNameReferences[pkgName] {
60+
if contain(start, end, ident.Pos(), ident.End()) {
61+
usedInSelection = true
62+
} else {
63+
usedInNonSelection = true
64+
}
65+
}
66+
if usedInSelection {
67+
adds = append(adds, spec)
68+
}
69+
if usedInSelection && !usedInNonSelection {
70+
deletes = append(deletes, spec)
71+
}
72+
}
73+
74+
return adds, deletes
75+
}
76+
77+
// ExtractToNewFile moves selected declarations into a new file.
78+
func ExtractToNewFile(
79+
ctx context.Context,
80+
snapshot *cache.Snapshot,
81+
fh file.Handle,
82+
rng protocol.Range,
83+
) (*protocol.WorkspaceEdit, error) {
84+
errorPrefix := "ExtractToNewFile"
85+
86+
pkg, pgf, err := NarrowestPackageForFile(ctx, snapshot, fh.URI())
87+
if err != nil {
88+
return nil, err
89+
}
90+
91+
start, end, err := pgf.RangePos(rng)
92+
if err != nil {
93+
return nil, fmt.Errorf("%s: %w", errorPrefix, err)
94+
}
95+
96+
start, end, filename, ok := selectedToplevelDecls(pgf, start, end)
97+
if !ok {
98+
return nil, bug.Errorf("precondition unmet")
99+
}
100+
101+
end = skipWhiteSpaces(pgf, end)
102+
103+
replaceRange, err := pgf.PosRange(start, end)
104+
if err != nil {
105+
return nil, bug.Errorf("findRangeAndFilename returned invalid range: %v", err)
106+
}
107+
108+
adds, deletes := findImportEdits(pgf.File, pkg.TypesInfo(), start, end)
109+
110+
var importDeletes []protocol.TextEdit
111+
// For unparenthesised declarations like `import "fmt"` we remove
112+
// the whole declaration because simply removing importSpec leaves
113+
// `import \n`, which does not compile.
114+
// For parenthesised declarations like `import ("fmt"\n "log")`
115+
// we only remove the ImportSpec, because removing the whole declaration
116+
// might remove other ImportsSpecs we don't want to touch.
117+
parenthesisFreeImports := findParenthesisFreeImports(pgf)
118+
for _, importSpec := range deletes {
119+
if decl := parenthesisFreeImports[importSpec]; decl != nil {
120+
importDeletes = append(importDeletes, removeNode(pgf, decl))
121+
} else {
122+
importDeletes = append(importDeletes, removeNode(pgf, importSpec))
123+
}
124+
}
125+
126+
importAdds := ""
127+
if len(adds) > 0 {
128+
importAdds += "import ("
129+
for _, importSpec := range adds {
130+
if importSpec.Name != nil {
131+
importAdds += importSpec.Name.Name + " " + importSpec.Path.Value + "\n"
132+
} else {
133+
importAdds += importSpec.Path.Value + "\n"
134+
}
135+
}
136+
importAdds += ")"
137+
}
138+
139+
newFileURI, err := resolveNewFileURI(ctx, snapshot, pgf.URI.Dir().Path(), filename)
140+
if err != nil {
141+
return nil, fmt.Errorf("%s: %w", errorPrefix, err)
142+
}
143+
144+
// TODO: attempt to duplicate the copyright header, if any.
145+
newFileContent, err := format.Source([]byte(
146+
"package " + pgf.File.Name.Name + "\n" +
147+
importAdds + "\n" +
148+
string(pgf.Src[start-pgf.File.FileStart:end-pgf.File.FileStart]),
149+
))
150+
if err != nil {
151+
return nil, err
152+
}
153+
154+
return protocol.NewWorkspaceEdit(
155+
// original file edits
156+
protocol.DocumentChangeEdit(fh, append(importDeletes, protocol.TextEdit{Range: replaceRange, NewText: ""})),
157+
protocol.DocumentChangeCreate(newFileURI),
158+
// created file edits
159+
protocol.DocumentChangeEdit(&uriVersion{uri: newFileURI, version: 0}, []protocol.TextEdit{
160+
{Range: protocol.Range{}, NewText: string(newFileContent)},
161+
})), nil
162+
}
163+
164+
// uriVersion implements protocol.fileHandle
165+
type uriVersion struct {
166+
uri protocol.DocumentURI
167+
version int32
168+
}
169+
170+
func (fh *uriVersion) URI() protocol.DocumentURI {
171+
return fh.uri
172+
}
173+
func (fh *uriVersion) Version() int32 {
174+
return fh.version
175+
}
176+
177+
// resolveNewFileURI checks that basename.go does not exists in dir, otherwise
178+
// select basename.{1,2,3,4,5}.go as filename.
179+
func resolveNewFileURI(ctx context.Context, snapshot *cache.Snapshot, dir string, basename string) (protocol.DocumentURI, error) {
180+
basename = strings.ToLower(basename)
181+
newPath := protocol.URIFromPath(filepath.Join(dir, basename+".go"))
182+
for count := 1; ; count++ {
183+
fh, err := snapshot.ReadFile(ctx, newPath)
184+
if err != nil {
185+
return "", nil
186+
}
187+
if _, err := fh.Content(); errors.Is(err, os.ErrNotExist) {
188+
break
189+
}
190+
if count >= 5 {
191+
return "", fmt.Errorf("resolveNewFileURI: exceeded retry limit")
192+
}
193+
filename := fmt.Sprintf("%s.%d.go", basename, count)
194+
newPath = protocol.URIFromPath(filepath.Join(dir, filename))
195+
}
196+
return newPath, nil
197+
}
198+
199+
// selectedToplevelDecls returns the lexical extent of the top-level
200+
// declarations enclosed by [start, end), along with the name of the
201+
// first declaration. The returned boolean reports whether the selection
202+
// should be offered code action.
203+
func selectedToplevelDecls(pgf *parsego.File, start, end token.Pos) (token.Pos, token.Pos, string, bool) {
204+
// selection cannot intersect a package declaration
205+
if intersect(start, end, pgf.File.Package, pgf.File.Name.End()) {
206+
return 0, 0, "", false
207+
}
208+
firstName := ""
209+
for _, decl := range pgf.File.Decls {
210+
if intersect(start, end, decl.Pos(), decl.End()) {
211+
var id *ast.Ident
212+
switch v := decl.(type) {
213+
case *ast.BadDecl:
214+
return 0, 0, "", false
215+
case *ast.FuncDecl:
216+
// if only selecting keyword "func" or function name, extend selection to the
217+
// whole function
218+
if contain(v.Pos(), v.Name.End(), start, end) {
219+
start, end = v.Pos(), v.End()
220+
}
221+
id = v.Name
222+
case *ast.GenDecl:
223+
// selection cannot intersect an import declaration
224+
if v.Tok == token.IMPORT {
225+
return 0, 0, "", false
226+
}
227+
// if only selecting keyword "type", "const", or "var", extend selection to the
228+
// whole declaration
229+
if v.Tok == token.TYPE && contain(v.Pos(), v.Pos()+4, start, end) ||
230+
v.Tok == token.CONST && contain(v.Pos(), v.Pos()+5, start, end) ||
231+
v.Tok == token.VAR && contain(v.Pos(), v.Pos()+3, start, end) {
232+
start, end = v.Pos(), v.End()
233+
}
234+
if len(v.Specs) > 0 {
235+
switch spec := v.Specs[0].(type) {
236+
case *ast.TypeSpec:
237+
id = spec.Name
238+
case *ast.ValueSpec:
239+
id = spec.Names[0]
240+
}
241+
}
242+
}
243+
// selection cannot partially intersect a node
244+
if !contain(start, end, decl.Pos(), decl.End()) {
245+
return 0, 0, "", false
246+
}
247+
if id != nil && firstName == "" {
248+
firstName = id.Name
249+
}
250+
// extends selection to docs comments
251+
var c *ast.CommentGroup
252+
switch decl := decl.(type) {
253+
case *ast.GenDecl:
254+
c = decl.Doc
255+
case *ast.FuncDecl:
256+
c = decl.Doc
257+
}
258+
if c != nil && c.Pos() < start {
259+
start = c.Pos()
260+
}
261+
}
262+
}
263+
for _, comment := range pgf.File.Comments {
264+
if intersect(start, end, comment.Pos(), comment.End()) {
265+
if !contain(start, end, comment.Pos(), comment.End()) {
266+
// selection cannot partially intersect a comment
267+
return 0, 0, "", false
268+
}
269+
}
270+
}
271+
if firstName == "" {
272+
return 0, 0, "", false
273+
}
274+
return start, end, firstName, true
275+
}
276+
277+
func skipWhiteSpaces(pgf *parsego.File, pos token.Pos) token.Pos {
278+
i := pos
279+
for ; i-pgf.File.FileStart < token.Pos(len(pgf.Src)); i++ {
280+
c := pgf.Src[i-pgf.File.FileStart]
281+
if !(c == ' ' || c == '\t' || c == '\n') {
282+
break
283+
}
284+
}
285+
return i
286+
}
287+
288+
func findParenthesisFreeImports(pgf *parsego.File) map[*ast.ImportSpec]*ast.GenDecl {
289+
decls := make(map[*ast.ImportSpec]*ast.GenDecl)
290+
for _, decl := range pgf.File.Decls {
291+
if g, ok := decl.(*ast.GenDecl); ok {
292+
if !g.Lparen.IsValid() && len(g.Specs) > 0 {
293+
if v, ok := g.Specs[0].(*ast.ImportSpec); ok {
294+
decls[v] = g
295+
}
296+
}
297+
}
298+
}
299+
return decls
300+
}
301+
302+
// removeNode returns a TextEdit that removes the node
303+
func removeNode(pgf *parsego.File, node ast.Node) protocol.TextEdit {
304+
rng, _ := pgf.PosRange(node.Pos(), node.End())
305+
return protocol.TextEdit{Range: rng, NewText: ""}
306+
}
307+
308+
// intersect checks if [a, b) and [c, d) intersect, assuming a <= b and c <= d
309+
func intersect(a, b, c, d token.Pos) bool {
310+
return !(b <= c || d <= a)
311+
}
312+
313+
// contain checks if [a, b) contains [c, d), assuming a <= b and c <= d
314+
func contain(a, b, c, d token.Pos) bool {
315+
return a <= c && d <= b
316+
}

gopls/internal/protocol/command/command_gen.go

Lines changed: 20 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)