Multi threaded state machine

This commit is contained in:
Andras Slemmer 2017-11-22 16:03:44 +00:00
parent 2edf632f7f
commit b71f0c49fb
50 changed files with 1515 additions and 375 deletions

View File

@ -2,6 +2,7 @@ package net.corda.client.rpc
import com.google.common.base.Stopwatch import com.google.common.base.Stopwatch
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
@ -144,10 +145,11 @@ class RPCPerformanceTests : AbstractRPCTest() {
parallelism = 8, parallelism = 8,
overallDuration = 5.minutes, overallDuration = 5.minutes,
injectionRate = 20000L / TimeUnit.SECONDS, injectionRate = 20000L / TimeUnit.SECONDS,
workBound = 50,
queueSizeMetricName = "$mode.QueueSize", queueSizeMetricName = "$mode.QueueSize",
workDurationMetricName = "$mode.WorkDuration", workDurationMetricName = "$mode.WorkDuration",
work = { work = {
proxy.ops.simpleReply(ByteArray(4096), 4096) doneFuture(proxy.ops.simpleReply(ByteArray(4096), 4096))
} }
) )
} }

View File

@ -0,0 +1,17 @@
package net.corda.core.internal
import java.util.concurrent.locks.ReentrantReadWriteLock
import kotlin.concurrent.read
import kotlin.concurrent.write
/**
* A [ConcurrentBox] allows the implementation of track() with reduced contention. [concurrent] may be run from several
* threads (which means it MUST be threadsafe!), while [exclusive] stops the world until the tracking has been set up.
* Internally [ConcurrentBox] is implemented simply as a read-write lock.
*/
class ConcurrentBox<out T>(val content: T) {
val lock = ReentrantReadWriteLock()
inline fun <R> concurrent(block: T.() -> R): R = lock.read { block(content) }
inline fun <R> exclusive(block: T.() -> R): R = lock.write { block(content) }
}

View File

@ -17,6 +17,7 @@ import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.debug
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import java.util.* import java.util.*
@ -72,7 +73,7 @@ sealed class FetchDataFlow<T : NamedByHash, in W : Any>(
return if (toFetch.isEmpty()) { return if (toFetch.isEmpty()) {
Result(fromDisk, emptyList()) Result(fromDisk, emptyList())
} else { } else {
logger.info("Requesting ${toFetch.size} dependency(s) for verification from ${otherSideSession.counterparty.name}") logger.debug { "Requesting ${toFetch.size} dependency(s) for verification from ${otherSideSession.counterparty.name}" }
// TODO: Support "large message" response streaming so response sizes are not limited by RAM. // TODO: Support "large message" response streaming so response sizes are not limited by RAM.
// We can then switch to requesting items in large batches to minimise the latency penalty. // We can then switch to requesting items in large batches to minimise the latency penalty.
@ -89,7 +90,7 @@ sealed class FetchDataFlow<T : NamedByHash, in W : Any>(
} }
// Check for a buggy/malicious peer answering with something that we didn't ask for. // Check for a buggy/malicious peer answering with something that we didn't ask for.
val downloaded = validateFetchResponse(UntrustworthyData(maybeItems), toFetch) val downloaded = validateFetchResponse(UntrustworthyData(maybeItems), toFetch)
logger.info("Fetched ${downloaded.size} elements from ${otherSideSession.counterparty.name}") logger.debug { "Fetched ${downloaded.size} elements from ${otherSideSession.counterparty.name}" }
maybeWriteToDisk(downloaded) maybeWriteToDisk(downloaded)
Result(fromDisk, downloaded) Result(fromDisk, downloaded)
} }

View File

@ -13,19 +13,38 @@ class LifeCycle<S : Enum<S>>(initial: S) {
private val lock = ReentrantReadWriteLock() private val lock = ReentrantReadWriteLock()
private var state = initial private var state = initial
/** Assert that the lifecycle in the [requiredState]. */ /**
fun requireState(requiredState: S) { * Assert that the lifecycle in the [requiredState]. Optionally runs [block], for the duration of which the
requireState({ "Required state to be $requiredState, was $it" }) { it == requiredState } * lifecycle is guaranteed to stay in [requiredState].
*/
fun <A> requireState(
requiredState: S,
block: () -> A
): A {
return requireState(
errorMessage = { "Required state to be $requiredState, was $it" },
predicate = { it == requiredState },
block = block
)
} }
fun requireState(requiredState: S) = requireState(requiredState) {}
/** Assert something about the current state atomically. */ /** Assert something about the current state atomically. */
fun <A> requireState(
errorMessage: (S) -> String,
predicate: (S) -> Boolean,
block: () -> A
): A {
return lock.readLock().withLock {
require(predicate(state)) { errorMessage(state) }
block()
}
}
fun requireState( fun requireState(
errorMessage: (S) -> String = { "Predicate failed on state $it" }, errorMessage: (S) -> String = { "Predicate failed on state $it" },
predicate: (S) -> Boolean predicate: (S) -> Boolean
) { ) {
lock.readLock().withLock { requireState(errorMessage, predicate) {}
require(predicate(state)) { errorMessage(state) }
}
} }
/** Transition the state from [from] to [to]. */ /** Transition the state from [from] to [to]. */

View File

