mirror of
https://github.com/corda/corda.git
synced 2025-06-16 22:28:15 +00:00
Added check to receive and sendAndReceive to make sure the primitive classes aren't used (#1400)
This commit is contained in:
@ -1,9 +1,11 @@
|
||||
package net.corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.Fiber.parkAndSerialize
|
||||
import co.paralleluniverse.fibers.FiberScheduler
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import co.paralleluniverse.strands.Strand
|
||||
import com.google.common.primitives.Primitives
|
||||
import net.corda.core.concurrent.CordaFuture
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.random63BitValue
|
||||
@ -165,24 +167,26 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
payload: Any,
|
||||
sessionFlow: FlowLogic<*>,
|
||||
retrySend: Boolean): UntrustworthyData<T> {
|
||||
requireNonPrimitive(receiveType)
|
||||
logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." }
|
||||
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
|
||||
val sessionData = if (session == null) {
|
||||
val receivedSessionData: ReceivedSessionMessage<SessionData> = if (session == null) {
|
||||
val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend)
|
||||
// Only do a receive here as the session init has carried the payload
|
||||
receiveInternal<SessionData>(newSession, receiveType)
|
||||
receiveInternal(newSession, receiveType)
|
||||
} else {
|
||||
val sendData = createSessionData(session, payload)
|
||||
sendAndReceiveInternal<SessionData>(session, sendData, receiveType)
|
||||
sendAndReceiveInternal(session, sendData, receiveType)
|
||||
}
|
||||
logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" }
|
||||
return sessionData.checkPayloadIs(receiveType)
|
||||
logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" }
|
||||
return receivedSessionData.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <T : Any> receive(receiveType: Class<T>,
|
||||
otherParty: Party,
|
||||
sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
|
||||
requireNonPrimitive(receiveType)
|
||||
logger.debug { "receive(${receiveType.name}, $otherParty) ..." }
|
||||
val session = getConfirmedSession(otherParty, sessionFlow)
|
||||
val sessionData = receiveInternal<SessionData>(session, receiveType)
|
||||
@ -190,6 +194,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
return sessionData.checkPayloadIs(receiveType)
|
||||
}
|
||||
|
||||
private fun requireNonPrimitive(receiveType: Class<*>) {
|
||||
require(!receiveType.isPrimitive) {
|
||||
"Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class"
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) {
|
||||
logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" }
|
||||
|
Reference in New Issue
Block a user