CORDA-644: Only serialise Kotlin lambdas when checkpointing. (#1801)

* Remove local function because it is serialised as a lambda.
* Don't automatically whitelist Kotlin lambdas unless checkpointing.
* Add comment to @CordaSerializable, warning not to allow AnnotationTarget.EXPRESSION.
This commit is contained in:
Chris Rankin 2017-10-09 13:02:40 +01:00 committed by GitHub
parent f83f1b7010
commit 689758a71c
5 changed files with 141 additions and 10 deletions

View File

@ -0,0 +1,92 @@
package net.corda.client.rpc
import co.paralleluniverse.fibers.Suspendable
import com.esotericsoftware.kryo.KryoException
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.messaging.startFlow
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.unwrap
import net.corda.node.internal.Node
import net.corda.node.internal.StartedNode
import net.corda.nodeapi.User
import net.corda.testing.*
import net.corda.testing.node.NodeBasedTest
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.ExpectedException
@CordaSerializable
data class Packet(val x: () -> Long)
class BlacklistKotlinClosureTest : NodeBasedTest() {
companion object {
@Suppress("UNUSED") val logger = loggerFor<BlacklistKotlinClosureTest>()
const val EVIL: Long = 666
}
@StartableByRPC
@InitiatingFlow
class FlowC(private val remoteParty: Party, private val data: Packet) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val session = initiateFlow(remoteParty)
val x = session.sendAndReceive<Packet>(data).unwrap { x -> x }
logger.info("FlowC: ${x.x()}")
}
}
@InitiatedBy(FlowC::class)
class RemoteFlowC(private val session: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val packet = session.receive<Packet>().unwrap { x -> x }
logger.info("RemoteFlowC: ${packet.x() + 1}")
session.send(Packet({ packet.x() + 1 }))
}
}
@JvmField
@Rule
val expectedEx: ExpectedException = ExpectedException.none()
private val rpcUser = User("user1", "test", permissions = setOf("ALL"))
private lateinit var aliceNode: StartedNode<Node>
private lateinit var bobNode: StartedNode<Node>
private lateinit var aliceClient: CordaRPCClient
private var connection: CordaRPCConnection? = null
private fun login(username: String, password: String) {
connection = aliceClient.start(username, password)
}
@Before
fun setUp() {
setCordappPackages("net.corda.client.rpc")
aliceNode = startNode(ALICE.name, rpcUsers = listOf(rpcUser)).getOrThrow()
bobNode = startNode(BOB.name, rpcUsers = listOf(rpcUser)).getOrThrow()
bobNode.registerInitiatedFlow(RemoteFlowC::class.java)
aliceClient = CordaRPCClient(aliceNode.internals.configuration.rpcAddress!!)
}
@After
fun done() {
connection?.close()
bobNode.internals.stop()
aliceNode.internals.stop()
unsetCordappPackages()
}
@Test
fun `closure sent via RPC`() {
login(rpcUser.username, rpcUser.password)
val proxy = connection!!.proxy
expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
proxy.startFlow(::FlowC, bobNode.info.chooseIdentity(), Packet{ EVIL }).returnValue.getOrThrow()
}
}

View File

@ -11,6 +11,9 @@ import java.lang.annotation.Inherited
* *
* It also makes it possible for a code reviewer to clearly identify the classes that can be passed on the wire. * It also makes it possible for a code reviewer to clearly identify the classes that can be passed on the wire.
* *
* Do NOT include [AnnotationTarget.EXPRESSION] as one of the @Target parameters, as this would allow any Lambda to
* be serialised. This would be a security hole.
*
* TODO: As we approach a long term wire format, this annotation will only be permitted on classes that meet certain criteria. * TODO: As we approach a long term wire format, this annotation will only be permitted on classes that meet certain criteria.
*/ */
@Target(AnnotationTarget.CLASS) @Target(AnnotationTarget.CLASS)

View File

@ -1,14 +1,22 @@
package net.corda.core.utilities package net.corda.core.utilities
import com.esotericsoftware.kryo.KryoException
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.nodeapi.internal.serialization.KRYO_CHECKPOINT_CONTEXT
import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.TestDependencyInjectionBase
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.rules.ExpectedException
class KotlinUtilsTest : TestDependencyInjectionBase() { class KotlinUtilsTest : TestDependencyInjectionBase() {
@JvmField
@Rule
val expectedEx: ExpectedException = ExpectedException.none()
@Test @Test
fun `transient property which is null`() { fun `transient property which is null`() {
val test = NullTransientProperty() val test = NullTransientProperty()
@ -18,26 +26,58 @@ class KotlinUtilsTest : TestDependencyInjectionBase() {
} }
@Test @Test
fun `transient property with non-capturing lamba`() { fun `checkpointing a transient property with non-capturing lamba`() {
val original = NonCapturingTransientProperty() val original = NonCapturingTransientProperty()
val originalVal = original.transientVal val originalVal = original.transientVal
val copy = original.serialize().deserialize() val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal) assertThat(copy.transientVal).isEqualTo(copyVal)
} }
@Test @Test
fun `transient property with capturing lamba`() { fun `serialise transient property with non-capturing lamba`() {
expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = NonCapturingTransientProperty()
original.serialize()
}
@Test
fun `deserialise transient property with non-capturing lamba`() {
expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = NonCapturingTransientProperty()
original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize()
}
@Test
fun `checkpointing a transient property with capturing lamba`() {
val original = CapturingTransientProperty("Hello") val original = CapturingTransientProperty("Hello")
val originalVal = original.transientVal val originalVal = original.transientVal
val copy = original.serialize().deserialize() val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal) assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal) assertThat(copy.transientVal).isEqualTo(copyVal)
assertThat(copy.transientVal).startsWith("Hello") assertThat(copy.transientVal).startsWith("Hello")
} }
@Test
fun `serialise transient property with capturing lamba`() {
expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = CapturingTransientProperty("Hello")
original.serialize()
}
@Test
fun `deserialise transient property with capturing lamba`() {
expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = CapturingTransientProperty("Hello")
original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize()
}
private class NullTransientProperty { private class NullTransientProperty {
var evalCount = 0 var evalCount = 0
val transientValue by transient { val transientValue by transient {

View File

@ -63,8 +63,6 @@ class CordaClassResolver(serializationContext: SerializationContext) : DefaultCl
if (type.isArray) return checkClass(type.componentType) if (type.isArray) return checkClass(type.componentType)
// Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry. // Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry.
if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass) if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass)
// Kotlin lambdas require some special treatment
if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) return null
// It's safe to have the Class already, since Kryo loads it with initialisation off. // It's safe to have the Class already, since Kryo loads it with initialisation off.
// If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted. // If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted.
// Thus, blacklisting precedes annotation checking. // Thus, blacklisting precedes annotation checking.

View File

@ -404,9 +404,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
fun pollForMessage() = session.receivedMessages.poll() val polledMessage = session.receivedMessages.poll()
val polledMessage = pollForMessage()
return if (polledMessage != null) { return if (polledMessage != null) {
if (this is SendAndReceive) { if (this is SendAndReceive) {
// Since we've already received the message, we downgrade to a send only to get the payload out and not // Since we've already received the message, we downgrade to a send only to get the payload out and not
@ -417,7 +415,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} else { } else {
// Suspend while we wait for a receive // Suspend while we wait for a receive
suspend(this) suspend(this)
pollForMessage() ?: session.receivedMessages.poll() ?:
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this") throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this")
} }
} }