Merged in rnicoll-checkpoint-generic (pull request #241)

Rework checkpoint storage to include the FiberRequest
This commit is contained in:
Ross Nicoll 2016-07-27 15:20:40 +01:00
commit 8bdeda63ae
5 changed files with 141 additions and 66 deletions

View File

@ -1,7 +1,9 @@
package com.r3corda.node.services.api package com.r3corda.node.services.api
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.statemachine.FiberRequest
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import com.r3corda.node.services.statemachine.StateMachineManager
/** /**
* Thread-safe storage of fiber checkpoints. * Thread-safe storage of fiber checkpoints.
@ -31,7 +33,5 @@ interface CheckpointStorage {
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo). // This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
data class Checkpoint( data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val awaitingTopic: String?, val request: FiberRequest?
val awaitingPayloadType: String?,
val receivedPayload: Any?
) )

View File

@ -0,0 +1,81 @@
package com.r3corda.node.services.statemachine
import com.r3corda.core.crypto.Party
// TODO: Clean this up
sealed class FiberRequest(val topic: String,
val destination: Party?,
val sessionIDForSend: Long,
val sessionIDForReceive: Long,
val payload: Any?) {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
@Transient
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
val receiveTopic: String
get() = topic + "." + sessionIDForReceive
override fun equals(other: Any?): Boolean
= if (other is FiberRequest) {
topic == other.topic
&& destination == other.destination
&& sessionIDForSend == other.sessionIDForSend
&& sessionIDForReceive == other.sessionIDForReceive
&& payload == other.payload
} else
false
override fun hashCode(): Int {
var hash = 1L
hash = (hash * 31) + topic.hashCode()
hash = (hash * 31) + if (destination == null)
0
else
destination.hashCode()
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + if (payload == null)
0
else
payload.hashCode()
return hash.toInt()
}
/**
* A fiber which is expecting a message response.
*/
class ExpectingResponse<R : Any>(
topic: String,
destination: Party?,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: Any?,
type: Class<R>
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
private val responseTypeName: String = type.name
override fun equals(other: Any?): Boolean
= if (other is ExpectingResponse<*>) {
super.equals(other)
&& responseTypeName == other.responseTypeName
} else
false
override fun toString(): String {
return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}"
}
val responseType: Class<R>
get() = Class.forName(responseTypeName) as Class<R>
}
class NotExpectingResponse(
topic: String,
destination: Party,
sessionIDForSend: Long,
obj: Any?
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
}

View File

