Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

Commit 7105dde

Browse files
authored
refactor mockgen and cleanup (#536)
1 parent 58935d8 commit 7105dde

File tree

9 files changed

+220
-298
lines changed

9 files changed

+220
-298
lines changed

gomock/call.go

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,16 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle
5050
t.Helper()
5151

5252
// TODO: check arity, types.
53-
margs := make([]Matcher, len(args))
53+
mArgs := make([]Matcher, len(args))
5454
for i, arg := range args {
5555
if m, ok := arg.(Matcher); ok {
56-
margs[i] = m
56+
mArgs[i] = m
5757
} else if arg == nil {
5858
// Handle nil specially so that passing a nil interface value
5959
// will match the typed nils of concrete args.
60-
margs[i] = Nil()
60+
mArgs[i] = Nil()
6161
} else {
62-
margs[i] = Eq(arg)
62+
mArgs[i] = Eq(arg)
6363
}
6464
}
6565

@@ -76,7 +76,7 @@ func newCall(t TestHelper, receiver interface{}, method string, methodType refle
7676
return rets
7777
}}
7878
return &Call{t: t, receiver: receiver, method: method, methodType: methodType,
79-
args: margs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
79+
args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
8080
}
8181

8282
// AnyTimes allows the expectation to be called 0 or more times
@@ -113,19 +113,19 @@ func (c *Call) DoAndReturn(f interface{}) *Call {
113113
v := reflect.ValueOf(f)
114114

115115
c.addAction(func(args []interface{}) []interface{} {
116-
vargs := make([]reflect.Value, len(args))
116+
vArgs := make([]reflect.Value, len(args))
117117
ft := v.Type()
118118
for i := 0; i < len(args); i++ {
119119
if args[i] != nil {
120-
vargs[i] = reflect.ValueOf(args[i])
120+
vArgs[i] = reflect.ValueOf(args[i])
121121
} else {
122122
// Use the zero value for the arg.
123-
vargs[i] = reflect.Zero(ft.In(i))
123+
vArgs[i] = reflect.Zero(ft.In(i))
124124
}
125125
}
126-
vrets := v.Call(vargs)
127-
rets := make([]interface{}, len(vrets))
128-
for i, ret := range vrets {
126+
vRets := v.Call(vArgs)
127+
rets := make([]interface{}, len(vRets))
128+
for i, ret := range vRets {
129129
rets[i] = ret.Interface()
130130
}
131131
return rets
@@ -142,17 +142,17 @@ func (c *Call) Do(f interface{}) *Call {
142142
v := reflect.ValueOf(f)
143143

144144
c.addAction(func(args []interface{}) []interface{} {
145-
vargs := make([]reflect.Value, len(args))
145+
vArgs := make([]reflect.Value, len(args))
146146
ft := v.Type()
147147
for i := 0; i < len(args); i++ {
148148
if args[i] != nil {
149-
vargs[i] = reflect.ValueOf(args[i])
149+
vArgs[i] = reflect.ValueOf(args[i])
150150
} else {
151151
// Use the zero value for the arg.
152-
vargs[i] = reflect.Zero(ft.In(i))
152+
vArgs[i] = reflect.Zero(ft.In(i))
153153
}
154154
}
155-
v.Call(vargs)
155+
v.Call(vArgs)
156156
return nil
157157
})
158158
return c
@@ -353,12 +353,12 @@ func (c *Call) matches(args []interface{}) error {
353353
// matches all the remaining arguments or the lack of any.
354354
// Convert the remaining arguments, if any, into a slice of the
355355
// expected type.
356-
vargsType := c.methodType.In(c.methodType.NumIn() - 1)
357-
vargs := reflect.MakeSlice(vargsType, 0, len(args)-i)
356+
vArgsType := c.methodType.In(c.methodType.NumIn() - 1)
357+
vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i)
358358
for _, arg := range args[i:] {
359-
vargs = reflect.Append(vargs, reflect.ValueOf(arg))
359+
vArgs = reflect.Append(vArgs, reflect.ValueOf(arg))
360360
}
361-
if m.Matches(vargs.Interface()) {
361+
if m.Matches(vArgs.Interface()) {
362362
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any())
363363
// Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher)
364364
// Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any())

gomock/callset_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ func TestCallSetFindMatch(t *testing.T) {
8484

8585
c1 := newCall(t, receiver, method, reflect.TypeOf(receiverType{}.Func))
8686
cs.exhausted = map[callSetKey][]*Call{
87-
callSetKey{receiver: receiver, fname: method}: []*Call{c1},
87+
{receiver: receiver, fname: method}: {c1},
8888
}
8989

9090
_, err := cs.FindMatch(receiver, method, args)

gomock/matchers.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ func (n notMatcher) Matches(x interface{}) bool {
153153
}
154154

155155
func (n notMatcher) String() string {
156-
// TODO: Improve this if we add a NotString method to the Matcher interface.
157156
return "not(" + n.m.String() + ")"
158157
}
159158

mockgen/mockgen.go

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838

3939
"github.com/golang/mock/mockgen/model"
4040

41+
"golang.org/x/mod/modfile"
4142
toolsimports "golang.org/x/tools/imports"
4243
)
4344

@@ -84,6 +85,7 @@ func main() {
8485
log.Fatal("Expected exactly two arguments")
8586
}
8687
packageName = flag.Arg(0)
88+
interfaces := strings.Split(flag.Arg(1), ",")
8789
if packageName == "." {
8890
dir, err := os.Getwd()
8991
if err != nil {
@@ -94,7 +96,7 @@ func main() {
9496
log.Fatalf("Parse package name failed: %v", err)
9597
}
9698
}
97-
pkg, err = reflectMode(packageName, strings.Split(flag.Arg(1), ","))
99+
pkg, err = reflectMode(packageName, interfaces)
98100
}
99101
if err != nil {
100102
log.Fatalf("Loading input failed: %v", err)
@@ -394,11 +396,6 @@ func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePa
394396
g.p("}")
395397
g.p("")
396398

397-
// TODO: Re-enable this if we can import the interface reliably.
398-
// g.p("// Verify that the mock satisfies the interface at compile time.")
399-
// g.p("var _ %v = (*%v)(nil)", typeName, mockType)
400-
// g.p("")
401-
402399
g.p("// New%v creates a new mock instance.", mockType)
403400
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
404401
g.in()
@@ -665,3 +662,44 @@ func printVersion() {
665662
printModuleVersion()
666663
}
667664
}
665+
666+
// parseImportPackage get package import path via source file
667+
// an alternative implementation is to use:
668+
// cfg := &packages.Config{Mode: packages.NeedName, Tests: true, Dir: srcDir}
669+
// pkgs, err := packages.Load(cfg, "file="+source)
670+
// However, it will call "go list" and slow down the performance
671+
func parsePackageImport(srcDir string) (string, error) {
672+
moduleMode := os.Getenv("GO111MODULE")
673+
// trying to find the module
674+
if moduleMode != "off" {
675+
currentDir := srcDir
676+
for {
677+
dat, err := ioutil.ReadFile(filepath.Join(currentDir, "go.mod"))
678+
if os.IsNotExist(err) {
679+
if currentDir == filepath.Dir(currentDir) {
680+
// at the root
681+
break
682+
}
683+
currentDir = filepath.Dir(currentDir)
684+
continue
685+
} else if err != nil {
686+
return "", err
687+
}
688+
modulePath := modfile.ModulePath(dat)
689+
return filepath.ToSlash(filepath.Join(modulePath, strings.TrimPrefix(srcDir, currentDir))), nil
690+
}
691+
}
692+
// fall back to GOPATH mode
693+
goPaths := os.Getenv("GOPATH")
694+
if goPaths == "" {
695+
return "", fmt.Errorf("GOPATH is not set")
696+
}
697+
goPathList := strings.Split(goPaths, string(os.PathListSeparator))
698+
for _, goPath := range goPathList {
699+
sourceRoot := filepath.Join(goPath, "src") + string(os.PathSeparator)
700+
if strings.HasPrefix(srcDir, sourceRoot) {
701+
return filepath.ToSlash(strings.TrimPrefix(srcDir, sourceRoot)), nil
702+
}
703+
}
704+
return "", errOutsideGoPath
705+
}

mockgen/mockgen_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package main
22

33
import (
44
"fmt"
5+
"io/ioutil"
6+
"os"
7+
"path/filepath"
58
"reflect"
69
"regexp"
710
"strings"
@@ -364,3 +367,85 @@ func Test_createPackageMap(t *testing.T) {
364367
})
365368
}
366369
}
370+
371+
func TestParsePackageImport_FallbackGoPath(t *testing.T) {
372+
goPath, err := ioutil.TempDir("", "gopath")
373+
if err != nil {
374+
t.Error(err)
375+
}
376+
defer func() {
377+
if err = os.RemoveAll(goPath); err != nil {
378+
t.Error(err)
379+
}
380+
}()
381+
srcDir := filepath.Join(goPath, "src/example.com/foo")
382+
err = os.MkdirAll(srcDir, 0755)
383+
if err != nil {
384+
t.Error(err)
385+
}
386+
key := "GOPATH"
387+
value := goPath
388+
if err := os.Setenv(key, value); err != nil {
389+
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
390+
}
391+
key = "GO111MODULE"
392+
value = "on"
393+
if err := os.Setenv(key, value); err != nil {
394+
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
395+
}
396+
pkgPath, err := parsePackageImport(srcDir)
397+
expected := "example.com/foo"
398+
if pkgPath != expected {
399+
t.Errorf("expect %s, got %s", expected, pkgPath)
400+
}
401+
}
402+
403+
func TestParsePackageImport_FallbackMultiGoPath(t *testing.T) {
404+
var goPathList []string
405+
406+
// first gopath
407+
goPath, err := ioutil.TempDir("", "gopath1")
408+
if err != nil {
409+
t.Error(err)
410+
}
411+
goPathList = append(goPathList, goPath)
412+
defer func() {
413+
if err = os.RemoveAll(goPath); err != nil {
414+
t.Error(err)
415+
}
416+
}()
417+
srcDir := filepath.Join(goPath, "src/example.com/foo")
418+
err = os.MkdirAll(srcDir, 0755)
419+
if err != nil {
420+
t.Error(err)
421+
}
422+
423+
// second gopath
424+
goPath, err = ioutil.TempDir("", "gopath2")
425+
if err != nil {
426+
t.Error(err)
427+
}
428+
goPathList = append(goPathList, goPath)
429+
defer func() {
430+
if err = os.RemoveAll(goPath); err != nil {
431+
t.Error(err)
432+
}
433+
}()
434+
435+
goPaths := strings.Join(goPathList, string(os.PathListSeparator))
436+
key := "GOPATH"
437+
value := goPaths
438+
if err := os.Setenv(key, value); err != nil {
439+
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
440+
}
441+
key = "GO111MODULE"
442+
value = "on"
443+
if err := os.Setenv(key, value); err != nil {
444+
t.Fatalf("unable to set environment variable %q to %q: %v", key, value, err)
445+
}
446+
pkgPath, err := parsePackageImport(srcDir)
447+
expected := "example.com/foo"
448+
if pkgPath != expected {
449+
t.Errorf("expect %s, got %s", expected, pkgPath)
450+
}
451+
}

mockgen/model/model.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func (intf *Interface) addImports(im map[string]bool) {
7171
}
7272
}
7373

74-
// AddMethod adds a new method, deduplicating by method name.
74+
// AddMethod adds a new method, de-duplicating by method name.
7575
func (intf *Interface) AddMethod(m *Method) {
7676
for _, me := range intf.Methods {
7777
if me.Name == m.Name {
@@ -260,11 +260,10 @@ func (mt *MapType) addImports(im map[string]bool) {
260260
// NamedType is an exported type in a package.
261261
type NamedType struct {
262262
Package string // may be empty
263-
Type string // TODO: should this be typed Type?
263+
Type string
264264
}
265265

266266
func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
267-
// TODO: is this right?
268267
if pkgOverride == nt.Package {
269268
return nt.Type
270269
}

0 commit comments

Comments
 (0)