Skip to content
This repository was archived by the owner on Jun 27, 2023. It is now read-only.

Commit 8321731

Browse files
committed
Add DoAndReturn as a complement to Do that does not ignore its return values.
Refactored how all the call actions work to make this easier.
1 parent 61503c5 commit 8321731

File tree

4 files changed

+183
-107
lines changed

4 files changed

+183
-107
lines changed

gomock/call.go

Lines changed: 84 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,11 @@ import (
2525
type Call struct {
2626
t TestReporter // for triggering test failures on invalid call setup
2727

28-
receiver interface{} // the receiver of the method call
29-
method string // the name of the method
30-
methodType reflect.Type // the type of the method
31-
args []Matcher // the args
32-
rets []interface{} // the return values (if any)
33-
origin string // file and line number of call setup
28+
receiver interface{} // the receiver of the method call
29+
method string // the name of the method
30+
methodType reflect.Type // the type of the method
31+
args []Matcher // the args
32+
origin string // file and line number of call setup
3433

3534
preReqs []*Call // prerequisite calls
3635

@@ -39,9 +38,10 @@ type Call struct {
3938

4039
numCalls int // actual number made
4140

42-
// Actions
43-
doFunc reflect.Value
44-
setArgs map[int]reflect.Value
41+
// actions are called when this Call is called. Each action gets the args and
42+
// can set the return values by returning a non-nil slice. Actions run in the
43+
// order they are created.
44+
actions []func([]interface{}) []interface{}
4545
}
4646

4747
// AnyTimes allows the expectation to be called 0 or more times
@@ -70,11 +70,56 @@ func (c *Call) MaxTimes(n int) *Call {
7070
return c
7171
}
7272

73-
// Do declares the action to run when the call is matched.
73+
// DoAndReturn declares the action to run when the call is matched.
74+
// The return values from this function are returned by the mocked function.
75+
// It takes an interface{} argument to support n-arity functions.
76+
func (c *Call) DoAndReturn(f interface{}) *Call {
77+
// TODO: Check arity and types here, rather than dying badly elsewhere.
78+
v := reflect.ValueOf(f)
79+
80+
c.addAction(func(args []interface{}) []interface{} {
81+
vargs := make([]reflect.Value, len(args))
82+
ft := v.Type()
83+
for i := 0; i < len(args); i++ {
84+
if args[i] != nil {
85+
vargs[i] = reflect.ValueOf(args[i])
86+
} else {
87+
// Use the zero value for the arg.
88+
vargs[i] = reflect.Zero(ft.In(i))
89+
}
90+
}
91+
vrets := v.Call(vargs)
92+
rets := make([]interface{}, len(vrets))
93+
for i, ret := range vrets {
94+
rets[i] = ret.Interface()
95+
}
96+
return rets
97+
})
98+
return c
99+
}
100+
101+
// Do declares the action to run when the call is matched. The function's
102+
// return values are ignored to retain backward compatibility. To use the
103+
// return values call DoAndReturn.
74104
// It takes an interface{} argument to support n-arity functions.
75105
func (c *Call) Do(f interface{}) *Call {
76106
// TODO: Check arity and types here, rather than dying badly elsewhere.
77-
c.doFunc = reflect.ValueOf(f)
107+
v := reflect.ValueOf(f)
108+
109+
c.addAction(func(args []interface{}) []interface{} {
110+
vargs := make([]reflect.Value, len(args))
111+
ft := v.Type()
112+
for i := 0; i < len(args); i++ {
113+
if args[i] != nil {
114+
vargs[i] = reflect.ValueOf(args[i])
115+
} else {
116+
// Use the zero value for the arg.
117+
vargs[i] = reflect.Zero(ft.In(i))
118+
}
119+
v.Call(vargs)
120+
}
121+
return nil
122+
})
78123
return c
79124
}
80125

@@ -113,7 +158,10 @@ func (c *Call) Return(rets ...interface{}) *Call {
113158
}
114159
}
115160

116-
c.rets = rets
161+
c.addAction(func([]interface{}) []interface{} {
162+
return rets
163+
})
164+
117165
return c
118166
}
119167

