Change the serialization/deserialization code of SessionMessage data to add more validation.

Address PR comments

As pointed out by Shams the SessionInit must be well formed at this point.
This commit is contained in:
Matthew Nesbit 2017-10-09 15:04:00 +01:00
parent e232d111ea
commit 899f7f9d0d
7 changed files with 53 additions and 32 deletions

View File

@ -11,11 +11,12 @@ import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific import net.corda.testing.amqpSpecific
import net.corda.testing.kryoSpecific import net.corda.testing.kryoSpecific
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.Assert.* import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Test import org.junit.Test
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.io.NotSerializableException import java.io.NotSerializableException
import java.nio.charset.StandardCharsets.* import java.nio.charset.StandardCharsets.US_ASCII
import java.util.* import java.util.*
class ListsSerializationTest : TestDependencyInjectionBase() { class ListsSerializationTest : TestDependencyInjectionBase() {
@ -40,16 +41,19 @@ class ListsSerializationTest : TestDependencyInjectionBase() {
@Test @Test
fun `check list can be serialized as part of SessionData`() { fun `check list can be serialized as part of SessionData`() {
run { run {
val sessionData = SessionData(123, listOf(1)) val sessionData = SessionData(123, listOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, listOf(1, 2)) val sessionData = SessionData(123, listOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1, 2), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, emptyList<Int>()) val sessionData = SessionData(123, emptyList<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptyList<Int>(), sessionData.payload.deserialize())
} }
} }

View File

@ -3,17 +3,19 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific import net.corda.testing.amqpSpecific
import net.corda.testing.kryoSpecific import net.corda.testing.kryoSpecific
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.bouncycastle.asn1.x500.X500Name
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
import org.junit.Test import org.junit.Test
import org.bouncycastle.asn1.x500.X500Name
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.util.* import java.util.*
import kotlin.test.assertEquals
class MapsSerializationTest : TestDependencyInjectionBase() { class MapsSerializationTest : TestDependencyInjectionBase() {
private companion object { private companion object {
@ -33,8 +35,9 @@ class MapsSerializationTest : TestDependencyInjectionBase() {
@Test @Test
fun `check list can be serialized as part of SessionData`() { fun `check list can be serialized as part of SessionData`() {
val sessionData = SessionData(123, smallMap) val sessionData = SessionData(123, smallMap.serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(smallMap, sessionData.payload.deserialize())
} }
@CordaSerializable @CordaSerializable

View File

@ -2,11 +2,13 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.kryoSpecific import net.corda.testing.kryoSpecific
import org.junit.Assert.* import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Test import org.junit.Test
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.util.* import java.util.*
@ -26,16 +28,19 @@ class SetsSerializationTest : TestDependencyInjectionBase() {
@Test @Test
fun `check set can be serialized as part of SessionData`() { fun `check set can be serialized as part of SessionData`() {
run { run {
val sessionData = SessionData(123, setOf(1)) val sessionData = SessionData(123, setOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, setOf(1, 2)) val sessionData = SessionData(123, setOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1, 2), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, emptySet<Int>()) val sessionData = SessionData(123, emptySet<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptySet<Int>(), sessionData.payload.deserialize())
} }
} }

View File

@ -15,6 +15,8 @@ import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.* import net.corda.core.internal.*
import net.corda.core.internal.concurrent.OpenFuture import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.node.services.api.FlowAppAuditEvent import net.corda.node.services.api.FlowAppAuditEvent
@ -327,7 +329,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
is FlowSessionState.Initiated -> sessionState.peerSessionId is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session") else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session")
} }
return SessionData(peerSessionId, payload) return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
} }
@Suspendable @Suspendable
@ -389,7 +391,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
session.state = FlowSessionState.Initiating(state.otherParty) session.state = FlowSessionState.Initiating(state.otherParty)
session.retryable = retryable session.retryable = retryable
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, firstPayload) val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
sendInternal(session, sessionInit) sendInternal(session, sessionInit)
if (waitForConfirmation) { if (waitForConfirmation) {
session.waitForConfirmation() session.waitForConfirmation()

View File

@ -5,7 +5,10 @@ import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.castIfPossible import net.corda.core.internal.castIfPossible
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
import java.io.IOException
@CordaSerializable @CordaSerializable
interface SessionMessage interface SessionMessage
@ -25,7 +28,7 @@ data class SessionInit(val initiatorSessionId: Long,
val initiatingFlowClass: String, val initiatingFlowClass: String,
val flowVersion: Int, val flowVersion: Int,
val appName: String, val appName: String,
val firstPayload: Any?) : SessionMessage val firstPayload: SerializedBytes<Any>?) : SessionMessage
data class SessionConfirm(override val initiatorSessionId: Long, data class SessionConfirm(override val initiatorSessionId: Long,
val initiatedSessionId: Long, val initiatedSessionId: Long,
@ -34,7 +37,7 @@ data class SessionConfirm(override val initiatorSessionId: Long,
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes<Any>) : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
@ -42,8 +45,15 @@ data class ErrorSessionEnd(override val recipientSessionId: Long, val errorRespo
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M) data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
fun <T> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> { fun <T : Any> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
return type.castIfPossible(message.payload)?.let { UntrustworthyData(it) } ?: val payloadData: T = try {
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " + val serializer = SerializationDefaults.SERIALIZATION_FACTORY
"${message.payload.javaClass.name} (${message.payload})") serializer.deserialize<T>(message.payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
} }

View File

@ -15,11 +15,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.* import net.corda.core.flows.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.*
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.castIfPossible
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
@ -290,7 +286,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
private fun onSessionMessage(message: ReceivedMessage) { private fun onSessionMessage(message: ReceivedMessage) {
val sessionMessage = message.data.deserialize<SessionMessage>() val sessionMessage = try {
message.data.deserialize<SessionMessage>()
} catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from ${message.peer}")
return
}
val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer) val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer)
if (sender != null) { if (sender != null) {
when (sessionMessage) { when (sessionMessage) {
@ -382,12 +383,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
updateCheckpoint(fiber) updateCheckpoint(fiber)
session to initiatedFlowFactory session to initiatedFlowFactory
} catch (e: SessionRejectException) { } catch (e: SessionRejectException) {
// TODO: Handle this more gracefully
try {
logger.warn("${e.logMessage}: $sessionInit") logger.warn("${e.logMessage}: $sessionInit")
} catch (e: Throwable) {
logger.warn("Problematic session init message during logging", e)
}
sendSessionReject(e.rejectMessage) sendSessionReject(e.rejectMessage)
return return
} catch (e: Exception) { } catch (e: Exception) {

View File

@ -710,11 +710,11 @@ class FlowFrameworkTests {
} }
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit { private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload) return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
} }
private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "") private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "")
private fun sessionData(payload: Any) = SessionData(0, payload) private fun sessionData(payload: Any) = SessionData(0, payload.serialize())
private val normalEnd = NormalSessionEnd(0) private val normalEnd = NormalSessionEnd(0)
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse) private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)