diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 88829caeec..587e89dd08 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -16,10 +16,7 @@ import net.corda.core.CordaInternal import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate -import net.corda.core.internal.FlowIORequest -import net.corda.core.internal.FlowStateMachine -import net.corda.core.internal.abbreviate -import net.corda.core.internal.uncheckedCast +import net.corda.core.internal.* import net.corda.core.messaging.DataFeed import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowAsyncOperation.kt b/core/src/main/kotlin/net/corda/core/internal/FlowAsyncOperation.kt new file mode 100644 index 0000000000..c0cd2284d8 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/FlowAsyncOperation.kt @@ -0,0 +1,33 @@ +/* + * R3 Proprietary and Confidential + * + * Copyright (c) 2018 R3 Limited. All rights reserved. + * + * The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law. + * + * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. + */ + +package net.corda.core.internal + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.concurrent.CordaFuture +import net.corda.core.flows.FlowLogic +import net.corda.core.serialization.CordaSerializable + +/** + * Interface for arbitrary operations that can be invoked in a flow asynchronously - the flow will suspend until the + * operation completes. Operation parameters are expected to be injected via constructor. + */ +@CordaSerializable +interface FlowAsyncOperation { + /** Performs the operation in a non-blocking fashion. */ + fun execute(): CordaFuture +} + +/** Executes the specified [operation] and suspends until operation completion. */ +@Suspendable +fun FlowLogic.executeAsync(operation: FlowAsyncOperation, maySkipCheckpoint: Boolean = false): R { + val request = FlowIORequest.ExecuteAsyncOperation(operation) + return stateMachine.suspend(request, maySkipCheckpoint) +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt index 35913909e8..4150610f8e 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowIORequest.kt @@ -91,5 +91,10 @@ sealed class FlowIORequest { * Suspend the flow until all Initiating sessions are confirmed. */ object WaitForSessionConfirmations : FlowIORequest() + + /** + * Execute the specified [operation], suspend the flow until completion. + */ + data class ExecuteAsyncOperation(val operation: FlowAsyncOperation) : FlowIORequest() } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt index 3e9dc3cb0e..120baee8c1 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Action.kt @@ -13,6 +13,7 @@ package net.corda.node.services.statemachine import net.corda.core.crypto.SecureHash import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party +import net.corda.core.internal.FlowAsyncOperation import net.corda.node.services.messaging.DeduplicationHandler import java.time.Instant @@ -121,6 +122,11 @@ sealed class Action { * Commit the current database transaction. */ object CommitTransaction : Action() { override fun toString() = "CommitTransaction" } + + /** + * Execute the specified [operation]. + */ + data class ExecuteAsyncOperation(val operation: FlowAsyncOperation<*>) : Action() } /** diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index a22a03bd1d..7a7d2cdb88 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -81,6 +81,7 @@ class ActionExecutorImpl( is Action.CreateTransaction -> executeCreateTransaction() is Action.RollbackTransaction -> executeRollbackTransaction() is Action.CommitTransaction -> executeCommitTransaction() + is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action) } } @@ -218,6 +219,19 @@ class ActionExecutorImpl( } } + @Suspendable + private fun executeAsyncOperation(fiber: FlowFiber, action: Action.ExecuteAsyncOperation) { + val operationFuture = action.operation.execute() + operationFuture.thenMatch( + success = { result -> + fiber.scheduleEvent(Event.AsyncOperationCompletion(result)) + }, + failure = { exception -> + fiber.scheduleEvent(Event.Error(exception)) + } + ) + } + private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes { return checkpoint.serialize(context = checkpointSerializationContext) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt index 9d18ebb18c..17baa423c2 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/Event.kt @@ -124,4 +124,13 @@ sealed class Event { * @param returnValue the return value of the flow. */ data class FlowFinish(val returnValue: Any?) : Event() + + /** + * Signals the completion of a [FlowAsyncOperation]. + * + * Scheduling is triggered by the service that completes the future returned by the async operation. + * + * @param returnValue the result of the operation. + */ + data class AsyncOperationCompletion(val returnValue: Any?) : Event() } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt index 7a2496b238..24c59aa6f9 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionImpl.kt @@ -18,6 +18,7 @@ import net.corda.core.identity.Party import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowStateMachine import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.serialize import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.UntrustworthyData @@ -60,7 +61,10 @@ class FlowSessionImpl( sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), shouldRetrySend = false ) - return getFlowStateMachine().suspend(request, maySkipCheckpoint)[this]!!.checkPayloadIs(receiveType) + val responseValues: Map> = getFlowStateMachine().suspend(request, maySkipCheckpoint) + val responseForCurrentSession = responseValues[this]!! + + return responseForCurrentSession.checkPayloadIs(receiveType) } @Suspendable diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt index be0029bca2..c313ee59e7 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/StartedFlowTransition.kt @@ -49,6 +49,7 @@ class StartedFlowTransition( is FlowIORequest.Sleep -> sleepTransition(flowIORequest) is FlowIORequest.GetFlowInfo -> getFlowInfoTransition(flowIORequest) is FlowIORequest.WaitForSessionConfirmations -> waitForSessionConfirmationsTransition() + is FlowIORequest.ExecuteAsyncOperation<*> -> executeAsyncOperation(flowIORequest) } } @@ -388,6 +389,9 @@ class StartedFlowTransition( is FlowIORequest.WaitForSessionConfirmations -> { collectErroredInitiatingSessionErrors(checkpoint) } + is FlowIORequest.ExecuteAsyncOperation<*> -> { + emptyList() + } } } @@ -406,4 +410,11 @@ class StartedFlowTransition( firstPayload = payload ) } + + private fun executeAsyncOperation(flowIORequest: FlowIORequest.ExecuteAsyncOperation<*>): TransitionResult { + return builder { + actions.add(Action.ExecuteAsyncOperation(flowIORequest.operation)) + FlowContinuation.ProcessEvents + } + } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt index ce413287d1..7944f63f43 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/transitions/TopLevelTransition.kt @@ -39,6 +39,7 @@ class TopLevelTransition( is Event.Suspend -> suspendTransition(event) is Event.FlowFinish -> flowFinishTransition(event) is Event.InitiateFlow -> initiateFlowTransition(event) + is Event.AsyncOperationCompletion -> asyncOperationCompletionTransition(event) } } @@ -243,4 +244,10 @@ class TopLevelTransition( } return null } + + private fun asyncOperationCompletionTransition(event: Event.AsyncOperationCompletion): TransitionResult { + return builder { + resumeFlowLogic(event.returnValue) + } + } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowAsyncOperationTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowAsyncOperationTests.kt new file mode 100644 index 0000000000..711fbb93d3 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowAsyncOperationTests.kt @@ -0,0 +1,141 @@ +/* + * R3 Proprietary and Confidential + * + * Copyright (c) 2018 R3 Limited. All rights reserved. + * + * The intellectual and technical concepts contained herein are proprietary to R3 and its suppliers and are protected by trade secret law. + * + * Distribution of this file or any portion thereof via any medium without the express permission of R3 is strictly prohibited. + */ + +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.concurrent.CordaFuture +import net.corda.core.flows.FlowLogic +import net.corda.core.internal.FlowAsyncOperation +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.concurrent.transpose +import net.corda.core.internal.executeAsync +import net.corda.core.node.AppServiceHub +import net.corda.core.node.services.CordaService +import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.node.internal.StartedNode +import net.corda.testing.node.internal.InternalMockNetwork +import net.corda.testing.node.internal.startFlow +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.ExecutionException +import kotlin.test.assertFailsWith + +class FlowAsyncOperationTests { + private lateinit var mockNet: InternalMockNetwork + private lateinit var aliceNode: StartedNode + @Before + fun setup() { + mockNet = InternalMockNetwork( + cordappPackages = listOf("net.corda.testing.contracts", "net.corda.node.services.statemachine"), + notarySpecs = emptyList() + ) + aliceNode = mockNet.createNode() + } + + @After + fun cleanUp() { + mockNet.stopNodes() + } + + @Test + fun `operation errors are propagated correctly`() { + val flow = object : FlowLogic() { + @Suspendable + override fun call() { + executeAsync(ErroredExecute()) + } + } + + assertFailsWith { aliceNode.services.startFlow(flow).resultFuture.get() } + } + + private class ErroredExecute : FlowAsyncOperation { + override fun execute(): CordaFuture { + throw Exception() + } + } + + @Test + fun `operation result errors are propagated correctly`() { + val flow = object : FlowLogic() { + @Suspendable + override fun call() { + executeAsync(ErroredResult()) + } + } + + assertFailsWith { aliceNode.services.startFlow(flow).resultFuture.get() } + } + + private class ErroredResult : FlowAsyncOperation { + override fun execute(): CordaFuture { + val future = openFuture() + future.setException(Exception()) + return future + } + } + + @Test(timeout = 30_000) + fun `flows waiting on an async operation do not block the thread`() { + // Kick off 10 flows that submit a task to the service and wait until completion + val numFlows = 10 + val futures = (1..10).map { + aliceNode.services.startFlow(TestFlowWithAsyncAction(false)).resultFuture + } + // Make sure all flows submitted a task to the service and are awaiting completion + val service = aliceNode.services.cordaService(WorkerService::class.java) + while (service.pendingCount != numFlows) Thread.sleep(100) + // Complete all pending tasks. If async operations aren't handled as expected, and one of the previous flows is + // actually blocking the thread, the following flow will deadlock and the test won't finish. + aliceNode.services.startFlow(TestFlowWithAsyncAction(true)).resultFuture.get() + // Make sure all waiting flows completed successfully + futures.transpose().get() + } + + private class TestFlowWithAsyncAction(val completeAllTasks: Boolean) : FlowLogic() { + @Suspendable + override fun call() { + val scv = serviceHub.cordaService(WorkerService::class.java) + executeAsync(WorkerServiceTask(completeAllTasks, scv)) + } + } + + private class WorkerServiceTask(val completeAllTasks: Boolean, val service: WorkerService) : FlowAsyncOperation { + override fun execute(): CordaFuture { + return service.performTask(completeAllTasks) + } + } + + /** A dummy worker service that queues up tasks and allows clearing the entire task backlog. */ + @CordaService + class WorkerService(val serviceHub: AppServiceHub) : SingletonSerializeAsToken() { + private val pendingTasks = ConcurrentLinkedQueue>() + val pendingCount: Int get() = pendingTasks.count() + + fun performTask(completeAllTasks: Boolean): CordaFuture { + val taskFuture = openFuture() + pendingTasks.add(taskFuture) + if (completeAllTasks) { + synchronized(this) { + while (!pendingTasks.isEmpty()) { + val fut = pendingTasks.poll()!! + fut.set(Unit) + } + } + } + return taskFuture + } + } +} +