r3corda wire compatibility

This commit is contained in:
Andras Slemmer 2018-02-08 15:13:25 +00:00
parent 1902a4f11e
commit 0a88b76e46
19 changed files with 489 additions and 485 deletions

3
.idea/compiler.xml generated
View File

@ -72,6 +72,7 @@
<module name="irs-demo-web_test" target="1.8" /> <module name="irs-demo-web_test" target="1.8" />
<module name="irs-demo_integrationTest" target="1.8" /> <module name="irs-demo_integrationTest" target="1.8" />
<module name="irs-demo_main" target="1.8" /> <module name="irs-demo_main" target="1.8" />
<module name="irs-demo_systemTest" target="1.8" />
<module name="irs-demo_test" target="1.8" /> <module name="irs-demo_test" target="1.8" />
<module name="isolated_main" target="1.8" /> <module name="isolated_main" target="1.8" />
<module name="isolated_test" target="1.8" /> <module name="isolated_test" target="1.8" />
@ -159,4 +160,4 @@
<component name="JavacSettings"> <component name="JavacSettings">
<option name="ADDITIONAL_OPTIONS_STRING" value="-parameters" /> <option name="ADDITIONAL_OPTIONS_STRING" value="-parameters" />
</component> </component>
</project> </project>

View File

@ -10,12 +10,18 @@ import net.corda.core.identity.Party
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.internal.StartedNode import net.corda.node.internal.StartedNode
import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.Message
import net.corda.node.services.statemachine.SessionData import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.testing.node.* import net.corda.node.services.statemachine.ExistingSessionMessage
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MessagingServiceSpy
import net.corda.testing.node.MockNetwork
import net.corda.testing.node.setMessagingServiceSpy
import net.corda.testing.node.startFlow
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
@ -79,12 +85,12 @@ class TutorialMockNetwork {
// modify message if it's 1 // modify message if it's 1
nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) { nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) {
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
val messageData = message.data.deserialize<Any>() val messageData = message.data.deserialize<Any>() as? ExistingSessionMessage
val payload = messageData?.payload
if (messageData is SessionData && messageData.payload.deserialize() == 1) { if (payload is DataSessionMessage && payload.payload.deserialize() == 1) {
val alteredMessageData = SessionData(messageData.recipientSessionId, 99.serialize()).serialize().bytes val alteredMessageData = messageData.copy(payload = payload.copy(99.serialize())).serialize().bytes
messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topicSession, alteredMessageData, message.uniqueMessageId), target, retryId) messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData), message.uniqueMessageId), target, retryId)
} else { } else {
messagingService.send(message, target, retryId) messagingService.send(message, target, retryId)
} }

View File

