Move FiberRequest out to a top level class

Move FiberRequest out to a top level class, both because it is expanding as functionality is added,
and to enable alternative state machine implementations to share it.
This commit is contained in:
Ross Nicoll 2016-07-24 09:30:16 +01:00
parent 31ee8ab60b
commit 2f04d876ae
5 changed files with 92 additions and 88 deletions

View File

@ -1,6 +1,7 @@
package com.r3corda.node.services.api
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.statemachine.FiberRequest
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import com.r3corda.node.services.statemachine.StateMachineManager
@ -32,5 +33,5 @@ interface CheckpointStorage {
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val request: StateMachineManager.FiberRequest?
val request: FiberRequest?
)

View File

@ -0,0 +1,81 @@
package com.r3corda.node.services.statemachine
import com.r3corda.core.crypto.Party
// TODO: Clean this up
sealed class FiberRequest(val topic: String,
val destination: Party?,
val sessionIDForSend: Long,
val sessionIDForReceive: Long,
val payload: Any?) {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
@Transient
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
val receiveTopic: String
get() = topic + "." + sessionIDForReceive
override fun equals(other: Any?): Boolean
= if (other is FiberRequest) {
topic == other.topic
&& destination == other.destination
&& sessionIDForSend == other.sessionIDForSend
&& sessionIDForReceive == other.sessionIDForReceive
&& payload == other.payload
} else
false
override fun hashCode(): Int {
var hash = 1L
hash = (hash * 31) + topic.hashCode()
hash = (hash * 31) + if (destination == null)
0
else
destination.hashCode()
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + if (payload == null)
0
else
payload.hashCode()
return hash.toInt()
}
/**
* A fiber which is expecting a message response.
*/
class ExpectingResponse<R : Any>(
topic: String,
destination: Party?,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: Any?,
type: Class<R>
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
private val responseTypeName: String = type.name
override fun equals(other: Any?): Boolean
= if (other is ExpectingResponse<*>) {
super.equals(other)
&& responseTypeName == other.responseTypeName
} else
false
override fun toString(): String {
return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}"
}
val responseType: Class<R>
get() = Class.forName(responseTypeName) as Class<R>
}
class NotExpectingResponse(
topic: String,
destination: Party,
sessionIDForSend: Long,
obj: Any?
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
}

View File

@ -28,7 +28,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit
@Transient internal lateinit var suspendAction: (FiberRequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal var receivedPayload: Any? = null
@ -72,7 +72,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
}
@Suspendable @Suppress("UNCHECKED_CAST")
private fun <T : Any> suspendAndExpectReceive(with: StateMachineManager.FiberRequest): UntrustworthyData<T> {
private fun <T : Any> suspendAndExpectReceive(with: FiberRequest): UntrustworthyData<T> {
suspend(with)
check(receivedPayload != null) { "Expected to receive something" }
val untrustworthy = UntrustworthyData(receivedPayload as T)
@ -87,24 +87,24 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
sessionIDForReceive: Long,
payload: Any,
recvType: Class<T>): UntrustworthyData<T> {
val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType)
val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType)
return suspendAndExpectReceive(result)
}
@Suspendable
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> {
val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType)
val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType)
return suspendAndExpectReceive(result)
}
@Suspendable
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
val result = StateMachineManager.FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload)
val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload)
suspend(result)
}
@Suspendable
private fun suspend(with: StateMachineManager.FiberRequest) {
private fun suspend(with: FiberRequest) {
parkAndSerialize { fiber, serializer ->
try {
suspendAction(with)

View File

@ -8,7 +8,6 @@ import com.esotericsoftware.kryo.Kryo
import com.google.common.base.Throwables
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.abbreviate
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic
@ -303,7 +302,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
responseType: Class<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
request: StateMachineManager.FiberRequest.ExpectingResponse<*>,
request: FiberRequest.ExpectingResponse<*>,
resumeAction: (Any?) -> Unit) {
val topic = request.receiveTopic
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
@ -323,84 +322,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
iterateStateMachine(psm, payload, resumeAction)
}
}
// TODO: Clean this up
sealed class FiberRequest(val topic: String,
val destination: Party?,
val sessionIDForSend: Long,
val sessionIDForReceive: Long,
val payload: Any?) {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
@Transient
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
val receiveTopic: String
get() = topic + "." + sessionIDForReceive
override fun equals(other: Any?): Boolean
= if (other is FiberRequest) {
topic == other.topic
&& destination == other.destination
&& sessionIDForSend == other.sessionIDForSend
&& sessionIDForReceive == other.sessionIDForReceive
&& payload == other.payload
} else
false
override fun hashCode(): Int {
var hash = 1L
hash = (hash * 31) + topic.hashCode()
hash = (hash * 31) + if (destination == null)
0
else
destination.hashCode()
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + sessionIDForReceive
hash = (hash * 31) + if (payload == null)
0
else
payload.hashCode()
return hash.toInt()
}
/**
* A fiber which is expecting a message response.
*/
class ExpectingResponse<R : Any>(
topic: String,
destination: Party?,
sessionIDForSend: Long,
sessionIDForReceive: Long,
obj: Any?,
type: Class<R>
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
private val responseTypeName: String = type.name
override fun equals(other: Any?): Boolean
= if (other is FiberRequest.ExpectingResponse<*>) {
super.equals(other)
&& responseTypeName == other.responseTypeName
} else
false
override fun toString(): String {
return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}"
}
val responseType: Class<R>
get() = Class.forName(responseTypeName) as Class<R>
}
class NotExpectingResponse(
topic: String,
destination: Party,
sessionIDForSend: Long,
obj: Any?
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
}
}
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")

View File

@ -6,6 +6,7 @@ import com.google.common.primitives.Ints
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.statemachine.FiberRequest
import com.r3corda.node.services.statemachine.StateMachineManager
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
@ -94,7 +95,7 @@ class PerFileCheckpointStorageTests {
}
private var checkpointCount = 1
private val request = StateMachineManager.FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null,
private val request = FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null,
java.lang.String::class.java)
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request)