From 3c578550a981d5bd5f738842defe49536f490a9c Mon Sep 17 00:00:00 2001 From: Mike Hearn <mike@r3cev.com> Date: Mon, 14 Dec 2015 21:41:57 +0100 Subject: [PATCH] State machines: don't leak references to completed state machines. Add an extension function for working with futures. --- src/main/kotlin/core/Utils.kt | 7 ++ .../kotlin/core/messaging/StateMachines.kt | 66 ++++++++++++++----- .../messaging/TwoPartyTradeProtocolTests.kt | 2 + 3 files changed, 58 insertions(+), 17 deletions(-) diff --git a/src/main/kotlin/core/Utils.kt b/src/main/kotlin/core/Utils.kt index ed77c5e44b..85dcac6c5a 100644 --- a/src/main/kotlin/core/Utils.kt +++ b/src/main/kotlin/core/Utils.kt @@ -9,10 +9,13 @@ package core import com.google.common.io.BaseEncoding +import com.google.common.util.concurrent.ListenableFuture +import com.google.common.util.concurrent.MoreExecutors import com.google.common.util.concurrent.SettableFuture import org.slf4j.Logger import java.time.Duration import java.util.* +import java.util.concurrent.Executor /** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */ open class OpaqueBytes(val bits: ByteArray) { @@ -41,6 +44,10 @@ val Int.hours: Duration get() = Duration.ofHours(this.toLong()) val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong()) val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong()) +fun <T> ListenableFuture<T>.whenComplete(executor: Executor? = null, body: () -> Unit) { + addListener(Runnable { body() }, executor ?: MoreExecutors.directExecutor()) +} + /** Executes the given block and sets the future to either the result, or any exception that was thrown. */ fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> { try { diff --git a/src/main/kotlin/core/messaging/StateMachines.kt b/src/main/kotlin/core/messaging/StateMachines.kt index 409be67968..e3e7331354 100644 --- a/src/main/kotlin/core/messaging/StateMachines.kt +++ b/src/main/kotlin/core/messaging/StateMachines.kt @@ -18,6 +18,7 @@ import core.serialization.createKryo import core.serialization.deserialize import core.serialization.serialize import core.utilities.trace +import core.whenComplete import org.apache.commons.javaflow.Continuation import org.apache.commons.javaflow.ContinuationClassLoader import org.objenesis.instantiator.ObjectInstantiator @@ -35,11 +36,20 @@ import java.util.concurrent.Executor * and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other * (bad for performance, good for programmer mental health!). * + * A "state machine" is a class with a single call method. The call method and any others it invokes are rewritten by + * a bytecode rewriting engine called JavaFlow, to ensure the code can be suspended and resumed at any point. + * * TODO: The framework should propagate exceptions and handle error handling automatically. + * TODO: This needs extension to the >2 party case. + * TODO: Thread safety */ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) { + // This map is backed by a database and will be used to store serialised state machines to disk, so we can resurrect + // them across node restarts. private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines") - private val _stateMachines: MutableList<ProtocolStateMachine<*,*>> = ArrayList() + // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines + // property. + private val _stateMachines = ArrayList<ProtocolStateMachine<*,*>>() /** Returns a snapshot of the currently registered state machines. */ val stateMachines: List<ProtocolStateMachine<*,*>> get() = ArrayList(_stateMachines) @@ -57,50 +67,65 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) restoreCheckpoints() } + /** Reads the database map and resurrects any serialised state machines. */ private fun restoreCheckpoints() { for (bytes in checkpointsMap.values) { val kryo = createKryo() // Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode // rewriting is performed correctly. - var psm: ProtocolStateMachine<*, *>? = null + var _psm: ProtocolStateMachine<*, *>? = null kryo.instantiatorStrategy = object : InstantiatorStrategy { val forwardingTo = kryo.instantiatorStrategy override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> { - if (ProtocolStateMachine::class.java.isAssignableFrom(type)) { - // The messing around with types we do here confuses the compiler/IDE a bit and it warns us. - @Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS") - return ObjectInstantiator<T> { - val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first - p.serviceHub = serviceHub - psm = p - psm as T - } - } else { + // If this is some object that isn't a state machine, use the default behaviour. + if (!ProtocolStateMachine::class.java.isAssignableFrom(type)) return forwardingTo.newInstantiatorOf(type) + + // Otherwise, return an 'object instantiator' (i.e. factory) that uses the JavaFlow classloader. + @Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS") + return ObjectInstantiator<T> { + val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first + // Pass the new object a pointer to the service hub where it can find objects that don't + // survive restarts. + p.serviceHub = serviceHub + _psm = p + p as T } } } - val checkpoint = bytes.deserialize<Checkpoint>(kryo) + val checkpoint = bytes.deserialize<Checkpoint>(kryo) val continuation = checkpoint.continuation - _stateMachines.add(psm!!) + + // We know _psm can't be null here, because we always serialise a ProtocolStateMachine subclass, so the + // code above that does "_psm = p" will always run. But the Kotlin compiler can't know that so we have to + // forcibly cast away the nullness with the !! operator. + val psm = _psm!! + registerStateMachine(psm) + val logger = LoggerFactory.getLogger(checkpoint.loggerName) val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType) - // The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the - // StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time. + + // And now re-wire the deserialised continuation back up to the network service. setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide, awaitingObjectOfType, checkpoint.awaitingTopic, bytes) } } + /** + * Kicks off a brand new state machine of the given class. It will send messages to the network node identified by + * the [otherSide] parameter, log with the named logger, and the [initialArgs] object will be passed to the call + * method of the [ProtocolStateMachine] object that is created. The state machine will be persisted when it suspends + * and will be removed once it completes. + */ fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String, continuationClass: Class<out T>): T { val logger = LoggerFactory.getLogger(loggerName) val (sm, continuation) = loadContinuationClass(continuationClass) sm.serviceHub = serviceHub - _stateMachines.add(sm) + registerStateMachine(sm) runInThread.execute { // The current state of the continuation is held in the closure attached to the messaging system whenever // the continuation suspends and tells us it expects a response. @@ -110,6 +135,13 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) return sm as T } + private fun registerStateMachine(psm: ProtocolStateMachine<*, *>) { + _stateMachines.add(psm) + psm.resultFuture.whenComplete(runInThread) { + _stateMachines.remove(psm) + } + } + @Suppress("UNCHECKED_CAST") private fun loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, *>>): Pair<ProtocolStateMachine<*, *>, Continuation> { val url = continuationClass.protectionDomain.codeSource.location diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt index 7da81e145e..e489a50270 100644 --- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt +++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt @@ -188,6 +188,8 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() { val tx = bobFuture.get() txns.add(tx.second) verify() + + assertTrue(smm.stateMachines.isEmpty()) } } }