diff --git a/gomock/call.go b/gomock/call.go index a3fa1ae4..4a8a11f1 100644 --- a/gomock/call.go +++ b/gomock/call.go @@ -380,19 +380,21 @@ func (c *Call) matches(args []interface{}) error { } } - // Check that all prerequisite calls have been satisfied. + // Check that the call is not exhausted. + if c.exhausted() { + return fmt.Errorf("Expected call at %s has already been called the max number of times.", c.origin) + } + + return nil +} + +func (c *Call) arePreReqsSatisfied() error { for _, preReqCall := range c.preReqs { if !preReqCall.satisfied() { return fmt.Errorf("Expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v", c.origin, preReqCall, c) } } - - // Check that the call is not exhausted. - if c.exhausted() { - return fmt.Errorf("Expected call at %s has already been called the max number of times.", c.origin) - } - return nil } diff --git a/gomock/callset.go b/gomock/callset.go index c44a8a58..1556204a 100644 --- a/gomock/callset.go +++ b/gomock/callset.go @@ -73,6 +73,8 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac err := call.matches(args) if err != nil { fmt.Fprintf(&callsErrors, "\n%v", err) + } else if err := call.arePreReqsSatisfied(); err != nil { + fmt.Fprintf(&callsErrors, "\n%v", err) } else { return call, nil } @@ -94,6 +96,41 @@ func (cs callSet) FindMatch(receiver interface{}, method string, args []interfac return nil, fmt.Errorf(callsErrors.String()) } +// FindMatch searches for a matching call. Returns error with explanation message if no call matched. +func (cs callSet) FindLooseMatch(receiver interface{}, method string, args []interface{}) (*Call, error) { + key := callSetKey{receiver, method} + + // Search through the expected calls. + expected := cs.expected[key] + var callsErrors bytes.Buffer + for _, call := range expected { + err := call.matches(args) + if err != nil { + continue + } else if err = call.arePreReqsSatisfied(); err != nil { + fmt.Fprintf(&callsErrors, "\n%v", err) + } else { + return call, nil + } + } + + // If we haven't found a match then search through the exhausted calls so we + // get useful error messages. + exhausted := cs.exhausted[key] + for _, call := range exhausted { + if err := call.matches(args); err != nil { + fmt.Fprintf(&callsErrors, "\n%v", err) + } + } + + errString := callsErrors.String() + if errString == "" { + return nil, nil + } else { + return nil, fmt.Errorf(callsErrors.String()) + } +} + // Failures returns the calls that are not satisfied. func (cs callSet) Failures() []*Call { failures := make([]*Call, 0, len(cs.expected)) diff --git a/gomock/controller.go b/gomock/controller.go index a7b79188..dc9d1f15 100644 --- a/gomock/controller.go +++ b/gomock/controller.go @@ -57,10 +57,11 @@ package gomock import ( "fmt" - "golang.org/x/net/context" "reflect" "runtime" "sync" + + "golang.org/x/net/context" ) // A TestReporter is something that can be used to report test failures. @@ -78,6 +79,7 @@ type Controller struct { t TestReporter expectedCalls *callSet finished bool + LooseMode bool } func NewController(t TestReporter) *Controller { @@ -144,12 +146,25 @@ func (ctrl *Controller) Call(receiver interface{}, method string, args ...interf ctrl.mu.Lock() defer ctrl.mu.Unlock() - expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args) + var actions []func([]interface{}) []interface{} + var expected *Call + var err error + if !ctrl.LooseMode { + expected, err = ctrl.expectedCalls.FindMatch(receiver, method, args) + } else { + expected, err = ctrl.expectedCalls.FindLooseMatch(receiver, method, args) + } if err != nil { origin := callerInfo(2) ctrl.t.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err) } + // this is to protect against nil dereference for calls that are not + // expected + if expected == nil && ctrl.LooseMode { + return actions + } + // Two things happen here: // * the matching call no longer needs to check prerequite calls, // * and the prerequite calls are no longer expected, so remove them. @@ -158,7 +173,8 @@ func (ctrl *Controller) Call(receiver interface{}, method string, args ...interf ctrl.expectedCalls.Remove(preReqCall) } - actions := expected.call(args) + actions = expected.call(args) + if expected.exhausted() { ctrl.expectedCalls.Remove(expected) } diff --git a/gomock/controller_test.go b/gomock/controller_test.go index b3bb4de2..15f9f17f 100644 --- a/gomock/controller_test.go +++ b/gomock/controller_test.go @@ -159,6 +159,16 @@ func createFixtures(t *testing.T) (reporter *ErrorReporter, ctrl *gomock.Control // successful or failed. reporter = NewErrorReporter(t) ctrl = gomock.NewController(reporter) + ctrl.LooseMode = false + return +} + +func createLooseFixtures(t *testing.T) (reporter *ErrorReporter, ctrl *gomock.Controller) { + // Same as above only this one enables LooseMode which won't + // fail for unexpected calls + reporter = NewErrorReporter(t) + ctrl = gomock.NewController(reporter) + ctrl.LooseMode = true return } @@ -168,7 +178,7 @@ func TestNoCalls(t *testing.T) { reporter.assertPass("No calls expected or made.") } -func TestNoRecordedCallsForAReceiver(t *testing.T) { +func TestNoRecordedCallsForAReceiverStrictMode(t *testing.T) { reporter, ctrl := createFixtures(t) subject := new(Subject) @@ -178,7 +188,17 @@ func TestNoRecordedCallsForAReceiver(t *testing.T) { ctrl.Finish() } -func TestNoRecordedMatchingMethodNameForAReceiver(t *testing.T) { +func TestNoRecordedCallsForAReceiverLooseMode(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.Call(subject, "NotRecordedMethod", "argument") + ctrl.Finish() + + reporter.assertPass("No calls expected but LooseMode does not cause failures") +} + +func TestNoRecordedMatchingMethodNameForAReceiverStrictMode(t *testing.T) { reporter, ctrl := createFixtures(t) subject := new(Subject) @@ -192,6 +212,80 @@ func TestNoRecordedMatchingMethodNameForAReceiver(t *testing.T) { }) } +func TestNoRecordedMatchingMethodNameForAReceiverLooseMode(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "argument") + ctrl.Call(subject, "NotRecordedMethod", "argument") + reporter.assertFatal(func() { + // The expected call wasn't made. + ctrl.Finish() + }) +} + +func TestMakingUnMatchingCallWhereASpecificCallAreExpectedLooseMode(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "argument") + ctrl.Call(subject, "FooMethod", "argument", "morearg2") + ctrl.Call(subject, "FooMethod", "argument1000") + ctrl.Call(subject, "FooMethod", "argument") + reporter.assertPass("Expected method call never made") +} + +func TestMakingUnMatchingCallWhereSpecificCallsAreExpectedNTimesLooseMode(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "argument") + ctrl.Call(subject, "FooMethod", "argument", "morearg2") + ctrl.Call(subject, "FooMethod", "argument1000") + ctrl.Call(subject, "FooMethod", "argument") + reporter.assertFatal(func() { + ctrl.Call(subject, "FooMethod", "argument") + }) +} + +func TestMakingAnUnexpectedCallWhereCallsAreExpectedLooseMode(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "argument") + ctrl.Call(subject, "NotRecordedMethod", "argument") + ctrl.Call(subject, "FooMethod", "argument") + reporter.assertPass("Expected method call made eventually") +} + +func TestMakingExpectedCallsInOrderLooseMode(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "argument2").After( + ctrl.RecordCall(subject, "FooMethod", "argument"), + ) + reporter.assertFatal(func() { + ctrl.Call(subject, "FooMethod", "argument2") + ctrl.Call(subject, "FooMethod", "argument") + }) + reporter.assertFail("Expected method calls in order even in LooseMode") +} + +func TestMakingAnUnexpectedCallWhereCallsAreExpectedStrictMode(t *testing.T) { + reporter, ctrl := createFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "argument") + reporter.assertFatal(func() { + ctrl.Call(subject, "NotRecordedMethod", "argument") + }) + ctrl.Call(subject, "FooMethod", "argument") + ctrl.Finish() + reporter.assertFail("Expected method call made eventually but only after other unexecpted " + + "calls (This is not permitted in loose mode)") +} + // This tests that a call with an arguments of some primitive type matches a recorded call. func TestExpectedMethodCall(t *testing.T) { reporter, ctrl := createFixtures(t) @@ -204,7 +298,7 @@ func TestExpectedMethodCall(t *testing.T) { reporter.assertPass("Expected method call made.") } -func TestUnexpectedMethodCall(t *testing.T) { +func TestUnexpectedMethodCallStrict(t *testing.T) { reporter, ctrl := createFixtures(t) subject := new(Subject) @@ -215,7 +309,17 @@ func TestUnexpectedMethodCall(t *testing.T) { ctrl.Finish() } -func TestRepeatedCall(t *testing.T) { +func TestUnexpectedMethodCallLoose(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.Call(subject, "FooMethod", "argument") + reporter.assertPass("No expected calls in loose mode are allowed") + + ctrl.Finish() +} + +func TestRepeatedCallStrict(t *testing.T) { reporter, ctrl := createFixtures(t) subject := new(Subject) @@ -231,6 +335,22 @@ func TestRepeatedCall(t *testing.T) { reporter.assertFail("After calling one too many times.") } +func TestRepeatedCallLoose(t *testing.T) { + reporter, ctrl := createLooseFixtures(t) + subject := new(Subject) + + ctrl.RecordCall(subject, "FooMethod", "argument").Times(3) + ctrl.Call(subject, "FooMethod", "argument") + ctrl.Call(subject, "FooMethod", "argument") + ctrl.Call(subject, "FooMethod", "argument") + reporter.assertPass("After expected repeated method calls.") + reporter.assertFatal(func() { + ctrl.Call(subject, "FooMethod", "argument") + }) + ctrl.Finish() + reporter.assertFail("After calling one too many times.") +} + func TestUnexpectedArgCount(t *testing.T) { reporter, ctrl := createFixtures(t) defer reporter.recoverUnexpectedFatal()