Skip to content

Commit cf80ab2

Browse files
authored
Add support to mock object methods (#297)
* Update Dependencies * Add support to mock object methods
1 parent a12b704 commit cf80ab2

File tree

20 files changed

+138
-63
lines changed

20 files changed

+138
-63
lines changed

.scalafmt.conf

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ spaces.inImportCurlyBraces = true
77
indentOperator = spray
88
unindentTopLevelOperators = true
99

10-
version=2.6.4
10+
version=2.7.2
1111

README.md

+29
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,35 @@ val failingMock: Foo = mock[Foo].returnsMT[ErrorOr, MyClass](*) raises Error("er
748748
val failingMock: Foo = mock[Foo].returnsMT[ErrorOr, MyClass](*) returns Left(Error("error"))
749749
```
750750
751+
## Mocking Scala `object`
752+
753+
Since version 1.16.0 it is possible to mock `object` methods, given that such definitions are global, the way to mock them is sligtly different in order to ensure we restore the real implementation of the object after we are done
754+
Example:
755+
756+
```scala
757+
object FooObject {
758+
def simpleMethod: String = "not mocked!"
759+
}
760+
761+
"mock" should {
762+
"stub an object method" in {
763+
FooObject.simpleMethod shouldBe "not mocked!"
764+
765+
withObjectMocked[FooObject.type] {
766+
FooObject.simpleMethod returns "mocked!"
767+
//or
768+
when(FooObject.simpleMethod) thenReturn "mocked!"
769+
770+
FooObject.simpleMethod shouldBe "mocked!"
771+
}
772+
773+
FooObject.simpleMethod shouldBe "not mocked!"
774+
}
775+
}
776+
```
777+
778+
As you can see, the effect of the mocking is only available inside the code block passed to `withObjectMocked`, when such block ends the behavior of the object is restored to its original implementation
779+
751780
752781
## Notes
753782

common/src/main/scala/org/mockito/MockitoAPI.scala

+13-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ private[mockito] trait DoSomething {
6868
* match argument types (`Type`)}}}
6969
*/
7070
def doReturn[T: ValueClassExtractor](toBeReturned: T, toBeReturnedNext: T*): Stubber =
71-
toBeReturnedNext.foldLeft(Mockito.doAnswer(ScalaReturns(toBeReturned))) {
72-
case (s, v) => s.doAnswer(ScalaReturns(v))
71+
toBeReturnedNext.foldLeft(Mockito.doAnswer(ScalaReturns(toBeReturned))) { case (s, v) =>
72+
s.doAnswer(ScalaReturns(v))
7373
}
7474

7575
/**
@@ -616,6 +616,17 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
616616
* they are created as final classes by the compiler
617617
*/
618618
def spyLambda[T <: AnyRef: ClassTag](realObj: T): T = Mockito.mock(clazz, AdditionalAnswers.delegatesTo(realObj))
619+
620+
/**
621+
* Mocks the specified object only for the context of the block
622+
*/
623+
def withObjectMocked[O <: AnyRef: ClassTag](block: => Any): Unit = {
624+
val moduleField = clazz[O].getDeclaredField("MODULE$")
625+
val realImpl = moduleField.get(null)
626+
ReflectionUtils.setFinalStatic(moduleField, mock[O])
627+
try block
628+
finally ReflectionUtils.setFinalStatic(moduleField, realImpl)
629+
}
619630
}
620631

621632
private[mockito] trait Verifications {

common/src/main/scala/org/mockito/ReflectionUtils.scala

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.mockito
22

3-
import java.lang.reflect.Method
3+
import java.lang.reflect.{ Field, Method, Modifier }
44

55
import org.mockito.internal.ValueClassWrapper
66
import org.mockito.invocation.InvocationOnMock
@@ -122,4 +122,12 @@ object ReflectionUtils {
122122
.toOption
123123
.getOrElse(Seq.empty)
124124
}
125+
126+
def setFinalStatic(field: Field, newValue: Any) = {
127+
field.setAccessible(true)
128+
val modifiersField = classOf[Field].getDeclaredField("modifiers")
129+
modifiersField.setAccessible(true)
130+
modifiersField.setInt(field, field.getModifiers & ~Modifier.FINAL)
131+
field.set(null, newValue)
132+
}
125133
}

common/src/main/scala/org/mockito/internal/invocation/ScalaInvocation.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ class ScalaInvocation(
5353
other match {
5454
case that: ScalaInvocation =>
5555
super.equals(that) &&
56-
getMock == that.getMock &&
57-
mockitoMethod == that.mockitoMethod &&
58-
(arguments sameElements that.arguments)
56+
getMock == that.getMock &&
57+
mockitoMethod == that.mockitoMethod &&
58+
(arguments sameElements that.arguments)
5959
case _ => false
6060
}
6161
override def hashCode(): Int = {

common/src/main/scala/org/mockito/stubbing/ScalaBaseStubbing.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@ abstract class ScalaBaseStubbing[T: ValueClassExtractor] {
1212
protected def delegate: OngoingStubbing[T]
1313

1414
protected def _thenReturn(value: T, values: Seq[T]): ScalaOngoingStubbing[T] =
15-
values.foldLeft(delegate.thenAnswer(ScalaReturns(value))) {
16-
case (s, v) => s.thenAnswer(ScalaReturns(v))
15+
values.foldLeft(delegate.thenAnswer(ScalaReturns(value))) { case (s, v) =>
16+
s.thenAnswer(ScalaReturns(v))
1717
}
1818

1919
private def thenThrow(t: Throwable): ScalaOngoingStubbing[T] = delegate thenAnswer new ScalaThrowsException(t)
2020

2121
protected def _thenThrow(throwables: Seq[Throwable]): ScalaOngoingStubbing[T] =
2222
if (throwables == null || throwables.isEmpty) thenThrow(null)
2323
else
24-
throwables.tail.foldLeft(thenThrow(throwables.head)) {
25-
case (os, t) => os andThenThrow t
24+
throwables.tail.foldLeft(thenThrow(throwables.head)) { case (os, t) =>
25+
os andThenThrow t
2626
}
2727

2828
protected def _thenThrow[E <: Throwable: ClassTag]: ScalaOngoingStubbing[T] = thenThrow((new ObjenesisStd).newInstance(clazz))

common/src/test/scala/org/mockito/matchers/MatcherProps.scala

+13-14
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,19 @@ class MatcherProps extends Properties("matchers") {
1414
import Generators._
1515

1616
property("AllOf") = forAll(chooseNum(0, 8))(length =>
17-
forAll(listOfN(length, arbitrary[ArgumentMatcher[MiniInt]]), arbitrary[MiniInt]) {
18-
case (matchers, value) =>
19-
val allOf = AllOf(matchers: _*)
20-
val stringRep = allOf.toString
21-
22-
classify(allOf.matches(value), "matches", "doesn't match") {
23-
(allOf.matches(value) ?= matchers.forall(_.matches(value))) :| "matches all underlying" &&
24-
matchers.iff {
25-
case Nil => stringRep ?= "<any>"
26-
case matcher :: Nil => stringRep ?= matcher.toString()
27-
case _ => stringRep ?= s"allOf(${matchers.mkString(", ")})"
28-
} :| "renders to string correctly"
29-
30-
}
17+
forAll(listOfN(length, arbitrary[ArgumentMatcher[MiniInt]]), arbitrary[MiniInt]) { case (matchers, value) =>
18+
val allOf = AllOf(matchers: _*)
19+
val stringRep = allOf.toString
20+
21+
classify(allOf.matches(value), "matches", "doesn't match") {
22+
(allOf.matches(value) ?= matchers.forall(_.matches(value))) :| "matches all underlying" &&
23+
matchers.iff {
24+
case Nil => stringRep ?= "<any>"
25+
case matcher :: Nil => stringRep ?= matcher.toString()
26+
case _ => stringRep ?= s"allOf(${matchers.mkString(", ")})"
27+
} :| "renders to string correctly"
28+
29+
}
3130
}
3231
)
3332

core/src/main/scala/org/mockito/MockitoScalaSession.scala

+15-13
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ class MockitoScalaSession(name: String, strictness: Strictness, logger: MockitoS
3636
mockitoSession.finishMocking(e)
3737
listener.reportIssues().foreach {
3838
case unStubbedCalls: UnexpectedInvocations if unStubbedCalls.nonEmpty =>
39-
throw new UnexpectedInvocationException(s"""A NullPointerException was thrown, check if maybe related to
40-
|$unStubbedCalls""".stripMargin, e)
39+
throw new UnexpectedInvocationException(
40+
s"""A NullPointerException was thrown, check if maybe related to
41+
|$unStubbedCalls""".stripMargin,
42+
e
43+
)
4144
case _ => throw e
4245
}
4346
case other =>
@@ -72,8 +75,8 @@ object MockitoScalaSession {
7275
override def toString: String =
7376
if (nonEmpty) {
7477
val locations = invocations.zipWithIndex
75-
.map {
76-
case (invocation, idx) => s"${idx + 1}. $invocation ${invocation.getLocation}"
78+
.map { case (invocation, idx) =>
79+
s"${idx + 1}. $invocation ${invocation.getLocation}"
7780
}
7881
.mkString("\n")
7982
s"""Unexpected invocations found
@@ -92,8 +95,8 @@ object MockitoScalaSession {
9295
override def toString: String =
9396
if (nonEmpty) {
9497
val locations = stubbings.zipWithIndex
95-
.map {
96-
case (stubbing, idx) => s"${idx + 1}. $stubbing ${stubbing.getLocation}"
98+
.map { case (stubbing, idx) =>
99+
s"${idx + 1}. $stubbing ${stubbing.getLocation}"
97100
}
98101
.mkString("\n")
99102
s"""Unnecessary stubbings detected.
@@ -113,8 +116,8 @@ object MockitoScalaSession {
113116
lazy val stubbings: Set[StubbedInvocationMatcher] =
114117
mockDetails
115118
.flatMap(_.getStubbings.asScala)
116-
.collect {
117-
case s: StubbedInvocationMatcher => s
119+
.collect { case s: StubbedInvocationMatcher =>
120+
s
118121
}
119122

120123
lazy val invocations: Set[Invocation] = mockDetails.flatMap(_.getInvocations.asScala)
@@ -140,11 +143,10 @@ object MockitoScalaSession {
140143
stubbings
141144
.filterNot(_.wasUsed())
142145
.flatMap(s => lenientStubbings.find(_.getMethod === s.getMethod).map(s -> _))
143-
.foreach {
144-
case (stubbing, lenient) =>
145-
stubbing.markStubUsed(new DescribedInvocation {
146-
override def getLocation: Location = lenient.getLocation
147-
})
146+
.foreach { case (stubbing, lenient) =>
147+
stubbing.markStubUsed(new DescribedInvocation {
148+
override def getLocation: Location = lenient.getLocation
149+
})
148150
}
149151
}
150152

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
distributionBase=GRADLE_USER_HOME
22
distributionPath=wrapper/dists
3-
distributionUrl=https\://services.gradle.org/distributions/gradle-6.6-bin.zip
3+
distributionUrl=https\://services.gradle.org/distributions/gradle-6.6.1-bin.zip
44
zipStoreBase=GRADLE_USER_HOME
55
zipStorePath=wrapper/dists

macro/src/main/scala/org/mockito/DoSomethingMacro.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ object DoSomethingMacro {
189189
else if (pf.isDefinedAt(invocation.children.last)) {
190190
val values = invocation.children
191191
.dropRight(1)
192-
.collect {
193-
case q"$_ val $name:$_ = $value" => name.toString -> value.asInstanceOf[c.Tree]
192+
.collect { case q"$_ val $name:$_ = $value" =>
193+
name.toString -> value.asInstanceOf[c.Tree]
194194
}
195195
.toMap
196196

macro/src/main/scala/org/mockito/VerifyMacro.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ private[mockito] trait VerificationMacroTransformer {
6464
else if (invocation.children.nonEmpty && pf.isDefinedAt(invocation.children.last)) {
6565
val values = invocation.children
6666
.dropRight(1)
67-
.collect {
68-
case q"$_ val $name:$_ = $value" => name.toString -> value.asInstanceOf[c.Tree]
67+
.collect { case q"$_ val $name:$_ = $value" =>
68+
name.toString -> value.asInstanceOf[c.Tree]
6969
}
7070
.toMap
7171

macro/src/main/scala/org/mockito/WhenMacro.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ object WhenMacro {
2424
else if (pf.isDefinedAt(invocation.children.last)) {
2525
val values = invocation.children
2626
.dropRight(1)
27-
.collect {
28-
case q"$_ val $name:$_ = $value" => name.toString -> value.asInstanceOf[c.Tree]
27+
.collect { case q"$_ val $name:$_ = $value" =>
28+
name.toString -> value.asInstanceOf[c.Tree]
2929
}
3030
.toMap
3131

macro/src/main/scala/org/mockito/captor/Captor.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ trait Captor[T] {
1818
def values: List[T]
1919

2020
def hasCaptured(expectations: T*)(implicit $eq: Equality[T]): Unit =
21-
expectations.zip(values).foreach {
22-
case (e, v) => if (e !== v) throw new ArgumentsAreDifferent(s"Got [$v] instead of [$e]")
21+
expectations.zip(values).foreach { case (e, v) =>
22+
if (e !== v) throw new ArgumentsAreDifferent(s"Got [$v] instead of [$e]")
2323
}
2424
}
2525

project/Dependencies.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import sbt._
22

33
object Dependencies {
44

5-
val scalatestVersion = "3.2.1"
5+
val scalatestVersion = "3.2.2"
66

77
val commonLibraries = Seq(
88
"org.mockito" % "mockito-core" % "3.5.13",
@@ -25,5 +25,5 @@ object Dependencies {
2525
val scalaz = "org.scalaz" %% "scalaz-core" % "7.3.2" % "provided"
2626

2727
val catsLaws = "org.typelevel" %% "cats-laws" % "2.0.0"
28-
val disciplineScalatest = "org.typelevel" %% "discipline-scalatest" % "2.0.0"
28+
val disciplineScalatest = "org.typelevel" %% "discipline-scalatest" % "2.0.1"
2929
}

scalatest/src/test/scala/user/org/mockito/IdiomaticStubbingTest.scala

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
package user.org.mockito
22

3+
import java.lang.reflect.{ Field, Modifier }
34
import java.util.concurrent.atomic.AtomicInteger
45

5-
import org.mockito.ArgumentMatchersSugar
6-
import org.mockito.IdiomaticStubbing
76
import org.mockito.invocation.InvocationOnMock
7+
import org.mockito.{ clazz, ArgumentMatchersSugar, IdiomaticStubbing }
88
import org.scalatest.matchers.should.Matchers
99
import org.scalatest.wordspec.AnyWordSpec
10-
import user.org.mockito.matchers.ValueCaseClassInt
11-
import user.org.mockito.matchers.ValueCaseClassString
12-
import user.org.mockito.matchers.ValueClass
10+
import user.org.mockito.matchers.{ ValueCaseClassInt, ValueCaseClassString, ValueClass }
11+
12+
import scala.reflect.ClassTag
1313

1414
class IdiomaticStubbingTest extends AnyWordSpec with Matchers with ArgumentMatchersSugar with IdiomaticMockitoTestSetup with IdiomaticStubbing {
1515

@@ -302,5 +302,16 @@ class IdiomaticStubbingTest extends AnyWordSpec with Matchers with ArgumentMatch
302302
mocked(*) returns "123"
303303
mocked("key") shouldBe "123"
304304
}
305+
306+
"stub an object method" in {
307+
FooObject.simpleMethod shouldBe "not mocked!"
308+
309+
withObjectMocked[FooObject.type] {
310+
FooObject.simpleMethod returns "mocked!"
311+
FooObject.simpleMethod shouldBe "mocked!"
312+
}
313+
314+
FooObject.simpleMethod shouldBe "not mocked!"
315+
}
305316
}
306317
}

scalatest/src/test/scala/user/org/mockito/MockitoSugarTest.scala

+11
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,17 @@ class MockitoSugarTest extends AnyWordSpec with MockitoSugar with Matchers with
432432

433433
verify(aMock).varargMethod("hola", 1, 2, 3)
434434
}
435+
436+
"stub an object method" in {
437+
FooObject.simpleMethod shouldBe "not mocked!"
438+
439+
withObjectMocked[FooObject.type] {
440+
when(FooObject.simpleMethod) thenReturn "mocked!"
441+
FooObject.simpleMethod shouldBe "mocked!"
442+
}
443+
444+
FooObject.simpleMethod shouldBe "not mocked!"
445+
}
435446
}
436447

437448
"spyLambda[T]" should {

scalatest/src/test/scala/user/org/mockito/PostfixVerificationsTest.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -683,11 +683,11 @@ class PostfixVerificationsTest extends AnyWordSpec with IdiomaticMockitoTestSetu
683683
"answersPF" in {
684684
val org = orgDouble()
685685

686-
org.doSomethingWithThisInt(*) answersPF {
687-
case i: Int => i * 10 + 2
686+
org.doSomethingWithThisInt(*) answersPF { case i: Int =>
687+
i * 10 + 2
688688
}
689-
org.doSomethingWithThisIntAndString(*, *) answersPF {
690-
case (i: Int, s: String) => (i * 10 + s.toInt).toString
689+
org.doSomethingWithThisIntAndString(*, *) answersPF { case (i: Int, s: String) =>
690+
(i * 10 + s.toInt).toString
691691
}
692692
org.doSomethingWithThisIntAndStringAndBoolean(*, *, *) answersPF {
693693
case (i: Int, s: String, true) => (i * 10 + s.toInt).toString + " verdadero"

scalatest/src/test/scala/user/org/mockito/PrefixExpectationsTest.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -707,11 +707,11 @@ class PrefixExpectationsTest extends AnyWordSpec with Matchers with ArgumentMatc
707707
"answersPF" in {
708708
val org = orgDouble()
709709

710-
org.doSomethingWithThisInt(*) answersPF {
711-
case i: Int => i * 10 + 2
710+
org.doSomethingWithThisInt(*) answersPF { case i: Int =>
711+
i * 10 + 2
712712
}
713-
org.doSomethingWithThisIntAndString(*, *) answersPF {
714-
case (i: Int, s: String) => (i * 10 + s.toInt).toString
713+
org.doSomethingWithThisIntAndString(*, *) answersPF { case (i: Int, s: String) =>
714+
(i * 10 + s.toInt).toString
715715
}
716716
org.doSomethingWithThisIntAndStringAndBoolean(*, *, *) answersPF {
717717
case (i: Int, s: String, true) => (i * 10 + s.toInt).toString + " verdadero"

scalatest/src/test/scala/user/org/mockito/TestModel.scala

+4
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,7 @@ class TestController(org: Org) {
126126
def async(f: => Int): Int = f
127127
def test(id: Int) = async(org.doSomethingWithThisInt(id))
128128
}
129+
130+
object FooObject {
131+
def simpleMethod: String = "not mocked!"
132+
}

version.properties

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#Version of the produced binaries. This file is intended to be checked-in.
22
#It will be automatically bumped by release automation.
3-
version=1.15.2
3+
version=1.16.0
44
previousVersion=1.15.1

0 commit comments

Comments
 (0)