mirror of
https://github.com/corda/corda.git
synced 2024-12-21 05:53:23 +00:00
State machines: don't leak references to completed state machines. Add an extension function for working with futures.
This commit is contained in:
parent
62f7237364
commit
3c578550a9
@ -9,10 +9,13 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import com.google.common.io.BaseEncoding
|
import com.google.common.io.BaseEncoding
|
||||||
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
|
import com.google.common.util.concurrent.MoreExecutors
|
||||||
import com.google.common.util.concurrent.SettableFuture
|
import com.google.common.util.concurrent.SettableFuture
|
||||||
import org.slf4j.Logger
|
import org.slf4j.Logger
|
||||||
import java.time.Duration
|
import java.time.Duration
|
||||||
import java.util.*
|
import java.util.*
|
||||||
|
import java.util.concurrent.Executor
|
||||||
|
|
||||||
/** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */
|
/** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */
|
||||||
open class OpaqueBytes(val bits: ByteArray) {
|
open class OpaqueBytes(val bits: ByteArray) {
|
||||||
@ -41,6 +44,10 @@ val Int.hours: Duration get() = Duration.ofHours(this.toLong())
|
|||||||
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
|
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
|
||||||
val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong())
|
val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong())
|
||||||
|
|
||||||
|
fun <T> ListenableFuture<T>.whenComplete(executor: Executor? = null, body: () -> Unit) {
|
||||||
|
addListener(Runnable { body() }, executor ?: MoreExecutors.directExecutor())
|
||||||
|
}
|
||||||
|
|
||||||
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */
|
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */
|
||||||
fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
|
fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
|
||||||
try {
|
try {
|
||||||
|
@ -18,6 +18,7 @@ import core.serialization.createKryo
|
|||||||
import core.serialization.deserialize
|
import core.serialization.deserialize
|
||||||
import core.serialization.serialize
|
import core.serialization.serialize
|
||||||
import core.utilities.trace
|
import core.utilities.trace
|
||||||
|
import core.whenComplete
|
||||||
import org.apache.commons.javaflow.Continuation
|
import org.apache.commons.javaflow.Continuation
|
||||||
import org.apache.commons.javaflow.ContinuationClassLoader
|
import org.apache.commons.javaflow.ContinuationClassLoader
|
||||||
import org.objenesis.instantiator.ObjectInstantiator
|
import org.objenesis.instantiator.ObjectInstantiator
|
||||||
@ -35,11 +36,20 @@ import java.util.concurrent.Executor
|
|||||||
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
|
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
|
||||||
* (bad for performance, good for programmer mental health!).
|
* (bad for performance, good for programmer mental health!).
|
||||||
*
|
*
|
||||||
|
* A "state machine" is a class with a single call method. The call method and any others it invokes are rewritten by
|
||||||
|
* a bytecode rewriting engine called JavaFlow, to ensure the code can be suspended and resumed at any point.
|
||||||
|
*
|
||||||
* TODO: The framework should propagate exceptions and handle error handling automatically.
|
* TODO: The framework should propagate exceptions and handle error handling automatically.
|
||||||
|
* TODO: This needs extension to the >2 party case.
|
||||||
|
* TODO: Thread safety
|
||||||
*/
|
*/
|
||||||
class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) {
|
class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) {
|
||||||
|
// This map is backed by a database and will be used to store serialised state machines to disk, so we can resurrect
|
||||||
|
// them across node restarts.
|
||||||
private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
|
private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
|
||||||
private val _stateMachines: MutableList<ProtocolStateMachine<*,*>> = ArrayList()
|
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
|
||||||
|
// property.
|
||||||
|
private val _stateMachines = ArrayList<ProtocolStateMachine<*,*>>()
|
||||||
|
|
||||||
/** Returns a snapshot of the currently registered state machines. */
|
/** Returns a snapshot of the currently registered state machines. */
|
||||||
val stateMachines: List<ProtocolStateMachine<*,*>> get() = ArrayList(_stateMachines)
|
val stateMachines: List<ProtocolStateMachine<*,*>> get() = ArrayList(_stateMachines)
|
||||||
@ -57,50 +67,65 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
|||||||
restoreCheckpoints()
|
restoreCheckpoints()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Reads the database map and resurrects any serialised state machines. */
|
||||||
private fun restoreCheckpoints() {
|
private fun restoreCheckpoints() {
|
||||||
for (bytes in checkpointsMap.values) {
|
for (bytes in checkpointsMap.values) {
|
||||||
val kryo = createKryo()
|
val kryo = createKryo()
|
||||||
|
|
||||||
// Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
|
// Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
|
||||||
// rewriting is performed correctly.
|
// rewriting is performed correctly.
|
||||||
var psm: ProtocolStateMachine<*, *>? = null
|
var _psm: ProtocolStateMachine<*, *>? = null
|
||||||
kryo.instantiatorStrategy = object : InstantiatorStrategy {
|
kryo.instantiatorStrategy = object : InstantiatorStrategy {
|
||||||
val forwardingTo = kryo.instantiatorStrategy
|
val forwardingTo = kryo.instantiatorStrategy
|
||||||
|
|
||||||
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
|
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
|
||||||
if (ProtocolStateMachine::class.java.isAssignableFrom(type)) {
|
// If this is some object that isn't a state machine, use the default behaviour.
|
||||||
// The messing around with types we do here confuses the compiler/IDE a bit and it warns us.
|
if (!ProtocolStateMachine::class.java.isAssignableFrom(type))
|
||||||
|
return forwardingTo.newInstantiatorOf(type)
|
||||||
|
|
||||||
|
// Otherwise, return an 'object instantiator' (i.e. factory) that uses the JavaFlow classloader.
|
||||||
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
|
@Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
|
||||||
return ObjectInstantiator<T> {
|
return ObjectInstantiator<T> {
|
||||||
val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
|
val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
|
||||||
|
// Pass the new object a pointer to the service hub where it can find objects that don't
|
||||||
|
// survive restarts.
|
||||||
p.serviceHub = serviceHub
|
p.serviceHub = serviceHub
|
||||||
psm = p
|
_psm = p
|
||||||
psm as T
|
p as T
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return forwardingTo.newInstantiatorOf(type)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
|
|
||||||
|
|
||||||
|
val checkpoint = bytes.deserialize<Checkpoint>(kryo)
|
||||||
val continuation = checkpoint.continuation
|
val continuation = checkpoint.continuation
|
||||||
_stateMachines.add(psm!!)
|
|
||||||
|
// We know _psm can't be null here, because we always serialise a ProtocolStateMachine subclass, so the
|
||||||
|
// code above that does "_psm = p" will always run. But the Kotlin compiler can't know that so we have to
|
||||||
|
// forcibly cast away the nullness with the !! operator.
|
||||||
|
val psm = _psm!!
|
||||||
|
registerStateMachine(psm)
|
||||||
|
|
||||||
val logger = LoggerFactory.getLogger(checkpoint.loggerName)
|
val logger = LoggerFactory.getLogger(checkpoint.loggerName)
|
||||||
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
|
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
|
||||||
// The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the
|
|
||||||
// StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time.
|
// And now re-wire the deserialised continuation back up to the network service.
|
||||||
setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide,
|
setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide,
|
||||||
awaitingObjectOfType, checkpoint.awaitingTopic, bytes)
|
awaitingObjectOfType, checkpoint.awaitingTopic, bytes)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Kicks off a brand new state machine of the given class. It will send messages to the network node identified by
|
||||||
|
* the [otherSide] parameter, log with the named logger, and the [initialArgs] object will be passed to the call
|
||||||
|
* method of the [ProtocolStateMachine] object that is created. The state machine will be persisted when it suspends
|
||||||
|
* and will be removed once it completes.
|
||||||
|
*/
|
||||||
fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String,
|
fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String,
|
||||||
continuationClass: Class<out T>): T {
|
continuationClass: Class<out T>): T {
|
||||||
val logger = LoggerFactory.getLogger(loggerName)
|
val logger = LoggerFactory.getLogger(loggerName)
|
||||||
val (sm, continuation) = loadContinuationClass(continuationClass)
|
val (sm, continuation) = loadContinuationClass(continuationClass)
|
||||||
sm.serviceHub = serviceHub
|
sm.serviceHub = serviceHub
|
||||||
_stateMachines.add(sm)
|
registerStateMachine(sm)
|
||||||
runInThread.execute {
|
runInThread.execute {
|
||||||
// The current state of the continuation is held in the closure attached to the messaging system whenever
|
// The current state of the continuation is held in the closure attached to the messaging system whenever
|
||||||
// the continuation suspends and tells us it expects a response.
|
// the continuation suspends and tells us it expects a response.
|
||||||
@ -110,6 +135,13 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
|||||||
return sm as T
|
return sm as T
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private fun registerStateMachine(psm: ProtocolStateMachine<*, *>) {
|
||||||
|
_stateMachines.add(psm)
|
||||||
|
psm.resultFuture.whenComplete(runInThread) {
|
||||||
|
_stateMachines.remove(psm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Suppress("UNCHECKED_CAST")
|
@Suppress("UNCHECKED_CAST")
|
||||||
private fun loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, *>>): Pair<ProtocolStateMachine<*, *>, Continuation> {
|
private fun loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, *>>): Pair<ProtocolStateMachine<*, *>, Continuation> {
|
||||||
val url = continuationClass.protectionDomain.codeSource.location
|
val url = continuationClass.protectionDomain.codeSource.location
|
||||||
|
@ -188,6 +188,8 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
|||||||
val tx = bobFuture.get()
|
val tx = bobFuture.get()
|
||||||
txns.add(tx.second)
|
txns.add(tx.second)
|
||||||
verify()
|
verify()
|
||||||
|
|
||||||
|
assertTrue(smm.stateMachines.isEmpty())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user