Remove SAME_THREAD executor and it's use in MockNetwork etc.

Remove all traces of unused optional Executor in messaging.
This commit is contained in:
rick.parker 2016-10-11 18:10:12 +01:00
parent b7afd54d29
commit 02a9f8fe67
16 changed files with 136 additions and 113 deletions

View File

@ -9,7 +9,6 @@ import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import java.time.Instant
import java.util.*
import java.util.concurrent.Executor
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.ThreadSafe
@ -26,11 +25,8 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe
interface MessagingService {
/**
* The provided function will be invoked for each received message whose topic matches the given string, on the given
* executor.
*
* If no executor is received then the callback will run on threads provided by the messaging service, and the
* callback is expected to be thread safe as a result.
* The provided function will be invoked for each received message whose topic matches the given string. The callback
* will run on threads provided by the messaging service, and the callback is expected to be thread safe as a result.
*
* The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler].
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
@ -41,14 +37,11 @@ interface MessagingService {
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun addMessageHandler(topic: String = "", sessionID: Long = DEFAULT_SESSION_ID, executor: Executor? = null, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
fun addMessageHandler(topic: String = "", sessionID: Long = DEFAULT_SESSION_ID, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
/**
* The provided function will be invoked for each received message whose topic and session matches, on the
* given executor.
*
* If no executor is received then the callback will run on threads provided by the messaging service, and the
* callback is expected to be thread safe as a result.
* The provided function will be invoked for each received message whose topic and session matches. The callback
* will run on threads provided by the messaging service, and the callback is expected to be thread safe as a result.
*
* The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler].
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
@ -56,7 +49,7 @@ interface MessagingService {
*
* @param topicSession identifier for the topic and session to listen for messages arriving on.
*/
fun addMessageHandler(topicSession: TopicSession, executor: Executor? = null, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
fun addMessageHandler(topicSession: TopicSession, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
/**
* Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once
@ -111,8 +104,8 @@ fun MessagingService.createMessage(topic: String, sessionID: Long = DEFAULT_SESS
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, executor: Executor? = null, callback: (Message) -> Unit)
= runOnNextMessage(TopicSession(topic, sessionID), executor, callback)
fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, callback: (Message) -> Unit)
= runOnNextMessage(TopicSession(topic, sessionID), callback)
/**
* Registers a handler for the given topic and session that runs the given callback with the message and then removes
@ -121,9 +114,9 @@ fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, executor:
*
* @param topicSession identifier for the topic and session to listen for messages arriving on.
*/
inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, executor: Executor? = null, crossinline callback: (Message) -> Unit) {
inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, crossinline callback: (Message) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topicSession, executor) { msg, reg ->
addMessageHandler(topicSession) { msg, reg ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topicSession == topicSession) { "Topic/session mismatch: ${msg.topicSession} vs $topicSession" }
@ -135,9 +128,9 @@ inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, executo
* Returns a [ListenableFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId.
* The payload is deserilaized to an object of type [M]. Any exceptions thrown will be captured by the future.
*/
fun <M : Any> MessagingService.onNext(topic: String, sessionId: Long, executor: Executor? = null): ListenableFuture<M> {
fun <M : Any> MessagingService.onNext(topic: String, sessionId: Long): ListenableFuture<M> {
val messageFuture = SettableFuture.create<M>()
runOnNextMessage(topic, sessionId, executor) { message ->
runOnNextMessage(topic, sessionId) { message ->
messageFuture.catch {
message.data.deserialize<M>()
}

View File

@ -6,7 +6,6 @@ import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.messaging.onNext
import com.r3corda.core.messaging.send
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import java.util.concurrent.Executor
/**
* Abstract superclass for request messages sent to services which expect a reply.
@ -22,9 +21,8 @@ interface ServiceRequestMessage {
*/
fun <R : Any> MessagingService.sendRequest(topic: String,
request: ServiceRequestMessage,
target: SingleMessageRecipient,
executor: Executor? = null): ListenableFuture<R> {
val responseFuture = onNext<R>(topic, request.sessionID, executor)
target: SingleMessageRecipient): ListenableFuture<R> {
val responseFuture = onNext<R>(topic, request.sessionID)
send(topic, DEFAULT_SESSION_ID, request, target)
return responseFuture
}

View File

@ -4,7 +4,6 @@ import com.codahale.metrics.MetricRegistry
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.RunOnCallerThread
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.X509Utilities
import com.r3corda.core.messaging.SingleMessageRecipient
@ -374,7 +373,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
val reg = NodeRegistration(info, instant.toEpochMilli(), type, expires)
val legalIdentityKey = obtainLegalIdentityKey()
val request = NetworkMapService.RegistrationRequest(reg.toWire(legalIdentityKey.private), net.myAddress)
return net.sendRequest(REGISTER_PROTOCOL_TOPIC, request, networkMapAddr, RunOnCallerThread)
return net.sendRequest(REGISTER_PROTOCOL_TOPIC, request, networkMapAddr)
}
protected open fun makeKeyManagementService(): KeyManagementService = PersistentKeyManagementService(partyKeys)

View File

@ -31,7 +31,7 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
addMessageHandler(topic: String,
crossinline handler: (Q) -> R,
crossinline exceptionConsumer: (Message, Exception) -> Unit): MessageHandlerRegistration {
return net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, r ->
return net.addMessageHandler(topic, DEFAULT_SESSION_ID) { message, r ->
try {
val request = message.data.deserialize<Q>()
val response = handler(request)

View File

@ -21,7 +21,6 @@ import java.time.Instant
import java.util.*
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executor
import javax.annotation.concurrent.ThreadSafe
// TODO: Stop the wallet explorer and other clients from using this class and get rid of persistentInbox
@ -92,8 +91,7 @@ class NodeMessagingClient(config: NodeConfiguration,
}
/** A registration to handle messages of different types */
data class Handler(val executor: Executor?,
val topicSession: TopicSession,
data class Handler(val topicSession: TopicSession,
val callback: (Message, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
/**
@ -369,15 +367,13 @@ class NodeMessagingClient(config: NodeConfiguration,
}
}
override fun addMessageHandler(topic: String, sessionID: Long, executor: Executor?,
callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
= addMessageHandler(TopicSession(topic, sessionID), executor, callback)
override fun addMessageHandler(topic: String, sessionID: Long, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
= addMessageHandler(TopicSession(topic, sessionID), callback)
override fun addMessageHandler(topicSession: TopicSession,
executor: Executor?,
callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
require(!topicSession.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
val handler = Handler(executor, topicSession, callback)
val handler = Handler(topicSession, callback)
handlers.add(handler)
val messagesToRedeliver = state.locked {
val messagesToRedeliver = undeliveredMessages

View File

@ -3,7 +3,6 @@ package com.r3corda.node.services.network
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.RunOnCallerThread
import com.r3corda.core.contracts.Contract
import com.r3corda.core.crypto.Party
import com.r3corda.core.map
@ -20,12 +19,10 @@ import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.node.services.api.RegulatorService
import com.r3corda.node.services.network.NetworkMapService.Companion.FETCH_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.SUBSCRIPTION_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.FetchMapResponse
import com.r3corda.node.services.network.NetworkMapService.SubscribeResponse
import com.r3corda.node.services.transactions.NotaryService
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.protocols.sendRequest
import rx.Observable
@ -70,7 +67,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
ifChangedSinceVer: Int?): ListenableFuture<Unit> {
if (subscribe && !registeredForPush) {
// Add handler to the network, for updates received from the remote network map service.
net.addMessageHandler(NetworkMapService.PUSH_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, null) { message, r ->
net.addMessageHandler(NetworkMapService.PUSH_PROTOCOL_TOPIC, DEFAULT_SESSION_ID) { message, r ->
try {
val req = message.data.deserialize<NetworkMapService.Update>()
val ackMessage = net.createMessage(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, DEFAULT_SESSION_ID,
@ -88,7 +85,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
// Fetch the network map and register for updates at the same time
val req = NetworkMapService.FetchMapRequest(subscribe, ifChangedSinceVer, net.myAddress)
val future = net.sendRequest<FetchMapResponse>(FETCH_PROTOCOL_TOPIC, req, networkMapAddress, RunOnCallerThread).map { resp ->
val future = net.sendRequest<FetchMapResponse>(FETCH_PROTOCOL_TOPIC, req, networkMapAddress).map { resp ->
// We may not receive any nodes back, if the map hasn't changed since the version specified
resp.nodes?.forEach { processRegistration(it) }
Unit
@ -120,7 +117,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
override fun deregisterForUpdates(net: MessagingService, service: NodeInfo): ListenableFuture<Unit> {
// Fetch the network map and register for updates at the same time
val req = NetworkMapService.SubscribeRequest(false, net.myAddress)
val future = net.sendRequest<SubscribeResponse>(SUBSCRIPTION_PROTOCOL_TOPIC, req, service.address, RunOnCallerThread).map {
val future = net.sendRequest<SubscribeResponse>(SUBSCRIPTION_PROTOCOL_TOPIC, req, service.address).map {
if (it.confirmed) Unit else throw NetworkCacheError.DeregistrationFailed()
}
_registrationFuture.setFuture(future)

View File

@ -154,7 +154,7 @@ abstract class AbstractNetworkMapService
handlers += addMessageHandler(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC,
{ req: NetworkMapService.SubscribeRequest -> processSubscriptionRequest(req) }
)
handlers += net.addMessageHandler(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, DEFAULT_SESSION_ID, null) { message, r ->
handlers += net.addMessageHandler(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC, DEFAULT_SESSION_ID) { message, r ->
val req = message.data.deserialize<NetworkMapService.UpdateAcknowledge>()
processAcknowledge(req)
}

View File

@ -175,7 +175,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
started = true
stateMachines.keys.forEach { resumeRestoredFiber(it) }
}
serviceHub.networkService.addMessageHandler(sessionTopic, executor) { message, reg ->
serviceHub.networkService.addMessageHandler(sessionTopic) { message, reg ->
executor.checkOnThread()
val sessionMessage = message.data.deserialize<SessionMessage>()
when (sessionMessage) {

View File

@ -1,5 +1,6 @@
package com.r3corda.node.utilities
import com.google.common.util.concurrent.SettableFuture
import com.google.common.util.concurrent.Uninterruptibles
import com.r3corda.core.utilities.loggerFor
import java.util.*
@ -41,19 +42,20 @@ interface AffinityExecutor : Executor {
}
/**
* Posts a no-op task to the executor and blocks this thread waiting for it to complete. This can be useful in
* tests when you want to be sure that a previous task submitted via [execute] has completed.
* Run the executor until there are no tasks pending and none executing.
*/
fun flush() {
fetchFrom { }
}
fun flush()
/**
* An executor backed by thread pool (which may often have a single thread) which makes it easy to schedule
* tasks in the future and verify code is running on the executor.
*/
class ServiceAffinityExecutor(threadName: String, numThreads: Int) : AffinityExecutor,
open class ServiceAffinityExecutor(threadName: String, numThreads: Int) : AffinityExecutor,
ThreadPoolExecutor(numThreads, numThreads, 0L, TimeUnit.MILLISECONDS, LinkedBlockingQueue<Runnable>()) {
companion object {
val logger = loggerFor<ServiceAffinityExecutor>()
}
private val threads = Collections.synchronizedSet(HashSet<Thread>())
private val uncaughtExceptionHandler = Thread.currentThread().uncaughtExceptionHandler
@ -82,8 +84,11 @@ interface AffinityExecutor : Executor {
override val isOnThread: Boolean get() = Thread.currentThread() in threads
companion object {
val logger = loggerFor<ServiceAffinityExecutor>()
override fun flush() {
do {
val f = SettableFuture.create<Boolean>()
execute { f.set(queue.isEmpty() && activeCount == 1) }
} while (!f.get())
}
}
@ -110,12 +115,9 @@ interface AffinityExecutor : Executor {
}
val taskQueueSize: Int get() = commandQ.size
}
companion object {
val SAME_THREAD: AffinityExecutor = object : AffinityExecutor {
override val isOnThread: Boolean get() = true
override fun execute(command: Runnable) = command.run()
override fun flush() {
throw UnsupportedOperationException()
}
}
}

View File

@ -117,7 +117,7 @@ class ArtemisMessagingTests {
}
private fun createMessagingClient(server: HostAndPort = hostAndPort): NodeMessagingClient {
return NodeMessagingClient(config, server, identity.public, AffinityExecutor.SAME_THREAD, false).apply {
return NodeMessagingClient(config, server, identity.public, AffinityExecutor.ServiceAffinityExecutor("ArtemisMessagingTests", 1), false).apply {
configureWithDevSSLCertificate()
messagingClient = this
}

View File

@ -71,7 +71,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
init {
val kms = MockKeyManagementService(ALICE_KEY)
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"), persistenceTx = { it() })
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"), AffinityExecutor.ServiceAffinityExecutor("test", 1), persistenceTx = { it() })
services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference {
override val testReference = this@NodeSchedulerServiceTest
}

View File

@ -117,6 +117,8 @@ class StateMachineManagerTests {
val payload = random63BitValue()
node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
node2.stop() // kill receiver
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
@ -137,6 +139,8 @@ class StateMachineManagerTests {
// Kick off first send and receive
node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints().size)
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
// Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())

View File

@ -3,7 +3,6 @@ package com.r3corda.node.utilities
import org.junit.After
import org.junit.Test
import java.util.*
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicReference
import kotlin.concurrent.thread
@ -12,30 +11,30 @@ import kotlin.test.assertFails
import kotlin.test.assertNotEquals
class AffinityExecutorTests {
@Test fun `AffinityExecutor SAME_THREAD executes on calling thread`() {
assert(AffinityExecutor.SAME_THREAD.isOnThread)
run {
val thatThread = CompletableFuture<Thread>()
AffinityExecutor.SAME_THREAD.execute { thatThread.complete(Thread.currentThread()) }
assertEquals(Thread.currentThread(), thatThread.get())
}
run {
val thatThread = CompletableFuture<Thread>()
AffinityExecutor.SAME_THREAD.executeASAP { thatThread.complete(Thread.currentThread()) }
assertEquals(Thread.currentThread(), thatThread.get())
}
}
var executor: AffinityExecutor.ServiceAffinityExecutor? = null
var _executor: AffinityExecutor.ServiceAffinityExecutor? = null
val executor: AffinityExecutor.ServiceAffinityExecutor get() = _executor!!
@After fun shutdown() {
executor?.shutdown()
_executor?.shutdown()
_executor = null
}
@Test fun `flush handles nested executes`() {
_executor = AffinityExecutor.ServiceAffinityExecutor("test4", 1)
var nestedRan = false
val latch = CountDownLatch(1)
executor.execute {
latch.await()
executor.execute { nestedRan = true }
}
latch.countDown()
executor.flush()
assert(nestedRan)
}
@Test fun `single threaded affinity executor runs on correct thread`() {
val thisThread = Thread.currentThread()
val executor = AffinityExecutor.ServiceAffinityExecutor("test thread", 1)
_executor = AffinityExecutor.ServiceAffinityExecutor("test thread", 1)
assert(!executor.isOnThread)
assertFails { executor.checkOnThread() }
@ -55,7 +54,7 @@ class AffinityExecutorTests {
}
@Test fun `pooled executor`() {
val executor = AffinityExecutor.ServiceAffinityExecutor("test2", 3)
_executor = AffinityExecutor.ServiceAffinityExecutor("test2", 3)
assert(!executor.isOnThread)
val latch = CountDownLatch(1)
@ -89,7 +88,7 @@ class AffinityExecutorTests {
// Run in a separate thread to avoid messing with any default exception handlers in the unit test thread.
thread {
Thread.currentThread().setUncaughtExceptionHandler { thread, throwable -> exception.set(throwable) }
val executor = AffinityExecutor.ServiceAffinityExecutor("test3", 1)
_executor = AffinityExecutor.ServiceAffinityExecutor("test3", 1)
executor.execute {
throw Exception("foo")
}

View File

@ -2,13 +2,13 @@ package com.r3corda.testing.node
import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.ThreadBox
import com.r3corda.core.messaging.*
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.MessagingServiceBuilder
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.testing.node.InMemoryMessagingNetwork.InMemoryMessaging
import org.slf4j.LoggerFactory
import rx.Observable
@ -16,7 +16,6 @@ import rx.subjects.PublishSubject
import java.time.Duration
import java.time.Instant
import java.util.*
import java.util.concurrent.Executor
import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.ThreadSafe
import kotlin.concurrent.schedule
@ -80,10 +79,11 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
*/
@Synchronized
fun createNode(manuallyPumped: Boolean,
executor: AffinityExecutor,
persistenceTx: (() -> Unit) -> Unit)
: Pair<Handle, com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging>> {
check(counter >= 0) { "In memory network stopped: please recreate." }
val builder = createNodeWithID(manuallyPumped, counter, persistenceTx = persistenceTx) as Builder
val builder = createNodeWithID(manuallyPumped, counter, executor, persistenceTx = persistenceTx) as Builder
counter++
val id = builder.id
return Pair(id, builder)
@ -97,10 +97,10 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
* @param description text string that identifies this node for message logging (if is enabled) or null to autogenerate.
* @param persistenceTx a lambda to wrap message handling in a transaction if necessary.
*/
fun createNodeWithID(manuallyPumped: Boolean, id: Int, description: String? = null,
fun createNodeWithID(manuallyPumped: Boolean, id: Int, executor: AffinityExecutor, description: String? = null,
persistenceTx: (() -> Unit) -> Unit)
: com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging> {
return Builder(manuallyPumped, Handle(id, description ?: "In memory node $id"), persistenceTx)
: MessagingServiceBuilder<InMemoryMessaging> {
return Builder(manuallyPumped, Handle(id, description ?: "In memory node $id"), executor, persistenceTx)
}
interface LatencyCalculator {
@ -140,11 +140,11 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
messageReceiveQueues.clear()
}
inner class Builder(val manuallyPumped: Boolean, val id: Handle, val persistenceTx: (() -> Unit) -> Unit)
inner class Builder(val manuallyPumped: Boolean, val id: Handle, val executor: AffinityExecutor, val persistenceTx: (() -> Unit) -> Unit)
: com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging> {
override fun start(): ListenableFuture<InMemoryMessaging> {
synchronized(this@InMemoryMessagingNetwork) {
val node = InMemoryMessaging(manuallyPumped, id, persistenceTx)
val node = InMemoryMessaging(manuallyPumped, id, executor, persistenceTx)
handleEndpointMap[id] = node
return Futures.immediateFuture(node)
}
@ -207,9 +207,10 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
@ThreadSafe
inner class InMemoryMessaging(private val manuallyPumped: Boolean,
private val handle: Handle,
private val executor: AffinityExecutor,
private val persistenceTx: (() -> Unit) -> Unit)
: SingletonSerializeAsToken(), com.r3corda.node.services.api.MessagingServiceInternal {
inner class Handler(val executor: Executor?, val topicSession: TopicSession,
inner class Handler(val topicSession: TopicSession,
val callback: (Message, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
@Volatile
@ -236,13 +237,13 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
}
}
override fun addMessageHandler(topic: String, sessionID: Long, executor: Executor?, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
= addMessageHandler(TopicSession(topic, sessionID), executor, callback)
override fun addMessageHandler(topic: String, sessionID: Long, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
= addMessageHandler(TopicSession(topic, sessionID), callback)
override fun addMessageHandler(topicSession: TopicSession, executor: Executor?, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
override fun addMessageHandler(topicSession: TopicSession, callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
check(running)
val (handler, items) = state.locked {
val handler = Handler(executor, topicSession, callback).apply { handlers.add(this) }
val handler = Handler(topicSession, callback).apply { handlers.add(this) }
val items = ArrayList(pendingRedelivery)
pendingRedelivery.clear()
Pair(handler, items)
@ -298,7 +299,12 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
fun pumpReceive(block: Boolean): MessageTransfer? {
check(manuallyPumped)
check(running)
return pumpReceiveInternal(block)
executor.flush()
try {
return pumpReceiveInternal(block)
} finally {
executor.flush()
}
}
/**
@ -341,24 +347,22 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
val (transfer, deliverTo) = next
if (transfer.message.uniqueMessageId !in processedMessages) {
for (handler in deliverTo) {
// Now deliver via the requested executor, or on this thread if no executor was provided at registration time.
(handler.executor ?: MoreExecutors.directExecutor()).execute {
try {
persistenceTx {
executor.execute {
persistenceTx {
for (handler in deliverTo) {
try {
handler.callback(transfer.message, handler)
} catch(e: Exception) {
log.error("Caught exception in handler for $this/${handler.topicSession}", e)
}
} catch(e: Exception) {
loggerFor<InMemoryMessagingNetwork>().error("Caught exception in handler for $this/${handler.topicSession}", e)
}
_receivedMessages.onNext(transfer)
processedMessages += transfer.message.uniqueMessageId
}
}
_receivedMessages.onNext(transfer)
processedMessages += transfer.message.uniqueMessageId
} else {
log.info("Drop duplicate message ${transfer.message.uniqueMessageId}")
}
return transfer
}
}

View File

@ -16,6 +16,7 @@ import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.keys.E2ETestKeyManagementService
import com.r3corda.node.services.messaging.CordaRPCOps
@ -26,6 +27,8 @@ import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.services.transactions.ValidatingNotaryService
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor
import com.r3corda.node.utilities.databaseTransaction
import org.slf4j.Logger
import java.nio.file.FileSystem
@ -33,6 +36,8 @@ import java.nio.file.Files
import java.nio.file.Path
import java.security.KeyPair
import java.util.*
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
/**
* A mock node brings up a suite of in-memory services in a fast manner suitable for unit testing.
@ -50,7 +55,7 @@ import java.util.*
class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
private val threadPerNode: Boolean = false,
private val defaultFactory: Factory = MockNetwork.DefaultFactory) {
private var counter = 0
private var nextNodeId = 0
val filesystem: FileSystem = Jimfs.newFileSystem(unix())
val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped)
@ -80,21 +85,47 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
}
}
/**
* Because this executor is shared, we need to be careful about nodes shutting it down.
*/
private val sharedUserCount = AtomicInteger(0)
private val sharedServerThread =
object : ServiceAffinityExecutor("Mock network shared node thread", 1) {
override fun shutdown() {
// We don't actually allow the shutdown of the network-wide shared thread pool until all references to
// it have been shutdown.
if (sharedUserCount.decrementAndGet() == 0) {
super.shutdown()
}
}
override fun awaitTermination(timeout: Long, unit: TimeUnit): Boolean {
if (!isShutdown) {
flush()
return true
} else {
return super.awaitTermination(timeout, unit)
}
}
}
open class MockNode(config: NodeConfiguration, val mockNet: MockNetwork, networkMapAddr: SingleMessageRecipient?,
advertisedServices: Set<ServiceInfo>, val id: Int, val keyPair: KeyPair?) : AbstractNode(config, networkMapAddr, advertisedServices, TestClock()) {
override val log: Logger = loggerFor<MockNode>()
override val serverThread: com.r3corda.node.utilities.AffinityExecutor =
override val serverThread: AffinityExecutor =
if (mockNet.threadPerNode)
com.r3corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor("Mock node thread", 1)
else
com.r3corda.node.utilities.AffinityExecutor.Companion.SAME_THREAD
ServiceAffinityExecutor("Mock node thread", 1)
else {
mockNet.sharedUserCount.incrementAndGet()
mockNet.sharedServerThread
}
// We only need to override the messaging service here, as currently everything that hits disk does so
// through the java.nio API which we are already mocking via Jimfs.
override fun makeMessagingService(): com.r3corda.node.services.api.MessagingServiceInternal {
override fun makeMessagingService(): MessagingServiceInternal {
require(id >= 0) { "Node ID must be zero or positive, was passed: " + id }
return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, configuration.myLegalName,
return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, serverThread, configuration.myLegalName,
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get()
}
@ -159,7 +190,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
start: Boolean = true, legalName: String? = null, keyPair: KeyPair? = null,
vararg advertisedServices: ServiceInfo): MockNode {
val newNode = forcedID == -1
val id = if (newNode) counter++ else forcedID
val id = if (newNode) nextNodeId++ else forcedID
val path = filesystem.getPath("/nodes/$id")
if (newNode)

View File

@ -164,7 +164,7 @@ class MockStorageService(override val attachments: AttachmentStorage = MockAttac
fun makeTestDataSourceProperties(nodeName: String = SecureHash.randomSHA256().toString()): Properties {
val props = Properties()
props.setProperty("dataSourceClassName", "org.h2.jdbcx.JdbcDataSource")
props.setProperty("dataSource.url", "jdbc:h2:mem:${nodeName}_persistence;DB_CLOSE_ON_EXIT=FALSE")
props.setProperty("dataSource.url", "jdbc:h2:mem:${nodeName}_persistence;MVCC=TRUE;DB_CLOSE_ON_EXIT=FALSE")
props.setProperty("dataSource.user", "sa")
props.setProperty("dataSource.password", "")
return props