@ -3,7 +3,7 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.node.services.statemachine.SessionData import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import net.corda.nodeapi.internal.serialization.amqp.Envelope import net.corda.nodeapi.internal.serialization.amqp.Envelope
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
@ -47,17 +47,17 @@ class ListsSerializationTest {
@Test @Test
fun `check list can be serialized as part of SessionData`() { fun `check list can be serialized as part of SessionData`() {
run { run {
val sessionData = SessionData(123, listOf(1).serialize()) val sessionData = DataSessionMessage(listOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1), sessionData.payload.deserialize()) assertEquals(listOf(1), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, listOf(1, 2).serialize()) val sessionData = DataSessionMessage(listOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(listOf(1, 2), sessionData.payload.deserialize()) assertEquals(listOf(1, 2), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, emptyList<Int>().serialize()) val sessionData = DataSessionMessage(emptyList<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptyList<Int>(), sessionData.payload.deserialize()) assertEquals(emptyList<Int>(), sessionData.payload.deserialize())
} }

View File

@ -6,7 +6,7 @@ import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.kryo.kryoMagic import net.corda.nodeapi.internal.serialization.kryo.kryoMagic
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.amqpSpecific import net.corda.testing.internal.amqpSpecific
@ -41,7 +41,7 @@ class MapsSerializationTest {
@Test @Test
fun `check list can be serialized as part of SessionData`() { fun `check list can be serialized as part of SessionData`() {
val sessionData = SessionData(123, smallMap.serialize()) val sessionData = DataSessionMessage(smallMap.serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(smallMap, sessionData.payload.deserialize()) assertEquals(smallMap, sessionData.payload.deserialize())
} }

View File

@ -4,10 +4,10 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData import net.corda.node.services.statemachine.DataSessionMessage
import net.corda.nodeapi.internal.serialization.kryo.kryoMagic import net.corda.nodeapi.internal.serialization.kryo.kryoMagic
import net.corda.testing.internal.kryoSpecific
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.kryoSpecific
import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertArrayEquals
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
import org.junit.Rule import org.junit.Rule
@ -34,17 +34,17 @@ class SetsSerializationTest {
@Test @Test
fun `check set can be serialized as part of SessionData`() { fun `check set can be serialized as part of SessionData`() {
run { run {
val sessionData = SessionData(123, setOf(1).serialize()) val sessionData = DataSessionMessage(setOf(1).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1), sessionData.payload.deserialize()) assertEquals(setOf(1), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, setOf(1, 2).serialize()) val sessionData = DataSessionMessage(setOf(1, 2).serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(setOf(1, 2), sessionData.payload.deserialize()) assertEquals(setOf(1, 2), sessionData.payload.deserialize())
} }
run { run {
val sessionData = SessionData(123, emptySet<Int>().serialize()) val sessionData = DataSessionMessage(emptySet<Int>().serialize())
assertEqualAfterRoundTripSerialization(sessionData) assertEqualAfterRoundTripSerialization(sessionData)
assertEquals(emptySet<Int>(), sessionData.payload.deserialize()) assertEquals(emptySet<Int>(), sessionData.payload.deserialize())
} }

View File

@ -1,9 +1,9 @@
package net.corda.services.messaging package net.corda.services.messaging
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.random63BitValue
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.map import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.randomOrNull import net.corda.core.internal.randomOrNull
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
@ -14,7 +14,9 @@ import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.node.internal.Node import net.corda.node.internal.Node
import net.corda.node.internal.StartedNode import net.corda.node.internal.StartedNode
import net.corda.node.services.messaging.* import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.send
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.chooseIdentity import net.corda.testing.core.chooseIdentity
import net.corda.testing.driver.DriverDSL import net.corda.testing.driver.DriverDSL
@ -27,6 +29,7 @@ import org.junit.Test
import java.util.* import java.util.*
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
class P2PMessagingTest { class P2PMessagingTest {
@ -50,19 +53,12 @@ class P2PMessagingTest {
alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!) alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!)
} }
val dummyTopic = "dummy.topic"
val responseMessage = "response" val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage) val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
// Send a single request with retry // Send a single request with retry
val responseFuture = with(alice.network) { val responseFuture = alice.receiveFrom(serviceAddress, retryId = 0)
val request = TestRequest(replyTo = myAddress)
val responseFuture = onNext<Any>(dummyTopic, request.sessionID)
val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes)
send(msg, serviceAddress, retryId = request.sessionID)
responseFuture
}
crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS) crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS)
// The request wasn't successful. // The request wasn't successful.
assertThat(responseFuture.isDone).isFalse() assertThat(responseFuture.isDone).isFalse()
@ -83,19 +79,12 @@ class P2PMessagingTest {
alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!) alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!)
} }
val dummyTopic = "dummy.topic"
val responseMessage = "response" val responseMessage = "response"
val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage) val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage)
val sessionId = random63BitValue()
// Send a single request with retry // Send a single request with retry
with(alice.network) { alice.receiveFrom(serviceAddress, retryId = 0)
val request = TestRequest(sessionId, myAddress)
val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes)
send(msg, serviceAddress, retryId = request.sessionID)
}
// Wait until the first request is received // Wait until the first request is received
crashingNodes.firstRequestReceived.await() crashingNodes.firstRequestReceived.await()
@ -108,7 +97,13 @@ class P2PMessagingTest {
// Restart the node and expect a response // Restart the node and expect a response
val aliceRestarted = startAlice() val aliceRestarted = startAlice()
val response = aliceRestarted.network.onNext<Any>(dummyTopic, sessionId).getOrThrow()
val responseFuture = openFuture<Any>()
aliceRestarted.network.runOnNextMessage("test.response") {
responseFuture.set(it.data.deserialize())
}
val response = responseFuture.getOrThrow()
assertThat(crashingNodes.requestsReceived.get()).isGreaterThan(numberOfRequestsReceived) assertThat(crashingNodes.requestsReceived.get()).isGreaterThan(numberOfRequestsReceived)
assertThat(response).isEqualTo(responseMessage) assertThat(response).isEqualTo(responseMessage)
} }
@ -133,11 +128,12 @@ class P2PMessagingTest {
) )
/** /**
* Sets up the [distributedServiceNodes] to respond to [dummyTopic] requests. All nodes will receive requests and * Sets up the [distributedServiceNodes] to respond to "test.request" requests. All nodes will receive requests and
* either ignore them or respond, depending on the value of [CrashingNodes.ignoreRequests], initially set to true. * either ignore them or respond to "test.response", depending on the value of [CrashingNodes.ignoreRequests],
* This may be used to simulate scenarios where nodes receive request messages but crash before sending back a response. * initially set to true. This may be used to simulate scenarios where nodes receive request messages but crash
* before sending back a response.
*/ */
private fun simulateCrashingNodes(distributedServiceNodes: List<StartedNode<*>>, dummyTopic: String, responseMessage: String): CrashingNodes { private fun simulateCrashingNodes(distributedServiceNodes: List<StartedNode<*>>, responseMessage: String): CrashingNodes {
val crashingNodes = CrashingNodes( val crashingNodes = CrashingNodes(
requestsReceived = AtomicInteger(0), requestsReceived = AtomicInteger(0),
firstRequestReceived = CountDownLatch(1), firstRequestReceived = CountDownLatch(1),
@ -146,7 +142,7 @@ class P2PMessagingTest {
distributedServiceNodes.forEach { distributedServiceNodes.forEach {
val nodeName = it.info.chooseIdentity().name val nodeName = it.info.chooseIdentity().name
it.network.addMessageHandler(dummyTopic) { netMessage, _ -> it.network.addMessageHandler("test.request") { netMessage, _ ->
crashingNodes.requestsReceived.incrementAndGet() crashingNodes.requestsReceived.incrementAndGet()
crashingNodes.firstRequestReceived.countDown() crashingNodes.firstRequestReceived.countDown()
// The node which receives the first request will ignore all requests // The node which receives the first request will ignore all requests
@ -158,7 +154,7 @@ class P2PMessagingTest {
} else { } else {
println("sending response") println("sending response")
val request = netMessage.data.deserialize<TestRequest>() val request = netMessage.data.deserialize<TestRequest>()
val response = it.network.createMessage(dummyTopic, request.sessionID, responseMessage.serialize().bytes) val response = it.network.createMessage("test.response", responseMessage.serialize().bytes)
it.network.send(response, request.replyTo) it.network.send(response, request.replyTo)
} }
} }
@ -188,19 +184,39 @@ class P2PMessagingTest {
} }
private fun StartedNode<*>.respondWith(message: Any) { private fun StartedNode<*>.respondWith(message: Any) {
network.addMessageHandler(javaClass.name) { netMessage, _ -> network.addMessageHandler("test.request") { netMessage, _ ->
val request = netMessage.data.deserialize<TestRequest>() val request = netMessage.data.deserialize<TestRequest>()
val response = network.createMessage(javaClass.name, request.sessionID, message.serialize().bytes) val response = network.createMessage("test.response", message.serialize().bytes)
network.send(response, request.replyTo) network.send(response, request.replyTo)
} }
} }
private fun StartedNode<*>.receiveFrom(target: MessageRecipients): CordaFuture<Any> { private fun StartedNode<*>.receiveFrom(target: MessageRecipients, retryId: Long? = null): CordaFuture<Any> {
val request = TestRequest(replyTo = network.myAddress) val response = openFuture<Any>()
return network.sendRequest(javaClass.name, request, target) network.runOnNextMessage("test.response") { netMessage ->
response.set(netMessage.data.deserialize())
}
network.send("test.request", TestRequest(replyTo = network.myAddress), target, retryId = retryId)
return response
}
/**
* Registers a handler for the given topic and session that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler].
*
* @param topic identifier for the topic and session to listen for messages arriving on.
*/
inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topic) { msg, reg ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" }
callback(msg)
}
} }
@CordaSerializable @CordaSerializable
private data class TestRequest(override val sessionID: Long = random63BitValue(), private data class TestRequest(val replyTo: SingleMessageRecipient)
override val replyTo: SingleMessageRecipient) : ServiceRequestMessage
} }

View File

@ -1,18 +1,15 @@
package net.corda.node.services.messaging package net.corda.node.services.messaging
import net.corda.core.concurrent.CordaFuture import co.paralleluniverse.fibers.Suspendable
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.uncheckedCast
import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.ByteSequence
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
/** /**
@ -27,29 +24,6 @@ import javax.annotation.concurrent.ThreadSafe
*/ */
@ThreadSafe @ThreadSafe
interface MessagingService { interface MessagingService {
companion object {
/**
* Session ID to use for services listening for the first message in a session (before a
* specific session ID has been established).
*/
val DEFAULT_SESSION_ID = 0L
}
/**
* The provided function will be invoked for each received message whose topic matches the given string. The callback
* will run on threads provided by the messaging service, and the callback is expected to be thread safe as a result.
*
* The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler].
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
* itself and yet addMessageHandler hasn't returned the handle yet.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun addMessageHandler(topic: String = "", sessionID: Long = DEFAULT_SESSION_ID, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
/** /**
* The provided function will be invoked for each received message whose topic and session matches. The callback * The provided function will be invoked for each received message whose topic and session matches. The callback
* will run on the main server thread provided when the messaging service is constructed, and a database * will run on the main server thread provided when the messaging service is constructed, and a database
@ -59,9 +33,9 @@ interface MessagingService {
* The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister * The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister
* itself and yet addMessageHandler hasn't returned the handle yet. * itself and yet addMessageHandler hasn't returned the handle yet.
* *
* @param topicSession identifier for the topic and session to listen for messages arriving on. * @param topic identifier for the topic to listen for messages arriving on.
*/ */
fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration
/** /**
* Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once * Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once
@ -86,15 +60,13 @@ interface MessagingService {
* @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two * @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two
* subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same * subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same
* sequence the send()s were called. By default this is chosen conservatively to be [target]. * sequence the send()s were called. By default this is chosen conservatively to be [target].
* @param acknowledgementHandler if non-null this handler will be called once the sent message has been committed by
* the broker. Note that if specified [send] itself may return earlier than the commit.
*/ */
@Suspendable
fun send( fun send(
message: Message, message: Message,
target: MessageRecipients, target: MessageRecipients,
retryId: Long? = null, retryId: Long? = null,
sequenceKey: Any = target, sequenceKey: Any = target
acknowledgementHandler: (() -> Unit)? = null
) )
/** A message with a target and sequenceKey specified. */ /** A message with a target and sequenceKey specified. */
@ -110,12 +82,9 @@ interface MessagingService {
* implementation. * implementation.
* *
* @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys. * @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys.
* @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id.
* Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary.
* @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed
* by the broker. Note that if specified [send] itself may return earlier than the commit.
*/ */
fun send(addressedMessages: List<AddressedMessage>, acknowledgementHandler: (() -> Unit)? = null) @Suspendable
fun send(addressedMessages: List<AddressedMessage>)
/** Cancels the scheduled message redelivery for the specified [retryId] */ /** Cancels the scheduled message redelivery for the specified [retryId] */
fun cancelRedelivery(retryId: Long) fun cancelRedelivery(retryId: Long)
@ -123,9 +92,9 @@ interface MessagingService {
/** /**
* Returns an initialised [Message] with the current time, etc, already filled in. * Returns an initialised [Message] with the current time, etc, already filled in.
* *
* @param topicSession identifier for the topic and session the message is sent to. * @param topic identifier for the topic the message is sent to.
*/ */
fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID = UUID.randomUUID()): Message fun createMessage(topic: String, data: ByteArray, deduplicationId: String = UUID.randomUUID().toString()): Message
/** Given information about either a specific node or a service returns its corresponding address */ /** Given information about either a specific node or a service returns its corresponding address */
fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients
@ -134,86 +103,12 @@ interface MessagingService {
val myAddress: SingleMessageRecipient val myAddress: SingleMessageRecipient
} }
/**
* Returns an initialised [Message] with the current time, etc, already filled in.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* Must not be blank.
* @param sessionID identifier for the session the message is part of. For messages sent to services before the
* construction of a session, use [DEFAULT_SESSION_ID].
*/
fun MessagingService.createMessage(topic: String, sessionID: Long = MessagingService.DEFAULT_SESSION_ID, data: ByteArray): Message
= createMessage(TopicSession(topic, sessionID), data)
/** fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: String = UUID.randomUUID().toString(), retryId: Long? = null)
* Registers a handler for the given topic and session ID that runs the given callback with the message and then removes = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId)
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler], as the handler is
* automatically deregistered before the callback runs.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, callback: (ReceivedMessage) -> Unit)
= runOnNextMessage(TopicSession(topic, sessionID), callback)
/**
* Registers a handler for the given topic and session that runs the given callback with the message and then removes
* itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback
* doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler].
*
* @param topicSession identifier for the topic and session to listen for messages arriving on.
*/
inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, crossinline callback: (ReceivedMessage) -> Unit) {
val consumed = AtomicBoolean()
addMessageHandler(topicSession) { msg, reg ->
removeMessageHandler(reg)
check(!consumed.getAndSet(true)) { "Called more than once" }
check(msg.topicSession == topicSession) { "Topic/session mismatch: ${msg.topicSession} vs $topicSession" }
callback(msg)
}
}
/**
* Returns a [CordaFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId.
* The payload is deserialized to an object of type [M]. Any exceptions thrown will be captured by the future.
*/
fun <M : Any> MessagingService.onNext(topic: String, sessionId: Long): CordaFuture<M> {
val messageFuture = openFuture<M>()
runOnNextMessage(topic, sessionId) { message ->
messageFuture.capture {
uncheckedCast(message.data.deserialize<Any>())
}
}
return messageFuture
}
fun MessagingService.send(topic: String, sessionID: Long, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID()) {
send(TopicSession(topic, sessionID), payload, to, uuid)
}
fun MessagingService.send(topicSession: TopicSession, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID(), retryId: Long? = null) {
send(createMessage(topicSession, payload.serialize().bytes, uuid), to, retryId)
}
interface MessageHandlerRegistration interface MessageHandlerRegistration
/**
* An identifier for the endpoint [MessagingService] message handlers listen at.
*
* @param topic identifier for the general subject of the message, for example "platform.network_map.fetch".
* The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]).
* @param sessionID identifier for the session the message is part of. For services listening before
* a session is established, use [DEFAULT_SESSION_ID].
*/
@CordaSerializable
data class TopicSession(val topic: String, val sessionID: Long = MessagingService.DEFAULT_SESSION_ID) {
fun isBlank() = topic.isBlank() && sessionID == MessagingService.DEFAULT_SESSION_ID
override fun toString(): String = "$topic.$sessionID"
}
/** /**
* A message is defined, at this level, to be a (topic, timestamp, byte arrays) triple, where the topic is a string in * A message is defined, at this level, to be a (topic, timestamp, byte arrays) triple, where the topic is a string in
* Java-style reverse dns form, with "platform." being a prefix reserved by the platform for its own use. Vendor * Java-style reverse dns form, with "platform." being a prefix reserved by the platform for its own use. Vendor
@ -226,10 +121,10 @@ data class TopicSession(val topic: String, val sessionID: Long = MessagingServic
*/ */
@CordaSerializable @CordaSerializable
interface Message { interface Message {
val topicSession: TopicSession val topic: String
val data: ByteArray val data: ByteSequence
val debugTimestamp: Instant val debugTimestamp: Instant
val uniqueMessageId: UUID val uniqueMessageId: String
} }
// TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that. // TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that.

View File

@ -13,10 +13,7 @@ import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.*
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.sequence
import net.corda.core.utilities.trace
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.api.NetworkMapCacheInternal
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
@ -98,20 +95,19 @@ class P2PMessagingClient(config: NodeConfiguration,
// that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid // that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid
// confusion. // confusion.
private val topicProperty = SimpleString("platform-topic") private val topicProperty = SimpleString("platform-topic")
private val sessionIdProperty = SimpleString("session-id")
private val cordaVendorProperty = SimpleString("corda-vendor") private val cordaVendorProperty = SimpleString("corda-vendor")
private val releaseVersionProperty = SimpleString("release-version") private val releaseVersionProperty = SimpleString("release-version")
private val platformVersionProperty = SimpleString("platform-version") private val platformVersionProperty = SimpleString("platform-version")
private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt() private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt()
private val messageMaxRetryCount: Int = 3 private val messageMaxRetryCount: Int = 3
fun createProcessedMessage(): AppendOnlyPersistentMap<UUID, Instant, ProcessedMessage, String> { fun createProcessedMessage(): AppendOnlyPersistentMap<String, Instant, ProcessedMessage, String> {
return AppendOnlyPersistentMap( return AppendOnlyPersistentMap(
toPersistentEntityKey = { it.toString() }, toPersistentEntityKey = { it },
fromPersistentEntity = { Pair(UUID.fromString(it.uuid), it.insertionTime) }, fromPersistentEntity = { Pair(it.uuid, it.insertionTime) },
toPersistentEntity = { key: UUID, value: Instant -> toPersistentEntity = { key: String, value: Instant ->
ProcessedMessage().apply { ProcessedMessage().apply {
uuid = key.toString() uuid = key
insertionTime = value insertionTime = value
} }
}, },
@ -139,9 +135,9 @@ class P2PMessagingClient(config: NodeConfiguration,
) )
} }
private class NodeClientMessage(override val topicSession: TopicSession, override val data: ByteArray, override val uniqueMessageId: UUID) : Message { private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: String) : Message {
override val debugTimestamp: Instant = Instant.now() override val debugTimestamp: Instant = Instant.now()
override fun toString() = "$topicSession#${String(data)}" override fun toString() = "$topic#${String(data.bytes)}"
} }
} }
@ -160,7 +156,7 @@ class P2PMessagingClient(config: NodeConfiguration,
private val scheduledMessageRedeliveries = ConcurrentHashMap<Long, ScheduledFuture<*>>() private val scheduledMessageRedeliveries = ConcurrentHashMap<Long, ScheduledFuture<*>>()
/** A registration to handle messages of different types */ /** A registration to handle messages of different types */
data class Handler(val topicSession: TopicSession, data class Handler(val topic: String,
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
private val cordaVendor = SimpleString(versionInfo.vendor) private val cordaVendor = SimpleString(versionInfo.vendor)
@ -181,7 +177,7 @@ class P2PMessagingClient(config: NodeConfiguration,
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids")
class ProcessedMessage( class ProcessedMessage(
@Id @Id
@Column(name = "message_id", length = 36) @Column(name = "message_id", length = 64)
var uuid: String = "", var uuid: String = "",
@Column(name = "insertion_time") @Column(name = "insertion_time")
@ -192,7 +188,7 @@ class P2PMessagingClient(config: NodeConfiguration,
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry") @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry")
class RetryMessage( class RetryMessage(
@Id @Id
@Column(name = "message_id", length = 36) @Column(name = "message_id", length = 64)
var key: Long = 0, var key: Long = 0,
@Lob @Lob
@ -383,14 +379,13 @@ class P2PMessagingClient(config: NodeConfiguration,
private fun artemisToCordaMessage(message: ClientMessage): ReceivedMessage? { private fun artemisToCordaMessage(message: ClientMessage): ReceivedMessage? {
try { try {
val topic = message.required(topicProperty) { getStringProperty(it) } val topic = message.required(topicProperty) { getStringProperty(it) }
val sessionID = message.required(sessionIdProperty) { getLongProperty(it) }
val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" } val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" }
val platformVersion = message.required(platformVersionProperty) { getIntProperty(it) } val platformVersion = message.required(platformVersionProperty) { getIntProperty(it) }
// Use the magic deduplication property built into Artemis as our message identity too // Use the magic deduplication property built into Artemis as our message identity too
val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { UUID.fromString(message.getStringProperty(it)) } val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { message.getStringProperty(it) }
log.info("Received message from: ${message.address} user: $user topic: $topic sessionID: $sessionID uuid: $uuid") log.info("Received message from: ${message.address} user: $user topic: $topic uuid: $uuid")
return ArtemisReceivedMessage(TopicSession(topic, sessionID), CordaX500Name.parse(user), platformVersion, uuid, message) return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uuid, message)
} catch (e: Exception) { } catch (e: Exception) {
log.error("Unable to process message, ignoring it: $message", e) log.error("Unable to process message, ignoring it: $message", e)
return null return null
@ -402,21 +397,21 @@ class P2PMessagingClient(config: NodeConfiguration,
return extractor(key) return extractor(key)
} }
private class ArtemisReceivedMessage(override val topicSession: TopicSession, private class ArtemisReceivedMessage(override val topic: String,
override val peer: CordaX500Name, override val peer: CordaX500Name,
override val platformVersion: Int, override val platformVersion: Int,
override val uniqueMessageId: UUID, override val uniqueMessageId: String,
private val message: ClientMessage) : ReceivedMessage { private val message: ClientMessage) : ReceivedMessage {
override val data: ByteArray by lazy { ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } } override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) }
override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp) override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp)
override fun toString() = "${topicSession.topic}#${data.sequence()}" override fun toString() = "$topic#$data"
} }
private fun deliver(msg: ReceivedMessage): Boolean { private fun deliver(msg: ReceivedMessage): Boolean {
state.checkNotLocked() state.checkNotLocked()
// Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added // Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added
// or removed whilst the filter is executing will not affect anything. // or removed whilst the filter is executing will not affect anything.
val deliverTo = handlers.filter { it.topicSession.isBlank() || it.topicSession == msg.topicSession } val deliverTo = handlers.filter { it.topic.isBlank() || it.topic== msg.topic }
try { try {
// This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will // This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will
// be slow, and Artemis can handle that case intelligently. We don't just invoke the handler // be slow, and Artemis can handle that case intelligently. We don't just invoke the handler
@ -429,11 +424,11 @@ class P2PMessagingClient(config: NodeConfiguration,
nodeExecutor.fetchFrom { nodeExecutor.fetchFrom {
database.transaction { database.transaction {
if (msg.uniqueMessageId in processedMessages) { if (msg.uniqueMessageId in processedMessages) {
log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topicSession}" } log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" }
} else { } else {
if (deliverTo.isEmpty()) { if (deliverTo.isEmpty()) {
// TODO: Implement dead letter queue, and send it there. // TODO: Implement dead letter queue, and send it there.
log.warn("Received message ${msg.uniqueMessageId} for ${msg.topicSession} that doesn't have any registered handlers yet") log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet")
} else { } else {
callHandlers(msg, deliverTo) callHandlers(msg, deliverTo)
} }
@ -443,7 +438,7 @@ class P2PMessagingClient(config: NodeConfiguration,
} }
} }
} catch (e: Exception) { } catch (e: Exception) {
log.error("Caught exception whilst executing message handler for ${msg.topicSession}", e) log.error("Caught exception whilst executing message handler for ${msg.topic}", e)
} }
return true return true
} }
@ -501,7 +496,7 @@ class P2PMessagingClient(config: NodeConfiguration,
} }
} }
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
// We have to perform sending on a different thread pool, since using the same pool for messaging and // We have to perform sending on a different thread pool, since using the same pool for messaging and
// fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals. // fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals.
messagingExecutor.fetchFrom { messagingExecutor.fetchFrom {
@ -512,20 +507,18 @@ class P2PMessagingClient(config: NodeConfiguration,
putStringProperty(cordaVendorProperty, cordaVendor) putStringProperty(cordaVendorProperty, cordaVendor)
putStringProperty(releaseVersionProperty, releaseVersion) putStringProperty(releaseVersionProperty, releaseVersion)
putIntProperty(platformVersionProperty, versionInfo.platformVersion) putIntProperty(platformVersionProperty, versionInfo.platformVersion)
putStringProperty(topicProperty, SimpleString(message.topicSession.topic)) putStringProperty(topicProperty, SimpleString(message.topic))
putLongProperty(sessionIdProperty, message.topicSession.sessionID) writeBodyBufferBytes(message.data.bytes)
writeBodyBufferBytes(message.data)
// Use the magic deduplication property built into Artemis as our message identity too // Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString())) putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString()))
// For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended
if (amqDelayMillis > 0 && message.topicSession.topic == StateMachineManagerImpl.sessionTopic.topic) { if (amqDelayMillis > 0 && message.topic == StateMachineManagerImpl.sessionTopic) {
putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis)
} }
} }
log.trace { log.trace {
"Send to: $mqAddress topic: ${message.topicSession.topic} " + "Send to: $mqAddress topic: ${message.topic} uuid: ${message.uniqueMessageId}"
"sessionID: ${message.topicSession.sessionID} uuid: ${message.uniqueMessageId}"
} }
artemis.producer.send(mqAddress, artemisMessage) artemis.producer.send(mqAddress, artemisMessage)
retryId?.let { retryId?.let {
@ -539,14 +532,12 @@ class P2PMessagingClient(config: NodeConfiguration,
} }
} }
} }
acknowledgementHandler?.invoke()
} }
override fun send(addressedMessages: List<MessagingService.AddressedMessage>, acknowledgementHandler: (() -> Unit)?) { override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) { for ((message, target, retryId, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey, null) send(message, target, retryId, sequenceKey)
} }
acknowledgementHandler?.invoke()
} }
private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) { private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) {
@ -622,15 +613,9 @@ class P2PMessagingClient(config: NodeConfiguration,
} }
override fun addMessageHandler(topic: String, override fun addMessageHandler(topic: String,
sessionID: Long,
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
return addMessageHandler(TopicSession(topic, sessionID), callback) require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
} val handler = Handler(topic, callback)
override fun addMessageHandler(topicSession: TopicSession,
callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
require(!topicSession.isBlank()) { "Topic must not be blank, as the empty topic is a special case." }
val handler = Handler(topicSession, callback)
handlers.add(handler) handlers.add(handler)
return handler return handler
} }
@ -639,9 +624,9 @@ class P2PMessagingClient(config: NodeConfiguration,
handlers.remove(registration) handlers.remove(registration)
} }
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message { override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
// TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying. // TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying.
return NodeClientMessage(topicSession, data, uuid) return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId)
} }
// TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port) // TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port)

