Thread safety and messaging bug fixes.

* Use the new AffinityExecutor code to fix some thread affinity issues where callbacks were running on the wrong threads. Add affinity assertions.
* Remove sleeps from UpdateBusinessDayProtocol.
* Remove a one-shot message handler before the callback is executed.
* Store un-routed messages in memory in ArtemisMessagingService to fix handler registration/message races. This is a temporary kludge until we use Artemis/MQ better.
This commit is contained in:
Mike Hearn 2016-04-25 15:25:58 +02:00
parent 63b8579669
commit 746aca8290
9 changed files with 124 additions and 117 deletions

View File

@ -69,8 +69,8 @@ interface MessagingService {
*/
fun MessagingService.runOnNextMessage(topic: String = "", executor: Executor? = null, callback: (Message) -> Unit) {
addMessageHandler(topic, executor) { msg, reg ->
callback(msg)
removeMessageHandler(reg)
callback(msg)
}
}

View File

@ -17,6 +17,7 @@ import core.serialization.THREAD_LOCAL_KRYO
import core.serialization.createKryo
import core.serialization.deserialize
import core.serialization.serialize
import core.utilities.AffinityExecutor
import core.utilities.ProgressTracker
import core.utilities.trace
import org.slf4j.Logger
@ -24,7 +25,6 @@ import org.slf4j.LoggerFactory
import java.io.PrintWriter
import java.io.StringWriter
import java.util.*
import java.util.concurrent.Executor
import javax.annotation.concurrent.ThreadSafe
/**
@ -38,6 +38,9 @@ import javax.annotation.concurrent.ThreadSafe
* A "state machine" is a class with a single call method. The call method and any others it invokes are rewritten by
* a bytecode rewriting engine called Quasar, to ensure the code can be suspended and resumed at any point.
*
* The SMM will always invoke the protocol fibers on the given [AffinityExecutor], regardless of which thread actually
* starts them via [add].
*
* TODO: Session IDs should be set up and propagated automatically, on demand.
* TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised
* continuation is always unique?
@ -50,7 +53,7 @@ import javax.annotation.concurrent.ThreadSafe
* TODO: Implement stub/skel classes that provide a basic RPC framework on top of this.
*/
@ThreadSafe
class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) {
class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExecutor) {
// This map is backed by a database and will be used to store serialised state machines to disk, so we can resurrect
// them across node restarts.
private val checkpointsMap = serviceHub.storageService.stateMachines
@ -114,7 +117,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
val topic = checkpoint.awaitingTopic
// And now re-wire the deserialised continuation back up to the network service.
serviceHub.networkService.runOnNextMessage(topic, runInThread) { netMsg ->
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
// TODO: See security note below.
val obj: Any = THREAD_LOCAL_KRYO.get().readClassAndObject(Input(netMsg.data))
if (!awaitingObjectOfType.isInstance(obj))
@ -154,15 +157,22 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
* restarted with checkpointed state machines in the storage service.
*/
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
val logger = LoggerFactory.getLogger(loggerName)
val fiber = ProtocolStateMachine(logic)
// Need to add before iterating in case of immediate completion
_stateMachines.add(logic)
iterateStateMachine(fiber, serviceHub.networkService, logger, null, null) {
it.start()
try {
val logger = LoggerFactory.getLogger(loggerName)
val fiber = ProtocolStateMachine(logic)
// Need to add before iterating in case of immediate completion
_stateMachines.add(logic)
executor.executeASAP {
iterateStateMachine(fiber, serviceHub.networkService, logger, null, null) {
it.start()
}
totalStartedProtocols.inc()
}
return fiber.resultFuture
} catch(e: Throwable) {
e.printStackTrace()
throw e
}
totalStartedProtocols.inc()
return fiber.resultFuture
}
private fun persistCheckpoint(prevCheckpointKey: SecureHash?, new: ByteArray): SecureHash {
@ -178,12 +188,12 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
private fun iterateStateMachine(psm: ProtocolStateMachine<*>, net: MessagingService, logger: Logger,
obj: Any?, prevCheckpointKey: SecureHash?, resumeFunc: (ProtocolStateMachine<*>) -> Unit) {
executor.checkOnThread()
val onSuspend = fun(request: FiberRequest, serFiber: ByteArray) {
// We have a request to do something: send, receive, or send-and-receive.
if (request is FiberRequest.ExpectingResponse<*>) {
// Prepare a listener on the network that runs in the background thread when we received a message.
checkpointAndSetupMessageHandler(logger, net, psm, request.responseType,
"${request.topic}.${request.sessionIDForReceive}", prevCheckpointKey, serFiber)
checkpointAndSetupMessageHandler(logger, net, psm, request, prevCheckpointKey, serFiber)
}
// If an object to send was provided (not null), send it now.
request.obj?.let {
@ -217,13 +227,22 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
}
private fun checkpointAndSetupMessageHandler(logger: Logger, net: MessagingService, psm: ProtocolStateMachine<*>,
responseType: Class<*>, topic: String, prevCheckpointKey: SecureHash?,
request: FiberRequest.ExpectingResponse<*>, prevCheckpointKey: SecureHash?,
serialisedFiber: ByteArray) {
val checkpoint = Checkpoint(serialisedFiber, logger.name, topic, responseType.name)
executor.checkOnThread()
val topic = "${request.topic}.${request.sessionIDForReceive}"
val checkpoint = Checkpoint(serialisedFiber, logger.name, topic, request.responseType.name)
val curPersistedBytes = checkpoint.serialize().bits
persistCheckpoint(prevCheckpointKey, curPersistedBytes)
val newCheckpointKey = curPersistedBytes.sha256()
net.runOnNextMessage(topic, runInThread) { netMsg ->
logger.trace { "Waiting for message of type ${request.responseType.name} on $topic" }
var consumed = false
net.runOnNextMessage(topic, executor) { netMsg ->
// Some assertions to ensure we don't execute on the wrong thread or get executed more than once.
executor.checkOnThread()
check(netMsg.topic == topic) { "Topic mismatch: ${netMsg.topic} vs $topic" }
check(!consumed)
consumed = true
// TODO: This is insecure: we should not deserialise whatever we find and *then* check.
//
// We should instead verify as we read the data that it's what we are expecting and throw as early as
@ -232,9 +251,8 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
// at the last moment when we do the downcast. However this would make protocol code harder to read and
// make it more difficult to migrate to a more explicit serialisation scheme later.
val obj: Any = THREAD_LOCAL_KRYO.get().readClassAndObject(Input(netMsg.data))
if (!responseType.isInstance(obj))
throw ClassCastException("Expected message of type ${responseType.name} but got ${obj.javaClass.name}")
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
if (!request.responseType.isInstance(obj))
throw IllegalStateException("Expected message of type ${request.responseType.name} but got ${obj.javaClass.name}", request.stackTraceInCaseOfProblems)
iterateStateMachine(psm, net, logger, obj, newCheckpointKey) {
try {
Fiber.unpark(it, QUASAR_UNBLOCKER)
@ -245,11 +263,16 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
}
}
// TODO: Override more of this to avoid the case where Strand.sleep triggers a call to a scheduler that then runs on the wrong thread.
object SameThreadFiberScheduler : FiberExecutorScheduler("Same thread scheduler", MoreExecutors.directExecutor())
// TODO: Clean this up
open class FiberRequest(val topic: String, val destination: MessageRecipients?,
val sessionIDForSend: Long, val sessionIDForReceive: Long, val obj: Any?) {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems = StackSnapshot()
class ExpectingResponse<R : Any>(
topic: String,
destination: MessageRecipients?,
@ -266,4 +289,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
obj: Any?
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
}
}
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -3,10 +3,7 @@ package core.node
import api.APIServer
import api.APIServerImpl
import com.codahale.metrics.MetricRegistry
import contracts.*
import core.Contract
import core.Party
import core.crypto.SecureHash
import core.crypto.generateKeyPair
import core.messaging.MessagingService
import core.messaging.StateMachineManager
@ -14,6 +11,7 @@ import core.node.services.*
import core.serialization.deserialize
import core.serialization.serialize
import core.testing.MockNetworkMapCache
import core.utilities.AffinityExecutor
import org.slf4j.Logger
import java.nio.file.FileAlreadyExistsException
import java.nio.file.Files
@ -22,7 +20,6 @@ import java.security.KeyPair
import java.security.PublicKey
import java.time.Clock
import java.util.*
import java.util.concurrent.Executors
/**
* A base node implementation that can be customised either for production (with real implementations that do real
@ -36,9 +33,9 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
protected abstract val log: Logger
// We will run as much stuff in this thread as possible to keep the risk of thread safety bugs low during the
// We will run as much stuff in this single thread as possible to keep the risk of thread safety bugs low during the
// low-performance prototyping period.
protected open val serverThread = Executors.newSingleThreadExecutor()
protected open val serverThread: AffinityExecutor = AffinityExecutor.ServiceAffinityExecutor("Node thread", 1)
// Objects in this list will be scanned by the DataUploadServlet and can be handed new data via HTTP.
// Don't mutate this after startup.

View File

@ -58,9 +58,11 @@ class Node(dir: Path, val p2pAddr: HostAndPort, configuration: NodeConfiguration
// when our process shuts down, but we try in stop() anyway just to be nice.
private var nodeFileLock: FileLock? = null
override fun makeMessagingService(): MessagingService = ArtemisMessagingService(dir, p2pAddr)
override fun makeMessagingService(): MessagingService = ArtemisMessagingService(dir, p2pAddr, serverThread)
private fun initWebServer(): Server {
// Note that the web server handlers will all run concurrently, and not on the node thread.
val port = p2pAddr.port + 1 // TODO: Move this into the node config file.
val server = Server(port)

View File

@ -44,9 +44,14 @@ import javax.annotation.concurrent.ThreadSafe
* The current implementation is skeletal and lacks features like security or firewall tunnelling (that is, you must
* be able to receive TCP connections in order to receive messages). It is good enough for local communication within
* a fully connected network, trusted network or on localhost.
*
* @param directory A place where Artemis can stash its message journal and other files.
* @param myHostPort What host and port to bind to for receiving inbound connections.
* @param defaultExecutor This will be used as the default executor to run message handlers on, if no other is specified.
*/
@ThreadSafe
class ArtemisMessagingService(val directory: Path, val myHostPort: HostAndPort) : MessagingService {
class ArtemisMessagingService(val directory: Path, val myHostPort: HostAndPort,
val defaultExecutor: Executor = RunOnCallerThread) : MessagingService {
// In future: can contain onion routing info, etc.
private data class Address(val hostAndPort: HostAndPort) : SingleMessageRecipient
@ -83,6 +88,9 @@ class ArtemisMessagingService(val directory: Path, val myHostPort: HostAndPort)
private val handlers = CopyOnWriteArrayList<Handler>()
// TODO: This is not robust and needs to be replaced by more intelligently using the message queue server.
private val undeliveredMessages = CopyOnWriteArrayList<Message>()
private fun getSendClient(addr: Address): ClientProducer {
return mutex.locked {
sendClients.getOrPut(addr) {
@ -131,20 +139,10 @@ class ArtemisMessagingService(val directory: Path, val myHostPort: HostAndPort)
// This code runs for every inbound message.
try {
if (!message.containsProperty(TOPIC_PROPERTY)) {
log.warn("Received message without a ${TOPIC_PROPERTY} property, ignoring")
log.warn("Received message without a $TOPIC_PROPERTY property, ignoring")
return@setMessageHandler
}
val topic = message.getStringProperty(TOPIC_PROPERTY)
// Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added
// or removed whilst the filter is executing will not affect anything.
val deliverTo = handlers.filter { if (it.topic.isBlank()) true else it.topic == topic }
if (deliverTo.isEmpty()) {
// This should probably be downgraded to a trace in future, so the protocol can evolve with new topics
// without causing log spam.
log.warn("Received message for $topic that doesn't have any registered handlers.")
return@setMessageHandler
}
val bits = ByteArray(message.bodySize)
message.bodyBuffer.readBytes(bits)
@ -156,15 +154,8 @@ class ArtemisMessagingService(val directory: Path, val myHostPort: HostAndPort)
override val debugMessageID: String = message.messageID.toString()
override fun serialise(): ByteArray = bits
}
for (handler in deliverTo) {
(handler.executor ?: RunOnCallerThread).execute {
try {
handler.callback(msg, handler)
} catch(e: Exception) {
log.error("Caught exception whilst executing message handler for $topic", e)
}
}
}
deliverMessage(msg)
} finally {
message.acknowledge()
}
@ -174,6 +165,36 @@ class ArtemisMessagingService(val directory: Path, val myHostPort: HostAndPort)
mutex.locked { running = true }
}
private fun deliverMessage(msg: Message): Boolean {
// Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added
// or removed whilst the filter is executing will not affect anything.
val deliverTo = handlers.filter { if (it.topic.isBlank()) true else it.topic == msg.topic }
if (deliverTo.isEmpty()) {
// This should probably be downgraded to a trace in future, so the protocol can evolve with new topics
// without causing log spam.
log.warn("Received message for ${msg.topic} that doesn't have any registered handlers yet")
// This is a hack; transient messages held in memory isn't crash resistant.
// TODO: Use Artemis API more effectively so we don't pop messages off a queue that we aren't ready to use.
undeliveredMessages += msg
return false
}
for (handler in deliverTo) {
(handler.executor ?: defaultExecutor).execute {
try {
handler.callback(msg, handler)
} catch(e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
}
}
}
return true
}
override fun stop() {
mutex.locked {
for (producer in sendClients.values)
@ -200,6 +221,7 @@ class ArtemisMessagingService(val directory: Path, val myHostPort: HostAndPort)
callback: (Message, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
val handler = Handler(executor, topic, callback)
handlers.add(handler)
undeliveredMessages.removeIf { deliverMessage(it) }
return handler
}

View File

@ -1,7 +1,6 @@
package core.testing
import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.MoreExecutors
import core.Party
import core.messaging.MessagingService
import core.messaging.SingleMessageRecipient
@ -12,14 +11,13 @@ import core.node.PhysicalLocation
import core.testing.MockIdentityService
import core.node.services.ServiceType
import core.node.services.TimestamperService
import core.utilities.AffinityExecutor
import core.utilities.loggerFor
import org.slf4j.Logger
import java.nio.file.Files
import java.nio.file.Path
import java.time.Clock
import java.util.*
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
/**
* A mock node brings up a suite of in-memory services in a fast manner suitable for unit testing.
@ -61,11 +59,11 @@ class MockNetwork(private val threadPerNode: Boolean = false,
open class MockNode(dir: Path, config: NodeConfiguration, val mockNet: MockNetwork,
withTimestamper: NodeInfo?, val id: Int) : AbstractNode(dir, config, withTimestamper, Clock.systemUTC()) {
override val log: Logger = loggerFor<MockNode>()
override val serverThread: ExecutorService =
override val serverThread: AffinityExecutor =
if (mockNet.threadPerNode)
Executors.newSingleThreadExecutor()
AffinityExecutor.ServiceAffinityExecutor("Mock node thread", 1)
else
MoreExecutors.newDirectExecutorService()
AffinityExecutor.SAME_THREAD
// We only need to override the messaging service here, as currently everything that hits disk does so
// through the java.nio API which we are already mocking via Jimfs.

View File

@ -1,6 +1,5 @@
package core.utilities
import com.google.common.base.Preconditions.checkState
import com.google.common.util.concurrent.Uninterruptibles
import java.time.Duration
import java.util.*
@ -16,10 +15,18 @@ interface AffinityExecutor : Executor {
val isOnThread: Boolean
/** Throws an IllegalStateException if the current thread is equal to the thread this executor is backed by. */
fun checkOnThread()
fun checkOnThread() {
if (!isOnThread)
throw IllegalStateException("On wrong thread: " + Thread.currentThread())
}
/** If isOnThread() then runnable is invoked immediately, otherwise the closure is queued onto the backing thread. */
fun executeASAP(runnable: () -> Unit)
fun executeASAP(runnable: () -> Unit) {
if (isOnThread)
runnable()
else
execute(runnable)
}
/** Terminates any backing thread (pool) without waiting for tasks to finish. */
fun shutdownNow()
@ -35,43 +42,11 @@ interface AffinityExecutor : Executor {
return CompletableFuture.supplyAsync(Supplier { fetcher() }, this).get()
}
abstract class BaseAffinityExecutor protected constructor() : AffinityExecutor {
protected val exceptionHandler: Thread.UncaughtExceptionHandler
init {
exceptionHandler = Thread.currentThread().uncaughtExceptionHandler
}
abstract override val isOnThread: Boolean
override fun checkOnThread() {
checkState(isOnThread, "On wrong thread: %s", Thread.currentThread())
}
override fun executeASAP(runnable: () -> Unit) {
val command = {
try {
runnable()
} catch (throwable: Throwable) {
exceptionHandler.uncaughtException(Thread.currentThread(), throwable)
}
}
if (isOnThread)
command()
else {
execute(command)
}
}
// Must comply with the Executor definition w.r.t. exceptions here.
abstract override fun execute(command: Runnable)
}
/**
* An executor backed by thread pool (which may often have a single thread) which makes it easy to schedule
* tasks in the future and verify code is running on the executor.
*/
class ServiceAffinityExecutor(threadName: String, numThreads: Int) : BaseAffinityExecutor() {
class ServiceAffinityExecutor(threadName: String, numThreads: Int) : AffinityExecutor {
protected val threads = Collections.synchronizedSet(HashSet<Thread>())
private val handler = Thread.currentThread().uncaughtExceptionHandler
@ -81,8 +56,15 @@ interface AffinityExecutor : Executor {
val threadFactory = fun(runnable: Runnable): Thread {
val thread = object : Thread() {
override fun run() {
runnable.run()
threads -= this
try {
runnable.run()
} catch (e: Throwable) {
e.printStackTrace()
handler.uncaughtException(this, e)
throw e
} finally {
threads -= this
}
}
}
thread.isDaemon = true
@ -100,29 +82,12 @@ interface AffinityExecutor : Executor {
override fun execute(command: Runnable) {
service.execute {
try {
command.run()
} catch (e: Throwable) {
if (handler != null)
handler.uncaughtException(Thread.currentThread(), e)
else
e.printStackTrace()
}
command.run()
}
}
fun <T> executeIn(time: Duration, command: () -> T): ScheduledFuture<T> {
return service.schedule(Callable {
try {
command()
} catch (e: Throwable) {
if (handler != null)
handler.uncaughtException(Thread.currentThread(), e)
else
e.printStackTrace()
throw e
}
}, time.toMillis(), TimeUnit.MILLISECONDS)
return service.schedule(Callable { command() }, time.toMillis(), TimeUnit.MILLISECONDS)
}
override fun shutdownNow() {
@ -140,7 +105,7 @@ interface AffinityExecutor : Executor {
*
* @param alwaysQueue If true, executeASAP will never short-circuit and will always queue up.
*/
class Gate(private val alwaysQueue: Boolean = false) : BaseAffinityExecutor() {
class Gate(private val alwaysQueue: Boolean = false) : AffinityExecutor {
private val thisThread = Thread.currentThread()
private val commandQ = LinkedBlockingQueue<Runnable>()
@ -163,7 +128,7 @@ interface AffinityExecutor : Executor {
}
companion object {
val SAME_THREAD: AffinityExecutor = object : BaseAffinityExecutor() {
val SAME_THREAD: AffinityExecutor = object : AffinityExecutor {
override val isOnThread: Boolean get() = true
override fun execute(command: Runnable) = command.run()
override fun shutdownNow() {

View File

@ -1,7 +1,6 @@
package demos.protocols
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.strands.Strand
import contracts.DealState
import contracts.InterestRateSwap
import core.StateAndRef
@ -100,7 +99,6 @@ object UpdateBusinessDayProtocol {
progressTracker.currentStep = FIXING
val participant = TwoPartyDealProtocol.Floater(party.address, sessionID, serviceHub.networkMapCache.timestampingNodes[0], dealStateAndRef, serviceHub.keyManagementService.freshKey(), sessionID, progressTracker.childrenFor[FIXING]!!)
Strand.sleep(100)
val result = subProtocol(participant)
return result.tx.outRef(0)
}
@ -119,7 +117,6 @@ object UpdateBusinessDayProtocol {
data class UpdateBusinessDayMessage(val date: LocalDate, val sessionID: Long)
object Handler {
fun register(node: Node) {
node.net.addMessageHandler("${TOPIC}.0") { msg, registration ->
// Just to validate we got the message

View File

@ -79,7 +79,7 @@ object TwoPartyDealProtocol {
val sessionID = random63BitValue()
// Make the first message we'll send to kick off the protocol.
val hello = Handshake<U>(payload, myKeyPair.public, sessionID)
val hello = Handshake(payload, myKeyPair.public, sessionID)
val maybeSTX = sendAndReceive<SignedTransaction>(DEAL_TOPIC, otherSide, otherSessionID, sessionID, hello)