@ -5,6 +5,7 @@ import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.startFlow import net.corda.core.messaging.startFlow
import net.corda.core.messaging.vaultTrackBy import net.corda.core.messaging.vaultTrackBy
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.finance.DOLLARS import net.corda.finance.DOLLARS
@ -63,14 +64,15 @@ class IntegrationTestingTutorial : IntegrationTest() {
// END 2 // END 2
// START 3 // START 3
val bobVaultUpdates = bobProxy.vaultTrackBy<Cash.State>().updates val bobVaultUpdates = bobProxy.vaultTrackBy<Cash.State>(criteria = QueryCriteria.VaultQueryCriteria(status = Vault.StateStatus.ALL)).updates
val aliceVaultUpdates = aliceProxy.vaultTrackBy<Cash.State>().updates val aliceVaultUpdates = aliceProxy.vaultTrackBy<Cash.State>(criteria = QueryCriteria.VaultQueryCriteria(status = Vault.StateStatus.ALL)).updates
// END 3 // END 3
// START 4 // START 4
val numberOfStates = 10
val issueRef = OpaqueBytes.of(0) val issueRef = OpaqueBytes.of(0)
val notaryParty = aliceProxy.notaryIdentities().first() val notaryParty = aliceProxy.notaryIdentities().first()
(1..10).map { i -> (1..numberOfStates).map { i ->
aliceProxy.startFlow(::CashIssueFlow, aliceProxy.startFlow(::CashIssueFlow,
i.DOLLARS, i.DOLLARS,
issueRef, issueRef,
@ -78,7 +80,7 @@ class IntegrationTestingTutorial : IntegrationTest() {
).returnValue ).returnValue
}.transpose().getOrThrow() }.transpose().getOrThrow()
// We wait for all of the issuances to run before we start making payments // We wait for all of the issuances to run before we start making payments
(1..10).map { i -> (1..numberOfStates).map { i ->
aliceProxy.startFlow(::CashPaymentFlow, aliceProxy.startFlow(::CashPaymentFlow,
i.DOLLARS, i.DOLLARS,
bob.nodeInfo.chooseIdentity(), bob.nodeInfo.chooseIdentity(),
@ -88,7 +90,7 @@ class IntegrationTestingTutorial : IntegrationTest() {
bobVaultUpdates.expectEvents { bobVaultUpdates.expectEvents {
parallel( parallel(
(1..10).map { i -> (1..numberOfStates).map { i ->
expect( expect(
match = { update: Vault.Update<Cash.State> -> match = { update: Vault.Update<Cash.State> ->
update.produced.first().state.data.amount.quantity == i * 100L update.produced.first().state.data.amount.quantity == i * 100L
@ -102,21 +104,44 @@ class IntegrationTestingTutorial : IntegrationTest() {
// END 4 // END 4
// START 5 // START 5
for (i in 1..10) { for (i in 1..numberOfStates) {
bobProxy.startFlow(::CashPaymentFlow, i.DOLLARS, alice.nodeInfo.chooseIdentity()).returnValue.getOrThrow() bobProxy.startFlow(::CashPaymentFlow, i.DOLLARS, alice.nodeInfo.chooseIdentity()).returnValue.getOrThrow()
} }
aliceVaultUpdates.expectEvents { aliceVaultUpdates.expectEvents {
sequence( sequence(
(1..10).map { i -> // issuance
expect { update: Vault.Update<Cash.State> -> parallel(
println("Alice got vault update of $update") (1..numberOfStates).map { i ->
assertEquals(update.produced.first().state.data.amount.quantity, i * 100L) expect(match = { it.moved() == -i * 100 }) { update: Vault.Update<Cash.State> ->
} assertEquals(0, update.consumed.size)
} }
}
),
// move to Bob
parallel(
(1..numberOfStates).map { i ->
expect(match = { it.moved() == i * 100 }) { update: Vault.Update<Cash.State> ->
}
}
),
// move back to Alice
sequence(
(1..numberOfStates).map { i ->
expect(match = { it.moved() == -i * 100 }) { update: Vault.Update<Cash.State> ->
assertEquals(update.consumed.size, 0)
}
}
)
) )
} }
// END 5 // END 5
} }
} }
fun Vault.Update<Cash.State>.moved(): Int {
val consumedSum = consumed.sumBy { it.state.data.amount.quantity.toInt() }
val producedSum = produced.sumBy { it.state.data.amount.quantity.toInt() }
return consumedSum - producedSum
}
} }

View File

@ -85,7 +85,7 @@ class TutorialMockNetwork {
// modify message if it's 1 // modify message if it's 1
nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) { nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
val messageData = message.data.deserialize<Any>() as? ExistingSessionMessage val messageData = message.data.deserialize<Any>() as? ExistingSessionMessage
val payload = messageData?.payload val payload = messageData?.payload
if (payload is DataSessionMessage && payload.payload.deserialize() == 1) { if (payload is DataSessionMessage && payload.payload.deserialize() == 1) {

View File

@ -3,6 +3,9 @@ package net.corda.flowhook
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.flowhook.FiberMonitor.correlator
import net.corda.flowhook.FiberMonitor.inspect
import net.corda.flowhook.FiberMonitor.newEvent
import net.corda.nodeapi.internal.persistence.DatabaseTransaction import net.corda.nodeapi.internal.persistence.DatabaseTransaction
import java.sql.Connection import java.sql.Connection
import java.time.Instant import java.time.Instant
@ -10,11 +13,12 @@ import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
import kotlin.concurrent.thread
data class MonitorEvent(val type: MonitorEventType, val keys: List<Any>, val extra: Any? = null) data class MonitorEvent(val type: MonitorEventType, val keys: List<Any>, val extra: Any? = null)
data class FullMonitorEvent(val timestamp: Instant, val trace: List<StackTraceElement>, val event: MonitorEvent) { data class FullMonitorEvent(val timestamp: Instant, val trace: List<StackTraceElement>, val event: MonitorEvent) {
override fun toString() = event.toString() override fun toString() = "$timestamp: ${event.type}"
} }
enum class MonitorEventType { enum class MonitorEventType {
@ -25,6 +29,8 @@ enum class MonitorEventType {
FiberStarted, FiberStarted,
FiberParking, FiberParking,
FiberParked,
FiberResuming,
FiberException, FiberException,
FiberResumed, FiberResumed,
FiberEnded, FiberEnded,
@ -37,7 +43,12 @@ enum class MonitorEventType {
SetThreadLocals, SetThreadLocals,
SetInheritableThreadLocals, SetInheritableThreadLocals,
GetThreadLocals, GetThreadLocals,
GetInheritableThreadLocals GetInheritableThreadLocals,
SendSessionMessage,
BrokerFlushStart,
BrokerFlushEnd
} }
/** /**
@ -58,11 +69,26 @@ object FiberMonitor {
private val started = AtomicBoolean(false) private val started = AtomicBoolean(false)
private var executor: ScheduledExecutorService? = null private var executor: ScheduledExecutorService? = null
val correlator = MonitorEventCorrelator() private val correlator = MonitorEventCorrelator()
private val eventsToDrop = setOf(
MonitorEventType.TransactionCreated,
MonitorEventType.ConnectionRequested,
MonitorEventType.ConnectionAcquired,
MonitorEventType.ConnectionReleased,
MonitorEventType.NettyThreadLocalMapCreated,
MonitorEventType.SetThreadLocals,
MonitorEventType.SetInheritableThreadLocals,
MonitorEventType.GetThreadLocals,
MonitorEventType.GetInheritableThreadLocals
)
fun newEvent(event: MonitorEvent) { fun newEvent(event: MonitorEvent) {
if (executor != null) { if (executor != null) {
val fullEvent = FullMonitorEvent(Instant.now(), Exception().stackTrace.toList(), event) val fullEvent = FullMonitorEvent(Instant.now(), Exception().stackTrace.toList(), event)
if (event.type in eventsToDrop) {
return
}
executor!!.execute { executor!!.execute {
processEvent(fullEvent) processEvent(fullEvent)
} }
@ -75,6 +101,12 @@ object FiberMonitor {
executor = Executors.newSingleThreadScheduledExecutor() executor = Executors.newSingleThreadScheduledExecutor()
executor!!.scheduleAtFixedRate(this::inspect, 100, 100, TimeUnit.MILLISECONDS) executor!!.scheduleAtFixedRate(this::inspect, 100, 100, TimeUnit.MILLISECONDS)
} }
thread {
while (true) {
Thread.sleep(1000)
this
}
}
} }
// Break on this function or [newEvent]. // Break on this function or [newEvent].
@ -174,58 +206,49 @@ class MonitorEventCorrelator {
fun getByType() = merged().entries.groupBy { it.key.javaClass } fun getByType() = merged().entries.groupBy { it.key.javaClass }
fun addEvent(fullMonitorEvent: FullMonitorEvent) { fun addEvent(fullMonitorEvent: FullMonitorEvent) {
events.add(fullMonitorEvent) synchronized(events) {
events.add(fullMonitorEvent)
}
} }
fun merged(): Map<Any, List<FullMonitorEvent>> { fun merged(): Map<Any, List<FullMonitorEvent>> {
val merged = HashMap<Any, ArrayList<FullMonitorEvent>>() val keyToEvents = HashMap<Any, HashSet<FullMonitorEvent>>()
for (event in events) {
val eventLists = HashSet<ArrayList<FullMonitorEvent>>() synchronized(events) {
for (key in event.event.keys) { for (event in events) {
val list = merged[key] for (key in event.event.keys) {
if (list != null) { keyToEvents.getOrPut(key) { HashSet() }.add(event)
eventLists.add(list)
} }
} }
val newList = when (eventLists.size) { }
0 -> ArrayList()
1 -> eventLists.first() val components = ArrayList<Set<Any>>()
else -> mergeAll(eventLists) val visited = HashSet<Any>()
for (root in keyToEvents.keys) {
if (root in visited) {
continue
} }
newList.add(event) val component = HashSet<Any>()
for (key in event.event.keys) { val toVisit = arrayListOf(root)
merged[key] = newList while (toVisit.isNotEmpty()) {
val current = toVisit.removeAt(toVisit.size - 1)
if (current in visited) {
continue
}
toVisit.addAll(keyToEvents[current]!!.flatMapTo(HashSet()) { it.event.keys })
component.add(current)
visited.add(current)
}
components.add(component)
}
val merged = HashMap<Any, List<FullMonitorEvent>>()
for (component in components) {
val eventList = component.flatMapTo(HashSet()) { keyToEvents[it]!! }.sortedBy { it.timestamp }
for (key in component) {
merged[key] = eventList
} }
} }
return merged return merged
} }
}
fun mergeAll(lists: Collection<List<FullMonitorEvent>>): ArrayList<FullMonitorEvent> {
return lists.fold(ArrayList()) { merged, next -> merge(merged, next) }
}
fun merge(a: List<FullMonitorEvent>, b: List<FullMonitorEvent>): ArrayList<FullMonitorEvent> {
val merged = ArrayList<FullMonitorEvent>()
var aIndex = 0
var bIndex = 0
while (true) {
if (aIndex >= a.size) {
merged.addAll(b.subList(bIndex, b.size))
return merged
}
if (bIndex >= b.size) {
merged.addAll(a.subList(aIndex, a.size))
return merged
}
val aElem = a[aIndex]
val bElem = b[bIndex]
if (aElem.timestamp < bElem.timestamp) {
merged.add(aElem)
aIndex++
} else {
merged.add(bElem)
bIndex++
}
}
}
}

View File

@ -1,19 +1,37 @@
package net.corda.flowhook package net.corda.flowhook
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import net.corda.core.internal.declaredField
import net.corda.node.services.statemachine.Event import net.corda.node.services.statemachine.Event
import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.node.services.statemachine.InitialSessionMessage
import net.corda.node.services.statemachine.SessionMessage
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import org.apache.activemq.artemis.core.io.buffer.TimedBuffer
import java.sql.Connection import java.sql.Connection
import java.util.concurrent.TimeUnit
@Suppress("UNUSED") @Suppress("UNUSED")
object FlowHookContainer { object FlowHookContainer {
@JvmStatic @JvmStatic
@Hook("co.paralleluniverse.fibers.Fiber") @Hook("co.paralleluniverse.fibers.Fiber")
fun park() { fun park1(blocker: Any?, postParkAction: Any?, timeout: Long?, unit: TimeUnit?) {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberParking, keys = listOf(Fiber.currentFiber()))) FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberParking, keys = listOf(Fiber.currentFiber())))
} }
@JvmStatic
@Hook("co.paralleluniverse.fibers.Fiber", passThis = true)
fun exec(fiber: Any) {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberResuming, keys = listOf(fiber)))
}
@JvmStatic
@Hook("co.paralleluniverse.fibers.Fiber", passThis = true)
fun onParked(fiber: Any) {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.FiberParked, keys = listOf(fiber)))
}
@JvmStatic @JvmStatic
@Hook("net.corda.node.services.statemachine.FlowStateMachineImpl") @Hook("net.corda.node.services.statemachine.FlowStateMachineImpl")
fun run() { fun run() {
@ -150,6 +168,36 @@ object FlowHookContainer {
})) }))
} }
@JvmStatic
@Hook("net.corda.node.services.statemachine.FlowMessagingImpl")
fun sendSessionMessage(party: Any, message: Any, deduplicationId: Any) {
message as SessionMessage
val sessionId = when (message) {
is InitialSessionMessage -> {
message.initiatorSessionId
}
is ExistingSessionMessage -> {
message.recipientSessionId
}
}
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.SendSessionMessage, keys = listOf(currentFiberOrThread(), sessionId)))
}
@JvmStatic
@Hook("org.apache.activemq.artemis.core.io.buffer.TimedBuffer", passThis = true)
fun flush(buffer: Any, force: Boolean): () -> Unit {
buffer as TimedBuffer
val thread = Thread.currentThread()
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.BrokerFlushStart, keys = listOf(thread), extra = object {
val force = force
val pendingSync = buffer.declaredField<Boolean>("pendingSync").value
}))
return {
FiberMonitor.newEvent(MonitorEvent(MonitorEventType.BrokerFlushEnd, keys = listOf(thread)))
}
}
private fun currentFiberOrThread(): Any { private fun currentFiberOrThread(): Any {
return Fiber.currentFiber() ?: Thread.currentThread() return Fiber.currentFiber() ?: Thread.currentThread()
} }

View File

@ -7,18 +7,22 @@ import net.corda.nodeapi.ArtemisTcpTransport
import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_USER import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.NODE_USER
import net.corda.nodeapi.internal.config.SSLConfiguration import net.corda.nodeapi.internal.config.SSLConfiguration
import org.apache.activemq.artemis.api.core.client.ActiveMQClient import org.apache.activemq.artemis.api.core.client.*
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import org.apache.activemq.artemis.api.core.client.ClientSessionFactory
class ArtemisMessagingClient(private val config: SSLConfiguration, private val serverAddress: NetworkHostAndPort, private val maxMessageSize: Int) { class ArtemisMessagingClient(
private val config: SSLConfiguration,
private val serverAddress: NetworkHostAndPort,
private val maxMessageSize: Int,
private val autoCommitSends: Boolean = true,
private val autoCommitAcks: Boolean = true,
private val confirmationWindowSize: Int = -1
) {
companion object { companion object {
private val log = loggerFor<ArtemisMessagingClient>() private val log = loggerFor<ArtemisMessagingClient>()
} }
class Started(val sessionFactory: ClientSessionFactory, val session: ClientSession, val producer: ClientProducer) class Started(val serverLocator: ServerLocator, val sessionFactory: ClientSessionFactory, val session: ClientSession, val producer: ClientProducer)
var started: Started? = null var started: Started? = null
private set private set
@ -35,17 +39,18 @@ class ArtemisMessagingClient(private val config: SSLConfiguration, private val s
clientFailureCheckPeriod = -1 clientFailureCheckPeriod = -1
minLargeMessageSize = maxMessageSize minLargeMessageSize = maxMessageSize
isUseGlobalPools = nodeSerializationEnv != null isUseGlobalPools = nodeSerializationEnv != null
confirmationWindowSize = this@ArtemisMessagingClient.confirmationWindowSize
} }
val sessionFactory = locator.createSessionFactory() val sessionFactory = locator.createSessionFactory()
// Login using the node username. The broker will authenticate us as its node (as opposed to another peer) // Login using the node username. The broker will authenticate us as its node (as opposed to another peer)
// using our TLS certificate. // using our TLS certificate.
// Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer // Note that the acknowledgement of messages is not flushed to the Artermis journal until the default buffer
// size of 1MB is acknowledged. // size of 1MB is acknowledged.
val session = sessionFactory!!.createSession(NODE_USER, NODE_USER, false, true, true, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE) val session = sessionFactory!!.createSession(NODE_USER, NODE_USER, false, autoCommitSends, autoCommitAcks, locator.isPreAcknowledge, DEFAULT_ACK_BATCH_SIZE)
session.start() session.start()
// Create a general purpose producer. // Create a general purpose producer.
val producer = session.createProducer() val producer = session.createProducer()
return Started(sessionFactory, session, producer).also { started = it } return Started(locator, sessionFactory, session, producer).also { started = it }
} }
fun stop() = synchronized(this) { fun stop() = synchronized(this) {
@ -55,6 +60,7 @@ class ArtemisMessagingClient(private val config: SSLConfiguration, private val s
session.commit() session.commit()
// Closing the factory closes all the sessions it produced as well. // Closing the factory closes all the sessions it produced as well.
sessionFactory.close() sessionFactory.close()
serverLocator.close()
} }
started = null started = null
} }

View File

@ -5,9 +5,7 @@ import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.toStringShort import net.corda.core.crypto.toStringShort
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.services.config.CertChainPolicyConfig import net.corda.node.services.config.*
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.node.services.messaging.ArtemisMessagingServer
import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent import net.corda.nodeapi.internal.ArtemisMessagingComponent
@ -143,6 +141,7 @@ class AMQPBridgeTest {
doReturn(artemisAddress).whenever(it).p2pAddress doReturn(artemisAddress).whenever(it).p2pAddress
doReturn("").whenever(it).exportJMXto doReturn("").whenever(it).exportJMXto
doReturn(emptyList<CertChainPolicyConfig>()).whenever(it).certificateChainCheckPolicies doReturn(emptyList<CertChainPolicyConfig>()).whenever(it).certificateChainCheckPolicies
doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration
} }
artemisConfig.configureWithDevSSLCertificate() artemisConfig.configureWithDevSSLCertificate()
val artemisServer = ArtemisMessagingServer(artemisConfig, artemisPort, MAX_MESSAGE_SIZE) val artemisServer = ArtemisMessagingServer(artemisConfig, artemisPort, MAX_MESSAGE_SIZE)

View File

@ -8,9 +8,7 @@ import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.services.config.CertChainPolicyConfig import net.corda.node.services.config.*
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.node.services.messaging.ArtemisMessagingServer
import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
@ -224,6 +222,7 @@ class ProtonWrapperTests {
doReturn(NetworkHostAndPort("0.0.0.0", artemisPort)).whenever(it).p2pAddress doReturn(NetworkHostAndPort("0.0.0.0", artemisPort)).whenever(it).p2pAddress
doReturn("").whenever(it).exportJMXto doReturn("").whenever(it).exportJMXto
doReturn(emptyList<CertChainPolicyConfig>()).whenever(it).certificateChainCheckPolicies doReturn(emptyList<CertChainPolicyConfig>()).whenever(it).certificateChainCheckPolicies
doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration
} }
artemisConfig.configureWithDevSSLCertificate() artemisConfig.configureWithDevSSLCertificate()

View File

@ -2,7 +2,7 @@ package net.corda.node.internal
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.google.common.collect.MutableClassToInstanceMap import com.google.common.collect.MutableClassToInstanceMap
import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource import com.zaxxer.hikari.HikariDataSource
import net.corda.confidential.SwapIdentitiesFlow import net.corda.confidential.SwapIdentitiesFlow
@ -90,7 +90,7 @@ import java.time.Duration
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit.SECONDS import java.util.concurrent.Executors
import kotlin.collections.set import kotlin.collections.set
import kotlin.reflect.KClass import kotlin.reflect.KClass
import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair
@ -110,7 +110,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
val platformClock: CordaClock, val platformClock: CordaClock,
protected val versionInfo: VersionInfo, protected val versionInfo: VersionInfo,
protected val cordappLoader: CordappLoader, protected val cordappLoader: CordappLoader,
private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { protected val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() {
private class StartedNodeImpl<out N : AbstractNode>( private class StartedNodeImpl<out N : AbstractNode>(
override val internals: N, override val internals: N,
@ -131,7 +131,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
// We will run as much stuff in this single thread as possible to keep the risk of thread safety bugs low during the // We will run as much stuff in this single thread as possible to keep the risk of thread safety bugs low during the
// low-performance prototyping period. // low-performance prototyping period.
protected abstract val serverThread: AffinityExecutor protected abstract val serverThread: AffinityExecutor.ServiceAffinityExecutor
protected lateinit var networkParameters: NetworkParameters protected lateinit var networkParameters: NetworkParameters
private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>() private val cordappServices = MutableClassToInstanceMap.create<SerializeAsToken>()
@ -140,7 +140,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
protected val services: ServiceHubInternal get() = _services protected val services: ServiceHubInternal get() = _services
private lateinit var _services: ServiceHubInternalImpl private lateinit var _services: ServiceHubInternalImpl
protected var myNotaryIdentity: PartyAndCertificate? = null protected var myNotaryIdentity: PartyAndCertificate? = null
private lateinit var checkpointStorage: CheckpointStorage protected lateinit var checkpointStorage: CheckpointStorage
private lateinit var tokenizableServices: List<Any> private lateinit var tokenizableServices: List<Any>
protected lateinit var attachments: NodeAttachmentService protected lateinit var attachments: NodeAttachmentService
protected lateinit var network: MessagingService protected lateinit var network: MessagingService
@ -229,23 +229,15 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
val notaryService = makeNotaryService(nodeServices, database) val notaryService = makeNotaryService(nodeServices, database)
val smm = makeStateMachineManager(database) val smm = makeStateMachineManager(database)
val flowLogicRefFactory = FlowLogicRefFactoryImpl(cordappLoader.appClassLoader) val flowLogicRefFactory = FlowLogicRefFactoryImpl(cordappLoader.appClassLoader)
val flowStarter = FlowStarterImpl(serverThread, smm, flowLogicRefFactory) val flowStarter = FlowStarterImpl(smm, flowLogicRefFactory)
val schedulerService = NodeSchedulerService( val schedulerService = NodeSchedulerService(
platformClock, platformClock,
database, database,
flowStarter, flowStarter,
transactionStorage, transactionStorage,
unfinishedSchedules = busyNodeLatch, unfinishedSchedules = busyNodeLatch,
serverThread = serverThread, flowLogicRefFactory = flowLogicRefFactory
flowLogicRefFactory = flowLogicRefFactory) )
if (serverThread is ExecutorService) {
runOnStop += {
// We wait here, even though any in-flight messages should have been drained away because the
// server thread can potentially have other non-messaging tasks scheduled onto it. The timeout value is
// arbitrary and might be inappropriate.
MoreExecutors.shutdownAndAwaitTermination(serverThread as ExecutorService, 50, SECONDS)
}
}
makeVaultObservers(schedulerService, database.hibernateConfig, smm, schemaService, flowLogicRefFactory) makeVaultObservers(schedulerService, database.hibernateConfig, smm, schemaService, flowLogicRefFactory)
val rpcOps = makeRPCOps(flowStarter, database, smm) val rpcOps = makeRPCOps(flowStarter, database, smm)
startMessagingService(rpcOps) startMessagingService(rpcOps)
@ -327,8 +319,9 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
} }
protected abstract fun myAddresses(): List<NetworkHostAndPort> protected abstract fun myAddresses(): List<NetworkHostAndPort>
protected open fun makeStateMachineManager(database: CordaPersistence): StateMachineManager { protected open fun makeStateMachineManager(database: CordaPersistence): StateMachineManager {
return StateMachineManagerImpl( return SingleThreadedStateMachineManager(
services, services,
checkpointStorage, checkpointStorage,
serverThread, serverThread,
@ -841,9 +834,9 @@ internal fun logVendorString(database: CordaPersistence, log: Logger) {
} }
} }
internal class FlowStarterImpl(private val serverThread: AffinityExecutor, private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter { internal class FlowStarterImpl(private val smm: StateMachineManager, private val flowLogicRefFactory: FlowLogicRefFactory) : FlowStarter {
override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> { override fun <T> startFlow(logic: FlowLogic<T>, context: InvocationContext): CordaFuture<FlowStateMachine<T>> {
return serverThread.fetchFrom { smm.startFlow(logic, context) } return smm.startFlow(logic, context)
} }
override fun <T> invokeFlowAsync( override fun <T> invokeFlowAsync(

View File

@ -4,24 +4,36 @@ import com.codahale.metrics.MetricFilter
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.codahale.metrics.graphite.GraphiteReporter import com.codahale.metrics.graphite.GraphiteReporter
import com.codahale.metrics.graphite.PickledGraphite import com.codahale.metrics.graphite.PickledGraphite
import com.google.common.util.concurrent.ThreadFactoryBuilder
import com.jcraft.jsch.JSch import com.jcraft.jsch.JSch
import com.jcraft.jsch.JSchException import com.jcraft.jsch.JSchException
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.Emoji import net.corda.core.internal.Emoji
import net.corda.core.internal.concurrent.thenMatch import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.services.config.GraphiteOptions import net.corda.node.services.config.GraphiteOptions
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.RelayConfiguration import net.corda.node.services.config.RelayConfiguration
import net.corda.node.services.statemachine.MultiThreadedStateMachineManager
import net.corda.node.services.statemachine.SingleThreadedStateMachineManager
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.internal.persistence.CordaPersistence
import org.fusesource.jansi.Ansi import org.fusesource.jansi.Ansi
import org.fusesource.jansi.AnsiConsole import org.fusesource.jansi.AnsiConsole
import java.io.IOException import java.io.IOException
import java.net.InetAddress import java.net.InetAddress
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
class EnterpriseNode(configuration: NodeConfiguration, open class EnterpriseNode(configuration: NodeConfiguration,
versionInfo: VersionInfo) : Node(configuration, versionInfo) { versionInfo: VersionInfo,
initialiseSerialization: Boolean = true,
cordappLoader: CordappLoader = makeCordappLoader(configuration)
) : Node(configuration, versionInfo, initialiseSerialization, cordappLoader) {
companion object { companion object {
private val logger by lazy { loggerFor<EnterpriseNode>() } private val logger by lazy { loggerFor<EnterpriseNode>() }
@ -144,4 +156,29 @@ D""".trimStart()
registerOptionalMetricsReporter(configuration, started.services.monitoringService.metrics) registerOptionalMetricsReporter(configuration, started.services.monitoringService.metrics)
return started return started
} }
private fun makeStateMachineExecutorService(): ExecutorService {
return Executors.newFixedThreadPool(
configuration.enterpriseConfiguration.tuning.flowThreadPoolSize,
ThreadFactoryBuilder().setNameFormat("flow-executor-%d").build()
)
}
override fun makeStateMachineManager(database: CordaPersistence): StateMachineManager {
if (configuration.enterpriseConfiguration.useMultiThreadedSMM) {
val executor = makeStateMachineExecutorService()
runOnStop += { executor.shutdown() }
return MultiThreadedStateMachineManager(
services,
checkpointStorage,
executor,
database,
newSecureRandom(),
busyNodeLatch,
cordappLoader.appClassLoader
)
} else {
return super.makeStateMachineManager(database)
}
}
} }

View File

@ -83,7 +83,8 @@ open class Node(configuration: NodeConfiguration,
private val sameVmNodeCounter = AtomicInteger() private val sameVmNodeCounter = AtomicInteger()
val scanPackagesSystemProperty = "net.corda.node.cordapp.scan.packages" val scanPackagesSystemProperty = "net.corda.node.cordapp.scan.packages"
val scanPackagesSeparator = "," val scanPackagesSeparator = ","
private fun makeCordappLoader(configuration: NodeConfiguration): CordappLoader { @JvmStatic
protected fun makeCordappLoader(configuration: NodeConfiguration): CordappLoader {
return System.getProperty(scanPackagesSystemProperty)?.let { scanPackages -> return System.getProperty(scanPackagesSystemProperty)?.let { scanPackages ->
CordappLoader.createDefaultWithTestPackages(configuration, scanPackages.split(scanPackagesSeparator)) CordappLoader.createDefaultWithTestPackages(configuration, scanPackages.split(scanPackagesSeparator))
} ?: CordappLoader.createDefault(configuration.baseDirectory) } ?: CordappLoader.createDefault(configuration.baseDirectory)
@ -157,8 +158,12 @@ open class Node(configuration: NodeConfiguration,
bridgeControlListener = BridgeControlListener(configuration, serverAddress, networkParameters.maxMessageSize) bridgeControlListener = BridgeControlListener(configuration, serverAddress, networkParameters.maxMessageSize)
printBasicNodeInfo("Incoming connection address", advertisedAddress.toString()) printBasicNodeInfo("Incoming connection address", advertisedAddress.toString())
val rpcServerConfiguration = RPCServerConfiguration.default.copy(
rpcThreadPoolSize = configuration.enterpriseConfiguration.tuning.rpcThreadPoolSize
)
rpcServerAddresses?.let { rpcServerAddresses?.let {
rpcMessagingClient = RPCMessagingClient(configuration.rpcOptions.sslConfig, it.admin, networkParameters.maxMessageSize) rpcMessagingClient = RPCMessagingClient(configuration.rpcOptions.sslConfig, it.admin, networkParameters.maxMessageSize, rpcServerConfiguration)
} }
verifierMessagingClient = when (configuration.verifierType) { verifierMessagingClient = when (configuration.verifierType) {
VerifierType.OutOfProcess -> VerifierMessagingClient(configuration, serverAddress, services.monitoringService.metrics, networkParameters.maxMessageSize) VerifierType.OutOfProcess -> VerifierMessagingClient(configuration, serverAddress, services.monitoringService.metrics, networkParameters.maxMessageSize)
@ -175,8 +180,10 @@ open class Node(configuration: NodeConfiguration,
serverThread, serverThread,
database, database,
services.networkMapCache, services.networkMapCache,
services.monitoringService.metrics,
advertisedAddress, advertisedAddress,
networkParameters.maxMessageSize) networkParameters.maxMessageSize
)
} }
private fun startLocalRpcBroker(): BrokerAddresses? { private fun startLocalRpcBroker(): BrokerAddresses? {

View File

@ -1,5 +1,39 @@
package net.corda.node.services.config package net.corda.node.services.config
data class EnterpriseConfiguration(val mutualExclusionConfiguration: MutualExclusionConfiguration) data class EnterpriseConfiguration(
val mutualExclusionConfiguration: MutualExclusionConfiguration,
val useMultiThreadedSMM: Boolean = true,
val tuning: PerformanceTuning = PerformanceTuning.default
)
data class MutualExclusionConfiguration(val on: Boolean = false, val machineName: String, val updateInterval: Long, val waitInterval: Long) data class MutualExclusionConfiguration(val on: Boolean = false, val machineName: String, val updateInterval: Long, val waitInterval: Long)
/**
* @param flowThreadPoolSize Determines the size of the thread pool used by the flow framework to run flows.
* @param maximumMessagingBatchSize Determines the maximum number of jobs the messaging layer submits asynchronously
* before waiting for a flush from the broker.
* @param rpcThreadPoolSize Determines the number of threads used by the RPC server to serve requests.
* @param p2pConfirmationWindowSize Determines the number of bytes buffered by the broker before flushing to disk and
* acking the triggering send. Setting this to -1 causes session commits to immediately return, potentially
* causing blowup in the broker if the rate of sends exceeds the broker's flush rate. Note also that this window
* causes send latency to be around [brokerConnectionTtlCheckInterval] if the window isn't saturated.
* @param brokerConnectionTtlCheckIntervalMs Determines the interval of TTL timeout checks, but most importantly it also
* determines the flush period of message acks in case [p2pConfirmationWindowSize] is not saturated in time.
*/
data class PerformanceTuning(
val flowThreadPoolSize: Int,
val maximumMessagingBatchSize: Int,
val rpcThreadPoolSize: Int,
val p2pConfirmationWindowSize: Int,
val brokerConnectionTtlCheckIntervalMs: Long
) {
companion object {
val default = PerformanceTuning(
flowThreadPoolSize = 1,
maximumMessagingBatchSize = 256,
rpcThreadPoolSize = 4,
p2pConfirmationWindowSize = 1048576,
brokerConnectionTtlCheckIntervalMs = 20
)
}
}

View File

@ -209,6 +209,16 @@ data class NodeConfigurationImpl(
if (dataSourceUrl.contains(":sqlserver:") && !dataSourceUrl.contains("sendStringParametersAsUnicode", true)) { if (dataSourceUrl.contains(":sqlserver:") && !dataSourceUrl.contains("sendStringParametersAsUnicode", true)) {
dataSourceProperties[DataSourceConfigTag.DATA_SOURCE_URL] = dataSourceUrl + ";sendStringParametersAsUnicode=false" dataSourceProperties[DataSourceConfigTag.DATA_SOURCE_URL] = dataSourceUrl + ";sendStringParametersAsUnicode=false"
} }
// Adjust connection pool size depending on N=flow thread pool size.
// If there is no configured pool size set it to N + 1, otherwise check that it's greater than N.
val flowThreadPoolSize = enterpriseConfiguration.tuning.flowThreadPoolSize
val maxConnectionPoolSize = dataSourceProperties.getProperty("maximumPoolSize")
if (maxConnectionPoolSize == null) {
dataSourceProperties.setProperty("maximumPoolSize", (flowThreadPoolSize + 1).toString())
} else {
require(maxConnectionPoolSize.toInt() > flowThreadPoolSize)
}
} }
} }

View File

@ -58,7 +58,6 @@ class NodeSchedulerService(private val clock: CordaClock,
private val flowStarter: FlowStarter, private val flowStarter: FlowStarter,
private val stateLoader: StateLoader, private val stateLoader: StateLoader,
private val unfinishedSchedules: ReusableLatch = ReusableLatch(), private val unfinishedSchedules: ReusableLatch = ReusableLatch(),
private val serverThread: Executor,
private val flowLogicRefFactory: FlowLogicRefFactory, private val flowLogicRefFactory: FlowLogicRefFactory,
private val log: Logger = staticLog, private val log: Logger = staticLog,
private val scheduledStates: MutableMap<StateRef, ScheduledStateRef> = createMap()) private val scheduledStates: MutableMap<StateRef, ScheduledStateRef> = createMap())
@ -244,24 +243,22 @@ class NodeSchedulerService(private val clock: CordaClock,
} }
private fun onTimeReached(scheduledState: ScheduledStateRef) { private fun onTimeReached(scheduledState: ScheduledStateRef) {
serverThread.execute { var flowName: String? = "(unknown)"
var flowName: String? = "(unknown)" try {
try { database.transaction {
database.transaction { val scheduledFlow = getScheduledFlow(scheduledState)
val scheduledFlow = getScheduledFlow(scheduledState) if (scheduledFlow != null) {
if (scheduledFlow != null) { flowName = scheduledFlow.javaClass.name
flowName = scheduledFlow.javaClass.name // TODO refactor the scheduler to store and propagate the original invocation context
// TODO refactor the scheduler to store and propagate the original invocation context val context = InvocationContext.newInstance(Origin.Scheduled(scheduledState))
val context = InvocationContext.newInstance(Origin.Scheduled(scheduledState)) val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture }
val future = flowStarter.startFlow(scheduledFlow, context).flatMap { it.resultFuture } future.then {
future.then { unfinishedSchedules.countDown()
unfinishedSchedules.countDown()
}
} }
} }
} catch (e: Exception) {
log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e)
} }
} catch (e: Exception) {
log.error("Failed to start scheduled flow $flowName for $scheduledState due to an internal error", e)
} }
} }

View File

@ -135,7 +135,7 @@ class PersistentIdentityService(override val trustRoot: X509Certificate,
log.debug { "Registering identity $identity" } log.debug { "Registering identity $identity" }
val key = mapToKey(identity) val key = mapToKey(identity)
keyToParties.addWithDuplicatesAllowed(key, identity) keyToParties.addWithDuplicatesAllowed(key, identity, false)
// Always keep the first party we registered, as that's the well known identity // Always keep the first party we registered, as that's the well known identity
principalToParties.addWithDuplicatesAllowed(identity.name, key, false) principalToParties.addWithDuplicatesAllowed(identity.name, key, false)
val parentId = mapToKey(identityCertChain[1].publicKey) val parentId = mapToKey(identityCertChain[1].publicKey)

View File

@ -142,7 +142,7 @@ class ArtemisMessagingServer(private val config: NodeConfiguration,
journalBufferSize_AIO = maxMessageSize // Required to address IllegalArgumentException (when Artemis uses Linux Async IO): Record is too large to store. journalBufferSize_AIO = maxMessageSize // Required to address IllegalArgumentException (when Artemis uses Linux Async IO): Record is too large to store.
journalFileSize = maxMessageSize // The size of each journal file in bytes. Artemis default is 10MiB. journalFileSize = maxMessageSize // The size of each journal file in bytes. Artemis default is 10MiB.
managementNotificationAddress = SimpleString(NOTIFICATIONS_ADDRESS) managementNotificationAddress = SimpleString(NOTIFICATIONS_ADDRESS)
connectionTtlCheckInterval = config.enterpriseConfiguration.tuning.brokerConnectionTtlCheckIntervalMs
// JMX enablement // JMX enablement
if (config.exportJMXto.isNotEmpty()) { if (config.exportJMXto.isNotEmpty()) {
isJMXManagementEnabled = true isJMXManagementEnabled = true

View File

@ -1,5 +1,6 @@
package net.corda.node.services.messaging package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
@ -60,15 +61,13 @@ interface MessagingService {
* @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two * @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two
* subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same * subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same
* sequence the send()s were called. By default this is chosen conservatively to be [target]. * sequence the send()s were called. By default this is chosen conservatively to be [target].
* @param acknowledgementHandler if non-null this handler will be called once the sent message has been committed by
* the broker. Note that if specified [send] itself may return earlier than the commit.
*/ */
@Suspendable
fun send( fun send(
message: Message, message: Message,
target: MessageRecipients, target: MessageRecipients,
retryId: Long? = null, retryId: Long? = null,
sequenceKey: Any = target, sequenceKey: Any = target
acknowledgementHandler: (() -> Unit)? = null
) )
/** A message with a target and sequenceKey specified. */ /** A message with a target and sequenceKey specified. */
@ -84,10 +83,9 @@ interface MessagingService {
* implementation. * implementation.
* *
* @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys. * @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys.
* @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed
* by the broker. Note that if specified [send] itself may return earlier than the commit.
*/ */
fun send(addressedMessages: List<AddressedMessage>, acknowledgementHandler: (() -> Unit)? = null) @Suspendable
fun send(addressedMessages: List<AddressedMessage>)
/** Cancels the scheduled message redelivery for the specified [retryId] */ /** Cancels the scheduled message redelivery for the specified [retryId] */
fun cancelRedelivery(retryId: Long) fun cancelRedelivery(retryId: Long)

View File

@ -0,0 +1,221 @@
package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.SettableFuture
import com.codahale.metrics.MetricRegistry
import net.corda.core.messaging.MessageRecipients
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.VersionInfo
import net.corda.node.services.statemachine.FlowMessagingImpl
import org.apache.activemq.artemis.api.core.ActiveMQDuplicateIdException
import org.apache.activemq.artemis.api.core.ActiveMQException
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.api.core.client.ClientProducer
import org.apache.activemq.artemis.api.core.client.ClientSession
import java.util.*
import java.util.concurrent.ArrayBlockingQueue
import java.util.concurrent.ExecutionException
import kotlin.concurrent.thread
interface AddressToArtemisQueueResolver {
/**
* Resolves a [MessageRecipients] to an Artemis queue name, creating the underlying queue if needed.
*/
fun resolveTargetToArtemisQueue(address: MessageRecipients): String
}
/**
* The [MessagingExecutor] is responsible for handling send and acknowledge jobs. It batches them using a bounded
* blocking queue, submits the jobs asynchronously and then waits for them to flush using [ClientSession.commit].
* Note that even though we buffer in theory this shouldn't increase latency as the executor is immediately woken up if
* it was waiting. The number of jobs in the queue is only ever greater than 1 if the commit takes a long time.
*/
class MessagingExecutor(
val session: ClientSession,
val producer: ClientProducer,
val versionInfo: VersionInfo,
val resolver: AddressToArtemisQueueResolver,
metricRegistry: MetricRegistry,
queueBound: Int
) {
private sealed class Job {
data class Acknowledge(val message: ClientMessage) : Job()
data class Send(
val message: Message,
val target: MessageRecipients,
val sentFuture: SettableFuture<Unit>
) : Job() {
override fun toString() = "Send(${message.uniqueMessageId}, target=$target)"
}
object Shutdown : Job() { override fun toString() = "Shutdown" }
}
private val queue = ArrayBlockingQueue<Job>(queueBound)
private var executor: Thread? = null
private val cordaVendor = SimpleString(versionInfo.vendor)
private val releaseVersion = SimpleString(versionInfo.releaseVersion)
private val sendMessageSizeMetric = metricRegistry.histogram("SendMessageSize")
private val sendLatencyMetric = metricRegistry.timer("SendLatency")
private val sendBatchSizeMetric = metricRegistry.histogram("SendBatchSize")
private companion object {
val log = contextLogger()
val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
}
/**
* Submit a send job of [message] to [target] and wait until it finishes.
* This call may yield the fiber.
*/
@Suspendable
fun send(message: Message, target: MessageRecipients) {
val sentFuture = SettableFuture<Unit>()
val job = Job.Send(message, target, sentFuture)
val context = sendLatencyMetric.time()
try {
queue.put(job)
sentFuture.get()
} catch (executionException: ExecutionException) {
throw executionException.cause!!
} finally {
context.stop()
}
}
/**
* Submit an acknowledge job of [message].
* This call does NOT wait for confirmation of the ACK receive. If a failure happens then the message will either be
* redelivered, deduped and acked, or the message was actually acked before failure in which case all is good.
*/
fun acknowledge(message: ClientMessage) {
queue.put(Job.Acknowledge(message))
}
fun start() {
require(executor == null)
executor = thread(name = "Messaging executor", isDaemon = true) {
val batch = ArrayList<Job>()
eventLoop@ while (true) {
batch.add(queue.take()) // Block until at least one job is available.
queue.drainTo(batch)
sendBatchSizeMetric.update(batch.filter { it is Job.Send }.size)
val shouldShutdown = try {
// Try to handle the batch in one commit.
handleBatchTransactional(batch)
} catch (exception: ActiveMQException) {
// A job failed, rollback and do it one at a time, simply log and skip if an individual job fails.
// If a send job fails the exception will be re-raised in the corresponding future.
// Note that this fallback assumes that there are no two jobs in the batch that depend on one
// another. As the exception is re-raised in the requesting calling thread in case of a send, we can
// assume no "in-flight" messages will be sent out of order after failure.
log.warn("Exception while handling transactional batch, falling back to handling one job at a time", exception)
handleBatchOneByOne(batch)
}
batch.clear()
if (shouldShutdown) {
break@eventLoop
}
}
}
}
fun close() {
val executor = this.executor
if (executor != null) {
queue.offer(Job.Shutdown)
executor.join()
this.executor = null
}
}
/**
* Handles a batch of jobs in one transaction.
* @return true if the executor should shut down, false otherwise.
* @throws ActiveMQException
*/
private fun handleBatchTransactional(batch: List<Job>): Boolean {
for (job in batch) {
when (job) {
is Job.Acknowledge -> {
acknowledgeJob(job)
}
is Job.Send -> {
sendJob(job)
}
Job.Shutdown -> {
session.commit()
return true
}
}
}
session.commit()
return false
}
/**
* Handles a batch of jobs one by one, committing after each.
* @return true if the executor should shut down, false otherwise.
*/
private fun handleBatchOneByOne(batch: List<Job>): Boolean {
for (job in batch) {
try {
when (job) {
is Job.Acknowledge -> {
acknowledgeJob(job)
session.commit()
}
is Job.Send -> {
try {
sendJob(job)
session.commit()
} catch (duplicateException: ActiveMQDuplicateIdException) {
log.warn("Message duplication", duplicateException)
job.sentFuture.set(Unit)
}
}
Job.Shutdown -> {
session.commit()
return true
}
}
} catch (exception: Throwable) {
log.error("Exception while handling job $job, disregarding", exception)
if (job is Job.Send) {
job.sentFuture.setException(exception)
}
session.rollback()
}
}
return false
}
private fun sendJob(job: Job.Send) {
val mqAddress = resolver.resolveTargetToArtemisQueue(job.target)
val artemisMessage = session.createMessage(true).apply {
putStringProperty(P2PMessagingClient.cordaVendorProperty, cordaVendor)
putStringProperty(P2PMessagingClient.releaseVersionProperty, releaseVersion)
putIntProperty(P2PMessagingClient.platformVersionProperty, versionInfo.platformVersion)
putStringProperty(P2PMessagingClient.topicProperty, SimpleString(job.message.topic))
sendMessageSizeMetric.update(job.message.data.bytes.size)
writeBodyBufferBytes(job.message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(job.message.uniqueMessageId.toString))
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
if (amqDelayMillis > 0 && job.message.topic == FlowMessagingImpl.sessionTopic) {
putLongProperty(org.apache.activemq.artemis.api.core.Message.HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
}
}
log.trace {
"Send to: $mqAddress topic: ${job.message.topic} " +
"sessionID: ${job.message.topic} id: ${job.message.uniqueMessageId}"
}
producer.send(SimpleString(mqAddress), artemisMessage) { job.sentFuture.set(Unit) }
}
private fun acknowledgeJob(job: Job.Acknowledge) {
job.message.individualAcknowledge()
}
}

View File

@ -1,5 +1,7 @@
package net.corda.node.services.messaging package net.corda.node.services.messaging
import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.MetricRegistry
import net.corda.core.crypto.toStringShort import net.corda.core.crypto.toStringShort
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.ThreadBox import net.corda.core.internal.ThreadBox
@ -18,7 +20,6 @@ import net.corda.node.VersionInfo
import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.api.NetworkMapCacheInternal
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.statemachine.DeduplicationId import net.corda.node.services.statemachine.DeduplicationId
import net.corda.node.services.statemachine.FlowMessagingImpl
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.AppendOnlyPersistentMap import net.corda.node.utilities.AppendOnlyPersistentMap
import net.corda.node.utilities.PersistentMap import net.corda.node.utilities.PersistentMap
@ -32,7 +33,8 @@ import net.corda.nodeapi.internal.bridging.BridgeEntry
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.Message.* import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER
import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.RoutingType
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.ClientConsumer import org.apache.activemq.artemis.api.core.client.ClientConsumer
@ -78,7 +80,7 @@ import javax.persistence.Lob
* @param maxMessageSize A bound applied to the message size. * @param maxMessageSize A bound applied to the message size.
*/ */
@ThreadSafe @ThreadSafe
class P2PMessagingClient(config: NodeConfiguration, class P2PMessagingClient(val config: NodeConfiguration,
private val versionInfo: VersionInfo, private val versionInfo: VersionInfo,
serverAddress: NetworkHostAndPort, serverAddress: NetworkHostAndPort,
private val myIdentity: PublicKey, private val myIdentity: PublicKey,
@ -86,20 +88,20 @@ class P2PMessagingClient(config: NodeConfiguration,
private val nodeExecutor: AffinityExecutor.ServiceAffinityExecutor, private val nodeExecutor: AffinityExecutor.ServiceAffinityExecutor,
private val database: CordaPersistence, private val database: CordaPersistence,
private val networkMap: NetworkMapCacheInternal, private val networkMap: NetworkMapCacheInternal,
private val metricRegistry: MetricRegistry,
advertisedAddress: NetworkHostAndPort = serverAddress, advertisedAddress: NetworkHostAndPort = serverAddress,
maxMessageSize: Int maxMessageSize: Int
) : SingletonSerializeAsToken(), MessagingService { ) : SingletonSerializeAsToken(), MessagingService, AddressToArtemisQueueResolver {
companion object { companion object {
private val log = contextLogger() private val log = contextLogger()
// This is a "property" attached to an Artemis MQ message object, which contains our own notion of "topic". // This is a "property" attached to an Artemis MQ message object, which contains our own notion of "topic".
// We should probably try to unify our notion of "topic" (really, just a string that identifies an endpoint // We should probably try to unify our notion of "topic" (really, just a string that identifies an endpoint
// that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid // that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid
// confusion. // confusion.
private val topicProperty = SimpleString("platform-topic") val topicProperty = SimpleString("platform-topic")
private val cordaVendorProperty = SimpleString("corda-vendor") val cordaVendorProperty = SimpleString("corda-vendor")
private val releaseVersionProperty = SimpleString("release-version") val releaseVersionProperty = SimpleString("release-version")
private val platformVersionProperty = SimpleString("platform-version") val platformVersionProperty = SimpleString("platform-version")
private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
private val messageMaxRetryCount: Int = 3 private val messageMaxRetryCount: Int = 3
fun createProcessedMessages(): AppendOnlyPersistentMap<DeduplicationId, Instant, ProcessedMessage, String> { fun createProcessedMessages(): AppendOnlyPersistentMap<DeduplicationId, Instant, ProcessedMessage, String> {
@ -159,20 +161,23 @@ class P2PMessagingClient(config: NodeConfiguration,
/** A registration to handle messages of different types */ /** A registration to handle messages of different types */
data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration data class HandlerRegistration(val topic: String, val callback: Any) : MessageHandlerRegistration
private val cordaVendor = SimpleString(versionInfo.vendor)
private val releaseVersion = SimpleString(versionInfo.releaseVersion)
/** An executor for sending messages */
private val messagingExecutor = AffinityExecutor.ServiceAffinityExecutor("Messaging ${myIdentity.toStringShort()}", 1)
override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress) override val myAddress: SingleMessageRecipient = NodeAddress(myIdentity, advertisedAddress)
private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong() private val messageRedeliveryDelaySeconds = config.messageRedeliveryDelaySeconds.toLong()
private val artemis = ArtemisMessagingClient(config, serverAddress, maxMessageSize) private val artemis = ArtemisMessagingClient(
config = config,
serverAddress = serverAddress,
maxMessageSize = maxMessageSize,
autoCommitSends = false,
autoCommitAcks = false,
confirmationWindowSize = config.enterpriseConfiguration.tuning.p2pConfirmationWindowSize
)
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>()) private val knownQueues = Collections.newSetFromMap(ConcurrentHashMap<String, Boolean>())
private val handlers = ConcurrentHashMap<String, MessageHandler>() private val handlers = ConcurrentHashMap<String, MessageHandler>()
private val processedMessages = createProcessedMessages() private val processedMessages = createProcessedMessages()
private var messagingExecutor: MessagingExecutor? = null
@Entity @Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
@ -203,7 +208,8 @@ class P2PMessagingClient(config: NodeConfiguration,
fun start() { fun start() {
state.locked { state.locked {
val session = artemis.start().session val started = artemis.start()
val session = started.session
val inbox = RemoteInboxAddress(myIdentity).queueName val inbox = RemoteInboxAddress(myIdentity).queueName
val inboxes = mutableListOf(inbox) val inboxes = mutableListOf(inbox)
// Create a queue, consumer and producer for handling P2P network messages. // Create a queue, consumer and producer for handling P2P network messages.
@ -220,6 +226,18 @@ class P2PMessagingClient(config: NodeConfiguration,
deliver(msg, message) deliver(msg, message)
} }
} }
val messagingExecutor = MessagingExecutor(
session,
started.producer,
versionInfo,
this@P2PMessagingClient,
metricRegistry,
queueBound = config.enterpriseConfiguration.tuning.maximumMessagingBatchSize
)
this@P2PMessagingClient.messagingExecutor = messagingExecutor
messagingExecutor.start()
registerBridgeControl(session, inboxes) registerBridgeControl(session, inboxes)
enumerateBridges(session, inboxes) enumerateBridges(session, inboxes)
} }
@ -253,6 +271,7 @@ class P2PMessagingClient(config: NodeConfiguration,
val artemisMessage = client.session.createMessage(false) val artemisMessage = client.session.createMessage(false)
artemisMessage.writeBodyBufferBytes(controlPacket) artemisMessage.writeBodyBufferBytes(controlPacket)
client.producer.send(BRIDGE_CONTROL, artemisMessage) client.producer.send(BRIDGE_CONTROL, artemisMessage)
client.session.commit()
} }
private fun updateBridgesOnNetworkChange(change: NetworkMapCache.MapChange) { private fun updateBridgesOnNetworkChange(change: NetworkMapCache.MapChange) {
@ -419,12 +438,7 @@ class P2PMessagingClient(config: NodeConfiguration,
// processing a message but if so, it'll be parked waiting for us to count down the latch, so // processing a message but if so, it'll be parked waiting for us to count down the latch, so
// the session itself is still around and we can still ack messages as a result. // the session itself is still around and we can still ack messages as a result.
override fun acknowledge() { override fun acknowledge() {
messagingExecutor.fetchFrom { messagingExecutor!!.acknowledge(artemisMessage)
state.locked {
artemisMessage.individualAcknowledge()
artemis.started!!.session.commit()
}
}
} }
} }
deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), acknowledgeHandle) deliverTo(msg, HandlerRegistration(msg.topic, deliverTo), acknowledgeHandle)
@ -476,6 +490,7 @@ class P2PMessagingClient(config: NodeConfiguration,
shutdownLatch.await() shutdownLatch.await()
} }
// Only first caller to gets running true to protect against double stop, which seems to happen in some integration tests. // Only first caller to gets running true to protect against double stop, which seems to happen in some integration tests.
messagingExecutor?.close()
if (running) { if (running) {
state.locked { state.locked {
artemis.stop() artemis.stop()
@ -483,74 +498,43 @@ class P2PMessagingClient(config: NodeConfiguration,
} }
} }
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { @Suspendable
// We have to perform sending on a different thread pool, since using the same pool for messaging and override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
// fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals. messagingExecutor!!.send(message, target)
messagingExecutor.fetchFrom { retryId?.let {
state.locked { database.transaction {
val mqAddress = getMQAddress(target) messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) })
val artemis = artemis.started!!
val artemisMessage = artemis.session.createMessage(true).apply {
putStringProperty(cordaVendorProperty, cordaVendor)
putStringProperty(releaseVersionProperty, releaseVersion)
putIntProperty(platformVersionProperty, versionInfo.platformVersion)
putStringProperty(topicProperty, SimpleString(message.topic))
writeBodyBufferBytes(message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString))
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
if (amqDelayMillis > 0 && message.topic == FlowMessagingImpl.sessionTopic) {
putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
}
}
log.trace {
"Send to: $mqAddress topic: ${message.topic} " +
"sessionID: ${message.topic} id: ${message.uniqueMessageId}"
}
artemis.producer.send(mqAddress, artemisMessage)
retryId?.let {
database.transaction {
messagesToRedeliver.computeIfAbsent(it, { Pair(message, target) })
}
scheduledMessageRedeliveries[it] = messagingExecutor.schedule({
sendWithRetry(0, mqAddress, artemisMessage, it)
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
}
} }
scheduledMessageRedeliveries[it] = nodeExecutor.schedule({
sendWithRetry(0, message, target, retryId)
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
} }
acknowledgementHandler?.invoke()
} }
override fun send(addressedMessages: List<MessagingService.AddressedMessage>, acknowledgementHandler: (() -> Unit)?) { @Suspendable
override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) { for ((message, target, retryId, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey, null) send(message, target, retryId, sequenceKey)
} }
acknowledgementHandler?.invoke()
} }
private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) { private fun sendWithRetry(retryCount: Int, message: Message, target: MessageRecipients, retryId: Long) {
fun ClientMessage.randomiseDuplicateId() {
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
log.trace { "Attempting to retry #$retryCount message delivery for $retryId" } log.trace { "Attempting to retry #$retryCount message delivery for $retryId" }
if (retryCount >= messageMaxRetryCount) { if (retryCount >= messageMaxRetryCount) {
log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $address") log.warn("Reached the maximum number of retries ($messageMaxRetryCount) for message $message redelivery to $target")
scheduledMessageRedeliveries.remove(retryId) scheduledMessageRedeliveries.remove(retryId)
return return
} }
message.randomiseDuplicateId() val messageWithRetryCount = object : Message by message {
override val uniqueMessageId = DeduplicationId("${message.uniqueMessageId.toString}-$retryCount")
state.locked {
log.trace { "Retry #$retryCount sending message $message to $address for $retryId" }
artemis.started!!.producer.send(address, message)
} }
scheduledMessageRedeliveries[retryId] = messagingExecutor.schedule({ messagingExecutor!!.send(messageWithRetryCount, target)
sendWithRetry(retryCount + 1, address, message, retryId)
scheduledMessageRedeliveries[retryId] = nodeExecutor.schedule({
sendWithRetry(retryCount + 1, message, target, retryId)
}, messageRedeliveryDelaySeconds, TimeUnit.SECONDS) }, messageRedeliveryDelaySeconds, TimeUnit.SECONDS)
} }
@ -565,14 +549,14 @@ class P2PMessagingClient(config: NodeConfiguration,
} }
} }
private fun getMQAddress(target: MessageRecipients): String { override fun resolveTargetToArtemisQueue(address: MessageRecipients): String {
return if (target == myAddress) { return if (address == myAddress) {
// If we are sending to ourselves then route the message directly to our P2P queue. // If we are sending to ourselves then route the message directly to our P2P queue.
RemoteInboxAddress(myIdentity).queueName RemoteInboxAddress(myIdentity).queueName
} else { } else {
// Otherwise we send the message to an internal queue for the target residing on our broker. It's then the // Otherwise we send the message to an internal queue for the target residing on our broker. It's then the
// broker's job to route the message to the target's P2P queue. // broker's job to route the message to the target's P2P queue.
val internalTargetQueue = (target as? ArtemisAddress)?.queueName ?: throw IllegalArgumentException("Not an Artemis address") val internalTargetQueue = (address as? ArtemisAddress)?.queueName ?: throw IllegalArgumentException("Not an Artemis address")
createQueueIfAbsent(internalTargetQueue) createQueueIfAbsent(internalTargetQueue)
internalTargetQueue internalTargetQueue
} }
@ -581,20 +565,18 @@ class P2PMessagingClient(config: NodeConfiguration,
/** Attempts to create a durable queue on the broker which is bound to an address of the same name. */ /** Attempts to create a durable queue on the broker which is bound to an address of the same name. */
private fun createQueueIfAbsent(queueName: String) { private fun createQueueIfAbsent(queueName: String) {
if (!knownQueues.contains(queueName)) { if (!knownQueues.contains(queueName)) {
state.alreadyLocked { val session = artemis.started!!.session
val session = artemis.started!!.session val queueQuery = session.queueQuery(SimpleString(queueName))
val queueQuery = session.queueQuery(SimpleString(queueName)) if (!queueQuery.isExists) {
if (!queueQuery.isExists) { log.info("Create fresh queue $queueName bound on same address")
log.info("Create fresh queue $queueName bound on same address") session.createQueue(queueName, RoutingType.ANYCAST, queueName, true)
session.createQueue(queueName, RoutingType.ANYCAST, queueName, true) if (queueName.startsWith(PEERS_PREFIX)) {
if (queueName.startsWith(PEERS_PREFIX)) { val keyHash = queueName.substring(PEERS_PREFIX.length)
val keyHash = queueName.substring(PEERS_PREFIX.length) val peers = networkMap.getNodesByOwningKeyIndex(keyHash)
val peers = networkMap.getNodesByOwningKeyIndex(keyHash) for (node in peers) {
for (node in peers) { val bridge = BridgeEntry(queueName, node.addresses, node.legalIdentities.map { it.name })
val bridge = BridgeEntry(queueName, node.addresses, node.legalIdentities.map { it.name }) val createBridgeMessage = BridgeControl.Create(myIdentity.toStringShort(), bridge)
val createBridgeMessage = BridgeControl.Create(myIdentity.toStringShort(), bridge) sendBridgeControl(createBridgeMessage)
sendBridgeControl(createBridgeMessage)
}
} }
} }
} }

View File

@ -11,14 +11,19 @@ import net.corda.nodeapi.internal.config.SSLConfiguration
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl
class RPCMessagingClient(private val config: SSLConfiguration, serverAddress: NetworkHostAndPort, maxMessageSize: Int) : SingletonSerializeAsToken(), AutoCloseable { class RPCMessagingClient(
private val config: SSLConfiguration,
serverAddress: NetworkHostAndPort,
maxMessageSize: Int,
private val rpcServerConfiguration: RPCServerConfiguration = RPCServerConfiguration.default
) : SingletonSerializeAsToken(), AutoCloseable {
private val artemis = ArtemisMessagingClient(config, serverAddress, maxMessageSize) private val artemis = ArtemisMessagingClient(config, serverAddress, maxMessageSize)
private var rpcServer: RPCServer? = null private var rpcServer: RPCServer? = null
fun start(rpcOps: RPCOps, securityManager: RPCSecurityManager) = synchronized(this) { fun start(rpcOps: RPCOps, securityManager: RPCSecurityManager) = synchronized(this) {
val locator = artemis.start().sessionFactory.serverLocator val locator = artemis.start().sessionFactory.serverLocator
val myCert = config.loadSslKeyStore().getCertificate(X509Utilities.CORDA_CLIENT_TLS) val myCert = config.loadSslKeyStore().getCertificate(X509Utilities.CORDA_CLIENT_TLS)
rpcServer = RPCServer(rpcOps, NODE_USER, NODE_USER, locator, securityManager, CordaX500Name.build(myCert.subjectX500Principal)) rpcServer = RPCServer(rpcOps, NODE_USER, NODE_USER, locator, securityManager, CordaX500Name.build(myCert.subjectX500Principal), rpcServerConfiguration)
} }
fun start2(serverControl: ActiveMQServerControl) = synchronized(this) { fun start2(serverControl: ActiveMQServerControl) = synchronized(this) {

View File

@ -442,7 +442,7 @@ class ObservableContext(
val artemisMessage = it.session.createMessage(false) val artemisMessage = it.session.createMessage(false)
serverToClient.writeToClientMessage(serializationContextWithObservableContext, artemisMessage) serverToClient.writeToClientMessage(serializationContextWithObservableContext, artemisMessage)
it.producer.send(clientAddress, artemisMessage) it.producer.send(clientAddress, artemisMessage)
log.debug("<- RPC <- $serverToClient") log.debug { "<- RPC <- $serverToClient" }
} }
} catch (throwable: Throwable) { } catch (throwable: Throwable) {
log.error("Failed to send message, kicking client. Message was $serverToClient", throwable) log.error("Failed to send message, kicking client. Message was $serverToClient", throwable)

View File

@ -21,7 +21,6 @@ class AbstractPartyToX500NameAsStringConverter(private val identityService: Iden
if (party != null) { if (party != null) {
val partyName = identityService.wellKnownPartyFromAnonymous(party)?.toString() val partyName = identityService.wellKnownPartyFromAnonymous(party)?.toString()
if (partyName != null) return partyName if (partyName != null) return partyName
log.warn("Identity service unable to resolve AbstractParty: $party")
} }
return null // non resolvable anonymous parties return null // non resolvable anonymous parties
} }
@ -30,7 +29,6 @@ class AbstractPartyToX500NameAsStringConverter(private val identityService: Iden
if (dbData != null) { if (dbData != null) {
val party = identityService.wellKnownPartyFromX500Name(CordaX500Name.parse(dbData)) val party = identityService.wellKnownPartyFromX500Name(CordaX500Name.parse(dbData))
if (party != null) return party if (party != null) return party
log.warn("Identity service unable to resolve X500name: $dbData")
} }
return null // non resolvable anonymous parties are stored as nulls return null // non resolvable anonymous parties are stored as nulls
} }

View File

@ -3,6 +3,7 @@ package net.corda.node.services.persistence
import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.internal.ConcurrentBox
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.core.messaging.StateMachineTransactionMapping
import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage
@ -51,16 +52,27 @@ class DBTransactionMappingStorage : StateMachineRecordedTransactionMappingStorag
} }
} }
val stateMachineTransactionMap = createMap() private class InnerState {
val updates: PublishSubject<StateMachineTransactionMapping> = PublishSubject.create() val stateMachineTransactionMap = createMap()
val updates: PublishSubject<StateMachineTransactionMapping> = PublishSubject.create()
override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) {
stateMachineTransactionMap[transactionId] = stateMachineRunId
updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId))
} }
override fun track(): DataFeed<List<StateMachineTransactionMapping>, StateMachineTransactionMapping> = private val concurrentBox = ConcurrentBox(InnerState())
DataFeed(stateMachineTransactionMap.allPersisted().map { StateMachineTransactionMapping(it.second, it.first) }.toList(),
updates.bufferUntilSubscribed().wrapWithDatabaseTransaction()) override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) {
concurrentBox.concurrent {
stateMachineTransactionMap[transactionId] = stateMachineRunId
updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId))
}
}
override fun track(): DataFeed<List<StateMachineTransactionMapping>, StateMachineTransactionMapping> {
return concurrentBox.exclusive {
DataFeed(
stateMachineTransactionMap.allPersisted().map { StateMachineTransactionMapping(it.second, it.first) }.toList(),
updates.bufferUntilSubscribed().wrapWithDatabaseTransaction()
)
}
}
} }

View File

@ -3,7 +3,7 @@ package net.corda.node.services.persistence
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.internal.ThreadBox import net.corda.core.internal.ConcurrentBox
import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.concurrent.doneFuture import net.corda.core.internal.concurrent.doneFuture
@ -79,10 +79,10 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S
} }
} }
private val txStorage = ThreadBox(createTransactionsMap(cacheSizeBytes)) private val txStorage = ConcurrentBox(createTransactionsMap(cacheSizeBytes))
override fun addTransaction(transaction: SignedTransaction): Boolean = override fun addTransaction(transaction: SignedTransaction): Boolean =
txStorage.locked { txStorage.concurrent {
addWithDuplicatesAllowed(transaction.id, transaction.toTxCacheValue()).apply { addWithDuplicatesAllowed(transaction.id, transaction.toTxCacheValue()).apply {
updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction)
} }
@ -94,13 +94,13 @@ class DBTransactionStorage(cacheSizeBytes: Long) : WritableTransactionStorage, S
override val updates: Observable<SignedTransaction> = updatesPublisher.wrapWithDatabaseTransaction() override val updates: Observable<SignedTransaction> = updatesPublisher.wrapWithDatabaseTransaction()
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> { override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> {
return txStorage.locked { return txStorage.exclusive {
DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) DataFeed(allPersisted().map { it.second.toSignedTx() }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
} }
} }
override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> { override fun trackTransaction(id: SecureHash): CordaFuture<SignedTransaction> {
return txStorage.locked { return txStorage.exclusive {
val existingTransaction = get(id) val existingTransaction = get(id)
if (existingTransaction == null) { if (existingTransaction == null) {
updatesPublisher.filter { it.id == id }.toFuture() updatesPublisher.filter { it.id == id }.toFuture()

View File

@ -122,7 +122,6 @@ class ActionExecutorImpl(
val exception = error.flowException val exception = error.flowException
log.debug("Propagating error", exception) log.debug("Propagating error", exception)
} }
val pendingSendAcks = CountUpDownLatch(0)
for (sessionState in action.sessions) { for (sessionState in action.sessions) {
// We cannot propagate if the session isn't live. // We cannot propagate if the session isn't live.
if (sessionState.initiatedState !is InitiatedSessionState.Live) { if (sessionState.initiatedState !is InitiatedSessionState.Live) {
@ -133,14 +132,9 @@ class ActionExecutorImpl(
val sinkSessionId = sessionState.initiatedState.peerSinkSessionId val sinkSessionId = sessionState.initiatedState.peerSinkSessionId
val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage) val existingMessage = ExistingSessionMessage(sinkSessionId, errorMessage)
val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId) val deduplicationId = DeduplicationId.createForError(errorMessage.errorId, sinkSessionId)
pendingSendAcks.countUp() flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId)
flowMessaging.sendSessionMessage(sessionState.peerParty, existingMessage, deduplicationId) {
pendingSendAcks.countDown()
}
} }
} }
// TODO we simply block here, perhaps this should be explicit in the worker state
pendingSendAcks.await()
} }
@Suspendable @Suspendable
@ -163,12 +157,12 @@ class ActionExecutorImpl(
@Suspendable @Suspendable
private fun executeSendInitial(action: Action.SendInitial) { private fun executeSendInitial(action: Action.SendInitial) {
flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId, null) flowMessaging.sendSessionMessage(action.party, action.initialise, action.deduplicationId)
} }
@Suspendable @Suspendable
private fun executeSendExisting(action: Action.SendExisting) { private fun executeSendExisting(action: Action.SendExisting) {
flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId, null) flowMessaging.sendSessionMessage(action.peerParty, action.message, action.deduplicationId)
} }
@Suspendable @Suspendable

View File

@ -1,5 +1,6 @@
package net.corda.node.services.statemachine package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.KryoException
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.identity.Party import net.corda.core.identity.Party
@ -20,7 +21,8 @@ interface FlowMessaging {
* Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to * Send [message] to [party] using [deduplicationId]. Optionally [acknowledgementHandler] may be specified to
* listen on the send acknowledgement. * listen on the send acknowledgement.
*/ */
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?) @Suspendable
fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId)
/** /**
* Start the messaging using the [onMessage] message handler. * Start the messaging using the [onMessage] message handler.
@ -45,7 +47,8 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
} }
} }
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId, acknowledgementHandler: (() -> Unit)?) { @Suspendable
override fun sendSessionMessage(party: Party, message: SessionMessage, deduplicationId: DeduplicationId) {
log.trace { "Sending message $deduplicationId $message to party $party" } log.trace { "Sending message $deduplicationId $message to party $party" }
val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId) val networkMessage = serviceHub.networkService.createMessage(sessionTopic, serializeSessionMessage(message).bytes, deduplicationId)
val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party") val partyInfo = serviceHub.networkMapCache.getPartyInfo(party) ?: throw IllegalArgumentException("Don't know about $party")
@ -54,7 +57,7 @@ class FlowMessagingImpl(val serviceHub: ServiceHubInternal): FlowMessaging {
is InitialSessionMessage -> message.initiatorSessionId is InitialSessionMessage -> message.initiatorSessionId
is ExistingSessionMessage -> message.recipientSessionId is ExistingSessionMessage -> message.recipientSessionId
} }
serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey, acknowledgementHandler = acknowledgementHandler) serviceHub.networkService.send(networkMessage, address, sequenceKey = sequenceKey)
} }
private fun serializeSessionMessage(message: SessionMessage): SerializedBytes<SessionMessage> { private fun serializeSessionMessage(message: SessionMessage): SerializedBytes<SessionMessage> {

View File

@ -7,7 +7,6 @@ import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import co.paralleluniverse.strands.channels.Channel import co.paralleluniverse.strands.channels.Channel
import com.codahale.metrics.Counter import com.codahale.metrics.Counter
import com.codahale.metrics.Metric
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.flows.* import net.corda.core.flows.*
@ -43,10 +42,7 @@ class TransientReference<out A>(@Transient val value: A)
class FlowStateMachineImpl<R>(override val id: StateMachineRunId, class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
override val logic: FlowLogic<R>, override val logic: FlowLogic<R>,
scheduler: FiberScheduler, scheduler: FiberScheduler
private val totalSuccessMetric: Counter,
private val totalErrorMetric: Counter
// Store the Party rather than the full cert path with PartyAndCertificate
) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R>, FlowFiber { ) : Fiber<Unit>(id.toString(), scheduler), FlowStateMachine<R>, FlowFiber {
companion object { companion object {
/** /**
@ -55,18 +51,6 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*> fun currentStateMachine(): FlowStateMachineImpl<*>? = Strand.currentStrand() as? FlowStateMachineImpl<*>
private val log: Logger = LoggerFactory.getLogger("net.corda.flow") private val log: Logger = LoggerFactory.getLogger("net.corda.flow")
@Suspendable
private fun abortFiber(): Nothing {
Fiber.park()
throw IllegalStateException("Ended fiber unparked")
}
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
val transaction = contextTransaction
contextTransactionOrNull = null
return TransientReference(transaction)
}
} }
override val serviceHub get() = getTransientField(TransientValues::serviceHub) override val serviceHub get() = getTransientField(TransientValues::serviceHub)
@ -90,6 +74,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return field.get(suppliedValues.value) return field.get(suppliedValues.value)
} }
private fun extractThreadLocalTransaction(): TransientReference<DatabaseTransaction> {
val transaction = contextTransaction
contextTransactionOrNull = null
return TransientReference(transaction)
}
/** /**
* Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message * Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message
* is not necessary. * is not necessary.
@ -145,8 +135,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val startTime = System.nanoTime() val startTime = System.nanoTime()
val resultOrError = try { val resultOrError = try {
val result = logic.call() val result = logic.call()
// TODO expose maySkipCheckpoint here suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = true)
suspend(FlowIORequest.WaitForSessionConfirmations, maySkipCheckpoint = false)
Try.Success(result) Try.Success(result)
} catch (throwable: Throwable) { } catch (throwable: Throwable) {
logger.warn("Flow threw exception", throwable) logger.warn("Flow threw exception", throwable)
@ -154,15 +143,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
val finalEvent = when (resultOrError) { val finalEvent = when (resultOrError) {
is Try.Success -> { is Try.Success -> {
totalSuccessMetric.inc()
Event.FlowFinish(resultOrError.value) Event.FlowFinish(resultOrError.value)
} }
is Try.Failure -> { is Try.Failure -> {
totalErrorMetric.inc()
Event.Error(resultOrError.exception) Event.Error(resultOrError.exception)
} }
} }
processEvent(getTransientField(TransientValues::transitionExecutor), finalEvent) scheduleEvent(finalEvent)
processEventsUntilFlowIsResumed() processEventsUntilFlowIsResumed()
recordDuration(startTime) recordDuration(startTime)
@ -192,6 +179,13 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
return resume.result as FlowSession return resume.result as FlowSession
} }
@Suspendable
private fun abortFiber(): Nothing {
while (true) {
Fiber.park()
}
}
// TODO Dummy implementation of access to application specific permission controls and audit logging // TODO Dummy implementation of access to application specific permission controls and audit logging
override fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>) { override fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>) {
val permissionGranted = true // TODO define permission control service on ServiceHubInternal and actually check authorization. val permissionGranted = true // TODO define permission control service on ServiceHubInternal and actually check authorization.
@ -257,7 +251,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
require(processEvent(transitionExecutor.value, event) == FlowContinuation.ProcessEvents) require(processEvent(transitionExecutor.value, event) == FlowContinuation.ProcessEvents)
Fiber.unparkDeserialized(this, scheduler) Fiber.unparkDeserialized(this, scheduler)
} }
return processEventsUntilFlowIsResumed() as R return uncheckedCast(processEventsUntilFlowIsResumed())
} }
@Suspendable @Suspendable

View File

@ -0,0 +1,668 @@
package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberExecutorScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.fibers.instrument.SuspendableHelper
import co.paralleluniverse.strands.channels.Channels
import com.codahale.metrics.Gauge
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.internal.concurrent.OpenFuture
import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.AcknowledgeHandle
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.statemachine.interceptors.*
import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.services.statemachine.transitions.StateMachineConfiguration
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
import net.corda.nodeapi.internal.serialization.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch
import rx.Observable
import rx.subjects.PublishSubject
import java.security.SecureRandom
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import javax.annotation.concurrent.ThreadSafe
import kotlin.collections.ArrayList
import kotlin.streams.toList
/**
* The StateMachineManagerImpl will always invoke the flow fibers on the given [AffinityExecutor], regardless of which
* thread actually starts them via [startFlow].
*/
@ThreadSafe
class MultiThreadedStateMachineManager(
val serviceHub: ServiceHubInternal,
val checkpointStorage: CheckpointStorage,
val executor: ExecutorService,
val database: CordaPersistence,
val secureRandom: SecureRandom,
private val unfinishedFibers: ReusableLatch = ReusableLatch(),
private val classloader: ClassLoader = MultiThreadedStateMachineManager::class.java.classLoader
) : StateMachineManager, StateMachineManagerInternal {
companion object {
private val logger = contextLogger()
}
private class Flow(val fiber: FlowStateMachineImpl<*>, val resultFuture: OpenFuture<Any?>)
private enum class State {
UNSTARTED,
STARTED,
STOPPING,
STOPPED
}
private val lifeCycle = LifeCycle(State.UNSTARTED)
private class InnerState {
val flows = ConcurrentHashMap<StateMachineRunId, Flow>()
val startedFutures = ConcurrentHashMap<StateMachineRunId, OpenFuture<Unit>>()
val changesPublisher = PublishSubject.create<StateMachineManager.Change>()!!
}
private val concurrentBox = ConcurrentBox(InnerState())
private val scheduler = FiberExecutorScheduler("Flow fiber scheduler", executor)
// How many Fibers are running and not suspended. If zero and stopping is true, then we are halted.
private val liveFibers = ReusableLatch()
// Monitoring support.
private val metrics = serviceHub.monitoringService.metrics
private val sessionToFlow = ConcurrentHashMap<SessionId, StateMachineRunId>()
private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub)
private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null
private val transitionExecutor = makeTransitionExecutor()
private var checkpointSerializationContext: SerializationContext? = null
private var tokenizableServices: List<Any>? = null
private var actionExecutor: ActionExecutor? = null
override val allStateMachines: List<FlowLogic<*>>
get() = concurrentBox.content.flows.values.map { it.fiber.logic }
private val totalStartedFlows = metrics.counter("Flows.Started")
private val totalFinishedFlows = metrics.counter("Flows.Finished")
private val totalSuccessFlows = metrics.counter("Flows.Success")
private val totalErrorFlows = metrics.counter("Flows.Error")
/**
* An observable that emits triples of the changing flow, the type of change, and a process-specific ID number
* which may change across restarts.
*
* We use assignment here so that multiple subscribers share the same wrapped Observable.
*/
override val changes: Observable<StateMachineManager.Change> = concurrentBox.content.changesPublisher
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
this.tokenizableServices = tokenizableServices
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
fiberDeserializationChecker?.start(checkpointSerializationContext)
val fibers = restoreFlowsFromCheckpoints()
metrics.register("Flows.InFlight", Gauge<Int> { concurrentBox.content.flows.size })
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as FlowStateMachineImpl<*>).logger.warn("Caught exception from flow", throwable)
}
serviceHub.networkMapCache.nodeReady.then {
resumeRestoredFlows(fibers)
flowMessaging.start { receivedMessage, acknowledgeHandle ->
lifeCycle.requireState(State.STARTED) {
onSessionMessage(receivedMessage, acknowledgeHandle)
}
}
}
lifeCycle.transition(State.UNSTARTED, State.STARTED)
}
override fun <A : FlowLogic<*>> findStateMachines(flowClass: Class<A>): List<Pair<A, CordaFuture<*>>> {
return concurrentBox.content.flows.values.mapNotNull {
flowClass.castIfPossible(it.fiber.logic)?.let { it to it.stateMachine.resultFuture }
}
}
/**
* Start the shutdown process, bringing the [MultiThreadedStateMachineManager] to a controlled stop. When this method returns,
* all Fibers have been suspended and checkpointed, or have completed.
*
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
*/
override fun stop(allowedUnsuspendedFiberCount: Int) {
require(allowedUnsuspendedFiberCount >= 0)
lifeCycle.transition(State.STARTED, State.STOPPING)
for ((_, flow) in concurrentBox.content.flows) {
flow.fiber.scheduleEvent(Event.SoftShutdown)
}
// Account for any expected Fibers in a test scenario.
liveFibers.countDown(allowedUnsuspendedFiberCount)
liveFibers.await()
fiberDeserializationChecker?.let {
val foundUnrestorableFibers = it.stop()
check(!foundUnrestorableFibers) { "Unrestorable checkpoints were created, please check the logs for details." }
}
lifeCycle.transition(State.STOPPING, State.STOPPED)
}
/**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines]
*/
override fun track(): DataFeed<List<FlowLogic<*>>, StateMachineManager.Change> {
return concurrentBox.exclusive {
DataFeed(flows.values.map { it.fiber.logic }, changesPublisher.bufferUntilSubscribed())
}
}
override fun <A> startFlow(
flowLogic: FlowLogic<A>,
context: InvocationContext,
ourIdentity: Party?
): CordaFuture<FlowStateMachine<A>> {
return lifeCycle.requireState(State.STARTED) {
startFlowInternal(
invocationContext = context,
flowLogic = flowLogic,
flowStart = FlowStart.Explicit,
ourIdentity = ourIdentity ?: getOurFirstIdentity(),
initialUnacknowledgedMessage = null,
isStartIdempotent = false
)
}
}
override fun killFlow(id: StateMachineRunId): Boolean {
concurrentBox.concurrent {
val flow = flows.remove(id)
if (flow != null) {
logger.debug("Killing flow known to physical node.")
decrementLiveFibers()
totalFinishedFlows.inc()
unfinishedFibers.countDown()
try {
flow.fiber.interrupt()
return true
} finally {
database.transaction {
checkpointStorage.removeCheckpoint(id)
}
}
} else {
// TODO replace with a clustered delete after we'll support clustered nodes
logger.debug("Unable to kill a flow unknown to physical node. Might be processed by another physical node.")
return false
}
}
}
override fun addSessionBinding(flowId: StateMachineRunId, sessionId: SessionId) {
val previousFlowId = sessionToFlow.put(sessionId, flowId)
if (previousFlowId != null) {
if (previousFlowId == flowId) {
logger.warn("Session binding from $sessionId to $flowId re-added")
} else {
throw IllegalStateException(
"Attempted to add session binding from session $sessionId to flow $flowId, " +
"however there was already a binding to $previousFlowId"
)
}
}
}
override fun removeSessionBindings(sessionIds: Set<SessionId>) {
val reRemovedSessionIds = HashSet<SessionId>()
for (sessionId in sessionIds) {
val flowId = sessionToFlow.remove(sessionId)
if (flowId == null) {
reRemovedSessionIds.add(sessionId)
}
}
if (reRemovedSessionIds.isNotEmpty()) {
logger.warn("Session binding from $reRemovedSessionIds re-removed")
}
}
override fun removeFlow(flowId: StateMachineRunId, removalReason: FlowRemovalReason, lastState: StateMachineState) {
concurrentBox.concurrent {
val flow = flows.remove(flowId)
if (flow != null) {
decrementLiveFibers()
totalFinishedFlows.inc()
unfinishedFibers.countDown()
return when (removalReason) {
is FlowRemovalReason.OrderlyFinish -> removeFlowOrderly(flow, removalReason, lastState)
is FlowRemovalReason.ErrorFinish -> removeFlowError(flow, removalReason, lastState)
FlowRemovalReason.SoftShutdown -> flow.fiber.scheduleEvent(Event.SoftShutdown)
}
} else {
logger.warn("Flow $flowId re-finished")
}
}
}
override fun signalFlowHasStarted(flowId: StateMachineRunId) {
concurrentBox.concurrent {
startedFutures.remove(flowId)?.set(Unit)
}
}
private fun checkQuasarJavaAgentPresence() {
check(SuspendableHelper.isJavaAgentActive(), {
"""Missing the '-javaagent' JVM argument. Make sure you run the tests with the Quasar java agent attached to your JVM.
#See https://docs.corda.net/troubleshooting.html - 'Fiber classes not instrumented' for more details.""".trimMargin("#")
})
}
private fun decrementLiveFibers() {
liveFibers.countDown()
}
private fun incrementLiveFibers() {
liveFibers.countUp()
}
private fun restoreFlowsFromCheckpoints(): List<Flow> {
return checkpointStorage.getAllCheckpoints().map { (id, serializedCheckpoint) ->
// If a flow is added before start() then don't attempt to restore it
if (concurrentBox.content.flows.containsKey(id)) return@map null
val checkpoint = deserializeCheckpoint(serializedCheckpoint)
if (checkpoint == null) return@map null
createFlowFromCheckpoint(
id = id,
checkpoint = checkpoint,
initialUnacknowledgedMessage = null,
isAnyCheckpointPersisted = true,
isStartIdempotent = false
)
}.toList().filterNotNull()
}
private fun resumeRestoredFlows(flows: List<Flow>) {
for (flow in flows) {
addAndStartFlow(flow.fiber.id, flow)
}
}
private fun onSessionMessage(message: ReceivedMessage, acknowledgeHandle: AcknowledgeHandle) {
val peer = message.peer
val sessionMessage = try {
message.data.deserialize<SessionMessage>()
} catch (ex: Exception) {
logger.error("Received corrupt SessionMessage data from $peer")
acknowledgeHandle.acknowledge()
return
}
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) {
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, acknowledgeHandle, sender)
is InitialSessionMessage -> onSessionInit(sessionMessage, message.platformVersion, acknowledgeHandle, sender)
}
} else {
logger.error("Unknown peer $peer in $sessionMessage")
}
}
private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) {
try {
val recipientId = sessionMessage.recipientSessionId
val flowId = sessionToFlow[recipientId]
if (flowId == null) {
acknowledgeHandle.acknowledge()
if (sessionMessage.payload is EndSessionMessage) {
logger.debug {
"Got ${EndSessionMessage::class.java.simpleName} for " +
"unknown session $recipientId, discarding..."
}
} else {
throw IllegalArgumentException("Cannot find flow corresponding to session ID $recipientId")
}
} else {
val flow = concurrentBox.content.flows[flowId] ?: throw IllegalStateException("Cannot find fiber corresponding to ID $flowId")
flow.fiber.scheduleEvent(Event.DeliverSessionMessage(sessionMessage, acknowledgeHandle, sender))
}
} catch (exception: Exception) {
logger.error("Exception while routing $sessionMessage", exception)
throw exception
}
}
private fun onSessionInit(sessionMessage: InitialSessionMessage, senderPlatformVersion: Int, acknowledgeHandle: AcknowledgeHandle, sender: Party) {
fun createErrorMessage(initiatorSessionId: SessionId, message: String): ExistingSessionMessage {
val errorId = secureRandom.nextLong()
val payload = RejectSessionMessage(message, errorId)
return ExistingSessionMessage(initiatorSessionId, payload)
}
val replyError = try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionMessage)
val initiatedSessionId = SessionId.createRandom(secureRandom)
val senderSession = FlowSessionImpl(sender, initiatedSessionId)
val flowLogic = initiatedFlowFactory.createFlow(senderSession)
val initiatedFlowInfo = when (initiatedFlowFactory) {
is InitiatedFlowFactory.Core -> FlowInfo(serviceHub.myInfo.platformVersion, "corda")
is InitiatedFlowFactory.CorDapp -> FlowInfo(initiatedFlowFactory.flowVersion, initiatedFlowFactory.appName)
}
val senderCoreFlowVersion = when (initiatedFlowFactory) {
is InitiatedFlowFactory.Core -> senderPlatformVersion
is InitiatedFlowFactory.CorDapp -> null
}
startInitiatedFlow(flowLogic, acknowledgeHandle, senderSession, initiatedSessionId, sessionMessage, senderCoreFlowVersion, initiatedFlowInfo)
null
} catch (exception: Exception) {
logger.warn("Exception while creating initiated flow", exception)
createErrorMessage(
sessionMessage.initiatorSessionId,
(exception as? SessionRejectException)?.message ?: "Unable to establish session"
)
}
if (replyError != null) {
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
acknowledgeHandle.acknowledge()
}
}
// TODO this is a temporary hack until we figure out multiple identities
private fun getOurFirstIdentity(): Party {
return serviceHub.myInfo.legalIdentities[0]
}
private fun getInitiatedFlowFactory(message: InitialSessionMessage): InitiatedFlowFactory<*> {
val initiatingFlowClass = try {
Class.forName(message.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java)
} catch (e: ClassNotFoundException) {
throw SessionRejectException("Don't know ${message.initiatorFlowClassName}")
} catch (e: ClassCastException) {
throw SessionRejectException("${message.initiatorFlowClassName} is not a flow")
}
return serviceHub.getFlowFactory(initiatingFlowClass) ?:
throw SessionRejectException("$initiatingFlowClass is not registered")
}
private fun <A> startInitiatedFlow(
flowLogic: FlowLogic<A>,
triggeringUnacknowledgedMessage: AcknowledgeHandle,
peerSession: FlowSessionImpl,
initiatedSessionId: SessionId,
initiatingMessage: InitialSessionMessage,
senderCoreFlowVersion: Int?,
initiatedFlowInfo: FlowInfo
) {
val flowStart = FlowStart.Initiated(peerSession, initiatedSessionId, initiatingMessage, senderCoreFlowVersion, initiatedFlowInfo)
val ourIdentity = getOurFirstIdentity()
startFlowInternal(
InvocationContext.peer(peerSession.counterparty.name), flowLogic, flowStart, ourIdentity,
triggeringUnacknowledgedMessage,
isStartIdempotent = false
)
}
private fun <A> startFlowInternal(
invocationContext: InvocationContext,
flowLogic: FlowLogic<A>,
flowStart: FlowStart,
ourIdentity: Party,
initialUnacknowledgedMessage: AcknowledgeHandle?,
isStartIdempotent: Boolean
): CordaFuture<FlowStateMachine<A>> {
val flowId = StateMachineRunId.createRandom()
val deduplicationSeed = when (flowStart) {
FlowStart.Explicit -> flowId.uuid.toString()
is FlowStart.Initiated ->
"${flowStart.initiatingMessage.initiatorSessionId.toLong}-" +
"${flowStart.initiatingMessage.initiationEntropy}"
}
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
// have access to the fiber (and thereby the service hub)
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, deduplicationSeed).getOrThrow()
val startedFuture = openFuture<Unit>()
val initialState = StateMachineState(
checkpoint = initialCheckpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isTransactionTracked = false,
isAnyCheckpointPersisted = false,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = flowLogic
)
flowStateMachineImpl.transientState = TransientReference(initialState)
concurrentBox.concurrent {
startedFutures[flowId] = startedFuture
}
totalStartedFlows.inc()
addAndStartFlow(flowId, Flow(flowStateMachineImpl, resultFuture))
return startedFuture.map { flowStateMachineImpl as FlowStateMachine<A> }
}
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
} catch (exception: Throwable) {
logger.error("Encountered unrestorable checkpoint!", exception)
null
}
}
private fun verifyFlowLogicIsSuspendable(logic: FlowLogic<Any?>) {
// Quasar requires (in Java 8) that at least the call method be annotated suspendable. Unfortunately, it's
// easy to forget to add this when creating a new flow, so we check here to give the user a better error.
//
// The Kotlin compiler can sometimes generate a synthetic bridge method from a single call declaration, which
// forwards to the void method and then returns Unit. However annotations do not get copied across to this
// bridge, so we have to do a more complex scan here.
val call = logic.javaClass.methods.first { !it.isSynthetic && it.name == "call" && it.parameterCount == 0 }
if (call.getAnnotation(Suspendable::class.java) == null) {
throw FlowException("${logic.javaClass.name}.call() is not annotated as @Suspendable. Please fix this.")
}
}
private fun createTransientValues(id: StateMachineRunId, resultFuture: CordaFuture<Any?>): FlowStateMachineImpl.TransientValues {
return FlowStateMachineImpl.TransientValues(
eventQueue = Channels.newChannel(16, Channels.OverflowPolicy.BLOCK),
resultFuture = resultFuture,
database = database,
transitionExecutor = transitionExecutor,
actionExecutor = actionExecutor!!,
stateMachine = StateMachine(id, StateMachineConfiguration.default, secureRandom),
serviceHub = serviceHub,
checkpointSerializationContext = checkpointSerializationContext!!
)
}
private fun createFlowFromCheckpoint(
id: StateMachineRunId,
checkpoint: Checkpoint,
isAnyCheckpointPersisted: Boolean,
isStartIdempotent: Boolean,
initialUnacknowledgedMessage: AcknowledgeHandle?
): Flow {
val flowState = checkpoint.flowState
val resultFuture = openFuture<Any?>()
val fiber = when (flowState) {
is FlowState.Unstarted -> {
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = logic
)
val fiber = FlowStateMachineImpl(id, logic, scheduler)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state)
fiber.logic.stateMachine = fiber
fiber
}
is FlowState.Started -> {
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
unacknowledgedMessages = initialUnacknowledgedMessage?.let { listOf(it) } ?: emptyList(),
isFlowResumed = false,
isTransactionTracked = false,
isAnyCheckpointPersisted = isAnyCheckpointPersisted,
isStartIdempotent = isStartIdempotent,
isRemoved = false,
flowLogic = fiber.logic
)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state)
fiber.logic.stateMachine = fiber
fiber
}
}
verifyFlowLogicIsSuspendable(fiber.logic)
return Flow(fiber, resultFuture)
}
private fun addAndStartFlow(id: StateMachineRunId, flow: Flow) {
val checkpoint = flow.fiber.snapshot().checkpoint
for (sessionId in getFlowSessionIds(checkpoint)) {
sessionToFlow.put(sessionId, id)
}
concurrentBox.concurrent {
incrementLiveFibers()
unfinishedFibers.countUp()
flows.put(id, flow)
flow.fiber.scheduleEvent(Event.DoRemainingWork)
when (checkpoint.flowState) {
is FlowState.Unstarted -> {
flow.fiber.start()
}
is FlowState.Started -> {
Fiber.unparkDeserialized(flow.fiber, scheduler)
}
}
changesPublisher.onNext(StateMachineManager.Change.Add(flow.fiber.logic))
}
}
private fun getFlowSessionIds(checkpoint: Checkpoint): Set<SessionId> {
val initiatedFlowStart = (checkpoint.flowState as? FlowState.Unstarted)?.flowStart as? FlowStart.Initiated
return if (initiatedFlowStart == null) {
checkpoint.sessions.keys
} else {
checkpoint.sessions.keys + initiatedFlowStart.initiatedSessionId
}
}
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor {
return ActionExecutorImpl(
serviceHub,
checkpointStorage,
flowMessaging,
this,
checkpointSerializationContext,
metrics
)
}
private fun makeTransitionExecutor(): TransitionExecutor {
val interceptors = ArrayList<TransitionInterceptor>()
interceptors.add { HospitalisingInterceptor(PropagatingFlowHospital, it) }
if (serviceHub.configuration.devMode) {
interceptors.add { DumpHistoryOnErrorInterceptor(it) }
}
if (serviceHub.configuration.shouldCheckCheckpoints()) {
interceptors.add { FiberDeserializationCheckingInterceptor(fiberDeserializationChecker!!, it) }
}
if (logger.isDebugEnabled) {
interceptors.add { PrintingInterceptor(it) }
}
val transitionExecutor: TransitionExecutor = TransitionExecutorImpl(secureRandom, database)
return interceptors.fold(transitionExecutor) { executor, interceptor -> interceptor(executor) }
}
private fun InnerState.removeFlowOrderly(
flow: Flow,
removalReason: FlowRemovalReason.OrderlyFinish,
lastState: StateMachineState
) {
totalSuccessFlows.inc()
drainFlowEventQueue(flow)
// final sanity checks
require(lastState.unacknowledgedMessages.isEmpty())
require(lastState.isRemoved)
require(lastState.checkpoint.subFlowStack.size == 1)
sessionToFlow.none { it.value == flow.fiber.id }
flow.resultFuture.set(removalReason.flowReturnValue)
lastState.flowLogic.progressTracker?.currentStep = ProgressTracker.DONE
changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Success(removalReason.flowReturnValue)))
}
private fun InnerState.removeFlowError(
flow: Flow,
removalReason: FlowRemovalReason.ErrorFinish,
lastState: StateMachineState
) {
totalErrorFlows.inc()
drainFlowEventQueue(flow)
val flowError = removalReason.flowErrors[0] // TODO what to do with several?
val exception = flowError.exception
(exception as? FlowException)?.originalErrorId = flowError.errorId
flow.resultFuture.setException(exception)
lastState.flowLogic.progressTracker?.endWithError(exception)
changesPublisher.onNext(StateMachineManager.Change.Removed(lastState.flowLogic, Try.Failure<Nothing>(exception)))
}
// The flow's event queue may be non-empty in case it shut down abruptly. We handle outstanding events here.
private fun drainFlowEventQueue(flow: Flow) {
while (true) {
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return
when (event) {
is Event.DoRemainingWork -> {}
is Event.DeliverSessionMessage -> {
// Acknowledge the message so it doesn't leak in the broker.
event.acknowledgeHandle.acknowledge()
when (event.sessionMessage.payload) {
EndSessionMessage -> {
logger.debug { "Unhandled message ${event.sessionMessage} by ${flow.fiber} due to flow shutting down" }
}
else -> {
logger.warn("Unhandled message ${event.sessionMessage} by ${flow.fiber} due to flow shutting down")
}
}
}
else -> {
logger.warn("Unhandled event $event by ${flow.fiber} due to flow shutting down")
}
}
}
}
}

