From 76808d36c3f2efd6874c910e29fa13f9e892efd6 Mon Sep 17 00:00:00 2001 From: Matthew Nesbit Date: Wed, 28 Sep 2016 15:18:22 +0100 Subject: [PATCH] Store protocol checkpoints in the DB, except during Single threaded MockNetwork activity, where we still use the file system based checkpointing. Make the Checkpoint acess a Sequence not an Iterable, so that we don't end up with all the checkpoints permanently resident in memory. Split up storage initialisation so that there is less code copying in MockNode Add header comment to DBCheckpointStorage class Respond to PR comments Resolve PR comments Rename iterator on checkpoints Fix typo Fixup checkpoints in DB logic after Shams's PR Delete duplicated code --- .../com/r3corda/node/internal/AbstractNode.kt | 25 ++- .../node/services/api/CheckpointStorage.kt | 8 +- .../persistence/DBCheckpointStorage.kt | 39 +++++ .../persistence/PerFileCheckpointStorage.kt | 11 +- .../statemachine/ProtocolStateMachineImpl.kt | 14 +- .../statemachine/StateMachineManager.kt | 15 +- .../r3corda/node/utilities/DatabaseSupport.kt | 20 +++ .../messaging/TwoPartyTradeProtocolTests.kt | 21 ++- .../persistence/DBCheckpointStorageTests.kt | 162 ++++++++++++++++++ .../PerFileCheckpointStorageTests.kt | 18 +- .../statemachine/StateMachineManagerTests.kt | 9 +- .../com/r3corda/testing/node/MockNode.kt | 21 +++ 12 files changed, 325 insertions(+), 38 deletions(-) create mode 100644 node/src/main/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorage.kt create mode 100644 node/src/test/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorageTests.kt diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index 82a7f4e8a2..084bd96866 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -151,6 +151,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap val customServices: ArrayList = ArrayList() protected val runOnStop: ArrayList = ArrayList() lateinit var database: Database + protected var dbCloser: Runnable? = null /** Locates and returns a service of the given type if loaded, or throws an exception if not found. */ inline fun findService() = customServices.filterIsInstance().single() @@ -238,13 +239,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap // Add vault observers CashBalanceAsMetricsObserver(services) ScheduledActivityObserver(services) - } - startMessagingService(ServerRPCOps(services, smm, database)) - runOnStop += Runnable { net.stop() } - _networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) - isPreviousCheckpointsPresent = checkpointStorage.checkpoints.any() - smm.start() + checkpointStorage.forEach { + isPreviousCheckpointsPresent = true + false + } + startMessagingService(ServerRPCOps(services, smm, database)) + runOnStop += Runnable { net.stop() } + _networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) + smm.start() + } started = true return this } @@ -269,7 +273,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap this.database = database // Now log the vendor string as this will also cause a connection to be tested eagerly. log.info("Connected to ${database.vendor} database.") - runOnStop += Runnable { toClose.close() } + dbCloser = Runnable { toClose.close() } + runOnStop += dbCloser!! databaseTransaction(database) { insideTransaction() } @@ -429,9 +434,13 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?) + protected open fun initialiseCheckpointService(dir: Path): CheckpointStorage { + return DBCheckpointStorage() + } + protected open fun initialiseStorageService(dir: Path): Pair { val attachments = makeAttachmentStorage(dir) - val checkpointStorage = PerFileCheckpointStorage(dir.resolve("checkpoints")) + val checkpointStorage = initialiseCheckpointService(dir) val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions")) _servicesThatAcceptUploads += attachments val (identity, keyPair) = obtainKeyPair(dir) diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt index 7693cd770c..34f4c4643a 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt @@ -21,11 +21,11 @@ interface CheckpointStorage { fun removeCheckpoint(checkpoint: Checkpoint) /** - * Returns a snapshot of all the checkpoints in the store. - * This may return more checkpoints than were added to this instance of the store; for example if the store persists - * checkpoints to disk. + * Allows the caller to process safely in a thread safe fashion the set of all checkpoints. + * The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management. + * Return false from the block to terminate further iteration. */ - val checkpoints: Iterable + fun forEach(block: (Checkpoint)->Boolean) } diff --git a/node/src/main/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorage.kt new file mode 100644 index 0000000000..0a2155044f --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorage.kt @@ -0,0 +1,39 @@ +package com.r3corda.node.services.persistence + +import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.serialization.SerializedBytes +import com.r3corda.core.serialization.deserialize +import com.r3corda.core.serialization.serialize +import com.r3corda.node.services.api.Checkpoint +import com.r3corda.node.services.api.CheckpointStorage +import com.r3corda.node.utilities.JDBCHashMap +import java.util.* + +/** + * Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites. + */ +class DBCheckpointStorage : CheckpointStorage { + private val checkpointStorage = Collections.synchronizedMap(JDBCHashMap>("checkpoints", loadOnInit = false)) + + override fun addCheckpoint(checkpoint: Checkpoint) { + val serialisedCheckpoint = checkpoint.serialize() + val id = serialisedCheckpoint.hash + checkpointStorage.put(id, serialisedCheckpoint) + } + + override fun removeCheckpoint(checkpoint: Checkpoint) { + val serialisedCheckpoint = checkpoint.serialize() + val id = serialisedCheckpoint.hash + checkpointStorage.remove(id) ?: throw IllegalArgumentException("Checkpoint not found") + } + + override fun forEach(block: (Checkpoint) -> Boolean) { + synchronized(checkpointStorage) { + for (checkpoint in checkpointStorage.values) { + if (!block(checkpoint.deserialize())) { + break + } + } + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt index 273ee2581c..e1705c41d9 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt @@ -60,9 +60,14 @@ class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage { logger.trace { "Removed $checkpoint ($checkpointFile)" } } - override val checkpoints: Iterable - get() = synchronized(checkpointFiles) { - checkpointFiles.keys.toList() + override fun forEach(block: (Checkpoint)->Boolean) { + synchronized(checkpointFiles) { + for(checkpoint in checkpointFiles.keys) { + if (!block(checkpoint)) { + break + } + } } + } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index ef5a07a176..bb4d1dda2a 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -17,8 +17,10 @@ import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.trace import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.statemachine.StateMachineManager.* +import com.r3corda.node.utilities.StrandLocalTransactionManager import com.r3corda.node.utilities.createDatabaseTransaction import org.jetbrains.exposed.sql.Database +import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.transactions.TransactionManager import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -62,6 +64,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, @Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var database: Database @Transient internal var fromCheckpoint: Boolean = false + @Transient internal var txTrampoline: Transaction? = null @Transient private var _logger: Logger? = null override val logger: Logger get() { @@ -113,7 +116,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." } } - private fun commitTransaction() { + internal fun commitTransaction() { val transaction = TransactionManager.current() try { logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." } @@ -210,7 +213,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, if (receivedMessage is SessionEnd) { openSessions.values.remove(receiveRequest.session) - throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended") + throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended") } else if (receiveRequest.receiveType.isInstance(receivedMessage)) { return receiveRequest.receiveType.cast(receivedMessage) } else { @@ -220,9 +223,14 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, @Suspendable private fun suspend(ioRequest: ProtocolIORequest) { - commitTransaction() + // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out. + txTrampoline = TransactionManager.currentOrNull() + StrandLocalTransactionManager.setThreadLocalTx(null) parkAndSerialize { fiber, serializer -> logger.trace { "Suspended $id on $ioRequest" } + // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB + StrandLocalTransactionManager.setThreadLocalTx(txTrampoline) + txTrampoline = null try { actionOnSuspend(ioRequest) } catch (t: Throwable) { diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index f61a5a7086..892168580f 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -28,6 +28,7 @@ import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AffinityExecutor import kotlinx.support.jdk8.collections.removeIf +import com.r3corda.node.utilities.isolatedTransaction import org.jetbrains.exposed.sql.Database import rx.Observable import rx.subjects.PublishSubject @@ -157,13 +158,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private fun restoreFibersFromCheckpoints() { mutex.locked { - checkpointStorage.checkpoints.forEach { + checkpointStorage.forEach { // If a protocol is added before start() then don't attempt to restore it if (!stateMachines.containsValue(it)) { val fiber = deserializeFiber(it.serialisedFiber) initFiber(fiber) stateMachines[fiber] = it } + true } } } @@ -276,6 +278,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, psm.serviceHub = serviceHub psm.actionOnSuspend = { ioRequest -> updateCheckpoint(psm) + // We commit on the fibers transaction that was copied across ThreadLocals during suspend + // This will free up the ThreadLocal so on return the caller can carry on with other transactions + psm.commitTransaction() processIORequest(ioRequest) } psm.actionOnEnd = { @@ -334,7 +339,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, */ fun add(loggerName: String, logic: ProtocolLogic): ProtocolStateMachine { val fiber = createFiber(loggerName, logic) - updateCheckpoint(fiber) + // We swap out the parent transaction context as using this frequently leads to a deadlock as we wait + // on the protocol completion future inside that context. The problem is that any progress checkpoints are + // unable to acquire the table lock and move forward till the calling transaction finishes. + // Committing in line here on a fresh context ensure we can progress. + isolatedTransaction(database) { + updateCheckpoint(fiber) + } // If we are not started then our checkpoint will be picked up during start mutex.locked { if (started) { diff --git a/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt b/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt index fbbf260a31..ba1ca4fce1 100644 --- a/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt +++ b/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt @@ -35,6 +35,15 @@ fun configureDatabase(props: Properties): Pair { return Pair(dataSource, database) } +fun isolatedTransaction(database: Database, block: Transaction.() -> T): T { + val oldContext = StrandLocalTransactionManager.setThreadLocalTx(null) + return try { + databaseTransaction(database, block) + } finally { + StrandLocalTransactionManager.restoreThreadLocalTx(oldContext) + } +} + /** * A relatively close copy of the [ThreadLocalTransactionManager] in Exposed but with the following adjustments to suit * our environment: @@ -51,6 +60,17 @@ class StrandLocalTransactionManager(initWithDatabase: Database) : TransactionMan private val threadLocalDb = ThreadLocal() private val threadLocalTx = ThreadLocal() + fun setThreadLocalTx(tx: Transaction?): Pair { + val oldTx = threadLocalTx.get() + threadLocalTx.set(tx) + return Pair(threadLocalDb.get(), oldTx) + } + + fun restoreThreadLocalTx(context: Pair) { + threadLocalDb.set(context.first) + threadLocalTx.set(context.second) + } + var database: Database get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}") set(value: Database) { diff --git a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt index 8d65900e6a..86bb6b5ce0 100644 --- a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -22,11 +22,13 @@ import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.persistence.PerFileTransactionStorage import com.r3corda.node.services.persistence.StorageServiceImpl +import com.r3corda.node.utilities.databaseTransaction import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Seller import com.r3corda.testing.* import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.MockNetwork +import com.r3corda.node.services.persistence.checkpoints import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before @@ -83,6 +85,9 @@ class TwoPartyTradeProtocolTests { aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) + aliceNode.disableDBCloseOnStop() + bobNode.disableDBCloseOnStop() + bobNode.services.fillWithSomeTestCash(2000.DOLLARS) val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey, 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second @@ -98,8 +103,14 @@ class TwoPartyTradeProtocolTests { aliceNode.stop() bobNode.stop() - assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty() - assertThat(bobNode.checkpointStorage.checkpoints).isEmpty() + databaseTransaction(aliceNode.database) { + assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty() + } + aliceNode.manuallyCloseDB() + databaseTransaction(bobNode.database) { + assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() + } + aliceNode.manuallyCloseDB() } } @@ -135,7 +146,7 @@ class TwoPartyTradeProtocolTests { bobNode.pumpReceive(false) // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. - assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1) + assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1) val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions assertThat(bobTransactionsBeforeCrash).isNotEmpty() @@ -167,8 +178,8 @@ class TwoPartyTradeProtocolTests { assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty() - assertThat(bobNode.checkpointStorage.checkpoints).isEmpty() - assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty() + assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() + assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty() val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null } assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash) diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorageTests.kt new file mode 100644 index 0000000000..f2137cf202 --- /dev/null +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -0,0 +1,162 @@ +package com.r3corda.node.services.persistence + +import com.google.common.primitives.Ints +import com.r3corda.core.serialization.SerializedBytes +import com.r3corda.core.utilities.LogHelper +import com.r3corda.node.services.api.Checkpoint +import com.r3corda.node.services.api.CheckpointStorage +import com.r3corda.node.services.transactions.PersistentUniquenessProvider +import com.r3corda.node.utilities.configureDatabase +import com.r3corda.node.utilities.databaseTransaction +import com.r3corda.testing.node.makeTestDataSourceProperties +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.jetbrains.exposed.sql.Database +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.io.Closeable + +internal fun CheckpointStorage.checkpoints(): List { + val checkpoints = mutableListOf() + forEach { + checkpoints += it + true + } + return checkpoints +} + +class DBCheckpointStorageTests { + lateinit var checkpointStorage: DBCheckpointStorage + lateinit var dataSource: Closeable + lateinit var database: Database + + @Before + fun setUp() { + LogHelper.setLevel(PersistentUniquenessProvider::class) + val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) + dataSource = dataSourceAndDatabase.first + database = dataSourceAndDatabase.second + newCheckpointStorage() + } + + @After + fun cleanUp() { + dataSource.close() + LogHelper.reset(PersistentUniquenessProvider::class) + } + + @Test + fun `add new checkpoint`() { + val checkpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(checkpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) + } + } + + @Test + fun `remove checkpoint`() { + val checkpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(checkpoint) + } + databaseTransaction(database) { + checkpointStorage.removeCheckpoint(checkpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + } + + @Test + fun `add and remove checkpoint in single commit operate`() { + val checkpoint = newCheckpoint() + val checkpoint2 = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(checkpoint2) + checkpointStorage.removeCheckpoint(checkpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) + } + } + + @Test + fun `remove unknown checkpoint`() { + val checkpoint = newCheckpoint() + databaseTransaction(database) { + assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { + checkpointStorage.removeCheckpoint(checkpoint) + } + } + } + + @Test + fun `add two checkpoints then remove first one`() { + val firstCheckpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(firstCheckpoint) + } + val secondCheckpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(secondCheckpoint) + } + databaseTransaction(database) { + checkpointStorage.removeCheckpoint(firstCheckpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) + } + } + + @Test + fun `add checkpoint and then remove after 'restart'`() { + val originalCheckpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(originalCheckpoint) + } + newCheckpointStorage() + val reconstructedCheckpoint = databaseTransaction(database) { + checkpointStorage.checkpoints().single() + } + databaseTransaction(database) { + assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) + } + databaseTransaction(database) { + checkpointStorage.removeCheckpoint(reconstructedCheckpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + } + + private fun newCheckpointStorage() { + databaseTransaction(database) { + checkpointStorage = DBCheckpointStorage() + } + } + + private var checkpointCount = 1 + private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++))) + +} \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt index e15046bfc1..65311009e2 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt @@ -34,9 +34,9 @@ class PerFileCheckpointStorageTests { fun `add new checkpoint`() { val checkpoint = newCheckpoint() checkpointStorage.addCheckpoint(checkpoint) - assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) } @Test @@ -44,9 +44,9 @@ class PerFileCheckpointStorageTests { val checkpoint = newCheckpoint() checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.removeCheckpoint(checkpoint) - assertThat(checkpointStorage.checkpoints).isEmpty() + assertThat(checkpointStorage.checkpoints()).isEmpty() newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).isEmpty() + assertThat(checkpointStorage.checkpoints()).isEmpty() } @Test @@ -64,9 +64,9 @@ class PerFileCheckpointStorageTests { val secondCheckpoint = newCheckpoint() checkpointStorage.addCheckpoint(secondCheckpoint) checkpointStorage.removeCheckpoint(firstCheckpoint) - assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) } @Test @@ -74,10 +74,10 @@ class PerFileCheckpointStorageTests { val originalCheckpoint = newCheckpoint() checkpointStorage.addCheckpoint(originalCheckpoint) newCheckpointStorage() - val reconstructedCheckpoint = checkpointStorage.checkpoints.single() + val reconstructedCheckpoint = checkpointStorage.checkpoints().single() assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) checkpointStorage.removeCheckpoint(reconstructedCheckpoint) - assertThat(checkpointStorage.checkpoints).isEmpty() + assertThat(checkpointStorage.checkpoints()).isEmpty() } @Test @@ -86,7 +86,7 @@ class PerFileCheckpointStorageTests { checkpointStorage.addCheckpoint(checkpoint) Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray()) newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) } private fun newCheckpointStorage() { diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt index b2796b9153..ad4db9413b 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -12,6 +12,7 @@ import com.r3corda.core.serialization.deserialize import com.r3corda.node.services.statemachine.StateMachineManager.SessionData import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage import com.r3corda.testing.node.InMemoryMessagingNetwork +import com.r3corda.node.services.persistence.checkpoints import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork.MockNode import org.assertj.core.api.Assertions.assertThat @@ -135,13 +136,13 @@ class StateMachineManagerTests { // Kick off first send and receive node2.smm.add("test", PingPongProtocol(node3.info.identity, payload)) - assertEquals(1, node2.checkpointStorage.checkpoints.count()) + assertEquals(1, node2.checkpointStorage.checkpoints().count()) // Restart node and thus reload the checkpoint and resend the message with same UUID node2.stop() val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) val (firstAgain, fut1) = node2b.getSingleProtocol() net.runNetwork() - assertEquals(1, node2.checkpointStorage.checkpoints.count()) + assertEquals(1, node2.checkpointStorage.checkpoints().count()) // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync. net.runNetwork() node2b.smm.executor.flush() @@ -149,8 +150,8 @@ class StateMachineManagerTests { // Check protocols completed cleanly and didn't get out of phase assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages - assertEquals(0, node2b.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") - assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") + assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") + assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt index 13029f8d2d..4943d07600 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt @@ -6,6 +6,7 @@ import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.crypto.Party +import com.r3corda.core.div import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.TopicSession import com.r3corda.core.messaging.runOnNextMessage @@ -20,10 +21,13 @@ import com.r3corda.core.testing.InMemoryVaultService import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.loggerFor import com.r3corda.node.internal.AbstractNode +import com.r3corda.node.services.api.CheckpointStorage import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.keys.E2ETestKeyManagementService import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.network.InMemoryNetworkMapService +import com.r3corda.node.services.persistence.DBCheckpointStorage +import com.r3corda.node.services.persistence.PerFileCheckpointStorage import com.r3corda.node.services.transactions.InMemoryUniquenessProvider import com.r3corda.node.utilities.databaseTransaction import com.r3corda.protocols.ServiceRequestMessage @@ -97,6 +101,14 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get() } + override fun initialiseCheckpointService(dir: Path): CheckpointStorage { + return if (mockNet.threadPerNode) { + DBCheckpointStorage() + } else { + PerFileCheckpointStorage(dir / "checkpoints") + } + } + override fun makeIdentityService() = MockIdentityService(mockNet.identities) override fun makeVaultService(): VaultService = InMemoryVaultService(services) @@ -153,6 +165,15 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, send(topic, target, payload) return receive(topic, payload.sessionID) } + + fun disableDBCloseOnStop() { + runOnStop.remove(dbCloser) + } + + fun manuallyCloseDB() { + dbCloser?.run() + dbCloser = null + } } /** Returns a node, optionally created by the passed factory method. */