View File

@ -1,27 +0,0 @@
package net.corda.node.services.messaging
import net.corda.core.concurrent.CordaFuture
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.serialization.CordaSerializable
/**
* Abstract superclass for request messages sent to services which expect a reply.
*/
@CordaSerializable
interface ServiceRequestMessage {
val sessionID: Long
val replyTo: SingleMessageRecipient
}
/**
* Sends a [ServiceRequestMessage] to [target] and returns a [CordaFuture] of the response.
* @param R The type of the response.
*/
fun <R : Any> MessagingService.sendRequest(topic: String,
request: ServiceRequestMessage,
target: MessageRecipients): CordaFuture<R> {
val responseFuture = onNext<R>(topic, request.sessionID)
send(topic, MessagingService.DEFAULT_SESSION_ID, request, target)
return responseFuture
}

View File

@ -2,6 +2,7 @@ package net.corda.node.services.statemachine
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party
import java.time.Instant import java.time.Instant
interface FlowIORequest { interface FlowIORequest {
@ -22,29 +23,28 @@ interface SendRequest : SessionedFlowIORequest {
val message: SessionMessage val message: SessionMessage
} }
interface ReceiveRequest<T : SessionMessage> : SessionedFlowIORequest, WaitingRequest { interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest {
val receiveType: Class<T>
val userReceiveType: Class<*>? val userReceiveType: Class<*>?
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session
} }
data class SendAndReceive<T : SessionMessage>(override val session: FlowSessionInternal, data class SendAndReceive(override val session: FlowSessionInternal,
override val message: SessionMessage, override val message: SessionMessage,
override val receiveType: Class<T>, override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest {
override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest<T> {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }
data class ReceiveOnly<T : SessionMessage>(override val session: FlowSessionInternal, data class ReceiveOnly(
override val receiveType: Class<T>, override val session: FlowSessionInternal,
override val userReceiveType: Class<*>?) : ReceiveRequest<T> { override val userReceiveType: Class<*>?
) : ReceiveRequest {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
} }
class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingRequest { class ReceiveAll(val requests: List<ReceiveRequest>) : WaitingRequest {
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
@ -53,8 +53,8 @@ class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingReque
} }
private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) } private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) }
private fun hasSuccessfulEndMessage(it: ReceiveRequest<SessionData>): Boolean { private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean {
return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd } return it.session.receivedMessages.map { it.message.payload }.any { it is DataSessionMessage || it is EndSessionMessage }
} }
@Suspendable @Suspendable
@ -70,7 +70,7 @@ class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingReque
if (isComplete(receivedMessages)) { if (isComplete(receivedMessages)) {
receivedMessages receivedMessages
} else { } else {
throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" }) throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a message but instead got nothing for $it." }.joinToString { "\n" })
} }
} }
} }
@ -90,15 +90,15 @@ class ReceiveAll(val requests: List<ReceiveRequest<SessionData>>) : WaitingReque
} }
@Suspendable @Suspendable
private fun poll(request: ReceiveRequest<SessionData>): ReceivedSessionMessage<*>? { private fun poll(request: ReceiveRequest): ExistingSessionMessage? {
return request.session.receivedMessages.poll() return request.session.receivedMessages.poll()?.message
} }
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant() override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant()
private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session } private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session }
data class RequestMessage(val request: ReceiveRequest<SessionData>, val message: ReceivedSessionMessage<*>) data class RequestMessage(val request: ReceiveRequest, val message: ExistingSessionMessage)
} }
data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest { data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest {
@ -110,7 +110,7 @@ data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachine
@Transient @Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message.payload is ErrorSessionMessage
} }
data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest { data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest {

View File

@ -17,18 +17,28 @@ import java.util.concurrent.ConcurrentLinkedQueue
class FlowSessionInternal( class FlowSessionInternal(
val flow: FlowLogic<*>, val flow: FlowLogic<*>,
val flowSession : FlowSession, val flowSession : FlowSession,
val ourSessionId: Long, val ourSessionId: SessionId,
val initiatingParty: Party?, val initiatingParty: Party?,
var state: FlowSessionState, var state: FlowSessionState,
var retryable: Boolean = false) { var retryable: Boolean = false) {
val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage<*>>() val receivedMessages = ConcurrentLinkedQueue<ReceivedSessionMessage>()
val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*> val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*>
override fun toString(): String { override fun toString(): String {
return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)" return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)"
} }
fun getPeerSessionId(): SessionId {
val sessionState = state
return when (sessionState) {
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $this")
}
}
} }
data class ReceivedSessionMessage(val peerParty: Party, val message: ExistingSessionMessage)
/** /**
* [FlowSessionState] describes the session's state. * [FlowSessionState] describes the session's state.
* *
@ -50,7 +60,7 @@ sealed class FlowSessionState {
override val sendToParty: Party get() = otherParty override val sendToParty: Party get() = otherParty
} }
data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowInfo) : FlowSessionState() { data class Initiated(val peerParty: Party, val peerSessionId: SessionId, val context: FlowInfo) : FlowSessionState() {
override val sendToParty: Party get() = peerParty override val sendToParty: Party get() = peerParty
} }
} }

View File

@ -9,7 +9,7 @@ import com.google.common.primitives.Primitives
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.newSecureRandom
import net.corda.core.flows.* import net.corda.core.flows.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.PartyAndCertificate
@ -31,6 +31,7 @@ import net.corda.nodeapi.internal.persistence.contextTransaction
import net.corda.nodeapi.internal.persistence.contextTransactionOrNull import net.corda.nodeapi.internal.persistence.contextTransactionOrNull
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.io.IOException
import java.nio.file.Paths import java.nio.file.Paths
import java.sql.SQLException import java.sql.SQLException
import java.time.Duration import java.time.Duration
@ -180,7 +181,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
requireNonPrimitive(receiveType) requireNonPrimitive(receiveType)
logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." }
val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) val session = getConfirmedSessionIfPresent(otherParty, sessionFlow)
val receivedSessionData: ReceivedSessionMessage<SessionData> = if (session == null) { val receivedSessionMessage: ReceivedSessionMessage = if (session == null) {
val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend) val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend)
// Only do a receive here as the session init has carried the payload // Only do a receive here as the session init has carried the payload
receiveInternal(newSession, receiveType) receiveInternal(newSession, receiveType)
@ -188,8 +189,20 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val sendData = createSessionData(session, payload) val sendData = createSessionData(session, payload)
sendAndReceiveInternal(session, sendData, receiveType) sendAndReceiveInternal(session, sendData, receiveType)
} }
logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" } val sessionData = receivedSessionMessage.message.checkDataSessionMessage()
return receivedSessionData.checkPayloadIs(receiveType) logger.debug { "Received ${sessionData.payload.toString().abbreviate(300)}" }
return sessionData.checkPayloadIs(receiveType)
}
private fun ExistingSessionMessage.checkDataSessionMessage(): DataSessionMessage {
when (payload) {
is DataSessionMessage -> {
return payload
}
else -> {
throw IllegalStateException("Was expecting ${DataSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
}
}
} }
@Suspendable @Suspendable
@ -200,9 +213,9 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
requireNonPrimitive(receiveType) requireNonPrimitive(receiveType)
logger.debug { "receive(${receiveType.name}, $otherParty) ..." } logger.debug { "receive(${receiveType.name}, $otherParty) ..." }
val session = getConfirmedSession(otherParty, sessionFlow) val session = getConfirmedSession(otherParty, sessionFlow)
val sessionData = receiveInternal<SessionData>(session, receiveType) val receivedSessionMessage = receiveInternal(session, receiveType).message.checkDataSessionMessage()
logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" } logger.debug { "Received ${receivedSessionMessage.payload.toString().abbreviate(300)}" }
return sessionData.checkPayloadIs(receiveType) return receivedSessionMessage.checkPayloadIs(receiveType)
} }
private fun requireNonPrimitive(receiveType: Class<*>) { private fun requireNonPrimitive(receiveType: Class<*>) {
@ -219,7 +232,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// Don't send the payload again if it was already piggy-backed on a session init // Don't send the payload again if it was already piggy-backed on a session init
initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false) initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false)
} else { } else {
sendInternal(session, createSessionData(session, payload)) sendInternal(session, ExistingSessionMessage(session.getPeerSessionId(), createSessionData(session, payload)))
} }
} }
@ -236,8 +249,8 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// If the tx isn't committed then we may have been resumed due to an session ending in an error // If the tx isn't committed then we may have been resumed due to an session ending in an error
for (session in openSessions.values) { for (session in openSessions.values) {
for (receivedMessage in session.receivedMessages) { for (receivedMessage in session.receivedMessages) {
if (receivedMessage.message is ErrorSessionEnd) { if (receivedMessage.message.payload is ErrorSessionMessage) {
session.erroredEnd(receivedMessage.message) session.erroredEnd(receivedMessage.message.payload.flowException)
} }
} }
} }
@ -294,16 +307,18 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> { override fun receiveAll(sessions: Map<FlowSession, Class<out Any>>, sessionFlow: FlowLogic<*>): Map<FlowSession, UntrustworthyData<Any>> {
val requests = ArrayList<ReceiveOnly<SessionData>>() val requests = ArrayList<ReceiveOnly>()
for ((session, receiveType) in sessions) { for ((session, receiveType) in sessions) {
val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow) val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow)
requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType)) requests.add(ReceiveOnly(sessionInternal, receiveType))
} }
val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend) val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend)
val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>() val result = LinkedHashMap<FlowSession, UntrustworthyData<Any>>()
for ((sessionInternal, requestAndMessage) in receivedMessages) { for ((sessionInternal, requestAndMessage) in receivedMessages) {
val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request) val message = requestAndMessage.message.confirmNoError(requestAndMessage.request.session)
result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class<out Any>) result[sessionInternal.flowSession] = message.checkDataSessionMessage().checkPayloadIs(
requestAndMessage.request.userReceiveType as Class<out Any>
)
} }
return result return result
} }
@ -315,41 +330,46 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
*/ */
@Suspendable @Suspendable
private fun FlowSessionInternal.waitForConfirmation() { private fun FlowSessionInternal.waitForConfirmation() {
val (peerParty, sessionInitResponse) = receiveInternal<SessionInitResponse>(this, null) val sessionInitResponse = receiveInternal(this, null)
if (sessionInitResponse is SessionConfirm) { val payload = sessionInitResponse.message.payload
state = FlowSessionState.Initiated( when (payload) {
peerParty, is ConfirmSessionMessage -> {
sessionInitResponse.initiatedSessionId, state = FlowSessionState.Initiated(
FlowInfo(sessionInitResponse.flowVersion, sessionInitResponse.appName)) sessionInitResponse.
} else { peerParty,
sessionInitResponse as SessionReject payload.initiatedSessionId,
throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}") payload.initiatedFlowInfo)
}
is RejectSessionMessage -> {
throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${payload.message}")
}
else -> {
throw IllegalStateException("Was expecting ${ConfirmSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead")
}
} }
} }
private fun createSessionData(session: FlowSessionInternal, payload: Any): SessionData { private fun createSessionData(session: FlowSessionInternal, payload: Any): DataSessionMessage {
val sessionState = session.state return DataSessionMessage(payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
val peerSessionId = when (sessionState) {
is FlowSessionState.Initiated -> sessionState.peerSessionId
else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session")
}
return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT))
} }
@Suspendable @Suspendable
private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message)) private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message))
private inline fun <reified M : ExistingSessionMessage> receiveInternal( @Suspendable
private fun receiveInternal(
session: FlowSessionInternal, session: FlowSessionInternal,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> { userReceiveType: Class<*>?): ReceivedSessionMessage {
return waitForMessage(ReceiveOnly(session, M::class.java, userReceiveType)) return waitForMessage(ReceiveOnly(session, userReceiveType))
} }
private inline fun <reified M : ExistingSessionMessage> sendAndReceiveInternal( @Suspendable
private fun sendAndReceiveInternal(
session: FlowSessionInternal, session: FlowSessionInternal,
message: SessionMessage, message: DataSessionMessage,
userReceiveType: Class<*>?): ReceivedSessionMessage<M> { userReceiveType: Class<*>?): ReceivedSessionMessage {
return waitForMessage(SendAndReceive(session, message, M::class.java, userReceiveType)) val sessionMessage = ExistingSessionMessage(session.getPeerSessionId(), message)
return waitForMessage(SendAndReceive(session, sessionMessage, userReceiveType))
} }
@Suspendable @Suspendable
@ -377,7 +397,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
sessionFlow: FlowLogic<*> sessionFlow: FlowLogic<*>
) { ) {
logger.trace { "Creating a new session with $otherParty" } logger.trace { "Creating a new session with $otherParty" }
val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty)) val session = FlowSessionInternal(sessionFlow, flowSession, SessionId.createRandom(newSecureRandom()), null, FlowSessionState.Uninitiated(otherParty))
openSessions[Pair(sessionFlow, otherParty)] = session openSessions[Pair(sessionFlow, otherParty)] = session
} }
@ -397,7 +417,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass
val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT) val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT)
logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.") logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.")
val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes) val sessionInit = InitialSessionMessage(session.ourSessionId, newSecureRandom().nextLong(), initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes)
sendInternal(session, sessionInit) sendInternal(session, sessionInit)
if (waitForConfirmation) { if (waitForConfirmation) {
session.waitForConfirmation() session.waitForConfirmation()
@ -406,8 +426,10 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
@Suspendable @Suspendable
private fun <M : ExistingSessionMessage> waitForMessage(receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> { private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage {
return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest) val receivedMessage = receiveRequest.suspendAndExpectReceive()
receivedMessage.message.confirmNoError(receiveRequest.session)
return receivedMessage
} }
private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend { private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend {
@ -418,7 +440,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
} }
@Suspendable @Suspendable
private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { private fun ReceiveRequest.suspendAndExpectReceive(): ReceivedSessionMessage {
val polledMessage = session.receivedMessages.poll() val polledMessage = session.receivedMessages.poll()
return if (polledMessage != null) { return if (polledMessage != null) {
if (this is SendAndReceive) { if (this is SendAndReceive) {
@ -431,35 +453,36 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
// Suspend while we wait for a receive // Suspend while we wait for a receive
suspend(this) suspend(this)
session.receivedMessages.poll() ?: session.receivedMessages.poll() ?:
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this") throw IllegalStateException("Was expecting a message but instead got nothing for $this")
} }
} }
private fun <M : ExistingSessionMessage> ReceivedSessionMessage<*>.confirmReceiveType( private fun ExistingSessionMessage.confirmNoError(session: FlowSessionInternal): ExistingSessionMessage {
receiveRequest: ReceiveRequest<M>): ReceivedSessionMessage<M> { when (payload) {
val session = receiveRequest.session is ConfirmSessionMessage,
val receiveType = receiveRequest.receiveType is DataSessionMessage -> {
if (receiveType.isInstance(message)) { return this
return uncheckedCast(this) }
} else if (message is SessionEnd) { is ErrorSessionMessage -> {
openSessions.values.remove(session) openSessions.values.remove(session)
if (message is ErrorSessionEnd) { session.erroredEnd(payload.flowException)
session.erroredEnd(message) }
} else { is RejectSessionMessage -> {
val expectedType = receiveRequest.userReceiveType?.name ?: receiveType.simpleName session.erroredEnd(UnexpectedFlowEndException("Counterparty sent session rejection message at unexpected time with message ${payload.message}"))
throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " + }
"sending a $expectedType") EndSessionMessage -> {
openSessions.values.remove(session)
throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " +
"sending data")
} }
} else {
throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest")
} }
} }
private fun FlowSessionInternal.erroredEnd(end: ErrorSessionEnd): Nothing { private fun FlowSessionInternal.erroredEnd(exception: Throwable?): Nothing {
if (end.errorResponse != null) { if (exception != null) {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
(end.errorResponse as java.lang.Throwable).fillInStackTrace() exception.fillInStackTrace()
throw end.errorResponse throw exception
} else { } else {
throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated") throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated")
} }
@ -560,3 +583,15 @@ val Class<out FlowLogic<*>>.appName: String
"<unknown>" "<unknown>"
} }
} }
fun <T : Any> DataSessionMessage.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize(payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
}
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
}

