From 3c578550a981d5bd5f738842defe49536f490a9c Mon Sep 17 00:00:00 2001
From: Mike Hearn <mike@r3cev.com>
Date: Mon, 14 Dec 2015 21:41:57 +0100
Subject: [PATCH] State machines: don't leak references to completed state
 machines. Add an extension function for working with futures.

---
 src/main/kotlin/core/Utils.kt                 |  7 ++
 .../kotlin/core/messaging/StateMachines.kt    | 66 ++++++++++++++-----
 .../messaging/TwoPartyTradeProtocolTests.kt   |  2 +
 3 files changed, 58 insertions(+), 17 deletions(-)

diff --git a/src/main/kotlin/core/Utils.kt b/src/main/kotlin/core/Utils.kt
index ed77c5e44b..85dcac6c5a 100644
--- a/src/main/kotlin/core/Utils.kt
+++ b/src/main/kotlin/core/Utils.kt
@@ -9,10 +9,13 @@
 package core
 
 import com.google.common.io.BaseEncoding
+import com.google.common.util.concurrent.ListenableFuture
+import com.google.common.util.concurrent.MoreExecutors
 import com.google.common.util.concurrent.SettableFuture
 import org.slf4j.Logger
 import java.time.Duration
 import java.util.*
+import java.util.concurrent.Executor
 
 /** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */
 open class OpaqueBytes(val bits: ByteArray) {
@@ -41,6 +44,10 @@ val Int.hours: Duration get() = Duration.ofHours(this.toLong())
 val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
 val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong())
 
