From b71f0c49fb8569c4adc6f10014e4c6ca5e0ba120 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Wed, 22 Nov 2017 16:03:44 +0000 Subject: [PATCH] Multi threaded state machine --- .../corda/client/rpc/RPCPerformanceTests.kt | 4 +- .../net/corda/core/internal/ConcurrentBox.kt | 17 + .../net/corda/core/internal/FetchDataFlow.kt | 5 +- .../net/corda/core/internal/LifeCycle.kt | 31 +- .../corda/docs/IntegrationTestingTutorial.kt | 49 +- .../mocknetwork/TutorialMockNetwork.kt | 2 +- .../kotlin/net/corda/flowhook/FiberMonitor.kt | 119 ++-- .../net/corda/flowhook/FlowHookContainer.kt | 50 +- .../internal/ArtemisMessagingClient.kt | 22 +- .../net/corda/node/amqp/AMQPBridgeTest.kt | 5 +- .../net/corda/node/amqp/ProtonWrapperTests.kt | 5 +- .../net/corda/node/internal/AbstractNode.kt | 31 +- .../net/corda/node/internal/EnterpriseNode.kt | 41 +- .../kotlin/net/corda/node/internal/Node.kt | 13 +- .../config/EnterpriseConfiguration.kt | 38 +- .../node/services/config/NodeConfiguration.kt | 10 + .../services/events/NodeSchedulerService.kt | 29 +- .../identity/PersistentIdentityService.kt | 2 +- .../messaging/ArtemisMessagingServer.kt | 2 +- .../node/services/messaging/Messaging.kt | 12 +- .../services/messaging/MessagingExecutor.kt | 221 ++++++ .../services/messaging/P2PMessagingClient.kt | 164 ++--- .../services/messaging/RPCMessagingClient.kt | 9 +- .../node/services/messaging/RPCServer.kt | 2 +- ...bstractPartyToX500NameAsStringConverter.kt | 2 - .../DBTransactionMappingStorage.kt | 30 +- .../persistence/DBTransactionStorage.kt | 10 +- .../statemachine/ActionExecutorImpl.kt | 12 +- .../services/statemachine/FlowMessaging.kt | 9 +- .../statemachine/FlowStateMachineImpl.kt | 40 +- .../MultiThreadedStateMachineManager.kt | 668 ++++++++++++++++++ .../statemachine/SessionRejectException.kt | 8 + ...t => SingleThreadedStateMachineManager.kt} | 20 +- .../statemachine/StateMachineState.kt | 4 +- .../statemachine/TransitionExecutor.kt | 2 +- .../DumpHistoryOnErrorInterceptor.kt | 4 +- .../interceptors/HospitalisingInterceptor.kt | 3 + .../node/services/vault/NodeVaultService.kt | 12 +- node/src/main/resources/reference.conf | 8 + .../events/NodeSchedulerServiceTest.kt | 4 - .../messaging/ArtemisMessagingTest.kt | 11 +- .../statemachine/FlowFrameworkTests.kt | 9 +- perftestcordapp/build.gradle | 2 +- .../perftestcordapp}/NodePerformanceTests.kt | 96 ++- .../flows/CashIssueAndPaymentNoSelection.kt | 1 + .../net/corda/irs/flows/AutoOfferFlow.kt | 4 +- .../testing/node/InMemoryMessagingNetwork.kt | 8 +- .../kotlin/net/corda/testing/node/MockNode.kt | 12 +- .../testing/node/internal/NodeBasedTest.kt | 3 +- .../node/internal/performance/Injectors.kt | 25 +- 50 files changed, 1515 insertions(+), 375 deletions(-) create mode 100644 core/src/main/kotlin/net/corda/core/internal/ConcurrentBox.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt create mode 100644 node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt rename node/src/main/kotlin/net/corda/node/services/statemachine/{StateMachineManagerImpl.kt => SingleThreadedStateMachineManager.kt} (98%) rename {node/src/integration-test/kotlin/net/corda/node => perftestcordapp/src/integrationTest/kotlin/com/r3/corda/enterprise/perftestcordapp}/NodePerformanceTests.kt (58%) diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt index b945487cb9..19581b81d4 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt @@ -2,6 +2,7 @@ package net.corda.client.rpc import com.google.common.base.Stopwatch import net.corda.client.rpc.internal.RPCClientConfiguration +import net.corda.core.internal.concurrent.doneFuture import net.corda.core.messaging.RPCOps import net.corda.core.utilities.minutes import net.corda.core.utilities.seconds @@ -144,10 +145,11 @@ class RPCPerformanceTests : AbstractRPCTest() { parallelism = 8, overallDuration = 5.minutes, injectionRate = 20000L / TimeUnit.SECONDS, + workBound = 50, queueSizeMetricName = "$mode.QueueSize", workDurationMetricName = "$mode.WorkDuration", work = { - proxy.ops.simpleReply(ByteArray(4096), 4096) + doneFuture(proxy.ops.simpleReply(ByteArray(4096), 4096)) } ) } diff --git a/core/src/main/kotlin/net/corda/core/internal/ConcurrentBox.kt b/core/src/main/kotlin/net/corda/core/internal/ConcurrentBox.kt new file mode 100644 index 0000000000..fc3fb08d48 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/ConcurrentBox.kt @@ -0,0 +1,17 @@ +package net.corda.core.internal + +import java.util.concurrent.locks.ReentrantReadWriteLock +import kotlin.concurrent.read +import kotlin.concurrent.write + +/** + * A [ConcurrentBox] allows the implementation of track() with reduced contention. [concurrent] may be run from several + * threads (which means it MUST be threadsafe!), while [exclusive] stops the world until the tracking has been set up. + * Internally [ConcurrentBox] is implemented simply as a read-write lock. + */ +class ConcurrentBox(val content: T) { + val lock = ReentrantReadWriteLock() + + inline fun concurrent(block: T.() -> R): R = lock.read { block(content) } + inline fun exclusive(block: T.() -> R): R = lock.write { block(content) } +} diff --git a/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt b/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt index 28b44dfa4e..1db993649a 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt @@ -17,6 +17,7 @@ import net.corda.core.serialization.SerializeAsTokenContext import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.UntrustworthyData +import net.corda.core.utilities.debug import net.corda.core.utilities.unwrap import java.util.* @@ -72,7 +73,7 @@ sealed class FetchDataFlow( return if (toFetch.isEmpty()) { Result(fromDisk, emptyList()) } else { - logger.info("Requesting ${toFetch.size} dependency(s) for verification from ${otherSideSession.counterparty.name}") + logger.debug { "Requesting ${toFetch.size} dependency(s) for verification from ${otherSideSession.counterparty.name}" } // TODO: Support "large message" response streaming so response sizes are not limited by RAM. // We can then switch to requesting items in large batches to minimise the latency penalty. @@ -89,7 +90,7 @@ sealed class FetchDataFlow( } // Check for a buggy/malicious peer answering with something that we didn't ask for. val downloaded = validateFetchResponse(UntrustworthyData(maybeItems), toFetch) - logger.info("Fetched ${downloaded.size} elements from ${otherSideSession.counterparty.name}") + logger.debug { "Fetched ${downloaded.size} elements from ${otherSideSession.counterparty.name}" } maybeWriteToDisk(downloaded) Result(fromDisk, downloaded) } diff --git a/core/src/main/kotlin/net/corda/core/internal/LifeCycle.kt b/core/src/main/kotlin/net/corda/core/internal/LifeCycle.kt index 96786ea3e9..78055bc2de 100644 --- a/core/src/main/kotlin/net/corda/core/internal/LifeCycle.kt +++ b/core/src/main/kotlin/net/corda/core/internal/LifeCycle.kt @@ -13,19 +13,38 @@ class LifeCycle>(initial: S) { private val lock = ReentrantReadWriteLock() private var state = initial - /** Assert that the lifecycle in the [requiredState]. */ - fun requireState(requiredState: S) { - requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState } + /** + * Assert that the lifecycle in the [requiredState]. Optionally runs [block], for the duration of which the + * lifecycle is guaranteed to stay in [requiredState]. + */ + fun requireState( + requiredState: S, + block: () -> A + ): A { + return requireState( + errorMessage = { "Required state to be $requiredState, was $it" }, + predicate = { it == requiredState }, + block = block + ) } + fun requireState(requiredState: S) = requireState(requiredState) {} /** Assert something about the current state atomically. */ + fun requireState( + errorMessage: (S) -> String, + predicate: (S) -> Boolean, + block: () -> A + ): A { + return lock.readLock().withLock { + require(predicate(state)) { errorMessage(state) } + block() + } + } fun requireState( errorMessage: (S) -> String = { "Predicate failed on state $it" }, predicate: (S) -> Boolean ) { - lock.readLock().withLock { - require(predicate(state)) { errorMessage(state) } - } + requireState(errorMessage, predicate) {} } /** Transition the state from [from] to [to]. */ diff --git a/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt b/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt index ea99522584..98c930d6fe 100644 --- a/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt +++ b/docs/source/example-code/src/integration-test/kotlin/net/corda/docs/IntegrationTestingTutorial.kt @@ -5,6 +5,7 @@ import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow import net.corda.core.messaging.vaultTrackBy import net.corda.core.node.services.Vault +import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.finance.DOLLARS @@ -63,14 +64,15 @@ class IntegrationTestingTutorial : IntegrationTest() { // END 2 // START 3 - val bobVaultUpdates = bobProxy.vaultTrackBy().updates - val aliceVaultUpdates = aliceProxy.vaultTrackBy().updates + val bobVaultUpdates = bobProxy.vaultTrackBy(criteria = QueryCriteria.VaultQueryCriteria(status = Vault.StateStatus.ALL)).updates + val aliceVaultUpdates = aliceProxy.vaultTrackBy(criteria = QueryCriteria.VaultQueryCriteria(status = Vault.StateStatus.ALL)).updates // END 3 // START 4 + val numberOfStates = 10 val issueRef = OpaqueBytes.of(0) val notaryParty = aliceProxy.notaryIdentities().first() - (1..10).map { i -> + (1..numberOfStates).map { i -> aliceProxy.startFlow(::CashIssueFlow, i.DOLLARS, issueRef, @@ -78,7 +80,7 @@ class IntegrationTestingTutorial : IntegrationTest() { ).returnValue }.transpose().getOrThrow() // We wait for all of the issuances to run before we start making payments - (1..10).map { i -> + (1..numberOfStates).map { i -> aliceProxy.startFlow(::CashPaymentFlow, i.DOLLARS, bob.nodeInfo.chooseIdentity(), @@ -88,7 +90,7 @@ class IntegrationTestingTutorial : IntegrationTest() { bobVaultUpdates.expectEvents { parallel( - (1..10).map { i -> + (1..numberOfStates).map { i -> expect( match = { update: Vault.Update -> update.produced.first().state.data.amount.quantity == i * 100L @@ -102,21 +104,44 @@ class IntegrationTestingTutorial : IntegrationTest() { // END 4 // START 5 - for (i in 1..10) { + for (i in 1..numberOfStates) { bobProxy.startFlow(::CashPaymentFlow, i.DOLLARS, alice.nodeInfo.chooseIdentity()).returnValue.getOrThrow() } aliceVaultUpdates.expectEvents { sequence( - (1..10).map { i -> - expect { update: Vault.Update -> - println("Alice got vault update of $update") - assertEquals(update.produced.first().state.data.amount.quantity, i * 100L) - } - } + // issuance + parallel( + (1..numberOfStates).map { i -> + expect(match = { it.moved() == -i * 100 }) { update: Vault.Update -> + assertEquals(0, update.consumed.size) + } + } + ), + // move to Bob + parallel( + (1..numberOfStates).map { i -> + expect(match = { it.moved() == i * 100 }) { update: Vault.Update -> + } + } + ), + // move back to Alice + sequence( + (1..numberOfStates).map { i -> + expect(match = { it.moved() == -i * 100 }) { update: Vault.Update -> + assertEquals(update.consumed.size, 0) + } + } + ) ) } // END 5 } } + + fun Vault.Update.moved(): Int { + val consumedSum = consumed.sumBy { it.state.data.amount.quantity.toInt() } + val producedSum = produced.sumBy { it.state.data.amount.quantity.toInt() } + return consumedSum - producedSum + } } \ No newline at end of file diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt index 0147e4e623..45fb50ef78 100644 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt @@ -85,7 +85,7 @@ class TutorialMockNetwork { // modify message if it's 1 nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) { - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { val messageData = message.data.deserialize() as? ExistingSessionMessage val payload = messageData?.payload if (payload is DataSessionMessage && payload.payload.deserialize() == 1) { diff --git a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FiberMonitor.kt b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FiberMonitor.kt index e0f1aedaca..4741fcbdce 100644 --- a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FiberMonitor.kt +++ b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FiberMonitor.kt @@ -3,6 +3,9 @@ package net.corda.flowhook import co.paralleluniverse.fibers.Fiber import net.corda.core.internal.uncheckedCast import net.corda.core.utilities.loggerFor +import net.corda.flowhook.FiberMonitor.correlator +import net.corda.flowhook.FiberMonitor.inspect +import net.corda.flowhook.FiberMonitor.newEvent import net.corda.nodeapi.internal.persistence.DatabaseTransaction import java.sql.Connection import java.time.Instant @@ -10,11 +13,12 @@ import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicBoolean +import kotlin.concurrent.thread data class MonitorEvent(val type: MonitorEventType, val keys: List, val extra: Any? = null) data class FullMonitorEvent(val timestamp: Instant, val trace: List, val event: MonitorEvent) { - override fun toString() = event.toString() + override fun toString() = "$timestamp: ${event.type}" } enum class MonitorEventType { @@ -25,6 +29,8 @@ enum class MonitorEventType { FiberStarted, FiberParking, + FiberParked, + FiberResuming, FiberException, FiberResumed, FiberEnded, @@ -37,7 +43,12 @@ enum class MonitorEventType { SetThreadLocals, SetInheritableThreadLocals, GetThreadLocals, - GetInheritableThreadLocals + GetInheritableThreadLocals, + + SendSessionMessage, + + BrokerFlushStart, + BrokerFlushEnd } /** @@ -58,11 +69,26 @@ object FiberMonitor { private val started = AtomicBoolean(false) private var executor: ScheduledExecutorService? = null - val correlator = MonitorEventCorrelator() + private val correlator = MonitorEventCorrelator() + + private val eventsToDrop = setOf( + MonitorEventType.TransactionCreated, + MonitorEventType.ConnectionRequested, + MonitorEventType.ConnectionAcquired, + MonitorEventType.ConnectionReleased, + MonitorEventType.NettyThreadLocalMapCreated, + MonitorEventType.SetThreadLocals, + MonitorEventType.SetInheritableThreadLocals, + MonitorEventType.GetThreadLocals, + MonitorEventType.GetInheritableThreadLocals + ) fun newEvent(event: MonitorEvent) { if (executor != null) { val fullEvent = FullMonitorEvent(Instant.now(), Exception().stackTrace.toList(), event) + if (event.type in eventsToDrop) { + return + } executor!!.execute { processEvent(fullEvent) } @@ -75,6 +101,12 @@ object FiberMonitor { executor = Executors.newSingleThreadScheduledExecutor() executor!!.scheduleAtFixedRate(this::inspect, 100, 100, TimeUnit.MILLISECONDS) } + thread { + while (true) { + Thread.sleep(1000) + this + } + } } // Break on this function or [newEvent]. @@ -174,58 +206,49 @@ class MonitorEventCorrelator { fun getByType() = merged().entries.groupBy { it.key.javaClass } fun addEvent(fullMonitorEvent: FullMonitorEvent) { - events.add(fullMonitorEvent) + synchronized(events) { + events.add(fullMonitorEvent) + } } fun merged(): Map> { - val merged = HashMap>() - for (event in events) { - val eventLists = HashSet>() - for (key in event.event.keys) { - val list = merged[key] - if (list != null) { - eventLists.add(list) + val keyToEvents = HashMap>() + + synchronized(events) { + for (event in events) { + for (key in event.event.keys) { + keyToEvents.getOrPut(key) { HashSet() }.add(event) } } - val newList = when (eventLists.size) { - 0 -> ArrayList() - 1 -> eventLists.first() - else -> mergeAll(eventLists) + } + + val components = ArrayList>() + val visited = HashSet() + for (root in keyToEvents.keys) { + if (root in visited) { + continue } - newList.add(event) - for (key in event.event.keys) { - merged[key] = newList + val component = HashSet() + val toVisit = arrayListOf(root) + while (toVisit.isNotEmpty()) { + val current = toVisit.removeAt(toVisit.size - 1) + if (current in visited) { + continue + } + toVisit.addAll(keyToEvents[current]!!.flatMapTo(HashSet()) { it.event.keys }) + component.add(current) + visited.add(current) + } + components.add(component) + } + + val merged = HashMap>() + for (component in components) { + val eventList = component.flatMapTo(HashSet()) { keyToEvents[it]!! }.sortedBy { it.timestamp } + for (key in component) { + merged[key] = eventList } } return merged } - - fun mergeAll(lists: Collection>): ArrayList { - return lists.fold(ArrayList()) { merged, next -> merge(merged, next) } - } - - fun merge(a: List, b: List): ArrayList { - val merged = ArrayList() - var aIndex = 0 - var bIndex = 0 - while (true) { - if (aIndex >= a.size) { - merged.addAll(b.subList(bIndex, b.size)) - return merged - } - if (bIndex >= b.size) { - merged.addAll(a.subList(aIndex, a.size)) - return merged - } - val aElem = a[aIndex] - val bElem = b[bIndex] - if (aElem.timestamp < bElem.timestamp) { - merged.add(aElem) - aIndex++ - } else { - merged.add(bElem) - bIndex++ - } - } - } -} \ No newline at end of file +} diff --git a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHookContainer.kt b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHookContainer.kt index abb55cca2c..d2d01cb6dd 100644 --- a/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHookContainer.kt +++ b/experimental/flow-hook/src/main/kotlin/net/corda/flowhook/FlowHookContainer.kt @@ -1,19 +1,37 @@ package net.corda.flowhook import co.paralleluniverse.fibers.Fiber +import net.corda.core.internal.declaredField import net.corda.node.services.statemachine.Event +import net.corda.node.services.statemachine.ExistingSessionMessage +import net.corda.node.services.statemachine.InitialSessionMessage +import net.corda.node.services.statemachine.SessionMessage import net.corda.nodeapi.internal.persistence.contextTransactionOrNull +import org.apache.activemq.artemis.core.io.buffer.TimedBuffer import java.sql.Connection +import java.util.concurrent.TimeUnit @Suppress("UNUSED") object FlowHookContainer { @JvmStatic @Hook("co.paralleluniverse.fibers.Fiber") - fun park() { + fun park1(blocker: Any?, postParkAction: Any?, timeout: Long?, unit: TimeUnit?) { FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberParking, keys = listOf(Fiber.currentFiber()))) } + @JvmStatic + @Hook("co.paralleluniverse.fibers.Fiber", passThis = true) + fun exec(fiber: Any) { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberResuming, keys = listOf(fiber))) + } + + @JvmStatic + @Hook("co.paralleluniverse.fibers.Fiber", passThis = true) + fun onParked(fiber: Any) { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberParked, keys = listOf(fiber))) + } + @JvmStatic @Hook("net.corda.node.services.statemachine.FlowStateMachineImpl") fun run() { @@ -150,6 +168,36 @@ object FlowHookContainer { })) } + @JvmStatic + @Hook("net.corda.node.services.statemachine.FlowMessagingImpl") + fun sendSessionMessage(party: Any, message: Any, deduplicationId: Any) { + message as SessionMessage + val sessionId = when (message) { + is InitialSessionMessage -> { + message.initiatorSessionId + } + is ExistingSessionMessage -> { + message.recipientSessionId + } + } + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.SendSessionMessage, keys = listOf(currentFiberOrThread(), sessionId))) + } + + @JvmStatic + @Hook("org.apache.activemq.artemis.core.io.buffer.TimedBuffer", passThis = true) + fun flush(buffer: Any, force: Boolean): () -> Unit { + buffer as TimedBuffer + val thread = Thread.currentThread() + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.BrokerFlushStart, keys = listOf(thread), extra = object { + val force = force + val pendingSync = buffer.declaredField("pendingSync").value + })) + + return { + FiberMonitor.newEvent(MonitorEvent(MonitorEventType.BrokerFlushEnd, keys = listOf(thread))) + } + } + private fun currentFiberOrThread(): Any { return Fiber.currentFiber() ?: Thread.currentThread() } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt index 99b3ab8e38..1050fa1a3c 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/ArtemisMessagingClient.kt @@ -7,18 +7,22 @@ import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_USER import net.corda.nodeapi.internal.config.SSLConfiguration -import org.apache.activemq.artemis.api.core.client.ActiveMQClient +import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE -import org.apache.activemq.artemis.api.core.client.ClientProducer -import org.apache.activemq.artemis.api.core.client.ClientSession -import org.apache.activemq.artemis.api.core.client.ClientSessionFactory -class ArtemisMessagingClient(private val config: SSLConfiguration, private val serverAddress: NetworkHostAndPort, private val maxMessageSize: Int) { +class ArtemisMessagingClient( + private val config: SSLConfiguration, + private val serverAddress: NetworkHostAndPort, + private val maxMessageSize: Int, + private val autoCommitSends: Boolean = true, + private val autoCommitAcks: Boolean = true, + private val confirmationWindowSize: Int = -1 +) { companion object { private val log = loggerFor() } - class Started(val sessionFactory: ClientSessionFactory, val session: ClientSession, val producer: ClientProducer) + class Started(val serverLocator: ServerLocator, val sessionFactory: ClientSessionFactory, val session: ClientSession, val producer: ClientProducer) var started: Started? = null private set @@ -35,17 +39,18 @@ class ArtemisMessagingClient(private val config: SSLConfiguration, private val s clientFailureCheckPeriod = -1 minLargeMessageSize = maxMessageSize isUseGlobalPools = nodeSerializationEnv != null + confirmationWindowSize = this@ArtemisMessagingClient.confirmationWindowSize } val sessionFactory = locator.createSessionFactory() // Login using the node username. The broker will authenticate us as its node (as opposed to another peer) // using our TLS certificate. // Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer // size of 1MB is acknowledged. - val session = sessionFactory!!.createSession(NODE_USER, NODE_USER, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE) + val session = sessionFactory!!.createSession(NODE_USER, NODE_USER, false, autoCommitSends, autoCommitAcks, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE) session.start() // Create a general purpose producer. val producer = session.createProducer() - return Started(sessionFactory, session, producer).also { started = it } + return Started(locator, sessionFactory, session, producer).also { started = it } } fun stop() = synchronized(this) { @@ -55,6 +60,7 @@ class ArtemisMessagingClient(private val config: SSLConfiguration, private val s session.commit() // Closing the factory closes all the sessions it produced as well. sessionFactory.close() + serverLocator.close() } started = null } diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt index f1046d4d41..51a940109c 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/AMQPBridgeTest.kt @@ -5,9 +5,7 @@ import com.nhaarman.mockito_kotlin.whenever import net.corda.core.crypto.toStringShort import net.corda.core.internal.div import net.corda.core.utilities.NetworkHostAndPort -import net.corda.node.services.config.CertChainPolicyConfig -import net.corda.node.services.config.NodeConfiguration -import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.node.services.config.* import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent @@ -143,6 +141,7 @@ class AMQPBridgeTest { doReturn(artemisAddress).whenever(it).p2pAddress doReturn("").whenever(it).exportJMXto doReturn(emptyList()).whenever(it).certificateChainCheckPolicies + doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration } artemisConfig.configureWithDevSSLCertificate() val artemisServer = ArtemisMessagingServer(artemisConfig, artemisPort, MAX_MESSAGE_SIZE) diff --git a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt index 64d7c09990..0ab376e936 100644 --- a/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/amqp/ProtonWrapperTests.kt @@ -8,9 +8,7 @@ import net.corda.core.identity.CordaX500Name import net.corda.core.internal.div import net.corda.core.toFuture import net.corda.core.utilities.NetworkHostAndPort -import net.corda.node.services.config.CertChainPolicyConfig -import net.corda.node.services.config.NodeConfiguration -import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.node.services.config.* import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX @@ -224,6 +222,7 @@ class ProtonWrapperTests { doReturn(NetworkHostAndPort("0.0.0.0", artemisPort)).whenever(it).p2pAddress doReturn("").whenever(it).exportJMXto doReturn(emptyList()).whenever(it).certificateChainCheckPolicies + doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration } artemisConfig.configureWithDevSSLCertificate() diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index a9dfdcbb40..e5ad108921 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -2,7 +2,7 @@ package net.corda.node.internal import com.codahale.metrics.MetricRegistry import com.google.common.collect.MutableClassToInstanceMap -import com.google.common.util.concurrent.MoreExecutors +import com.google.common.util.concurrent.ThreadFactoryBuilder import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariDataSource import net.corda.confidential.SwapIdentitiesFlow @@ -90,7 +90,7 @@ import java.time.Duration import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ExecutorService -import java.util.concurrent.TimeUnit.SECONDS +import java.util.concurrent.Executors import kotlin.collections.set import kotlin.reflect.KClass import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair @@ -110,7 +110,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val platformClock: CordaClock, protected val versionInfo: VersionInfo, protected val cordappLoader: CordappLoader, - private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { + protected val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { private class StartedNodeImpl( override val internals: N, @@ -131,7 +131,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, // We will run as much stuff in this single thread as possible to keep the risk of thread safety bugs low during the // low-performance prototyping period. - protected abstract val serverThread: AffinityExecutor + protected abstract val serverThread: AffinityExecutor.ServiceAffinityExecutor protected lateinit var networkParameters: NetworkParameters private val cordappServices = MutableClassToInstanceMap.create() @@ -140,7 +140,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, protected val services: ServiceHubInternal get() = _services private lateinit var _services: ServiceHubInternalImpl protected var myNotaryIdentity: PartyAndCertificate? = null - private lateinit var checkpointStorage: CheckpointStorage + protected lateinit var checkpointStorage: CheckpointStorage private lateinit var tokenizableServices: List protected lateinit var attachments: NodeAttachmentService protected lateinit var network: MessagingService @@ -229,23 +229,15 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val notaryService = makeNotaryService(nodeServices, database) val smm = makeStateMachineManager(database) val flowLogicRefFactory = FlowLogicRefFactoryImpl(cordappLoader.appClassLoader) - val flowStarter = FlowStarterImpl(serverThread, smm, flowLogicRefFactory) + val flowStarter = FlowStarterImpl(smm, flowLogicRefFactory) val schedulerService = NodeSchedulerService( platformClock, database, flowStarter, transactionStorage, unfinishedSchedules = busyNodeLatch, - serverThread = serverThread, - flowLogicRefFactory = flowLogicRefFactory) - if (serverThread is ExecutorService) { - runOnStop += { - // We wait here, even though any in-flight messages should have been drained away because the - // server thread can potentially have other non-messaging tasks scheduled onto it. The timeout value is - // arbitrary and might be inappropriate. - MoreExecutors.shutdownAndAwaitTermination(serverThread as ExecutorService, 50, SECONDS) - } - } + flowLogicRefFactory = flowLogicRefFactory + ) makeVaultObservers(schedulerService, database.hibernateConfig, smm, schemaService, flowLogicRefFactory) val rpcOps = makeRPCOps(flowStarter, database, smm) startMessagingService(rpcOps) @@ -327,8 +319,9 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } protected abstract fun myAddresses(): List + protected open fun makeStateMachineManager(database: CordaPersistence): StateMachineManager { - return StateMachineManagerImpl( + return SingleThreadedStateMachineManager( services, checkpointStorage, serverThread, @@ -841,9 +834,9 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) { } } -internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { +internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { override fun startFlow(logic: FlowLogic, context: InvocationContext): CordaFuture> { - return serverThread.fetchFrom { smm.startFlow(logic, context) } + return smm.startFlow(logic, context) } override fun invokeFlowAsync( diff --git a/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt b/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt index 901391a114..d4bb045015 100644 --- a/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt @@ -4,24 +4,36 @@ import com.codahale.metrics.MetricFilter import com.codahale.metrics.MetricRegistry import com.codahale.metrics.graphite.GraphiteReporter import com.codahale.metrics.graphite.PickledGraphite +import com.google.common.util.concurrent.ThreadFactoryBuilder import com.jcraft.jsch.JSch import com.jcraft.jsch.JSchException +import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name import net.corda.core.internal.Emoji import net.corda.core.internal.concurrent.thenMatch import net.corda.core.utilities.loggerFor import net.corda.node.VersionInfo +import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.services.config.GraphiteOptions import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.RelayConfiguration +import net.corda.node.services.statemachine.MultiThreadedStateMachineManager +import net.corda.node.services.statemachine.SingleThreadedStateMachineManager +import net.corda.node.services.statemachine.StateMachineManager +import net.corda.nodeapi.internal.persistence.CordaPersistence import org.fusesource.jansi.Ansi import org.fusesource.jansi.AnsiConsole import java.io.IOException import java.net.InetAddress +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors import java.util.concurrent.TimeUnit -class EnterpriseNode(configuration: NodeConfiguration, - versionInfo: VersionInfo) : Node(configuration, versionInfo) { +open class EnterpriseNode(configuration: NodeConfiguration, + versionInfo: VersionInfo, + initialiseSerialization: Boolean = true, + cordappLoader: CordappLoader = makeCordappLoader(configuration) +) : Node(configuration, versionInfo, initialiseSerialization, cordappLoader) { companion object { private val logger by lazy { loggerFor() } @@ -144,4 +156,29 @@ D""".trimStart() registerOptionalMetricsReporter(configuration, started.services.monitoringService.metrics) return started } + + private fun makeStateMachineExecutorService(): ExecutorService { + return Executors.newFixedThreadPool( + configuration.enterpriseConfiguration.tuning.flowThreadPoolSize, + ThreadFactoryBuilder().setNameFormat("flow-executor-%d").build() + ) + } + + override fun makeStateMachineManager(database: CordaPersistence): StateMachineManager { + if (configuration.enterpriseConfiguration.useMultiThreadedSMM) { + val executor = makeStateMachineExecutorService() + runOnStop += { executor.shutdown() } + return MultiThreadedStateMachineManager( + services, + checkpointStorage, + executor, + database, + newSecureRandom(), + busyNodeLatch, + cordappLoader.appClassLoader + ) + } else { + return super.makeStateMachineManager(database) + } + } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index d41c55fedb..8d663fca57 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -83,7 +83,8 @@ open class Node(configuration: NodeConfiguration, private val sameVmNodeCounter = AtomicInteger() val scanPackagesSystemProperty = "net.corda.node.cordapp.scan.packages" val scanPackagesSeparator = "," - private fun makeCordappLoader(configuration: NodeConfiguration): CordappLoader { + @JvmStatic + protected fun makeCordappLoader(configuration: NodeConfiguration): CordappLoader { return System.getProperty(scanPackagesSystemProperty)?.let { scanPackages -> CordappLoader.createDefaultWithTestPackages(configuration, scanPackages.split(scanPackagesSeparator)) } ?: CordappLoader.createDefault(configuration.baseDirectory) @@ -157,8 +158,12 @@ open class Node(configuration: NodeConfiguration, bridgeControlListener = BridgeControlListener(configuration, serverAddress, networkParameters.maxMessageSize) printBasicNodeInfo("Incoming connection address", advertisedAddress.toString()) + + val rpcServerConfiguration = RPCServerConfiguration.default.copy( + rpcThreadPoolSize = configuration.enterpriseConfiguration.tuning.rpcThreadPoolSize + ) rpcServerAddresses?.let { - rpcMessagingClient = RPCMessagingClient(configuration.rpcOptions.sslConfig, it.admin, networkParameters.maxMessageSize) + rpcMessagingClient = RPCMessagingClient(configuration.rpcOptions.sslConfig, it.admin, networkParameters.maxMessageSize, rpcServerConfiguration) } verifierMessagingClient = when (configuration.verifierType) { VerifierType.OutOfProcess -> VerifierMessagingClient(configuration, serverAddress, services.monitoringService.metrics, networkParameters.maxMessageSize) @@ -175,8 +180,10 @@ open class Node(configuration: NodeConfiguration, serverThread, database, services.networkMapCache, + services.monitoringService.metrics, advertisedAddress, - networkParameters.maxMessageSize) + networkParameters.maxMessageSize + ) } private fun startLocalRpcBroker(): BrokerAddresses? { diff --git a/node/src/main/kotlin/net/corda/node/services/config/EnterpriseConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/EnterpriseConfiguration.kt index f8aee69d48..3087838093 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/EnterpriseConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/EnterpriseConfiguration.kt @@ -1,5 +1,39 @@ package net.corda.node.services.config -data class EnterpriseConfiguration(val mutualExclusionConfiguration: MutualExclusionConfiguration) +data class EnterpriseConfiguration( + val mutualExclusionConfiguration: MutualExclusionConfiguration, + val useMultiThreadedSMM: Boolean = true, + val tuning: PerformanceTuning = PerformanceTuning.default +) -data class MutualExclusionConfiguration(val on: Boolean = false, val machineName: String, val updateInterval: Long, val waitInterval: Long) \ No newline at end of file +data class MutualExclusionConfiguration(val on: Boolean = false, val machineName: String, val updateInterval: Long, val waitInterval: Long) + +/** + * @param flowThreadPoolSize Determines the size of the thread pool used by the flow framework to run flows. + * @param maximumMessagingBatchSize Determines the maximum number of jobs the messaging layer submits asynchronously + * before waiting for a flush from the broker. + * @param rpcThreadPoolSize Determines the number of threads used by the RPC server to serve requests. + * @param p2pConfirmationWindowSize Determines the number of bytes buffered by the broker before flushing to disk and + * acking the triggering send. Setting this to -1 causes session commits to immediately return, potentially + * causing blowup in the broker if the rate of sends exceeds the broker's flush rate. Note also that this window + * causes send latency to be around [brokerConnectionTtlCheckInterval] if the window isn't saturated. + * @param brokerConnectionTtlCheckIntervalMs Determines the interval of TTL timeout checks, but most importantly it also + * determines the flush period of message acks in case [p2pConfirmationWindowSize] is not saturated in time. + */ +data class PerformanceTuning( + val flowThreadPoolSize: Int, + val maximumMessagingBatchSize: Int, + val rpcThreadPoolSize: Int, + val p2pConfirmationWindowSize: Int, + val brokerConnectionTtlCheckIntervalMs: Long +) { + companion object { + val default = PerformanceTuning( + flowThreadPoolSize = 1, + maximumMessagingBatchSize = 256, + rpcThreadPoolSize = 4, + p2pConfirmationWindowSize = 1048576, + brokerConnectionTtlCheckIntervalMs = 20 + ) + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt index f56ddbd4cd..4c2cafbf2a 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt @@ -209,6 +209,16 @@ data class NodeConfigurationImpl( if (dataSourceUrl.contains(":sqlserver:") && !dataSourceUrl.contains("sendStringParametersAsUnicode", true)) { dataSourceProperties[DataSourceConfigTag.DATA_SOURCE_URL] = dataSourceUrl + ";sendStringParametersAsUnicode=false" } + + // Adjust connection pool size depending on N=flow thread pool size. + // If there is no configured pool size set it to N + 1, otherwise check that it's greater than N. + val flowThreadPoolSize = enterpriseConfiguration.tuning.flowThreadPoolSize + val maxConnectionPoolSize = dataSourceProperties.getProperty("maximumPoolSize") + if (maxConnectionPoolSize == null) { + dataSourceProperties.setProperty("maximumPoolSize", (flowThreadPoolSize + 1).toString()) + } else { + require(maxConnectionPoolSize.toInt() > flowThreadPoolSize) + } } } diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index 72ef0fe6a6..cfee41f1c7 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -58,7 +58,6 @@ class NodeSchedulerService(private val clock: CordaClock, private val flowStarter: FlowStarter, private val stateLoader: StateLoader, private val unfinishedSchedules: ReusableLatch = ReusableLatch(), - private val serverThread: Executor, private val flowLogicRefFactory: FlowLogicRefFactory, private val log: Logger = staticLog, private val scheduledStates: MutableMap = createMap()) @@ -244,24 +243,22 @@ class NodeSchedulerService(private val clock: CordaClock, } private fun onTimeReached(scheduledState: ScheduledStateRef) { - serverThread.execute { - var flowName: String? = "(unknown)" - try { - database.transaction { - val scheduledFlow = getScheduledFlow(scheduledState) - if (scheduledFlow != null) { - flowName = scheduledFlow.javaClass.name - // TODO refactor the scheduler to store and propagate the original invocation context - val context = InvocationContext.newInstance(Origin.Scheduled(scheduledState)) - val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture } - future.then { - unfinishedSchedules.countDown() - } + var flowName: String? = "(unknown)" + try { + database.transaction { + val scheduledFlow = getScheduledFlow(scheduledState) + if (scheduledFlow != null) { + flowName = scheduledFlow.javaClass.name + // TODO refactor the scheduler to store and propagate the original invocation context + val context = InvocationContext.newInstance(Origin.Scheduled(scheduledState)) + val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture } + future.then { + unfinishedSchedules.countDown() } } - } catch (e: Exception) { - log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e) } + } catch (e: Exception) { + log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e) } } diff --git a/node/src/main/kotlin/net/corda/node/services/identity/PersistentIdentityService.kt b/node/src/main/kotlin/net/corda/node/services/identity/PersistentIdentityService.kt index 210e610ee4..9d43100b43 100644 --- a/node/src/main/kotlin/net/corda/node/services/identity/PersistentIdentityService.kt +++ b/node/src/main/kotlin/net/corda/node/services/identity/PersistentIdentityService.kt @@ -135,7 +135,7 @@ class PersistentIdentityService(override val trustRoot: X509Certificate, log.debug { "Registering identity $identity" } val key = mapToKey(identity) - keyToParties.addWithDuplicatesAllowed(key, identity) + keyToParties.addWithDuplicatesAllowed(key, identity, false) // Always keep the first party we registered, as that's the well known identity principalToParties.addWithDuplicatesAllowed(identity.name, key, false) val parentId = mapToKey(identityCertChain[1].publicKey) diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index f6837b80c1..21e72cdcae 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -142,7 +142,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration, journalBufferSize_AIO = maxMessageSize // Required to address IllegalArgumentException (when Artemis uses Linux Async IO): Record is too large to store. journalFileSize = maxMessageSize // The size of each journal file in bytes. Artemis default is 10MiB. managementNotificationAddress = SimpleString(NOTIFICATIONS_ADDRESS) - + connectionTtlCheckInterval = config.enterpriseConfiguration.tuning.brokerConnectionTtlCheckIntervalMs // JMX enablement if (config.exportJMXto.isNotEmpty()) { isJMXManagementEnabled = true diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 898c881e90..03b6e33517 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -1,5 +1,6 @@ package net.corda.node.services.messaging +import co.paralleluniverse.fibers.Suspendable import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.CordaX500Name import net.corda.core.messaging.MessageRecipients @@ -60,15 +61,13 @@ interface MessagingService { * @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two * subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same * sequence the send()s were called. By default this is chosen conservatively to be [target]. - * @param acknowledgementHandler if non-null this handler will be called once the sent message has been committed by - * the broker. Note that if specified [send] itself may return earlier than the commit. */ + @Suspendable fun send( message: Message, target: MessageRecipients, retryId: Long? = null, - sequenceKey: Any = target, - acknowledgementHandler: (() -> Unit)? = null + sequenceKey: Any = target ) /** A message with a target and sequenceKey specified. */ @@ -84,10 +83,9 @@ interface MessagingService { * implementation. * * @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys. - * @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed - * by the broker. Note that if specified [send] itself may return earlier than the commit. */ - fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)? = null) + @Suspendable + fun send(addressedMessages: List) /** Cancels the scheduled message redelivery for the specified [retryId] */ fun cancelRedelivery(retryId: Long) diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt b/node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt new file mode 100644 index 0000000000..a0e661890a --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/messaging/MessagingExecutor.kt @@ -0,0 +1,221 @@ +package net.corda.node.services.messaging + +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.SettableFuture +import com.codahale.metrics.MetricRegistry +import net.corda.core.messaging.MessageRecipients +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.trace +import net.corda.node.VersionInfo +import net.corda.node.services.statemachine.FlowMessagingImpl +import org.apache.activemq.artemis.api.core.ActiveMQDuplicateIdException +import org.apache.activemq.artemis.api.core.ActiveMQException +import org.apache.activemq.artemis.api.core.SimpleString +import org.apache.activemq.artemis.api.core.client.ClientMessage +import org.apache.activemq.artemis.api.core.client.ClientProducer +import org.apache.activemq.artemis.api.core.client.ClientSession +import java.util.* +import java.util.concurrent.ArrayBlockingQueue +import java.util.concurrent.ExecutionException +import kotlin.concurrent.thread + +interface AddressToArtemisQueueResolver { + /** + * Resolves a [MessageRecipients] to an Artemis queue name, creating the underlying queue if needed. + */ + fun resolveTargetToArtemisQueue(address: MessageRecipients): String +} + +/** + * The [MessagingExecutor] is responsible for handling send and acknowledge jobs. It batches them using a bounded + * blocking queue, submits the jobs asynchronously and then waits for them to flush using [ClientSession.commit]. + * Note that even though we buffer in theory this shouldn't increase latency as the executor is immediately woken up if + * it was waiting. The number of jobs in the queue is only ever greater than 1 if the commit takes a long time. + */ +class MessagingExecutor( + val session: ClientSession, + val producer: ClientProducer, + val versionInfo: VersionInfo, + val resolver: AddressToArtemisQueueResolver, + metricRegistry: MetricRegistry, + queueBound: Int +) { + private sealed class Job { + data class Acknowledge(val message: ClientMessage) : Job() + data class Send( + val message: Message, + val target: MessageRecipients, + val sentFuture: SettableFuture + ) : Job() { + override fun toString() = "Send(${message.uniqueMessageId}, target=$target)" + } + object Shutdown : Job() { override fun toString() = "Shutdown" } + } + + private val queue = ArrayBlockingQueue(queueBound) + private var executor: Thread? = null + private val cordaVendor = SimpleString(versionInfo.vendor) + private val releaseVersion = SimpleString(versionInfo.releaseVersion) + private val sendMessageSizeMetric = metricRegistry.histogram("SendMessageSize") + private val sendLatencyMetric = metricRegistry.timer("SendLatency") + private val sendBatchSizeMetric = metricRegistry.histogram("SendBatchSize") + + private companion object { + val log = contextLogger() + val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt() + } + + /** + * Submit a send job of [message] to [target] and wait until it finishes. + * This call may yield the fiber. + */ + @Suspendable + fun send(message: Message, target: MessageRecipients) { + val sentFuture = SettableFuture() + val job = Job.Send(message, target, sentFuture) + val context = sendLatencyMetric.time() + try { + queue.put(job) + sentFuture.get() + } catch (executionException: ExecutionException) { + throw executionException.cause!! + } finally { + context.stop() + } + } + + /** + * Submit an acknowledge job of [message]. + * This call does NOT wait for confirmation of the ACK receive. If a failure happens then the message will either be + * redelivered, deduped and acked, or the message was actually acked before failure in which case all is good. + */ + fun acknowledge(message: ClientMessage) { + queue.put(Job.Acknowledge(message)) + } + + fun start() { + require(executor == null) + executor = thread(name = "Messaging executor", isDaemon = true) { + val batch = ArrayList() + eventLoop@ while (true) { + batch.add(queue.take()) // Block until at least one job is available. + queue.drainTo(batch) + sendBatchSizeMetric.update(batch.filter { it is Job.Send }.size) + val shouldShutdown = try { + // Try to handle the batch in one commit. + handleBatchTransactional(batch) + } catch (exception: ActiveMQException) { + // A job failed, rollback and do it one at a time, simply log and skip if an individual job fails. + // If a send job fails the exception will be re-raised in the corresponding future. + // Note that this fallback assumes that there are no two jobs in the batch that depend on one + // another. As the exception is re-raised in the requesting calling thread in case of a send, we can + // assume no "in-flight" messages will be sent out of order after failure. + log.warn("Exception while handling transactional batch, falling back to handling one job at a time", exception) + handleBatchOneByOne(batch) + } + batch.clear() + if (shouldShutdown) { + break@eventLoop + } + } + } + } + + fun close() { + val executor = this.executor + if (executor != null) { + queue.offer(Job.Shutdown) + executor.join() + this.executor = null + } + } + + /** + * Handles a batch of jobs in one transaction. + * @return true if the executor should shut down, false otherwise. + * @throws ActiveMQException + */ + private fun handleBatchTransactional(batch: List): Boolean { + for (job in batch) { + when (job) { + is Job.Acknowledge -> { + acknowledgeJob(job) + } + is Job.Send -> { + sendJob(job) + } + Job.Shutdown -> { + session.commit() + return true + } + } + } + session.commit() + return false + } + + /** + * Handles a batch of jobs one by one, committing after each. + * @return true if the executor should shut down, false otherwise. + */ + private fun handleBatchOneByOne(batch: List): Boolean { + for (job in batch) { + try { + when (job) { + is Job.Acknowledge -> { + acknowledgeJob(job) + session.commit() + } + is Job.Send -> { + try { + sendJob(job) + session.commit() + } catch (duplicateException: ActiveMQDuplicateIdException) { + log.warn("Message duplication", duplicateException) + job.sentFuture.set(Unit) + } + } + Job.Shutdown -> { + session.commit() + return true + } + } + } catch (exception: Throwable) { + log.error("Exception while handling job $job, disregarding", exception) + if (job is Job.Send) { + job.sentFuture.setException(exception) + } + session.rollback() + } + } + return false + } + + private fun sendJob(job: Job.Send) { + val mqAddress = resolver.resolveTargetToArtemisQueue(job.target) + val artemisMessage = session.createMessage(true).apply { + putStringProperty(P2PMessagingClient.cordaVendorProperty, cordaVendor) + putStringProperty(P2PMessagingClient.releaseVersionProperty, releaseVersion) + putIntProperty(P2PMessagingClient.platformVersionProperty, versionInfo.platformVersion) + putStringProperty(P2PMessagingClient.topicProperty, SimpleString(job.message.topic)) + sendMessageSizeMetric.update(job.message.data.bytes.size) + writeBodyBufferBytes(job.message.data.bytes) + // Use the magic deduplication property built into Artemis as our message identity too + putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(job.message.uniqueMessageId.toString)) + + // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended + if (amqDelayMillis > 0 && job.message.topic == FlowMessagingImpl.sessionTopic) { + putLongProperty(org.apache.activemq.artemis.api.core.Message.HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) + } + } + log.trace { + "Send to: $mqAddress topic: ${job.message.topic} " + + "sessionID: ${job.message.topic} id: ${job.message.uniqueMessageId}" + } + producer.send(SimpleString(mqAddress), artemisMessage) { job.sentFuture.set(Unit) } + } + + private fun acknowledgeJob(job: Job.Acknowledge) { + job.message.individualAcknowledge() + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index 03bc7073ed..e0769c8f27 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -1,5 +1,7 @@ package net.corda.node.services.messaging +import co.paralleluniverse.fibers.Suspendable +import com.codahale.metrics.MetricRegistry import net.corda.core.crypto.toStringShort import net.corda.core.identity.CordaX500Name import net.corda.core.internal.ThreadBox @@ -18,7 +20,6 @@ import net.corda.node.VersionInfo import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.statemachine.DeduplicationId -import net.corda.node.services.statemachine.FlowMessagingImpl import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.PersistentMap @@ -32,7 +33,8 @@ import net.corda.nodeapi.internal.bridging.BridgeEntry import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException -import org.apache.activemq.artemis.api.core.Message.* +import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID +import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ClientConsumer @@ -78,7 +80,7 @@ import javax.persistence.Lob * @param maxMessageSize A bound applied to the message size. */ @ThreadSafe -class P2PMessagingClient(config: NodeConfiguration, +class P2PMessagingClient(val config: NodeConfiguration, private val versionInfo: VersionInfo, serverAddress: NetworkHostAndPort, private val myIdentity: PublicKey, @@ -86,20 +88,20 @@ class P2PMessagingClient(config: NodeConfiguration, private val nodeExecutor: AffinityExecutor.ServiceAffinityExecutor, private val database: CordaPersistence, private val networkMap: NetworkMapCacheInternal, + private val metricRegistry: MetricRegistry, advertisedAddress: NetworkHostAndPort = serverAddress, maxMessageSize: Int -) : SingletonSerializeAsToken(), MessagingService { +) : SingletonSerializeAsToken(), MessagingService, AddressToArtemisQueueResolver { companion object { private val log = contextLogger() // This is a "property" attached to an Artemis MQ message object, which contains our own notion of "topic". // We should probably try to unify our notion of "topic" (really, just a string that identifies an endpoint // that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid // confusion. - private val topicProperty = SimpleString("platform-topic") - private val cordaVendorProperty = SimpleString("corda-vendor") - private val releaseVersionProperty = SimpleString("release-version") - private val platformVersionProperty = SimpleString("platform-version") - private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt() + val topicProperty = SimpleString("platform-topic") + val cordaVendorProperty = SimpleString("corda-vendor") + val releaseVersionProperty = SimpleString("release-version") + val platformVersionProperty = SimpleString("platform-version") private val messageMaxRetryCount: Int = 3 fun createProcessedMessages(): AppendOnlyPersistentMap { @@ -159,20 +161,23 @@ class P2PMessagingClient(config: NodeConfiguration, /** A registration to handle messages of different types */ data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration - private val cordaVendor = SimpleString(versionInfo.vendor) - private val releaseVersion = SimpleString(versionInfo.releaseVersion) - /** An executor for sending messages */ - private val messagingExecutor = AffinityExecutor.ServiceAffinityExecutor("Messaging ${myIdentity.toStringShort()}", 1) - override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress) private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong() - private val artemis = ArtemisMessagingClient(config, serverAddress, maxMessageSize) + private val artemis = ArtemisMessagingClient( + config = config, + serverAddress = serverAddress, + maxMessageSize = maxMessageSize, + autoCommitSends = false, + autoCommitAcks = false, + confirmationWindowSize = config.enterpriseConfiguration.tuning.p2pConfirmationWindowSize + ) private val state = ThreadBox(InnerState()) private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap()) private val handlers = ConcurrentHashMap() private val processedMessages = createProcessedMessages() + private var messagingExecutor: MessagingExecutor? = null @Entity @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") @@ -203,7 +208,8 @@ class P2PMessagingClient(config: NodeConfiguration, fun start() { state.locked { - val session = artemis.start().session + val started = artemis.start() + val session = started.session val inbox = RemoteInboxAddress(myIdentity).queueName val inboxes = mutableListOf(inbox) // Create a queue, consumer and producer for handling P2P network messages. @@ -220,6 +226,18 @@ class P2PMessagingClient(config: NodeConfiguration, deliver(msg, message) } } + + val messagingExecutor = MessagingExecutor( + session, + started.producer, + versionInfo, + this@P2PMessagingClient, + metricRegistry, + queueBound = config.enterpriseConfiguration.tuning.maximumMessagingBatchSize + ) + this@P2PMessagingClient.messagingExecutor = messagingExecutor + messagingExecutor.start() + registerBridgeControl(session, inboxes) enumerateBridges(session, inboxes) } @@ -253,6 +271,7 @@ class P2PMessagingClient(config: NodeConfiguration, val artemisMessage = client.session.createMessage(false) artemisMessage.writeBodyBufferBytes(controlPacket) client.producer.send(BRIDGE_CONTROL, artemisMessage) + client.session.commit() } private fun updateBridgesOnNetworkChange(change: NetworkMapCache.MapChange) { @@ -419,12 +438,7 @@ class P2PMessagingClient(config: NodeConfiguration, // processing a message but if so, it'll be parked waiting for us to count down the latch, so // the session itself is still around and we can still ack messages as a result. override fun acknowledge() { - messagingExecutor.fetchFrom { - state.locked { - artemisMessage.individualAcknowledge() - artemis.started!!.session.commit() - } - } + messagingExecutor!!.acknowledge(artemisMessage) } } deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), acknowledgeHandle) @@ -476,6 +490,7 @@ class P2PMessagingClient(config: NodeConfiguration, shutdownLatch.await() } // Only first caller to gets running true to protect against double stop, which seems to happen in some integration tests. + messagingExecutor?.close() if (running) { state.locked { artemis.stop() @@ -483,74 +498,43 @@ class P2PMessagingClient(config: NodeConfiguration, } } - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { - // We have to perform sending on a different thread pool, since using the same pool for messaging and - // fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals. - messagingExecutor.fetchFrom { - state.locked { - val mqAddress = getMQAddress(target) - val artemis = artemis.started!! - val artemisMessage = artemis.session.createMessage(true).apply { - putStringProperty(cordaVendorProperty, cordaVendor) - putStringProperty(releaseVersionProperty, releaseVersion) - putIntProperty(platformVersionProperty, versionInfo.platformVersion) - putStringProperty(topicProperty, SimpleString(message.topic)) - writeBodyBufferBytes(message.data.bytes) - // Use the magic deduplication property built into Artemis as our message identity too - putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString)) - - // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended - if (amqDelayMillis > 0 && message.topic == FlowMessagingImpl.sessionTopic) { - putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) - } - } - log.trace { - "Send to: $mqAddress topic: ${message.topic} " + - "sessionID: ${message.topic} id: ${message.uniqueMessageId}" - } - artemis.producer.send(mqAddress, artemisMessage) - retryId?.let { - database.transaction { - messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) }) - } - scheduledMessageRedeliveries[it] = messagingExecutor.schedule({ - sendWithRetry(0, mqAddress, artemisMessage, it) - }, messageRedeliveryDelaySeconds, TimeUnit.SECONDS) - - } + @Suspendable + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { + messagingExecutor!!.send(message, target) + retryId?.let { + database.transaction { + messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) }) } + scheduledMessageRedeliveries[it] = nodeExecutor.schedule({ + sendWithRetry(0, message, target, retryId) + }, messageRedeliveryDelaySeconds, TimeUnit.SECONDS) + } - acknowledgementHandler?.invoke() } - override fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)?) { + @Suspendable + override fun send(addressedMessages: List) { for ((message, target, retryId, sequenceKey) in addressedMessages) { - send(message, target, retryId, sequenceKey, null) + send(message, target, retryId, sequenceKey) } - acknowledgementHandler?.invoke() } - private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) { - fun ClientMessage.randomiseDuplicateId() { - putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) - } - + private fun sendWithRetry(retryCount: Int, message: Message, target: MessageRecipients, retryId: Long) { log.trace { "Attempting to retry #$retryCount message delivery for $retryId" } if (retryCount >= messageMaxRetryCount) { - log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $address") + log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $target") scheduledMessageRedeliveries.remove(retryId) return } - message.randomiseDuplicateId() - - state.locked { - log.trace { "Retry #$retryCount sending message $message to $address for $retryId" } - artemis.started!!.producer.send(address, message) + val messageWithRetryCount = object : Message by message { + override val uniqueMessageId = DeduplicationId("${message.uniqueMessageId.toString}-$retryCount") } - scheduledMessageRedeliveries[retryId] = messagingExecutor.schedule({ - sendWithRetry(retryCount + 1, address, message, retryId) + messagingExecutor!!.send(messageWithRetryCount, target) + + scheduledMessageRedeliveries[retryId] = nodeExecutor.schedule({ + sendWithRetry(retryCount + 1, message, target, retryId) }, messageRedeliveryDelaySeconds, TimeUnit.SECONDS) } @@ -565,14 +549,14 @@ class P2PMessagingClient(config: NodeConfiguration, } } - private fun getMQAddress(target: MessageRecipients): String { - return if (target == myAddress) { + override fun resolveTargetToArtemisQueue(address: MessageRecipients): String { + return if (address == myAddress) { // If we are sending to ourselves then route the message directly to our P2P queue. RemoteInboxAddress(myIdentity).queueName } else { // Otherwise we send the message to an internal queue for the target residing on our broker. It's then the // broker's job to route the message to the target's P2P queue. - val internalTargetQueue = (target as? ArtemisAddress)?.queueName ?: throw IllegalArgumentException("Not an Artemis address") + val internalTargetQueue = (address as? ArtemisAddress)?.queueName ?: throw IllegalArgumentException("Not an Artemis address") createQueueIfAbsent(internalTargetQueue) internalTargetQueue } @@ -581,20 +565,18 @@ class P2PMessagingClient(config: NodeConfiguration, /** Attempts to create a durable queue on the broker which is bound to an address of the same name. */ private fun createQueueIfAbsent(queueName: String) { if (!knownQueues.contains(queueName)) { - state.alreadyLocked { - val session = artemis.started!!.session - val queueQuery = session.queueQuery(SimpleString(queueName)) - if (!queueQuery.isExists) { - log.info("Create fresh queue $queueName bound on same address") - session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) - if (queueName.startsWith(PEERS_PREFIX)) { - val keyHash = queueName.substring(PEERS_PREFIX.length) - val peers = networkMap.getNodesByOwningKeyIndex(keyHash) - for (node in peers) { - val bridge = BridgeEntry(queueName, node.addresses, node.legalIdentities.map { it.name }) - val createBridgeMessage = BridgeControl.Create(myIdentity.toStringShort(), bridge) - sendBridgeControl(createBridgeMessage) - } + val session = artemis.started!!.session + val queueQuery = session.queueQuery(SimpleString(queueName)) + if (!queueQuery.isExists) { + log.info("Create fresh queue $queueName bound on same address") + session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) + if (queueName.startsWith(PEERS_PREFIX)) { + val keyHash = queueName.substring(PEERS_PREFIX.length) + val peers = networkMap.getNodesByOwningKeyIndex(keyHash) + for (node in peers) { + val bridge = BridgeEntry(queueName, node.addresses, node.legalIdentities.map { it.name }) + val createBridgeMessage = BridgeControl.Create(myIdentity.toStringShort(), bridge) + sendBridgeControl(createBridgeMessage) } } } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt index 6543cbb6fb..7fbd823671 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCMessagingClient.kt @@ -11,14 +11,19 @@ import net.corda.nodeapi.internal.config.SSLConfiguration import net.corda.nodeapi.internal.crypto.X509Utilities import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl -class RPCMessagingClient(private val config: SSLConfiguration, serverAddress: NetworkHostAndPort, maxMessageSize: Int) : SingletonSerializeAsToken(), AutoCloseable { +class RPCMessagingClient( + private val config: SSLConfiguration, + serverAddress: NetworkHostAndPort, + maxMessageSize: Int, + private val rpcServerConfiguration: RPCServerConfiguration = RPCServerConfiguration.default +) : SingletonSerializeAsToken(), AutoCloseable { private val artemis = ArtemisMessagingClient(config, serverAddress, maxMessageSize) private var rpcServer: RPCServer? = null fun start(rpcOps: RPCOps, securityManager: RPCSecurityManager) = synchronized(this) { val locator = artemis.start().sessionFactory.serverLocator val myCert = config.loadSslKeyStore().getCertificate(X509Utilities.CORDA_CLIENT_TLS) - rpcServer = RPCServer(rpcOps, NODE_USER, NODE_USER, locator, securityManager, CordaX500Name.build(myCert.subjectX500Principal)) + rpcServer = RPCServer(rpcOps, NODE_USER, NODE_USER, locator, securityManager, CordaX500Name.build(myCert.subjectX500Principal), rpcServerConfiguration) } fun start2(serverControl: ActiveMQServerControl) = synchronized(this) { diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt index 0e06674aa0..a25dc5dd45 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -442,7 +442,7 @@ class ObservableContext( val artemisMessage = it.session.createMessage(false) serverToClient.writeToClientMessage(serializationContextWithObservableContext, artemisMessage) it.producer.send(clientAddress, artemisMessage) - log.debug("<- RPC <- $serverToClient") + log.debug { "<- RPC <- $serverToClient" } } } catch (throwable: Throwable) { log.error("Failed to send message, kicking client. Message was $serverToClient", throwable) diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/AbstractPartyToX500NameAsStringConverter.kt b/node/src/main/kotlin/net/corda/node/services/persistence/AbstractPartyToX500NameAsStringConverter.kt index cf813a25d9..2617c4530c 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/AbstractPartyToX500NameAsStringConverter.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/AbstractPartyToX500NameAsStringConverter.kt @@ -21,7 +21,6 @@ class AbstractPartyToX500NameAsStringConverter(private val identityService: Iden if (party != null) { val partyName = identityService.wellKnownPartyFromAnonymous(party)?.toString() if (partyName != null) return partyName - log.warn("Identity service unable to resolve AbstractParty: $party") } return null // non resolvable anonymous parties } @@ -30,7 +29,6 @@ class AbstractPartyToX500NameAsStringConverter(private val identityService: Iden if (dbData != null) { val party = identityService.wellKnownPartyFromX500Name(CordaX500Name.parse(dbData)) if (party != null) return party - log.warn("Identity service unable to resolve X500name: $dbData") } return null // non resolvable anonymous parties are stored as nulls } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt index 3ed86c65af..bbf0735087 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt @@ -3,6 +3,7 @@ package net.corda.node.services.persistence import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.crypto.SecureHash import net.corda.core.flows.StateMachineRunId +import net.corda.core.internal.ConcurrentBox import net.corda.core.messaging.DataFeed import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage @@ -51,16 +52,27 @@ class DBTransactionMappingStorage : StateMachineRecordedTransactionMappingStorag } } - val stateMachineTransactionMap = createMap() - val updates: PublishSubject = PublishSubject.create() - - override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) { - stateMachineTransactionMap[transactionId] = stateMachineRunId - updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId)) + private class InnerState { + val stateMachineTransactionMap = createMap() + val updates: PublishSubject = PublishSubject.create() } - override fun track(): DataFeed, StateMachineTransactionMapping> = - DataFeed(stateMachineTransactionMap.allPersisted().map { StateMachineTransactionMapping(it.second, it.first) }.toList(), - updates.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + private val concurrentBox = ConcurrentBox(InnerState()) + + override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) { + concurrentBox.concurrent { + stateMachineTransactionMap[transactionId] = stateMachineRunId + updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId)) + } + } + + override fun track(): DataFeed, StateMachineTransactionMapping> { + return concurrentBox.exclusive { + DataFeed( + stateMachineTransactionMap.allPersisted().map { StateMachineTransactionMapping(it.second, it.first) }.toList(), + updates.bufferUntilSubscribed().wrapWithDatabaseTransaction() + ) + } + } } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index d715cc70d5..3898880d31 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -3,7 +3,7 @@ package net.corda.node.services.persistence import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.crypto.TransactionSignature -import net.corda.core.internal.ThreadBox +import net.corda.core.internal.ConcurrentBox import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.concurrent.doneFuture @@ -79,10 +79,10 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S } } - private val txStorage = ThreadBox(createTransactionsMap(cacheSizeBytes)) + private val txStorage = ConcurrentBox(createTransactionsMap(cacheSizeBytes)) override fun addTransaction(transaction: SignedTransaction): Boolean = - txStorage.locked { + txStorage.concurrent { addWithDuplicatesAllowed(transaction.id, transaction.toTxCacheValue()).apply { updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) } @@ -94,13 +94,13 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S override val updates: Observable = updatesPublisher.wrapWithDatabaseTransaction() override fun track(): DataFeed, SignedTransaction> { - return txStorage.locked { + return txStorage.exclusive { DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) } } override fun trackTransaction(id: SecureHash): CordaFuture { - return txStorage.locked { + return txStorage.exclusive { val existingTransaction = get(id) if (existingTransaction == null) { updatesPublisher.filter { it.id == id }.toFuture() diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt index 4f7514da01..5bbb81c5f9 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/ActionExecutorImpl.kt @@ -122,7 +122,6 @@ class ActionExecutorImpl( val exception = error.flowException log.debug("Propagating error", exception) } - val pendingSendAcks = CountUpDownLatch(0) for (sessionState in action.sessions) { // We cannot propagate if the session isn't live. if (sessionState.initiatedState !is InitiatedSessionState.Live) { @@ -133,14 +132,9 @@ class ActionExecutorImpl( val sinkSessionId = sessionState.initiatedState.peerSinkSessionId val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage) val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId) - pendingSendAcks.countUp() - flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId) { - pendingSendAcks.countDown() - } + flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId) } } - // TODO we simply block here, perhaps this should be explicit in the worker state - pendingSendAcks.await() } @Suspendable @@ -163,12 +157,12 @@ class ActionExecutorImpl( @Suspendable private fun executeSendInitial(action: Action.SendInitial) { - flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId, null) + flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId) } @Suspendable private fun executeSendExisting(action: Action.SendExisting) { - flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId, null) + flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId) } @Suspendable diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt index 8adf000fca..9f149cc0ae 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowMessaging.kt @@ -1,5 +1,6 @@ package net.corda.node.services.statemachine +import co.paralleluniverse.fibers.Suspendable import com.esotericsoftware.kryo.KryoException import net.corda.core.flows.FlowException import net.corda.core.identity.Party @@ -20,7 +21,8 @@ interface FlowMessaging { * Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to * listen on the send acknowledgement. */ - fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?) + @Suspendable + fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) /** * Start the messaging using the [onMessage] message handler. @@ -45,7 +47,8 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { } } - override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?) { + @Suspendable + override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) { log.trace { "Sending message $deduplicationId $message to party $party" } val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId) val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party") @@ -54,7 +57,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging { is InitialSessionMessage -> message.initiatorSessionId is ExistingSessionMessage -> message.recipientSessionId } - serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey, acknowledgementHandler = acknowledgementHandler) + serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey) } private fun serializeSessionMessage(message: SessionMessage): SerializedBytes { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index 76b3afe2fc..7329590715 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -7,7 +7,6 @@ import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.channels.Channel import com.codahale.metrics.Counter -import com.codahale.metrics.Metric import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.flows.* @@ -43,10 +42,7 @@ class TransientReference(@Transient val value: A) class FlowStateMachineImpl(override val id: StateMachineRunId, override val logic: FlowLogic, - scheduler: FiberScheduler, - private val totalSuccessMetric: Counter, - private val totalErrorMetric: Counter - // Store the Party rather than the full cert path with PartyAndCertificate + scheduler: FiberScheduler ) : Fiber(id.toString(), scheduler), FlowStateMachine, FlowFiber { companion object { /** @@ -55,18 +51,6 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*> private val log: Logger = LoggerFactory.getLogger("net.corda.flow") - - @Suspendable - private fun abortFiber(): Nothing { - Fiber.park() - throw IllegalStateException("Ended fiber unparked") - } - - private fun extractThreadLocalTransaction(): TransientReference { - val transaction = contextTransaction - contextTransactionOrNull = null - return TransientReference(transaction) - } } override val serviceHub get() = getTransientField(TransientValues::serviceHub) @@ -90,6 +74,12 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, return field.get(suppliedValues.value) } + private fun extractThreadLocalTransaction(): TransientReference { + val transaction = contextTransaction + contextTransactionOrNull = null + return TransientReference(transaction) + } + /** * Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message * is not necessary. @@ -145,8 +135,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val startTime = System.nanoTime() val resultOrError = try { val result = logic.call() - // TODO expose maySkipCheckpoint here - suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = false) + suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true) Try.Success(result) } catch (throwable: Throwable) { logger.warn("Flow threw exception", throwable) @@ -154,15 +143,13 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } val finalEvent = when (resultOrError) { is Try.Success -> { - totalSuccessMetric.inc() Event.FlowFinish(resultOrError.value) } is Try.Failure -> { - totalErrorMetric.inc() Event.Error(resultOrError.exception) } } - processEvent(getTransientField(TransientValues::transitionExecutor), finalEvent) + scheduleEvent(finalEvent) processEventsUntilFlowIsResumed() recordDuration(startTime) @@ -192,6 +179,13 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, return resume.result as FlowSession } + @Suspendable + private fun abortFiber(): Nothing { + while (true) { + Fiber.park() + } + } + // TODO Dummy implementation of access to application specific permission controls and audit logging override fun checkFlowPermission(permissionName: String, extraAuditData: Map) { val permissionGranted = true // TODO define permission control service on ServiceHubInternal and actually check authorization. @@ -257,7 +251,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, require(processEvent(transitionExecutor.value, event) == FlowContinuation.ProcessEvents) Fiber.unparkDeserialized(this, scheduler) } - return processEventsUntilFlowIsResumed() as R + return uncheckedCast(processEventsUntilFlowIsResumed()) } @Suspendable diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt new file mode 100644 index 0000000000..c62713511b --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/MultiThreadedStateMachineManager.kt @@ -0,0 +1,668 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.FiberExecutorScheduler +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.fibers.instrument.SuspendableHelper +import co.paralleluniverse.strands.channels.Channels +import com.codahale.metrics.Gauge +import net.corda.core.concurrent.CordaFuture +import net.corda.core.context.InvocationContext +import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowInfo +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StateMachineRunId +import net.corda.core.identity.Party +import net.corda.core.internal.* +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.messaging.DataFeed +import net.corda.core.serialization.* +import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.Try +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.debug +import net.corda.node.internal.InitiatedFlowFactory +import net.corda.node.services.api.CheckpointStorage +import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.config.shouldCheckCheckpoints +import net.corda.node.services.messaging.AcknowledgeHandle +import net.corda.node.services.messaging.ReceivedMessage +import net.corda.node.services.statemachine.interceptors.* +import net.corda.node.services.statemachine.transitions.StateMachine +import net.corda.node.services.statemachine.transitions.StateMachineConfiguration +import net.corda.node.utilities.AffinityExecutor +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl +import net.corda.nodeapi.internal.serialization.withTokenContext +import org.apache.activemq.artemis.utils.ReusableLatch +import rx.Observable +import rx.subjects.PublishSubject +import java.security.SecureRandom +import java.util.* +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutorService +import javax.annotation.concurrent.ThreadSafe +import kotlin.collections.ArrayList +import kotlin.streams.toList + +/** + * The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which + * thread actually starts them via [startFlow]. + */ +@ThreadSafe +class MultiThreadedStateMachineManager( + val serviceHub: ServiceHubInternal, + val checkpointStorage: CheckpointStorage, + val executor: ExecutorService, + val database: CordaPersistence, + val secureRandom: SecureRandom, + private val unfinishedFibers: ReusableLatch = ReusableLatch(), + private val classloader: ClassLoader = MultiThreadedStateMachineManager::class.java.classLoader +) : StateMachineManager, StateMachineManagerInternal { + companion object { + private val logger = contextLogger() + } + + private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture) + private enum class State { + UNSTARTED, + STARTED, + STOPPING, + STOPPED + } + + private val lifeCycle = LifeCycle(State.UNSTARTED) + private class InnerState { + val flows = ConcurrentHashMap() + val startedFutures = ConcurrentHashMap>() + val changesPublisher = PublishSubject.create()!! + } + + private val concurrentBox = ConcurrentBox(InnerState()) + + private val scheduler = FiberExecutorScheduler("Flow fiber scheduler", executor) + // How many Fibers are running and not suspended. If zero and stopping is true, then we are halted. + private val liveFibers = ReusableLatch() + // Monitoring support. + private val metrics = serviceHub.monitoringService.metrics + private val sessionToFlow = ConcurrentHashMap() + private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub) + private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null + private val transitionExecutor = makeTransitionExecutor() + + private var checkpointSerializationContext: SerializationContext? = null + private var tokenizableServices: List? = null + private var actionExecutor: ActionExecutor? = null + + override val allStateMachines: List> + get() = concurrentBox.content.flows.values.map { it.fiber.logic } + + + private val totalStartedFlows = metrics.counter("Flows.Started") + private val totalFinishedFlows = metrics.counter("Flows.Finished") + private val totalSuccessFlows = metrics.counter("Flows.Success") + private val totalErrorFlows = metrics.counter("Flows.Error") + + /** + * An observable that emits triples of the changing flow, the type of change, and a process-specific ID number + * which may change across restarts. + * + * We use assignment here so that multiple subscribers share the same wrapped Observable. + */ + override val changes: Observable = concurrentBox.content.changesPublisher + + override fun start(tokenizableServices: List) { + checkQuasarJavaAgentPresence() + this.tokenizableServices = tokenizableServices + val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext( + SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub) + ) + this.checkpointSerializationContext = checkpointSerializationContext + this.actionExecutor = makeActionExecutor(checkpointSerializationContext) + fiberDeserializationChecker?.start(checkpointSerializationContext) + val fibers = restoreFlowsFromCheckpoints() + metrics.register("Flows.InFlight", Gauge { concurrentBox.content.flows.size }) + Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> + (fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable) + } + serviceHub.networkMapCache.nodeReady.then { + resumeRestoredFlows(fibers) + flowMessaging.start { receivedMessage, acknowledgeHandle -> + lifeCycle.requireState(State.STARTED) { + onSessionMessage(receivedMessage, acknowledgeHandle) + } + } + } + lifeCycle.transition(State.UNSTARTED, State.STARTED) + } + + override fun > findStateMachines(flowClass: Class): List>> { + return concurrentBox.content.flows.values.mapNotNull { + flowClass.castIfPossible(it.fiber.logic)?.let { it to it.stateMachine.resultFuture } + } + } + + /** + * Start the shutdown process, bringing the [MultiThreadedStateMachineManager] to a controlled stop. When this method returns, + * all Fibers have been suspended and checkpointed, or have completed. + * + * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. + */ + override fun stop(allowedUnsuspendedFiberCount: Int) { + require(allowedUnsuspendedFiberCount >= 0) + lifeCycle.transition(State.STARTED, State.STOPPING) + for ((_, flow) in concurrentBox.content.flows) { + flow.fiber.scheduleEvent(Event.SoftShutdown) + } + // Account for any expected Fibers in a test scenario. + liveFibers.countDown(allowedUnsuspendedFiberCount) + liveFibers.await() + fiberDeserializationChecker?.let { + val foundUnrestorableFibers = it.stop() + check(!foundUnrestorableFibers) { "Unrestorable checkpoints were created, please check the logs for details." } + } + lifeCycle.transition(State.STOPPING, State.STOPPED) + } + + /** + * Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and + * calls to [allStateMachines] + */ + override fun track(): DataFeed>, StateMachineManager.Change> { + return concurrentBox.exclusive { + DataFeed(flows.values.map { it.fiber.logic }, changesPublisher.bufferUntilSubscribed()) + } + } + + override fun startFlow( + flowLogic: FlowLogic, + context: InvocationContext, + ourIdentity: Party? + ): CordaFuture> { + return lifeCycle.requireState(State.STARTED) { + startFlowInternal( + invocationContext = context, + flowLogic = flowLogic, + flowStart = FlowStart.Explicit, + ourIdentity = ourIdentity ?: getOurFirstIdentity(), + initialUnacknowledgedMessage = null, + isStartIdempotent = false + ) + } + } + + override fun killFlow(id: StateMachineRunId): Boolean { + concurrentBox.concurrent { + val flow = flows.remove(id) + if (flow != null) { + logger.debug("Killing flow known to physical node.") + decrementLiveFibers() + totalFinishedFlows.inc() + unfinishedFibers.countDown() + try { + flow.fiber.interrupt() + return true + } finally { + database.transaction { + checkpointStorage.removeCheckpoint(id) + } + } + } else { + // TODO replace with a clustered delete after we'll support clustered nodes + logger.debug("Unable to kill a flow unknown to physical node. Might be processed by another physical node.") + return false + } + } + } + + override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) { + val previousFlowId = sessionToFlow.put(sessionId, flowId) + if (previousFlowId != null) { + if (previousFlowId == flowId) { + logger.warn("Session binding from $sessionId to $flowId re-added") + } else { + throw IllegalStateException( + "Attempted to add session binding from session $sessionId to flow $flowId, " + + "however there was already a binding to $previousFlowId" + ) + } + } + } + + override fun removeSessionBindings(sessionIds: Set) { + val reRemovedSessionIds = HashSet() + for (sessionId in sessionIds) { + val flowId = sessionToFlow.remove(sessionId) + if (flowId == null) { + reRemovedSessionIds.add(sessionId) + } + } + if (reRemovedSessionIds.isNotEmpty()) { + logger.warn("Session binding from $reRemovedSessionIds re-removed") + } + } + + override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) { + concurrentBox.concurrent { + val flow = flows.remove(flowId) + if (flow != null) { + decrementLiveFibers() + totalFinishedFlows.inc() + unfinishedFibers.countDown() + return when (removalReason) { + is FlowRemovalReason.OrderlyFinish -> removeFlowOrderly(flow, removalReason, lastState) + is FlowRemovalReason.ErrorFinish -> removeFlowError(flow, removalReason, lastState) + FlowRemovalReason.SoftShutdown -> flow.fiber.scheduleEvent(Event.SoftShutdown) + } + } else { + logger.warn("Flow $flowId re-finished") + } + } + } + + override fun signalFlowHasStarted(flowId: StateMachineRunId) { + concurrentBox.concurrent { + startedFutures.remove(flowId)?.set(Unit) + } + } + + private fun checkQuasarJavaAgentPresence() { + check(SuspendableHelper.isJavaAgentActive(), { + """Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM. + #See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#") + }) + } + + private fun decrementLiveFibers() { + liveFibers.countDown() + } + + private fun incrementLiveFibers() { + liveFibers.countUp() + } + + private fun restoreFlowsFromCheckpoints(): List { + return checkpointStorage.getAllCheckpoints().map { (id, serializedCheckpoint) -> + // If a flow is added before start() then don't attempt to restore it + if (concurrentBox.content.flows.containsKey(id)) return@map null + val checkpoint = deserializeCheckpoint(serializedCheckpoint) + if (checkpoint == null) return@map null + createFlowFromCheckpoint( + id = id, + checkpoint = checkpoint, + initialUnacknowledgedMessage = null, + isAnyCheckpointPersisted = true, + isStartIdempotent = false + ) + }.toList().filterNotNull() + } + + private fun resumeRestoredFlows(flows: List) { + for (flow in flows) { + addAndStartFlow(flow.fiber.id, flow) + } + } + + private fun onSessionMessage(message: ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) { + val peer = message.peer + val sessionMessage = try { + message.data.deserialize() + } catch (ex: Exception) { + logger.error("Received corrupt SessionMessage data from $peer") + acknowledgeHandle.acknowledge() + return + } + val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) + if (sender != null) { + when (sessionMessage) { + is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, acknowledgeHandle, sender) + is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, acknowledgeHandle, sender) + } + } else { + logger.error("Unknown peer $peer in $sessionMessage") + } + } + + private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + try { + val recipientId = sessionMessage.recipientSessionId + val flowId = sessionToFlow[recipientId] + if (flowId == null) { + acknowledgeHandle.acknowledge() + if (sessionMessage.payload is EndSessionMessage) { + logger.debug { + "Got ${EndSessionMessage::class.java.simpleName} for " + + "unknown session $recipientId, discarding..." + } + } else { + throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId") + } + } else { + val flow = concurrentBox.content.flows[flowId] ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId") + flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, acknowledgeHandle, sender)) + } + } catch (exception: Exception) { + logger.error("Exception while routing $sessionMessage", exception) + throw exception + } + } + + private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, acknowledgeHandle: AcknowledgeHandle, sender: Party) { + fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage { + val errorId = secureRandom.nextLong() + val payload = RejectSessionMessage(message, errorId) + return ExistingSessionMessage(initiatorSessionId, payload) + } + val replyError = try { + val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage) + val initiatedSessionId = SessionId.createRandom(secureRandom) + val senderSession = FlowSessionImpl(sender, initiatedSessionId) + val flowLogic = initiatedFlowFactory.createFlow(senderSession) + val initiatedFlowInfo = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> FlowInfo(serviceHub.myInfo.platformVersion, "corda") + is InitiatedFlowFactory.CorDapp -> FlowInfo(initiatedFlowFactory.flowVersion, initiatedFlowFactory.appName) + } + val senderCoreFlowVersion = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> senderPlatformVersion + is InitiatedFlowFactory.CorDapp -> null + } + startInitiatedFlow(flowLogic, acknowledgeHandle, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo) + null + } catch (exception: Exception) { + logger.warn("Exception while creating initiated flow", exception) + createErrorMessage( + sessionMessage.initiatorSessionId, + (exception as? SessionRejectException)?.message ?: "Unable to establish session" + ) + } + + if (replyError != null) { + flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) + acknowledgeHandle.acknowledge() + } + } + + // TODO this is a temporary hack until we figure out multiple identities + private fun getOurFirstIdentity(): Party { + return serviceHub.myInfo.legalIdentities[0] + } + + private fun getInitiatedFlowFactory(message: InitialSessionMessage): InitiatedFlowFactory<*> { + val initiatingFlowClass = try { + Class.forName(message.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java) + } catch (e: ClassNotFoundException) { + throw SessionRejectException("Don't know ${message.initiatorFlowClassName}") + } catch (e: ClassCastException) { + throw SessionRejectException("${message.initiatorFlowClassName} is not a flow") + } + return serviceHub.getFlowFactory(initiatingFlowClass) ?: + throw SessionRejectException("$initiatingFlowClass is not registered") + } + + private fun startInitiatedFlow( + flowLogic: FlowLogic, + triggeringUnacknowledgedMessage: AcknowledgeHandle, + peerSession: FlowSessionImpl, + initiatedSessionId: SessionId, + initiatingMessage: InitialSessionMessage, + senderCoreFlowVersion: Int?, + initiatedFlowInfo: FlowInfo + ) { + val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo) + val ourIdentity = getOurFirstIdentity() + startFlowInternal( + InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity, + triggeringUnacknowledgedMessage, + isStartIdempotent = false + ) + } + + private fun startFlowInternal( + invocationContext: InvocationContext, + flowLogic: FlowLogic, + flowStart: FlowStart, + ourIdentity: Party, + initialUnacknowledgedMessage: AcknowledgeHandle?, + isStartIdempotent: Boolean + ): CordaFuture> { + val flowId = StateMachineRunId.createRandom() + val deduplicationSeed = when (flowStart) { + FlowStart.Explicit -> flowId.uuid.toString() + is FlowStart.Initiated -> + "${flowStart.initiatingMessage.initiatorSessionId.toLong}-" + + "${flowStart.initiatingMessage.initiationEntropy}" + } + + // Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties + // have access to the fiber (and thereby the service hub) + val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler) + val resultFuture = openFuture() + flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) + flowLogic.stateMachine = flowStateMachineImpl + val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!) + + val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed).getOrThrow() + val startedFuture = openFuture() + val initialState = StateMachineState( + checkpoint = initialCheckpoint, + unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = false, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = flowLogic + ) + flowStateMachineImpl.transientState = TransientReference(initialState) + concurrentBox.concurrent { + startedFutures[flowId] = startedFuture + } + totalStartedFlows.inc() + addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture)) + return startedFuture.map { flowStateMachineImpl as FlowStateMachine } + } + + private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes): Checkpoint? { + return try { + serializedCheckpoint.deserialize(context = checkpointSerializationContext!!) + } catch (exception: Throwable) { + logger.error("Encountered unrestorable checkpoint!", exception) + null + } + } + + private fun verifyFlowLogicIsSuspendable(logic: FlowLogic) { + // Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's + // easy to forget to add this when creating a new flow, so we check here to give the user a better error. + // + // The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which + // forwards to the void method and then returns Unit. However annotations do not get copied across to this + // bridge, so we have to do a more complex scan here. + val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 } + if (call.getAnnotation(Suspendable::class.java) == null) { + throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.") + } + } + + private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture): FlowStateMachineImpl.TransientValues { + return FlowStateMachineImpl.TransientValues( + eventQueue = Channels.newChannel(16, Channels.OverflowPolicy.BLOCK), + resultFuture = resultFuture, + database = database, + transitionExecutor = transitionExecutor, + actionExecutor = actionExecutor!!, + stateMachine = StateMachine(id, StateMachineConfiguration.default, secureRandom), + serviceHub = serviceHub, + checkpointSerializationContext = checkpointSerializationContext!! + ) + } + + private fun createFlowFromCheckpoint( + id: StateMachineRunId, + checkpoint: Checkpoint, + isAnyCheckpointPersisted: Boolean, + isStartIdempotent: Boolean, + initialUnacknowledgedMessage: AcknowledgeHandle? + ): Flow { + val flowState = checkpoint.flowState + val resultFuture = openFuture() + val fiber = when (flowState) { + is FlowState.Unstarted -> { + val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!) + val state = StateMachineState( + checkpoint = checkpoint, + unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = isAnyCheckpointPersisted, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = logic + ) + val fiber = FlowStateMachineImpl(id, logic, scheduler) + fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) + fiber.transientState = TransientReference(state) + fiber.logic.stateMachine = fiber + fiber + } + is FlowState.Started -> { + val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!) + val state = StateMachineState( + checkpoint = checkpoint, + unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(), + isFlowResumed = false, + isTransactionTracked = false, + isAnyCheckpointPersisted = isAnyCheckpointPersisted, + isStartIdempotent = isStartIdempotent, + isRemoved = false, + flowLogic = fiber.logic + ) + fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) + fiber.transientState = TransientReference(state) + fiber.logic.stateMachine = fiber + fiber + } + } + + verifyFlowLogicIsSuspendable(fiber.logic) + + return Flow(fiber, resultFuture) + } + + private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) { + val checkpoint = flow.fiber.snapshot().checkpoint + for (sessionId in getFlowSessionIds(checkpoint)) { + sessionToFlow.put(sessionId, id) + } + concurrentBox.concurrent { + incrementLiveFibers() + unfinishedFibers.countUp() + flows.put(id, flow) + flow.fiber.scheduleEvent(Event.DoRemainingWork) + when (checkpoint.flowState) { + is FlowState.Unstarted -> { + flow.fiber.start() + } + is FlowState.Started -> { + Fiber.unparkDeserialized(flow.fiber, scheduler) + } + } + changesPublisher.onNext(StateMachineManager.Change.Add(flow.fiber.logic)) + } + } + + private fun getFlowSessionIds(checkpoint: Checkpoint): Set { + val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated + return if (initiatedFlowStart == null) { + checkpoint.sessions.keys + } else { + checkpoint.sessions.keys + initiatedFlowStart.initiatedSessionId + } + } + + private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor { + return ActionExecutorImpl( + serviceHub, + checkpointStorage, + flowMessaging, + this, + checkpointSerializationContext, + metrics + ) + } + + private fun makeTransitionExecutor(): TransitionExecutor { + val interceptors = ArrayList() + interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) } + if (serviceHub.configuration.devMode) { + interceptors.add { DumpHistoryOnErrorInterceptor(it) } + } + if (serviceHub.configuration.shouldCheckCheckpoints()) { + interceptors.add { FiberDeserializationCheckingInterceptor(fiberDeserializationChecker!!, it) } + } + if (logger.isDebugEnabled) { + interceptors.add { PrintingInterceptor(it) } + } + val transitionExecutor: TransitionExecutor = TransitionExecutorImpl(secureRandom, database) + return interceptors.fold(transitionExecutor) { executor, interceptor -> interceptor(executor) } + } + + private fun InnerState.removeFlowOrderly( + flow: Flow, + removalReason: FlowRemovalReason.OrderlyFinish, + lastState: StateMachineState + ) { + totalSuccessFlows.inc() + drainFlowEventQueue(flow) + // final sanity checks + require(lastState.unacknowledgedMessages.isEmpty()) + require(lastState.isRemoved) + require(lastState.checkpoint.subFlowStack.size == 1) + sessionToFlow.none { it.value == flow.fiber.id } + flow.resultFuture.set(removalReason.flowReturnValue) + lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE + changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Success(removalReason.flowReturnValue))) + } + + private fun InnerState.removeFlowError( + flow: Flow, + removalReason: FlowRemovalReason.ErrorFinish, + lastState: StateMachineState + ) { + totalErrorFlows.inc() + drainFlowEventQueue(flow) + val flowError = removalReason.flowErrors[0] // TODO what to do with several? + val exception = flowError.exception + (exception as? FlowException)?.originalErrorId = flowError.errorId + flow.resultFuture.setException(exception) + lastState.flowLogic.progressTracker?.endWithError(exception) + changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Failure(exception))) + } + + // The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here. + private fun drainFlowEventQueue(flow: Flow) { + while (true) { + val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return + when (event) { + is Event.DoRemainingWork -> {} + is Event.DeliverSessionMessage -> { + // Acknowledge the message so it doesn't leak in the broker. + event.acknowledgeHandle.acknowledge() + when (event.sessionMessage.payload) { + EndSessionMessage -> { + logger.debug { "Unhandled message ${event.sessionMessage} by ${flow.fiber} due to flow shutting down" } + } + else -> { + logger.warn("Unhandled message ${event.sessionMessage} by ${flow.fiber} due to flow shutting down") + } + } + } + else -> { + logger.warn("Unhandled event $event by ${flow.fiber} due to flow shutting down") + } + } + } + } +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt new file mode 100644 index 0000000000..8db2ea67a8 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionRejectException.kt @@ -0,0 +1,8 @@ +package net.corda.node.services.statemachine + +import net.corda.core.CordaException + +/** + * An exception propagated and thrown in case a session initiation fails. + */ +class SessionRejectException(reason: String) : CordaException(reason) \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt similarity index 98% rename from node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt rename to node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index e1d5d7cc92..50fba1e293 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -6,7 +6,6 @@ import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.instrument.SuspendableHelper import co.paralleluniverse.strands.channels.Channels import com.codahale.metrics.Gauge -import net.corda.core.CordaException import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.flows.FlowException @@ -46,6 +45,7 @@ import rx.subjects.PublishSubject import java.security.SecureRandom import java.util.* import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutorService import javax.annotation.concurrent.ThreadSafe import kotlin.collections.ArrayList import kotlin.streams.toList @@ -55,14 +55,14 @@ import kotlin.streams.toList * thread actually starts them via [startFlow]. */ @ThreadSafe -class StateMachineManagerImpl( +class SingleThreadedStateMachineManager( val serviceHub: ServiceHubInternal, val checkpointStorage: CheckpointStorage, - val executor: AffinityExecutor, + val executor: ExecutorService, val database: CordaPersistence, val secureRandom: SecureRandom, private val unfinishedFibers: ReusableLatch = ReusableLatch(), - private val classloader: ClassLoader = StateMachineManagerImpl::class.java.classLoader + private val classloader: ClassLoader = SingleThreadedStateMachineManager::class.java.classLoader ) : StateMachineManager, StateMachineManagerInternal { companion object { private val logger = contextLogger() @@ -145,7 +145,7 @@ class StateMachineManagerImpl( } /** - * Start the shutdown process, bringing the [StateMachineManagerImpl] to a controlled stop. When this method returns, + * Start the shutdown process, bringing the [SingleThreadedStateMachineManager] to a controlled stop. When this method returns, * all Fibers have been suspended and checkpointed, or have completed. * * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. @@ -328,7 +328,6 @@ class StateMachineManagerImpl( private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { try { - executor.checkOnThread() val recipientId = sessionMessage.recipientSessionId val flowId = sessionToFlow[recipientId] if (flowId == null) { @@ -381,7 +380,7 @@ class StateMachineManagerImpl( } if (replyError != null) { - flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom), null) + flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom)) acknowledgeHandle.acknowledge() } } @@ -439,7 +438,7 @@ class StateMachineManagerImpl( // Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties // have access to the fiber (and thereby the service hub) - val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler, totalSuccessFlows, totalErrorFlows) + val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler) val resultFuture = openFuture() flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) flowLogic.stateMachine = flowStateMachineImpl @@ -523,7 +522,7 @@ class StateMachineManagerImpl( isRemoved = false, flowLogic = logic ) - val fiber = FlowStateMachineImpl(id, logic, scheduler, totalSuccessFlows, totalErrorFlows) + val fiber = FlowStateMachineImpl(id, logic, scheduler) fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) fiber.transientState = TransientReference(state) fiber.logic.stateMachine = fiber @@ -651,6 +650,7 @@ class StateMachineManagerImpl( while (true) { val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return when (event) { + is Event.DoRemainingWork -> {} is Event.DeliverSessionMessage -> { // Acknowledge the message so it doesn't leak in the broker. event.acknowledgeHandle.acknowledge() @@ -670,5 +670,3 @@ class StateMachineManagerImpl( } } } - -class SessionRejectException(reason: String) : CordaException(reason) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt index d6449865be..ff41c37fb5 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineState.kt @@ -169,7 +169,7 @@ sealed class FlowState { val flowStart: FlowStart, val frozenFlowLogic: SerializedBytes> ) : FlowState() { - override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash}" + override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash})" } /** @@ -182,7 +182,7 @@ sealed class FlowState { val flowIORequest: FlowIORequest<*>, val frozenFiber: SerializedBytes> ) : FlowState() { - override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash}" + override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})" } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt index 127cd5a286..768768eef9 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/TransitionExecutor.kt @@ -20,6 +20,6 @@ interface TransitionExecutor { } /** - * An interceptor of a transition. These are currently explicitly hooked up in [StateMachineManagerImpl]. + * An interceptor of a transition. These are currently explicitly hooked up in [MultiThreadedStateMachineManager]. */ typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt index 7c7ed6e209..c6d62d31e2 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/DumpHistoryOnErrorInterceptor.kt @@ -35,13 +35,13 @@ class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : Transiti } if (nextState.checkpoint.errorState is ErrorState.Errored) { - log.warn("Flow ${fiber.id} dirtied, dumping all transitions:\n${record!!.joinToString("\n")}") + log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}") for (error in nextState.checkpoint.errorState.errors) { log.warn("Flow ${fiber.id} error", error.exception) } } - if (transition.newState.isRemoved) { + if (nextState.isRemoved) { records.remove(fiber.id) } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt index 8573f937e0..2143fed67e 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/interceptors/HospitalisingInterceptor.kt @@ -38,6 +38,9 @@ class HospitalisingInterceptor( } } } + if (nextState.isRemoved) { + hospitalisedFlows.remove(fiber.id) + } return Pair(continuation, nextState) } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index a731f4841e..9207330e95 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -65,7 +65,7 @@ class NodeVaultService( val updatesPublisher: rx.Observer> get() = _updatesPublisher.bufferUntilDatabaseCommit().tee(_rawUpdatesPublisher) } - private val mutex = ThreadBox(InnerState()) + private val concurrentBox = ConcurrentBox(InnerState()) private fun recordUpdate(update: Vault.Update): Vault.Update { if (!update.isEmpty()) { @@ -103,10 +103,10 @@ class NodeVaultService( } override val rawUpdates: Observable> - get() = mutex.locked { _rawUpdatesPublisher } + get() = concurrentBox.content._rawUpdatesPublisher override val updates: Observable> - get() = mutex.locked { _updatesInDbTx } + get() = concurrentBox.content._updatesInDbTx override fun notifyAll(statesToRecord: StatesToRecord, txns: Iterable) { if (statesToRecord == StatesToRecord.NONE) @@ -205,7 +205,7 @@ class NodeVaultService( private fun processAndNotify(update: Vault.Update) { if (!update.isEmpty()) { recordUpdate(update) - mutex.locked { + concurrentBox.concurrent { // flowId required by SoftLockManager to perform auto-registration of soft locks for new states val uuid = (Strand.currentStrand() as? FlowStateMachineImpl<*>)?.id?.uuid val vaultUpdate = if (uuid != null) update.copy(flowId = uuid) else update @@ -387,7 +387,7 @@ class NodeVaultService( @Throws(VaultQueryException::class) override fun _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class): Vault.Page { - log.info("Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting") + log.debug {"Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting" } // calculate total results where a page specification has been defined var totalStates = -1L if (!paging.isDefault) { @@ -468,7 +468,7 @@ class NodeVaultService( @Throws(VaultQueryException::class) override fun _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class): DataFeed, Vault.Update> { - return mutex.locked { + return concurrentBox.exclusive { val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType) val updates: Observable> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) }) DataFeed(snapshotResults, updates) diff --git a/node/src/main/resources/reference.conf b/node/src/main/resources/reference.conf index b5547910dd..646599c601 100644 --- a/node/src/main/resources/reference.conf +++ b/node/src/main/resources/reference.conf @@ -31,6 +31,14 @@ enterpriseConfiguration = { updateInterval = 20000 waitInterval = 40000 } + tuning = { + flowThreadPoolSize = 1 + rpcThreadPoolSize = 4 + maximumMessagingBatchSize = 256 + p2pConfirmationWindowSize = 1048576 + brokerConnectionTtlCheckIntervalMs = 20 + } + useMultiThreadedSMM = true } rpcSettings = { useSsl = false diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index a063bca8b6..81b579db57 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -1,10 +1,7 @@ package net.corda.node.services.events -import com.google.common.util.concurrent.MoreExecutors import com.nhaarman.mockito_kotlin.* import net.corda.core.contracts.* -import net.corda.core.crypto.generateKeyPair -import net.corda.core.crypto.newSecureRandom import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRef import net.corda.core.flows.FlowLogicRefFactory @@ -58,7 +55,6 @@ class NodeSchedulerServiceTest { database, flowStarter, stateLoader, - serverThread = MoreExecutors.directExecutor(), flowLogicRefFactory = flowLogicRefFactory, log = log, scheduledStates = mutableMapOf()).apply { start() } diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt index 02dc19df6c..c3f9f73b89 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt @@ -1,17 +1,12 @@ package net.corda.node.services.messaging +import com.codahale.metrics.MetricRegistry import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.whenever import net.corda.core.crypto.generateKeyPair -import net.corda.core.concurrent.CordaFuture -import com.codahale.metrics.MetricRegistry -import net.corda.core.crypto.generateKeyPair -import net.corda.core.internal.concurrent.openFuture import net.corda.core.utilities.NetworkHostAndPort import net.corda.node.internal.configureDatabase -import net.corda.node.services.config.CertChainPolicyConfig -import net.corda.node.services.config.NodeConfiguration -import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.node.services.config.* import net.corda.node.services.network.NetworkMapCacheImpl import net.corda.node.services.network.PersistentNetworkMapCache import net.corda.node.services.transactions.PersistentUniquenessProvider @@ -73,6 +68,7 @@ class ArtemisMessagingTest { doReturn("").whenever(it).exportJMXto doReturn(emptyList()).whenever(it).certificateChainCheckPolicies doReturn(5).whenever(it).messageRedeliveryDelaySeconds + doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration } LogHelper.setLevel(PersistentUniquenessProvider::class) database = configureDatabase(makeTestDataSourceProperties(), DatabaseConfig(runMigration = true), rigorousMock()) @@ -176,6 +172,7 @@ class ArtemisMessagingTest { ServiceAffinityExecutor("ArtemisMessagingTests", 1), database, networkMapCache, + MetricRegistry(), maxMessageSize = maxMessageSize).apply { config.configureWithDevSSLCertificate() messagingClient = this diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 26013e7c91..f6e924e542 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -72,10 +72,6 @@ class FlowFrameworkTests { private lateinit var alice: Party private lateinit var bob: Party - private fun StartedNode<*>.flushSmm() { - (this.smm as StateMachineManagerImpl).executor.flush() - } - @Before fun start() { mockNet = MockNetwork( @@ -165,7 +161,6 @@ class FlowFrameworkTests { aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow // Make sure the add() has finished initial processing. - bobNode.flushSmm() bobNode.internals.disableDBCloseOnStop() bobNode.dispose() // kill receiver val restoredFlow = bobNode.restartAndGetRestoredFlow() @@ -191,7 +186,6 @@ class FlowFrameworkTests { assertEquals(1, bobNode.checkpointStorage.checkpoints().size) } // Make sure the add() has finished initial processing. - bobNode.flushSmm() bobNode.internals.disableDBCloseOnStop() // Restart node and thus reload the checkpoint and resend the message with same UUID bobNode.dispose() @@ -204,7 +198,6 @@ class FlowFrameworkTests { val (firstAgain, fut1) = node2b.getSingleFlow() // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. mockNet.runNetwork() - node2b.flushSmm() fut1.getOrThrow() val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } @@ -575,7 +568,7 @@ class FlowFrameworkTests { @Test fun `customised client flow which has annotated @InitiatingFlow again`() { - assertThatExceptionOfType(ExecutionException::class.java).isThrownBy { + assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture }.withMessageContaining(InitiatingFlow::class.java.simpleName) } diff --git a/perftestcordapp/build.gradle b/perftestcordapp/build.gradle index 48142030f1..a1456642b9 100644 --- a/perftestcordapp/build.gradle +++ b/perftestcordapp/build.gradle @@ -43,7 +43,7 @@ dependencies { // TODO Remove this once we have app configs compile "com.typesafe:config:$typesafe_config_version" - testCompile project(':test-utils') + testCompile project(':node-driver') testCompile project(path: ':core', configuration: 'testArtifacts') testCompile "junit:junit:$junit_version" diff --git a/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt b/perftestcordapp/src/integrationTest/kotlin/com/r3/corda/enterprise/perftestcordapp/NodePerformanceTests.kt similarity index 58% rename from node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt rename to perftestcordapp/src/integrationTest/kotlin/com/r3/corda/enterprise/perftestcordapp/NodePerformanceTests.kt index 58a7b7e8ff..aaf4379353 100644 --- a/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt +++ b/perftestcordapp/src/integrationTest/kotlin/com/r3/corda/enterprise/perftestcordapp/NodePerformanceTests.kt @@ -1,22 +1,19 @@ -package net.corda.node +package com.r3.corda.enterprise.perftestcordapp import co.paralleluniverse.fibers.Suspendable import com.google.common.base.Stopwatch +import com.r3.corda.enterprise.perftestcordapp.flows.CashIssueAndPaymentNoSelection import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC -import net.corda.core.internal.concurrent.transpose import net.corda.core.messaging.startFlow import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.minutes import net.corda.finance.DOLLARS +import net.corda.finance.flows.CashIssueAndPaymentFlow import net.corda.finance.flows.CashIssueFlow -import net.corda.finance.flows.CashPaymentFlow import net.corda.node.services.Permissions.Companion.startFlow -import net.corda.testing.core.ALICE_NAME -import net.corda.testing.core.DUMMY_BANK_A_NAME -import net.corda.testing.core.DUMMY_NOTARY_NAME -import net.corda.testing.core.TestIdentity +import net.corda.testing.core.* import net.corda.testing.driver.NodeHandle import net.corda.testing.driver.PortAllocation import net.corda.testing.driver.driver @@ -73,7 +70,7 @@ class NodePerformanceTests : IntegrationTest() { queueBound = 50 ) { val timing = Stopwatch.createStarted().apply { - connection.proxy.startFlow(::EmptyFlow).returnValue.getOrThrow() + connection.proxy.startFlow(NodePerformanceTests::EmptyFlow).returnValue.getOrThrow() }.stop().elapsed(TimeUnit.MICROSECONDS) timings.add(timing) } @@ -95,8 +92,14 @@ class NodePerformanceTests : IntegrationTest() { a as NodeHandle.InProcess val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, a.node.services.monitoringService.metrics) a.rpcClientToNode().use("A", "A") { connection -> - startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) { - connection.proxy.startFlow(::EmptyFlow).returnValue.get() + startPublishingFixedRateInjector( + metricRegistry = metricRegistry, + parallelism = 16, + overallDuration = 5.minutes, + injectionRate = 2000L / TimeUnit.SECONDS, + workBound = 50 + ) { + connection.proxy.startFlow(NodePerformanceTests::EmptyFlow).returnValue } } } @@ -109,8 +112,14 @@ class NodePerformanceTests : IntegrationTest() { a as NodeHandle.InProcess val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, a.node.services.monitoringService.metrics) a.rpcClientToNode().use("A", "A") { connection -> - startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) { - connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), ALICE).returnValue.get() + startPublishingFixedRateInjector( + metricRegistry = metricRegistry, + parallelism = 16, + overallDuration = 5.minutes, + injectionRate = 2000L / TimeUnit.SECONDS, + workBound = 50 + ) { + connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), ALICE).returnValue } } } @@ -118,24 +127,50 @@ class NodePerformanceTests : IntegrationTest() { @Test fun `self pay rate`() { - val user = User("A", "A", setOf(startFlow(), startFlow())) + val user = User("A", "A", setOf(startFlow())) driver( notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, rpcUsers = listOf(user))), startNodesInProcess = true, - extraCordappPackagesToScan = listOf("net.corda.finance"), + extraCordappPackagesToScan = listOf("net.corda.finance", "com.r3.corda.enterprise.perftestcordapp"), portAllocation = PortAllocation.Incremental(20000) ) { val notary = defaultNotaryNode.getOrThrow() as NodeHandle.InProcess val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, notary.node.services.monitoringService.metrics) notary.rpcClientToNode().use("A", "A") { connection -> - println("ISSUING") - val doneFutures = (1..100).toList().map { - connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), defaultNotaryIdentity).returnValue - }.toList() - doneFutures.transpose().get() - println("STARTING PAYMENT") - startPublishingFixedRateInjector(metricRegistry, 8, 5.minutes, 5L / TimeUnit.SECONDS) { - connection.proxy.startFlow(::CashPaymentFlow, 1.DOLLARS, defaultNotaryIdentity).returnValue.get() + startPublishingFixedRateInjector( + metricRegistry = metricRegistry, + parallelism = 64, + overallDuration = 5.minutes, + injectionRate = 300L / TimeUnit.SECONDS, + workBound = 50 + ) { + connection.proxy.startFlow(::CashIssueAndPaymentFlow, 1.DOLLARS, OpaqueBytes.of(0), defaultNotaryIdentity, false, defaultNotaryIdentity).returnValue + } + } + } + } + + @Test + fun `self pay rate without selection`() { + val user = User("A", "A", setOf(startFlow())) + driver( + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME)), + startNodesInProcess = true, + portAllocation = PortAllocation.Incremental(20000) + ) { + val aliceFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true) + val alice = aliceFuture.getOrThrow() as NodeHandle.InProcess + defaultNotaryNode.getOrThrow() + val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, alice.node.services.monitoringService.metrics) + alice.rpcClientToNode().use("A", "A") { connection -> + startPublishingFixedRateInjector( + metricRegistry = metricRegistry, + parallelism = 64, + overallDuration = 5.minutes, + injectionRate = 50L / TimeUnit.SECONDS, + workBound = 500 + ) { + connection.proxy.startFlow(::CashIssueAndPaymentNoSelection, 1.DOLLARS, OpaqueBytes.of(0), alice.nodeInfo.legalIdentities[0], false, defaultNotaryIdentity).returnValue } } } @@ -143,18 +178,19 @@ class NodePerformanceTests : IntegrationTest() { @Test fun `single pay`() { - val user = User("A", "A", setOf(startFlow(), startFlow())) + val user = User("A", "A", setOf(startFlow())) driver( - notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, rpcUsers = listOf(user))), + notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME)), startNodesInProcess = true, - extraCordappPackagesToScan = listOf("net.corda.finance"), portAllocation = PortAllocation.Incremental(20000) ) { - val notary = defaultNotaryNode.getOrThrow() as NodeHandle.InProcess - val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, notary.node.services.monitoringService.metrics) - notary.rpcClientToNode().use("A", "A") { connection -> - connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), defaultNotaryIdentity).returnValue.getOrThrow() - connection.proxy.startFlow(::CashPaymentFlow, 1.DOLLARS, defaultNotaryIdentity).returnValue.getOrThrow() + val aliceFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)) + val bobFuture = startNode(providedName = BOB_NAME, rpcUsers = listOf(user)) + val alice = aliceFuture.getOrThrow() as NodeHandle.InProcess + val bob = bobFuture.getOrThrow() as NodeHandle.InProcess + defaultNotaryNode.getOrThrow() + alice.rpcClientToNode().use("A", "A") { connection -> + connection.proxy.startFlow(::CashIssueAndPaymentNoSelection, 1.DOLLARS, OpaqueBytes.of(0), bob.nodeInfo.legalIdentities[0], false, defaultNotaryIdentity).returnValue.getOrThrow() } } } diff --git a/perftestcordapp/src/main/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/CashIssueAndPaymentNoSelection.kt b/perftestcordapp/src/main/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/CashIssueAndPaymentNoSelection.kt index 1cc43611a1..92261509ef 100644 --- a/perftestcordapp/src/main/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/CashIssueAndPaymentNoSelection.kt +++ b/perftestcordapp/src/main/kotlin/com/r3/corda/enterprise/perftestcordapp/flows/CashIssueAndPaymentNoSelection.kt @@ -41,6 +41,7 @@ class CashIssueAndPaymentNoSelection(val amount: Amount, fun deriveState(txState: TransactionState, amt: Amount>, owner: AbstractParty) = txState.copy(data = txState.data.copy(amount = amt, owner = owner)) + progressTracker.currentStep = GENERATING_TX val issueResult = subFlow(CashIssueFlow(amount, issueRef, notary)) val cashStateAndRef = issueResult.stx.tx.outRef(0) diff --git a/samples/irs-demo/cordapp/src/main/kotlin/net/corda/irs/flows/AutoOfferFlow.kt b/samples/irs-demo/cordapp/src/main/kotlin/net/corda/irs/flows/AutoOfferFlow.kt index 9da48655a9..229daa9c5c 100644 --- a/samples/irs-demo/cordapp/src/main/kotlin/net/corda/irs/flows/AutoOfferFlow.kt +++ b/samples/irs-demo/cordapp/src/main/kotlin/net/corda/irs/flows/AutoOfferFlow.kt @@ -71,8 +71,8 @@ object AutoOfferFlow { // and because in a real life app you'd probably have more complex logic here e.g. describing why the report // was filed, checking that the reportee is a regulated entity and not some random node from the wrong // country and so on. - val regulator = serviceHub.identityService.partiesFromName("Regulator", true).single() - subFlow(ReportToRegulatorFlow(regulator, finalTx)) + // val regulator = serviceHub.identityService.partiesFromName("Regulator", true).single() + // subFlow(ReportToRegulatorFlow(regulator, finalTx)) return finalTx } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index 3859f4adc1..6c2ee03fe4 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -327,20 +327,18 @@ class InMemoryMessagingNetwork internal constructor( state.locked { check(handlers.remove(registration as Handler)) } } - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { check(running) msgSend(this, message, target) - acknowledgementHandler?.invoke() if (!sendManuallyPumped) { pumpSend(false) } } - override fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)?) { + override fun send(addressedMessages: List) { for ((message, target, retryId, sequenceKey) in addressedMessages) { - send(message, target, retryId, sequenceKey, null) + send(message, target, retryId, sequenceKey) } - acknowledgementHandler?.invoke() } override fun stop() { diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt index 43dfdf1735..0c776f0682 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -37,7 +37,6 @@ import net.corda.node.services.messaging.MessagingService import net.corda.node.services.transactions.BFTNonValidatingNotaryService import net.corda.node.services.transactions.BFTSMaRt import net.corda.node.services.transactions.InMemoryTransactionVerifierService -import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor import net.corda.nodeapi.internal.DevIdentityGenerator import net.corda.nodeapi.internal.config.User @@ -45,14 +44,14 @@ import net.corda.nodeapi.internal.network.NetworkParametersCopier import net.corda.nodeapi.internal.network.NotaryInfo import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.DatabaseConfig -import net.corda.testing.core.DUMMY_NOTARY_NAME import net.corda.testing.common.internal.testNetworkParameters +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.setGlobalSerialization import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.testThreadFactory import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDatabaseProperties -import net.corda.testing.core.setGlobalSerialization import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.sshd.common.util.security.SecurityUtils import rx.internal.schedulers.CachedThreadScheduler @@ -270,7 +269,7 @@ open class MockNetwork(private val cordappPackages: List, private val entropyRoot = args.entropyRoot var counter = entropyRoot override val log get() = staticLog - override val serverThread: AffinityExecutor = + override val serverThread = if (mockNet.threadPerNode) { ServiceAffinityExecutor("Mock node $id thread", 1) } else { @@ -514,6 +513,9 @@ private fun mockNodeConfiguration(): NodeConfiguration { doReturn(5).whenever(it).messageRedeliveryDelaySeconds doReturn(5.seconds.toMillis()).whenever(it).additionalNodeInfoPollingFrequencyMsec doReturn(null).whenever(it).devModeOptions - doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration + doReturn(EnterpriseConfiguration( + mutualExclusionConfiguration = MutualExclusionConfiguration(false, "", 20000, 40000), + useMultiThreadedSMM = false + )).whenever(it).enterpriseConfiguration } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt index 612f499510..f475694aeb 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/NodeBasedTest.kt @@ -8,6 +8,7 @@ import net.corda.core.internal.div import net.corda.core.node.NodeInfo import net.corda.core.utilities.getOrThrow import net.corda.node.VersionInfo +import net.corda.node.internal.EnterpriseNode import net.corda.node.internal.Node import net.corda.node.internal.StartedNode import net.corda.node.internal.cordapp.CordappLoader @@ -128,7 +129,7 @@ abstract class NodeBasedTest(private val cordappPackages: List = emptyLi } class InProcessNode( - configuration: NodeConfiguration, versionInfo: VersionInfo, cordappPackages: List) : Node( + configuration: NodeConfiguration, versionInfo: VersionInfo, cordappPackages: List) : EnterpriseNode( configuration, versionInfo, false, CordappLoader.createDefaultWithTestPackages(configuration, cordappPackages)) { override fun getRxIoScheduler() = CachedThreadScheduler(testThreadFactory()).also { runOnStop += it::shutdown } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/performance/Injectors.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/performance/Injectors.kt index 0e73bbf070..eafd572ecb 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/performance/Injectors.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/performance/Injectors.kt @@ -1,10 +1,14 @@ package net.corda.testing.node.internal.performance + import com.codahale.metrics.Gauge import com.codahale.metrics.MetricRegistry import com.google.common.base.Stopwatch +import net.corda.core.concurrent.CordaFuture +import net.corda.core.utilities.getOrThrow import net.corda.testing.internal.performance.Rate import net.corda.testing.node.internal.ShutdownManager +import org.slf4j.LoggerFactory import java.time.Duration import java.util.* import java.util.concurrent.CountDownLatch @@ -16,6 +20,7 @@ import java.util.concurrent.locks.ReentrantLock import kotlin.concurrent.thread import kotlin.concurrent.withLock +private val log = LoggerFactory.getLogger("TightLoopInjector") fun startTightLoopInjector( parallelism: Int, numberOfInjections: Int, @@ -34,7 +39,11 @@ fun startTightLoopInjector( while (true) { if (leftToSubmit.getAndDecrement() == 0) break executor.submit { - work() + try { + work() + } catch (exception: Exception) { + log.error("Error while executing injection", exception) + } if (queuedCount.decrementAndGet() < queueBound / 2) { lock.withLock { canQueueAgain.signal() @@ -60,11 +69,13 @@ fun startPublishingFixedRateInjector( parallelism: Int, overallDuration: Duration, injectionRate: Rate, + workBound: Int, queueSizeMetricName: String = "QueueSize", workDurationMetricName: String = "WorkDuration", - work: () -> Unit + work: () -> CordaFuture<*> ) { val workSemaphore = Semaphore(0) + val workBoundSemaphore = Semaphore(workBound) metricRegistry.register(queueSizeMetricName, Gauge { workSemaphore.availablePermits() }) val workDurationTimer = metricRegistry.timer(workDurationMetricName) ShutdownManager.run { @@ -72,19 +83,16 @@ fun startPublishingFixedRateInjector( registerShutdown { executor.shutdown() } val workExecutor = Executors.newFixedThreadPool(parallelism) registerShutdown { workExecutor.shutdown() } - val timings = Collections.synchronizedList(ArrayList()) for (i in 1..parallelism) { workExecutor.submit { try { while (true) { workSemaphore.acquire() + workBoundSemaphore.acquire() workDurationTimer.time { - timings.add( - Stopwatch.createStarted().apply { - work() - }.stop().elapsed(TimeUnit.MICROSECONDS) - ) + work().getOrThrow() } + workBoundSemaphore.release() } } catch (throwable: Throwable) { throwable.printStackTrace() @@ -105,4 +113,3 @@ fun startPublishingFixedRateInjector( Thread.sleep(overallDuration.toMillis()) } } -