View File

@ -1,59 +1,122 @@
package net.corda.node.services.statemachine package net.corda.node.services.statemachine
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.UnexpectedFlowEndException import net.corda.core.flows.FlowInfo
import net.corda.core.identity.Party
import net.corda.core.internal.castIfPossible
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.UntrustworthyData import java.security.SecureRandom
import java.io.IOException
/**
* A session between two flows is identified by two session IDs, the initiating and the initiated session ID.
* However after the session has been established the communication is symmetric. From then on we differentiate between
* the two session IDs with "source" ID (the ID from which we receive) and "sink" ID (the ID to which we send).
*
* Flow A (initiating) Flow B (initiated)
* initiatingId=sourceId=0
* send(Initiate(initiatingId=0)) -----> initiatingId=sinkId=0
* initiatedId=sourceId=1
* initiatedId=sinkId=1 <----- send(Confirm(initiatedId=1))
*/
@CordaSerializable
sealed class SessionMessage
@CordaSerializable @CordaSerializable
interface SessionMessage data class SessionId(val toLong: Long) {
companion object {
interface ExistingSessionMessage : SessionMessage { fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong())
val recipientSessionId: Long
}
interface SessionInitResponse : ExistingSessionMessage {
val initiatorSessionId: Long
override val recipientSessionId: Long get() = initiatorSessionId
}
interface SessionEnd : ExistingSessionMessage
data class SessionInit(val initiatorSessionId: Long,
val initiatingFlowClass: String,
val flowVersion: Int,
val appName: String,
val firstPayload: SerializedBytes<Any>?) : SessionMessage
data class SessionConfirm(override val initiatorSessionId: Long,
val initiatedSessionId: Long,
val flowVersion: Int,
val appName: String) : SessionInitResponse
data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse
data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes<Any>) : ExistingSessionMessage
data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd
data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd
data class ReceivedSessionMessage<out M : ExistingSessionMessage>(val sender: Party, val message: M)
fun <T : Any> ReceivedSessionMessage<SessionData>.checkPayloadIs(type: Class<T>): UntrustworthyData<T> {
val payloadData: T = try {
val serializer = SerializationDefaults.SERIALIZATION_FACTORY
serializer.deserialize<T>(message.payload, type, SerializationDefaults.P2P_CONTEXT)
} catch (ex: Exception) {
throw IOException("Payload invalid", ex)
} }
return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?:
throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " +
"${payloadData.javaClass.name} (${payloadData})")
} }
/**
* The initial message to initiate a session with.
*
* @param initiatorSessionId the session ID of the initiator. On the sending side this is the *source* ID, on the
* receiving side this is the *sink* ID.
* @param initiationEntropy additional randomness to seed the initiated flow's deduplication ID.
* @param initiatorFlowClassName the class name to be used to determine the initiating-initiated mapping on the receiver
* side.
* @param flowVersion the version of the initiating flow.
* @param appName the name of the cordapp defining the initiating flow, or "corda" if it's a core flow.
* @param firstPayload the optional first payload.
*/
data class InitialSessionMessage(
val initiatorSessionId: SessionId,
val initiationEntropy: Long,
val initiatorFlowClassName: String,
val flowVersion: Int,
val appName: String,
val firstPayload: SerializedBytes<Any>?
) : SessionMessage() {
override fun toString() = "InitialSessionMessage(" +
"initiatorSessionId=$initiatorSessionId, " +
"initiationEntropy=$initiationEntropy, " +
"initiatorFlowClassName=$initiatorFlowClassName, " +
"appName=$appName, " +
"firstPayload=${firstPayload?.javaClass}" +
")"
}
/**
* A message sent when a session has been established already.
*
* @param recipientSessionId the recipient session ID. On the sending side this is the *sink* ID, on the receiving side
* this is the *source* ID.
* @param payload the rest of the message.
*/
data class ExistingSessionMessage(
val recipientSessionId: SessionId,
val payload: ExistingSessionMessagePayload
) : SessionMessage()
/**
* The payload of an [ExistingSessionMessage]
*/
@CordaSerializable
sealed class ExistingSessionMessagePayload
/**
* The confirmation message sent by the initiated side.
* @param initiatedSessionId the initiated session ID, the other half of [InitialSessionMessage.initiatorSessionId].
* This is the *source* ID on the sending(initiated) side, and the *sink* ID on the receiving(initiating) side.
*/
data class ConfirmSessionMessage(
val initiatedSessionId: SessionId,
val initiatedFlowInfo: FlowInfo
) : ExistingSessionMessagePayload()
/**
* A message containing flow-related data.
*
* @param payload the serialised payload.
*/
data class DataSessionMessage(val payload: SerializedBytes<Any>) : ExistingSessionMessagePayload() {
override fun toString() = "DataSessionMessage(payload=${payload.javaClass})"
}
/**
* A message indicating that an error has happened.
*
* @param flowException the exception that happened. This is null if the error condition wasn't revealed to the
* receiving side.
* @param errorId the ID of the source error. This is always specified to allow posteriori correlation of error conditions.
*/
data class ErrorSessionMessage(val flowException: FlowException?, val errorId: Long) : ExistingSessionMessagePayload()
/**
* A message indicating that a session initiation has failed.
*
* @param message a message describing the problem to the initator.
* @param errorId an error ID identifying this error condition.
*/
data class RejectSessionMessage(val message: String, val errorId: Long) : ExistingSessionMessagePayload()
/**
* A message indicating that the flow hosting the session has ended. Note that this message is strictly part of the
* session protocol, the flow may be removed before all counter-flows have ended.
*
* The sole purpose of this message currently is to provide diagnostic in cases where the two communicating flows'
* protocols don't match up, e.g. one is waiting for the other, but the other side has already finished.
*/
object EndSessionMessage : ExistingSessionMessagePayload()

