@@ -14,12 +14,12 @@ import (
14
14
"go/ast"
15
15
"go/token"
16
16
"go/types"
17
- "html/template"
18
17
"os"
19
18
"path/filepath"
20
19
"sort"
21
20
"strconv"
22
21
"strings"
22
+ "text/template"
23
23
"unicode"
24
24
25
25
"golang.org/x/tools/go/ast/astutil"
@@ -34,44 +34,34 @@ import (
34
34
35
35
const testTmplString = `
36
36
func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
37
- {{- /* Constructor input parameters struct declaration. */}}
38
- {{- if and .Receiver .Receiver.Constructor}}
39
- {{- if gt (len .Receiver.Constructor.Args) 1}}
40
- type constructorArgs struct {
41
- {{- range .Receiver.Constructor.Args}}
42
- {{.Name}} {{.Type}}
43
- {{- end}}
44
- }
45
- {{- end}}
46
- {{- end}}
47
-
48
- {{- /* Functions/methods input parameters struct declaration. */}}
49
- {{- if gt (len .Func.Args) 1}}
50
- type args struct {
51
- {{- range .Func.Args}}
52
- {{.Name}} {{.Type}}
53
- {{- end}}
54
- }
55
- {{- end}}
56
-
57
37
{{- /* Test cases struct declaration and empty initialization. */}}
58
38
tests := []struct {
59
39
name string // description of this test case
40
+
41
+ {{- $commentPrinted := false }}
60
42
{{- if and .Receiver .Receiver.Constructor}}
61
- {{- if gt (len .Receiver.Constructor.Args) 1}}
62
- constructorArgs constructorArgs
43
+ {{- range .Receiver.Constructor.Args}}
44
+ {{- if .Name}}
45
+ {{- if not $commentPrinted}}
46
+ // Named input parameters for receiver constructor.
47
+ {{- $commentPrinted = true }}
48
+ {{- end}}
49
+ {{.Name}} {{.Type}}
63
50
{{- end}}
64
- {{- if eq (len .Receiver.Constructor.Args) 1}}
65
- constructorArg {{(index .Receiver.Constructor.Args 0).Type}}
66
51
{{- end}}
67
52
{{- end}}
68
53
69
- {{- if gt (len .Func.Args) 1}}
70
- args args
54
+ {{- $commentPrinted := false }}
55
+ {{- range .Func.Args}}
56
+ {{- if .Name}}
57
+ {{- if not $commentPrinted}}
58
+ // Named input parameters for target function.
59
+ {{- $commentPrinted = true }}
60
+ {{- end}}
61
+ {{.Name}} {{.Type}}
71
62
{{- end}}
72
- {{- if eq (len .Func.Args) 1}}
73
- arg {{(index .Func.Args 0).Type}}
74
63
{{- end}}
64
+
75
65
{{- range $index, $res := .Func.Results}}
76
66
{{- if eq $res.Name "gotErr"}}
77
67
wantErr bool
@@ -96,7 +86,12 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
96
86
{{- .Receiver.Constructor.Name}}
97
87
98
88
{{- /* Constructor input parameters. */ -}}
99
- ({{- if eq (len .Receiver.Constructor.Args) 1}}tt.constructorArg{{end}}{{if gt (len .Func.Args) 1}}{{fieldNames .Receiver.Constructor.Args "tt.constructorArgs."}}{{end}})
89
+ (
90
+ {{- range $index, $arg := .Receiver.Constructor.Args}}
91
+ {{- if ne $index 0}}, {{end}}
92
+ {{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
93
+ {{- end -}}
94
+ )
100
95
101
96
{{- /* Handles the error return from constructor. */}}
102
97
{{- $last := last .Receiver.Constructor.Results}}
@@ -123,7 +118,12 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
123
118
{{- end}}{{.Func.Name}}
124
119
125
120
{{- /* Input parameters. */ -}}
126
- ({{- if eq (len .Func.Args) 1}}tt.arg{{end}}{{if gt (len .Func.Args) 1}}{{fieldNames .Func.Args "tt.args."}}{{end}})
121
+ (
122
+ {{- range $index, $arg := .Func.Args}}
123
+ {{- if ne $index 0}}, {{end}}
124
+ {{- if .Name}}tt.{{.Name}}{{else}}{{.Value}}{{end}}
125
+ {{- end -}}
126
+ )
127
127
128
128
{{- /* Handles the returned error before the rest of return value. */}}
129
129
{{- $last := last .Func.Results}}
@@ -155,8 +155,12 @@ func {{.TestFuncName}}(t *{{.TestingPackageName}}.T) {
155
155
}
156
156
`
157
157
158
+ // Name is the name of the field this input parameter should reference.
159
+ // Value is the expression this input parameter should accept.
160
+ //
161
+ // Exactly one of Name or Value must be set.
158
162
type field struct {
159
- Name , Type string
163
+ Name , Type , Value string
160
164
}
161
165
162
166
type function struct {
@@ -191,6 +195,9 @@ type testInfo struct {
191
195
var testTmpl = template .Must (template .New ("test" ).Funcs (template.FuncMap {
192
196
"add" : func (a , b int ) int { return a + b },
193
197
"last" : func (slice []field ) field {
198
+ if len (slice ) == 0 {
199
+ return field {}
200
+ }
194
201
return slice [len (slice )- 1 ]
195
202
},
196
203
"fieldNames" : func (fields []field , qualifier string ) (res string ) {
@@ -450,36 +457,32 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
450
457
451
458
errorType := types .Universe .Lookup ("error" ).Type ()
452
459
453
- // TODO(hxjiang): if input parameter is not named (meaning it's not used),
454
- // pass the zero value to the function call.
455
- // TODO(hxjiang): if the input parameter is named, define the field by using
456
- // the parameter's name instead of in%d.
457
460
// TODO(hxjiang): handle special case for ctx.Context input.
458
- for index := range sig .Params ().Len () {
459
- var name string
460
- if index == 0 {
461
- name = "in"
461
+ for i := range sig .Params ().Len () {
462
+ param := sig .Params ().At (i )
463
+ name , typ := param .Name (), param .Type ()
464
+ f := field {Type : types .TypeString (typ , qf )}
465
+ if name == "" || name == "_" {
466
+ f .Value = typesinternal .ZeroString (typ , qf )
462
467
} else {
463
- name = fmt . Sprintf ( "in%d" , index + 1 )
468
+ f . Name = name
464
469
}
465
- data .Func .Args = append (data .Func .Args , field {
466
- Name : name ,
467
- Type : types .TypeString (sig .Params ().At (index ).Type (), qf ),
468
- })
470
+ data .Func .Args = append (data .Func .Args , f )
469
471
}
470
472
471
- for index := range sig .Results ().Len () {
473
+ for i := range sig .Results ().Len () {
474
+ typ := sig .Results ().At (i ).Type ()
472
475
var name string
473
- if index == sig .Results ().Len ()- 1 && types .Identical (sig . Results (). At ( index ). Type () , errorType ) {
476
+ if i == sig .Results ().Len ()- 1 && types .Identical (typ , errorType ) {
474
477
name = "gotErr"
475
- } else if index == 0 {
478
+ } else if i == 0 {
476
479
name = "got"
477
480
} else {
478
- name = fmt .Sprintf ("got%d" , index + 1 )
481
+ name = fmt .Sprintf ("got%d" , i + 1 )
479
482
}
480
483
data .Func .Results = append (data .Func .Results , field {
481
484
Name : name ,
482
- Type : types .TypeString (sig . Results (). At ( index ). Type () , qf ),
485
+ Type : types .TypeString (typ , qf ),
483
486
})
484
487
}
485
488
@@ -587,25 +590,25 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
587
590
588
591
if constructor != nil {
589
592
data .Receiver .Constructor = & function {Name : constructor .Name ()}
590
- for index := range constructor .Signature ().Params ().Len () {
591
- var name string
592
- if index == 0 {
593
- name = "in"
593
+ for i := range constructor .Signature ().Params ().Len () {
594
+ param := constructor .Signature ().Params ().At (i )
595
+ name , typ := param .Name (), param .Type ()
596
+ f := field {Type : types .TypeString (typ , qf )}
597
+ if name == "" || name == "_" {
598
+ f .Value = typesinternal .ZeroString (typ , qf )
594
599
} else {
595
- name = fmt . Sprintf ( "in%d" , index + 1 )
600
+ f . Name = name
596
601
}
597
- data .Receiver .Constructor .Args = append (data .Receiver .Constructor .Args , field {
598
- Name : name ,
599
- Type : types .TypeString (constructor .Signature ().Params ().At (index ).Type (), qf ),
600
- })
602
+ data .Receiver .Constructor .Args = append (data .Receiver .Constructor .Args , f )
601
603
}
602
- for index := range constructor .Signature ().Results ().Len () {
604
+ for i := range constructor .Signature ().Results ().Len () {
605
+ typ := constructor .Signature ().Results ().At (i ).Type ()
603
606
var name string
604
- if index == 0 {
607
+ if i == 0 {
605
608
// The first return value must be of type T, *T, or a type whose named
606
609
// type is the same as named type of T.
607
610
name = varName
608
- } else if index == constructor .Signature ().Results ().Len ()- 1 && types .Identical (constructor . Signature (). Results (). At ( index ). Type () , errorType ) {
611
+ } else if i == constructor .Signature ().Results ().Len ()- 1 && types .Identical (typ , errorType ) {
609
612
name = "err"
610
613
} else {
611
614
// Drop any return values beyond the first and the last.
@@ -614,12 +617,48 @@ func AddTestForFunc(ctx context.Context, snapshot *cache.Snapshot, loc protocol.
614
617
}
615
618
data .Receiver .Constructor .Results = append (data .Receiver .Constructor .Results , field {
616
619
Name : name ,
617
- Type : types .TypeString (constructor . Signature (). Results (). At ( index ). Type () , qf ),
620
+ Type : types .TypeString (typ , qf ),
618
621
})
619
622
}
620
623
}
621
624
}
622
625
626
+ // Resolves duplicate parameter names between the function and its
627
+ // receiver's constructor. It adds prefix to the constructor's parameters
628
+ // until no conflicts remain.
629
+ if data .Receiver != nil && data .Receiver .Constructor != nil {
630
+ seen := map [string ]bool {}
631
+ for _ , f := range data .Func .Args {
632
+ if f .Name == "" {
633
+ continue
634
+ }
635
+ seen [f .Name ] = true
636
+ }
637
+
638
+ // "" for no change, "c" for constructor, "i" for input.
639
+ for _ , prefix := range []string {"" , "c" , "c_" , "i" , "i_" } {
640
+ conflict := false
641
+ for _ , f := range data .Receiver .Constructor .Args {
642
+ if f .Name == "" {
643
+ continue
644
+ }
645
+ if seen [prefix + f .Name ] {
646
+ conflict = true
647
+ break
648
+ }
649
+ }
650
+ if ! conflict {
651
+ for i , f := range data .Receiver .Constructor .Args {
652
+ if f .Name == "" {
653
+ continue
654
+ }
655
+ data .Receiver .Constructor .Args [i ].Name = prefix + data .Receiver .Constructor .Args [i ].Name
656
+ }
657
+ break
658
+ }
659
+ }
660
+ }
661
+
623
662
// Compute edits to update imports.
624
663
//
625
664
// If we're adding to an existing test file, we need to adjust existing
0 commit comments