@@ -131,9 +179,6 @@ func (c *Call) SetArg(n int, value interface{}) *Call {
131179
h.Helper()
132180
}
133181

134-
if c.setArgs == nil {
135-
c.setArgs = make(map[int]reflect.Value)
136-
}
137182
mt := c.methodType
138183
// TODO: This will break on variadic methods.
139184
// We will need to check those at invocation time.
@@ -159,7 +204,17 @@ func (c *Call) SetArg(n int, value interface{}) *Call {
159204
c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]",
160205
n, at, c.origin)
161206
}
162-
c.setArgs[n] = reflect.ValueOf(value)
207+
208+
c.addAction(func(args []interface{}) []interface{} {
209+
v := reflect.ValueOf(value)
210+
switch reflect.TypeOf(args[n]).Kind() {
211+
case reflect.Slice:
212+
setSlice(args[n], v)
213+
default:
214+
reflect.ValueOf(args[n]).Elem().Set(v)
215+
}
216+
return nil
217+
})
163218
return c
164219
}
165220

@@ -296,43 +351,21 @@ func (c *Call) dropPrereqs() (preReqs []*Call) {
296351
return
297352
}
298353

299-
func (c *Call) call(args []interface{}) (rets []interface{}, action func()) {
300-
c.numCalls++
301-
302-
// Actions
303-
if c.doFunc.IsValid() {
304-
doArgs := make([]reflect.Value, len(args))
305-
ft := c.doFunc.Type()
306-
for i := 0; i < len(args); i++ {
307-
if args[i] != nil {
308-
doArgs[i] = reflect.ValueOf(args[i])
309-
} else {
310-
// Use the zero value for the arg.
311-
doArgs[i] = reflect.Zero(ft.In(i))
312-
}
313-
}
314-
action = func() { c.doFunc.Call(doArgs) }
315-
}
316-
for n, v := range c.setArgs {
317-
switch reflect.TypeOf(args[n]).Kind() {
318-
case reflect.Slice:
319-
setSlice(args[n], v)
320-
default:
321-
reflect.ValueOf(args[n]).Elem().Set(v)
322-
}
323-
}
324-
325-
rets = c.rets
326-
if rets == nil {
354+
func (c *Call) defaultActions() []func([]interface{}) []interface{} {
355+
return []func([]interface{}) []interface{}{func([]interface{}) []interface{} {
327356
// Synthesize the zero value for each of the return args' types.
328357
mt := c.methodType
329-
rets = make([]interface{}, mt.NumOut())
358+
rets := make([]interface{}, mt.NumOut())
330359
for i := 0; i < mt.NumOut(); i++ {
331360
rets[i] = reflect.Zero(mt.Out(i)).Interface()
332361
}
333-
}
362+
return rets
363+
}}
364+
}
334365

335-
return
366+
func (c *Call) call(args []interface{}) []func([]interface{}) []interface{} {
367+
c.numCalls++
368+
return c.actions
336369
}
337370

338371
// InOrder declares that the given calls should occur in order.
@@ -348,3 +381,7 @@ func setSlice(arg interface{}, v reflect.Value) {
348381
va.Index(i).Set(v.Index(i))
349382
}
350383
}
384+
385+
func (c *Call) addAction(action func([]interface{}) []interface{}) {
386+
c.actions = append(c.actions, action)
387+
}

gomock/call_test.go

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package gomock
22

33
import (
4-
"reflect"
54
"testing"
65
)
76

