diff --git a/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/internal/CreateInstance.kt b/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/internal/CreateInstance.kt index 0515e1b..b13186c 100644 --- a/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/internal/CreateInstance.kt +++ b/mockito-kotlin/src/main/kotlin/org/mockito/kotlin/internal/CreateInstance.kt @@ -26,9 +26,16 @@ package org.mockito.kotlin.internal import kotlin.reflect.KClass +import kotlin.reflect.KProperty1 +import kotlin.reflect.full.primaryConstructor inline fun createInstance(): T { - return when (T::class) { + return createInstance(T::class) +} + +@Suppress("UNCHECKED_CAST") +fun createInstance(kClass: KClass): T { + return when (kClass) { Boolean::class -> false as T Byte::class -> 0.toByte() as T Char::class -> 0.toChar() as T @@ -37,16 +44,17 @@ inline fun createInstance(): T { Long::class -> 0L as T Float::class -> 0f as T Double::class -> 0.0 as T - else -> createInstance(T::class) + else -> createInstanceNonPrimitive(kClass) } } @Suppress("UNCHECKED_CAST") -fun createInstance(@Suppress("UNUSED_PARAMETER") kClass: KClass): T { - return if(kClass.isValue) { - val boxImpl = - kClass.java.declaredMethods.single { it.name == "box-impl" && it.parameterCount == 1 } - boxImpl.invoke(null, castNull()) as T +private fun createInstanceNonPrimitive(kClass: KClass): T { + return if (kClass.isValue) { + val boxImpl = + kClass.java.declaredMethods.single { it.name == "box-impl" && it.parameterCount == 1 } + val wrappedType = getValueClassWrappedType(kClass) + boxImpl.invoke(null, createInstance(wrappedType)) as T } else { castNull() } @@ -60,3 +68,11 @@ fun createInstance(@Suppress("UNUSED_PARAMETER") kClass: KClass): T */ @Suppress("UNCHECKED_CAST") private fun castNull(): T = null as T + +private fun getValueClassWrappedType(kClass: KClass<*>): KClass<*> { + require(kClass.isValue) + + val primaryConstructor = checkNotNull(kClass.primaryConstructor) + val wrappedType = primaryConstructor.parameters.single().type + return wrappedType.classifier as KClass<*> +} diff --git a/tests/src/test/kotlin/test/ArgumentCaptorTest.kt b/tests/src/test/kotlin/test/ArgumentCaptorTest.kt index 891f7ff..d467088 100644 --- a/tests/src/test/kotlin/test/ArgumentCaptorTest.kt +++ b/tests/src/test/kotlin/test/ArgumentCaptorTest.kt @@ -2,6 +2,7 @@ package test import com.nhaarman.expect.expect import com.nhaarman.expect.expectErrorWithMessage +import org.junit.Ignore import org.junit.Test import org.mockito.kotlin.* import java.util.* @@ -382,4 +383,36 @@ class ArgumentCaptorTest : TestBase() { verify(m).nullableValueClass(captor.capture()) expect(captor.firstValue).toBeNull() } + + @Test + @Ignore("See issue #555") + fun argumentCaptor_primitive_value_class() { + /* Given */ + val m: SynchronousFunctions = mock() + val valueClass = PrimitiveValueClass(123) + + /* When */ + m.primitiveValueClass(valueClass) + + /* Then */ + val captor = argumentCaptor() + verify(m).primitiveValueClass(captor.capture()) + expect(captor.firstValue).toBe(valueClass) + } + + @Test + @Ignore("See issue #555") + fun argumentCaptor_nullable_primitive_value_class() { + /* Given */ + val m: SynchronousFunctions = mock() + val valueClass = PrimitiveValueClass(123) + + /* When */ + m.nullablePrimitiveValueClass(valueClass) + + /* Then */ + val captor = argumentCaptor() + verify(m).nullablePrimitiveValueClass(captor.capture()) + expect(captor.firstValue).toBe(valueClass) + } } diff --git a/tests/src/test/kotlin/test/Classes.kt b/tests/src/test/kotlin/test/Classes.kt index ca06afd..b93e2d2 100644 --- a/tests/src/test/kotlin/test/Classes.kt +++ b/tests/src/test/kotlin/test/Classes.kt @@ -101,6 +101,8 @@ interface SynchronousFunctions { fun valueClassResult(): ValueClass fun nullableValueClassResult(): ValueClass? fun nestedValueClassResult(): NestedValueClass + fun primitiveValueClass(v: PrimitiveValueClass) + fun nullablePrimitiveValueClass(v: PrimitiveValueClass?) } interface SuspendFunctions { @@ -125,6 +127,9 @@ value class ValueClass(val content: String) @JvmInline value class NestedValueClass(val value: ValueClass) +@JvmInline +value class PrimitiveValueClass(val value: Long) + interface ExtraInterface abstract class ThrowingConstructor { diff --git a/tests/src/test/kotlin/test/MatchersTest.kt b/tests/src/test/kotlin/test/MatchersTest.kt index c913e4b..24ba28d 100644 --- a/tests/src/test/kotlin/test/MatchersTest.kt +++ b/tests/src/test/kotlin/test/MatchersTest.kt @@ -3,6 +3,7 @@ package test import com.nhaarman.expect.expect import com.nhaarman.expect.expectErrorWithMessage import kotlinx.coroutines.test.runTest +import org.junit.Ignore import org.junit.Test import org.junit.experimental.runners.Enclosed import org.junit.runner.RunWith @@ -201,6 +202,14 @@ class MatchersTest : TestBase() { } } + @Test + fun anyPrimitiveValueClass() { + mock().apply { + primitiveValueClass(PrimitiveValueClass(123)) + verify(this).primitiveValueClass(any()) + } + } + @Test fun anyNeverVerifiesForNullValue() { mock().apply { @@ -383,6 +392,22 @@ class MatchersTest : TestBase() { verify(this).nullableValueClass(anyOrNull()) } } + + @Test + fun anyOrNullNullablePrimitiveValueClass() { + mock().apply { + nullablePrimitiveValueClass(PrimitiveValueClass(123)) + verify(this).nullablePrimitiveValueClass(anyOrNull()) + } + } + + @Test + fun anyOrNullNullablePrimitiveValueClassNullValue() { + mock().apply { + nullablePrimitiveValueClass(null) + verify(this).nullablePrimitiveValueClass(anyOrNull()) + } + } } class EqMatchersTest { @@ -594,6 +619,25 @@ class MatchersTest : TestBase() { } } + @Test + fun eqNullableValueClass() { + val valueClass = ValueClass("Content") + mock().apply { + nullableValueClass(valueClass) + verify(this).nullableValueClass(eq(valueClass)) + } + } + + @Test + @Ignore("See issue #555") + fun eqNullablePrimitiveValueClass() { + val valueClass = PrimitiveValueClass(123) + mock().apply { + nullablePrimitiveValueClass(valueClass) + verify(this).nullablePrimitiveValueClass(eq(valueClass)) + } + } + @Test fun eqNestedValueClass() { val nestedValueClass = NestedValueClass(ValueClass("Content"))