diff --git a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt index 44df2fbb6e..9db156bcee 100644 --- a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt @@ -54,6 +54,8 @@ fun Path.writeLines(lines: Iterable, charset: Charset = UTF_8, var fun InputStream.copyTo(target: Path, vararg options: CopyOption): Long = Files.copy(this, target, *options) +fun Class.castIfPossible(obj: Any): T? = if (isInstance(obj)) cast(obj) else null + /** Returns a [DeclaredField] wrapper around the declared (possibly non-public) static field of the receiver [Class]. */ fun Class<*>.staticField(name: String): DeclaredField = DeclaredField(this, name, null) /** Returns a [DeclaredField] wrapper around the declared (possibly non-public) static field of the receiver [KClass]. */ diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt index 86d4fdfa1b..f6670d7a20 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt @@ -5,6 +5,7 @@ import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output +import net.corda.core.internal.castIfPossible import net.corda.core.node.ServiceHub import net.corda.core.serialization.SingletonSerializationToken.Companion.singletonSerializationToken @@ -46,11 +47,7 @@ class SerializeAsTokenSerializer : Serializer() { override fun read(kryo: Kryo, input: Input, type: Class): T { val token = (kryo.readClassAndObject(input) as? SerializationToken) ?: throw KryoException("Non-token read for tokenized type: ${type.name}") val fromToken = token.fromToken(kryo.serializationContext() ?: throw KryoException("Attempt to read a token for a ${SerializeAsToken::class.simpleName} instance of ${type.name} without initialising a context")) - if (type.isAssignableFrom(fromToken.javaClass)) { - return type.cast(fromToken) - } else { - throw KryoException("Token read ($token) did not return expected tokenized type: ${type.name}") - } + return type.castIfPossible(fromToken) ?: throw KryoException("Token read ($token) did not return expected tokenized type: ${type.name}") } } diff --git a/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt index 896fead8be..9f6bb582a5 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt @@ -3,6 +3,7 @@ package net.corda.core.transactions import net.corda.core.contracts.* import net.corda.core.identity.Party import net.corda.core.indexOfOrThrow +import net.corda.core.internal.castIfPossible import java.security.PublicKey import java.util.* import java.util.function.Predicate @@ -65,7 +66,6 @@ abstract class BaseTransaction( @Suppress("UNCHECKED_CAST") fun outRef(index: Int): StateAndRef = StateAndRef(outputs[index] as TransactionState, StateRef(id, index)) - /** * Returns a [StateAndRef] for the requested output state, or throws [IllegalArgumentException] if not found. */ @@ -89,34 +89,41 @@ abstract class BaseTransaction( * Clazz must be an extension of [ContractState]. * @return the possibly empty list of output states matching the clazz restriction. */ - fun outputsOfType(clazz: Class): List { - @Suppress("UNCHECKED_CAST") - return outputs.filter { clazz.isInstance(it.data) }.map { it.data as T } - } + fun outputsOfType(clazz: Class): List = outputs.mapNotNull { clazz.castIfPossible(it.data) } + + inline fun outputsOfType(): List = outputsOfType(T::class.java) /** * Helper to simplify filtering outputs according to a [Predicate]. - * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. * @return the possibly empty list of output states matching the predicate and clazz restrictions. */ - fun filterOutputs(predicate: Predicate, clazz: Class): List { + fun filterOutputs(clazz: Class, predicate: Predicate): List { return outputsOfType(clazz).filter { predicate.test(it) } } + inline fun filterOutputs(crossinline predicate: (T) -> Boolean): List { + return filterOutputs(T::class.java, Predicate { predicate(it) }) + } + /** * Helper to simplify finding a single output matching a [Predicate]. - * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. * @return the single item matching the predicate. * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. */ - fun findOutput(predicate: Predicate, clazz: Class): T { - return filterOutputs(predicate, clazz).single() + fun findOutput(clazz: Class, predicate: Predicate): T { + return outputsOfType(clazz).single { predicate.test(it) } + } + + inline fun findOutput(crossinline predicate: (T) -> Boolean): T { + return findOutput(T::class.java, Predicate { predicate(it) }) } /** @@ -126,55 +133,44 @@ abstract class BaseTransaction( * @return the possibly empty list of output [StateAndRef] states matching the clazz restriction. */ fun outRefsOfType(clazz: Class): List> { - @Suppress("UNCHECKED_CAST") - return outputs.mapIndexed { index, state -> StateAndRef(state, StateRef(id, index)) } - .filter { clazz.isInstance(it.state.data) } - .map { it as StateAndRef } + return outputs.mapIndexedNotNull { index, state -> + @Suppress("UNCHECKED_CAST") + clazz.castIfPossible(state.data)?.let { StateAndRef(state as TransactionState, StateRef(id, index)) } + } } + inline fun outRefsOfType(): List> = outRefsOfType(T::class.java) + /** * Helper to simplify filtering output [StateAndRef] items according to a [Predicate]. - * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. * @return the possibly empty list of output [StateAndRef] states matching the predicate and clazz restrictions. */ - fun filterOutRefs(predicate: Predicate, clazz: Class): List> { + fun filterOutRefs(clazz: Class, predicate: Predicate): List> { return outRefsOfType(clazz).filter { predicate.test(it.state.data) } } + inline fun filterOutRefs(crossinline predicate: (T) -> Boolean): List> { + return filterOutRefs(T::class.java, Predicate { predicate(it) }) + } + /** * Helper to simplify finding a single output [StateAndRef] matching a [Predicate]. - * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. * @return the single [StateAndRef] item matching the predicate. * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. */ - fun findOutRef(predicate: Predicate, clazz: Class): StateAndRef { - return filterOutRefs(predicate, clazz).single() - } - - //Kotlin extension methods to take advantage of Kotlin's smart type inference when querying the LedgerTransaction - inline fun outputsOfType(): List = this.outputsOfType(T::class.java) - - inline fun filterOutputs(crossinline predicate: (T) -> Boolean): List { - return filterOutputs(Predicate { predicate(it) }, T::class.java) - } - - inline fun findOutput(crossinline predicate: (T) -> Boolean): T { - return findOutput(Predicate { predicate(it) }, T::class.java) - } - - inline fun outRefsOfType(): List> = this.outRefsOfType(T::class.java) - - inline fun filterOutRefs(crossinline predicate: (T) -> Boolean): List> { - return filterOutRefs(Predicate { predicate(it) }, T::class.java) + fun findOutRef(clazz: Class, predicate: Predicate): StateAndRef { + return outRefsOfType(clazz).single { predicate.test(it.state.data) } } inline fun findOutRef(crossinline predicate: (T) -> Boolean): StateAndRef { - return findOutRef(Predicate { predicate(it) }, T::class.java) + return findOutRef(T::class.java, Predicate { predicate(it) }) } } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt index dd370b6e6c..7745e7162c 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt @@ -3,6 +3,7 @@ package net.corda.core.transactions import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party +import net.corda.core.internal.castIfPossible import net.corda.core.serialization.CordaSerializable import java.security.PublicKey import java.util.* @@ -51,7 +52,7 @@ class LedgerTransaction( * @return The [StateAndRef] */ @Suppress("UNCHECKED_CAST") - fun inRef(index: Int) = inputs[index] as StateAndRef + fun inRef(index: Int): StateAndRef = inputs[index] as StateAndRef /** * Verifies this transaction and throws an exception if not valid, depending on the type. For general transactions: @@ -158,10 +159,9 @@ class LedgerTransaction( * [clazz] must be an extension of [ContractState]. * @return the possibly empty list of inputs matching the clazz restriction. */ - fun inputsOfType(clazz: Class): List { - @Suppress("UNCHECKED_CAST") - return inputs.map { it.state.data }.filterIsInstance(clazz) - } + fun inputsOfType(clazz: Class): List = inputs.mapNotNull { clazz.castIfPossible(it.state.data) } + + inline fun inputsOfType(): List = inputsOfType(T::class.java) /** * Helper to simplify getting all inputs states of a particular class, interface, or base class. @@ -171,18 +171,26 @@ class LedgerTransaction( */ fun inRefsOfType(clazz: Class): List> { @Suppress("UNCHECKED_CAST") - return inputs.filter { clazz.isInstance(it.state.data) }.map { it as StateAndRef } + return inputs.mapNotNull { if (clazz.isInstance(it.state.data)) it as StateAndRef else null } } + inline fun inRefsOfType(): List> = inRefsOfType(T::class.java) + /** * Helper to simplify filtering inputs according to a [Predicate]. - * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * [clazz] must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. * @return the possibly empty list of input states matching the predicate and clazz restrictions. */ - fun filterInputs(predicate: Predicate, clazz: Class): List = inputsOfType(clazz).filter { predicate.test(it) } + fun filterInputs(clazz: Class, predicate: Predicate): List { + return inputsOfType(clazz).filter { predicate.test(it) } + } + + inline fun filterInputs(crossinline predicate: (T) -> Boolean): List { + return filterInputs(T::class.java, Predicate { predicate(it) }) + } /** * Helper to simplify filtering inputs according to a [Predicate]. @@ -192,29 +200,47 @@ class LedgerTransaction( * [clazz] must be an extension of [ContractState]. * @return the possibly empty list of inputs [StateAndRef] matching the predicate and clazz restrictions. */ - fun filterInRefs(predicate: Predicate, clazz: Class): List> = inRefsOfType(clazz).filter { predicate.test(it.state.data) } + fun filterInRefs(clazz: Class, predicate: Predicate): List> { + return inRefsOfType(clazz).filter { predicate.test(it.state.data) } + } + + inline fun filterInRefs(crossinline predicate: (T) -> Boolean): List> { + return filterInRefs(T::class.java, Predicate { predicate(it) }) + } /** * Helper to simplify finding a single input [ContractState] matching a [Predicate]. - * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * [clazz] must be an extension of ContractState. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. * @return the single item matching the predicate. * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. */ - fun findInput(predicate: Predicate, clazz: Class): T = filterInputs(predicate, clazz).single() + fun findInput(clazz: Class, predicate: Predicate): T { + return inputsOfType(clazz).single { predicate.test(it) } + } + + inline fun findInput(crossinline predicate: (T) -> Boolean): T { + return findInput(T::class.java, Predicate { predicate(it) }) + } /** * Helper to simplify finding a single input matching a [Predicate]. - * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * [clazz] must be an extension of ContractState. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. * @return the single item matching the predicate. * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. */ - fun findInRef(predicate: Predicate, clazz: Class): StateAndRef = filterInRefs(predicate, clazz).single() + fun findInRef(clazz: Class, predicate: Predicate): StateAndRef { + return inRefsOfType(clazz).single { predicate.test(it.state.data) } + } + + inline fun findInRef(crossinline predicate: (T) -> Boolean): StateAndRef { + return findInRef(T::class.java, Predicate { predicate(it) }) + } /** * Helper to simplify getting an indexed command. @@ -230,34 +256,44 @@ class LedgerTransaction( * @return the possibly empty list of commands with [CommandData] values matching the clazz restriction. */ fun commandsOfType(clazz: Class): List { - return commands.filter { clazz.isInstance(it.value) }.map { Command(it.value, it.signers) } + return commands.mapNotNull { (signers, _, value) -> clazz.castIfPossible(value)?.let { Command(it, signers) } } } + inline fun commandsOfType(): List = commandsOfType(T::class.java) + /** * Helper to simplify filtering [Command] items according to a [Predicate]. - * @param predicate A filtering function taking a [CommandData] item of type T and returning true if it should be included in the list. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * [clazz] must be an extension of [CommandData]. + * @param predicate A filtering function taking a [CommandData] item of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. * @return the possibly empty list of [Command] items with [CommandData] values matching the predicate and clazz restrictions. */ - fun filterCommands(predicate: Predicate, clazz: Class): List { + fun filterCommands(clazz: Class, predicate: Predicate): List { @Suppress("UNCHECKED_CAST") return commandsOfType(clazz).filter { predicate.test(it.value as T) } } + inline fun filterCommands(crossinline predicate: (T) -> Boolean): List { + return filterCommands(T::class.java, Predicate { predicate(it) }) + } /** * Helper to simplify finding a single [Command] items according to a [Predicate]. - * @param predicate A filtering function taking a [CommandData] item of type T and returning true if it should be included in the list. - * The class filtering is applied before the predicate. * @param clazz The class type used for filtering via an [Class.isInstance] check. * [clazz] must be an extension of [CommandData]. + * @param predicate A filtering function taking a [CommandData] item of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. * @return the [Command] item with [CommandData] values matching the predicate and clazz restrictions. * @throws IllegalArgumentException if no items, or multiple items matched the requirements. */ - fun findCommand(predicate: Predicate, clazz: Class): Command { - return filterCommands(predicate, clazz).single() + fun findCommand(clazz: Class, predicate: Predicate): Command { + @Suppress("UNCHECKED_CAST") + return commandsOfType(clazz).single { predicate.test(it.value as T) } + } + + inline fun findCommand(crossinline predicate: (T) -> Boolean): Command { + return findCommand(T::class.java, Predicate { predicate(it) }) } /** @@ -273,37 +309,6 @@ class LedgerTransaction( * @return The Attachment with the matching id. * @throws IllegalArgumentException if no item matches the id. */ - fun getAttachment(id: SecureHash): Attachment = attachments.single { it.id == id } - - //Kotlin extension methods to take advantage of Kotlin's smart type inference when querying the LedgerTransaction - inline fun inputsOfType(): List = this.inputsOfType(T::class.java) - - inline fun inRefsOfType(): List> = this.inRefsOfType(T::class.java) - - inline fun filterInputs(crossinline predicate: (T) -> Boolean): List { - return filterInputs(Predicate { predicate(it) }, T::class.java) - } - - inline fun filterInRefs(crossinline predicate: (T) -> Boolean): List> { - return filterInRefs(Predicate { predicate(it) }, T::class.java) - } - - inline fun findInRef(crossinline predicate: (T) -> Boolean): StateAndRef { - return findInRef(Predicate { predicate(it) }, T::class.java) - } - - inline fun findInput(crossinline predicate: (T) -> Boolean): T { - return findInput(Predicate { predicate(it) }, T::class.java) - } - - inline fun commandsOfType(): List = this.commandsOfType(T::class.java) - - inline fun filterCommands(crossinline predicate: (T) -> Boolean): List { - return filterCommands(Predicate { predicate(it) }, T::class.java) - } - - inline fun findCommand(crossinline predicate: (T) -> Boolean): Command { - return findCommand(Predicate { predicate(it) }, T::class.java) - } + fun getAttachment(id: SecureHash): Attachment = attachments.first { it.id == id } } diff --git a/core/src/test/kotlin/net/corda/core/contracts/LedgerTransactionQueryTests.kt b/core/src/test/kotlin/net/corda/core/contracts/LedgerTransactionQueryTests.kt index 55b36038c6..752b0c9c02 100644 --- a/core/src/test/kotlin/net/corda/core/contracts/LedgerTransactionQueryTests.kt +++ b/core/src/test/kotlin/net/corda/core/contracts/LedgerTransactionQueryTests.kt @@ -189,7 +189,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Filtered Input Tests`() { val ltx = makeDummyTransaction() - val intStates = ltx.filterInputs(Predicate { it.data.rem(2) == 0 }, IntTypeDummyState::class.java) + val intStates = ltx.filterInputs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) assertEquals(3, intStates.size) assertEquals(listOf(0, 2, 4), intStates.map { it.data }) val stringStates: List = ltx.filterInputs { it.data == "3" } @@ -199,7 +199,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Filtered InRef Tests`() { val ltx = makeDummyTransaction() - val intStates = ltx.filterInRefs(Predicate { it.data.rem(2) == 0 }, IntTypeDummyState::class.java) + val intStates = ltx.filterInRefs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) assertEquals(3, intStates.size) assertEquals(listOf(0, 2, 4), intStates.map { it.state.data.data }) assertEquals(listOf(ltx.inputs[0], ltx.inputs[4], ltx.inputs[8]), intStates) @@ -211,7 +211,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Filtered Output Tests`() { val ltx = makeDummyTransaction() - val intStates = ltx.filterOutputs(Predicate { it.data.rem(2) == 0 }, IntTypeDummyState::class.java) + val intStates = ltx.filterOutputs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) assertEquals(3, intStates.size) assertEquals(listOf(0, 2, 4), intStates.map { it.data }) val stringStates: List = ltx.filterOutputs { it.data == "3" } @@ -221,7 +221,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Filtered OutRef Tests`() { val ltx = makeDummyTransaction() - val intStates = ltx.filterOutRefs(Predicate { it.data.rem(2) == 0 }, IntTypeDummyState::class.java) + val intStates = ltx.filterOutRefs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) assertEquals(3, intStates.size) assertEquals(listOf(0, 2, 4), intStates.map { it.state.data.data }) assertEquals(listOf(0, 4, 8), intStates.map { it.ref.index }) @@ -235,7 +235,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Filtered Commands Tests`() { val ltx = makeDummyTransaction() - val intCmds1 = ltx.filterCommands(Predicate { it.id.rem(2) == 0 }, Commands.Cmd1::class.java) + val intCmds1 = ltx.filterCommands(Commands.Cmd1::class.java, Predicate { it.id.rem(2) == 0 }) assertEquals(3, intCmds1.size) assertEquals(listOf(0, 2, 4), intCmds1.map { (it.value as Commands.Cmd1).id }) val intCmds2 = ltx.filterCommands { it.id == 3 } @@ -245,7 +245,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Find Input Tests`() { val ltx = makeDummyTransaction() - val intState = ltx.findInput(Predicate { it.data == 4 }, IntTypeDummyState::class.java) + val intState = ltx.findInput(IntTypeDummyState::class.java, Predicate { it.data == 4 }) assertEquals(ltx.getInput(8), intState) val stringState: StringTypeDummyState = ltx.findInput { it.data == "3" } assertEquals(ltx.getInput(7), stringState) @@ -254,7 +254,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Find InRef Tests`() { val ltx = makeDummyTransaction() - val intState = ltx.findInRef(Predicate { it.data == 4 }, IntTypeDummyState::class.java) + val intState = ltx.findInRef(IntTypeDummyState::class.java, Predicate { it.data == 4 }) assertEquals(ltx.inRef(8), intState) val stringState: StateAndRef = ltx.findInRef { it.data == "3" } assertEquals(ltx.inRef(7), stringState) @@ -263,7 +263,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Find Output Tests`() { val ltx = makeDummyTransaction() - val intState = ltx.findOutput(Predicate { it.data == 4 }, IntTypeDummyState::class.java) + val intState = ltx.findOutput(IntTypeDummyState::class.java, Predicate { it.data == 4 }) assertEquals(ltx.getOutput(8), intState) val stringState: StringTypeDummyState = ltx.findOutput { it.data == "3" } assertEquals(ltx.getOutput(7), stringState) @@ -272,7 +272,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Find OutRef Tests`() { val ltx = makeDummyTransaction() - val intState = ltx.findOutRef(Predicate { it.data == 4 }, IntTypeDummyState::class.java) + val intState = ltx.findOutRef(IntTypeDummyState::class.java, Predicate { it.data == 4 }) assertEquals(ltx.outRef(8), intState) val stringState: StateAndRef = ltx.findOutRef { it.data == "3" } assertEquals(ltx.outRef(7), stringState) @@ -281,7 +281,7 @@ class LedgerTransactionQueryTests : TestDependencyInjectionBase() { @Test fun `Find Commands Tests`() { val ltx = makeDummyTransaction() - val intCmd1 = ltx.findCommand(Predicate { it.id == 2 }, Commands.Cmd1::class.java) + val intCmd1 = ltx.findCommand(Commands.Cmd1::class.java, Predicate { it.id == 2 }) assertEquals(ltx.getCommand(4), intCmd1) val intCmd2 = ltx.findCommand { it.id == 3 } assertEquals(ltx.getCommand(7), intCmd2) diff --git a/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt index 85d777d8d0..00d72cd0b4 100644 --- a/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt @@ -1,5 +1,6 @@ package net.corda.node.services.database +import net.corda.core.internal.castIfPossible import net.corda.core.schemas.MappedSchema import net.corda.core.utilities.loggerFor import net.corda.node.services.api.SchemaService @@ -100,17 +101,14 @@ class HibernateConfiguration(val schemaService: SchemaService, val useDefaultLog override fun supportsAggressiveRelease(): Boolean = true - override fun getConnection(): Connection = - DatabaseTransactionManager.newTransaction(Connection.TRANSACTION_REPEATABLE_READ).connection - - override fun unwrap(unwrapType: Class): T { - try { - return unwrapType.cast(this) - } catch(e: ClassCastException) { - throw UnknownUnwrapTypeException(unwrapType) - } + override fun getConnection(): Connection { + return DatabaseTransactionManager.newTransaction(Connection.TRANSACTION_REPEATABLE_READ).connection } - override fun isUnwrappableAs(unwrapType: Class<*>?): Boolean = (unwrapType == NodeDatabaseConnectionProvider::class.java) + override fun unwrap(unwrapType: Class): T { + return unwrapType.castIfPossible(this) ?: throw UnknownUnwrapTypeException(unwrapType) + } + + override fun isUnwrappableAs(unwrapType: Class<*>?): Boolean = unwrapType == NodeDatabaseConnectionProvider::class.java } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt index 8b25d0b0b7..baee9d549f 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt @@ -4,6 +4,7 @@ import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.identity.Party +import net.corda.core.internal.castIfPossible import net.corda.core.serialization.CordaSerializable import net.corda.core.utilities.UntrustworthyData @@ -42,10 +43,7 @@ data class ErrorSessionEnd(override val recipientSessionId: Long, val errorRespo data class ReceivedSessionMessage(val sender: Party, val message: M) fun ReceivedSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { - if (type.isInstance(message.payload)) { - return UntrustworthyData(type.cast(message.payload)) - } else { - throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " + - "${message.payload.javaClass.name} (${message.payload})") - } + return type.castIfPossible(message.payload)?.let { UntrustworthyData(it) } ?: + throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " + + "${message.payload.javaClass.name} (${message.payload})") } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 6ee71cb3c1..009c5a10ec 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -17,6 +17,7 @@ import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party +import net.corda.core.internal.castIfPossible import net.corda.core.messaging.DataFeed import net.corda.core.serialization.* import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT @@ -145,10 +146,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, fun

, T> findStateMachines(flowClass: Class

): List>> { @Suppress("UNCHECKED_CAST") return mutex.locked { - stateMachines.keys - .map { it.logic } - .filterIsInstance(flowClass) - .map { it to (it.stateMachine as FlowStateMachineImpl).resultFuture } + stateMachines.keys.mapNotNull { + flowClass.castIfPossible(it.logic)?.let { it to (it.stateMachine as FlowStateMachineImpl).resultFuture } + } } } @@ -380,7 +380,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? { return try { - checkpoint.serializedFiber.deserialize>(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { fromCheckpoint = true } + checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { fromCheckpoint = true } } catch (t: Throwable) { logger.error("Encountered unrestorable checkpoint!", t) null