CORDA-1494: Re-enable notarisation retries in the new state machine manager (#3295)

* Remove all notion of message level retry.

* Introduce randomness into de-duplication IDs based on the session rather than the flow, in support of idempotent flows.

* CORDA-1494: Re-enable notarisation retries in the new state machine manager.

The original message-based retry approach does not work well with the new
flow state machine due to the way sessions are handled. We decided to move
the retry logic to flow-level: introduce RetryableFlow that won't have
checkpoints persisted and will be restarted after a configurable timeout
if it does not complete in time.

The RetryableFlow functionality will be internal for now, as it's mainly
tailored for the notary client flow, and there are many subtle ways it can
fail when used with arbitrary flows.
This commit is contained in:
Andrius Dagys 2018-06-07 08:45:32 +01:00 committed by GitHub
parent 6a2e50b730
commit 0978d041a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 545 additions and 337 deletions

View File

@ -2225,7 +2225,7 @@ public final class net.corda.core.flows.NotaryFlow extends java.lang.Object
## ##
@DoNotImplement @DoNotImplement
@InitiatingFlow @InitiatingFlow
public static class net.corda.core.flows.NotaryFlow$Client extends net.corda.core.flows.FlowLogic public static class net.corda.core.flows.NotaryFlow$Client extends net.corda.core.flows.FlowLogic implements net.corda.core.internal.TimedFlow
public <init>(net.corda.core.transactions.SignedTransaction) public <init>(net.corda.core.transactions.SignedTransaction)
public <init>(net.corda.core.transactions.SignedTransaction, net.corda.core.utilities.ProgressTracker) public <init>(net.corda.core.transactions.SignedTransaction, net.corda.core.utilities.ProgressTracker)
@Suspendable @Suspendable

View File

@ -7,6 +7,7 @@ import net.corda.core.contracts.TimeWindow
import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.FetchDataFlow import net.corda.core.internal.FetchDataFlow
import net.corda.core.internal.TimedFlow
import net.corda.core.internal.notary.generateSignature import net.corda.core.internal.notary.generateSignature
import net.corda.core.internal.notary.validateSignatures import net.corda.core.internal.notary.validateSignatures
import net.corda.core.internal.pushToLoggingContext import net.corda.core.internal.pushToLoggingContext
@ -31,8 +32,10 @@ class NotaryFlow {
*/ */
@DoNotImplement @DoNotImplement
@InitiatingFlow @InitiatingFlow
open class Client(private val stx: SignedTransaction, open class Client(
override val progressTracker: ProgressTracker) : FlowLogic<List<TransactionSignature>>() { private val stx: SignedTransaction,
override val progressTracker: ProgressTracker
) : FlowLogic<List<TransactionSignature>>(), TimedFlow {
constructor(stx: SignedTransaction) : this(stx, tracker()) constructor(stx: SignedTransaction) : this(stx, tracker())
companion object { companion object {

View File

@ -19,13 +19,9 @@ sealed class FlowIORequest<out R : Any> {
* @property shouldRetrySend specifies whether the send should be retried. * @property shouldRetrySend specifies whether the send should be retried.
*/ */
data class Send( data class Send(
val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>, val sessionToMessage: Map<FlowSession, SerializedBytes<Any>>
val shouldRetrySend: Boolean
) : FlowIORequest<Unit>() { ) : FlowIORequest<Unit>() {
override fun toString() = "Send(" + override fun toString() = "Send(sessionToMessage=${sessionToMessage.mapValues { it.value.hash }})"
"sessionToMessage=${sessionToMessage.mapValues { it.value.hash }}, " +
"shouldRetrySend=$shouldRetrySend" +
")"
} }
/** /**

View File

@ -0,0 +1,23 @@
package net.corda.core.internal
/**
* A marker for a flow that will return the same result if replayed from the beginning. Any side effects the flow causes
* must also be idempotent.
*
* Flow idempotency allows skipping persisting checkpoints, allowing better performance.
*/
interface IdempotentFlow
/**
* An idempotent flow that needs to be replayed if it does not complete within a certain timeout.
*
* Example use would be the notary client flow: if the client sends a request to an HA notary cluster, it will get
* accepted by one of the cluster members, but the member might crash before returning a response. The client flow
* would be stuck waiting for that member to come back up. Retrying the notary flow will re-send the request to the
* next available notary cluster member.
*
* Note that any sub-flows called by a [TimedFlow] are assumed to be [IdempotentFlow] and will NOT have checkpoints
* persisted. Otherwise, it wouldn't be possible to correctly reset the [TimedFlow].
*/
// TODO: allow specifying retry settings per flow
interface TimedFlow : IdempotentFlow

View File

@ -0,0 +1,189 @@
package net.corda.node.services
import co.paralleluniverse.fibers.Suspendable
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.*
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.ResolveTransactionsFlow
import net.corda.core.internal.notary.NotaryServiceFlow
import net.corda.core.internal.notary.TrustedAuthorityNotaryService
import net.corda.core.internal.notary.UniquenessProvider
import net.corda.core.node.AppServiceHub
import net.corda.core.node.NotaryInfo
import net.corda.core.node.services.CordaService
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.seconds
import net.corda.node.internal.StartedNode
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.NotaryConfig
import net.corda.node.services.config.P2PMessagingRetryConfiguration
import net.corda.nodeapi.internal.DevIdentityGenerator
import net.corda.nodeapi.internal.network.NetworkParametersCopier
import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.dummyCommand
import net.corda.testing.core.singleIdentity
import net.corda.testing.internal.LogHelper
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MockNetworkParameters
import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.InternalMockNodeParameters
import net.corda.testing.node.internal.startFlow
import org.junit.AfterClass
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Test
import org.slf4j.MDC
import java.security.PublicKey
import java.util.concurrent.atomic.AtomicInteger
class TimedFlowTests {
companion object {
/** The notary nodes don't run any consensus protocol, so 2 nodes are sufficient for the purpose of this test. */
private const val CLUSTER_SIZE = 2
/** A shared counter across all notary service nodes. */
var requestsReceived: AtomicInteger = AtomicInteger(0)
private lateinit var mockNet: InternalMockNetwork
private lateinit var notary: Party
private lateinit var node: StartedNode<InternalMockNetwork.MockNode>
init {
LogHelper.setLevel("+net.corda.flow", "+net.corda.testing.node", "+net.corda.node.services.messaging")
}
@BeforeClass
@JvmStatic
fun setup() {
mockNet = InternalMockNetwork(
listOf("net.corda.testing.contracts", "net.corda.node.services"),
MockNetworkParameters().withServicePeerAllocationStrategy(InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()),
threadPerNode = true
)
val started = startClusterAndNode(mockNet)
notary = started.first
node = started.second
}
@AfterClass
@JvmStatic
fun stopNodes() {
mockNet.stopNodes()
}
private fun startClusterAndNode(mockNet: InternalMockNetwork): Pair<Party, StartedNode<InternalMockNetwork.MockNode>> {
val replicaIds = (0 until CLUSTER_SIZE)
val notaryIdentity = DevIdentityGenerator.generateDistributedNotaryCompositeIdentity(
replicaIds.map { mockNet.baseDirectory(mockNet.nextNodeId + it) },
CordaX500Name("Custom Notary", "Zurich", "CH"))
val networkParameters = NetworkParametersCopier(testNetworkParameters(listOf(NotaryInfo(notaryIdentity, true))))
val notaryConfig = mock<NotaryConfig> {
whenever(it.custom).thenReturn(true)
whenever(it.isClusterConfig).thenReturn(true)
whenever(it.validating).thenReturn(true)
}
val notaryNodes = (0 until CLUSTER_SIZE).map {
mockNet.createUnstartedNode(InternalMockNodeParameters(configOverrides = {
doReturn(notaryConfig).whenever(it).notary
}))
}
val aliceNode = mockNet.createUnstartedNode(
InternalMockNodeParameters(
legalName = CordaX500Name("Alice", "AliceCorp", "GB"),
configOverrides = { conf: NodeConfiguration ->
val retryConfig = P2PMessagingRetryConfiguration(1.seconds, 3, 1.0)
doReturn(retryConfig).whenever(conf).p2pMessagingRetry
}
)
)
// MockNetwork doesn't support notary clusters, so we create all the nodes we need unstarted, and then install the
// network-parameters in their directories before they're started.
val node = (notaryNodes + aliceNode).map { node ->
networkParameters.install(mockNet.baseDirectory(node.id))
node.start()
}.last()
return Pair(notaryIdentity, node)
}
}
@Before
fun resetCounter() {
requestsReceived = AtomicInteger(0)
}
@Test
fun `timed flows are restarted`() {
node.run {
val issueTx = signInitialTransaction(notary) {
setTimeWindow(services.clock.instant(), 30.seconds)
addOutputState(DummyContract.SingleOwnerState(owner = info.singleIdentity()), DummyContract.PROGRAM_ID, AlwaysAcceptAttachmentConstraint)
}
val flow = NotaryFlow.Client(issueTx)
val notarySignatures = services.startFlow(flow).resultFuture.get()
(issueTx + notarySignatures).verifyRequiredSignatures()
}
}
@Test
fun `timed sub-flows are restarted`() {
node.run {
val issueTx = signInitialTransaction(notary) {
setTimeWindow(services.clock.instant(), 30.seconds)
addOutputState(DummyContract.SingleOwnerState(owner = info.singleIdentity()), DummyContract.PROGRAM_ID, AlwaysAcceptAttachmentConstraint)
}
val flow = FinalityFlow(issueTx)
val stx = services.startFlow(flow).resultFuture.get()
stx.verifyRequiredSignatures()
}
}
private fun StartedNode<InternalMockNetwork.MockNode>.signInitialTransaction(notary: Party, block: TransactionBuilder.() -> Any?): SignedTransaction {
return services.signInitialTransaction(
TransactionBuilder(notary).apply {
addCommand(dummyCommand(services.myInfo.singleIdentity().owningKey))
block()
}
)
}
@CordaService
private class TestNotaryService(override val services: AppServiceHub, override val notaryIdentityKey: PublicKey) : TrustedAuthorityNotaryService() {
override val uniquenessProvider = mock<UniquenessProvider>()
override fun createServiceFlow(otherPartySession: FlowSession): FlowLogic<Void?> = TestNotaryFlow(otherPartySession, this)
override fun start() {}
override fun stop() {}
}
/** A notary flow that will yield without returning a response on the very first received request. */
private class TestNotaryFlow(otherSide: FlowSession, service: TestNotaryService) : NotaryServiceFlow(otherSide, service) {
@Suspendable
override fun validateRequest(requestPayload: NotarisationPayload): TransactionParts {
val myIdentity = serviceHub.myInfo.legalIdentities.first()
MDC.put("name", myIdentity.name.toString())
logger.info("Received a request from ${otherSideSession.counterparty.name}")
val stx = requestPayload.signedTransaction
subFlow(ResolveTransactionsFlow(stx, otherSideSession))
if (TimedFlowTests.requestsReceived.getAndIncrement() == 0) {
logger.info("Ignoring")
// Waiting forever
stateMachine.suspend(FlowIORequest.WaitForLedgerCommit(SecureHash.randomSHA256()), false)
} else {
logger.info("Processing")
}
return TransactionParts(stx.id, stx.inputs, stx.tx.timeWindow, stx.notary)
}
}
}

View File

@ -2,9 +2,7 @@ package net.corda.services.messaging
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.randomOrNull
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
@ -15,7 +13,6 @@ import net.corda.core.utilities.seconds
import net.corda.node.services.messaging.MessagingService import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.send import net.corda.node.services.messaging.send
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.driver.DriverDSL import net.corda.testing.driver.DriverDSL
import net.corda.testing.driver.DriverParameters import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.InProcess import net.corda.testing.driver.InProcess
@ -26,10 +23,7 @@ import net.corda.testing.node.NotarySpec
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Test import org.junit.Test
import java.util.* import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
class P2PMessagingTest { class P2PMessagingTest {
private companion object { private companion object {
@ -43,128 +37,12 @@ class P2PMessagingTest {
} }
} }
@Test
fun `distributed service requests are retried if one of the nodes in the cluster goes down without sending a response`() {
startDriverWithDistributedService { distributedServiceNodes ->
val alice = startAlice()
val serviceAddress = alice.services.networkMapCache.run {
val notaryParty = notaryIdentities.randomOrNull()!!
alice.internalServices.networkService.getAddressOfParty(getPartyInfo(notaryParty)!!)
}
val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
// Send a single request with retry
val responseFuture = alice.receiveFrom(serviceAddress, retryId = 0)
crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS)
// The request wasn't successful.
assertThat(responseFuture.isDone).isFalse()
crashingNodes.ignoreRequests = false
// The retry should be successful.
val response = responseFuture.getOrThrow(10.seconds)
assertThat(response).isEqualTo(responseMessage)
}
}
@Test
fun `distributed service request retries are persisted across client node restarts`() {
startDriverWithDistributedService { distributedServiceNodes ->
val alice = startAlice()
val serviceAddress = alice.services.networkMapCache.run {
val notaryParty = notaryIdentities.randomOrNull()!!
alice.internalServices.networkService.getAddressOfParty(getPartyInfo(notaryParty)!!)
}
val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
// Send a single request with retry
alice.receiveFrom(serviceAddress, retryId = 0)
// Wait until the first request is received
crashingNodes.firstRequestReceived.await()
// Stop alice's node after we ensured that the first request was delivered and ignored.
alice.stop()
val numberOfRequestsReceived = crashingNodes.requestsReceived.get()
assertThat(numberOfRequestsReceived).isGreaterThanOrEqualTo(1)
crashingNodes.ignoreRequests = false
// Restart the node and expect a response
val aliceRestarted = startAlice()
val responseFuture = openFuture<Any>()
aliceRestarted.internalServices.networkService.runOnNextMessage("test.response") {
responseFuture.set(it.data.deserialize())
}
val response = responseFuture.getOrThrow()
assertThat(crashingNodes.requestsReceived.get()).isGreaterThan(numberOfRequestsReceived)
assertThat(response).isEqualTo(responseMessage)
}
}
private fun startDriverWithDistributedService(dsl: DriverDSL.(List<InProcess>) -> Unit) { private fun startDriverWithDistributedService(dsl: DriverDSL.(List<InProcess>) -> Unit) {
driver(DriverParameters(startNodesInProcess = true, notarySpecs = listOf(NotarySpec(DISTRIBUTED_SERVICE_NAME, cluster = ClusterSpec.Raft(clusterSize = 2))))) { driver(DriverParameters(startNodesInProcess = true, notarySpecs = listOf(NotarySpec(DISTRIBUTED_SERVICE_NAME, cluster = ClusterSpec.Raft(clusterSize = 2))))) {
dsl(defaultNotaryHandle.nodeHandles.getOrThrow().map { (it as InProcess) }) dsl(defaultNotaryHandle.nodeHandles.getOrThrow().map { (it as InProcess) })
} }
} }
private fun DriverDSL.startAlice(): InProcess {
return startNode(providedName = ALICE_NAME, customOverrides = mapOf("p2pMessagingRetry" to mapOf(
"messageRedeliveryDelay" to 1.seconds, "backoffBase" to 1.0, "maxRetryCount" to 3)))
.map { (it as InProcess) }
.getOrThrow()
}
data class CrashingNodes(
val firstRequestReceived: CountDownLatch,
val requestsReceived: AtomicInteger,
var ignoreRequests: Boolean
)
/**
* Sets up the [distributedServiceNodes] to respond to "test.request" requests. All nodes will receive requests and
* either ignore them or respond to "test.response", depending on the value of [CrashingNodes.ignoreRequests],
* initially set to true. This may be used to simulate scenarios where nodes receive request messages but crash
* before sending back a response.
*/
private fun simulateCrashingNodes(distributedServiceNodes: List<InProcess>, responseMessage: String): CrashingNodes {
val crashingNodes = CrashingNodes(
requestsReceived = AtomicInteger(0),
firstRequestReceived = CountDownLatch(1),
ignoreRequests = true
)
distributedServiceNodes.forEach {
val nodeName = it.services.myInfo.legalIdentitiesAndCerts.first().name
it.internalServices.networkService.addMessageHandler("test.request") { netMessage, _, handler ->
crashingNodes.requestsReceived.incrementAndGet()
crashingNodes.firstRequestReceived.countDown()
// The node which receives the first request will ignore all requests
print("$nodeName: Received request - ")
if (crashingNodes.ignoreRequests) {
println("ignoring")
// Requests are ignored to simulate a service node crashing before sending back a response.
// A retry by the client will result in the message being redelivered to another node in the service cluster.
} else {
println("sending response")
val request = netMessage.data.deserialize<TestRequest>()
val response = it.internalServices.networkService.createMessage("test.response", responseMessage.serialize().bytes)
it.internalServices.networkService.send(response, request.replyTo)
}
handler.afterDatabaseTransaction()
}
}
return crashingNodes
}
private fun assertAllNodesAreUsed(participatingServiceNodes: List<InProcess>, serviceName: CordaX500Name, originatingNode: InProcess) { private fun assertAllNodesAreUsed(participatingServiceNodes: List<InProcess>, serviceName: CordaX500Name, originatingNode: InProcess) {
// Setup each node in the distributed service to return back it's NodeInfo so that we can know which node is being used // Setup each node in the distributed service to return back it's NodeInfo so that we can know which node is being used
participatingServiceNodes.forEach { node -> participatingServiceNodes.forEach { node ->
@ -195,12 +73,12 @@ class P2PMessagingTest {
} }
} }
private fun InProcess.receiveFrom(target: MessageRecipients, retryId: Long? = null): CordaFuture<Any> { private fun InProcess.receiveFrom(target: MessageRecipients): CordaFuture<Any> {
val response = openFuture<Any>() val response = openFuture<Any>()
internalServices.networkService.runOnNextMessage("test.response") { netMessage -> internalServices.networkService.runOnNextMessage("test.response") { netMessage ->
response.set(netMessage.data.deserialize()) response.set(netMessage.data.deserialize())
} }
internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target, retryId = retryId) internalServices.networkService.send("test.request", TestRequest(replyTo = internalServices.networkService.myAddress), target)
return response return response
} }

View File

@ -64,8 +64,6 @@ interface MessagingService {
* There is no way to know if a message has been received. If your flow requires this, you need the recipient * There is no way to know if a message has been received. If your flow requires this, you need the recipient
* to send an ACK message back. * to send an ACK message back.
* *
* @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id.
* Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary.
* @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two * @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two
* subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same * subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same
* sequence the send()s were called. By default this is chosen conservatively to be [target]. * sequence the send()s were called. By default this is chosen conservatively to be [target].
@ -74,7 +72,6 @@ interface MessagingService {
fun send( fun send(
message: Message, message: Message,
target: MessageRecipients, target: MessageRecipients,
retryId: Long? = null,
sequenceKey: Any = target sequenceKey: Any = target
) )
@ -82,7 +79,6 @@ interface MessagingService {
data class AddressedMessage( data class AddressedMessage(
val message: Message, val message: Message,
val target: MessageRecipients, val target: MessageRecipients,
val retryId: Long? = null,
val sequenceKey: Any = target val sequenceKey: Any = target
) )
@ -95,9 +91,6 @@ interface MessagingService {
@Suspendable @Suspendable
fun send(addressedMessages: List<AddressedMessage>) fun send(addressedMessages: List<AddressedMessage>)
/** Cancels the scheduled message redelivery for the specified [retryId] */
fun cancelRedelivery(retryId: Long)
/** /**
* Returns an initialised [Message] with the current time, etc, already filled in. * Returns an initialised [Message] with the current time, etc, already filled in.
* *
@ -115,7 +108,7 @@ interface MessagingService {
val myAddress: SingleMessageRecipient val myAddress: SingleMessageRecipient
} }
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), retryId: Long? = null, additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to, retryId) fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: SenderDeduplicationId = SenderDeduplicationId(DeduplicationId.createRandom(newSecureRandom()), ourSenderUUID), additionalHeaders: Map<String, String> = emptyMap()) = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId, additionalHeaders), to)
interface MessageHandlerRegistration interface MessageHandlerRegistration

View File

@ -15,7 +15,11 @@ import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.nodeSerializationEnv import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.* import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.LifecycleSupport import net.corda.node.internal.LifecycleSupport
import net.corda.node.internal.artemis.ReactiveArtemisConsumer import net.corda.node.internal.artemis.ReactiveArtemisConsumer
@ -26,43 +30,41 @@ import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExternalEvent import net.corda.node.services.statemachine.ExternalEvent
import net.corda.node.services.statemachine.SenderDeduplicationId import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.PersistentMap
import net.corda.nodeapi.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport import net.corda.nodeapi.ArtemisTcpTransport.Companion.p2pConnectorTcpTransport
import net.corda.nodeapi.internal.ArtemisMessagingComponent import net.corda.nodeapi.internal.ArtemisMessagingComponent
import net.corda.nodeapi.internal.ArtemisMessagingComponent.* import net.corda.nodeapi.internal.ArtemisMessagingComponent.ArtemisAddress
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_CONTROL import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_CONTROL
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_NOTIFY import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.BRIDGE_NOTIFY
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.JOURNAL_HEADER_SIZE import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.JOURNAL_HEADER_SIZE
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2PMessagingHeaders
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEERS_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.NodeAddress
import net.corda.nodeapi.internal.ArtemisMessagingComponent.RemoteInboxAddress
import net.corda.nodeapi.internal.ArtemisMessagingComponent.ServiceAddress
import net.corda.nodeapi.internal.bridging.BridgeControl import net.corda.nodeapi.internal.bridging.BridgeControl
import net.corda.nodeapi.internal.bridging.BridgeEntry import net.corda.nodeapi.internal.bridging.BridgeEntry
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.requireMessageSize import net.corda.nodeapi.internal.requireMessageSize
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER
import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.RoutingType
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.ActiveMQClient
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY import org.apache.activemq.artemis.api.core.client.ClientConsumer
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.client.ServerLocator
import rx.Observable import rx.Observable
import rx.Subscription import rx.Subscription
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.io.Serializable
import java.security.PublicKey import java.security.PublicKey
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
import java.util.concurrent.ScheduledFuture
import java.util.concurrent.TimeUnit
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import javax.persistence.Column
import javax.persistence.Entity
import javax.persistence.Id
import javax.persistence.Lob
/** /**
* This class implements the [MessagingService] API using Apache Artemis, the successor to their ActiveMQ product. * This class implements the [MessagingService] API using Apache Artemis, the successor to their ActiveMQ product.
@ -106,35 +108,12 @@ class P2PMessagingClient(val config: NodeConfiguration,
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()
fun createMessageToRedeliver(): PersistentMap<Long, Pair<Message, MessageRecipients>, RetryMessage, Long> {
return PersistentMap(
toPersistentEntityKey = { it },
fromPersistentEntity = {
Pair(it.key,
Pair(it.message.deserialize(context = SerializationDefaults.STORAGE_CONTEXT),
it.recipients.deserialize(context = SerializationDefaults.STORAGE_CONTEXT))
)
},
toPersistentEntity = { _key: Long, (_message: Message, _recipient: MessageRecipients): Pair<Message, MessageRecipients> ->
RetryMessage().apply {
key = _key
message = _message.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes
recipients = _recipient.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes
}
},
persistentEntityClass = RetryMessage::class.java
)
}
class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message { class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: DeduplicationId, override val senderUUID: String?, override val additionalHeaders: Map<String, String>) : Message {
override val debugTimestamp: Instant = Instant.now() override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topic#${String(data.bytes)}" override fun toString() = "$topic#${String(data.bytes)}"
} }
} }
private val messageMaxRetryCount: Int = config.p2pMessagingRetry.maxRetryCount
private val backoffBase: Double = config.p2pMessagingRetry.backoffBase
private class InnerState { private class InnerState {
var started = false var started = false
var running = false var running = false
@ -150,17 +129,12 @@ class P2PMessagingClient(val config: NodeConfiguration,
fun sendMessage(address: String, message: ClientMessage) = producer!!.send(address, message) fun sendMessage(address: String, message: ClientMessage) = producer!!.send(address, message)
} }
private val messagesToRedeliver = createMessageToRedeliver()
private val scheduledMessageRedeliveries = ConcurrentHashMap<Long, ScheduledFuture<*>>()
/** A registration to handle messages of different types */ /** A registration to handle messages of different types */
data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration
override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress) override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress)
override val ourSenderUUID = UUID.randomUUID().toString() override val ourSenderUUID = UUID.randomUUID().toString()
private val messageRedeliveryDelaySeconds = config.p2pMessagingRetry.messageRedeliveryDelay.seconds
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>()) private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>())
private val delayStartQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>()) private val delayStartQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>())
@ -170,21 +144,6 @@ class P2PMessagingClient(val config: NodeConfiguration,
private val deduplicator = P2PMessageDeduplicator(database) private val deduplicator = P2PMessageDeduplicator(database)
internal var messagingExecutor: MessagingExecutor? = null internal var messagingExecutor: MessagingExecutor? = null
@Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry")
class RetryMessage(
@Id
@Column(name = "message_id", length = 64, nullable = false)
var key: Long = 0,
@Lob
@Column(nullable = false)
var message: ByteArray = EMPTY_BYTE_ARRAY,
@Lob
@Column(nullable = false)
var recipients: ByteArray = EMPTY_BYTE_ARRAY
) : Serializable
fun start() { fun start() {
state.locked { state.locked {
started = true started = true
@ -235,8 +194,6 @@ class P2PMessagingClient(val config: NodeConfiguration,
registerBridgeControl(bridgeSession!!, inboxes.toList()) registerBridgeControl(bridgeSession!!, inboxes.toList())
enumerateBridges(bridgeSession!!, inboxes.toList()) enumerateBridges(bridgeSession!!, inboxes.toList())
} }
resumeMessageRedelivery()
} }
private fun InnerState.registerBridgeControl(session: ClientSession, inboxes: List<String>) { private fun InnerState.registerBridgeControl(session: ClientSession, inboxes: List<String>) {
@ -335,12 +292,6 @@ class P2PMessagingClient(val config: NodeConfiguration,
sendBridgeControl(startupMessage) sendBridgeControl(startupMessage)
} }
private fun resumeMessageRedelivery() {
messagesToRedeliver.forEach { retryId, (message, target) ->
send(message, target, retryId)
}
}
private val shutdownLatch = CountDownLatch(1) private val shutdownLatch = CountDownLatch(1)
/** /**
@ -508,53 +459,15 @@ class P2PMessagingClient(val config: NodeConfiguration,
override fun close() = stop() override fun close() = stop()
@Suspendable @Suspendable
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) {
requireMessageSize(message.data.size, maxMessageSize) requireMessageSize(message.data.size, maxMessageSize)
messagingExecutor!!.send(message, target) messagingExecutor!!.send(message, target)
retryId?.let {
database.transaction {
messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) })
}
scheduledMessageRedeliveries[it] = nodeExecutor.schedule({
sendWithRetry(0, message, target, retryId)
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
}
} }
@Suspendable @Suspendable
override fun send(addressedMessages: List<MessagingService.AddressedMessage>) { override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) { for ((message, target, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey) send(message, target, sequenceKey)
}
}
private fun sendWithRetry(retryCount: Int, message: Message, target: MessageRecipients, retryId: Long) {
log.trace { "Attempting to retry #$retryCount message delivery for $retryId" }
if (retryCount >= messageMaxRetryCount) {
log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $target")
scheduledMessageRedeliveries.remove(retryId)
return
}
val messageWithRetryCount = object : Message by message {
override val uniqueMessageId = DeduplicationId("${message.uniqueMessageId.toString}-$retryCount")
}
messagingExecutor!!.send(messageWithRetryCount, target)
scheduledMessageRedeliveries[retryId] = nodeExecutor.schedule({
sendWithRetry(retryCount + 1, message, target, retryId)
}, messageRedeliveryDelaySeconds * Math.pow(backoffBase, retryCount.toDouble()).toLong(), TimeUnit.SECONDS)
}
override fun cancelRedelivery(retryId: Long) {
database.transaction {
messagesToRedeliver.remove(retryId)
}
scheduledMessageRedeliveries[retryId]?.let {
log.trace { "Cancelling message redelivery for retry id $retryId" }
if (!it.isDone) it.cancel(true)
scheduledMessageRedeliveries.remove(retryId)
} }
} }

View File

@ -3,19 +3,18 @@ package net.corda.node.services.schema
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.FungibleAsset import net.corda.core.contracts.FungibleAsset
import net.corda.core.contracts.LinearState import net.corda.core.contracts.LinearState
import net.corda.node.internal.schemas.NodeInfoSchemaV1
import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.CommonSchemaV1
import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.MappedSchema
import net.corda.core.schemas.PersistentState import net.corda.core.schemas.PersistentState
import net.corda.core.schemas.QueryableState import net.corda.core.schemas.QueryableState
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.node.internal.schemas.NodeInfoSchemaV1
import net.corda.node.services.api.SchemaService import net.corda.node.services.api.SchemaService
import net.corda.node.services.api.SchemaService.SchemaOptions import net.corda.node.services.api.SchemaService.SchemaOptions
import net.corda.node.services.events.NodeSchedulerService import net.corda.node.services.events.NodeSchedulerService
import net.corda.node.services.identity.PersistentIdentityService import net.corda.node.services.identity.PersistentIdentityService
import net.corda.node.services.keys.PersistentKeyManagementService import net.corda.node.services.keys.PersistentKeyManagementService
import net.corda.node.services.messaging.P2PMessageDeduplicator import net.corda.node.services.messaging.P2PMessageDeduplicator
import net.corda.node.services.messaging.P2PMessagingClient
import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.persistence.DBTransactionMappingStorage import net.corda.node.services.persistence.DBTransactionMappingStorage
import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.persistence.DBTransactionStorage
@ -45,7 +44,6 @@ class NodeSchemaService(extraSchemas: Set<MappedSchema> = emptySet(), includeNot
NodeSchedulerService.PersistentScheduledState::class.java, NodeSchedulerService.PersistentScheduledState::class.java,
NodeAttachmentService.DBAttachment::class.java, NodeAttachmentService.DBAttachment::class.java,
P2PMessageDeduplicator.ProcessedMessage::class.java, P2PMessageDeduplicator.ProcessedMessage::class.java,
P2PMessagingClient.RetryMessage::class.java,
PersistentIdentityService.PersistentIdentity::class.java, PersistentIdentityService.PersistentIdentity::class.java,
PersistentIdentityService.PersistentIdentityNames::class.java, PersistentIdentityService.PersistentIdentityNames::class.java,
ContractUpgradeServiceImpl.DBContractUpgrade::class.java ContractUpgradeServiceImpl.DBContractUpgrade::class.java

View File

@ -135,6 +135,19 @@ sealed class Action {
* Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details. * Retry a flow from the last checkpoint, or if there is no checkpoint, restart the flow with the same invocation details.
*/ */
data class RetryFlowFromSafePoint(val currentState: StateMachineState) : Action() data class RetryFlowFromSafePoint(val currentState: StateMachineState) : Action()
/**
* Schedule the flow [flowId] to be retried if it does not complete within the timeout period specified in the configuration.
*
* Note that this only works with [TimedFlow].
*/
data class ScheduleFlowTimeout(val flowId: StateMachineRunId) : Action()
/**
* Cancel the retry timeout for flow [flowId]. This must be called when a timed flow completes to prevent
* unnecessary additional invocations.
*/
data class CancelFlowTimeout(val flowId: StateMachineRunId) : Action()
} }
/** /**

View File

@ -74,9 +74,10 @@ class ActionExecutorImpl(
is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action) is Action.ExecuteAsyncOperation -> executeAsyncOperation(fiber, action)
is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action) is Action.ReleaseSoftLocks -> executeReleaseSoftLocks(action)
is Action.RetryFlowFromSafePoint -> executeRetryFlowFromSafePoint(action) is Action.RetryFlowFromSafePoint -> executeRetryFlowFromSafePoint(action)
is Action.ScheduleFlowTimeout -> scheduleFlowTimeout(action)
is Action.CancelFlowTimeout -> cancelFlowTimeout(action)
} }
} }
private fun executeReleaseSoftLocks(action: Action.ReleaseSoftLocks) { private fun executeReleaseSoftLocks(action: Action.ReleaseSoftLocks) {
if (action.uuid != null) services.vaultService.softLockRelease(action.uuid) if (action.uuid != null) services.vaultService.softLockRelease(action.uuid)
} }
@ -234,4 +235,12 @@ class ActionExecutorImpl(
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> { private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext) return checkpoint.serialize(context = checkpointSerializationContext)
} }
private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) {
stateMachineManager.cancelFlowTimeout(action.flowId)
}
private fun scheduleFlowTimeout(action: Action.ScheduleFlowTimeout) {
stateMachineManager.scheduleFlowTimeout(action.flowId)
}
} }

View File

@ -18,16 +18,16 @@ data class DeduplicationId(val toString: String) {
* creating IDs in case the message-generating flow logic is replayed on hard failure. * creating IDs in case the message-generating flow logic is replayed on hard failure.
* *
* A normal deduplication ID consists of: * A normal deduplication ID consists of:
* 1. A deduplication seed set per flow. This is either the flow's ID or in case of an initated flow the * 1. A deduplication seed set per session. This is the initiator's session ID, with a prefix for initiator
* initiator's session ID. * or initiated.
* 2. The number of *clean* suspends since the start of the flow. * 2. The number of *clean* suspends since the start of the flow.
* 3. An optional additional index, for cases where several messages are sent as part of the state transition. * 3. An optional additional index, for cases where several messages are sent as part of the state transition.
* Note that care must be taken with this index, it must be a deterministic counter. For example a naive * Note that care must be taken with this index, it must be a deterministic counter. For example a naive
* iteration over a HashMap will produce a different list of indeces than a previous run, causing the * iteration over a HashMap will produce a different list of indeces than a previous run, causing the
* message-id map to change, which means deduplication will not happen correctly. * message-id map to change, which means deduplication will not happen correctly.
*/ */
fun createForNormal(checkpoint: Checkpoint, index: Int): DeduplicationId { fun createForNormal(checkpoint: Checkpoint, index: Int, session: SessionState): DeduplicationId {
return DeduplicationId("N-${checkpoint.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index") return DeduplicationId("N-${session.deduplicationSeed}-${checkpoint.numberOfSuspends}-$index")
} }
/** /**

View File

@ -73,8 +73,7 @@ class FlowSessionImpl(
@Suspendable @Suspendable
override fun send(payload: Any, maySkipCheckpoint: Boolean) { override fun send(payload: Any, maySkipCheckpoint: Boolean) {
val request = FlowIORequest.Send( val request = FlowIORequest.Send(
sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT)), sessionToMessage = mapOf(this to payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
shouldRetrySend = false
) )
return getFlowStateMachine().suspend(request, maySkipCheckpoint) return getFlowStateMachine().suspend(request, maySkipCheckpoint)
} }

View File

@ -138,7 +138,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val transitionExecutor = getTransientField(TransientValues::transitionExecutor) val transitionExecutor = getTransientField(TransientValues::transitionExecutor)
val eventQueue = getTransientField(TransientValues::eventQueue) val eventQueue = getTransientField(TransientValues::eventQueue)
try { try {
eventLoop@while (true) { eventLoop@ while (true) {
val nextEvent = eventQueue.receive() val nextEvent = eventQueue.receive()
val continuation = processEvent(transitionExecutor, nextEvent) val continuation = processEvent(transitionExecutor, nextEvent)
when (continuation) { when (continuation) {
@ -326,11 +326,14 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
parkAndSerialize { _, _ -> parkAndSerialize { _, _ ->
logger.trace { "Suspended on $ioRequest" } logger.trace { "Suspended on $ioRequest" }
// Will skip checkpoint if there are any idempotent flows in the subflow stack.
val skipPersistingCheckpoint = containsIdempotentFlows() || maySkipCheckpoint
contextTransactionOrNull = transaction.value contextTransactionOrNull = transaction.value
val event = try { val event = try {
Event.Suspend( Event.Suspend(
ioRequest = ioRequest, ioRequest = ioRequest,
maySkipCheckpoint = maySkipCheckpoint, maySkipCheckpoint = skipPersistingCheckpoint,
fiber = this.serialize(context = serializationContext.value) fiber = this.serialize(context = serializationContext.value)
) )
} catch (throwable: Throwable) { } catch (throwable: Throwable) {
@ -354,6 +357,11 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
)) ))
} }
private fun containsIdempotentFlows(): Boolean {
val subFlowStack = snapshot().checkpoint.subFlowStack
return subFlowStack.any { IdempotentFlow::class.java.isAssignableFrom(it.flowClass) }
}
@Suspendable @Suspendable
override fun scheduleEvent(event: Event) { override fun scheduleEvent(event: Event) {
getTransientField(TransientValues::eventQueue).send(event) getTransientField(TransientValues::eventQueue).send(event)

View File

@ -0,0 +1,9 @@
package net.corda.node.services.statemachine
import net.corda.core.CordaException
/**
* This exception is fired once the retry timeout of a [TimedFlow] expires.
* It will indicate to the flow hospital to restart the flow.
*/
data class FlowTimeoutException(val maxRetries: Int) : CordaException("replaying flow from the last checkpoint")

View File

@ -8,25 +8,17 @@ import co.paralleluniverse.strands.channels.Channels
import com.codahale.metrics.Gauge import com.codahale.metrics.Gauge
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.context.InvocationOrigin
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.*
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.castIfPossible
import net.corda.core.internal.concurrent.OpenFuture import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.map import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.*
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
@ -39,11 +31,6 @@ import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.createSubFlowVersion import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.createSubFlowVersion
import net.corda.node.services.statemachine.interceptors.* import net.corda.node.services.statemachine.interceptors.*
import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor
import net.corda.node.services.statemachine.interceptors.FiberDeserializationChecker
import net.corda.node.services.statemachine.interceptors.FiberDeserializationCheckingInterceptor
import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor
import net.corda.node.services.statemachine.interceptors.PrintingInterceptor
import net.corda.node.services.statemachine.transitions.StateMachine import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
@ -55,11 +42,11 @@ import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.security.SecureRandom import java.security.SecureRandom
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.*
import java.util.concurrent.ExecutorService
import java.util.concurrent.locks.ReentrantLock import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import kotlin.collections.ArrayList import kotlin.collections.ArrayList
import kotlin.collections.HashMap
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
import kotlin.streams.toList import kotlin.streams.toList
@ -83,18 +70,28 @@ class SingleThreadedStateMachineManager(
private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>) private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>)
private data class ScheduledTimeout(
/** Will fire a [FlowTimeoutException] indicating to the flow hospital to restart the flow. */
val scheduledFuture: ScheduledFuture<*>,
/** Specifies the number of times this flow has been retried. */
val retryCount: Int = 0
)
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
// property. // property.
private class InnerState { private class InnerState {
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!! val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
// True if we're shutting down, so don't resume anything. /** True if we're shutting down, so don't resume anything. */
var stopping = false var stopping = false
val flows = HashMap<StateMachineRunId, Flow>() val flows = HashMap<StateMachineRunId, Flow>()
val startedFutures = HashMap<StateMachineRunId, OpenFuture<Unit>>() val startedFutures = HashMap<StateMachineRunId, OpenFuture<Unit>>()
/** Flows scheduled to be retried if not finished within the specified timeout period. */
val timedFlows = HashMap<StateMachineRunId, ScheduledTimeout>()
} }
private val mutex = ThreadBox(InnerState()) private val mutex = ThreadBox(InnerState())
private val scheduler = FiberExecutorScheduler("Same thread scheduler", executor) private val scheduler = FiberExecutorScheduler("Same thread scheduler", executor)
private val timeoutScheduler = Executors.newScheduledThreadPool(1)
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
private val liveFibers = ReusableLatch() private val liveFibers = ReusableLatch()
// Monitoring support. // Monitoring support.
@ -209,8 +206,8 @@ class SingleThreadedStateMachineManager(
} }
override fun killFlow(id: StateMachineRunId): Boolean { override fun killFlow(id: StateMachineRunId): Boolean {
return mutex.locked { return mutex.locked {
cancelTimeoutIfScheduled(id)
val flow = flows.remove(id) val flow = flows.remove(id)
if (flow != null) { if (flow != null) {
logger.debug("Killing flow known to physical node.") logger.debug("Killing flow known to physical node.")
@ -262,6 +259,7 @@ class SingleThreadedStateMachineManager(
override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) { override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) {
mutex.locked { mutex.locked {
cancelTimeoutIfScheduled(flowId)
val flow = flows.remove(flowId) val flow = flows.remove(flowId)
if (flow != null) { if (flow != null) {
decrementLiveFibers() decrementLiveFibers()
@ -426,10 +424,11 @@ class SingleThreadedStateMachineManager(
"unknown session $recipientId, discarding..." "unknown session $recipientId, discarding..."
} }
} else { } else {
throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId") logger.warn("Cannot find flow corresponding to session ID $recipientId.")
} }
} else { } else {
val flow = mutex.locked { flows[flowId] } ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") val flow = mutex.locked { flows[flowId] }
?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId")
flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender)) flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender))
} }
} catch (exception: Exception) { } catch (exception: Exception) {
@ -444,6 +443,7 @@ class SingleThreadedStateMachineManager(
val payload = RejectSessionMessage(message, errorId) val payload = RejectSessionMessage(message, errorId)
return ExistingSessionMessage(initiatorSessionId, payload) return ExistingSessionMessage(initiatorSessionId, payload)
} }
val replyError = try { val replyError = try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage) val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage)
val initiatedSessionId = SessionId.createRandom(secureRandom) val initiatedSessionId = SessionId.createRandom(secureRandom)
@ -486,8 +486,8 @@ class SingleThreadedStateMachineManager(
} catch (e: ClassCastException) { } catch (e: ClassCastException) {
throw SessionRejectException("${message.initiatorFlowClassName} is not a flow") throw SessionRejectException("${message.initiatorFlowClassName} is not a flow")
} }
return serviceHub.getFlowFactory(initiatingFlowClass) ?: return serviceHub.getFlowFactory(initiatingFlowClass)
throw SessionRejectException("$initiatingFlowClass is not registered") ?: throw SessionRejectException("$initiatingFlowClass is not registered")
} }
private fun <A> startInitiatedFlow( private fun <A> startInitiatedFlow(
@ -532,7 +532,7 @@ class SingleThreadedStateMachineManager(
flowLogic.stateMachine = flowStateMachineImpl flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!) val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
val flowCorDappVersion= createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion) val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed, flowCorDappVersion).getOrThrow() val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed, flowCorDappVersion).getOrThrow()
val startedFuture = openFuture<Unit>() val startedFuture = openFuture<Unit>()
@ -556,6 +556,59 @@ class SingleThreadedStateMachineManager(
return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> } return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> }
} }
override fun scheduleFlowTimeout(flowId: StateMachineRunId) {
mutex.locked { scheduleTimeout(flowId) }
}
override fun cancelFlowTimeout(flowId: StateMachineRunId) {
mutex.locked { cancelTimeoutIfScheduled(flowId) }
}
/**
* Schedules the flow [flowId] to be retried if it does not finish within the timeout period
* specified in the config.
*
* Assumes lock is taken on the [InnerState].
*/
private fun InnerState.scheduleTimeout(flowId: StateMachineRunId) {
val flow = flows[flowId]
if (flow != null) {
val scheduledTimeout = timedFlows[flowId]
val retryCount = if (scheduledTimeout != null) {
val timeoutFuture = scheduledTimeout.scheduledFuture
if (!timeoutFuture.isDone) scheduledTimeout.scheduledFuture.cancel(true)
scheduledTimeout.retryCount
} else 0
val scheduledFuture = scheduleTimeoutException(flow, retryCount)
timedFlows[flowId] = ScheduledTimeout(scheduledFuture, retryCount + 1)
} else {
logger.warn("Unable to schedule timeout for flow $flowId flow not found.")
}
}
/** Schedules a [FlowTimeoutException] to be fired in order to restart the flow. */
private fun scheduleTimeoutException(flow: Flow, retryCount: Int): ScheduledFuture<*> {
return with(serviceHub.configuration.p2pMessagingRetry) {
val timeoutDelaySeconds = messageRedeliveryDelay.seconds * Math.pow(backoffBase, retryCount.toDouble()).toLong()
timeoutScheduler.schedule({
val event = Event.Error(FlowTimeoutException(maxRetryCount))
flow.fiber.scheduleEvent(event)
}, timeoutDelaySeconds, TimeUnit.SECONDS)
}
}
/**
* Cancels any scheduled flow timeout for [flowId].
*
* Assumes lock is taken on the [InnerState].
*/
private fun InnerState.cancelTimeoutIfScheduled(flowId: StateMachineRunId) {
timedFlows[flowId]?.let { (future, _) ->
if (!future.isDone) future.cancel(true)
timedFlows.remove(flowId)
}
}
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? { private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try { return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!) serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
@ -663,6 +716,8 @@ class SingleThreadedStateMachineManager(
} else { } else {
oldFlow.resultFuture.captureLater(flow.resultFuture) oldFlow.resultFuture.captureLater(flow.resultFuture)
} }
val flowLogic = flow.fiber.logic
if (flowLogic is TimedFlow) scheduleTimeout(id)
flow.fiber.scheduleEvent(Event.DoRemainingWork) flow.fiber.scheduleEvent(Event.DoRemainingWork)
when (checkpoint.flowState) { when (checkpoint.flowState) {
is FlowState.Unstarted -> { is FlowState.Unstarted -> {

View File

@ -1,6 +1,7 @@
package net.corda.node.services.statemachine package net.corda.node.services.statemachine
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.internal.TimedFlow
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import java.sql.SQLException import java.sql.SQLException
import java.time.Instant import java.time.Instant
@ -12,7 +13,7 @@ import java.util.concurrent.ConcurrentHashMap
object StaffedFlowHospital : FlowHospital { object StaffedFlowHospital : FlowHospital {
private val log = loggerFor<StaffedFlowHospital>() private val log = loggerFor<StaffedFlowHospital>()
private val staff = listOf(DeadlockNurse, DuplicateInsertSpecialist) private val staff = listOf(DeadlockNurse, DuplicateInsertSpecialist, DoctorTimeout)
private val patients = ConcurrentHashMap<StateMachineRunId, MedicalHistory>() private val patients = ConcurrentHashMap<StateMachineRunId, MedicalHistory>()
@ -124,4 +125,31 @@ object StaffedFlowHospital : FlowHospital {
return exception != null && (exception is org.hibernate.exception.ConstraintViolationException || mentionsConstraintViolation(exception.cause)) return exception != null && (exception is org.hibernate.exception.ConstraintViolationException || mentionsConstraintViolation(exception.cause))
} }
} }
/**
* Restarts [TimedFlow], keeping track of the number of retries and making sure it does not
* exceed the limit specified by the [FlowTimeoutException].
*/
object DoctorTimeout : Staff {
override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: MedicalHistory): Diagnosis {
if (newError is FlowTimeoutException) {
if (isTimedFlow(flowFiber)) {
if (history.notDischargedForTheSameThingMoreThan(newError.maxRetries, this)) {
return Diagnosis.DISCHARGE
} else {
log.warn("\"Maximum number of retries reached for timed flow ${flowFiber.javaClass}")
}
} else {
log.warn("\"Unable to restart flow: ${flowFiber.javaClass}, it is not timed and does not contain any timed sub-flows.")
}
}
return Diagnosis.NOT_MY_SPECIALTY
}
private fun isTimedFlow(flowFiber: FlowFiber): Boolean {
return flowFiber.snapshot().checkpoint.subFlowStack.any {
TimedFlow::class.java.isAssignableFrom(it.flowClass)
}
}
}
} }

