From 8faf72f7b50ba087617ba79ff27a596c85a703cf Mon Sep 17 00:00:00 2001 From: Dimos Raptis Date: Mon, 20 Apr 2020 08:09:38 +0100 Subject: [PATCH] [ENT-5210] - Whitelist SNAPPY encoding (#6163) * [ENT-5210] - Whitelist SNAPPY encoding * Remove unused imports --- .../kotlin/net/corda/core/flows/FlowLogic.kt | 17 ++--------------- .../net/corda/core/internal/FlowStateMachine.kt | 3 +++ .../statemachine/FlowStateMachineImpl.kt | 11 +++++++++++ .../internal/SerializationScheme.kt | 8 +++++++- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 8e4abdb05a..b63a235468 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -24,9 +24,6 @@ import net.corda.core.messaging.DataFeed import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.SerializationDefaults -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.serialize import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.UntrustworthyData @@ -267,7 +264,7 @@ abstract class FlowLogic { @Suspendable internal fun FlowSession.sendAndReceiveWithRetry(receiveType: Class, payload: Any): UntrustworthyData { val request = FlowIORequest.SendAndReceive( - sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), + sessionToMessage = stateMachine.serialize(mapOf(this to payload)), shouldRetrySend = true ) return stateMachine.suspend(request, maySkipCheckpoint = false)[this]!!.checkPayloadIs(receiveType) @@ -350,21 +347,11 @@ abstract class FlowLogic { @JvmOverloads fun sendAllMap(payloadsPerSession: Map, maySkipCheckpoint: Boolean = false) { val request = FlowIORequest.Send( - sessionToMessage = serializePayloads(payloadsPerSession) + sessionToMessage = stateMachine.serialize(payloadsPerSession) ) stateMachine.suspend(request, maySkipCheckpoint) } - @Suspendable - private fun serializePayloads(payloadsPerSession: Map): Map> { - val cachedSerializedPayloads = mutableMapOf>() - - return payloadsPerSession.mapValues { (_, payload) -> - cachedSerializedPayloads[payload] ?: payload.serialize(context = SerializationDefaults.P2P_CONTEXT).also { cachedSerializedPayloads[payload] = it } - } - } - - /** * Invokes the given subflow. This function returns once the subflow completes successfully with the result * returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index c057efa31e..443647fc1f 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -8,6 +8,7 @@ import net.corda.core.context.InvocationContext import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.node.ServiceHub +import net.corda.core.serialization.SerializedBytes import org.slf4j.Logger /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ @@ -17,6 +18,8 @@ interface FlowStateMachine { @Suspendable fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): SUSPENDRETURN + fun serialize(payloads: Map): Map> + @Suspendable fun initiateFlow(destination: Destination, wellKnownParty: Party): FlowSession diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 1791afc7e5..6d2c95efa9 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -13,8 +13,11 @@ import net.corda.core.flows.* import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.internal.* +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.checkpointSerialize +import net.corda.core.serialization.serialize import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.Try import net.corda.core.utilities.debug @@ -429,6 +432,14 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id) } + override fun serialize(payloads: Map): Map> { + val cachedSerializedPayloads = mutableMapOf>() + + return payloads.mapValues { (_, payload) -> + cachedSerializedPayloads[payload] ?: payload.serialize(context = SerializationDefaults.P2P_CONTEXT).also { cachedSerializedPayloads[payload] = it } + } + } + @Suspendable override fun suspend(ioRequest: FlowIORequest, maySkipCheckpoint: Boolean): R { val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext)) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt index 102875b6fa..2447ed9642 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt @@ -17,6 +17,12 @@ internal object NullEncodingWhitelist : EncodingWhitelist { override fun acceptEncoding(encoding: SerializationEncoding) = false } +internal object SnappyEncodingWhitelist: EncodingWhitelist { + override fun acceptEncoding(encoding: SerializationEncoding): Boolean { + return encoding == CordaSerializationEncoding.SNAPPY + } +} + @KeepForDJVM data class SerializationContextImpl @JvmOverloads constructor(override val preferredSerializationVersion: SerializationMagic, override val deserializationClassLoader: ClassLoader, @@ -25,7 +31,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe override val objectReferencesEnabled: Boolean, override val useCase: SerializationContext.UseCase, override val encoding: SerializationEncoding?, - override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist, + override val encodingWhitelist: EncodingWhitelist = SnappyEncodingWhitelist, override val lenientCarpenterEnabled: Boolean = false, override val carpenterDisabled: Boolean = false, override val preventDataLoss: Boolean = false,