Removed session IDs from the send and receive methods of ProtocolLogic and are now partially managed by HandshakeMessage

This commit is contained in:
Shams Asari
2016-09-13 17:37:42 +01:00
parent f314bab6c8
commit 8ea20dd0d2
32 changed files with 539 additions and 519 deletions

View File

@ -1,12 +1,11 @@
package com.r3corda.node.services
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.AbstractStateReplacementProtocol
import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.TOPIC
object NotaryChange {
class Plugin : CordaPluginRegistry() {
@ -19,18 +18,9 @@ object NotaryChange {
*/
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
init {
addMessageHandler(NotaryChangeProtocol.TOPIC,
{ req: AbstractStateReplacementProtocol.Handshake -> handleChangeNotaryRequest(req) }
)
}
private fun handleChangeNotaryRequest(req: AbstractStateReplacementProtocol.Handshake): Ack {
val protocol = NotaryChangeProtocol.Acceptor(
req.replyToParty,
req.sessionID,
req.sessionIdForSend)
services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
return Ack
addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake ->
NotaryChangeProtocol.Acceptor(req.replyToParty)
}
}
}
}

View File

@ -1,10 +1,14 @@
package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.messaging.Message
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.ServiceRequestMessage
import javax.annotation.concurrent.ThreadSafe
@ -14,6 +18,10 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe
abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<AbstractNodeService>()
}
val net: MessagingServiceInternal get() = services.networkService
/**
@ -57,4 +65,37 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
crossinline handler: (Q) -> R) {
addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception })
}
/**
* Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the
* necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession.
* @param topic the topic on which the handshake is sent from the other party
* @param loggerName the logger name to use when starting the protocol
* @param protocolFactory a function to create the protocol with the given handshake message
* @param onResultFuture provides access to the [ListenableFuture] when the protocol starts
*/
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>,
crossinline onResultFuture: (ListenableFuture<R>, H) -> Unit) {
net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg ->
try {
val handshake = message.data.deserialize<H>()
val protocol = protocolFactory(handshake)
protocol.registerSession(handshake)
val resultFuture = services.startProtocol(loggerName, protocol)
onResultFuture(resultFuture, handshake)
} catch (e: Exception) {
logger.error("Unable to process ${H::class.java.name} message", e)
}
}
}
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>) {
addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> })
}
}

View File