View File

@ -13,6 +13,7 @@ import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.context.InvocationContext import net.corda.core.context.InvocationContext
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.newSecureRandom
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowInfo import net.corda.core.flows.FlowInfo
@ -37,7 +38,6 @@ import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.config.shouldCheckCheckpoints import net.corda.node.services.config.shouldCheckCheckpoints
import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.TopicSession
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.newNamedSingleThreadExecutor import net.corda.node.utilities.newNamedSingleThreadExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
@ -72,7 +72,7 @@ class StateMachineManagerImpl(
companion object { companion object {
private val logger = contextLogger() private val logger = contextLogger()
internal val sessionTopic = TopicSession("platform.session") internal val sessionTopic = "platform.session"
init { init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
@ -121,8 +121,8 @@ class StateMachineManagerImpl(
private val totalStartedFlows = metrics.counter("Flows.Started") private val totalStartedFlows = metrics.counter("Flows.Started")
private val totalFinishedFlows = metrics.counter("Flows.Finished") private val totalFinishedFlows = metrics.counter("Flows.Finished")
private val openSessions = ConcurrentHashMap<Long, FlowSessionInternal>() private val openSessions = ConcurrentHashMap<SessionId, FlowSessionInternal>()
private val recentlyClosedSessions = ConcurrentHashMap<Long, Party>() private val recentlyClosedSessions = ConcurrentHashMap<SessionId, Party>()
// Context for tokenized services in checkpoints // Context for tokenized services in checkpoints
private lateinit var tokenizableServices: List<Any> private lateinit var tokenizableServices: List<Any>
@ -281,7 +281,7 @@ class StateMachineManagerImpl(
if (sender != null) { if (sender != null) {
when (sessionMessage) { when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender) is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender)
is SessionInit -> onSessionInit(sessionMessage, message, sender) is InitialSessionMessage -> onSessionInit(sessionMessage, message, sender)
} }
} else { } else {
logger.error("Unknown peer $peer in $sessionMessage") logger.error("Unknown peer $peer in $sessionMessage")
@ -294,15 +294,15 @@ class StateMachineManagerImpl(
session.fiber.pushToLoggingContext() session.fiber.pushToLoggingContext()
session.fiber.logger.trace { "Received $message on $session from $sender" } session.fiber.logger.trace { "Received $message on $session from $sender" }
if (session.retryable) { if (session.retryable) {
if (message is SessionConfirm && session.state is FlowSessionState.Initiated) { if (message.payload is ConfirmSessionMessage && session.state is FlowSessionState.Initiated) {
session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} session is idempotent" } session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} session is idempotent" }
return return
} }
if (message !is SessionConfirm) { if (message.payload !is ConfirmSessionMessage) {
serviceHub.networkService.cancelRedelivery(session.ourSessionId) serviceHub.networkService.cancelRedelivery(session.ourSessionId.toLong)
} }
} }
if (message is SessionEnd) { if (message.payload is EndSessionMessage || message.payload is ErrorSessionMessage) {
openSessions.remove(message.recipientSessionId) openSessions.remove(message.recipientSessionId)
} }
session.receivedMessages += ReceivedSessionMessage(sender, message) session.receivedMessages += ReceivedSessionMessage(sender, message)
@ -317,9 +317,9 @@ class StateMachineManagerImpl(
} else { } else {
val peerParty = recentlyClosedSessions.remove(message.recipientSessionId) val peerParty = recentlyClosedSessions.remove(message.recipientSessionId)
if (peerParty != null) { if (peerParty != null) {
if (message is SessionConfirm) { if (message.payload is ConfirmSessionMessage) {
logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" } logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId)) sendSessionMessage(peerParty, ExistingSessionMessage(message.payload.initiatedSessionId, EndSessionMessage))
} else { } else {
logger.trace { "Ignoring session end message for already closed session: $message" } logger.trace { "Ignoring session end message for already closed session: $message" }
} }
@ -336,12 +336,12 @@ class StateMachineManagerImpl(
return waitingForResponse?.shouldResume(message, session) ?: false return waitingForResponse?.shouldResume(message, session) ?: false
} }
private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { private fun onSessionInit(sessionInit: InitialSessionMessage, receivedMessage: ReceivedMessage, sender: Party) {
logger.trace { "Received $sessionInit from $sender" } logger.trace { "Received $sessionInit from $sender" }
val senderSessionId = sessionInit.initiatorSessionId val senderSessionId = sessionInit.initiatorSessionId
fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(senderSessionId, message)) fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, random63BitValue())))
val (session, initiatedFlowFactory) = try { val (session, initiatedFlowFactory) = try {
val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit)
@ -354,11 +354,11 @@ class StateMachineManagerImpl(
val session = FlowSessionInternal( val session = FlowSessionInternal(
flow, flow,
flowSession, flowSession,
random63BitValue(), SessionId.createRandom(newSecureRandom()),
sender, sender,
FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName)))
if (sessionInit.firstPayload != null) { if (sessionInit.firstPayload != null) {
session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) session.receivedMessages += ReceivedSessionMessage(sender, ExistingSessionMessage(session.ourSessionId, DataSessionMessage(sessionInit.firstPayload)))
} }
openSessions[session.ourSessionId] = session openSessions[session.ourSessionId] = session
val context = InvocationContext.peer(sender.name) val context = InvocationContext.peer(sender.name)
@ -386,19 +386,19 @@ class StateMachineManagerImpl(
is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName
} }
sendSessionMessage(sender, SessionConfirm(senderSessionId, session.ourSessionId, ourFlowVersion, appName), session.fiber) sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, ConfirmSessionMessage(session.ourSessionId, FlowInfo(ourFlowVersion, appName))), session.fiber)
session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass}" } session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatorFlowClassName}" }
session.fiber.logger.trace { "Initiated from $sessionInit on $session" } session.fiber.logger.trace { "Initiated from $sessionInit on $session" }
resumeFiber(session.fiber) resumeFiber(session.fiber)
} }
private fun getInitiatedFlowFactory(sessionInit: SessionInit): InitiatedFlowFactory<*> { private fun getInitiatedFlowFactory(sessionInit: InitialSessionMessage): InitiatedFlowFactory<*> {
val initiatingFlowClass = try { val initiatingFlowClass = try {
Class.forName(sessionInit.initiatingFlowClass, true, classloader).asSubclass(FlowLogic::class.java) Class.forName(sessionInit.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java)
} catch (e: ClassNotFoundException) { } catch (e: ClassNotFoundException) {
throw SessionRejectException("Don't know ${sessionInit.initiatingFlowClass}") throw SessionRejectException("Don't know ${sessionInit.initiatorFlowClassName}")
} catch (e: ClassCastException) { } catch (e: ClassCastException) {
throw SessionRejectException("${sessionInit.initiatingFlowClass} is not a flow") throw SessionRejectException("${sessionInit.initiatorFlowClassName} is not a flow")
} }
return serviceHub.getFlowFactory(initiatingFlowClass) ?: return serviceHub.getFlowFactory(initiatingFlowClass) ?:
throw SessionRejectException("$initiatingFlowClass is not registered") throw SessionRejectException("$initiatingFlowClass is not registered")
@ -492,7 +492,7 @@ class StateMachineManagerImpl(
private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) { private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) {
val initiatedState = state as? FlowSessionState.Initiated ?: return val initiatedState = state as? FlowSessionState.Initiated ?: return
val sessionEnd = if (exception == null) { val sessionEnd = if (exception == null) {
NormalSessionEnd(initiatedState.peerSessionId) EndSessionMessage
} else { } else {
val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) { val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) {
// Only propagate this FlowException if our local flow threw it or it was propagated to us and we only // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only
@ -501,9 +501,9 @@ class StateMachineManagerImpl(
} else { } else {
null null
} }
ErrorSessionEnd(initiatedState.peerSessionId, errorResponse) ErrorSessionMessage(errorResponse, 0)
} }
sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber) sendSessionMessage(initiatedState.peerParty, ExistingSessionMessage(initiatedState.peerSessionId, sessionEnd), fiber)
recentlyClosedSessions[ourSessionId] = initiatedState.peerParty recentlyClosedSessions[ourSessionId] = initiatedState.peerParty
} }
@ -573,14 +573,14 @@ class StateMachineManagerImpl(
} }
private fun processSendRequest(ioRequest: SendRequest) { private fun processSendRequest(ioRequest: SendRequest) {
val retryId = if (ioRequest.message is SessionInit) { val retryId = if (ioRequest.message is InitialSessionMessage) {
with(ioRequest.session) { with(ioRequest.session) {
openSessions[ourSessionId] = this openSessions[ourSessionId] = this
if (retryable) ourSessionId else null if (retryable) ourSessionId.toLong else null
} }
} else null } else null
sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId) sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId)
if (ioRequest !is ReceiveRequest<*>) { if (ioRequest !is ReceiveRequest) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
resumeFiber(ioRequest.session.fiber) resumeFiber(ioRequest.session.fiber)
} }
@ -625,12 +625,15 @@ class StateMachineManagerImpl(
// Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface.
is KryoException, is KryoException,
is NotSerializableException -> { is NotSerializableException -> {
if (message !is ErrorSessionEnd || message.errorResponse == null) throw e if (message is ExistingSessionMessage && message.payload is ErrorSessionMessage && message.payload.flowException != null) {
logger.warn("Something in ${message.errorResponse.javaClass.name} is not serialisable. " + logger.warn("Something in ${message.payload.flowException.javaClass.name} is not serialisable. " +
"Instead sending back an exception which is serialisable to ensure session end occurs properly.", e) "Instead sending back an exception which is serialisable to ensure session end occurs properly.", e)
// The subclass may have overridden toString so we use that // The subclass may have overridden toString so we use that
val exMessage = message.errorResponse.let { if (it.javaClass != FlowException::class.java) it.toString() else it.message } val exMessage = message.payload.flowException.message
message.copy(errorResponse = FlowException(exMessage)).serialize() message.copy(payload = message.payload.copy(flowException = FlowException(exMessage))).serialize()
} else {
throw e
}
} }
else -> throw e else -> throw e
} }

