diff --git a/node/src/integration-test/kotlin/net/corda/node/flows/FlowSessionCloseTest.kt b/node/src/integration-test/kotlin/net/corda/node/flows/FlowSessionCloseTest.kt index a7e0cf877e..b7abe4249f 100644 --- a/node/src/integration-test/kotlin/net/corda/node/flows/FlowSessionCloseTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/flows/FlowSessionCloseTest.kt @@ -43,7 +43,7 @@ class FlowSessionCloseTest { ).transpose().getOrThrow() CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { - assertThatThrownBy { it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), true, null, false).returnValue.getOrThrow() } + assertThatThrownBy { it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), true, null, InitiatorFlow.ResponderReaction.NORMAL_CLOSE).returnValue.getOrThrow() } .isInstanceOf(CordaRuntimeException::class.java) .hasMessageContaining(PrematureSessionCloseException::class.java.name) .hasMessageContaining("The following session was closed before it was initialised") @@ -52,18 +52,26 @@ class FlowSessionCloseTest { } @Test(timeout=300_000) - fun `flow cannot access closed session`() { + fun `flow cannot access closed session, unless it's a duplicate close which is handled gracefully`() { driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()), notarySpecs = emptyList())) { val (nodeAHandle, nodeBHandle) = listOf( startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)), startNode(providedName = BOB_NAME, rpcUsers = listOf(user)) ).transpose().getOrThrow() - InitiatorFlow.SessionAPI.values().forEach { sessionAPI -> - CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { - assertThatThrownBy { it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, sessionAPI, false).returnValue.getOrThrow() } - .isInstanceOf(UnexpectedFlowEndException::class.java) - .hasMessageContaining("Tried to access ended session") + + CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { + InitiatorFlow.SessionAPI.values().forEach { sessionAPI -> + when (sessionAPI) { + InitiatorFlow.SessionAPI.CLOSE -> { + it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, sessionAPI, InitiatorFlow.ResponderReaction.NORMAL_CLOSE).returnValue.getOrThrow() + } + else -> { + assertThatThrownBy { it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, sessionAPI, InitiatorFlow.ResponderReaction.NORMAL_CLOSE).returnValue.getOrThrow() } + .isInstanceOf(UnexpectedFlowEndException::class.java) + .hasMessageContaining("Tried to access ended session") + } + } } } @@ -79,7 +87,7 @@ class FlowSessionCloseTest { ).transpose().getOrThrow() CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { - it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, null, false).returnValue.getOrThrow() + it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, null, InitiatorFlow.ResponderReaction.NORMAL_CLOSE).returnValue.getOrThrow() } } } @@ -93,7 +101,7 @@ class FlowSessionCloseTest { ).transpose().getOrThrow() CordaRPCClient(nodeAHandle.rpcAddress).start(user.username, user.password).use { - it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, null, true).returnValue.getOrThrow() + it.proxy.startFlow(::InitiatorFlow, nodeBHandle.nodeInfo.legalIdentities.first(), false, null, InitiatorFlow.ResponderReaction.RETRY_CLOSE_FROM_CHECKPOINT).returnValue.getOrThrow() } } } @@ -151,14 +159,21 @@ class FlowSessionCloseTest { @StartableByRPC class InitiatorFlow(val party: Party, private val prematureClose: Boolean = false, private val accessClosedSessionWithApi: SessionAPI? = null, - private val retryClose: Boolean = false): FlowLogic() { + private val responderReaction: ResponderReaction): FlowLogic() { @CordaSerializable enum class SessionAPI { SEND, SEND_AND_RECEIVE, RECEIVE, - GET_FLOW_INFO + GET_FLOW_INFO, + CLOSE + } + + @CordaSerializable + enum class ResponderReaction { + NORMAL_CLOSE, + RETRY_CLOSE_FROM_CHECKPOINT } @Suspendable @@ -169,7 +184,7 @@ class FlowSessionCloseTest { session.close() } - session.send(retryClose) + session.send(responderReaction) sleep(1.seconds) if (accessClosedSessionWithApi != null) { @@ -178,6 +193,7 @@ class FlowSessionCloseTest { SessionAPI.RECEIVE -> session.receive() SessionAPI.SEND_AND_RECEIVE -> session.sendAndReceive("dummy payload") SessionAPI.GET_FLOW_INFO -> session.getCounterpartyFlowInfo() + SessionAPI.CLOSE -> session.close() } } } @@ -192,16 +208,21 @@ class FlowSessionCloseTest { @Suspendable override fun call() { - val retryClose = otherSideSession.receive() + val responderReaction = otherSideSession.receive() .unwrap{ it } - otherSideSession.close() + when(responderReaction) { + InitiatorFlow.ResponderReaction.NORMAL_CLOSE -> { + otherSideSession.close() + } + InitiatorFlow.ResponderReaction.RETRY_CLOSE_FROM_CHECKPOINT -> { + otherSideSession.close() - // failing with a transient exception to force a replay of the close. - if (retryClose) { - if (!thrown) { - thrown = true - throw SQLTransientConnectionException("Connection is not available") + // failing with a transient exception to force a replay of the close. + if (!thrown) { + thrown = true + throw SQLTransientConnectionException("Connection is not available") + } } } }