mirror of
https://github.com/corda/corda.git
synced 2024-12-21 05:53:23 +00:00
State machines: don't leak references to completed state machines. Add an extension function for working with futures.
This commit is contained in:
parent
62f7237364
commit
3c578550a9
@ -9,10 +9,13 @@
|
||||
package core
|
||||
|
||||
import com.google.common.io.BaseEncoding
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.MoreExecutors
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import org.slf4j.Logger
|
||||
import java.time.Duration
|
||||
import java.util.*
|
||||
import java.util.concurrent.Executor
|
||||
|
||||
/** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */
|
||||
open class OpaqueBytes(val bits: ByteArray) {
|
||||
@ -41,6 +44,10 @@ val Int.hours: Duration get() = Duration.ofHours(this.toLong())
|
||||
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
|
||||
val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong())
|
||||
|
||||
fun <T> ListenableFuture<T>.whenComplete(executor: Executor? = null, body: () -> Unit) {
|
||||
addListener(Runnable { body() }, executor ?: MoreExecutors.directExecutor())
|
||||
}
|
||||
|
||||
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */
|
||||
fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
|
||||
try {
|
||||
|
@ -18,6 +18,7 @@ import core.serialization.createKryo
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.serialize
|
||||
import core.utilities.trace
|
||||
import core.whenComplete
|
||||
import org.apache.commons.javaflow.Continuation
|
||||
import org.apache.commons.javaflow.ContinuationClassLoader
|
||||
import org.objenesis.instantiator.ObjectInstantiator
|
||||
@ -35,11 +36,20 @@ import java.util.concurrent.Executor
|
||||
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
|
||||
* (bad for performance, good for programmer mental health!).
|
||||
*
|
||||
* A "state machine" is a class with a single call method. The call method and any others it invokes are rewritten by
|
||||
* a bytecode rewriting engine called JavaFlow, to ensure the code can be suspended and resumed at any point.
|
||||
*
|
||||
* TODO: The framework should propagate exceptions and handle error handling automatically.
|
||||
* TODO: This needs extension to the >2 party case.
|
||||
* TODO: Thread safety
|
||||
*/
|
||||
class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) {
|
||||
// This map is backed by a database and will be used to store serialised state machines to disk, so we can resurrect
|
||||
// them across node restarts.
|
||||
private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
|
||||
private val _stateMachines: MutableList<ProtocolStateMachine<*,*>> = ArrayList()
|
||||
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
|
||||
// property.
|
||||
private val _stateMachines = ArrayList<ProtocolStateMachine<*,*>>()
|
||||
|
||||
/** Returns a snapshot of the currently registered state machines. */
|
||||
val stateMachines: List<ProtocolStateMachine<*,*>> get() = ArrayList(_stateMachines)
|
||||
@ -57,50 +67,65 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
restoreCheckpoints()
|
||||
}
|
||||
|
||||
/** Reads the database map and resurrects any serialised state machines. */
|
||||
private fun restoreCheckpoints() {
|
||||
for (bytes in checkpointsMap.values) {
|
||||
val kryo = createKryo()
|
||||
|
||||
// Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
|
||||
// rewriting is performed correctly.
|
||||
var psm: ProtocolStateMachine<*, *>? = null
|
||||
var _psm: ProtocolStateMachine<*, *>? = null
|
||||
kryo.instantiatorStrategy = object : InstantiatorStrategy {
|
||||
val forwardingTo = kryo.instantiatorStrategy
|
||||
|
||||
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
|
||||
if (ProtocolStateMachine::class.java.isAssignableFrom(type)) {
|
||||
// The messing around with types we do here confuses the compiler/IDE a bit and it warns us.
|
||||
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
|
||||
return ObjectInstantiator<T> {
|
||||
val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
|
||||
p.serviceHub = serviceHub
|
||||
psm = p
|
||||
psm as T
|
||||
}
|
||||
} else {
|
||||
// If this is some object that isn't a state machine, use the default behaviour.
|
||||
if (!ProtocolStateMachine::class.java.isAssignableFrom(type))
|
||||
return forwardingTo.newInstantiatorOf(type)
|
||||
|
||||
// Otherwise, return an 'object instantiator' (i.e. factory) that uses the JavaFlow classloader.
|
||||
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
|
||||
return ObjectInstantiator<T> {
|
||||
val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
|
||||
// Pass the new object a pointer to the service hub where it can find objects that don't
|
||||
// survive restarts.
|
||||
p.serviceHub = serviceHub
|
||||
_psm = p
|
||||
p as T
|
||||
}
|
||||
}
|
||||
}
|
||||
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
|
||||
|
||||
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
|
||||
val continuation = checkpoint.continuation
|
||||
_stateMachines.add(psm!!)
|
||||
|
||||
// We know _psm can't be null here, because we always serialise a ProtocolStateMachine subclass, so the
|
||||
// code above that does "_psm = p" will always run. But the Kotlin compiler can't know that so we have to
|
||||
// forcibly cast away the nullness with the !! operator.
|
||||
val psm = _psm!!
|
||||
registerStateMachine(psm)
|
||||
|
||||
val logger = LoggerFactory.getLogger(checkpoint.loggerName)
|
||||
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
|
||||
// The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the
|
||||
// StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time.
|
||||
|
||||
// And now re-wire the deserialised continuation back up to the network service.
|
||||
setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide,
|
||||
awaitingObjectOfType, checkpoint.awaitingTopic, bytes)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Kicks off a brand new state machine of the given class. It will send messages to the network node identified by
|
||||
* the [otherSide] parameter, log with the named logger, and the [initialArgs] object will be passed to the call
|
||||
* method of the [ProtocolStateMachine] object that is created. The state machine will be persisted when it suspends
|
||||
* and will be removed once it completes.
|
||||
*/
|
||||
fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String,
|
||||
continuationClass: Class<out T>): T {
|
||||
val logger = LoggerFactory.getLogger(loggerName)
|
||||
val (sm, continuation) = loadContinuationClass(continuationClass)
|
||||
sm.serviceHub = serviceHub
|
||||
_stateMachines.add(sm)
|
||||
registerStateMachine(sm)
|
||||
runInThread.execute {
|
||||
// The current state of the continuation is held in the closure attached to the messaging system whenever
|
||||
// the continuation suspends and tells us it expects a response.
|
||||
@ -110,6 +135,13 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
return sm as T
|
||||
}
|
||||
|
||||
private fun registerStateMachine(psm: ProtocolStateMachine<*, *>) {
|
||||
_stateMachines.add(psm)
|
||||
psm.resultFuture.whenComplete(runInThread) {
|
||||
_stateMachines.remove(psm)
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
private fun loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, *>>): Pair<ProtocolStateMachine<*, *>, Continuation> {
|
||||
val url = continuationClass.protectionDomain.codeSource.location
|
||||
|
@ -188,6 +188,8 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
val tx = bobFuture.get()
|
||||
txns.add(tx.second)
|
||||
verify()
|
||||
|
||||
assertTrue(smm.stateMachines.isEmpty())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user