mirror of
https://github.com/corda/corda.git
synced 2025-06-22 17:09:00 +00:00
Automatic session management between two protocols, and removal of explict topics
This commit is contained in:
@ -24,6 +24,7 @@ import com.r3corda.core.serialization.SingletonSerializeAsToken
|
||||
import com.r3corda.core.serialization.deserialize
|
||||
import com.r3corda.core.serialization.serialize
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.utilities.debug
|
||||
import com.r3corda.node.api.APIServer
|
||||
import com.r3corda.node.services.api.*
|
||||
import com.r3corda.node.services.config.NodeConfiguration
|
||||
@ -54,8 +55,10 @@ import java.nio.file.Path
|
||||
import java.security.KeyPair
|
||||
import java.time.Clock
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* A base node implementation that can be customised either for production (with real implementations that do real
|
||||
@ -91,6 +94,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
|
||||
protected val _servicesThatAcceptUploads = ArrayList<AcceptsFileUpload>()
|
||||
val servicesThatAcceptUploads: List<AcceptsFileUpload> = _servicesThatAcceptUploads
|
||||
|
||||
private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
|
||||
|
||||
val services = object : ServiceHubInternal() {
|
||||
override val networkService: MessagingServiceInternal get() = net
|
||||
override val networkMapCache: NetworkMapCache get() = netMapCache
|
||||
@ -109,6 +114,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
|
||||
return smm.add(loggerName, logic).resultFuture
|
||||
}
|
||||
|
||||
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
|
||||
require(markerClass !in protocolFactories) { "${markerClass.java.name} has already been used to register a protocol" }
|
||||
log.debug { "Registering ${markerClass.java.name}" }
|
||||
protocolFactories[markerClass.java] = protocolFactory
|
||||
}
|
||||
|
||||
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
|
||||
return protocolFactories[markerClass]
|
||||
}
|
||||
|
||||
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(storage, txs)
|
||||
}
|
||||
|
||||
|
@ -1,11 +1,9 @@
|
||||
package com.r3corda.node.services
|
||||
|
||||
import com.r3corda.core.node.CordaPluginRegistry
|
||||
import com.r3corda.node.services.api.AbstractNodeService
|
||||
import com.r3corda.core.serialization.SingletonSerializeAsToken
|
||||
import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.protocols.AbstractStateReplacementProtocol
|
||||
import com.r3corda.protocols.NotaryChangeProtocol
|
||||
import com.r3corda.protocols.NotaryChangeProtocol.TOPIC
|
||||
|
||||
object NotaryChange {
|
||||
class Plugin : CordaPluginRegistry() {
|
||||
@ -16,11 +14,9 @@ object NotaryChange {
|
||||
* A service that monitors the network for requests for changing the notary of a state,
|
||||
* and immediately runs the [NotaryChangeProtocol] if the auto-accept criteria are met.
|
||||
*/
|
||||
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
|
||||
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
|
||||
init {
|
||||
addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake ->
|
||||
NotaryChangeProtocol.Acceptor(req.replyToParty)
|
||||
}
|
||||
services.registerProtocolInitiator(NotaryChangeProtocol.Instigator::class) { NotaryChangeProtocol.Acceptor(it) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,16 +1,12 @@
|
||||
package com.r3corda.node.services.api
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.messaging.Message
|
||||
import com.r3corda.core.messaging.MessageHandlerRegistration
|
||||
import com.r3corda.core.messaging.createMessage
|
||||
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.serialization.SingletonSerializeAsToken
|
||||
import com.r3corda.core.serialization.deserialize
|
||||
import com.r3corda.core.serialization.serialize
|
||||
import com.r3corda.core.utilities.loggerFor
|
||||
import com.r3corda.protocols.HandshakeMessage
|
||||
import com.r3corda.protocols.ServiceRequestMessage
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
|
||||
@ -20,10 +16,6 @@ import javax.annotation.concurrent.ThreadSafe
|
||||
@ThreadSafe
|
||||
abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() {
|
||||
|
||||
companion object {
|
||||
val logger = loggerFor<AbstractNodeService>()
|
||||
}
|
||||
|
||||
val net: MessagingServiceInternal get() = services.networkService
|
||||
|
||||
/**
|
||||
@ -68,36 +60,4 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
|
||||
return addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception })
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the
|
||||
* necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession.
|
||||
* @param topic the topic on which the handshake is sent from the other party
|
||||
* @param loggerName the logger name to use when starting the protocol
|
||||
* @param protocolFactory a function to create the protocol with the given handshake message
|
||||
* @param onResultFuture provides access to the [ListenableFuture] when the protocol starts
|
||||
*/
|
||||
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
|
||||
topic: String,
|
||||
loggerName: String,
|
||||
crossinline protocolFactory: (H) -> ProtocolLogic<R>,
|
||||
crossinline onResultFuture: ProtocolLogic<R>.(ListenableFuture<R>, H) -> Unit) {
|
||||
net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg ->
|
||||
try {
|
||||
val handshake = message.data.deserialize<H>()
|
||||
val protocol = protocolFactory(handshake)
|
||||
protocol.registerSession(handshake)
|
||||
val resultFuture = services.startProtocol(loggerName, protocol)
|
||||
protocol.onResultFuture(resultFuture, handshake)
|
||||
} catch (e: Exception) {
|
||||
logger.error("Unable to process ${H::class.java.name} message", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
|
||||
topic: String,
|
||||
loggerName: String,
|
||||
crossinline protocolFactory: (H) -> ProtocolLogic<R>) {
|
||||
addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> })
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package com.r3corda.node.services.api
|
||||
|
||||
import com.r3corda.core.crypto.SecureHash
|
||||
import com.r3corda.core.serialization.SerializedBytes
|
||||
import com.r3corda.node.services.statemachine.ProtocolIORequest
|
||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||
|
||||
/**
|
||||
@ -30,14 +30,13 @@ interface CheckpointStorage {
|
||||
}
|
||||
|
||||
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
|
||||
data class Checkpoint(
|
||||
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||
val request: ProtocolIORequest?,
|
||||
val receivedPayload: Any?
|
||||
) {
|
||||
// This flag is always false when loaded from storage as it isn't serialised.
|
||||
// It is used to track when the associated fiber has been created, but not necessarily started when
|
||||
// messages for protocols arrive before the system has fully loaded at startup.
|
||||
@Transient
|
||||
var fiberCreated: Boolean = false
|
||||
}
|
||||
class Checkpoint(val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>) {
|
||||
|
||||
val id: SecureHash get() = serialisedFiber.hash
|
||||
|
||||
override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id
|
||||
|
||||
override fun hashCode(): Int = id.hashCode()
|
||||
|
||||
override fun toString(): String = "${javaClass.simpleName}(id=$id)"
|
||||
}
|
||||
|
@ -1,14 +1,16 @@
|
||||
package com.r3corda.node.services.api
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.messaging.MessagingService
|
||||
import com.r3corda.core.node.ServiceHub
|
||||
import com.r3corda.core.node.services.TxWritableStorageService
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolLogicRefFactory
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
|
||||
import org.slf4j.LoggerFactory
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
interface MessagingServiceInternal : MessagingService {
|
||||
/**
|
||||
@ -49,7 +51,7 @@ abstract class ServiceHubInternal : ServiceHub {
|
||||
* @param txs The transactions to record.
|
||||
*/
|
||||
internal fun recordTransactionsInternal(writableStorageService: TxWritableStorageService, txs: Iterable<SignedTransaction>) {
|
||||
val stateMachineRunId = ProtocolStateMachineImpl.retrieveCurrentStateMachine()?.id
|
||||
val stateMachineRunId = ProtocolStateMachineImpl.currentStateMachine()?.id
|
||||
if (stateMachineRunId != null) {
|
||||
txs.forEach {
|
||||
storageService.stateMachineRecordedTransactionMapping.addMapping(stateMachineRunId, it.id)
|
||||
@ -68,6 +70,23 @@ abstract class ServiceHubInternal : ServiceHub {
|
||||
*/
|
||||
abstract fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T>
|
||||
|
||||
/**
|
||||
* Register the protocol factory we wish to use when a initiating party attempts to communicate with us. The
|
||||
* registration is done against a marker [KClass] which is sent in the session handsake by the other party. If this
|
||||
* marker class has been registered then the corresponding factory will be used to create the protocol which will
|
||||
* communicate with the other side. If there is no mapping then the session attempt is rejected.
|
||||
* @param markerClass The marker [KClass] present in a session initiation attempt, which is a 1:1 mapping to a [Class]
|
||||
* using the <pre>::class</pre> construct. Any marker class can be used, with the default being the class of the initiating
|
||||
* protocol. This enables the registration to be of the form: registerProtocolInitiator(InitiatorProtocol::class, ::InitiatedProtocol)
|
||||
* @param protocolFactory The protocol factory generating the initiated protocol.
|
||||
*/
|
||||
abstract fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>)
|
||||
|
||||
/**
|
||||
* Return the protocol factory that has been registered with [markerClass], or null if no factory is found.
|
||||
*/
|
||||
abstract fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)?
|
||||
|
||||
override fun <T : Any> invokeProtocolAsync(logicType: Class<out ProtocolLogic<T>>, vararg args: Any?): ListenableFuture<T> {
|
||||
val logicRef = protocolLogicRefFactory.create(logicType, *args)
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
|
@ -1,28 +1,25 @@
|
||||
package com.r3corda.node.services.clientapi
|
||||
|
||||
import com.r3corda.core.node.CordaPluginRegistry
|
||||
import com.r3corda.node.services.api.AbstractNodeService
|
||||
import com.r3corda.core.serialization.SingletonSerializeAsToken
|
||||
import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.protocols.TwoPartyDealProtocol
|
||||
import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC
|
||||
import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation
|
||||
import com.r3corda.protocols.TwoPartyDealProtocol.Fixer
|
||||
import com.r3corda.protocols.TwoPartyDealProtocol.Floater
|
||||
|
||||
/**
|
||||
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
|
||||
* running scheduled fixings for the [InterestRateSwap] contract.
|
||||
*
|
||||
* TODO: This will be replaced with the automatic sessionID / session setup work.
|
||||
* TODO: This will be replaced with the symmetric session work
|
||||
*/
|
||||
object FixingSessionInitiation {
|
||||
class Plugin: CordaPluginRegistry() {
|
||||
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
|
||||
}
|
||||
|
||||
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
|
||||
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
|
||||
init {
|
||||
addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation ->
|
||||
TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType)
|
||||
}
|
||||
services.registerProtocolInitiator(Floater::class) { Fixer(it) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -169,7 +169,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
|
||||
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
|
||||
val protocol = FinalityProtocol(tx, setOf(req), setOf(req.recipient))
|
||||
return TransactionBuildResult.ProtocolStarted(
|
||||
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
|
||||
smm.add("broadcast", protocol).id,
|
||||
tx,
|
||||
"Cash payment transaction generated"
|
||||
)
|
||||
@ -203,7 +203,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
|
||||
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
|
||||
val protocol = FinalityProtocol(tx, setOf(req), participants)
|
||||
return TransactionBuildResult.ProtocolStarted(
|
||||
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
|
||||
smm.add("broadcast", protocol).id,
|
||||
tx,
|
||||
"Cash destruction transaction generated"
|
||||
)
|
||||
@ -222,7 +222,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
|
||||
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
|
||||
val protocol = BroadcastTransactionProtocol(tx, setOf(req), setOf(req.recipient))
|
||||
return TransactionBuildResult.ProtocolStarted(
|
||||
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
|
||||
smm.add("broadcast", protocol).id,
|
||||
tx,
|
||||
"Cash issuance completed"
|
||||
)
|
||||
|
@ -1,17 +1,12 @@
|
||||
package com.r3corda.node.services.persistence
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.failure
|
||||
import com.r3corda.core.messaging.MessagingService
|
||||
import com.r3corda.core.messaging.TopicSession
|
||||
import com.r3corda.core.node.CordaPluginRegistry
|
||||
import com.r3corda.core.node.NodeInfo
|
||||
import com.r3corda.core.node.recordTransactions
|
||||
import com.r3corda.core.serialization.serialize
|
||||
import com.r3corda.core.success
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.serialization.SingletonSerializeAsToken
|
||||
import com.r3corda.core.utilities.loggerFor
|
||||
import com.r3corda.node.services.api.AbstractNodeService
|
||||
import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.protocols.*
|
||||
import java.io.InputStream
|
||||
@ -39,78 +34,73 @@ object DataVending {
|
||||
// TODO: I don't like that this needs ServiceHubInternal, but passing in a state machine breaks MockServices because
|
||||
// the state machine isn't set when this is constructed. [NodeSchedulerService] has the same problem, and both
|
||||
// should be fixed at the same time.
|
||||
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
|
||||
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
|
||||
|
||||
companion object {
|
||||
val logger = loggerFor<DataVending.Service>()
|
||||
|
||||
/**
|
||||
* Notify a node of a transaction. Normally any notarisation required would happen before this is called.
|
||||
*/
|
||||
fun notify(net: MessagingService,
|
||||
myIdentity: Party,
|
||||
recipient: NodeInfo,
|
||||
transaction: SignedTransaction) {
|
||||
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity)
|
||||
net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address)
|
||||
}
|
||||
}
|
||||
|
||||
val storage = services.storageService
|
||||
|
||||
class TransactionRejectedError(msg: String) : Exception(msg)
|
||||
|
||||
init {
|
||||
addMessageHandler(FetchTransactionsProtocol.TOPIC,
|
||||
{ req: FetchDataProtocol.Request -> handleTXRequest(req) },
|
||||
{ message, e -> logger.error("Failure processing data vending request.", e) }
|
||||
)
|
||||
|
||||
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
|
||||
{ req: FetchDataProtocol.Request -> handleAttachmentRequest(req) },
|
||||
{ message, e -> logger.error("Failure processing data vending request.", e) }
|
||||
)
|
||||
|
||||
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
|
||||
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
|
||||
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
|
||||
// cash without from unknown parties?
|
||||
addProtocolHandler(
|
||||
BroadcastTransactionProtocol.TOPIC,
|
||||
"Resolving transactions",
|
||||
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage ->
|
||||
ResolveTransactionsProtocol(req.tx, req.replyToParty)
|
||||
},
|
||||
{ future, req ->
|
||||
future.success {
|
||||
serviceHub.recordTransactions(req.tx)
|
||||
}.failure { throwable ->
|
||||
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
|
||||
}
|
||||
})
|
||||
services.registerProtocolInitiator(FetchTransactionsProtocol::class, ::FetchTransactionsHandler)
|
||||
services.registerProtocolInitiator(FetchAttachmentsProtocol::class, ::FetchAttachmentsHandler)
|
||||
services.registerProtocolInitiator(BroadcastTransactionProtocol::class, ::NotifyTransactionHandler)
|
||||
}
|
||||
|
||||
private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> {
|
||||
require(req.hashes.isNotEmpty())
|
||||
return req.hashes.map {
|
||||
val tx = storage.validatedTransactions.getTransaction(it)
|
||||
if (tx == null)
|
||||
logger.info("Got request for unknown tx $it")
|
||||
tx
|
||||
|
||||
private class FetchTransactionsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
|
||||
require(it.hashes.isNotEmpty())
|
||||
it
|
||||
}
|
||||
val txs = request.hashes.map {
|
||||
val tx = serviceHub.storageService.validatedTransactions.getTransaction(it)
|
||||
if (tx == null)
|
||||
logger.info("Got request for unknown tx $it")
|
||||
tx
|
||||
}
|
||||
send(otherParty, txs)
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleAttachmentRequest(req: FetchDataProtocol.Request): List<ByteArray?> {
|
||||
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
|
||||
require(req.hashes.isNotEmpty())
|
||||
return req.hashes.map {
|
||||
val jar: InputStream? = storage.attachments.openAttachment(it)?.open()
|
||||
if (jar == null) {
|
||||
logger.info("Got request for unknown attachment $it")
|
||||
null
|
||||
} else {
|
||||
jar.readBytes()
|
||||
|
||||
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
|
||||
private class FetchAttachmentsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
|
||||
require(it.hashes.isNotEmpty())
|
||||
it
|
||||
}
|
||||
val attachments = request.hashes.map {
|
||||
val jar: InputStream? = serviceHub.storageService.attachments.openAttachment(it)?.open()
|
||||
if (jar == null) {
|
||||
logger.info("Got request for unknown attachment $it")
|
||||
null
|
||||
} else {
|
||||
jar.readBytes()
|
||||
}
|
||||
}
|
||||
send(otherParty, attachments)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
|
||||
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
|
||||
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
|
||||
// cash without from unknown parties?
|
||||
class NotifyTransactionHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val request = receive<BroadcastTransactionProtocol.NotifyTxRequest>(otherParty).unwrap { it }
|
||||
subProtocol(ResolveTransactionsProtocol(request.tx, otherParty), shareParentSessions = true)
|
||||
serviceHub.recordTransactions(request.tx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -1,53 +1,38 @@
|
||||
package com.r3corda.node.services.statemachine
|
||||
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.messaging.TopicSession
|
||||
import java.util.*
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.ProtocolSession
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
|
||||
|
||||
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
|
||||
interface ProtocolIORequest {
|
||||
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
|
||||
// don't have the original stack trace because it's in a suspended fiber.
|
||||
val stackTraceInCaseOfProblems: StackSnapshot
|
||||
val topic: String
|
||||
val session: ProtocolSession
|
||||
}
|
||||
|
||||
interface SendRequest : ProtocolIORequest {
|
||||
val destination: Party
|
||||
val payload: Any
|
||||
val sendSessionID: Long
|
||||
val uniqueMessageId: UUID
|
||||
val message: SessionMessage
|
||||
}
|
||||
|
||||
interface ReceiveRequest<T> : ProtocolIORequest {
|
||||
interface ReceiveRequest<T : SessionMessage> : ProtocolIORequest {
|
||||
val receiveType: Class<T>
|
||||
val receiveSessionID: Long
|
||||
val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID)
|
||||
}
|
||||
|
||||
data class SendAndReceive<T>(override val topic: String,
|
||||
override val destination: Party,
|
||||
override val payload: Any,
|
||||
override val sendSessionID: Long,
|
||||
override val uniqueMessageId: UUID,
|
||||
override val receiveType: Class<T>,
|
||||
override val receiveSessionID: Long) : SendRequest, ReceiveRequest<T> {
|
||||
data class SendAndReceive<T : SessionMessage>(override val session: ProtocolSession,
|
||||
override val message: SessionMessage,
|
||||
override val receiveType: Class<T>) : SendRequest, ReceiveRequest<T> {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
data class ReceiveOnly<T>(override val topic: String,
|
||||
override val receiveType: Class<T>,
|
||||
override val receiveSessionID: Long) : ReceiveRequest<T> {
|
||||
data class ReceiveOnly<T : SessionMessage>(override val session: ProtocolSession,
|
||||
override val receiveType: Class<T>) : ReceiveRequest<T> {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
||||
data class SendOnly(override val destination: Party,
|
||||
override val topic: String,
|
||||
override val payload: Any,
|
||||
override val sendSessionID: Long,
|
||||
override val uniqueMessageId: UUID) : SendRequest {
|
||||
data class SendOnly(override val session: ProtocolSession, override val message: SessionMessage) : SendRequest {
|
||||
@Transient
|
||||
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
|
||||
}
|
||||
|
@ -8,16 +8,22 @@ import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolSessionException
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
import com.r3corda.core.protocols.StateMachineRunId
|
||||
import com.r3corda.core.random63BitValue
|
||||
import com.r3corda.core.rootCause
|
||||
import com.r3corda.core.utilities.UntrustworthyData
|
||||
import com.r3corda.core.utilities.trace
|
||||
import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.*
|
||||
import com.r3corda.node.utilities.createDatabaseTransaction
|
||||
import org.jetbrains.exposed.sql.Database
|
||||
import org.jetbrains.exposed.sql.transactions.TransactionManager
|
||||
import org.slf4j.Logger
|
||||
import org.slf4j.LoggerFactory
|
||||
import java.io.PrintWriter
|
||||
import java.io.StringWriter
|
||||
import java.sql.SQLException
|
||||
import java.util.*
|
||||
import java.util.concurrent.ExecutionException
|
||||
@ -36,12 +42,26 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
private val loggerName: String)
|
||||
: Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
|
||||
|
||||
companion object {
|
||||
// Used to work around a small limitation in Quasar.
|
||||
private val QUASAR_UNBLOCKER = run {
|
||||
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER")
|
||||
field.isAccessible = true
|
||||
field.get(null)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the current [ProtocolStateMachineImpl] or null if executing outside of one.
|
||||
*/
|
||||
fun currentStateMachine(): ProtocolStateMachineImpl<*>? = Strand.currentStrand() as? ProtocolStateMachineImpl<*>
|
||||
}
|
||||
|
||||
// These fields shouldn't be serialised, so they are marked @Transient.
|
||||
@Transient lateinit override var serviceHub: ServiceHubInternal
|
||||
@Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit
|
||||
@Transient internal lateinit var actionOnSuspend: (ProtocolIORequest) -> Unit
|
||||
@Transient internal lateinit var actionOnEnd: () -> Unit
|
||||
@Transient internal var receivedPayload: Any? = null
|
||||
@Transient internal lateinit var database: Database
|
||||
@Transient internal var fromCheckpoint: Boolean = false
|
||||
|
||||
@Transient private var _logger: Logger? = null
|
||||
override val logger: Logger get() {
|
||||
@ -62,18 +82,20 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
}
|
||||
}
|
||||
|
||||
internal val openSessions = HashMap<Pair<ProtocolLogic<*>, Party>, ProtocolSession>()
|
||||
|
||||
init {
|
||||
logic.psm = this
|
||||
name = id.toString()
|
||||
}
|
||||
|
||||
@Suspendable @Suppress("UNCHECKED_CAST")
|
||||
@Suspendable
|
||||
override fun run(): R {
|
||||
createTransaction()
|
||||
val result = try {
|
||||
logic.call()
|
||||
} catch (t: Throwable) {
|
||||
actionOnEnd()
|
||||
_resultFuture?.setException(t)
|
||||
processException(t)
|
||||
commitTransaction()
|
||||
throw ExecutionException(t)
|
||||
}
|
||||
@ -106,56 +128,140 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> {
|
||||
suspend(receiveRequest)
|
||||
check(receivedPayload != null) { "Expected to receive something" }
|
||||
val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload))
|
||||
receivedPayload = null
|
||||
return untrustworthy
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <T : Any> sendAndReceive(topic: String,
|
||||
destination: Party,
|
||||
sessionIDForSend: Long,
|
||||
sessionIDForReceive: Long,
|
||||
override fun <T : Any> sendAndReceive(otherParty: Party,
|
||||
payload: Any,
|
||||
receiveType: Class<T>): UntrustworthyData<T> {
|
||||
return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, UUID.randomUUID(), receiveType, sessionIDForReceive))
|
||||
receiveType: Class<T>,
|
||||
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
|
||||
val session = getSession(otherParty, sessionProtocol)
|
||||
val sendSessionData = createSessionData(session, payload)
|
||||
val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java)
|
||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> {
|
||||
return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive))
|
||||
override fun <T : Any> receive(otherParty: Party,
|
||||
receiveType: Class<T>,
|
||||
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
|
||||
val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java)
|
||||
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
|
||||
suspend(SendOnly(destination, topic, payload, sessionID, UUID.randomUUID()))
|
||||
override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) {
|
||||
val session = getSession(otherParty, sessionProtocol)
|
||||
val sendSessionData = createSessionData(session, payload)
|
||||
sendInternal(session, sendSessionData)
|
||||
}
|
||||
|
||||
private fun createSessionData(session: ProtocolSession, payload: Any): SessionData {
|
||||
val otherPartySessionId = session.otherPartySessionId
|
||||
?: throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
|
||||
return SessionData(otherPartySessionId, payload)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun suspend(protocolIORequest: ProtocolIORequest) {
|
||||
private fun sendInternal(session: ProtocolSession, message: SessionMessage) {
|
||||
suspend(SendOnly(session, message))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <T : SessionMessage> receiveInternal(session: ProtocolSession, receiveType: Class<T>): T {
|
||||
return suspendAndExpectReceive(ReceiveOnly(session, receiveType))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <T : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class<T>): T {
|
||||
return suspendAndExpectReceive(SendAndReceive(session, message, receiveType))
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession {
|
||||
return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol)
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession {
|
||||
val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null)
|
||||
openSessions[Pair(sessionProtocol, otherParty)] = session
|
||||
val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name
|
||||
val sessionInit = SessionInit(session.ourSessionId, serviceHub.storageService.myLegalIdentity, counterpartyProtocol)
|
||||
val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java)
|
||||
if (sessionInitResponse is SessionConfirm) {
|
||||
session.otherPartySessionId = sessionInitResponse.initiatedSessionId
|
||||
return session
|
||||
} else {
|
||||
sessionInitResponse as SessionReject
|
||||
throw ProtocolSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun <T : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): T {
|
||||
fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll()
|
||||
|
||||
val receivedMessage = getReceivedMessage() ?: run {
|
||||
// Suspend while we wait for the receive
|
||||
receiveRequest.session.waitingForResponse = true
|
||||
suspend(receiveRequest)
|
||||
receiveRequest.session.waitingForResponse = false
|
||||
getReceivedMessage()
|
||||
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $id $receiveRequest")
|
||||
}
|
||||
|
||||
if (receivedMessage is SessionEnd) {
|
||||
openSessions.values.remove(receiveRequest.session)
|
||||
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended")
|
||||
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
|
||||
return receiveRequest.receiveType.cast(receivedMessage)
|
||||
} else {
|
||||
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $id $receiveRequest")
|
||||
}
|
||||
}
|
||||
|
||||
@Suspendable
|
||||
private fun suspend(ioRequest: ProtocolIORequest) {
|
||||
commitTransaction()
|
||||
parkAndSerialize { fiber, serializer ->
|
||||
logger.trace { "Suspended $id on $ioRequest" }
|
||||
try {
|
||||
suspendAction(protocolIORequest)
|
||||
actionOnSuspend(ioRequest)
|
||||
} catch (t: Throwable) {
|
||||
// Do not throw exception again - Quasar completely bins it.
|
||||
logger.warn("Captured exception which was swallowed by Quasar", t)
|
||||
actionOnEnd()
|
||||
_resultFuture?.setException(t)
|
||||
// TODO When error handling is introduced, look into whether we should be deleting the checkpoint and
|
||||
// completing the Future
|
||||
processException(t)
|
||||
}
|
||||
}
|
||||
createTransaction()
|
||||
}
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Retrieves our state machine id if we are running a [ProtocolStateMachineImpl].
|
||||
*/
|
||||
fun retrieveCurrentStateMachine(): ProtocolStateMachineImpl<*>? {
|
||||
return Strand.currentStrand() as? ProtocolStateMachineImpl<*>
|
||||
private fun processException(t: Throwable) {
|
||||
actionOnEnd()
|
||||
_resultFuture?.setException(t)
|
||||
}
|
||||
|
||||
internal fun resume(scheduler: FiberScheduler) {
|
||||
try {
|
||||
if (fromCheckpoint) {
|
||||
logger.info("$id resumed from checkpoint")
|
||||
fromCheckpoint = false
|
||||
Fiber.unparkDeserialized(this, scheduler)
|
||||
} else if (state == State.NEW) {
|
||||
logger.trace { "$id started" }
|
||||
start()
|
||||
} else {
|
||||
logger.trace { "$id resumed" }
|
||||
Fiber.unpark(this, QUASAR_UNBLOCKER)
|
||||
}
|
||||
} catch (t: Throwable) {
|
||||
logger.error("$id threw '${t.rootCause}'")
|
||||
logger.trace {
|
||||
val s = StringWriter()
|
||||
t.rootCause.printStackTrace(PrintWriter(s))
|
||||
"Stack trace of protocol error: $s"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -3,34 +3,38 @@ package com.r3corda.node.services.statemachine
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.FiberExecutorScheduler
|
||||
import co.paralleluniverse.io.serialization.kryo.KryoSerializer
|
||||
import co.paralleluniverse.strands.Strand
|
||||
import com.codahale.metrics.Gauge
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.google.common.base.Throwables
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.ThreadBox
|
||||
import com.r3corda.core.abbreviate
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.messaging.TopicSession
|
||||
import com.r3corda.core.messaging.runOnNextMessage
|
||||
import com.r3corda.core.messaging.send
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
import com.r3corda.core.protocols.StateMachineRunId
|
||||
import com.r3corda.core.random63BitValue
|
||||
import com.r3corda.core.serialization.*
|
||||
import com.r3corda.core.then
|
||||
import com.r3corda.core.utilities.ProgressTracker
|
||||
import com.r3corda.core.utilities.debug
|
||||
import com.r3corda.core.utilities.loggerFor
|
||||
import com.r3corda.core.utilities.trace
|
||||
import com.r3corda.node.services.api.Checkpoint
|
||||
import com.r3corda.node.services.api.CheckpointStorage
|
||||
import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.node.utilities.AddOrRemove
|
||||
import com.r3corda.node.utilities.AffinityExecutor
|
||||
import kotlinx.support.jdk8.collections.removeIf
|
||||
import org.jetbrains.exposed.sql.Database
|
||||
import rx.Observable
|
||||
import rx.subjects.PublishSubject
|
||||
import rx.subjects.UnicastSubject
|
||||
import java.io.PrintWriter
|
||||
import java.io.StringWriter
|
||||
import java.util.*
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
import java.util.concurrent.ExecutionException
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
|
||||
@ -48,7 +52,6 @@ import javax.annotation.concurrent.ThreadSafe
|
||||
* The SMM will always invoke the protocol fibers on the given [AffinityExecutor], regardless of which thread actually
|
||||
* starts them via [add].
|
||||
*
|
||||
* TODO: Session IDs should be set up and propagated automatically, on demand.
|
||||
* TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised
|
||||
* continuation is always unique?
|
||||
* TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk
|
||||
@ -58,12 +61,19 @@ import javax.annotation.concurrent.ThreadSafe
|
||||
* TODO: Implement stub/skel classes that provide a basic RPC framework on top of this.
|
||||
*/
|
||||
@ThreadSafe
|
||||
class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableServices: List<Any>,
|
||||
class StateMachineManager(val serviceHub: ServiceHubInternal,
|
||||
tokenizableServices: List<Any>,
|
||||
val checkpointStorage: CheckpointStorage,
|
||||
val executor: AffinityExecutor,
|
||||
val database: Database) {
|
||||
|
||||
inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor)
|
||||
|
||||
companion object {
|
||||
private val logger = loggerFor<StateMachineManager>()
|
||||
internal val sessionTopic = TopicSession("platform.session")
|
||||
}
|
||||
|
||||
val scheduler = FiberScheduler()
|
||||
|
||||
data class Change(
|
||||
@ -95,6 +105,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
private val totalStartedProtocols = metrics.counter("Protocols.Started")
|
||||
private val totalFinishedProtocols = metrics.counter("Protocols.Finished")
|
||||
|
||||
private val openSessions = ConcurrentHashMap<Long, ProtocolSession>()
|
||||
private val recentlyClosedSessions = ConcurrentHashMap<Long, Party>()
|
||||
|
||||
// Context for tokenized services in checkpoints
|
||||
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo())
|
||||
|
||||
@ -119,6 +132,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
val changes: Observable<Change>
|
||||
get() = mutex.content.changesPublisher
|
||||
|
||||
init {
|
||||
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
|
||||
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
|
||||
}
|
||||
}
|
||||
|
||||
fun start() {
|
||||
restoreFibersFromCheckpoints()
|
||||
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
|
||||
* calls to [allStateMachines]
|
||||
@ -131,69 +155,99 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
}
|
||||
}
|
||||
|
||||
// Used to work around a small limitation in Quasar.
|
||||
private val QUASAR_UNBLOCKER = run {
|
||||
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER")
|
||||
field.isAccessible = true
|
||||
field.get(null)
|
||||
}
|
||||
|
||||
init {
|
||||
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
|
||||
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
|
||||
}
|
||||
}
|
||||
|
||||
fun start() {
|
||||
checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) }
|
||||
serviceHub.networkMapCache.mapServiceRegistered.then(executor) {
|
||||
mutex.locked {
|
||||
started = true
|
||||
stateMachines.forEach { restartFiber(it.key, it.value) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun createFiberForCheckpoint(checkpoint: Checkpoint) {
|
||||
if (!checkpoint.fiberCreated) {
|
||||
val fiber = deserializeFiber(checkpoint.serialisedFiber)
|
||||
initFiber(fiber, { checkpoint })
|
||||
}
|
||||
}
|
||||
|
||||
private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
|
||||
if (checkpoint.request is ReceiveRequest<*>) {
|
||||
val topicSession = checkpoint.request.receiveTopicSession
|
||||
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
|
||||
iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) {
|
||||
try {
|
||||
Fiber.unparkDeserialized(fiber, scheduler)
|
||||
} catch (e: Throwable) {
|
||||
logError(e, it, topicSession, fiber)
|
||||
private fun restoreFibersFromCheckpoints() {
|
||||
mutex.locked {
|
||||
checkpointStorage.checkpoints.forEach {
|
||||
// If a protocol is added before start() then don't attempt to restore it
|
||||
if (!stateMachines.containsValue(it)) {
|
||||
val fiber = deserializeFiber(it.serialisedFiber)
|
||||
initFiber(fiber)
|
||||
stateMachines[fiber] = it
|
||||
}
|
||||
}
|
||||
if (checkpoint.request is SendRequest) {
|
||||
sendMessage(fiber, checkpoint.request)
|
||||
}
|
||||
}
|
||||
|
||||
private fun resumeRestoredFibers() {
|
||||
mutex.locked {
|
||||
started = true
|
||||
stateMachines.keys.forEach { resumeRestoredFiber(it) }
|
||||
}
|
||||
serviceHub.networkService.addMessageHandler(sessionTopic, executor) { message, reg ->
|
||||
executor.checkOnThread()
|
||||
val sessionMessage = message.data.deserialize<SessionMessage>()
|
||||
when (sessionMessage) {
|
||||
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage)
|
||||
is SessionInit -> onSessionInit(sessionMessage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun resumeRestoredFiber(fiber: ProtocolStateMachineImpl<*>) {
|
||||
fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
|
||||
if (fiber.openSessions.values.any { it.waitingForResponse }) {
|
||||
fiber.logger.info("Restored fiber pending on receive ${fiber.id}}")
|
||||
} else {
|
||||
resumeFiber(fiber)
|
||||
}
|
||||
}
|
||||
|
||||
private fun onExistingSessionMessage(message: ExistingSessionMessage) {
|
||||
val session = openSessions[message.recipientSessionId]
|
||||
if (session != null) {
|
||||
session.psm.logger.trace { "${session.psm.id} received $message on $session" }
|
||||
if (message is SessionEnd) {
|
||||
openSessions.remove(message.recipientSessionId)
|
||||
}
|
||||
session.receivedMessages += message
|
||||
if (session.waitingForResponse) {
|
||||
updateCheckpoint(session.psm)
|
||||
resumeFiber(session.psm)
|
||||
}
|
||||
} else {
|
||||
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}")
|
||||
executor.executeASAP {
|
||||
if (checkpoint.request is SendRequest) {
|
||||
sendMessage(fiber, checkpoint.request)
|
||||
}
|
||||
iterateStateMachine(fiber, checkpoint.receivedPayload) {
|
||||
try {
|
||||
Fiber.unparkDeserialized(fiber, scheduler)
|
||||
} catch (e: Throwable) {
|
||||
logError(e, it, null, fiber)
|
||||
}
|
||||
val otherParty = recentlyClosedSessions.remove(message.recipientSessionId)
|
||||
if (otherParty != null) {
|
||||
if (message is SessionConfirm) {
|
||||
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
|
||||
sendSessionMessage(otherParty, SessionEnd(message.initiatedSessionId), null)
|
||||
} else {
|
||||
logger.trace { "Ignoring session end message for already closed session: $message" }
|
||||
}
|
||||
} else {
|
||||
logger.warn("Received a session message for unknown session: $message")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun onSessionInit(sessionInit: SessionInit) {
|
||||
logger.trace { "Received $sessionInit" }
|
||||
//TODO Verify the other party are who they say they are from the TLS subsystem
|
||||
val otherParty = sessionInit.initiatorParty
|
||||
val otherPartySessionId = sessionInit.initiatorSessionId
|
||||
try {
|
||||
val markerClass = Class.forName(sessionInit.protocolName)
|
||||
val protocolFactory = serviceHub.getProtocolFactory(markerClass)
|
||||
if (protocolFactory != null) {
|
||||
val protocol = protocolFactory(otherParty)
|
||||
val psm = createFiber(sessionInit.protocolName, protocol)
|
||||
val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId)
|
||||
openSessions[session.ourSessionId] = session
|
||||
psm.openSessions[Pair(protocol, otherParty)] = session
|
||||
updateCheckpoint(psm)
|
||||
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm)
|
||||
psm.logger.debug { "Starting new ${psm.id} from $sessionInit on $session" }
|
||||
startFiber(psm)
|
||||
} else {
|
||||
logger.warn("Unknown protocol marker class in $sessionInit")
|
||||
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null)
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
logger.warn("Received invalid $sessionInit", e)
|
||||
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null)
|
||||
}
|
||||
}
|
||||
|
||||
private fun serializeFiber(fiber: ProtocolStateMachineImpl<*>): SerializedBytes<ProtocolStateMachineImpl<*>> {
|
||||
// We don't use the passed-in serializer here, because we need to use our own augmented Kryo.
|
||||
val kryo = quasarKryo()
|
||||
// add the map of tokens -> tokenizedServices to the kyro context
|
||||
SerializeAsTokenSerializer.setContext(kryo, serializationContext)
|
||||
@ -204,7 +258,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
val kryo = quasarKryo()
|
||||
// put the map of token -> tokenized into the kryo context
|
||||
SerializeAsTokenSerializer.setContext(kryo, serializationContext)
|
||||
return serialisedFiber.deserialize(kryo)
|
||||
return serialisedFiber.deserialize(kryo).apply { fromCheckpoint = true }
|
||||
}
|
||||
|
||||
private fun quasarKryo(): Kryo {
|
||||
@ -212,70 +266,51 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
return createKryo(serializer.kryo)
|
||||
}
|
||||
|
||||
private fun logError(e: Throwable, payload: Any?, topicSession: TopicSession?, psm: ProtocolStateMachineImpl<*>) {
|
||||
psm.logger.error("Protocol state machine ${psm.javaClass.name} threw '${Throwables.getRootCause(e)}' " +
|
||||
"when handling a message of type ${payload?.javaClass?.name} on queue $topicSession")
|
||||
if (psm.logger.isTraceEnabled) {
|
||||
val s = StringWriter()
|
||||
Throwables.getRootCause(e).printStackTrace(PrintWriter(s))
|
||||
psm.logger.trace("Stack trace of protocol error is: $s")
|
||||
}
|
||||
private fun <T> createFiber(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachineImpl<T> {
|
||||
val id = StateMachineRunId.createRandom()
|
||||
return ProtocolStateMachineImpl(id, logic, scheduler, loggerName).apply { initFiber(this) }
|
||||
}
|
||||
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint {
|
||||
private fun initFiber(psm: ProtocolStateMachineImpl<*>) {
|
||||
psm.database = database
|
||||
psm.serviceHub = serviceHub
|
||||
psm.suspendAction = { request ->
|
||||
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
|
||||
onNextSuspend(psm, request)
|
||||
psm.actionOnSuspend = { ioRequest ->
|
||||
updateCheckpoint(psm)
|
||||
processIORequest(ioRequest)
|
||||
}
|
||||
psm.actionOnEnd = {
|
||||
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
|
||||
mutex.locked {
|
||||
val finalCheckpoint = stateMachines.remove(psm)
|
||||
if (finalCheckpoint != null) {
|
||||
checkpointStorage.removeCheckpoint(finalCheckpoint)
|
||||
}
|
||||
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
|
||||
totalFinishedProtocols.inc()
|
||||
notifyChangeObservers(psm, AddOrRemove.REMOVE)
|
||||
}
|
||||
endAllFiberSessions(psm)
|
||||
}
|
||||
val checkpoint = startingCheckpoint()
|
||||
checkpoint.fiberCreated = true
|
||||
totalStartedProtocols.inc()
|
||||
mutex.locked {
|
||||
stateMachines[psm] = checkpoint
|
||||
totalStartedProtocols.inc()
|
||||
notifyChangeObservers(psm, AddOrRemove.ADD)
|
||||
}
|
||||
return checkpoint
|
||||
}
|
||||
|
||||
/**
|
||||
* Kicks off a brand new state machine of the given class. It will log with the named logger.
|
||||
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
|
||||
* restarted with checkpointed state machines in the storage service.
|
||||
*/
|
||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
|
||||
val id = StateMachineRunId.createRandom()
|
||||
val fiber = ProtocolStateMachineImpl(id, logic, scheduler, loggerName)
|
||||
// Need to add before iterating in case of immediate completion
|
||||
val checkpoint = initFiber(fiber) {
|
||||
val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
|
||||
checkpoint
|
||||
}
|
||||
checkpointStorage.addCheckpoint(checkpoint)
|
||||
mutex.locked { // If we are not started then our checkpoint will be picked up during start
|
||||
if (!started) {
|
||||
return fiber
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
executor.executeASAP {
|
||||
iterateStateMachine(fiber, null) {
|
||||
fiber.start()
|
||||
private fun endAllFiberSessions(psm: ProtocolStateMachineImpl<*>) {
|
||||
openSessions.values.removeIf { session ->
|
||||
if (session.psm == psm) {
|
||||
val otherPartySessionId = session.otherPartySessionId
|
||||
if (otherPartySessionId != null) {
|
||||
sendSessionMessage(session.otherParty, SessionEnd(otherPartySessionId), psm)
|
||||
}
|
||||
recentlyClosedSessions[session.ourSessionId] = session.otherParty
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun startFiber(fiber: ProtocolStateMachineImpl<*>) {
|
||||
try {
|
||||
resumeFiber(fiber)
|
||||
} catch (e: ExecutionException) {
|
||||
// There are two ways we can take exceptions in this method:
|
||||
//
|
||||
@ -290,17 +325,29 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
if (e.cause !is ExecutionException)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Kicks off a brand new state machine of the given class. It will log with the named logger.
|
||||
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
|
||||
* restarted with checkpointed state machines in the storage service.
|
||||
*/
|
||||
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
|
||||
val fiber = createFiber(loggerName, logic)
|
||||
updateCheckpoint(fiber)
|
||||
// If we are not started then our checkpoint will be picked up during start
|
||||
mutex.locked {
|
||||
if (started) {
|
||||
startFiber(fiber)
|
||||
}
|
||||
}
|
||||
return fiber
|
||||
}
|
||||
|
||||
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
|
||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||
request: ProtocolIORequest?,
|
||||
receivedPayload: Any?) {
|
||||
val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload)
|
||||
val previousCheckpoint = mutex.locked {
|
||||
stateMachines.put(psm, newCheckpoint)
|
||||
}
|
||||
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>) {
|
||||
check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
|
||||
val newCheckpoint = Checkpoint(serializeFiber(psm))
|
||||
val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) }
|
||||
if (previousCheckpoint != null) {
|
||||
checkpointStorage.removeCheckpoint(previousCheckpoint)
|
||||
}
|
||||
@ -308,90 +355,70 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
|
||||
checkpointingMeter.mark()
|
||||
}
|
||||
|
||||
private fun iterateStateMachine(psm: ProtocolStateMachineImpl<*>,
|
||||
receivedPayload: Any?,
|
||||
resumeAction: (Any?) -> Unit) {
|
||||
executor.checkOnThread()
|
||||
psm.receivedPayload = receivedPayload
|
||||
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
|
||||
resumeAction(receivedPayload)
|
||||
}
|
||||
|
||||
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) {
|
||||
val serialisedFiber = serializeFiber(psm)
|
||||
updateCheckpoint(psm, serialisedFiber, request, null)
|
||||
// We have a request to do something: send, receive, or send-and-receive.
|
||||
if (request is ReceiveRequest<*>) {
|
||||
// Prepare a listener on the network that runs in the background thread when we receive a message.
|
||||
prepareToReceiveForRequest(psm, serialisedFiber, request)
|
||||
}
|
||||
if (request is SendRequest) {
|
||||
performSendRequest(psm, request)
|
||||
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
|
||||
executor.executeASAP {
|
||||
psm.resume(scheduler)
|
||||
}
|
||||
}
|
||||
|
||||
private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, request: ReceiveRequest<*>) {
|
||||
executor.checkOnThread()
|
||||
val queueID = request.receiveTopicSession
|
||||
psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" }
|
||||
iterateOnResponse(psm, serialisedFiber, request) {
|
||||
try {
|
||||
Fiber.unpark(psm, QUASAR_UNBLOCKER)
|
||||
} catch(e: Throwable) {
|
||||
logError(e, it, queueID, psm)
|
||||
private fun processIORequest(ioRequest: ProtocolIORequest) {
|
||||
if (ioRequest is SendRequest) {
|
||||
if (ioRequest.message is SessionInit) {
|
||||
openSessions[ioRequest.session.ourSessionId] = ioRequest.session
|
||||
}
|
||||
sendSessionMessage(ioRequest.session.otherParty, ioRequest.message, ioRequest.session.psm)
|
||||
if (ioRequest !is ReceiveRequest<*>) {
|
||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||
resumeFiber(ioRequest.session.psm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) {
|
||||
val topicSession = sendMessage(psm, request)
|
||||
private fun sendSessionMessage(party: Party, message: SessionMessage, psm: ProtocolStateMachineImpl<*>?) {
|
||||
val node = serviceHub.networkMapCache.getNodeByLegalName(party.name)
|
||||
?: throw IllegalArgumentException("Don't know about party $party")
|
||||
val logger = psm?.logger ?: logger
|
||||
logger.trace { "${psm?.id} sending $message to party $party" }
|
||||
serviceHub.networkService.send(sessionTopic, message, node.address)
|
||||
}
|
||||
|
||||
if (request is SendOnly) {
|
||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||
iterateStateMachine(psm, null) {
|
||||
try {
|
||||
Fiber.unpark(psm, QUASAR_UNBLOCKER)
|
||||
} catch(e: Throwable) {
|
||||
logError(e, request.payload, topicSession, psm)
|
||||
}
|
||||
}
|
||||
|
||||
interface SessionMessage
|
||||
|
||||
interface ExistingSessionMessage: SessionMessage {
|
||||
val recipientSessionId: Long
|
||||
}
|
||||
|
||||
data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage
|
||||
|
||||
interface SessionInitResponse : ExistingSessionMessage
|
||||
|
||||
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
|
||||
override val recipientSessionId: Long get() = initiatorSessionId
|
||||
}
|
||||
|
||||
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
|
||||
override val recipientSessionId: Long get() = initiatorSessionId
|
||||
}
|
||||
|
||||
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
|
||||
override fun toString(): String {
|
||||
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
|
||||
}
|
||||
}
|
||||
|
||||
private fun sendMessage(psm: ProtocolStateMachineImpl<*>, request: SendRequest): TopicSession {
|
||||
val topicSession = TopicSession(request.topic, request.sendSessionID)
|
||||
val payload = request.payload
|
||||
psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" }
|
||||
val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?:
|
||||
throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems)
|
||||
serviceHub.networkService.send(topicSession, payload, node.address, request.uniqueMessageId)
|
||||
return topicSession
|
||||
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
|
||||
|
||||
|
||||
data class ProtocolSession(val protocol: ProtocolLogic<*>,
|
||||
val otherParty: Party,
|
||||
val ourSessionId: Long,
|
||||
var otherPartySessionId: Long?,
|
||||
@Volatile var waitingForResponse: Boolean = false) {
|
||||
|
||||
val receivedMessages = ConcurrentLinkedQueue<ExistingSessionMessage>()
|
||||
val psm: ProtocolStateMachineImpl<*> get() = protocol.psm as ProtocolStateMachineImpl<*>
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
|
||||
* received.
|
||||
*/
|
||||
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
|
||||
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
|
||||
request: ReceiveRequest<*>,
|
||||
resumeAction: (Any?) -> Unit) {
|
||||
val topicSession = request.receiveTopicSession
|
||||
serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg ->
|
||||
// Assertion to ensure we don't execute on the wrong thread.
|
||||
executor.checkOnThread()
|
||||
// TODO: This is insecure: we should not deserialise whatever we find and *then* check.
|
||||
// We should instead verify as we read the data that it's what we are expecting and throw as early as
|
||||
// possible. We only do it this way for convenience during the prototyping stage. Note that this means
|
||||
// we could simply not require the programmer to specify the expected return type at all, and catch it
|
||||
// at the last moment when we do the downcast. However this would make protocol code harder to read and
|
||||
// make it more difficult to migrate to a more explicit serialisation scheme later.
|
||||
val payload = netMsg.data.deserialize<Any>()
|
||||
check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" }
|
||||
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
|
||||
updateCheckpoint(psm, serialisedFiber, null, payload)
|
||||
psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" }
|
||||
iterateStateMachine(psm, payload, resumeAction)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +1,11 @@
|
||||
package com.r3corda.node.services.transactions
|
||||
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.node.services.ServiceType
|
||||
import com.r3corda.core.node.services.TimestampChecker
|
||||
import com.r3corda.core.node.services.UniquenessProvider
|
||||
import com.r3corda.node.services.api.AbstractNodeService
|
||||
import com.r3corda.core.serialization.SingletonSerializeAsToken
|
||||
import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.protocols.NotaryProtocol
|
||||
import com.r3corda.protocols.NotaryProtocol.TOPIC
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
/**
|
||||
* A Notary service acts as the final signer of a transaction ensuring two things:
|
||||
@ -17,22 +16,18 @@ import com.r3corda.protocols.NotaryProtocol.TOPIC
|
||||
*
|
||||
* This is the base implementation that can be customised with specific Notary transaction commit protocol.
|
||||
*/
|
||||
abstract class NotaryService(services: ServiceHubInternal,
|
||||
val timestampChecker: TimestampChecker,
|
||||
val uniquenessProvider: UniquenessProvider) : AbstractNodeService(services) {
|
||||
abstract class NotaryService(markerClass: KClass<out NotaryProtocol.Client>, services: ServiceHubInternal) : SingletonSerializeAsToken() {
|
||||
// Do not specify this as an advertised service. Use a concrete implementation.
|
||||
// TODO: We do not want a service type that cannot be used. Fix the type system abuse here.
|
||||
object Type : ServiceType("corda.notary")
|
||||
|
||||
abstract val logger: org.slf4j.Logger
|
||||
|
||||
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
|
||||
abstract val protocolFactory: NotaryProtocol.Factory
|
||||
|
||||
init {
|
||||
addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake ->
|
||||
protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider)
|
||||
}
|
||||
services.registerProtocolInitiator(markerClass) { createProtocol(it) }
|
||||
}
|
||||
|
||||
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
|
||||
abstract fun createProtocol(otherParty: Party): NotaryProtocol.Service
|
||||
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
package com.r3corda.node.services.transactions
|
||||
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.node.services.ServiceType
|
||||
import com.r3corda.core.node.services.TimestampChecker
|
||||
import com.r3corda.core.node.services.UniquenessProvider
|
||||
@ -9,11 +10,13 @@ import com.r3corda.protocols.NotaryProtocol
|
||||
|
||||
/** A simple Notary service that does not perform transaction validation */
|
||||
class SimpleNotaryService(services: ServiceHubInternal,
|
||||
timestampChecker: TimestampChecker,
|
||||
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) {
|
||||
val timestampChecker: TimestampChecker,
|
||||
val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.Client::class, services) {
|
||||
object Type : ServiceType("corda.notary.simple")
|
||||
|
||||
override val logger = loggerFor<SimpleNotaryService>()
|
||||
|
||||
override val protocolFactory = NotaryProtocol.DefaultFactory
|
||||
override fun createProtocol(otherParty: Party): NotaryProtocol.Service {
|
||||
return NotaryProtocol.Service(otherParty, timestampChecker, uniquenessProvider)
|
||||
}
|
||||
}
|
||||
|
@ -11,17 +11,13 @@ import com.r3corda.protocols.ValidatingNotaryProtocol
|
||||
|
||||
/** A Notary service that validates the transaction chain of he submitted transaction before committing it */
|
||||
class ValidatingNotaryService(services: ServiceHubInternal,
|
||||
timestampChecker: TimestampChecker,
|
||||
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) {
|
||||
val timestampChecker: TimestampChecker,
|
||||
val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.ValidatingClient::class, services) {
|
||||
object Type : ServiceType("corda.notary.validating")
|
||||
|
||||
override val logger = loggerFor<ValidatingNotaryService>()
|
||||
|
||||
override val protocolFactory = object : NotaryProtocol.Factory {
|
||||
override fun create(otherSide: Party,
|
||||
timestampChecker: TimestampChecker,
|
||||
uniquenessProvider: UniquenessProvider): NotaryProtocol.Service {
|
||||
return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider)
|
||||
}
|
||||
override fun createProtocol(otherParty: Party): ValidatingNotaryProtocol {
|
||||
return ValidatingNotaryProtocol(otherParty, timestampChecker, uniquenessProvider)
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import com.r3corda.core.days
|
||||
import com.r3corda.core.messaging.SingleMessageRecipient
|
||||
import com.r3corda.core.node.ServiceHub
|
||||
import com.r3corda.core.node.services.*
|
||||
import com.r3corda.core.protocols.ProtocolStateMachine
|
||||
import com.r3corda.core.protocols.StateMachineRunId
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.transactions.WireTransaction
|
||||
@ -23,7 +24,6 @@ import com.r3corda.node.services.persistence.PerFileTransactionStorage
|
||||
import com.r3corda.node.services.persistence.StorageServiceImpl
|
||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
|
||||
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
|
||||
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
|
||||
import com.r3corda.testing.*
|
||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
||||
import com.r3corda.testing.node.MockNetwork
|
||||
@ -89,11 +89,11 @@ class TwoPartyTradeProtocolTests {
|
||||
|
||||
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey)
|
||||
|
||||
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
|
||||
val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
|
||||
|
||||
// TODO: Verify that the result was inserted into the transaction database.
|
||||
// assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id])
|
||||
assertEquals(aliceResult.get(), bobResult.get())
|
||||
assertEquals(aliceResult.get(), bobPsm.get().resultFuture.get())
|
||||
|
||||
aliceNode.stop()
|
||||
bobNode.stop()
|
||||
@ -120,21 +120,19 @@ class TwoPartyTradeProtocolTests {
|
||||
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
|
||||
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
|
||||
|
||||
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerFuture
|
||||
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerResult
|
||||
|
||||
// Everything is on this thread so we can now step through the protocol one step at a time.
|
||||
// Seller Alice already sent a message to Buyer Bob. Pump once:
|
||||
fun pumpAlice() = (aliceNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false)
|
||||
|
||||
fun pumpBob() = (bobNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false)
|
||||
|
||||
pumpBob()
|
||||
bobNode.pumpReceive(false)
|
||||
|
||||
// Bob sends a couple of queries for the dependencies back to Alice. Alice reponds.
|
||||
pumpAlice()
|
||||
pumpBob()
|
||||
pumpAlice()
|
||||
pumpBob()
|
||||
aliceNode.pumpReceive(false)
|
||||
bobNode.pumpReceive(false)
|
||||
aliceNode.pumpReceive(false)
|
||||
bobNode.pumpReceive(false)
|
||||
aliceNode.pumpReceive(false)
|
||||
bobNode.pumpReceive(false)
|
||||
|
||||
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
|
||||
assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1)
|
||||
@ -147,7 +145,7 @@ class TwoPartyTradeProtocolTests {
|
||||
|
||||
// Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use.
|
||||
// She will wait around until Bob comes back.
|
||||
assertThat(pumpAlice()).isNotNull()
|
||||
assertThat(aliceNode.pumpReceive(false)).isNotNull()
|
||||
|
||||
// ... bring the node back up ... the act of constructing the SMM will re-register the message handlers
|
||||
// that Bob was waiting on before the reboot occurred.
|
||||
@ -309,16 +307,16 @@ class TwoPartyTradeProtocolTests {
|
||||
val attachmentID = attachment(ByteArrayInputStream(stream.toByteArray()))
|
||||
|
||||
val bobsFakeCash = fillUpForBuyer(false, bobNode.keyManagement.freshKey().public).second
|
||||
val bobsSignedTxns = insertFakeTransactions(bobsFakeCash, bobNode.services)
|
||||
insertFakeTransactions(bobsFakeCash, bobNode.services)
|
||||
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
|
||||
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second
|
||||
val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
|
||||
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
|
||||
|
||||
net.runNetwork() // Clear network map registration messages
|
||||
|
||||
val aliceTxStream = aliceNode.storage.validatedTransactions.track().second
|
||||
val aliceTxMappings = aliceNode.storage.stateMachineRecordedTransactionMapping.track().second
|
||||
val (bobResult, aliceResult, bobSmId, aliceSmId) = runBuyerAndSeller("alice's paper".outputStateAndRef())
|
||||
val aliceSmId = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerId
|
||||
|
||||
net.runNetwork()
|
||||
|
||||
@ -367,21 +365,20 @@ class TwoPartyTradeProtocolTests {
|
||||
}
|
||||
}
|
||||
|
||||
data class RunResult(
|
||||
val buyerFuture: Future<SignedTransaction>,
|
||||
val sellerFuture: Future<SignedTransaction>,
|
||||
val buyerSmId: StateMachineRunId,
|
||||
val sellerSmId: StateMachineRunId
|
||||
private data class RunResult(
|
||||
// The buyer is not created immediately, only when the seller starts running
|
||||
val buyer: Future<ProtocolStateMachine<SignedTransaction>>,
|
||||
val sellerResult: Future<SignedTransaction>,
|
||||
val sellerId: StateMachineRunId
|
||||
)
|
||||
|
||||
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>): RunResult {
|
||||
val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
|
||||
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult {
|
||||
val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty ->
|
||||
Buyer(otherParty, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
|
||||
}
|
||||
val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
|
||||
connectProtocols(buyer, seller)
|
||||
// We start the Buyer first, as the Seller sends the first message
|
||||
val buyerPsm = bobNode.smm.add("$TOPIC.buyer", buyer)
|
||||
val sellerPsm = aliceNode.smm.add("$TOPIC.seller", seller)
|
||||
return RunResult(buyerPsm.resultFuture, sellerPsm.resultFuture, buyerPsm.id, sellerPsm.id)
|
||||
val sellerResultFuture = aliceNode.smm.add("seller", seller).resultFuture
|
||||
return RunResult(buyerFuture, sellerResultFuture, seller.psm.id)
|
||||
}
|
||||
|
||||
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
|
||||
@ -404,7 +401,7 @@ class TwoPartyTradeProtocolTests {
|
||||
|
||||
net.runNetwork() // Clear network map registration messages
|
||||
|
||||
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
|
||||
val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
|
||||
|
||||
net.runNetwork()
|
||||
|
||||
@ -412,7 +409,7 @@ class TwoPartyTradeProtocolTests {
|
||||
if (bobError)
|
||||
aliceResult.get()
|
||||
else
|
||||
bobResult.get()
|
||||
bobPsm.get().resultFuture.get()
|
||||
}
|
||||
assertTrue(e.cause is TransactionVerificationException)
|
||||
assertNotNull(e.cause!!.cause)
|
||||
@ -506,6 +503,7 @@ class TwoPartyTradeProtocolTests {
|
||||
return Pair(vault, listOf(ap))
|
||||
}
|
||||
|
||||
|
||||
class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage {
|
||||
override fun track(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
|
||||
return delegate.track()
|
||||
@ -530,4 +528,5 @@ class TwoPartyTradeProtocolTests {
|
||||
data class Add(val transaction: SignedTransaction) : TxRecord
|
||||
data class Get(val id: SecureHash) : TxRecord
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -2,23 +2,24 @@ package com.r3corda.node.services
|
||||
|
||||
import com.codahale.metrics.MetricRegistry
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.node.services.*
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolLogicRefFactory
|
||||
import com.r3corda.core.protocols.StateMachineRunId
|
||||
import com.r3corda.core.testing.InMemoryVaultService
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.node.serialization.NodeClock
|
||||
import com.r3corda.node.services.api.MessagingServiceInternal
|
||||
import com.r3corda.node.services.api.MonitoringService
|
||||
import com.r3corda.node.services.api.ServiceHubInternal
|
||||
import com.r3corda.testing.node.MockNetworkMapCache
|
||||
import com.r3corda.node.services.network.NetworkMapService
|
||||
import com.r3corda.node.services.persistence.DataVending
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager
|
||||
import com.r3corda.core.testing.InMemoryVaultService
|
||||
import com.r3corda.testing.node.MockStorageService
|
||||
import com.r3corda.testing.MOCK_IDENTITY_SERVICE
|
||||
import com.r3corda.testing.node.MockNetworkMapCache
|
||||
import com.r3corda.testing.node.MockStorageService
|
||||
import java.time.Clock
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import kotlin.reflect.KClass
|
||||
|
||||
@Suppress("LeakingThis")
|
||||
open class MockServiceHubInternal(
|
||||
@ -28,7 +29,6 @@ open class MockServiceHubInternal(
|
||||
val identity: IdentityService? = MOCK_IDENTITY_SERVICE,
|
||||
val storage: TxWritableStorageService? = MockStorageService(),
|
||||
val mapCache: NetworkMapCache? = MockNetworkMapCache(),
|
||||
val mapService: NetworkMapService? = null,
|
||||
val scheduler: SchedulerService? = null,
|
||||
val overrideClock: Clock? = NodeClock(),
|
||||
val protocolFactory: ProtocolLogicRefFactory? = ProtocolLogicRefFactory()
|
||||
@ -57,14 +57,10 @@ open class MockServiceHubInternal(
|
||||
private val txStorageService: TxWritableStorageService
|
||||
get() = storage ?: throw UnsupportedOperationException()
|
||||
|
||||
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
|
||||
private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
|
||||
|
||||
lateinit var smm: StateMachineManager
|
||||
|
||||
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
||||
return smm.add(loggerName, logic).resultFuture
|
||||
}
|
||||
|
||||
init {
|
||||
if (net != null && storage != null) {
|
||||
// Creating this class is sufficient, we don't have to store it anywhere, because it registers a listener
|
||||
@ -72,4 +68,18 @@ open class MockServiceHubInternal(
|
||||
DataVending.Service(this)
|
||||
}
|
||||
}
|
||||
|
||||
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
|
||||
|
||||
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
|
||||
return smm.add(loggerName, logic).resultFuture
|
||||
}
|
||||
|
||||
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
|
||||
protocolFactories[markerClass.java] = protocolFactory
|
||||
}
|
||||
|
||||
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
|
||||
return protocolFactories[markerClass]
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package com.r3corda.node.services
|
||||
import com.google.common.jimfs.Configuration
|
||||
import com.google.common.jimfs.Jimfs
|
||||
import com.r3corda.core.contracts.*
|
||||
import com.r3corda.core.crypto.SecureHash
|
||||
import com.r3corda.core.days
|
||||
import com.r3corda.core.node.ServiceHub
|
||||
import com.r3corda.core.node.recordTransactions
|
||||
@ -12,18 +11,16 @@ import com.r3corda.core.protocols.ProtocolLogicRef
|
||||
import com.r3corda.core.protocols.ProtocolLogicRefFactory
|
||||
import com.r3corda.core.serialization.SingletonSerializeAsToken
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY
|
||||
import com.r3corda.core.utilities.LogHelper
|
||||
import com.r3corda.testing.node.TestClock
|
||||
import com.r3corda.node.services.events.NodeSchedulerService
|
||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
||||
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager
|
||||
import com.r3corda.node.services.vault.NodeVaultService
|
||||
import com.r3corda.node.utilities.AddOrRemove
|
||||
import com.r3corda.node.utilities.AffinityExecutor
|
||||
import com.r3corda.node.utilities.configureDatabase
|
||||
import com.r3corda.testing.ALICE_KEY
|
||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
||||
import com.r3corda.testing.node.MockKeyManagementService
|
||||
import com.r3corda.testing.node.TestClock
|
||||
import com.r3corda.testing.node.makeTestDataSourceProperties
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.junit.After
|
||||
@ -34,7 +31,9 @@ import java.nio.file.FileSystem
|
||||
import java.security.PublicKey
|
||||
import java.time.Clock
|
||||
import java.time.Instant
|
||||
import java.util.concurrent.*
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.TimeUnit
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
|
||||
@ -128,8 +127,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
|
||||
(serviceHub as TestReference).testReference.calls += increment
|
||||
(serviceHub as TestReference).testReference.countDown.countDown()
|
||||
}
|
||||
|
||||
override val topic: String get() = throw UnsupportedOperationException()
|
||||
}
|
||||
|
||||
class Command : TypeOnlyCommandData()
|
||||
|
@ -9,7 +9,6 @@ import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
|
||||
import com.r3corda.node.internal.AbstractNode
|
||||
import com.r3corda.node.services.network.NetworkMapService
|
||||
import com.r3corda.node.services.transactions.SimpleNotaryService
|
||||
import com.r3corda.protocols.NotaryChangeProtocol
|
||||
import com.r3corda.protocols.NotaryChangeProtocol.Instigator
|
||||
import com.r3corda.protocols.StateReplacementException
|
||||
import com.r3corda.protocols.StateReplacementRefused
|
||||
@ -49,7 +48,7 @@ class NotaryChangeTests {
|
||||
val state = issueState(clientNodeA)
|
||||
val newNotary = newNotaryNode.info.identity
|
||||
val protocol = Instigator(state, newNotary)
|
||||
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
|
||||
val future = clientNodeA.services.startProtocol("notary-change", protocol)
|
||||
|
||||
net.runNetwork()
|
||||
|
||||
@ -62,7 +61,7 @@ class NotaryChangeTests {
|
||||
val state = issueMultiPartyState(clientNodeA, clientNodeB)
|
||||
val newNotary = newNotaryNode.info.identity
|
||||
val protocol = Instigator(state, newNotary)
|
||||
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
|
||||
val future = clientNodeA.services.startProtocol("notary-change", protocol)
|
||||
|
||||
net.runNetwork()
|
||||
|
||||
@ -78,7 +77,7 @@ class NotaryChangeTests {
|
||||
val state = issueMultiPartyState(clientNodeA, clientNodeB)
|
||||
val newEvilNotary = Party("Evil Notary", generateKeyPair().public)
|
||||
val protocol = Instigator(state, newEvilNotary)
|
||||
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
|
||||
val future = clientNodeA.services.startProtocol("notary-change", protocol)
|
||||
|
||||
net.runNetwork()
|
||||
|
||||
|
@ -1,17 +1,19 @@
|
||||
package com.r3corda.node.services
|
||||
|
||||
import com.r3corda.core.contracts.Timestamp
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.contracts.TransactionType
|
||||
import com.r3corda.core.crypto.DigitalSignature
|
||||
import com.r3corda.core.seconds
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
|
||||
import com.r3corda.testing.node.MockNetwork
|
||||
import com.r3corda.node.services.network.NetworkMapService
|
||||
import com.r3corda.node.services.transactions.SimpleNotaryService
|
||||
import com.r3corda.protocols.NotaryError
|
||||
import com.r3corda.protocols.NotaryException
|
||||
import com.r3corda.protocols.NotaryProtocol
|
||||
import com.r3corda.testing.MINI_CORP_KEY
|
||||
import com.r3corda.testing.node.MockNetwork
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.time.Instant
|
||||
@ -45,10 +47,7 @@ class NotaryServiceTests {
|
||||
tx.toSignedTransaction(false)
|
||||
}
|
||||
|
||||
val protocol = NotaryProtocol.Client(stx)
|
||||
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
|
||||
net.runNetwork()
|
||||
|
||||
val future = runNotaryClient(stx)
|
||||
val signature = future.get()
|
||||
signature.verifyWithECDSA(stx.txBits)
|
||||
}
|
||||
@ -61,10 +60,7 @@ class NotaryServiceTests {
|
||||
tx.toSignedTransaction(false)
|
||||
}
|
||||
|
||||
val protocol = NotaryProtocol.Client(stx)
|
||||
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
|
||||
net.runNetwork()
|
||||
|
||||
val future = runNotaryClient(stx)
|
||||
val signature = future.get()
|
||||
signature.verifyWithECDSA(stx.txBits)
|
||||
}
|
||||
@ -78,16 +74,13 @@ class NotaryServiceTests {
|
||||
tx.toSignedTransaction(false)
|
||||
}
|
||||
|
||||
val protocol = NotaryProtocol.Client(stx)
|
||||
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
|
||||
net.runNetwork()
|
||||
val future = runNotaryClient(stx)
|
||||
|
||||
val ex = assertFailsWith(ExecutionException::class) { future.get() }
|
||||
val error = (ex.cause as NotaryException).error
|
||||
assertTrue(error is NotaryError.TimestampInvalid)
|
||||
}
|
||||
|
||||
|
||||
@Test fun `should report conflict for a duplicate transaction`() {
|
||||
val stx = run {
|
||||
val inputState = issueState(clientNode)
|
||||
@ -98,8 +91,8 @@ class NotaryServiceTests {
|
||||
|
||||
val firstSpend = NotaryProtocol.Client(stx)
|
||||
val secondSpend = NotaryProtocol.Client(stx)
|
||||
clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.first", firstSpend)
|
||||
val future = clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.second", secondSpend)
|
||||
clientNode.services.startProtocol("notary.first", firstSpend)
|
||||
val future = clientNode.services.startProtocol("notary.second", secondSpend)
|
||||
|
||||
net.runNetwork()
|
||||
|
||||
@ -108,4 +101,12 @@ class NotaryServiceTests {
|
||||
assertEquals(notaryError.tx, stx.tx)
|
||||
notaryError.conflict.verified()
|
||||
}
|
||||
|
||||
|
||||
private fun runNotaryClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
|
||||
val protocol = NotaryProtocol.Client(stx)
|
||||
val future = clientNode.services.startProtocol("notary-test", protocol)
|
||||
net.runNetwork()
|
||||
return future
|
||||
}
|
||||
}
|
@ -1,8 +1,11 @@
|
||||
package com.r3corda.node.services
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.contracts.Command
|
||||
import com.r3corda.core.contracts.DummyContract
|
||||
import com.r3corda.core.contracts.TransactionType
|
||||
import com.r3corda.core.crypto.DigitalSignature
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
|
||||
import com.r3corda.node.services.network.NetworkMapService
|
||||
@ -44,9 +47,7 @@ class ValidatingNotaryServiceTests {
|
||||
tx.toSignedTransaction(false)
|
||||
}
|
||||
|
||||
val protocol = NotaryProtocol.Client(stx)
|
||||
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
|
||||
net.runNetwork()
|
||||
val future = runValidatingClient(stx)
|
||||
|
||||
val ex = assertFailsWith(ExecutionException::class) { future.get() }
|
||||
val notaryError = (ex.cause as NotaryException).error
|
||||
@ -64,9 +65,7 @@ class ValidatingNotaryServiceTests {
|
||||
tx.toSignedTransaction(false)
|
||||
}
|
||||
|
||||
val protocol = NotaryProtocol.Client(stx)
|
||||
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
|
||||
net.runNetwork()
|
||||
val future = runValidatingClient(stx)
|
||||
|
||||
val ex = assertFailsWith(ExecutionException::class) { future.get() }
|
||||
val notaryError = (ex.cause as NotaryException).error
|
||||
@ -75,4 +74,11 @@ class ValidatingNotaryServiceTests {
|
||||
val missingKeys = (notaryError as NotaryError.SignaturesMissing).missingSigners
|
||||
assertEquals(setOf(expectedMissingKey), missingKeys)
|
||||
}
|
||||
|
||||
private fun runValidatingClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
|
||||
val protocol = NotaryProtocol.ValidatingClient(stx)
|
||||
val future = clientNode.services.startProtocol("notary", protocol)
|
||||
net.runNetwork()
|
||||
return future
|
||||
}
|
||||
}
|
@ -1,13 +1,20 @@
|
||||
package com.r3corda.node.services.persistence
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.r3corda.contracts.asset.Cash
|
||||
import com.r3corda.core.contracts.Amount
|
||||
import com.r3corda.core.contracts.Issued
|
||||
import com.r3corda.core.contracts.TransactionType
|
||||
import com.r3corda.core.contracts.USD
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.transactions.SignedTransaction
|
||||
import com.r3corda.core.utilities.DUMMY_NOTARY
|
||||
import com.r3corda.testing.node.MockNetwork
|
||||
import com.r3corda.node.services.persistence.DataVending.Service.NotifyTransactionHandler
|
||||
import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest
|
||||
import com.r3corda.testing.MEGA_CORP
|
||||
import com.r3corda.testing.node.MockNetwork
|
||||
import com.r3corda.testing.node.MockNetwork.MockNode
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
@ -38,9 +45,8 @@ class DataVendingServiceTests {
|
||||
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
|
||||
val tx = ptx.toSignedTransaction()
|
||||
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
|
||||
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
|
||||
vaultServiceNode.info, tx)
|
||||
network.runNetwork()
|
||||
|
||||
registerNode.sendNotifyTx(tx, vaultServiceNode)
|
||||
|
||||
// Check the transaction is in the receiving node
|
||||
val actual = vaultServiceNode.services.vaultService.currentVault.states.singleOrNull()
|
||||
@ -67,11 +73,23 @@ class DataVendingServiceTests {
|
||||
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
|
||||
val tx = ptx.toSignedTransaction(false)
|
||||
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
|
||||
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
|
||||
vaultServiceNode.info, tx)
|
||||
network.runNetwork()
|
||||
|
||||
registerNode.sendNotifyTx(tx, vaultServiceNode)
|
||||
|
||||
// Check the transaction is not in the receiving node
|
||||
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
|
||||
}
|
||||
}
|
||||
|
||||
private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) {
|
||||
walletServiceNode.services.registerProtocolInitiator(NotifyTxProtocol::class, ::NotifyTransactionHandler)
|
||||
services.startProtocol("notify-tx", NotifyTxProtocol(walletServiceNode.info.identity, tx))
|
||||
network.runNetwork()
|
||||
}
|
||||
|
||||
|
||||
private class NotifyTxProtocol(val otherParty: Party, val stx: SignedTransaction) : ProtocolLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() = send(otherParty, NotifyTxRequest(stx, emptySet()))
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -10,12 +10,14 @@ import org.assertj.core.api.Assertions.assertThatExceptionOfType
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
import java.nio.file.FileSystem
|
||||
import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
|
||||
class PerFileCheckpointStorageTests {
|
||||
|
||||
val fileSystem = Jimfs.newFileSystem(unix())
|
||||
val storeDir = fileSystem.getPath("store")
|
||||
val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
|
||||
val storeDir: Path = fileSystem.getPath("store")
|
||||
lateinit var checkpointStorage: PerFileCheckpointStorage
|
||||
|
||||
@Before
|
||||
@ -92,6 +94,6 @@ class PerFileCheckpointStorageTests {
|
||||
}
|
||||
|
||||
private var checkpointCount = 1
|
||||
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null)
|
||||
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
|
||||
|
||||
}
|
@ -2,14 +2,20 @@ package com.r3corda.node.services.statemachine
|
||||
|
||||
import co.paralleluniverse.fibers.Fiber
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.r3corda.core.crypto.Party
|
||||
import com.r3corda.core.messaging.SingleMessageRecipient
|
||||
import com.r3corda.core.protocols.ProtocolLogic
|
||||
import com.r3corda.core.protocols.ProtocolSessionException
|
||||
import com.r3corda.core.random63BitValue
|
||||
import com.r3corda.testing.connectProtocols
|
||||
import com.r3corda.core.serialization.deserialize
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
|
||||
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
|
||||
import com.r3corda.testing.node.InMemoryMessagingNetwork
|
||||
import com.r3corda.testing.node.MockNetwork
|
||||
import com.r3corda.testing.node.MockNetwork.MockNode
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.After
|
||||
import org.junit.Before
|
||||
import org.junit.Test
|
||||
@ -50,18 +56,18 @@ class StateMachineManagerTests {
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol suspended just after receiving payload`() {
|
||||
val topic = "send-and-receive"
|
||||
fun `protocol restarted just after receiving payload`() {
|
||||
node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
|
||||
val payload = random63BitValue()
|
||||
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
|
||||
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
|
||||
connectProtocols(sendProtocol, receiveProtocol)
|
||||
node1.smm.add("test", sendProtocol)
|
||||
node2.smm.add("test", receiveProtocol)
|
||||
net.runNetwork()
|
||||
node1.smm.add("test", SendProtocol(payload, node2.info.identity))
|
||||
|
||||
// We push through just enough messages to get only the SessionData sent
|
||||
// TODO We should be able to give runNetwork a predicate for when to stop
|
||||
net.runNetwork(2)
|
||||
node2.stop()
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
|
||||
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
|
||||
net.runNetwork()
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
|
||||
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
}
|
||||
|
||||
@Test
|
||||
@ -83,7 +89,7 @@ class StateMachineManagerTests {
|
||||
node3.stop()
|
||||
|
||||
node3 = net.createNode(node1.info.address, forcedID = node3.id)
|
||||
val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first
|
||||
val restoredProtocol = node3.getSingleProtocol<ProtocolNoBlocking>().first
|
||||
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
|
||||
net.runNetwork() // Allow network map messages to flow
|
||||
node3.smm.executor.flush()
|
||||
@ -99,43 +105,44 @@ class StateMachineManagerTests {
|
||||
|
||||
@Test
|
||||
fun `protocol loaded from checkpoint will respond to messages from before start`() {
|
||||
val topic = "send-and-receive"
|
||||
val payload = random63BitValue()
|
||||
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
|
||||
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
|
||||
connectProtocols(sendProtocol, receiveProtocol)
|
||||
node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
|
||||
val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.identity)
|
||||
node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol
|
||||
node2.stop() // kill receiver
|
||||
node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
|
||||
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
|
||||
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
|
||||
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `protocol with send will resend on interrupted restart`() {
|
||||
val topic = "send-and-receive"
|
||||
val payload = random63BitValue()
|
||||
val payload2 = random63BitValue()
|
||||
|
||||
var sentCount = 0
|
||||
var receivedCount = 0
|
||||
net.messagingNetwork.sentMessages.subscribe { if (it.message.topicSession.topic == topic) sentCount++ }
|
||||
net.messagingNetwork.receivedMessages.subscribe { if (it.message.topicSession.topic == topic) receivedCount++ }
|
||||
net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ }
|
||||
net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ }
|
||||
val node3 = net.createNode(node1.info.address)
|
||||
net.runNetwork()
|
||||
val firstProtocol = PingPongProtocol(topic, node3.info.identity, payload)
|
||||
val secondProtocol = PingPongProtocol(topic, node2.info.identity, payload2)
|
||||
connectProtocols(firstProtocol, secondProtocol)
|
||||
|
||||
var secondProtocol: PingPongProtocol? = null
|
||||
node3.services.registerProtocolInitiator(PingPongProtocol::class) {
|
||||
val protocol = PingPongProtocol(it, payload2)
|
||||
secondProtocol = protocol
|
||||
protocol
|
||||
}
|
||||
|
||||
// Kick off first send and receive
|
||||
node2.smm.add("test", firstProtocol)
|
||||
node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
|
||||
assertEquals(1, node2.checkpointStorage.checkpoints.count())
|
||||
// Restart node and thus reload the checkpoint and resend the message with same UUID
|
||||
node2.stop()
|
||||
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
|
||||
val (firstAgain, fut1) = node2b.smm.findStateMachines(PingPongProtocol::class.java).single()
|
||||
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
|
||||
net.runNetwork()
|
||||
assertEquals(1, node2.checkpointStorage.checkpoints.count())
|
||||
// Now add in the other half of the protocol. First message should get deduped. So message data stays in sync.
|
||||
node3.smm.add("test", secondProtocol)
|
||||
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
|
||||
net.runNetwork()
|
||||
node2b.smm.executor.flush()
|
||||
fut1.get()
|
||||
@ -146,15 +153,66 @@ class StateMachineManagerTests {
|
||||
assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
|
||||
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
|
||||
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
|
||||
assertEquals(payload, secondProtocol.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
|
||||
assertEquals(payload + 1, secondProtocol.receivedPayload2, "Received payload does not match the expected second value on Node 2")
|
||||
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
|
||||
assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2")
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `sending to multiple parties`() {
|
||||
val node3 = net.createNode(node1.info.address)
|
||||
net.runNetwork()
|
||||
node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
|
||||
node3.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
|
||||
val payload = random63BitValue()
|
||||
node1.smm.add("multiple-send", SendProtocol(payload, node2.info.identity, node3.info.identity))
|
||||
net.runNetwork()
|
||||
val node2Protocol = node2.getSingleProtocol<ReceiveThenSuspendProtocol>().first
|
||||
val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first
|
||||
assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `receiving from multiple parties`() {
|
||||
val node3 = net.createNode(node1.info.address)
|
||||
net.runNetwork()
|
||||
val node2Payload = random63BitValue()
|
||||
val node3Payload = random63BitValue()
|
||||
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node2Payload, it) }
|
||||
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
|
||||
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.identity, node3.info.identity)
|
||||
node1.smm.add("multiple-receive", multiReceiveProtocol)
|
||||
net.runNetwork(1) // session handshaking
|
||||
// have the messages arrive in reverse order of receive
|
||||
node3.pumpReceive(false)
|
||||
node2.pumpReceive(false)
|
||||
net.runNetwork() // pump remaining messages
|
||||
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
|
||||
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `exception thrown on other side`() {
|
||||
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { ExceptionProtocol }
|
||||
val future = node1.smm.add("exception", ReceiveThenSuspendProtocol(node2.info.identity)).resultFuture
|
||||
net.runNetwork()
|
||||
assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java)
|
||||
}
|
||||
|
||||
private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
|
||||
return transfer.message.topicSession == StateMachineManager.sessionTopic
|
||||
&& transfer.message.data.deserialize<SessionMessage>() is SessionData
|
||||
}
|
||||
|
||||
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
|
||||
val servicesArray = advertisedServices.toTypedArray()
|
||||
val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray)
|
||||
stop()
|
||||
val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray())
|
||||
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
|
||||
return node.smm.findStateMachines(P::class.java).single().first
|
||||
return newNode.getSingleProtocol<P>().first
|
||||
}
|
||||
|
||||
private inline fun <reified P : ProtocolLogic<*>> MockNode.getSingleProtocol(): Pair<P, ListenableFuture<*>> {
|
||||
return smm.findStateMachines(P::class.java).single()
|
||||
}
|
||||
|
||||
|
||||
@ -165,8 +223,6 @@ class StateMachineManagerTests {
|
||||
override fun call() {
|
||||
protocolStarted = true
|
||||
}
|
||||
|
||||
override val topic: String get() = throw UnsupportedOperationException()
|
||||
}
|
||||
|
||||
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
|
||||
@ -177,8 +233,6 @@ class StateMachineManagerTests {
|
||||
override fun doCall() {
|
||||
protocolStarted = true
|
||||
}
|
||||
|
||||
override val topic: String get() = throw UnsupportedOperationException()
|
||||
}
|
||||
|
||||
|
||||
@ -187,30 +241,37 @@ class StateMachineManagerTests {
|
||||
val lazyTime by lazy { serviceHub.clock.instant() }
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
override fun call() = Unit
|
||||
}
|
||||
|
||||
|
||||
private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() {
|
||||
|
||||
init {
|
||||
require(otherParties.isNotEmpty())
|
||||
}
|
||||
|
||||
override val topic: String get() = throw UnsupportedOperationException()
|
||||
}
|
||||
|
||||
|
||||
private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() = send(otherParty, payload)
|
||||
override fun call() = otherParties.forEach { send(it, payload) }
|
||||
}
|
||||
|
||||
|
||||
private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() {
|
||||
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() {
|
||||
|
||||
@Transient var receivedPayload: Any? = null
|
||||
init {
|
||||
require(otherParties.isNotEmpty())
|
||||
}
|
||||
|
||||
@Transient var receivedPayloads: List<Any> = emptyList()
|
||||
|
||||
@Suspendable
|
||||
override fun doCall() {
|
||||
receivedPayload = receive<Any>(otherParty).unwrap { it }
|
||||
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
|
||||
}
|
||||
}
|
||||
|
||||
private class PingPongProtocol(override val topic: String, val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() {
|
||||
private class PingPongProtocol(val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() {
|
||||
|
||||
@Transient var receivedPayload: Long? = null
|
||||
@Transient var receivedPayload2: Long? = null
|
||||
|
||||
@ -219,7 +280,10 @@ class StateMachineManagerTests {
|
||||
receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it }
|
||||
receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it }
|
||||
}
|
||||
}
|
||||
|
||||
private object ExceptionProtocol : ProtocolLogic<Nothing>() {
|
||||
override fun call(): Nothing = throw Exception()
|
||||
}
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user