Skip to content

Commit 82912d2

Browse files
committed
internal/imports: fix starvation in external candidate search
CL 589975 triggered a latent starvation bug when the findImports operation of internal/imports was cancelled. This CL first reproduces the starvation conditions by refactoring to isolate the algorithm in a unit test. Then the missing select statement is added to fix the bug. addExternalCandidates is also simplified somewhat using x/sync/errgroup. Many thanks to [email protected] for finding the root cause of this starvation. Fixes golang/go#67923 Change-Id: Ib0a12a9a667af84150d84c3e988e460c9ae1d973 Reviewed-on: https://go-review.googlesource.com/c/tools/+/591756 Reviewed-by: Hyang-Ah Hana Kim <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent 3e94830 commit 82912d2

File tree

2 files changed

+148
-72
lines changed

2 files changed

+148
-72
lines changed

internal/imports/fix.go

+103-72
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"unicode"
2828
"unicode/utf8"
2929

30+
"golang.org/x/sync/errgroup"
3031
"golang.org/x/tools/go/ast/astutil"
3132
"golang.org/x/tools/internal/event"
3233
"golang.org/x/tools/internal/gocommand"
@@ -1140,8 +1141,8 @@ type Resolver interface {
11401141
// scan works with callback to search for packages. See scanCallback for details.
11411142
scan(ctx context.Context, callback *scanCallback) error
11421143

1143-
// loadExports returns the set of exported symbols in the package at dir.
1144-
// loadExports may be called concurrently.
1144+
// loadExports returns the package name and set of exported symbols in the
1145+
// package at dir. loadExports may be called concurrently.
11451146
loadExports(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error)
11461147

11471148
// scoreImportPath returns the relevance for an import path.
@@ -1218,54 +1219,52 @@ func addExternalCandidates(ctx context.Context, pass *pass, refs references, fil
12181219
imp *ImportInfo
12191220
pkg *packageInfo
12201221
}
1221-
results := make(chan result, len(refs))
1222+
results := make([]*result, len(refs))
12221223

1223-
ctx, cancel := context.WithCancel(ctx)
1224-
var wg sync.WaitGroup
1225-
defer func() {
1226-
cancel()
1227-
wg.Wait()
1228-
}()
1229-
var (
1230-
firstErr error
1231-
firstErrOnce sync.Once
1232-
)
1233-
for pkgName, symbols := range refs {
1234-
wg.Add(1)
1235-
go func(pkgName string, symbols map[string]bool) {
1236-
defer wg.Done()
1224+
g, ctx := errgroup.WithContext(ctx)
12371225

1238-
found, err := findImport(ctx, pass, found[pkgName], pkgName, symbols)
1226+
searcher := symbolSearcher{
1227+
logf: pass.env.logf,
1228+
srcDir: pass.srcDir,
1229+
xtest: strings.HasSuffix(pass.f.Name.Name, "_test"),
1230+
loadExports: resolver.loadExports,
1231+
}
1232+
1233+
i := 0
1234+
for pkgName, symbols := range refs {
1235+
index := i // claim an index in results
1236+
i++
1237+
pkgName := pkgName
1238+
symbols := symbols
12391239

1240+
g.Go(func() error {
1241+
found, err := searcher.search(ctx, found[pkgName], pkgName, symbols)
12401242
if err != nil {
1241-
firstErrOnce.Do(func() {
1242-
firstErr = err
1243-
cancel()
1244-
})
1245-
return
1243+
return err
12461244
}
1247-
12481245
if found == nil {
1249-
return // No matching package.
1246+
return nil // No matching package.
12501247
}
12511248

12521249
imp := &ImportInfo{
12531250
ImportPath: found.importPathShort,
12541251
}
1255-
12561252
pkg := &packageInfo{
12571253
name: pkgName,
12581254
exports: symbols,
12591255
}
1260-
results <- result{imp, pkg}
1261-
}(pkgName, symbols)
1256+
results[index] = &result{imp, pkg}
1257+
return nil
1258+
})
1259+
}
1260+
if err := g.Wait(); err != nil {
1261+
return err
12621262
}
1263-
go func() {
1264-
wg.Wait()
1265-
close(results)
1266-
}()
12671263

1268-
for result := range results {
1264+
for _, result := range results {
1265+
if result == nil {
1266+
continue
1267+
}
12691268
// Don't offer completions that would shadow predeclared
12701269
// names, such as github.com/coreos/etcd/error.
12711270
if types.Universe.Lookup(result.pkg.name) != nil { // predeclared
@@ -1279,7 +1278,7 @@ func addExternalCandidates(ctx context.Context, pass *pass, refs references, fil
12791278
}
12801279
pass.addCandidate(result.imp, result.pkg)
12811280
}
1282-
return firstErr
1281+
return nil
12831282
}
12841283

12851284
// notIdentifier reports whether ch is an invalid identifier character.
@@ -1669,39 +1668,55 @@ func sortSymbols(syms []stdlib.Symbol) {
16691668
})
16701669
}
16711670

1672-
// findImport searches for a package with the given symbols.
1673-
// If no package is found, findImport returns ("", false, nil)
1674-
func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgName string, symbols map[string]bool) (*pkg, error) {
1671+
// A symbolSearcher searches for a package with a set of symbols, among a set
1672+
// of candidates. See [symbolSearcher.search].
1673+
//
1674+
// The search occurs within the scope of a single file, with context captured
1675+
// in srcDir and xtest.
1676+
type symbolSearcher struct {
1677+
logf func(string, ...any)
1678+
srcDir string // directory containing the file
1679+
xtest bool // if set, the file containing is an x_test file
1680+
loadExports func(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error)
1681+
}
1682+
1683+
// search searches the provided candidates for a package containing all
1684+
// exported symbols.
1685+
//
1686+
// If successful, returns the resulting package.
1687+
func (s *symbolSearcher) search(ctx context.Context, candidates []pkgDistance, pkgName string, symbols map[string]bool) (*pkg, error) {
16751688
// Sort the candidates by their import package length,
16761689
// assuming that shorter package names are better than long
16771690
// ones. Note that this sorts by the de-vendored name, so
16781691
// there's no "penalty" for vendoring.
16791692
sort.Sort(byDistanceOrImportPathShortLength(candidates))
1680-
if pass.env.Logf != nil {
1693+
if s.logf != nil {
16811694
for i, c := range candidates {
1682-
pass.env.Logf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
1695+
s.logf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
16831696
}
16841697
}
1685-
resolver, err := pass.env.GetResolver()
1686-
if err != nil {
1687-
return nil, err
1688-
}
16891698

1690-
// Collect exports for packages with matching names.
1699+
// Arrange rescv so that we can we can await results in order of relevance
1700+
// and exit as soon as we find the first match.
1701+
//
1702+
// Search with bounded concurrency, returning as soon as the first result
1703+
// among rescv is non-nil.
16911704
rescv := make([]chan *pkg, len(candidates))
16921705
for i := range candidates {
16931706
rescv[i] = make(chan *pkg, 1)
16941707
}
16951708
const maxConcurrentPackageImport = 4
16961709
loadExportsSem := make(chan struct{}, maxConcurrentPackageImport)
16971710

1711+
// Ensure that all work is completed at exit.
16981712
ctx, cancel := context.WithCancel(ctx)
16991713
var wg sync.WaitGroup
17001714
defer func() {
17011715
cancel()
17021716
wg.Wait()
17031717
}()
17041718

1719+
// Start the search.
17051720
wg.Add(1)
17061721
go func() {
17071722
defer wg.Done()
@@ -1712,51 +1727,67 @@ func findImport(ctx context.Context, pass *pass, candidates []pkgDistance, pkgNa
17121727
return
17131728
}
17141729

1730+
i := i
1731+
c := c
17151732
wg.Add(1)
1716-
go func(c pkgDistance, resc chan<- *pkg) {
1733+
go func() {
17171734
defer func() {
17181735
<-loadExportsSem
17191736
wg.Done()
17201737
}()
1721-
1722-
pass.env.logf("loading exports in dir %s (seeking package %s)", c.pkg.dir, pkgName)
1723-
// If we're an x_test, load the package under test's test variant.
1724-
includeTest := strings.HasSuffix(pass.f.Name.Name, "_test") && c.pkg.dir == pass.srcDir
1725-
_, exports, err := resolver.loadExports(ctx, c.pkg, includeTest)
1726-
if err != nil {
1727-
pass.env.logf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err)
1728-
resc <- nil
1729-
return
1738+
if s.logf != nil {
1739+
s.logf("loading exports in dir %s (seeking package %s)", c.pkg.dir, pkgName)
17301740
}
1731-
1732-
exportsMap := make(map[string]bool, len(exports))
1733-
for _, sym := range exports {
1734-
exportsMap[sym.Name] = true
1735-
}
1736-
1737-
// If it doesn't have the right
1738-
// symbols, send nil to mean no match.
1739-
for symbol := range symbols {
1740-
if !exportsMap[symbol] {
1741-
resc <- nil
1742-
return
1741+
pkg, err := s.searchOne(ctx, c, symbols)
1742+
if err != nil {
1743+
if s.logf != nil && ctx.Err() == nil {
1744+
s.logf("loading exports in dir %s (seeking package %s): %v", c.pkg.dir, pkgName, err)
17431745
}
1746+
pkg = nil
17441747
}
1745-
resc <- c.pkg
1746-
}(c, rescv[i])
1748+
rescv[i] <- pkg // may be nil
1749+
}()
17471750
}
17481751
}()
17491752

1753+
// Await the first (best) result.
17501754
for _, resc := range rescv {
1751-
pkg := <-resc
1752-
if pkg == nil {
1753-
continue
1755+
select {
1756+
case r := <-resc:
1757+
if r != nil {
1758+
return r, nil
1759+
}
1760+
case <-ctx.Done():
1761+
return nil, ctx.Err()
17541762
}
1755-
return pkg, nil
17561763
}
17571764
return nil, nil
17581765
}
17591766

1767+
func (s *symbolSearcher) searchOne(ctx context.Context, c pkgDistance, symbols map[string]bool) (*pkg, error) {
1768+
if ctx.Err() != nil {
1769+
return nil, ctx.Err()
1770+
}
1771+
// If we're considering the package under test from an x_test, load the
1772+
// test variant.
1773+
includeTest := s.xtest && c.pkg.dir == s.srcDir
1774+
_, exports, err := s.loadExports(ctx, c.pkg, includeTest)
1775+
if err != nil {
1776+
return nil, err
1777+
}
1778+
1779+
exportsMap := make(map[string]bool, len(exports))
1780+
for _, sym := range exports {
1781+
exportsMap[sym.Name] = true
1782+
}
1783+
for symbol := range symbols {
1784+
if !exportsMap[symbol] {
1785+
return nil, nil // no match
1786+
}
1787+
}
1788+
return c.pkg, nil
1789+
}
1790+
17601791
// pkgIsCandidate reports whether pkg is a candidate for satisfying the
17611792
// finding which package pkgIdent in the file named by filename is trying
17621793
// to refer to.

internal/imports/fix_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@ import (
1111
"go/build"
1212
"log"
1313
"os"
14+
"path"
1415
"path/filepath"
1516
"reflect"
1617
"sort"
1718
"strings"
1819
"sync"
20+
"sync/atomic"
1921
"testing"
2022

2123
"golang.org/x/tools/go/packages/packagestest"
@@ -2956,6 +2958,49 @@ var _, _ = fmt.Sprintf, dot.Dot
29562958
}.processTest(t, "golang.org/fake", "x.go", nil, nil, want)
29572959
}
29582960

2961+
func TestSymbolSearchStarvation(t *testing.T) {
2962+
// This test verifies the fix for golang/go#67923: searching through
2963+
// candidates should not starve when the context is cancelled.
2964+
//
2965+
// To reproduce the conditions that led to starvation, cancel the context
2966+
// half way through the search, by leveraging the loadExports callback.
2967+
const candCount = 100
2968+
var loaded atomic.Int32
2969+
ctx, cancel := context.WithCancel(context.Background())
2970+
searcher := symbolSearcher{
2971+
logf: t.Logf,
2972+
srcDir: "/path/to/foo",
2973+
loadExports: func(ctx context.Context, pkg *pkg, includeTest bool) (string, []stdlib.Symbol, error) {
2974+
if loaded.Add(1) > candCount/2 {
2975+
cancel()
2976+
}
2977+
return "bar", []stdlib.Symbol{
2978+
{Name: "A", Kind: stdlib.Var},
2979+
{Name: "B", Kind: stdlib.Var},
2980+
// Missing: "C", so that none of these packages match.
2981+
}, nil
2982+
},
2983+
}
2984+
2985+
var candidates []pkgDistance
2986+
for i := 0; i < candCount; i++ {
2987+
name := fmt.Sprintf("bar%d", i)
2988+
candidates = append(candidates, pkgDistance{
2989+
pkg: &pkg{
2990+
dir: path.Join(searcher.srcDir, name),
2991+
importPathShort: "foo/" + name,
2992+
packageName: name,
2993+
relevance: 1,
2994+
},
2995+
distance: 1,
2996+
})
2997+
}
2998+
2999+
// We don't actually care what happens, as long as it doesn't deadlock!
3000+
_, err := searcher.search(ctx, candidates, "bar", map[string]bool{"A": true, "B": true, "C": true})
3001+
t.Logf("search completed with err: %v", err)
3002+
}
3003+
29593004
func TestMatchesPath(t *testing.T) {
29603005
tests := []struct {
29613006
ident string

0 commit comments

Comments
 (0)