View File

@ -0,0 +1,8 @@
package net.corda.node.services.statemachine
import net.corda.core.CordaException
/**
* An exception propagated and thrown in case a session initiation fails.
*/
class SessionRejectException(reason: String) : CordaException(reason)

View File

@ -6,7 +6,6 @@ import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.fibers.instrument.SuspendableHelper import co.paralleluniverse.fibers.instrument.SuspendableHelper
import co.paralleluniverse.strands.channels.Channels import co.paralleluniverse.strands.channels.Channels
import com.codahale.metrics.Gauge import com.codahale.metrics.Gauge
import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
@ -46,6 +45,7 @@ import rx.subjects.PublishSubject
import java.security.SecureRandom import java.security.SecureRandom
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import kotlin.collections.ArrayList import kotlin.collections.ArrayList
import kotlin.streams.toList import kotlin.streams.toList
@ -55,14 +55,14 @@ import kotlin.streams.toList
* thread actually starts them via [startFlow]. * thread actually starts them via [startFlow].
*/ */
@ThreadSafe @ThreadSafe
class StateMachineManagerImpl( class SingleThreadedStateMachineManager(
val serviceHub: ServiceHubInternal, val serviceHub: ServiceHubInternal,
val checkpointStorage: CheckpointStorage, val checkpointStorage: CheckpointStorage,
val executor: AffinityExecutor, val executor: ExecutorService,
val database: CordaPersistence, val database: CordaPersistence,
val secureRandom: SecureRandom, val secureRandom: SecureRandom,
private val unfinishedFibers: ReusableLatch = ReusableLatch(), private val unfinishedFibers: ReusableLatch = ReusableLatch(),
private val classloader: ClassLoader = StateMachineManagerImpl::class.java.classLoader private val classloader: ClassLoader = SingleThreadedStateMachineManager::class.java.classLoader
) : StateMachineManager, StateMachineManagerInternal { ) : StateMachineManager, StateMachineManagerInternal {
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
@ -145,7 +145,7 @@ class StateMachineManagerImpl(
} }
/** /**
* Start the shutdown process, bringing the [StateMachineManagerImpl] to a controlled stop. When this method returns, * Start the shutdown process, bringing the [SingleThreadedStateMachineManager] to a controlled stop. When this method returns,
* all Fibers have been suspended and checkpointed, or have completed. * all Fibers have been suspended and checkpointed, or have completed.
* *
* @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests.
@ -328,7 +328,6 @@ class StateMachineManagerImpl(
private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) { private fun onExistingSessionMessage(sessionMessage: ExistingSessionMessage, acknowledgeHandle: AcknowledgeHandle, sender: Party) {
try { try {
executor.checkOnThread()
val recipientId = sessionMessage.recipientSessionId val recipientId = sessionMessage.recipientSessionId
val flowId = sessionToFlow[recipientId] val flowId = sessionToFlow[recipientId]
if (flowId == null) { if (flowId == null) {
@ -381,7 +380,7 @@ class StateMachineManagerImpl(
} }
if (replyError != null) { if (replyError != null) {
flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom), null) flowMessaging.sendSessionMessage(sender, replyError, DeduplicationId.createRandom(secureRandom))
acknowledgeHandle.acknowledge() acknowledgeHandle.acknowledge()
} }
} }
@ -439,7 +438,7 @@ class StateMachineManagerImpl(
// Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties // Before we construct the state machine state by freezing the FlowLogic we need to make sure that lazy properties
// have access to the fiber (and thereby the service hub) // have access to the fiber (and thereby the service hub)
val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler, totalSuccessFlows, totalErrorFlows) val flowStateMachineImpl = FlowStateMachineImpl(flowId, flowLogic, scheduler)
val resultFuture = openFuture<Any?>() val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture)) flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl flowLogic.stateMachine = flowStateMachineImpl
@ -523,7 +522,7 @@ class StateMachineManagerImpl(
isRemoved = false, isRemoved = false,
flowLogic = logic flowLogic = logic
) )
val fiber = FlowStateMachineImpl(id, logic, scheduler, totalSuccessFlows, totalErrorFlows) val fiber = FlowStateMachineImpl(id, logic, scheduler)
fiber.transientValues = TransientReference(createTransientValues(id, resultFuture)) fiber.transientValues = TransientReference(createTransientValues(id, resultFuture))
fiber.transientState = TransientReference(state) fiber.transientState = TransientReference(state)
fiber.logic.stateMachine = fiber fiber.logic.stateMachine = fiber
@ -651,6 +650,7 @@ class StateMachineManagerImpl(
while (true) { while (true) {
val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return val event = flow.fiber.transientValues!!.value.eventQueue.tryReceive() ?: return
when (event) { when (event) {
is Event.DoRemainingWork -> {}
is Event.DeliverSessionMessage -> { is Event.DeliverSessionMessage -> {
// Acknowledge the message so it doesn't leak in the broker. // Acknowledge the message so it doesn't leak in the broker.
event.acknowledgeHandle.acknowledge() event.acknowledgeHandle.acknowledge()
@ -670,5 +670,3 @@ class StateMachineManagerImpl(
} }
} }
} }
class SessionRejectException(reason: String) : CordaException(reason)

View File

@ -169,7 +169,7 @@ sealed class FlowState {
val flowStart: FlowStart, val flowStart: FlowStart,
val frozenFlowLogic: SerializedBytes<FlowLogic<*>> val frozenFlowLogic: SerializedBytes<FlowLogic<*>>
) : FlowState() { ) : FlowState() {
override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash}" override fun toString() = "Unstarted(flowStart=$flowStart, frozenFlowLogic=${frozenFlowLogic.hash})"
} }
/** /**
@ -182,7 +182,7 @@ sealed class FlowState {
val flowIORequest: FlowIORequest<*>, val flowIORequest: FlowIORequest<*>,
val frozenFiber: SerializedBytes<FlowStateMachineImpl<*>> val frozenFiber: SerializedBytes<FlowStateMachineImpl<*>>
) : FlowState() { ) : FlowState() {
override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash}" override fun toString() = "Started(flowIORequest=$flowIORequest, frozenFiber=${frozenFiber.hash})"
} }
} }

View File

@ -20,6 +20,6 @@ interface TransitionExecutor {
} }
/** /**
* An interceptor of a transition. These are currently explicitly hooked up in [StateMachineManagerImpl]. * An interceptor of a transition. These are currently explicitly hooked up in [MultiThreadedStateMachineManager].
*/ */
typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor typealias TransitionInterceptor = (TransitionExecutor) -> TransitionExecutor

View File

@ -35,13 +35,13 @@ class DumpHistoryOnErrorInterceptor(val delegate: TransitionExecutor) : Transiti
} }
if (nextState.checkpoint.errorState is ErrorState.Errored) { if (nextState.checkpoint.errorState is ErrorState.Errored) {
log.warn("Flow ${fiber.id} dirtied, dumping all transitions:\n${record!!.joinToString("\n")}") log.warn("Flow ${fiber.id} errored, dumping all transitions:\n${record!!.joinToString("\n")}")
for (error in nextState.checkpoint.errorState.errors) { for (error in nextState.checkpoint.errorState.errors) {
log.warn("Flow ${fiber.id} error", error.exception) log.warn("Flow ${fiber.id} error", error.exception)
} }
} }
if (transition.newState.isRemoved) { if (nextState.isRemoved) {
records.remove(fiber.id) records.remove(fiber.id)
} }