+fun <T> ListenableFuture<T>.whenComplete(executor: Executor? = null, body: () -> Unit) {
+    addListener(Runnable { body() }, executor ?: MoreExecutors.directExecutor())
+}
+
 /** Executes the given block and sets the future to either the result, or any exception that was thrown. */
 fun <T> SettableFuture<T>.setFrom(logger: Logger? = null, block: () -> T): SettableFuture<T> {
     try {
diff --git a/src/main/kotlin/core/messaging/StateMachines.kt b/src/main/kotlin/core/messaging/StateMachines.kt
index 409be67968..e3e7331354 100644
--- a/src/main/kotlin/core/messaging/StateMachines.kt
+++ b/src/main/kotlin/core/messaging/StateMachines.kt
@@ -18,6 +18,7 @@ import core.serialization.createKryo
 import core.serialization.deserialize
 import core.serialization.serialize
 import core.utilities.trace
+import core.whenComplete
 import org.apache.commons.javaflow.Continuation
 import org.apache.commons.javaflow.ContinuationClassLoader
 import org.objenesis.instantiator.ObjectInstantiator
@@ -35,11 +36,20 @@ import java.util.concurrent.Executor
  * and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
  * (bad for performance, good for programmer mental health!).
  *
+ * A "state machine" is a class with a single call method. The call method and any others it invokes are rewritten by
+ * a bytecode rewriting engine called JavaFlow, to ensure the code can be suspended and resumed at any point.
+ *
  * TODO: The framework should propagate exceptions and handle error handling automatically.
+ * TODO: This needs extension to the >2 party case.
+ * TODO: Thread safety
  */
 class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) {
+    // This map is backed by a database and will be used to store serialised state machines to disk, so we can resurrect
+    // them across node restarts.
     private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
-    private val _stateMachines: MutableList<ProtocolStateMachine<*,*>> = ArrayList()
+    // A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
+    // property.
+    private val _stateMachines = ArrayList<ProtocolStateMachine<*,*>>()
 
     /** Returns a snapshot of the currently registered state machines. */
     val stateMachines: List<ProtocolStateMachine<*,*>> get() = ArrayList(_stateMachines)
@@ -57,50 +67,65 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
         restoreCheckpoints()
     }
 
+    /** Reads the database map and resurrects any serialised state machines. */
     private fun restoreCheckpoints() {
         for (bytes in checkpointsMap.values) {
             val kryo = createKryo()
 
             // Set up Kryo to use the JavaFlow classloader when deserialising, so the magical continuation bytecode
             // rewriting is performed correctly.
-            var psm: ProtocolStateMachine<*, *>? = null
+            var _psm: ProtocolStateMachine<*, *>? = null
             kryo.instantiatorStrategy = object : InstantiatorStrategy {
                 val forwardingTo = kryo.instantiatorStrategy
 
                 override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
-                    if (ProtocolStateMachine::class.java.isAssignableFrom(type)) {
-                        // The messing around with types we do here confuses the compiler/IDE a bit and it warns us.
-                        @Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
-                        return ObjectInstantiator<T> {
-                            val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
-                            p.serviceHub = serviceHub
-                            psm = p
-                            psm as T
-                        }
-                    } else {
+                    // If this is some object that isn't a state machine, use the default behaviour.
+                    if (!ProtocolStateMachine::class.java.isAssignableFrom(type))
                         return forwardingTo.newInstantiatorOf(type)
+
+                    // Otherwise, return an 'object instantiator' (i.e. factory) that uses the JavaFlow classloader.
+                    @Suppress("UNCHECKED_CAST", "CAST_NEVER_SUCCEEDS")
+                    return ObjectInstantiator<T> {
+                        val p = loadContinuationClass(type as Class<out ProtocolStateMachine<*, *>>).first
+                        // Pass the new object a pointer to the service hub where it can find objects that don't
+                        // survive restarts.
+                        p.serviceHub = serviceHub
+                        _psm = p
+                        p as T
                     }
                 }
             }
-            val checkpoint = bytes.deserialize<Checkpoint>(kryo)
 
+            val checkpoint = bytes.deserialize<Checkpoint>(kryo)
             val continuation = checkpoint.continuation
-            _stateMachines.add(psm!!)
+
+            // We know _psm can't be null here, because we always serialise a ProtocolStateMachine subclass, so the
+            // code above that does "_psm = p" will always run. But the Kotlin compiler can't know that so we have to
+            // forcibly cast away the nullness with the !! operator.
+            val psm = _psm!!
+            registerStateMachine(psm)
+
             val logger = LoggerFactory.getLogger(checkpoint.loggerName)
             val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
-            // The act of calling this method re-persists the bytes into the in-memory hashmap so re-saving the
-            // StateMachineManager to disk will work even if some state machines didn't wake up in the intervening time.
+
+            // And now re-wire the deserialised continuation back up to the network service.
             setupNextMessageHandler(logger, serviceHub.networkService, continuation, checkpoint.otherSide,
                     awaitingObjectOfType, checkpoint.awaitingTopic, bytes)
         }
     }
 
+    /**
+     * Kicks off a brand new state machine of the given class. It will send messages to the network node identified by
+     * the [otherSide] parameter, log with the named logger, and the [initialArgs] object will be passed to the call
+     * method of the [ProtocolStateMachine] object that is created. The state machine will be persisted when it suspends
+     * and will be removed once it completes.
+     */
     fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String,
                                                 continuationClass: Class<out T>): T {
         val logger = LoggerFactory.getLogger(loggerName)
         val (sm, continuation) = loadContinuationClass(continuationClass)
         sm.serviceHub = serviceHub
-        _stateMachines.add(sm)
+        registerStateMachine(sm)
         runInThread.execute {
             // The current state of the continuation is held in the closure attached to the messaging system whenever
             // the continuation suspends and tells us it expects a response.
@@ -110,6 +135,13 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
         return sm as T
     }
 
+    private fun registerStateMachine(psm: ProtocolStateMachine<*, *>) {
+        _stateMachines.add(psm)
+        psm.resultFuture.whenComplete(runInThread) {
+            _stateMachines.remove(psm)
+        }
+    }
+
     @Suppress("UNCHECKED_CAST")
     private fun loadContinuationClass(continuationClass: Class<out ProtocolStateMachine<*, *>>): Pair<ProtocolStateMachine<*, *>, Continuation> {
         val url = continuationClass.protectionDomain.codeSource.location
diff --git a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
index 7da81e145e..e489a50270 100644
--- a/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
+++ b/src/test/kotlin/core/messaging/TwoPartyTradeProtocolTests.kt
@@ -188,6 +188,8 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
             val tx = bobFuture.get()
             txns.add(tx.second)
             verify()
+
+            assertTrue(smm.stateMachines.isEmpty())
         }
     }
 }