From d054238f191604bda6270156157fb12947938286 Mon Sep 17 00:00:00 2001
From: Stefano Franz <roastario@gmail.com>
Date: Fri, 15 Jun 2018 16:01:56 +0100
Subject: [PATCH] ENT-2013 fix issue when a node is restarted with the same
 entity, but different keypair (#3319)

* fix issue when a node is restarted with the same entity, but different keypair

* address review comment

* remove incorrect test
---
 .../network/PersistentNetworkMapCacheTest.kt  |  8 +--
 .../net/corda/node/internal/AbstractNode.kt   | 31 ++++++++---
 .../network/PersistentNetworkMapCache.kt      |  2 +-
 .../net/corda/node/internal/NodeTest.kt       | 53 +++++++++++++++++--
 4 files changed, 75 insertions(+), 19 deletions(-)

diff --git a/node/src/integration-test/kotlin/net/corda/node/services/network/PersistentNetworkMapCacheTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/network/PersistentNetworkMapCacheTest.kt
index 45db0c0397..3dfc00aec9 100644
--- a/node/src/integration-test/kotlin/net/corda/node/services/network/PersistentNetworkMapCacheTest.kt
+++ b/node/src/integration-test/kotlin/net/corda/node/services/network/PersistentNetworkMapCacheTest.kt
@@ -7,13 +7,7 @@ import net.corda.core.node.NodeInfo
 import net.corda.core.utilities.NetworkHostAndPort
 import net.corda.node.internal.Node
 import net.corda.node.internal.StartedNode
-import net.corda.testing.core.ALICE_NAME
-import net.corda.testing.core.BOB_NAME
-import net.corda.testing.core.CHARLIE_NAME
-import net.corda.testing.core.DUMMY_NOTARY_NAME
-import net.corda.testing.core.TestIdentity
-import net.corda.testing.core.getTestPartyAndCertificate
-import net.corda.testing.core.singleIdentity
+import net.corda.testing.core.*
 import net.corda.testing.node.internal.NodeBasedTest
 import org.assertj.core.api.Assertions.assertThat
 import org.assertj.core.api.Assertions.assertThatExceptionOfType
diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
index c8f321f13f..0a824dee58 100644
--- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
+++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt
@@ -169,7 +169,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
 
     /** Set to non-null once [start] has been successfully called. */
     open val started get() = _started
-    @Volatile private var _started: StartedNode<AbstractNode>? = null
+    @Volatile
+    private var _started: StartedNode<AbstractNode>? = null
 
     /** The implementation of the [CordaRPCOps] interface used by this node. */
     open fun makeRPCOps(flowStarter: FlowStarter, smm: StateMachineManager): CordaRPCOps {
@@ -390,7 +391,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
                 serial = 0
         )
 
-        val nodeInfoFromDb = networkMapCache.getNodeByLegalName(identity.name)
+        val nodeInfoFromDb = getPreviousNodeInfoIfPresent(networkMapCache, identity)
+
 
         val nodeInfo = if (potentialNodeInfo == nodeInfoFromDb?.copy(serial = 0)) {
             // The node info hasn't changed. We use the one from the database to preserve the serial.
@@ -420,6 +422,19 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
         return Pair(keyPairs, nodeInfo)
     }
 
+    private fun getPreviousNodeInfoIfPresent(networkMapCache: NetworkMapCacheBaseInternal, identity: PartyAndCertificate): NodeInfo? {
+        val nodeInfosFromDb = networkMapCache.getNodesByLegalName(identity.name)
+
+        return when (nodeInfosFromDb.size) {
+            0 -> null
+            1 -> nodeInfosFromDb[0]
+            else -> {
+                log.warn("Found more than one node registration with our legal name, this is only expected if our keypair has been regenerated")
+                nodeInfosFromDb[0]
+            }
+        }
+    }
+
     // Publish node info on startup and start task that sends every day a heartbeat - republishes node info.
     private fun tryPublishNodeInfoAsync(signedNodeInfo: SignedNodeInfo, networkMapClient: NetworkMapClient) {
         // By default heartbeat interval should be set to 1 day, but for testing we may change it.
@@ -784,7 +799,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
     }
 
     private fun makeCoreNotaryService(notaryConfig: NotaryConfig, database: CordaPersistence): NotaryService {
-        val notaryKey = myNotaryIdentity?.owningKey ?: throw IllegalArgumentException("No notary identity initialized when creating a notary service")
+        val notaryKey = myNotaryIdentity?.owningKey
+                ?: throw IllegalArgumentException("No notary identity initialized when creating a notary service")
         return notaryConfig.run {
             when {
                 raft != null -> {
@@ -880,7 +896,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
             throw ConfigurationException("The name '$singleName' for $id doesn't match what's in the key store: $subject")
         } else if (notaryConfig != null && notaryConfig.isClusterConfig && notaryConfig.serviceLegalName != null && subject != notaryConfig.serviceLegalName) {
             // Note that we're not checking if `notaryConfig.serviceLegalName` is not present for backwards compatibility.
-            throw ConfigurationException("The name of the notary service '${notaryConfig.serviceLegalName}' for $id doesn't match what's in the key store: $subject. "+
+            throw ConfigurationException("The name of the notary service '${notaryConfig.serviceLegalName}' for $id doesn't match what's in the key store: $subject. " +
                     "You might need to adjust the configuration of `notary.serviceLegalName`.")
         }
 
@@ -902,8 +918,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
             log.info("Starting Jolokia agent on HTTP port: $port")
             val libDir = Paths.get(configuration.baseDirectory.toString(), "drivers")
             val jarFilePath = JVMAgentRegistry.resolveAgentJar(
-                    "jolokia-jvm-${NodeBuildProperties.JOLOKIA_AGENT_VERSION}-agent.jar", libDir) ?:
-                    throw Error("Unable to locate agent jar file")
+                    "jolokia-jvm-${NodeBuildProperties.JOLOKIA_AGENT_VERSION}-agent.jar", libDir)
+                    ?: throw Error("Unable to locate agent jar file")
             log.info("Agent jar file: $jarFilePath")
             JVMAgentRegistry.attach("jolokia", "port=$port", jarFilePath)
         }
@@ -939,7 +955,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
         override val networkMapUpdater: NetworkMapUpdater get() = this@AbstractNode.networkMapUpdater
         override fun <T : SerializeAsToken> cordaService(type: Class<T>): T {
             require(type.isAnnotationPresent(CordaService::class.java)) { "${type.name} is not a Corda service" }
-            return cordappServices.getInstance(type) ?: throw IllegalArgumentException("Corda service ${type.name} does not exist")
+            return cordappServices.getInstance(type)
+                    ?: throw IllegalArgumentException("Corda service ${type.name} does not exist")
         }
 
         override fun getFlowFactory(initiatingFlowClass: Class<out FlowLogic<*>>): InitiatedFlowFactory<*>? {
diff --git a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt
index 0a79109ce6..bd4f0f8b4d 100644
--- a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt
+++ b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapCache.kt
@@ -166,7 +166,7 @@ open class PersistentNetworkMapCache(
         }
     }
 
-    override fun getNodesByLegalName(name: CordaX500Name): List<NodeInfo> = database.transaction { queryByLegalName(session, name) }
+    override fun getNodesByLegalName(name: CordaX500Name): List<NodeInfo> = database.transaction { queryByLegalName(session, name) }.sortedByDescending { it.serial }
 
     override fun getNodesByLegalIdentityKey(identityKey: PublicKey): List<NodeInfo> = nodesByKeyCache[identityKey]!!
 
diff --git a/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt b/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt
index ea95fcf464..76e2d0b7e4 100644
--- a/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt
+++ b/node/src/test/kotlin/net/corda/node/internal/NodeTest.kt
@@ -58,7 +58,7 @@ class NodeTest {
 
     @Test
     fun `generateAndSaveNodeInfo works`() {
-        val configuration = createConfig()
+        val configuration = createConfig(ALICE_NAME)
         val platformVersion = 789
         configureDatabase(configuration.dataSourceProperties, configuration.database, { null }, { null }).use { database ->
             val node = Node(configuration, rigorousMock<VersionInfo>().also {
@@ -70,7 +70,7 @@ class NodeTest {
 
     @Test
     fun `clear network map cache works`() {
-        val configuration = createConfig()
+        val configuration = createConfig(ALICE_NAME)
         val (nodeInfo, _) = createNodeInfoAndSigned(ALICE_NAME)
         configureDatabase(configuration.dataSourceProperties, configuration.database, { null }, { null }).use {
             it.transaction {
@@ -96,6 +96,52 @@ class NodeTest {
         }
     }
 
+    @Test
+    fun `Node can start with multiple keypairs for it's identity`() {
+        val configuration = createConfig(ALICE_NAME)
+        val (nodeInfo1, _) = createNodeInfoAndSigned(ALICE_NAME)
+        val (nodeInfo2, _) = createNodeInfoAndSigned(ALICE_NAME)
+
+
+        val persistentNodeInfo2 = NodeInfoSchemaV1.PersistentNodeInfo(
+                id = 0,
+                hash = nodeInfo2.serialize().hash.toString(),
+                addresses = nodeInfo2.addresses.map { NodeInfoSchemaV1.DBHostAndPort.fromHostAndPort(it) },
+                legalIdentitiesAndCerts = nodeInfo2.legalIdentitiesAndCerts.mapIndexed { idx, elem ->
+                    NodeInfoSchemaV1.DBPartyAndCertificate(elem, isMain = idx == 0)
+                },
+                platformVersion = nodeInfo2.platformVersion,
+                serial = nodeInfo2.serial
+        )
+
+        val persistentNodeInfo1 = NodeInfoSchemaV1.PersistentNodeInfo(
+                id = 0,
+                hash = nodeInfo1.serialize().hash.toString(),
+                addresses = nodeInfo1.addresses.map { NodeInfoSchemaV1.DBHostAndPort.fromHostAndPort(it) },
+                legalIdentitiesAndCerts = nodeInfo1.legalIdentitiesAndCerts.mapIndexed { idx, elem ->
+                    NodeInfoSchemaV1.DBPartyAndCertificate(elem, isMain = idx == 0)
+                },
+                platformVersion = nodeInfo1.platformVersion,
+                serial = nodeInfo1.serial
+        )
+
+        configureDatabase(configuration.dataSourceProperties, configuration.database, { null }, { null }).use {
+            it.transaction {
+                session.save(persistentNodeInfo1)
+            }
+            it.transaction {
+                session.save(persistentNodeInfo2)
+            }
+
+            val node = Node(configuration, rigorousMock<VersionInfo>().also {
+                doReturn(10).whenever(it).platformVersion
+            }, initialiseSerialization = false)
+
+            //this throws an exception with old behaviour
+            node.generateNodeInfo()
+        }
+    }
+
     private fun getAllInfos(database: CordaPersistence): List<NodeInfoSchemaV1.PersistentNodeInfo> {
         return database.transaction {
             val criteria = session.criteriaBuilder.createQuery(NodeInfoSchemaV1.PersistentNodeInfo::class.java)
@@ -104,11 +150,10 @@ class NodeTest {
         }
     }
 
-    private fun createConfig(): NodeConfiguration {
+    private fun createConfig(nodeName: CordaX500Name): NodeConfiguration {
         val dataSourceProperties = makeTestDataSourceProperties()
         val databaseConfig = DatabaseConfig()
         val nodeAddress = NetworkHostAndPort("0.1.2.3", 456)
-        val nodeName = CordaX500Name("Manx Blockchain Corp", "Douglas", "IM")
         return rigorousMock<AbstractNodeConfiguration>().also {
             doReturn(nodeAddress).whenever(it).p2pAddress
             doReturn(nodeName).whenever(it).myLegalName