Change the serialization/deserialization code of SessionMessage data to add more validation.

Address PR comments

As pointed out by Shams the SessionInit must be well formed at this point.
This commit is contained in:
Matthew Nesbit 2017-10-09 15:04:00 +01:00
parent e232d111ea
commit 899f7f9d0d
7 changed files with 53 additions and 32 deletions

View File

@ -11,11 +11,12 @@ import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific
import net.corda.testing.kryoSpecific
import org.assertj.core.api.Assertions
import org.junit.Assert.*
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.io.NotSerializableException
import java.nio.charset.StandardCharsets.*
import java.nio.charset.StandardCharsets.US_ASCII
import java.util.*
class ListsSerializationTest : TestDependencyInjectionBase() {
@ -40,16 +41,19 @@ class ListsSerializationTest : TestDependencyInjectionBase() {
@Test
fun `check list can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, listOf(1))
val sessionData = SessionData(123, listOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, listOf(1, 2))
val sessionData = SessionData(123, listOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptyList<Int>())
val sessionData = SessionData(123, emptyList<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptyList<Int>(), sessionData.payload.deserialize())
}
}

View File

@ -3,17 +3,19 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific
import net.corda.testing.kryoSpecific
import org.assertj.core.api.Assertions
import org.bouncycastle.asn1.x500.X500Name
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import org.bouncycastle.asn1.x500.X500Name
import java.io.ByteArrayOutputStream
import java.util.*
import kotlin.test.assertEquals
class MapsSerializationTest : TestDependencyInjectionBase() {
private companion object {
@ -33,8 +35,9 @@ class MapsSerializationTest : TestDependencyInjectionBase() {
@Test
fun `check list can be serialized as part of SessionData`() {
val sessionData = SessionData(123, smallMap)
val sessionData = SessionData(123, smallMap.serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(smallMap, sessionData.payload.deserialize())
}
@CordaSerializable

View File

@ -2,11 +2,13 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.kryoSpecific
import org.junit.Assert.*
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.util.*
@ -26,16 +28,19 @@ class SetsSerializationTest : TestDependencyInjectionBase() {
@Test
fun `check set can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, setOf(1))
val sessionData = SessionData(123, setOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, setOf(1, 2))
val sessionData = SessionData(123, setOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptySet<Int>())
val sessionData = SessionData(123, emptySet<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptySet<Int>(), sessionData.payload.deserialize())
}
}

View File

@ -15,6 +15,8 @@ import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.*
import net.corda.node.services.api.FlowAppAuditEvent
@ -327,7 +329,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session")
}
return SessionData(peerSessionId, payload)
return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
}
@Suspendable
@ -389,7 +391,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
session.state = FlowSessionState.Initiating(state.otherParty)
session.retryable = retryable
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, firstPayload)
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
sendInternal(session, sessionInit)
if (waitForConfirmation) {
session.waitForConfirmation()

View File

@ -5,7 +5,10 @@ import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party
import net.corda.core.internal.castIfPossible
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.UntrustworthyData
import java.io.IOException
@CordaSerializable
interface SessionMessage
@ -25,7 +28,7 @@ data class SessionInit(val initiatorSessionId: Long,
val initiatingFlowClass: String,
val flowVersion: Int,
val appName: String,
val firstPayload: Any?) : SessionMessage
val firstPayload: SerializedBytes<Any>?) : SessionMessage
data class SessionConfirm(override val initiatorSessionId: Long,
val initiatedSessionId: Long,
@ -34,7 +37,7 @@ data class SessionConfirm(override val initiatorSessionId: Long,
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage
data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes<Any>) : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
@ -42,8 +45,15 @@ data class ErrorSessionEnd(override val recipientSessionId: Long, val errorRespo
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
fun <T> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
return type.castIfPossible(message.payload)?.let { UntrustworthyData(it) } ?:
fun <T : Any> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize<T>(message.payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
"${message.payload.javaClass.name} (${message.payload})")
"${payloadData.javaClass.name} (${payloadData})")
}

View File

@ -15,11 +15,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.castIfPossible
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.*
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
@ -290,7 +286,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
private fun onSessionMessage(message: ReceivedMessage) {
val sessionMessage = message.data.deserialize<SessionMessage>()
val sessionMessage = try {
message.data.deserialize<SessionMessage>()
} catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from ${message.peer}")
return
}
val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer)
if (sender != null) {
when (sessionMessage) {
@ -382,12 +383,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
updateCheckpoint(fiber)
session to initiatedFlowFactory
} catch (e: SessionRejectException) {
// TODO: Handle this more gracefully
try {
logger.warn("${e.logMessage}: $sessionInit")
} catch (e: Throwable) {
logger.warn("Problematic session init message during logging", e)
}
logger.warn("${e.logMessage}: $sessionInit")
sendSessionReject(e.rejectMessage)
return
} catch (e: Exception) {

View File

@ -710,11 +710,11 @@ class FlowFrameworkTests {
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload)
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
}
private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "")
private fun sessionData(payload: Any) = SessionData(0, payload)
private fun sessionData(payload: Any) = SessionData(0, payload.serialize())
private val normalEnd = NormalSessionEnd(0)
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)