View File

@ -38,6 +38,9 @@ class HospitalisingInterceptor(
} }
} }
} }
if (nextState.isRemoved) {
hospitalisedFlows.remove(fiber.id)
}
return Pair(continuation, nextState) return Pair(continuation, nextState)
} }
} }

View File

@ -65,7 +65,7 @@ class NodeVaultService(
val updatesPublisher: rx.Observer<Vault.Update<ContractState>> get() = _updatesPublisher.bufferUntilDatabaseCommit().tee(_rawUpdatesPublisher) val updatesPublisher: rx.Observer<Vault.Update<ContractState>> get() = _updatesPublisher.bufferUntilDatabaseCommit().tee(_rawUpdatesPublisher)
} }
private val mutex = ThreadBox(InnerState()) private val concurrentBox = ConcurrentBox(InnerState())
private fun recordUpdate(update: Vault.Update<ContractState>): Vault.Update<ContractState> { private fun recordUpdate(update: Vault.Update<ContractState>): Vault.Update<ContractState> {
if (!update.isEmpty()) { if (!update.isEmpty()) {
@ -103,10 +103,10 @@ class NodeVaultService(
} }
override val rawUpdates: Observable<Vault.Update<ContractState>> override val rawUpdates: Observable<Vault.Update<ContractState>>
get() = mutex.locked { _rawUpdatesPublisher } get() = concurrentBox.content._rawUpdatesPublisher
override val updates: Observable<Vault.Update<ContractState>> override val updates: Observable<Vault.Update<ContractState>>
get() = mutex.locked { _updatesInDbTx } get() = concurrentBox.content._updatesInDbTx
override fun notifyAll(statesToRecord: StatesToRecord, txns: Iterable<CoreTransaction>) { override fun notifyAll(statesToRecord: StatesToRecord, txns: Iterable<CoreTransaction>) {
if (statesToRecord == StatesToRecord.NONE) if (statesToRecord == StatesToRecord.NONE)
@ -205,7 +205,7 @@ class NodeVaultService(
private fun processAndNotify(update: Vault.Update<ContractState>) { private fun processAndNotify(update: Vault.Update<ContractState>) {
if (!update.isEmpty()) { if (!update.isEmpty()) {
recordUpdate(update) recordUpdate(update)
mutex.locked { concurrentBox.concurrent {
// flowId required by SoftLockManager to perform auto-registration of soft locks for new states // flowId required by SoftLockManager to perform auto-registration of soft locks for new states
val uuid = (Strand.currentStrand() as? FlowStateMachineImpl<*>)?.id?.uuid val uuid = (Strand.currentStrand() as? FlowStateMachineImpl<*>)?.id?.uuid
val vaultUpdate = if (uuid != null) update.copy(flowId = uuid) else update val vaultUpdate = if (uuid != null) update.copy(flowId = uuid) else update
@ -387,7 +387,7 @@ class NodeVaultService(
@Throws(VaultQueryException::class) @Throws(VaultQueryException::class)
override fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): Vault.Page<T> { override fun <T : ContractState> _queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): Vault.Page<T> {
log.info("Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting") log.debug {"Vault Query for contract type: $contractStateType, criteria: $criteria, pagination: $paging, sorting: $sorting" }
// calculate total results where a page specification has been defined // calculate total results where a page specification has been defined
var totalStates = -1L var totalStates = -1L
if (!paging.isDefault) { if (!paging.isDefault) {
@ -468,7 +468,7 @@ class NodeVaultService(
@Throws(VaultQueryException::class) @Throws(VaultQueryException::class)
override fun <T : ContractState> _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> { override fun <T : ContractState> _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> {
return mutex.locked { return concurrentBox.exclusive {
val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType) val snapshotResults = _queryBy(criteria, paging, sorting, contractStateType)
val updates: Observable<Vault.Update<T>> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) }) val updates: Observable<Vault.Update<T>> = uncheckedCast(_updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractStateType, snapshotResults.stateTypes) })
DataFeed(snapshotResults, updates) DataFeed(snapshotResults, updates)

View File

@ -31,6 +31,14 @@ enterpriseConfiguration = {
updateInterval = 20000 updateInterval = 20000
waitInterval = 40000 waitInterval = 40000
} }
tuning = {
flowThreadPoolSize = 1
rpcThreadPoolSize = 4
maximumMessagingBatchSize = 256
p2pConfirmationWindowSize = 1048576
brokerConnectionTtlCheckIntervalMs = 20
}
useMultiThreadedSMM = true
} }
rpcSettings = { rpcSettings = {
useSsl = false useSsl = false

View File

@ -1,10 +1,7 @@
package net.corda.node.services.events package net.corda.node.services.events
import com.google.common.util.concurrent.MoreExecutors
import com.nhaarman.mockito_kotlin.* import com.nhaarman.mockito_kotlin.*
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.generateKeyPair
import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowLogicRef import net.corda.core.flows.FlowLogicRef
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
@ -58,7 +55,6 @@ class NodeSchedulerServiceTest {
database, database,
flowStarter, flowStarter,
stateLoader, stateLoader,
serverThread = MoreExecutors.directExecutor(),
flowLogicRefFactory = flowLogicRefFactory, flowLogicRefFactory = flowLogicRefFactory,
log = log, log = log,
scheduledStates = mutableMapOf()).apply { start() } scheduledStates = mutableMapOf()).apply { start() }

View File

@ -1,17 +1,12 @@
package net.corda.node.services.messaging package net.corda.node.services.messaging
import com.codahale.metrics.MetricRegistry
import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.generateKeyPair
import net.corda.core.concurrent.CordaFuture
import com.codahale.metrics.MetricRegistry
import net.corda.core.crypto.generateKeyPair
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.node.internal.configureDatabase import net.corda.node.internal.configureDatabase
import net.corda.node.services.config.CertChainPolicyConfig import net.corda.node.services.config.*
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.network.NetworkMapCacheImpl import net.corda.node.services.network.NetworkMapCacheImpl
import net.corda.node.services.network.PersistentNetworkMapCache import net.corda.node.services.network.PersistentNetworkMapCache
import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.transactions.PersistentUniquenessProvider
@ -73,6 +68,7 @@ class ArtemisMessagingTest {
doReturn("").whenever(it).exportJMXto doReturn("").whenever(it).exportJMXto
doReturn(emptyList<CertChainPolicyConfig>()).whenever(it).certificateChainCheckPolicies doReturn(emptyList<CertChainPolicyConfig>()).whenever(it).certificateChainCheckPolicies
doReturn(5).whenever(it).messageRedeliveryDelaySeconds doReturn(5).whenever(it).messageRedeliveryDelaySeconds
doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration
} }
LogHelper.setLevel(PersistentUniquenessProvider::class) LogHelper.setLevel(PersistentUniquenessProvider::class)
database = configureDatabase(makeTestDataSourceProperties(), DatabaseConfig(runMigration = true), rigorousMock()) database = configureDatabase(makeTestDataSourceProperties(), DatabaseConfig(runMigration = true), rigorousMock())
@ -176,6 +172,7 @@ class ArtemisMessagingTest {
ServiceAffinityExecutor("ArtemisMessagingTests", 1), ServiceAffinityExecutor("ArtemisMessagingTests", 1),
database, database,
networkMapCache, networkMapCache,
MetricRegistry(),
maxMessageSize = maxMessageSize).apply { maxMessageSize = maxMessageSize).apply {
config.configureWithDevSSLCertificate() config.configureWithDevSSLCertificate()
messagingClient = this messagingClient = this

View File

@ -72,10 +72,6 @@ class FlowFrameworkTests {
private lateinit var alice: Party private lateinit var alice: Party
private lateinit var bob: Party private lateinit var bob: Party
private fun StartedNode<*>.flushSmm() {
(this.smm as StateMachineManagerImpl).executor.flush()
}
@Before @Before
fun start() { fun start() {
mockNet = MockNetwork( mockNet = MockNetwork(
@ -165,7 +161,6 @@ class FlowFrameworkTests {
aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } aliceNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow bobNode.services.startFlow(ReceiveFlow(alice).nonTerminating()) // Prepare checkpointed receive flow
// Make sure the add() has finished initial processing. // Make sure the add() has finished initial processing.
bobNode.flushSmm()
bobNode.internals.disableDBCloseOnStop() bobNode.internals.disableDBCloseOnStop()
bobNode.dispose() // kill receiver bobNode.dispose() // kill receiver
val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>() val restoredFlow = bobNode.restartAndGetRestoredFlow<ReceiveFlow>()
@ -191,7 +186,6 @@ class FlowFrameworkTests {
assertEquals(1, bobNode.checkpointStorage.checkpoints().size) assertEquals(1, bobNode.checkpointStorage.checkpoints().size)
} }
// Make sure the add() has finished initial processing. // Make sure the add() has finished initial processing.
bobNode.flushSmm()
bobNode.internals.disableDBCloseOnStop() bobNode.internals.disableDBCloseOnStop()
// Restart node and thus reload the checkpoint and resend the message with same UUID // Restart node and thus reload the checkpoint and resend the message with same UUID
bobNode.dispose() bobNode.dispose()
@ -204,7 +198,6 @@ class FlowFrameworkTests {
val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>() val (firstAgain, fut1) = node2b.getSingleFlow<PingPongFlow>()
// Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync. // Run the network which will also fire up the second flow. First message should get deduped. So message data stays in sync.
mockNet.runNetwork() mockNet.runNetwork()
node2b.flushSmm()
fut1.getOrThrow() fut1.getOrThrow()
val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer }
@ -575,7 +568,7 @@ class FlowFrameworkTests {
@Test @Test
fun `customised client flow which has annotated @InitiatingFlow again`() { fun `customised client flow which has annotated @InitiatingFlow again`() {
assertThatExceptionOfType(ExecutionException::class.java).isThrownBy { assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture aliceNode.services.startFlow(IncorrectCustomSendFlow("Hello", bob)).resultFuture
}.withMessageContaining(InitiatingFlow::class.java.simpleName) }.withMessageContaining(InitiatingFlow::class.java.simpleName)
} }

View File

@ -43,7 +43,7 @@ dependencies {
// TODO Remove this once we have app configs // TODO Remove this once we have app configs
compile "com.typesafe:config:$typesafe_config_version" compile "com.typesafe:config:$typesafe_config_version"
testCompile project(':test-utils') testCompile project(':node-driver')
testCompile project(path: ':core', configuration: 'testArtifacts') testCompile project(path: ':core', configuration: 'testArtifacts')
testCompile "junit:junit:$junit_version" testCompile "junit:junit:$junit_version"

View File

@ -1,22 +1,19 @@
package net.corda.node package com.r3.corda.enterprise.perftestcordapp
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.base.Stopwatch import com.google.common.base.Stopwatch
import com.r3.corda.enterprise.perftestcordapp.flows.CashIssueAndPaymentNoSelection
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StartableByRPC import net.corda.core.flows.StartableByRPC
import net.corda.core.internal.concurrent.transpose
import net.corda.core.messaging.startFlow import net.corda.core.messaging.startFlow
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.minutes import net.corda.core.utilities.minutes
import net.corda.finance.DOLLARS import net.corda.finance.DOLLARS
import net.corda.finance.flows.CashIssueAndPaymentFlow
import net.corda.finance.flows.CashIssueFlow import net.corda.finance.flows.CashIssueFlow
import net.corda.finance.flows.CashPaymentFlow
import net.corda.node.services.Permissions.Companion.startFlow import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.*
import net.corda.testing.core.DUMMY_BANK_A_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.NodeHandle import net.corda.testing.driver.NodeHandle
import net.corda.testing.driver.PortAllocation import net.corda.testing.driver.PortAllocation
import net.corda.testing.driver.driver import net.corda.testing.driver.driver
@ -73,7 +70,7 @@ class NodePerformanceTests : IntegrationTest() {
queueBound = 50 queueBound = 50
) { ) {
val timing = Stopwatch.createStarted().apply { val timing = Stopwatch.createStarted().apply {
connection.proxy.startFlow(::EmptyFlow).returnValue.getOrThrow() connection.proxy.startFlow(NodePerformanceTests::EmptyFlow).returnValue.getOrThrow()
}.stop().elapsed(TimeUnit.MICROSECONDS) }.stop().elapsed(TimeUnit.MICROSECONDS)
timings.add(timing) timings.add(timing)
} }
@ -95,8 +92,14 @@ class NodePerformanceTests : IntegrationTest() {
a as NodeHandle.InProcess a as NodeHandle.InProcess
val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, a.node.services.monitoringService.metrics) val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, a.node.services.monitoringService.metrics)
a.rpcClientToNode().use("A", "A") { connection -> a.rpcClientToNode().use("A", "A") { connection ->
startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) { startPublishingFixedRateInjector(
connection.proxy.startFlow(::EmptyFlow).returnValue.get() metricRegistry = metricRegistry,
parallelism = 16,
overallDuration = 5.minutes,
injectionRate = 2000L / TimeUnit.SECONDS,
workBound = 50
) {
connection.proxy.startFlow(NodePerformanceTests::EmptyFlow).returnValue
} }
} }
} }
@ -109,8 +112,14 @@ class NodePerformanceTests : IntegrationTest() {
a as NodeHandle.InProcess a as NodeHandle.InProcess
val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, a.node.services.monitoringService.metrics) val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, a.node.services.monitoringService.metrics)
a.rpcClientToNode().use("A", "A") { connection -> a.rpcClientToNode().use("A", "A") { connection ->
startPublishingFixedRateInjector(metricRegistry, 1, 5.minutes, 2000L / TimeUnit.SECONDS) { startPublishingFixedRateInjector(
connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), ALICE).returnValue.get() metricRegistry = metricRegistry,
parallelism = 16,
overallDuration = 5.minutes,
injectionRate = 2000L / TimeUnit.SECONDS,
workBound = 50
) {
connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), ALICE).returnValue
} }
} }
} }
@ -118,24 +127,50 @@ class NodePerformanceTests : IntegrationTest() {
@Test @Test
fun `self pay rate`() { fun `self pay rate`() {
val user = User("A", "A", setOf(startFlow<CashIssueFlow>(), startFlow<CashPaymentFlow>())) val user = User("A", "A", setOf(startFlow<CashIssueAndPaymentFlow>()))
driver( driver(
notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, rpcUsers = listOf(user))), notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, rpcUsers = listOf(user))),
startNodesInProcess = true, startNodesInProcess = true,
extraCordappPackagesToScan = listOf("net.corda.finance"), extraCordappPackagesToScan = listOf("net.corda.finance", "com.r3.corda.enterprise.perftestcordapp"),
portAllocation = PortAllocation.Incremental(20000) portAllocation = PortAllocation.Incremental(20000)
) { ) {
val notary = defaultNotaryNode.getOrThrow() as NodeHandle.InProcess val notary = defaultNotaryNode.getOrThrow() as NodeHandle.InProcess
val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, notary.node.services.monitoringService.metrics) val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, notary.node.services.monitoringService.metrics)
notary.rpcClientToNode().use("A", "A") { connection -> notary.rpcClientToNode().use("A", "A") { connection ->
println("ISSUING") startPublishingFixedRateInjector(
val doneFutures = (1..100).toList().map { metricRegistry = metricRegistry,
connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), defaultNotaryIdentity).returnValue parallelism = 64,
}.toList() overallDuration = 5.minutes,
doneFutures.transpose().get() injectionRate = 300L / TimeUnit.SECONDS,
println("STARTING PAYMENT") workBound = 50
startPublishingFixedRateInjector(metricRegistry, 8, 5.minutes, 5L / TimeUnit.SECONDS) { ) {
connection.proxy.startFlow(::CashPaymentFlow, 1.DOLLARS, defaultNotaryIdentity).returnValue.get() connection.proxy.startFlow(::CashIssueAndPaymentFlow, 1.DOLLARS, OpaqueBytes.of(0), defaultNotaryIdentity, false, defaultNotaryIdentity).returnValue
}
}
}
}
@Test
fun `self pay rate without selection`() {
val user = User("A", "A", setOf(startFlow<CashIssueAndPaymentNoSelection>()))
driver(
notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME)),
startNodesInProcess = true,
portAllocation = PortAllocation.Incremental(20000)
) {
val aliceFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user), startInSameProcess = true)
val alice = aliceFuture.getOrThrow() as NodeHandle.InProcess
defaultNotaryNode.getOrThrow()
val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, alice.node.services.monitoringService.metrics)
alice.rpcClientToNode().use("A", "A") { connection ->
startPublishingFixedRateInjector(
metricRegistry = metricRegistry,
parallelism = 64,
overallDuration = 5.minutes,
injectionRate = 50L / TimeUnit.SECONDS,
workBound = 500
) {
connection.proxy.startFlow(::CashIssueAndPaymentNoSelection, 1.DOLLARS, OpaqueBytes.of(0), alice.nodeInfo.legalIdentities[0], false, defaultNotaryIdentity).returnValue
} }
} }
} }
@ -143,18 +178,19 @@ class NodePerformanceTests : IntegrationTest() {
@Test @Test
fun `single pay`() { fun `single pay`() {
val user = User("A", "A", setOf(startFlow<CashIssueFlow>(), startFlow<CashPaymentFlow>())) val user = User("A", "A", setOf(startFlow<CashIssueAndPaymentNoSelection>()))
driver( driver(
notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME, rpcUsers = listOf(user))), notarySpecs = listOf(NotarySpec(DUMMY_NOTARY_NAME)),
startNodesInProcess = true, startNodesInProcess = true,
extraCordappPackagesToScan = listOf("net.corda.finance"),
portAllocation = PortAllocation.Incremental(20000) portAllocation = PortAllocation.Incremental(20000)
) { ) {
val notary = defaultNotaryNode.getOrThrow() as NodeHandle.InProcess val aliceFuture = startNode(providedName = ALICE_NAME, rpcUsers = listOf(user))
val metricRegistry = startReporter((this as InternalDriverDSL).shutdownManager, notary.node.services.monitoringService.metrics) val bobFuture = startNode(providedName = BOB_NAME, rpcUsers = listOf(user))
notary.rpcClientToNode().use("A", "A") { connection -> val alice = aliceFuture.getOrThrow() as NodeHandle.InProcess
connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), defaultNotaryIdentity).returnValue.getOrThrow() val bob = bobFuture.getOrThrow() as NodeHandle.InProcess
connection.proxy.startFlow(::CashPaymentFlow, 1.DOLLARS, defaultNotaryIdentity).returnValue.getOrThrow() defaultNotaryNode.getOrThrow()
alice.rpcClientToNode().use("A", "A") { connection ->
connection.proxy.startFlow(::CashIssueAndPaymentNoSelection, 1.DOLLARS, OpaqueBytes.of(0), bob.nodeInfo.legalIdentities[0], false, defaultNotaryIdentity).returnValue.getOrThrow()
} }
} }
} }

