mirror of
https://github.com/corda/corda.git
synced 2024-12-20 05:28:21 +00:00
Rework checkpoint storage to include the FiberRequest
Rework checkpoint storage to include the FiberRequest, so that different requests can be supported.
This commit is contained in:
parent
2c139ae40c
commit
31ee8ab60b
@ -2,6 +2,7 @@ package com.r3corda.node.services.api
|
||||
|
||||
import com.r3corda.core.serialization.SerializedBytes
|
||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager
|
||||
|
||||
/**
|
||||
* Thread-safe storage of fiber checkpoints.
|
||||
@ -31,7 +32,5 @@ interface CheckpointStorage {
|
||||
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
||||
data class Checkpoint(
|
||||
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||
val awaitingTopic: String?,
|
||||
val awaitingPayloadType: String?,
|
||||
val receivedPayload: Any?
|
||||
val request: StateMachineManager.FiberRequest?
|
||||
)
|
@ -119,28 +119,40 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
||||
initFiber(fiber, { checkpoint })
|
||||
|
||||
val topic = checkpoint.awaitingTopic
|
||||
if (topic != null) {
|
||||
val awaitingPayloadType = Class.forName(checkpoint.awaitingPayloadType)
|
||||
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic")
|
||||
iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, topic) {
|
||||
try {
|
||||
Fiber.unparkDeserialized(fiber, scheduler)
|
||||
} catch (e: Throwable) {
|
||||
logError(e, it, topic, fiber)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}")
|
||||
executor.executeASAP {
|
||||
iterateStateMachine(fiber, checkpoint.receivedPayload) {
|
||||
when (checkpoint.request) {
|
||||
is FiberRequest.ExpectingResponse<*> -> {
|
||||
val topic = checkpoint.request.receiveTopic
|
||||
val awaitingPayloadType = checkpoint.request.responseType
|
||||
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${awaitingPayloadType.name} on topic $topic")
|
||||
iterateOnResponse(fiber, awaitingPayloadType, checkpoint.serialisedFiber, checkpoint.request) {
|
||||
try {
|
||||
Fiber.unparkDeserialized(fiber, scheduler)
|
||||
} catch (e: Throwable) {
|
||||
logError(e, it, null, fiber)
|
||||
logError(e, it, topic, fiber)
|
||||
}
|
||||
}
|
||||
}
|
||||
is FiberRequest.NotExpectingResponse -> restoreNotExpectingResponse(fiber, checkpoint.request)
|
||||
null -> restoreNotExpectingResponse(fiber)
|
||||
else -> throw IllegalStateException("Unknown fiber request type " + checkpoint.request)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Restore a Fiber which was not expecting a response (either because it hasn't asked for one, such as a new
|
||||
* Fiber, or because the detail it needed has arrived).
|
||||
*/
|
||||
private fun restoreNotExpectingResponse(fiber: ProtocolStateMachineImpl<*>, request: FiberRequest? = null) {
|
||||
val payload = request?.payload
|
||||
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${payload.toString().abbreviate(50)}")
|
||||
executor.executeASAP {
|
||||
iterateStateMachine(fiber, payload) {
|
||||
try {
|
||||
Fiber.unparkDeserialized(fiber, scheduler)
|
||||
} catch (e: Throwable) {
|
||||
logError(e, it, null, fiber)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,7 +215,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
|
||||
// Need to add before iterating in case of immediate completion
|
||||
initFiber(fiber) {
|
||||
val checkpoint = Checkpoint(serializeFiber(fiber), null, null, null)
|
||||
val checkpoint = Checkpoint(serializeFiber(fiber), null)
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
checkpoint
|
||||
}
|
||||
@ -222,10 +234,8 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
|
||||
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
|
||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||
awaitingTopic: String?,
|
||||
awaitingPayloadType: Class<*>?,
|
||||
receivedPayload: Any?) {
|
||||
val newCheckpoint = Checkpoint(serialisedFiber, awaitingTopic, awaitingPayloadType?.name, receivedPayload)
|
||||
request: FiberRequest) {
|
||||
val newCheckpoint = Checkpoint(serialisedFiber, request)
|
||||
val previousCheckpoint = stateMachines.put(psm, newCheckpoint)
|
||||
if (previousCheckpoint != null) {
|
||||
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
||||
@ -245,9 +255,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
|
||||
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: FiberRequest) {
|
||||
// We have a request to do something: send, receive, or send-and-receive.
|
||||
if (request is FiberRequest.ExpectingResponse<*>) {
|
||||
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
||||
checkpointOnExpectingResponse(psm, request)
|
||||
when (request) {
|
||||
is FiberRequest.ExpectingResponse<*> -> {
|
||||
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
||||
checkpointOnExpectingResponse(psm, request)
|
||||
}
|
||||
}
|
||||
// If a non-null payload to send was provided, send it now.
|
||||
request.payload?.let {
|
||||
@ -273,9 +285,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
executor.checkOnThread()
|
||||
val topic = "${request.topic}.${request.sessionIDForReceive}"
|
||||
val serialisedFiber = serializeFiber(psm)
|
||||
updateCheckpoint(psm, serialisedFiber, topic, request.responseType, null)
|
||||
updateCheckpoint(psm, serialisedFiber, request)
|
||||
psm.logger.trace { "Preparing to receive message of type ${request.responseType.name} on topic $topic" }
|
||||
iterateOnResponse(psm, request.responseType, serialisedFiber, topic) {
|
||||
iterateOnResponse(psm, request.responseType, serialisedFiber, request) {
|
||||
try {
|
||||
Fiber.unpark(psm, QUASAR_UNBLOCKER)
|
||||
} catch(e: Throwable) {
|
||||
@ -284,11 +296,16 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
|
||||
* received.
|
||||
*/
|
||||
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
|
||||
responseType: Class<*>,
|
||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||
topic: String,
|
||||
request: StateMachineManager.FiberRequest.ExpectingResponse<*>,
|
||||
resumeAction: (Any?) -> Unit) {
|
||||
val topic = request.receiveTopic
|
||||
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
|
||||
// Assertion to ensure we don't execute on the wrong thread.
|
||||
executor.checkOnThread()
|
||||
@ -301,30 +318,81 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
val payload = netMsg.data.deserialize<Any>()
|
||||
check(responseType.isInstance(payload)) { "Expected message of type ${responseType.name} but got ${payload.javaClass.name}" }
|
||||
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
|
||||
updateCheckpoint(psm, serialisedFiber, null, null, payload)
|
||||
updateCheckpoint(psm, serialisedFiber, request)
|
||||
psm.logger.trace { "Received message of type ${payload.javaClass.name} on topic $topic (${payload.toString().abbreviate(50)})" }
|
||||
iterateStateMachine(psm, payload, resumeAction)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Clean this up
|
||||
open class FiberRequest(val topic: String,
|
||||
val destination: Party?,
|
||||
val sessionIDForSend: Long,
|
||||
val sessionIDForReceive: Long,
|
||||
val payload: Any?) {
|
||||
sealed class FiberRequest(val topic: String,
|
||||
val destination: Party?,
|
||||
val sessionIDForSend: Long,
|
||||
val sessionIDForReceive: Long,
|
||||
val payload: Any?) {
|
||||
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
|
||||
// don't have the original stack trace because it's in a suspended fiber.
|
||||
val stackTraceInCaseOfProblems = StackSnapshot()
|
||||
@Transient
|
||||
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
|
||||
|
||||
val receiveTopic: String
|
||||
get() = topic + "." + sessionIDForReceive
|
||||
|
||||
|
||||
override fun equals(other: Any?): Boolean
|
||||
= if (other is FiberRequest) {
|
||||
topic == other.topic
|
||||
&& destination == other.destination
|
||||
&& sessionIDForSend == other.sessionIDForSend
|
||||
&& sessionIDForReceive == other.sessionIDForReceive
|
||||
&& payload == other.payload
|
||||
} else
|
||||
false
|
||||
|
||||
override fun hashCode(): Int {
|
||||
var hash = 1L
|
||||
|
||||
hash = (hash * 31) + topic.hashCode()
|
||||
hash = (hash * 31) + if (destination == null)
|
||||
0
|
||||
else
|
||||
destination.hashCode()
|
||||
hash = (hash * 31) + sessionIDForReceive
|
||||
hash = (hash * 31) + sessionIDForReceive
|
||||
hash = (hash * 31) + if (payload == null)
|
||||
0
|
||||
else
|
||||
payload.hashCode()
|
||||
|
||||
return hash.toInt()
|
||||
}
|
||||
|
||||
/**
|
||||
* A fiber which is expecting a message response.
|
||||
*/
|
||||
class ExpectingResponse<R : Any>(
|
||||
topic: String,
|
||||
destination: Party?,
|
||||
sessionIDForSend: Long,
|
||||
sessionIDForReceive: Long,
|
||||
obj: Any?,
|
||||
val responseType: Class<R>
|
||||
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj)
|
||||
type: Class<R>
|
||||
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
|
||||
private val responseTypeName: String = type.name
|
||||
|
||||
override fun equals(other: Any?): Boolean
|
||||
= if (other is FiberRequest.ExpectingResponse<*>) {
|
||||
super.equals(other)
|
||||
&& responseTypeName == other.responseTypeName
|
||||
} else
|
||||
false
|
||||
|
||||
override fun toString(): String {
|
||||
return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}"
|
||||
}
|
||||
val responseType: Class<R>
|
||||
get() = Class.forName(responseTypeName) as Class<R>
|
||||
}
|
||||
|
||||
class NotExpectingResponse(
|
||||
topic: String,
|
||||
|
@ -3,8 +3,10 @@ package com.r3corda.node.services.persistence
|
||||
import com.google.common.jimfs.Configuration.unix
|
||||
import com.google.common.jimfs.Jimfs
|
||||
import com.google.common.primitives.Ints
|
||||
import com.r3corda.core.random63BitValue
|
||||
import com.r3corda.core.serialization.SerializedBytes
|
||||
import com.r3corda.node.services.api.Checkpoint
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||
import org.junit.After
|
||||
@ -92,6 +94,8 @@ class PerFileCheckpointStorageTests {
|
||||
}
|
||||
|
||||
private var checkpointCount = 1
|
||||
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), "topic", "javaType", null)
|
||||
private val request = StateMachineManager.FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null,
|
||||
java.lang.String::class.java)
|
||||
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request)
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user