mirror of
https://github.com/corda/corda.git
synced 2025-01-03 03:36:48 +00:00
Add e2e/unit tests for session data persistence
This commit is contained in:
parent
f4fa08ed10
commit
fcee4ed7cb
@ -0,0 +1,134 @@
|
|||||||
|
package net.corda.services.messaging
|
||||||
|
|
||||||
|
import net.corda.core.identity.CordaX500Name
|
||||||
|
import net.corda.core.utilities.ByteSequence
|
||||||
|
import net.corda.node.services.messaging.MessageIdentifier
|
||||||
|
import net.corda.node.services.messaging.P2PMessageDeduplicator
|
||||||
|
import net.corda.node.services.messaging.ReceivedMessage
|
||||||
|
import net.corda.node.services.messaging.generateShardId
|
||||||
|
import net.corda.node.services.statemachine.MessageType
|
||||||
|
import net.corda.node.services.statemachine.SessionId
|
||||||
|
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||||
|
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||||
|
import net.corda.testing.core.SerializationEnvironmentRule
|
||||||
|
import net.corda.testing.internal.TestingNamedCacheFactory
|
||||||
|
import net.corda.testing.internal.configureDatabase
|
||||||
|
import net.corda.testing.node.MockServices
|
||||||
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
|
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||||
|
import org.junit.After
|
||||||
|
import org.junit.Before
|
||||||
|
import org.junit.Rule
|
||||||
|
import org.junit.Test
|
||||||
|
import java.lang.IllegalArgumentException
|
||||||
|
import java.math.BigInteger
|
||||||
|
import java.time.Instant
|
||||||
|
|
||||||
|
class P2PMessageDeduplicatorTest {
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
private const val TOPIC = "whatever"
|
||||||
|
private val DATA = ByteSequence.of("blah blah blah".toByteArray())
|
||||||
|
private val SHARD_ID = generateShardId("some-flow-id")
|
||||||
|
private val SESSION_ID = SessionId(BigInteger.ONE)
|
||||||
|
private val TIMESTAMP = Instant.now()
|
||||||
|
private val SENDER = CordaX500Name("CordaWorld", "The Sea Devil", "NeverLand", "NL")
|
||||||
|
private const val SENDER_UUID = "some-sender-uuid"
|
||||||
|
private const val PLATFORM_VERSION = 42
|
||||||
|
|
||||||
|
private const val FIRST_SENDER_SEQ_NO = 10L
|
||||||
|
private const val LAST_SENDER_SEQ_NO = 35L
|
||||||
|
}
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
@JvmField
|
||||||
|
val testSerialization = SerializationEnvironmentRule()
|
||||||
|
|
||||||
|
private lateinit var database: CordaPersistence
|
||||||
|
private lateinit var deduplicator: P2PMessageDeduplicator
|
||||||
|
|
||||||
|
|
||||||
|
@Before
|
||||||
|
fun setUp() {
|
||||||
|
val dataSourceProps = MockServices.makeTestDataSourceProperties()
|
||||||
|
database = configureDatabase(dataSourceProps, DatabaseConfig(), { null }, { null }, runMigrationScripts = true, allowHibernateToManageAppSchema = false)
|
||||||
|
deduplicator = P2PMessageDeduplicator(TestingNamedCacheFactory(), database)
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
fun tearDown() {
|
||||||
|
database.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout=300_000)
|
||||||
|
fun `correctly deduplicates a session-init message`() {
|
||||||
|
val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP)
|
||||||
|
val receivedMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO)
|
||||||
|
|
||||||
|
assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isFalse()
|
||||||
|
|
||||||
|
processMessage(receivedMessage)
|
||||||
|
|
||||||
|
assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isTrue()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout=300_000)
|
||||||
|
fun `fails when requested to deduplicate a non session-init message`() {
|
||||||
|
val msgId = MessageIdentifier(MessageType.DATA_MESSAGE, SHARD_ID, SESSION_ID, 3, TIMESTAMP)
|
||||||
|
val receivedMessage = createMessage(msgId, 25)
|
||||||
|
|
||||||
|
assertThatThrownBy { deduplicator.isDuplicateSessionInit(receivedMessage) }.isInstanceOf(IllegalArgumentException::class.java)
|
||||||
|
.hasMessageContaining("was not a session-init message")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test(timeout=300_000)
|
||||||
|
fun `updates session data correctly when session is completed`() {
|
||||||
|
val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP)
|
||||||
|
val sessionInitMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO)
|
||||||
|
|
||||||
|
processMessage(sessionInitMessage)
|
||||||
|
|
||||||
|
val sessionDataAfterSessionInit = database.transaction {
|
||||||
|
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value)
|
||||||
|
}
|
||||||
|
assertThat(sessionDataAfterSessionInit.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
|
||||||
|
assertThat(sessionDataAfterSessionInit.lastSenderSeqNo).isNull()
|
||||||
|
assertThat(sessionDataAfterSessionInit.generationTime).isEqualTo(TIMESTAMP)
|
||||||
|
|
||||||
|
database.transaction {
|
||||||
|
deduplicator.signalSessionEnd(SESSION_ID, SENDER_UUID, LAST_SENDER_SEQ_NO)
|
||||||
|
}
|
||||||
|
|
||||||
|
val sessionDataAfterSessionEnd = database.transaction {
|
||||||
|
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value)
|
||||||
|
}
|
||||||
|
assertThat(sessionDataAfterSessionEnd.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
|
||||||
|
assertThat(sessionDataAfterSessionEnd.lastSenderSeqNo).isEqualTo(LAST_SENDER_SEQ_NO)
|
||||||
|
assertThat(sessionDataAfterSessionEnd.generationTime).isEqualTo(TIMESTAMP)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun processMessage(receivedMessage: ReceivedMessage) {
|
||||||
|
deduplicator.isDuplicateSessionInit(receivedMessage)
|
||||||
|
|
||||||
|
deduplicator.signalMessageProcessStart(receivedMessage)
|
||||||
|
database.transaction {
|
||||||
|
deduplicator.persistDeduplicationId(receivedMessage.uniqueMessageId)
|
||||||
|
}
|
||||||
|
deduplicator.signalMessageProcessFinish(receivedMessage.uniqueMessageId)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun createMessage(msgId: MessageIdentifier, senderSeqNo: Long?): ReceivedMessage {
|
||||||
|
return MockReceivedMessage(TOPIC, DATA, TIMESTAMP, msgId, SENDER_UUID, emptyMap(), SENDER, PLATFORM_VERSION, senderSeqNo, msgId.messageType == MessageType.SESSION_INIT)
|
||||||
|
}
|
||||||
|
|
||||||
|
data class MockReceivedMessage(override val topic: String,
|
||||||
|
override val data: ByteSequence,
|
||||||
|
override val debugTimestamp: Instant,
|
||||||
|
override val uniqueMessageId: MessageIdentifier,
|
||||||
|
override val senderUUID: String?,
|
||||||
|
override val additionalHeaders: Map<String, String>,
|
||||||
|
override val peer: CordaX500Name,
|
||||||
|
override val platformVersion: Int,
|
||||||
|
override val senderSeqNo: Long?,
|
||||||
|
override val isSessionInit: Boolean): ReceivedMessage
|
||||||
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
package net.corda.services.messaging
|
||||||
|
|
||||||
|
import co.paralleluniverse.fibers.Suspendable
|
||||||
|
import net.corda.core.flows.FlowLogic
|
||||||
|
import net.corda.core.flows.FlowSession
|
||||||
|
import net.corda.core.flows.InitiatedBy
|
||||||
|
import net.corda.core.flows.InitiatingFlow
|
||||||
|
import net.corda.core.flows.StartableByRPC
|
||||||
|
import net.corda.core.identity.Party
|
||||||
|
import net.corda.core.internal.concurrent.transpose
|
||||||
|
import net.corda.core.messaging.startFlow
|
||||||
|
import net.corda.core.utilities.getOrThrow
|
||||||
|
import net.corda.core.utilities.unwrap
|
||||||
|
import net.corda.node.services.Permissions
|
||||||
|
import net.corda.testing.core.ALICE_NAME
|
||||||
|
import net.corda.testing.core.BOB_NAME
|
||||||
|
import net.corda.testing.driver.DriverParameters
|
||||||
|
import net.corda.testing.driver.driver
|
||||||
|
import net.corda.testing.node.User
|
||||||
|
import net.corda.testing.node.internal.enclosedCordapp
|
||||||
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
|
import org.junit.Test
|
||||||
|
|
||||||
|
class SessionDataPersistenceTest {
|
||||||
|
|
||||||
|
private val user = User("u", "p", setOf(Permissions.all()))
|
||||||
|
|
||||||
|
@Test(timeout=300_000)
|
||||||
|
fun `session data are persisted successfully and with the appropriate sequence numbers`() {
|
||||||
|
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList(), cordappsForAllNodes = setOf(enclosedCordapp()))) {
|
||||||
|
val (alice, bob) = listOf(
|
||||||
|
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
|
||||||
|
startNode(providedName = BOB_NAME)).transpose().getOrThrow()
|
||||||
|
|
||||||
|
val numberOfMessages = 3
|
||||||
|
alice.rpc.startFlow(::InitiatorFlow, bob.nodeInfo.legalIdentities.first(), numberOfMessages).returnValue.get()
|
||||||
|
|
||||||
|
// session data are not maintained for the initiator side.
|
||||||
|
val aliceLastSeqNumbers = alice.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get()
|
||||||
|
assertThat(aliceLastSeqNumbers).isEmpty()
|
||||||
|
|
||||||
|
// only one flow here, so sender sequence number is expected to start from zero and increment by one for each message.
|
||||||
|
val bobLastSeqNumbers = bob.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get()
|
||||||
|
assertThat(bobLastSeqNumbers).hasSize(1)
|
||||||
|
assertThat(bobLastSeqNumbers).first().isEqualTo(Pair(0, numberOfMessages - 1))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@StartableByRPC
|
||||||
|
@InitiatingFlow
|
||||||
|
class InitiatorFlow(private val otherParty: Party, private val numberOfMessages: Int) : FlowLogic<Unit>() {
|
||||||
|
@Suspendable
|
||||||
|
override fun call() {
|
||||||
|
val session = initiateFlow(otherParty)
|
||||||
|
session.send(numberOfMessages)
|
||||||
|
|
||||||
|
(2 .. numberOfMessages).forEach {
|
||||||
|
session.send("message $it")
|
||||||
|
}
|
||||||
|
|
||||||
|
session.receive<String>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@InitiatedBy(InitiatorFlow::class)
|
||||||
|
open class ResponderFlow(private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
|
||||||
|
@Suspendable
|
||||||
|
override fun call() {
|
||||||
|
val numberOfMessages = otherPartySession.receive<Int>().unwrap { it }
|
||||||
|
|
||||||
|
(2 .. numberOfMessages).forEach {
|
||||||
|
otherPartySession.receive<String>()
|
||||||
|
}
|
||||||
|
|
||||||
|
otherPartySession.send("Got them all")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@StartableByRPC
|
||||||
|
class GetSeqNumbersFlow: FlowLogic<MutableList<Pair<Int, Int>>>() {
|
||||||
|
@Suspendable
|
||||||
|
override fun call(): MutableList<Pair<Int, Int>> {
|
||||||
|
return getSeqNumbers()
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getSeqNumbers(): MutableList<Pair<Int, Int>> {
|
||||||
|
val sequenceNumbers = mutableListOf<Pair<Int, Int>>()
|
||||||
|
serviceHub.jdbcSession().createStatement().use { stmt ->
|
||||||
|
stmt.execute("SELECT init_sequence_number, last_sequence_number FROM node_session_data")
|
||||||
|
while (stmt.resultSet.next()) {
|
||||||
|
val firstSeqNo = stmt.resultSet.getInt(1)
|
||||||
|
val lastSeqNo = stmt.resultSet.getInt(2)
|
||||||
|
sequenceNumbers += Pair(firstSeqNo, lastSeqNo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sequenceNumbers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user