[ENT-5210] - Whitelist SNAPPY encoding (#6163)

* [ENT-5210] - Whitelist SNAPPY encoding

* Remove unused imports
This commit is contained in:
Dimos Raptis 2020-04-20 08:09:38 +01:00 committed by GitHub
parent 6f437b5b09
commit 8faf72f7b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 16 deletions

View File

@ -24,9 +24,6 @@ import net.corda.core.messaging.DataFeed
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
@ -267,7 +264,7 @@ abstract class FlowLogic<out T> {
@Suspendable @Suspendable
internal fun <R : Any> FlowSession.sendAndReceiveWithRetry(receiveType: Class<R>, payload: Any): UntrustworthyData<R> { internal fun <R : Any> FlowSession.sendAndReceiveWithRetry(receiveType: Class<R>, payload: Any): UntrustworthyData<R> {
val request = FlowIORequest.SendAndReceive( val request = FlowIORequest.SendAndReceive(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), sessionToMessage = stateMachine.serialize(mapOf(this to payload)),
shouldRetrySend = true shouldRetrySend = true
) )
return stateMachine.suspend(request, maySkipCheckpoint = false)[this]!!.checkPayloadIs(receiveType) return stateMachine.suspend(request, maySkipCheckpoint = false)[this]!!.checkPayloadIs(receiveType)
@ -350,21 +347,11 @@ abstract class FlowLogic<out T> {
@JvmOverloads @JvmOverloads
fun sendAllMap(payloadsPerSession: Map<FlowSession, Any>, maySkipCheckpoint: Boolean = false) { fun sendAllMap(payloadsPerSession: Map<FlowSession, Any>, maySkipCheckpoint: Boolean = false) {
val request = FlowIORequest.Send( val request = FlowIORequest.Send(
sessionToMessage = serializePayloads(payloadsPerSession) sessionToMessage = stateMachine.serialize(payloadsPerSession)
) )
stateMachine.suspend(request, maySkipCheckpoint) stateMachine.suspend(request, maySkipCheckpoint)
} }
@Suspendable
private fun serializePayloads(payloadsPerSession: Map<FlowSession, Any>): Map<FlowSession, SerializedBytes<Any>> {
val cachedSerializedPayloads = mutableMapOf<Any, SerializedBytes<Any>>()
return payloadsPerSession.mapValues { (_, payload) ->
cachedSerializedPayloads[payload] ?: payload.serialize(context = SerializationDefaults.P2P_CONTEXT).also { cachedSerializedPayloads[payload] = it }
}
}
/** /**
* Invokes the given subflow. This function returns once the subflow completes successfully with the result * Invokes the given subflow. This function returns once the subflow completes successfully with the result
* returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the * returned by that subflow's [call] method. If the subflow has a progress tracker, it is attached to the

View File

@ -8,6 +8,7 @@ import net.corda.core.context.InvocationContext
import net.corda.core.flows.* import net.corda.core.flows.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializedBytes
import org.slf4j.Logger import org.slf4j.Logger
/** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */
@ -17,6 +18,8 @@ interface FlowStateMachine<FLOWRETURN> {
@Suspendable @Suspendable
fun <SUSPENDRETURN : Any> suspend(ioRequest: FlowIORequest<SUSPENDRETURN>, maySkipCheckpoint: Boolean): SUSPENDRETURN fun <SUSPENDRETURN : Any> suspend(ioRequest: FlowIORequest<SUSPENDRETURN>, maySkipCheckpoint: Boolean): SUSPENDRETURN
fun serialize(payloads: Map<FlowSession, Any>): Map<FlowSession, SerializedBytes<Any>>
@Suspendable @Suspendable
fun initiateFlow(destination: Destination, wellKnownParty: Party): FlowSession fun initiateFlow(destination: Destination, wellKnownParty: Party): FlowSession

View File

@ -13,8 +13,11 @@ import net.corda.core.flows.*
import net.corda.core.identity.AnonymousParty import net.corda.core.identity.AnonymousParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.* import net.corda.core.internal.*
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal.CheckpointSerializationContext import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
@ -429,6 +432,14 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id) FlowStackSnapshotFactory.instance.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id)
} }
override fun serialize(payloads: Map<FlowSession, Any>): Map<FlowSession, SerializedBytes<Any>> {
val cachedSerializedPayloads = mutableMapOf<Any, SerializedBytes<Any>>()
return payloads.mapValues { (_, payload) ->
cachedSerializedPayloads[payload] ?: payload.serialize(context = SerializationDefaults.P2P_CONTEXT).also { cachedSerializedPayloads[payload] = it }
}
}
@Suspendable @Suspendable
override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R { override fun <R : Any> suspend(ioRequest: FlowIORequest<R>, maySkipCheckpoint: Boolean): R {
val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext)) val serializationContext = TransientReference(getTransientField(TransientValues::checkpointSerializationContext))

View File

@ -17,6 +17,12 @@ internal object NullEncodingWhitelist : EncodingWhitelist {
override fun acceptEncoding(encoding: SerializationEncoding) = false override fun acceptEncoding(encoding: SerializationEncoding) = false
} }
internal object SnappyEncodingWhitelist: EncodingWhitelist {
override fun acceptEncoding(encoding: SerializationEncoding): Boolean {
return encoding == CordaSerializationEncoding.SNAPPY
}
}
@KeepForDJVM @KeepForDJVM
data class SerializationContextImpl @JvmOverloads constructor(override val preferredSerializationVersion: SerializationMagic, data class SerializationContextImpl @JvmOverloads constructor(override val preferredSerializationVersion: SerializationMagic,
override val deserializationClassLoader: ClassLoader, override val deserializationClassLoader: ClassLoader,
@ -25,7 +31,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe
override val objectReferencesEnabled: Boolean, override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase, override val useCase: SerializationContext.UseCase,
override val encoding: SerializationEncoding?, override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist, override val encodingWhitelist: EncodingWhitelist = SnappyEncodingWhitelist,
override val lenientCarpenterEnabled: Boolean = false, override val lenientCarpenterEnabled: Boolean = false,
override val carpenterDisabled: Boolean = false, override val carpenterDisabled: Boolean = false,
override val preventDataLoss: Boolean = false, override val preventDataLoss: Boolean = false,