Add e2e/unit tests for session data persistence

This commit is contained in:
Dimos Raptis 2020-09-24 10:50:33 +01:00
parent f4fa08ed10
commit fcee4ed7cb
2 changed files with 235 additions and 0 deletions

View File

@ -0,0 +1,134 @@
package net.corda.services.messaging
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.ByteSequence
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.P2PMessageDeduplicator
import net.corda.node.services.messaging.ReceivedMessage
import net.corda.node.services.messaging.generateShardId
import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.TestingNamedCacheFactory
import net.corda.testing.internal.configureDatabase
import net.corda.testing.node.MockServices
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import java.lang.IllegalArgumentException
import java.math.BigInteger
import java.time.Instant
class P2PMessageDeduplicatorTest {
companion object {
private const val TOPIC = "whatever"
private val DATA = ByteSequence.of("blah blah blah".toByteArray())
private val SHARD_ID = generateShardId("some-flow-id")
private val SESSION_ID = SessionId(BigInteger.ONE)
private val TIMESTAMP = Instant.now()
private val SENDER = CordaX500Name("CordaWorld", "The Sea Devil", "NeverLand", "NL")
private const val SENDER_UUID = "some-sender-uuid"
private const val PLATFORM_VERSION = 42
private const val FIRST_SENDER_SEQ_NO = 10L
private const val LAST_SENDER_SEQ_NO = 35L
}
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
private lateinit var database: CordaPersistence
private lateinit var deduplicator: P2PMessageDeduplicator
@Before
fun setUp() {
val dataSourceProps = MockServices.makeTestDataSourceProperties()
database = configureDatabase(dataSourceProps, DatabaseConfig(), { null }, { null }, runMigrationScripts = true, allowHibernateToManageAppSchema = false)
deduplicator = P2PMessageDeduplicator(TestingNamedCacheFactory(), database)
}
@After
fun tearDown() {
database.close()
}
@Test(timeout=300_000)
fun `correctly deduplicates a session-init message`() {
val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP)
val receivedMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO)
assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isFalse()
processMessage(receivedMessage)
assertThat(deduplicator.isDuplicateSessionInit(receivedMessage)).isTrue()
}
@Test(timeout=300_000)
fun `fails when requested to deduplicate a non session-init message`() {
val msgId = MessageIdentifier(MessageType.DATA_MESSAGE, SHARD_ID, SESSION_ID, 3, TIMESTAMP)
val receivedMessage = createMessage(msgId, 25)
assertThatThrownBy { deduplicator.isDuplicateSessionInit(receivedMessage) }.isInstanceOf(IllegalArgumentException::class.java)
.hasMessageContaining("was not a session-init message")
}
@Test(timeout=300_000)
fun `updates session data correctly when session is completed`() {
val msgId = MessageIdentifier(MessageType.SESSION_INIT, SHARD_ID, SESSION_ID, 0, TIMESTAMP)
val sessionInitMessage = createMessage(msgId, FIRST_SENDER_SEQ_NO)
processMessage(sessionInitMessage)
val sessionDataAfterSessionInit = database.transaction {
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value)
}
assertThat(sessionDataAfterSessionInit.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
assertThat(sessionDataAfterSessionInit.lastSenderSeqNo).isNull()
assertThat(sessionDataAfterSessionInit.generationTime).isEqualTo(TIMESTAMP)
database.transaction {
deduplicator.signalSessionEnd(SESSION_ID, SENDER_UUID, LAST_SENDER_SEQ_NO)
}
val sessionDataAfterSessionEnd = database.transaction {
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value)
}
assertThat(sessionDataAfterSessionEnd.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
assertThat(sessionDataAfterSessionEnd.lastSenderSeqNo).isEqualTo(LAST_SENDER_SEQ_NO)
assertThat(sessionDataAfterSessionEnd.generationTime).isEqualTo(TIMESTAMP)
}
private fun processMessage(receivedMessage: ReceivedMessage) {
deduplicator.isDuplicateSessionInit(receivedMessage)
deduplicator.signalMessageProcessStart(receivedMessage)
database.transaction {
deduplicator.persistDeduplicationId(receivedMessage.uniqueMessageId)
}
deduplicator.signalMessageProcessFinish(receivedMessage.uniqueMessageId)
}
private fun createMessage(msgId: MessageIdentifier, senderSeqNo: Long?): ReceivedMessage {
return MockReceivedMessage(TOPIC, DATA, TIMESTAMP, msgId, SENDER_UUID, emptyMap(), SENDER, PLATFORM_VERSION, senderSeqNo, msgId.messageType == MessageType.SESSION_INIT)
}
data class MockReceivedMessage(override val topic: String,
override val data: ByteSequence,
override val debugTimestamp: Instant,
override val uniqueMessageId: MessageIdentifier,
override val senderUUID: String?,
override val additionalHeaders: Map<String, String>,
override val peer: CordaX500Name,
override val platformVersion: Int,
override val senderSeqNo: Long?,
override val isSessionInit: Boolean): ReceivedMessage
}

