r3corda wire compatibility

This commit is contained in:
Andras Slemmer 2018-02-08 15:13:25 +00:00
parent 1902a4f11e
commit 0a88b76e46
19 changed files with 489 additions and 485 deletions

3
.idea/compiler.xml generated
View File

@ -72,6 +72,7 @@
<module name="irs-demo-web_test" target="1.8" />
<module name="irs-demo_integrationTest" target="1.8" />
<module name="irs-demo_main" target="1.8" />
<module name="irs-demo_systemTest" target="1.8" />
<module name="irs-demo_test" target="1.8" />
<module name="isolated_main" target="1.8" />
<module name="isolated_test" target="1.8" />
@ -159,4 +160,4 @@
<component name="JavacSettings">
<option name="ADDITIONAL_OPTIONS_STRING" value="-parameters" />
</component>
</project>
</project>

View File

@ -10,12 +10,18 @@ import net.corda.core.identity.Party
import net.corda.core.messaging.MessageRecipients
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.internal.StartedNode
import net.corda.node.services.messaging.Message
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.node.*
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MessagingServiceSpy
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.setMessagingServiceSpy
import net.corda.testing.node.startFlow
import org.junit.After
import org.junit.Before
import org.junit.Rule
@ -79,12 +85,12 @@ class TutorialMockNetwork {
// modify message if it's 1
nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) {
val messageData = message.data.deserialize<Any>()
if (messageData is SessionData && messageData.payload.deserialize() == 1) {
val alteredMessageData = SessionData(messageData.recipientSessionId, 99.serialize()).serialize().bytes
messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topicSession, alteredMessageData, message.uniqueMessageId), target, retryId)
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
val messageData = message.data.deserialize<Any>() as? ExistingSessionMessage
val payload = messageData?.payload
if (payload is DataSessionMessage && payload.payload.deserialize() == 1) {
val alteredMessageData = messageData.copy(payload = payload.copy(99.serialize())).serialize().bytes
messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData), message.uniqueMessageId), target, retryId)
} else {
messagingService.send(message, target, retryId)
}

View File

@ -3,7 +3,7 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.*
import net.corda.node.services.statemachine.SessionData
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import net.corda.nodeapi.internal.serialization.amqp.Envelope
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
@ -47,17 +47,17 @@ class ListsSerializationTest {
@Test
fun `check list can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, listOf(1).serialize())
val sessionData = DataSessionMessage(listOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, listOf(1, 2).serialize())
val sessionData = DataSessionMessage(listOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptyList<Int>().serialize())
val sessionData = DataSessionMessage(emptyList<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptyList<Int>(), sessionData.payload.deserialize())
}

View File

@ -6,7 +6,7 @@ import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.kryo.kryoMagic
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.amqpSpecific
@ -41,7 +41,7 @@ class MapsSerializationTest {
@Test
fun `check list can be serialized as part of SessionData`() {
val sessionData = SessionData(123, smallMap.serialize())
val sessionData = DataSessionMessage(smallMap.serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(smallMap, sessionData.payload.deserialize())
}

View File

@ -4,10 +4,10 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.kryo.kryoMagic
import net.corda.testing.internal.kryoSpecific
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.kryoSpecific
import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals
import org.junit.Rule
@ -34,17 +34,17 @@ class SetsSerializationTest {
@Test
fun `check set can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, setOf(1).serialize())
val sessionData = DataSessionMessage(setOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, setOf(1, 2).serialize())
val sessionData = DataSessionMessage(setOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1, 2), sessionData.payload.deserialize())
}
run {
val sessionData = SessionData(123, emptySet<Int>().serialize())
val sessionData = DataSessionMessage(emptySet<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptySet<Int>(), sessionData.payload.deserialize())
}

View File

