mirror of
https://github.com/corda/corda.git
synced 2024-12-22 06:17:55 +00:00
Change the serialization/deserialization code of SessionMessage data to add more validation.
Address PR comments As pointed out by Shams the SessionInit must be well formed at this point.
This commit is contained in:
parent
e232d111ea
commit
899f7f9d0d
@ -11,11 +11,12 @@ import net.corda.testing.TestDependencyInjectionBase
|
||||
import net.corda.testing.amqpSpecific
|
||||
import net.corda.testing.kryoSpecific
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.junit.Assert.*
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Test
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.NotSerializableException
|
||||
import java.nio.charset.StandardCharsets.*
|
||||
import java.nio.charset.StandardCharsets.US_ASCII
|
||||
import java.util.*
|
||||
|
||||
class ListsSerializationTest : TestDependencyInjectionBase() {
|
||||
@ -40,16 +41,19 @@ class ListsSerializationTest : TestDependencyInjectionBase() {
|
||||
@Test
|
||||
fun `check list can be serialized as part of SessionData`() {
|
||||
run {
|
||||
val sessionData = SessionData(123, listOf(1))
|
||||
val sessionData = SessionData(123, listOf(1).serialize())
|
||||
assertEqualAfterRoundTripSerialization(sessionData)
|
||||
assertEquals(listOf(1), sessionData.payload.deserialize())
|
||||
}
|
||||
run {
|
||||
val sessionData = SessionData(123, listOf(1, 2))
|
||||
val sessionData = SessionData(123, listOf(1, 2).serialize())
|
||||
assertEqualAfterRoundTripSerialization(sessionData)
|
||||
assertEquals(listOf(1, 2), sessionData.payload.deserialize())
|
||||
}
|
||||
run {
|
||||
val sessionData = SessionData(123, emptyList<Int>())
|
||||
val sessionData = SessionData(123, emptyList<Int>().serialize())
|
||||
assertEqualAfterRoundTripSerialization(sessionData)
|
||||
assertEquals(emptyList<Int>(), sessionData.payload.deserialize())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,17 +3,19 @@ package net.corda.nodeapi.internal.serialization
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.util.DefaultClassResolver
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.node.services.statemachine.SessionData
|
||||
import net.corda.testing.TestDependencyInjectionBase
|
||||
import net.corda.testing.amqpSpecific
|
||||
import net.corda.testing.kryoSpecific
|
||||
import org.assertj.core.api.Assertions
|
||||
import org.bouncycastle.asn1.x500.X500Name
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Test
|
||||
import org.bouncycastle.asn1.x500.X500Name
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.util.*
|
||||
import kotlin.test.assertEquals
|
||||
|
||||
class MapsSerializationTest : TestDependencyInjectionBase() {
|
||||
private companion object {
|
||||
@ -33,8 +35,9 @@ class MapsSerializationTest : TestDependencyInjectionBase() {
|
||||
|
||||
@Test
|
||||
fun `check list can be serialized as part of SessionData`() {
|
||||
val sessionData = SessionData(123, smallMap)
|
||||
val sessionData = SessionData(123, smallMap.serialize())
|
||||
assertEqualAfterRoundTripSerialization(sessionData)
|
||||
assertEquals(smallMap, sessionData.payload.deserialize())
|
||||
}
|
||||
|
||||
@CordaSerializable
|
||||
|
@ -2,11 +2,13 @@ package net.corda.nodeapi.internal.serialization
|
||||
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.util.DefaultClassResolver
|
||||
import net.corda.core.serialization.deserialize
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.node.services.statemachine.SessionData
|
||||
import net.corda.testing.TestDependencyInjectionBase
|
||||
import net.corda.testing.kryoSpecific
|
||||
import org.junit.Assert.*
|
||||
import org.junit.Assert.assertArrayEquals
|
||||
import org.junit.Assert.assertEquals
|
||||
import org.junit.Test
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.util.*
|
||||
@ -26,16 +28,19 @@ class SetsSerializationTest : TestDependencyInjectionBase() {
|
||||
@Test
|
||||
fun `check set can be serialized as part of SessionData`() {
|
||||
run {
|
||||
val sessionData = SessionData(123, setOf(1))
|
||||
val sessionData = SessionData(123, setOf(1).serialize())
|
||||
assertEqualAfterRoundTripSerialization(sessionData)
|
||||
assertEquals(setOf(1), sessionData.payload.deserialize())
|
||||
}
|
||||
run {
|
||||
val sessionData = SessionData(123, setOf(1, 2))
|
||||
val sessionData = SessionData(123, setOf(1, 2).serialize())
|
||||
assertEqualAfterRoundTripSerialization(sessionData)
|
||||
assertEquals(setOf(1, 2), sessionData.payload.deserialize())
|
||||
}
|
||||
run {
|
||||
val sessionData = SessionData(123, emptySet<Int>())
|
||||
val sessionData = SessionData(123, emptySet<Int>().serialize())
|
||||
assertEqualAfterRoundTripSerialization(sessionData)
|
||||
assertEquals(emptySet<Int>(), sessionData.payload.deserialize())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,8 @@ import net.corda.core.identity.PartyAndCertificate
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.internal.concurrent.OpenFuture
|
||||
import net.corda.core.internal.concurrent.openFuture
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.utilities.*
|
||||
import net.corda.node.services.api.FlowAppAuditEvent
|
||||
@ -327,7 +329,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
is FlowSessionState.Initiated -> sessionState.peerSessionId
|
||||
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session")
|
||||
}
|
||||
return SessionData(peerSessionId, payload)
|
||||
return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
@ -389,7 +391,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
session.state = FlowSessionState.Initiating(state.otherParty)
|
||||
session.retryable = retryable
|
||||
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
|
||||
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, firstPayload)
|
||||
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
|
||||
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
|
||||
sendInternal(session, sessionInit)
|
||||
if (waitForConfirmation) {
|
||||
session.waitForConfirmation()
|
||||
|
@ -5,7 +5,10 @@ import net.corda.core.flows.UnexpectedFlowEndException
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.castIfPossible
|
||||
import net.corda.core.serialization.CordaSerializable
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.UntrustworthyData
|
||||
import java.io.IOException
|
||||
|
||||
@CordaSerializable
|
||||
interface SessionMessage
|
||||
@ -25,7 +28,7 @@ data class SessionInit(val initiatorSessionId: Long,
|
||||
val initiatingFlowClass: String,
|
||||
val flowVersion: Int,
|
||||
val appName: String,
|
||||
val firstPayload: Any?) : SessionMessage
|
||||
val firstPayload: SerializedBytes<Any>?) : SessionMessage
|
||||
|
||||
data class SessionConfirm(override val initiatorSessionId: Long,
|
||||
val initiatedSessionId: Long,
|
||||
@ -34,7 +37,7 @@ data class SessionConfirm(override val initiatorSessionId: Long,
|
||||
|
||||
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
|
||||
|
||||
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage
|
||||
data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes<Any>) : ExistingSessionMessage
|
||||
|
||||
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
|
||||
|
||||
@ -42,8 +45,15 @@ data class ErrorSessionEnd(override val recipientSessionId: Long, val errorRespo
|
||||
|
||||
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
|
||||
|
||||
fun <T> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
|
||||
return type.castIfPossible(message.payload)?.let { UntrustworthyData(it) } ?:
|
||||
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
|
||||
"${message.payload.javaClass.name} (${message.payload})")
|
||||
fun <T : Any> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
|
||||
val payloadData: T = try {
|
||||
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
|
||||
serializer.deserialize<T>(message.payload, type, SerializationDefaults.P2P_CONTEXT)
|
||||
} catch (ex: Exception) {
|
||||
throw IOException("Payload invalid", ex)
|
||||
}
|
||||
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
|
||||
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
|
||||
"${payloadData.javaClass.name} (${payloadData})")
|
||||
|
||||
}
|
||||
|
@ -15,11 +15,7 @@ import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.random63BitValue
|
||||
import net.corda.core.flows.*
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.FlowStateMachine
|
||||
import net.corda.core.internal.ThreadBox
|
||||
import net.corda.core.internal.bufferUntilSubscribed
|
||||
import net.corda.core.internal.castIfPossible
|
||||
import net.corda.core.internal.uncheckedCast
|
||||
import net.corda.core.internal.*
|
||||
import net.corda.core.messaging.DataFeed
|
||||
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
|
||||
import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY
|
||||
@ -290,7 +286,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
}
|
||||
|
||||
private fun onSessionMessage(message: ReceivedMessage) {
|
||||
val sessionMessage = message.data.deserialize<SessionMessage>()
|
||||
val sessionMessage = try {
|
||||
message.data.deserialize<SessionMessage>()
|
||||
} catch (ex: Exception) {
|
||||
logger.error("Received corrupt SessionMessage data from ${message.peer}")
|
||||
return
|
||||
}
|
||||
val sender = serviceHub.networkMapCache.getPeerByLegalName(message.peer)
|
||||
if (sender != null) {
|
||||
when (sessionMessage) {
|
||||
@ -382,12 +383,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
updateCheckpoint(fiber)
|
||||
session to initiatedFlowFactory
|
||||
} catch (e: SessionRejectException) {
|
||||
// TODO: Handle this more gracefully
|
||||
try {
|
||||
logger.warn("${e.logMessage}: $sessionInit")
|
||||
} catch (e: Throwable) {
|
||||
logger.warn("Problematic session init message during logging", e)
|
||||
}
|
||||
sendSessionReject(e.rejectMessage)
|
||||
return
|
||||
} catch (e: Exception) {
|
||||
|
@ -710,11 +710,11 @@ class FlowFrameworkTests {
|
||||
}
|
||||
|
||||
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit {
|
||||
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload)
|
||||
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
|
||||
}
|
||||
|
||||
private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "")
|
||||
private fun sessionData(payload: Any) = SessionData(0, payload)
|
||||
private fun sessionData(payload: Any) = SessionData(0, payload.serialize())
|
||||
private val normalEnd = NormalSessionEnd(0)
|
||||
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user