diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt index ede793b2be..d9d0f58e0c 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt @@ -11,11 +11,12 @@ import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.amqpSpecific import net.corda.testing.kryoSpecific import org.assertj.core.api.Assertions -import org.junit.Assert.* +import org.junit.Assert.assertArrayEquals +import org.junit.Assert.assertEquals import org.junit.Test import java.io.ByteArrayOutputStream import java.io.NotSerializableException -import java.nio.charset.StandardCharsets.* +import java.nio.charset.StandardCharsets.US_ASCII import java.util.* class ListsSerializationTest : TestDependencyInjectionBase() { @@ -40,16 +41,19 @@ class ListsSerializationTest : TestDependencyInjectionBase() { @Test fun `check list can be serialized as part of SessionData`() { run { - val sessionData = SessionData(123, listOf(1)) + val sessionData = SessionData(123, listOf(1).serialize()) assertEqualAfterRoundTripSerialization(sessionData) + assertEquals(listOf(1), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, listOf(1, 2)) + val sessionData = SessionData(123, listOf(1, 2).serialize()) assertEqualAfterRoundTripSerialization(sessionData) + assertEquals(listOf(1, 2), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, emptyList()) + val sessionData = SessionData(123, emptyList().serialize()) assertEqualAfterRoundTripSerialization(sessionData) + assertEquals(emptyList(), sessionData.payload.deserialize()) } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt index 9788885420..4e9f598eab 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt @@ -3,17 +3,19 @@ package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.util.DefaultClassResolver import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.node.services.statemachine.SessionData import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.amqpSpecific import net.corda.testing.kryoSpecific import org.assertj.core.api.Assertions +import org.bouncycastle.asn1.x500.X500Name import org.junit.Assert.assertArrayEquals import org.junit.Test -import org.bouncycastle.asn1.x500.X500Name import java.io.ByteArrayOutputStream import java.util.* +import kotlin.test.assertEquals class MapsSerializationTest : TestDependencyInjectionBase() { private companion object { @@ -33,8 +35,9 @@ class MapsSerializationTest : TestDependencyInjectionBase() { @Test fun `check list can be serialized as part of SessionData`() { - val sessionData = SessionData(123, smallMap) + val sessionData = SessionData(123, smallMap.serialize()) assertEqualAfterRoundTripSerialization(sessionData) + assertEquals(smallMap, sessionData.payload.deserialize()) } @CordaSerializable diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt index 4a652a7521..210a0cd800 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt @@ -2,11 +2,13 @@ package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.util.DefaultClassResolver +import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.node.services.statemachine.SessionData import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.kryoSpecific -import org.junit.Assert.* +import org.junit.Assert.assertArrayEquals +import org.junit.Assert.assertEquals import org.junit.Test import java.io.ByteArrayOutputStream import java.util.* @@ -26,16 +28,19 @@ class SetsSerializationTest : TestDependencyInjectionBase() { @Test fun `check set can be serialized as part of SessionData`() { run { - val sessionData = SessionData(123, setOf(1)) + val sessionData = SessionData(123, setOf(1).serialize()) assertEqualAfterRoundTripSerialization(sessionData) + assertEquals(setOf(1), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, setOf(1, 2)) + val sessionData = SessionData(123, setOf(1, 2).serialize()) assertEqualAfterRoundTripSerialization(sessionData) + assertEquals(setOf(1, 2), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, emptySet()) + val sessionData = SessionData(123, emptySet().serialize()) assertEqualAfterRoundTripSerialization(sessionData) + assertEquals(emptySet(), sessionData.payload.deserialize()) } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 019940363b..89f0bf6dc6 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -15,6 +15,8 @@ import net.corda.core.identity.PartyAndCertificate import net.corda.core.internal.* import net.corda.core.internal.concurrent.OpenFuture import net.corda.core.internal.concurrent.openFuture +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.serialize import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.* import net.corda.node.services.api.FlowAppAuditEvent @@ -327,7 +329,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, is FlowSessionState.Initiated -> sessionState.peerSessionId else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session") } - return SessionData(peerSessionId, payload) + return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT)) } @Suspendable @@ -389,7 +391,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, session.state = FlowSessionState.Initiating(state.otherParty) session.retryable = retryable val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass - val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, firstPayload) + val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT) + val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes) sendInternal(session, sessionInit) if (waitForConfirmation) { session.waitForConfirmation() diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt index fc103e6dca..c321d3768a 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt @@ -5,7 +5,10 @@ import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.identity.Party import net.corda.core.internal.castIfPossible import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes import net.corda.core.utilities.UntrustworthyData +import java.io.IOException @CordaSerializable interface SessionMessage @@ -25,7 +28,7 @@ data class SessionInit(val initiatorSessionId: Long, val initiatingFlowClass: String, val flowVersion: Int, val appName: String, - val firstPayload: Any?) : SessionMessage + val firstPayload: SerializedBytes?) : SessionMessage data class SessionConfirm(override val initiatorSessionId: Long, val initiatedSessionId: Long, @@ -34,7 +37,7 @@ data class SessionConfirm(override val initiatorSessionId: Long, data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse -data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage +data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes) : ExistingSessionMessage data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd @@ -42,8 +45,15 @@ data class ErrorSessionEnd(override val recipientSessionId: Long, val errorRespo data class ReceivedSessionMessage(val sender: Party, val message: M) -fun ReceivedSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { - return type.castIfPossible(message.payload)?.let { UntrustworthyData(it) } ?: +fun ReceivedSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { + val payloadData: T = try { + val serializer = SerializationDefaults.SERIALIZATION_FACTORY + serializer.deserialize(message.payload, type, SerializationDefaults.P2P_CONTEXT) + } catch (ex: Exception) { + throw IOException("Payload invalid", ex) + } + return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?: throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " + - "${message.payload.javaClass.name} (${message.payload})") + "${payloadData.javaClass.name} (${payloadData})") + } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 2215354b07..74697821e9 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -15,11 +15,7 @@ import net.corda.core.crypto.SecureHash import net.corda.core.crypto.random63BitValue import net.corda.core.flows.* import net.corda.core.identity.Party -import net.corda.core.internal.FlowStateMachine -import net.corda.core.internal.ThreadBox -import net.corda.core.internal.bufferUntilSubscribed -import net.corda.core.internal.castIfPossible -import net.corda.core.internal.uncheckedCast +import net.corda.core.internal.* import net.corda.core.messaging.DataFeed import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY @@ -290,7 +286,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } private fun onSessionMessage(message: ReceivedMessage) { - val sessionMessage = message.data.deserialize() + val sessionMessage = try { + message.data.deserialize() + } catch (ex: Exception) { + logger.error("Received corrupt SessionMessage data from ${message.peer}") + return + } val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer) if (sender != null) { when (sessionMessage) { @@ -382,12 +383,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, updateCheckpoint(fiber) session to initiatedFlowFactory } catch (e: SessionRejectException) { - // TODO: Handle this more gracefully - try { - logger.warn("${e.logMessage}: $sessionInit") - } catch (e: Throwable) { - logger.warn("Problematic session init message during logging", e) - } + logger.warn("${e.logMessage}: $sessionInit") sendSessionReject(e.rejectMessage) return } catch (e: Exception) { diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index d4391109e3..418ffb5142 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -710,11 +710,11 @@ class FlowFrameworkTests { } private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): SessionInit { - return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload) + return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) } private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "") - private fun sessionData(payload: Any) = SessionData(0, payload) + private fun sessionData(payload: Any) = SessionData(0, payload.serialize()) private val normalEnd = NormalSessionEnd(0) private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)