View File

@ -41,6 +41,7 @@ class CashIssueAndPaymentNoSelection(val amount: Amount<Currency>,
fun deriveState(txState: TransactionState<Cash.State>, amt: Amount<Issued<Currency>>, owner: AbstractParty) fun deriveState(txState: TransactionState<Cash.State>, amt: Amount<Issued<Currency>>, owner: AbstractParty)
= txState.copy(data = txState.data.copy(amount = amt, owner = owner)) = txState.copy(data = txState.data.copy(amount = amt, owner = owner))
progressTracker.currentStep = GENERATING_TX
val issueResult = subFlow(CashIssueFlow(amount, issueRef, notary)) val issueResult = subFlow(CashIssueFlow(amount, issueRef, notary))
val cashStateAndRef = issueResult.stx.tx.outRef<Cash.State>(0) val cashStateAndRef = issueResult.stx.tx.outRef<Cash.State>(0)

View File

@ -71,8 +71,8 @@ object AutoOfferFlow {
// and because in a real life app you'd probably have more complex logic here e.g. describing why the report // and because in a real life app you'd probably have more complex logic here e.g. describing why the report
// was filed, checking that the reportee is a regulated entity and not some random node from the wrong // was filed, checking that the reportee is a regulated entity and not some random node from the wrong
// country and so on. // country and so on.
val regulator = serviceHub.identityService.partiesFromName("Regulator", true).single() // val regulator = serviceHub.identityService.partiesFromName("Regulator", true).single()
subFlow(ReportToRegulatorFlow(regulator, finalTx)) // subFlow(ReportToRegulatorFlow(regulator, finalTx))
return finalTx return finalTx
} }

