Added check to receive and sendAndReceive to make sure the primitive classes aren't used (#1400)

This commit is contained in:
Shams Asari 2017-09-06 17:36:13 +01:00
parent 50c51493c3
commit c293d6b18a
3 changed files with 63 additions and 5 deletions

5
.idea/compiler.xml generated
View File

@ -14,6 +14,9 @@
<module name="client_test" target="1.8" /> <module name="client_test" target="1.8" />
<module name="corda-project_main" target="1.8" /> <module name="corda-project_main" target="1.8" />
<module name="corda-project_test" target="1.8" /> <module name="corda-project_test" target="1.8" />
<module name="corda-webserver_integrationTest" target="1.8" />
<module name="corda-webserver_main" target="1.8" />
<module name="corda-webserver_test" target="1.8" />
<module name="cordform-common_main" target="1.8" /> <module name="cordform-common_main" target="1.8" />
<module name="cordform-common_test" target="1.8" /> <module name="cordform-common_test" target="1.8" />
<module name="core_main" target="1.8" /> <module name="core_main" target="1.8" />
@ -96,6 +99,8 @@
<module name="verifier_test" target="1.8" /> <module name="verifier_test" target="1.8" />
<module name="webcapsule_main" target="1.6" /> <module name="webcapsule_main" target="1.6" />
<module name="webcapsule_test" target="1.6" /> <module name="webcapsule_test" target="1.6" />
<module name="webserver-webcapsule_main" target="1.6" />
<module name="webserver-webcapsule_test" target="1.6" />
<module name="webserver_integrationTest" target="1.8" /> <module name="webserver_integrationTest" target="1.8" />
<module name="webserver_main" target="1.8" /> <module name="webserver_main" target="1.8" />
<module name="webserver_test" target="1.8" /> <module name="webserver_test" target="1.8" />

View File

@ -1,6 +1,7 @@
package net.corda.core.flows; package net.corda.core.flows;
import co.paralleluniverse.fibers.Suspendable; import co.paralleluniverse.fibers.Suspendable;
import com.google.common.primitives.Primitives;
import net.corda.core.identity.Party; import net.corda.core.identity.Party;
import net.corda.testing.node.MockNetwork; import net.corda.testing.node.MockNetwork;
import org.junit.After; import org.junit.After;
@ -11,6 +12,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import static org.junit.Assert.fail;
public class FlowsInJavaTest { public class FlowsInJavaTest {
@ -41,6 +43,30 @@ public class FlowsInJavaTest {
assertThat(result.get()).isEqualTo("Hello"); assertThat(result.get()).isEqualTo("Hello");
} }
@Test
public void primitiveClassForReceiveType() throws InterruptedException {
// Using the primitive classes causes problems with the checkpointing so we use the wrapper classes and convert
// to the primitive class at callsite.
for (Class<?> receiveType : Primitives.allWrapperTypes()) {
primitiveReceiveTypeTest(receiveType);
}
}
private void primitiveReceiveTypeTest(Class<?> receiveType) throws InterruptedException {
PrimitiveReceiveFlow flow = new PrimitiveReceiveFlow(node2.getInfo().getLegalIdentity(), receiveType);
Future<?> result = node1.getServices().startFlow(flow).getResultFuture();
mockNet.runNetwork();
try {
result.get();
fail("ExecutionException should have been thrown");
} catch (ExecutionException e) {
assertThat(e.getCause())
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("primitive")
.hasMessageContaining(receiveType.getName());
}
}
@InitiatingFlow @InitiatingFlow
private static class SendInUnwrapFlow extends FlowLogic<String> { private static class SendInUnwrapFlow extends FlowLogic<String> {
private final Party otherParty; private final Party otherParty;
@ -74,4 +100,21 @@ public class FlowsInJavaTest {
} }
} }
@InitiatingFlow
private static class PrimitiveReceiveFlow extends FlowLogic<Void> {
private final Party otherParty;
private final Class<?> receiveType;
private PrimitiveReceiveFlow(Party otherParty, Class<?> receiveType) {
this.otherParty = otherParty;
this.receiveType = receiveType;
}
@Suspendable
@Override
public Void call() throws FlowException {
receive(Primitives.unwrap(receiveType), otherParty);
return null;
}
}
} }

View File

@ -1,9 +1,11 @@
package net.corda.node.services.statemachine package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Fiber.parkAndSerialize
import co.paralleluniverse.fibers.FiberScheduler import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import com.google.common.primitives.Primitives
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
@ -165,24 +167,26 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
payload: Any, payload: Any,
sessionFlow: FlowLogic<*>, sessionFlow: FlowLogic<*>,
retrySend: Boolean): UntrustworthyData<T> { retrySend: Boolean): UntrustworthyData<T> {
requireNonPrimitive(receiveType)
logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." }
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
val sessionData = if (session == null) { val receivedSessionData: ReceivedSessionMessage<SessionData> = if (session == null) {
val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend) val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend)
// Only do a receive here as the session init has carried the payload // Only do a receive here as the session init has carried the payload
receiveInternal<SessionData>(newSession, receiveType) receiveInternal(newSession, receiveType)
} else { } else {
val sendData = createSessionData(session, payload) val sendData = createSessionData(session, payload)
sendAndReceiveInternal<SessionData>(session, sendData, receiveType) sendAndReceiveInternal(session, sendData, receiveType)
} }
logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" } logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" }
return sessionData.checkPayloadIs(receiveType) return receivedSessionData.checkPayloadIs(receiveType)
} }
@Suspendable @Suspendable
override fun <T : Any> receive(receiveType: Class<T>, override fun <T : Any> receive(receiveType: Class<T>,
otherParty: Party, otherParty: Party,
sessionFlow: FlowLogic<*>): UntrustworthyData<T> { sessionFlow: FlowLogic<*>): UntrustworthyData<T> {
requireNonPrimitive(receiveType)
logger.debug { "receive(${receiveType.name}, $otherParty) ..." } logger.debug { "receive(${receiveType.name}, $otherParty) ..." }
val session = getConfirmedSession(otherParty, sessionFlow) val session = getConfirmedSession(otherParty, sessionFlow)
val sessionData = receiveInternal<SessionData>(session, receiveType) val sessionData = receiveInternal<SessionData>(session, receiveType)
@ -190,6 +194,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return sessionData.checkPayloadIs(receiveType) return sessionData.checkPayloadIs(receiveType)
} }
private fun requireNonPrimitive(receiveType: Class<*>) {
require(!receiveType.isPrimitive) {
"Use the wrapper type ${Primitives.wrap(receiveType).name} instead of the primitive $receiveType.class"
}
}
@Suspendable @Suspendable
override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) { override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) {
logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" } logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" }