artemis, sessions, mock: Add Service addressing, tests pass

This commit is contained in:
Andras Slemmer 2016-12-12 12:59:08 +00:00 committed by exfalso
parent 978ab7e35e
commit fd436b0cdc
19 changed files with 280 additions and 114 deletions

View File

@ -4,6 +4,7 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import net.corda.core.catch import net.corda.core.catch
import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.node.services.DEFAULT_SESSION_ID
import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.DeserializeAsKotlinObjectDef import net.corda.core.serialization.DeserializeAsKotlinObjectDef
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
@ -79,6 +80,8 @@ interface MessagingService {
*/ */
fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID = UUID.randomUUID()): Message fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID = UUID.randomUUID()): Message
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
/** Returns an address that refers to this node. */ /** Returns an address that refers to this node. */
val myAddress: SingleMessageRecipient val myAddress: SingleMessageRecipient
} }

View File

@ -81,7 +81,6 @@ interface ServiceHub {
* Typical use is during signing in flows and for unit test signing. * Typical use is during signing in flows and for unit test signing.
*/ */
val notaryIdentityKey: KeyPair get() = this.keyManagementService.toKeyPair(this.myInfo.notaryIdentity.owningKey.keys) val notaryIdentityKey: KeyPair get() = this.keyManagementService.toKeyPair(this.myInfo.notaryIdentity.owningKey.keys)
} }
/** /**

View File

@ -5,9 +5,11 @@ import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.contracts.Contract import net.corda.core.contracts.Contract
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.MessagingService import net.corda.core.messaging.MessagingService
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceEntry
import net.corda.core.randomOrNull import net.corda.core.randomOrNull
import rx.Observable import rx.Observable
@ -64,16 +66,28 @@ interface NetworkMapCache {
fun getNodeByLegalName(name: String): NodeInfo? = partyNodes.singleOrNull { it.legalIdentity.name == name } fun getNodeByLegalName(name: String): NodeInfo? = partyNodes.singleOrNull { it.legalIdentity.name == name }
/** Look up the node info for a composite key. */ /** Look up the node info for a composite key. */
fun getNodeByCompositeKey(compositeKey: CompositeKey): NodeInfo? { fun getNodeByLegalIdentityKey(compositeKey: CompositeKey): NodeInfo? {
// Although we should never have more than one match, it is theoretically possible. Report an error if it happens. // Although we should never have more than one match, it is theoretically possible. Report an error if it happens.
val candidates = partyNodes.filter { val candidates = partyNodes.filter { it.legalIdentity.owningKey == compositeKey }
(it.legalIdentity.owningKey == compositeKey)
|| it.advertisedServices.any { it.identity.owningKey == compositeKey }
}
check(candidates.size <= 1) { "Found more than one match for key $compositeKey" } check(candidates.size <= 1) { "Found more than one match for key $compositeKey" }
return candidates.singleOrNull() return candidates.singleOrNull()
} }
/**
* Look up all nodes advertising the service owned by [compositeKey]
*/
fun getNodesByAdvertisedServiceIdentityKey(compositeKey: CompositeKey): List<NodeInfo> {
return partyNodes.filter { it.advertisedServices.any { it.identity.owningKey == compositeKey } }
}
/**
* Returns information about the party, which may be a specific node or a service
*
* @party The party we would like the address of.
* @return The address of the party, if found.
*/
fun getPartyInfo(party: Party): PartyInfo?
/** /**
* Given a [party], returns a node advertising it as an identity. If more than one node found the result * Given a [party], returns a node advertising it as an identity. If more than one node found the result
* is chosen at random. * is chosen at random.

View File

@ -0,0 +1,10 @@
package net.corda.core.node.services
import net.corda.core.crypto.Party
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceEntry
sealed class PartyInfo(val party: Party) {
class Node(val node: NodeInfo) : PartyInfo(node.legalIdentity)
class Service(val service: ServiceEntry) : PartyInfo(service.identity)
}

View File

@ -72,7 +72,7 @@ abstract class AbstractStateReplacementFlow<T> {
@Suspendable @Suspendable
private fun collectSignatures(participants: List<CompositeKey>, stx: SignedTransaction): List<DigitalSignature.WithKey> { private fun collectSignatures(participants: List<CompositeKey>, stx: SignedTransaction): List<DigitalSignature.WithKey> {
val parties = participants.map { val parties = participants.map {
val participantNode = serviceHub.networkMapCache.getNodeByCompositeKey(it) ?: val participantNode = serviceHub.networkMapCache.getNodeByLegalIdentityKey(it) ?:
throw IllegalStateException("Participant $it to state $originalState not found on the network") throw IllegalStateException("Participant $it to state $originalState not found on the network")
participantNode.legalIdentity participantNode.legalIdentity
} }

View File

@ -277,7 +277,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
* A service entry contains the advertised [ServiceInfo] along with the service identity. The identity *name* is * A service entry contains the advertised [ServiceInfo] along with the service identity. The identity *name* is
* taken from the configuration or, if non specified, generated by combining the node's legal name and the service id. * taken from the configuration or, if non specified, generated by combining the node's legal name and the service id.
*/ */
private fun makeServiceEntries(): List<ServiceEntry> { protected fun makeServiceEntries(): List<ServiceEntry> {
return advertisedServices.map { return advertisedServices.map {
val serviceId = it.type.id val serviceId = it.type.id
val serviceName = it.name ?: "$serviceId|${configuration.myLegalName}" val serviceName = it.name ?: "$serviceId|${configuration.myLegalName}"

View File

@ -3,6 +3,7 @@ package net.corda.node.services.messaging
import com.google.common.annotations.VisibleForTesting import com.google.common.annotations.VisibleForTesting
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.messaging.MessageRecipientGroup
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.read import net.corda.core.read
@ -34,6 +35,7 @@ abstract class ArtemisMessagingComponent() : SingletonSerializeAsToken() {
const val INTERNAL_PREFIX = "internal." const val INTERNAL_PREFIX = "internal."
const val PEERS_PREFIX = "${INTERNAL_PREFIX}peers." const val PEERS_PREFIX = "${INTERNAL_PREFIX}peers."
const val SERVICES_PREFIX = "${INTERNAL_PREFIX}services."
const val CLIENTS_PREFIX = "clients." const val CLIENTS_PREFIX = "clients."
const val P2P_QUEUE = "p2p.inbound" const val P2P_QUEUE = "p2p.inbound"
const val RPC_REQUESTS_QUEUE = "rpc.requests" const val RPC_REQUESTS_QUEUE = "rpc.requests"
@ -55,7 +57,7 @@ abstract class ArtemisMessagingComponent() : SingletonSerializeAsToken() {
} }
} }
protected interface ArtemisAddress { protected interface ArtemisAddress : SingleMessageRecipient {
val queueName: SimpleString val queueName: SimpleString
val hostAndPort: HostAndPort val hostAndPort: HostAndPort
} }
@ -69,11 +71,20 @@ abstract class ArtemisMessagingComponent() : SingletonSerializeAsToken() {
* may change or evolve and code that relies upon it being a simple host/port may not function correctly. * may change or evolve and code that relies upon it being a simple host/port may not function correctly.
* For instance it may contain onion routing data. * For instance it may contain onion routing data.
*/ */
data class NodeAddress(val identity: CompositeKey, override val hostAndPort: HostAndPort) : SingleMessageRecipient, ArtemisAddress { data class NodeAddress(override val queueName: SimpleString, override val hostAndPort: HostAndPort) : ArtemisAddress {
override val queueName: SimpleString = SimpleString("$PEERS_PREFIX${identity.toBase58String()}") companion object {
fun asPeer(identity: CompositeKey, hostAndPort: HostAndPort) =
NodeAddress(SimpleString("$PEERS_PREFIX${identity.toBase58String()}"), hostAndPort)
fun asService(identity: CompositeKey, hostAndPort: HostAndPort) =
NodeAddress(SimpleString("$SERVICES_PREFIX${identity.toBase58String()}"), hostAndPort)
}
override fun toString(): String = "${javaClass.simpleName}(identity = $queueName, $hostAndPort)" override fun toString(): String = "${javaClass.simpleName}(identity = $queueName, $hostAndPort)"
} }
data class ServiceAddress(val identity: CompositeKey) : MessageRecipientGroup {
val queueName: SimpleString = SimpleString("$SERVICES_PREFIX${identity.toBase58String()}")
}
/** The config object is used to pass in the passwords for the certificate KeyStore and TrustStore */ /** The config object is used to pass in the passwords for the certificate KeyStore and TrustStore */
abstract val config: NodeSSLConfiguration abstract val config: NodeSSLConfiguration

View File

@ -9,6 +9,7 @@ import net.corda.core.crypto.X509Utilities.CORDA_CLIENT_CA
import net.corda.core.crypto.X509Utilities.CORDA_ROOT_CA import net.corda.core.crypto.X509Utilities.CORDA_ROOT_CA
import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.newSecureRandom
import net.corda.core.div import net.corda.core.div
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.NetworkMapCache
import net.corda.core.node.services.NetworkMapCache.MapChange import net.corda.core.node.services.NetworkMapCache.MapChange
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
@ -92,7 +93,7 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
fun start() = mutex.locked { fun start() = mutex.locked {
if (!running) { if (!running) {
configureAndStartServer() configureAndStartServer()
networkChangeHandle = networkMapCache.changed.subscribe { destroyOrCreateBridge(it) } networkChangeHandle = networkMapCache.changed.subscribe { destroyOrCreateBridges(it) }
running = true running = true
} }
} }
@ -120,14 +121,36 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
* We create the bridges indirectly now because the network map is not persisted and there are no ways to obtain host and port information on startup. * We create the bridges indirectly now because the network map is not persisted and there are no ways to obtain host and port information on startup.
* TODO : Create the bridge directly from the list of queues on start up when we have a persisted network map service. * TODO : Create the bridge directly from the list of queues on start up when we have a persisted network map service.
*/ */
private fun destroyOrCreateBridge(change: MapChange) { private fun destroyOrCreateBridges(change: MapChange) {
val (newNode, staleNode) = when (change) { fun addAddresses(node: NodeInfo, target: HashSet<ArtemisAddress>) {
is MapChange.Modified -> change.node to change.previousNode val nodeAddress = node.address as ArtemisAddress
is MapChange.Removed -> null to change.node target.add(nodeAddress)
is MapChange.Added -> change.node to null change.node.advertisedServices.forEach {
target.add(NodeAddress.asService(it.identity.owningKey, nodeAddress.hostAndPort))
}
}
val addressesToCreateBridgesTo = HashSet<ArtemisAddress>()
val addressesToRemoveBridgesTo = HashSet<ArtemisAddress>()
when (change) {
is MapChange.Modified -> {
addAddresses(change.node, addressesToCreateBridgesTo)
addAddresses(change.previousNode, addressesToRemoveBridgesTo)
}
is MapChange.Removed -> {
addAddresses(change.node, addressesToRemoveBridgesTo)
}
is MapChange.Added -> {
addAddresses(change.node, addressesToCreateBridgesTo)
}
}
(addressesToRemoveBridgesTo - addressesToCreateBridgesTo).forEach {
maybeDestroyBridge(bridgeNameForAddress(it))
}
addressesToCreateBridgesTo.forEach {
maybeDeployBridgeForAddress(it)
} }
(staleNode?.address as? ArtemisAddress)?.let { maybeDestroyBridge(it.queueName) }
(newNode?.address as? ArtemisAddress)?.let { if (activeMQServer.queueQuery(it.queueName).isExists) maybeDeployBridgeForAddress(it) }
} }
private fun configureAndStartServer() { private fun configureAndStartServer() {
@ -138,31 +161,47 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
registerActivationFailureListener { exception -> throw exception } registerActivationFailureListener { exception -> throw exception }
// Some types of queue might need special preparation on our side, like dialling back or preparing // Some types of queue might need special preparation on our side, like dialling back or preparing
// a lazily initialised subsystem. // a lazily initialised subsystem.
registerPostQueueCreationCallback { deployBridgeFromNewPeerQueue(it) } registerPostQueueCreationCallback { deployBridgeFromNewQueue(it) }
registerPostQueueDeletionCallback { address, qName -> log.debug { "Queue deleted: $qName for $address" } } registerPostQueueDeletionCallback { address, qName -> log.debug { "Queue deleted: $qName for $address" } }
} }
activeMQServer.start() activeMQServer.start()
printBasicNodeInfo("Node listening on address", myHostPort.toString()) printBasicNodeInfo("Node listening on address", myHostPort.toString())
} }
private fun deployBridgeFromNewPeerQueue(queueName: SimpleString) { private fun maybeDeployBridgeForNode(queueName: SimpleString, nodeInfo: NodeInfo) {
log.debug { "Queue created: $queueName" } log.debug("Deploying bridge for $queueName to $nodeInfo")
if (!queueName.startsWith(PEERS_PREFIX)) return val address = nodeInfo.address
try { if (address is NodeAddress) {
val identity = CompositeKey.parseFromBase58(queueName.substring(PEERS_PREFIX.length)) maybeDeployBridgeForAddress(NodeAddress(queueName, address.hostAndPort))
val nodeInfo = networkMapCache.getNodeByCompositeKey(identity) } else {
if (nodeInfo != null) { log.error("Don't know how to deal with $address")
val address = nodeInfo.address }
if (address is NodeAddress) { }
maybeDeployBridgeForAddress(address)
private fun deployBridgeFromNewQueue(queueName: SimpleString) {
log.debug { "Queue created: $queueName, deploying bridge(s)" }
when {
queueName.startsWith(PEERS_PREFIX) -> try {
val identity = CompositeKey.parseFromBase58(queueName.substring(PEERS_PREFIX.length))
val nodeInfo = networkMapCache.getNodeByLegalIdentityKey(identity)
if (nodeInfo != null) {
maybeDeployBridgeForNode(queueName, nodeInfo)
} else { } else {
log.error("Don't know how to deal with $address") log.error("Queue created for a peer that we don't know from the network map: $queueName")
} }
} else { } catch (e: AddressFormatException) {
log.error("Queue created for a peer that we don't know from the network map: $queueName") log.error("Flow violation: Could not parse peer queue name as Base 58: $queueName")
}
queueName.startsWith(SERVICES_PREFIX) -> try {
val identity = CompositeKey.parseFromBase58(queueName.substring(SERVICES_PREFIX.length))
val nodeInfos = networkMapCache.getNodesByAdvertisedServiceIdentityKey(identity)
for (nodeInfo in nodeInfos) {
maybeDeployBridgeForNode(queueName, nodeInfo)
}
} catch (e: AddressFormatException) {
log.error("Flow violation: Could not parse service queue name as Base 58: $queueName")
} }
} catch (e: AddressFormatException) {
log.error("Flow violation: Could not parse queue name as Base 58: $queueName")
} }
} }
@ -240,26 +279,29 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
tcpTransport(OUTBOUND, hostAndPort.hostText, hostAndPort.port) tcpTransport(OUTBOUND, hostAndPort.hostText, hostAndPort.port)
) )
private fun bridgeExists(name: SimpleString) = activeMQServer.clusterManager.bridges.containsKey(name.toString()) private fun bridgeExists(name: String) = activeMQServer.clusterManager.bridges.containsKey(name)
private fun maybeDeployBridgeForAddress(address: ArtemisAddress) { private fun maybeDeployBridgeForAddress(address: ArtemisAddress) {
if (!connectorExists(address.hostAndPort)) { if (!connectorExists(address.hostAndPort)) {
addConnector(address.hostAndPort) addConnector(address.hostAndPort)
} }
if (!bridgeExists(address.queueName)) { val bridgeName = bridgeNameForAddress(address)
deployBridge(address) if (!bridgeExists(bridgeName)) {
deployBridge(bridgeName, address)
} }
} }
private fun bridgeNameForAddress(address: ArtemisAddress) = "${address.queueName}-${address.hostAndPort}"
/** /**
* All nodes are expected to have a public facing address called [ArtemisMessagingComponent.P2P_QUEUE] for receiving * All nodes are expected to have a public facing address called [ArtemisMessagingComponent.P2P_QUEUE] for receiving
* messages from other nodes. When we want to send a message to a node we send it to our internal address/queue for it, * messages from other nodes. When we want to send a message to a node we send it to our internal address/queue for it,
* as defined by ArtemisAddress.queueName. A bridge is then created to forward messages from this queue to the node's * as defined by ArtemisAddress.queueName. A bridge is then created to forward messages from this queue to the node's
* P2P address. * P2P address.
*/ */
private fun deployBridge(address: ArtemisAddress) { private fun deployBridge(bridgeName: String, address: ArtemisAddress) {
activeMQServer.deployBridge(BridgeConfiguration().apply { activeMQServer.deployBridge(BridgeConfiguration().apply {
name = address.queueName.toString() name = bridgeName
queueName = address.queueName.toString() queueName = address.queueName.toString()
forwardingAddress = P2P_QUEUE forwardingAddress = P2P_QUEUE
staticConnectors = listOf(address.hostAndPort.toString()) staticConnectors = listOf(address.hostAndPort.toString())
@ -272,9 +314,9 @@ class ArtemisMessagingServer(override val config: NodeConfiguration,
}) })
} }
private fun maybeDestroyBridge(name: SimpleString) { private fun maybeDestroyBridge(name: String) {
if (bridgeExists(name)) { if (bridgeExists(name)) {
activeMQServer.destroyBridge(name.toString()) activeMQServer.destroyBridge(name)
} }
} }

View File

@ -5,6 +5,7 @@ import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.ThreadBox import net.corda.core.ThreadBox
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.messaging.* import net.corda.core.messaging.*
import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.opaque import net.corda.core.serialization.opaque
import net.corda.core.success import net.corda.core.success
@ -96,7 +97,7 @@ class NodeMessagingClient(override val config: NodeConfiguration,
/** /**
* Apart from the NetworkMapService this is the only other address accessible to the node outside of lookups against the NetworkMapCache. * Apart from the NetworkMapService this is the only other address accessible to the node outside of lookups against the NetworkMapCache.
*/ */
override val myAddress: SingleMessageRecipient = if (myIdentity != null) NodeAddress(myIdentity, serverHostPort) else NetworkMapAddress(serverHostPort) override val myAddress: SingleMessageRecipient = if (myIdentity != null) NodeAddress.asPeer(myIdentity, serverHostPort) else NetworkMapAddress(serverHostPort)
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
private val handlers = CopyOnWriteArrayList<Handler>() private val handlers = CopyOnWriteArrayList<Handler>()
@ -449,4 +450,11 @@ class NodeMessagingClient(override val config: NodeConfiguration,
} }
} }
} }
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {
return when (partyInfo) {
is PartyInfo.Node -> partyInfo.node.address
is PartyInfo.Service -> ArtemisMessagingComponent.ServiceAddress(partyInfo.service.identity.owningKey)
}
}
} }

