Skip to content

Commit f6023c0

Browse files
committed
Fix handling of primitive value class arguments in eq matcher and argumentCaptor.
1 parent ae8905b commit f6023c0

File tree

7 files changed

+154
-61
lines changed

7 files changed

+154
-61
lines changed

mockito-kotlin/src/main/kotlin/org/mockito/kotlin/ArgumentCaptor.kt

Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,43 +25,45 @@
2525

2626
package org.mockito.kotlin
2727

28-
import org.mockito.kotlin.internal.createInstance
2928
import org.mockito.ArgumentCaptor
29+
import org.mockito.kotlin.internal.createInstance
3030
import java.lang.reflect.Array
3131
import kotlin.reflect.KClass
32+
import kotlin.reflect.KType
33+
import kotlin.reflect.typeOf
3234

3335
/**
3436
* Creates a [KArgumentCaptor] for given type.
3537
*/
36-
inline fun <reified T : Any> argumentCaptor(): KArgumentCaptor<T> {
37-
return KArgumentCaptor(T::class)
38+
inline fun <reified T : Any?> argumentCaptor(): KArgumentCaptor<T> {
39+
return KArgumentCaptor(typeOf<T>())
3840
}
3941

4042
/**
4143
* Creates 2 [KArgumentCaptor]s for given types.
4244
*/
4345
inline fun <reified A : Any, reified B : Any> argumentCaptor(
44-
a: KClass<A> = A::class,
45-
b: KClass<B> = B::class
46+
@Suppress("unused") a: KClass<A> = A::class,
47+
@Suppress("unused") b: KClass<B> = B::class
4648
): Pair<KArgumentCaptor<A>, KArgumentCaptor<B>> {
4749
return Pair(
48-
KArgumentCaptor(a),
49-
KArgumentCaptor(b)
50+
KArgumentCaptor(typeOf<A>()),
51+
KArgumentCaptor(typeOf<B>())
5052
)
5153
}
5254

5355
/**
5456
* Creates 3 [KArgumentCaptor]s for given types.
5557
*/
5658
inline fun <reified A : Any, reified B : Any, reified C : Any> argumentCaptor(
57-
a: KClass<A> = A::class,
58-
b: KClass<B> = B::class,
59-
c: KClass<C> = C::class
59+
@Suppress("unused") a: KClass<A> = A::class,
60+
@Suppress("unused") b: KClass<B> = B::class,
61+
@Suppress("unused") c: KClass<C> = C::class
6062
): Triple<KArgumentCaptor<A>, KArgumentCaptor<B>, KArgumentCaptor<C>> {
6163
return Triple(
62-
KArgumentCaptor(a),
63-
KArgumentCaptor(b),
64-
KArgumentCaptor(c)
64+
KArgumentCaptor(typeOf<A>()),
65+
KArgumentCaptor(typeOf<B>()),
66+
KArgumentCaptor(typeOf<C>())
6567
)
6668
}
6769

@@ -97,35 +99,35 @@ class ArgumentCaptorHolder5<out A, out B, out C, out D, out E>(
9799
* Creates 4 [KArgumentCaptor]s for given types.
98100
*/
99101
inline fun <reified A : Any, reified B : Any, reified C : Any, reified D : Any> argumentCaptor(
100-
a: KClass<A> = A::class,
101-
b: KClass<B> = B::class,
102-
c: KClass<C> = C::class,
103-
d: KClass<D> = D::class
102+
@Suppress("unused") a: KClass<A> = A::class,
103+
@Suppress("unused") b: KClass<B> = B::class,
104+
@Suppress("unused") c: KClass<C> = C::class,
105+
@Suppress("unused") d: KClass<D> = D::class
104106
): ArgumentCaptorHolder4<KArgumentCaptor<A>, KArgumentCaptor<B>, KArgumentCaptor<C>, KArgumentCaptor<D>> {
105107
return ArgumentCaptorHolder4(
106-
KArgumentCaptor(a),
107-
KArgumentCaptor(b),
108-
KArgumentCaptor(c),
109-
KArgumentCaptor(d)
108+
KArgumentCaptor(typeOf<A>()),
109+
KArgumentCaptor(typeOf<B>()),
110+
KArgumentCaptor(typeOf<C>()),
111+
KArgumentCaptor(typeOf<D>())
110112
)
111113
}
112114

113115
/**
114116
* Creates 4 [KArgumentCaptor]s for given types.
115117
*/
116118
inline fun <reified A : Any, reified B : Any, reified C : Any, reified D : Any, reified E : Any> argumentCaptor(
117-
a: KClass<A> = A::class,
118-
b: KClass<B> = B::class,
119-
c: KClass<C> = C::class,
120-
d: KClass<D> = D::class,
121-
e: KClass<E> = E::class
119+
@Suppress("unused") a: KClass<A> = A::class,
120+
@Suppress("unused") b: KClass<B> = B::class,
121+
@Suppress("unused") c: KClass<C> = C::class,
122+
@Suppress("unused") d: KClass<D> = D::class,
123+
@Suppress("unused") e: KClass<E> = E::class
122124
): ArgumentCaptorHolder5<KArgumentCaptor<A>, KArgumentCaptor<B>, KArgumentCaptor<C>, KArgumentCaptor<D>, KArgumentCaptor<E>> {
123125
return ArgumentCaptorHolder5(
124-
KArgumentCaptor(a),
125-
KArgumentCaptor(b),
126-
KArgumentCaptor(c),
127-
KArgumentCaptor(d),
128-
KArgumentCaptor(e)
126+
KArgumentCaptor(typeOf<A>()),
127+
KArgumentCaptor(typeOf<B>()),
128+
KArgumentCaptor(typeOf<C>()),
129+
KArgumentCaptor(typeOf<D>()),
130+
KArgumentCaptor(typeOf<E>())
129131
)
130132
}
131133

@@ -140,7 +142,7 @@ inline fun <reified T : Any> argumentCaptor(f: KArgumentCaptor<T>.() -> Unit): K
140142
* Creates a [KArgumentCaptor] for given nullable type.
141143
*/
142144
inline fun <reified T : Any> nullableArgumentCaptor(): KArgumentCaptor<T?> {
143-
return KArgumentCaptor(T::class)
145+
return KArgumentCaptor(typeOf<T>())
144146
}
145147

146148
/**
@@ -157,17 +159,17 @@ inline fun <reified T : Any> capture(captor: ArgumentCaptor<T>): T {
157159
return captor.capture() ?: createInstance()
158160
}
159161

160-
class KArgumentCaptor<out T : Any?> (
161-
private val tClass: KClass<*>
162-
) {
162+
class KArgumentCaptor<out T : Any?>(private val kType: KType) {
163+
private val clazz = kType.classifier as KClass<*>
164+
163165
private val captor: ArgumentCaptor<Any?> =
164-
if (tClass.isValue) {
166+
if (clazz.isValue && !kType.isMarkedNullable) {
165167
val boxImpl =
166-
tClass.java.declaredMethods
168+
clazz.java.declaredMethods
167169
.single { it.name == "box-impl" && it.parameterCount == 1 }
168170
boxImpl.parameters[0].type // is the boxed type of the value type
169171
} else {
170-
tClass.java
172+
clazz.java
171173
}.let {
172174
ArgumentCaptor.forClass(it)
173175
}
@@ -219,27 +221,29 @@ class KArgumentCaptor<out T : Any?> (
219221
// In Java, `captor.capture` returns null and so the method is called with `[null]`
220222
// In Kotlin, we have to create `[null]` explicitly.
221223
// This code-path is applied for non-vararg array arguments as well, but it seems to work fine.
222-
return captor.capture() as T ?: if (tClass.java.isArray) {
224+
return toKotlinType(captor.capture()) ?: if (clazz.java.isArray) {
223225
singleElementArray()
224226
} else {
225-
createInstance(tClass)
227+
createInstance(clazz)
226228
} as T
227229
}
228230

229-
private fun singleElementArray(): Any? = Array.newInstance(tClass.java.componentType, 1)
231+
private fun singleElementArray(): Any? = Array.newInstance(clazz.java.componentType, 1)
230232

231233
@Suppress("UNCHECKED_CAST")
232-
private fun toKotlinType(rawCapturedValue: Any?) : T {
233-
return if(tClass.isValue) {
234-
rawCapturedValue
235-
?.let {
234+
private fun toKotlinType(rawCapturedValue: Any?): T {
235+
if (rawCapturedValue == null) return null as T
236+
237+
if (clazz.isValue && rawCapturedValue::class != clazz) {
238+
return rawCapturedValue
239+
.let {
236240
val boxImpl =
237-
tClass.java.declaredMethods.single { it.name == "box-impl" && it.parameterCount == 1 }
241+
clazz.java.declaredMethods.single { it.name == "box-impl" && it.parameterCount == 1 }
238242
boxImpl.invoke(null, it)
239243
} as T
240-
} else {
241-
rawCapturedValue as T
242244
}
245+
246+
return rawCapturedValue as T
243247
}
244248
}
245249

mockito-kotlin/src/main/kotlin/org/mockito/kotlin/Matchers.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.mockito.ArgumentMatcher
2929
import org.mockito.ArgumentMatchers
3030
import org.mockito.kotlin.internal.createInstance
3131
import kotlin.reflect.KClass
32+
import kotlin.reflect.typeOf
3233

3334
/** Object argument that is equal to the given value. */
3435
inline fun <reified T : Any?> eq(value: T): T {
@@ -91,11 +92,17 @@ inline fun <reified T > anyValueClass(): T {
9192
return boxImpl.invoke(null, ArgumentMatchers.any(boxedType)) as T
9293
}
9394

94-
inline fun <reified T > eqValueClass(value: T): T {
95+
inline fun <reified T> eqValueClass(value: T): T {
9596
require(T::class.isValue) {
9697
"${T::class.qualifiedName} is not a value class."
9798
}
9899

100+
if (typeOf<T>().isMarkedNullable) {
101+
// if the value is both value class and nullable, then Kotlin passes the value class boxed
102+
// towards Mockito java code.
103+
return ArgumentMatchers.eq(value)
104+
}
105+
99106
val unboxImpl =
100107
T::class.java.declaredMethods
101108
.single { it.name == "unbox-impl" && it.parameterCount == 0 }

tests/src/test/kotlin/test/ArgumentCaptorTest.kt

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@ package test
22

33
import com.nhaarman.expect.expect
44
import com.nhaarman.expect.expectErrorWithMessage
5-
import org.junit.Ignore
65
import org.junit.Test
7-
import org.mockito.kotlin.*
6+
import org.mockito.kotlin.KArgumentCaptor
7+
import org.mockito.kotlin.any
8+
import org.mockito.kotlin.argumentCaptor
9+
import org.mockito.kotlin.doNothing
10+
import org.mockito.kotlin.mock
11+
import org.mockito.kotlin.nullableArgumentCaptor
12+
import org.mockito.kotlin.times
13+
import org.mockito.kotlin.verify
14+
import org.mockito.kotlin.whenever
815
import java.util.*
916

1017
class ArgumentCaptorTest : TestBase() {
@@ -385,7 +392,6 @@ class ArgumentCaptorTest : TestBase() {
385392
}
386393

387394
@Test
388-
@Ignore("See issue #555")
389395
fun argumentCaptor_primitive_value_class() {
390396
/* Given */
391397
val m: SynchronousFunctions = mock()
@@ -401,7 +407,6 @@ class ArgumentCaptorTest : TestBase() {
401407
}
402408

403409
@Test
404-
@Ignore("See issue #555")
405410
fun argumentCaptor_nullable_primitive_value_class() {
406411
/* Given */
407412
val m: SynchronousFunctions = mock()
@@ -411,7 +416,7 @@ class ArgumentCaptorTest : TestBase() {
411416
m.nullablePrimitiveValueClass(valueClass)
412417

413418
/* Then */
414-
val captor = argumentCaptor<PrimitiveValueClass>()
419+
val captor = argumentCaptor<PrimitiveValueClass?>()
415420
verify(m).nullablePrimitiveValueClass(captor.capture())
416421
expect(captor.firstValue).toBe(valueClass)
417422
}

tests/src/test/kotlin/test/Classes.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@ interface SynchronousFunctions {
9898
fun valueClass(v: ValueClass)
9999
fun nullableValueClass(v: ValueClass?)
100100
fun nestedValueClass(v: NestedValueClass)
101+
fun primitiveValueClass(v: PrimitiveValueClass)
102+
fun nullablePrimitiveValueClass(v: PrimitiveValueClass?)
101103
fun valueClassResult(): ValueClass
102104
fun nullableValueClassResult(): ValueClass?
103105
fun nestedValueClassResult(): NestedValueClass
104-
fun primitiveValueClass(v: PrimitiveValueClass)
105-
fun nullablePrimitiveValueClass(v: PrimitiveValueClass?)
106+
fun primitiveValueClassResult(): PrimitiveValueClass
107+
fun nullablePrimitiveValueClassResult(): PrimitiveValueClass?
106108
}
107109

108110
interface SuspendFunctions {
@@ -118,6 +120,8 @@ interface SuspendFunctions {
118120
suspend fun valueClassResult(): ValueClass
119121
suspend fun nullableValueClassResult(): ValueClass?
120122
suspend fun nestedValueClassResult(): NestedValueClass
123+
suspend fun primitiveValueClassResult(): PrimitiveValueClass
124+
suspend fun nullablePrimitiveValueClassResult(): PrimitiveValueClass?
121125
suspend fun builderMethod(): SuspendFunctions
122126
}
123127

tests/src/test/kotlin/test/CoroutinesOngoingStubbingTest.kt

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,4 +337,40 @@ class CoroutinesOngoingStubbingTest {
337337
expect(result).toBe(nestedValueClass)
338338
expect(result.value).toBe(nestedValueClass.value)
339339
}
340+
341+
@Test
342+
fun `should stub suspendable function call with primitive value class result`() = runTest {
343+
/* Given */
344+
val primitiveValueClass = PrimitiveValueClass(42)
345+
val mock = mock<SuspendFunctions> {
346+
on(mock.primitiveValueClassResult()) doSuspendableAnswer {
347+
delay(1)
348+
primitiveValueClass
349+
}
350+
}
351+
352+
/* When */
353+
val result: PrimitiveValueClass = mock.primitiveValueClassResult()
354+
355+
/* Then */
356+
expect(result).toBe(primitiveValueClass)
357+
}
358+
359+
@Test
360+
fun `should stub suspendable function call with nullable primitive value class result`() = runTest {
361+
/* Given */
362+
val primitiveValueClass = PrimitiveValueClass(42)
363+
val mock = mock<SuspendFunctions> {
364+
on (mock.nullablePrimitiveValueClassResult()) doSuspendableAnswer {
365+
delay(1)
366+
primitiveValueClass
367+
}
368+
}
369+
370+
/* When */
371+
val result: PrimitiveValueClass? = mock.nullablePrimitiveValueClassResult()
372+
373+
/* Then */
374+
expect(result).toBe(primitiveValueClass)
375+
}
340376
}

tests/src/test/kotlin/test/MatchersTest.kt

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package test
33
import com.nhaarman.expect.expect
44
import com.nhaarman.expect.expectErrorWithMessage
55
import kotlinx.coroutines.test.runTest
6-
import org.junit.Ignore
76
import org.junit.Test
87
import org.junit.experimental.runners.Enclosed
98
import org.junit.runner.RunWith
@@ -629,12 +628,20 @@ class MatchersTest : TestBase() {
629628
}
630629

631630
@Test
632-
@Ignore("See issue #555")
631+
fun eqPrimitiveValueClass() {
632+
val primitiveValueClass = PrimitiveValueClass(123)
633+
mock<SynchronousFunctions>().apply {
634+
primitiveValueClass(primitiveValueClass)
635+
verify(this).primitiveValueClass(eq(primitiveValueClass))
636+
}
637+
}
638+
639+
@Test
633640
fun eqNullablePrimitiveValueClass() {
634-
val valueClass = PrimitiveValueClass(123)
641+
val primitiveValueClass = PrimitiveValueClass(123) as PrimitiveValueClass?
635642
mock<SynchronousFunctions>().apply {
636-
nullablePrimitiveValueClass(valueClass)
637-
verify(this).nullablePrimitiveValueClass(eq(valueClass))
643+
nullablePrimitiveValueClass(primitiveValueClass)
644+
verify(this).nullablePrimitiveValueClass(eq(primitiveValueClass))
638645
}
639646
}
640647

tests/src/test/kotlin/test/OngoingStubbingTest.kt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,36 @@ class OngoingStubbingTest : TestBase() {
317317
expect(result.value).toBe(nestedValueClass.value)
318318
}
319319

320+
@Test
321+
fun `should stub function call with primitive value class result`() {
322+
/* Given */
323+
val primitiveValueClass = PrimitiveValueClass(42)
324+
val mock = mock<SynchronousFunctions> {
325+
on { primitiveValueClassResult() } doReturn primitiveValueClass
326+
}
327+
328+
/* When */
329+
val result: PrimitiveValueClass = mock.primitiveValueClassResult()
330+
331+
/* Then */
332+
expect(result).toBe(primitiveValueClass)
333+
}
334+
335+
@Test
336+
fun `should stub function call with nullable primitive value class result`() {
337+
/* Given */
338+
val primitiveValueClass = PrimitiveValueClass(42)
339+
val mock = mock<SynchronousFunctions> {
340+
on { nullablePrimitiveValueClassResult() } doReturn primitiveValueClass
341+
}
342+
343+
/* When */
344+
val result: PrimitiveValueClass? = mock.nullablePrimitiveValueClassResult()
345+
346+
/* Then */
347+
expect(result).toBe(primitiveValueClass)
348+
}
349+
320350
@Test
321351
fun `should stub consecutive function calls with value class results`() {
322352
/* Given */

0 commit comments

Comments
 (0)