View File

@ -3,7 +3,6 @@ package net.corda.node.messaging
import net.corda.core.messaging.AllPossibleRecipients import net.corda.core.messaging.AllPossibleRecipients
import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.Message
import net.corda.node.services.messaging.TopicStringValidator import net.corda.node.services.messaging.TopicStringValidator
import net.corda.node.services.messaging.createMessage
import net.corda.testing.internal.rigorousMock import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
@ -51,10 +50,10 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray() val bits = "test-content".toByteArray()
var finalDelivery: Message? = null var finalDelivery: Message? = null
node2.network.addMessageHandler { msg, _ -> node2.network.addMessageHandler("test.topic") { msg, _ ->
node2.network.send(msg, node3.network.myAddress) node2.network.send(msg, node3.network.myAddress)
} }
node3.network.addMessageHandler { msg, _ -> node3.network.addMessageHandler("test.topic") { msg, _ ->
finalDelivery = msg finalDelivery = msg
} }
@ -63,7 +62,7 @@ class InMemoryMessagingTests {
mockNet.runNetwork(rounds = 1) mockNet.runNetwork(rounds = 1)
assertTrue(Arrays.equals(finalDelivery!!.data, bits)) assertTrue(Arrays.equals(finalDelivery!!.data.bytes, bits))
} }
@Test @Test
@ -75,7 +74,7 @@ class InMemoryMessagingTests {
val bits = "test-content".toByteArray() val bits = "test-content".toByteArray()
var counter = 0 var counter = 0
listOf(node1, node2, node3).forEach { it.network.addMessageHandler { _, _ -> counter++ } } listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _ -> counter++ } }
node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>()) node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock<AllPossibleRecipients>())
mockNet.runNetwork(rounds = 1) mockNet.runNetwork(rounds = 1)
assertEquals(3, counter) assertEquals(3, counter)

