diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessageDeduplicatorTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessageDeduplicatorTest.kt new file mode 100644 index 0000000000..83db356a8d --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessageDeduplicatorTest.kt @@ -0,0 +1,134 @@ +package net.corda.services.messaging + +import net.corda.core.identity.CordaX500Name +import net.corda.core.utilities.ByteSequence +import net.corda.node.services.messaging.MessageIdentifier +import net.corda.node.services.messaging.P2PMessageDeduplicator +import net.corda.node.services.messaging.ReceivedMessage +import net.corda.node.services.messaging.generateShardId +import net.corda.node.services.statemachine.MessageType +import net.corda.node.services.statemachine.SessionId +import net.corda.nodeapi.internal.persistence.CordaPersistence +import net.corda.nodeapi.internal.persistence.DatabaseConfig +import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.internal.TestingNamedCacheFactory +import net.corda.testing.internal.configureDatabase +import net.corda.testing.node.MockServices +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.After +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import java.lang.IllegalArgumentException +import java.math.BigInteger +import java.time.Instant + +class P2PMessageDeduplicatorTest { + + companion object { + private const val TOPIC = "whatever" + private val DATA = ByteSequence.of("blah blah blah".toByteArray()) + private val SHARD_ID = generateShardId("some-flow-id") + private val SESSION_ID = SessionId(BigInteger.ONE) + private val TIMESTAMP = Instant.now() + private val SENDER = CordaX500Name("CordaWorld", "The Sea Devil", "NeverLand", "NL") + private const val SENDER_UUID = "some-sender-uuid" + private const val PLATFORM_VERSION = 42 + + private const val FIRST_SENDER_SEQ_NO = 10L + private const val LAST_SENDER_SEQ_NO = 35L + } + + @Rule + @JvmField + val testSerialization = SerializationEnvironmentRule() + + private lateinit var database: CordaPersistence + private lateinit var deduplicator: P2PMessageDeduplicator + + + @Before + fun setUp() { + val dataSourceProps = MockServices.makeTestDataSourceProperties() + database = configureDatabase(dataSourceProps, DatabaseConfig(), { null }, { null }, runMigrationScripts = true, allowHibernateToManageAppSchema = false) + deduplicator = P2PMessageDeduplicator(TestingNamedCacheFactory(), database) + } + + @After + fun tearDown() { + database.close() + } + + @Test(timeout=300_000) + fun `correctly deduplicates a session-init message`() { + val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP) + val receivedMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO) + + assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isFalse() + + processMessage(receivedMessage) + + assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isTrue() + } + + @Test(timeout=300_000) + fun `fails when requested to deduplicate a non session-init message`() { + val msgId = MessageIdentifier(MessageType.DATA_MESSAGE, SHARD_ID, SESSION_ID, 3, TIMESTAMP) + val receivedMessage = createMessage(msgId, 25) + + assertThatThrownBy { deduplicator.isDuplicateSessionInit(receivedMessage) }.isInstanceOf(IllegalArgumentException::class.java) + .hasMessageContaining("was not a session-init message") + } + + @Test(timeout=300_000) + fun `updates session data correctly when session is completed`() { + val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP) + val sessionInitMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO) + + processMessage(sessionInitMessage) + + val sessionDataAfterSessionInit = database.transaction { + entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value) + } + assertThat(sessionDataAfterSessionInit.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO) + assertThat(sessionDataAfterSessionInit.lastSenderSeqNo).isNull() + assertThat(sessionDataAfterSessionInit.generationTime).isEqualTo(TIMESTAMP) + + database.transaction { + deduplicator.signalSessionEnd(SESSION_ID, SENDER_UUID, LAST_SENDER_SEQ_NO) + } + + val sessionDataAfterSessionEnd = database.transaction { + entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value) + } + assertThat(sessionDataAfterSessionEnd.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO) + assertThat(sessionDataAfterSessionEnd.lastSenderSeqNo).isEqualTo(LAST_SENDER_SEQ_NO) + assertThat(sessionDataAfterSessionEnd.generationTime).isEqualTo(TIMESTAMP) + } + + private fun processMessage(receivedMessage: ReceivedMessage) { + deduplicator.isDuplicateSessionInit(receivedMessage) + + deduplicator.signalMessageProcessStart(receivedMessage) + database.transaction { + deduplicator.persistDeduplicationId(receivedMessage.uniqueMessageId) + } + deduplicator.signalMessageProcessFinish(receivedMessage.uniqueMessageId) + } + + private fun createMessage(msgId: MessageIdentifier, senderSeqNo: Long?): ReceivedMessage { + return MockReceivedMessage(TOPIC, DATA, TIMESTAMP, msgId, SENDER_UUID, emptyMap(), SENDER, PLATFORM_VERSION, senderSeqNo, msgId.messageType == MessageType.SESSION_INIT) + } + + data class MockReceivedMessage(override val topic: String, + override val data: ByteSequence, + override val debugTimestamp: Instant, + override val uniqueMessageId: MessageIdentifier, + override val senderUUID: String?, + override val additionalHeaders: Map, + override val peer: CordaX500Name, + override val platformVersion: Int, + override val senderSeqNo: Long?, + override val isSessionInit: Boolean): ReceivedMessage +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/SessionDataPersistenceTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/SessionDataPersistenceTest.kt new file mode 100644 index 0000000000..258535ac1e --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/SessionDataPersistenceTest.kt @@ -0,0 +1,101 @@ +package net.corda.services.messaging + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.transpose +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.node.services.Permissions +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.driver +import net.corda.testing.node.User +import net.corda.testing.node.internal.enclosedCordapp +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class SessionDataPersistenceTest { + + private val user = User("u", "p", setOf(Permissions.all())) + + @Test(timeout=300_000) + fun `session data are persisted successfully and with the appropriate sequence numbers`() { + driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList(), cordappsForAllNodes = setOf(enclosedCordapp()))) { + val (alice, bob) = listOf( + startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)), + startNode(providedName = BOB_NAME)).transpose().getOrThrow() + + val numberOfMessages = 3 + alice.rpc.startFlow(::InitiatorFlow, bob.nodeInfo.legalIdentities.first(), numberOfMessages).returnValue.get() + + // session data are not maintained for the initiator side. + val aliceLastSeqNumbers = alice.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get() + assertThat(aliceLastSeqNumbers).isEmpty() + + // only one flow here, so sender sequence number is expected to start from zero and increment by one for each message. + val bobLastSeqNumbers = bob.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get() + assertThat(bobLastSeqNumbers).hasSize(1) + assertThat(bobLastSeqNumbers).first().isEqualTo(Pair(0, numberOfMessages - 1)) + } + } + + @StartableByRPC + @InitiatingFlow + class InitiatorFlow(private val otherParty: Party, private val numberOfMessages: Int) : FlowLogic() { + @Suspendable + override fun call() { + val session = initiateFlow(otherParty) + session.send(numberOfMessages) + + (2 .. numberOfMessages).forEach { + session.send("message $it") + } + + session.receive() + } + } + + @InitiatedBy(InitiatorFlow::class) + open class ResponderFlow(private val otherPartySession: FlowSession) : FlowLogic() { + @Suspendable + override fun call() { + val numberOfMessages = otherPartySession.receive().unwrap { it } + + (2 .. numberOfMessages).forEach { + otherPartySession.receive() + } + + otherPartySession.send("Got them all") + } + } + + @StartableByRPC + class GetSeqNumbersFlow: FlowLogic>>() { + @Suspendable + override fun call(): MutableList> { + return getSeqNumbers() + } + + private fun getSeqNumbers(): MutableList> { + val sequenceNumbers = mutableListOf>() + serviceHub.jdbcSession().createStatement().use { stmt -> + stmt.execute("SELECT init_sequence_number, last_sequence_number FROM node_session_data") + while (stmt.resultSet.next()) { + val firstSeqNo = stmt.resultSet.getInt(1) + val lastSeqNo = stmt.resultSet.getInt(2) + sequenceNumbers += Pair(firstSeqNo, lastSeqNo) + } + } + + return sequenceNumbers + } + } + +} \ No newline at end of file