From c293d6b18acb7b6ffb36c11594f3a8bfb150dcdb Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Wed, 6 Sep 2017 17:36:13 +0100 Subject: [PATCH] Added check to receive and sendAndReceive to make sure the primitive classes aren't used (#1400) --- .idea/compiler.xml | 5 +++ .../net/corda/core/flows/FlowsInJavaTest.java | 43 +++++++++++++++++++ .../statemachine/FlowStateMachineImpl.kt | 20 ++++++--- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/.idea/compiler.xml b/.idea/compiler.xml index 83b640e02c..652807fef9 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -14,6 +14,9 @@ + + + @@ -96,6 +99,8 @@ + + diff --git a/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java b/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java index 5c6308c2d5..af255b2938 100644 --- a/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java +++ b/core/src/test/java/net/corda/core/flows/FlowsInJavaTest.java @@ -1,6 +1,7 @@ package net.corda.core.flows; import co.paralleluniverse.fibers.Suspendable; +import com.google.common.primitives.Primitives; import net.corda.core.identity.Party; import net.corda.testing.node.MockNetwork; import org.junit.After; @@ -11,6 +12,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; +import static org.junit.Assert.fail; public class FlowsInJavaTest { @@ -41,6 +43,30 @@ public class FlowsInJavaTest { assertThat(result.get()).isEqualTo("Hello"); } + @Test + public void primitiveClassForReceiveType() throws InterruptedException { + // Using the primitive classes causes problems with the checkpointing so we use the wrapper classes and convert + // to the primitive class at callsite. + for (Class receiveType : Primitives.allWrapperTypes()) { + primitiveReceiveTypeTest(receiveType); + } + } + + private void primitiveReceiveTypeTest(Class receiveType) throws InterruptedException { + PrimitiveReceiveFlow flow = new PrimitiveReceiveFlow(node2.getInfo().getLegalIdentity(), receiveType); + Future result = node1.getServices().startFlow(flow).getResultFuture(); + mockNet.runNetwork(); + try { + result.get(); + fail("ExecutionException should have been thrown"); + } catch (ExecutionException e) { + assertThat(e.getCause()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("primitive") + .hasMessageContaining(receiveType.getName()); + } + } + @InitiatingFlow private static class SendInUnwrapFlow extends FlowLogic { private final Party otherParty; @@ -74,4 +100,21 @@ public class FlowsInJavaTest { } } + @InitiatingFlow + private static class PrimitiveReceiveFlow extends FlowLogic { + private final Party otherParty; + private final Class receiveType; + + private PrimitiveReceiveFlow(Party otherParty, Class receiveType) { + this.otherParty = otherParty; + this.receiveType = receiveType; + } + + @Suspendable + @Override + public Void call() throws FlowException { + receive(Primitives.unwrap(receiveType), otherParty); + return null; + } + } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 977e1b416e..d8fa038ed2 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -1,9 +1,11 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.Fiber.parkAndSerialize import co.paralleluniverse.fibers.FiberScheduler import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand +import com.google.common.primitives.Primitives import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.crypto.random63BitValue @@ -165,24 +167,26 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, payload: Any, sessionFlow: FlowLogic<*>, retrySend: Boolean): UntrustworthyData { + requireNonPrimitive(receiveType) logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) - val sessionData = if (session == null) { + val receivedSessionData: ReceivedSessionMessage = if (session == null) { val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend) // Only do a receive here as the session init has carried the payload - receiveInternal(newSession, receiveType) + receiveInternal(newSession, receiveType) } else { val sendData = createSessionData(session, payload) - sendAndReceiveInternal(session, sendData, receiveType) + sendAndReceiveInternal(session, sendData, receiveType) } - logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" } - return sessionData.checkPayloadIs(receiveType) + logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" } + return receivedSessionData.checkPayloadIs(receiveType) } @Suspendable override fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData { + requireNonPrimitive(receiveType) logger.debug { "receive(${receiveType.name}, $otherParty) ..." } val session = getConfirmedSession(otherParty, sessionFlow) val sessionData = receiveInternal(session, receiveType) @@ -190,6 +194,12 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, return sessionData.checkPayloadIs(receiveType) } + private fun requireNonPrimitive(receiveType: Class<*>) { + require(!receiveType.isPrimitive) { + "Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class" + } + } + @Suspendable override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) { logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" }