diff --git a/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt b/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt index 60e8b93482..0ae69ff7ae 100644 --- a/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt +++ b/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt @@ -1,18 +1,17 @@ package net.corda.core.internal import net.corda.core.contracts.ContractState -import net.corda.core.contracts.LinearPointer import net.corda.core.contracts.StatePointer -import net.corda.core.contracts.StaticPointer import java.lang.reflect.Field import java.util.* /** * Uses reflection to search for instances of [StatePointer] within a [ContractState]. + * TODO: Doesn't handle calculated properties. Add support for this. */ class StatePointerSearch(val state: ContractState) { // Classes in these packages should not be part of a search. - private val blackListedPackages = setOf("java.", "javax.") + private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.") // Type required for traversal. private data class FieldWithObject(val obj: Any, val field: Field) @@ -21,14 +20,26 @@ class StatePointerSearch(val state: ContractState) { private val statePointers = mutableSetOf>() // Record seen objects to avoid getting stuck in loops. - private val seenObjects = mutableSetOf().apply { add(state) } + private val seenObjects = Collections.newSetFromMap(IdentityHashMap()).apply { add(state) } // Queue of fields to search. private val fieldQueue = ArrayDeque().apply { addAllFields(state) } + // Get fields of class and all super-classes. + private fun getAllFields(clazz: Class<*>): List { + val fields = mutableListOf() + var currentClazz = clazz + while (currentClazz.superclass != null) { + fields.addAll(currentClazz.declaredFields) + currentClazz = currentClazz.superclass + } + return fields + } + // Helper for adding all fields to the queue. private fun ArrayDeque.addAllFields(obj: Any) { - val fields = obj::class.java.declaredFields + val fields = getAllFields(obj::class.java) + val fieldsWithObjects = fields.mapNotNull { field -> // Ignore classes which have not been loaded. // Assumption: all required state classes are already loaded. @@ -36,39 +47,44 @@ class StatePointerSearch(val state: ContractState) { if (packageName == null) { null } else { - // Ignore JDK classes. - val isBlacklistedPackage = blackListedPackages.any { packageName.startsWith(it) } - if (isBlacklistedPackage) { - null - } else { - FieldWithObject(obj, field) - } + FieldWithObject(obj, field) } } addAll(fieldsWithObjects) } - private fun handleField(obj: Any, field: Field) { - when { - // StatePointer. Handles nullable StatePointers too. - field.type == LinearPointer::class.java -> statePointers.add(field.get(obj) as? LinearPointer<*> ?: return) - field.type == StaticPointer::class.java -> statePointers.add(field.get(obj) as? StaticPointer<*> ?: return) - // Not StatePointer. + private fun handleIterable(iterable: Iterable<*>) { + iterable.forEach { obj -> handleObject(obj) } + } + + private fun handleMap(map: Map<*, *>) { + map.forEach { k, v -> + handleObject(k) + handleObject(v) + } + } + + private fun handleObject(obj: Any?) { + if (obj == null) return + seenObjects.add(obj) + when (obj) { + is Map<*, *> -> handleMap(obj) + is StatePointer<*> -> statePointers.add(obj) + is Iterable<*> -> handleIterable(obj) else -> { - val newObj = field.get(obj) ?: return - - // Ignore nulls. - if (newObj in seenObjects) { - return - } - - // Recurse. - fieldQueue.addAllFields(newObj) - seenObjects.add(obj) + val packageName = obj.javaClass.`package`.name + val isBlackListed = blackListedPackages.any { packageName.startsWith(it) } + if (isBlackListed.not()) fieldQueue.addAllFields(obj) } } } + private fun handleField(obj: Any, field: Field) { + val newObj = field.get(obj) ?: return + if (newObj in seenObjects) return + handleObject(newObj) + } + fun search(): Set> { while (fieldQueue.isNotEmpty()) { val (obj, field) = fieldQueue.pop() diff --git a/core/src/test/kotlin/net/corda/core/internal/StatePointerSearchTests.kt b/core/src/test/kotlin/net/corda/core/internal/StatePointerSearchTests.kt new file mode 100644 index 0000000000..716fa2c157 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/StatePointerSearchTests.kt @@ -0,0 +1,77 @@ +package net.corda.core.internal + +import net.corda.core.contracts.* +import net.corda.core.crypto.NullKeys +import net.corda.core.identity.AbstractParty +import net.corda.core.identity.AnonymousParty +import net.corda.core.utilities.OpaqueBytes +import org.junit.Test +import kotlin.test.assertEquals + +class StatePointerSearchTests { + + private val partyAndRef = PartyAndReference(AnonymousParty(NullKeys.NullPublicKey), OpaqueBytes.of(0)) + + private data class StateWithGeneric(val amount: Amount>>) : ContractState { + override val participants: List get() = listOf() + } + + private data class StateWithList(val pointerList: List>) : ContractState { + override val participants: List get() = listOf() + } + + private data class StateWithMap(val pointerMap: Map) : ContractState { + override val participants: List get() = listOf() + } + + private data class StateWithSet(val pointerSet: Set>) : ContractState { + override val participants: List get() = listOf() + } + + private data class StateWithListOfList(val pointerSet: List>>) : ContractState { + override val participants: List get() = listOf() + } + + @Test + fun `find pointer in state with generic type`() { + val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java) + val testState = StateWithGeneric(Amount(100L, Issued(partyAndRef, linearPointer))) + val results = StatePointerSearch(testState).search() + assertEquals(results, setOf(linearPointer)) + } + + @Test + fun `find pointers which are inside a list`() { + val linearPointerOne = LinearPointer(UniqueIdentifier(), LinearState::class.java) + val linearPointerTwo = LinearPointer(UniqueIdentifier(), LinearState::class.java) + val testState = StateWithList(listOf(linearPointerOne, linearPointerTwo)) + val results = StatePointerSearch(testState).search() + assertEquals(results, setOf(linearPointerOne, linearPointerTwo)) + } + + @Test + fun `find pointers which are inside a map`() { + val linearPointerOne = LinearPointer(UniqueIdentifier(), LinearState::class.java) + val linearPointerTwo = LinearPointer(UniqueIdentifier(), LinearState::class.java) + val testState = StateWithMap(mapOf(linearPointerOne to 1, 2 to linearPointerTwo)) + val results = StatePointerSearch(testState).search() + assertEquals(results, setOf(linearPointerOne, linearPointerTwo)) + } + + @Test + fun `find pointers which are inside a set`() { + val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java) + val testState = StateWithSet(setOf(linearPointer)) + val results = StatePointerSearch(testState).search() + assertEquals(results, setOf(linearPointer)) + } + + @Test + fun `find pointers which are inside nested iterables`() { + val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java) + val testState = StateWithListOfList(listOf(listOf(linearPointer))) + val results = StatePointerSearch(testState).search() + assertEquals(results, setOf(linearPointer)) + } + +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/ResolveStatePointersTest.kt b/node/src/test/kotlin/net/corda/node/services/transactions/ResolveStatePointersTest.kt index e6c6f06aa9..0d7326b790 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/ResolveStatePointersTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/ResolveStatePointersTest.kt @@ -87,7 +87,7 @@ class ResolveStatePointersTest { @Test fun `resolving nested pointers is possible`() { // Create barOne. - createPointedToState(barOne) + val barOneStateAndRef = createPointedToState(barOne) // Create another Bar - barTwo - which points to barOne. val barTwoStateAndRef = createPointedToState(barTwo) @@ -105,6 +105,7 @@ class ResolveStatePointersTest { // Check both Bar StateRefs have been added to the transaction. assertEquals(2, tx.referenceStates().size) + assertEquals(setOf(barOneStateAndRef.ref, barTwoStateAndRef.ref), tx.referenceStates().toSet()) } @Test