View File

@ -126,7 +126,7 @@ class ArtemisMessagingTest {
messagingClient.send(message, messagingClient.myAddress) messagingClient.send(message, messagingClient.myAddress)
val actual: Message = receivedMessages.take() val actual: Message = receivedMessages.take()
assertEquals("first msg", String(actual.data)) assertEquals("first msg", String(actual.data.bytes))
assertNull(receivedMessages.poll(200, MILLISECONDS)) assertNull(receivedMessages.poll(200, MILLISECONDS))
} }

View File

@ -37,8 +37,8 @@ import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStra
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
import net.corda.testing.node.MockNetwork.MockNode import net.corda.testing.node.MockNetwork.MockNode
import net.corda.testing.node.MockNodeParameters import net.corda.testing.node.MockNodeParameters
import net.corda.testing.node.pumpReceive
import net.corda.testing.node.internal.startFlow import net.corda.testing.node.internal.startFlow
import net.corda.testing.node.pumpReceive
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType
@ -296,14 +296,16 @@ class FlowFrameworkTests {
mockNet.runNetwork() mockNet.runNetwork()
assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy {
resultFuture.getOrThrow() resultFuture.getOrThrow()
}.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive }
} }
@Test @Test
fun `receiving unexpected session end before entering sendAndReceive`() { fun `receiving unexpected session end before entering sendAndReceive`() {
bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() }
val sessionEndReceived = Semaphore(0) val sessionEndReceived = Semaphore(0)
receivedSessionMessagesObservable().filter { it.message is SessionEnd }.subscribe { sessionEndReceived.release() } receivedSessionMessagesObservable().filter {
it.message is ExistingSessionMessage && it.message.payload is EndSessionMessage
}.subscribe { sessionEndReceived.release() }
val resultFuture = aliceNode.services.startFlow( val resultFuture = aliceNode.services.startFlow(
WaitForOtherSideEndBeforeSendAndReceive(bob, sessionEndReceived)).resultFuture WaitForOtherSideEndBeforeSendAndReceive(bob, sessionEndReceived)).resultFuture
mockNet.runNetwork() mockNet.runNetwork()
@ -356,7 +358,7 @@ class FlowFrameworkTests {
assertSessionTransfers( assertSessionTransfers(
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode, bobNode sent sessionConfirm() to aliceNode,
bobNode sent erroredEnd() to aliceNode bobNode sent errorMessage() to aliceNode
) )
} }
@ -389,10 +391,11 @@ class FlowFrameworkTests {
assertSessionTransfers( assertSessionTransfers(
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode, bobNode sent sessionConfirm() to aliceNode,
bobNode sent erroredEnd(erroringFlow.get().exceptionThrown) to aliceNode bobNode sent errorMessage(erroringFlow.get().exceptionThrown) to aliceNode
) )
// Make sure the original stack trace isn't sent down the wire // Make sure the original stack trace isn't sent down the wire
assertThat((receivedSessionMessages.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty() val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat((lastMessage.payload as ErrorSessionMessage).flowException!!.stackTrace).isEmpty()
} }
@Test @Test
@ -438,7 +441,7 @@ class FlowFrameworkTests {
aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, aliceNode sent sessionInit(ReceiveFlow::class) to bobNode,
bobNode sent sessionConfirm() to aliceNode, bobNode sent sessionConfirm() to aliceNode,
bobNode sent sessionData("Hello") to aliceNode, bobNode sent sessionData("Hello") to aliceNode,
aliceNode sent erroredEnd() to bobNode aliceNode sent errorMessage() to bobNode
) )
} }
@ -603,20 +606,20 @@ class FlowFrameworkTests {
@Test @Test
fun `unknown class in session init`() { fun `unknown class in session init`() {
aliceNode.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), bob) aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob)
mockNet.runNetwork() mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = receivedSessionMessages.last().message as SessionReject val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class") assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("Don't know not.a.real.Class")
} }
@Test @Test
fun `non-flow class in session init`() { fun `non-flow class in session init`() {
aliceNode.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), bob) aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob)
mockNet.runNetwork() mockNet.runNetwork()
assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected
val reject = receivedSessionMessages.last().message as SessionReject val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage
assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow") assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("${String::class.java.name} is not a flow")
} }
@Test @Test
@ -682,14 +685,14 @@ class FlowFrameworkTests {
return observable.toFuture() return observable.toFuture()
} }
private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): SessionInit { private fun sessionInit(clientFlowClass: KClass<out FlowLogic<*>>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage {
return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize())
} }
private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "") private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, "")))
private fun sessionData(payload: Any) = SessionData(0, payload.serialize()) private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize()))
private val normalEnd = NormalSessionEnd(0) private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0)
private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse) private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0))
private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) { private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) {
services.networkService.apply { services.networkService.apply {
@ -709,7 +712,9 @@ class FlowFrameworkTests {
} }
private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) {
val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null val isPayloadTransfer: Boolean get() =
message is ExistingSessionMessage && message.payload is DataSessionMessage ||
message is InitialSessionMessage && message.firstPayload != null
override fun toString(): String = "$from sent $message to $to" override fun toString(): String = "$from sent $message to $to"
} }
@ -718,7 +723,7 @@ class FlowFrameworkTests {
} }
private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> { private fun Observable<MessageTransfer>.toSessionTransfers(): Observable<SessionTransfer> {
return filter { it.message.topicSession == StateMachineManagerImpl.sessionTopic }.map { return filter { it.message.topic == StateMachineManagerImpl.sessionTopic }.map {
val from = it.sender.id val from = it.sender.id
val message = it.message.data.deserialize<SessionMessage>() val message = it.message.data.deserialize<SessionMessage>()
SessionTransfer(from, sanitise(message), it.recipients) SessionTransfer(from, sanitise(message), it.recipients)
@ -726,12 +731,23 @@ class FlowFrameworkTests {
} }
private fun sanitise(message: SessionMessage) = when (message) { private fun sanitise(message: SessionMessage) = when (message) {
is SessionData -> message.copy(recipientSessionId = 0) is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "")
is SessionInit -> message.copy(initiatorSessionId = 0, appName = "") is ExistingSessionMessage -> {
is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0, appName = "") val payload = message.payload
is NormalSessionEnd -> message.copy(recipientSessionId = 0) message.copy(
is ErrorSessionEnd -> message.copy(recipientSessionId = 0) recipientSessionId = SessionId(0),
else -> message payload = when (payload) {
is ConfirmSessionMessage -> payload.copy(
initiatedSessionId = SessionId(0),
initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "")
)
is ErrorSessionMessage -> payload.copy(
errorId = 0
)
else -> payload
}
)
}
} }
private infix fun StartedNode<MockNode>.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message) private infix fun StartedNode<MockNode>.sent(message: SessionMessage): Pair<Int, SessionMessage> = Pair(internals.id, message)

