State machines: don't leak references to completed state machines. Add an extension function for working with futures.

This commit is contained in:
Mike Hearn 2015-12-14 21:41:57 +01:00
parent 62f7237364
commit 3c578550a9
3 changed files with 58 additions and 17 deletions

View File

@ -9,10 +9,13 @@
package core package core
import com.google.common.io.BaseEncoding import com.google.common.io.BaseEncoding
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import org.slf4j.Logger import org.slf4j.Logger
import java.time.Duration import java.time.Duration
import java.util.* import java.util.*
import java.util.concurrent.Executor
/** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */ /** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */
open class OpaqueBytes(val bits: ByteArray) { open class OpaqueBytes(val bits: ByteArray) {
@ -41,6 +44,10 @@ val Int.hours: Duration get() = Duration.ofHours(this.toLong())
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong()) val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong()) val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong())
fun <T> ListenableFuture<T>.whenComplete(executor: Executor? = null, body: () -> Unit) {
addListener(Runnable { body() }, executor ?: MoreExecutors.directExecutor())
}
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */ /** Executes the given block and sets the future to either the result, or any exception that was thrown. */
fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> { fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
try { try {

View File

@ -18,6 +18,7 @@ import core.serialization.createKryo
import core.serialization.deserialize import core.serialization.deserialize
import core.serialization.serialize import core.serialization.serialize
import core.utilities.trace import core.utilities.trace
import core.whenComplete
import org.apache.commons.javaflow.Continuation import org.apache.commons.javaflow.Continuation
import org.apache.commons.javaflow.ContinuationClassLoader import org.apache.commons.javaflow.ContinuationClassLoader
import org.objenesis.instantiator.ObjectInstantiator import org.objenesis.instantiator.ObjectInstantiator
@ -35,11 +36,20 @@ import java.util.concurrent.Executor
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other * and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
* (bad for performance, good for programmer mental health!). * (bad for performance, good for programmer mental health!).
* *
* A "state machine" is a class with a single call method. The call method and any others it invokes are rewritten by
* a bytecode rewriting engine called JavaFlow, to ensure the code can be suspended and resumed at any point.
*
* TODO: The framework should propagate exceptions and handle error handling automatically. * TODO: The framework should propagate exceptions and handle error handling automatically.
* TODO: This needs extension to the >2 party case.
* TODO: Thread safety
*/ */
class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) { class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) {
// This map is backed by a database and will be used to store serialised state machines to disk, so we can resurrect
// them across node restarts.
private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines") private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
private val _stateMachines: MutableList<ProtocolStateMachine<*,*>> = ArrayList() // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
// property.
private val _stateMachines = ArrayList<ProtocolStateMachine<*,*>>()
/** Returns a snapshot of the currently registered state machines. */ /** Returns a snapshot of the currently registered state machines. */
val stateMachines: List<ProtocolStateMachine<*,*>> get() = ArrayList(_stateMachines) val stateMachines: List<ProtocolStateMachine<*,*>> get() = ArrayList(_stateMachines)
@ -57,50 +67,65 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
restoreCheckpoints() restoreCheckpoints()
} }
/** Reads the database map and resurrects any serialised state machines. */
private fun restoreCheckpoints() { private fun restoreCheckpoints() {
for (bytes in checkpointsMap.values) { for (bytes in checkpointsMap.values) {
val kryo = createKryo() val kryo = createKryo()
// Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode // Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
// rewriting is performed correctly. // rewriting is performed correctly.
var psm: ProtocolStateMachine<*, *>? = null var _psm: ProtocolStateMachine<*, *>? = null
kryo.instantiatorStrategy = object : InstantiatorStrategy { kryo.instantiatorStrategy = object : InstantiatorStrategy {
val forwardingTo = kryo.instantiatorStrategy val forwardingTo = kryo.instantiatorStrategy
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> { override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
if (ProtocolStateMachine::class.java.isAssignableFrom(type)) { // If this is some object that isn't a state machine, use the default behaviour.
// The messing around with types we do here confuses the compiler/IDE a bit and it warns us. if (!ProtocolStateMachine::class.java.isAssignableFrom(type))
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
return ObjectInstantiator<T> {
val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
p.serviceHub = serviceHub
psm = p
psm as T
}
} else {
return forwardingTo.newInstantiatorOf(type) return forwardingTo.newInstantiatorOf(type)
// Otherwise, return an 'object instantiator' (i.e. factory) that uses the JavaFlow classloader.
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
return ObjectInstantiator<T> {
val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
// Pass the new object a pointer to the service hub where it can find objects that don't
// survive restarts.
p.serviceHub = serviceHub
_psm = p
p as T
} }
} }
} }
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
val continuation = checkpoint.continuation val continuation = checkpoint.continuation
_stateMachines.add(psm!!)
// We know _psm can't be null here, because we always serialise a ProtocolStateMachine subclass, so the
// code above that does "_psm = p" will always run. But the Kotlin compiler can't know that so we have to
// forcibly cast away the nullness with the !! operator.
val psm = _psm!!
registerStateMachine(psm)
val logger = LoggerFactory.getLogger(checkpoint.loggerName) val logger = LoggerFactory.getLogger(checkpoint.loggerName)
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType) val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
// The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the
// StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time. // And now re-wire the deserialised continuation back up to the network service.
setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide, setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide,
awaitingObjectOfType, checkpoint.awaitingTopic, bytes) awaitingObjectOfType, checkpoint.awaitingTopic, bytes)
} }
} }
/**
* Kicks off a brand new state machine of the given class. It will send messages to the network node identified by
* the [otherSide] parameter, log with the named logger, and the [initialArgs] object will be passed to the call
* method of the [ProtocolStateMachine] object that is created. The state machine will be persisted when it suspends
* and will be removed once it completes.
*/
fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String, fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String,
continuationClass: Class<out T>): T { continuationClass: Class<out T>): T {
val logger = LoggerFactory.getLogger(loggerName) val logger = LoggerFactory.getLogger(loggerName)
val (sm, continuation) = loadContinuationClass(continuationClass) val (sm, continuation) = loadContinuationClass(continuationClass)
sm.serviceHub = serviceHub sm.serviceHub = serviceHub
_stateMachines.add(sm) registerStateMachine(sm)
runInThread.execute { runInThread.execute {
// The current state of the continuation is held in the closure attached to the messaging system whenever // The current state of the continuation is held in the closure attached to the messaging system whenever
// the continuation suspends and tells us it expects a response. // the continuation suspends and tells us it expects a response.
@ -110,6 +135,13 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
return sm as T return sm as T
} }
private fun registerStateMachine(psm: ProtocolStateMachine<*, *>) {
_stateMachines.add(psm)
psm.resultFuture.whenComplete(runInThread) {
_stateMachines.remove(psm)
}
}
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
private fun loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, *>>): Pair<ProtocolStateMachine<*, *>, Continuation> { private fun loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, *>>): Pair<ProtocolStateMachine<*, *>, Continuation> {
val url = continuationClass.protectionDomain.codeSource.location val url = continuationClass.protectionDomain.codeSource.location

View File

@ -188,6 +188,8 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
val tx = bobFuture.get() val tx = bobFuture.get()
txns.add(tx.second) txns.add(tx.second)
verify() verify()
assertTrue(smm.stateMachines.isEmpty())
} }
} }
} }