mirror of
https://github.com/corda/corda.git
synced 2025-06-05 09:00:53 +00:00
CORDA-2426 Fixed bug in state pointer search. (#4561)
* Fixed bug in state pointer search and added tests. * Blacklisted problematic package. * Addressed Shams' comments. * Addressed round two of comments. * Fixed another bug whereby the DFS gets stuck in an infinite loop.
This commit is contained in:
parent
5c5407fbed
commit
084b3a1a1d
@ -1,18 +1,17 @@
|
|||||||
package net.corda.core.internal
|
package net.corda.core.internal
|
||||||
|
|
||||||
import net.corda.core.contracts.ContractState
|
import net.corda.core.contracts.ContractState
|
||||||
import net.corda.core.contracts.LinearPointer
|
|
||||||
import net.corda.core.contracts.StatePointer
|
import net.corda.core.contracts.StatePointer
|
||||||
import net.corda.core.contracts.StaticPointer
|
|
||||||
import java.lang.reflect.Field
|
import java.lang.reflect.Field
|
||||||
import java.util.*
|
import java.util.*
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Uses reflection to search for instances of [StatePointer] within a [ContractState].
|
* Uses reflection to search for instances of [StatePointer] within a [ContractState].
|
||||||
|
* TODO: Doesn't handle calculated properties. Add support for this.
|
||||||
*/
|
*/
|
||||||
class StatePointerSearch(val state: ContractState) {
|
class StatePointerSearch(val state: ContractState) {
|
||||||
// Classes in these packages should not be part of a search.
|
// Classes in these packages should not be part of a search.
|
||||||
private val blackListedPackages = setOf("java.", "javax.")
|
private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.")
|
||||||
|
|
||||||
// Type required for traversal.
|
// Type required for traversal.
|
||||||
private data class FieldWithObject(val obj: Any, val field: Field)
|
private data class FieldWithObject(val obj: Any, val field: Field)
|
||||||
@ -21,14 +20,26 @@ class StatePointerSearch(val state: ContractState) {
|
|||||||
private val statePointers = mutableSetOf<StatePointer<*>>()
|
private val statePointers = mutableSetOf<StatePointer<*>>()
|
||||||
|
|
||||||
// Record seen objects to avoid getting stuck in loops.
|
// Record seen objects to avoid getting stuck in loops.
|
||||||
private val seenObjects = mutableSetOf<Any>().apply { add(state) }
|
private val seenObjects = Collections.newSetFromMap(IdentityHashMap<Any, Boolean>()).apply { add(state) }
|
||||||
|
|
||||||
// Queue of fields to search.
|
// Queue of fields to search.
|
||||||
private val fieldQueue = ArrayDeque<FieldWithObject>().apply { addAllFields(state) }
|
private val fieldQueue = ArrayDeque<FieldWithObject>().apply { addAllFields(state) }
|
||||||
|
|
||||||
|
// Get fields of class and all super-classes.
|
||||||
|
private fun getAllFields(clazz: Class<*>): List<Field> {
|
||||||
|
val fields = mutableListOf<Field>()
|
||||||
|
var currentClazz = clazz
|
||||||
|
while (currentClazz.superclass != null) {
|
||||||
|
fields.addAll(currentClazz.declaredFields)
|
||||||
|
currentClazz = currentClazz.superclass
|
||||||
|
}
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
// Helper for adding all fields to the queue.
|
// Helper for adding all fields to the queue.
|
||||||
private fun ArrayDeque<FieldWithObject>.addAllFields(obj: Any) {
|
private fun ArrayDeque<FieldWithObject>.addAllFields(obj: Any) {
|
||||||
val fields = obj::class.java.declaredFields
|
val fields = getAllFields(obj::class.java)
|
||||||
|
|
||||||
val fieldsWithObjects = fields.mapNotNull { field ->
|
val fieldsWithObjects = fields.mapNotNull { field ->
|
||||||
// Ignore classes which have not been loaded.
|
// Ignore classes which have not been loaded.
|
||||||
// Assumption: all required state classes are already loaded.
|
// Assumption: all required state classes are already loaded.
|
||||||
@ -36,39 +47,44 @@ class StatePointerSearch(val state: ContractState) {
|
|||||||
if (packageName == null) {
|
if (packageName == null) {
|
||||||
null
|
null
|
||||||
} else {
|
} else {
|
||||||
// Ignore JDK classes.
|
FieldWithObject(obj, field)
|
||||||
val isBlacklistedPackage = blackListedPackages.any { packageName.startsWith(it) }
|
|
||||||
if (isBlacklistedPackage) {
|
|
||||||
null
|
|
||||||
} else {
|
|
||||||
FieldWithObject(obj, field)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
addAll(fieldsWithObjects)
|
addAll(fieldsWithObjects)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleField(obj: Any, field: Field) {
|
private fun handleIterable(iterable: Iterable<*>) {
|
||||||
when {
|
iterable.forEach { obj -> handleObject(obj) }
|
||||||
// StatePointer. Handles nullable StatePointers too.
|
}
|
||||||
field.type == LinearPointer::class.java -> statePointers.add(field.get(obj) as? LinearPointer<*> ?: return)
|
|
||||||
field.type == StaticPointer::class.java -> statePointers.add(field.get(obj) as? StaticPointer<*> ?: return)
|
private fun handleMap(map: Map<*, *>) {
|
||||||
// Not StatePointer.
|
map.forEach { k, v ->
|
||||||
|
handleObject(k)
|
||||||
|
handleObject(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun handleObject(obj: Any?) {
|
||||||
|
if (obj == null) return
|
||||||
|
seenObjects.add(obj)
|
||||||
|
when (obj) {
|
||||||
|
is Map<*, *> -> handleMap(obj)
|
||||||
|
is StatePointer<*> -> statePointers.add(obj)
|
||||||
|
is Iterable<*> -> handleIterable(obj)
|
||||||
else -> {
|
else -> {
|
||||||
val newObj = field.get(obj) ?: return
|
val packageName = obj.javaClass.`package`.name
|
||||||
|
val isBlackListed = blackListedPackages.any { packageName.startsWith(it) }
|
||||||
// Ignore nulls.
|
if (isBlackListed.not()) fieldQueue.addAllFields(obj)
|
||||||
if (newObj in seenObjects) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Recurse.
|
|
||||||
fieldQueue.addAllFields(newObj)
|
|
||||||
seenObjects.add(obj)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun handleField(obj: Any, field: Field) {
|
||||||
|
val newObj = field.get(obj) ?: return
|
||||||
|
if (newObj in seenObjects) return
|
||||||
|
handleObject(newObj)
|
||||||
|
}
|
||||||
|
|
||||||
fun search(): Set<StatePointer<*>> {
|
fun search(): Set<StatePointer<*>> {
|
||||||
while (fieldQueue.isNotEmpty()) {
|
while (fieldQueue.isNotEmpty()) {
|
||||||
val (obj, field) = fieldQueue.pop()
|
val (obj, field) = fieldQueue.pop()
|
||||||
|
@ -0,0 +1,77 @@
|
|||||||
|
package net.corda.core.internal
|
||||||
|
|
||||||
|
import net.corda.core.contracts.*
|
||||||
|
import net.corda.core.crypto.NullKeys
|
||||||
|
import net.corda.core.identity.AbstractParty
|
||||||
|
import net.corda.core.identity.AnonymousParty
|
||||||
|
import net.corda.core.utilities.OpaqueBytes
|
||||||
|
import org.junit.Test
|
||||||
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
|
class StatePointerSearchTests {
|
||||||
|
|
||||||
|
private val partyAndRef = PartyAndReference(AnonymousParty(NullKeys.NullPublicKey), OpaqueBytes.of(0))
|
||||||
|
|
||||||
|
private data class StateWithGeneric(val amount: Amount<Issued<LinearPointer<LinearState>>>) : ContractState {
|
||||||
|
override val participants: List<AbstractParty> get() = listOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
private data class StateWithList(val pointerList: List<LinearPointer<LinearState>>) : ContractState {
|
||||||
|
override val participants: List<AbstractParty> get() = listOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
private data class StateWithMap(val pointerMap: Map<Any, Any>) : ContractState {
|
||||||
|
override val participants: List<AbstractParty> get() = listOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
private data class StateWithSet(val pointerSet: Set<LinearPointer<LinearState>>) : ContractState {
|
||||||
|
override val participants: List<AbstractParty> get() = listOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
private data class StateWithListOfList(val pointerSet: List<List<LinearPointer<LinearState>>>) : ContractState {
|
||||||
|
override val participants: List<AbstractParty> get() = listOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `find pointer in state with generic type`() {
|
||||||
|
val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
val testState = StateWithGeneric(Amount(100L, Issued(partyAndRef, linearPointer)))
|
||||||
|
val results = StatePointerSearch(testState).search()
|
||||||
|
assertEquals(results, setOf(linearPointer))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `find pointers which are inside a list`() {
|
||||||
|
val linearPointerOne = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
val linearPointerTwo = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
val testState = StateWithList(listOf(linearPointerOne, linearPointerTwo))
|
||||||
|
val results = StatePointerSearch(testState).search()
|
||||||
|
assertEquals(results, setOf(linearPointerOne, linearPointerTwo))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `find pointers which are inside a map`() {
|
||||||
|
val linearPointerOne = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
val linearPointerTwo = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
val testState = StateWithMap(mapOf(linearPointerOne to 1, 2 to linearPointerTwo))
|
||||||
|
val results = StatePointerSearch(testState).search()
|
||||||
|
assertEquals(results, setOf(linearPointerOne, linearPointerTwo))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `find pointers which are inside a set`() {
|
||||||
|
val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
val testState = StateWithSet(setOf(linearPointer))
|
||||||
|
val results = StatePointerSearch(testState).search()
|
||||||
|
assertEquals(results, setOf(linearPointer))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `find pointers which are inside nested iterables`() {
|
||||||
|
val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
val testState = StateWithListOfList(listOf(listOf(linearPointer)))
|
||||||
|
val results = StatePointerSearch(testState).search()
|
||||||
|
assertEquals(results, setOf(linearPointer))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -87,7 +87,7 @@ class ResolveStatePointersTest {
|
|||||||
@Test
|
@Test
|
||||||
fun `resolving nested pointers is possible`() {
|
fun `resolving nested pointers is possible`() {
|
||||||
// Create barOne.
|
// Create barOne.
|
||||||
createPointedToState(barOne)
|
val barOneStateAndRef = createPointedToState(barOne)
|
||||||
|
|
||||||
// Create another Bar - barTwo - which points to barOne.
|
// Create another Bar - barTwo - which points to barOne.
|
||||||
val barTwoStateAndRef = createPointedToState(barTwo)
|
val barTwoStateAndRef = createPointedToState(barTwo)
|
||||||
@ -105,6 +105,7 @@ class ResolveStatePointersTest {
|
|||||||
|
|
||||||
// Check both Bar StateRefs have been added to the transaction.
|
// Check both Bar StateRefs have been added to the transaction.
|
||||||
assertEquals(2, tx.referenceStates().size)
|
assertEquals(2, tx.referenceStates().size)
|
||||||
|
assertEquals(setOf(barOneStateAndRef.ref, barTwoStateAndRef.ref), tx.referenceStates().toSet())
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
Loading…
x
Reference in New Issue
Block a user