From cedb290bc99083330f3d6ec076f221d22ce25186 Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Fri, 6 Sep 2019 10:25:00 +0100 Subject: [PATCH] CORDA-3188: Ignore synthetic and static fields when searching for state pointers (#5439) --- .../corda/core/internal/StatePointerSearch.kt | 32 ++++++++----------- .../core/internal/StatePointerSearchTests.kt | 15 +++++++++ 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt b/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt index 4fc21ca409..e552172844 100644 --- a/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt +++ b/core/src/main/kotlin/net/corda/core/internal/StatePointerSearch.kt @@ -11,8 +11,10 @@ import java.util.* * TODO: Doesn't handle calculated properties. Add support for this. */ class StatePointerSearch(val state: ContractState) { - // Classes in these packages should not be part of a search. - private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.") + private companion object { + // Classes in these packages should not be part of a search. + private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.") + } // Type required for traversal. private data class FieldWithObject(val obj: Any, val field: Field) @@ -21,16 +23,17 @@ class StatePointerSearch(val state: ContractState) { private val statePointers = mutableSetOf>() // Record seen objects to avoid getting stuck in loops. - private val seenObjects = Collections.newSetFromMap(IdentityHashMap()).apply { add(state) } + private val seenObjects = Collections.newSetFromMap(IdentityHashMap()) // Queue of fields to search. - private val fieldQueue = ArrayDeque().apply { addAllFields(state) } + private val fieldQueue = ArrayDeque() // Helper for adding all fields to the queue. - private fun ArrayDeque.addAllFields(obj: Any) { + private fun addAllFields(obj: Any) { val fields = FieldUtils.getAllFieldsList(obj::class.java) - val fieldsWithObjects = fields.mapNotNull { field -> + fields.mapNotNullTo(fieldQueue) { field -> + if (field.isSynthetic || field.isStatic) return@mapNotNullTo null // Ignore classes which have not been loaded. // Assumption: all required state classes are already loaded. val packageName = field.type.packageNameOrNull @@ -40,11 +43,10 @@ class StatePointerSearch(val state: ContractState) { FieldWithObject(obj, field) } } - addAll(fieldsWithObjects) } private fun handleIterable(iterable: Iterable<*>) { - iterable.forEach { obj -> handleObject(obj) } + iterable.forEach(::handleObject) } private fun handleMap(map: Map<*, *>) { @@ -55,8 +57,7 @@ class StatePointerSearch(val state: ContractState) { } private fun handleObject(obj: Any?) { - if (obj == null) return - seenObjects.add(obj) + if (obj == null || !seenObjects.add(obj)) return when (obj) { is Map<*, *> -> handleMap(obj) is StatePointer<*> -> statePointers.add(obj) @@ -64,22 +65,17 @@ class StatePointerSearch(val state: ContractState) { else -> { val packageName = obj.javaClass.packageNameOrNull ?: "" val isBlackListed = blackListedPackages.any { packageName.startsWith(it) } - if (isBlackListed.not()) fieldQueue.addAllFields(obj) + if (!isBlackListed) addAllFields(obj) } } } - private fun handleField(obj: Any, field: Field) { - val newObj = field.get(obj) ?: return - if (newObj in seenObjects) return - handleObject(newObj) - } - fun search(): Set> { + handleObject(state) while (fieldQueue.isNotEmpty()) { val (obj, field) = fieldQueue.pop() field.isAccessible = true - handleField(obj, field) + handleObject(field.get(obj)) } return statePointers } diff --git a/core/src/test/kotlin/net/corda/core/internal/StatePointerSearchTests.kt b/core/src/test/kotlin/net/corda/core/internal/StatePointerSearchTests.kt index 716fa2c157..cf1686e1f0 100644 --- a/core/src/test/kotlin/net/corda/core/internal/StatePointerSearchTests.kt +++ b/core/src/test/kotlin/net/corda/core/internal/StatePointerSearchTests.kt @@ -5,6 +5,7 @@ import net.corda.core.crypto.NullKeys import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty import net.corda.core.utilities.OpaqueBytes +import org.assertj.core.api.Assertions.assertThat import org.junit.Test import kotlin.test.assertEquals @@ -32,6 +33,15 @@ class StatePointerSearchTests { override val participants: List get() = listOf() } + private data class StateWithStaticField(val blah: Int) : ContractState { + companion object { + @JvmStatic + val pointer = LinearPointer(UniqueIdentifier(), LinearState::class.java) + } + + override val participants: List get() = listOf() + } + @Test fun `find pointer in state with generic type`() { val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java) @@ -74,4 +84,9 @@ class StatePointerSearchTests { assertEquals(results, setOf(linearPointer)) } + @Test + fun `ignore static fields`() { + val results = StatePointerSearch(StateWithStaticField(1)).search() + assertThat(results).isEmpty() + } } \ No newline at end of file