diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index 82a7f4e8a2..084bd96866 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -151,6 +151,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap val customServices: ArrayList = ArrayList() protected val runOnStop: ArrayList = ArrayList() lateinit var database: Database + protected var dbCloser: Runnable? = null /** Locates and returns a service of the given type if loaded, or throws an exception if not found. */ inline fun findService() = customServices.filterIsInstance().single() @@ -238,13 +239,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap // Add vault observers CashBalanceAsMetricsObserver(services) ScheduledActivityObserver(services) - } - startMessagingService(ServerRPCOps(services, smm, database)) - runOnStop += Runnable { net.stop() } - _networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) - isPreviousCheckpointsPresent = checkpointStorage.checkpoints.any() - smm.start() + checkpointStorage.forEach { + isPreviousCheckpointsPresent = true + false + } + startMessagingService(ServerRPCOps(services, smm, database)) + runOnStop += Runnable { net.stop() } + _networkMapRegistrationFuture.setFuture(registerWithNetworkMap()) + smm.start() + } started = true return this } @@ -269,7 +273,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap this.database = database // Now log the vendor string as this will also cause a connection to be tested eagerly. log.info("Connected to ${database.vendor} database.") - runOnStop += Runnable { toClose.close() } + dbCloser = Runnable { toClose.close() } + runOnStop += dbCloser!! databaseTransaction(database) { insideTransaction() } @@ -429,9 +434,13 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?) + protected open fun initialiseCheckpointService(dir: Path): CheckpointStorage { + return DBCheckpointStorage() + } + protected open fun initialiseStorageService(dir: Path): Pair { val attachments = makeAttachmentStorage(dir) - val checkpointStorage = PerFileCheckpointStorage(dir.resolve("checkpoints")) + val checkpointStorage = initialiseCheckpointService(dir) val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions")) _servicesThatAcceptUploads += attachments val (identity, keyPair) = obtainKeyPair(dir) diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt index 7693cd770c..34f4c4643a 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/CheckpointStorage.kt @@ -21,11 +21,11 @@ interface CheckpointStorage { fun removeCheckpoint(checkpoint: Checkpoint) /** - * Returns a snapshot of all the checkpoints in the store. - * This may return more checkpoints than were added to this instance of the store; for example if the store persists - * checkpoints to disk. + * Allows the caller to process safely in a thread safe fashion the set of all checkpoints. + * The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management. + * Return false from the block to terminate further iteration. */ - val checkpoints: Iterable + fun forEach(block: (Checkpoint)->Boolean) } diff --git a/node/src/main/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorage.kt new file mode 100644 index 0000000000..0a2155044f --- /dev/null +++ b/node/src/main/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorage.kt @@ -0,0 +1,39 @@ +package com.r3corda.node.services.persistence + +import com.r3corda.core.crypto.SecureHash +import com.r3corda.core.serialization.SerializedBytes +import com.r3corda.core.serialization.deserialize +import com.r3corda.core.serialization.serialize +import com.r3corda.node.services.api.Checkpoint +import com.r3corda.node.services.api.CheckpointStorage +import com.r3corda.node.utilities.JDBCHashMap +import java.util.* + +/** + * Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites. + */ +class DBCheckpointStorage : CheckpointStorage { + private val checkpointStorage = Collections.synchronizedMap(JDBCHashMap>("checkpoints", loadOnInit = false)) + + override fun addCheckpoint(checkpoint: Checkpoint) { + val serialisedCheckpoint = checkpoint.serialize() + val id = serialisedCheckpoint.hash + checkpointStorage.put(id, serialisedCheckpoint) + } + + override fun removeCheckpoint(checkpoint: Checkpoint) { + val serialisedCheckpoint = checkpoint.serialize() + val id = serialisedCheckpoint.hash + checkpointStorage.remove(id) ?: throw IllegalArgumentException("Checkpoint not found") + } + + override fun forEach(block: (Checkpoint) -> Boolean) { + synchronized(checkpointStorage) { + for (checkpoint in checkpointStorage.values) { + if (!block(checkpoint.deserialize())) { + break + } + } + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt b/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt index 273ee2581c..e1705c41d9 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorage.kt @@ -60,9 +60,14 @@ class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage { logger.trace { "Removed $checkpoint ($checkpointFile)" } } - override val checkpoints: Iterable - get() = synchronized(checkpointFiles) { - checkpointFiles.keys.toList() + override fun forEach(block: (Checkpoint)->Boolean) { + synchronized(checkpointFiles) { + for(checkpoint in checkpointFiles.keys) { + if (!block(checkpoint)) { + break + } + } } + } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt index ef5a07a176..bb4d1dda2a 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/ProtocolStateMachineImpl.kt @@ -17,8 +17,10 @@ import com.r3corda.core.utilities.UntrustworthyData import com.r3corda.core.utilities.trace import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.services.statemachine.StateMachineManager.* +import com.r3corda.node.utilities.StrandLocalTransactionManager import com.r3corda.node.utilities.createDatabaseTransaction import org.jetbrains.exposed.sql.Database +import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.transactions.TransactionManager import org.slf4j.Logger import org.slf4j.LoggerFactory @@ -62,6 +64,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, @Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var database: Database @Transient internal var fromCheckpoint: Boolean = false + @Transient internal var txTrampoline: Transaction? = null @Transient private var _logger: Logger? = null override val logger: Logger get() { @@ -113,7 +116,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." } } - private fun commitTransaction() { + internal fun commitTransaction() { val transaction = TransactionManager.current() try { logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." } @@ -210,7 +213,7 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, if (receivedMessage is SessionEnd) { openSessions.values.remove(receiveRequest.session) - throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended") + throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended") } else if (receiveRequest.receiveType.isInstance(receivedMessage)) { return receiveRequest.receiveType.cast(receivedMessage) } else { @@ -220,9 +223,14 @@ class ProtocolStateMachineImpl(override val id: StateMachineRunId, @Suspendable private fun suspend(ioRequest: ProtocolIORequest) { - commitTransaction() + // we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out. + txTrampoline = TransactionManager.currentOrNull() + StrandLocalTransactionManager.setThreadLocalTx(null) parkAndSerialize { fiber, serializer -> logger.trace { "Suspended $id on $ioRequest" } + // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB + StrandLocalTransactionManager.setThreadLocalTx(txTrampoline) + txTrampoline = null try { actionOnSuspend(ioRequest) } catch (t: Throwable) { diff --git a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt index f61a5a7086..892168580f 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/statemachine/StateMachineManager.kt @@ -28,6 +28,7 @@ import com.r3corda.node.services.api.ServiceHubInternal import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AffinityExecutor import kotlinx.support.jdk8.collections.removeIf +import com.r3corda.node.utilities.isolatedTransaction import org.jetbrains.exposed.sql.Database import rx.Observable import rx.subjects.PublishSubject @@ -157,13 +158,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private fun restoreFibersFromCheckpoints() { mutex.locked { - checkpointStorage.checkpoints.forEach { + checkpointStorage.forEach { // If a protocol is added before start() then don't attempt to restore it if (!stateMachines.containsValue(it)) { val fiber = deserializeFiber(it.serialisedFiber) initFiber(fiber) stateMachines[fiber] = it } + true } } } @@ -276,6 +278,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, psm.serviceHub = serviceHub psm.actionOnSuspend = { ioRequest -> updateCheckpoint(psm) + // We commit on the fibers transaction that was copied across ThreadLocals during suspend + // This will free up the ThreadLocal so on return the caller can carry on with other transactions + psm.commitTransaction() processIORequest(ioRequest) } psm.actionOnEnd = { @@ -334,7 +339,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, */ fun add(loggerName: String, logic: ProtocolLogic): ProtocolStateMachine { val fiber = createFiber(loggerName, logic) - updateCheckpoint(fiber) + // We swap out the parent transaction context as using this frequently leads to a deadlock as we wait + // on the protocol completion future inside that context. The problem is that any progress checkpoints are + // unable to acquire the table lock and move forward till the calling transaction finishes. + // Committing in line here on a fresh context ensure we can progress. + isolatedTransaction(database) { + updateCheckpoint(fiber) + } // If we are not started then our checkpoint will be picked up during start mutex.locked { if (started) { diff --git a/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt b/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt index fbbf260a31..ba1ca4fce1 100644 --- a/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt +++ b/node/src/main/kotlin/com/r3corda/node/utilities/DatabaseSupport.kt @@ -35,6 +35,15 @@ fun configureDatabase(props: Properties): Pair { return Pair(dataSource, database) } +fun isolatedTransaction(database: Database, block: Transaction.() -> T): T { + val oldContext = StrandLocalTransactionManager.setThreadLocalTx(null) + return try { + databaseTransaction(database, block) + } finally { + StrandLocalTransactionManager.restoreThreadLocalTx(oldContext) + } +} + /** * A relatively close copy of the [ThreadLocalTransactionManager] in Exposed but with the following adjustments to suit * our environment: @@ -51,6 +60,17 @@ class StrandLocalTransactionManager(initWithDatabase: Database) : TransactionMan private val threadLocalDb = ThreadLocal() private val threadLocalTx = ThreadLocal() + fun setThreadLocalTx(tx: Transaction?): Pair { + val oldTx = threadLocalTx.get() + threadLocalTx.set(tx) + return Pair(threadLocalDb.get(), oldTx) + } + + fun restoreThreadLocalTx(context: Pair) { + threadLocalDb.set(context.first) + threadLocalTx.set(context.second) + } + var database: Database get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}") set(value: Database) { diff --git a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt index 8d65900e6a..86bb6b5ce0 100644 --- a/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/messaging/TwoPartyTradeProtocolTests.kt @@ -22,11 +22,13 @@ import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.persistence.NodeAttachmentService import com.r3corda.node.services.persistence.PerFileTransactionStorage import com.r3corda.node.services.persistence.StorageServiceImpl +import com.r3corda.node.utilities.databaseTransaction import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer import com.r3corda.protocols.TwoPartyTradeProtocol.Seller import com.r3corda.testing.* import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.MockNetwork +import com.r3corda.node.services.persistence.checkpoints import org.assertj.core.api.Assertions.assertThat import org.junit.After import org.junit.Before @@ -83,6 +85,9 @@ class TwoPartyTradeProtocolTests { aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY) bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY) + aliceNode.disableDBCloseOnStop() + bobNode.disableDBCloseOnStop() + bobNode.services.fillWithSomeTestCash(2000.DOLLARS) val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey, 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second @@ -98,8 +103,14 @@ class TwoPartyTradeProtocolTests { aliceNode.stop() bobNode.stop() - assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty() - assertThat(bobNode.checkpointStorage.checkpoints).isEmpty() + databaseTransaction(aliceNode.database) { + assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty() + } + aliceNode.manuallyCloseDB() + databaseTransaction(bobNode.database) { + assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() + } + aliceNode.manuallyCloseDB() } } @@ -135,7 +146,7 @@ class TwoPartyTradeProtocolTests { bobNode.pumpReceive(false) // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. - assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1) + assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1) val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions assertThat(bobTransactionsBeforeCrash).isNotEmpty() @@ -167,8 +178,8 @@ class TwoPartyTradeProtocolTests { assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty() - assertThat(bobNode.checkpointStorage.checkpoints).isEmpty() - assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty() + assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty() + assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty() val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null } assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash) diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorageTests.kt new file mode 100644 index 0000000000..f2137cf202 --- /dev/null +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -0,0 +1,162 @@ +package com.r3corda.node.services.persistence + +import com.google.common.primitives.Ints +import com.r3corda.core.serialization.SerializedBytes +import com.r3corda.core.utilities.LogHelper +import com.r3corda.node.services.api.Checkpoint +import com.r3corda.node.services.api.CheckpointStorage +import com.r3corda.node.services.transactions.PersistentUniquenessProvider +import com.r3corda.node.utilities.configureDatabase +import com.r3corda.node.utilities.databaseTransaction +import com.r3corda.testing.node.makeTestDataSourceProperties +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatExceptionOfType +import org.jetbrains.exposed.sql.Database +import org.junit.After +import org.junit.Before +import org.junit.Test +import java.io.Closeable + +internal fun CheckpointStorage.checkpoints(): List { + val checkpoints = mutableListOf() + forEach { + checkpoints += it + true + } + return checkpoints +} + +class DBCheckpointStorageTests { + lateinit var checkpointStorage: DBCheckpointStorage + lateinit var dataSource: Closeable + lateinit var database: Database + + @Before + fun setUp() { + LogHelper.setLevel(PersistentUniquenessProvider::class) + val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) + dataSource = dataSourceAndDatabase.first + database = dataSourceAndDatabase.second + newCheckpointStorage() + } + + @After + fun cleanUp() { + dataSource.close() + LogHelper.reset(PersistentUniquenessProvider::class) + } + + @Test + fun `add new checkpoint`() { + val checkpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(checkpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) + } + } + + @Test + fun `remove checkpoint`() { + val checkpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(checkpoint) + } + databaseTransaction(database) { + checkpointStorage.removeCheckpoint(checkpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + } + + @Test + fun `add and remove checkpoint in single commit operate`() { + val checkpoint = newCheckpoint() + val checkpoint2 = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(checkpoint) + checkpointStorage.addCheckpoint(checkpoint2) + checkpointStorage.removeCheckpoint(checkpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2) + } + } + + @Test + fun `remove unknown checkpoint`() { + val checkpoint = newCheckpoint() + databaseTransaction(database) { + assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { + checkpointStorage.removeCheckpoint(checkpoint) + } + } + } + + @Test + fun `add two checkpoints then remove first one`() { + val firstCheckpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(firstCheckpoint) + } + val secondCheckpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(secondCheckpoint) + } + databaseTransaction(database) { + checkpointStorage.removeCheckpoint(firstCheckpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) + } + newCheckpointStorage() + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) + } + } + + @Test + fun `add checkpoint and then remove after 'restart'`() { + val originalCheckpoint = newCheckpoint() + databaseTransaction(database) { + checkpointStorage.addCheckpoint(originalCheckpoint) + } + newCheckpointStorage() + val reconstructedCheckpoint = databaseTransaction(database) { + checkpointStorage.checkpoints().single() + } + databaseTransaction(database) { + assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) + } + databaseTransaction(database) { + checkpointStorage.removeCheckpoint(reconstructedCheckpoint) + } + databaseTransaction(database) { + assertThat(checkpointStorage.checkpoints()).isEmpty() + } + } + + private fun newCheckpointStorage() { + databaseTransaction(database) { + checkpointStorage = DBCheckpointStorage() + } + } + + private var checkpointCount = 1 + private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++))) + +} \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt index e15046bfc1..65311009e2 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/persistence/PerFileCheckpointStorageTests.kt @@ -34,9 +34,9 @@ class PerFileCheckpointStorageTests { fun `add new checkpoint`() { val checkpoint = newCheckpoint() checkpointStorage.addCheckpoint(checkpoint) - assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) } @Test @@ -44,9 +44,9 @@ class PerFileCheckpointStorageTests { val checkpoint = newCheckpoint() checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.removeCheckpoint(checkpoint) - assertThat(checkpointStorage.checkpoints).isEmpty() + assertThat(checkpointStorage.checkpoints()).isEmpty() newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).isEmpty() + assertThat(checkpointStorage.checkpoints()).isEmpty() } @Test @@ -64,9 +64,9 @@ class PerFileCheckpointStorageTests { val secondCheckpoint = newCheckpoint() checkpointStorage.addCheckpoint(secondCheckpoint) checkpointStorage.removeCheckpoint(firstCheckpoint) - assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint) } @Test @@ -74,10 +74,10 @@ class PerFileCheckpointStorageTests { val originalCheckpoint = newCheckpoint() checkpointStorage.addCheckpoint(originalCheckpoint) newCheckpointStorage() - val reconstructedCheckpoint = checkpointStorage.checkpoints.single() + val reconstructedCheckpoint = checkpointStorage.checkpoints().single() assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint) checkpointStorage.removeCheckpoint(reconstructedCheckpoint) - assertThat(checkpointStorage.checkpoints).isEmpty() + assertThat(checkpointStorage.checkpoints()).isEmpty() } @Test @@ -86,7 +86,7 @@ class PerFileCheckpointStorageTests { checkpointStorage.addCheckpoint(checkpoint) Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray()) newCheckpointStorage() - assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint) + assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint) } private fun newCheckpointStorage() { diff --git a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt index b2796b9153..ad4db9413b 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/statemachine/StateMachineManagerTests.kt @@ -12,6 +12,7 @@ import com.r3corda.core.serialization.deserialize import com.r3corda.node.services.statemachine.StateMachineManager.SessionData import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage import com.r3corda.testing.node.InMemoryMessagingNetwork +import com.r3corda.node.services.persistence.checkpoints import com.r3corda.testing.node.MockNetwork import com.r3corda.testing.node.MockNetwork.MockNode import org.assertj.core.api.Assertions.assertThat @@ -135,13 +136,13 @@ class StateMachineManagerTests { // Kick off first send and receive node2.smm.add("test", PingPongProtocol(node3.info.identity, payload)) - assertEquals(1, node2.checkpointStorage.checkpoints.count()) + assertEquals(1, node2.checkpointStorage.checkpoints().count()) // Restart node and thus reload the checkpoint and resend the message with same UUID node2.stop() val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) val (firstAgain, fut1) = node2b.getSingleProtocol() net.runNetwork() - assertEquals(1, node2.checkpointStorage.checkpoints.count()) + assertEquals(1, node2.checkpointStorage.checkpoints().count()) // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync. net.runNetwork() node2b.smm.executor.flush() @@ -149,8 +150,8 @@ class StateMachineManagerTests { // Check protocols completed cleanly and didn't get out of phase assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages - assertEquals(0, node2b.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") - assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended") + assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") + assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2") diff --git a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt index 13029f8d2d..4943d07600 100644 --- a/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/com/r3corda/testing/node/MockNode.kt @@ -6,6 +6,7 @@ import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.crypto.Party +import com.r3corda.core.div import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.TopicSession import com.r3corda.core.messaging.runOnNextMessage @@ -20,10 +21,13 @@ import com.r3corda.core.testing.InMemoryVaultService import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.loggerFor import com.r3corda.node.internal.AbstractNode +import com.r3corda.node.services.api.CheckpointStorage import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.keys.E2ETestKeyManagementService import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.network.InMemoryNetworkMapService +import com.r3corda.node.services.persistence.DBCheckpointStorage +import com.r3corda.node.services.persistence.PerFileCheckpointStorage import com.r3corda.node.services.transactions.InMemoryUniquenessProvider import com.r3corda.node.utilities.databaseTransaction import com.r3corda.protocols.ServiceRequestMessage @@ -97,6 +101,14 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get() } + override fun initialiseCheckpointService(dir: Path): CheckpointStorage { + return if (mockNet.threadPerNode) { + DBCheckpointStorage() + } else { + PerFileCheckpointStorage(dir / "checkpoints") + } + } + override fun makeIdentityService() = MockIdentityService(mockNet.identities) override fun makeVaultService(): VaultService = InMemoryVaultService(services) @@ -153,6 +165,15 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, send(topic, target, payload) return receive(topic, payload.sessionID) } + + fun disableDBCloseOnStop() { + runOnStop.remove(dbCloser) + } + + fun manuallyCloseDB() { + dbCloser?.run() + dbCloser = null + } } /** Returns a node, optionally created by the passed factory method. */