Skip to content

Commit 90e3327

Browse files
committed
add support for generics
Fixes: golang#621
1 parent 32e424a commit 90e3327

File tree

6 files changed

+131
-77
lines changed

6 files changed

+131
-77
lines changed
Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020 Google LLC
1+
// Copyright 2022 Google LLC
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,18 +12,16 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
//go:build !go1.14
16-
// +build !go1.14
15+
//go:build go1.18
16+
// +build go1.18
1717

18-
package gomock_test
18+
package main
1919

20-
import "testing"
20+
import "go/ast"
2121

22-
func TestDuplicateFinishCallFails(t *testing.T) {
23-
rep, ctrl := createFixtures(t)
24-
25-
ctrl.Finish()
26-
rep.assertPass("the first Finish call should succeed")
27-
28-
rep.assertFatal(ctrl.Finish, "Controller.Finish was called more than once. It has to be called exactly once.")
22+
func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field {
23+
if ts == nil || ts.TypeParams == nil {
24+
return nil
25+
}
26+
return ts.TypeParams.List
2927
}
Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019 Google LLC
1+
// Copyright 2022 Google LLC
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -12,16 +12,13 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
//go:build !go1.12
16-
// +build !go1.12
15+
//go:build !go1.18
16+
// +build !go1.18
1717

1818
package main
1919

20-
import (
21-
"log"
22-
)
20+
import "go/ast"
2321

24-
func printModuleVersion() {
25-
log.Printf("No version information is available for Mockgen compiled with " +
26-
"version 1.11")
22+
func getTypeSpecTypeParams(ts *ast.TypeSpec) []*ast.Field {
23+
return nil
2724
}

mockgen/mockgen.go

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -371,46 +371,72 @@ func (g *generator) mockName(typeName string) string {
371371
return "Mock" + typeName
372372
}
373373

374+
// formattedTypeParams returns a long and short form of type param info used for
375+
// printing. If analyzing a interface with type param [I any, O any] the result
376+
// will be:
377+
// "[I any, O any]", "[I, O]"
378+
func (g *generator) formattedTypeParams(it *model.Interface, pkgOverride string) (string, string) {
379+
if len(it.TypeParams) == 0 {
380+
return "", ""
381+
}
382+
var long, short strings.Builder
383+
long.WriteString("[")
384+
short.WriteString("[")
385+
for i, v := range it.TypeParams {
386+
if i != 0 {
387+
long.WriteString(", ")
388+
short.WriteString(", ")
389+
}
390+
long.WriteString(v.Name)
391+
short.WriteString(v.Name)
392+
long.WriteString(fmt.Sprintf(" %s", v.Type.String(g.packageMap, pkgOverride)))
393+
}
394+
long.WriteString("]")
395+
short.WriteString("]")
396+
return long.String(), short.String()
397+
}
398+
374399
func (g *generator) GenerateMockInterface(intf *model.Interface, outputPackagePath string) error {
375400
mockType := g.mockName(intf.Name)
401+
longTp, shortTp := g.formattedTypeParams(intf, outputPackagePath)
376402

377403
g.p("")
378404
g.p("// %v is a mock of %v interface.", mockType, intf.Name)
379-
g.p("type %v struct {", mockType)
405+
g.p("type %v%v struct {", mockType, longTp)
380406
g.in()
381407
g.p("ctrl *gomock.Controller")
382-
g.p("recorder *%vMockRecorder", mockType)
408+
g.p("recorder *%vMockRecorder%v", mockType, shortTp)
383409
g.out()
384410
g.p("}")
385411
g.p("")
386412

387413
g.p("// %vMockRecorder is the mock recorder for %v.", mockType, mockType)
388-
g.p("type %vMockRecorder struct {", mockType)
414+
g.p("type %vMockRecorder%v struct {", mockType, longTp)
389415
g.in()
390-
g.p("mock *%v", mockType)
416+
g.p("mock *%v%v", mockType, shortTp)
391417
g.out()
392418
g.p("}")
393419
g.p("")
394420

395421
g.p("// New%v creates a new mock instance.", mockType)
396-
g.p("func New%v(ctrl *gomock.Controller) *%v {", mockType, mockType)
422+
g.p("func New%v%v(ctrl *gomock.Controller) *%v%v {", mockType, longTp, mockType, shortTp)
397423
g.in()
398-
g.p("mock := &%v{ctrl: ctrl}", mockType)
399-
g.p("mock.recorder = &%vMockRecorder{mock}", mockType)
424+
g.p("mock := &%v%v{ctrl: ctrl}", mockType, shortTp)
425+
g.p("mock.recorder = &%vMockRecorder%v{mock}", mockType, shortTp)
400426
g.p("return mock")
401427
g.out()
402428
g.p("}")
403429
g.p("")
404430

405431
// XXX: possible name collision here if someone has EXPECT in their interface.
406432
g.p("// EXPECT returns an object that allows the caller to indicate expected use.")
407-
g.p("func (m *%v) EXPECT() *%vMockRecorder {", mockType, mockType)
433+
g.p("func (m *%v%v) EXPECT() *%vMockRecorder%v {", mockType, shortTp, mockType, shortTp)
408434
g.in()
409435
g.p("return m.recorder")
410436
g.out()
411437
g.p("}")
412438

413-
g.GenerateMockMethods(mockType, intf, outputPackagePath)
439+
g.GenerateMockMethods(mockType, intf, outputPackagePath, shortTp)
414440

415441
return nil
416442
}
@@ -421,13 +447,13 @@ func (b byMethodName) Len() int { return len(b) }
421447
func (b byMethodName) Swap(i, j int) { b[i], b[j] = b[j], b[i] }
422448
func (b byMethodName) Less(i, j int) bool { return b[i].Name < b[j].Name }
423449

424-
func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride string) {
450+
func (g *generator) GenerateMockMethods(mockType string, intf *model.Interface, pkgOverride, shortTp string) {
425451
sort.Sort(byMethodName(intf.Methods))
426452
for _, m := range intf.Methods {
427453
g.p("")
428-
_ = g.GenerateMockMethod(mockType, m, pkgOverride)
454+
_ = g.GenerateMockMethod(mockType, m, pkgOverride, shortTp)
429455
g.p("")
430-
_ = g.GenerateMockRecorderMethod(mockType, m)
456+
_ = g.GenerateMockRecorderMethod(mockType, m, shortTp)
431457
}
432458
}
433459

@@ -446,7 +472,7 @@ func makeArgString(argNames, argTypes []string) string {
446472

447473
// GenerateMockMethod generates a mock method implementation.
448474
// If non-empty, pkgOverride is the package in which unqualified types reside.
449-
func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride string) error {
475+
func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOverride, shortTp string) error {
450476
argNames := g.getArgNames(m)
451477
argTypes := g.getArgTypes(m, pkgOverride)
452478
argString := makeArgString(argNames, argTypes)
@@ -467,7 +493,7 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
467493
idRecv := ia.allocateIdentifier("m")
468494

469495
g.p("// %v mocks base method.", m.Name)
470-
g.p("func (%v *%v) %v(%v)%v {", idRecv, mockType, m.Name, argString, retString)
496+
g.p("func (%v *%v%v) %v(%v)%v {", idRecv, mockType, shortTp, m.Name, argString, retString)
471497
g.in()
472498
g.p("%s.ctrl.T.Helper()", idRecv)
473499

@@ -511,7 +537,7 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
511537
return nil
512538
}
513539

514-
func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method) error {
540+
func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method, shortTp string) error {
515541
argNames := g.getArgNames(m)
516542

517543
var argString string
@@ -535,7 +561,7 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method)
535561
idRecv := ia.allocateIdentifier("mr")
536562

537563
g.p("// %v indicates an expected call of %v.", m.Name, m.Name)
538-
g.p("func (%s *%vMockRecorder) %v(%v) *gomock.Call {", idRecv, mockType, m.Name, argString)
564+
g.p("func (%s *%vMockRecorder%v) %v(%v) *gomock.Call {", idRecv, mockType, shortTp, m.Name, argString)
539565
g.in()
540566
g.p("%s.mock.ctrl.T.Helper()", idRecv)
541567

@@ -558,7 +584,7 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method)
558584
callArgs = ", " + idVarArgs + "..."
559585
}
560586
}
561-
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, m.Name, callArgs)
587+
g.p(`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s)(nil).%s)%s)`, idRecv, idRecv, m.Name, mockType, shortTp, m.Name, callArgs)
562588

563589
g.out()
564590
g.p("}")

mockgen/model/model.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ func (pkg *Package) Imports() map[string]bool {
5353

5454
// Interface is a Go interface.
5555
type Interface struct {
56-
Name string
57-
Methods []*Method
56+
Name string
57+
Methods []*Method
58+
TypeParams []*Parameter
5859
}
5960

6061
// Print writes the interface name and its methods.

0 commit comments

Comments
 (0)