View File

@ -15,9 +15,7 @@ import net.corda.core.serialization.deserialize
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.netmap.VisualiserViewModel.Style import net.corda.netmap.VisualiserViewModel.Style
import net.corda.netmap.simulation.IRSSimulation import net.corda.netmap.simulation.IRSSimulation
import net.corda.node.services.statemachine.SessionConfirm import net.corda.node.services.statemachine.*
import net.corda.node.services.statemachine.SessionEnd
import net.corda.node.services.statemachine.SessionInit
import net.corda.testing.core.chooseIdentity import net.corda.testing.core.chooseIdentity
import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork
@ -342,12 +340,16 @@ class NetworkMapVisualiser : Application() {
private fun transferIsInteresting(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean { private fun transferIsInteresting(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
// Loopback messages are boring. // Loopback messages are boring.
if (transfer.sender == transfer.recipients) return false if (transfer.sender == transfer.recipients) return false
val message = transfer.message.data.deserialize<Any>() val message = transfer.message.data.deserialize<SessionMessage>()
return when (message) { return when (message) {
is SessionEnd -> false is InitialSessionMessage -> message.firstPayload != null
is SessionConfirm -> false is ExistingSessionMessage -> when (message.payload) {
is SessionInit -> message.firstPayload != null is ConfirmSessionMessage -> false
else -> true is DataSessionMessage -> true
is ErrorSessionMessage -> true
is RejectSessionMessage -> true
is EndSessionMessage -> false
}
} }
} }
} }

View File

@ -14,11 +14,17 @@ import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.PartyInfo
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.services.messaging.* import net.corda.node.services.messaging.Message
import net.corda.node.services.messaging.MessageHandlerRegistration
import net.corda.node.services.messaging.MessagingService
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.testing.node.InMemoryMessagingNetwork.TestMessagingService
import org.apache.activemq.artemis.utils.ReusableLatch import org.apache.activemq.artemis.utils.ReusableLatch
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import rx.Observable import rx.Observable
@ -57,7 +63,7 @@ class InMemoryMessagingNetwork internal constructor(
@CordaSerializable @CordaSerializable
data class MessageTransfer(val sender: PeerHandle, val message: Message, val recipients: MessageRecipients) { data class MessageTransfer(val sender: PeerHandle, val message: Message, val recipients: MessageRecipients) {
override fun toString() = "${message.topicSession} from '$sender' to '$recipients'" override fun toString() = "${message.topic} from '$sender' to '$recipients'"
} }
// All sent messages are kept here until pumpSend is called, or manuallyPumped is set to false // All sent messages are kept here until pumpSend is called, or manuallyPumped is set to false
@ -241,17 +247,17 @@ class InMemoryMessagingNetwork internal constructor(
_sentMessages.onNext(transfer) _sentMessages.onNext(transfer)
} }
data class InMemoryMessage(override val topicSession: TopicSession, data class InMemoryMessage(override val topic: String,
override val data: ByteArray, override val data: ByteSequence,
override val uniqueMessageId: UUID, override val uniqueMessageId: String,
override val debugTimestamp: Instant = Instant.now()) : Message { override val debugTimestamp: Instant = Instant.now()) : Message {
override fun toString() = "$topicSession#${String(data)}" override fun toString() = "$topic#${String(data.bytes)}"
} }
private data class InMemoryReceivedMessage(override val topicSession: TopicSession, private data class InMemoryReceivedMessage(override val topic: String,
override val data: ByteArray, override val data: ByteSequence,
override val platformVersion: Int, override val platformVersion: Int,
override val uniqueMessageId: UUID, override val uniqueMessageId: String,
override val debugTimestamp: Instant, override val debugTimestamp: Instant,
override val peer: CordaX500Name) : ReceivedMessage override val peer: CordaX500Name) : ReceivedMessage
@ -270,8 +276,7 @@ class InMemoryMessagingNetwork internal constructor(
private val peerHandle: PeerHandle, private val peerHandle: PeerHandle,
private val executor: AffinityExecutor, private val executor: AffinityExecutor,
private val database: CordaPersistence) : SingletonSerializeAsToken(), TestMessagingService { private val database: CordaPersistence) : SingletonSerializeAsToken(), TestMessagingService {
private inner class Handler(val topicSession: TopicSession, inner class Handler(val topicSession: String, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
@Volatile @Volatile
private var running = true private var running = true
@ -282,7 +287,7 @@ class InMemoryMessagingNetwork internal constructor(
} }
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(HashSet<UUID>()) private val processedMessages: MutableSet<String> = Collections.synchronizedSet(HashSet<String>())
override val myAddress: PeerHandle get() = peerHandle override val myAddress: PeerHandle get() = peerHandle
@ -304,13 +309,10 @@ class InMemoryMessagingNetwork internal constructor(
} }
} }
override fun addMessageHandler(topic: String, sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration override fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
= addMessageHandler(TopicSession(topic, sessionID), callback)
override fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration {
check(running) check(running)
val (handler, transfers) = state.locked { val (handler, transfers) = state.locked {
val handler = Handler(topicSession, callback).apply { handlers.add(this) } val handler = Handler(topic, callback).apply { handlers.add(this) }
val pending = ArrayList<MessageTransfer>() val pending = ArrayList<MessageTransfer>()
database.transaction { database.transaction {
pending.addAll(pendingRedelivery) pending.addAll(pendingRedelivery)
@ -328,20 +330,18 @@ class InMemoryMessagingNetwork internal constructor(
state.locked { check(handlers.remove(registration as Handler)) } state.locked { check(handlers.remove(registration as Handler)) }
} }
override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) {
check(running) check(running)
msgSend(this, message, target) msgSend(this, message, target)
acknowledgementHandler?.invoke()
if (!sendManuallyPumped) { if (!sendManuallyPumped) {
pumpSend(false) pumpSend(false)
} }
} }
override fun send(addressedMessages: List<MessagingService.AddressedMessage>, acknowledgementHandler: (() -> Unit)?) { override fun send(addressedMessages: List<MessagingService.AddressedMessage>) {
for ((message, target, retryId, sequenceKey) in addressedMessages) { for ((message, target, retryId, sequenceKey) in addressedMessages) {
send(message, target, retryId, sequenceKey, null) send(message, target, retryId, sequenceKey)
} }
acknowledgementHandler?.invoke()
} }
override fun stop() { override fun stop() {
@ -356,8 +356,8 @@ class InMemoryMessagingNetwork internal constructor(
override fun cancelRedelivery(retryId: Long) {} override fun cancelRedelivery(retryId: Long) {}
/** Returns the given (topic & session, data) pair as a newly created message object. */ /** Returns the given (topic & session, data) pair as a newly created message object. */
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message { override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message {
return InMemoryMessage(topicSession, data, uuid) return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId)
} }
/** /**
@ -390,14 +390,14 @@ class InMemoryMessagingNetwork internal constructor(
while (deliverTo == null) { while (deliverTo == null) {
val transfer = (if (block) q.take() else q.poll()) ?: return null val transfer = (if (block) q.take() else q.poll()) ?: return null
deliverTo = state.locked { deliverTo = state.locked {
val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topicSession == it.topicSession } val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topic == it.topicSession }
if (matchingHandlers.isEmpty()) { if (matchingHandlers.isEmpty()) {
// Got no handlers for this message yet. Keep the message around and attempt redelivery after a new // Got no handlers for this message yet. Keep the message around and attempt redelivery after a new
// handler has been registered. The purpose of this path is to make unit tests that have multi-threading // handler has been registered. The purpose of this path is to make unit tests that have multi-threading
// reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting // reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting
// up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at // up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at
// least sometimes. // least sometimes.
log.warn("Message to ${transfer.message.topicSession} could not be delivered") log.warn("Message to ${transfer.message.topic} could not be delivered")
database.transaction { database.transaction {
pendingRedelivery.add(transfer) pendingRedelivery.add(transfer)
} }
@ -438,8 +438,8 @@ class InMemoryMessagingNetwork internal constructor(
} }
private fun MessageTransfer.toReceivedMessage(): ReceivedMessage = InMemoryReceivedMessage( private fun MessageTransfer.toReceivedMessage(): ReceivedMessage = InMemoryReceivedMessage(
message.topicSession, message.topic,
message.data.copyOf(), // Kryo messes with the buffer so give each client a unique copy OpaqueBytes(message.data.bytes.copyOf()), // Kryo messes with the buffer so give each client a unique copy
1, 1,
message.uniqueMessageId, message.uniqueMessageId,
message.debugTimestamp, message.debugTimestamp,