Address Rick's comments

This commit is contained in:
Dimos Raptis 2020-09-29 15:38:24 +01:00
parent 1af6e89927
commit 0858b7852d
9 changed files with 46 additions and 27 deletions

View File

@ -89,7 +89,7 @@ class P2PMessageDeduplicatorTest {
processMessage(sessionInitMessage) processMessage(sessionInitMessage)
val sessionDataAfterSessionInit = database.transaction { val sessionDataAfterSessionInit = database.transaction {
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value) entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.toHex())
} }
assertThat(sessionDataAfterSessionInit.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO) assertThat(sessionDataAfterSessionInit.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
assertThat(sessionDataAfterSessionInit.lastSenderSeqNo).isNull() assertThat(sessionDataAfterSessionInit.lastSenderSeqNo).isNull()
@ -100,7 +100,7 @@ class P2PMessageDeduplicatorTest {
} }
val sessionDataAfterSessionEnd = database.transaction { val sessionDataAfterSessionEnd = database.transaction {
entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.value) entityManager.find(P2PMessageDeduplicator.SessionData::class.java, SESSION_ID.toHex())
} }
assertThat(sessionDataAfterSessionEnd.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO) assertThat(sessionDataAfterSessionEnd.firstSenderSeqNo).isEqualTo(FIRST_SENDER_SEQ_NO)
assertThat(sessionDataAfterSessionEnd.lastSenderSeqNo).isEqualTo(LAST_SENDER_SEQ_NO) assertThat(sessionDataAfterSessionEnd.lastSenderSeqNo).isEqualTo(LAST_SENDER_SEQ_NO)

View File

