mirror of
https://github.com/corda/corda.git
synced 2025-06-18 15:18:16 +00:00
CORDA-3188: Ignore synthetic and static fields when searching for state pointers (#5439)
This commit is contained in:
@ -11,8 +11,10 @@ import java.util.*
|
|||||||
* TODO: Doesn't handle calculated properties. Add support for this.
|
* TODO: Doesn't handle calculated properties. Add support for this.
|
||||||
*/
|
*/
|
||||||
class StatePointerSearch(val state: ContractState) {
|
class StatePointerSearch(val state: ContractState) {
|
||||||
// Classes in these packages should not be part of a search.
|
private companion object {
|
||||||
private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.")
|
// Classes in these packages should not be part of a search.
|
||||||
|
private val blackListedPackages = setOf("java.", "javax.", "org.bouncycastle.", "net.i2p.crypto.")
|
||||||
|
}
|
||||||
|
|
||||||
// Type required for traversal.
|
// Type required for traversal.
|
||||||
private data class FieldWithObject(val obj: Any, val field: Field)
|
private data class FieldWithObject(val obj: Any, val field: Field)
|
||||||
@ -21,16 +23,17 @@ class StatePointerSearch(val state: ContractState) {
|
|||||||
private val statePointers = mutableSetOf<StatePointer<*>>()
|
private val statePointers = mutableSetOf<StatePointer<*>>()
|
||||||
|
|
||||||
// Record seen objects to avoid getting stuck in loops.
|
// Record seen objects to avoid getting stuck in loops.
|
||||||
private val seenObjects = Collections.newSetFromMap(IdentityHashMap<Any, Boolean>()).apply { add(state) }
|
private val seenObjects = Collections.newSetFromMap(IdentityHashMap<Any, Boolean>())
|
||||||
|
|
||||||
// Queue of fields to search.
|
// Queue of fields to search.
|
||||||
private val fieldQueue = ArrayDeque<FieldWithObject>().apply { addAllFields(state) }
|
private val fieldQueue = ArrayDeque<FieldWithObject>()
|
||||||
|
|
||||||
// Helper for adding all fields to the queue.
|
// Helper for adding all fields to the queue.
|
||||||
private fun ArrayDeque<FieldWithObject>.addAllFields(obj: Any) {
|
private fun addAllFields(obj: Any) {
|
||||||
val fields = FieldUtils.getAllFieldsList(obj::class.java)
|
val fields = FieldUtils.getAllFieldsList(obj::class.java)
|
||||||
|
|
||||||
val fieldsWithObjects = fields.mapNotNull { field ->
|
fields.mapNotNullTo(fieldQueue) { field ->
|
||||||
|
if (field.isSynthetic || field.isStatic) return@mapNotNullTo null
|
||||||
// Ignore classes which have not been loaded.
|
// Ignore classes which have not been loaded.
|
||||||
// Assumption: all required state classes are already loaded.
|
// Assumption: all required state classes are already loaded.
|
||||||
val packageName = field.type.packageNameOrNull
|
val packageName = field.type.packageNameOrNull
|
||||||
@ -40,11 +43,10 @@ class StatePointerSearch(val state: ContractState) {
|
|||||||
FieldWithObject(obj, field)
|
FieldWithObject(obj, field)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
addAll(fieldsWithObjects)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleIterable(iterable: Iterable<*>) {
|
private fun handleIterable(iterable: Iterable<*>) {
|
||||||
iterable.forEach { obj -> handleObject(obj) }
|
iterable.forEach(::handleObject)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleMap(map: Map<*, *>) {
|
private fun handleMap(map: Map<*, *>) {
|
||||||
@ -55,8 +57,7 @@ class StatePointerSearch(val state: ContractState) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private fun handleObject(obj: Any?) {
|
private fun handleObject(obj: Any?) {
|
||||||
if (obj == null) return
|
if (obj == null || !seenObjects.add(obj)) return
|
||||||
seenObjects.add(obj)
|
|
||||||
when (obj) {
|
when (obj) {
|
||||||
is Map<*, *> -> handleMap(obj)
|
is Map<*, *> -> handleMap(obj)
|
||||||
is StatePointer<*> -> statePointers.add(obj)
|
is StatePointer<*> -> statePointers.add(obj)
|
||||||
@ -64,22 +65,17 @@ class StatePointerSearch(val state: ContractState) {
|
|||||||
else -> {
|
else -> {
|
||||||
val packageName = obj.javaClass.packageNameOrNull ?: ""
|
val packageName = obj.javaClass.packageNameOrNull ?: ""
|
||||||
val isBlackListed = blackListedPackages.any { packageName.startsWith(it) }
|
val isBlackListed = blackListedPackages.any { packageName.startsWith(it) }
|
||||||
if (isBlackListed.not()) fieldQueue.addAllFields(obj)
|
if (!isBlackListed) addAllFields(obj)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun handleField(obj: Any, field: Field) {
|
|
||||||
val newObj = field.get(obj) ?: return
|
|
||||||
if (newObj in seenObjects) return
|
|
||||||
handleObject(newObj)
|
|
||||||
}
|
|
||||||
|
|
||||||
fun search(): Set<StatePointer<*>> {
|
fun search(): Set<StatePointer<*>> {
|
||||||
|
handleObject(state)
|
||||||
while (fieldQueue.isNotEmpty()) {
|
while (fieldQueue.isNotEmpty()) {
|
||||||
val (obj, field) = fieldQueue.pop()
|
val (obj, field) = fieldQueue.pop()
|
||||||
field.isAccessible = true
|
field.isAccessible = true
|
||||||
handleField(obj, field)
|
handleObject(field.get(obj))
|
||||||
}
|
}
|
||||||
return statePointers
|
return statePointers
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import net.corda.core.crypto.NullKeys
|
|||||||
import net.corda.core.identity.AbstractParty
|
import net.corda.core.identity.AbstractParty
|
||||||
import net.corda.core.identity.AnonymousParty
|
import net.corda.core.identity.AnonymousParty
|
||||||
import net.corda.core.utilities.OpaqueBytes
|
import net.corda.core.utilities.OpaqueBytes
|
||||||
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
import org.junit.Test
|
import org.junit.Test
|
||||||
import kotlin.test.assertEquals
|
import kotlin.test.assertEquals
|
||||||
|
|
||||||
@ -32,6 +33,15 @@ class StatePointerSearchTests {
|
|||||||
override val participants: List<AbstractParty> get() = listOf()
|
override val participants: List<AbstractParty> get() = listOf()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private data class StateWithStaticField(val blah: Int) : ContractState {
|
||||||
|
companion object {
|
||||||
|
@JvmStatic
|
||||||
|
val pointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
|
}
|
||||||
|
|
||||||
|
override val participants: List<AbstractParty> get() = listOf()
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `find pointer in state with generic type`() {
|
fun `find pointer in state with generic type`() {
|
||||||
val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
val linearPointer = LinearPointer(UniqueIdentifier(), LinearState::class.java)
|
||||||
@ -74,4 +84,9 @@ class StatePointerSearchTests {
|
|||||||
assertEquals(results, setOf(linearPointer))
|
assertEquals(results, setOf(linearPointer))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `ignore static fields`() {
|
||||||
|
val results = StatePointerSearch(StateWithStaticField(1)).search()
|
||||||
|
assertThat(results).isEmpty()
|
||||||
|
}
|
||||||
}
|
}
|
Reference in New Issue
Block a user