View File

@ -200,11 +200,11 @@ private class RPCKryo(observableSerializer: Serializer<Observable<Any>>? = null)
register(ArtemisMessagingComponent.NodeAddress::class.java, register(ArtemisMessagingComponent.NodeAddress::class.java,
read = { kryo, input -> read = { kryo, input ->
ArtemisMessagingComponent.NodeAddress( ArtemisMessagingComponent.NodeAddress(
CompositeKey.parseFromBase58(kryo.readObject(input, String::class.java)), kryo.readObject(input, SimpleString::class.java),
kryo.readObject(input, HostAndPort::class.java)) kryo.readObject(input, HostAndPort::class.java))
}, },
write = { kryo, output, nodeAddress -> write = { kryo, output, nodeAddress ->
kryo.writeObject(output, nodeAddress.identity.toBase58String()) kryo.writeObject(output, nodeAddress.queueName)
kryo.writeObject(output, nodeAddress.hostAndPort) kryo.writeObject(output, nodeAddress.hostAndPort)
} }
) )

View File

@ -6,6 +6,7 @@ import com.google.common.util.concurrent.SettableFuture
import net.corda.core.bufferUntilSubscribed import net.corda.core.bufferUntilSubscribed
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.map import net.corda.core.map
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.MessagingService import net.corda.core.messaging.MessagingService
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.messaging.createMessage import net.corda.core.messaging.createMessage
@ -14,6 +15,7 @@ import net.corda.core.node.services.DEFAULT_SESSION_ID
import net.corda.core.node.services.NetworkCacheError import net.corda.core.node.services.NetworkCacheError
import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.NetworkMapCache
import net.corda.core.node.services.NetworkMapCache.MapChange import net.corda.core.node.services.NetworkMapCache.MapChange
import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
@ -52,6 +54,21 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
private var registeredForPush = false private var registeredForPush = false
protected var registeredNodes: MutableMap<Party, NodeInfo> = Collections.synchronizedMap(HashMap<Party, NodeInfo>()) protected var registeredNodes: MutableMap<Party, NodeInfo> = Collections.synchronizedMap(HashMap<Party, NodeInfo>())
override fun getPartyInfo(party: Party): PartyInfo? {
val node = registeredNodes[party]
if (node != null) {
return PartyInfo.Node(node)
}
for (entry in registeredNodes) {
for (service in entry.value.advertisedServices) {
if (service.identity == party) {
return PartyInfo.Service(service)
}
}
}
return null
}
override fun track(): Pair<List<NodeInfo>, Observable<MapChange>> { override fun track(): Pair<List<NodeInfo>, Observable<MapChange>> {
synchronized(_changed) { synchronized(_changed) {
return Pair(partyNodes, _changed.bufferUntilSubscribed()) return Pair(partyNodes, _changed.bufferUntilSubscribed())

View File

@ -154,9 +154,12 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
private fun createSessionData(session: FlowSession, payload: Any): SessionData { private fun createSessionData(session: FlowSession, payload: Any): SessionData {
val otherPartySessionId = session.otherPartySessionId val sessionState = session.state
?: throw IllegalStateException("We've somehow held onto an unconfirmed session: $session") val peerSessionId = when (sessionState) {
return SessionData(otherPartySessionId, payload) is StateMachineManager.FlowSessionState.Initiating -> throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
is StateMachineManager.FlowSessionState.Initiated -> sessionState.peerSessionId
}
return SessionData(peerSessionId, payload)
} }
@Suspendable @Suspendable
@ -191,20 +194,19 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
*/ */
@Suspendable @Suspendable
private fun startNewSession(otherParty: Party, sessionFlow: FlowLogic<*>, firstPayload: Any?): FlowSession { private fun startNewSession(otherParty: Party, sessionFlow: FlowLogic<*>, firstPayload: Any?): FlowSession {
val node = serviceHub.networkMapCache.getRepresentativeNode(otherParty) ?: throw IllegalArgumentException("Don't know about party $otherParty") logger.trace { "Initiating a new session with $otherParty" }
val nodeIdentity = node.legalIdentity val session = FlowSession(sessionFlow, random63BitValue(), FlowSessionState.Initiating(otherParty))
logger.trace { "Initiating a new session with $nodeIdentity (representative of $otherParty)" } openSessions[Pair(sessionFlow, otherParty)] = session
val session = FlowSession(sessionFlow, nodeIdentity, random63BitValue(), null) val counterpartyFlow = sessionFlow.getCounterpartyMarker(otherParty).name
openSessions[Pair(sessionFlow, nodeIdentity)] = session
val counterpartyFlow = sessionFlow.getCounterpartyMarker(nodeIdentity).name
val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload) val sessionInit = SessionInit(session.ourSessionId, counterpartyFlow, firstPayload)
val sessionInitResponse = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit) val sessionInitResponse = sendAndReceiveInternal<SessionInitResponse>(session, sessionInit)
if (sessionInitResponse is SessionConfirm) { if (sessionInitResponse is SessionConfirm) {
session.otherPartySessionId = sessionInitResponse.initiatedSessionId require(session.state is FlowSessionState.Initiating)
session.state = FlowSessionState.Initiated(sessionInitResponse.peerParty, sessionInitResponse.initiatedSessionId)
return session return session
} else { } else {
sessionInitResponse as SessionReject sessionInitResponse as SessionReject
throw FlowSessionException("Party $nodeIdentity rejected session attempt: ${sessionInitResponse.errorMessage}") throw FlowSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
} }
} }
@ -228,7 +230,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
if (receivedMessage is SessionEnd) { if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session) openSessions.values.remove(receiveRequest.session)
throw FlowSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended on $receiveRequest") throw FlowSessionException("Counterparty on ${receiveRequest.session.state.sendToParty} has prematurely ended on $receiveRequest")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) { } else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage) return receiveRequest.receiveType.cast(receivedMessage)
} else { } else {

View File

@ -253,11 +253,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
resumeFiber(session.psm) resumeFiber(session.psm)
} }
} else { } else {
val otherParty = recentlyClosedSessions.remove(message.recipientSessionId) val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
if (otherParty != null) { if (peerParty != null) {
if (message is SessionConfirm) { if (message is SessionConfirm) {
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" } logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(otherParty, SessionEnd(message.initiatedSessionId), null) sendSessionMessage(peerParty, SessionEnd(message.initiatedSessionId), null)
} else { } else {
logger.trace { "Ignoring session end message for already closed session: $message" } logger.trace { "Ignoring session end message for already closed session: $message" }
} }
@ -276,14 +276,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
if (flowFactory != null) { if (flowFactory != null) {
val flow = flowFactory(otherParty) val flow = flowFactory(otherParty)
val psm = createFiber(flow) val psm = createFiber(flow)
val session = FlowSession(flow, otherParty, random63BitValue(), otherPartySessionId) val session = FlowSession(flow, random63BitValue(), FlowSessionState.Initiated(otherParty, otherPartySessionId))
if (sessionInit.firstPayload != null) { if (sessionInit.firstPayload != null) {
session.receivedMessages += SessionData(session.ourSessionId, sessionInit.firstPayload) session.receivedMessages += SessionData(session.ourSessionId, sessionInit.firstPayload)
} }
openSessions[session.ourSessionId] = session openSessions[session.ourSessionId] = session
psm.openSessions[Pair(flow, otherParty)] = session psm.openSessions[Pair(flow, otherParty)] = session
updateCheckpoint(psm) updateCheckpoint(psm)
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm) sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId, serviceHub.myInfo.legalIdentity), psm)
psm.logger.debug { "Initiated from $sessionInit on $session" } psm.logger.debug { "Initiated from $sessionInit on $session" }
startFiber(psm) startFiber(psm)
} else { } else {
@ -355,11 +355,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun endAllFiberSessions(psm: FlowStateMachineImpl<*>) { private fun endAllFiberSessions(psm: FlowStateMachineImpl<*>) {
openSessions.values.removeIf { session -> openSessions.values.removeIf { session ->
if (session.psm == psm) { if (session.psm == psm) {
val otherPartySessionId = session.otherPartySessionId val initiatedState = session.state as? FlowSessionState.Initiated
if (otherPartySessionId != null) { if (initiatedState != null) {
sendSessionMessage(session.otherParty, SessionEnd(otherPartySessionId), psm) sendSessionMessage(initiatedState.peerParty, SessionEnd(initiatedState.peerSessionId), psm)
recentlyClosedSessions[session.ourSessionId] = initiatedState.peerParty
} }
recentlyClosedSessions[session.ourSessionId] = session.otherParty
true true
} else { } else {
false false
@ -437,7 +437,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
if (ioRequest.message is SessionInit) { if (ioRequest.message is SessionInit) {
openSessions[ioRequest.session.ourSessionId] = ioRequest.session openSessions[ioRequest.session.ourSessionId] = ioRequest.session
} }
sendSessionMessage(ioRequest.session.otherParty, ioRequest.message, ioRequest.session.psm) sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.psm)
if (ioRequest !is ReceiveRequest<*>) { if (ioRequest !is ReceiveRequest<*>) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
resumeFiber(ioRequest.session.psm) resumeFiber(ioRequest.session.psm)
@ -446,11 +446,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
private fun sendSessionMessage(party: Party, message: SessionMessage, psm: FlowStateMachineImpl<*>?) { private fun sendSessionMessage(party: Party, message: SessionMessage, psm: FlowStateMachineImpl<*>?) {
val node = serviceHub.networkMapCache.getNodeByCompositeKey(party.owningKey) val partyInfo = serviceHub.networkMapCache.getPartyInfo(party)
?: throw IllegalArgumentException("Don't know about party $party") ?: throw IllegalArgumentException("Don't know about party $party")
val address = serviceHub.networkService.getAddressOfParty(partyInfo)
val logger = psm?.logger ?: logger val logger = psm?.logger ?: logger
logger.trace { "Sending $message to party $party" } logger.debug { "Sending $message to party $party, address: $address" }
serviceHub.networkService.send(sessionTopic, message, node.address) serviceHub.networkService.send(sessionTopic, message, address)
} }
@ -464,7 +465,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
interface SessionInitResponse : ExistingSessionMessage interface SessionInitResponse : ExistingSessionMessage
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse { data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long, val peerParty: Party) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId override val recipientSessionId: Long get() = initiatorSessionId
} }
@ -480,16 +481,29 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
sealed class FlowSessionState {
abstract val sendToParty: Party
class Initiating(
val otherParty: Party /** This may be a specific peer or a service party */
) : FlowSessionState() {
override val sendToParty: Party get() = otherParty
}
class Initiated(
val peerParty: Party, /** This must be a peer party */
val peerSessionId: Long
) : FlowSessionState() {
override val sendToParty: Party get() = peerParty
}
}
data class FlowSession(val flow: FlowLogic<*>, data class FlowSession(
val otherParty: Party, val flow: FlowLogic<*>,
val ourSessionId: Long, val ourSessionId: Long,
var otherPartySessionId: Long?, var state: FlowSessionState,
@Volatile var waitingForResponse: Boolean = false) { @Volatile var waitingForResponse: Boolean = false
) {
val receivedMessages = ConcurrentLinkedQueue<ExistingSessionMessage>() val receivedMessages = ConcurrentLinkedQueue<ExistingSessionMessage>()
val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*> val psm: FlowStateMachineImpl<*> get() = flow.fsm as FlowStateMachineImpl<*>
} }
} }

View File

@ -133,7 +133,7 @@ class TwoPartyTradeFlowTests {
val aliceKey = aliceNode.services.legalIdentityKey val aliceKey = aliceNode.services.legalIdentityKey
val notaryKey = notaryNode.services.notaryIdentityKey val notaryKey = notaryNode.services.notaryIdentityKey
val bobAddr = bobNode.net.myAddress as InMemoryMessagingNetwork.Handle val bobAddr = bobNode.net.myAddress as InMemoryMessagingNetwork.PeerHandle
val networkMapAddr = notaryNode.info.address val networkMapAddr = notaryNode.info.address
net.runNetwork() // Clear network map registration messages net.runNetwork() // Clear network map registration messages

View File

@ -29,14 +29,14 @@ class InMemoryNetworkMapCacheTest {
val nodeB = network.createNode(null, -1, MockNetwork.DefaultFactory, true, "Node B", keyPair, ServiceInfo(NetworkMapService.type)) val nodeB = network.createNode(null, -1, MockNetwork.DefaultFactory, true, "Node B", keyPair, ServiceInfo(NetworkMapService.type))
// Node A currently knows only about itself, so this returns node A // Node A currently knows only about itself, so this returns node A
assertEquals(nodeA.netMapCache.getNodeByCompositeKey(keyPair.public.composite), nodeA.info) assertEquals(nodeA.netMapCache.getNodeByLegalIdentityKey(keyPair.public.composite), nodeA.info)
databaseTransaction(nodeA.database) { databaseTransaction(nodeA.database) {
nodeA.netMapCache.addNode(nodeB.info) nodeA.netMapCache.addNode(nodeB.info)
} }
// Now both nodes match, so it throws an error // Now both nodes match, so it throws an error
expect<IllegalStateException> { expect<IllegalStateException> {
nodeA.netMapCache.getNodeByCompositeKey(keyPair.public.composite) nodeA.netMapCache.getNodeByLegalIdentityKey(keyPair.public.composite)
} }
} }
} }

View File

@ -81,7 +81,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
databaseTransaction(database) { databaseTransaction(database) {
val kms = MockKeyManagementService(ALICE_KEY) val kms = MockKeyManagementService(ALICE_KEY)
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"), AffinityExecutor.ServiceAffinityExecutor("test", 1), database) val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.PeerHandle(0, "None"), AffinityExecutor.ServiceAffinityExecutor("test", 1), database)
services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference { services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference {
override val testReference = this@NodeSchedulerServiceTest override val testReference = this@NodeSchedulerServiceTest
} }

View File

@ -198,14 +198,14 @@ class StateMachineManagerTests {
assertSessionTransfers(node2, assertSessionTransfers(node2,
node1 sent sessionInit(SendFlow::class, payload) to node2, node1 sent sessionInit(SendFlow::class, payload) to node2,
node2 sent sessionConfirm() to node1, node2 sent sessionConfirm(node2) to node1,
node1 sent sessionEnd() to node2 node1 sent sessionEnd() to node2
//There's no session end from the other flows as they're manually suspended //There's no session end from the other flows as they're manually suspended
) )
assertSessionTransfers(node3, assertSessionTransfers(node3,
node1 sent sessionInit(SendFlow::class, payload) to node3, node1 sent sessionInit(SendFlow::class, payload) to node3,
node3 sent sessionConfirm() to node1, node3 sent sessionConfirm(node3) to node1,
node1 sent sessionEnd() to node3 node1 sent sessionEnd() to node3
//There's no session end from the other flows as they're manually suspended //There's no session end from the other flows as they're manually suspended
) )
@ -231,14 +231,14 @@ class StateMachineManagerTests {
assertSessionTransfers(node2, assertSessionTransfers(node2,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
node2 sent sessionConfirm() to node1, node2 sent sessionConfirm(node2) to node1,
node2 sent sessionData(node2Payload) to node1, node2 sent sessionData(node2Payload) to node1,
node2 sent sessionEnd() to node1 node2 sent sessionEnd() to node1
) )
assertSessionTransfers(node3, assertSessionTransfers(node3,
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3, node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node3,
node3 sent sessionConfirm() to node1, node3 sent sessionConfirm(node3) to node1,
node3 sent sessionData(node3Payload) to node1, node3 sent sessionData(node3Payload) to node1,
node3 sent sessionEnd() to node1 node3 sent sessionEnd() to node1
) )
@ -252,7 +252,7 @@ class StateMachineManagerTests {
assertSessionTransfers( assertSessionTransfers(
node1 sent sessionInit(PingPongFlow::class, 10L) to node2, node1 sent sessionInit(PingPongFlow::class, 10L) to node2,
node2 sent sessionConfirm() to node1, node2 sent sessionConfirm(node2) to node1,
node2 sent sessionData(20L) to node1, node2 sent sessionData(20L) to node1,
node1 sent sessionData(11L) to node2, node1 sent sessionData(11L) to node2,
node2 sent sessionData(21L) to node1, node2 sent sessionData(21L) to node1,
@ -268,7 +268,7 @@ class StateMachineManagerTests {
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowSessionException::class.java) assertThatThrownBy { future.getOrThrow() }.isInstanceOf(FlowSessionException::class.java)
assertSessionTransfers( assertSessionTransfers(
node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2, node1 sent sessionInit(ReceiveThenSuspendFlow::class) to node2,
node2 sent sessionConfirm() to node1, node2 sent sessionConfirm(node2) to node1,
node2 sent sessionEnd() to node1 node2 sent sessionEnd() to node1
) )
} }
@ -290,7 +290,7 @@ class StateMachineManagerTests {
private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload) private fun sessionInit(flowMarker: KClass<*>, payload: Any? = null) = SessionInit(0, flowMarker.java.name, payload)
private fun sessionConfirm() = SessionConfirm(0, 0) private fun sessionConfirm(mockNode: MockNode) = SessionConfirm(0, 0, mockNode.info.legalIdentity)
private fun sessionData(payload: Any) = SessionData(0, payload) private fun sessionData(payload: Any) = SessionData(0, payload)
@ -314,7 +314,7 @@ class StateMachineManagerTests {
return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map { return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map {
val from = it.sender.id val from = it.sender.id
val message = it.message.data.deserialize<SessionMessage>() val message = it.message.data.deserialize<SessionMessage>()
val to = (it.recipients as InMemoryMessagingNetwork.Handle).id val to = (it.recipients as InMemoryMessagingNetwork.PeerHandle).id
SessionTransfer(from, sanitise(message), to) SessionTransfer(from, sanitise(message), to)
} }
} }

View File

@ -4,9 +4,12 @@ import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import net.corda.core.ThreadBox import net.corda.core.ThreadBox
import net.corda.core.getOrThrow
import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.X509Utilities
import net.corda.core.getOrThrow
import net.corda.core.messaging.* import net.corda.core.messaging.*
import net.corda.core.node.ServiceEntry
import net.corda.core.node.services.PartyInfo
import net.corda.core.node.services.ServiceInfo
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.services.api.MessagingServiceBuilder import net.corda.node.services.api.MessagingServiceBuilder
@ -35,18 +38,20 @@ import kotlin.concurrent.thread
* messages one by one to registered handlers. Alternatively, a messaging system may be manually pumped, in which * messages one by one to registered handlers. Alternatively, a messaging system may be manually pumped, in which
* case no thread is created and a caller is expected to force delivery one at a time (this is useful for unit * case no thread is created and a caller is expected to force delivery one at a time (this is useful for unit
* testing). * testing).
*
* @param random The RNG used to choose which node to send to in case one sends to a service.
*/ */
@ThreadSafe @ThreadSafe
class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSerializeAsToken() { class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean, val random: SplittableRandom = SplittableRandom()) : SingletonSerializeAsToken() {
companion object { companion object {
val MESSAGES_LOG_NAME = "messages" val MESSAGES_LOG_NAME = "messages"
private val log = LoggerFactory.getLogger(MESSAGES_LOG_NAME) private val log = LoggerFactory.getLogger(MESSAGES_LOG_NAME)
} }
private var counter = 0 // -1 means stopped. private var counter = 0 // -1 means stopped.
private val handleEndpointMap = HashMap<Handle, InMemoryMessaging>() private val handleEndpointMap = HashMap<PeerHandle, InMemoryMessaging>()
data class MessageTransfer(val sender: Handle, val message: Message, val recipients: MessageRecipients) { data class MessageTransfer(val sender: PeerHandle, val message: Message, val recipients: MessageRecipients) {
override fun toString() = "${message.topicSession} from '$sender' to '$recipients'" override fun toString() = "${message.topicSession} from '$sender' to '$recipients'"
} }
@ -64,9 +69,11 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
// been created yet. If the node identified by the given handle has gone away/been shut down then messages // been created yet. If the node identified by the given handle has gone away/been shut down then messages
// stack up here waiting for it to come back. The intent of this is to simulate a reliable messaging network. // stack up here waiting for it to come back. The intent of this is to simulate a reliable messaging network.
// The corresponding stream reflects when a message was pumpReceive'd // The corresponding stream reflects when a message was pumpReceive'd
private val messageReceiveQueues = HashMap<Handle, LinkedBlockingQueue<MessageTransfer>>() private val messageReceiveQueues = HashMap<PeerHandle, LinkedBlockingQueue<MessageTransfer>>()
private val _receivedMessages = PublishSubject.create<MessageTransfer>() private val _receivedMessages = PublishSubject.create<MessageTransfer>()
private val serviceToPeersMapping = HashMap<ServiceHandle, HashSet<PeerHandle>>()
val messagesInFlight = ReusableLatch() val messagesInFlight = ReusableLatch()
@Suppress("unused") // Used by the visualiser tool. @Suppress("unused") // Used by the visualiser tool.
@ -90,9 +97,10 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
@Synchronized @Synchronized
fun createNode(manuallyPumped: Boolean, fun createNode(manuallyPumped: Boolean,
executor: AffinityExecutor, executor: AffinityExecutor,
database: Database): Pair<Handle, MessagingServiceBuilder<InMemoryMessaging>> { advertisedServices: List<ServiceEntry>,
database: Database): Pair<PeerHandle, MessagingServiceBuilder<InMemoryMessaging>> {
check(counter >= 0) { "In memory network stopped: please recreate." } check(counter >= 0) { "In memory network stopped: please recreate." }
val builder = createNodeWithID(manuallyPumped, counter, executor, database = database) as Builder val builder = createNodeWithID(manuallyPumped, counter, executor, advertisedServices, database = database) as Builder
counter++ counter++
val id = builder.id val id = builder.id
return Pair(id, builder) return Pair(id, builder)
@ -106,10 +114,15 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
* @param description text string that identifies this node for message logging (if is enabled) or null to autogenerate. * @param description text string that identifies this node for message logging (if is enabled) or null to autogenerate.
* @param persistenceTx a lambda to wrap message handling in a transaction if necessary. * @param persistenceTx a lambda to wrap message handling in a transaction if necessary.
*/ */
fun createNodeWithID(manuallyPumped: Boolean, id: Int, executor: AffinityExecutor, description: String? = null, fun createNodeWithID(
database: Database) manuallyPumped: Boolean,
id: Int,
executor: AffinityExecutor,
advertisedServices: List<ServiceEntry>,
description: String? = null,
database: Database)
: MessagingServiceBuilder<InMemoryMessaging> { : MessagingServiceBuilder<InMemoryMessaging> {
return Builder(manuallyPumped, Handle(id, description ?: "In memory node $id"), executor, database = database) return Builder(manuallyPumped, PeerHandle(id, description ?: "In memory node $id"), advertisedServices.map(::ServiceHandle), executor, database = database)
} }
interface LatencyCalculator { interface LatencyCalculator {
@ -127,12 +140,20 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
} }
@Synchronized @Synchronized
private fun netNodeHasShutdown(handle: Handle) { private fun netNodeHasShutdown(peerHandle: PeerHandle) {
handleEndpointMap.remove(handle) handleEndpointMap.remove(peerHandle)
} }
@Synchronized @Synchronized
private fun getQueueForHandle(recipients: Handle) = messageReceiveQueues.getOrPut(recipients) { LinkedBlockingQueue() } private fun getQueueForPeerHandle(recipients: PeerHandle) = messageReceiveQueues.getOrPut(recipients) { LinkedBlockingQueue() }
@Synchronized
private fun getQueuesForServiceHandle(recipients: ServiceHandle): List<LinkedBlockingQueue<MessageTransfer>> {
return serviceToPeersMapping[recipients]!!.map {
messageReceiveQueues.getOrPut(it) { LinkedBlockingQueue() }
}
}
val everyoneOnline: AllPossibleRecipients = object : AllPossibleRecipients {} val everyoneOnline: AllPossibleRecipients = object : AllPossibleRecipients {}
@ -149,22 +170,35 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
messageReceiveQueues.clear() messageReceiveQueues.clear()
} }
inner class Builder(val manuallyPumped: Boolean, val id: Handle, val executor: AffinityExecutor, val database: Database) : MessagingServiceBuilder<InMemoryMessaging> { inner class Builder(
val manuallyPumped: Boolean,
val id: PeerHandle,
val serviceHandles: List<ServiceHandle>,
val executor: AffinityExecutor,
val database: Database) : MessagingServiceBuilder<InMemoryMessaging> {
override fun start(): ListenableFuture<InMemoryMessaging> { override fun start(): ListenableFuture<InMemoryMessaging> {
synchronized(this@InMemoryMessagingNetwork) { synchronized(this@InMemoryMessagingNetwork) {
val node = InMemoryMessaging(manuallyPumped, id, executor, database) val node = InMemoryMessaging(manuallyPumped, id, executor, database)
handleEndpointMap[id] = node handleEndpointMap[id] = node
serviceHandles.forEach {
serviceToPeersMapping.getOrPut(it) { HashSet<PeerHandle>() }.add(id)
Unit
}
return Futures.immediateFuture(node) return Futures.immediateFuture(node)
} }
} }
} }
class Handle(val id: Int, val description: String) : SingleMessageRecipient { class PeerHandle(val id: Int, val description: String) : SingleMessageRecipient {
override fun toString() = description override fun toString() = description
override fun equals(other: Any?) = other is Handle && other.id == id override fun equals(other: Any?) = other is PeerHandle && other.id == id
override fun hashCode() = id.hashCode() override fun hashCode() = id.hashCode()
} }
data class ServiceHandle(val service: ServiceEntry) : MessageRecipientGroup {
override fun toString() = "Service($service)"
}
// If block is set to true this function will only return once a message has been pushed onto the recipients' queues // If block is set to true this function will only return once a message has been pushed onto the recipients' queues
fun pumpSend(block: Boolean): MessageTransfer? { fun pumpSend(block: Boolean): MessageTransfer? {
val transfer = (if (block) messageSendQueue.take() else messageSendQueue.poll()) ?: return null val transfer = (if (block) messageSendQueue.take() else messageSendQueue.poll()) ?: return null
@ -190,12 +224,17 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
fun pumpSendInternal(transfer: MessageTransfer) { fun pumpSendInternal(transfer: MessageTransfer) {
when (transfer.recipients) { when (transfer.recipients) {
is Handle -> getQueueForHandle(transfer.recipients).add(transfer) is PeerHandle -> getQueueForPeerHandle(transfer.recipients).add(transfer)
is ServiceHandle -> {
val queues = getQueuesForServiceHandle(transfer.recipients)
val chosedPeerIndex = random.nextInt(queues.size)
queues[chosedPeerIndex].add(transfer)
}
is AllPossibleRecipients -> { is AllPossibleRecipients -> {
// This means all possible recipients _that the network knows about at the time_, not literally everyone // This means all possible recipients _that the network knows about at the time_, not literally everyone
// who joins into the indefinite future. // who joins into the indefinite future.
for (handle in handleEndpointMap.keys) for (handle in handleEndpointMap.keys)
getQueueForHandle(handle).add(transfer) getQueueForPeerHandle(handle).add(transfer)
} }
else -> throw IllegalArgumentException("Unknown type of recipient handle") else -> throw IllegalArgumentException("Unknown type of recipient handle")
} }
@ -211,7 +250,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
*/ */
@ThreadSafe @ThreadSafe
inner class InMemoryMessaging(private val manuallyPumped: Boolean, inner class InMemoryMessaging(private val manuallyPumped: Boolean,
private val handle: Handle, private val peerHandle: PeerHandle,
private val executor: AffinityExecutor, private val executor: AffinityExecutor,
private val database: Database) : SingletonSerializeAsToken(), MessagingServiceInternal { private val database: Database) : SingletonSerializeAsToken(), MessagingServiceInternal {
inner class Handler(val topicSession: TopicSession, inner class Handler(val topicSession: TopicSession,
@ -228,7 +267,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>()) private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>())
override val myAddress: Handle get() = handle override val myAddress: PeerHandle get() = peerHandle
private val backgroundThread = if (manuallyPumped) null else private val backgroundThread = if (manuallyPumped) null else
thread(isDaemon = true, name = "In-memory message dispatcher") { thread(isDaemon = true, name = "In-memory message dispatcher") {
@ -241,6 +280,13 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
} }
} }
override fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients {
return when (partyInfo) {
is PartyInfo.Node -> partyInfo.node.address
is PartyInfo.Service -> ServiceHandle(partyInfo.service)
}
}
override fun addMessageHandler(topic: String, sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration override fun addMessageHandler(topic: String, sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
= addMessageHandler(TopicSession(topic, sessionID), callback) = addMessageHandler(TopicSession(topic, sessionID), callback)
@ -279,7 +325,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
backgroundThread.join() backgroundThread.join()
} }
running = false running = false
netNodeHasShutdown(handle) netNodeHasShutdown(peerHandle)
} }
/** Returns the given (topic & session, data) pair as a newly created message object. */ /** Returns the given (topic & session, data) pair as a newly created message object. */
@ -347,7 +393,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
} }
private fun pumpReceiveInternal(block: Boolean): MessageTransfer? { private fun pumpReceiveInternal(block: Boolean): MessageTransfer? {
val q = getQueueForHandle(handle) val q = getQueueForPeerHandle(peerHandle)
val next = getNextQueue(q, block) ?: return null val next = getNextQueue(q, block) ?: return null
val (transfer, deliverTo) = next val (transfer, deliverTo) = next

View File

@ -5,6 +5,7 @@ import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
import net.corda.core.* import net.corda.core.*
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
import net.corda.core.messaging.RPCOps
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.CordaPluginRegistry import net.corda.core.node.CordaPluginRegistry
import net.corda.core.node.PhysicalLocation import net.corda.core.node.PhysicalLocation
@ -15,7 +16,6 @@ import net.corda.node.internal.AbstractNode
import net.corda.node.services.api.MessagingServiceInternal import net.corda.node.services.api.MessagingServiceInternal
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.keys.E2ETestKeyManagementService import net.corda.node.services.keys.E2ETestKeyManagementService
import net.corda.core.messaging.RPCOps
import net.corda.node.services.network.InMemoryNetworkMapService import net.corda.node.services.network.InMemoryNetworkMapService
import net.corda.node.services.network.NetworkMapService import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.transactions.InMemoryUniquenessProvider import net.corda.node.services.transactions.InMemoryUniquenessProvider
@ -118,7 +118,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// through the java.nio API which we are already mocking via Jimfs. // through the java.nio API which we are already mocking via Jimfs.
override fun makeMessagingService(): MessagingServiceInternal { override fun makeMessagingService(): MessagingServiceInternal {
require(id >= 0) { "Node ID must be zero or positive, was passed: " + id } require(id >= 0) { "Node ID must be zero or positive, was passed: " + id }
return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, serverThread, configuration.myLegalName, database).start().getOrThrow() return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, serverThread, makeServiceEntries(), configuration.myLegalName, database).start().getOrThrow()
} }
override fun makeIdentityService() = MockIdentityService(mockNet.identities) override fun makeIdentityService() = MockIdentityService(mockNet.identities)