Merge pull request #2978 from corda/CORDA-1352/aslemmer-os-killflow

CORDA-1352: Add killFlow
This commit is contained in:
Andras Slemmer 2018-05-15 15:49:27 +01:00 committed by GitHub
commit 32aa1bf9d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 0 deletions

View File

@ -260,6 +260,13 @@ interface CordaRPCOps : RPCOps {
@RPCReturnsObservables @RPCReturnsObservables
fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T> fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T>
/**
* Attempts to kill a flow. This is not a clean termination and should be reserved for exceptional cases such as stuck fibers.
*
* @return whether the flow existed and was killed.
*/
fun killFlow(id: StateMachineRunId): Boolean
/** Returns Node's NodeInfo, assuming this will not change while the node is running. */ /** Returns Node's NodeInfo, assuming this will not change while the node is running. */
fun nodeInfo(): NodeInfo fun nodeInfo(): NodeInfo

View File

@ -9,6 +9,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByRPC import net.corda.core.flows.StartableByRPC
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
@ -111,6 +112,8 @@ internal class CordaRPCOpsImpl(
return snapshot return snapshot
} }
override fun killFlow(id: StateMachineRunId) = smm.killFlow(id)
override fun stateMachinesFeed(): DataFeed<List<StateMachineInfo>, StateMachineUpdate> { override fun stateMachinesFeed(): DataFeed<List<StateMachineInfo>, StateMachineUpdate> {
return database.transaction { return database.transaction {
val (allStateMachines, changes) = smm.track() val (allStateMachines, changes) = smm.track()

View File

@ -4,6 +4,7 @@ import net.corda.client.rpc.PermissionException
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party import net.corda.core.identity.Party
@ -75,6 +76,10 @@ class RpcAuthorisationProxy(private val implementation: CordaRPCOps, private val
implementation.startTrackedFlowDynamic(logicType, *args) implementation.startTrackedFlowDynamic(logicType, *args)
} }
override fun killFlow(id: StateMachineRunId): Boolean = guard("killFlow") {
return implementation.killFlow(id)
}
override fun nodeInfo(): NodeInfo = guard("nodeInfo", implementation::nodeInfo) override fun nodeInfo(): NodeInfo = guard("nodeInfo", implementation::nodeInfo)
override fun notaryIdentities(): List<Party> = guard("notaryIdentities", implementation::notaryIdentities) override fun notaryIdentities(): List<Party> = guard("notaryIdentities", implementation::notaryIdentities)

View File

@ -5,6 +5,7 @@ import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.doOnError import net.corda.core.doOnError
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.doOnError import net.corda.core.internal.concurrent.doOnError
@ -90,6 +91,8 @@ class RpcExceptionHandlingProxy(private val delegate: SecureCordaRPCOps) : Corda
override fun acceptNewNetworkParameters(parametersHash: SecureHash) = wrap { delegate.acceptNewNetworkParameters(parametersHash) } override fun acceptNewNetworkParameters(parametersHash: SecureHash) = wrap { delegate.acceptNewNetworkParameters(parametersHash) }
override fun killFlow(id: StateMachineRunId) = wrap { delegate.killFlow(id) }
override fun nodeInfo() = wrap(delegate::nodeInfo) override fun nodeInfo() = wrap(delegate::nodeInfo)
override fun notaryIdentities() = wrap(delegate::notaryIdentities) override fun notaryIdentities() = wrap(delegate::notaryIdentities)

View File

@ -1,5 +1,6 @@
package net.corda.node package net.corda.node
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.client.rpc.PermissionException import net.corda.client.rpc.PermissionException
import net.corda.core.context.AuthServiceId import net.corda.core.context.AuthServiceId
@ -23,6 +24,7 @@ import net.corda.core.node.services.queryBy
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.finance.DOLLARS import net.corda.finance.DOLLARS
import net.corda.finance.GBP import net.corda.finance.GBP
import net.corda.finance.USD import net.corda.finance.USD
@ -46,6 +48,7 @@ import net.corda.testing.node.internal.InternalMockNetwork.MockNode
import net.corda.testing.node.internal.InternalMockNodeParameters import net.corda.testing.node.internal.InternalMockNodeParameters
import net.corda.testing.node.testActor import net.corda.testing.node.testActor
import org.apache.commons.io.IOUtils import org.apache.commons.io.IOUtils
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After import org.junit.After
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
@ -296,6 +299,71 @@ class CordaRPCOpsImplTest {
} }
} }
@Test
fun `kill a stuck flow through RPC`() {
withPermissions(startFlow<NewJoinerFlow>(), invokeRpc(CordaRPCOps::killFlow), invokeRpc(CordaRPCOps::stateMachinesFeed), invokeRpc(CordaRPCOps::stateMachinesSnapshot)) {
val flow = rpc.startFlow(::NewJoinerFlow)
val killed = rpc.killFlow(flow.id)
assertThat(killed).isTrue()
assertThat(rpc.stateMachinesSnapshot().map { info -> info.id }).doesNotContain(flow.id)
}
}
@Test
fun `kill a waiting flow through RPC`() {
withPermissions(startFlow<HopefulFlow>(), invokeRpc(CordaRPCOps::killFlow), invokeRpc(CordaRPCOps::stateMachinesFeed), invokeRpc(CordaRPCOps::stateMachinesSnapshot)) {
val flow = rpc.startFlow(::HopefulFlow, alice)
val killed = rpc.killFlow(flow.id)
assertThat(killed).isTrue()
assertThat(rpc.stateMachinesSnapshot().map { info -> info.id }).doesNotContain(flow.id)
}
}
@Test
fun `kill a nonexistent flow through RPC`() {
withPermissions(invokeRpc(CordaRPCOps::killFlow)) {
val nonexistentFlowId = StateMachineRunId.createRandom()
val killed = rpc.killFlow(nonexistentFlowId)
assertThat(killed).isFalse()
}
}
@StartableByRPC
class NewJoinerFlow : FlowLogic<String>() {
@Suspendable
override fun call(): String {
logger.info("When can I join you say? Almost there buddy...")
Fiber.currentFiber().join()
return "You'll never get me!"
}
}
@StartableByRPC
class HopefulFlow(private val party: Party) : FlowLogic<String>() {
@Suspendable
override fun call(): String {
logger.info("Waiting for a miracle...")
val miracle = initiateFlow(party).receive<String>().unwrap { it }
return miracle
}
}
class NonRPCFlow : FlowLogic<Unit>() { class NonRPCFlow : FlowLogic<Unit>() {
@Suspendable @Suspendable
override fun call() = Unit override fun call() = Unit