Store protocol checkpoints in the DB, except during Single threaded MockNetwork activity, where we still use the file system based checkpointing. Make the Checkpoint acess a Sequence not an Iterable, so that we don't end up with all the checkpoints permanently resident in memory.

Split up storage initialisation so that there is less code copying in MockNode

Add header comment to DBCheckpointStorage class

Respond to PR comments

Resolve PR comments

Rename iterator on checkpoints

Fix typo

Fixup checkpoints in DB logic after Shams's PR

Delete duplicated code
This commit is contained in:
Matthew Nesbit 2016-09-28 15:18:22 +01:00
parent 67fdf9b2ff
commit 76808d36c3
12 changed files with 325 additions and 38 deletions

View File

@ -151,6 +151,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
val customServices: ArrayList<Any> = ArrayList() val customServices: ArrayList<Any> = ArrayList()
protected val runOnStop: ArrayList<Runnable> = ArrayList() protected val runOnStop: ArrayList<Runnable> = ArrayList()
lateinit var database: Database lateinit var database: Database
protected var dbCloser: Runnable? = null
/** Locates and returns a service of the given type if loaded, or throws an exception if not found. */ /** Locates and returns a service of the given type if loaded, or throws an exception if not found. */
inline fun <reified T : Any> findService() = customServices.filterIsInstance<T>().single() inline fun <reified T : Any> findService() = customServices.filterIsInstance<T>().single()
@ -238,13 +239,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
// Add vault observers // Add vault observers
CashBalanceAsMetricsObserver(services) CashBalanceAsMetricsObserver(services)
ScheduledActivityObserver(services) ScheduledActivityObserver(services)
}
startMessagingService(ServerRPCOps(services, smm, database)) checkpointStorage.forEach {
runOnStop += Runnable { net.stop() } isPreviousCheckpointsPresent = true
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) false
isPreviousCheckpointsPresent = checkpointStorage.checkpoints.any() }
smm.start() startMessagingService(ServerRPCOps(services, smm, database))
runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
smm.start()
}
started = true started = true
return this return this
} }
@ -269,7 +273,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
this.database = database this.database = database
// Now log the vendor string as this will also cause a connection to be tested eagerly. // Now log the vendor string as this will also cause a connection to be tested eagerly.
log.info("Connected to ${database.vendor} database.") log.info("Connected to ${database.vendor} database.")
runOnStop += Runnable { toClose.close() } dbCloser = Runnable { toClose.close() }
runOnStop += dbCloser!!
databaseTransaction(database) { databaseTransaction(database) {
insideTransaction() insideTransaction()
} }
@ -429,9 +434,13 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?) protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?)
protected open fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return DBCheckpointStorage()
}
protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> { protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> {
val attachments = makeAttachmentStorage(dir) val attachments = makeAttachmentStorage(dir)
val checkpointStorage = PerFileCheckpointStorage(dir.resolve("checkpoints")) val checkpointStorage = initialiseCheckpointService(dir)
val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions")) val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions"))
_servicesThatAcceptUploads += attachments _servicesThatAcceptUploads += attachments
val (identity, keyPair) = obtainKeyPair(dir) val (identity, keyPair) = obtainKeyPair(dir)

View File

@ -21,11 +21,11 @@ interface CheckpointStorage {
fun removeCheckpoint(checkpoint: Checkpoint) fun removeCheckpoint(checkpoint: Checkpoint)
/** /**
* Returns a snapshot of all the checkpoints in the store. * Allows the caller to process safely in a thread safe fashion the set of all checkpoints.
* This may return more checkpoints than were added to this instance of the store; for example if the store persists * The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management.
* checkpoints to disk. * Return false from the block to terminate further iteration.
*/ */
val checkpoints: Iterable<Checkpoint> fun forEach(block: (Checkpoint)->Boolean)
} }

View File

@ -0,0 +1,39 @@
package com.r3corda.node.services.persistence
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.utilities.JDBCHashMap
import java.util.*
/**
* Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites.
*/
class DBCheckpointStorage : CheckpointStorage {
private val checkpointStorage = Collections.synchronizedMap(JDBCHashMap<SecureHash, SerializedBytes<Checkpoint>>("checkpoints", loadOnInit = false))
override fun addCheckpoint(checkpoint: Checkpoint) {
val serialisedCheckpoint = checkpoint.serialize()
val id = serialisedCheckpoint.hash
checkpointStorage.put(id, serialisedCheckpoint)
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
val serialisedCheckpoint = checkpoint.serialize()
val id = serialisedCheckpoint.hash
checkpointStorage.remove(id) ?: throw IllegalArgumentException("Checkpoint not found")
}
override fun forEach(block: (Checkpoint) -> Boolean) {
synchronized(checkpointStorage) {
for (checkpoint in checkpointStorage.values) {
if (!block(checkpoint.deserialize())) {
break
}
}
}
}
}

