peristing psm checkpoints to disk (one per file)

This commit is contained in:
Shams Asari 2016-05-09 10:15:08 +01:00
parent 48137d325b
commit 4271693b85
9 changed files with 321 additions and 108 deletions

View File

@ -7,21 +7,22 @@ import com.codahale.metrics.Gauge
import com.esotericsoftware.kryo.io.Input
import com.google.common.base.Throwables
import com.google.common.util.concurrent.ListenableFuture
import core.crypto.SecureHash
import core.crypto.sha256
import core.node.ServiceHub
import core.node.storage.Checkpoint
import core.protocols.ProtocolLogic
import core.protocols.ProtocolStateMachine
import core.serialization.SerializedBytes
import core.serialization.THREAD_LOCAL_KRYO
import core.serialization.createKryo
import core.serialization.deserialize
import core.serialization.serialize
import core.then
import core.utilities.AffinityExecutor
import core.utilities.ProgressTracker
import core.utilities.trace
import java.io.PrintWriter
import java.io.StringWriter
import java.util.*
import java.util.Collections.synchronizedMap
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.ThreadSafe
@ -58,10 +59,10 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
// This map is backed by a database and will be used to store serialised state machines to disk, so we can resurrect
// them across node restarts.
private val checkpointsMap = serviceHub.storageService.stateMachines
private val checkpointStorage = serviceHub.storageService.checkpointStorage
// A list of all the state machines being managed by this class. We expose snapshots of it via the stateMachines
// property.
private val stateMachines = Collections.synchronizedList(ArrayList<ProtocolLogic<*>>())
private val stateMachines = synchronizedMap(HashMap<ProtocolStateMachine<*>, Checkpoint>())
// Monitoring support.
private val metrics = serviceHub.monitoringService.metrics
@ -78,7 +79,10 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
fun <T> findStateMachines(klass: Class<out ProtocolLogic<T>>): List<Pair<ProtocolLogic<T>, ListenableFuture<T>>> {
synchronized(stateMachines) {
@Suppress("UNCHECKED_CAST")
return stateMachines.filterIsInstance(klass).map { it to (it.psm as ProtocolStateMachine<T>).resultFuture }
return stateMachines.keys
.map { it.logic }
.filterIsInstance(klass)
.map { it to (it.psm as ProtocolStateMachine<T>).resultFuture }
}
}
@ -89,13 +93,6 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
field.get(null)
}
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
private class Checkpoint(
val serialisedFiber: ByteArray,
val awaitingTopic: String,
val awaitingObjectOfType: String // java class name
)
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as ProtocolStateMachine<*>).logger.error("Caught exception from protocol", throwable)
@ -105,17 +102,16 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
/** Reads the database map and resurrects any serialised state machines. */
private fun restoreCheckpoints() {
for (bytes in checkpointsMap.values) {
val checkpoint = bytes.deserialize<Checkpoint>()
val checkpointKey = SecureHash.sha256(bytes)
for (checkpoint in checkpointStorage.checkpoints) {
// Grab the Kryo engine configured by Quasar for its own stuff, and then do our own configuration on top
// so we can deserialised the nested stream that holds the fiber.
val psm = deserializeFiber(checkpoint.serialisedFiber)
stateMachines.add(psm.logic)
initFiber(psm, checkpoint)
val awaitingObjectOfType = Class.forName(checkpoint.awaitingObjectOfType)
val topic = checkpoint.awaitingTopic
psm.logger.info("restored ${psm.logic} - was previously awaiting on topic $topic")
// And now re-wire the deserialised continuation back up to the network service.
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
// TODO: See security note below.
@ -123,7 +119,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
if (!awaitingObjectOfType.isInstance(obj))
throw ClassCastException("Received message of unexpected type: ${obj.javaClass.name} vs ${awaitingObjectOfType.name}")
psm.logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
iterateStateMachine(psm, obj, checkpointKey) {
iterateStateMachine(psm, obj) {
try {
Fiber.unparkDeserialized(it, scheduler)
} catch(e: Throwable) {
@ -134,6 +130,12 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
}
}
private fun deserializeFiber(serialisedFiber: SerializedBytes<ProtocolStateMachine<*>>): ProtocolStateMachine<*> {
val deserializer = Fiber.getFiberSerializer(false) as KryoSerializer
val kryo = createKryo(deserializer.kryo)
return serialisedFiber.deserialize(kryo)
}
private fun logError(e: Throwable, obj: Any, topic: String, psm: ProtocolStateMachine<*>) {
psm.logger.error("Protocol state machine ${psm.javaClass.name} threw '${Throwables.getRootCause(e)}' " +
"when handling a message of type ${obj.javaClass.name} on topic $topic")
@ -144,11 +146,16 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
}
}
private fun deserializeFiber(bits: ByteArray): ProtocolStateMachine<*> {
val deserializer = Fiber.getFiberSerializer(false) as KryoSerializer
val kryo = createKryo(deserializer.kryo)
val psm = kryo.readClassAndObject(Input(bits)) as ProtocolStateMachine<*>
return psm
private fun initFiber(psm: ProtocolStateMachine<*>, checkpoint: Checkpoint?) {
stateMachines[psm] = checkpoint
psm.resultFuture.then(executor) {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
val finalCheckpoint = stateMachines.remove(psm)
if (finalCheckpoint != null) {
checkpointStorage.removeCheckpoint(finalCheckpoint)
}
totalFinishedProtocols.inc()
}
}
/**
@ -160,9 +167,9 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
try {
val fiber = ProtocolStateMachine(logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion
stateMachines.add(logic)
initFiber(fiber, null)
executor.executeASAP {
iterateStateMachine(fiber, null, null) {
iterateStateMachine(fiber, null) {
it.start()
}
totalStartedProtocols.inc()
@ -174,27 +181,26 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
}
}
private fun persistCheckpoint(prevCheckpointKey: SecureHash?, new: ByteArray): SecureHash {
private fun replaceCheckpoint(psm: ProtocolStateMachine<*>, newCheckpoint: Checkpoint) {
// It's OK for this to be unsynchronised, as the prev/new byte arrays are specific to a continuation instance,
// and the underlying map provided by the database layer is expected to be thread safe.
if (prevCheckpointKey != null)
checkpointsMap.remove(prevCheckpointKey)
val key = SecureHash.sha256(new)
checkpointsMap[key] = new
val previousCheckpoint = stateMachines.put(psm, newCheckpoint)
if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint)
}
checkpointStorage.addCheckpoint(newCheckpoint)
checkpointingMeter.mark()
return key
}
private fun iterateStateMachine(psm: ProtocolStateMachine<*>,
obj: Any?,
prevCheckpointKey: SecureHash?,
resumeFunc: (ProtocolStateMachine<*>) -> Unit) {
executor.checkOnThread()
val onSuspend = fun(request: FiberRequest, serFiber: ByteArray) {
val onSuspend = fun(request: FiberRequest, serialisedFiber: SerializedBytes<ProtocolStateMachine<*>>) {
// We have a request to do something: send, receive, or send-and-receive.
if (request is FiberRequest.ExpectingResponse<*>) {
// Prepare a listener on the network that runs in the background thread when we received a message.
checkpointAndSetupMessageHandler(psm, request, prevCheckpointKey, serFiber)
checkpointAndSetupMessageHandler(psm, request, serialisedFiber)
}
// If an object to send was provided (not null), send it now.
request.obj?.let {
@ -204,7 +210,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
}
if (request is FiberRequest.NotExpectingResponse) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
iterateStateMachine(psm, null, prevCheckpointKey) {
iterateStateMachine(psm, null) {
try {
Fiber.unpark(it, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
@ -217,26 +223,15 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
psm.prepareForResumeWith(serviceHub, obj, onSuspend)
resumeFunc(psm)
// We're back! Check if the fiber is finished and if so, clean up.
if (psm.isTerminated) {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
stateMachines.remove(psm.logic)
checkpointsMap.remove(prevCheckpointKey)
totalFinishedProtocols.inc()
}
}
private fun checkpointAndSetupMessageHandler(psm: ProtocolStateMachine<*>,
request: FiberRequest.ExpectingResponse<*>,
prevCheckpointKey: SecureHash?,
serialisedFiber: ByteArray) {
serialisedFiber: SerializedBytes<ProtocolStateMachine<*>>) {
executor.checkOnThread()
val topic = "${request.topic}.${request.sessionIDForReceive}"
val checkpoint = Checkpoint(serialisedFiber, topic, request.responseType.name)
val curPersistedBytes = checkpoint.serialize().bits
persistCheckpoint(prevCheckpointKey, curPersistedBytes)
val newCheckpointKey = curPersistedBytes.sha256()
val newCheckpoint = Checkpoint(serialisedFiber, topic, request.responseType.name)
replaceCheckpoint(psm, newCheckpoint)
psm.logger.trace { "Waiting for message of type ${request.responseType.name} on $topic" }
val consumed = AtomicBoolean()
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
@ -254,7 +249,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
val obj: Any = THREAD_LOCAL_KRYO.get().readClassAndObject(Input(netMsg.data))
if (!request.responseType.isInstance(obj))
throw IllegalStateException("Expected message of type ${request.responseType.name} but got ${obj.javaClass.name}", request.stackTraceInCaseOfProblems)
iterateStateMachine(psm, obj, newCheckpointKey) {
iterateStateMachine(psm, obj) {
try {
Fiber.unpark(it, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
@ -289,4 +284,4 @@ class StateMachineManager(val serviceHub: ServiceHub, val executor: AffinityExec
}
}
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -11,8 +11,10 @@ import core.crypto.generateKeyPair
import core.messaging.MessagingService
import core.messaging.StateMachineManager
import core.messaging.runOnNextMessage
import core.node.subsystems.*
import core.node.services.*
import core.node.storage.CheckpointStorage
import core.node.storage.PerFileCheckpointStorage
import core.node.subsystems.*
import core.random63BitValue
import core.serialization.deserialize
import core.serialization.serialize
@ -200,13 +202,14 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
protected open fun initialiseStorageService(dir: Path): StorageService {
val attachments = makeAttachmentStorage(dir)
val checkpointStorage = PerFileCheckpointStorage(dir.resolve("checkpoints"))
_servicesThatAcceptUploads += attachments
val (identity, keypair) = obtainKeyPair(dir)
return constructStorageService(attachments, keypair, identity)
return constructStorageService(attachments, checkpointStorage, keypair, identity)
}
protected open fun constructStorageService(attachments: NodeAttachmentService, keypair: KeyPair, identity: Party) =
StorageServiceImpl(attachments, keypair, identity)
protected open fun constructStorageService(attachments: NodeAttachmentService, checkpointStorage: CheckpointStorage, keypair: KeyPair, identity: Party) =
StorageServiceImpl(attachments, checkpointStorage, keypair, identity)
private fun obtainKeyPair(dir: Path): Pair<Party, KeyPair> {
// Load the private identity key, creating it if necessary. The identity key is a long term well known key that

View File

@ -0,0 +1,109 @@
package core.node.storage
import core.crypto.sha256
import core.protocols.ProtocolStateMachine
import core.serialization.SerializedBytes
import core.serialization.deserialize
import core.serialization.serialize
import core.utilities.loggerFor
import core.utilities.trace
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.StandardCopyOption.ATOMIC_MOVE
import java.util.*
import java.util.Collections.synchronizedMap
import javax.annotation.concurrent.ThreadSafe
/**
* Thread-safe storage of fiber checkpoints.
*/
interface CheckpointStorage {
/**
* Add a new checkpoint to the store.
*/
fun addCheckpoint(checkpoint: Checkpoint)
/**
* Remove existing checkpoint from the store. It is an error to attempt to remove a checkpoint which doesn't exist
* in the store. Doing so will throw an [IllegalArgumentException].
*/
fun removeCheckpoint(checkpoint: Checkpoint)
/**
* Returns a snapshot of all the checkpoints in the store.
* This may return more checkpoints than were added to this instance of the store; for example if the store persists
* checkpoints to disk.
*/
val checkpoints: Iterable<Checkpoint>
}
/**
* File-based checkpoint storage, storing checkpoints per file.
*/
@ThreadSafe
class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage {
companion object {
private val logger = loggerFor<PerFileCheckpointStorage>()
private val fileExtension = ".checkpoint"
}
private val checkpointFiles = synchronizedMap(IdentityHashMap<Checkpoint, Path>())
init {
logger.trace { "Initialising per file checkpoint storage on $storeDir" }
Files.createDirectories(storeDir)
Files.list(storeDir)
.filter { it.toString().toLowerCase().endsWith(fileExtension) }
.forEach {
val checkpoint = Files.readAllBytes(it).deserialize<Checkpoint>()
checkpointFiles[checkpoint] = it
}
}
override fun addCheckpoint(checkpoint: Checkpoint) {
val serialisedCheckpoint = checkpoint.serialize()
val fileName = "${serialisedCheckpoint.hash.toString().toLowerCase()}$fileExtension"
val checkpointFile = storeDir.resolve(fileName)
atomicWrite(checkpointFile, serialisedCheckpoint)
logger.trace { "Stored $checkpoint to $checkpointFile" }
checkpointFiles[checkpoint] = checkpointFile
}
private fun atomicWrite(checkpointFile: Path, serialisedCheckpoint: SerializedBytes<Checkpoint>) {
val tempCheckpointFile = checkpointFile.parent.resolve("${checkpointFile.fileName}.tmp")
serialisedCheckpoint.writeToFile(tempCheckpointFile)
Files.move(tempCheckpointFile, checkpointFile, ATOMIC_MOVE)
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
val checkpointFile = checkpointFiles.remove(checkpoint)
require(checkpointFile != null) { "Trying to removing unknown checkpoint: $checkpoint" }
Files.delete(checkpointFile)
logger.trace { "Removed $checkpoint ($checkpointFile)" }
}
override val checkpoints: Iterable<Checkpoint>
get() = synchronized(checkpointFiles) {
checkpointFiles.keys.toList()
}
}
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachine<*>>,
val awaitingTopic: String,
val awaitingObjectOfType: String // java class name
)
{
override fun toString(): String {
return "Checkpoint(#serialisedFiber=${serialisedFiber.sha256()}, awaitingTopic=$awaitingTopic, awaitingObjectOfType=$awaitingObjectOfType)"
}
}

View File

@ -4,14 +4,11 @@ import com.codahale.metrics.MetricRegistry
import contracts.Cash
import core.*
import core.crypto.SecureHash
import core.messaging.MessagingService
import core.node.subsystems.NetworkMapCache
import core.node.services.AttachmentStorage
import core.utilities.RecordingMap
import core.node.storage.CheckpointStorage
import java.security.KeyPair
import java.security.PrivateKey
import java.security.PublicKey
import java.time.Clock
import java.util.*
/**
@ -132,7 +129,7 @@ interface StorageService {
*/
val validatedTransactions: MutableMap<SecureHash, SignedTransaction>
val stateMachines: MutableMap<SecureHash, ByteArray>
val checkpointStorage: CheckpointStorage
/** Provides access to storage of arbitrary JAR files (which may contain only data, no code). */
val attachments: AttachmentStorage

View File

@ -4,15 +4,16 @@ import core.Party
import core.SignedTransaction
import core.crypto.SecureHash
import core.node.services.AttachmentStorage
import core.node.subsystems.StorageService
import core.node.storage.CheckpointStorage
import core.utilities.RecordingMap
import org.slf4j.LoggerFactory
import java.security.KeyPair
import java.util.*
open class StorageServiceImpl(attachments: AttachmentStorage,
keypair: KeyPair,
identity: Party = Party("Unit test party", keypair.public),
open class StorageServiceImpl(override val attachments: AttachmentStorage,
override val checkpointStorage: CheckpointStorage,
override val myLegalIdentityKey: KeyPair,
override val myLegalIdentity: Party = Party("Unit test party", myLegalIdentityKey.public),
// This parameter is for unit tests that want to observe operation details.
val recordingAs: (String) -> String = { tableName -> "" })
: StorageService {
@ -36,10 +37,5 @@ open class StorageServiceImpl(attachments: AttachmentStorage,
override val validatedTransactions: MutableMap<SecureHash, SignedTransaction>
get() = getMapOriginal("validated-transactions")
override val stateMachines: MutableMap<SecureHash, ByteArray>
get() = getMapOriginal("state-machines")
override val attachments: AttachmentStorage = attachments
override val myLegalIdentity = identity
override val myLegalIdentityKey = keypair
}

View File

@ -4,17 +4,17 @@ import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberScheduler
import co.paralleluniverse.fibers.Suspendable
import co.paralleluniverse.io.serialization.kryo.KryoSerializer
import com.esotericsoftware.kryo.io.Output
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import core.messaging.MessageRecipients
import core.messaging.StateMachineManager
import core.node.ServiceHub
import core.serialization.SerializedBytes
import core.serialization.createKryo
import core.serialization.serialize
import core.utilities.UntrustworthyData
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.ByteArrayOutputStream
/**
* A ProtocolStateMachine instance is a suspendable fiber that delegates all actual logic to a [ProtocolLogic] instance.
@ -27,7 +27,7 @@ import java.io.ByteArrayOutputStream
class ProtocolStateMachine<R>(val logic: ProtocolLogic<R>, scheduler: FiberScheduler, val loggerName: String) : Fiber<R>("protocol", scheduler) {
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient private var suspendFunc: ((result: StateMachineManager.FiberRequest, serFiber: ByteArray) -> Unit)? = null
@Transient private var suspendAction: ((result: StateMachineManager.FiberRequest, serialisedFiber: SerializedBytes<ProtocolStateMachine<*>>) -> Unit)? = null
@Transient private var resumeWithObject: Any? = null
@Transient lateinit var serviceHub: ServiceHub
@ -54,9 +54,10 @@ class ProtocolStateMachine<R>(val logic: ProtocolLogic<R>, scheduler: FiberSched
logic.psm = this
}
fun prepareForResumeWith(serviceHub: ServiceHub, withObject: Any?,
suspendFunc: (StateMachineManager.FiberRequest, ByteArray) -> Unit) {
this.suspendFunc = suspendFunc
fun prepareForResumeWith(serviceHub: ServiceHub,
withObject: Any?,
suspendAction: (StateMachineManager.FiberRequest, SerializedBytes<ProtocolStateMachine<*>>) -> Unit) {
this.suspendAction = suspendAction
this.resumeWithObject = withObject
this.serviceHub = serviceHub
}
@ -76,16 +77,7 @@ class ProtocolStateMachine<R>(val logic: ProtocolLogic<R>, scheduler: FiberSched
@Suspendable @Suppress("UNCHECKED_CAST")
private fun <T : Any> suspendAndExpectReceive(with: StateMachineManager.FiberRequest): UntrustworthyData<T> {
parkAndSerialize { fiber, serializer ->
// We don't use the passed-in serializer here, because we need to use our own augmented Kryo.
val deserializer = getFiberSerializer(false) as KryoSerializer
val kryo = createKryo(deserializer.kryo)
val stream = ByteArrayOutputStream()
Output(stream).use {
kryo.writeClassAndObject(it, this)
}
suspendFunc!!(with, stream.toByteArray())
}
suspend(with)
val tmp = resumeWithObject ?: throw IllegalStateException("Expected to receive something")
resumeWithObject = null
return UntrustworthyData(tmp as T)
@ -107,6 +99,17 @@ class ProtocolStateMachine<R>(val logic: ProtocolLogic<R>, scheduler: FiberSched
@Suspendable
fun send(topic: String, destination: MessageRecipients, sessionID: Long, obj: Any) {
val result = StateMachineManager.FiberRequest.NotExpectingResponse(topic, destination, sessionID, obj)
parkAndSerialize { fiber, writer -> suspendFunc!!(result, writer.write(fiber)) }
suspend(result)
}
@Suspendable
private fun suspend(with: StateMachineManager.FiberRequest) {
parkAndSerialize { fiber, serializer ->
// We don't use the passed-in serializer here, because we need to use our own augmented Kryo.
val deserializer = getFiberSerializer(false) as KryoSerializer
val kryo = createKryo(deserializer.kryo)
suspendAction!!(with, this.serialize(kryo))
}
}
}

View File

@ -4,8 +4,10 @@ import com.codahale.metrics.MetricRegistry
import core.crypto.*
import core.messaging.MessagingService
import core.node.ServiceHub
import core.node.subsystems.*
import core.node.services.*
import core.node.storage.Checkpoint
import core.node.storage.CheckpointStorage
import core.node.subsystems.*
import core.serialization.SerializedBytes
import core.serialization.deserialize
import core.testing.MockNetworkMapCache
@ -22,6 +24,7 @@ import java.time.Clock
import java.time.Duration
import java.time.ZoneId
import java.util.*
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.jar.JarInputStream
import javax.annotation.concurrent.ThreadSafe
@ -89,8 +92,25 @@ class MockAttachmentStorage : AttachmentStorage {
}
}
class MockCheckpointStorage : CheckpointStorage {
private val _checkpoints = ConcurrentLinkedQueue<Checkpoint>()
override val checkpoints: Iterable<Checkpoint>
get() = _checkpoints.toList()
override fun addCheckpoint(checkpoint: Checkpoint) {
_checkpoints.add(checkpoint)
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
require(_checkpoints.remove(checkpoint))
}
}
@ThreadSafe
class MockStorageService : StorageServiceImpl(MockAttachmentStorage(), generateKeyPair())
class MockStorageService : StorageServiceImpl(MockAttachmentStorage(), MockCheckpointStorage(), generateKeyPair())
class MockServices(
customWallet: WalletService? = null,

View File

@ -7,16 +7,18 @@ import core.crypto.SecureHash
import core.node.NodeConfiguration
import core.node.NodeInfo
import core.node.ServiceHub
import core.node.services.NodeAttachmentService
import core.node.services.ServiceType
import core.node.storage.CheckpointStorage
import core.node.subsystems.NodeWalletService
import core.node.subsystems.StorageService
import core.node.subsystems.StorageServiceImpl
import core.node.subsystems.Wallet
import core.node.services.*
import core.testing.InMemoryMessagingNetwork
import core.testing.MockNetwork
import core.testutils.*
import core.utilities.BriefLogFormatter
import core.utilities.RecordingMap
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Test
@ -26,7 +28,6 @@ import java.io.ByteArrayOutputStream
import java.nio.file.Path
import java.security.KeyPair
import java.security.PublicKey
import java.util.*
import java.util.concurrent.ExecutionException
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry
@ -94,11 +95,14 @@ class TwoPartyTradeProtocolTests {
aliceNode.stop()
bobNode.stop()
assertThat(aliceNode.storage.checkpointStorage.checkpoints).isEmpty()
assertThat(bobNode.storage.checkpointStorage.checkpoints).isEmpty()
}
}
@Test
fun shutdownAndRestore() {
fun `shutdown and restore`() {
transactionGroupFor<ContractState> {
var (aliceNode, bobNode) = net.createTwoNodes()
val aliceAddr = aliceNode.net.myAddress
@ -149,9 +153,7 @@ class TwoPartyTradeProtocolTests {
pumpBob()
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
// Save the state machine to "disk" (i.e. a variable, here)
val savedCheckpoints = HashMap(bobNode.storage.stateMachines)
assertEquals(1, savedCheckpoints.size)
assertThat(bobNode.storage.checkpointStorage.checkpoints).hasSize(1)
// .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk.
bobNode.stop()
@ -165,14 +167,7 @@ class TwoPartyTradeProtocolTests {
bobNode = net.createNode(networkMapAddr, bobAddr.id, object : MockNetwork.Factory {
override fun create(dir: Path, config: NodeConfiguration, network: MockNetwork, networkMapAddr: NodeInfo?,
advertisedServices: Set<ServiceType>, id: Int): MockNetwork.MockNode {
return object : MockNetwork.MockNode(dir, config, network, networkMapAddr, advertisedServices, bobAddr.id) {
override fun initialiseStorageService(dir: Path): StorageService {
val ss = super.initialiseStorageService(dir)
val smMap = ss.stateMachines
smMap.putAll(savedCheckpoints)
return ss
}
}
return MockNetwork.MockNode(dir, config, network, networkMapAddr, advertisedServices, bobAddr.id)
}
})
@ -185,8 +180,7 @@ class TwoPartyTradeProtocolTests {
// Bob is now finished and has the same transaction as Alice.
assertEquals(bobFuture.get(), aliceFuture.get())
assertTrue(bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java).isEmpty())
assertThat(bobNode.smm.findStateMachines(TwoPartyTradeProtocol.Buyer::class.java)).isEmpty()
}
}
@ -200,9 +194,9 @@ class TwoPartyTradeProtocolTests {
advertisedServices: Set<ServiceType>, id: Int): MockNetwork.MockNode {
return object : MockNetwork.MockNode(dir, config, network, networkMapAddr, advertisedServices, id) {
// That constructs the storage service object in a customised way ...
override fun constructStorageService(attachments: NodeAttachmentService, keypair: KeyPair, identity: Party): StorageServiceImpl {
override fun constructStorageService(attachments: NodeAttachmentService, checkpointStorage: CheckpointStorage, keypair: KeyPair, identity: Party): StorageServiceImpl {
// To use RecordingMaps instead of ordinary HashMaps.
return StorageServiceImpl(attachments, keypair, identity, { tableName -> name })
return StorageServiceImpl(attachments, checkpointStorage, keypair, identity, { tableName -> name })
}
}
}

View File

@ -0,0 +1,96 @@
package core.node.storage
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.google.common.primitives.Ints
import core.serialization.SerializedBytes
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.file.Files
class PerFileCheckpointStorageTests {
val fileSystem = Jimfs.newFileSystem(unix())
val storeDir = fileSystem.getPath("store")
lateinit var checkpointStorage: PerFileCheckpointStorage
@Before
fun setUp() {
newCheckpointStorage()
}
@After
fun cleanUp() {
fileSystem.close()
}
@Test
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
}
@Test
fun `remove checkpoint`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.removeCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints).isEmpty()
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).isEmpty()
}
@Test
fun `remove unknown checkpoint`() {
val checkpoint = newCheckpoint()
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
checkpointStorage.removeCheckpoint(checkpoint)
}
}
@Test
fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(firstCheckpoint)
val secondCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(secondCheckpoint)
checkpointStorage.removeCheckpoint(firstCheckpoint)
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(secondCheckpoint)
}
@Test
fun `add checkpoint and then remove after 'restart'`() {
val originalCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(originalCheckpoint)
newCheckpointStorage()
val reconstructedCheckpoint = checkpointStorage.checkpoints.single()
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
assertThat(checkpointStorage.checkpoints).isEmpty()
}
@Test
fun `non-checkpoint files are ignored`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray())
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints).containsExactly(checkpoint)
}
private fun newCheckpointStorage() {
checkpointStorage = PerFileCheckpointStorage(storeDir)
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), "topic", "javaType")
}