@@ -50,34 +49,3 @@ func TestCall_After(t *testing.T) {
5049
}
5150
})
5251
}
53-
54-
func TestCall_SetArg(t *testing.T) {
55-
t.Run("SetArgSlice", func(t *testing.T) {
56-
c := &Call{
57-
methodType: reflect.TypeOf(func([]byte) {}),
58-
t: &mockTestReporter{},
59-
}
60-
c.SetArg(0, []byte{1, 2, 3})
61-
62-
in := []byte{4, 5, 6}
63-
c.call([]interface{}{in})
64-
65-
if in[0] != 1 || in[1] != 2 || in[2] != 3 {
66-
t.Error("Expected SetArg() to modify input slice argument")
67-
}
68-
})
69-
70-
t.Run("SetArgPointer", func(t *testing.T) {
71-
c := &Call{
72-
methodType: reflect.TypeOf(func(*int) {}),
73-
t: &mockTestReporter{},
74-
}
75-
c.SetArg(0, 42)
76-
77-
in := 43
78-
c.call([]interface{}{&in})
79-
if in != 42 {
80-
t.Error("Expected SetArg() to modify value pointed to by argument")
81-
}
82-
})
83-
}

gomock/controller.go

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,7 @@ func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...
116116
}
117117
}
118118
ctrl.t.Fatalf("gomock: failed finding method %s on %T", method, receiver)
119-
// In case t.Fatalf does not panic.
120-
panic(fmt.Sprintf("gomock: failed finding method %s on %T", method, receiver))
119+
panic("unreachable")
121120
}
122121

123122
func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
@@ -140,6 +139,7 @@ func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method st
140139

141140
origin := callerInfo(2)
142141
call := &Call{t: ctrl.t, receiver: receiver, method: method, methodType: methodType, args: margs, origin: origin, minCalls: 1, maxCalls: 1}
142+
call.actions = call.defaultActions()
143143

144144
ctrl.expectedCalls.Add(call)
145145
return call
@@ -150,36 +150,37 @@ func (ctrl *Controller) Call(receiver interface{}, method string, args ...interf
150150
h.Helper()
151151
}
152152

153-
ctrl.mu.Lock()
154-
defer ctrl.mu.Unlock()
153+
// Nest this code so we can use defer to make sure the lock is released.
154+
actions := func() []func([]interface{}) []interface{} {
155+
ctrl.mu.Lock()
156+
defer ctrl.mu.Unlock()
155157

156-
expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
157-
if err != nil {
158-
origin := callerInfo(2)
159-
ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
160-
}
158+
expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
159+
if err != nil {
160+
origin := callerInfo(2)
161+
ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
162+
}
161163

162-
// Two things happen here:
163-
// * the matching call no longer needs to check prerequite calls,
164-
// * and the prerequite calls are no longer expected, so remove them.
165-
preReqCalls := expected.dropPrereqs()
166-
for _, preReqCall := range preReqCalls {
167-
ctrl.expectedCalls.Remove(preReqCall)
168-
}
164+
// Two things happen here:
165+
// * the matching call no longer needs to check prerequite calls,
166+
// * and the prerequite calls are no longer expected, so remove them.
167+
preReqCalls := expected.dropPrereqs()
168+
for _, preReqCall := range preReqCalls {
169+
ctrl.expectedCalls.Remove(preReqCall)
170+
}
169171

170-
rets, action := expected.call(args)
171-
if expected.exhausted() {
172-
ctrl.expectedCalls.Remove(expected)
173-
}
172+
actions := expected.call(args)
173+
if expected.exhausted() {
174+
ctrl.expectedCalls.Remove(expected)
175+
}
176+
return actions
177+
}()
174178

175-
// Don't hold the lock while doing the call's action (if any)
176-
// so that actions may execute concurrently.
177-
// We use the deferred Unlock to capture any panics that happen above;
178-
// here we add a deferred Lock to balance it.
179-
ctrl.mu.Unlock()
180-
defer ctrl.mu.Lock()
181-
if action != nil {
182-
action()
179+
var rets []interface{}
180+
for _, action := range actions {
181+
if r := action(args); r != nil {
182+
rets = r
183+
}
183184
}
184185

185186
return rets

0 commit comments

Comments
 (0)