From 0a88b76e461fc540038553b77a6f2afd95c04566 Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Thu, 8 Feb 2018 15:13:25 +0000 Subject: [PATCH 1/3] r3corda wire compatibility --- .idea/compiler.xml | 3 +- .../mocknetwork/TutorialMockNetwork.kt | 22 ++- .../serialization/ListsSerializationTest.kt | 8 +- .../serialization/MapsSerializationTest.kt | 4 +- .../serialization/SetsSerializationTest.kt | 10 +- .../services/messaging/P2PMessagingTest.kt | 84 +++++---- .../node/services/messaging/Messaging.kt | 135 ++------------- .../services/messaging/P2PMessagingClient.kt | 81 ++++----- .../messaging/ServiceRequestMessage.kt | 27 --- .../services/statemachine/FlowIORequest.kt | 34 ++-- .../statemachine/FlowSessionInternal.kt | 16 +- .../statemachine/FlowStateMachineImpl.kt | 161 ++++++++++------- .../services/statemachine/SessionMessage.kt | 163 ++++++++++++------ .../statemachine/StateMachineManagerImpl.kt | 69 ++++---- .../node/messaging/InMemoryMessagingTests.kt | 9 +- .../messaging/ArtemisMessagingTest.kt | 2 +- .../statemachine/FlowFrameworkTests.kt | 70 +++++--- .../net/corda/netmap/NetworkMapVisualiser.kt | 18 +- .../testing/node/InMemoryMessagingNetwork.kt | 58 +++---- 19 files changed, 489 insertions(+), 485 deletions(-) delete mode 100644 node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt diff --git a/.idea/compiler.xml b/.idea/compiler.xml index d8d9f1498a..dba1dad5e2 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -72,6 +72,7 @@ + @@ -159,4 +160,4 @@ - + \ No newline at end of file diff --git a/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt b/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt index 6841c5d7fc..1ab3513656 100644 --- a/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt +++ b/docs/source/example-code/src/main/kotlin/net/corda/docs/tutorial/mocknetwork/TutorialMockNetwork.kt @@ -10,12 +10,18 @@ import net.corda.core.identity.Party import net.corda.core.messaging.MessageRecipients import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap import net.corda.node.internal.StartedNode import net.corda.node.services.messaging.Message -import net.corda.node.services.statemachine.SessionData -import net.corda.testing.node.* +import net.corda.node.services.statemachine.DataSessionMessage +import net.corda.node.services.statemachine.ExistingSessionMessage +import net.corda.testing.node.InMemoryMessagingNetwork +import net.corda.testing.node.MessagingServiceSpy +import net.corda.testing.node.MockNetwork +import net.corda.testing.node.setMessagingServiceSpy +import net.corda.testing.node.startFlow import org.junit.After import org.junit.Before import org.junit.Rule @@ -79,12 +85,12 @@ class TutorialMockNetwork { // modify message if it's 1 nodeB.setMessagingServiceSpy(object : MessagingServiceSpy(nodeB.network) { - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { - val messageData = message.data.deserialize() - - if (messageData is SessionData && messageData.payload.deserialize() == 1) { - val alteredMessageData = SessionData(messageData.recipientSessionId, 99.serialize()).serialize().bytes - messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topicSession, alteredMessageData, message.uniqueMessageId), target, retryId) + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { + val messageData = message.data.deserialize() as? ExistingSessionMessage + val payload = messageData?.payload + if (payload is DataSessionMessage && payload.payload.deserialize() == 1) { + val alteredMessageData = messageData.copy(payload = payload.copy(99.serialize())).serialize().bytes + messagingService.send(InMemoryMessagingNetwork.InMemoryMessage(message.topic, OpaqueBytes(alteredMessageData), message.uniqueMessageId), target, retryId) } else { messagingService.send(message, target, retryId) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt index 87b1b50a27..dc53b8fc29 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/ListsSerializationTest.kt @@ -3,7 +3,7 @@ package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.util.DefaultClassResolver import net.corda.core.serialization.* -import net.corda.node.services.statemachine.SessionData +import net.corda.node.services.statemachine.DataSessionMessage import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput import net.corda.nodeapi.internal.serialization.amqp.Envelope import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory @@ -47,17 +47,17 @@ class ListsSerializationTest { @Test fun `check list can be serialized as part of SessionData`() { run { - val sessionData = SessionData(123, listOf(1).serialize()) + val sessionData = DataSessionMessage(listOf(1).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(listOf(1), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, listOf(1, 2).serialize()) + val sessionData = DataSessionMessage(listOf(1, 2).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(listOf(1, 2), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, emptyList().serialize()) + val sessionData = DataSessionMessage(emptyList().serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(emptyList(), sessionData.payload.deserialize()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt index d1b9af493c..a76bb8a52e 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/MapsSerializationTest.kt @@ -6,7 +6,7 @@ import net.corda.core.identity.CordaX500Name import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.node.services.statemachine.SessionData +import net.corda.node.services.statemachine.DataSessionMessage import net.corda.nodeapi.internal.serialization.kryo.kryoMagic import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.internal.amqpSpecific @@ -41,7 +41,7 @@ class MapsSerializationTest { @Test fun `check list can be serialized as part of SessionData`() { - val sessionData = SessionData(123, smallMap.serialize()) + val sessionData = DataSessionMessage(smallMap.serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(smallMap, sessionData.payload.deserialize()) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt index fb18178b36..48ba75540e 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt @@ -4,10 +4,10 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.util.DefaultClassResolver import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.node.services.statemachine.SessionData +import net.corda.node.services.statemachine.DataSessionMessage import net.corda.nodeapi.internal.serialization.kryo.kryoMagic -import net.corda.testing.internal.kryoSpecific import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.internal.kryoSpecific import org.junit.Assert.assertArrayEquals import org.junit.Assert.assertEquals import org.junit.Rule @@ -34,17 +34,17 @@ class SetsSerializationTest { @Test fun `check set can be serialized as part of SessionData`() { run { - val sessionData = SessionData(123, setOf(1).serialize()) + val sessionData = DataSessionMessage(setOf(1).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(setOf(1), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, setOf(1, 2).serialize()) + val sessionData = DataSessionMessage(setOf(1, 2).serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(setOf(1, 2), sessionData.payload.deserialize()) } run { - val sessionData = SessionData(123, emptySet().serialize()) + val sessionData = DataSessionMessage(emptySet().serialize()) assertEqualAfterRoundTripSerialization(sessionData) assertEquals(emptySet(), sessionData.payload.deserialize()) } diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt index 987277597d..3cea7b31e6 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt @@ -1,9 +1,9 @@ package net.corda.services.messaging import net.corda.core.concurrent.CordaFuture -import net.corda.core.crypto.random63BitValue import net.corda.core.identity.CordaX500Name import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.openFuture import net.corda.core.internal.randomOrNull import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient @@ -14,7 +14,9 @@ import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.seconds import net.corda.node.internal.Node import net.corda.node.internal.StartedNode -import net.corda.node.services.messaging.* +import net.corda.node.services.messaging.MessagingService +import net.corda.node.services.messaging.ReceivedMessage +import net.corda.node.services.messaging.send import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.chooseIdentity import net.corda.testing.driver.DriverDSL @@ -27,6 +29,7 @@ import org.junit.Test import java.util.* import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicInteger class P2PMessagingTest { @@ -50,19 +53,12 @@ class P2PMessagingTest { alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!) } - val dummyTopic = "dummy.topic" val responseMessage = "response" - val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage) + val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage) // Send a single request with retry - val responseFuture = with(alice.network) { - val request = TestRequest(replyTo = myAddress) - val responseFuture = onNext(dummyTopic, request.sessionID) - val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes) - send(msg, serviceAddress, retryId = request.sessionID) - responseFuture - } + val responseFuture = alice.receiveFrom(serviceAddress, retryId = 0) crashingNodes.firstRequestReceived.await(5, TimeUnit.SECONDS) // The request wasn't successful. assertThat(responseFuture.isDone).isFalse() @@ -83,19 +79,12 @@ class P2PMessagingTest { alice.network.getAddressOfParty(getPartyInfo(notaryParty)!!) } - val dummyTopic = "dummy.topic" val responseMessage = "response" - val crashingNodes = simulateCrashingNodes(distributedServiceNodes, dummyTopic, responseMessage) - - val sessionId = random63BitValue() + val crashingNodes = simulateCrashingNodes(distributedServiceNodes, responseMessage) // Send a single request with retry - with(alice.network) { - val request = TestRequest(sessionId, myAddress) - val msg = createMessage(TopicSession(dummyTopic), data = request.serialize().bytes) - send(msg, serviceAddress, retryId = request.sessionID) - } + alice.receiveFrom(serviceAddress, retryId = 0) // Wait until the first request is received crashingNodes.firstRequestReceived.await() @@ -108,7 +97,13 @@ class P2PMessagingTest { // Restart the node and expect a response val aliceRestarted = startAlice() - val response = aliceRestarted.network.onNext(dummyTopic, sessionId).getOrThrow() + + val responseFuture = openFuture() + aliceRestarted.network.runOnNextMessage("test.response") { + responseFuture.set(it.data.deserialize()) + } + val response = responseFuture.getOrThrow() + assertThat(crashingNodes.requestsReceived.get()).isGreaterThan(numberOfRequestsReceived) assertThat(response).isEqualTo(responseMessage) } @@ -133,11 +128,12 @@ class P2PMessagingTest { ) /** - * Sets up the [distributedServiceNodes] to respond to [dummyTopic] requests. All nodes will receive requests and - * either ignore them or respond, depending on the value of [CrashingNodes.ignoreRequests], initially set to true. - * This may be used to simulate scenarios where nodes receive request messages but crash before sending back a response. + * Sets up the [distributedServiceNodes] to respond to "test.request" requests. All nodes will receive requests and + * either ignore them or respond to "test.response", depending on the value of [CrashingNodes.ignoreRequests], + * initially set to true. This may be used to simulate scenarios where nodes receive request messages but crash + * before sending back a response. */ - private fun simulateCrashingNodes(distributedServiceNodes: List>, dummyTopic: String, responseMessage: String): CrashingNodes { + private fun simulateCrashingNodes(distributedServiceNodes: List>, responseMessage: String): CrashingNodes { val crashingNodes = CrashingNodes( requestsReceived = AtomicInteger(0), firstRequestReceived = CountDownLatch(1), @@ -146,7 +142,7 @@ class P2PMessagingTest { distributedServiceNodes.forEach { val nodeName = it.info.chooseIdentity().name - it.network.addMessageHandler(dummyTopic) { netMessage, _ -> + it.network.addMessageHandler("test.request") { netMessage, _ -> crashingNodes.requestsReceived.incrementAndGet() crashingNodes.firstRequestReceived.countDown() // The node which receives the first request will ignore all requests @@ -158,7 +154,7 @@ class P2PMessagingTest { } else { println("sending response") val request = netMessage.data.deserialize() - val response = it.network.createMessage(dummyTopic, request.sessionID, responseMessage.serialize().bytes) + val response = it.network.createMessage("test.response", responseMessage.serialize().bytes) it.network.send(response, request.replyTo) } } @@ -188,19 +184,39 @@ class P2PMessagingTest { } private fun StartedNode<*>.respondWith(message: Any) { - network.addMessageHandler(javaClass.name) { netMessage, _ -> + network.addMessageHandler("test.request") { netMessage, _ -> val request = netMessage.data.deserialize() - val response = network.createMessage(javaClass.name, request.sessionID, message.serialize().bytes) + val response = network.createMessage("test.response", message.serialize().bytes) network.send(response, request.replyTo) } } - private fun StartedNode<*>.receiveFrom(target: MessageRecipients): CordaFuture { - val request = TestRequest(replyTo = network.myAddress) - return network.sendRequest(javaClass.name, request, target) + private fun StartedNode<*>.receiveFrom(target: MessageRecipients, retryId: Long? = null): CordaFuture { + val response = openFuture() + network.runOnNextMessage("test.response") { netMessage -> + response.set(netMessage.data.deserialize()) + } + network.send("test.request", TestRequest(replyTo = network.myAddress), target, retryId = retryId) + return response + } + + /** + * Registers a handler for the given topic and session that runs the given callback with the message and then removes + * itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback + * doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler]. + * + * @param topic identifier for the topic and session to listen for messages arriving on. + */ + inline fun MessagingService.runOnNextMessage(topic: String, crossinline callback: (ReceivedMessage) -> Unit) { + val consumed = AtomicBoolean() + addMessageHandler(topic) { msg, reg -> + removeMessageHandler(reg) + check(!consumed.getAndSet(true)) { "Called more than once" } + check(msg.topic == topic) { "Topic/session mismatch: ${msg.topic} vs $topic" } + callback(msg) + } } @CordaSerializable - private data class TestRequest(override val sessionID: Long = random63BitValue(), - override val replyTo: SingleMessageRecipient) : ServiceRequestMessage + private data class TestRequest(val replyTo: SingleMessageRecipient) } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index 468de4d8f5..6260ea59db 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -1,18 +1,15 @@ package net.corda.node.services.messaging -import net.corda.core.concurrent.CordaFuture +import co.paralleluniverse.fibers.Suspendable import net.corda.core.identity.CordaX500Name -import net.corda.core.internal.concurrent.openFuture -import net.corda.core.internal.uncheckedCast import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.core.utilities.ByteSequence import java.time.Instant import java.util.* -import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.ThreadSafe /** @@ -27,29 +24,6 @@ import javax.annotation.concurrent.ThreadSafe */ @ThreadSafe interface MessagingService { - companion object { - /** - * Session ID to use for services listening for the first message in a session (before a - * specific session ID has been established). - */ - val DEFAULT_SESSION_ID = 0L - } - - /** - * The provided function will be invoked for each received message whose topic matches the given string. The callback - * will run on threads provided by the messaging service, and the callback is expected to be thread safe as a result. - * - * The returned object is an opaque handle that may be used to un-register handlers later with [removeMessageHandler]. - * The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister - * itself and yet addMessageHandler hasn't returned the handle yet. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]). - * @param sessionID identifier for the session the message is part of. For services listening before - * a session is established, use [DEFAULT_SESSION_ID]. - */ - fun addMessageHandler(topic: String = "", sessionID: Long = DEFAULT_SESSION_ID, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration - /** * The provided function will be invoked for each received message whose topic and session matches. The callback * will run on the main server thread provided when the messaging service is constructed, and a database @@ -59,9 +33,9 @@ interface MessagingService { * The handle is passed to the callback as well, to avoid race conditions whereby the callback wants to unregister * itself and yet addMessageHandler hasn't returned the handle yet. * - * @param topicSession identifier for the topic and session to listen for messages arriving on. + * @param topic identifier for the topic to listen for messages arriving on. */ - fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration + fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration /** * Removes a handler given the object returned from [addMessageHandler]. The callback will no longer be invoked once @@ -86,15 +60,13 @@ interface MessagingService { * @param sequenceKey an object that may be used to enable a parallel [MessagingService] implementation. Two * subsequent send()s with the same [sequenceKey] (up to equality) are guaranteed to be delivered in the same * sequence the send()s were called. By default this is chosen conservatively to be [target]. - * @param acknowledgementHandler if non-null this handler will be called once the sent message has been committed by - * the broker. Note that if specified [send] itself may return earlier than the commit. */ + @Suspendable fun send( message: Message, target: MessageRecipients, retryId: Long? = null, - sequenceKey: Any = target, - acknowledgementHandler: (() -> Unit)? = null + sequenceKey: Any = target ) /** A message with a target and sequenceKey specified. */ @@ -110,12 +82,9 @@ interface MessagingService { * implementation. * * @param addressedMessages The list of messages together with the recipients, retry ids and sequence keys. - * @param retryId if provided the message will be scheduled for redelivery until [cancelRedelivery] is called for this id. - * Note that this feature should only be used when the target is an idempotent distributed service, e.g. a notary. - * @param acknowledgementHandler if non-null this handler will be called once all sent messages have been committed - * by the broker. Note that if specified [send] itself may return earlier than the commit. */ - fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)? = null) + @Suspendable + fun send(addressedMessages: List) /** Cancels the scheduled message redelivery for the specified [retryId] */ fun cancelRedelivery(retryId: Long) @@ -123,9 +92,9 @@ interface MessagingService { /** * Returns an initialised [Message] with the current time, etc, already filled in. * - * @param topicSession identifier for the topic and session the message is sent to. + * @param topic identifier for the topic the message is sent to. */ - fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID = UUID.randomUUID()): Message + fun createMessage(topic: String, data: ByteArray, deduplicationId: String = UUID.randomUUID().toString()): Message /** Given information about either a specific node or a service returns its corresponding address */ fun getAddressOfParty(partyInfo: PartyInfo): MessageRecipients @@ -134,86 +103,12 @@ interface MessagingService { val myAddress: SingleMessageRecipient } -/** - * Returns an initialised [Message] with the current time, etc, already filled in. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * Must not be blank. - * @param sessionID identifier for the session the message is part of. For messages sent to services before the - * construction of a session, use [DEFAULT_SESSION_ID]. - */ -fun MessagingService.createMessage(topic: String, sessionID: Long = MessagingService.DEFAULT_SESSION_ID, data: ByteArray): Message - = createMessage(TopicSession(topic, sessionID), data) -/** - * Registers a handler for the given topic and session ID that runs the given callback with the message and then removes - * itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback - * doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler], as the handler is - * automatically deregistered before the callback runs. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]). - * @param sessionID identifier for the session the message is part of. For services listening before - * a session is established, use [DEFAULT_SESSION_ID]. - */ -fun MessagingService.runOnNextMessage(topic: String, sessionID: Long, callback: (ReceivedMessage) -> Unit) - = runOnNextMessage(TopicSession(topic, sessionID), callback) - -/** - * Registers a handler for the given topic and session that runs the given callback with the message and then removes - * itself. This is useful for one-shot handlers that aren't supposed to stick around permanently. Note that this callback - * doesn't take the registration object, unlike the callback to [MessagingService.addMessageHandler]. - * - * @param topicSession identifier for the topic and session to listen for messages arriving on. - */ -inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, crossinline callback: (ReceivedMessage) -> Unit) { - val consumed = AtomicBoolean() - addMessageHandler(topicSession) { msg, reg -> - removeMessageHandler(reg) - check(!consumed.getAndSet(true)) { "Called more than once" } - check(msg.topicSession == topicSession) { "Topic/session mismatch: ${msg.topicSession} vs $topicSession" } - callback(msg) - } -} - -/** - * Returns a [CordaFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId. - * The payload is deserialized to an object of type [M]. Any exceptions thrown will be captured by the future. - */ -fun MessagingService.onNext(topic: String, sessionId: Long): CordaFuture { - val messageFuture = openFuture() - runOnNextMessage(topic, sessionId) { message -> - messageFuture.capture { - uncheckedCast(message.data.deserialize()) - } - } - return messageFuture -} - -fun MessagingService.send(topic: String, sessionID: Long, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID()) { - send(TopicSession(topic, sessionID), payload, to, uuid) -} - -fun MessagingService.send(topicSession: TopicSession, payload: Any, to: MessageRecipients, uuid: UUID = UUID.randomUUID(), retryId: Long? = null) { - send(createMessage(topicSession, payload.serialize().bytes, uuid), to, retryId) -} +fun MessagingService.send(topicSession: String, payload: Any, to: MessageRecipients, deduplicationId: String = UUID.randomUUID().toString(), retryId: Long? = null) + = send(createMessage(topicSession, payload.serialize().bytes, deduplicationId), to, retryId) interface MessageHandlerRegistration -/** - * An identifier for the endpoint [MessagingService] message handlers listen at. - * - * @param topic identifier for the general subject of the message, for example "platform.network_map.fetch". - * The topic can be the empty string to match all messages (session ID must be [DEFAULT_SESSION_ID]). - * @param sessionID identifier for the session the message is part of. For services listening before - * a session is established, use [DEFAULT_SESSION_ID]. - */ -@CordaSerializable -data class TopicSession(val topic: String, val sessionID: Long = MessagingService.DEFAULT_SESSION_ID) { - fun isBlank() = topic.isBlank() && sessionID == MessagingService.DEFAULT_SESSION_ID - override fun toString(): String = "$topic.$sessionID" -} - /** * A message is defined, at this level, to be a (topic, timestamp, byte arrays) triple, where the topic is a string in * Java-style reverse dns form, with "platform." being a prefix reserved by the platform for its own use. Vendor @@ -226,10 +121,10 @@ data class TopicSession(val topic: String, val sessionID: Long = MessagingServic */ @CordaSerializable interface Message { - val topicSession: TopicSession - val data: ByteArray + val topic: String + val data: ByteSequence val debugTimestamp: Instant - val uniqueMessageId: UUID + val uniqueMessageId: String } // TODO Have ReceivedMessage point to the TLS certificate of the peer, and [peer] would simply be the subject DN of that. diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt index dd7358bad4..de9bf27f46 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/P2PMessagingClient.kt @@ -13,10 +13,7 @@ import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.contextLogger -import net.corda.core.utilities.sequence -import net.corda.core.utilities.trace +import net.corda.core.utilities.* import net.corda.node.VersionInfo import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.config.NodeConfiguration @@ -98,20 +95,19 @@ class P2PMessagingClient(config: NodeConfiguration, // that will handle messages, like a URL) with the terminology used by underlying MQ libraries, to avoid // confusion. private val topicProperty = SimpleString("platform-topic") - private val sessionIdProperty = SimpleString("session-id") private val cordaVendorProperty = SimpleString("corda-vendor") private val releaseVersionProperty = SimpleString("release-version") private val platformVersionProperty = SimpleString("platform-version") private val amqDelayMillis = System.getProperty("amq.delivery.delay.ms", "0").toInt() private val messageMaxRetryCount: Int = 3 - fun createProcessedMessage(): AppendOnlyPersistentMap { + fun createProcessedMessage(): AppendOnlyPersistentMap { return AppendOnlyPersistentMap( - toPersistentEntityKey = { it.toString() }, - fromPersistentEntity = { Pair(UUID.fromString(it.uuid), it.insertionTime) }, - toPersistentEntity = { key: UUID, value: Instant -> + toPersistentEntityKey = { it }, + fromPersistentEntity = { Pair(it.uuid, it.insertionTime) }, + toPersistentEntity = { key: String, value: Instant -> ProcessedMessage().apply { - uuid = key.toString() + uuid = key insertionTime = value } }, @@ -139,9 +135,9 @@ class P2PMessagingClient(config: NodeConfiguration, ) } - private class NodeClientMessage(override val topicSession: TopicSession, override val data: ByteArray, override val uniqueMessageId: UUID) : Message { + private class NodeClientMessage(override val topic: String, override val data: ByteSequence, override val uniqueMessageId: String) : Message { override val debugTimestamp: Instant = Instant.now() - override fun toString() = "$topicSession#${String(data)}" + override fun toString() = "$topic#${String(data.bytes)}" } } @@ -160,7 +156,7 @@ class P2PMessagingClient(config: NodeConfiguration, private val scheduledMessageRedeliveries = ConcurrentHashMap>() /** A registration to handle messages of different types */ - data class Handler(val topicSession: TopicSession, + data class Handler(val topic: String, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration private val cordaVendor = SimpleString(versionInfo.vendor) @@ -181,7 +177,7 @@ class P2PMessagingClient(config: NodeConfiguration, @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_ids") class ProcessedMessage( @Id - @Column(name = "message_id", length = 36) + @Column(name = "message_id", length = 64) var uuid: String = "", @Column(name = "insertion_time") @@ -192,7 +188,7 @@ class P2PMessagingClient(config: NodeConfiguration, @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}message_retry") class RetryMessage( @Id - @Column(name = "message_id", length = 36) + @Column(name = "message_id", length = 64) var key: Long = 0, @Lob @@ -383,14 +379,13 @@ class P2PMessagingClient(config: NodeConfiguration, private fun artemisToCordaMessage(message: ClientMessage): ReceivedMessage? { try { val topic = message.required(topicProperty) { getStringProperty(it) } - val sessionID = message.required(sessionIdProperty) { getLongProperty(it) } val user = requireNotNull(message.getStringProperty(HDR_VALIDATED_USER)) { "Message is not authenticated" } val platformVersion = message.required(platformVersionProperty) { getIntProperty(it) } // Use the magic deduplication property built into Artemis as our message identity too - val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { UUID.fromString(message.getStringProperty(it)) } - log.info("Received message from: ${message.address} user: $user topic: $topic sessionID: $sessionID uuid: $uuid") + val uuid = message.required(HDR_DUPLICATE_DETECTION_ID) { message.getStringProperty(it) } + log.info("Received message from: ${message.address} user: $user topic: $topic uuid: $uuid") - return ArtemisReceivedMessage(TopicSession(topic, sessionID), CordaX500Name.parse(user), platformVersion, uuid, message) + return ArtemisReceivedMessage(topic, CordaX500Name.parse(user), platformVersion, uuid, message) } catch (e: Exception) { log.error("Unable to process message, ignoring it: $message", e) return null @@ -402,21 +397,21 @@ class P2PMessagingClient(config: NodeConfiguration, return extractor(key) } - private class ArtemisReceivedMessage(override val topicSession: TopicSession, + private class ArtemisReceivedMessage(override val topic: String, override val peer: CordaX500Name, override val platformVersion: Int, - override val uniqueMessageId: UUID, + override val uniqueMessageId: String, private val message: ClientMessage) : ReceivedMessage { - override val data: ByteArray by lazy { ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } } + override val data: ByteSequence by lazy { OpaqueBytes(ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }) } override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp) - override fun toString() = "${topicSession.topic}#${data.sequence()}" + override fun toString() = "$topic#$data" } private fun deliver(msg: ReceivedMessage): Boolean { state.checkNotLocked() // Because handlers is a COW list, the loop inside filter will operate on a snapshot. Handlers being added // or removed whilst the filter is executing will not affect anything. - val deliverTo = handlers.filter { it.topicSession.isBlank() || it.topicSession == msg.topicSession } + val deliverTo = handlers.filter { it.topic.isBlank() || it.topic== msg.topic } try { // This will perform a BLOCKING call onto the executor. Thus if the handlers are slow, we will // be slow, and Artemis can handle that case intelligently. We don't just invoke the handler @@ -429,11 +424,11 @@ class P2PMessagingClient(config: NodeConfiguration, nodeExecutor.fetchFrom { database.transaction { if (msg.uniqueMessageId in processedMessages) { - log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topicSession}" } + log.trace { "Discard duplicate message ${msg.uniqueMessageId} for ${msg.topic}" } } else { if (deliverTo.isEmpty()) { // TODO: Implement dead letter queue, and send it there. - log.warn("Received message ${msg.uniqueMessageId} for ${msg.topicSession} that doesn't have any registered handlers yet") + log.warn("Received message ${msg.uniqueMessageId} for ${msg.topic} that doesn't have any registered handlers yet") } else { callHandlers(msg, deliverTo) } @@ -443,7 +438,7 @@ class P2PMessagingClient(config: NodeConfiguration, } } } catch (e: Exception) { - log.error("Caught exception whilst executing message handler for ${msg.topicSession}", e) + log.error("Caught exception whilst executing message handler for ${msg.topic}", e) } return true } @@ -501,7 +496,7 @@ class P2PMessagingClient(config: NodeConfiguration, } } - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { // We have to perform sending on a different thread pool, since using the same pool for messaging and // fibers leads to Netty buffer memory leaks, caused by both Netty and Quasar fiddling with thread-locals. messagingExecutor.fetchFrom { @@ -512,20 +507,18 @@ class P2PMessagingClient(config: NodeConfiguration, putStringProperty(cordaVendorProperty, cordaVendor) putStringProperty(releaseVersionProperty, releaseVersion) putIntProperty(platformVersionProperty, versionInfo.platformVersion) - putStringProperty(topicProperty, SimpleString(message.topicSession.topic)) - putLongProperty(sessionIdProperty, message.topicSession.sessionID) - writeBodyBufferBytes(message.data) + putStringProperty(topicProperty, SimpleString(message.topic)) + writeBodyBufferBytes(message.data.bytes) // Use the magic deduplication property built into Artemis as our message identity too putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(message.uniqueMessageId.toString())) // For demo purposes - if set then add a delay to messages in order to demonstrate that the flows are doing as intended - if (amqDelayMillis > 0 && message.topicSession.topic == StateMachineManagerImpl.sessionTopic.topic) { + if (amqDelayMillis > 0 && message.topic == StateMachineManagerImpl.sessionTopic) { putLongProperty(HDR_SCHEDULED_DELIVERY_TIME, System.currentTimeMillis() + amqDelayMillis) } } log.trace { - "Send to: $mqAddress topic: ${message.topicSession.topic} " + - "sessionID: ${message.topicSession.sessionID} uuid: ${message.uniqueMessageId}" + "Send to: $mqAddress topic: ${message.topic} uuid: ${message.uniqueMessageId}" } artemis.producer.send(mqAddress, artemisMessage) retryId?.let { @@ -539,14 +532,12 @@ class P2PMessagingClient(config: NodeConfiguration, } } } - acknowledgementHandler?.invoke() } - override fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)?) { + override fun send(addressedMessages: List) { for ((message, target, retryId, sequenceKey) in addressedMessages) { - send(message, target, retryId, sequenceKey, null) + send(message, target, retryId, sequenceKey) } - acknowledgementHandler?.invoke() } private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) { @@ -622,15 +613,9 @@ class P2PMessagingClient(config: NodeConfiguration, } override fun addMessageHandler(topic: String, - sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { - return addMessageHandler(TopicSession(topic, sessionID), callback) - } - - override fun addMessageHandler(topicSession: TopicSession, - callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { - require(!topicSession.isBlank()) { "Topic must not be blank, as the empty topic is a special case." } - val handler = Handler(topicSession, callback) + require(!topic.isBlank()) { "Topic must not be blank, as the empty topic is a special case." } + val handler = Handler(topic, callback) handlers.add(handler) return handler } @@ -639,9 +624,9 @@ class P2PMessagingClient(config: NodeConfiguration, handlers.remove(registration) } - override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message { + override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message { // TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying. - return NodeClientMessage(topicSession, data, uuid) + return NodeClientMessage(topic, OpaqueBytes(data), deduplicationId) } // TODO Rethink PartyInfo idea and merging PeerAddress/ServiceAddress (the only difference is that Service address doesn't hold host and port) diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt deleted file mode 100644 index 15c68a3b66..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt +++ /dev/null @@ -1,27 +0,0 @@ -package net.corda.node.services.messaging - -import net.corda.core.concurrent.CordaFuture -import net.corda.core.messaging.MessageRecipients -import net.corda.core.messaging.SingleMessageRecipient -import net.corda.core.serialization.CordaSerializable - -/** - * Abstract superclass for request messages sent to services which expect a reply. - */ -@CordaSerializable -interface ServiceRequestMessage { - val sessionID: Long - val replyTo: SingleMessageRecipient -} - -/** - * Sends a [ServiceRequestMessage] to [target] and returns a [CordaFuture] of the response. - * @param R The type of the response. - */ -fun MessagingService.sendRequest(topic: String, - request: ServiceRequestMessage, - target: MessageRecipients): CordaFuture { - val responseFuture = onNext(topic, request.sessionID) - send(topic, MessagingService.DEFAULT_SESSION_ID, request, target) - return responseFuture -} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt index bd29525072..65b24a7046 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowIORequest.kt @@ -2,6 +2,7 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Suspendable import net.corda.core.crypto.SecureHash +import net.corda.core.identity.Party import java.time.Instant interface FlowIORequest { @@ -22,29 +23,28 @@ interface SendRequest : SessionedFlowIORequest { val message: SessionMessage } -interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest { - val receiveType: Class +interface ReceiveRequest : SessionedFlowIORequest, WaitingRequest { val userReceiveType: Class<*>? override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = this.session === session } -data class SendAndReceive(override val session: FlowSessionInternal, - override val message: SessionMessage, - override val receiveType: Class, - override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest { +data class SendAndReceive(override val session: FlowSessionInternal, + override val message: SessionMessage, + override val userReceiveType: Class<*>?) : SendRequest, ReceiveRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } -data class ReceiveOnly(override val session: FlowSessionInternal, - override val receiveType: Class, - override val userReceiveType: Class<*>?) : ReceiveRequest { +data class ReceiveOnly( + override val session: FlowSessionInternal, + override val userReceiveType: Class<*>? +) : ReceiveRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() } -class ReceiveAll(val requests: List>) : WaitingRequest { +class ReceiveAll(val requests: List) : WaitingRequest { @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() @@ -53,8 +53,8 @@ class ReceiveAll(val requests: List>) : WaitingReque } private fun shouldResumeIfRelevant() = requests.all { hasSuccessfulEndMessage(it) } - private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean { - return it.session.receivedMessages.map { it.message }.any { it is SessionData || it is SessionEnd } + private fun hasSuccessfulEndMessage(it: ReceiveRequest): Boolean { + return it.session.receivedMessages.map { it.message.payload }.any { it is DataSessionMessage || it is EndSessionMessage } } @Suspendable @@ -70,7 +70,7 @@ class ReceiveAll(val requests: List>) : WaitingReque if (isComplete(receivedMessages)) { receivedMessages } else { - throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a ${it.receiveType.simpleName} but instead got nothing for $it." }.joinToString { "\n" }) + throw IllegalStateException(requests.filter { it.session !in receivedMessages.keys }.map { "Was expecting a message but instead got nothing for $it." }.joinToString { "\n" }) } } } @@ -90,15 +90,15 @@ class ReceiveAll(val requests: List>) : WaitingReque } @Suspendable - private fun poll(request: ReceiveRequest): ReceivedSessionMessage<*>? { - return request.session.receivedMessages.poll() + private fun poll(request: ReceiveRequest): ExistingSessionMessage? { + return request.session.receivedMessages.poll()?.message } override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = isRelevant(session) && shouldResumeIfRelevant() private fun isRelevant(session: FlowSessionInternal) = requests.any { it.session === session } - data class RequestMessage(val request: ReceiveRequest, val message: ReceivedSessionMessage<*>) + data class RequestMessage(val request: ReceiveRequest, val message: ExistingSessionMessage) } data class SendOnly(override val session: FlowSessionInternal, override val message: SessionMessage) : SendRequest { @@ -110,7 +110,7 @@ data class WaitForLedgerCommit(val hash: SecureHash, val fiber: FlowStateMachine @Transient override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot() - override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message is ErrorSessionEnd + override fun shouldResume(message: ExistingSessionMessage, session: FlowSessionInternal): Boolean = message.payload is ErrorSessionMessage } data class Sleep(val until: Instant, val fiber: FlowStateMachineImpl<*>) : FlowIORequest { diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt index dc5b39c6f5..58c134e39c 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSessionInternal.kt @@ -17,18 +17,28 @@ import java.util.concurrent.ConcurrentLinkedQueue class FlowSessionInternal( val flow: FlowLogic<*>, val flowSession : FlowSession, - val ourSessionId: Long, + val ourSessionId: SessionId, val initiatingParty: Party?, var state: FlowSessionState, var retryable: Boolean = false) { - val receivedMessages = ConcurrentLinkedQueue>() + val receivedMessages = ConcurrentLinkedQueue() val fiber: FlowStateMachineImpl<*> get() = flow.stateMachine as FlowStateMachineImpl<*> override fun toString(): String { return "${javaClass.simpleName}(flow=$flow, ourSessionId=$ourSessionId, initiatingParty=$initiatingParty, state=$state)" } + + fun getPeerSessionId(): SessionId { + val sessionState = state + return when (sessionState) { + is FlowSessionState.Initiated -> sessionState.peerSessionId + else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $this") + } + } } +data class ReceivedSessionMessage(val peerParty: Party, val message: ExistingSessionMessage) + /** * [FlowSessionState] describes the session's state. * @@ -50,7 +60,7 @@ sealed class FlowSessionState { override val sendToParty: Party get() = otherParty } - data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowInfo) : FlowSessionState() { + data class Initiated(val peerParty: Party, val peerSessionId: SessionId, val context: FlowInfo) : FlowSessionState() { override val sendToParty: Party get() = peerParty } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index fb3b06c737..7fc71e37c7 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -9,7 +9,7 @@ import com.google.common.primitives.Primitives import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.random63BitValue +import net.corda.core.crypto.newSecureRandom import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate @@ -31,6 +31,7 @@ import net.corda.nodeapi.internal.persistence.contextTransaction import net.corda.nodeapi.internal.persistence.contextTransactionOrNull import org.slf4j.Logger import org.slf4j.LoggerFactory +import java.io.IOException import java.nio.file.Paths import java.sql.SQLException import java.time.Duration @@ -180,7 +181,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, requireNonPrimitive(receiveType) logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) - val receivedSessionData: ReceivedSessionMessage = if (session == null) { + val receivedSessionMessage: ReceivedSessionMessage = if (session == null) { val newSession = initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend) // Only do a receive here as the session init has carried the payload receiveInternal(newSession, receiveType) @@ -188,8 +189,20 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val sendData = createSessionData(session, payload) sendAndReceiveInternal(session, sendData, receiveType) } - logger.debug { "Received ${receivedSessionData.message.payload.toString().abbreviate(300)}" } - return receivedSessionData.checkPayloadIs(receiveType) + val sessionData = receivedSessionMessage.message.checkDataSessionMessage() + logger.debug { "Received ${sessionData.payload.toString().abbreviate(300)}" } + return sessionData.checkPayloadIs(receiveType) + } + + private fun ExistingSessionMessage.checkDataSessionMessage(): DataSessionMessage { + when (payload) { + is DataSessionMessage -> { + return payload + } + else -> { + throw IllegalStateException("Was expecting ${DataSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead") + } + } } @Suspendable @@ -200,9 +213,9 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, requireNonPrimitive(receiveType) logger.debug { "receive(${receiveType.name}, $otherParty) ..." } val session = getConfirmedSession(otherParty, sessionFlow) - val sessionData = receiveInternal(session, receiveType) - logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" } - return sessionData.checkPayloadIs(receiveType) + val receivedSessionMessage = receiveInternal(session, receiveType).message.checkDataSessionMessage() + logger.debug { "Received ${receivedSessionMessage.payload.toString().abbreviate(300)}" } + return receivedSessionMessage.checkPayloadIs(receiveType) } private fun requireNonPrimitive(receiveType: Class<*>) { @@ -219,7 +232,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, // Don't send the payload again if it was already piggy-backed on a session init initiateSession(otherParty, sessionFlow, payload, waitForConfirmation = false) } else { - sendInternal(session, createSessionData(session, payload)) + sendInternal(session, ExistingSessionMessage(session.getPeerSessionId(), createSessionData(session, payload))) } } @@ -236,8 +249,8 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, // If the tx isn't committed then we may have been resumed due to an session ending in an error for (session in openSessions.values) { for (receivedMessage in session.receivedMessages) { - if (receivedMessage.message is ErrorSessionEnd) { - session.erroredEnd(receivedMessage.message) + if (receivedMessage.message.payload is ErrorSessionMessage) { + session.erroredEnd(receivedMessage.message.payload.flowException) } } } @@ -294,16 +307,18 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable override fun receiveAll(sessions: Map>, sessionFlow: FlowLogic<*>): Map> { - val requests = ArrayList>() + val requests = ArrayList() for ((session, receiveType) in sessions) { val sessionInternal = getConfirmedSession(session.counterparty, sessionFlow) - requests.add(ReceiveOnly(sessionInternal, SessionData::class.java, receiveType)) + requests.add(ReceiveOnly(sessionInternal, receiveType)) } val receivedMessages = ReceiveAll(requests).suspendAndExpectReceive(suspend) val result = LinkedHashMap>() for ((sessionInternal, requestAndMessage) in receivedMessages) { - val message = requestAndMessage.message.confirmReceiveType(requestAndMessage.request) - result[sessionInternal.flowSession] = message.checkPayloadIs(requestAndMessage.request.userReceiveType as Class) + val message = requestAndMessage.message.confirmNoError(requestAndMessage.request.session) + result[sessionInternal.flowSession] = message.checkDataSessionMessage().checkPayloadIs( + requestAndMessage.request.userReceiveType as Class + ) } return result } @@ -315,41 +330,46 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, */ @Suspendable private fun FlowSessionInternal.waitForConfirmation() { - val (peerParty, sessionInitResponse) = receiveInternal(this, null) - if (sessionInitResponse is SessionConfirm) { - state = FlowSessionState.Initiated( - peerParty, - sessionInitResponse.initiatedSessionId, - FlowInfo(sessionInitResponse.flowVersion, sessionInitResponse.appName)) - } else { - sessionInitResponse as SessionReject - throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}") + val sessionInitResponse = receiveInternal(this, null) + val payload = sessionInitResponse.message.payload + when (payload) { + is ConfirmSessionMessage -> { + state = FlowSessionState.Initiated( + sessionInitResponse. + peerParty, + payload.initiatedSessionId, + payload.initiatedFlowInfo) + } + is RejectSessionMessage -> { + throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${payload.message}") + } + else -> { + throw IllegalStateException("Was expecting ${ConfirmSessionMessage::class.java.simpleName} but got ${payload.javaClass.simpleName} instead") + } } } - private fun createSessionData(session: FlowSessionInternal, payload: Any): SessionData { - val sessionState = session.state - val peerSessionId = when (sessionState) { - is FlowSessionState.Initiated -> sessionState.peerSessionId - else -> throw IllegalStateException("We've somehow held onto a non-initiated session: $session") - } - return SessionData(peerSessionId, payload.serialize(context = SerializationDefaults.P2P_CONTEXT)) + private fun createSessionData(session: FlowSessionInternal, payload: Any): DataSessionMessage { + return DataSessionMessage(payload.serialize(context = SerializationDefaults.P2P_CONTEXT)) } @Suspendable private fun sendInternal(session: FlowSessionInternal, message: SessionMessage) = suspend(SendOnly(session, message)) - private inline fun receiveInternal( + @Suspendable + private fun receiveInternal( session: FlowSessionInternal, - userReceiveType: Class<*>?): ReceivedSessionMessage { - return waitForMessage(ReceiveOnly(session, M::class.java, userReceiveType)) + userReceiveType: Class<*>?): ReceivedSessionMessage { + return waitForMessage(ReceiveOnly(session, userReceiveType)) } - private inline fun sendAndReceiveInternal( + @Suspendable + private fun sendAndReceiveInternal( session: FlowSessionInternal, - message: SessionMessage, - userReceiveType: Class<*>?): ReceivedSessionMessage { - return waitForMessage(SendAndReceive(session, message, M::class.java, userReceiveType)) + message: DataSessionMessage, + userReceiveType: Class<*>?): ReceivedSessionMessage { + val sessionMessage = ExistingSessionMessage(session.getPeerSessionId(), message) + return waitForMessage(SendAndReceive(session, sessionMessage, userReceiveType)) } @Suspendable @@ -377,7 +397,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, sessionFlow: FlowLogic<*> ) { logger.trace { "Creating a new session with $otherParty" } - val session = FlowSessionInternal(sessionFlow, flowSession, random63BitValue(), null, FlowSessionState.Uninitiated(otherParty)) + val session = FlowSessionInternal(sessionFlow, flowSession, SessionId.createRandom(newSecureRandom()), null, FlowSessionState.Uninitiated(otherParty)) openSessions[Pair(sessionFlow, otherParty)] = session } @@ -397,7 +417,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val (version, initiatingFlowClass) = session.flow.javaClass.flowVersionAndInitiatingClass val payloadBytes = firstPayload?.serialize(context = SerializationDefaults.P2P_CONTEXT) logger.info("Initiating flow session with party ${otherParty.name}. Session id for tracing purposes is ${session.ourSessionId}.") - val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes) + val sessionInit = InitialSessionMessage(session.ourSessionId, newSecureRandom().nextLong(), initiatingFlowClass.name, version, session.flow.javaClass.appName, payloadBytes) sendInternal(session, sessionInit) if (waitForConfirmation) { session.waitForConfirmation() @@ -406,8 +426,10 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage { - return receiveRequest.suspendAndExpectReceive().confirmReceiveType(receiveRequest) + private fun waitForMessage(receiveRequest: ReceiveRequest): ReceivedSessionMessage { + val receivedMessage = receiveRequest.suspendAndExpectReceive() + receivedMessage.message.confirmNoError(receiveRequest.session) + return receivedMessage } private val suspend : ReceiveAll.Suspend = object : ReceiveAll.Suspend { @@ -418,7 +440,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun ReceiveRequest<*>.suspendAndExpectReceive(): ReceivedSessionMessage<*> { + private fun ReceiveRequest.suspendAndExpectReceive(): ReceivedSessionMessage { val polledMessage = session.receivedMessages.poll() return if (polledMessage != null) { if (this is SendAndReceive) { @@ -431,35 +453,36 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, // Suspend while we wait for a receive suspend(this) session.receivedMessages.poll() ?: - throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got nothing for $this") + throw IllegalStateException("Was expecting a message but instead got nothing for $this") } } - private fun ReceivedSessionMessage<*>.confirmReceiveType( - receiveRequest: ReceiveRequest): ReceivedSessionMessage { - val session = receiveRequest.session - val receiveType = receiveRequest.receiveType - if (receiveType.isInstance(message)) { - return uncheckedCast(this) - } else if (message is SessionEnd) { - openSessions.values.remove(session) - if (message is ErrorSessionEnd) { - session.erroredEnd(message) - } else { - val expectedType = receiveRequest.userReceiveType?.name ?: receiveType.simpleName - throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " + - "sending a $expectedType") + private fun ExistingSessionMessage.confirmNoError(session: FlowSessionInternal): ExistingSessionMessage { + when (payload) { + is ConfirmSessionMessage, + is DataSessionMessage -> { + return this + } + is ErrorSessionMessage -> { + openSessions.values.remove(session) + session.erroredEnd(payload.flowException) + } + is RejectSessionMessage -> { + session.erroredEnd(UnexpectedFlowEndException("Counterparty sent session rejection message at unexpected time with message ${payload.message}")) + } + EndSessionMessage -> { + openSessions.values.remove(session) + throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " + + "sending data") } - } else { - throw IllegalStateException("Was expecting a ${receiveType.simpleName} but instead got $message for $receiveRequest") } } - private fun FlowSessionInternal.erroredEnd(end: ErrorSessionEnd): Nothing { - if (end.errorResponse != null) { + private fun FlowSessionInternal.erroredEnd(exception: Throwable?): Nothing { + if (exception != null) { @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - (end.errorResponse as java.lang.Throwable).fillInStackTrace() - throw end.errorResponse + exception.fillInStackTrace() + throw exception } else { throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated") } @@ -560,3 +583,15 @@ val Class>.appName: String "" } } + +fun DataSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { + val payloadData: T = try { + val serializer = SerializationDefaults.SERIALIZATION_FACTORY + serializer.deserialize(payload, type, SerializationDefaults.P2P_CONTEXT) + } catch (ex: Exception) { + throw IOException("Payload invalid", ex) + } + return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?: + throw UnexpectedFlowEndException("We were expecting a ${type.name} but we instead got a " + + "${payloadData.javaClass.name} (${payloadData})") +} diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt index c321d3768a..5481275f05 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/SessionMessage.kt @@ -1,59 +1,122 @@ package net.corda.node.services.statemachine +import net.corda.core.crypto.random63BitValue import net.corda.core.flows.FlowException -import net.corda.core.flows.UnexpectedFlowEndException -import net.corda.core.identity.Party -import net.corda.core.internal.castIfPossible +import net.corda.core.flows.FlowInfo import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializedBytes -import net.corda.core.utilities.UntrustworthyData -import java.io.IOException +import java.security.SecureRandom + +/** + * A session between two flows is identified by two session IDs, the initiating and the initiated session ID. + * However after the session has been established the communication is symmetric. From then on we differentiate between + * the two session IDs with "source" ID (the ID from which we receive) and "sink" ID (the ID to which we send). + * + * Flow A (initiating) Flow B (initiated) + * initiatingId=sourceId=0 + * send(Initiate(initiatingId=0)) -----> initiatingId=sinkId=0 + * initiatedId=sourceId=1 + * initiatedId=sinkId=1 <----- send(Confirm(initiatedId=1)) + */ +@CordaSerializable +sealed class SessionMessage + @CordaSerializable -interface SessionMessage - -interface ExistingSessionMessage : SessionMessage { - val recipientSessionId: Long -} - -interface SessionInitResponse : ExistingSessionMessage { - val initiatorSessionId: Long - override val recipientSessionId: Long get() = initiatorSessionId -} - -interface SessionEnd : ExistingSessionMessage - -data class SessionInit(val initiatorSessionId: Long, - val initiatingFlowClass: String, - val flowVersion: Int, - val appName: String, - val firstPayload: SerializedBytes?) : SessionMessage - -data class SessionConfirm(override val initiatorSessionId: Long, - val initiatedSessionId: Long, - val flowVersion: Int, - val appName: String) : SessionInitResponse - -data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse - -data class SessionData(override val recipientSessionId: Long, val payload: SerializedBytes) : ExistingSessionMessage - -data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd - -data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd - -data class ReceivedSessionMessage(val sender: Party, val message: M) - -fun ReceivedSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { - val payloadData: T = try { - val serializer = SerializationDefaults.SERIALIZATION_FACTORY - serializer.deserialize(message.payload, type, SerializationDefaults.P2P_CONTEXT) - } catch (ex: Exception) { - throw IOException("Payload invalid", ex) +data class SessionId(val toLong: Long) { + companion object { + fun createRandom(secureRandom: SecureRandom) = SessionId(secureRandom.nextLong()) } - return type.castIfPossible(payloadData)?.let { UntrustworthyData(it) } ?: - throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " + - "${payloadData.javaClass.name} (${payloadData})") - } + +/** + * The initial message to initiate a session with. + * + * @param initiatorSessionId the session ID of the initiator. On the sending side this is the *source* ID, on the + * receiving side this is the *sink* ID. + * @param initiationEntropy additional randomness to seed the initiated flow's deduplication ID. + * @param initiatorFlowClassName the class name to be used to determine the initiating-initiated mapping on the receiver + * side. + * @param flowVersion the version of the initiating flow. + * @param appName the name of the cordapp defining the initiating flow, or "corda" if it's a core flow. + * @param firstPayload the optional first payload. + */ +data class InitialSessionMessage( + val initiatorSessionId: SessionId, + val initiationEntropy: Long, + val initiatorFlowClassName: String, + val flowVersion: Int, + val appName: String, + val firstPayload: SerializedBytes? +) : SessionMessage() { + override fun toString() = "InitialSessionMessage(" + + "initiatorSessionId=$initiatorSessionId, " + + "initiationEntropy=$initiationEntropy, " + + "initiatorFlowClassName=$initiatorFlowClassName, " + + "appName=$appName, " + + "firstPayload=${firstPayload?.javaClass}" + + ")" +} + +/** + * A message sent when a session has been established already. + * + * @param recipientSessionId the recipient session ID. On the sending side this is the *sink* ID, on the receiving side + * this is the *source* ID. + * @param payload the rest of the message. + */ +data class ExistingSessionMessage( + val recipientSessionId: SessionId, + val payload: ExistingSessionMessagePayload +) : SessionMessage() + +/** + * The payload of an [ExistingSessionMessage] + */ +@CordaSerializable +sealed class ExistingSessionMessagePayload + +/** + * The confirmation message sent by the initiated side. + * @param initiatedSessionId the initiated session ID, the other half of [InitialSessionMessage.initiatorSessionId]. + * This is the *source* ID on the sending(initiated) side, and the *sink* ID on the receiving(initiating) side. + */ +data class ConfirmSessionMessage( + val initiatedSessionId: SessionId, + val initiatedFlowInfo: FlowInfo +) : ExistingSessionMessagePayload() + +/** + * A message containing flow-related data. + * + * @param payload the serialised payload. + */ +data class DataSessionMessage(val payload: SerializedBytes) : ExistingSessionMessagePayload() { + override fun toString() = "DataSessionMessage(payload=${payload.javaClass})" +} + +/** + * A message indicating that an error has happened. + * + * @param flowException the exception that happened. This is null if the error condition wasn't revealed to the + * receiving side. + * @param errorId the ID of the source error. This is always specified to allow posteriori correlation of error conditions. + */ +data class ErrorSessionMessage(val flowException: FlowException?, val errorId: Long) : ExistingSessionMessagePayload() + +/** + * A message indicating that a session initiation has failed. + * + * @param message a message describing the problem to the initator. + * @param errorId an error ID identifying this error condition. + */ +data class RejectSessionMessage(val message: String, val errorId: Long) : ExistingSessionMessagePayload() + +/** + * A message indicating that the flow hosting the session has ended. Note that this message is strictly part of the + * session protocol, the flow may be removed before all counter-flows have ended. + * + * The sole purpose of this message currently is to provide diagnostic in cases where the two communicating flows' + * protocols don't match up, e.g. one is waiting for the other, but the other side has already finished. + */ +object EndSessionMessage : ExistingSessionMessagePayload() diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt index 500529aa5c..f6b7ebfe01 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt @@ -13,6 +13,7 @@ import net.corda.core.CordaException import net.corda.core.concurrent.CordaFuture import net.corda.core.context.InvocationContext import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.newSecureRandom import net.corda.core.crypto.random63BitValue import net.corda.core.flows.FlowException import net.corda.core.flows.FlowInfo @@ -37,7 +38,6 @@ import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.config.shouldCheckCheckpoints import net.corda.node.services.messaging.ReceivedMessage -import net.corda.node.services.messaging.TopicSession import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.newNamedSingleThreadExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence @@ -72,7 +72,7 @@ class StateMachineManagerImpl( companion object { private val logger = contextLogger() - internal val sessionTopic = TopicSession("platform.session") + internal val sessionTopic = "platform.session" init { Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable -> @@ -121,8 +121,8 @@ class StateMachineManagerImpl( private val totalStartedFlows = metrics.counter("Flows.Started") private val totalFinishedFlows = metrics.counter("Flows.Finished") - private val openSessions = ConcurrentHashMap() - private val recentlyClosedSessions = ConcurrentHashMap() + private val openSessions = ConcurrentHashMap() + private val recentlyClosedSessions = ConcurrentHashMap() // Context for tokenized services in checkpoints private lateinit var tokenizableServices: List @@ -281,7 +281,7 @@ class StateMachineManagerImpl( if (sender != null) { when (sessionMessage) { is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender) - is SessionInit -> onSessionInit(sessionMessage, message, sender) + is InitialSessionMessage -> onSessionInit(sessionMessage, message, sender) } } else { logger.error("Unknown peer $peer in $sessionMessage") @@ -294,15 +294,15 @@ class StateMachineManagerImpl( session.fiber.pushToLoggingContext() session.fiber.logger.trace { "Received $message on $session from $sender" } if (session.retryable) { - if (message is SessionConfirm && session.state is FlowSessionState.Initiated) { + if (message.payload is ConfirmSessionMessage && session.state is FlowSessionState.Initiated) { session.fiber.logger.trace { "Ignoring duplicate confirmation for session ${session.ourSessionId} – session is idempotent" } return } - if (message !is SessionConfirm) { - serviceHub.networkService.cancelRedelivery(session.ourSessionId) + if (message.payload !is ConfirmSessionMessage) { + serviceHub.networkService.cancelRedelivery(session.ourSessionId.toLong) } } - if (message is SessionEnd) { + if (message.payload is EndSessionMessage || message.payload is ErrorSessionMessage) { openSessions.remove(message.recipientSessionId) } session.receivedMessages += ReceivedSessionMessage(sender, message) @@ -317,9 +317,9 @@ class StateMachineManagerImpl( } else { val peerParty = recentlyClosedSessions.remove(message.recipientSessionId) if (peerParty != null) { - if (message is SessionConfirm) { + if (message.payload is ConfirmSessionMessage) { logger.trace { "Received session confirmation but associated fiber has already terminated, so sending session end" } - sendSessionMessage(peerParty, NormalSessionEnd(message.initiatedSessionId)) + sendSessionMessage(peerParty, ExistingSessionMessage(message.payload.initiatedSessionId, EndSessionMessage)) } else { logger.trace { "Ignoring session end message for already closed session: $message" } } @@ -336,12 +336,12 @@ class StateMachineManagerImpl( return waitingForResponse?.shouldResume(message, session) ?: false } - private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { + private fun onSessionInit(sessionInit: InitialSessionMessage, receivedMessage: ReceivedMessage, sender: Party) { logger.trace { "Received $sessionInit from $sender" } val senderSessionId = sessionInit.initiatorSessionId - fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(senderSessionId, message)) + fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, random63BitValue()))) val (session, initiatedFlowFactory) = try { val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) @@ -354,11 +354,11 @@ class StateMachineManagerImpl( val session = FlowSessionInternal( flow, flowSession, - random63BitValue(), + SessionId.createRandom(newSecureRandom()), sender, FlowSessionState.Initiated(sender, senderSessionId, FlowInfo(senderFlowVersion, sessionInit.appName))) if (sessionInit.firstPayload != null) { - session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) + session.receivedMessages += ReceivedSessionMessage(sender, ExistingSessionMessage(session.ourSessionId, DataSessionMessage(sessionInit.firstPayload))) } openSessions[session.ourSessionId] = session val context = InvocationContext.peer(sender.name) @@ -386,19 +386,19 @@ class StateMachineManagerImpl( is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName } - sendSessionMessage(sender, SessionConfirm(senderSessionId, session.ourSessionId, ourFlowVersion, appName), session.fiber) - session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass}" } + sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, ConfirmSessionMessage(session.ourSessionId, FlowInfo(ourFlowVersion, appName))), session.fiber) + session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatorFlowClassName}" } session.fiber.logger.trace { "Initiated from $sessionInit on $session" } resumeFiber(session.fiber) } - private fun getInitiatedFlowFactory(sessionInit: SessionInit): InitiatedFlowFactory<*> { + private fun getInitiatedFlowFactory(sessionInit: InitialSessionMessage): InitiatedFlowFactory<*> { val initiatingFlowClass = try { - Class.forName(sessionInit.initiatingFlowClass, true, classloader).asSubclass(FlowLogic::class.java) + Class.forName(sessionInit.initiatorFlowClassName, true, classloader).asSubclass(FlowLogic::class.java) } catch (e: ClassNotFoundException) { - throw SessionRejectException("Don't know ${sessionInit.initiatingFlowClass}") + throw SessionRejectException("Don't know ${sessionInit.initiatorFlowClassName}") } catch (e: ClassCastException) { - throw SessionRejectException("${sessionInit.initiatingFlowClass} is not a flow") + throw SessionRejectException("${sessionInit.initiatorFlowClassName} is not a flow") } return serviceHub.getFlowFactory(initiatingFlowClass) ?: throw SessionRejectException("$initiatingFlowClass is not registered") @@ -492,7 +492,7 @@ class StateMachineManagerImpl( private fun FlowSessionInternal.endSession(context: InvocationContext, exception: Throwable?, propagated: Boolean) { val initiatedState = state as? FlowSessionState.Initiated ?: return val sessionEnd = if (exception == null) { - NormalSessionEnd(initiatedState.peerSessionId) + EndSessionMessage } else { val errorResponse = if (exception is FlowException && (!propagated || initiatingParty != null)) { // Only propagate this FlowException if our local flow threw it or it was propagated to us and we only @@ -501,9 +501,9 @@ class StateMachineManagerImpl( } else { null } - ErrorSessionEnd(initiatedState.peerSessionId, errorResponse) + ErrorSessionMessage(errorResponse, 0) } - sendSessionMessage(initiatedState.peerParty, sessionEnd, fiber) + sendSessionMessage(initiatedState.peerParty, ExistingSessionMessage(initiatedState.peerSessionId, sessionEnd), fiber) recentlyClosedSessions[ourSessionId] = initiatedState.peerParty } @@ -573,14 +573,14 @@ class StateMachineManagerImpl( } private fun processSendRequest(ioRequest: SendRequest) { - val retryId = if (ioRequest.message is SessionInit) { + val retryId = if (ioRequest.message is InitialSessionMessage) { with(ioRequest.session) { openSessions[ourSessionId] = this - if (retryable) ourSessionId else null + if (retryable) ourSessionId.toLong else null } } else null sendSessionMessage(ioRequest.session.state.sendToParty, ioRequest.message, ioRequest.session.fiber, retryId) - if (ioRequest !is ReceiveRequest<*>) { + if (ioRequest !is ReceiveRequest) { // We sent a message, but don't expect a response, so re-enter the continuation to let it keep going. resumeFiber(ioRequest.session.fiber) } @@ -625,12 +625,15 @@ class StateMachineManagerImpl( // Handling Kryo and AMQP serialization problems. Unfortunately the two exception types do not share much of a common exception interface. is KryoException, is NotSerializableException -> { - if (message !is ErrorSessionEnd || message.errorResponse == null) throw e - logger.warn("Something in ${message.errorResponse.javaClass.name} is not serialisable. " + - "Instead sending back an exception which is serialisable to ensure session end occurs properly.", e) - // The subclass may have overridden toString so we use that - val exMessage = message.errorResponse.let { if (it.javaClass != FlowException::class.java) it.toString() else it.message } - message.copy(errorResponse = FlowException(exMessage)).serialize() + if (message is ExistingSessionMessage && message.payload is ErrorSessionMessage && message.payload.flowException != null) { + logger.warn("Something in ${message.payload.flowException.javaClass.name} is not serialisable. " + + "Instead sending back an exception which is serialisable to ensure session end occurs properly.", e) + // The subclass may have overridden toString so we use that + val exMessage = message.payload.flowException.message + message.copy(payload = message.payload.copy(flowException = FlowException(exMessage))).serialize() + } else { + throw e + } } else -> throw e } diff --git a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt index 0f57a4dc3c..425d2f0e94 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt @@ -3,7 +3,6 @@ package net.corda.node.messaging import net.corda.core.messaging.AllPossibleRecipients import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.TopicStringValidator -import net.corda.node.services.messaging.createMessage import net.corda.testing.internal.rigorousMock import net.corda.testing.node.MockNetwork import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY @@ -51,10 +50,10 @@ class InMemoryMessagingTests { val bits = "test-content".toByteArray() var finalDelivery: Message? = null - node2.network.addMessageHandler { msg, _ -> + node2.network.addMessageHandler("test.topic") { msg, _ -> node2.network.send(msg, node3.network.myAddress) } - node3.network.addMessageHandler { msg, _ -> + node3.network.addMessageHandler("test.topic") { msg, _ -> finalDelivery = msg } @@ -63,7 +62,7 @@ class InMemoryMessagingTests { mockNet.runNetwork(rounds = 1) - assertTrue(Arrays.equals(finalDelivery!!.data, bits)) + assertTrue(Arrays.equals(finalDelivery!!.data.bytes, bits)) } @Test @@ -75,7 +74,7 @@ class InMemoryMessagingTests { val bits = "test-content".toByteArray() var counter = 0 - listOf(node1, node2, node3).forEach { it.network.addMessageHandler { _, _ -> counter++ } } + listOf(node1, node2, node3).forEach { it.network.addMessageHandler("test.topic") { _, _ -> counter++ } } node1.network.send(node2.network.createMessage("test.topic", data = bits), rigorousMock()) mockNet.runNetwork(rounds = 1) assertEquals(3, counter) diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt index 7a464d548b..e17bfe86ac 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTest.kt @@ -126,7 +126,7 @@ class ArtemisMessagingTest { messagingClient.send(message, messagingClient.myAddress) val actual: Message = receivedMessages.take() - assertEquals("first msg", String(actual.data)) + assertEquals("first msg", String(actual.data.bytes)) assertNull(receivedMessages.poll(200, MILLISECONDS)) } diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index 2ac326ff2a..e16037795a 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -37,8 +37,8 @@ import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStra import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode import net.corda.testing.node.MockNodeParameters -import net.corda.testing.node.pumpReceive import net.corda.testing.node.internal.startFlow +import net.corda.testing.node.pumpReceive import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType @@ -296,14 +296,16 @@ class FlowFrameworkTests { mockNet.runNetwork() assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { resultFuture.getOrThrow() - }.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive + } } @Test fun `receiving unexpected session end before entering sendAndReceive`() { bobNode.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } val sessionEndReceived = Semaphore(0) - receivedSessionMessagesObservable().filter { it.message is SessionEnd }.subscribe { sessionEndReceived.release() } + receivedSessionMessagesObservable().filter { + it.message is ExistingSessionMessage && it.message.payload is EndSessionMessage + }.subscribe { sessionEndReceived.release() } val resultFuture = aliceNode.services.startFlow( WaitForOtherSideEndBeforeSendAndReceive(bob, sessionEndReceived)).resultFuture mockNet.runNetwork() @@ -356,7 +358,7 @@ class FlowFrameworkTests { assertSessionTransfers( aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, bobNode sent sessionConfirm() to aliceNode, - bobNode sent erroredEnd() to aliceNode + bobNode sent errorMessage() to aliceNode ) } @@ -389,10 +391,11 @@ class FlowFrameworkTests { assertSessionTransfers( aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, bobNode sent sessionConfirm() to aliceNode, - bobNode sent erroredEnd(erroringFlow.get().exceptionThrown) to aliceNode + bobNode sent errorMessage(erroringFlow.get().exceptionThrown) to aliceNode ) // Make sure the original stack trace isn't sent down the wire - assertThat((receivedSessionMessages.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty() + val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage + assertThat((lastMessage.payload as ErrorSessionMessage).flowException!!.stackTrace).isEmpty() } @Test @@ -438,7 +441,7 @@ class FlowFrameworkTests { aliceNode sent sessionInit(ReceiveFlow::class) to bobNode, bobNode sent sessionConfirm() to aliceNode, bobNode sent sessionData("Hello") to aliceNode, - aliceNode sent erroredEnd() to bobNode + aliceNode sent errorMessage() to bobNode ) } @@ -603,20 +606,20 @@ class FlowFrameworkTests { @Test fun `unknown class in session init`() { - aliceNode.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), bob) + aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, "not.a.real.Class", 1, "", null), bob) mockNet.runNetwork() assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected - val reject = receivedSessionMessages.last().message as SessionReject - assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class") + val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage + assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("Don't know not.a.real.Class") } @Test fun `non-flow class in session init`() { - aliceNode.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), bob) + aliceNode.sendSessionMessage(InitialSessionMessage(SessionId(random63BitValue()), 0, String::class.java.name, 1, "", null), bob) mockNet.runNetwork() assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected - val reject = receivedSessionMessages.last().message as SessionReject - assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow") + val lastMessage = receivedSessionMessages.last().message as ExistingSessionMessage + assertThat((lastMessage.payload as RejectSessionMessage).message).isEqualTo("${String::class.java.name} is not a flow") } @Test @@ -682,14 +685,14 @@ class FlowFrameworkTests { return observable.toFuture() } - private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): SessionInit { - return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) + private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): InitialSessionMessage { + return InitialSessionMessage(SessionId(0), 0, clientFlowClass.java.name, flowVersion, "", payload?.serialize()) } - private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "") - private fun sessionData(payload: Any) = SessionData(0, payload.serialize()) - private val normalEnd = NormalSessionEnd(0) - private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse) + private fun sessionConfirm(flowVersion: Int = 1) = ExistingSessionMessage(SessionId(0), ConfirmSessionMessage(SessionId(0), FlowInfo(flowVersion, ""))) + private fun sessionData(payload: Any) = ExistingSessionMessage(SessionId(0), DataSessionMessage(payload.serialize())) + private val normalEnd = ExistingSessionMessage(SessionId(0), EndSessionMessage) // NormalSessionEnd(0) + private fun errorMessage(errorResponse: FlowException? = null) = ExistingSessionMessage(SessionId(0), ErrorSessionMessage(errorResponse, 0)) private fun StartedNode<*>.sendSessionMessage(message: SessionMessage, destination: Party) { services.networkService.apply { @@ -709,7 +712,9 @@ class FlowFrameworkTests { } private data class SessionTransfer(val from: Int, val message: SessionMessage, val to: MessageRecipients) { - val isPayloadTransfer: Boolean get() = message is SessionData || message is SessionInit && message.firstPayload != null + val isPayloadTransfer: Boolean get() = + message is ExistingSessionMessage && message.payload is DataSessionMessage || + message is InitialSessionMessage && message.firstPayload != null override fun toString(): String = "$from sent $message to $to" } @@ -718,7 +723,7 @@ class FlowFrameworkTests { } private fun Observable.toSessionTransfers(): Observable { - return filter { it.message.topicSession == StateMachineManagerImpl.sessionTopic }.map { + return filter { it.message.topic == StateMachineManagerImpl.sessionTopic }.map { val from = it.sender.id val message = it.message.data.deserialize() SessionTransfer(from, sanitise(message), it.recipients) @@ -726,12 +731,23 @@ class FlowFrameworkTests { } private fun sanitise(message: SessionMessage) = when (message) { - is SessionData -> message.copy(recipientSessionId = 0) - is SessionInit -> message.copy(initiatorSessionId = 0, appName = "") - is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0, appName = "") - is NormalSessionEnd -> message.copy(recipientSessionId = 0) - is ErrorSessionEnd -> message.copy(recipientSessionId = 0) - else -> message + is InitialSessionMessage -> message.copy(initiatorSessionId = SessionId(0), initiationEntropy = 0, appName = "") + is ExistingSessionMessage -> { + val payload = message.payload + message.copy( + recipientSessionId = SessionId(0), + payload = when (payload) { + is ConfirmSessionMessage -> payload.copy( + initiatedSessionId = SessionId(0), + initiatedFlowInfo = payload.initiatedFlowInfo.copy(appName = "") + ) + is ErrorSessionMessage -> payload.copy( + errorId = 0 + ) + else -> payload + } + ) + } } private infix fun StartedNode.sent(message: SessionMessage): Pair = Pair(internals.id, message) diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt index c96b08c2f2..bbf21b1ece 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt @@ -15,9 +15,7 @@ import net.corda.core.serialization.deserialize import net.corda.core.utilities.ProgressTracker import net.corda.netmap.VisualiserViewModel.Style import net.corda.netmap.simulation.IRSSimulation -import net.corda.node.services.statemachine.SessionConfirm -import net.corda.node.services.statemachine.SessionEnd -import net.corda.node.services.statemachine.SessionInit +import net.corda.node.services.statemachine.* import net.corda.testing.core.chooseIdentity import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockNetwork @@ -342,12 +340,16 @@ class NetworkMapVisualiser : Application() { private fun transferIsInteresting(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean { // Loopback messages are boring. if (transfer.sender == transfer.recipients) return false - val message = transfer.message.data.deserialize() + val message = transfer.message.data.deserialize() return when (message) { - is SessionEnd -> false - is SessionConfirm -> false - is SessionInit -> message.firstPayload != null - else -> true + is InitialSessionMessage -> message.firstPayload != null + is ExistingSessionMessage -> when (message.payload) { + is ConfirmSessionMessage -> false + is DataSessionMessage -> true + is ErrorSessionMessage -> true + is RejectSessionMessage -> true + is EndSessionMessage -> false + } } } } diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index fbfc3cf4d9..f1d45b1028 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -14,11 +14,17 @@ import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.trace -import net.corda.node.services.messaging.* +import net.corda.node.services.messaging.Message +import net.corda.node.services.messaging.MessageHandlerRegistration +import net.corda.node.services.messaging.MessagingService +import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.utilities.AffinityExecutor import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.testing.node.InMemoryMessagingNetwork.TestMessagingService import org.apache.activemq.artemis.utils.ReusableLatch import org.slf4j.LoggerFactory import rx.Observable @@ -57,7 +63,7 @@ class InMemoryMessagingNetwork internal constructor( @CordaSerializable data class MessageTransfer(val sender: PeerHandle, val message: Message, val recipients: MessageRecipients) { - override fun toString() = "${message.topicSession} from '$sender' to '$recipients'" + override fun toString() = "${message.topic} from '$sender' to '$recipients'" } // All sent messages are kept here until pumpSend is called, or manuallyPumped is set to false @@ -241,17 +247,17 @@ class InMemoryMessagingNetwork internal constructor( _sentMessages.onNext(transfer) } - data class InMemoryMessage(override val topicSession: TopicSession, - override val data: ByteArray, - override val uniqueMessageId: UUID, - override val debugTimestamp: Instant = Instant.now()) : Message { - override fun toString() = "$topicSession#${String(data)}" + data class InMemoryMessage(override val topic: String, + override val data: ByteSequence, + override val uniqueMessageId: String, + override val debugTimestamp: Instant = Instant.now()) : Message { + override fun toString() = "$topic#${String(data.bytes)}" } - private data class InMemoryReceivedMessage(override val topicSession: TopicSession, - override val data: ByteArray, + private data class InMemoryReceivedMessage(override val topic: String, + override val data: ByteSequence, override val platformVersion: Int, - override val uniqueMessageId: UUID, + override val uniqueMessageId: String, override val debugTimestamp: Instant, override val peer: CordaX500Name) : ReceivedMessage @@ -270,8 +276,7 @@ class InMemoryMessagingNetwork internal constructor( private val peerHandle: PeerHandle, private val executor: AffinityExecutor, private val database: CordaPersistence) : SingletonSerializeAsToken(), TestMessagingService { - private inner class Handler(val topicSession: TopicSession, - val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration + inner class Handler(val topicSession: String, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration @Volatile private var running = true @@ -282,7 +287,7 @@ class InMemoryMessagingNetwork internal constructor( } private val state = ThreadBox(InnerState()) - private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) + private val processedMessages: MutableSet = Collections.synchronizedSet(HashSet()) override val myAddress: PeerHandle get() = peerHandle @@ -304,13 +309,10 @@ class InMemoryMessagingNetwork internal constructor( } } - override fun addMessageHandler(topic: String, sessionID: Long, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration - = addMessageHandler(TopicSession(topic, sessionID), callback) - - override fun addMessageHandler(topicSession: TopicSession, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { + override fun addMessageHandler(topic: String, callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit): MessageHandlerRegistration { check(running) val (handler, transfers) = state.locked { - val handler = Handler(topicSession, callback).apply { handlers.add(this) } + val handler = Handler(topic, callback).apply { handlers.add(this) } val pending = ArrayList() database.transaction { pending.addAll(pendingRedelivery) @@ -328,20 +330,18 @@ class InMemoryMessagingNetwork internal constructor( state.locked { check(handlers.remove(registration as Handler)) } } - override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any, acknowledgementHandler: (() -> Unit)?) { + override fun send(message: Message, target: MessageRecipients, retryId: Long?, sequenceKey: Any) { check(running) msgSend(this, message, target) - acknowledgementHandler?.invoke() if (!sendManuallyPumped) { pumpSend(false) } } - override fun send(addressedMessages: List, acknowledgementHandler: (() -> Unit)?) { + override fun send(addressedMessages: List) { for ((message, target, retryId, sequenceKey) in addressedMessages) { - send(message, target, retryId, sequenceKey, null) + send(message, target, retryId, sequenceKey) } - acknowledgementHandler?.invoke() } override fun stop() { @@ -356,8 +356,8 @@ class InMemoryMessagingNetwork internal constructor( override fun cancelRedelivery(retryId: Long) {} /** Returns the given (topic & session, data) pair as a newly created message object. */ - override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message { - return InMemoryMessage(topicSession, data, uuid) + override fun createMessage(topic: String, data: ByteArray, deduplicationId: String): Message { + return InMemoryMessage(topic, OpaqueBytes(data), deduplicationId) } /** @@ -390,14 +390,14 @@ class InMemoryMessagingNetwork internal constructor( while (deliverTo == null) { val transfer = (if (block) q.take() else q.poll()) ?: return null deliverTo = state.locked { - val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topicSession == it.topicSession } + val matchingHandlers = handlers.filter { it.topicSession.isBlank() || transfer.message.topic == it.topicSession } if (matchingHandlers.isEmpty()) { // Got no handlers for this message yet. Keep the message around and attempt redelivery after a new // handler has been registered. The purpose of this path is to make unit tests that have multi-threading // reliable, as a sender may attempt to send a message to a receiver that hasn't finished setting // up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at // least sometimes. - log.warn("Message to ${transfer.message.topicSession} could not be delivered") + log.warn("Message to ${transfer.message.topic} could not be delivered") database.transaction { pendingRedelivery.add(transfer) } @@ -438,8 +438,8 @@ class InMemoryMessagingNetwork internal constructor( } private fun MessageTransfer.toReceivedMessage(): ReceivedMessage = InMemoryReceivedMessage( - message.topicSession, - message.data.copyOf(), // Kryo messes with the buffer so give each client a unique copy + message.topic, + OpaqueBytes(message.data.bytes.copyOf()), // Kryo messes with the buffer so give each client a unique copy 1, message.uniqueMessageId, message.debugTimestamp, From 70399eb2ac45e6604a77d69d3f6c8bccb2af77df Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Fri, 9 Feb 2018 15:20:06 +0000 Subject: [PATCH 2/3] API changes --- .ci/api-current.txt | 82 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 69 insertions(+), 13 deletions(-) diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 5ae28f1fe9..3e86538b36 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -1558,6 +1558,7 @@ public final class net.corda.core.identity.IdentityUtils extends java.lang.Objec @net.corda.core.serialization.CordaSerializable public interface net.corda.core.messaging.AllPossibleRecipients extends net.corda.core.messaging.MessageRecipients ## @net.corda.core.DoNotImplement public interface net.corda.core.messaging.CordaRPCOps extends net.corda.core.messaging.RPCOps + public abstract void acceptNewNetworkParameters(net.corda.core.crypto.SecureHash) public abstract void addVaultTransactionNote(net.corda.core.crypto.SecureHash, String) public abstract boolean attachmentExists(net.corda.core.crypto.SecureHash) public abstract void clearNetworkMapCache() @@ -1568,6 +1569,7 @@ public final class net.corda.core.identity.IdentityUtils extends java.lang.Objec @kotlin.Deprecated @org.jetbrains.annotations.NotNull public abstract List internalVerifiedTransactionsSnapshot() @net.corda.core.messaging.RPCReturnsObservables @org.jetbrains.annotations.NotNull public abstract net.corda.core.messaging.DataFeed networkMapFeed() @org.jetbrains.annotations.NotNull public abstract List networkMapSnapshot() + @net.corda.core.messaging.RPCReturnsObservables @org.jetbrains.annotations.NotNull public abstract net.corda.core.messaging.DataFeed networkParametersFeed() @org.jetbrains.annotations.NotNull public abstract net.corda.core.node.NodeInfo nodeInfo() @org.jetbrains.annotations.Nullable public abstract net.corda.core.node.NodeInfo nodeInfoFromParty(net.corda.core.identity.AbstractParty) @org.jetbrains.annotations.NotNull public abstract List notaryIdentities() @@ -1658,6 +1660,21 @@ public final class net.corda.core.messaging.CordaRPCOpsKt extends java.lang.Obje ## @net.corda.core.serialization.CordaSerializable public interface net.corda.core.messaging.MessageRecipients ## +@net.corda.core.serialization.CordaSerializable public final class net.corda.core.messaging.ParametersUpdateInfo extends java.lang.Object + public (net.corda.core.crypto.SecureHash, net.corda.core.node.NetworkParameters, String, java.time.Instant) + @org.jetbrains.annotations.NotNull public final net.corda.core.crypto.SecureHash component1() + @org.jetbrains.annotations.NotNull public final net.corda.core.node.NetworkParameters component2() + @org.jetbrains.annotations.NotNull public final String component3() + @org.jetbrains.annotations.NotNull public final java.time.Instant component4() + @org.jetbrains.annotations.NotNull public final net.corda.core.messaging.ParametersUpdateInfo copy(net.corda.core.crypto.SecureHash, net.corda.core.node.NetworkParameters, String, java.time.Instant) + public boolean equals(Object) + @org.jetbrains.annotations.NotNull public final String getDescription() + @org.jetbrains.annotations.NotNull public final net.corda.core.crypto.SecureHash getHash() + @org.jetbrains.annotations.NotNull public final net.corda.core.node.NetworkParameters getParameters() + @org.jetbrains.annotations.NotNull public final java.time.Instant getUpdateDeadline() + public int hashCode() + public String toString() +## @net.corda.core.DoNotImplement public interface net.corda.core.messaging.RPCOps public abstract int getProtocolVersion() ## @@ -1723,6 +1740,25 @@ public @interface net.corda.core.messaging.RPCReturnsObservables @org.jetbrains.annotations.NotNull public abstract net.corda.core.messaging.FlowHandle startFlow(net.corda.core.flows.FlowLogic) @org.jetbrains.annotations.NotNull public abstract net.corda.core.messaging.FlowProgressHandle startTrackedFlow(net.corda.core.flows.FlowLogic) ## +@net.corda.core.serialization.CordaSerializable public final class net.corda.core.node.NetworkParameters extends java.lang.Object + public (int, List, int, int, java.time.Instant, int) + public final int component1() + @org.jetbrains.annotations.NotNull public final List component2() + public final int component3() + public final int component4() + @org.jetbrains.annotations.NotNull public final java.time.Instant component5() + public final int component6() + @org.jetbrains.annotations.NotNull public final net.corda.core.node.NetworkParameters copy(int, List, int, int, java.time.Instant, int) + public boolean equals(Object) + public final int getEpoch() + public final int getMaxMessageSize() + public final int getMaxTransactionSize() + public final int getMinimumPlatformVersion() + @org.jetbrains.annotations.NotNull public final java.time.Instant getModifiedTime() + @org.jetbrains.annotations.NotNull public final List getNotaries() + public int hashCode() + public String toString() +## @net.corda.core.serialization.CordaSerializable public final class net.corda.core.node.NodeInfo extends java.lang.Object public (List, List, int, long) @org.jetbrains.annotations.NotNull public final List component1() @@ -1742,6 +1778,17 @@ public @interface net.corda.core.messaging.RPCReturnsObservables public final boolean isLegalIdentity(net.corda.core.identity.Party) public String toString() ## +@net.corda.core.serialization.CordaSerializable public final class net.corda.core.node.NotaryInfo extends java.lang.Object + public (net.corda.core.identity.Party, boolean) + @org.jetbrains.annotations.NotNull public final net.corda.core.identity.Party component1() + public final boolean component2() + @org.jetbrains.annotations.NotNull public final net.corda.core.node.NotaryInfo copy(net.corda.core.identity.Party, boolean) + public boolean equals(Object) + @org.jetbrains.annotations.NotNull public final net.corda.core.identity.Party getIdentity() + public final boolean getValidating() + public int hashCode() + public String toString() +## @net.corda.core.DoNotImplement public interface net.corda.core.node.ServiceHub extends net.corda.core.node.ServicesForResolution @org.jetbrains.annotations.NotNull public abstract net.corda.core.transactions.SignedTransaction addSignature(net.corda.core.transactions.SignedTransaction) @org.jetbrains.annotations.NotNull public abstract net.corda.core.transactions.SignedTransaction addSignature(net.corda.core.transactions.SignedTransaction, java.security.PublicKey) @@ -3148,6 +3195,7 @@ public final class net.corda.core.utilities.ByteArrays extends java.lang.Object @net.corda.core.serialization.CordaSerializable public abstract class net.corda.core.utilities.ByteSequence extends java.lang.Object implements java.lang.Comparable public int compareTo(net.corda.core.utilities.ByteSequence) @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ByteSequence copy() + @org.jetbrains.annotations.NotNull public final byte[] copyBytes() public boolean equals(Object) @org.jetbrains.annotations.NotNull public abstract byte[] getBytes() public final int getOffset() @@ -3157,9 +3205,12 @@ public final class net.corda.core.utilities.ByteArrays extends java.lang.Object @kotlin.jvm.JvmStatic @org.jetbrains.annotations.NotNull public static final net.corda.core.utilities.ByteSequence of(byte[], int) @kotlin.jvm.JvmStatic @org.jetbrains.annotations.NotNull public static final net.corda.core.utilities.ByteSequence of(byte[], int, int) @org.jetbrains.annotations.NotNull public final java.io.ByteArrayInputStream open() + @org.jetbrains.annotations.NotNull public final java.nio.ByteBuffer putTo(java.nio.ByteBuffer) + @org.jetbrains.annotations.NotNull public final java.nio.ByteBuffer slice(int, int) @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ByteSequence subSequence(int, int) @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ByteSequence take(int) @org.jetbrains.annotations.NotNull public String toString() + public final void writeTo(java.io.OutputStream) public static final net.corda.core.utilities.ByteSequence$Companion Companion ## public static final class net.corda.core.utilities.ByteSequence$Companion extends java.lang.Object @@ -3796,20 +3847,25 @@ public static final class net.corda.testing.node.ClusterSpec$Raft extends net.co public static final class net.corda.testing.node.InMemoryMessagingNetwork$Companion extends java.lang.Object ## public static final class net.corda.testing.node.InMemoryMessagingNetwork$InMemoryMessage extends java.lang.Object implements net.corda.node.services.messaging.Message - public (net.corda.node.services.messaging.TopicSession, byte[], UUID, java.time.Instant) - @org.jetbrains.annotations.NotNull public final net.corda.node.services.messaging.TopicSession component1() - @org.jetbrains.annotations.NotNull public final byte[] component2() - @org.jetbrains.annotations.NotNull public final UUID component3() + public (String, net.corda.core.utilities.ByteSequence, String, java.time.Instant) + @org.jetbrains.annotations.NotNull public final String component1() + @org.jetbrains.annotations.NotNull public final net.corda.core.utilities.ByteSequence component2() + @org.jetbrains.annotations.NotNull public final String component3() @org.jetbrains.annotations.NotNull public final java.time.Instant component4() - @org.jetbrains.annotations.NotNull public final net.corda.testing.node.InMemoryMessagingNetwork$InMemoryMessage copy(net.corda.node.services.messaging.TopicSession, byte[], UUID, java.time.Instant) + @org.jetbrains.annotations.NotNull public final net.corda.testing.node.InMemoryMessagingNetwork$InMemoryMessage copy(String, net.corda.core.utilities.ByteSequence, String, java.time.Instant) public boolean equals(Object) - @org.jetbrains.annotations.NotNull public byte[] getData() + @org.jetbrains.annotations.NotNull public net.corda.core.utilities.ByteSequence getData() @org.jetbrains.annotations.NotNull public java.time.Instant getDebugTimestamp() - @org.jetbrains.annotations.NotNull public net.corda.node.services.messaging.TopicSession getTopicSession() - @org.jetbrains.annotations.NotNull public UUID getUniqueMessageId() + @org.jetbrains.annotations.NotNull public String getTopic() + @org.jetbrains.annotations.NotNull public String getUniqueMessageId() public int hashCode() @org.jetbrains.annotations.NotNull public String toString() ## +public final class net.corda.testing.node.InMemoryMessagingNetwork$InMemoryMessaging$Handler extends java.lang.Object implements net.corda.node.services.messaging.MessageHandlerRegistration + public (net.corda.testing.node.InMemoryMessagingNetwork$InMemoryMessaging, String, kotlin.jvm.functions.Function2) + @org.jetbrains.annotations.NotNull public final kotlin.jvm.functions.Function2 getCallback() + @org.jetbrains.annotations.NotNull public final String getTopicSession() +## public static interface net.corda.testing.node.InMemoryMessagingNetwork$LatencyCalculator @org.jetbrains.annotations.NotNull public abstract java.time.Duration between(net.corda.core.messaging.SingleMessageRecipient, net.corda.core.messaging.SingleMessageRecipient) ## @@ -3869,16 +3925,15 @@ public static final class net.corda.testing.node.InMemoryMessagingNetwork$pumpSe ## public class net.corda.testing.node.MessagingServiceSpy extends java.lang.Object implements net.corda.node.services.messaging.MessagingService public (net.corda.node.services.messaging.MessagingService) - @org.jetbrains.annotations.NotNull public net.corda.node.services.messaging.MessageHandlerRegistration addMessageHandler(String, long, kotlin.jvm.functions.Function2) - @org.jetbrains.annotations.NotNull public net.corda.node.services.messaging.MessageHandlerRegistration addMessageHandler(net.corda.node.services.messaging.TopicSession, kotlin.jvm.functions.Function2) + @org.jetbrains.annotations.NotNull public net.corda.node.services.messaging.MessageHandlerRegistration addMessageHandler(String, kotlin.jvm.functions.Function2) public void cancelRedelivery(long) - @org.jetbrains.annotations.NotNull public net.corda.node.services.messaging.Message createMessage(net.corda.node.services.messaging.TopicSession, byte[], UUID) + @org.jetbrains.annotations.NotNull public net.corda.node.services.messaging.Message createMessage(String, byte[], String) @org.jetbrains.annotations.NotNull public net.corda.core.messaging.MessageRecipients getAddressOfParty(net.corda.core.node.services.PartyInfo) @org.jetbrains.annotations.NotNull public final net.corda.node.services.messaging.MessagingService getMessagingService() @org.jetbrains.annotations.NotNull public net.corda.core.messaging.SingleMessageRecipient getMyAddress() public void removeMessageHandler(net.corda.node.services.messaging.MessageHandlerRegistration) - public void send(List, kotlin.jvm.functions.Function0) - public void send(net.corda.node.services.messaging.Message, net.corda.core.messaging.MessageRecipients, Long, Object, kotlin.jvm.functions.Function0) + @co.paralleluniverse.fibers.Suspendable public void send(List) + @co.paralleluniverse.fibers.Suspendable public void send(net.corda.node.services.messaging.Message, net.corda.core.messaging.MessageRecipients, Long, Object) ## public final class net.corda.testing.node.MockKeyManagementService extends net.corda.core.serialization.SingletonSerializeAsToken implements net.corda.core.node.services.KeyManagementService @org.jetbrains.annotations.NotNull public Iterable filterMyKeys(Iterable) @@ -4019,6 +4074,7 @@ public final class net.corda.testing.node.MockNodeKt extends java.lang.Object public long getAttachmentContentCacheSizeBytes() @org.jetbrains.annotations.NotNull public java.nio.file.Path getCertificatesDirectory() public boolean getDetectPublicIp() + public boolean getNoLocalShell() @org.jetbrains.annotations.NotNull public java.nio.file.Path getNodeKeystore() @org.jetbrains.annotations.NotNull public java.nio.file.Path getSslKeystore() public long getTransactionCacheSizeBytes() From d01b2cbe975ef0edbbc3011ba729ddde4844dd8e Mon Sep 17 00:00:00 2001 From: Andras Slemmer Date: Fri, 9 Feb 2018 16:03:36 +0000 Subject: [PATCH 3/3] Address comments, fix test --- .../internal/serialization/SetsSerializationTest.kt | 2 +- .../services/statemachine/StateMachineManagerImpl.kt | 2 +- .../net/corda/node/messaging/InMemoryMessagingTests.kt | 9 ++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt index 48ba75540e..edd1eabf58 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SetsSerializationTest.kt @@ -32,7 +32,7 @@ class SetsSerializationTest { } @Test - fun `check set can be serialized as part of SessionData`() { + fun `check set can be serialized as part of DataSessionMessage`() { run { val sessionData = DataSessionMessage(setOf(1).serialize()) assertEqualAfterRoundTripSerialization(sessionData) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt index f6b7ebfe01..70ee60d5a5 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManagerImpl.kt @@ -341,7 +341,7 @@ class StateMachineManagerImpl( logger.trace { "Received $sessionInit from $sender" } val senderSessionId = sessionInit.initiatorSessionId - fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, random63BitValue()))) + fun sendSessionReject(message: String) = sendSessionMessage(sender, ExistingSessionMessage(senderSessionId, RejectSessionMessage(message, errorId = sessionInit.initiatorSessionId.toLong))) val (session, initiatedFlowFactory) = try { val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) diff --git a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt index 425d2f0e94..6981ac1b2f 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt @@ -5,7 +5,6 @@ import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.TopicStringValidator import net.corda.testing.internal.rigorousMock import net.corda.testing.node.MockNetwork -import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY import org.junit.After import org.junit.Before import org.junit.Test @@ -93,8 +92,8 @@ class InMemoryMessagingTests { node1.network.addMessageHandler("valid_message") { _, _ -> received++ } - val invalidMessage = node2.network.createMessage("invalid_message", data = EMPTY_BYTE_ARRAY) - val validMessage = node2.network.createMessage("valid_message", data = EMPTY_BYTE_ARRAY) + val invalidMessage = node2.network.createMessage("invalid_message", data = ByteArray(1)) + val validMessage = node2.network.createMessage("valid_message", data = ByteArray(1)) node2.network.send(invalidMessage, node1.network.myAddress) mockNet.runNetwork() assertEquals(0, received) @@ -105,8 +104,8 @@ class InMemoryMessagingTests { // Here's the core of the test; previously the unhandled message would cause runNetwork() to abort early, so // this would fail. Make fresh messages to stop duplicate uniqueMessageId causing drops - val invalidMessage2 = node2.network.createMessage("invalid_message", data = EMPTY_BYTE_ARRAY) - val validMessage2 = node2.network.createMessage("valid_message", data = EMPTY_BYTE_ARRAY) + val invalidMessage2 = node2.network.createMessage("invalid_message", data = ByteArray(1)) + val validMessage2 = node2.network.createMessage("valid_message", data = ByteArray(1)) node2.network.send(invalidMessage2, node1.network.myAddress) node2.network.send(validMessage2, node1.network.myAddress) mockNet.runNetwork()