Store protocol checkpoints in the DB, except during Single threaded MockNetwork activity, where we still use the file system based checkpointing. Make the Checkpoint acess a Sequence not an Iterable, so that we don't end up with all the checkpoints permanently resident in memory.

Split up storage initialisation so that there is less code copying in MockNode

Add header comment to DBCheckpointStorage class

Respond to PR comments

Resolve PR comments

Rename iterator on checkpoints

Fix typo

Fixup checkpoints in DB logic after Shams's PR

Delete duplicated code
This commit is contained in:
Matthew Nesbit 2016-09-28 15:18:22 +01:00
parent 67fdf9b2ff
commit 76808d36c3
12 changed files with 325 additions and 38 deletions

View File

@ -151,6 +151,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
val customServices: ArrayList<Any> = ArrayList()
protected val runOnStop: ArrayList<Runnable> = ArrayList()
lateinit var database: Database
protected var dbCloser: Runnable? = null
/** Locates and returns a service of the given type if loaded, or throws an exception if not found. */
inline fun <reified T : Any> findService() = customServices.filterIsInstance<T>().single()
@ -238,13 +239,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
// Add vault observers
CashBalanceAsMetricsObserver(services)
ScheduledActivityObserver(services)
}
startMessagingService(ServerRPCOps(services, smm, database))
runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
isPreviousCheckpointsPresent = checkpointStorage.checkpoints.any()
smm.start()
checkpointStorage.forEach {
isPreviousCheckpointsPresent = true
false
}
startMessagingService(ServerRPCOps(services, smm, database))
runOnStop += Runnable { net.stop() }
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
smm.start()
}
started = true
return this
}
@ -269,7 +273,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
this.database = database
// Now log the vendor string as this will also cause a connection to be tested eagerly.
log.info("Connected to ${database.vendor} database.")
runOnStop += Runnable { toClose.close() }
dbCloser = Runnable { toClose.close() }
runOnStop += dbCloser!!
databaseTransaction(database) {
insideTransaction()
}
@ -429,9 +434,13 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?)
protected open fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return DBCheckpointStorage()
}
protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> {
val attachments = makeAttachmentStorage(dir)
val checkpointStorage = PerFileCheckpointStorage(dir.resolve("checkpoints"))
val checkpointStorage = initialiseCheckpointService(dir)
val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions"))
_servicesThatAcceptUploads += attachments
val (identity, keyPair) = obtainKeyPair(dir)

View File

@ -21,11 +21,11 @@ interface CheckpointStorage {
fun removeCheckpoint(checkpoint: Checkpoint)
/**
* Returns a snapshot of all the checkpoints in the store.
* This may return more checkpoints than were added to this instance of the store; for example if the store persists
* checkpoints to disk.
* Allows the caller to process safely in a thread safe fashion the set of all checkpoints.
* The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management.
* Return false from the block to terminate further iteration.
*/
val checkpoints: Iterable<Checkpoint>
fun forEach(block: (Checkpoint)->Boolean)
}

View File

@ -0,0 +1,39 @@
package com.r3corda.node.services.persistence
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.utilities.JDBCHashMap
import java.util.*
/**
* Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites.
*/
class DBCheckpointStorage : CheckpointStorage {
private val checkpointStorage = Collections.synchronizedMap(JDBCHashMap<SecureHash, SerializedBytes<Checkpoint>>("checkpoints", loadOnInit = false))
override fun addCheckpoint(checkpoint: Checkpoint) {
val serialisedCheckpoint = checkpoint.serialize()
val id = serialisedCheckpoint.hash
checkpointStorage.put(id, serialisedCheckpoint)
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
val serialisedCheckpoint = checkpoint.serialize()
val id = serialisedCheckpoint.hash
checkpointStorage.remove(id) ?: throw IllegalArgumentException("Checkpoint not found")
}
override fun forEach(block: (Checkpoint) -> Boolean) {
synchronized(checkpointStorage) {
for (checkpoint in checkpointStorage.values) {
if (!block(checkpoint.deserialize())) {
break
}
}
}
}
}

View File

@ -60,9 +60,14 @@ class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage {
logger.trace { "Removed $checkpoint ($checkpointFile)" }
}
override val checkpoints: Iterable<Checkpoint>
get() = synchronized(checkpointFiles) {
checkpointFiles.keys.toList()
override fun forEach(block: (Checkpoint)->Boolean) {
synchronized(checkpointFiles) {
for(checkpoint in checkpointFiles.keys) {
if (!block(checkpoint)) {
break
}
}
}
}
}

View File

@ -17,8 +17,10 @@ import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.StrandLocalTransactionManager
import com.r3corda.node.utilities.createDatabaseTransaction
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger
import org.slf4j.LoggerFactory
@ -62,6 +64,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
@Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal lateinit var database: Database
@Transient internal var fromCheckpoint: Boolean = false
@Transient internal var txTrampoline: Transaction? = null
@Transient private var _logger: Logger? = null
override val logger: Logger get() {
@ -113,7 +116,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." }
}
private fun commitTransaction() {
internal fun commitTransaction() {
val transaction = TransactionManager.current()
try {
logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." }
@ -210,7 +213,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session)
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended")
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage)
} else {
@ -220,9 +223,14 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
@Suspendable
private fun suspend(ioRequest: ProtocolIORequest) {
commitTransaction()
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
txTrampoline = TransactionManager.currentOrNull()
StrandLocalTransactionManager.setThreadLocalTx(null)
parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended $id on $ioRequest" }
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
txTrampoline = null
try {
actionOnSuspend(ioRequest)
} catch (t: Throwable) {

View File

@ -28,6 +28,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import kotlinx.support.jdk8.collections.removeIf
import com.r3corda.node.utilities.isolatedTransaction
import org.jetbrains.exposed.sql.Database
import rx.Observable
import rx.subjects.PublishSubject
@ -157,13 +158,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
private fun restoreFibersFromCheckpoints() {
mutex.locked {
checkpointStorage.checkpoints.forEach {
checkpointStorage.forEach {
// If a protocol is added before start() then don't attempt to restore it
if (!stateMachines.containsValue(it)) {
val fiber = deserializeFiber(it.serialisedFiber)
initFiber(fiber)
stateMachines[fiber] = it
}
true
}
}
}
@ -276,6 +278,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
psm.serviceHub = serviceHub
psm.actionOnSuspend = { ioRequest ->
updateCheckpoint(psm)
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
psm.commitTransaction()
processIORequest(ioRequest)
}
psm.actionOnEnd = {
@ -334,7 +339,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
*/
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val fiber = createFiber(loggerName, logic)
updateCheckpoint(fiber)
// We swap out the parent transaction context as using this frequently leads to a deadlock as we wait
// on the protocol completion future inside that context. The problem is that any progress checkpoints are
// unable to acquire the table lock and move forward till the calling transaction finishes.
// Committing in line here on a fresh context ensure we can progress.
isolatedTransaction(database) {
updateCheckpoint(fiber)
}
// If we are not started then our checkpoint will be picked up during start
mutex.locked {
if (started) {

View File

@ -35,6 +35,15 @@ fun configureDatabase(props: Properties): Pair<Closeable, Database> {
return Pair(dataSource, database)
}
fun <T> isolatedTransaction(database: Database, block: Transaction.() -> T): T {
val oldContext = StrandLocalTransactionManager.setThreadLocalTx(null)
return try {
databaseTransaction(database, block)
} finally {
StrandLocalTransactionManager.restoreThreadLocalTx(oldContext)
}
}
/**
* A relatively close copy of the [ThreadLocalTransactionManager] in Exposed but with the following adjustments to suit
* our environment:
@ -51,6 +60,17 @@ class StrandLocalTransactionManager(initWithDatabase: Database) : TransactionMan
private val threadLocalDb = ThreadLocal<Database>()
private val threadLocalTx = ThreadLocal<Transaction>()
fun setThreadLocalTx(tx: Transaction?): Pair<Database?, Transaction?> {
val oldTx = threadLocalTx.get()
threadLocalTx.set(tx)
return Pair(threadLocalDb.get(), oldTx)
}
fun restoreThreadLocalTx(context: Pair<Database?, Transaction?>) {
threadLocalDb.set(context.first)
threadLocalTx.set(context.second)
}
var database: Database
get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}")
set(value: Database) {

View File

@ -22,11 +22,13 @@ import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.persistence.NodeAttachmentService
import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.persistence.checkpoints
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
@ -83,6 +85,9 @@ class TwoPartyTradeProtocolTests {
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
aliceNode.disableDBCloseOnStop()
bobNode.disableDBCloseOnStop()
bobNode.services.fillWithSomeTestCash(2000.DOLLARS)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
@ -98,8 +103,14 @@ class TwoPartyTradeProtocolTests {
aliceNode.stop()
bobNode.stop()
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty()
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty()
databaseTransaction(aliceNode.database) {
assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty()
}
aliceNode.manuallyCloseDB()
databaseTransaction(bobNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
}
aliceNode.manuallyCloseDB()
}
}
@ -135,7 +146,7 @@ class TwoPartyTradeProtocolTests {
bobNode.pumpReceive(false)
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1)
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions
assertThat(bobTransactionsBeforeCrash).isNotEmpty()
@ -167,8 +178,8 @@ class TwoPartyTradeProtocolTests {
assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty()
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty()
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty()
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty()
val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null }
assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash)

View File

@ -0,0 +1,162 @@
package com.r3corda.node.services.persistence
import com.google.common.primitives.Ints
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.utilities.LogHelper
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.transactions.PersistentUniquenessProvider
import com.r3corda.node.utilities.configureDatabase
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.jetbrains.exposed.sql.Database
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.io.Closeable
internal fun CheckpointStorage.checkpoints(): List<Checkpoint> {
val checkpoints = mutableListOf<Checkpoint>()
forEach {
checkpoints += it
true
}
return checkpoints
}
class DBCheckpointStorageTests {
lateinit var checkpointStorage: DBCheckpointStorage
lateinit var dataSource: Closeable
lateinit var database: Database
@Before
fun setUp() {
LogHelper.setLevel(PersistentUniquenessProvider::class)
val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties())
dataSource = dataSourceAndDatabase.first
database = dataSourceAndDatabase.second
newCheckpointStorage()
}
@After
fun cleanUp() {
dataSource.close()
LogHelper.reset(PersistentUniquenessProvider::class)
}
@Test
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(checkpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
}
@Test
fun `remove checkpoint`() {
val checkpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(checkpoint)
}
databaseTransaction(database) {
checkpointStorage.removeCheckpoint(checkpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
}
@Test
fun `add and remove checkpoint in single commit operate`() {
val checkpoint = newCheckpoint()
val checkpoint2 = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.addCheckpoint(checkpoint2)
checkpointStorage.removeCheckpoint(checkpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
}
}
@Test
fun `remove unknown checkpoint`() {
val checkpoint = newCheckpoint()
databaseTransaction(database) {
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
checkpointStorage.removeCheckpoint(checkpoint)
}
}
}
@Test
fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(firstCheckpoint)
}
val secondCheckpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(secondCheckpoint)
}
databaseTransaction(database) {
checkpointStorage.removeCheckpoint(firstCheckpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
}
newCheckpointStorage()
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
}
}
@Test
fun `add checkpoint and then remove after 'restart'`() {
val originalCheckpoint = newCheckpoint()
databaseTransaction(database) {
checkpointStorage.addCheckpoint(originalCheckpoint)
}
newCheckpointStorage()
val reconstructedCheckpoint = databaseTransaction(database) {
checkpointStorage.checkpoints().single()
}
databaseTransaction(database) {
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
}
databaseTransaction(database) {
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
}
databaseTransaction(database) {
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
}
private fun newCheckpointStorage() {
databaseTransaction(database) {
checkpointStorage = DBCheckpointStorage()
}
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
}

View File

@ -34,9 +34,9 @@ class PerFileCheckpointStorageTests {
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
@Test
@ -44,9 +44,9 @@ class PerFileCheckpointStorageTests {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.removeCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints).isEmpty()
assertThat(checkpointStorage.checkpoints()).isEmpty()
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).isEmpty()
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
@Test
@ -64,9 +64,9 @@ class PerFileCheckpointStorageTests {
val secondCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(secondCheckpoint)
checkpointStorage.removeCheckpoint(firstCheckpoint)
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
}
@Test
@ -74,10 +74,10 @@ class PerFileCheckpointStorageTests {
val originalCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(originalCheckpoint)
newCheckpointStorage()
val reconstructedCheckpoint = checkpointStorage.checkpoints.single()
val reconstructedCheckpoint = checkpointStorage.checkpoints().single()
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
assertThat(checkpointStorage.checkpoints).isEmpty()
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
@Test
@ -86,7 +86,7 @@ class PerFileCheckpointStorageTests {
checkpointStorage.addCheckpoint(checkpoint)
Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray())
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
private fun newCheckpointStorage() {

View File

@ -12,6 +12,7 @@ import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat
@ -135,13 +136,13 @@ class StateMachineManagerTests {
// Kick off first send and receive
node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints.count())
assertEquals(1, node2.checkpointStorage.checkpoints().count())
// Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints.count())
assertEquals(1, node2.checkpointStorage.checkpoints().count())
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork()
node2b.smm.executor.flush()
@ -149,8 +150,8 @@ class StateMachineManagerTests {
// Check protocols completed cleanly and didn't get out of phase
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertEquals(0, node2b.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")

View File

@ -6,6 +6,7 @@ import com.google.common.util.concurrent.Futures
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.div
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
@ -20,10 +21,13 @@ import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.keys.E2ETestKeyManagementService
import com.r3corda.node.services.messaging.CordaRPCOps
import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.persistence.DBCheckpointStorage
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.protocols.ServiceRequestMessage
@ -97,6 +101,14 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get()
}
override fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return if (mockNet.threadPerNode) {
DBCheckpointStorage()
} else {
PerFileCheckpointStorage(dir / "checkpoints")
}
}
override fun makeIdentityService() = MockIdentityService(mockNet.identities)
override fun makeVaultService(): VaultService = InMemoryVaultService(services)
@ -153,6 +165,15 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
send(topic, target, payload)
return receive(topic, payload.sessionID)
}
fun disableDBCloseOnStop() {
runOnStop.remove(dbCloser)
}
fun manuallyCloseDB() {
dbCloser?.run()
dbCloser = null
}
}
/** Returns a node, optionally created by the passed factory method. */