mirror of
https://github.com/corda/corda.git
synced 2024-12-22 06:17:55 +00:00
Add e2e/unit tests for session data persistence
This commit is contained in:
parent
f4fa08ed10
commit
fcee4ed7cb
@ -0,0 +1,134 @@
|
||||
package net.corda.services.messaging
|
||||
|
||||
import net.corda.core.identity.CordaX500Name
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.node.services.messaging.MessageIdentifier
|
||||
import net.corda.node.services.messaging.P2PMessageDeduplicator
|
||||
import net.corda.node.services.messaging.ReceivedMessage
|
||||
import net.corda.node.services.messaging.generateShardId
|
||||
import net.corda.node.services.statemachine.MessageType
|
||||
import net.corda.node.services.statemachine.SessionId
|
||||
import net.corda.nodeapi.internal.persistence.CordaPersistence
|
||||
import net.corda.nodeapi.internal.persistence.DatabaseConfig
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import net.corda.testing.internal.TestingNamedCacheFactory
|
||||
import net.corda.testing.internal.configureDatabase
|
||||
import net.corda.testing.node.MockServices
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import java.lang.IllegalArgumentException
|
||||
import java.math.BigInteger
|
||||
import java.time.Instant
|
||||
|
||||
class P2PMessageDeduplicatorTest {
|
||||
|
||||
companion object {
|
||||
private const val TOPIC = "whatever"
|
||||
private val DATA = ByteSequence.of("blah blah blah".toByteArray())
|
||||
private val SHARD_ID = generateShardId("some-flow-id")
|
||||
private val SESSION_ID = SessionId(BigInteger.ONE)
|
||||
private val TIMESTAMP = Instant.now()
|
||||
private val SENDER = CordaX500Name("CordaWorld", "The Sea Devil", "NeverLand", "NL")
|
||||
private const val SENDER_UUID = "some-sender-uuid"
|
||||
private const val PLATFORM_VERSION = 42
|
||||
|
||||
private const val FIRST_SENDER_SEQ_NO = 10L
|
||||
private const val LAST_SENDER_SEQ_NO = 35L
|
||||
}
|
||||
|
||||
@Rule
|
||||
@JvmField
|
||||
val testSerialization = SerializationEnvironmentRule()
|
||||
|
||||
private lateinit var database: CordaPersistence
|
||||
private lateinit var deduplicator: P2PMessageDeduplicator
|
||||
|
||||
|
||||
@Before
|
||||
fun setUp() {
|
||||
val dataSourceProps = MockServices.makeTestDataSourceProperties()
|
||||
database = configureDatabase(dataSourceProps, DatabaseConfig(), { null }, { null }, runMigrationScripts = true, allowHibernateToManageAppSchema = false)
|
||||
deduplicator = P2PMessageDeduplicator(TestingNamedCacheFactory(), database)
|
||||
}
|
||||
|
||||
@After
|
||||
fun tearDown() {
|
||||
database.close()
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `correctly deduplicates a session-init message`() {
|
||||
val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP)
|
||||
val receivedMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO)
|
||||
|
||||
assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isFalse()
|
||||
|
||||
processMessage(receivedMessage)
|
||||
|
||||
assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isTrue()
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `fails when requested to deduplicate a non session-init message`() {
|
||||
val msgId = MessageIdentifier(MessageType.DATA_MESSAGE, SHARD_ID, SESSION_ID, 3, TIMESTAMP)
|
||||
val receivedMessage = createMessage(msgId, 25)
|
||||
|
||||
assertThatThrownBy { deduplicator.isDuplicateSessionInit(receivedMessage) }.isInstanceOf(IllegalArgumentException::class.java)
|
||||
.hasMessageContaining("was not a session-init message")
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `updates session data correctly when session is completed`() {
|
||||
val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP)
|
||||
val sessionInitMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO)
|
||||
|
||||
processMessage(sessionInitMessage)
|
||||
|
||||
val sessionDataAfterSessionInit = database.transaction {
|
||||
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value)
|
||||
}
|
||||
assertThat(sessionDataAfterSessionInit.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
|
||||
assertThat(sessionDataAfterSessionInit.lastSenderSeqNo).isNull()
|
||||
assertThat(sessionDataAfterSessionInit.generationTime).isEqualTo(TIMESTAMP)
|
||||
|
||||
database.transaction {
|
||||
deduplicator.signalSessionEnd(SESSION_ID, SENDER_UUID, LAST_SENDER_SEQ_NO)
|
||||
}
|
||||
|
||||
val sessionDataAfterSessionEnd = database.transaction {
|
||||
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value)
|
||||
}
|
||||
assertThat(sessionDataAfterSessionEnd.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
|
||||
assertThat(sessionDataAfterSessionEnd.lastSenderSeqNo).isEqualTo(LAST_SENDER_SEQ_NO)
|
||||
assertThat(sessionDataAfterSessionEnd.generationTime).isEqualTo(TIMESTAMP)
|
||||
}
|
||||
|
||||
private fun processMessage(receivedMessage: ReceivedMessage) {
|
||||
deduplicator.isDuplicateSessionInit(receivedMessage)
|
||||
|
||||
deduplicator.signalMessageProcessStart(receivedMessage)
|
||||
database.transaction {
|
||||
deduplicator.persistDeduplicationId(receivedMessage.uniqueMessageId)
|
||||
}
|
||||
deduplicator.signalMessageProcessFinish(receivedMessage.uniqueMessageId)
|
||||
}
|
||||
|
||||
private fun createMessage(msgId: MessageIdentifier, senderSeqNo: Long?): ReceivedMessage {
|
||||
return MockReceivedMessage(TOPIC, DATA, TIMESTAMP, msgId, SENDER_UUID, emptyMap(), SENDER, PLATFORM_VERSION, senderSeqNo, msgId.messageType == MessageType.SESSION_INIT)
|
||||
}
|
||||
|
||||
data class MockReceivedMessage(override val topic: String,
|
||||
override val data: ByteSequence,
|
||||
override val debugTimestamp: Instant,
|
||||
override val uniqueMessageId: MessageIdentifier,
|
||||
override val senderUUID: String?,
|
||||
override val additionalHeaders: Map<String, String>,
|
||||
override val peer: CordaX500Name,
|
||||
override val platformVersion: Int,
|
||||
override val senderSeqNo: Long?,
|
||||
override val isSessionInit: Boolean): ReceivedMessage
|
||||
}
|
@ -0,0 +1,101 @@
|
||||
package net.corda.services.messaging
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.flows.InitiatedBy
|
||||
import net.corda.core.flows.InitiatingFlow
|
||||
import net.corda.core.flows.StartableByRPC
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.concurrent.transpose
|
||||
import net.corda.core.messaging.startFlow
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.node.services.Permissions
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
import net.corda.testing.driver.DriverParameters
|
||||
import net.corda.testing.driver.driver
|
||||
import net.corda.testing.node.User
|
||||
import net.corda.testing.node.internal.enclosedCordapp
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.Test
|
||||
|
||||
class SessionDataPersistenceTest {
|
||||
|
||||
private val user = User("u", "p", setOf(Permissions.all()))
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `session data are persisted successfully and with the appropriate sequence numbers`() {
|
||||
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList(), cordappsForAllNodes = setOf(enclosedCordapp()))) {
|
||||
val (alice, bob) = listOf(
|
||||
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
|
||||
startNode(providedName = BOB_NAME)).transpose().getOrThrow()
|
||||
|
||||
val numberOfMessages = 3
|
||||
alice.rpc.startFlow(::InitiatorFlow, bob.nodeInfo.legalIdentities.first(), numberOfMessages).returnValue.get()
|
||||
|
||||
// session data are not maintained for the initiator side.
|
||||
val aliceLastSeqNumbers = alice.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get()
|
||||
assertThat(aliceLastSeqNumbers).isEmpty()
|
||||
|
||||
// only one flow here, so sender sequence number is expected to start from zero and increment by one for each message.
|
||||
val bobLastSeqNumbers = bob.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get()
|
||||
assertThat(bobLastSeqNumbers).hasSize(1)
|
||||
assertThat(bobLastSeqNumbers).first().isEqualTo(Pair(0, numberOfMessages - 1))
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class InitiatorFlow(private val otherParty: Party, private val numberOfMessages: Int) : FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val session = initiateFlow(otherParty)
|
||||
session.send(numberOfMessages)
|
||||
|
||||
(2 .. numberOfMessages).forEach {
|
||||
session.send("message $it")
|
||||
}
|
||||
|
||||
session.receive<String>()
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatedBy(InitiatorFlow::class)
|
||||
open class ResponderFlow(private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val numberOfMessages = otherPartySession.receive<Int>().unwrap { it }
|
||||
|
||||
(2 .. numberOfMessages).forEach {
|
||||
otherPartySession.receive<String>()
|
||||
}
|
||||
|
||||
otherPartySession.send("Got them all")
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
class GetSeqNumbersFlow: FlowLogic<MutableList<Pair<Int, Int>>>() {
|
||||
@Suspendable
|
||||
override fun call(): MutableList<Pair<Int, Int>> {
|
||||
return getSeqNumbers()
|
||||
}
|
||||
|
||||
private fun getSeqNumbers(): MutableList<Pair<Int, Int>> {
|
||||
val sequenceNumbers = mutableListOf<Pair<Int, Int>>()
|
||||
serviceHub.jdbcSession().createStatement().use { stmt ->
|
||||
stmt.execute("SELECT init_sequence_number, last_sequence_number FROM node_session_data")
|
||||
while (stmt.resultSet.next()) {
|
||||
val firstSeqNo = stmt.resultSet.getInt(1)
|
||||
val lastSeqNo = stmt.resultSet.getInt(2)
|
||||
sequenceNumbers += Pair(firstSeqNo, lastSeqNo)
|
||||
}
|
||||
}
|
||||
|
||||
return sequenceNumbers
|
||||
}
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user