@ -1,9 +1,9 @@
package net.corda.services.messaging
import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.random63BitValue
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.randomOrNull
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
@ -14,7 +14,9 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds
import net.corda.node.internal.Node
import net.corda.node.internal.StartedNode
import net.corda.node.services.messaging.*
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.send
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.chooseIdentity
import net.corda.testing.driver.DriverDSL
@ -27,6 +29,7 @@ import org.junit.Test
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
class P2PMessagingTest {
@ -50,19 +53,12 @@ class P2PMessagingTest {
alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!)
}
val dummyTopic = "dummy.topic"
val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage)
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
// Send a single request with retry
val responseFuture = with(alice.network) {
val request = TestRequest(replyTo = myAddress)
val responseFuture = onNext<Any>(dummyTopic, request.sessionID)
val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes)
send(msg, serviceAddress, retryId = request.sessionID)
responseFuture
}
val responseFuture = alice.receiveFrom(serviceAddress, retryId = 0)
crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS)
// The request wasn't successful.
assertThat(responseFuture.isDone).isFalse()
@ -83,19 +79,12 @@ class P2PMessagingTest {
alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!)
}
val dummyTopic = "dummy.topic"
val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage)
val sessionId = random63BitValue()
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
// Send a single request with retry
with(alice.network) {
val request = TestRequest(sessionId, myAddress)
val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes)
send(msg, serviceAddress, retryId = request.sessionID)
}
alice.receiveFrom(serviceAddress, retryId = 0)
// Wait until the first request is received
crashingNodes.firstRequestReceived.await()
@ -108,7 +97,13 @@ class P2PMessagingTest {
// Restart the node and expect a response
val aliceRestarted = startAlice()
val response = aliceRestarted.network.onNext<Any>(dummyTopic, sessionId).getOrThrow()
val responseFuture = openFuture<Any>()
aliceRestarted.network.runOnNextMessage("test.response") {
responseFuture.set(it.data.deserialize())
}
val response = responseFuture.getOrThrow()
assertThat(crashingNodes.requestsReceived.get()).isGreaterThan(numberOfRequestsReceived)
assertThat(response).isEqualTo(responseMessage)
}
@ -133,11 +128,12 @@ class P2PMessagingTest {
)
/**
* Sets up the [distributedServiceNodes] to respond to [dummyTopic] requests. All nodes will receive requests and
* either ignore them or respond, depending on the value of [CrashingNodes.ignoreRequests], initially set to true.
* This may be used to simulate scenarios where nodes receive request messages but crash before sending back a response.
* Sets up the [distributedServiceNodes] to respond to "test.request" requests. All nodes will receive requests and
* either ignore them or respond to "test.response", depending on the value of [CrashingNodes.ignoreRequests],
* initially set to true. This may be used to simulate scenarios where nodes receive request messages but crash
* before sending back a response.
*/
private fun simulateCrashingNodes(distributedServiceNodes: List<StartedNode<*>>, dummyTopic: String, responseMessage: String): CrashingNodes {
private fun simulateCrashingNodes(distributedServiceNodes: List<StartedNode<*>>, responseMessage: String): CrashingNodes {
val crashingNodes = CrashingNodes(
requestsReceived = AtomicInteger(0),
firstRequestReceived = CountDownLatch(1),
@ -146,7 +142,7 @@ class P2PMessagingTest {
distributedServiceNodes.forEach {
val nodeName = it.info.chooseIdentity().name
it.network.addMessageHandler(dummyTopic) { netMessage, _ ->
it.network.addMessageHandler("test.request") { netMessage, _ ->
crashingNodes.requestsReceived.incrementAndGet()
crashingNodes.firstRequestReceived.countDown()
// The node which receives the first request will ignore all requests
@ -158,7 +154,7 @@ class P2PMessagingTest {
} else {
println("sending response")
val request = netMessage.data.deserialize<TestRequest>()
val response = it.network.createMessage(dummyTopic, request.sessionID, responseMessage.serialize().bytes)
val response = it.network.createMessage("test.response", responseMessage.serialize().bytes)
it.network.send(response, request.replyTo)
}
}
@ -188,19 +184,39 @@ class P2PMessagingTest {
}
private fun StartedNode<*>.respondWith(message: Any) {
network.addMessageHandler(javaClass.name) { netMessage, _ ->
network.addMessageHandler("test.request") { netMessage, _ ->
val request = netMessage.data.deserialize<TestRequest>()
val response = network.createMessage(javaClass.name, request.sessionID, message.serialize().bytes)
val response = network.createMessage("test.response", message.serialize().bytes)
network.send(response, request.replyTo)
}
}
private fun StartedNode<*>.receiveFrom(target: MessageRecipients): CordaFuture<Any> {
val request = TestRequest(replyTo = network.myAddress)
return network.sendRequest(javaClass.name, request, target)
private fun StartedNode<*>.receiveFrom(target: MessageRecipients, retryId: Long? = null): CordaFuture<Any> {
val response = openFuture<Any>()
network.runOnNextMessage("test.response") { netMessage ->
response.set(netMessage.data.deserialize())
}
network.send("test.request", TestRequest(replyTo = network.myAddress), target, retryId = retryId)
return response
}
/**
* Registers a handler for the given topic and session that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler].
*
* @param topic identifier for the topic and session to listen for messages arriving on.
*/
inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topic) { msg, reg ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" }
callback(msg)
}
}
@CordaSerializable
private data class TestRequest(override val sessionID: Long = random63BitValue(),
override val replyTo: SingleMessageRecipient) : ServiceRequestMessage
private data class TestRequest(val replyTo: SingleMessageRecipient)
}

View File

@ -1,18 +1,15 @@
package net.corda.node.services.messaging
import net.corda.core.concurrent.CordaFuture
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.ByteSequence
import java.time.Instant
import java.util.*
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.ThreadSafe
/**
@ -27,29 +24,6 @@ import javax.annotation.concurrent.ThreadSafe
*/
@ThreadSafe
interface MessagingService {
companion object {
/**
* Session ID to use for services listening for the first message in a session (before a
* specific session ID has been established).
*/
val DEFAULT_SESSION_ID = 0L
}
/**
* The provided function will be invoked for each received message whose topic matches the given string. The callback
* will run on threads provided by the messaging service, and the callback is expected to be thread safe as a result.
*
* The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler].
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
* itself and yet addMessageHandler hasn't returned the handle yet.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun addMessageHandler(topic: String = "", sessionID: Long = DEFAULT_SESSION_ID, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
/**
* The provided function will be invoked for each received message whose topic and session matches. The callback
* will run on the main server thread provided when the messaging service is constructed, and a database
@ -59,9 +33,9 @@ interface MessagingService {
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
* itself and yet addMessageHandler hasn't returned the handle yet.
*
* @param topicSession identifier for the topic and session to listen for messages arriving on.
* @param topic identifier for the topic to listen for messages arriving on.
*/
fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
/**
* Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once
@ -86,15 +60,13 @@ interface MessagingService {
* @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two
* subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same
* sequence the send()s were called. By default this is chosen conservatively to be [target].
* @param acknowledgementHandler if non-null this handler will be called once the sent message has been committed by
* the broker. Note that if specified [send] itself may return earlier than the commit.
*/
@Suspendable
fun send(
message: Message,
target: MessageRecipients,
retryId: Long? = null,
sequenceKey: Any = target,
acknowledgementHandler: (() -> Unit)? = null
sequenceKey: Any = target
)
/** A message with a target and sequenceKey specified. */
@ -110,12 +82,9 @@ interface MessagingService {
* implementation.
*
* @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys.
* @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id.
* Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary.
* @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed
* by the broker. Note that if specified [send] itself may return earlier than the commit.
*/
fun send(addressedMessages: List<AddressedMessage>, acknowledgementHandler: (() -> Unit)? = null)
@Suspendable
fun send(addressedMessages: List<AddressedMessage>)
/** Cancels the scheduled message redelivery for the specified [retryId] */
fun cancelRedelivery(retryId: Long)
@ -123,9 +92,9 @@ interface MessagingService {
/**
* Returns an initialised [Message] with the current time, etc, already filled in.
*
* @param topicSession identifier for the topic and session the message is sent to.
* @param topic identifier for the topic the message is sent to.
*/
fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID = UUID.randomUUID()): Message
fun createMessage(topic: String, data: ByteArray, deduplicationId: String = UUID.randomUUID().toString()): Message
/** Given information about either a specific node or a service returns its corresponding address */
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
@ -134,86 +103,12 @@ interface MessagingService {
val myAddress: SingleMessageRecipient
}
/**
* Returns an initialised [Message] with the current time, etc, already filled in.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* Must not be blank.
* @param sessionID identifier for the session the message is part of. For messages sent to services before the
* construction of a session, use [DEFAULT_SESSION_ID].
*/
fun MessagingService.createMessage(topic: String, sessionID: Long = MessagingService.DEFAULT_SESSION_ID, data: ByteArray): Message
= createMessage(TopicSession(topic, sessionID), data)
/**
* Registers a handler for the given topic and session ID that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler], as the handler is
* automatically deregistered before the callback runs.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, callback: (ReceivedMessage) -> Unit)
= runOnNextMessage(TopicSession(topic, sessionID), callback)
/**
* Registers a handler for the given topic and session that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler].
*
* @param topicSession identifier for the topic and session to listen for messages arriving on.
*/
inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, crossinline callback: (ReceivedMessage) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topicSession) { msg, reg ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topicSession == topicSession) { "Topic/session mismatch: ${msg.topicSession} vs $topicSession" }
callback(msg)
}
}
/**
* Returns a [CordaFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId.
* The payload is deserialized to an object of type [M]. Any exceptions thrown will be captured by the future.
*/
fun <M : Any> MessagingService.onNext(topic: String, sessionId: Long): CordaFuture<M> {
val messageFuture = openFuture<M>()
runOnNextMessage(topic, sessionId) { message ->
messageFuture.capture {
uncheckedCast(message.data.deserialize<Any>())
}
}
return messageFuture
}
fun MessagingService.send(topic: String, sessionID: Long, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID()) {
send(TopicSession(topic, sessionID), payload, to, uuid)
}
fun MessagingService.send(topicSession: TopicSession, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID(), retryId: Long? = null) {
send(createMessage(topicSession, payload.serialize().bytes, uuid), to, retryId)
}
fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: String = UUID.randomUUID().toString(), retryId: Long? = null)
= send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId)
interface MessageHandlerRegistration
/**
* An identifier for the endpoint [MessagingService] message handlers listen at.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
@CordaSerializable
data class TopicSession(val topic: String, val sessionID: Long = MessagingService.DEFAULT_SESSION_ID) {
fun isBlank() = topic.isBlank() && sessionID == MessagingService.DEFAULT_SESSION_ID
override fun toString(): String = "$topic.$sessionID"
}
/**
* A message is defined, at this level, to be a (topic, timestamp, byte arrays) triple, where the topic is a string in
* Java-style reverse dns form, with "platform." being a prefix reserved by the platform for its own use. Vendor
@ -226,10 +121,10 @@ data class TopicSession(val topic: String, val sessionID: Long = MessagingServic
*/
@CordaSerializable
interface Message {
val topicSession: TopicSession
val data: ByteArray
val topic: String
val data: ByteSequence
val debugTimestamp: Instant
val uniqueMessageId: UUID
val uniqueMessageId: String
}
// TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that.

View File

@ -13,10 +13,7 @@ import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.sequence
import net.corda.core.utilities.trace
import net.corda.core.utilities.*
import net.corda.node.VersionInfo
import net.corda.node.services.api.NetworkMapCacheInternal
import net.corda.node.services.config.NodeConfiguration
@ -98,20 +95,19 @@ class P2PMessagingClient(config: NodeConfiguration,
// that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid
// confusion.
private val topicProperty = SimpleString("platform-topic")
private val sessionIdProperty = SimpleString("session-id")
private val cordaVendorProperty = SimpleString("corda-vendor")
private val releaseVersionProperty = SimpleString("release-version")
private val platformVersionProperty = SimpleString("platform-version")
private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
private val messageMaxRetryCount: Int = 3
fun createProcessedMessage(): AppendOnlyPersistentMap<UUID, Instant, ProcessedMessage, String> {
fun createProcessedMessage(): AppendOnlyPersistentMap<String, Instant, ProcessedMessage, String> {
return AppendOnlyPersistentMap(
toPersistentEntityKey = { it.toString() },
fromPersistentEntity = { Pair(UUID.fromString(it.uuid), it.insertionTime) },
toPersistentEntity = { key: UUID, value: Instant ->
toPersistentEntityKey = { it },
fromPersistentEntity = { Pair(it.uuid, it.insertionTime) },
toPersistentEntity = { key: String, value: Instant ->
ProcessedMessage().apply {
uuid = key.toString()
uuid = key
insertionTime = value
}
},
@ -139,9 +135,9 @@ class P2PMessagingClient(config: NodeConfiguration,
)
}
private class NodeClientMessage(override val topicSession: TopicSession, override val data: ByteArray, override val uniqueMessageId: UUID) : Message {
private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: String) : Message {
override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topicSession#${String(data)}"
override fun toString() = "$topic#${String(data.bytes)}"
}
}
@ -160,7 +156,7 @@ class P2PMessagingClient(config: NodeConfiguration,
private val scheduledMessageRedeliveries = ConcurrentHashMap<Long, ScheduledFuture<*>>()
/** A registration to handle messages of different types */
data class Handler(val topicSession: TopicSession,
data class Handler(val topic: String,
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
private val cordaVendor = SimpleString(versionInfo.vendor)
@ -181,7 +177,7 @@ class P2PMessagingClient(config: NodeConfiguration,
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
class ProcessedMessage(
@Id
@Column(name = "message_id", length = 36)
@Column(name = "message_id", length = 64)
var uuid: String = "",
@Column(name = "insertion_time")
@ -192,7 +188,7 @@ class P2PMessagingClient(config: NodeConfiguration,
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry")
class RetryMessage(
@Id
@Column(name = "message_id", length = 36)
@Column(name = "message_id", length = 64)
var key: Long = 0,
@Lob
@ -383,14 +379,13 @@ class P2PMessagingClient(config: NodeConfiguration,
private fun artemisToCordaMessage(message: ClientMessage): ReceivedMessage? {
try {
val topic = message.required(topicProperty) { getStringProperty(it) }
val sessionID = message.required(sessionIdProperty) { getLongProperty(it) }
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
val platformVersion = message.required(platformVersionProperty) { getIntProperty(it) }
// Use the magic deduplication property built into Artemis as our message identity too
val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { UUID.fromString(message.getStringProperty(it)) }
log.info("Received message from: ${message.address} user: $user topic: $topic sessionID: $sessionID uuid: $uuid")
val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { message.getStringProperty(it) }
log.info("Received message from: ${message.address} user: $user topic: $topic uuid: $uuid")
return ArtemisReceivedMessage(TopicSession(topic, sessionID), CordaX500Name.parse(user), platformVersion, uuid, message)
return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uuid, message)
} catch (e: Exception) {
log.error("Unable to process message, ignoring it: $message", e)
return null
@ -402,21 +397,21 @@ class P2PMessagingClient(config: NodeConfiguration,
return extractor(key)
}
private class ArtemisReceivedMessage(override val topicSession: TopicSession,
private class ArtemisReceivedMessage(override val topic: String,
override val peer: CordaX500Name,
override val platformVersion: Int,
override val uniqueMessageId: UUID,
override val uniqueMessageId: String,
private val message: ClientMessage) : ReceivedMessage {
override val data: ByteArray by lazy { ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } }
override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) }
override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp)
override fun toString() = "${topicSession.topic}#${data.sequence()}"
override fun toString() = "$topic#$data"
}
private fun deliver(msg: ReceivedMessage): Boolean {
state.checkNotLocked()
// Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added
// or removed whilst the filter is executing will not affect anything.
val deliverTo = handlers.filter { it.topicSession.isBlank() || it.topicSession == msg.topicSession }
val deliverTo = handlers.filter { it.topic.isBlank() || it.topic== msg.topic }
try {
// This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will
// be slow, and Artemis can handle that case intelligently. We don't just invoke the handler
@ -429,11 +424,11 @@ class P2PMessagingClient(config: NodeConfiguration,
nodeExecutor.fetchFrom {
database.transaction {
if (msg.uniqueMessageId in processedMessages) {
log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topicSession}" }
log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" }
} else {
if (deliverTo.isEmpty()) {
// TODO: Implement dead letter queue, and send it there.
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topicSession} that doesn't have any registered handlers yet")
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet")
} else {
callHandlers(msg, deliverTo)
}
@ -443,7 +438,7 @@ class P2PMessagingClient(config: NodeConfiguration,
}
}
} catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topicSession}", e)
log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
}
return true
}
@ -501,7 +496,7 @@ class P2PMessagingClient(config: NodeConfiguration,
}
}
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
// We have to perform sending on a different thread pool, since using the same pool for messaging and
// fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals.
messagingExecutor.fetchFrom {
@ -512,20 +507,18 @@ class P2PMessagingClient(config: NodeConfiguration,
putStringProperty(cordaVendorProperty, cordaVendor)
putStringProperty(releaseVersionProperty, releaseVersion)
putIntProperty(platformVersionProperty, versionInfo.platformVersion)
putStringProperty(topicProperty, SimpleString(message.topicSession.topic))
putLongProperty(sessionIdProperty, message.topicSession.sessionID)
writeBodyBufferBytes(message.data)
putStringProperty(topicProperty, SimpleString(message.topic))
writeBodyBufferBytes(message.data.bytes)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString()))
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
if (amqDelayMillis > 0 && message.topicSession.topic == StateMachineManagerImpl.sessionTopic.topic) {
if (amqDelayMillis > 0 && message.topic == StateMachineManagerImpl.sessionTopic) {
putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
}
}
log.trace {
"Send to: $mqAddress topic: ${message.topicSession.topic} " +
"sessionID: ${message.topicSession.sessionID} uuid: ${message.uniqueMessageId}"
"Send to: $mqAddress topic: ${message.topic} uuid: ${message.uniqueMessageId}"
}
artemis.producer.send(mqAddress, artemisMessage)
retryId?.let {
@ -539,14 +532,12 @@ class P2PMessagingClient(config: NodeConfiguration,
}
}
}
acknowledgementHandler?.invoke()
}
override fun send(addressedMessages: List<MessagingService.AddressedMessage>, acknowledgementHandler: (() -> Unit)?) {
override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey, null)
send(message, target, retryId, sequenceKey)
}
acknowledgementHandler?.invoke()
}
private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) {
@ -622,15 +613,9 @@ class P2PMessagingClient(config: NodeConfiguration,
}
override fun addMessageHandler(topic: String,
sessionID: Long,
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
return addMessageHandler(TopicSession(topic, sessionID), callback)
}
override fun addMessageHandler(topicSession: TopicSession,
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
require(!topicSession.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
val handler = Handler(topicSession, callback)
require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
val handler = Handler(topic, callback)
handlers.add(handler)
return handler
}
@ -639,9 +624,9 @@ class P2PMessagingClient(config: NodeConfiguration,
handlers.remove(registration)
}
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message {
override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
// TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying.
return NodeClientMessage(topicSession, data, uuid)
return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId)
}
// TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port)

View File

@ -1,27 +0,0 @@
package net.corda.node.services.messaging
import net.corda.core.concurrent.CordaFuture
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.serialization.CordaSerializable
/**
* Abstract superclass for request messages sent to services which expect a reply.
*/
@CordaSerializable
interface ServiceRequestMessage {
val sessionID: Long
val replyTo: SingleMessageRecipient
}
/**
* Sends a [ServiceRequestMessage] to [target] and returns a [CordaFuture] of the response.
* @param R The type of the response.
*/
fun <R : Any> MessagingService.sendRequest(topic: String,
request: ServiceRequestMessage,
target: MessageRecipients): CordaFuture<R> {
val responseFuture = onNext<R>(topic, request.sessionID)
send(topic, MessagingService.DEFAULT_SESSION_ID, request, target)
return responseFuture
}

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party
import java.time.Instant
interface FlowIORequest {
@ -22,29 +23,28 @@ interface SendRequest : SessionedFlowIORequest {
val message: SessionMessage
}
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest {
val receiveType: Class<T>
interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest {
val userReceiveType: Class<*>?
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
}
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal,
override val message: SessionMessage,
override val receiveType: Class<T>,
override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest<T> {
data class SendAndReceive(override val session: FlowSessionInternal,
override val message: SessionMessage,
override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInternal,
override val receiveType: Class<T>,
override val userReceiveType: Class<*>?) : ReceiveRequest<T> {
data class ReceiveOnly(
override val session: FlowSessionInternal,
override val userReceiveType: Class<*>?
) : ReceiveRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest {
class ReceiveAll(val requests: List<ReceiveRequest>) : WaitingRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
@ -53,8 +53,8 @@ class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingReque
}
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean {
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd }
private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean {
return it.session.receivedMessages.map { it.message.payload }.any { it is DataSessionMessage || it is EndSessionMessage }
}
@Suspendable
@ -70,7 +70,7 @@ class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingReque
if (isComplete(receivedMessages)) {
receivedMessages
} else {
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" })
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a message but instead got nothing for $it." }.joinToString { "\n" })
}
}
}
@ -90,15 +90,15 @@ class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingReque
}
@Suspendable
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? {
return request.session.receivedMessages.poll()
private fun poll(request: ReceiveRequest): ExistingSessionMessage? {
return request.session.receivedMessages.poll()?.message
}
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>)
data class RequestMessage(val request: ReceiveRequest, val message: ExistingSessionMessage)
}
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
@ -110,7 +110,7 @@ data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachine
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message.payload is ErrorSessionMessage
}
data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest {

View File

@ -17,18 +17,28 @@ import java.util.concurrent.ConcurrentLinkedQueue
class FlowSessionInternal(
val flow: FlowLogic<*>,
val flowSession : FlowSession,
val ourSessionId: Long,
val ourSessionId: SessionId,
val initiatingParty: Party?,
var state: FlowSessionState,
var retryable: Boolean = false) {
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<*>>()
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage>()
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
override fun toString(): String {
return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)"
}
fun getPeerSessionId(): SessionId {
val sessionState = state
return when (sessionState) {
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $this")
}
}
}
data class ReceivedSessionMessage(val peerParty: Party, val message: ExistingSessionMessage)
/**
* [FlowSessionState] describes the session's state.
*
@ -50,7 +60,7 @@ sealed class FlowSessionState {
override val sendToParty: Party get() = otherParty
}
data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowInfo) : FlowSessionState() {
data class Initiated(val peerParty: Party, val peerSessionId: SessionId, val context: FlowInfo) : FlowSessionState() {
override val sendToParty: Party get() = peerParty
}
}

View File

@ -9,7 +9,7 @@ import com.google.common.primitives.Primitives
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue
import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate
@ -31,6 +31,7 @@ import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.IOException
import java.nio.file.Paths
import java.sql.SQLException
import java.time.Duration
@ -180,7 +181,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
requireNonPrimitive(receiveType)
logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." }
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
val receivedSessionData: ReceivedSessionMessage<SessionData> = if (session == null) {
val receivedSessionMessage: ReceivedSessionMessage = if (session == null) {
val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend)
// Only do a receive here as the session init has carried the payload
receiveInternal(newSession, receiveType)
@ -188,8 +189,20 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val sendData = createSessionData(session, payload)
sendAndReceiveInternal(session, sendData, receiveType)
}
logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" }
return receivedSessionData.checkPayloadIs(receiveType)
val sessionData = receivedSessionMessage.message.checkDataSessionMessage()
logger.debug { "Received ${sessionData.payload.toString().abbreviate(300)}" }
return sessionData.checkPayloadIs(receiveType)
}
private fun ExistingSessionMessage.checkDataSessionMessage(): DataSessionMessage {
when (payload) {
is DataSessionMessage -> {
return payload
}
else -> {
throw IllegalStateException("Was expecting ${DataSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
}
}
}
@Suspendable
@ -200,9 +213,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
requireNonPrimitive(receiveType)
logger.debug { "receive(${receiveType.name}, $otherParty) ..." }
val session = getConfirmedSession(otherParty, sessionFlow)
val sessionData = receiveInternal<SessionData>(session, receiveType)
logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" }
return sessionData.checkPayloadIs(receiveType)
val receivedSessionMessage = receiveInternal(session, receiveType).message.checkDataSessionMessage()
logger.debug { "Received ${receivedSessionMessage.payload.toString().abbreviate(300)}" }
return receivedSessionMessage.checkPayloadIs(receiveType)
}
private fun requireNonPrimitive(receiveType: Class<*>) {
@ -219,7 +232,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// Don't send the payload again if it was already piggy-backed on a session init
initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false)
} else {
sendInternal(session, createSessionData(session, payload))
sendInternal(session, ExistingSessionMessage(session.getPeerSessionId(), createSessionData(session, payload)))
}
}
@ -236,8 +249,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// If the tx isn't committed then we may have been resumed due to an session ending in an error
for (session in openSessions.values) {
for (receivedMessage in session.receivedMessages) {
if (receivedMessage.message is ErrorSessionEnd) {
session.erroredEnd(receivedMessage.message)
if (receivedMessage.message.payload is ErrorSessionMessage) {
session.erroredEnd(receivedMessage.message.payload.flowException)
}
}
}
@ -294,16 +307,18 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
val requests = ArrayList<ReceiveOnly<SessionData>>()
val requests = ArrayList<ReceiveOnly>()
for ((session, receiveType) in sessions) {
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType))
requests.add(ReceiveOnly(sessionInternal, receiveType))
}
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
for ((sessionInternal, requestAndMessage) in receivedMessages) {
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request)
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>)
val message = requestAndMessage.message.confirmNoError(requestAndMessage.request.session)
result[sessionInternal.flowSession] = message.checkDataSessionMessage().checkPayloadIs(
requestAndMessage.request.userReceiveType as Class<out Any>
)
}
return result
}
@ -315,41 +330,46 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
*/
@Suspendable
private fun FlowSessionInternal.waitForConfirmation() {
val (peerParty, sessionInitResponse) = receiveInternal<SessionInitResponse>(this, null)
if (sessionInitResponse is SessionConfirm) {
state = FlowSessionState.Initiated(
peerParty,
sessionInitResponse.initiatedSessionId,
FlowInfo(sessionInitResponse.flowVersion, sessionInitResponse.appName))
} else {
sessionInitResponse as SessionReject
throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}")
val sessionInitResponse = receiveInternal(this, null)
val payload = sessionInitResponse.message.payload
when (payload) {
is ConfirmSessionMessage -> {
state = FlowSessionState.Initiated(
sessionInitResponse.
peerParty,
payload.initiatedSessionId,
payload.initiatedFlowInfo)
}
is RejectSessionMessage -> {
throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${payload.message}")
}
else -> {
throw IllegalStateException("Was expecting ${ConfirmSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
}
}
}
private fun createSessionData(session: FlowSessionInternal, payload: Any): SessionData {
val sessionState = session.state
val peerSessionId = when (sessionState) {
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session")
}
return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
private fun createSessionData(session: FlowSessionInternal, payload: Any): DataSessionMessage {
return DataSessionMessage(payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
}
@Suspendable
private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message))
private inline fun <reified M : ExistingSessionMessage> receiveInternal(
@Suspendable
private fun receiveInternal(
session: FlowSessionInternal,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(ReceiveOnly(session, M::class.java, userReceiveType))
userReceiveType: Class<*>?): ReceivedSessionMessage {
return waitForMessage(ReceiveOnly(session, userReceiveType))
}
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal(
@Suspendable
private fun sendAndReceiveInternal(
session: FlowSessionInternal,
message: SessionMessage,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> {
return waitForMessage(SendAndReceive(session, message, M::class.java, userReceiveType))
message: DataSessionMessage,
userReceiveType: Class<*>?): ReceivedSessionMessage {
val sessionMessage = ExistingSessionMessage(session.getPeerSessionId(), message)
return waitForMessage(SendAndReceive(session, sessionMessage, userReceiveType))
}
@Suspendable
@ -377,7 +397,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
sessionFlow: FlowLogic<*>
) {
logger.trace { "Creating a new session with $otherParty" }
val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty))
val session = FlowSessionInternal(sessionFlow, flowSession, SessionId.createRandom(newSecureRandom()), null, FlowSessionState.Uninitiated(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session
}
@ -397,7 +417,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.")
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
val sessionInit = InitialSessionMessage(session.ourSessionId, newSecureRandom().nextLong(), initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
sendInternal(session, sessionInit)
if (waitForConfirmation) {
session.waitForConfirmation()
@ -406,8 +426,10 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
private fun <M : ExistingSessionMessage> waitForMessage(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest)
private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage {
val receivedMessage = receiveRequest.suspendAndExpectReceive()
receivedMessage.message.confirmNoError(receiveRequest.session)
return receivedMessage
}
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
@ -418,7 +440,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> {
private fun ReceiveRequest.suspendAndExpectReceive(): ReceivedSessionMessage {
val polledMessage = session.receivedMessages.poll()
return if (polledMessage != null) {
if (this is SendAndReceive) {
@ -431,35 +453,36 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// Suspend while we wait for a receive
suspend(this)
session.receivedMessages.poll() ?:
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this")
throw IllegalStateException("Was expecting a message but instead got nothing for $this")
}
}
private fun <M : ExistingSessionMessage> ReceivedSessionMessage<*>.confirmReceiveType(
receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> {
val session = receiveRequest.session
val receiveType = receiveRequest.receiveType
if (receiveType.isInstance(message)) {
return uncheckedCast(this)
} else if (message is SessionEnd) {
openSessions.values.remove(session)
if (message is ErrorSessionEnd) {
session.erroredEnd(message)
} else {
val expectedType = receiveRequest.userReceiveType?.name ?: receiveType.simpleName
throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " +
"sending a $expectedType")
private fun ExistingSessionMessage.confirmNoError(session: FlowSessionInternal): ExistingSessionMessage {
when (payload) {
is ConfirmSessionMessage,
is DataSessionMessage -> {
return this
}
is ErrorSessionMessage -> {
openSessions.values.remove(session)
session.erroredEnd(payload.flowException)
}
is RejectSessionMessage -> {
session.erroredEnd(UnexpectedFlowEndException("Counterparty sent session rejection message at unexpected time with message ${payload.message}"))
}
EndSessionMessage -> {
openSessions.values.remove(session)
throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " +
"sending data")
}
} else {
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest")
}
}
private fun FlowSessionInternal.erroredEnd(end: ErrorSessionEnd): Nothing {
if (end.errorResponse != null) {
private fun FlowSessionInternal.erroredEnd(exception: Throwable?): Nothing {
if (exception != null) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(end.errorResponse as java.lang.Throwable).fillInStackTrace()
throw end.errorResponse
exception.fillInStackTrace()
throw exception
} else {
throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated")
}
@ -560,3 +583,15 @@ val Class<out FlowLogic<*>>.appName: String
"<unknown>"
}
}
fun <T : Any> DataSessionMessage.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize(payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
}

View File

@ -1,59 +1,122 @@
package net.corda.node.services.statemachine
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.FlowException
import net.corda.core.flows.UnexpectedFlowEndException
import net.corda.core.identity.Party
import net.corda.core.internal.castIfPossible
import net.corda.core.flows.FlowInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.UntrustworthyData
import java.io.IOException
import java.security.SecureRandom
/**
* A session between two flows is identified by two session IDs, the initiating and the initiated session ID.
* However after the session has been established the communication is symmetric. From then on we differentiate between
* the two session IDs with "source" ID (the ID from which we receive) and "sink" ID (the ID to which we send).
*
* Flow A (initiating) Flow B (initiated)
* initiatingId=sourceId=0
* send(Initiate(initiatingId=0)) -----> initiatingId=sinkId=0
* initiatedId=sourceId=1
* initiatedId=sinkId=1 <----- send(Confirm(initiatedId=1))
*/
@CordaSerializable
sealed class SessionMessage
@CordaSerializable
interface SessionMessage
interface ExistingSessionMessage : SessionMessage {
val recipientSessionId: Long
}
interface SessionInitResponse : ExistingSessionMessage {
val initiatorSessionId: Long
override val recipientSessionId: Long get() = initiatorSessionId
}
interface SessionEnd : ExistingSessionMessage
data class SessionInit(val initiatorSessionId: Long,
val initiatingFlowClass: String,
val flowVersion: Int,
val appName: String,
val firstPayload: SerializedBytes<Any>?) : SessionMessage
data class SessionConfirm(override val initiatorSessionId: Long,
val initiatedSessionId: Long,
val flowVersion: Int,
val appName: String) : SessionInitResponse
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes<Any>) : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
fun <T : Any> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize<T>(message.payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
data class SessionId(val toLong: Long) {
companion object {
fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong())
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
}
/**
* The initial message to initiate a session with.
*
* @param initiatorSessionId the session ID of the initiator. On the sending side this is the *source* ID, on the
* receiving side this is the *sink* ID.
* @param initiationEntropy additional randomness to seed the initiated flow's deduplication ID.
* @param initiatorFlowClassName the class name to be used to determine the initiating-initiated mapping on the receiver
* side.
* @param flowVersion the version of the initiating flow.
* @param appName the name of the cordapp defining the initiating flow, or "corda" if it's a core flow.
* @param firstPayload the optional first payload.
*/
data class InitialSessionMessage(
val initiatorSessionId: SessionId,
val initiationEntropy: Long,
val initiatorFlowClassName: String,
val flowVersion: Int,
val appName: String,
val firstPayload: SerializedBytes<Any>?
) : SessionMessage() {
override fun toString() = "InitialSessionMessage(" +
"initiatorSessionId=$initiatorSessionId, " +
"initiationEntropy=$initiationEntropy, " +
"initiatorFlowClassName=$initiatorFlowClassName, " +
"appName=$appName, " +
"firstPayload=${firstPayload?.javaClass}" +
")"
}
/**
* A message sent when a session has been established already.
*
* @param recipientSessionId the recipient session ID. On the sending side this is the *sink* ID, on the receiving side
* this is the *source* ID.
* @param payload the rest of the message.
*/
data class ExistingSessionMessage(
val recipientSessionId: SessionId,
val payload: ExistingSessionMessagePayload
) : SessionMessage()
/**
* The payload of an [ExistingSessionMessage]
*/
@CordaSerializable
sealed class ExistingSessionMessagePayload
/**
* The confirmation message sent by the initiated side.
* @param initiatedSessionId the initiated session ID, the other half of [InitialSessionMessage.initiatorSessionId].
* This is the *source* ID on the sending(initiated) side, and the *sink* ID on the receiving(initiating) side.
*/
data class ConfirmSessionMessage(
val initiatedSessionId: SessionId,
val initiatedFlowInfo: FlowInfo
) : ExistingSessionMessagePayload()
/**
* A message containing flow-related data.
*
* @param payload the serialised payload.
*/
data class DataSessionMessage(val payload: SerializedBytes<Any>) : ExistingSessionMessagePayload() {
override fun toString() = "DataSessionMessage(payload=${payload.javaClass})"
}
/**
* A message indicating that an error has happened.
*
* @param flowException the exception that happened. This is null if the error condition wasn't revealed to the
* receiving side.
* @param errorId the ID of the source error. This is always specified to allow posteriori correlation of error conditions.
*/
data class ErrorSessionMessage(val flowException: FlowException?, val errorId: Long) : ExistingSessionMessagePayload()
/**
* A message indicating that a session initiation has failed.
*
* @param message a message describing the problem to the initator.
* @param errorId an error ID identifying this error condition.
*/
data class RejectSessionMessage(val message: String, val errorId: Long) : ExistingSessionMessagePayload()
/**
* A message indicating that the flow hosting the session has ended. Note that this message is strictly part of the
* session protocol, the flow may be removed before all counter-flows have ended.
*
* The sole purpose of this message currently is to provide diagnostic in cases where the two communicating flows'
* protocols don't match up, e.g. one is waiting for the other, but the other side has already finished.
*/
object EndSessionMessage : ExistingSessionMessagePayload()

View File

@ -13,6 +13,7 @@ import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.newSecureRandom
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo
@ -37,7 +38,6 @@ import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.TopicSession
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.newNamedSingleThreadExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence
@ -72,7 +72,7 @@ class StateMachineManagerImpl(
companion object {
private val logger = contextLogger()
internal val sessionTopic = TopicSession("platform.session")
internal val sessionTopic = "platform.session"
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
@ -121,8 +121,8 @@ class StateMachineManagerImpl(
private val totalStartedFlows = metrics.counter("Flows.Started")
private val totalFinishedFlows = metrics.counter("Flows.Finished")
private val openSessions = ConcurrentHashMap<Long, FlowSessionInternal>()
private val recentlyClosedSessions = ConcurrentHashMap<Long, Party>()
private val openSessions = ConcurrentHashMap<SessionId, FlowSessionInternal>()
private val recentlyClosedSessions = ConcurrentHashMap<SessionId, Party>()
// Context for tokenized services in checkpoints
private lateinit var tokenizableServices: List<Any>
@ -281,7 +281,7 @@ class StateMachineManagerImpl(
if (sender != null) {
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender)
is SessionInit -> onSessionInit(sessionMessage, message, sender)
is InitialSessionMessage -> onSessionInit(sessionMessage, message, sender)
}
} else {
logger.error("Unknown peer $peer in $sessionMessage")
@ -294,15 +294,15 @@ class StateMachineManagerImpl(
session.fiber.pushToLoggingContext()
session.fiber.logger.trace { "Received $message on $session from $sender" }
if (session.retryable) {
if (message is SessionConfirm && session.state is FlowSessionState.Initiated) {
if (message.payload is ConfirmSessionMessage && session.state is FlowSessionState.Initiated) {
session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} session is idempotent" }
return
}
if (message !is SessionConfirm) {
serviceHub.networkService.cancelRedelivery(session.ourSessionId)
if (message.payload !is ConfirmSessionMessage) {
serviceHub.networkService.cancelRedelivery(session.ourSessionId.toLong)
}
}
if (message is SessionEnd) {
if (message.payload is EndSessionMessage || message.payload is ErrorSessionMessage) {
openSessions.remove(message.recipientSessionId)
}
session.receivedMessages += ReceivedSessionMessage(sender, message)
@ -317,9 +317,9 @@ class StateMachineManagerImpl(
} else {
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
if (peerParty != null) {
if (message is SessionConfirm) {
if (message.payload is ConfirmSessionMessage) {
logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId))
sendSessionMessage(peerParty, ExistingSessionMessage(message.payload.initiatedSessionId, EndSessionMessage))
} else {
logger.trace { "Ignoring session end message for already closed session: $message" }
}
@ -336,12 +336,12 @@ class StateMachineManagerImpl(
return waitingForResponse?.shouldResume(message, session) ?: false
}
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) {
private fun onSessionInit(sessionInit: InitialSessionMessage, receivedMessage: ReceivedMessage, sender: Party) {
logger.trace { "Received $sessionInit from $sender" }
val senderSessionId = sessionInit.initiatorSessionId
fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(senderSessionId, message))
fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, random63BitValue())))
val (session, initiatedFlowFactory) = try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit)
@ -354,11 +354,11 @@ class StateMachineManagerImpl(
val session = FlowSessionInternal(
flow,
flowSession,
random63BitValue(),
SessionId.createRandom(newSecureRandom()),
sender,
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
if (sessionInit.firstPayload != null) {
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload))
session.receivedMessages += ReceivedSessionMessage(sender, ExistingSessionMessage(session.ourSessionId, DataSessionMessage(sessionInit.firstPayload)))
}
openSessions[session.ourSessionId] = session
val context = InvocationContext.peer(sender.name)
@ -386,19 +386,19 @@ class StateMachineManagerImpl(
is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName
}
sendSessionMessage(sender, SessionConfirm(senderSessionId, session.ourSessionId, ourFlowVersion, appName), session.fiber)
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass}" }
sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, ConfirmSessionMessage(session.ourSessionId, FlowInfo(ourFlowVersion, appName))), session.fiber)
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatorFlowClassName}" }
session.fiber.logger.trace { "Initiated from $sessionInit on $session" }
resumeFiber(session.fiber)
}
private fun getInitiatedFlowFactory(sessionInit: SessionInit): InitiatedFlowFactory<*> {
private fun getInitiatedFlowFactory(sessionInit: InitialSessionMessage): InitiatedFlowFactory<*> {
val initiatingFlowClass = try {
Class.forName(sessionInit.initiatingFlowClass, true, classloader).asSubclass(FlowLogic::class.java)
Class.forName(sessionInit.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java)
} catch (e: ClassNotFoundException) {
throw SessionRejectException("Don't know ${sessionInit.initiatingFlowClass}")
throw SessionRejectException("Don't know ${sessionInit.initiatorFlowClassName}")
} catch (e: ClassCastException) {
throw SessionRejectException("${sessionInit.initiatingFlowClass} is not a flow")
throw SessionRejectException("${sessionInit.initiatorFlowClassName} is not a flow")
}
return serviceHub.getFlowFactory(initiatingFlowClass) ?:
throw SessionRejectException("$initiatingFlowClass is not registered")
@ -492,7 +492,7 @@ class StateMachineManagerImpl(
private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) {
val initiatedState = state as? FlowSessionState.Initiated ?: return
val sessionEnd = if (exception == null) {
NormalSessionEnd(initiatedState.peerSessionId)
EndSessionMessage
} else {
val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) {
// Only propagate this FlowException if our local flow threw it or it was propagated to us and we only
@ -501,9 +501,9 @@ class StateMachineManagerImpl(
} else {
null
}
ErrorSessionEnd(initiatedState.peerSessionId, errorResponse)
ErrorSessionMessage(errorResponse, 0)
}
sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber)
sendSessionMessage(initiatedState.peerParty, ExistingSessionMessage(initiatedState.peerSessionId, sessionEnd), fiber)
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
}
@ -573,14 +573,14 @@ class StateMachineManagerImpl(
}
private fun processSendRequest(ioRequest: SendRequest) {
val retryId = if (ioRequest.message is SessionInit) {
val retryId = if (ioRequest.message is InitialSessionMessage) {
with(ioRequest.session) {
openSessions[ourSessionId] = this
if (retryable) ourSessionId else null
if (retryable) ourSessionId.toLong else null
}
} else null
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId)
if (ioRequest !is ReceiveRequest<*>) {
if (ioRequest !is ReceiveRequest) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
resumeFiber(ioRequest.session.fiber)
}
@ -625,12 +625,15 @@ class StateMachineManagerImpl(
// Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface.
is KryoException,
is NotSerializableException -> {
if (message !is ErrorSessionEnd || message.errorResponse == null) throw e
logger.warn("Something in ${message.errorResponse.javaClass.name} is not serialisable. " +
"Instead sending back an exception which is serialisable to ensure session end occurs properly.", e)
// The subclass may have overridden toString so we use that
val exMessage = message.errorResponse.let { if (it.javaClass != FlowException::class.java) it.toString() else it.message }
message.copy(errorResponse = FlowException(exMessage)).serialize()
if (message is ExistingSessionMessage && message.payload is ErrorSessionMessage && message.payload.flowException != null) {
logger.warn("Something in ${message.payload.flowException.javaClass.name} is not serialisable. " +
"Instead sending back an exception which is serialisable to ensure session end occurs properly.", e)
// The subclass may have overridden toString so we use that
val exMessage = message.payload.flowException.message
message.copy(payload = message.payload.copy(flowException = FlowException(exMessage))).serialize()
} else {
throw e
}
}
else -> throw e
}

View File

@ -3,7 +3,6 @@ package net.corda.node.messaging
import net.corda.core.messaging.AllPossibleRecipients
import net.corda.node.services.messaging.Message
import net.corda.node.services.messaging.TopicStringValidator
import net.corda.node.services.messaging.createMessage
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockNetwork
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
@ -51,10 +50,10 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray()
var finalDelivery: Message? = null
node2.network.addMessageHandler { msg, _ ->
node2.network.addMessageHandler("test.topic") { msg, _ ->
node2.network.send(msg, node3.network.myAddress)
}
node3.network.addMessageHandler { msg, _ ->
node3.network.addMessageHandler("test.topic") { msg, _ ->
finalDelivery = msg
}
@ -63,7 +62,7 @@ class InMemoryMessagingTests {
mockNet.runNetwork(rounds = 1)
assertTrue(Arrays.equals(finalDelivery!!.data, bits))
assertTrue(Arrays.equals(finalDelivery!!.data.bytes, bits))
}
@Test
@ -75,7 +74,7 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray()
var counter = 0
listOf(node1, node2, node3).forEach { it.network.addMessageHandler { _, _ -> counter++ } }
listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _ -> counter++ } }
node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>())
mockNet.runNetwork(rounds = 1)
assertEquals(3, counter)

View File

@ -126,7 +126,7 @@ class ArtemisMessagingTest {
messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take()
assertEquals("first msg", String(actual.data))
assertEquals("first msg", String(actual.data.bytes))
assertNull(receivedMessages.poll(200, MILLISECONDS))
}

View File

@ -37,8 +37,8 @@ import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStra
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.node.MockNodeParameters
import net.corda.testing.node.pumpReceive
import net.corda.testing.node.internal.startFlow
import net.corda.testing.node.pumpReceive
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType
@ -296,14 +296,16 @@ class FlowFrameworkTests {
mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
resultFuture.getOrThrow()
}.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive
}
}
@Test
fun `receiving unexpected session end before entering sendAndReceive`() {
bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() }
val sessionEndReceived = Semaphore(0)
receivedSessionMessagesObservable().filter { it.message is SessionEnd }.subscribe { sessionEndReceived.release() }
receivedSessionMessagesObservable().filter {
it.message is ExistingSessionMessage && it.message.payload is EndSessionMessage
}.subscribe { sessionEndReceived.release() }
val resultFuture = aliceNode.services.startFlow(
WaitForOtherSideEndBeforeSendAndReceive(bob, sessionEndReceived)).resultFuture
mockNet.runNetwork()
@ -356,7 +358,7 @@ class FlowFrameworkTests {
assertSessionTransfers(
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent erroredEnd() to aliceNode
bobNode sent errorMessage() to aliceNode
)
}
@ -389,10 +391,11 @@ class FlowFrameworkTests {
assertSessionTransfers(
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent erroredEnd(erroringFlow.get().exceptionThrown) to aliceNode
bobNode sent errorMessage(erroringFlow.get().exceptionThrown) to aliceNode
)
// Make sure the original stack trace isn't sent down the wire
assertThat((receivedSessionMessages.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty()
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat((lastMessage.payload as ErrorSessionMessage).flowException!!.stackTrace).isEmpty()
}
@Test
@ -438,7 +441,7 @@ class FlowFrameworkTests {
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode,
bobNode sent sessionData("Hello") to aliceNode,
aliceNode sent erroredEnd() to bobNode
aliceNode sent errorMessage() to bobNode
)
}
@ -603,20 +606,20 @@ class FlowFrameworkTests {
@Test
fun `unknown class in session init`() {
aliceNode.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), bob)
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob)
mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = receivedSessionMessages.last().message as SessionReject
assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class")
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("Don't know not.a.real.Class")
}
@Test
fun `non-flow class in session init`() {
aliceNode.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), bob)
aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob)
mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = receivedSessionMessages.last().message as SessionReject
assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow")
val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("${String::class.java.name} is not a flow")
}
@Test
@ -682,14 +685,14 @@ class FlowFrameworkTests {
return observable.toFuture()
}
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit {
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
}
private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "")
private fun sessionData(payload: Any) = SessionData(0, payload.serialize())
private val normalEnd = NormalSessionEnd(0)
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse)
private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply {
@ -709,7 +712,9 @@ class FlowFrameworkTests {
}
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null
val isPayloadTransfer: Boolean get() =
message is ExistingSessionMessage && message.payload is DataSessionMessage ||
message is InitialSessionMessage && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to"
}
@ -718,7 +723,7 @@ class FlowFrameworkTests {
}
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.message.topicSession == StateMachineManagerImpl.sessionTopic }.map {
return filter { it.message.topic == StateMachineManagerImpl.sessionTopic }.map {
val from = it.sender.id
val message = it.message.data.deserialize<SessionMessage>()
SessionTransfer(from, sanitise(message), it.recipients)
@ -726,12 +731,23 @@ class FlowFrameworkTests {
}
private fun sanitise(message: SessionMessage) = when (message) {
is SessionData -> message.copy(recipientSessionId = 0)
is SessionInit -> message.copy(initiatorSessionId = 0, appName = "")
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0, appName = "")
is NormalSessionEnd -> message.copy(recipientSessionId = 0)
is ErrorSessionEnd -> message.copy(recipientSessionId = 0)
else -> message
is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "")
is ExistingSessionMessage -> {
val payload = message.payload
message.copy(
recipientSessionId = SessionId(0),
payload = when (payload) {
is ConfirmSessionMessage -> payload.copy(
initiatedSessionId = SessionId(0),
initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "")
)
is ErrorSessionMessage -> payload.copy(
errorId = 0
)
else -> payload
}
)
}
}
private infix fun StartedNode<MockNode>.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)

View File

@ -15,9 +15,7 @@ import net.corda.core.serialization.deserialize
import net.corda.core.utilities.ProgressTracker
import net.corda.netmap.VisualiserViewModel.Style
import net.corda.netmap.simulation.IRSSimulation
import net.corda.node.services.statemachine.SessionConfirm
import net.corda.node.services.statemachine.SessionEnd
import net.corda.node.services.statemachine.SessionInit
import net.corda.node.services.statemachine.*
import net.corda.testing.core.chooseIdentity
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MockNetwork
@ -342,12 +340,16 @@ class NetworkMapVisualiser : Application() {
private fun transferIsInteresting(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
// Loopback messages are boring.
if (transfer.sender == transfer.recipients) return false
val message = transfer.message.data.deserialize<Any>()
val message = transfer.message.data.deserialize<SessionMessage>()
return when (message) {
is SessionEnd -> false
is SessionConfirm -> false
is SessionInit -> message.firstPayload != null
else -> true
is InitialSessionMessage -> message.firstPayload != null
is ExistingSessionMessage -> when (message.payload) {
is ConfirmSessionMessage -> false
is DataSessionMessage -> true
is ErrorSessionMessage -> true
is RejectSessionMessage -> true
is EndSessionMessage -> false
}
}
}
}

View File

@ -14,11 +14,17 @@ import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.trace
import net.corda.node.services.messaging.*
import net.corda.node.services.messaging.Message
import net.corda.node.services.messaging.MessageHandlerRegistration
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.node.InMemoryMessagingNetwork.TestMessagingService
import org.apache.activemq.artemis.utils.ReusableLatch
import org.slf4j.LoggerFactory
import rx.Observable
@ -57,7 +63,7 @@ class InMemoryMessagingNetwork internal constructor(
@CordaSerializable
data class MessageTransfer(val sender: PeerHandle, val message: Message, val recipients: MessageRecipients) {
override fun toString() = "${message.topicSession} from '$sender' to '$recipients'"
override fun toString() = "${message.topic} from '$sender' to '$recipients'"
}
// All sent messages are kept here until pumpSend is called, or manuallyPumped is set to false
@ -241,17 +247,17 @@ class InMemoryMessagingNetwork internal constructor(
_sentMessages.onNext(transfer)
}
data class InMemoryMessage(override val topicSession: TopicSession,
override val data: ByteArray,
override val uniqueMessageId: UUID,
override val debugTimestamp: Instant = Instant.now()) : Message {
override fun toString() = "$topicSession#${String(data)}"
data class InMemoryMessage(override val topic: String,
override val data: ByteSequence,
override val uniqueMessageId: String,
override val debugTimestamp: Instant = Instant.now()) : Message {
override fun toString() = "$topic#${String(data.bytes)}"
}
private data class InMemoryReceivedMessage(override val topicSession: TopicSession,
override val data: ByteArray,
private data class InMemoryReceivedMessage(override val topic: String,
override val data: ByteSequence,
override val platformVersion: Int,
override val uniqueMessageId: UUID,
override val uniqueMessageId: String,
override val debugTimestamp: Instant,
override val peer: CordaX500Name) : ReceivedMessage
@ -270,8 +276,7 @@ class InMemoryMessagingNetwork internal constructor(
private val peerHandle: PeerHandle,
private val executor: AffinityExecutor,
private val database: CordaPersistence) : SingletonSerializeAsToken(), TestMessagingService {
private inner class Handler(val topicSession: TopicSession,
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
inner class Handler(val topicSession: String, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
@Volatile
private var running = true
@ -282,7 +287,7 @@ class InMemoryMessagingNetwork internal constructor(
}
private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>())
private val processedMessages: MutableSet<String> = Collections.synchronizedSet(HashSet<String>())
override val myAddress: PeerHandle get() = peerHandle
@ -304,13 +309,10 @@ class InMemoryMessagingNetwork internal constructor(
}
}
override fun addMessageHandler(topic: String, sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
= addMessageHandler(TopicSession(topic, sessionID), callback)
override fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
override fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
check(running)
val (handler, transfers) = state.locked {
val handler = Handler(topicSession, callback).apply { handlers.add(this) }
val handler = Handler(topic, callback).apply { handlers.add(this) }
val pending = ArrayList<MessageTransfer>()
database.transaction {
pending.addAll(pendingRedelivery)
@ -328,20 +330,18 @@ class InMemoryMessagingNetwork internal constructor(
state.locked { check(handlers.remove(registration as Handler)) }
}
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
check(running)
msgSend(this, message, target)
acknowledgementHandler?.invoke()
if (!sendManuallyPumped) {
pumpSend(false)
}
}
override fun send(addressedMessages: List<MessagingService.AddressedMessage>, acknowledgementHandler: (() -> Unit)?) {
override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey, null)
send(message, target, retryId, sequenceKey)
}
acknowledgementHandler?.invoke()
}
override fun stop() {
@ -356,8 +356,8 @@ class InMemoryMessagingNetwork internal constructor(
override fun cancelRedelivery(retryId: Long) {}
/** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message {
return InMemoryMessage(topicSession, data, uuid)
override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId)
}
/**
@ -390,14 +390,14 @@ class InMemoryMessagingNetwork internal constructor(
while (deliverTo == null) {
val transfer = (if (block) q.take() else q.poll()) ?: return null
deliverTo = state.locked {
val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topicSession == it.topicSession }
val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topic == it.topicSession }
if (matchingHandlers.isEmpty()) {
// Got no handlers for this message yet. Keep the message around and attempt redelivery after a new
// handler has been registered. The purpose of this path is to make unit tests that have multi-threading
// reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting
// up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at
// least sometimes.
log.warn("Message to ${transfer.message.topicSession} could not be delivered")
log.warn("Message to ${transfer.message.topic} could not be delivered")
database.transaction {
pendingRedelivery.add(transfer)
}
@ -438,8 +438,8 @@ class InMemoryMessagingNetwork internal constructor(
}
private fun MessageTransfer.toReceivedMessage(): ReceivedMessage = InMemoryReceivedMessage(
message.topicSession,
message.data.copyOf(), // Kryo messes with the buffer so give each client a unique copy
message.topic,
OpaqueBytes(message.data.bytes.copyOf()), // Kryo messes with the buffer so give each client a unique copy
1,
message.uniqueMessageId,
message.debugTimestamp,