@ -2,7 +2,6 @@ package net.corda.node.services.messaging
import net.corda.node.services.statemachine.MessageType import net.corda.node.services.statemachine.MessageType
import net.corda.node.services.statemachine.SessionId import net.corda.node.services.statemachine.SessionId
import java.math.BigInteger
import java.time.Instant import java.time.Instant
/** /**
@ -31,23 +30,22 @@ data class MessageIdentifier(
companion object { companion object {
const val SHARD_SIZE_IN_CHARS = 8 const val SHARD_SIZE_IN_CHARS = 8
const val LONG_SIZE_IN_HEX = 16 // 64 / 4 const val LONG_SIZE_IN_HEX = 16 // 64 / 4
const val SESSION_ID_SIZE_IN_HEX = SessionId.MAX_BIT_SIZE / 4 private const val HEX_RADIX = 16
const val HEX_RADIX = 16
fun parse(id: String): MessageIdentifier { fun parse(id: String): MessageIdentifier {
val prefix = id.substring(0, 2) val prefix = id.substring(0, 2)
val messageType = MessageType.fromPrefix(prefix) val messageType = MessageType.fromPrefix(prefix)
val timestamp = java.lang.Long.parseUnsignedLong(id.substring(3, 19), HEX_RADIX) val timestamp = java.lang.Long.parseUnsignedLong(id.substring(3, 19), HEX_RADIX)
val shardIdentifier = id.substring(20, 28) val shardIdentifier = id.substring(20, 28)
val sessionId = BigInteger(id.substring(29, 61), HEX_RADIX) val sessionId = SessionId.fromHex(id.substring(29, 61))
val sessionSequenceNumber = Integer.parseInt(id.substring(62), HEX_RADIX) val sessionSequenceNumber = Integer.parseInt(id.substring(62), HEX_RADIX)
return MessageIdentifier(messageType, shardIdentifier, SessionId(sessionId), sessionSequenceNumber, Instant.ofEpochMilli(timestamp)) return MessageIdentifier(messageType, shardIdentifier, sessionId, sessionSequenceNumber, Instant.ofEpochMilli(timestamp))
} }
} }
override fun toString(): String { override fun toString(): String {
val prefix = messageType.prefix val prefix = messageType.prefix
val encodedSessionIdentifier = String.format("%1$0${SESSION_ID_SIZE_IN_HEX}X", sessionIdentifier.value) val encodedSessionIdentifier = sessionIdentifier.toHex()
val encodedSequenceNumber = Integer.toHexString(sessionSequenceNumber).toUpperCase() val encodedSequenceNumber = Integer.toHexString(sessionSequenceNumber).toUpperCase()
val encodedTimestamp = String.format("%1$0${LONG_SIZE_IN_HEX}X", timestamp.toEpochMilli()) val encodedTimestamp = String.format("%1$0${LONG_SIZE_IN_HEX}X", timestamp.toEpochMilli())
return "$prefix-$encodedTimestamp-$shardIdentifier-$encodedSessionIdentifier-$encodedSequenceNumber" return "$prefix-$encodedTimestamp-$shardIdentifier-$encodedSessionIdentifier-$encodedSequenceNumber"

View File

@ -41,15 +41,15 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa
*/ */
private val sessionData = createSessionDataMap(cacheFactory) private val sessionData = createSessionDataMap(cacheFactory)
private fun createSessionDataMap(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<SessionId, MessageMeta, SessionData, BigInteger> { private fun createSessionDataMap(cacheFactory: NamedCacheFactory): AppendOnlyPersistentMap<SessionId, MessageMeta, SessionData, String> {
return AppendOnlyPersistentMap( return AppendOnlyPersistentMap(
cacheFactory = cacheFactory, cacheFactory = cacheFactory,
name = "P2PMessageDeduplicator_sessionData", name = "P2PMessageDeduplicator_sessionData",
toPersistentEntityKey = { it.value }, toPersistentEntityKey = { it.toHex() },
fromPersistentEntity = { Pair(SessionId(it.sessionId), MessageMeta(it.generationTime, it.senderHash, it.firstSenderSeqNo, it.lastSenderSeqNo)) }, fromPersistentEntity = { Pair(SessionId.fromHex(it.sessionId), MessageMeta(it.generationTime, it.senderHash, it.firstSenderSeqNo, it.lastSenderSeqNo)) },
toPersistentEntity = { key: SessionId, value: MessageMeta -> toPersistentEntity = { key: SessionId, value: MessageMeta ->
SessionData().apply { SessionData().apply {
sessionId = key.value sessionId = key.toHex()
generationTime = value.generationTime generationTime = value.generationTime
senderHash = value.senderHash senderHash = value.senderHash
firstSenderSeqNo = value.firstSenderSeqNo firstSenderSeqNo = value.firstSenderSeqNo
@ -132,7 +132,7 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa
val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(SessionData::class.java) val criteriaUpdate = criteriaBuilder.createCriteriaUpdate(SessionData::class.java)
val queryRoot = criteriaUpdate.from(SessionData::class.java) val queryRoot = criteriaUpdate.from(SessionData::class.java)
criteriaUpdate.set(SessionData::lastSenderSeqNo.name, value.lastSenderSeqNo) criteriaUpdate.set(SessionData::lastSenderSeqNo.name, value.lastSenderSeqNo)
criteriaUpdate.where(criteriaBuilder.equal(queryRoot.get<BigInteger>(SessionData::sessionId.name), key.value)) criteriaUpdate.where(criteriaBuilder.equal(queryRoot.get<BigInteger>(SessionData::sessionId.name), key.toHex()))
val update = session.createQuery(criteriaUpdate) val update = session.createQuery(criteriaUpdate)
val rowsUpdated = update.executeUpdate() val rowsUpdated = update.executeUpdate()
return rowsUpdated != 0 return rowsUpdated != 0
@ -142,9 +142,12 @@ class P2PMessageDeduplicator(cacheFactory: NamedCacheFactory, private val databa
@Suppress("MagicNumber") // database column width @Suppress("MagicNumber") // database column width
@Table(name = "${NODE_DATABASE_PREFIX}session_data") @Table(name = "${NODE_DATABASE_PREFIX}session_data")
class SessionData ( class SessionData (
/**
* The session identifier in hexadecimal form.
*/
@Id @Id
@Column(name = "session_id", nullable = false) @Column(name = "session_id", nullable = false)
var sessionId: BigInteger = BigInteger.ZERO, var sessionId: String = "",
/** /**
* The time the corresponding session-init message was originally generated on the sender side. * The time the corresponding session-init message was originally generated on the sender side.

View File

@ -40,11 +40,23 @@ data class SessionId(val value: BigInteger) {
SessionId(this.value.plus(BigInteger.ONE)) SessionId(this.value.plus(BigInteger.ONE))
} }
fun toHex(): String {
return String.format("%1$0${SESSION_ID_SIZE_IN_HEX}X", value)
}
companion object { companion object {
const val MAX_BIT_SIZE = 128 const val MAX_BIT_SIZE = 128
const val SESSION_ID_SIZE_IN_HEX = MAX_BIT_SIZE / 4
val LARGEST_SESSION_ID = BigInteger.valueOf(2).pow(MAX_BIT_SIZE).minus(BigInteger.ONE) val LARGEST_SESSION_ID = BigInteger.valueOf(2).pow(MAX_BIT_SIZE).minus(BigInteger.ONE)
fun createRandom(secureRandom: SecureRandom) = SessionId(BigInteger(MAX_BIT_SIZE, secureRandom)) fun createRandom(secureRandom: SecureRandom) = SessionId(BigInteger(MAX_BIT_SIZE, secureRandom))
@Suppress("MagicNumber")
fun fromHex(hexValue: String): SessionId {
require(hexValue.length == SESSION_ID_SIZE_IN_HEX) { "A session identifier in hex form must be $SESSION_ID_SIZE_IN_HEX characters long" }
val value = BigInteger(hexValue, 16)
return SessionId(value)
}
} }
} }

View File

@ -38,9 +38,6 @@ import net.corda.node.internal.InitiatedFlowFactory
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.DeduplicationHandler import net.corda.node.services.messaging.DeduplicationHandler
import net.corda.node.services.messaging.MessageIdentifier
import net.corda.node.services.messaging.SenderSequenceNumber
import net.corda.node.services.messaging.SenderUUID
import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.currentStateMachine import net.corda.node.services.statemachine.FlowStateMachineImpl.Companion.currentStateMachine
import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor import net.corda.node.services.statemachine.interceptors.DumpHistoryOnErrorInterceptor
import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor import net.corda.node.services.statemachine.interceptors.HospitalisingInterceptor
@ -707,8 +704,7 @@ internal class SingleThreadedStateMachineManager(
val sender = serviceHub.networkMapCache.getPeerByLegalName(peer) val sender = serviceHub.networkMapCache.getPeerByLegalName(peer)
if (sender != null) { if (sender != null) {
when (sessionMessage) { when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event, is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage, sender, event)
event.receivedMessage.uniqueMessageId, event.receivedMessage.senderUUID, event.receivedMessage.senderSeqNo)
is InitialSessionMessage -> onSessionInit(sessionMessage, sender, event) is InitialSessionMessage -> onSessionInit(sessionMessage, sender, event)
} }
} else { } else {
@ -721,10 +717,7 @@ internal class SingleThreadedStateMachineManager(
private fun onExistingSessionMessage( private fun onExistingSessionMessage(
sessionMessage: ExistingSessionMessage, sessionMessage: ExistingSessionMessage,
sender: Party, sender: Party,
externalEvent: ExternalEvent.ExternalMessageEvent, externalEvent: ExternalEvent.ExternalMessageEvent
messageIdentifier: MessageIdentifier,
senderUUID: SenderUUID?,
senderSequenceNumber: SenderSequenceNumber?
) { ) {
try { try {
val deduplicationHandler = externalEvent.deduplicationHandler val deduplicationHandler = externalEvent.deduplicationHandler
@ -742,7 +735,8 @@ internal class SingleThreadedStateMachineManager(
logger.info("Cannot find flow corresponding to session ID - $recipientId.") logger.info("Cannot find flow corresponding to session ID - $recipientId.")
} }
} else { } else {
val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender, messageIdentifier, senderUUID, senderSequenceNumber) val event = Event.DeliverSessionMessage(sessionMessage, deduplicationHandler, sender,
externalEvent.receivedMessage.uniqueMessageId, externalEvent.receivedMessage.senderUUID, externalEvent.receivedMessage.senderSeqNo)
innerState.withLock { innerState.withLock {
flows[flowId]?.run { fiber.scheduleEvent(event) } flows[flowId]?.run { fiber.scheduleEvent(event) }
// If flow is not running add it to the list of external events to be processed if/when the flow resumes. // If flow is not running add it to the list of external events to be processed if/when the flow resumes.

View File

@ -354,7 +354,7 @@ sealed class SessionState {
* Returns the new form of the state * Returns the new form of the state
*/ */
fun bufferMessage(messageIdentifier: MessageIdentifier, messagePayload: ExistingSessionMessagePayload): SessionState { fun bufferMessage(messageIdentifier: MessageIdentifier, messagePayload: ExistingSessionMessagePayload): SessionState {
return this.copy(bufferedMessages = bufferedMessages + Pair(messageIdentifier, messagePayload), nextSendingSeqNumber = nextSendingSeqNumber + 1) return bufferMessages(listOf(messageIdentifier to messagePayload))
} }
/** /**

View File

@ -245,7 +245,7 @@ class TopLevelTransition(
checkpointState = checkpoint.checkpointState.copy( checkpointState = checkpoint.checkpointState.copy(
numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1, numberOfSuspends = checkpoint.checkpointState.numberOfSuspends + 1,
numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1, numberOfCommits = checkpoint.checkpointState.numberOfCommits + 1,
suspensionTime = context.time suspensionTime = context.time
), ),
flowState = FlowState.Finished, flowState = FlowState.Finished,
result = event.returnValue, result = event.returnValue,

View File

@ -6,7 +6,7 @@
<changeSet author="R3.Corda" id="add_session_data_table"> <changeSet author="R3.Corda" id="add_session_data_table">
<createTable tableName="node_session_data"> <createTable tableName="node_session_data">
<column name="session_id" type="NUMBER(128)"> <column name="session_id" type="NVARCHAR(32)">
<constraints nullable="false"/> <constraints nullable="false"/>
</column> </column>
<column name="init_generation_time" type="timestamp"> <column name="init_generation_time" type="timestamp">

View File

@ -39,4 +39,16 @@ class SessionIdTest {
assertThat(initiatedSessionId.value.toLong()).isEqualTo(0) assertThat(initiatedSessionId.value.toLong()).isEqualTo(0)
} }
@Test(timeout=300_000)
fun `conversion from and to hex form works properly`() {
val sessionId = SessionId(BigInteger.valueOf(42))
val sessionIdHexForm = "0000000000000000000000000000002A"
assertThat(sessionId.toHex()).isEqualTo(sessionIdHexForm)
assertThat(SessionId.fromHex(sessionIdHexForm)).isEqualTo(sessionId)
assertThatThrownBy { SessionId.fromHex("2A") }
.isInstanceOf(IllegalArgumentException::class.java)
.hasMessageContaining("A session identifier in hex form must be 32 characters long")
}
} }