diff --git a/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt b/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt index 33d5b911d2..81e19ebace 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt @@ -6,7 +6,6 @@ import com.google.common.util.concurrent.SettableFuture import net.corda.core.bufferUntilSubscribed import net.corda.core.crypto.Party import net.corda.core.map -import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessagingService import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.createMessage diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt index 93170474bd..8eacae3048 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -85,7 +85,7 @@ class TwoPartyTradeFlowTests { net = MockNetwork(false, true) ledger { - notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) + notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) val aliceKey = aliceNode.services.legalIdentityKey @@ -125,7 +125,7 @@ class TwoPartyTradeFlowTests { @Test fun `shutdown and restore`() { ledger { - notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) + notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) aliceNode.disableDBCloseOnStop() @@ -235,7 +235,7 @@ class TwoPartyTradeFlowTests { @Test fun `check dependencies of sale asset are resolved`() { - notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) + notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY) val aliceKey = aliceNode.services.legalIdentityKey @@ -327,7 +327,7 @@ class TwoPartyTradeFlowTests { @Test fun `track() works`() { - notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) + notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY) val aliceKey = aliceNode.services.legalIdentityKey @@ -427,7 +427,7 @@ class TwoPartyTradeFlowTests { aliceError: Boolean, expectedMessageSubstring: String ) { - notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) + notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) val aliceKey = aliceNode.services.legalIdentityKey diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt index ffddf0c4ca..e037803ba3 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/StateMachineManagerTests.kt @@ -3,20 +3,30 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.contracts.DOLLARS +import net.corda.core.contracts.issuedBy import net.corda.core.crypto.Party +import net.corda.core.crypto.generateKeyPair import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowSessionException import net.corda.core.getOrThrow import net.corda.core.random63BitValue +import net.corda.core.serialization.OpaqueBytes import net.corda.core.serialization.deserialize +import net.corda.flows.CashCommand +import net.corda.flows.CashFlow +import net.corda.flows.NotaryFlow import net.corda.node.services.persistence.checkpoints import net.corda.node.services.statemachine.StateMachineManager.* import net.corda.node.utilities.databaseTransaction +import net.corda.testing.expect +import net.corda.testing.expectEvents import net.corda.testing.initiateSingleShotFlow import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode +import net.corda.testing.sequence import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.After @@ -30,16 +40,24 @@ import kotlin.test.assertTrue class StateMachineManagerTests { - private val net = MockNetwork() + private val net = MockNetwork(servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin()) private val sessionTransfers = ArrayList() private lateinit var node1: MockNode private lateinit var node2: MockNode + private lateinit var notary1: MockNode + private lateinit var notary2: MockNode @Before fun start() { val nodes = net.createTwoNodes() node1 = nodes.first node2 = nodes.second + val notaryKeyPair = generateKeyPair() + // Note that these notaries don't operate correctly as they don's share their state. They are only used for testing + // service addressing. + notary1 = net.createNotaryNode(networkMapAddr = node1.services.myInfo.address, keyPair = notaryKeyPair, serviceName = "notary-service-2000") + notary2 = net.createNotaryNode(networkMapAddr = node1.services.myInfo.address, keyPair = notaryKeyPair, serviceName = "notary-service-2000") + net.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it } net.runNetwork() } @@ -260,6 +278,57 @@ class StateMachineManagerTests { ) } + @Test + fun `different notaries are picked when addressing shared notary identity`() { + assertEquals(notary1.info.notaryIdentity, notary2.info.notaryIdentity) + node1.services.startFlow(CashFlow(CashCommand.IssueCash( + DOLLARS(2000), + OpaqueBytes.of(0x01), + node1.info.legalIdentity, + notary1.info.notaryIdentity))) + // We pay a couple of times, the notary picking should go round robin + for (i in 1 .. 3) { + node1.services.startFlow(CashFlow(CashCommand.PayCash( + DOLLARS(500).issuedBy(node1.info.legalIdentity.ref(0x01)), + node2.info.legalIdentity))) + net.runNetwork() + } + sessionTransfers.expectEvents(isStrict = false) { + sequence( + // First Pay + expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) { + it.message as SessionInit + require(it.from == node1.id) + require(it.to == TransferRecipient.Service(notary1.info.notaryIdentity)) + }, + expect(match = { it.message is SessionConfirm }) { + it.message as SessionConfirm + require(it.from == notary1.id) + }, + // Second pay + expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) { + it.message as SessionInit + require(it.from == node1.id) + require(it.to == TransferRecipient.Service(notary1.info.notaryIdentity)) + }, + expect(match = { it.message is SessionConfirm }) { + it.message as SessionConfirm + require(it.from == notary2.id) + }, + // Third pay + expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) { + it.message as SessionInit + require(it.from == node1.id) + require(it.to == TransferRecipient.Service(notary1.info.notaryIdentity)) + }, + expect(match = { it.message is SessionConfirm }) { + it.message as SessionConfirm + require(it.from == notary1.id) + } + ) + } + } + @Test fun `exception thrown on other side`() { node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow } @@ -301,11 +370,16 @@ class StateMachineManagerTests { } private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer) { - val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.id } + val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == TransferRecipient.Peer(node.id) } assertThat(actualForNode).containsExactly(*expected) } - private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: Int) { + private interface TransferRecipient { + data class Peer(val id: Int) : TransferRecipient + data class Service(val identity: Party) : TransferRecipient + } + + private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: TransferRecipient) { val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null override fun toString(): String = "$from sent $message to $to" } @@ -314,7 +388,12 @@ class StateMachineManagerTests { return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map { val from = it.sender.id val message = it.message.data.deserialize() - val to = (it.recipients as InMemoryMessagingNetwork.PeerHandle).id + val recipients = it.recipients + val to = when (recipients) { + is InMemoryMessagingNetwork.PeerHandle -> TransferRecipient.Peer(recipients.id) + is InMemoryMessagingNetwork.ServiceHandle -> TransferRecipient.Service(recipients.service.identity) + else -> throw IllegalStateException("Unknown recipients $recipients") + } SessionTransfer(from, sanitise(message), to) } } @@ -330,7 +409,7 @@ class StateMachineManagerTests { } private infix fun MockNode.sent(message: SessionMessage): Pair = Pair(id, message) - private infix fun Pair.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.id) + private infix fun Pair.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, TransferRecipient.Peer(node.id)) private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic() { diff --git a/samples/bank-of-corda-demo/src/test/kotlin/net/corda/bank/flow/IssuerFlowTest.kt b/samples/bank-of-corda-demo/src/test/kotlin/net/corda/bank/flow/IssuerFlowTest.kt index 5e03a4ea4e..50c1ac83b9 100644 --- a/samples/bank-of-corda-demo/src/test/kotlin/net/corda/bank/flow/IssuerFlowTest.kt +++ b/samples/bank-of-corda-demo/src/test/kotlin/net/corda/bank/flow/IssuerFlowTest.kt @@ -34,7 +34,7 @@ class IssuerFlowTest { fun `test issuer flow`() { net = MockNetwork(false, true) ledger { - notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) + notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) bankOfCordaNode = net.createPartyNode(notaryNode.info.address, BOC_ISSUER_PARTY.name, BOC_KEY) bankClientNode = net.createPartyNode(notaryNode.info.address, MEGA_CORP.name, MEGA_CORP_KEY) diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index 80b228b120..cacf5d33ba 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -9,7 +9,6 @@ import net.corda.core.getOrThrow import net.corda.core.messaging.* import net.corda.core.node.ServiceEntry import net.corda.core.node.services.PartyInfo -import net.corda.core.node.services.ServiceInfo import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.trace import net.corda.node.services.api.MessagingServiceBuilder @@ -42,7 +41,10 @@ import kotlin.concurrent.thread * @param random The RNG used to choose which node to send to in case one sends to a service. */ @ThreadSafe -class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: SplittableRandom = SplittableRandom()) : SingletonSerializeAsToken() { +class InMemoryMessagingNetwork( + val sendManuallyPumped: Boolean, + val servicePeerAllocationStrategy: ServicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random() +) : SingletonSerializeAsToken() { companion object { val MESSAGES_LOG_NAME = "messages" private val log = LoggerFactory.getLogger(MESSAGES_LOG_NAME) @@ -72,7 +74,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli private val messageReceiveQueues = HashMap>() private val _receivedMessages = PublishSubject.create() - private val serviceToPeersMapping = HashMap>() + private val serviceToPeersMapping = HashMap>() val messagesInFlight = ReusableLatch() @@ -181,7 +183,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli val node = InMemoryMessaging(manuallyPumped, id, executor, database) handleEndpointMap[id] = node serviceHandles.forEach { - serviceToPeersMapping.getOrPut(it) { HashSet() }.add(id) + serviceToPeersMapping.getOrPut(it) { LinkedHashSet() }.add(id) Unit } return Futures.immediateFuture(node) @@ -189,7 +191,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli } } - class PeerHandle(val id: Int, val description: String) : SingleMessageRecipient { + data class PeerHandle(val id: Int, val description: String) : SingleMessageRecipient { override fun toString() = description override fun equals(other: Any?) = other is PeerHandle && other.id == id override fun hashCode() = id.hashCode() @@ -199,6 +201,27 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli override fun toString() = "Service($service)" } + /** + * Mock service loadbalancing + */ + sealed class ServicePeerAllocationStrategy { + abstract fun pickNext(service: ServiceHandle, pickFrom: List): A + class Random(val random: SplittableRandom = SplittableRandom()) : ServicePeerAllocationStrategy() { + override fun pickNext(service: ServiceHandle, pickFrom: List): A { + return pickFrom[random.nextInt(pickFrom.size)] + } + } + class RoundRobin : ServicePeerAllocationStrategy() { + val previousPicks = HashMap() + override fun pickNext(service: ServiceHandle, pickFrom: List): A { + val nextIndex = previousPicks.compute(service) { _key, previous -> + (previous?.plus(1) ?: 0) % pickFrom.size + } + return pickFrom[nextIndex] + } + } + } + // If block is set to true this function will only return once a message has been pushed onto the recipients' queues fun pumpSend(block: Boolean): MessageTransfer? { val transfer = (if (block) messageSendQueue.take() else messageSendQueue.poll()) ?: return null @@ -227,8 +250,8 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer) is ServiceHandle -> { val queues = getQueuesForServiceHandle(transfer.recipients) - val chosedPeerIndex = random.nextInt(queues.size) - queues[chosedPeerIndex].add(transfer) + val queue = servicePeerAllocationStrategy.pickNext(transfer.recipients, queues) + queue.add(transfer) } is AllPossibleRecipients -> { // This means all possible recipients _that the network knows about at the time_, not literally everyone diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt index ef85a0efc7..83b5e6c413 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -47,10 +47,12 @@ import java.util.concurrent.atomic.AtomicInteger */ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, private val threadPerNode: Boolean = false, + private val servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = + InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(), private val defaultFactory: Factory = MockNetwork.DefaultFactory) { private var nextNodeId = 0 val filesystem: FileSystem = Jimfs.newFileSystem(unix()) - val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped) + val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped, servicePeerAllocationStrategy) // A unique identifier for this network to segregate databases with the same nodeID but different networks. private val networkId = random63BitValue() @@ -268,8 +270,8 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, return BasketOfNodes(nodes, notaryNode, mapNode) } - fun createNotaryNode(legalName: String? = null, keyPair: KeyPair? = null): MockNode { - return createNode(null, -1, defaultFactory, true, legalName, keyPair, ServiceInfo(NetworkMapService.type), ServiceInfo(ValidatingNotaryService.type)) + fun createNotaryNode(networkMapAddr: SingleMessageRecipient? = null, legalName: String? = null, keyPair: KeyPair? = null, serviceName: String? = null): MockNode { + return createNode(networkMapAddr, -1, defaultFactory, true, legalName, keyPair, ServiceInfo(NetworkMapService.type), ServiceInfo(ValidatingNotaryService.type, serviceName)) } fun createPartyNode(networkMapAddr: SingleMessageRecipient, legalName: String? = null, keyPair: KeyPair? = null): MockNode {