@@ -371,46 +371,72 @@ func (g *generator) mockName(typeName string) string {
371
371
return "Mock" + typeName
372
372
}
373
373
374
+ // formattedTypeParams returns a long and short form of type param info used for
375
+ // printing. If analyzing a interface with type param [I any, O any] the result
376
+ // will be:
377
+ // "[I any, O any]", "[I, O]"
378
+ func (g * generator ) formattedTypeParams (it * model.Interface , pkgOverride string ) (string , string ) {
379
+ if len (it .TypeParams ) == 0 {
380
+ return "" , ""
381
+ }
382
+ var long , short strings.Builder
383
+ long .WriteString ("[" )
384
+ short .WriteString ("[" )
385
+ for i , v := range it .TypeParams {
386
+ if i != 0 {
387
+ long .WriteString (", " )
388
+ short .WriteString (", " )
389
+ }
390
+ long .WriteString (v .Name )
391
+ short .WriteString (v .Name )
392
+ long .WriteString (fmt .Sprintf (" %s" , v .Type .String (g .packageMap , pkgOverride )))
393
+ }
394
+ long .WriteString ("]" )
395
+ short .WriteString ("]" )
396
+ return long .String (), short .String ()
397
+ }
398
+
374
399
func (g * generator ) GenerateMockInterface (intf * model.Interface , outputPackagePath string ) error {
375
400
mockType := g .mockName (intf .Name )
401
+ longTp , shortTp := g .formattedTypeParams (intf , outputPackagePath )
376
402
377
403
g .p ("" )
378
404
g .p ("// %v is a mock of %v interface." , mockType , intf .Name )
379
- g .p ("type %v struct {" , mockType )
405
+ g .p ("type %v%v struct {" , mockType , longTp )
380
406
g .in ()
381
407
g .p ("ctrl *gomock.Controller" )
382
- g .p ("recorder *%vMockRecorder" , mockType )
408
+ g .p ("recorder *%vMockRecorder%v " , mockType , shortTp )
383
409
g .out ()
384
410
g .p ("}" )
385
411
g .p ("" )
386
412
387
413
g .p ("// %vMockRecorder is the mock recorder for %v." , mockType , mockType )
388
- g .p ("type %vMockRecorder struct {" , mockType )
414
+ g .p ("type %vMockRecorder%v struct {" , mockType , longTp )
389
415
g .in ()
390
- g .p ("mock *%v" , mockType )
416
+ g .p ("mock *%v%v " , mockType , shortTp )
391
417
g .out ()
392
418
g .p ("}" )
393
419
g .p ("" )
394
420
395
421
g .p ("// New%v creates a new mock instance." , mockType )
396
- g .p ("func New%v(ctrl *gomock.Controller) *%v {" , mockType , mockType )
422
+ g .p ("func New%v%v (ctrl *gomock.Controller) *%v%v {" , mockType , longTp , mockType , shortTp )
397
423
g .in ()
398
- g .p ("mock := &%v{ctrl: ctrl}" , mockType )
399
- g .p ("mock.recorder = &%vMockRecorder{mock}" , mockType )
424
+ g .p ("mock := &%v%v {ctrl: ctrl}" , mockType , shortTp )
425
+ g .p ("mock.recorder = &%vMockRecorder%v {mock}" , mockType , shortTp )
400
426
g .p ("return mock" )
401
427
g .out ()
402
428
g .p ("}" )
403
429
g .p ("" )
404
430
405
431
// XXX: possible name collision here if someone has EXPECT in their interface.
406
432
g .p ("// EXPECT returns an object that allows the caller to indicate expected use." )
407
- g .p ("func (m *%v) EXPECT() *%vMockRecorder {" , mockType , mockType )
433
+ g .p ("func (m *%v%v ) EXPECT() *%vMockRecorder%v {" , mockType , shortTp , mockType , shortTp )
408
434
g .in ()
409
435
g .p ("return m.recorder" )
410
436
g .out ()
411
437
g .p ("}" )
412
438
413
- g .GenerateMockMethods (mockType , intf , outputPackagePath )
439
+ g .GenerateMockMethods (mockType , intf , outputPackagePath , shortTp )
414
440
415
441
return nil
416
442
}
@@ -421,13 +447,13 @@ func (b byMethodName) Len() int { return len(b) }
421
447
func (b byMethodName ) Swap (i , j int ) { b [i ], b [j ] = b [j ], b [i ] }
422
448
func (b byMethodName ) Less (i , j int ) bool { return b [i ].Name < b [j ].Name }
423
449
424
- func (g * generator ) GenerateMockMethods (mockType string , intf * model.Interface , pkgOverride string ) {
450
+ func (g * generator ) GenerateMockMethods (mockType string , intf * model.Interface , pkgOverride , shortTp string ) {
425
451
sort .Sort (byMethodName (intf .Methods ))
426
452
for _ , m := range intf .Methods {
427
453
g .p ("" )
428
- _ = g .GenerateMockMethod (mockType , m , pkgOverride )
454
+ _ = g .GenerateMockMethod (mockType , m , pkgOverride , shortTp )
429
455
g .p ("" )
430
- _ = g .GenerateMockRecorderMethod (mockType , m )
456
+ _ = g .GenerateMockRecorderMethod (mockType , m , shortTp )
431
457
}
432
458
}
433
459
@@ -446,7 +472,7 @@ func makeArgString(argNames, argTypes []string) string {
446
472
447
473
// GenerateMockMethod generates a mock method implementation.
448
474
// If non-empty, pkgOverride is the package in which unqualified types reside.
449
- func (g * generator ) GenerateMockMethod (mockType string , m * model.Method , pkgOverride string ) error {
475
+ func (g * generator ) GenerateMockMethod (mockType string , m * model.Method , pkgOverride , shortTp string ) error {
450
476
argNames := g .getArgNames (m )
451
477
argTypes := g .getArgTypes (m , pkgOverride )
452
478
argString := makeArgString (argNames , argTypes )
@@ -467,7 +493,7 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
467
493
idRecv := ia .allocateIdentifier ("m" )
468
494
469
495
g .p ("// %v mocks base method." , m .Name )
470
- g .p ("func (%v *%v) %v(%v)%v {" , idRecv , mockType , m .Name , argString , retString )
496
+ g .p ("func (%v *%v%v ) %v(%v)%v {" , idRecv , mockType , shortTp , m .Name , argString , retString )
471
497
g .in ()
472
498
g .p ("%s.ctrl.T.Helper()" , idRecv )
473
499
@@ -511,7 +537,7 @@ func (g *generator) GenerateMockMethod(mockType string, m *model.Method, pkgOver
511
537
return nil
512
538
}
513
539
514
- func (g * generator ) GenerateMockRecorderMethod (mockType string , m * model.Method ) error {
540
+ func (g * generator ) GenerateMockRecorderMethod (mockType string , m * model.Method , shortTp string ) error {
515
541
argNames := g .getArgNames (m )
516
542
517
543
var argString string
@@ -535,7 +561,7 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method)
535
561
idRecv := ia .allocateIdentifier ("mr" )
536
562
537
563
g .p ("// %v indicates an expected call of %v." , m .Name , m .Name )
538
- g .p ("func (%s *%vMockRecorder) %v(%v) *gomock.Call {" , idRecv , mockType , m .Name , argString )
564
+ g .p ("func (%s *%vMockRecorder%v ) %v(%v) *gomock.Call {" , idRecv , mockType , shortTp , m .Name , argString )
539
565
g .in ()
540
566
g .p ("%s.mock.ctrl.T.Helper()" , idRecv )
541
567
@@ -558,7 +584,7 @@ func (g *generator) GenerateMockRecorderMethod(mockType string, m *model.Method)
558
584
callArgs = ", " + idVarArgs + "..."
559
585
}
560
586
}
561
- g .p (`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s)(nil).%s)%s)` , idRecv , idRecv , m .Name , mockType , m .Name , callArgs )
587
+ g .p (`return %s.mock.ctrl.RecordCallWithMethodType(%s.mock, "%s", reflect.TypeOf((*%s%s )(nil).%s)%s)` , idRecv , idRecv , m .Name , mockType , shortTp , m .Name , callArgs )
562
588
563
589
g .out ()
564
590
g .p ("}" )
0 commit comments