mirror of
https://github.com/corda/corda.git
synced 2025-06-13 04:38:19 +00:00
Merged in mnesbit-cor-389-checkpoints-in-db (pull request #385)
Store protocol checkpoints in the DB
This commit is contained in:
@ -151,6 +151,7 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
|
|||||||
val customServices: ArrayList<Any> = ArrayList()
|
val customServices: ArrayList<Any> = ArrayList()
|
||||||
protected val runOnStop: ArrayList<Runnable> = ArrayList()
|
protected val runOnStop: ArrayList<Runnable> = ArrayList()
|
||||||
lateinit var database: Database
|
lateinit var database: Database
|
||||||
|
protected var dbCloser: Runnable? = null
|
||||||
|
|
||||||
/** Locates and returns a service of the given type if loaded, or throws an exception if not found. */
|
/** Locates and returns a service of the given type if loaded, or throws an exception if not found. */
|
||||||
inline fun <reified T : Any> findService() = customServices.filterIsInstance<T>().single()
|
inline fun <reified T : Any> findService() = customServices.filterIsInstance<T>().single()
|
||||||
@ -238,13 +239,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
|
|||||||
// Add vault observers
|
// Add vault observers
|
||||||
CashBalanceAsMetricsObserver(services)
|
CashBalanceAsMetricsObserver(services)
|
||||||
ScheduledActivityObserver(services)
|
ScheduledActivityObserver(services)
|
||||||
}
|
|
||||||
|
|
||||||
|
checkpointStorage.forEach {
|
||||||
|
isPreviousCheckpointsPresent = true
|
||||||
|
false
|
||||||
|
}
|
||||||
startMessagingService(ServerRPCOps(services, smm, database))
|
startMessagingService(ServerRPCOps(services, smm, database))
|
||||||
runOnStop += Runnable { net.stop() }
|
runOnStop += Runnable { net.stop() }
|
||||||
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
|
_networkMapRegistrationFuture.setFuture(registerWithNetworkMap())
|
||||||
isPreviousCheckpointsPresent = checkpointStorage.checkpoints.any()
|
|
||||||
smm.start()
|
smm.start()
|
||||||
|
}
|
||||||
started = true
|
started = true
|
||||||
return this
|
return this
|
||||||
}
|
}
|
||||||
@ -269,7 +273,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
|
|||||||
this.database = database
|
this.database = database
|
||||||
// Now log the vendor string as this will also cause a connection to be tested eagerly.
|
// Now log the vendor string as this will also cause a connection to be tested eagerly.
|
||||||
log.info("Connected to ${database.vendor} database.")
|
log.info("Connected to ${database.vendor} database.")
|
||||||
runOnStop += Runnable { toClose.close() }
|
dbCloser = Runnable { toClose.close() }
|
||||||
|
runOnStop += dbCloser!!
|
||||||
databaseTransaction(database) {
|
databaseTransaction(database) {
|
||||||
insideTransaction()
|
insideTransaction()
|
||||||
}
|
}
|
||||||
@ -429,9 +434,13 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
|
|||||||
|
|
||||||
protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?)
|
protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?)
|
||||||
|
|
||||||
|
protected open fun initialiseCheckpointService(dir: Path): CheckpointStorage {
|
||||||
|
return DBCheckpointStorage()
|
||||||
|
}
|
||||||
|
|
||||||
protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> {
|
protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> {
|
||||||
val attachments = makeAttachmentStorage(dir)
|
val attachments = makeAttachmentStorage(dir)
|
||||||
val checkpointStorage = PerFileCheckpointStorage(dir.resolve("checkpoints"))
|
val checkpointStorage = initialiseCheckpointService(dir)
|
||||||
val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions"))
|
val transactionStorage = PerFileTransactionStorage(dir.resolve("transactions"))
|
||||||
_servicesThatAcceptUploads += attachments
|
_servicesThatAcceptUploads += attachments
|
||||||
val (identity, keyPair) = obtainKeyPair(dir)
|
val (identity, keyPair) = obtainKeyPair(dir)
|
||||||
|
@ -21,11 +21,11 @@ interface CheckpointStorage {
|
|||||||
fun removeCheckpoint(checkpoint: Checkpoint)
|
fun removeCheckpoint(checkpoint: Checkpoint)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a snapshot of all the checkpoints in the store.
|
* Allows the caller to process safely in a thread safe fashion the set of all checkpoints.
|
||||||
* This may return more checkpoints than were added to this instance of the store; for example if the store persists
|
* The checkpoints are only valid during the lifetime of a single call to the block, to allow memory management.
|
||||||
* checkpoints to disk.
|
* Return false from the block to terminate further iteration.
|
||||||
*/
|
*/
|
||||||
val checkpoints: Iterable<Checkpoint>
|
fun forEach(block: (Checkpoint)->Boolean)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,39 @@
|
|||||||
|
package com.r3corda.node.services.persistence
|
||||||
|
|
||||||
|
import com.r3corda.core.crypto.SecureHash
|
||||||
|
import com.r3corda.core.serialization.SerializedBytes
|
||||||
|
import com.r3corda.core.serialization.deserialize
|
||||||
|
import com.r3corda.core.serialization.serialize
|
||||||
|
import com.r3corda.node.services.api.Checkpoint
|
||||||
|
import com.r3corda.node.services.api.CheckpointStorage
|
||||||
|
import com.r3corda.node.utilities.JDBCHashMap
|
||||||
|
import java.util.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites.
|
||||||
|
*/
|
||||||
|
class DBCheckpointStorage : CheckpointStorage {
|
||||||
|
private val checkpointStorage = Collections.synchronizedMap(JDBCHashMap<SecureHash, SerializedBytes<Checkpoint>>("checkpoints", loadOnInit = false))
|
||||||
|
|
||||||
|
override fun addCheckpoint(checkpoint: Checkpoint) {
|
||||||
|
val serialisedCheckpoint = checkpoint.serialize()
|
||||||
|
val id = serialisedCheckpoint.hash
|
||||||
|
checkpointStorage.put(id, serialisedCheckpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun removeCheckpoint(checkpoint: Checkpoint) {
|
||||||
|
val serialisedCheckpoint = checkpoint.serialize()
|
||||||
|
val id = serialisedCheckpoint.hash
|
||||||
|
checkpointStorage.remove(id) ?: throw IllegalArgumentException("Checkpoint not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun forEach(block: (Checkpoint) -> Boolean) {
|
||||||
|
synchronized(checkpointStorage) {
|
||||||
|
for (checkpoint in checkpointStorage.values) {
|
||||||
|
if (!block(checkpoint.deserialize())) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -60,9 +60,14 @@ class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage {
|
|||||||
logger.trace { "Removed $checkpoint ($checkpointFile)" }
|
logger.trace { "Removed $checkpoint ($checkpointFile)" }
|
||||||
}
|
}
|
||||||
|
|
||||||
override val checkpoints: Iterable<Checkpoint>
|
override fun forEach(block: (Checkpoint)->Boolean) {
|
||||||
get() = synchronized(checkpointFiles) {
|
synchronized(checkpointFiles) {
|
||||||
checkpointFiles.keys.toList()
|
for(checkpoint in checkpointFiles.keys) {
|
||||||
|
if (!block(checkpoint)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -17,8 +17,10 @@ import com.r3corda.core.utilities.UntrustworthyData
|
|||||||
import com.r3corda.core.utilities.trace
|
import com.r3corda.core.utilities.trace
|
||||||
import com.r3corda.node.services.api.ServiceHubInternal
|
import com.r3corda.node.services.api.ServiceHubInternal
|
||||||
import com.r3corda.node.services.statemachine.StateMachineManager.*
|
import com.r3corda.node.services.statemachine.StateMachineManager.*
|
||||||
|
import com.r3corda.node.utilities.StrandLocalTransactionManager
|
||||||
import com.r3corda.node.utilities.createDatabaseTransaction
|
import com.r3corda.node.utilities.createDatabaseTransaction
|
||||||
import org.jetbrains.exposed.sql.Database
|
import org.jetbrains.exposed.sql.Database
|
||||||
|
import org.jetbrains.exposed.sql.Transaction
|
||||||
import org.jetbrains.exposed.sql.transactions.TransactionManager
|
import org.jetbrains.exposed.sql.transactions.TransactionManager
|
||||||
import org.slf4j.Logger
|
import org.slf4j.Logger
|
||||||
import org.slf4j.LoggerFactory
|
import org.slf4j.LoggerFactory
|
||||||
@ -62,6 +64,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||||
@Transient internal lateinit var database: Database
|
@Transient internal lateinit var database: Database
|
||||||
@Transient internal var fromCheckpoint: Boolean = false
|
@Transient internal var fromCheckpoint: Boolean = false
|
||||||
|
@Transient internal var txTrampoline: Transaction? = null
|
||||||
|
|
||||||
@Transient private var _logger: Logger? = null
|
@Transient private var _logger: Logger? = null
|
||||||
override val logger: Logger get() {
|
override val logger: Logger get() {
|
||||||
@ -113,7 +116,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." }
|
logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}." }
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun commitTransaction() {
|
internal fun commitTransaction() {
|
||||||
val transaction = TransactionManager.current()
|
val transaction = TransactionManager.current()
|
||||||
try {
|
try {
|
||||||
logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." }
|
logger.trace { "Commiting database transaction $transaction on ${Strand.currentStrand()}." }
|
||||||
@ -210,7 +213,7 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
|
|
||||||
if (receivedMessage is SessionEnd) {
|
if (receivedMessage is SessionEnd) {
|
||||||
openSessions.values.remove(receiveRequest.session)
|
openSessions.values.remove(receiveRequest.session)
|
||||||
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended")
|
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurely ended")
|
||||||
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
|
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
|
||||||
return receiveRequest.receiveType.cast(receivedMessage)
|
return receiveRequest.receiveType.cast(receivedMessage)
|
||||||
} else {
|
} else {
|
||||||
@ -220,9 +223,14 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
|||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
private fun suspend(ioRequest: ProtocolIORequest) {
|
private fun suspend(ioRequest: ProtocolIORequest) {
|
||||||
commitTransaction()
|
// we have to pass the Thread local Transaction across via a transient field as the Fiber Park swaps them out.
|
||||||
|
txTrampoline = TransactionManager.currentOrNull()
|
||||||
|
StrandLocalTransactionManager.setThreadLocalTx(null)
|
||||||
parkAndSerialize { fiber, serializer ->
|
parkAndSerialize { fiber, serializer ->
|
||||||
logger.trace { "Suspended $id on $ioRequest" }
|
logger.trace { "Suspended $id on $ioRequest" }
|
||||||
|
// restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB
|
||||||
|
StrandLocalTransactionManager.setThreadLocalTx(txTrampoline)
|
||||||
|
txTrampoline = null
|
||||||
try {
|
try {
|
||||||
actionOnSuspend(ioRequest)
|
actionOnSuspend(ioRequest)
|
||||||
} catch (t: Throwable) {
|
} catch (t: Throwable) {
|
||||||
|
@ -28,6 +28,7 @@ import com.r3corda.node.services.api.ServiceHubInternal
|
|||||||
import com.r3corda.node.utilities.AddOrRemove
|
import com.r3corda.node.utilities.AddOrRemove
|
||||||
import com.r3corda.node.utilities.AffinityExecutor
|
import com.r3corda.node.utilities.AffinityExecutor
|
||||||
import kotlinx.support.jdk8.collections.removeIf
|
import kotlinx.support.jdk8.collections.removeIf
|
||||||
|
import com.r3corda.node.utilities.isolatedTransaction
|
||||||
import org.jetbrains.exposed.sql.Database
|
import org.jetbrains.exposed.sql.Database
|
||||||
import rx.Observable
|
import rx.Observable
|
||||||
import rx.subjects.PublishSubject
|
import rx.subjects.PublishSubject
|
||||||
@ -157,13 +158,14 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
|
|
||||||
private fun restoreFibersFromCheckpoints() {
|
private fun restoreFibersFromCheckpoints() {
|
||||||
mutex.locked {
|
mutex.locked {
|
||||||
checkpointStorage.checkpoints.forEach {
|
checkpointStorage.forEach {
|
||||||
// If a protocol is added before start() then don't attempt to restore it
|
// If a protocol is added before start() then don't attempt to restore it
|
||||||
if (!stateMachines.containsValue(it)) {
|
if (!stateMachines.containsValue(it)) {
|
||||||
val fiber = deserializeFiber(it.serialisedFiber)
|
val fiber = deserializeFiber(it.serialisedFiber)
|
||||||
initFiber(fiber)
|
initFiber(fiber)
|
||||||
stateMachines[fiber] = it
|
stateMachines[fiber] = it
|
||||||
}
|
}
|
||||||
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -276,6 +278,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
psm.serviceHub = serviceHub
|
psm.serviceHub = serviceHub
|
||||||
psm.actionOnSuspend = { ioRequest ->
|
psm.actionOnSuspend = { ioRequest ->
|
||||||
updateCheckpoint(psm)
|
updateCheckpoint(psm)
|
||||||
|
// We commit on the fibers transaction that was copied across ThreadLocals during suspend
|
||||||
|
// This will free up the ThreadLocal so on return the caller can carry on with other transactions
|
||||||
|
psm.commitTransaction()
|
||||||
processIORequest(ioRequest)
|
processIORequest(ioRequest)
|
||||||
}
|
}
|
||||||
psm.actionOnEnd = {
|
psm.actionOnEnd = {
|
||||||
@ -334,7 +339,13 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
|
|||||||
*/
|
*/
|
||||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
|
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
|
||||||
val fiber = createFiber(loggerName, logic)
|
val fiber = createFiber(loggerName, logic)
|
||||||
|
// We swap out the parent transaction context as using this frequently leads to a deadlock as we wait
|
||||||
|
// on the protocol completion future inside that context. The problem is that any progress checkpoints are
|
||||||
|
// unable to acquire the table lock and move forward till the calling transaction finishes.
|
||||||
|
// Committing in line here on a fresh context ensure we can progress.
|
||||||
|
isolatedTransaction(database) {
|
||||||
updateCheckpoint(fiber)
|
updateCheckpoint(fiber)
|
||||||
|
}
|
||||||
// If we are not started then our checkpoint will be picked up during start
|
// If we are not started then our checkpoint will be picked up during start
|
||||||
mutex.locked {
|
mutex.locked {
|
||||||
if (started) {
|
if (started) {
|
||||||
|
@ -35,6 +35,15 @@ fun configureDatabase(props: Properties): Pair<Closeable, Database> {
|
|||||||
return Pair(dataSource, database)
|
return Pair(dataSource, database)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun <T> isolatedTransaction(database: Database, block: Transaction.() -> T): T {
|
||||||
|
val oldContext = StrandLocalTransactionManager.setThreadLocalTx(null)
|
||||||
|
return try {
|
||||||
|
databaseTransaction(database, block)
|
||||||
|
} finally {
|
||||||
|
StrandLocalTransactionManager.restoreThreadLocalTx(oldContext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A relatively close copy of the [ThreadLocalTransactionManager] in Exposed but with the following adjustments to suit
|
* A relatively close copy of the [ThreadLocalTransactionManager] in Exposed but with the following adjustments to suit
|
||||||
* our environment:
|
* our environment:
|
||||||
@ -51,6 +60,17 @@ class StrandLocalTransactionManager(initWithDatabase: Database) : TransactionMan
|
|||||||
private val threadLocalDb = ThreadLocal<Database>()
|
private val threadLocalDb = ThreadLocal<Database>()
|
||||||
private val threadLocalTx = ThreadLocal<Transaction>()
|
private val threadLocalTx = ThreadLocal<Transaction>()
|
||||||
|
|
||||||
|
fun setThreadLocalTx(tx: Transaction?): Pair<Database?, Transaction?> {
|
||||||
|
val oldTx = threadLocalTx.get()
|
||||||
|
threadLocalTx.set(tx)
|
||||||
|
return Pair(threadLocalDb.get(), oldTx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun restoreThreadLocalTx(context: Pair<Database?, Transaction?>) {
|
||||||
|
threadLocalDb.set(context.first)
|
||||||
|
threadLocalTx.set(context.second)
|
||||||
|
}
|
||||||
|
|
||||||
var database: Database
|
var database: Database
|
||||||
get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}")
|
get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}")
|
||||||
set(value: Database) {
|
set(value: Database) {
|
||||||
|
@ -22,11 +22,13 @@ import com.r3corda.node.services.config.NodeConfiguration
|
|||||||
import com.r3corda.node.services.persistence.NodeAttachmentService
|
import com.r3corda.node.services.persistence.NodeAttachmentService
|
||||||
import com.r3corda.node.services.persistence.PerFileTransactionStorage
|
import com.r3corda.node.services.persistence.PerFileTransactionStorage
|
||||||
import com.r3corda.node.services.persistence.StorageServiceImpl
|
import com.r3corda.node.services.persistence.StorageServiceImpl
|
||||||
|
import com.r3corda.node.utilities.databaseTransaction
|
||||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
|
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
|
||||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
|
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
|
||||||
import com.r3corda.testing.*
|
import com.r3corda.testing.*
|
||||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
||||||
import com.r3corda.testing.node.MockNetwork
|
import com.r3corda.testing.node.MockNetwork
|
||||||
|
import com.r3corda.node.services.persistence.checkpoints
|
||||||
import org.assertj.core.api.Assertions.assertThat
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
import org.junit.After
|
import org.junit.After
|
||||||
import org.junit.Before
|
import org.junit.Before
|
||||||
@ -83,6 +85,9 @@ class TwoPartyTradeProtocolTests {
|
|||||||
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
|
aliceNode = net.createPartyNode(notaryNode.info.address, ALICE.name, ALICE_KEY)
|
||||||
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
|
bobNode = net.createPartyNode(notaryNode.info.address, BOB.name, BOB_KEY)
|
||||||
|
|
||||||
|
aliceNode.disableDBCloseOnStop()
|
||||||
|
bobNode.disableDBCloseOnStop()
|
||||||
|
|
||||||
bobNode.services.fillWithSomeTestCash(2000.DOLLARS)
|
bobNode.services.fillWithSomeTestCash(2000.DOLLARS)
|
||||||
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
|
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
|
||||||
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
|
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
|
||||||
@ -98,8 +103,14 @@ class TwoPartyTradeProtocolTests {
|
|||||||
aliceNode.stop()
|
aliceNode.stop()
|
||||||
bobNode.stop()
|
bobNode.stop()
|
||||||
|
|
||||||
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty()
|
databaseTransaction(aliceNode.database) {
|
||||||
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty()
|
assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty()
|
||||||
|
}
|
||||||
|
aliceNode.manuallyCloseDB()
|
||||||
|
databaseTransaction(bobNode.database) {
|
||||||
|
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
|
||||||
|
}
|
||||||
|
aliceNode.manuallyCloseDB()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -135,7 +146,7 @@ class TwoPartyTradeProtocolTests {
|
|||||||
bobNode.pumpReceive(false)
|
bobNode.pumpReceive(false)
|
||||||
|
|
||||||
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
|
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
|
||||||
assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1)
|
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
|
||||||
|
|
||||||
val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions
|
val bobTransactionsBeforeCrash = (bobNode.storage.validatedTransactions as PerFileTransactionStorage).transactions
|
||||||
assertThat(bobTransactionsBeforeCrash).isNotEmpty()
|
assertThat(bobTransactionsBeforeCrash).isNotEmpty()
|
||||||
@ -167,8 +178,8 @@ class TwoPartyTradeProtocolTests {
|
|||||||
|
|
||||||
assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty()
|
assertThat(bobNode.smm.findStateMachines(Buyer::class.java)).isEmpty()
|
||||||
|
|
||||||
assertThat(bobNode.checkpointStorage.checkpoints).isEmpty()
|
assertThat(bobNode.checkpointStorage.checkpoints()).isEmpty()
|
||||||
assertThat(aliceNode.checkpointStorage.checkpoints).isEmpty()
|
assertThat(aliceNode.checkpointStorage.checkpoints()).isEmpty()
|
||||||
|
|
||||||
val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null }
|
val restoredBobTransactions = bobTransactionsBeforeCrash.filter { bobNode.storage.validatedTransactions.getTransaction(it.id) != null }
|
||||||
assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash)
|
assertThat(restoredBobTransactions).containsAll(bobTransactionsBeforeCrash)
|
||||||
|
@ -0,0 +1,162 @@
|
|||||||
|
package com.r3corda.node.services.persistence
|
||||||
|
|
||||||
|
import com.google.common.primitives.Ints
|
||||||
|
import com.r3corda.core.serialization.SerializedBytes
|
||||||
|
import com.r3corda.core.utilities.LogHelper
|
||||||
|
import com.r3corda.node.services.api.Checkpoint
|
||||||
|
import com.r3corda.node.services.api.CheckpointStorage
|
||||||
|
import com.r3corda.node.services.transactions.PersistentUniquenessProvider
|
||||||
|
import com.r3corda.node.utilities.configureDatabase
|
||||||
|
import com.r3corda.node.utilities.databaseTransaction
|
||||||
|
import com.r3corda.testing.node.makeTestDataSourceProperties
|
||||||
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
|
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||||
|
import org.jetbrains.exposed.sql.Database
|
||||||
|
import org.junit.After
|
||||||
|
import org.junit.Before
|
||||||
|
import org.junit.Test
|
||||||
|
import java.io.Closeable
|
||||||
|
|
||||||
|
internal fun CheckpointStorage.checkpoints(): List<Checkpoint> {
|
||||||
|
val checkpoints = mutableListOf<Checkpoint>()
|
||||||
|
forEach {
|
||||||
|
checkpoints += it
|
||||||
|
true
|
||||||
|
}
|
||||||
|
return checkpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
class DBCheckpointStorageTests {
|
||||||
|
lateinit var checkpointStorage: DBCheckpointStorage
|
||||||
|
lateinit var dataSource: Closeable
|
||||||
|
lateinit var database: Database
|
||||||
|
|
||||||
|
@Before
|
||||||
|
fun setUp() {
|
||||||
|
LogHelper.setLevel(PersistentUniquenessProvider::class)
|
||||||
|
val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties())
|
||||||
|
dataSource = dataSourceAndDatabase.first
|
||||||
|
database = dataSourceAndDatabase.second
|
||||||
|
newCheckpointStorage()
|
||||||
|
}
|
||||||
|
|
||||||
|
@After
|
||||||
|
fun cleanUp() {
|
||||||
|
dataSource.close()
|
||||||
|
LogHelper.reset(PersistentUniquenessProvider::class)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `add new checkpoint`() {
|
||||||
|
val checkpoint = newCheckpoint()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
|
||||||
|
}
|
||||||
|
newCheckpointStorage()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `remove checkpoint`() {
|
||||||
|
val checkpoint = newCheckpoint()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.removeCheckpoint(checkpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||||
|
}
|
||||||
|
newCheckpointStorage()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `add and remove checkpoint in single commit operate`() {
|
||||||
|
val checkpoint = newCheckpoint()
|
||||||
|
val checkpoint2 = newCheckpoint()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
|
checkpointStorage.addCheckpoint(checkpoint2)
|
||||||
|
checkpointStorage.removeCheckpoint(checkpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
|
||||||
|
}
|
||||||
|
newCheckpointStorage()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `remove unknown checkpoint`() {
|
||||||
|
val checkpoint = newCheckpoint()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
|
||||||
|
checkpointStorage.removeCheckpoint(checkpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `add two checkpoints then remove first one`() {
|
||||||
|
val firstCheckpoint = newCheckpoint()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.addCheckpoint(firstCheckpoint)
|
||||||
|
}
|
||||||
|
val secondCheckpoint = newCheckpoint()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.addCheckpoint(secondCheckpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.removeCheckpoint(firstCheckpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
|
||||||
|
}
|
||||||
|
newCheckpointStorage()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
fun `add checkpoint and then remove after 'restart'`() {
|
||||||
|
val originalCheckpoint = newCheckpoint()
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.addCheckpoint(originalCheckpoint)
|
||||||
|
}
|
||||||
|
newCheckpointStorage()
|
||||||
|
val reconstructedCheckpoint = databaseTransaction(database) {
|
||||||
|
checkpointStorage.checkpoints().single()
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
|
||||||
|
}
|
||||||
|
databaseTransaction(database) {
|
||||||
|
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun newCheckpointStorage() {
|
||||||
|
databaseTransaction(database) {
|
||||||
|
checkpointStorage = DBCheckpointStorage()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private var checkpointCount = 1
|
||||||
|
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
|
||||||
|
|
||||||
|
}
|
@ -34,9 +34,9 @@ class PerFileCheckpointStorageTests {
|
|||||||
fun `add new checkpoint`() {
|
fun `add new checkpoint`() {
|
||||||
val checkpoint = newCheckpoint()
|
val checkpoint = newCheckpoint()
|
||||||
checkpointStorage.addCheckpoint(checkpoint)
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
|
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
|
||||||
newCheckpointStorage()
|
newCheckpointStorage()
|
||||||
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
|
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -44,9 +44,9 @@ class PerFileCheckpointStorageTests {
|
|||||||
val checkpoint = newCheckpoint()
|
val checkpoint = newCheckpoint()
|
||||||
checkpointStorage.addCheckpoint(checkpoint)
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
checkpointStorage.removeCheckpoint(checkpoint)
|
checkpointStorage.removeCheckpoint(checkpoint)
|
||||||
assertThat(checkpointStorage.checkpoints).isEmpty()
|
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||||
newCheckpointStorage()
|
newCheckpointStorage()
|
||||||
assertThat(checkpointStorage.checkpoints).isEmpty()
|
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -64,9 +64,9 @@ class PerFileCheckpointStorageTests {
|
|||||||
val secondCheckpoint = newCheckpoint()
|
val secondCheckpoint = newCheckpoint()
|
||||||
checkpointStorage.addCheckpoint(secondCheckpoint)
|
checkpointStorage.addCheckpoint(secondCheckpoint)
|
||||||
checkpointStorage.removeCheckpoint(firstCheckpoint)
|
checkpointStorage.removeCheckpoint(firstCheckpoint)
|
||||||
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint)
|
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
|
||||||
newCheckpointStorage()
|
newCheckpointStorage()
|
||||||
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint)
|
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -74,10 +74,10 @@ class PerFileCheckpointStorageTests {
|
|||||||
val originalCheckpoint = newCheckpoint()
|
val originalCheckpoint = newCheckpoint()
|
||||||
checkpointStorage.addCheckpoint(originalCheckpoint)
|
checkpointStorage.addCheckpoint(originalCheckpoint)
|
||||||
newCheckpointStorage()
|
newCheckpointStorage()
|
||||||
val reconstructedCheckpoint = checkpointStorage.checkpoints.single()
|
val reconstructedCheckpoint = checkpointStorage.checkpoints().single()
|
||||||
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
|
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
|
||||||
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
|
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
|
||||||
assertThat(checkpointStorage.checkpoints).isEmpty()
|
assertThat(checkpointStorage.checkpoints()).isEmpty()
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@ -86,7 +86,7 @@ class PerFileCheckpointStorageTests {
|
|||||||
checkpointStorage.addCheckpoint(checkpoint)
|
checkpointStorage.addCheckpoint(checkpoint)
|
||||||
Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray())
|
Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray())
|
||||||
newCheckpointStorage()
|
newCheckpointStorage()
|
||||||
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
|
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
private fun newCheckpointStorage() {
|
private fun newCheckpointStorage() {
|
||||||
|
@ -12,6 +12,7 @@ import com.r3corda.core.serialization.deserialize
|
|||||||
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
|
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
|
||||||
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
|
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
|
||||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
||||||
|
import com.r3corda.node.services.persistence.checkpoints
|
||||||
import com.r3corda.testing.node.MockNetwork
|
import com.r3corda.testing.node.MockNetwork
|
||||||
import com.r3corda.testing.node.MockNetwork.MockNode
|
import com.r3corda.testing.node.MockNetwork.MockNode
|
||||||
import org.assertj.core.api.Assertions.assertThat
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
@ -135,13 +136,13 @@ class StateMachineManagerTests {
|
|||||||
|
|
||||||
// Kick off first send and receive
|
// Kick off first send and receive
|
||||||
node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
|
node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
|
||||||
assertEquals(1, node2.checkpointStorage.checkpoints.count())
|
assertEquals(1, node2.checkpointStorage.checkpoints().count())
|
||||||
// Restart node and thus reload the checkpoint and resend the message with same UUID
|
// Restart node and thus reload the checkpoint and resend the message with same UUID
|
||||||
node2.stop()
|
node2.stop()
|
||||||
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
|
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
|
||||||
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
|
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
assertEquals(1, node2.checkpointStorage.checkpoints.count())
|
assertEquals(1, node2.checkpointStorage.checkpoints().count())
|
||||||
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
|
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
|
||||||
net.runNetwork()
|
net.runNetwork()
|
||||||
node2b.smm.executor.flush()
|
node2b.smm.executor.flush()
|
||||||
@ -149,8 +150,8 @@ class StateMachineManagerTests {
|
|||||||
// Check protocols completed cleanly and didn't get out of phase
|
// Check protocols completed cleanly and didn't get out of phase
|
||||||
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
|
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
|
||||||
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages
|
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") // can't give a precise value as every addMessageHandler re-runs the undelivered messages
|
||||||
assertEquals(0, node2b.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
|
assertEquals(0, node2b.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
|
||||||
assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
|
assertEquals(0, node3.checkpointStorage.checkpoints().count(), "Checkpoints left after restored protocol should have ended")
|
||||||
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
|
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
|
||||||
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
|
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
|
||||||
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
|
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
|
||||||
|
@ -6,6 +6,7 @@ import com.google.common.util.concurrent.Futures
|
|||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.google.common.util.concurrent.SettableFuture
|
import com.google.common.util.concurrent.SettableFuture
|
||||||
import com.r3corda.core.crypto.Party
|
import com.r3corda.core.crypto.Party
|
||||||
|
import com.r3corda.core.div
|
||||||
import com.r3corda.core.messaging.SingleMessageRecipient
|
import com.r3corda.core.messaging.SingleMessageRecipient
|
||||||
import com.r3corda.core.messaging.TopicSession
|
import com.r3corda.core.messaging.TopicSession
|
||||||
import com.r3corda.core.messaging.runOnNextMessage
|
import com.r3corda.core.messaging.runOnNextMessage
|
||||||
@ -20,10 +21,13 @@ import com.r3corda.core.testing.InMemoryVaultService
|
|||||||
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
|
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
|
||||||
import com.r3corda.core.utilities.loggerFor
|
import com.r3corda.core.utilities.loggerFor
|
||||||
import com.r3corda.node.internal.AbstractNode
|
import com.r3corda.node.internal.AbstractNode
|
||||||
|
import com.r3corda.node.services.api.CheckpointStorage
|
||||||
import com.r3corda.node.services.config.NodeConfiguration
|
import com.r3corda.node.services.config.NodeConfiguration
|
||||||
import com.r3corda.node.services.keys.E2ETestKeyManagementService
|
import com.r3corda.node.services.keys.E2ETestKeyManagementService
|
||||||
import com.r3corda.node.services.messaging.CordaRPCOps
|
import com.r3corda.node.services.messaging.CordaRPCOps
|
||||||
import com.r3corda.node.services.network.InMemoryNetworkMapService
|
import com.r3corda.node.services.network.InMemoryNetworkMapService
|
||||||
|
import com.r3corda.node.services.persistence.DBCheckpointStorage
|
||||||
|
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
|
||||||
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
|
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
|
||||||
import com.r3corda.node.utilities.databaseTransaction
|
import com.r3corda.node.utilities.databaseTransaction
|
||||||
import com.r3corda.protocols.ServiceRequestMessage
|
import com.r3corda.protocols.ServiceRequestMessage
|
||||||
@ -97,6 +101,14 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
|||||||
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get()
|
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override fun initialiseCheckpointService(dir: Path): CheckpointStorage {
|
||||||
|
return if (mockNet.threadPerNode) {
|
||||||
|
DBCheckpointStorage()
|
||||||
|
} else {
|
||||||
|
PerFileCheckpointStorage(dir / "checkpoints")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
override fun makeIdentityService() = MockIdentityService(mockNet.identities)
|
override fun makeIdentityService() = MockIdentityService(mockNet.identities)
|
||||||
|
|
||||||
override fun makeVaultService(): VaultService = InMemoryVaultService(services)
|
override fun makeVaultService(): VaultService = InMemoryVaultService(services)
|
||||||
@ -153,6 +165,15 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
|
|||||||
send(topic, target, payload)
|
send(topic, target, payload)
|
||||||
return receive(topic, payload.sessionID)
|
return receive(topic, payload.sessionID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fun disableDBCloseOnStop() {
|
||||||
|
runOnStop.remove(dbCloser)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun manuallyCloseDB() {
|
||||||
|
dbCloser?.run()
|
||||||
|
dbCloser = null
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns a node, optionally created by the passed factory method. */
|
/** Returns a node, optionally created by the passed factory method. */
|
||||||
|
Reference in New Issue
Block a user