View File

@ -60,9 +60,14 @@ class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage {
logger.trace { "Removed $checkpoint ($checkpointFile)" } logger.trace { "Removed $checkpoint ($checkpointFile)" }
} }
override val checkpoints: Iterable<Checkpoint> override fun forEach(block: (Checkpoint)->Boolean) {
get() = synchronized(checkpointFiles) { synchronized(checkpointFiles) {
checkpointFiles.keys.toList() for(checkpoint in checkpointFiles.keys) {
if (!block(checkpoint)) {
break
}
}
} }
}
} }

View File

@ -17,8 +17,10 @@ import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.trace import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.* import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.StrandLocalTransactionManager
import com.r3corda.node.utilities.createDatabaseTransaction import com.r3corda.node.utilities.createDatabaseTransaction
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
@ -62,6 +64,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
@Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal lateinit var database: Database @Transient internal lateinit var database: Database
@Transient internal var fromCheckpoint: Boolean = false @Transient internal var fromCheckpoint: Boolean = false
@Transient internal var txTrampoline: Transaction? = null
@Transient private var _logger: Logger? = null @Transient private var _logger: Logger? = null
override val logger: Logger get() { override val logger: Logger get() {
@ -113,7 +116,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." } logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." }
} }
private fun commitTransaction() { internal fun commitTransaction() {
val transaction = TransactionManager.current() val transaction = TransactionManager.current()
try { try {
logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." } logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." }
@ -210,7 +213,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
if (receivedMessage is SessionEnd) { if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session) openSessions.values.remove(receiveRequest.session)
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended") throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) { } else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage) return receiveRequest.receiveType.cast(receivedMessage)
} else { } else {
@ -220,9 +223,14 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable @Suspendable
private fun suspend(ioRequest: ProtocolIORequest) { private fun suspend(ioRequest: ProtocolIORequest) {
commitTransaction() // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null)
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended $id on $ioRequest" } logger.trace { "Suspended $id on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
txTrampoline = null
try { try {
actionOnSuspend(ioRequest) actionOnSuspend(ioRequest)
} catch (t: Throwable) { } catch (t: Throwable) {

View File

@ -28,6 +28,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import kotlinx.support.jdk8.collections.removeIf import kotlinx.support.jdk8.collections.removeIf
import com.r3corda.node.utilities.isolatedTransaction
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -157,13 +158,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun restoreFibersFromCheckpoints() { private fun restoreFibersFromCheckpoints() {
mutex.locked { mutex.locked {
checkpointStorage.checkpoints.forEach { checkpointStorage.forEach {
// If a protocol is added before start() then don't attempt to restore it // If a protocol is added before start() then don't attempt to restore it
if (!stateMachines.containsValue(it)) { if (!stateMachines.containsValue(it)) {
val fiber = deserializeFiber(it.serialisedFiber) val fiber = deserializeFiber(it.serialisedFiber)
initFiber(fiber) initFiber(fiber)
stateMachines[fiber] = it stateMachines[fiber] = it
} }
true
} }
} }
} }
@ -276,6 +278,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
psm.serviceHub = serviceHub psm.serviceHub = serviceHub
psm.actionOnSuspend = { ioRequest -> psm.actionOnSuspend = { ioRequest ->
updateCheckpoint(psm) updateCheckpoint(psm)
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
psm.commitTransaction()
processIORequest(ioRequest) processIORequest(ioRequest)
} }
psm.actionOnEnd = { psm.actionOnEnd = {
@ -334,7 +339,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
*/ */
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> { fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val fiber = createFiber(loggerName, logic) val fiber = createFiber(loggerName, logic)
updateCheckpoint(fiber) // We swap out the parent transaction context as using this frequently leads to a deadlock as we wait
// on the protocol completion future inside that context. The problem is that any progress checkpoints are
// unable to acquire the table lock and move forward till the calling transaction finishes.
// Committing in line here on a fresh context ensure we can progress.
isolatedTransaction(database) {
updateCheckpoint(fiber)
}
// If we are not started then our checkpoint will be picked up during start // If we are not started then our checkpoint will be picked up during start
mutex.locked { mutex.locked {
if (started) { if (started) {

View File

@ -35,6 +35,15 @@ fun configureDatabase(props: Properties): Pair<Closeable, Database> {
return Pair(dataSource, database) return Pair(dataSource, database)
} }
fun <T> isolatedTransaction(database: Database, block: Transaction.() -> T): T {
val oldContext = StrandLocalTransactionManager.setThreadLocalTx(null)
return try {
databaseTransaction(database, block)
} finally {
StrandLocalTransactionManager.restoreThreadLocalTx(oldContext)
}
}
/** /**
* A relatively close copy of the [ThreadLocalTransactionManager] in Exposed but with the following adjustments to suit * A relatively close copy of the [ThreadLocalTransactionManager] in Exposed but with the following adjustments to suit
* our environment: * our environment:
@ -51,6 +60,17 @@ class StrandLocalTransactionManager(initWithDatabase: Database) : TransactionMan
private val threadLocalDb = ThreadLocal<Database>() private val threadLocalDb = ThreadLocal<Database>()
private val threadLocalTx = ThreadLocal<Transaction>() private val threadLocalTx = ThreadLocal<Transaction>()
fun setThreadLocalTx(tx: Transaction?): Pair<Database?, Transaction?> {
val oldTx = threadLocalTx.get()
threadLocalTx.set(tx)
return Pair(threadLocalDb.get(), oldTx)
}
fun restoreThreadLocalTx(context: Pair<Database?, Transaction?>) {
threadLocalDb.set(context.first)
threadLocalTx.set(context.second)
}
var database: Database var database: Database
get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}") get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}")
set(value: Database) { set(value: Database) {

View File

@ -22,11 +22,13 @@ import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.PerFileTransactionStorage import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.testing.* import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.persistence.checkpoints
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
@ -83,6 +85,9 @@ class TwoPartyTradeProtocolTests {
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
aliceNode.disableDBCloseOnStop()
bobNode.disableDBCloseOnStop()
bobNode.services.fillWithSomeTestCash(2000.DOLLARS) bobNode.services.fillWithSomeTestCash(2000.DOLLARS)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey, val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
@ -98,8 +103,14 @@ class TwoPartyTradeProtocolTests {
aliceNode.stop() aliceNode.stop()
bobNode.stop() bobNode.stop()
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty() databaseTransaction(aliceNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty() assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty()
}
aliceNode.manuallyCloseDB()
databaseTransaction(bobNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
}
aliceNode.manuallyCloseDB()
} }
} }
@ -135,7 +146,7 @@ class TwoPartyTradeProtocolTests {
bobNode.pumpReceive(false) bobNode.pumpReceive(false)
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1) assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions
assertThat(bobTransactionsBeforeCrash).isNotEmpty() assertThat(bobTransactionsBeforeCrash).isNotEmpty()
@ -167,8 +178,8 @@ class TwoPartyTradeProtocolTests {
assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty() assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty()
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty() assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty() assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty()
val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null } val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null }
assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash) assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash)

View File

@ -0,0 +1,162 @@
package com.r3corda.node.services.persistence
import com.google.common.primitives.Ints
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.utilities.LogHelper
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.transactions.PersistentUniquenessProvider
import com.r3corda.node.utilities.configureDatabase
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.jetbrains.exposed.sql.Database
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.io.Closeable
internal fun CheckpointStorage.checkpoints(): List<Checkpoint> {
val checkpoints = mutableListOf<Checkpoint>()
forEach {
checkpoints += it
true
}
return checkpoints
}
class DBCheckpointStorageTests {
lateinit var checkpointStorage: DBCheckpointStorage
lateinit var dataSource: Closeable
lateinit var database: Database
@Before
fun setUp() {
LogHelper.setLevel(PersistentUniquenessProvider::class)
val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties())
dataSource = dataSourceAndDatabase.first
database = dataSourceAndDatabase.second
newCheckpointStorage()
}
@After
fun cleanUp() {
dataSource.close()
LogHelper.reset(PersistentUniquenessProvider::class)
}
@Test
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(checkpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
}
@Test
fun `remove checkpoint`() {
val checkpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(checkpoint)
}
databaseTransaction(database) {
checkpointStorage.removeCheckpoint(checkpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
}
@Test
fun `add and remove checkpoint in single commit operate`() {
val checkpoint = newCheckpoint()
val checkpoint2 = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(checkpoint2)
checkpointStorage.removeCheckpoint(checkpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
}
}
@Test
fun `remove unknown checkpoint`() {
val checkpoint = newCheckpoint()
databaseTransaction(database) {
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
checkpointStorage.removeCheckpoint(checkpoint)
}
}
}
@Test
fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(firstCheckpoint)
}
val secondCheckpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(secondCheckpoint)
}
databaseTransaction(database) {
checkpointStorage.removeCheckpoint(firstCheckpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
}
}
@Test
fun `add checkpoint and then remove after 'restart'`() {
val originalCheckpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(originalCheckpoint)
}
newCheckpointStorage()
val reconstructedCheckpoint = databaseTransaction(database) {
checkpointStorage.checkpoints().single()
}
databaseTransaction(database) {
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
}
databaseTransaction(database) {
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
}
private fun newCheckpointStorage() {
databaseTransaction(database) {
checkpointStorage = DBCheckpointStorage()
}
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
}

View File

@ -34,9 +34,9 @@ class PerFileCheckpointStorageTests {
fun `add new checkpoint`() { fun `add new checkpoint`() {
val checkpoint = newCheckpoint() val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.addCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
newCheckpointStorage() newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
} }
@Test @Test
@ -44,9 +44,9 @@ class PerFileCheckpointStorageTests {
val checkpoint = newCheckpoint() val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.removeCheckpoint(checkpoint) checkpointStorage.removeCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints).isEmpty() assertThat(checkpointStorage.checkpoints()).isEmpty()
newCheckpointStorage() newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).isEmpty() assertThat(checkpointStorage.checkpoints()).isEmpty()
} }
@Test @Test
@ -64,9 +64,9 @@ class PerFileCheckpointStorageTests {
val secondCheckpoint = newCheckpoint() val secondCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(secondCheckpoint) checkpointStorage.addCheckpoint(secondCheckpoint)
checkpointStorage.removeCheckpoint(firstCheckpoint) checkpointStorage.removeCheckpoint(firstCheckpoint)
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint) assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
newCheckpointStorage() newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint) assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
} }
@Test @Test
@ -74,10 +74,10 @@ class PerFileCheckpointStorageTests {
val originalCheckpoint = newCheckpoint() val originalCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(originalCheckpoint) checkpointStorage.addCheckpoint(originalCheckpoint)
newCheckpointStorage() newCheckpointStorage()
val reconstructedCheckpoint = checkpointStorage.checkpoints.single() val reconstructedCheckpoint = checkpointStorage.checkpoints().single()
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
checkpointStorage.removeCheckpoint(reconstructedCheckpoint) checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
assertThat(checkpointStorage.checkpoints).isEmpty() assertThat(checkpointStorage.checkpoints()).isEmpty()
} }
@Test @Test
@ -86,7 +86,7 @@ class PerFileCheckpointStorageTests {
checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.addCheckpoint(checkpoint)
Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray()) Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray())
newCheckpointStorage() newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
} }
private fun newCheckpointStorage() { private fun newCheckpointStorage() {

View File

@ -12,6 +12,7 @@ import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
@ -135,13 +136,13 @@ class StateMachineManagerTests {
// Kick off first send and receive // Kick off first send and receive
node2.smm.add("test", PingPongProtocol(node3.info.identity, payload)) node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints.count()) assertEquals(1, node2.checkpointStorage.checkpoints().count())
// Restart node and thus reload the checkpoint and resend the message with same UUID // Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop() node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>() val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
net.runNetwork() net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints.count()) assertEquals(1, node2.checkpointStorage.checkpoints().count())
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync. // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork() net.runNetwork()
node2b.smm.executor.flush() node2b.smm.executor.flush()
@ -149,8 +150,8 @@ class StateMachineManagerTests {
// Check protocols completed cleanly and didn't get out of phase // Check protocols completed cleanly and didn't get out of phase
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertEquals(0, node2b.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")

View File

@ -6,6 +6,7 @@ import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.div
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.messaging.TopicSession import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.messaging.runOnNextMessage
@ -20,10 +21,13 @@ import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.internal.AbstractNode import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.keys.E2ETestKeyManagementService import com.r3corda.node.services.keys.E2ETestKeyManagementService
import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.messaging.CordaRPCOps
import com.r3corda.node.services.network.InMemoryNetworkMapService import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.persistence.DBCheckpointStorage
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import com.r3corda.node.utilities.databaseTransaction import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.ServiceRequestMessage import com.r3corda.protocols.ServiceRequestMessage
@ -97,6 +101,14 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get() persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get()
} }
override fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return if (mockNet.threadPerNode) {
DBCheckpointStorage()
} else {
PerFileCheckpointStorage(dir / "checkpoints")
}
}
override fun makeIdentityService() = MockIdentityService(mockNet.identities) override fun makeIdentityService() = MockIdentityService(mockNet.identities)
override fun makeVaultService(): VaultService = InMemoryVaultService(services) override fun makeVaultService(): VaultService = InMemoryVaultService(services)
@ -153,6 +165,15 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
send(topic, target, payload) send(topic, target, payload)
return receive(topic, payload.sessionID) return receive(topic, payload.sessionID)
} }
fun disableDBCloseOnStop() {
runOnStop.remove(dbCloser)
}
fun manuallyCloseDB() {
dbCloser?.run()
dbCloser = null
}
} }
/** Returns a node, optionally created by the passed factory method. */ /** Returns a node, optionally created by the passed factory method. */