diff --git a/README.md b/README.md index a4213e50..17a7d5ce 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,23 @@ the version of mockgen used to generate your mocks. ## Running mockgen -`mockgen` has two modes of operation: source and reflect. +`mockgen` has three modes of operation: archive, source and reflect. + +### Archive mode + +Archive mode generates mock interfaces from a package archive +file (.a). It is enabled by using the -archive flag, the import +path is also needed as a non-flag argument. No other flags are +required. + +Example: + +```bash +# Build the package to a archive. +go build -o pkg.a database/sql/driver + +mockgen -archive=pkg.a database/sql/driver +``` ### Source mode @@ -71,6 +87,8 @@ The `mockgen` command is used to generate source code for a mock class given a Go source file containing interfaces to be mocked. It supports the following flags: +- `-archive`: A package archive file containing interfaces to be mocked. + - `-source`: A file containing interfaces to be mocked. - `-destination`: A file to which to write the resulting source code. If you diff --git a/mockgen/archive.go b/mockgen/archive.go new file mode 100644 index 00000000..3aeb7e5e --- /dev/null +++ b/mockgen/archive.go @@ -0,0 +1,55 @@ +package main + +import ( + "fmt" + "go/token" + "go/types" + "log" + "os" + + "github.com/golang/mock/mockgen/model" + + "golang.org/x/tools/go/gcexportdata" +) + +func archiveMode(importpath, archive string) (*model.Package, error) { + f, err := os.Open(archive) + if err != nil { + return nil, err + } + defer f.Close() + r, err := gcexportdata.NewReader(f) + if err != nil { + return nil, fmt.Errorf("read export data %q: %v", archive, err) + } + + fset := token.NewFileSet() + imports := make(map[string]*types.Package) + tp, err := gcexportdata.Read(r, fset, imports, importpath) + if err != nil { + return nil, err + } + + pkg := &model.Package{ + Name: tp.Name(), + PkgPath: tp.Path(), + } + for _, name := range tp.Scope().Names() { + m := tp.Scope().Lookup(name) + tn, ok := m.(*types.TypeName) + if !ok { + continue + } + ti, ok := tn.Type().Underlying().(*types.Interface) + if !ok { + continue + } + it, err := model.InterfaceFromGoTypesType(ti) + if err != nil { + log.Fatal(err) + } + it.Name = m.Name() + pkg.Interfaces = append(pkg.Interfaces, it) + } + return pkg, nil +} diff --git a/mockgen/mockgen.go b/mockgen/mockgen.go index 79cb921c..d3305bc1 100644 --- a/mockgen/mockgen.go +++ b/mockgen/mockgen.go @@ -54,6 +54,7 @@ var ( ) var ( + archive = flag.String("archive", "", "(archive mode) Input Go archive file; enables archive mode.") source = flag.String("source", "", "(source mode) Input Go source file; enables source mode.") destination = flag.String("destination", "", "Output file; defaults to stdout.") mockNames = flag.String("mock_names", "", "Comma-separated interfaceName=mockName pairs of explicit mock names to use. Mock names default to 'Mock'+ interfaceName suffix.") @@ -80,6 +81,12 @@ func main() { var packageName string if *source != "" { pkg, err = sourceMode(*source) + } else if *archive != "" { + if flag.NArg() != 1 { + usage() + log.Fatal("Expected exactly one argument") + } + pkg, err = archiveMode(flag.Arg(0), *archive) } else { if flag.NArg() != 2 { usage() @@ -139,6 +146,8 @@ func main() { g := new(generator) if *source != "" { g.filename = *source + } else if *archive != "" { + g.filename = *archive } else { g.srcPackage = packageName g.srcInterfaces = flag.Arg(1) @@ -201,7 +210,14 @@ func usage() { flag.PrintDefaults() } -const usageText = `mockgen has two modes of operation: source and reflect. +const usageText = `mockgen has three modes of operation: archive, source and reflect. + +Archive mode generates mock interfaces from a package archive +file (.a). It is enabled by using the -archive flag, the import +path is also needed as a non-flag argument. No other flags are +required. +Example: + mockgen -archive=pkg.a importpath Source mode generates mock interfaces from a source file. It is enabled by using the -source flag. Other flags that diff --git a/mockgen/model/model_gotypes.go b/mockgen/model/model_gotypes.go new file mode 100644 index 00000000..2b1a976a --- /dev/null +++ b/mockgen/model/model_gotypes.go @@ -0,0 +1,164 @@ +package model + +import ( + "fmt" + "go/types" +) + +// InterfaceFromGoTypesType returns a pointer to an interface for the +// given reflection interface type. +func InterfaceFromGoTypesType(it *types.Interface) (*Interface, error) { + intf := &Interface{} + + for i := 0; i < it.NumMethods(); i++ { + mt := it.Method(i) + // TODO: need to skip unexported methods? or just raise an error? + m := &Method{ + Name: mt.Name(), + } + + var err error + m.In, m.Variadic, m.Out, err = funcArgsFromGoTypesType(mt.Type().(*types.Signature)) + if err != nil { + return nil, err + } + + intf.AddMethod(m) + } + + return intf, nil +} + +func funcArgsFromGoTypesType(t *types.Signature) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) { + nin := t.Params().Len() + if t.Variadic() { + nin-- + } + var p *Parameter + for i := 0; i < nin; i++ { + p, err = parameterFromGoTypesType(t.Params().At(i), false) + if err != nil { + return + } + in = append(in, p) + } + if t.Variadic() { + p, err = parameterFromGoTypesType(t.Params().At(nin), true) + if err != nil { + return + } + variadic = p + } + for i := 0; i < t.Results().Len(); i++ { + p, err = parameterFromGoTypesType(t.Results().At(i), false) + if err != nil { + return + } + out = append(out, p) + } + return +} + +func parameterFromGoTypesType(v *types.Var, variadic bool) (*Parameter, error) { + t := v.Type() + if variadic { + t = t.(*types.Slice).Elem() + } + tt, err := typeFromGoTypesType(t) + if err != nil { + return nil, err + } + return &Parameter{Name: v.Name(), Type: tt}, nil +} + +func typeFromGoTypesType(t types.Type) (Type, error) { + // Hack workaround for https://golang.org/issue/3853. + // This explicit check should not be necessary. + // if t == byteType { + // return PredeclaredType("byte"), nil + // } + + if t, ok := t.(*types.Named); ok { + tn := t.Obj() + if tn.Pkg() == nil { + return PredeclaredType(tn.Name()), nil + } + return &NamedType{ + Package: tn.Pkg().Path(), + Type: tn.Name(), + }, nil + } + + // only unnamed or predeclared types after here + + // Lots of types have element types. Let's do the parsing and error checking for all of them. + var elemType Type + if t, ok := t.(interface{ Elem() types.Type }); ok { + var err error + elemType, err = typeFromGoTypesType(t.Elem()) + if err != nil { + return nil, err + } + } + + switch t := t.(type) { + case *types.Array: + return &ArrayType{ + Len: int(t.Len()), + Type: elemType, + }, nil + case *types.Basic: + return PredeclaredType(t.String()), nil + case *types.Chan: + var dir ChanDir + switch t.Dir() { + case types.RecvOnly: + dir = RecvDir + case types.SendOnly: + dir = SendDir + } + return &ChanType{ + Dir: dir, + Type: elemType, + }, nil + case *types.Signature: + in, variadic, out, err := funcArgsFromGoTypesType(t) + if err != nil { + return nil, err + } + return &FuncType{ + In: in, + Out: out, + Variadic: variadic, + }, nil + case *types.Interface: + if t.NumMethods() == 0 { + return PredeclaredType("interface{}"), nil + } + case *types.Map: + kt, err := typeFromGoTypesType(t.Key()) + if err != nil { + return nil, err + } + return &MapType{ + Key: kt, + Value: elemType, + }, nil + case *types.Pointer: + return &PointerType{ + Type: elemType, + }, nil + case *types.Slice: + return &ArrayType{ + Len: -1, + Type: elemType, + }, nil + case *types.Struct: + if t.NumFields() == 0 { + return PredeclaredType("struct{}"), nil + } + } + + // TODO: Struct, UnsafePointer + return nil, fmt.Errorf("can't yet turn %v (%T) into a model.Type", t.String(), t) +}