View File

@ -92,6 +92,8 @@ interface StateMachineManagerInternal {
fun removeSessionBindings(sessionIds: Set<SessionId>) fun removeSessionBindings(sessionIds: Set<SessionId>)
fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState)
fun retryFlowFromSafePoint(currentState: StateMachineState) fun retryFlowFromSafePoint(currentState: StateMachineState)
fun scheduleFlowTimeout(flowId: StateMachineRunId)
fun cancelFlowTimeout(flowId: StateMachineRunId)
} }
/** /**

View File

@ -51,7 +51,6 @@ data class StateMachineState(
* @param flowState the state of the flow itself, including the frozen fiber/FlowLogic. * @param flowState the state of the flow itself, including the frozen fiber/FlowLogic.
* @param errorState the "dirtiness" state including the involved errors and their propagation status. * @param errorState the "dirtiness" state including the involved errors and their propagation status.
* @param numberOfSuspends the number of flow suspends due to IO API calls. * @param numberOfSuspends the number of flow suspends due to IO API calls.
* @param deduplicationSeed the basis seed for the deduplication ID. This is used to produce replayable IDs.
*/ */
data class Checkpoint( data class Checkpoint(
val invocationContext: InvocationContext, val invocationContext: InvocationContext,
@ -60,8 +59,7 @@ data class Checkpoint(
val subFlowStack: List<SubFlow>, val subFlowStack: List<SubFlow>,
val flowState: FlowState, val flowState: FlowState,
val errorState: ErrorState, val errorState: ErrorState,
val numberOfSuspends: Int, val numberOfSuspends: Int
val deduplicationSeed: String
) { ) {
companion object { companion object {
@ -82,8 +80,7 @@ data class Checkpoint(
subFlowStack = listOf(topLevelSubFlow), subFlowStack = listOf(topLevelSubFlow),
flowState = FlowState.Unstarted(flowStart, frozenFlowLogic), flowState = FlowState.Unstarted(flowStart, frozenFlowLogic),
errorState = ErrorState.Clean, errorState = ErrorState.Clean,
numberOfSuspends = 0, numberOfSuspends = 0
deduplicationSeed = deduplicationSeed
) )
} }
} }
@ -95,13 +92,19 @@ data class Checkpoint(
*/ */
sealed class SessionState { sealed class SessionState {
abstract val deduplicationSeed: String
/** /**
* We haven't yet sent the initialisation message * We haven't yet sent the initialisation message
*/ */
data class Uninitiated( data class Uninitiated(
val party: Party, val party: Party,
val initiatingSubFlow: SubFlow.Initiating val initiatingSubFlow: SubFlow.Initiating,
) : SessionState() val sourceSessionId: SessionId,
val additionalEntropy: Long
) : SessionState() {
override val deduplicationSeed: String get() = "R-${sourceSessionId.toLong}-$additionalEntropy"
}
/** /**
* We have sent the initialisation message but have not yet received a confirmation. * We have sent the initialisation message but have not yet received a confirmation.
@ -109,7 +112,8 @@ sealed class SessionState {
*/ */
data class Initiating( data class Initiating(
val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>, val bufferedMessages: List<Pair<DeduplicationId, ExistingSessionMessagePayload>>,
val rejectionError: FlowError? val rejectionError: FlowError?,
override val deduplicationSeed: String
) : SessionState() ) : SessionState()
/** /**
@ -121,7 +125,8 @@ sealed class SessionState {
val peerFlowInfo: FlowInfo, val peerFlowInfo: FlowInfo,
val receivedMessages: List<DataSessionMessage>, val receivedMessages: List<DataSessionMessage>,
val initiatedState: InitiatedSessionState, val initiatedState: InitiatedSessionState,
val errors: List<FlowError> val errors: List<FlowError>,
override val deduplicationSeed: String
) : SessionState() ) : SessionState()
} }

View File

@ -1,7 +1,19 @@
package net.corda.node.services.statemachine.transitions package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.node.services.statemachine.* import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.EndSessionMessage
import net.corda.node.services.statemachine.ErrorSessionMessage
import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.InitiatedSessionState
import net.corda.node.services.statemachine.RejectSessionMessage
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState
/** /**
* This transition handles incoming session messages. It handles the following cases: * This transition handles incoming session messages. It handles the following cases:
@ -62,7 +74,8 @@ class DeliverSessionMessageTransition(
peerFlowInfo = message.initiatedFlowInfo, peerFlowInfo = message.initiatedFlowInfo,
receivedMessages = emptyList(), receivedMessages = emptyList(),
initiatedState = InitiatedSessionState.Live(message.initiatedSessionId), initiatedState = InitiatedSessionState.Live(message.initiatedSessionId),
errors = emptyList() errors = emptyList(),
deduplicationSeed = sessionState.deduplicationSeed
) )
val newCheckpoint = currentState.checkpoint.copy( val newCheckpoint = currentState.checkpoint.copy(
sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession) sessions = currentState.checkpoint.sessions + (event.sessionMessage.recipientSessionId to initiatedSession)

View File

@ -6,7 +6,22 @@ import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.toNonEmptySet
import net.corda.node.services.statemachine.* import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.Checkpoint
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowError
import net.corda.node.services.statemachine.FlowSessionImpl
import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.InitialSessionMessage
import net.corda.node.services.statemachine.InitiatedSessionState
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.SessionId
import net.corda.node.services.statemachine.SessionMap
import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState
import net.corda.node.services.statemachine.SubFlow
/** /**
* This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request * This transition describes what should happen with a specific [FlowIORequest]. Note that at this time the request
@ -214,13 +229,15 @@ class StartedFlowTransition(
if (sessionState !is SessionState.Uninitiated) { if (sessionState !is SessionState.Uninitiated) {
continue continue
} }
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++) val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, sessionState.additionalEntropy, null)
val initialMessage = createInitialSessionMessage(sessionState.initiatingSubFlow, sourceSessionId, null) val newSessionState = SessionState.Initiating(
actions.add(Action.SendInitial(sessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(), bufferedMessages = emptyList(),
rejectionError = null rejectionError = null,
deduplicationSeed = sessionState.deduplicationSeed
) )
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, newSessionState)
actions.add(Action.SendInitial(sessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
newSessions[sourceSessionId] = newSessionState
} }
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
} }
@ -249,14 +266,15 @@ class StartedFlowTransition(
return freshErrorTransition(CannotFindSessionException(sourceSessionId)) return freshErrorTransition(CannotFindSessionException(sourceSessionId))
} else { } else {
val sessionMessage = DataSessionMessage(message) val sessionMessage = DataSessionMessage(message)
val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++) val deduplicationId = DeduplicationId.createForNormal(checkpoint, index++, existingSessionState)
when (existingSessionState) { when (existingSessionState) {
is SessionState.Uninitiated -> { is SessionState.Uninitiated -> {
val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, message) val initialMessage = createInitialSessionMessage(existingSessionState.initiatingSubFlow, sourceSessionId, existingSessionState.additionalEntropy, message)
actions.add(Action.SendInitial(existingSessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID))) actions.add(Action.SendInitial(existingSessionState.party, initialMessage, SenderDeduplicationId(deduplicationId, startingState.senderUUID)))
newSessions[sourceSessionId] = SessionState.Initiating( newSessions[sourceSessionId] = SessionState.Initiating(
bufferedMessages = emptyList(), bufferedMessages = emptyList(),
rejectionError = null rejectionError = null,
deduplicationSeed = existingSessionState.deduplicationSeed
) )
Unit Unit
} }
@ -388,12 +406,13 @@ class StartedFlowTransition(
private fun createInitialSessionMessage( private fun createInitialSessionMessage(
initiatingSubFlow: SubFlow.Initiating, initiatingSubFlow: SubFlow.Initiating,
sourceSessionId: SessionId, sourceSessionId: SessionId,
additionalEntropy: Long,
payload: SerializedBytes<Any>? payload: SerializedBytes<Any>?
): InitialSessionMessage { ): InitialSessionMessage {
return InitialSessionMessage( return InitialSessionMessage(
initiatorSessionId = sourceSessionId, initiatorSessionId = sourceSessionId,
// We add additional entropy to add to the initiated side's deduplication seed. // We add additional entropy to add to the initiated side's deduplication seed.
initiationEntropy = context.secureRandom.nextLong(), initiationEntropy = additionalEntropy,
initiatorFlowClassName = initiatingSubFlow.classToInitiateWith.name, initiatorFlowClassName = initiatingSubFlow.classToInitiateWith.name,
flowVersion = initiatingSubFlow.flowInfo.flowVersion, flowVersion = initiatingSubFlow.flowInfo.flowVersion,
appName = initiatingSubFlow.flowInfo.appName, appName = initiatingSubFlow.flowInfo.appName,

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.InitiatingFlow import net.corda.core.flows.InitiatingFlow
import net.corda.core.internal.FlowIORequest import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.TimedFlow
import net.corda.core.utilities.Try import net.corda.core.utilities.Try
import net.corda.node.services.statemachine.* import net.corda.node.services.statemachine.*
@ -95,11 +96,20 @@ class TopLevelTransition(
val subFlow = SubFlow.create(event.subFlowClass, event.subFlowVersion) val subFlow = SubFlow.create(event.subFlowClass, event.subFlowVersion)
when (subFlow) { when (subFlow) {
is Try.Success -> { is Try.Success -> {
val containsTimedSubFlows = currentState.checkpoint.subFlowStack.any {
TimedFlow::class.java.isAssignableFrom(it.flowClass)
}
val isCurrentSubFlowTimed = TimedFlow::class.java.isAssignableFrom(event.subFlowClass)
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = currentState.checkpoint.copy( checkpoint = currentState.checkpoint.copy(
subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value subFlowStack = currentState.checkpoint.subFlowStack + subFlow.value
) )
) )
// We don't schedule a timeout if there already is a timed subflow on the stack - a timeout had
// been scheduled already.
if (isCurrentSubFlowTimed && !containsTimedSubFlows) {
actions.add(Action.ScheduleFlowTimeout(currentState.flowLogic.runId))
}
} }
is Try.Failure -> { is Try.Failure -> {
freshErrorTransition(subFlow.exception) freshErrorTransition(subFlow.exception)
@ -115,35 +125,56 @@ class TopLevelTransition(
if (checkpoint.subFlowStack.isEmpty()) { if (checkpoint.subFlowStack.isEmpty()) {
freshErrorTransition(UnexpectedEventInState()) freshErrorTransition(UnexpectedEventInState())
} else { } else {
val lastSubFlowClass = checkpoint.subFlowStack.last().flowClass
val isLastSubFlowTimed = TimedFlow::class.java.isAssignableFrom(lastSubFlowClass)
val newSubFlowStack = checkpoint.subFlowStack.dropLast(1)
currentState = currentState.copy( currentState = currentState.copy(
checkpoint = checkpoint.copy( checkpoint = checkpoint.copy(
subFlowStack = checkpoint.subFlowStack.subList(0, checkpoint.subFlowStack.size - 1).toList() subFlowStack = newSubFlowStack
) )
) )
if (isLastSubFlowTimed && !containsTimedFlows(currentState.checkpoint.subFlowStack)) {
actions.add(Action.CancelFlowTimeout(currentState.flowLogic.runId))
}
} }
FlowContinuation.ProcessEvents FlowContinuation.ProcessEvents
} }
} }
private fun containsTimedFlows(subFlowStack: List<SubFlow>): Boolean {
return subFlowStack.any { TimedFlow::class.java.isAssignableFrom(it.flowClass) }
}
private fun suspendTransition(event: Event.Suspend): TransitionResult { private fun suspendTransition(event: Event.Suspend): TransitionResult {
return builder { return builder {
val newCheckpoint = currentState.checkpoint.copy( val newCheckpoint = currentState.checkpoint.copy(
flowState = FlowState.Started(event.ioRequest, event.fiber), flowState = FlowState.Started(event.ioRequest, event.fiber),
numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1 numberOfSuspends = currentState.checkpoint.numberOfSuspends + 1
) )
actions.addAll(arrayOf( if (event.maySkipCheckpoint) {
Action.PersistCheckpoint(context.id, newCheckpoint), actions.addAll(arrayOf(
Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers), Action.CommitTransaction,
Action.CommitTransaction, Action.ScheduleEvent(Event.DoRemainingWork)
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers), ))
Action.ScheduleEvent(Event.DoRemainingWork) currentState = currentState.copy(
)) checkpoint = newCheckpoint,
currentState = currentState.copy( isFlowResumed = false
checkpoint = newCheckpoint, )
pendingDeduplicationHandlers = emptyList(), } else {
isFlowResumed = false, actions.addAll(arrayOf(
isAnyCheckpointPersisted = true Action.PersistCheckpoint(context.id, newCheckpoint),
) Action.PersistDeduplicationFacts(currentState.pendingDeduplicationHandlers),
Action.CommitTransaction,
Action.AcknowledgeMessages(currentState.pendingDeduplicationHandlers),
Action.ScheduleEvent(Event.DoRemainingWork)
))
currentState = currentState.copy(
checkpoint = newCheckpoint,
pendingDeduplicationHandlers = emptyList(),
isFlowResumed = false,
isAnyCheckpointPersisted = true
)
}
FlowContinuation.ProcessEvents FlowContinuation.ProcessEvents
} }
} }
@ -191,7 +222,7 @@ class TopLevelTransition(
val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state -> val sendEndMessageActions = currentState.checkpoint.sessions.values.mapIndexed { index, state ->
if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) { if (state is SessionState.Initiated && state.initiatedState is InitiatedSessionState.Live) {
val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage) val message = ExistingSessionMessage(state.initiatedState.peerSinkSessionId, EndSessionMessage)
val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index) val deduplicationId = DeduplicationId.createForNormal(currentState.checkpoint, index, state)
Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID)) Action.SendExisting(state.peerParty, message, SenderDeduplicationId(deduplicationId, currentState.senderUUID))
} else { } else {
null null
@ -210,7 +241,7 @@ class TopLevelTransition(
} }
val sourceSessionId = SessionId.createRandom(context.secureRandom) val sourceSessionId = SessionId.createRandom(context.secureRandom)
val sessionImpl = FlowSessionImpl(event.party, sourceSessionId) val sessionImpl = FlowSessionImpl(event.party, sourceSessionId)
val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.party, initiatingSubFlow)) val newSessions = checkpoint.sessions + (sourceSessionId to SessionState.Uninitiated(event.party, initiatingSubFlow, sourceSessionId, context.secureRandom.nextLong()))
currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions)) currentState = currentState.copy(checkpoint = checkpoint.copy(sessions = newSessions))
actions.add(Action.AddSessionBinding(context.id, sourceSessionId)) actions.add(Action.AddSessionBinding(context.id, sourceSessionId))
FlowContinuation.Resume(sessionImpl) FlowContinuation.Resume(sessionImpl)
@ -241,4 +272,4 @@ class TopLevelTransition(
FlowContinuation.Abort FlowContinuation.Abort
} }
} }
} }

View File

@ -1,7 +1,17 @@
package net.corda.node.services.statemachine.transitions package net.corda.node.services.statemachine.transitions
import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowInfo
import net.corda.node.services.statemachine.* import net.corda.node.services.statemachine.Action
import net.corda.node.services.statemachine.ConfirmSessionMessage
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.FlowStart
import net.corda.node.services.statemachine.FlowState
import net.corda.node.services.statemachine.InitiatedSessionState
import net.corda.node.services.statemachine.SenderDeduplicationId
import net.corda.node.services.statemachine.SessionState
import net.corda.node.services.statemachine.StateMachineState
/** /**
* This transition is responsible for starting the flow from a FlowLogic instance. It creates the first checkpoint and * This transition is responsible for starting the flow from a FlowLogic instance. It creates the first checkpoint and
@ -45,7 +55,8 @@ class UnstartedFlowTransition(
} else { } else {
listOf(DataSessionMessage(initiatingMessage.firstPayload)) listOf(DataSessionMessage(initiatingMessage.firstPayload))
}, },
errors = emptyList() errors = emptyList(),
deduplicationSeed = "D-${initiatingMessage.initiatorSessionId.toLong}-${initiatingMessage.initiationEntropy}"
) )
val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo) val confirmationMessage = ConfirmSessionMessage(flowStart.initiatedSessionId, flowStart.initiatedFlowInfo)
val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage) val sessionMessage = ExistingSessionMessage(initiatingMessage.initiatorSessionId, confirmationMessage)
@ -58,7 +69,7 @@ class UnstartedFlowTransition(
Action.SendExisting( Action.SendExisting(
flowStart.peerSession.counterparty, flowStart.peerSession.counterparty,
sessionMessage, sessionMessage,
SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0), currentState.senderUUID) SenderDeduplicationId(DeduplicationId.createForNormal(currentState.checkpoint, 0, initiatedState), currentState.senderUUID)
) )
) )
} }

View File

@ -70,9 +70,9 @@ class RetryFlowMockTest {
val messagesSent = mutableListOf<Message>() val messagesSent = mutableListOf<Message>()
val partyB = internalNodeB.info.legalIdentities.first() val partyB = internalNodeB.info.legalIdentities.first()
internalNodeA.setMessagingServiceSpy(object : MessagingServiceSpy(internalNodeA.network) { internalNodeA.setMessagingServiceSpy(object : MessagingServiceSpy(internalNodeA.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) {
messagesSent.add(message) messagesSent.add(message)
messagingService.send(message, target, retryId) messagingService.send(message, target)
} }
}) })
internalNodeA.startFlow(SendAndRetryFlow(1, partyB)).get() internalNodeA.startFlow(SendAndRetryFlow(1, partyB)).get()

View File

@ -4,8 +4,18 @@ import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.Command import net.corda.core.contracts.Command
import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.* import net.corda.core.crypto.Crypto
import net.corda.core.flows.* import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.generateKeyPair
import net.corda.core.crypto.sha256
import net.corda.core.crypto.sign
import net.corda.core.flows.NotarisationPayload
import net.corda.core.flows.NotarisationRequest
import net.corda.core.flows.NotarisationRequestSignature
import net.corda.core.flows.NotaryError
import net.corda.core.flows.NotaryException
import net.corda.core.flows.NotaryFlow
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.notary.generateSignature import net.corda.core.internal.notary.generateSignature
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
@ -26,7 +36,12 @@ import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.dummyCommand import net.corda.testing.core.dummyCommand
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
import net.corda.testing.node.TestClock import net.corda.testing.node.TestClock
import net.corda.testing.node.internal.* import net.corda.testing.node.internal.InMemoryMessage
import net.corda.testing.node.internal.InternalMockNetwork
import net.corda.testing.node.internal.InternalMockNodeParameters
import net.corda.testing.node.internal.MessagingServiceSpy
import net.corda.testing.node.internal.setMessagingServiceSpy
import net.corda.testing.node.internal.startFlow
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
@ -293,7 +308,7 @@ class ValidatingNotaryServiceTests {
private fun runNotarisationAndInterceptClientPayload(payloadModifier: (NotarisationPayload) -> NotarisationPayload) { private fun runNotarisationAndInterceptClientPayload(payloadModifier: (NotarisationPayload) -> NotarisationPayload) {
aliceNode.setMessagingServiceSpy(object : MessagingServiceSpy(aliceNode.network) { aliceNode.setMessagingServiceSpy(object : MessagingServiceSpy(aliceNode.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) {
val messageData = message.data.deserialize<Any>() as? InitialSessionMessage val messageData = message.data.deserialize<Any>() as? InitialSessionMessage
val payload = messageData?.firstPayload!!.deserialize() val payload = messageData?.firstPayload!!.deserialize()
@ -301,10 +316,10 @@ class ValidatingNotaryServiceTests {
val alteredPayload = payloadModifier(payload) val alteredPayload = payloadModifier(payload)
val alteredMessageData = messageData.copy(firstPayload = alteredPayload.serialize()) val alteredMessageData = messageData.copy(firstPayload = alteredPayload.serialize())
val alteredMessage = InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData.serialize().bytes), message.uniqueMessageId) val alteredMessage = InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData.serialize().bytes), message.uniqueMessageId)
messagingService.send(alteredMessage, target, retryId) messagingService.send(alteredMessage, target)
} else { } else {
messagingService.send(message, target, retryId) messagingService.send(message, target)
} }
} }
}) })

View File

@ -420,7 +420,7 @@ class InMemoryMessagingNetwork private constructor(
state.locked { check(handlers.remove(registration as Handler)) } state.locked { check(handlers.remove(registration as Handler)) }
} }
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) {
check(running) check(running)
msgSend(this, message, target) msgSend(this, message, target)
if (!sendManuallyPumped) { if (!sendManuallyPumped) {
@ -429,8 +429,8 @@ class InMemoryMessagingNetwork private constructor(
} }
override fun send(addressedMessages: List<MessagingService.AddressedMessage>) { override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) { for ((message, target, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey) send(message, target, sequenceKey)
} }
} }
@ -443,8 +443,6 @@ class InMemoryMessagingNetwork private constructor(
netNodeHasShutdown(peerHandle) netNodeHasShutdown(peerHandle)
} }
override fun cancelRedelivery(retryId: Long) {}
/** Returns the given (topic & session, data) pair as a newly created message object. */ /** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message { override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map<String, String>): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID) return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID)