mirror of
https://github.com/corda/corda.git
synced 2025-06-22 17:09:00 +00:00
Move FiberRequest out to a top level class
Move FiberRequest out to a top level class, both because it is expanding as functionality is added, and to enable alternative state machine implementations to share it.
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
package com.r3corda.node.services.api
|
package com.r3corda.node.services.api
|
||||||
|
|
||||||
import com.r3corda.core.serialization.SerializedBytes
|
import com.r3corda.core.serialization.SerializedBytes
|
||||||
|
import com.r3corda.node.services.statemachine.FiberRequest
|
||||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||||
import com.r3corda.node.services.statemachine.StateMachineManager
|
import com.r3corda.node.services.statemachine.StateMachineManager
|
||||||
|
|
||||||
@ -32,5 +33,5 @@ interface CheckpointStorage {
|
|||||||
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
||||||
data class Checkpoint(
|
data class Checkpoint(
|
||||||
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||||
val request: StateMachineManager.FiberRequest?
|
val request: FiberRequest?
|
||||||
)
|
)
|
@ -0,0 +1,81 @@
|
|||||||
|
package com.r3corda.node.services.statemachine
|
||||||
|
|
||||||
|
import com.r3corda.core.crypto.Party
|
||||||
|
|
||||||
|
// TODO: Clean this up
|
||||||
|
sealed class FiberRequest(val topic: String,
|
||||||
|
val destination: Party?,
|
||||||
|
val sessionIDForSend: Long,
|
||||||
|
val sessionIDForReceive: Long,
|
||||||
|
val payload: Any?) {
|
||||||
|
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
|
||||||
|
// don't have the original stack trace because it's in a suspended fiber.
|
||||||
|
@Transient
|
||||||
|
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
|
||||||
|
|
||||||
|
val receiveTopic: String
|
||||||
|
get() = topic + "." + sessionIDForReceive
|
||||||
|
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean
|
||||||
|
= if (other is FiberRequest) {
|
||||||
|
topic == other.topic
|
||||||
|
&& destination == other.destination
|
||||||
|
&& sessionIDForSend == other.sessionIDForSend
|
||||||
|
&& sessionIDForReceive == other.sessionIDForReceive
|
||||||
|
&& payload == other.payload
|
||||||
|
} else
|
||||||
|
false
|
||||||
|
|
||||||
|
override fun hashCode(): Int {
|
||||||
|
var hash = 1L
|
||||||
|
|
||||||
|
hash = (hash * 31) + topic.hashCode()
|
||||||
|
hash = (hash * 31) + if (destination == null)
|
||||||
|
0
|
||||||
|
else
|
||||||
|
destination.hashCode()
|
||||||
|
hash = (hash * 31) + sessionIDForReceive
|
||||||
|
hash = (hash * 31) + sessionIDForReceive
|
||||||
|
hash = (hash * 31) + if (payload == null)
|
||||||
|
0
|
||||||
|
else
|
||||||
|
payload.hashCode()
|
||||||
|
|
||||||
|
return hash.toInt()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A fiber which is expecting a message response.
|
||||||
|
*/
|
||||||
|
class ExpectingResponse<R : Any>(
|
||||||
|
topic: String,
|
||||||
|
destination: Party?,
|
||||||
|
sessionIDForSend: Long,
|
||||||
|
sessionIDForReceive: Long,
|
||||||
|
obj: Any?,
|
||||||
|
type: Class<R>
|
||||||
|
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
|
||||||
|
private val responseTypeName: String = type.name
|
||||||
|
|
||||||
|
override fun equals(other: Any?): Boolean
|
||||||
|
= if (other is ExpectingResponse<*>) {
|
||||||
|
super.equals(other)
|
||||||
|
&& responseTypeName == other.responseTypeName
|
||||||
|
} else
|
||||||
|
false
|
||||||
|
|
||||||
|
override fun toString(): String {
|
||||||
|
return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}"
|
||||||
|
}
|
||||||
|
val responseType: Class<R>
|
||||||
|
get() = Class.forName(responseTypeName) as Class<R>
|
||||||
|
}
|
||||||
|
|
||||||
|
class NotExpectingResponse(
|
||||||
|
topic: String,
|
||||||
|
destination: Party,
|
||||||
|
sessionIDForSend: Long,
|
||||||
|
obj: Any?
|
||||||
|
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
|
||||||
|
}
|
@ -28,7 +28,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
|
|||||||
|
|
||||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
@Transient lateinit override var serviceHub: ServiceHubInternal
|
||||||
@Transient internal lateinit var suspendAction: (StateMachineManager.FiberRequest) -> Unit
|
@Transient internal lateinit var suspendAction: (FiberRequest) -> Unit
|
||||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||||
@Transient internal var receivedPayload: Any? = null
|
@Transient internal var receivedPayload: Any? = null
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable @Suppress("UNCHECKED_CAST")
|
@Suspendable @Suppress("UNCHECKED_CAST")
|
||||||
private fun <T : Any> suspendAndExpectReceive(with: StateMachineManager.FiberRequest): UntrustworthyData<T> {
|
private fun <T : Any> suspendAndExpectReceive(with: FiberRequest): UntrustworthyData<T> {
|
||||||
suspend(with)
|
suspend(with)
|
||||||
check(receivedPayload != null) { "Expected to receive something" }
|
check(receivedPayload != null) { "Expected to receive something" }
|
||||||
val untrustworthy = UntrustworthyData(receivedPayload as T)
|
val untrustworthy = UntrustworthyData(receivedPayload as T)
|
||||||
@ -87,24 +87,24 @@ class ProtocolStateMachineImpl<R>(val logic: ProtocolLogic<R>,
|
|||||||
sessionIDForReceive: Long,
|
sessionIDForReceive: Long,
|
||||||
payload: Any,
|
payload: Any,
|
||||||
recvType: Class<T>): UntrustworthyData<T> {
|
recvType: Class<T>): UntrustworthyData<T> {
|
||||||
val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType)
|
val result = FiberRequest.ExpectingResponse(topic, destination, sessionIDForSend, sessionIDForReceive, payload, recvType)
|
||||||
return suspendAndExpectReceive(result)
|
return suspendAndExpectReceive(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> {
|
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, recvType: Class<T>): UntrustworthyData<T> {
|
||||||
val result = StateMachineManager.FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType)
|
val result = FiberRequest.ExpectingResponse(topic, null, -1, sessionIDForReceive, null, recvType)
|
||||||
return suspendAndExpectReceive(result)
|
return suspendAndExpectReceive(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
|
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
|
||||||
val result = StateMachineManager.FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload)
|
val result = FiberRequest.NotExpectingResponse(topic, destination, sessionID, payload)
|
||||||
suspend(result)
|
suspend(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Suspendable
|
@Suspendable
|
||||||
private fun suspend(with: StateMachineManager.FiberRequest) {
|
private fun suspend(with: FiberRequest) {
|
||||||
parkAndSerialize { fiber, serializer ->
|
parkAndSerialize { fiber, serializer ->
|
||||||
try {
|
try {
|
||||||
suspendAction(with)
|
suspendAction(with)
|
||||||
|
@ -8,7 +8,6 @@ import com.esotericsoftware.kryo.Kryo
|
|||||||
import com.google.common.base.Throwables
|
import com.google.common.base.Throwables
|
||||||
import com.google.common.util.concurrent.ListenableFuture
|
import com.google.common.util.concurrent.ListenableFuture
|
||||||
import com.r3corda.core.abbreviate
|
import com.r3corda.core.abbreviate
|
||||||
import com.r3corda.core.crypto.Party
|
|
||||||
import com.r3corda.core.messaging.runOnNextMessage
|
import com.r3corda.core.messaging.runOnNextMessage
|
||||||
import com.r3corda.core.messaging.send
|
import com.r3corda.core.messaging.send
|
||||||
import com.r3corda.core.protocols.ProtocolLogic
|
import com.r3corda.core.protocols.ProtocolLogic
|
||||||
@ -303,7 +302,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
|
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
|
||||||
responseType: Class<*>,
|
responseType: Class<*>,
|
||||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||||
request: StateMachineManager.FiberRequest.ExpectingResponse<*>,
|
request: FiberRequest.ExpectingResponse<*>,
|
||||||
resumeAction: (Any?) -> Unit) {
|
resumeAction: (Any?) -> Unit) {
|
||||||
val topic = request.receiveTopic
|
val topic = request.receiveTopic
|
||||||
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
|
serviceHub.networkService.runOnNextMessage(topic, executor) { netMsg ->
|
||||||
@ -323,84 +322,6 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
|||||||
iterateStateMachine(psm, payload, resumeAction)
|
iterateStateMachine(psm, payload, resumeAction)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Clean this up
|
|
||||||
sealed class FiberRequest(val topic: String,
|
|
||||||
val destination: Party?,
|
|
||||||
val sessionIDForSend: Long,
|
|
||||||
val sessionIDForReceive: Long,
|
|
||||||
val payload: Any?) {
|
|
||||||
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
|
|
||||||
// don't have the original stack trace because it's in a suspended fiber.
|
|
||||||
@Transient
|
|
||||||
val stackTraceInCaseOfProblems: StackSnapshot? = StackSnapshot()
|
|
||||||
|
|
||||||
val receiveTopic: String
|
|
||||||
get() = topic + "." + sessionIDForReceive
|
|
||||||
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean
|
|
||||||
= if (other is FiberRequest) {
|
|
||||||
topic == other.topic
|
|
||||||
&& destination == other.destination
|
|
||||||
&& sessionIDForSend == other.sessionIDForSend
|
|
||||||
&& sessionIDForReceive == other.sessionIDForReceive
|
|
||||||
&& payload == other.payload
|
|
||||||
} else
|
|
||||||
false
|
|
||||||
|
|
||||||
override fun hashCode(): Int {
|
|
||||||
var hash = 1L
|
|
||||||
|
|
||||||
hash = (hash * 31) + topic.hashCode()
|
|
||||||
hash = (hash * 31) + if (destination == null)
|
|
||||||
0
|
|
||||||
else
|
|
||||||
destination.hashCode()
|
|
||||||
hash = (hash * 31) + sessionIDForReceive
|
|
||||||
hash = (hash * 31) + sessionIDForReceive
|
|
||||||
hash = (hash * 31) + if (payload == null)
|
|
||||||
0
|
|
||||||
else
|
|
||||||
payload.hashCode()
|
|
||||||
|
|
||||||
return hash.toInt()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A fiber which is expecting a message response.
|
|
||||||
*/
|
|
||||||
class ExpectingResponse<R : Any>(
|
|
||||||
topic: String,
|
|
||||||
destination: Party?,
|
|
||||||
sessionIDForSend: Long,
|
|
||||||
sessionIDForReceive: Long,
|
|
||||||
obj: Any?,
|
|
||||||
type: Class<R>
|
|
||||||
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj) {
|
|
||||||
private val responseTypeName: String = type.name
|
|
||||||
|
|
||||||
override fun equals(other: Any?): Boolean
|
|
||||||
= if (other is FiberRequest.ExpectingResponse<*>) {
|
|
||||||
super.equals(other)
|
|
||||||
&& responseTypeName == other.responseTypeName
|
|
||||||
} else
|
|
||||||
false
|
|
||||||
|
|
||||||
override fun toString(): String {
|
|
||||||
return "Expecting response via topic ${receiveTopic} of type ${responseTypeName}"
|
|
||||||
}
|
|
||||||
val responseType: Class<R>
|
|
||||||
get() = Class.forName(responseTypeName) as Class<R>
|
|
||||||
}
|
|
||||||
|
|
||||||
class NotExpectingResponse(
|
|
||||||
topic: String,
|
|
||||||
destination: Party,
|
|
||||||
sessionIDForSend: Long,
|
|
||||||
obj: Any?
|
|
||||||
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
class StackSnapshot : Throwable("This is a stack trace to help identify the source of the underlying problem")
|
||||||
|
@ -6,6 +6,7 @@ import com.google.common.primitives.Ints
|
|||||||
import com.r3corda.core.random63BitValue
|
import com.r3corda.core.random63BitValue
|
||||||
import com.r3corda.core.serialization.SerializedBytes
|
import com.r3corda.core.serialization.SerializedBytes
|
||||||
import com.r3corda.node.services.api.Checkpoint
|
import com.r3corda.node.services.api.Checkpoint
|
||||||
|
import com.r3corda.node.services.statemachine.FiberRequest
|
||||||
import com.r3corda.node.services.statemachine.StateMachineManager
|
import com.r3corda.node.services.statemachine.StateMachineManager
|
||||||
import org.assertj.core.api.Assertions.assertThat
|
import org.assertj.core.api.Assertions.assertThat
|
||||||
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||||
@ -94,7 +95,7 @@ class PerFileCheckpointStorageTests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private var checkpointCount = 1
|
private var checkpointCount = 1
|
||||||
private val request = StateMachineManager.FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null,
|
private val request = FiberRequest.ExpectingResponse("topic", null, random63BitValue(), random63BitValue(), null,
|
||||||
java.lang.String::class.java)
|
java.lang.String::class.java)
|
||||||
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request)
|
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), request)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user