From b79faf3db62251b3b379fd1955eac2840119cefb Mon Sep 17 00:00:00 2001 From: Bruno Bonanno <241804+bbonanno@users.noreply.github.com> Date: Wed, 10 Oct 2018 19:23:53 +0100 Subject: [PATCH] Bugfix: add support for multiple param lists and/or implicits --- .../org/mockito/IdiomaticMockitoTest.scala | 17 +++--- .../scala/org/mockito/DoSomethingMacro.scala | 40 ++++++------- macro/src/main/scala/org/mockito/Utils.scala | 6 ++ .../main/scala/org/mockito/VerifyMacro.scala | 60 +++++++++---------- .../main/scala/org/mockito/WhenMacro.scala | 40 ++++++------- 5 files changed, 86 insertions(+), 77 deletions(-) diff --git a/core/src/test/scala/org/mockito/IdiomaticMockitoTest.scala b/core/src/test/scala/org/mockito/IdiomaticMockitoTest.scala index a4cad64f..cee49f75 100644 --- a/core/src/test/scala/org/mockito/IdiomaticMockitoTest.scala +++ b/core/src/test/scala/org/mockito/IdiomaticMockitoTest.scala @@ -10,6 +10,8 @@ import scala.language.postfixOps class IdiomaticMockitoTest extends WordSpec with scalatest.Matchers with IdiomaticMockito with ArgumentMatchersSugar { + class Implicit[T] + class Foo { def bar = "not mocked" def baz = "not mocked" @@ -28,7 +30,7 @@ class IdiomaticMockitoTest extends WordSpec with scalatest.Matchers with Idiomat def iBlowUp(v: Int, v2: String): String = throw new IllegalArgumentException("I was called!") - def iHaveTypeParams[A, B](a: A, b: B): String = "not mocked" + def iHaveTypeParamsAndImplicits[A, B](a: A, b: B)(implicit v3: Implicit[A]): String = "not mocked" } class Bar { @@ -54,16 +56,17 @@ class IdiomaticMockitoTest extends WordSpec with scalatest.Matchers with Idiomat aMock.bar shouldBe "mocked again!" } - "create a mock where I can mix matchers and normal parameters" in { + "create a mock where I can mix matchers, normal and implicit parameters" in { val aMock = mock[Foo] + implicit val implicitValue: Implicit[Int] = mock[Implicit[Int]] - aMock.iHaveTypeParams[Int, String](*, "test") shouldReturn "mocked!" + aMock.iHaveTypeParamsAndImplicits[Int, String](*, "test") shouldReturn "mocked!" - aMock.iHaveTypeParams(3, "test") shouldBe "mocked!" - aMock.iHaveTypeParams(5, "test") shouldBe "mocked!" - aMock.iHaveTypeParams(5, "est") shouldBe "" + aMock.iHaveTypeParamsAndImplicits(3, "test") shouldBe "mocked!" + aMock.iHaveTypeParamsAndImplicits(5, "test") shouldBe "mocked!" + aMock.iHaveTypeParamsAndImplicits(5, "est") shouldBe "" - aMock.iHaveTypeParams[Int, String](*, "test") wasCalled twice + aMock.iHaveTypeParamsAndImplicits[Int, String](*, "test") wasCalled twice } "stub a real call" in { diff --git a/macro/src/main/scala/org/mockito/DoSomethingMacro.scala b/macro/src/main/scala/org/mockito/DoSomethingMacro.scala index 56900ea1..745dceb7 100644 --- a/macro/src/main/scala/org/mockito/DoSomethingMacro.scala +++ b/macro/src/main/scala/org/mockito/DoSomethingMacro.scala @@ -13,12 +13,12 @@ object DoSomethingMacro { c.Expr[T] { c.macroApplication match { - case q"$_.DoSomethingOps[$r]($v).willBe($_.returned).by[$_]($obj.$method[..$targs](..$args))" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"org.mockito.InternalMockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs](..$newArgs)" + case q"$_.DoSomethingOps[$r]($v).willBe($_.returned).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"org.mockito.InternalMockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs](...$newArgs)" } else - q"org.mockito.InternalMockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs](..$args)" + q"org.mockito.InternalMockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs](...$args)" case q"$_.DoSomethingOps[$r]($v).willBe($_.returned).by[$_]($obj.$method[..$targs])" => q"org.mockito.InternalMockitoSugar.doReturn[$r]($v).when($obj).$method[..$targs]" @@ -33,12 +33,12 @@ object DoSomethingMacro { c.Expr[T] { c.macroApplication match { - case q"$_.DoSomethingOps[$r]($v).willBe($_.answered).by[$_]($obj.$method[..$targs](..$args))" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"org.mockito.Mockito.doAnswer(org.mockito.DoSomethingMacro.argumentToAnswer($v)).when($obj).$method[..$targs](..$newArgs)" + case q"$_.DoSomethingOps[$r]($v).willBe($_.answered).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"org.mockito.Mockito.doAnswer(org.mockito.DoSomethingMacro.argumentToAnswer($v)).when($obj).$method[..$targs](...$newArgs)" } else - q"org.mockito.Mockito.doAnswer(org.mockito.DoSomethingMacro.argumentToAnswer($v)).when($obj).$method[..$targs](..$args)" + q"org.mockito.Mockito.doAnswer(org.mockito.DoSomethingMacro.argumentToAnswer($v)).when($obj).$method[..$targs](...$args)" case q"$_.DoSomethingOps[$r]($v).willBe($_.answered).by[$_]($obj.$method[..$targs])" => q"org.mockito.Mockito.doAnswer(org.mockito.DoSomethingMacro.argumentToAnswer($v)).when($obj).$method[..$targs]" @@ -53,12 +53,12 @@ object DoSomethingMacro { c.Expr[T] { c.macroApplication match { - case q"$_.ThrowSomethingOps[$_]($v).willBe($_.thrown).by[$_]($obj.$method[..$targs](..$args))" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"org.mockito.InternalMockitoSugar.doThrow($v).when($obj).$method[..$targs](..$newArgs)" + case q"$_.ThrowSomethingOps[$_]($v).willBe($_.thrown).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"org.mockito.InternalMockitoSugar.doThrow($v).when($obj).$method[..$targs](...$newArgs)" } else - q"org.mockito.InternalMockitoSugar.doThrow($v).when($obj).$method[..$targs](..$args)" + q"org.mockito.InternalMockitoSugar.doThrow($v).when($obj).$method[..$targs](...$args)" case q"$_.ThrowSomethingOps[$_]($v).willBe($_.thrown).by[$_]($obj.$method[..$targs])" => q"org.mockito.InternalMockitoSugar.doThrow($v).when($obj).$method[..$targs]" @@ -73,12 +73,12 @@ object DoSomethingMacro { c.Expr[T] { c.macroApplication match { - case q"$_.theRealMethod.willBe($_.called).by[$_]($obj.$method[..$targs](..$args))" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"org.mockito.InternalMockitoSugar.doCallRealMethod.when($obj).$method[..$targs](..$newArgs)" + case q"$_.theRealMethod.willBe($_.called).by[$_]($obj.$method[..$targs](...$args))" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"org.mockito.InternalMockitoSugar.doCallRealMethod.when($obj).$method[..$targs](...$newArgs)" } else - q"org.mockito.InternalMockitoSugar.doCallRealMethod.when($obj).$method[..$targs](..$args)" + q"org.mockito.InternalMockitoSugar.doCallRealMethod.when($obj).$method[..$targs](...$args)" case q"$_.theRealMethod.willBe($_.called).by[$_]($obj.$method[..$targs])" => q"org.mockito.InternalMockitoSugar.doCallRealMethod.when($obj).$method[..$targs]" diff --git a/macro/src/main/scala/org/mockito/Utils.scala b/macro/src/main/scala/org/mockito/Utils.scala index d4035b44..a36e6ff5 100644 --- a/macro/src/main/scala/org/mockito/Utils.scala +++ b/macro/src/main/scala/org/mockito/Utils.scala @@ -2,6 +2,9 @@ package org.mockito import scala.reflect.macros.blackbox object Utils { + private[mockito] def hasMatchers(c: blackbox.Context)(args: List[c.Tree]): Boolean = + args.exists(arg => isMatcher(c)(arg)) + private[mockito] def isMatcher(c: blackbox.Context)(arg: c.Tree): Boolean = { import c.universe._ if (arg.toString().contains("org.mockito.matchers.ValueClassMatchers")) true @@ -54,6 +57,9 @@ object Utils { } } + private[mockito] def transformArgs(c: blackbox.Context)(args: List[c.Tree]): List[c.Tree] = + args.map(arg => transformArg(c)(arg)) + private[mockito] def transformArg(c: blackbox.Context)(arg: c.Tree): c.Tree = { import c.universe._ if (isMatcher(c)(arg)) arg diff --git a/macro/src/main/scala/org/mockito/VerifyMacro.scala b/macro/src/main/scala/org/mockito/VerifyMacro.scala index bba10989..985c178d 100644 --- a/macro/src/main/scala/org/mockito/VerifyMacro.scala +++ b/macro/src/main/scala/org/mockito/VerifyMacro.scala @@ -18,12 +18,12 @@ object VerifyMacro { c.Expr[Unit] { c.macroApplication match { - case q"$_.StubbingOps[$_]($obj.$method[..$targs](..$args)).was($_.called)($order)" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"$order.verify($obj).$method[..$targs](..$newArgs)" + case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).was($_.called)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"$order.verify($obj).$method[..$targs](...$newArgs)" } else - q"$order.verify($obj).$method[..$targs](..$args)" + q"$order.verify($obj).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).was($_.called)($order)" => q"$order.verify($obj).$method[..$targs]" @@ -38,12 +38,12 @@ object VerifyMacro { c.Expr[Unit] { c.macroApplication match { - case q"$_.StubbingOps[$_]($obj.$method[..$targs](..$args)).was($_.never).called($order)" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"$order.verifyWithMode($obj, org.mockito.Mockito.never).$method[..$targs](..$newArgs)" + case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).was($_.never).called($order)" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"$order.verifyWithMode($obj, org.mockito.Mockito.never).$method[..$targs](...$newArgs)" } else - q"$order.verifyWithMode($obj, org.mockito.Mockito.never).$method[..$targs](..$args)" + q"$order.verifyWithMode($obj, org.mockito.Mockito.never).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).was($_.never).called($order)" => q"$order.verifyWithMode($obj, org.mockito.Mockito.never).$method[..$targs]" @@ -65,12 +65,12 @@ object VerifyMacro { c.Expr[Unit] { c.macroApplication match { - case q"$_.StubbingOps[$_]($obj.$method[..$targs](..$args)).wasCalled($times)($order)" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"$order.verifyWithMode($obj, org.mockito.Mockito.times($times.times)).$method[..$targs](..$newArgs)" + case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($times)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"$order.verifyWithMode($obj, org.mockito.Mockito.times($times.times)).$method[..$targs](...$newArgs)" } else - q"$order.verifyWithMode($obj, org.mockito.Mockito.times($times.times)).$method[..$targs](..$args)" + q"$order.verifyWithMode($obj, org.mockito.Mockito.times($times.times)).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($times)($order)" => q"$order.verifyWithMode($obj, org.mockito.Mockito.times($times.times)).$method[..$targs]" @@ -87,12 +87,12 @@ object VerifyMacro { c.Expr[Unit] { c.macroApplication match { - case q"$_.StubbingOps[$_]($obj.$method[..$targs](..$args)).wasCalled($times)($order)" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"$order.verifyWithMode($obj, org.mockito.Mockito.atLeast($times.times)).$method[..$targs](..$newArgs)" + case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($times)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"$order.verifyWithMode($obj, org.mockito.Mockito.atLeast($times.times)).$method[..$targs](...$newArgs)" } else - q"$order.verifyWithMode($obj, org.mockito.Mockito.atLeast($times.times)).$method[..$targs](..$args)" + q"$order.verifyWithMode($obj, org.mockito.Mockito.atLeast($times.times)).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($times)($order)" => q"$order.verifyWithMode($obj, org.mockito.Mockito.atLeast($times.times)).$method[..$targs]" @@ -109,12 +109,12 @@ object VerifyMacro { c.Expr[Unit] { c.macroApplication match { - case q"$_.StubbingOps[$_]($obj.$method[..$targs](..$args)).wasCalled($times)($order)" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"$order.verifyWithMode($obj, org.mockito.Mockito.atMost($times.times)).$method[..$targs](..$newArgs)" + case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($times)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"$order.verifyWithMode($obj, org.mockito.Mockito.atMost($times.times)).$method[..$targs](...$newArgs)" } else - q"$order.verifyWithMode($obj, org.mockito.Mockito.atMost($times.times)).$method[..$targs](..$args)" + q"$order.verifyWithMode($obj, org.mockito.Mockito.atMost($times.times)).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($times)($order)" => q"$order.verifyWithMode($obj, org.mockito.Mockito.atMost($times.times)).$method[..$targs]" @@ -131,12 +131,12 @@ object VerifyMacro { c.Expr[Unit] { c.macroApplication match { - case q"$_.StubbingOps[$_]($obj.$method[..$targs](..$args)).wasCalled($_)($order)" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"$order.verifyWithMode($obj, org.mockito.Mockito.only).$method[..$targs](..$newArgs)" + case q"$_.StubbingOps[$_]($obj.$method[..$targs](...$args)).wasCalled($_)($order)" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"$order.verifyWithMode($obj, org.mockito.Mockito.only).$method[..$targs](...$newArgs)" } else - q"$order.verifyWithMode($obj, org.mockito.Mockito.only).$method[..$targs](..$args)" + q"$order.verifyWithMode($obj, org.mockito.Mockito.only).$method[..$targs](...$args)" case q"$_.StubbingOps[$_]($obj.$method[..$targs]).wasCalled($_)($order)" => q"$order.verifyWithMode($obj, org.mockito.Mockito.only).$method[..$targs]" diff --git a/macro/src/main/scala/org/mockito/WhenMacro.scala b/macro/src/main/scala/org/mockito/WhenMacro.scala index 12514133..9999eb25 100644 --- a/macro/src/main/scala/org/mockito/WhenMacro.scala +++ b/macro/src/main/scala/org/mockito/WhenMacro.scala @@ -18,12 +18,12 @@ object WhenMacro { c.Expr[ReturnActions[T]] { c.macroApplication match { - case q"$_.StubbingOps[$t]($obj.$method[..$targs](..$args)).shouldReturn" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"new org.mockito.WhenMacro.ReturnActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$newArgs)))" + case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldReturn" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"new org.mockito.WhenMacro.ReturnActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)))" } else - q"new org.mockito.WhenMacro.ReturnActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$args)))" + q"new org.mockito.WhenMacro.ReturnActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)))" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldReturn" => q"new org.mockito.WhenMacro.ReturnActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs]))" @@ -39,12 +39,12 @@ object WhenMacro { c.Expr[ScalaOngoingStubbing[T]] { c.macroApplication match { - case q"$_.StubbingOps[$t]($obj.$method[..$targs](..$args)).shouldCallRealMethod" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"new org.mockito.stubbing.ScalaOngoingStubbing(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$newArgs)).thenCallRealMethod())" + case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldCallRealMethod" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"new org.mockito.stubbing.ScalaOngoingStubbing(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)).thenCallRealMethod())" } else - q"new org.mockito.stubbing.ScalaOngoingStubbing(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$args)).thenCallRealMethod())" + q"new org.mockito.stubbing.ScalaOngoingStubbing(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)).thenCallRealMethod())" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldCallRealMethod" => q"new org.mockito.stubbing.ScalaOngoingStubbing(org.mockito.Mockito.when[$t]($obj.$method[..$targs]).thenCallRealMethod())" @@ -64,12 +64,12 @@ object WhenMacro { c.Expr[ThrowActions[T]] { c.macroApplication match { - case q"$_.StubbingOps[$t]($obj.$method[..$targs](..$args)).shouldThrow" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"new org.mockito.WhenMacro.ThrowActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$newArgs)))" + case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldThrow" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"new org.mockito.WhenMacro.ThrowActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)))" } else - q"new org.mockito.WhenMacro.ThrowActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$args)))" + q"new org.mockito.WhenMacro.ThrowActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)))" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldThrow" => q"new org.mockito.WhenMacro.ThrowActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs]))" @@ -111,12 +111,12 @@ object WhenMacro { c.Expr[AnswerActions[T]] { c.macroApplication match { - case q"$_.StubbingOps[$t]($obj.$method[..$targs](..$args)).shouldAnswer" => - if (args.exists(a => isMatcher(c)(a))) { - val newArgs: Seq[Tree] = args.map(a => transformArg(c)(a)) - q"new org.mockito.WhenMacro.AnswerActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$newArgs)))" + case q"$_.StubbingOps[$t]($obj.$method[..$targs](...$args)).shouldAnswer" => + if (args.exists(a => hasMatchers(c)(a))) { + val newArgs = args.map(a => transformArgs(c)(a)) + q"new org.mockito.WhenMacro.AnswerActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$newArgs)))" } else - q"new org.mockito.WhenMacro.AnswerActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](..$args)))" + q"new org.mockito.WhenMacro.AnswerActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs](...$args)))" case q"$_.StubbingOps[$t]($obj.$method[..$targs]).shouldAnswer" => q"new org.mockito.WhenMacro.AnswerActions(org.mockito.Mockito.when[$t]($obj.$method[..$targs]))"