@ -28,7 +28,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
// These fields shouldn't be serialised, so they are marked @Transient. // These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal @Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit @Transient internal lateinit var suspendAction: (FiberRequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit @Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal var receivedPayload: Any? = null @Transient internal var receivedPayload: Any? = null
@ -72,7 +72,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
} }
@Suspendable @Suppress("UNCHECKED_CAST") @Suspendable @Suppress("UNCHECKED_CAST")
private fun <T : Any> suspendAndExpectReceive(with: StateMachineManager.FiberRequest): UntrustworthyData<T> { private fun <T : Any> suspendAndExpectReceive(with: FiberRequest): UntrustworthyData<T> {
suspend(with) suspend(with)
check(receivedPayload != null) { "Expected to receive something" } check(receivedPayload != null) { "Expected to receive something" }
val untrustworthy = UntrustworthyData(receivedPayload as T) val untrustworthy = UntrustworthyData(receivedPayload as T)
@ -87,24 +87,24 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
sessionIDForReceive: Long, sessionIDForReceive: Long,
payload: Any, payload: Any,
recvType: Class<T>): UntrustworthyData<T> { recvType: Class<T>): UntrustworthyData<T> {
val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType) val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType)
return suspendAndExpectReceive(result) return suspendAndExpectReceive(result)
} }
@Suspendable @Suspendable
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> { override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> {
val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType) val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType)
return suspendAndExpectReceive(result) return suspendAndExpectReceive(result)
} }
@Suspendable @Suspendable
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) { override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
val result = StateMachineManager.FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload) val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload)
suspend(result) suspend(result)
} }
@Suspendable @Suspendable
private fun suspend(with: StateMachineManager.FiberRequest) { private fun suspend(with: FiberRequest) {
parkAndSerialize { fiber, serializer -> parkAndSerialize { fiber, serializer ->
try { try {
suspendAction(with) suspendAction(with)

View File

@ -8,7 +8,6 @@ import com.esotericsoftware.kryo.Kryo
import com.google.common.base.Throwables import com.google.common.base.Throwables
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.abbreviate import com.r3corda.core.abbreviate
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.protocols.ProtocolLogic
@ -123,28 +122,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val fiber = deserializeFiber(checkpoint.serialisedFiber) val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint }) initFiber(fiber, { checkpoint })
val topic = checkpoint.awaitingTopic when (checkpoint.request) {
if (topic != null) { is FiberRequest.ExpectingResponse<*> -> {
val awaitingPayloadType = Class.forName(checkpoint.awaitingPayloadType) val topic = checkpoint.request.receiveTopic
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic") val awaitingPayloadType = checkpoint.request.responseType
iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, topic) { fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic")
try { iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, checkpoint.request) {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, topic, fiber)
}
}
} else {
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}")
executor.executeASAP {
iterateStateMachine(fiber, checkpoint.receivedPayload) {
try { try {
Fiber.unparkDeserialized(fiber, scheduler) Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) { } catch (e: Throwable) {
logError(e, it, null, fiber) logError(e, it, topic, fiber)
} }
} }
} }
is FiberRequest.NotExpectingResponse -> restoreNotExpectingResponse(fiber, checkpoint.request)
null -> restoreNotExpectingResponse(fiber)
else -> throw IllegalStateException("Unknown fiber request type " + checkpoint.request)
}
}
/**
* Restore a Fiber which was not expecting a response (either because it hasn't asked for one, such as a new
* Fiber, or because the detail it needed has arrived).
*/
private fun restoreNotExpectingResponse(fiber: ProtocolStateMachineImpl<*>, request: FiberRequest? = null) {
val payload = request?.payload
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${payload.toString().abbreviate(50)}")
executor.executeASAP {
iterateStateMachine(fiber, payload) {
try {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, null, fiber)
}
}
} }
} }
@ -207,7 +218,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName) val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion // Need to add before iterating in case of immediate completion
initFiber(fiber) { initFiber(fiber) {
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null) val checkpoint = Checkpoint(serializeFiber(fiber), null)
checkpointStorage.addCheckpoint(checkpoint) checkpointStorage.addCheckpoint(checkpoint)
checkpoint checkpoint
} }
@ -226,10 +237,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>, private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
awaitingTopic: String?, request: FiberRequest) {
awaitingPayloadType: Class<*>?, val newCheckpoint = Checkpoint(serialisedFiber, request)
receivedPayload: Any?) {
val newCheckpoint = Checkpoint(serialisedFiber, awaitingTopic, awaitingPayloadType?.name, receivedPayload)
val previousCheckpoint = stateMachines.put(psm, newCheckpoint) val previousCheckpoint = stateMachines.put(psm, newCheckpoint)
if (previousCheckpoint != null) { if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint) checkpointStorage.removeCheckpoint(previousCheckpoint)
@ -249,9 +258,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) { private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) {
// We have a request to do something: send, receive, or send-and-receive. // We have a request to do something: send, receive, or send-and-receive.
if (request is FiberRequest.ExpectingResponse<*>) { when (request) {
// Prepare a listener on the network that runs in the background thread when we receive a message. is FiberRequest.ExpectingResponse<*> -> {
checkpointOnExpectingResponse(psm, request) // Prepare a listener on the network that runs in the background thread when we receive a message.
checkpointOnExpectingResponse(psm, request)
}
} }
// If a non-null payload to send was provided, send it now. // If a non-null payload to send was provided, send it now.
request.payload?.let { request.payload?.let {
@ -277,9 +288,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
executor.checkOnThread() executor.checkOnThread()
val topic = "${request.topic}.${request.sessionIDForReceive}" val topic = "${request.topic}.${request.sessionIDForReceive}"
val serialisedFiber = serializeFiber(psm) val serialisedFiber = serializeFiber(psm)
updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null) updateCheckpoint(psm, serialisedFiber, request)
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" } psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" }
iterateOnResponse(psm, request.responseType, serialisedFiber, topic) { iterateOnResponse(psm, request.responseType, serialisedFiber, request) {
try { try {
Fiber.unpark(psm, QUASAR_UNBLOCKER) Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) { } catch(e: Throwable) {
@ -288,11 +299,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
} }
} }
/**
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
* received.
*/
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>, private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
responseType: Class<*>, responseType: Class<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
topic: String, request: FiberRequest.ExpectingResponse<*>,
resumeAction: (Any?) -> Unit) { resumeAction: (Any?) -> Unit) {
val topic = request.receiveTopic
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg -> serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
// Assertion to ensure we don't execute on the wrong thread. // Assertion to ensure we don't execute on the wrong thread.
executor.checkOnThread() executor.checkOnThread()
@ -305,38 +321,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val payload = netMsg.data.deserialize<Any>() val payload = netMsg.data.deserialize<Any>()
check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" } check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" }
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload // Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
updateCheckpoint(psm, serialisedFiber, null, null, payload) updateCheckpoint(psm, serialisedFiber, request)
psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic $topic (${payload.toString().abbreviate(50)})" } psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic $topic (${payload.toString().abbreviate(50)})" }
iterateStateMachine(psm, payload, resumeAction) iterateStateMachine(psm, payload, resumeAction)
} }
} }
// TODO: Clean this up
open class FiberRequest(val topic: String,
val destination: Party?,
val sessionIDForSend: Long,
val sessionIDForReceive: Long,
val payload: Any?) {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems = StackSnapshot()
class ExpectingResponse<R : Any>(
topic: String,
destination: Party?,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: Any?,
val responseType: Class<R>
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj)
class NotExpectingResponse(
topic: String,
destination: Party,
sessionIDForSend: Long,
obj: Any?
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
}
} }
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem") class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -3,8 +3,11 @@ package com.r3corda.node.services.persistence
import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs import com.google.common.jimfs.Jimfs
import com.google.common.primitives.Ints import com.google.common.primitives.Ints
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.api.Checkpoint import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.statemachine.FiberRequest
import com.r3corda.node.services.statemachine.StateMachineManager
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After import org.junit.After
@ -92,6 +95,8 @@ class PerFileCheckpointStorageTests {
} }
private var checkpointCount = 1 private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), "topic", "javaType", null) private val request = FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null,
java.lang.String::class.java)
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request)
} }