View File

@ -0,0 +1,101 @@
package net.corda.services.messaging
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.transpose
import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.User
import net.corda.testing.node.internal.enclosedCordapp
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class SessionDataPersistenceTest {
private val user = User("u", "p", setOf(Permissions.all()))
@Test(timeout=300_000)
fun `session data are persisted successfully and with the appropriate sequence numbers`() {
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList(), cordappsForAllNodes = setOf(enclosedCordapp()))) {
val (alice, bob) = listOf(
startNode(providedName = ALICE_NAME, rpcUsers = listOf(user)),
startNode(providedName = BOB_NAME)).transpose().getOrThrow()
val numberOfMessages = 3
alice.rpc.startFlow(::InitiatorFlow, bob.nodeInfo.legalIdentities.first(), numberOfMessages).returnValue.get()
// session data are not maintained for the initiator side.
val aliceLastSeqNumbers = alice.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get()
assertThat(aliceLastSeqNumbers).isEmpty()
// only one flow here, so sender sequence number is expected to start from zero and increment by one for each message.
val bobLastSeqNumbers = bob.rpc.startFlow(::GetSeqNumbersFlow).returnValue.get()
assertThat(bobLastSeqNumbers).hasSize(1)
assertThat(bobLastSeqNumbers).first().isEqualTo(Pair(0, numberOfMessages - 1))
}
}
@StartableByRPC
@InitiatingFlow
class InitiatorFlow(private val otherParty: Party, private val numberOfMessages: Int) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val session = initiateFlow(otherParty)
session.send(numberOfMessages)
(2 .. numberOfMessages).forEach {
session.send("message $it")
}
session.receive<String>()
}
}
@InitiatedBy(InitiatorFlow::class)
open class ResponderFlow(private val otherPartySession: FlowSession) : FlowLogic<Unit>() {
@Suspendable
override fun call() {
val numberOfMessages = otherPartySession.receive<Int>().unwrap { it }
(2 .. numberOfMessages).forEach {
otherPartySession.receive<String>()
}
otherPartySession.send("Got them all")
}
}
@StartableByRPC
class GetSeqNumbersFlow: FlowLogic<MutableList<Pair<Int, Int>>>() {
@Suspendable
override fun call(): MutableList<Pair<Int, Int>> {
return getSeqNumbers()
}
private fun getSeqNumbers(): MutableList<Pair<Int, Int>> {
val sequenceNumbers = mutableListOf<Pair<Int, Int>>()
serviceHub.jdbcSession().createStatement().use { stmt ->
stmt.execute("SELECT init_sequence_number, last_sequence_number FROM node_session_data")
while (stmt.resultSet.next()) {
val firstSeqNo = stmt.resultSet.getInt(1)
val lastSeqNo = stmt.resultSet.getInt(2)
sequenceNumbers += Pair(firstSeqNo, lastSeqNo)
}
}
return sequenceNumbers
}
}
}