diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/BlacklistKotlinClosureTest.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/BlacklistKotlinClosureTest.kt new file mode 100644 index 0000000000..776b96f87f --- /dev/null +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/BlacklistKotlinClosureTest.kt @@ -0,0 +1,92 @@ +package net.corda.client.rpc + +import co.paralleluniverse.fibers.Suspendable +import com.esotericsoftware.kryo.KryoException +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.messaging.startFlow +import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.unwrap +import net.corda.node.internal.Node +import net.corda.node.internal.StartedNode +import net.corda.nodeapi.User +import net.corda.testing.* +import net.corda.testing.node.NodeBasedTest +import org.junit.After +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.rules.ExpectedException + +@CordaSerializable +data class Packet(val x: () -> Long) + +class BlacklistKotlinClosureTest : NodeBasedTest() { + companion object { + @Suppress("UNUSED") val logger = loggerFor() + const val EVIL: Long = 666 + } + + @StartableByRPC + @InitiatingFlow + class FlowC(private val remoteParty: Party, private val data: Packet) : FlowLogic() { + @Suspendable + override fun call() { + val session = initiateFlow(remoteParty) + val x = session.sendAndReceive(data).unwrap { x -> x } + logger.info("FlowC: ${x.x()}") + } + } + + @InitiatedBy(FlowC::class) + class RemoteFlowC(private val session: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + val packet = session.receive().unwrap { x -> x } + logger.info("RemoteFlowC: ${packet.x() + 1}") + session.send(Packet({ packet.x() + 1 })) + } + } + + @JvmField + @Rule + val expectedEx: ExpectedException = ExpectedException.none() + + private val rpcUser = User("user1", "test", permissions = setOf("ALL")) + private lateinit var aliceNode: StartedNode + private lateinit var bobNode: StartedNode + private lateinit var aliceClient: CordaRPCClient + private var connection: CordaRPCConnection? = null + + private fun login(username: String, password: String) { + connection = aliceClient.start(username, password) + } + + @Before + fun setUp() { + setCordappPackages("net.corda.client.rpc") + aliceNode = startNode(ALICE.name, rpcUsers = listOf(rpcUser)).getOrThrow() + bobNode = startNode(BOB.name, rpcUsers = listOf(rpcUser)).getOrThrow() + bobNode.registerInitiatedFlow(RemoteFlowC::class.java) + aliceClient = CordaRPCClient(aliceNode.internals.configuration.rpcAddress!!) + } + + @After + fun done() { + connection?.close() + bobNode.internals.stop() + aliceNode.internals.stop() + unsetCordappPackages() + } + + @Test + fun `closure sent via RPC`() { + login(rpcUser.username, rpcUser.password) + val proxy = connection!!.proxy + expectedEx.expect(KryoException::class.java) + expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization") + proxy.startFlow(::FlowC, bobNode.info.chooseIdentity(), Packet{ EVIL }).returnValue.getOrThrow() + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/CordaSerializable.kt b/core/src/main/kotlin/net/corda/core/serialization/CordaSerializable.kt index ff90f2a462..ce80444256 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/CordaSerializable.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/CordaSerializable.kt @@ -11,6 +11,9 @@ import java.lang.annotation.Inherited * * It also makes it possible for a code reviewer to clearly identify the classes that can be passed on the wire. * + * Do NOT include [AnnotationTarget.EXPRESSION] as one of the @Target parameters, as this would allow any Lambda to + * be serialised. This would be a security hole. + * * TODO: As we approach a long term wire format, this annotation will only be permitted on classes that meet certain criteria. */ @Target(AnnotationTarget.CLASS) diff --git a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt index 60bfdd8369..22ef438a6b 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt @@ -1,14 +1,22 @@ package net.corda.core.utilities +import com.esotericsoftware.kryo.KryoException import net.corda.core.crypto.random63BitValue import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.nodeapi.internal.serialization.KRYO_CHECKPOINT_CONTEXT import net.corda.testing.TestDependencyInjectionBase import org.assertj.core.api.Assertions.assertThat +import org.junit.Rule import org.junit.Test +import org.junit.rules.ExpectedException class KotlinUtilsTest : TestDependencyInjectionBase() { + @JvmField + @Rule + val expectedEx: ExpectedException = ExpectedException.none() + @Test fun `transient property which is null`() { val test = NullTransientProperty() @@ -18,26 +26,58 @@ class KotlinUtilsTest : TestDependencyInjectionBase() { } @Test - fun `transient property with non-capturing lamba`() { + fun `checkpointing a transient property with non-capturing lamba`() { val original = NonCapturingTransientProperty() val originalVal = original.transientVal - val copy = original.serialize().deserialize() + val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT) val copyVal = copy.transientVal assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copy.transientVal).isEqualTo(copyVal) } @Test - fun `transient property with capturing lamba`() { + fun `serialise transient property with non-capturing lamba`() { + expectedEx.expect(KryoException::class.java) + expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization") + val original = NonCapturingTransientProperty() + original.serialize() + } + + @Test + fun `deserialise transient property with non-capturing lamba`() { + expectedEx.expect(KryoException::class.java) + expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization") + val original = NonCapturingTransientProperty() + original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize() + } + + @Test + fun `checkpointing a transient property with capturing lamba`() { val original = CapturingTransientProperty("Hello") val originalVal = original.transientVal - val copy = original.serialize().deserialize() + val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT) val copyVal = copy.transientVal assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copy.transientVal).isEqualTo(copyVal) assertThat(copy.transientVal).startsWith("Hello") } + @Test + fun `serialise transient property with capturing lamba`() { + expectedEx.expect(KryoException::class.java) + expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization") + val original = CapturingTransientProperty("Hello") + original.serialize() + } + + @Test + fun `deserialise transient property with capturing lamba`() { + expectedEx.expect(KryoException::class.java) + expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization") + val original = CapturingTransientProperty("Hello") + original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize() + } + private class NullTransientProperty { var evalCount = 0 val transientValue by transient { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt index b614284139..afcdcbf0c8 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt @@ -63,8 +63,6 @@ class CordaClassResolver(serializationContext: SerializationContext) : DefaultCl if (type.isArray) return checkClass(type.componentType) // Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry. if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass) - // Kotlin lambdas require some special treatment - if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) return null // It's safe to have the Class already, since Kryo loads it with initialisation off. // If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted. // Thus, blacklisting precedes annotation checking. diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index b3103f979e..59f40a8533 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -404,9 +404,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { - fun pollForMessage() = session.receivedMessages.poll() - - val polledMessage = pollForMessage() + val polledMessage = session.receivedMessages.poll() return if (polledMessage != null) { if (this is SendAndReceive) { // Since we've already received the message, we downgrade to a send only to get the payload out and not @@ -417,7 +415,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } else { // Suspend while we wait for a receive suspend(this) - pollForMessage() ?: + session.receivedMessages.poll() ?: throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this") } }