Skip to content

Commit 524affb

Browse files
authored
Merge pull request #77 from mockito/issue/75
Implements ScalaNullResultGuardian
2 parents 107962b + 02c2026 commit 524affb

File tree

12 files changed

+756
-677
lines changed

12 files changed

+756
-677
lines changed

.sbtopts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-J-Xmx4G
2+
-J-XX:MaxMetaspaceSize=1G
3+
-J-XX:MaxPermSize=1G
4+
-J-XX:+CMSClassUnloadingEnabled

build.sbt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import scala.io.Source
44
import scala.language.postfixOps
55
import scala.util.Try
66

7+
ThisBuild / scalaVersion := "2.12.8"
8+
79
lazy val commonSettings =
810
Seq(
911
organization := "org.mockito",
@@ -26,6 +28,7 @@ lazy val commonSettings =
2628
"-Ypartial-unification",
2729
"-language:higherKinds",
2830
"-Xfatal-warnings",
31+
"-language:reflectiveCalls",
2932
// "-Xmacro-settings:mockito-print-when,mockito-print-do-something,mockito-print-verify,mockito-print-captor,mockito-print-matcher,mockito-print-extractor"
3033
),
3134
)

common/src/main/scala/org/mockito/MockitoAPI.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ private[mockito] trait MockCreator {
3333
def mock[T <: AnyRef: ClassTag: WeakTypeTag](mockSettings: MockSettings): T
3434
def mock[T <: AnyRef: ClassTag: WeakTypeTag](name: String)(implicit defaultAnswer: DefaultAnswer): T
3535

36-
def spy[T](realObj: T): T
36+
def spy[T <: AnyRef: ClassTag: WeakTypeTag](realObj: T): T
3737
def spyLambda[T <: AnyRef: ClassTag](realObj: T): T
3838

3939
/**
@@ -222,6 +222,9 @@ private[mockito] trait MockitoEnhancer extends MockCreator {
222222
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](name: String)(implicit defaultAnswer: DefaultAnswer): T =
223223
mock(withSettings.name(name))
224224

225+
def spy[T <: AnyRef: ClassTag: WeakTypeTag](realObj: T): T =
226+
mock[T](withSettings(DefaultAnswers.CallsRealMethods).spiedInstance(realObj))
227+
225228
/**
226229
* Delegates to <code>Mockito.reset(T... mocks)</code>, but restores the default stubs that
227230
* deal with default argument values
@@ -317,11 +320,6 @@ private[mockito] trait Verifications {
317320
*/
318321
private[mockito] trait Rest extends MockitoEnhancer with DoSomething with Verifications {
319322

320-
/**
321-
* Delegates to <code>Mockito.spy()</code>, it's only here to expose the full Mockito API
322-
*/
323-
def spy[T](realObj: T): T = Mockito.spy(realObj)
324-
325323
/**
326324
* Creates a "spy" in a way that supports lambdas and anonymous classes as they don't work with the standard spy as
327325
* they are created as final classes by the compiler

common/src/main/scala/org/mockito/ReflectionUtils.scala

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,41 @@ package org.mockito
22

33
import java.util.function
44

5-
import org.mockito.internal.handler.ScalaMockHandler.{ArgumentExtractor, Extractors}
6-
7-
import scala.reflect.runtime.universe._
5+
import org.mockito.internal.handler.ScalaMockHandler.{ ArgumentExtractor, Extractors }
6+
import org.mockito.invocation.InvocationOnMock
7+
import ru.vyarus.java.generics.resolver.GenericsResolver
88

99
private[mockito] object ReflectionUtils {
1010

11+
import scala.reflect.runtime.{ universe => ru }
12+
import ru._
13+
14+
private val mirror = runtimeMirror(getClass.getClassLoader)
15+
private val customMirror = mirror.asInstanceOf[{
16+
def methodToJava(sym: scala.reflect.internal.Symbols#MethodSymbol): java.lang.reflect.Method
17+
}]
18+
19+
implicit class InvocationOnMockOps(invocation: InvocationOnMock) {
20+
def returnType: Class[_] = {
21+
val method = invocation.getMethod
22+
val clazz = method.getDeclaringClass
23+
val javaReturnType = invocation.getMethod.getReturnType
24+
25+
if (javaReturnType == classOf[Object])
26+
mirror
27+
.classSymbol(clazz)
28+
.info
29+
.decls
30+
.filter(d => d.isMethod && !d.isConstructor)
31+
.find(d => customMirror.methodToJava(d.asInstanceOf[scala.reflect.internal.Symbols#MethodSymbol]) == method)
32+
.map(_.asMethod)
33+
.filter(_.returnType.typeSymbol.isClass)
34+
.map(methodSymbol => mirror.runtimeClass(methodSymbol.returnType.typeSymbol.asClass))
35+
.getOrElse(GenericsResolver.resolve(invocation.getMock.getClass).`type`(clazz).method(method).resolveReturnClass())
36+
else javaReturnType
37+
}
38+
}
39+
1140
def interfaces[T](implicit tag: WeakTypeTag[T]): List[Class[_]] =
1241
tag.tpe match {
1342
case RefinedType(types, _) =>

common/src/main/scala/org/mockito/internal/handler/ScalaMockHandler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class ScalaMockHandler[T](mockSettings: MockCreationSettings[T]) extends MockHan
3232

3333
object ScalaMockHandler {
3434
def apply[T](mockSettings: MockCreationSettings[T]): MockHandler[T] =
35-
new InvocationNotifierHandler[T](new NullResultGuardian[T](new ScalaMockHandler(mockSettings)), mockSettings)
35+
new InvocationNotifierHandler[T](new ScalaNullResultGuardian[T](new ScalaMockHandler(mockSettings)), mockSettings)
3636

3737
private def readField[T](invocation: InterceptedInvocation, field: String): T = {
3838
val f = classOf[InterceptedInvocation].getDeclaredField(field)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package org.mockito.internal.handler
2+
import org.mockito.internal.util.Primitives.defaultValue
3+
import org.mockito.invocation.{ Invocation, InvocationContainer, MockHandler }
4+
import org.mockito.mock.MockCreationSettings
5+
import org.mockito.ReflectionUtils._
6+
7+
class ScalaNullResultGuardian[T](delegate: MockHandler[T]) extends MockHandler[T] {
8+
9+
override def handle(invocation: Invocation): AnyRef = {
10+
val result = delegate.handle(invocation)
11+
val returnType = invocation.returnType
12+
if (result == null && returnType.isPrimitive)
13+
defaultValue(returnType).asInstanceOf[AnyRef]
14+
else
15+
result
16+
}
17+
18+
override def getMockSettings: MockCreationSettings[T] = delegate.getMockSettings
19+
override def getInvocationContainer: InvocationContainer = delegate.getInvocationContainer
20+
}

common/src/main/scala/org/mockito/stubbing/ReturnsSmartNulls.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,21 @@ package org.mockito.stubbing
22
import java.lang.reflect.Modifier.isFinal
33

44
import org.mockito.Mockito.mock
5+
import org.mockito.ReflectionUtils._
56
import org.mockito.internal.debugging.LocationImpl
67
import org.mockito.internal.exceptions.Reporter.smartNullPointerException
78
import org.mockito.internal.stubbing.defaultanswers.ReturnsMoreEmptyValues
89
import org.mockito.internal.util.ObjectMethodsGuru.isToStringMethod
910
import org.mockito.invocation.{ InvocationOnMock, Location }
10-
import ru.vyarus.java.generics.resolver.GenericsResolver.resolve
1111

1212
object ReturnsSmartNulls extends DefaultAnswer {
1313

1414
val delegate = new ReturnsMoreEmptyValues
1515

1616
override def apply(invocation: InvocationOnMock): Option[Any] = Option(delegate.answer(invocation)).orElse {
17-
val method = invocation.getMethod
18-
val context = resolve(invocation.getMock.getClass).`type`(method.getDeclaringClass)
19-
val returnType = context.method(method).resolveReturnClass()
20-
if (!returnType.isPrimitive && !isFinal(returnType.getModifiers))
17+
val returnType = invocation.returnType
18+
19+
if (!returnType.isPrimitive && !isFinal(returnType.getModifiers) && classOf[Object] != returnType)
2120
Some(mock(returnType, ThrowsSmartNullPointer(invocation)))
2221
else
2322
None

core/src/main/scala/org/mockito/IdiomaticMockito.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ trait IdiomaticMockito extends MockCreator {
2222
override def mock[T <: AnyRef: ClassTag: WeakTypeTag](implicit defaultAnswer: DefaultAnswer): T =
2323
MockitoSugar.mock[T]
2424

25-
override def spy[T](realObj: T): T = MockitoSugar.spy(realObj)
25+
override def spy[T <: AnyRef: ClassTag: WeakTypeTag](realObj: T): T = MockitoSugar.spy(realObj)
2626

2727
override def spyLambda[T <: AnyRef: ClassTag](realObj: T): T = MockitoSugar.spyLambda(realObj)
2828

0 commit comments

Comments
 (0)