@ -1,10 +1,11 @@
package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC
import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation
/**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
@ -17,12 +18,10 @@ object FixingSessionInitiation {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) {
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
init {
services.networkService.addMessageHandler(TwoPartyDealProtocol.FIX_INITIATE_TOPIC, DEFAULT_SESSION_ID) { msg, registration ->
val initiation = msg.data.deserialize<TwoPartyDealProtocol.FixingSessionInitiation>()
val protocol = TwoPartyDealProtocol.Fixer(initiation)
services.startProtocol("fixings", protocol)
addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation ->
TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType)
}
}
}

View File

@ -6,7 +6,6 @@ import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.serialize
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
@ -50,8 +49,7 @@ object DataVending {
myIdentity: Party,
recipient: NodeInfo,
transaction: SignedTransaction) {
val sessionID = random63BitValue()
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity, sessionID)
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity)
net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address)
}
}
@ -65,29 +63,29 @@ object DataVending {
{ req: FetchDataProtocol.Request -> handleTXRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleAttachmentRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
addMessageHandler(BroadcastTransactionProtocol.TOPIC,
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage -> handleTXNotification(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
}
private fun handleTXNotification(req: BroadcastTransactionProtocol.NotifyTxRequestMessage): Unit {
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
services.startProtocol("Resolving transactions", ResolveTransactionsProtocol(req.tx, req.replyToParty))
.success {
services.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
addProtocolHandler(
BroadcastTransactionProtocol.TOPIC,
"Resolving transactions",
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage ->
ResolveTransactionsProtocol(req.tx, req.replyToParty)
},
{ future, req ->
future.success {
services.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
})
}
private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> {

View File

@ -1,12 +1,12 @@
package com.r3corda.node.services.transactions
import com.r3corda.core.messaging.Ack
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.NotaryProtocol.TOPIC
/**
* A Notary service acts as the final signer of a transaction ensuring two things:
@ -30,19 +30,9 @@ abstract class NotaryService(services: ServiceHubInternal,
abstract val protocolFactory: NotaryProtocol.Factory
init {
addMessageHandler(NotaryProtocol.TOPIC,
{ req: NotaryProtocol.Handshake -> processRequest(req) }
)
addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake ->
protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider)
}
}
private fun processRequest(req: NotaryProtocol.Handshake): Ack {
val protocol = protocolFactory.create(
req.replyToParty,
req.sessionID,
req.sendSessionID,
timestampChecker,
uniquenessProvider)
services.startProtocol(NotaryProtocol.TOPIC, protocol)
return Ack
}
}

View File

@ -19,11 +19,9 @@ class ValidatingNotaryService(services: ServiceHubInternal,
override val protocolFactory = object : NotaryProtocol.Factory {
override fun create(otherSide: Party,
sendSessionID: Long,
receiveSessionID: Long,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): NotaryProtocol.Service {
return ValidatingNotaryProtocol(otherSide, sendSessionID, receiveSessionID, timestampChecker, uniquenessProvider)
return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider)
}
}
}

View File

@ -1,6 +1,5 @@
package com.r3corda.node.messaging
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.contracts.CommercialPaper
import com.r3corda.contracts.asset.*
import com.r3corda.contracts.testing.fillWithSomeTestCash
@ -9,27 +8,26 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TransactionStorage
import com.r3corda.core.node.services.Wallet
import com.r3corda.core.random63BitValue
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.LogHelper
import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.protocols.TwoPartyTradeProtocol
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
@ -42,6 +40,7 @@ import java.security.KeyPair
import java.security.PublicKey
import java.util.*
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry
import kotlin.test.assertEquals
@ -56,21 +55,11 @@ import kotlin.test.assertTrue
* We assume that Alice and Bob already found each other via some market, and have agreed the details already.
*/
class TwoPartyTradeProtocolTests {
lateinit var net: MockNetwork
private fun runSeller(smm: StateMachineManager, notary: NodeInfo,
otherSide: Party, assetToSell: StateAndRef<OwnableState>, price: Amount<Currency>,
myKeyPair: KeyPair, buyerSessionID: Long): ListenableFuture<SignedTransaction> {
val seller = TwoPartyTradeProtocol.Seller(otherSide, notary, assetToSell, price, myKeyPair, buyerSessionID)
return smm.add("${TwoPartyTradeProtocol.TOPIC}.seller", seller).resultFuture
}
private fun runBuyer(smm: StateMachineManager, notaryNode: NodeInfo,
otherSide: Party, acceptablePrice: Amount<Currency>, typeToBuy: Class<out OwnableState>,
sessionID: Long): ListenableFuture<SignedTransaction> {
val buyer = TwoPartyTradeProtocol.Buyer(otherSide, notaryNode.identity, acceptablePrice, typeToBuy, sessionID)
return smm.add("${TwoPartyTradeProtocol.TOPIC}.buyer", buyer).resultFuture
}
lateinit var notaryNode: MockNetwork.MockNode
lateinit var aliceNode: MockNetwork.MockNode
lateinit var bobNode: MockNetwork.MockNode
@Before
fun before() {
@ -92,10 +81,9 @@ class TwoPartyTradeProtocolTests {
net = MockNetwork(false, true)
ledger {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
bobNode.services.fillWithSomeTestCash(2000.DOLLARS)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
@ -103,26 +91,7 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
// We start the Buyer first, as the Seller sends the first message
val bobResult = runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
val aliceResult = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
// TODO: Verify that the result was inserted into the transaction database.
// assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id])
@ -139,9 +108,9 @@ class TwoPartyTradeProtocolTests {
@Test
fun `shutdown and restore`() {
ledger {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
var bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val bobAddr = bobNode.net.myAddress as InMemoryMessagingNetwork.Handle
val networkMapAddr = notaryNode.info.address
@ -153,25 +122,7 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
val aliceFuture = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).second
// Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once:
@ -210,7 +161,7 @@ class TwoPartyTradeProtocolTests {
}, true, BOB.name, BOB_KEY)
// Find the future representing the result of this state machine again.
val bobFuture = bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).single().second
val bobFuture = bobNode.smm.findStateMachines(Buyer::class.java).single().second
// And off we go again.
net.runNetwork()
@ -218,7 +169,7 @@ class TwoPartyTradeProtocolTests {
// Bob is now finished and has the same transaction as Alice.
assertThat(bobFuture.get()).isEqualTo(aliceFuture.get())
assertThat(bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java)).isEmpty()
assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty()
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty()
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty()
@ -250,9 +201,9 @@ class TwoPartyTradeProtocolTests {
@Test
fun `check dependencies of sale asset are resolved`() {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = makeNodeWithTracking(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = makeNodeWithTracking(notaryNode.info.address, BOB.name, BOB_KEY)
ledger(aliceNode.services) {
@ -271,27 +222,9 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second
val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
net.runNetwork() // Clear network map registration messages
runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
runBuyerAndSeller("alice's paper".outputStateAndRef())
net.runNetwork()
@ -370,14 +303,25 @@ class TwoPartyTradeProtocolTests {
}
}
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : Pair<Future<SignedTransaction>, Future<SignedTransaction>> {
val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
connectProtocols(buyer, seller)
// We start the Buyer first, as the Seller sends the first message
val buyerResult = bobNode.smm.add("$TOPIC.buyer", buyer).resultFuture
val sellerResult = aliceNode.smm.add("$TOPIC.seller", seller).resultFuture
return Pair(buyerResult, sellerResult)
}
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
bobError: Boolean,
aliceError: Boolean,
expectedMessageSubstring: String
) {
val notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
val aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
val bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
notaryNode = net.createNotaryNode(DUMMY_NOTARY.name, DUMMY_NOTARY_KEY)
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
val issuer = MEGA_CORP.ref(1, 2, 3)
val bobKey = bobNode.keyManagement.freshKey()
@ -388,27 +332,9 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(bobsBadCash, bobNode.services, bobNode.storage.myLegalIdentityKey, bobNode.storage.myLegalIdentityKey)
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val buyerSessionID = random63BitValue()
net.runNetwork() // Clear network map registration messages
val aliceResult = runSeller(
aliceNode.smm,
notaryNode.info,
bobNode.info.identity,
"alice's paper".outputStateAndRef(),
1000.DOLLARS,
ALICE_KEY,
buyerSessionID
)
val bobResult = runBuyer(
bobNode.smm,
notaryNode.info,
aliceNode.info.identity,
1000.DOLLARS,
CommercialPaper.State::class.java,
buyerSessionID
)
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
net.runNetwork()

View File

@ -1,18 +1,30 @@
package com.r3corda.node.services
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.protocols.ProtocolLogic
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.map
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.network.NetworkMapService.*
import com.r3corda.node.services.network.NetworkMapService.Companion.FETCH_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.PUSH_ACK_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.REGISTER_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NetworkMapService.Companion.SUBSCRIPTION_PROTOCOL_TOPIC
import com.r3corda.node.services.network.NodeRegistration
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.protocols.ServiceRequestMessage
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.junit.Before
import org.junit.Test
import java.security.PrivateKey
import java.time.Instant
import java.util.concurrent.Future
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
@ -36,7 +48,7 @@ class InMemoryNetworkMapServiceTest {
// Confirm the service contains only its own node
assertEquals(1, service.nodes.count())
assertNull(service.processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Register the second node
var seq = 1L
@ -44,64 +56,22 @@ class InMemoryNetworkMapServiceTest {
val nodeKey = registerNode.storage.myLegalIdentityKey
val addChange = NodeRegistration(registerNode.info, seq++, AddOrRemove.ADD, expires)
val addWireChange = addChange.toWire(nodeKey.private)
service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
assertEquals(2, service.nodes.count())
assertEquals(mapServiceNode.info, service.processQueryRequest(NetworkMapService.QueryIdentityRequest(mapServiceNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
assertEquals(mapServiceNode.info, service.processQueryRequest(QueryIdentityRequest(mapServiceNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Re-registering should be a no-op
service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
service.processRegistrationChangeRequest(RegistrationRequest(addWireChange, mapServiceNode.info.address, Long.MIN_VALUE))
assertEquals(2, service.nodes.count())
// Confirm that de-registering the node succeeds and drops it from the node lists
val removeChange = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires)
val removeWireChange = removeChange.toWire(nodeKey.private)
assert(service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
assertNull(service.processQueryRequest(NetworkMapService.QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
assert(service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
assertNull(service.processQueryRequest(QueryIdentityRequest(registerNode.info.identity, mapServiceNode.info.address, Long.MIN_VALUE)).node)
// Trying to de-register a node that doesn't exist should fail
assert(!service.processRegistrationChangeRequest(NetworkMapService.RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
}
class TestAcknowledgePSM(val server: NodeInfo, val mapVersion: Int) : ProtocolLogic<Unit>() {
override val topic: String get() = NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC
@Suspendable
override fun call() {
val req = NetworkMapService.UpdateAcknowledge(mapVersion, serviceHub.networkService.myAddress)
send(server.identity, 0, req)
}
}
class TestFetchPSM(val server: NodeInfo, val subscribe: Boolean, val ifChangedSinceVersion: Int? = null)
: ProtocolLogic<Collection<NodeRegistration>?>() {
override val topic: String get() = NetworkMapService.FETCH_PROTOCOL_TOPIC
@Suspendable
override fun call(): Collection<NodeRegistration>? {
val sessionID = random63BitValue()
val req = NetworkMapService.FetchMapRequest(subscribe, ifChangedSinceVersion, serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.FetchMapResponse>(server.identity, 0, sessionID, req).unwrap { it.nodes }
}
}
class TestRegisterPSM(val server: NodeInfo, val reg: NodeRegistration, val privateKey: PrivateKey)
: ProtocolLogic<NetworkMapService.RegistrationResponse>() {
override val topic: String get() = NetworkMapService.REGISTER_PROTOCOL_TOPIC
@Suspendable
override fun call(): NetworkMapService.RegistrationResponse {
val sessionID = random63BitValue()
val req = NetworkMapService.RegistrationRequest(reg.toWire(privateKey), serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.RegistrationResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
}
class TestSubscribePSM(val server: NodeInfo, val subscribe: Boolean)
: ProtocolLogic<NetworkMapService.SubscribeResponse>() {
override val topic: String get() = NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC
@Suspendable
override fun call(): NetworkMapService.SubscribeResponse {
val sessionID = random63BitValue()
val req = NetworkMapService.SubscribeRequest(subscribe, serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<NetworkMapService.SubscribeResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
assert(!service.processRegistrationChangeRequest(RegistrationRequest(removeWireChange, mapServiceNode.info.address, Long.MIN_VALUE)).success)
}
@Test
@ -113,7 +83,7 @@ class InMemoryNetworkMapServiceTest {
// Confirm all nodes have registered themselves
network.runNetwork()
var fetchPsm = registerNode.services.startProtocol(NetworkMapService.FETCH_PROTOCOL_TOPIC, TestFetchPSM(mapServiceNode.info, false))
var fetchPsm = fetchMap(registerNode, mapServiceNode, false)
network.runNetwork()
assertEquals(2, fetchPsm.get()?.count())
@ -122,12 +92,12 @@ class InMemoryNetworkMapServiceTest {
val expires = Instant.now() + NetworkMapService.DEFAULT_EXPIRATION_PERIOD
val seq = 2L
val reg = NodeRegistration(registerNode.info, seq, AddOrRemove.REMOVE, expires)
val registerPsm = registerNode.services.startProtocol(NetworkMapService.REGISTER_PROTOCOL_TOPIC, TestRegisterPSM(mapServiceNode.info, reg, nodeKey.private))
val registerPsm = registration(registerNode, mapServiceNode, reg, nodeKey.private)
network.runNetwork()
assertTrue(registerPsm.get().success)
// Now only map service node should be registered
fetchPsm = registerNode.services.startProtocol(NetworkMapService.FETCH_PROTOCOL_TOPIC, TestFetchPSM(mapServiceNode.info, false))
fetchPsm = fetchMap(registerNode, mapServiceNode, false)
network.runNetwork()
assertEquals(mapServiceNode.info, fetchPsm.get()?.filter { it.type == AddOrRemove.ADD }?.map { it.node }?.single())
}
@ -139,8 +109,7 @@ class InMemoryNetworkMapServiceTest {
// Test subscribing to updates
network.runNetwork()
val subscribePsm = registerNode.services.startProtocol(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC,
TestSubscribePSM(mapServiceNode.info, true))
val subscribePsm = subscribe(registerNode, mapServiceNode, true)
network.runNetwork()
subscribePsm.get()
@ -161,10 +130,8 @@ class InMemoryNetworkMapServiceTest {
assertEquals(1, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1))
// Send in an acknowledgment and verify the count goes down
val acknowledgePsm = registerNode.services.startProtocol(NetworkMapService.PUSH_ACK_PROTOCOL_TOPIC,
TestAcknowledgePSM(mapServiceNode.info, startingMapVersion + 1))
updateAcknowlege(registerNode, mapServiceNode, startingMapVersion + 1)
network.runNetwork()
acknowledgePsm.get()
assertEquals(0, service.getUnacknowledgedCount(registerNode.info.address, startingMapVersion + 1))
@ -181,4 +148,25 @@ class InMemoryNetworkMapServiceTest {
}
}
}
}
private fun registration(registerNode: MockNode, mapServiceNode: MockNode, reg: NodeRegistration, privateKey: PrivateKey): ListenableFuture<RegistrationResponse> {
val req = RegistrationRequest(reg.toWire(privateKey), registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<RegistrationResponse>(REGISTER_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun subscribe(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean): ListenableFuture<SubscribeResponse> {
val req = SubscribeRequest(subscribe, registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<SubscribeResponse>(SUBSCRIPTION_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun updateAcknowlege(registerNode: MockNode, mapServiceNode: MockNode, mapVersion: Int) {
val req = UpdateAcknowledge(mapVersion, registerNode.services.networkService.myAddress)
registerNode.send(PUSH_ACK_PROTOCOL_TOPIC, mapServiceNode, req)
}
private fun fetchMap(registerNode: MockNode, mapServiceNode: MockNode, subscribe: Boolean, ifChangedSinceVersion: Int? = null): Future<Collection<NodeRegistration>?> {
val req = FetchMapRequest(subscribe, ifChangedSinceVersion, registerNode.services.networkService.myAddress, random63BitValue())
return registerNode.sendAndReceive<FetchMapResponse>(FETCH_PROTOCOL_TOPIC, mapServiceNode, req).map { it.nodes }
}
}

View File

@ -1,33 +1,36 @@
package com.r3corda.node.services
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.contracts.asset.Cash
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.newSecureRandom
import com.r3corda.core.messaging.MessageHandlerRegistration
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.Wallet
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.OpaqueBytes
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.DUMMY_PUBKEY_1
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.monitor.*
import com.r3corda.node.services.monitor.WalletMonitorService.Companion.IN_EVENT_TOPIC
import com.r3corda.node.services.monitor.WalletMonitorService.Companion.REGISTER_TOPIC
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.testing.*
import com.r3corda.testing.expect
import com.r3corda.testing.expectEvents
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import com.r3corda.testing.parallel
import com.r3corda.testing.sequence
import org.junit.Before
import org.junit.Test
import rx.subjects.PublishSubject
import rx.subjects.ReplaySubject
import java.util.*
import java.util.concurrent.TimeUnit
import kotlin.test.*
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
/**
* Unit tests for the wallet monitoring service.
@ -43,32 +46,13 @@ class WalletMonitorServiceTests {
/**
* Authenticate the register node with the monitor service node.
*/
private fun authenticate(monitorServiceNode: MockNetwork.MockNode, registerNode: MockNetwork.MockNode): Long {
private fun authenticate(monitorServiceNode: MockNode, registerNode: MockNode): Long {
network.runNetwork()
val sessionID = random63BitValue()
val authenticatePsm = registerNode.services.startProtocol(WalletMonitorService.REGISTER_TOPIC,
TestRegisterPSM(monitorServiceNode.info, sessionID))
val sessionId = random63BitValue()
val authenticatePsm = register(registerNode, monitorServiceNode, sessionId)
network.runNetwork()
authenticatePsm.get(1, TimeUnit.SECONDS)
return sessionID
}
class TestReceiveWalletUpdatePSM(val sessionID: Long)
: ProtocolLogic<ServiceToClientEvent.OutputState>() {
override val topic: String get() = WalletMonitorService.IN_EVENT_TOPIC
@Suspendable
override fun call(): ServiceToClientEvent.OutputState
= receive<ServiceToClientEvent.OutputState>(sessionID).unwrap { it }
}
class TestRegisterPSM(val server: NodeInfo, val sessionID: Long)
: ProtocolLogic<RegisterResponse>() {
override val topic: String get() = WalletMonitorService.REGISTER_TOPIC
@Suspendable
override fun call(): RegisterResponse {
val req = RegisterRequest(serviceHub.networkService.myAddress, sessionID)
return sendAndReceive<RegisterResponse>(server.identity, 0, sessionID, req).unwrap { it }
}
return sessionId
}
/**
@ -79,9 +63,7 @@ class WalletMonitorServiceTests {
val (monitorServiceNode, registerNode) = network.createTwoNodes()
network.runNetwork()
val sessionID = random63BitValue()
val authenticatePsm = registerNode.services.startProtocol(WalletMonitorService.REGISTER_TOPIC,
TestRegisterPSM(monitorServiceNode.info, sessionID))
val authenticatePsm = register(registerNode, monitorServiceNode, random63BitValue())
network.runNetwork()
val result = authenticatePsm.get(1, TimeUnit.SECONDS)
assertTrue(result.success)
@ -94,8 +76,7 @@ class WalletMonitorServiceTests {
fun `event received`() {
val (monitorServiceNode, registerNode) = network.createTwoNodes()
val sessionID = authenticate(monitorServiceNode, registerNode)
var receivePsm = registerNode.services.startProtocol(WalletMonitorService.IN_EVENT_TOPIC,
TestReceiveWalletUpdatePSM(sessionID))
var receivePsm = receiveWalletUpdate(registerNode, sessionID)
var expected = Wallet.Update(emptySet(), emptySet())
monitorServiceNode.inNodeWalletMonitorService!!.notifyWalletUpdate(expected)
network.runNetwork()
@ -104,8 +85,7 @@ class WalletMonitorServiceTests {
assertEquals(expected.produced, actual.produced)
// Check that states are passed through correctly
receivePsm = registerNode.services.startProtocol(WalletMonitorService.IN_EVENT_TOPIC,
TestReceiveWalletUpdatePSM(sessionID))
receivePsm = receiveWalletUpdate(registerNode, sessionID)
val consumed = setOf(StateRef(SecureHash.randomSHA256(), 0))
val producedState = TransactionState(DummyContract.SingleOwnerState(newSecureRandom().nextInt(), DUMMY_PUBKEY_1), DUMMY_NOTARY)
val produced = setOf(StateAndRef(producedState, StateRef(SecureHash.randomSHA256(), 0)))
@ -125,7 +105,7 @@ class WalletMonitorServiceTests {
val events = ReplaySubject.create<ServiceToClientEvent>()
val ref = OpaqueBytes(ByteArray(1) {1})
registerNode.net.addMessageHandler(WalletMonitorService.IN_EVENT_TOPIC, sessionID) { msg, reg ->
registerNode.net.addMessageHandler(IN_EVENT_TOPIC, sessionID) { msg, reg ->
events.onNext(msg.data.deserialize<ServiceToClientEvent>())
}
@ -178,7 +158,7 @@ class WalletMonitorServiceTests {
val quantity = 1000L
val events = ReplaySubject.create<ServiceToClientEvent>()
registerNode.net.addMessageHandler(WalletMonitorService.IN_EVENT_TOPIC, sessionID) { msg, reg ->
registerNode.net.addMessageHandler(IN_EVENT_TOPIC, sessionID) { msg, reg ->
events.onNext(msg.data.deserialize<ServiceToClientEvent>())
}
@ -240,4 +220,14 @@ class WalletMonitorServiceTests {
)
}
}
private fun register(registerNode: MockNode, monitorServiceNode: MockNode, sessionId: Long): ListenableFuture<RegisterResponse> {
val req = RegisterRequest(registerNode.services.networkService.myAddress, sessionId)
return registerNode.sendAndReceive<RegisterResponse>(REGISTER_TOPIC, monitorServiceNode, req)
}
private fun receiveWalletUpdate(registerNode: MockNode, sessionId: Long): ListenableFuture<ServiceToClientEvent.OutputState> {
return registerNode.receive<ServiceToClientEvent.OutputState>(IN_EVENT_TOPIC, sessionId)
}
}

View File

@ -6,6 +6,7 @@ import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.random63BitValue
import com.r3corda.testing.connectProtocols
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat
@ -49,10 +50,12 @@ class StateMachineManagerTests {
@Test
fun `protocol suspended just after receiving payload`() {
val topic = "send-and-receive"
val sessionID = random63BitValue()
val payload = random63BitValue()
node1.smm.add("test", SendProtocol(topic, node2.info.identity, sessionID, payload))
node2.smm.add("test", ReceiveProtocol(topic, sessionID))
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node1.smm.add("test", sendProtocol)
node2.smm.add("test", receiveProtocol)
net.runNetwork()
node2.stop()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
@ -90,19 +93,19 @@ class StateMachineManagerTests {
}
private class SendProtocol(override val topic: String, val destination: Party, val sessionID: Long, val payload: Any) : ProtocolLogic<Unit>() {
private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(destination, sessionID, payload)
override fun call() = send(otherParty, payload)
}
private class ReceiveProtocol(override val topic: String, val sessionID: Long) : NonTerminatingProtocol() {
private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() {
@Transient var receivedPayload: Any? = null
@Suspendable
override fun doCall() {
receivedPayload = receive<Any>(sessionID).unwrap { it }
receivedPayload = receive<Any>(otherParty).unwrap { it }
}
}