diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt index 6a0c4477e3..2d8b8fc231 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt @@ -2,7 +2,11 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Suspendable import net.corda.core.concurrent.CordaFuture -import net.corda.core.flows.* +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine @@ -15,9 +19,12 @@ import net.corda.node.services.FinalityHandler import net.corda.node.services.messaging.Message import net.corda.node.services.persistence.DBTransactionStorage import net.corda.nodeapi.internal.persistence.contextTransaction -import net.corda.testing.common.internal.eventually import net.corda.testing.core.TestIdentity -import net.corda.testing.node.internal.* +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.enclosedCordapp +import net.corda.testing.node.internal.MessagingServiceSpy +import net.corda.testing.node.internal.newContext +import net.corda.testing.node.internal.TestStartedNode import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.hibernate.exception.ConstraintViolationException @@ -28,7 +35,9 @@ import org.junit.Test import java.sql.SQLException import java.time.Duration import java.util.* -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Semaphore +import java.util.concurrent.TimeUnit import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertNull @@ -47,7 +56,6 @@ class RetryFlowMockTest { RetryFlow.count = 0 SendAndRetryFlow.count = 0 RetryInsertFlow.count = 0 - KeepSendingFlow.count.set(0) } private fun TestStartedNode.startFlow(logic: FlowLogic): CordaFuture { @@ -89,35 +97,33 @@ class RetryFlowMockTest { assertEquals(2, SendAndRetryFlow.count) } - @Test - fun `Restart does not set senderUUID`() { + @Test(timeout=300_000) + fun `Restart does not set senderUUID`() { val messagesSent = Collections.synchronizedList(mutableListOf()) val partyB = nodeB.info.legalIdentities.first() + val expectedMessagesSent = CountDownLatch(3) nodeA.setMessagingServiceSpy(object : MessagingServiceSpy() { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { messagesSent.add(message) + expectedMessagesSent.countDown() messagingService.send(message, target) } }) - val count = 10000 // Lots of iterations so the flow keeps going long enough - nodeA.startFlow(KeepSendingFlow(count, partyB)) - eventually(waitBetween = Duration.ofMillis(10)) { - assertTrue(messagesSent.isNotEmpty()) - assertNotNull(messagesSent.first().senderUUID) - } + nodeA.startFlow(KeepSendingFlow(partyB)) + KeepSendingFlow.lock.acquire() + assertTrue(messagesSent.isNotEmpty()) + assertNotNull(messagesSent.first().senderUUID) nodeA = mockNet.restartNode(nodeA) - // This is a bit racy because restarting the node actually starts it, so we need to make sure there's enough iterations we get here with flow still going. nodeA.setMessagingServiceSpy(object : MessagingServiceSpy() { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { messagesSent.add(message) + expectedMessagesSent.countDown() messagingService.send(message, target) } }) - // Now short circuit the iterations so the flow finishes soon. - KeepSendingFlow.count.set(count - 2) - eventually(waitBetween = Duration.ofMillis(10)) { - assertTrue(nodeA.smm.allStateMachines.isEmpty()) - } + ReceiveFlow3.lock.release() + assertTrue(expectedMessagesSent.await(20, TimeUnit.SECONDS)) + assertEquals(3, messagesSent.size) assertNull(messagesSent.last().senderUUID) } @@ -235,32 +241,36 @@ class RetryFlowMockTest { } @InitiatingFlow - class KeepSendingFlow(private val i: Int, private val other: Party) : FlowLogic() { + class KeepSendingFlow(private val other: Party) : FlowLogic() { + companion object { - val count = AtomicInteger(0) + val lock = Semaphore(0) } @Suspendable override fun call() { val session = initiateFlow(other) - session.send(i.toString()) - do { - logger.info("Sending... $count") - session.send("Boo") - } while (count.getAndIncrement() < i) + session.send("boo") + lock.release() + session.receive() + session.send("boo") } } @Suppress("unused") @InitiatedBy(KeepSendingFlow::class) class ReceiveFlow3(private val other: FlowSession) : FlowLogic() { + + companion object { + val lock = Semaphore(0) + } + @Suspendable override fun call() { - var count = other.receive().unwrap { it.toInt() } - while (count-- > 0) { - val received = other.receive().unwrap { it } - logger.info("Received... $received $count") - } + other.receive() + lock.acquire() + other.send("hoo") + other.receive() } } @@ -286,4 +296,27 @@ class RetryFlowMockTest { contextTransaction.session.save(tx) } } + + @InitiatingFlow + class UnbalancedSendAndReceiveFlow(private val other: Party) : FlowLogic() { + + @Suspendable + override fun call() { + val session = initiateFlow(other) + session.send("boo") + session.receive() + session.receive() + } + } + + @Suppress("unused") + @InitiatedBy(UnbalancedSendAndReceiveFlow::class) + class UnbalancedSendAndReceiveResponder(private val other: FlowSession) : FlowLogic() { + + @Suspendable + override fun call() { + other.receive() + other.send("hoo") + } + } }