From 93bb24ed17d028d8ac515d3b9d2f10f7ec686c3e Mon Sep 17 00:00:00 2001 From: Shams Asari Date: Mon, 30 Jul 2018 10:35:03 +0100 Subject: [PATCH] Internal mock node clean up (#3715) * InMemoryMessagingNetwork.InMemoryMessaging renamed to MockNodeMessagingService and moved to internal package * start method added to MockNodeMessagingService which enables AbstractNode to call makeMessagingService in its c'tor * Removed TopicStringValidator as it's no longer used * Clean up of TestStartedNode * Merged InMemoryMessagingTests into InternalMockNetworkTests as it's testing InternalMockNetwork --- .idea/compiler.xml | 4 + .../net/corda/core/flows/FlowTestsUtils.kt | 4 +- .../AttachmentSerializationTest.kt | 2 +- .../net/corda/node/internal/AbstractNode.kt | 21 +- .../kotlin/net/corda/node/internal/Node.kt | 13 +- .../node/services/messaging/Messaging.kt | 7 - .../SingleThreadedStateMachineManager.kt | 2 +- .../node/messaging/InMemoryMessagingTests.kt | 115 ------ .../node/messaging/TwoPartyTradeFlowTests.kt | 3 +- .../statemachine/FlowFrameworkTests.kt | 2 +- .../statemachine/RetryFlowMockTest.kt | 13 +- .../ValidatingNotaryServiceTests.kt | 24 +- .../vault/VaultSoftLockManagerTest.kt | 2 +- .../testing/node/InMemoryMessagingNetwork.kt | 346 +++--------------- .../net/corda/testing/node/MockNetwork.kt | 2 +- .../testing/node/internal/InMemoryMessage.kt | 4 +- .../node/internal/InternalMockNetwork.kt | 116 +++--- .../node/internal/InternalTestUtils.kt | 5 - .../node/internal/MockNodeMessagingService.kt | 283 ++++++++++++++ .../node/internal/InternalMockNetworkTests.kt | 96 ++++- 20 files changed, 529 insertions(+), 535 deletions(-) delete mode 100644 node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt create mode 100644 testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNodeMessagingService.kt diff --git a/.idea/compiler.xml b/.idea/compiler.xml index fbd48220f1..bc63206a42 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -128,6 +128,10 @@ + + + + diff --git a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt index e058fbcce6..fbb2d3bad2 100644 --- a/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt +++ b/core/src/test/kotlin/net/corda/core/flows/FlowTestsUtils.kt @@ -38,14 +38,14 @@ class NoAnswer(private val closure: () -> Unit = {}) : FlowLogic() { * Allows to register a flow of type [R] against an initiating flow of type [I]. */ inline fun , reified R : FlowLogic<*>> TestStartedNode.registerInitiatedFlow(initiatingFlowType: KClass, crossinline construct: (session: FlowSession) -> R) { - internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true) + registerFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> construct(session) }, R::class.javaObjectType, true) } /** * Allows to register a flow of type [Answer] against an initiating flow of type [I], returning a valure of type [R]. */ inline fun , reified R : Any> TestStartedNode.registerAnswer(initiatingFlowType: KClass, value: R) { - internalRegisterFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true) + registerFlowFactory(initiatingFlowType.java, InitiatedFlowFactory.Core { session -> Answer(session, value) }, Answer::class.javaObjectType, true) } /** diff --git a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt index af7a46d936..92739f70c6 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt @@ -151,7 +151,7 @@ class AttachmentSerializationTest { } private fun launchFlow(clientLogic: ClientLogic, rounds: Int, sendData: Boolean = false) { - server.internalRegisterFlowFactory( + server.registerFlowFactory( ClientLogic::class.java, InitiatedFlowFactory.Core { ServerLogic(it, sendData) }, ServerLogic::class.java, diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index 7bba2ee345..3c6b59b6dd 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -107,15 +107,12 @@ import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair * sweeping up the Node into the Kryo checkpoint serialization via any flows holding a reference to ServiceHub. */ // TODO Log warning if this node is a notary but not one of the ones specified in the network parameters, both for core and custom - -// In theory the NodeInfo for the node should be passed in, instead, however currently this is constructed by the -// AbstractNode. It should be possible to generate the NodeInfo outside of AbstractNode, so it can be passed in. abstract class AbstractNode(val configuration: NodeConfiguration, - val platformClock: CordaClock, - protected val versionInfo: VersionInfo, - protected val cordappLoader: CordappLoader, - protected val serverThread: AffinityExecutor.ServiceAffinityExecutor, - private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { + val platformClock: CordaClock, + protected val versionInfo: VersionInfo, + protected val cordappLoader: CordappLoader, + protected val serverThread: AffinityExecutor.ServiceAffinityExecutor, + private val busyNodeLatch: ReusableLatch = ReusableLatch()) : SingletonSerializeAsToken() { protected abstract val log: Logger @@ -180,6 +177,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val transactionVerifierService = InMemoryTransactionVerifierService(transactionVerifierWorkerCount).tokenize() val contractUpgradeService = ContractUpgradeServiceImpl().tokenize() val auditService = DummyAuditService().tokenize() + @Suppress("LeakingThis") + protected val network: MessagingService = makeMessagingService().tokenize() val services = ServiceHubInternalImpl().tokenize() @Suppress("LeakingThis") val smm = makeStateMachineManager() @@ -194,8 +193,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration, configuration.drainingModePollPeriod, unfinishedSchedules = busyNodeLatch ).tokenize().closeOnStop() - // TODO Making this non-lateinit requires MockNode being able to create a blank InMemoryMessaging instance - protected lateinit var network: MessagingService private val cordappServices = MutableClassToInstanceMap.create() private val flowFactories = ConcurrentHashMap>, InitiatedFlowFactory<*>>() @@ -286,10 +283,6 @@ abstract class AbstractNode(val configuration: NodeConfiguration, } log.info("Node starting up ...") - // TODO First thing we do is create the MessagingService. This should have been done by the c'tor but it's not - // possible (yet) to due restriction from MockNode - network = makeMessagingService().tokenize() - val trustRoot = initKeyStore() val nodeCa = configuration.loadNodeKeyStore().getCertificate(X509Utilities.CORDA_CLIENT_CA) initialiseJVMAgents() diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 52dbfe8257..e1f4ae7f81 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -4,7 +4,6 @@ import com.codahale.metrics.JmxReporter import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme import net.corda.core.concurrent.CordaFuture import net.corda.core.flows.FlowLogic -import net.corda.core.flows.InitiatedBy import net.corda.core.identity.CordaX500Name import net.corda.core.identity.PartyAndCertificate import net.corda.core.internal.Emoji @@ -13,7 +12,6 @@ import net.corda.core.internal.concurrent.thenMatch import net.corda.core.internal.div import net.corda.core.internal.errors.AddressBindingException import net.corda.core.internal.notary.NotaryService -import net.corda.node.services.api.StartedNodeServices import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.RPCOps import net.corda.core.node.NetworkParameters @@ -39,6 +37,7 @@ import net.corda.node.serialization.kryo.KryoServerSerializationScheme import net.corda.node.services.Permissions import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.services.api.StartedNodeServices import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.SecurityConfiguration import net.corda.node.services.config.shouldInitCrashShell @@ -58,6 +57,7 @@ import net.corda.serialization.internal.* import org.h2.jdbc.JdbcSQLException import org.slf4j.Logger import org.slf4j.LoggerFactory +import rx.Observable import rx.Scheduler import rx.schedulers.Schedulers import java.net.BindException @@ -66,7 +66,6 @@ import java.time.Clock import java.util.concurrent.atomic.AtomicInteger import javax.management.ObjectName import kotlin.system.exitProcess -import rx.Observable class NodeWithInfo(val node: Node, val info: NodeInfo) { val services: StartedNodeServices = object : StartedNodeServices, ServiceHubInternal by node.services, FlowStarter by node.flowStarter {} @@ -195,7 +194,7 @@ open class Node(configuration: NodeConfiguration, override fun startMessagingService(rpcOps: RPCOps, nodeInfo: NodeInfo, myNotaryIdentity: PartyAndCertificate?, networkParameters: NetworkParameters) { require(nodeInfo.legalIdentities.size in 1..2) { "Currently nodes must have a primary address and optionally one serviced address" } - val client = network as P2PMessagingClient + network as P2PMessagingClient // Construct security manager reading users data either from the 'security' config section // if present or from rpcUsers list if the former is missing from config. @@ -219,7 +218,7 @@ open class Node(configuration: NodeConfiguration, startLocalRpcBroker(securityManager) } - val bridgeControlListener = BridgeControlListener(configuration, client.serverAddress, networkParameters.maxMessageSize) + val bridgeControlListener = BridgeControlListener(configuration, network.serverAddress, networkParameters.maxMessageSize) printBasicNodeInfo("Advertised P2P messaging addresses", nodeInfo.addresses.joinToString()) val rpcServerConfiguration = RPCServerConfiguration.DEFAULT @@ -248,8 +247,8 @@ open class Node(configuration: NodeConfiguration, closeOnStop() init(rpcOps, securityManager) } - client.closeOnStop() - client.start( + network.closeOnStop() + network.start( myIdentity = nodeInfo.legalIdentities[0].owningKey, serviceIdentity = if (nodeInfo.legalIdentities.size == 1) null else nodeInfo.legalIdentities[1].owningKey, advertisedAddress = nodeInfo.addresses[0], diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 2bb319f70b..b95c9bbbe6 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -146,13 +146,6 @@ interface ReceivedMessage : Message { val isSessionInit: Boolean } -/** A singleton that's useful for validating topic strings */ -object TopicStringValidator { - private val regex = "[a-zA-Z0-9.]+".toPattern() - /** @throws IllegalArgumentException if the given topic contains invalid characters */ - fun check(tag: String) = require(regex.matcher(tag).matches()) -} - /** * This handler is used to implement exactly-once delivery of an external event on top of an at-least-once delivery. This is done * using two hooks that are called from the event processor, one called from the database transaction committing the diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt index b3cb8691bc..40d1c92b99 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SingleThreadedStateMachineManager.kt @@ -101,7 +101,7 @@ class SingleThreadedStateMachineManager( private val flowMessaging: FlowMessaging = FlowMessagingImpl(serviceHub) private val fiberDeserializationChecker = if (serviceHub.configuration.shouldCheckCheckpoints()) FiberDeserializationChecker() else null private val transitionExecutor = makeTransitionExecutor() - private val ourSenderUUID get() = serviceHub.networkService.ourSenderUUID // This is a getter since AbstractNode.network is still lateinit + private val ourSenderUUID = serviceHub.networkService.ourSenderUUID private var checkpointSerializationContext: SerializationContext? = null private var actionExecutor: ActionExecutor? = null diff --git a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt deleted file mode 100644 index 4add5f243c..0000000000 --- a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt +++ /dev/null @@ -1,115 +0,0 @@ -package net.corda.node.messaging - -import net.corda.core.messaging.AllPossibleRecipients -import net.corda.node.services.messaging.Message -import net.corda.node.services.messaging.TopicStringValidator -import net.corda.testing.internal.rigorousMock -import net.corda.testing.node.internal.InternalMockNetwork -import org.junit.After -import org.junit.Before -import org.junit.Test -import java.util.* -import kotlin.test.assertEquals -import kotlin.test.assertFails -import kotlin.test.assertTrue - -class InMemoryMessagingTests { - lateinit var mockNet: InternalMockNetwork - - @Before - fun setUp() { - mockNet = InternalMockNetwork() - } - - @After - fun tearDown() { - mockNet.stopNodes() - } - - @Test - fun `topic string validation`() { - TopicStringValidator.check("this.is.ok") - TopicStringValidator.check("this.is.OkAlso") - assertFails { - TopicStringValidator.check("this.is.not-ok") - } - assertFails { - TopicStringValidator.check("") - } - assertFails { - TopicStringValidator.check("this.is not ok") // Spaces - } - } - - @Test - fun basics() { - val node1 = mockNet.createNode() - val node2 = mockNet.createNode() - val node3 = mockNet.createNode() - - val bits = "test-content".toByteArray() - var finalDelivery: Message? = null - node2.network.addMessageHandler("test.topic") { msg, _, _ -> - node2.network.send(msg, node3.network.myAddress) - } - node3.network.addMessageHandler("test.topic") { msg, _, _ -> - finalDelivery = msg - } - - // Node 1 sends a message and it should end up in finalDelivery, after we run the network - node1.network.send(node1.network.createMessage("test.topic", data = bits), node2.network.myAddress) - - mockNet.runNetwork(rounds = 1) - - assertTrue(Arrays.equals(finalDelivery!!.data.bytes, bits)) - } - - @Test - fun broadcast() { - val node1 = mockNet.createNode() - val node2 = mockNet.createNode() - val node3 = mockNet.createNode() - - val bits = "test-content".toByteArray() - - var counter = 0 - listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } } - node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock()) - mockNet.runNetwork(rounds = 1) - assertEquals(3, counter) - } - - /** - * Tests that unhandled messages in the received queue are skipped and the next message processed, rather than - * causing processing to return null as if there was no message. - */ - @Test - fun `skip unhandled messages`() { - val node1 = mockNet.createNode() - val node2 = mockNet.createNode() - var received = 0 - - node1.network.addMessageHandler("valid_message") { _, _, _ -> - received++ - } - - val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1)) - val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1)) - node2.network.send(invalidMessage, node1.network.myAddress) - mockNet.runNetwork() - assertEquals(0, received) - - node2.network.send(validMessage, node1.network.myAddress) - mockNet.runNetwork() - assertEquals(1, received) - - // Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so - // this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops - val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1)) - val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1)) - node2.network.send(invalidMessage2, node1.network.myAddress) - node2.network.send(validMessage2, node1.network.myAddress) - mockNet.runNetwork() - assertEquals(2, received) - } -} diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index e8a5ae22b6..66a8430516 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -47,7 +47,6 @@ import net.corda.testing.internal.LogHelper import net.corda.testing.internal.TEST_TX_TIME import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.vault.VaultFiller -import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockServices import net.corda.testing.node.internal.* import net.corda.testing.node.ledger @@ -225,7 +224,7 @@ class TwoPartyTradeFlowTests(private val anonymous: Boolean) { aliceNode.internals.disableDBCloseOnStop() bobNode.internals.disableDBCloseOnStop() - val bobAddr = bobNode.network.myAddress as InMemoryMessagingNetwork.PeerHandle + val bobAddr = bobNode.network.myAddress mockNet.runNetwork() // Clear network map registration messages val notary = mockNet.defaultNotaryIdentity diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 59098cffed..0d6dc218d8 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -821,7 +821,7 @@ private inline fun > TestStartedNode.registerFlowFactor initiatingFlowClass: KClass>, initiatedFlowVersion: Int = 1, noinline flowFactory: (FlowSession) -> P): CordaFuture

{ - val observable = internalRegisterFlowFactory( + val observable = registerFlowFactory( initiatingFlowClass.java, InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory), P::class.java, diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt index 8053edd357..82f4d45a82 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/RetryFlowMockTest.kt @@ -15,11 +15,6 @@ import net.corda.node.services.FinalityHandler import net.corda.node.services.messaging.Message import net.corda.node.services.persistence.DBTransactionStorage import net.corda.nodeapi.internal.persistence.contextTransaction -import net.corda.testing.node.internal.cordappsForPackages -import net.corda.testing.node.internal.InternalMockNetwork -import net.corda.testing.node.internal.MessagingServiceSpy -import net.corda.testing.node.internal.newContext -import net.corda.testing.node.internal.setMessagingServiceSpy import net.corda.testing.node.internal.* import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy @@ -79,7 +74,7 @@ class RetryFlowMockTest { fun `Retry does not set senderUUID`() { val messagesSent = Collections.synchronizedList(mutableListOf()) val partyB = nodeB.info.legalIdentities.first() - nodeA.setMessagingServiceSpy(object : MessagingServiceSpy(nodeA.network) { + nodeA.setMessagingServiceSpy(object : MessagingServiceSpy() { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { messagesSent.add(message) messagingService.send(message, target) @@ -95,7 +90,7 @@ class RetryFlowMockTest { fun `Restart does not set senderUUID`() { val messagesSent = Collections.synchronizedList(mutableListOf()) val partyB = nodeB.info.legalIdentities.first() - nodeA.setMessagingServiceSpy(object : MessagingServiceSpy(nodeA.network) { + nodeA.setMessagingServiceSpy(object : MessagingServiceSpy() { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { messagesSent.add(message) messagingService.send(message, target) @@ -109,7 +104,7 @@ class RetryFlowMockTest { assertNotNull(messagesSent.first().senderUUID) nodeA = mockNet.restartNode(nodeA) // This is a bit racy because restarting the node actually starts it, so we need to make sure there's enough iterations we get here with flow still going. - nodeA.setMessagingServiceSpy(object : MessagingServiceSpy(nodeA.network) { + nodeA.setMessagingServiceSpy(object : MessagingServiceSpy() { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { messagesSent.add(message) messagingService.send(message, target) @@ -117,7 +112,7 @@ class RetryFlowMockTest { }) // Now short circuit the iterations so the flow finishes soon. KeepSendingFlow.count.set(count - 2) - while (nodeA.smm.allStateMachines.size > 0) { + while (nodeA.smm.allStateMachines.isNotEmpty()) { Thread.sleep(10) } assertNull(messagesSent.last().senderUUID) diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt index 1979cf962f..4815ace9d8 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt @@ -4,18 +4,8 @@ import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.Command import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef -import net.corda.core.crypto.Crypto -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.TransactionSignature -import net.corda.core.crypto.generateKeyPair -import net.corda.core.crypto.sha256 -import net.corda.core.crypto.sign -import net.corda.core.flows.NotarisationPayload -import net.corda.core.flows.NotarisationRequest -import net.corda.core.flows.NotarisationRequestSignature -import net.corda.core.flows.NotaryError -import net.corda.core.flows.NotaryException -import net.corda.core.flows.NotaryFlow +import net.corda.core.crypto.* +import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.internal.notary.generateSignature import net.corda.core.messaging.MessageRecipients @@ -35,13 +25,6 @@ import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.dummyCommand import net.corda.testing.core.singleIdentity import net.corda.testing.node.TestClock -import net.corda.testing.node.internal.cordappsForPackages -import net.corda.testing.node.internal.InMemoryMessage -import net.corda.testing.node.internal.InternalMockNetwork -import net.corda.testing.node.internal.InternalMockNodeParameters -import net.corda.testing.node.internal.MessagingServiceSpy -import net.corda.testing.node.internal.setMessagingServiceSpy -import net.corda.testing.node.internal.startFlow import net.corda.testing.node.internal.* import org.assertj.core.api.Assertions.assertThat import org.junit.After @@ -308,7 +291,7 @@ class ValidatingNotaryServiceTests { } private fun runNotarisationAndInterceptClientPayload(payloadModifier: (NotarisationPayload) -> NotarisationPayload) { - aliceNode.setMessagingServiceSpy(object : MessagingServiceSpy(aliceNode.network) { + aliceNode.setMessagingServiceSpy(object : MessagingServiceSpy() { override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { val messageData = message.data.deserialize() as? InitialSessionMessage val payload = messageData?.firstPayload!!.deserialize() @@ -318,7 +301,6 @@ class ValidatingNotaryServiceTests { val alteredMessageData = messageData.copy(firstPayload = alteredPayload.serialize()) val alteredMessage = InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData.serialize().bytes), message.uniqueMessageId) messagingService.send(alteredMessage, target) - } else { messagingService.send(message, target) } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt index a6e2bb307a..29c1fccf83 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultSoftLockManagerTest.kt @@ -66,7 +66,7 @@ class NodePair(private val mockNet: InternalMockNetwork) { private set fun communicate(clientLogic: AbstractClientLogic, rebootClient: Boolean): FlowStateMachine { - server.internalRegisterFlowFactory(AbstractClientLogic::class.java, InitiatedFlowFactory.Core { ServerLogic(it, serverRunning) }, ServerLogic::class.java, false) + server.registerFlowFactory(AbstractClientLogic::class.java, InitiatedFlowFactory.Core { ServerLogic(it, serverRunning) }, ServerLogic::class.java, false) client.services.startFlow(clientLogic) while (!serverRunning.get()) mockNet.runNetwork(1) if (rebootClient) { diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index fed57f372b..357a200e87 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -1,45 +1,36 @@ package net.corda.testing.node +import net.corda.core.CordaInternal import net.corda.core.DoNotImplement import net.corda.core.crypto.CompositeKey import net.corda.core.identity.CordaX500Name import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate -import net.corda.core.internal.ThreadBox import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.AllPossibleRecipients import net.corda.core.messaging.MessageRecipientGroup import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient -import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.ByteSequence -import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.trace -import net.corda.node.services.messaging.* -import net.corda.node.services.statemachine.DeduplicationId -import net.corda.node.services.statemachine.ExternalEvent -import net.corda.node.services.statemachine.SenderDeduplicationId -import net.corda.node.utilities.AffinityExecutor -import net.corda.testing.node.internal.InMemoryMessage -import net.corda.testing.node.internal.InternalMockMessagingService +import net.corda.node.services.messaging.Message +import net.corda.testing.node.internal.MockNodeMessagingService import org.apache.activemq.artemis.utils.ReusableLatch import org.slf4j.LoggerFactory import rx.Observable import rx.subjects.PublishSubject import java.time.Duration -import java.time.Instant import java.util.* import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.ThreadSafe import kotlin.concurrent.schedule -import kotlin.concurrent.thread /** - * An in-memory network allows you to manufacture [InternalMockMessagingService]s for a set of participants. Each - * [InternalMockMessagingService] maintains a queue of messages it has received, and a background thread that dispatches + * An in-memory network allows you to manufacture [MockNodeMessagingService]s for a set of participants. Each + * [MockNodeMessagingService] maintains a queue of messages it has received, and a background thread that dispatches * messages one by one to registered handlers. Alternatively, a messaging system may be manually pumped, in which * case no thread is created and a caller is expected to force delivery one at a time (this is useful for unit * testing). @@ -57,16 +48,16 @@ class InMemoryMessagingNetwork private constructor( private const val MESSAGES_LOG_NAME = "messages" private val log = LoggerFactory.getLogger(MESSAGES_LOG_NAME) - internal fun create( - sendManuallyPumped: Boolean, - servicePeerAllocationStrategy: ServicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(), - messagesInFlight: ReusableLatch = ReusableLatch()): InMemoryMessagingNetwork { + @CordaInternal + internal fun create(sendManuallyPumped: Boolean, + servicePeerAllocationStrategy: ServicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(), + messagesInFlight: ReusableLatch = ReusableLatch()): InMemoryMessagingNetwork { return InMemoryMessagingNetwork(sendManuallyPumped, servicePeerAllocationStrategy, messagesInFlight) } } private var counter = 0 // -1 means stopped. - private val handleEndpointMap = HashMap() + private val handleEndpointMap = HashMap() /** A class which represents a message being transferred from sender to recipients, within the [InMemoryMessageNetwork]. **/ @CordaSerializable @@ -107,40 +98,35 @@ class InMemoryMessagingNetwork private constructor( /** A stream of (sender, message, recipients) triples containing messages once they have been received. */ val receivedMessages: Observable get() = _receivedMessages - internal val endpoints: List @Synchronized get() = handleEndpointMap.values.toList() + internal val endpoints: List + @CordaInternal + @Synchronized + get() = handleEndpointMap.values.toList() /** Get a [List] of all the [MockMessagingService] endpoints **/ - val endpointsExternal: List @Synchronized get() = handleEndpointMap.values.map { MockMessagingService.createMockMessagingService(it) }.toList() + val endpointsExternal: List + @Synchronized + get() = handleEndpointMap.values.map { MockMessagingService.createMockMessagingService(it) }.toList() - /** - * Creates a node at the given address: useful if you want to recreate a node to simulate a restart. - * - * @param manuallyPumped if set to true, then you are expected to call [InMemoryMessaging.pumpReceive] - * in order to cause the delivery of a single message, which will occur on the thread of the caller. If set to false - * then this class will set up a background thread to deliver messages asynchronously, if the handler specifies no - * executor. - * @param id the numeric ID to use, e.g. set to whatever ID the node used last time. - * @param description text string that identifies this node for message logging (if is enabled) or null to autogenerate. - */ - internal fun createNodeWithID( - manuallyPumped: Boolean, - id: Int, - executor: AffinityExecutor, - description: CordaX500Name = CordaX500Name(organisation = "In memory node $id", locality = "London", country = "UK") - ): InternalMockMessagingService { - val peerHandle = PeerHandle(id, description) - peersMapping[peerHandle.name] = peerHandle // Assume that the same name - the same entity in MockNetwork. + @CordaInternal + internal fun getPeer(name: CordaX500Name): PeerHandle? = peersMapping[name] + + @CordaInternal + internal fun initPeer(messagingService: MockNodeMessagingService): MockNodeMessagingService? { + peersMapping[messagingService.myAddress.name] = messagingService.myAddress // Assume that the same name - the same entity in MockNetwork. return synchronized(this) { - val node = InMemoryMessaging(manuallyPumped, peerHandle, executor) - val oldNode = handleEndpointMap.put(peerHandle, node) - if (oldNode != null) { - node.inheritPendingRedelivery(oldNode) - } - node + handleEndpointMap.put(messagingService.myAddress, messagingService) } } - internal fun onNotaryIdentity(node: InternalMockMessagingService, notaryService: PartyAndCertificate?) { - val peerHandle = (node as InMemoryMessaging).peerHandle + @CordaInternal + internal fun onMessageTransfer(transfer: MessageTransfer) { + _receivedMessages.onNext(transfer) + messagesInFlight.countDown() + } + + @CordaInternal + internal fun addNotaryIdentity(node: MockNodeMessagingService, notaryService: PartyAndCertificate?) { + val peerHandle = node.myAddress notaryService?.let { if (it.owningKey !is CompositeKey) peersMapping[it.name] = peerHandle } val serviceHandles = notaryService?.let { listOf(DistributedServiceHandle(it.party)) } ?: emptyList() //TODO only notary can be distributed? @@ -161,22 +147,31 @@ class InMemoryMessagingNetwork private constructor( private var latencyCalculator: LatencyCalculator? = null private val timer = Timer() - @Synchronized - private fun msgSend(from: InMemoryMessaging, message: Message, recipients: MessageRecipients) { - messagesInFlight.countUp() - messageSendQueue += MessageTransfer.createMessageTransfer(from.myAddress, message, recipients) + @CordaInternal + internal fun msgSend(from: MockNodeMessagingService, message: Message, recipients: MessageRecipients) { + synchronized(this) { + messagesInFlight.countUp() + messageSendQueue += MessageTransfer.createMessageTransfer(from.myAddress, message, recipients) + } + if (!sendManuallyPumped) { + pumpSend(false) + } } + @CordaInternal @Synchronized - private fun netNodeHasShutdown(peerHandle: PeerHandle) { + internal fun netNodeHasShutdown(peerHandle: PeerHandle) { val endpoint = handleEndpointMap[peerHandle] - if (!(endpoint?.hasPendingDeliveries() ?: false)) { + if (endpoint?.hasPendingDeliveries() != true) { handleEndpointMap.remove(peerHandle) } } + @CordaInternal @Synchronized - private fun getQueueForPeerHandle(recipients: PeerHandle) = messageReceiveQueues.getOrPut(recipients) { LinkedBlockingQueue() } + internal fun getQueueForPeerHandle(recipients: PeerHandle): LinkedBlockingQueue { + return messageReceiveQueues.getOrPut(recipients) { LinkedBlockingQueue() } + } @Synchronized private fun getQueuesForServiceHandle(recipients: DistributedServiceHandle): List> { @@ -276,30 +271,6 @@ class InMemoryMessagingNetwork private constructor( return transfer } - /** - * When a new message handler is added, this implies we have started a new node. The add handler logic uses this to - * push back any un-acknowledged messages for this peer onto the head of the queue (rather than the tail) to maintain message - * delivery order. We push them back because their consumption was not complete and a restarted node would - * see them re-delivered if this was Artemis. - */ - @Synchronized - private fun unPopMessages(transfers: Collection, us: PeerHandle) { - messageReceiveQueues.compute(us) { _, existing -> - if (existing == null) { - LinkedBlockingQueue().apply { - addAll(transfers) - } - } else { - existing.apply { - val drained = mutableListOf() - existing.drainTo(drained) - existing.addAll(transfers) - existing.addAll(drained) - } - } - } - } - private fun pumpSendInternal(transfer: MessageTransfer) { when (transfer.recipients) { is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer) @@ -311,35 +282,25 @@ class InMemoryMessagingNetwork private constructor( is AllPossibleRecipients -> { // This means all possible recipients _that the network knows about at the time_, not literally everyone // who joins into the indefinite future. - for (handle in handleEndpointMap.keys) - getQueueForPeerHandle(handle).add(transfer) + synchronized(this) { + for (handle in handleEndpointMap.keys) { + getQueueForPeerHandle(handle).add(transfer) + } + } } else -> throw IllegalArgumentException("Unknown type of recipient handle") } _sentMessages.onNext(transfer) } - private data class InMemoryReceivedMessage(override val topic: String, - override val data: ByteSequence, - override val platformVersion: Int, - override val uniqueMessageId: DeduplicationId, - override val debugTimestamp: Instant, - override val peer: CordaX500Name, - override val senderUUID: String? = null, - override val senderSeqNo: Long? = null, - /** Note this flag is never set in the in memory network. */ - override val isSessionInit: Boolean = false) : ReceivedMessage { - - override val additionalHeaders: Map = emptyMap() - } - /** * A class that provides an abstraction over the nodes' messaging service that also contains the ability to * receive messages from the queue for testing purposes. */ - class MockMessagingService private constructor(private val messagingService: InternalMockMessagingService) { + class MockMessagingService private constructor(private val messagingService: MockNodeMessagingService) { companion object { - internal fun createMockMessagingService(messagingService: InternalMockMessagingService): MockMessagingService { + @CordaInternal + internal fun createMockMessagingService(messagingService: MockNodeMessagingService): MockMessagingService { return MockMessagingService(messagingService) } } @@ -352,197 +313,4 @@ class InMemoryMessagingNetwork private constructor( */ fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? = messagingService.pumpReceive(block) } - - @ThreadSafe - private inner class InMemoryMessaging(private val manuallyPumped: Boolean, - val peerHandle: PeerHandle, - private val executor: AffinityExecutor) : SingletonSerializeAsToken(), InternalMockMessagingService { - private inner class Handler(val topicSession: String, val callback: MessageHandler) : MessageHandlerRegistration - - @Volatile - private var running = true - - private inner class InnerState { - val handlers: MutableList = ArrayList() - val pendingRedelivery = LinkedHashSet() - } - - private val state = ThreadBox(InnerState()) - private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) - - override val myAddress: PeerHandle get() = peerHandle - override val ourSenderUUID: String = UUID.randomUUID().toString() - - private val backgroundThread = if (manuallyPumped) null else - thread(isDaemon = true, name = "In-memory message dispatcher") { - while (!Thread.currentThread().isInterrupted) { - try { - pumpReceiveInternal(true) - } catch (e: InterruptedException) { - break - } - } - } - - override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients { - return when (partyInfo) { - is PartyInfo.SingleNode -> peersMapping[partyInfo.party.name] - ?: throw IllegalArgumentException("No StartedMockNode for party ${partyInfo.party.name}") - is PartyInfo.DistributedNode -> DistributedServiceHandle(partyInfo.party) - } - } - - override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration { - check(running) - val (handler, transfers) = state.locked { - val handler = Handler(topic, callback).apply { handlers.add(this) } - val pending = ArrayList() - pending.addAll(pendingRedelivery) - pendingRedelivery.clear() - Pair(handler, pending) - } - - unPopMessages(transfers, peerHandle) - return handler - } - - fun inheritPendingRedelivery(other: InMemoryMessaging) { - state.locked { - pendingRedelivery.addAll(other.state.locked { pendingRedelivery }) - } - } - - override fun removeMessageHandler(registration: MessageHandlerRegistration) { - check(running) - state.locked { check(handlers.remove(registration as Handler)) } - } - - override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { - check(running) - msgSend(this, message, target) - if (!sendManuallyPumped) { - pumpSend(false) - } - } - - override fun send(addressedMessages: List) { - for ((message, target, sequenceKey) in addressedMessages) { - send(message, target, sequenceKey) - } - } - - override fun close() { - if (backgroundThread != null) { - backgroundThread.interrupt() - backgroundThread.join() - } - running = false - netNodeHasShutdown(peerHandle) - } - - /** Returns the given (topic & session, data) pair as a newly created message object. */ - override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map): Message { - return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID) - } - - /** - * Delivers a single message from the internal queue. If there are no messages waiting to be delivered and block - * is true, waits until one has been provided on a different thread via send. If block is false, the return - * result indicates whether a message was delivered or not. - * - * @return the message that was processed, if any in this round. - */ - override fun pumpReceive(block: Boolean): MessageTransfer? { - check(manuallyPumped) - check(running) - executor.flush() - try { - return pumpReceiveInternal(block) - } finally { - executor.flush() - } - } - - /** - * Get the next transfer, and matching queue, that is ready to handle. Any pending transfers without handlers - * are placed into `pendingRedelivery` to try again later. - * - * @param block if this should block until a message it can process. - */ - private fun getNextQueue(q: LinkedBlockingQueue, block: Boolean): Pair>? { - var deliverTo: List? = null - // Pop transfers off the queue until we run out (and are not blocking), or find something we can process - while (deliverTo == null) { - val transfer = (if (block) q.take() else q.poll()) ?: return null - deliverTo = state.locked { - val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topic == it.topicSession } - if (matchingHandlers.isEmpty()) { - // Got no handlers for this message yet. Keep the message around and attempt redelivery after a new - // handler has been registered. The purpose of this path is to make unit tests that have multi-threading - // reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting - // up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at - // least sometimes. - log.warn("Message to ${transfer.message.topic} could not be delivered") - pendingRedelivery.add(transfer) - null - } else { - matchingHandlers - } - } - if (deliverTo != null) { - return Pair(transfer, deliverTo) - } - } - return null - } - - private fun pumpReceiveInternal(block: Boolean): MessageTransfer? { - val q = getQueueForPeerHandle(peerHandle) - val (transfer, deliverTo) = getNextQueue(q, block) ?: return null - if (transfer.message.uniqueMessageId !in processedMessages) { - executor.execute { - for (handler in deliverTo) { - try { - val receivedMessage = transfer.toReceivedMessage() - state.locked { pendingRedelivery.add(transfer) } - handler.callback(receivedMessage, handler, InMemoryDeduplicationHandler(receivedMessage, transfer)) - } catch (e: Exception) { - log.error("Caught exception in handler for $this/${handler.topicSession}", e) - } - } - _receivedMessages.onNext(transfer) - messagesInFlight.countDown() - } - } else { - log.info("Drop duplicate message ${transfer.message.uniqueMessageId}") - } - return transfer - } - - private fun MessageTransfer.toReceivedMessage(): ReceivedMessage = InMemoryReceivedMessage( - message.topic, - OpaqueBytes(message.data.bytes.copyOf()), // Kryo messes with the buffer so give each client a unique copy - 1, - message.uniqueMessageId, - message.debugTimestamp, - sender.name) - - private inner class InMemoryDeduplicationHandler(override val receivedMessage: ReceivedMessage, val transfer: MessageTransfer) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent { - override val externalCause: ExternalEvent - get() = this - override val deduplicationHandler: DeduplicationHandler - get() = this - - override fun afterDatabaseTransaction() { - this@InMemoryMessaging.state.locked { pendingRedelivery.remove(transfer) } - } - - override fun insideDatabaseTransaction() { - processedMessages += transfer.message.uniqueMessageId - } - } - - fun hasPendingDeliveries(): Boolean = state.locked { pendingRedelivery.isNotEmpty() } - } } - diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt index 291911a6c6..21f2f6021e 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/MockNetwork.kt @@ -161,7 +161,7 @@ class StartedMockNode private constructor(private val node: TestStartedNode) { * @return the message that was processed, if any in this round. */ fun pumpReceive(block: Boolean = false): InMemoryMessagingNetwork.MessageTransfer? { - return (node.network as InternalMockMessagingService).pumpReceive(block) + return node.network.pumpReceive(block) } /** Returns the currently live flows of type [flowClass], and their corresponding result future. */ diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt index 52132ebc92..d0c8ae135b 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InMemoryMessage.kt @@ -6,7 +6,7 @@ import net.corda.node.services.statemachine.DeduplicationId import java.time.Instant /** - * An implementation of [Message] for in memory messaging by the test [InMemoryMessagingNetwork]. + * An implementation of [Message] for in memory messaging by the test [MockNodeMessagingService]. */ data class InMemoryMessage(override val topic: String, override val data: ByteSequence, @@ -17,4 +17,4 @@ data class InMemoryMessage(override val topic: String, override val additionalHeaders: Map = emptyMap() override fun toString() = "$topic#${String(data.bytes)}" -} \ No newline at end of file +} diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt index 48e6c0ff40..50b951ef60 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalMockNetwork.kt @@ -24,7 +24,10 @@ import net.corda.core.node.NodeInfo import net.corda.core.node.NotaryInfo import net.corda.core.node.services.IdentityService import net.corda.core.serialization.SerializationWhitelist -import net.corda.core.utilities.* +import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.contextLogger +import net.corda.core.utilities.hours +import net.corda.core.utilities.seconds import net.corda.node.VersionInfo import net.corda.node.cordapp.CordappLoader import net.corda.node.internal.AbstractNode @@ -36,6 +39,7 @@ import net.corda.node.services.api.StartedNodeServices import net.corda.node.services.config.* import net.corda.node.services.keys.E2ETestKeyManagementService import net.corda.node.services.keys.KeyManagementServiceInternal +import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.MessagingService import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.statemachine.StateMachineManager @@ -70,10 +74,6 @@ import java.util.concurrent.atomic.AtomicInteger val MOCK_VERSION_INFO = VersionInfo(1, "Mock release", "Mock revision", "Mock Vendor") -fun TestStartedNode.pumpReceive(block: Boolean = false): InMemoryMessagingNetwork.MessageTransfer? { - return (network as InternalMockMessagingService).pumpReceive(block) -} - data class MockNodeArgs( val config: NodeConfiguration, val network: InternalMockNetwork, @@ -109,12 +109,23 @@ interface TestStartedNode { val smm: StateMachineManager val attachments: NodeAttachmentService val rpcOps: CordaRPCOps - val network: MessagingService + val network: MockNodeMessagingService val database: CordaPersistence val notaryService: NotaryService? fun dispose() = internals.stop() + fun pumpReceive(block: Boolean = false): InMemoryMessagingNetwork.MessageTransfer? { + return network.pumpReceive(block) + } + + /** + * Attach a [MessagingServiceSpy] to the [InternalMockNetwork.MockNode] allowing interception and modification of messages. + */ + fun setMessagingServiceSpy(spy: MessagingServiceSpy) { + internals.setMessagingServiceSpy(spy) + } + /** * Use this method to register your initiated flows in your tests. This is automatically done by the node when it * starts up for all [FlowLogic] classes it finds which are annotated with [InitiatedBy]. @@ -122,10 +133,10 @@ interface TestStartedNode { */ fun > registerInitiatedFlow(initiatedFlowClass: Class): Observable - fun > internalRegisterFlowFactory(initiatingFlowClass: Class>, - flowFactory: InitiatedFlowFactory, - initiatedFlowClass: Class, - track: Boolean): Observable + fun > registerFlowFactory(initiatingFlowClass: Class>, + flowFactory: InitiatedFlowFactory, + initiatedFlowClass: Class, + track: Boolean): Observable } open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNetworkParameters(), @@ -145,10 +156,6 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe require(networkParameters.notaries.isEmpty()) { "Define notaries using notarySpecs" } } - private companion object { - private val logger = loggerFor() - } - var nextNodeId = 0 private set private val filesystem = Jimfs.newFileSystem(unix()) @@ -282,12 +289,15 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe args.network.getServerThread(args.id), args.network.busyLatch ) { + companion object { + private val staticLog = contextLogger() + } - /** The actual [StartedNode] implementation created by this node */ + /** The actual [TestStartedNode] implementation created by this node */ private class TestStartedNodeImpl( override val internals: MockNode, override val attachments: NodeAttachmentService, - override val network: MessagingService, + override val network: MockNodeMessagingService, override val services: StartedNodeServices, override val info: NodeInfo, override val smm: StateMachineManager, @@ -295,7 +305,7 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe override val rpcOps: CordaRPCOps, override val notaryService: NotaryService?) : TestStartedNode { - override fun > internalRegisterFlowFactory( + override fun > registerFlowFactory( initiatingFlowClass: Class>, flowFactory: InitiatedFlowFactory, initiatedFlowClass: Class, @@ -308,25 +318,11 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe internals.registerInitiatedFlow(smm, initiatedFlowClass) } - override fun createStartedNode(nodeInfo: NodeInfo, rpcOps: CordaRPCOps, notaryService: NotaryService?): TestStartedNode = - TestStartedNodeImpl( - this, - attachments, - network, - object : StartedNodeServices, ServiceHubInternal by services, FlowStarter by flowStarter { }, - nodeInfo, - smm, - database, - rpcOps, - notaryService - ) - - companion object { - private val staticLog = contextLogger() - } - val mockNet = args.network val id = args.id + init { + require(id >= 0) { "Node ID must be zero or positive, was passed: $id" } + } private val entropyRoot = args.entropyRoot var counter = entropyRoot override val log get() = staticLog @@ -343,9 +339,23 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe override val started: TestStartedNode? get() = uncheckedCast(super.started) + override fun createStartedNode(nodeInfo: NodeInfo, rpcOps: CordaRPCOps, notaryService: NotaryService?): TestStartedNode { + return TestStartedNodeImpl( + this, + attachments, + network as MockNodeMessagingService, + object : StartedNodeServices, ServiceHubInternal by services, FlowStarter by flowStarter { }, + nodeInfo, + smm, + database, + rpcOps, + notaryService + ) + } + override fun start(): TestStartedNode { mockNet.networkParametersCopier.install(configuration.baseDirectory) - return super.start().also { advertiseNodeToNetwork(it) } + return super.start().also(::advertiseNodeToNetwork) } private fun advertiseNodeToNetwork(newNode: TestStartedNode) { @@ -357,28 +367,20 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe } } - override fun makeMessagingService(): MessagingService { - require(id >= 0) { "Node ID must be zero or positive, was passed: $id" } - // TODO AbstractNode is forced to call this method in start(), and not in the c'tor, because the mockNet - // c'tor parameter isn't available. We need to be able to return a InternalMockMessagingService - // here that can be populated properly in startMessagingService. - return mockNet.messagingNetwork.createNodeWithID( - !mockNet.threadPerNode, - id, - serverThread, - configuration.myLegalName - ).closeOnStop() + override fun makeMessagingService(): MockNodeMessagingService { + return MockNodeMessagingService(configuration, serverThread).closeOnStop() } override fun startMessagingService(rpcOps: RPCOps, nodeInfo: NodeInfo, myNotaryIdentity: PartyAndCertificate?, networkParameters: NetworkParameters) { - mockNet.messagingNetwork.onNotaryIdentity(network as InternalMockMessagingService, myNotaryIdentity) + (network as MockNodeMessagingService).start(mockNet.messagingNetwork, !mockNet.threadPerNode, id, myNotaryIdentity) } - fun setMessagingServiceSpy(messagingServiceSpy: MessagingServiceSpy) { - network = messagingServiceSpy + fun setMessagingServiceSpy(spy: MessagingServiceSpy) { + spy._messagingService = network + (network as MockNodeMessagingService).spy = spy } override fun makeKeyManagementService(identityService: IdentityService): KeyManagementServiceInternal { @@ -551,13 +553,15 @@ open class InternalMockNetwork(defaultParameters: MockNetworkParameters = MockNe } } -open class MessagingServiceSpy(val messagingService: MessagingService) : MessagingService by messagingService +abstract class MessagingServiceSpy { + internal var _messagingService: MessagingService? = null + set(value) { + check(field == null) { "Spy has already been attached to a node" } + field = value + } + val messagingService: MessagingService get() = checkNotNull(_messagingService) { "Spy has not been attached to a node" } -/** - * Attach a [MessagingServiceSpy] to the [InternalMockNetwork.MockNode] allowing interception and modification of messages. - */ -fun TestStartedNode.setMessagingServiceSpy(messagingServiceSpy: MessagingServiceSpy) { - internals.setMessagingServiceSpy(messagingServiceSpy) + abstract fun send(message: Message, target: MessageRecipients, sequenceKey: Any) } private fun mockNodeConfiguration(): NodeConfiguration { @@ -580,4 +584,4 @@ private fun mockNodeConfiguration(): NodeConfiguration { doReturn(5.seconds.toMillis()).whenever(it).additionalNodeInfoPollingFrequencyMsec doReturn(null).whenever(it).devModeOptions } -} \ No newline at end of file +} diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt index dbcbb47785..6c2094a2e7 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/InternalTestUtils.kt @@ -14,7 +14,6 @@ import net.corda.core.utilities.millis import net.corda.core.utilities.seconds import net.corda.node.services.api.StartedNodeServices import net.corda.node.services.messaging.Message -import net.corda.node.services.messaging.MessagingService import net.corda.testing.internal.chooseIdentity import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.User @@ -109,8 +108,4 @@ fun StartedNodeServices.newContext(): InvocationContext = testContext(myInfo.cho fun InMemoryMessagingNetwork.MessageTransfer.getMessage(): Message = message -internal interface InternalMockMessagingService : MessagingService { - fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? -} - fun CordaRPCClient.start(user: User) = start(user.username, user.password) \ No newline at end of file diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNodeMessagingService.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNodeMessagingService.kt new file mode 100644 index 0000000000..a33b2823a0 --- /dev/null +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/MockNodeMessagingService.kt @@ -0,0 +1,283 @@ +package net.corda.testing.node.internal + +import net.corda.core.identity.CordaX500Name +import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.ThreadBox +import net.corda.core.messaging.MessageRecipients +import net.corda.core.node.services.PartyInfo +import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.contextLogger +import net.corda.node.services.config.NodeConfiguration +import net.corda.node.services.messaging.* +import net.corda.node.services.statemachine.DeduplicationId +import net.corda.node.services.statemachine.ExternalEvent +import net.corda.node.services.statemachine.SenderDeduplicationId +import net.corda.node.utilities.AffinityExecutor +import net.corda.testing.node.InMemoryMessagingNetwork +import java.time.Instant +import java.util.* +import java.util.concurrent.LinkedBlockingQueue +import javax.annotation.concurrent.ThreadSafe +import kotlin.concurrent.thread + +@ThreadSafe +class MockNodeMessagingService(private val configuration: NodeConfiguration, + private val executor: AffinityExecutor) : SingletonSerializeAsToken(), MessagingService { + private companion object { + private val log = contextLogger() + } + + private inner class Handler(val topicSession: String, val callback: MessageHandler) : MessageHandlerRegistration + + @Volatile + private var running = true + + private inner class InnerState { + val handlers: MutableList = ArrayList() + val pendingRedelivery = LinkedHashSet() + } + + private val state = ThreadBox(InnerState()) + private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) + + override val ourSenderUUID: String = UUID.randomUUID().toString() + + private var _myAddress: InMemoryMessagingNetwork.PeerHandle? = null + override val myAddress: InMemoryMessagingNetwork.PeerHandle get() = checkNotNull(_myAddress) { "Not started" } + + private lateinit var network: InMemoryMessagingNetwork + private var backgroundThread: Thread? = null + + var spy: MessagingServiceSpy? = null + + /** + * @param manuallyPumped if set to true, then you are expected to call [MockNodeMessagingService.pumpReceive] + * in order to cause the delivery of a single message, which will occur on the thread of the caller. If set to false + * then this class will set up a background thread to deliver messages asynchronously, if the handler specifies no + * executor. + * @param id the numeric ID to use, e.g. set to whatever ID the node used last time. + */ + fun start(network: InMemoryMessagingNetwork, manuallyPumped: Boolean, id: Int, notaryService: PartyAndCertificate?) { + val peerHandle = InMemoryMessagingNetwork.PeerHandle(id, configuration.myLegalName) + + this.network = network + _myAddress = peerHandle + + val oldNode = network.initPeer(this) + if (oldNode != null) { + inheritPendingRedelivery(oldNode) + } + + if (!manuallyPumped) { + backgroundThread = thread(isDaemon = true, name = "In-memory message dispatcher") { + while (!Thread.currentThread().isInterrupted) { + try { + pumpReceiveInternal(true) + } catch (e: InterruptedException) { + break + } + } + } + } + + network.addNotaryIdentity(this, notaryService) + } + + override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients { + return when (partyInfo) { + is PartyInfo.SingleNode -> network.getPeer(partyInfo.party.name) + ?: throw IllegalArgumentException("No StartedMockNode for party ${partyInfo.party.name}") + is PartyInfo.DistributedNode -> InMemoryMessagingNetwork.DistributedServiceHandle(partyInfo.party) + } + } + + override fun addMessageHandler(topic: String, callback: MessageHandler): MessageHandlerRegistration { + check(running) + val (handler, transfers) = state.locked { + val handler = Handler(topic, callback).apply { handlers.add(this) } + val pending = ArrayList() + pending.addAll(pendingRedelivery) + pendingRedelivery.clear() + Pair(handler, pending) + } + + unPopMessages(transfers) + return handler + } + + /** + * When a new message handler is added, this implies we have started a new node. The add handler logic uses this to + * push back any un-acknowledged messages for this peer onto the head of the queue (rather than the tail) to maintain message + * delivery order. We push them back because their consumption was not complete and a restarted node would + * see them re-delivered if this was Artemis. + */ + private fun unPopMessages(transfers: Collection) { + val messageQueue = network.getQueueForPeerHandle(myAddress) + val drained = ArrayList().apply { messageQueue.drainTo(this) } + messageQueue.addAll(transfers) + messageQueue.addAll(drained) + } + + private fun inheritPendingRedelivery(other: MockNodeMessagingService) { + state.locked { + pendingRedelivery.addAll(other.state.locked { pendingRedelivery }) + } + } + + override fun removeMessageHandler(registration: MessageHandlerRegistration) { + check(running) + state.locked { check(handlers.remove(registration as Handler)) } + } + + override fun send(message: Message, target: MessageRecipients, sequenceKey: Any) { + check(running) + val spy = this.spy + if (spy != null) { + this.spy = null + try { + spy.send(message, target, sequenceKey) + } finally { + this.spy = spy + } + } else { + network.msgSend(this, message, target) + } + } + + override fun send(addressedMessages: List) { + for ((message, target, sequenceKey) in addressedMessages) { + send(message, target, sequenceKey) + } + } + + override fun close() { + backgroundThread?.let { + it.interrupt() + it.join() + } + running = false + network.netNodeHasShutdown(myAddress) + } + + /** Returns the given (topic & session, data) pair as a newly created message object. */ + override fun createMessage(topic: String, data: ByteArray, deduplicationId: SenderDeduplicationId, additionalHeaders: Map): Message { + return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId.deduplicationId, senderUUID = deduplicationId.senderUUID) + } + + /** + * Delivers a single message from the internal queue. If there are no messages waiting to be delivered and block + * is true, waits until one has been provided on a different thread via send. If block is false, the return + * result indicates whether a message was delivered or not. + * + * @return the message that was processed, if any in this round. + */ + fun pumpReceive(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? { + check(backgroundThread == null) + check(running) + executor.flush() + try { + return pumpReceiveInternal(block) + } finally { + executor.flush() + } + } + + /** + * Get the next transfer, and matching queue, that is ready to handle. Any pending transfers without handlers + * are placed into `pendingRedelivery` to try again later. + * + * @param block if this should block until a message it can process. + */ + private fun getNextQueue(q: LinkedBlockingQueue, block: Boolean): Pair>? { + var deliverTo: List? = null + // Pop transfers off the queue until we run out (and are not blocking), or find something we can process + while (deliverTo == null) { + val transfer = (if (block) q.take() else q.poll()) ?: return null + deliverTo = state.locked { + val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topic == it.topicSession } + if (matchingHandlers.isEmpty()) { + // Got no handlers for this message yet. Keep the message around and attempt redelivery after a new + // handler has been registered. The purpose of this path is to make unit tests that have multi-threading + // reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting + // up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at + // least sometimes. + log.warn("Message to ${transfer.message.topic} could not be delivered") + pendingRedelivery.add(transfer) + null + } else { + matchingHandlers + } + } + if (deliverTo != null) { + return Pair(transfer, deliverTo) + } + } + return null + } + + private fun pumpReceiveInternal(block: Boolean): InMemoryMessagingNetwork.MessageTransfer? { + val q = network.getQueueForPeerHandle(myAddress) + val (transfer, deliverTo) = getNextQueue(q, block) ?: return null + if (transfer.message.uniqueMessageId !in processedMessages) { + executor.execute { + for (handler in deliverTo) { + try { + val receivedMessage = transfer.toReceivedMessage() + state.locked { pendingRedelivery.add(transfer) } + handler.callback(receivedMessage, handler, InMemoryDeduplicationHandler(receivedMessage, transfer)) + } catch (e: Exception) { + log.error("Caught exception in handler for $this/${handler.topicSession}", e) + } + } + network.onMessageTransfer(transfer) + } + } else { + log.info("Drop duplicate message ${transfer.message.uniqueMessageId}") + } + return transfer + } + + private fun InMemoryMessagingNetwork.MessageTransfer.toReceivedMessage(): ReceivedMessage { + return InMemoryReceivedMessage( + message.topic, + OpaqueBytes(message.data.bytes.copyOf()), // Kryo messes with the buffer so give each client a unique copy + 1, + message.uniqueMessageId, + message.debugTimestamp, + sender.name + ) + } + + private data class InMemoryReceivedMessage(override val topic: String, + override val data: ByteSequence, + override val platformVersion: Int, + override val uniqueMessageId: DeduplicationId, + override val debugTimestamp: Instant, + override val peer: CordaX500Name, + override val senderUUID: String? = null, + override val senderSeqNo: Long? = null, + /** Note this flag is never set in the in memory network. */ + override val isSessionInit: Boolean = false) : ReceivedMessage { + + override val additionalHeaders: Map = emptyMap() + } + + private inner class InMemoryDeduplicationHandler(override val receivedMessage: ReceivedMessage, val transfer: InMemoryMessagingNetwork.MessageTransfer) : DeduplicationHandler, ExternalEvent.ExternalMessageEvent { + override val externalCause: ExternalEvent + get() = this + override val deduplicationHandler: DeduplicationHandler + get() = this + + override fun afterDatabaseTransaction() { + this@MockNodeMessagingService.state.locked { pendingRedelivery.remove(transfer) } + } + + override fun insideDatabaseTransaction() { + processedMessages += transfer.message.uniqueMessageId + } + } + + fun hasPendingDeliveries(): Boolean = state.locked { pendingRedelivery.isNotEmpty() } +} diff --git a/testing/node-driver/src/test/kotlin/net/corda/testing/node/internal/InternalMockNetworkTests.kt b/testing/node-driver/src/test/kotlin/net/corda/testing/node/internal/InternalMockNetworkTests.kt index ba5f777102..185a3dcaff 100644 --- a/testing/node-driver/src/test/kotlin/net/corda/testing/node/internal/InternalMockNetworkTests.kt +++ b/testing/node-driver/src/test/kotlin/net/corda/testing/node/internal/InternalMockNetworkTests.kt @@ -1,10 +1,104 @@ package net.corda.testing.node.internal +import net.corda.core.messaging.AllPossibleRecipients import net.corda.core.serialization.internal.effectiveSerializationEnv +import net.corda.node.services.messaging.Message +import net.corda.testing.internal.rigorousMock import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.After import org.junit.Test +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertTrue class InternalMockNetworkTests { + lateinit var mockNet: InternalMockNetwork + + @After + fun tearDown() { + if (this::mockNet.isInitialized) { + mockNet.stopNodes() + } + } + + @Test + fun basics() { + mockNet = InternalMockNetwork() + + val node1 = mockNet.createNode() + val node2 = mockNet.createNode() + val node3 = mockNet.createNode() + + val bits = "test-content".toByteArray() + var finalDelivery: Message? = null + node2.network.addMessageHandler("test.topic") { msg, _, _ -> + node2.network.send(msg, node3.network.myAddress) + } + node3.network.addMessageHandler("test.topic") { msg, _, _ -> + finalDelivery = msg + } + + // Node 1 sends a message and it should end up in finalDelivery, after we run the network + node1.network.send(node1.network.createMessage("test.topic", data = bits), node2.network.myAddress) + + mockNet.runNetwork(rounds = 1) + + assertTrue(Arrays.equals(finalDelivery!!.data.bytes, bits)) + } + + @Test + fun broadcast() { + mockNet = InternalMockNetwork() + + val node1 = mockNet.createNode() + val node2 = mockNet.createNode() + val node3 = mockNet.createNode() + + val bits = "test-content".toByteArray() + + var counter = 0 + listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _, _ -> counter++ } } + node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock()) + mockNet.runNetwork(rounds = 1) + assertEquals(3, counter) + } + + /** + * Tests that unhandled messages in the received queue are skipped and the next message processed, rather than + * causing processing to return null as if there was no message. + */ + @Test + fun `skip unhandled messages`() { + mockNet = InternalMockNetwork() + + val node1 = mockNet.createNode() + val node2 = mockNet.createNode() + var received = 0 + + node1.network.addMessageHandler("valid_message") { _, _, _ -> + received++ + } + + val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1)) + val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1)) + node2.network.send(invalidMessage, node1.network.myAddress) + mockNet.runNetwork() + assertEquals(0, received) + + node2.network.send(validMessage, node1.network.myAddress) + mockNet.runNetwork() + assertEquals(1, received) + + // Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so + // this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops + val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1)) + val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1)) + node2.network.send(invalidMessage2, node1.network.myAddress) + node2.network.send(validMessage2, node1.network.myAddress) + mockNet.runNetwork() + assertEquals(2, received) + } + @Test fun `does not leak serialization env if init fails`() { val e = Exception("didn't work") @@ -15,4 +109,4 @@ class InternalMockNetworkTests { }.isSameAs(e) assertThatThrownBy { effectiveSerializationEnv }.isInstanceOf(IllegalStateException::class.java) } -} \ No newline at end of file +}