Add SMM test for round robin node picking

This commit is contained in:
Andras Slemmer
2016-12-13 11:25:33 +00:00
committed by exfalso
parent fd436b0cdc
commit 7ee88b6ec8
6 changed files with 125 additions and 22 deletions

View File

@ -6,7 +6,6 @@ import com.google.common.util.concurrent.SettableFuture
import net.corda.core.bufferUntilSubscribed import net.corda.core.bufferUntilSubscribed
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.map import net.corda.core.map
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.MessagingService import net.corda.core.messaging.MessagingService
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.messaging.createMessage import net.corda.core.messaging.createMessage

View File

@ -85,7 +85,7 @@ class TwoPartyTradeFlowTests {
net = MockNetwork(false, true) net = MockNetwork(false, true)
ledger { ledger {
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val aliceKey = aliceNode.services.legalIdentityKey val aliceKey = aliceNode.services.legalIdentityKey
@ -125,7 +125,7 @@ class TwoPartyTradeFlowTests {
@Test @Test
fun `shutdown and restore`() { fun `shutdown and restore`() {
ledger { ledger {
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
aliceNode.disableDBCloseOnStop() aliceNode.disableDBCloseOnStop()
@ -235,7 +235,7 @@ class TwoPartyTradeFlowTests {
@Test @Test
fun `check dependencies of sale asset are resolved`() { fun `check dependencies of sale asset are resolved`() {
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY)
val aliceKey = aliceNode.services.legalIdentityKey val aliceKey = aliceNode.services.legalIdentityKey
@ -327,7 +327,7 @@ class TwoPartyTradeFlowTests {
@Test @Test
fun `track() works`() { fun `track() works`() {
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY)
val aliceKey = aliceNode.services.legalIdentityKey val aliceKey = aliceNode.services.legalIdentityKey
@ -427,7 +427,7 @@ class TwoPartyTradeFlowTests {
aliceError: Boolean, aliceError: Boolean,
expectedMessageSubstring: String expectedMessageSubstring: String
) { ) {
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val aliceKey = aliceNode.services.legalIdentityKey val aliceKey = aliceNode.services.legalIdentityKey

View File

@ -3,20 +3,30 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.issuedBy
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.crypto.generateKeyPair
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSessionException import net.corda.core.flows.FlowSessionException
import net.corda.core.getOrThrow import net.corda.core.getOrThrow
import net.corda.core.random63BitValue import net.corda.core.random63BitValue
import net.corda.core.serialization.OpaqueBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.flows.CashCommand
import net.corda.flows.CashFlow
import net.corda.flows.NotaryFlow
import net.corda.node.services.persistence.checkpoints import net.corda.node.services.persistence.checkpoints
import net.corda.node.services.statemachine.StateMachineManager.* import net.corda.node.services.statemachine.StateMachineManager.*
import net.corda.node.utilities.databaseTransaction import net.corda.node.utilities.databaseTransaction
import net.corda.testing.expect
import net.corda.testing.expectEvents
import net.corda.testing.initiateSingleShotFlow import net.corda.testing.initiateSingleShotFlow
import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.sequence
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After import org.junit.After
@ -30,16 +40,24 @@ import kotlin.test.assertTrue
class StateMachineManagerTests { class StateMachineManagerTests {
private val net = MockNetwork() private val net = MockNetwork(servicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin())
private val sessionTransfers = ArrayList<SessionTransfer>() private val sessionTransfers = ArrayList<SessionTransfer>()
private lateinit var node1: MockNode private lateinit var node1: MockNode
private lateinit var node2: MockNode private lateinit var node2: MockNode
private lateinit var notary1: MockNode
private lateinit var notary2: MockNode
@Before @Before
fun start() { fun start() {
val nodes = net.createTwoNodes() val nodes = net.createTwoNodes()
node1 = nodes.first node1 = nodes.first
node2 = nodes.second node2 = nodes.second
val notaryKeyPair = generateKeyPair()
// Note that these notaries don't operate correctly as they don's share their state. They are only used for testing
// service addressing.
notary1 = net.createNotaryNode(networkMapAddr = node1.services.myInfo.address, keyPair = notaryKeyPair, serviceName = "notary-service-2000")
notary2 = net.createNotaryNode(networkMapAddr = node1.services.myInfo.address, keyPair = notaryKeyPair, serviceName = "notary-service-2000")
net.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it } net.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it }
net.runNetwork() net.runNetwork()
} }
@ -260,6 +278,57 @@ class StateMachineManagerTests {
) )
} }
@Test
fun `different notaries are picked when addressing shared notary identity`() {
assertEquals(notary1.info.notaryIdentity, notary2.info.notaryIdentity)
node1.services.startFlow(CashFlow(CashCommand.IssueCash(
DOLLARS(2000),
OpaqueBytes.of(0x01),
node1.info.legalIdentity,
notary1.info.notaryIdentity)))
// We pay a couple of times, the notary picking should go round robin
for (i in 1 .. 3) {
node1.services.startFlow(CashFlow(CashCommand.PayCash(
DOLLARS(500).issuedBy(node1.info.legalIdentity.ref(0x01)),
node2.info.legalIdentity)))
net.runNetwork()
}
sessionTransfers.expectEvents(isStrict = false) {
sequence(
// First Pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) {
it.message as SessionInit
require(it.from == node1.id)
require(it.to == TransferRecipient.Service(notary1.info.notaryIdentity))
},
expect(match = { it.message is SessionConfirm }) {
it.message as SessionConfirm
require(it.from == notary1.id)
},
// Second pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) {
it.message as SessionInit
require(it.from == node1.id)
require(it.to == TransferRecipient.Service(notary1.info.notaryIdentity))
},
expect(match = { it.message is SessionConfirm }) {
it.message as SessionConfirm
require(it.from == notary2.id)
},
// Third pay
expect(match = { it.message is SessionInit && it.message.flowName == NotaryFlow.Client::class.java.name }) {
it.message as SessionInit
require(it.from == node1.id)
require(it.to == TransferRecipient.Service(notary1.info.notaryIdentity))
},
expect(match = { it.message is SessionConfirm }) {
it.message as SessionConfirm
require(it.from == notary1.id)
}
)
}
}
@Test @Test
fun `exception thrown on other side`() { fun `exception thrown on other side`() {
node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow } node2.services.registerFlowInitiator(ReceiveThenSuspendFlow::class) { ExceptionFlow }
@ -301,11 +370,16 @@ class StateMachineManagerTests {
} }
private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer) { private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer) {
val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.id } val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == TransferRecipient.Peer(node.id) }
assertThat(actualForNode).containsExactly(*expected) assertThat(actualForNode).containsExactly(*expected)
} }
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: Int) { private interface TransferRecipient {
data class Peer(val id: Int) : TransferRecipient
data class Service(val identity: Party) : TransferRecipient
}
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: TransferRecipient) {
val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to" override fun toString(): String = "$from sent $message to $to"
} }
@ -314,7 +388,12 @@ class StateMachineManagerTests {
return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map { return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map {
val from = it.sender.id val from = it.sender.id
val message = it.message.data.deserialize<SessionMessage>() val message = it.message.data.deserialize<SessionMessage>()
val to = (it.recipients as InMemoryMessagingNetwork.PeerHandle).id val recipients = it.recipients
val to = when (recipients) {
is InMemoryMessagingNetwork.PeerHandle -> TransferRecipient.Peer(recipients.id)
is InMemoryMessagingNetwork.ServiceHandle -> TransferRecipient.Service(recipients.service.identity)
else -> throw IllegalStateException("Unknown recipients $recipients")
}
SessionTransfer(from, sanitise(message), to) SessionTransfer(from, sanitise(message), to)
} }
} }
@ -330,7 +409,7 @@ class StateMachineManagerTests {
} }
private infix fun MockNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(id, message) private infix fun MockNode.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(id, message)
private infix fun Pair<Int, SessionMessage>.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.id) private infix fun Pair<Int, SessionMessage>.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, TransferRecipient.Peer(node.id))
private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() { private class NoOpFlow(val nonTerminating: Boolean = false) : FlowLogic<Unit>() {

View File

@ -34,7 +34,7 @@ class IssuerFlowTest {
fun `test issuer flow`() { fun `test issuer flow`() {
net = MockNetwork(false, true) net = MockNetwork(false, true)
ledger { ledger {
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY) notaryNode = net.createNotaryNode(null, DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
bankOfCordaNode = net.createPartyNode(notaryNode.info.address, BOC_ISSUER_PARTY.name, BOC_KEY) bankOfCordaNode = net.createPartyNode(notaryNode.info.address, BOC_ISSUER_PARTY.name, BOC_KEY)
bankClientNode = net.createPartyNode(notaryNode.info.address, MEGA_CORP.name, MEGA_CORP_KEY) bankClientNode = net.createPartyNode(notaryNode.info.address, MEGA_CORP.name, MEGA_CORP_KEY)

View File

@ -9,7 +9,6 @@ import net.corda.core.getOrThrow
import net.corda.core.messaging.* import net.corda.core.messaging.*
import net.corda.core.node.ServiceEntry import net.corda.core.node.ServiceEntry
import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.ServiceInfo
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.services.api.MessagingServiceBuilder import net.corda.node.services.api.MessagingServiceBuilder
@ -42,7 +41,10 @@ import kotlin.concurrent.thread
* @param random The RNG used to choose which node to send to in case one sends to a service. * @param random The RNG used to choose which node to send to in case one sends to a service.
*/ */
@ThreadSafe @ThreadSafe
class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: SplittableRandom = SplittableRandom()) : SingletonSerializeAsToken() { class InMemoryMessagingNetwork(
val sendManuallyPumped: Boolean,
val servicePeerAllocationStrategy: ServicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random()
) : SingletonSerializeAsToken() {
companion object { companion object {
val MESSAGES_LOG_NAME = "messages" val MESSAGES_LOG_NAME = "messages"
private val log = LoggerFactory.getLogger(MESSAGES_LOG_NAME) private val log = LoggerFactory.getLogger(MESSAGES_LOG_NAME)
@ -72,7 +74,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli
private val messageReceiveQueues = HashMap<PeerHandle, LinkedBlockingQueue<MessageTransfer>>() private val messageReceiveQueues = HashMap<PeerHandle, LinkedBlockingQueue<MessageTransfer>>()
private val _receivedMessages = PublishSubject.create<MessageTransfer>() private val _receivedMessages = PublishSubject.create<MessageTransfer>()
private val serviceToPeersMapping = HashMap<ServiceHandle, HashSet<PeerHandle>>() private val serviceToPeersMapping = HashMap<ServiceHandle, LinkedHashSet<PeerHandle>>()
val messagesInFlight = ReusableLatch() val messagesInFlight = ReusableLatch()
@ -181,7 +183,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli
val node = InMemoryMessaging(manuallyPumped, id, executor, database) val node = InMemoryMessaging(manuallyPumped, id, executor, database)
handleEndpointMap[id] = node handleEndpointMap[id] = node
serviceHandles.forEach { serviceHandles.forEach {
serviceToPeersMapping.getOrPut(it) { HashSet<PeerHandle>() }.add(id) serviceToPeersMapping.getOrPut(it) { LinkedHashSet<PeerHandle>() }.add(id)
Unit Unit
} }
return Futures.immediateFuture(node) return Futures.immediateFuture(node)
@ -189,7 +191,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli
} }
} }
class PeerHandle(val id: Int, val description: String) : SingleMessageRecipient { data class PeerHandle(val id: Int, val description: String) : SingleMessageRecipient {
override fun toString() = description override fun toString() = description
override fun equals(other: Any?) = other is PeerHandle && other.id == id override fun equals(other: Any?) = other is PeerHandle && other.id == id
override fun hashCode() = id.hashCode() override fun hashCode() = id.hashCode()
@ -199,6 +201,27 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli
override fun toString() = "Service($service)" override fun toString() = "Service($service)"
} }
/**
* Mock service loadbalancing
*/
sealed class ServicePeerAllocationStrategy {
abstract fun <A> pickNext(service: ServiceHandle, pickFrom: List<A>): A
class Random(val random: SplittableRandom = SplittableRandom()) : ServicePeerAllocationStrategy() {
override fun <A> pickNext(service: ServiceHandle, pickFrom: List<A>): A {
return pickFrom[random.nextInt(pickFrom.size)]
}
}
class RoundRobin : ServicePeerAllocationStrategy() {
val previousPicks = HashMap<ServiceHandle, Int>()
override fun <A> pickNext(service: ServiceHandle, pickFrom: List<A>): A {
val nextIndex = previousPicks.compute(service) { _key, previous ->
(previous?.plus(1) ?: 0) % pickFrom.size
}
return pickFrom[nextIndex]
}
}
}
// If block is set to true this function will only return once a message has been pushed onto the recipients' queues // If block is set to true this function will only return once a message has been pushed onto the recipients' queues
fun pumpSend(block: Boolean): MessageTransfer? { fun pumpSend(block: Boolean): MessageTransfer? {
val transfer = (if (block) messageSendQueue.take() else messageSendQueue.poll()) ?: return null val transfer = (if (block) messageSendQueue.take() else messageSendQueue.poll()) ?: return null
@ -227,8 +250,8 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: Spli
is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer) is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer)
is ServiceHandle -> { is ServiceHandle -> {
val queues = getQueuesForServiceHandle(transfer.recipients) val queues = getQueuesForServiceHandle(transfer.recipients)
val chosedPeerIndex = random.nextInt(queues.size) val queue = servicePeerAllocationStrategy.pickNext(transfer.recipients, queues)
queues[chosedPeerIndex].add(transfer) queue.add(transfer)
} }
is AllPossibleRecipients -> { is AllPossibleRecipients -> {
// This means all possible recipients _that the network knows about at the time_, not literally everyone // This means all possible recipients _that the network knows about at the time_, not literally everyone

View File

@ -47,10 +47,12 @@ import java.util.concurrent.atomic.AtomicInteger
*/ */
class MockNetwork(private val networkSendManuallyPumped: Boolean = false, class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
private val threadPerNode: Boolean = false, private val threadPerNode: Boolean = false,
private val servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy =
InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(),
private val defaultFactory: Factory = MockNetwork.DefaultFactory) { private val defaultFactory: Factory = MockNetwork.DefaultFactory) {
private var nextNodeId = 0 private var nextNodeId = 0
val filesystem: FileSystem = Jimfs.newFileSystem(unix()) val filesystem: FileSystem = Jimfs.newFileSystem(unix())
val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped) val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped, servicePeerAllocationStrategy)
// A unique identifier for this network to segregate databases with the same nodeID but different networks. // A unique identifier for this network to segregate databases with the same nodeID but different networks.
private val networkId = random63BitValue() private val networkId = random63BitValue()
@ -268,8 +270,8 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
return BasketOfNodes(nodes, notaryNode, mapNode) return BasketOfNodes(nodes, notaryNode, mapNode)
} }
fun createNotaryNode(legalName: String? = null, keyPair: KeyPair? = null): MockNode { fun createNotaryNode(networkMapAddr: SingleMessageRecipient? = null, legalName: String? = null, keyPair: KeyPair? = null, serviceName: String? = null): MockNode {
return createNode(null, -1, defaultFactory, true, legalName, keyPair, ServiceInfo(NetworkMapService.type), ServiceInfo(ValidatingNotaryService.type)) return createNode(networkMapAddr, -1, defaultFactory, true, legalName, keyPair, ServiceInfo(NetworkMapService.type), ServiceInfo(ValidatingNotaryService.type, serviceName))
} }
fun createPartyNode(networkMapAddr: SingleMessageRecipient, legalName: String? = null, keyPair: KeyPair? = null): MockNode { fun createPartyNode(networkMapAddr: SingleMessageRecipient, legalName: String? = null, keyPair: KeyPair? = null): MockNode {