View File

@ -327,20 +327,18 @@ class InMemoryMessagingNetwork internal constructor(
state.locked { check(handlers.remove(registration as Handler)) } state.locked { check(handlers.remove(registration as Handler)) }
} }
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
check(running) check(running)
msgSend(this, message, target) msgSend(this, message, target)
acknowledgementHandler?.invoke()
if (!sendManuallyPumped) { if (!sendManuallyPumped) {
pumpSend(false) pumpSend(false)
} }
} }
override fun send(addressedMessages: List<MessagingService.AddressedMessage>, acknowledgementHandler: (() -> Unit)?) { override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) { for ((message, target, retryId, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey, null) send(message, target, retryId, sequenceKey)
} }
acknowledgementHandler?.invoke()
} }
override fun stop() { override fun stop() {

View File

@ -37,7 +37,6 @@ import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.transactions.BFTNonValidatingNotaryService import net.corda.node.services.transactions.BFTNonValidatingNotaryService
import net.corda.node.services.transactions.BFTSMaRt import net.corda.node.services.transactions.BFTSMaRt
import net.corda.node.services.transactions.InMemoryTransactionVerifierService import net.corda.node.services.transactions.InMemoryTransactionVerifierService
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor import net.corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor
import net.corda.nodeapi.internal.DevIdentityGenerator import net.corda.nodeapi.internal.DevIdentityGenerator
import net.corda.nodeapi.internal.config.User import net.corda.nodeapi.internal.config.User
@ -45,14 +44,14 @@ import net.corda.nodeapi.internal.network.NetworkParametersCopier
import net.corda.nodeapi.internal.network.NotaryInfo import net.corda.nodeapi.internal.network.NotaryInfo
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.setGlobalSerialization
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.testThreadFactory import net.corda.testing.internal.testThreadFactory
import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO import net.corda.testing.node.MockServices.Companion.MOCK_VERSION_INFO
import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties import net.corda.testing.node.MockServices.Companion.makeTestDataSourceProperties
import net.corda.testing.node.MockServices.Companion.makeTestDatabaseProperties import net.corda.testing.node.MockServices.Companion.makeTestDatabaseProperties
import net.corda.testing.core.setGlobalSerialization
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import org.apache.sshd.common.util.security.SecurityUtils import org.apache.sshd.common.util.security.SecurityUtils
import rx.internal.schedulers.CachedThreadScheduler import rx.internal.schedulers.CachedThreadScheduler
@ -270,7 +269,7 @@ open class MockNetwork(private val cordappPackages: List<String>,
private val entropyRoot = args.entropyRoot private val entropyRoot = args.entropyRoot
var counter = entropyRoot var counter = entropyRoot
override val log get() = staticLog override val log get() = staticLog
override val serverThread: AffinityExecutor = override val serverThread =
if (mockNet.threadPerNode) { if (mockNet.threadPerNode) {
ServiceAffinityExecutor("Mock node $id thread", 1) ServiceAffinityExecutor("Mock node $id thread", 1)
} else { } else {
@ -514,6 +513,9 @@ private fun mockNodeConfiguration(): NodeConfiguration {
doReturn(5).whenever(it).messageRedeliveryDelaySeconds doReturn(5).whenever(it).messageRedeliveryDelaySeconds
doReturn(5.seconds.toMillis()).whenever(it).additionalNodeInfoPollingFrequencyMsec doReturn(5.seconds.toMillis()).whenever(it).additionalNodeInfoPollingFrequencyMsec
doReturn(null).whenever(it).devModeOptions doReturn(null).whenever(it).devModeOptions
doReturn(EnterpriseConfiguration(MutualExclusionConfiguration(false, "", 20000, 40000))).whenever(it).enterpriseConfiguration doReturn(EnterpriseConfiguration(
mutualExclusionConfiguration = MutualExclusionConfiguration(false, "", 20000, 40000),
useMultiThreadedSMM = false
)).whenever(it).enterpriseConfiguration
} }
} }

View File

@ -8,6 +8,7 @@ import net.corda.core.internal.div
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.internal.EnterpriseNode
import net.corda.node.internal.Node import net.corda.node.internal.Node
import net.corda.node.internal.StartedNode import net.corda.node.internal.StartedNode
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
@ -128,7 +129,7 @@ abstract class NodeBasedTest(private val cordappPackages: List<String> = emptyLi
} }
class InProcessNode( class InProcessNode(
configuration: NodeConfiguration, versionInfo: VersionInfo, cordappPackages: List<String>) : Node( configuration: NodeConfiguration, versionInfo: VersionInfo, cordappPackages: List<String>) : EnterpriseNode(
configuration, versionInfo, false, CordappLoader.createDefaultWithTestPackages(configuration, cordappPackages)) { configuration, versionInfo, false, CordappLoader.createDefaultWithTestPackages(configuration, cordappPackages)) {
override fun getRxIoScheduler() = CachedThreadScheduler(testThreadFactory()).also { runOnStop += it::shutdown } override fun getRxIoScheduler() = CachedThreadScheduler(testThreadFactory()).also { runOnStop += it::shutdown }
} }

View File

@ -1,10 +1,14 @@
package net.corda.testing.node.internal.performance package net.corda.testing.node.internal.performance
import com.codahale.metrics.Gauge import com.codahale.metrics.Gauge
import com.codahale.metrics.MetricRegistry import com.codahale.metrics.MetricRegistry
import com.google.common.base.Stopwatch import com.google.common.base.Stopwatch
import net.corda.core.concurrent.CordaFuture
import net.corda.core.utilities.getOrThrow
import net.corda.testing.internal.performance.Rate import net.corda.testing.internal.performance.Rate
import net.corda.testing.node.internal.ShutdownManager import net.corda.testing.node.internal.ShutdownManager
import org.slf4j.LoggerFactory
import java.time.Duration import java.time.Duration
import java.util.* import java.util.*
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
@ -16,6 +20,7 @@ import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.thread import kotlin.concurrent.thread
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
private val log = LoggerFactory.getLogger("TightLoopInjector")
fun startTightLoopInjector( fun startTightLoopInjector(
parallelism: Int, parallelism: Int,
numberOfInjections: Int, numberOfInjections: Int,
@ -34,7 +39,11 @@ fun startTightLoopInjector(
while (true) { while (true) {
if (leftToSubmit.getAndDecrement() == 0) break if (leftToSubmit.getAndDecrement() == 0) break
executor.submit { executor.submit {
work() try {
work()
} catch (exception: Exception) {
log.error("Error while executing injection", exception)
}
if (queuedCount.decrementAndGet() < queueBound / 2) { if (queuedCount.decrementAndGet() < queueBound / 2) {
lock.withLock { lock.withLock {
canQueueAgain.signal() canQueueAgain.signal()
@ -60,11 +69,13 @@ fun startPublishingFixedRateInjector(
parallelism: Int, parallelism: Int,
overallDuration: Duration, overallDuration: Duration,
injectionRate: Rate, injectionRate: Rate,
workBound: Int,
queueSizeMetricName: String = "QueueSize", queueSizeMetricName: String = "QueueSize",
workDurationMetricName: String = "WorkDuration", workDurationMetricName: String = "WorkDuration",
work: () -> Unit work: () -> CordaFuture<*>
) { ) {
val workSemaphore = Semaphore(0) val workSemaphore = Semaphore(0)
val workBoundSemaphore = Semaphore(workBound)
metricRegistry.register(queueSizeMetricName, Gauge { workSemaphore.availablePermits() }) metricRegistry.register(queueSizeMetricName, Gauge { workSemaphore.availablePermits() })
val workDurationTimer = metricRegistry.timer(workDurationMetricName) val workDurationTimer = metricRegistry.timer(workDurationMetricName)
ShutdownManager.run { ShutdownManager.run {
@ -72,19 +83,16 @@ fun startPublishingFixedRateInjector(
registerShutdown { executor.shutdown() } registerShutdown { executor.shutdown() }
val workExecutor = Executors.newFixedThreadPool(parallelism) val workExecutor = Executors.newFixedThreadPool(parallelism)
registerShutdown { workExecutor.shutdown() } registerShutdown { workExecutor.shutdown() }
val timings = Collections.synchronizedList(ArrayList<Long>())
for (i in 1..parallelism) { for (i in 1..parallelism) {
workExecutor.submit { workExecutor.submit {
try { try {
while (true) { while (true) {
workSemaphore.acquire() workSemaphore.acquire()
workBoundSemaphore.acquire()
workDurationTimer.time { workDurationTimer.time {
timings.add( work().getOrThrow()
Stopwatch.createStarted().apply {
work()
}.stop().elapsed(TimeUnit.MICROSECONDS)
)
} }
workBoundSemaphore.release()
} }
} catch (throwable: Throwable) { } catch (throwable: Throwable) {
throwable.printStackTrace() throwable.printStackTrace()
@ -105,4 +113,3 @@ fun startPublishingFixedRateInjector(
Thread.sleep(overallDuration.toMillis()) Thread.sleep(overallDuration.toMillis())
} }
} }