Merge pull request #1853 from corda/mnesbit-protect-sessionmessage

Protect the serialization/deserialization code of SessionMessage data
This commit is contained in:
Matthew Nesbit 2017-10-11 15:53:03 +01:00 committed by GitHub
commit 692b57d845
7 changed files with 53 additions and 32 deletions

View File

@ -11,11 +11,12 @@ import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific
import net.corda.testing.kryoSpecific
import org.assertj.core.api.Assertions
import org.junit.Assert.*
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.io.NotSerializableException
import java.nio.charset.StandardCharsets.*
import java.nio.charset.StandardCharsets.US_ASCII
import java.util.*
class ListsSerializationTest : TestDependencyInjectionBase() {
@ -40,16 +41,19 @@ class ListsSerializationTest : TestDependencyInjectionBase() {
@Test
fun `check list can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, listOf(1))
val sessionData = SessionData(123, listOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, listOf(1, 2))
val sessionData = SessionData(123, listOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptyList<Int>())
val sessionData = SessionData(123, emptyList<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptyList<Int>(), sessionData.payload.deserialize())
}
}

View File

@ -3,17 +3,19 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific
import net.corda.testing.kryoSpecific
import org.assertj.core.api.Assertions
import org.bouncycastle.asn1.x500.X500Name
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import org.bouncycastle.asn1.x500.X500Name
import java.io.ByteArrayOutputStream
import java.util.*
import kotlin.test.assertEquals
class MapsSerializationTest : TestDependencyInjectionBase() {
private companion object {
@ -33,8 +35,9 @@ class MapsSerializationTest : TestDependencyInjectionBase() {
@Test
fun `check list can be serialized as part of SessionData`() {
val sessionData = SessionData(123, smallMap)
val sessionData = SessionData(123, smallMap.serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(smallMap, sessionData.payload.deserialize())
}
@CordaSerializable

View File

@ -2,11 +2,13 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.kryoSpecific
import org.junit.Assert.*
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.util.*
@ -26,16 +28,19 @@ class SetsSerializationTest : TestDependencyInjectionBase() {
@Test
fun `check set can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, setOf(1))
val sessionData = SessionData(123, setOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, setOf(1, 2))
val sessionData = SessionData(123, setOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptySet<Int>())
val sessionData = SessionData(123, emptySet<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptySet<Int>(), sessionData.payload.deserialize())
}
}

View File

@ -15,6 +15,8 @@ import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.*
import net.corda.node.services.api.FlowAppAuditEvent
@ -327,7 +329,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session")
}
return SessionData(peerSessionId, payload)
return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
}
@Suspendable
@ -389,7 +391,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
session.state = FlowSessionState.Initiating(state.otherParty)
session.retryable = retryable
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, firstPayload)
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
sendInternal(session, sessionInit)
if (waitForConfirmation) {
session.waitForConfirmation()

View File

@ -5,7 +5,10 @@ import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party
import net.corda.core.internal.castIfPossible
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.UntrustworthyData
import java.io.IOException
@CordaSerializable
interface SessionMessage
@ -25,7 +28,7 @@ data class SessionInit(val initiatorSessionId: Long,
val initiatingFlowClass: String,
val flowVersion: Int,
val appName: String,
val firstPayload: Any?) : SessionMessage
val firstPayload: SerializedBytes<Any>?) : SessionMessage
data class SessionConfirm(override val initiatorSessionId: Long,
val initiatedSessionId: Long,
@ -34,7 +37,7 @@ data class SessionConfirm(override val initiatorSessionId: Long,
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage
data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes<Any>) : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
@ -42,8 +45,15 @@ data class ErrorSessionEnd(override val recipientSessionId: Long, val errorRespo
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
fun <T> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
return type.castIfPossible(message.payload)?.let { UntrustworthyData(it) } ?:
fun <T : Any> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize<T>(message.payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
"${message.payload.javaClass.name} (${message.payload})")
"${payloadData.javaClass.name} (${payloadData})")
}

View File

@ -15,11 +15,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.castIfPossible
import net.corda.core.internal.uncheckedCast
import net.corda.core.internal.*
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
@ -290,7 +286,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
}
private fun onSessionMessage(message: ReceivedMessage) {
val sessionMessage = message.data.deserialize<SessionMessage>()
val sessionMessage = try {
message.data.deserialize<SessionMessage>()
} catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from ${message.peer}")
return
}
val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer)
if (sender != null) {
when (sessionMessage) {
@ -382,12 +383,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
updateCheckpoint(fiber)
session to initiatedFlowFactory
} catch (e: SessionRejectException) {
// TODO: Handle this more gracefully
try {
logger.warn("${e.logMessage}: $sessionInit")
} catch (e: Throwable) {
logger.warn("Problematic session init message during logging", e)
}
logger.warn("${e.logMessage}: $sessionInit")
sendSessionReject(e.rejectMessage)
return
} catch (e: Exception) {

View File

@ -710,11 +710,11 @@ class FlowFrameworkTests {
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload)
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
}
private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "")
private fun sessionData(payload: Any) = SessionData(0, payload)
private fun sessionData(payload: Any) = SessionData(0, payload.serialize())
private val normalEnd = NormalSessionEnd(0)
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)