Automatic session management between two protocols, and removal of explict topics

This commit is contained in:
Shams Asari
2016-09-27 18:25:26 +01:00
parent 4da73e28c7
commit 67fdf9b2ff
54 changed files with 1055 additions and 1113 deletions

View File

@ -24,6 +24,7 @@ import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.debug
import com.r3corda.node.api.APIServer
import com.r3corda.node.services.api.*
import com.r3corda.node.services.config.NodeConfiguration
@ -54,8 +55,10 @@ import java.nio.file.Path
import java.security.KeyPair
import java.time.Clock
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ExecutorService
import java.util.concurrent.TimeUnit
import kotlin.reflect.KClass
/**
* A base node implementation that can be customised either for production (with real implementations that do real
@ -91,6 +94,8 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
protected val _servicesThatAcceptUploads = ArrayList<AcceptsFileUpload>()
val servicesThatAcceptUploads: List<AcceptsFileUpload> = _servicesThatAcceptUploads
private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
val services = object : ServiceHubInternal() {
override val networkService: MessagingServiceInternal get() = net
override val networkMapCache: NetworkMapCache get() = netMapCache
@ -109,6 +114,16 @@ abstract class AbstractNode(val configuration: NodeConfiguration, val networkMap
return smm.add(loggerName, logic).resultFuture
}
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
require(markerClass !in protocolFactories) { "${markerClass.java.name} has already been used to register a protocol" }
log.debug { "Registering ${markerClass.java.name}" }
protocolFactories[markerClass.java] = protocolFactory
}
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
return protocolFactories[markerClass]
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(storage, txs)
}

View File

@ -1,11 +1,9 @@
package com.r3corda.node.services
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.AbstractStateReplacementProtocol
import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.TOPIC
object NotaryChange {
class Plugin : CordaPluginRegistry() {
@ -16,11 +14,9 @@ object NotaryChange {
* A service that monitors the network for requests for changing the notary of a state,
* and immediately runs the [NotaryChangeProtocol] if the auto-accept criteria are met.
*/
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init {
addProtocolHandler(TOPIC, TOPIC) { req: AbstractStateReplacementProtocol.Handshake ->
NotaryChangeProtocol.Acceptor(req.replyToParty)
}
services.registerProtocolInitiator(NotaryChangeProtocol.Instigator::class) { NotaryChangeProtocol.Acceptor(it) }
}
}
}

View File

@ -1,16 +1,12 @@
package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.messaging.Message
import com.r3corda.core.messaging.MessageHandlerRegistration
import com.r3corda.core.messaging.createMessage
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.protocols.HandshakeMessage
import com.r3corda.protocols.ServiceRequestMessage
import javax.annotation.concurrent.ThreadSafe
@ -20,10 +16,6 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe
abstract class AbstractNodeService(val services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<AbstractNodeService>()
}
val net: MessagingServiceInternal get() = services.networkService
/**
@ -68,36 +60,4 @@ abstract class AbstractNodeService(val services: ServiceHubInternal) : Singleton
return addMessageHandler(topic, handler, { message: Message, exception: Exception -> throw exception })
}
/**
* Register a handler to kick-off a protocol when a [HandshakeMessage] is received by the node. This performs the
* necessary steps to enable communication between the two protocols, including calling ProtocolLogic.registerSession.
* @param topic the topic on which the handshake is sent from the other party
* @param loggerName the logger name to use when starting the protocol
* @param protocolFactory a function to create the protocol with the given handshake message
* @param onResultFuture provides access to the [ListenableFuture] when the protocol starts
*/
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>,
crossinline onResultFuture: ProtocolLogic<R>.(ListenableFuture<R>, H) -> Unit) {
net.addMessageHandler(topic, DEFAULT_SESSION_ID, null) { message, reg ->
try {
val handshake = message.data.deserialize<H>()
val protocol = protocolFactory(handshake)
protocol.registerSession(handshake)
val resultFuture = services.startProtocol(loggerName, protocol)
protocol.onResultFuture(resultFuture, handshake)
} catch (e: Exception) {
logger.error("Unable to process ${H::class.java.name} message", e)
}
}
}
protected inline fun <reified H : HandshakeMessage, R : Any> addProtocolHandler(
topic: String,
loggerName: String,
crossinline protocolFactory: (H) -> ProtocolLogic<R>) {
addProtocolHandler(topic, loggerName, protocolFactory, { future, handshake -> })
}
}

View File

@ -1,7 +1,7 @@
package com.r3corda.node.services.api
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.statemachine.ProtocolIORequest
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
/**
@ -30,14 +30,13 @@ interface CheckpointStorage {
}
// This class will be serialised, so everything it points to transitively must also be serialisable (with Kryo).
data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val request: ProtocolIORequest?,
val receivedPayload: Any?
) {
// This flag is always false when loaded from storage as it isn't serialised.
// It is used to track when the associated fiber has been created, but not necessarily started when
// messages for protocols arrive before the system has fully loaded at startup.
@Transient
var fiberCreated: Boolean = false
}
class Checkpoint(val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>) {
val id: SecureHash get() = serialisedFiber.hash
override fun equals(other: Any?): Boolean = other === this || other is Checkpoint && other.id == this.id
override fun hashCode(): Int = id.hashCode()
override fun toString(): String = "${javaClass.simpleName}(id=$id)"
}

View File

@ -1,14 +1,16 @@
package com.r3corda.node.services.api
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.TxWritableStorageService
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.statemachine.ProtocolStateMachineImpl
import org.slf4j.LoggerFactory
import kotlin.reflect.KClass
interface MessagingServiceInternal : MessagingService {
/**
@ -49,7 +51,7 @@ abstract class ServiceHubInternal : ServiceHub {
* @param txs The transactions to record.
*/
internal fun recordTransactionsInternal(writableStorageService: TxWritableStorageService, txs: Iterable<SignedTransaction>) {
val stateMachineRunId = ProtocolStateMachineImpl.retrieveCurrentStateMachine()?.id
val stateMachineRunId = ProtocolStateMachineImpl.currentStateMachine()?.id
if (stateMachineRunId != null) {
txs.forEach {
storageService.stateMachineRecordedTransactionMapping.addMapping(stateMachineRunId, it.id)
@ -68,6 +70,23 @@ abstract class ServiceHubInternal : ServiceHub {
*/
abstract fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T>
/**
* Register the protocol factory we wish to use when a initiating party attempts to communicate with us. The
* registration is done against a marker [KClass] which is sent in the session handsake by the other party. If this
* marker class has been registered then the corresponding factory will be used to create the protocol which will
* communicate with the other side. If there is no mapping then the session attempt is rejected.
* @param markerClass The marker [KClass] present in a session initiation attempt, which is a 1:1 mapping to a [Class]
* using the <pre>::class</pre> construct. Any marker class can be used, with the default being the class of the initiating
* protocol. This enables the registration to be of the form: registerProtocolInitiator(InitiatorProtocol::class, ::InitiatedProtocol)
* @param protocolFactory The protocol factory generating the initiated protocol.
*/
abstract fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>)
/**
* Return the protocol factory that has been registered with [markerClass], or null if no factory is found.
*/
abstract fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)?
override fun <T : Any> invokeProtocolAsync(logicType: Class<out ProtocolLogic<T>>, vararg args: Any?): ListenableFuture<T> {
val logicRef = protocolLogicRefFactory.create(logicType, *args)
@Suppress("UNCHECKED_CAST")

View File

@ -1,28 +1,25 @@
package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol
import com.r3corda.protocols.TwoPartyDealProtocol.FIX_INITIATE_TOPIC
import com.r3corda.protocols.TwoPartyDealProtocol.FixingSessionInitiation
import com.r3corda.protocols.TwoPartyDealProtocol.Fixer
import com.r3corda.protocols.TwoPartyDealProtocol.Floater
/**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
* running scheduled fixings for the [InterestRateSwap] contract.
*
* TODO: This will be replaced with the automatic sessionID / session setup work.
* TODO: This will be replaced with the symmetric session work
*/
object FixingSessionInitiation {
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init {
addProtocolHandler(FIX_INITIATE_TOPIC, "fixings") { initiation: FixingSessionInitiation ->
TwoPartyDealProtocol.Fixer(initiation.replyToParty, initiation.oracleType)
}
services.registerProtocolInitiator(Floater::class) { Fixer(it) }
}
}
}

View File

@ -169,7 +169,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
val protocol = FinalityProtocol(tx, setOf(req), setOf(req.recipient))
return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
smm.add("broadcast", protocol).id,
tx,
"Cash payment transaction generated"
)
@ -203,7 +203,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
val tx = builder.toSignedTransaction(checkSufficientSignatures = false)
val protocol = FinalityProtocol(tx, setOf(req), participants)
return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
smm.add("broadcast", protocol).id,
tx,
"Cash destruction transaction generated"
)
@ -222,7 +222,7 @@ class NodeMonitorService(services: ServiceHubInternal, val smm: StateMachineMana
// Issuance transactions do not need to be notarised, so we can skip directly to broadcasting it
val protocol = BroadcastTransactionProtocol(tx, setOf(req), setOf(req.recipient))
return TransactionBuildResult.ProtocolStarted(
smm.add(BroadcastTransactionProtocol.TOPIC, protocol).id,
smm.add("broadcast", protocol).id,
tx,
"Cash issuance completed"
)

View File

@ -1,17 +1,12 @@
package com.r3corda.node.services.persistence
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.crypto.Party
import com.r3corda.core.failure
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.serialization.serialize
import com.r3corda.core.success
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.*
import java.io.InputStream
@ -39,78 +34,73 @@ object DataVending {
// TODO: I don't like that this needs ServiceHubInternal, but passing in a state machine breaks MockServices because
// the state machine isn't set when this is constructed. [NodeSchedulerService] has the same problem, and both
// should be fixed at the same time.
class Service(services: ServiceHubInternal) : AbstractNodeService(services) {
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
companion object {
val logger = loggerFor<DataVending.Service>()
/**
* Notify a node of a transaction. Normally any notarisation required would happen before this is called.
*/
fun notify(net: MessagingService,
myIdentity: Party,
recipient: NodeInfo,
transaction: SignedTransaction) {
val msg = BroadcastTransactionProtocol.NotifyTxRequestMessage(transaction, emptySet(), myIdentity)
net.send(net.createMessage(TopicSession(BroadcastTransactionProtocol.TOPIC, 0), msg.serialize().bits), recipient.address)
}
}
val storage = services.storageService
class TransactionRejectedError(msg: String) : Exception(msg)
init {
addMessageHandler(FetchTransactionsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleTXRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
{ req: FetchDataProtocol.Request -> handleAttachmentRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
addProtocolHandler(
BroadcastTransactionProtocol.TOPIC,
"Resolving transactions",
{ req: BroadcastTransactionProtocol.NotifyTxRequestMessage ->
ResolveTransactionsProtocol(req.tx, req.replyToParty)
},
{ future, req ->
future.success {
serviceHub.recordTransactions(req.tx)
}.failure { throwable ->
logger.warn("Received invalid transaction ${req.tx.id} from ${req.replyToParty}", throwable)
}
})
services.registerProtocolInitiator(FetchTransactionsProtocol::class, ::FetchTransactionsHandler)
services.registerProtocolInitiator(FetchAttachmentsProtocol::class, ::FetchAttachmentsHandler)
services.registerProtocolInitiator(BroadcastTransactionProtocol::class, ::NotifyTransactionHandler)
}
private fun handleTXRequest(req: FetchDataProtocol.Request): List<SignedTransaction?> {
require(req.hashes.isNotEmpty())
return req.hashes.map {
val tx = storage.validatedTransactions.getTransaction(it)
if (tx == null)
logger.info("Got request for unknown tx $it")
tx
private class FetchTransactionsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
require(it.hashes.isNotEmpty())
it
}
val txs = request.hashes.map {
val tx = serviceHub.storageService.validatedTransactions.getTransaction(it)
if (tx == null)
logger.info("Got request for unknown tx $it")
tx
}
send(otherParty, txs)
}
}
private fun handleAttachmentRequest(req: FetchDataProtocol.Request): List<ByteArray?> {
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
require(req.hashes.isNotEmpty())
return req.hashes.map {
val jar: InputStream? = storage.attachments.openAttachment(it)?.open()
if (jar == null) {
logger.info("Got request for unknown attachment $it")
null
} else {
jar.readBytes()
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
private class FetchAttachmentsHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<FetchDataProtocol.Request>(otherParty).unwrap {
require(it.hashes.isNotEmpty())
it
}
val attachments = request.hashes.map {
val jar: InputStream? = serviceHub.storageService.attachments.openAttachment(it)?.open()
if (jar == null) {
logger.info("Got request for unknown attachment $it")
null
} else {
jar.readBytes()
}
}
send(otherParty, attachments)
}
}
// TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction
// includes us in any outside that list. Potentially just if it includes any outside that list at all.
// TODO: Do we want to be able to reject specific transactions on more complex rules, for example reject incoming
// cash without from unknown parties?
class NotifyTransactionHandler(val otherParty: Party) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() {
val request = receive<BroadcastTransactionProtocol.NotifyTxRequest>(otherParty).unwrap { it }
subProtocol(ResolveTransactionsProtocol(request.tx, otherParty), shareParentSessions = true)
serviceHub.recordTransactions(request.tx)
}
}
}
}

View File

@ -1,53 +1,38 @@
package com.r3corda.node.services.statemachine
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.TopicSession
import java.util.*
import com.r3corda.node.services.statemachine.StateMachineManager.ProtocolSession
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
// TODO revisit when Kotlin 1.1 is released and data classes can extend other classes
interface ProtocolIORequest {
// This is used to identify where we suspended, in case of message mismatch errors and other things where we
// don't have the original stack trace because it's in a suspended fiber.
val stackTraceInCaseOfProblems: StackSnapshot
val topic: String
val session: ProtocolSession
}
interface SendRequest : ProtocolIORequest {
val destination: Party
val payload: Any
val sendSessionID: Long
val uniqueMessageId: UUID
val message: SessionMessage
}
interface ReceiveRequest<T> : ProtocolIORequest {
interface ReceiveRequest<T : SessionMessage> : ProtocolIORequest {
val receiveType: Class<T>
val receiveSessionID: Long
val receiveTopicSession: TopicSession get() = TopicSession(topic, receiveSessionID)
}
data class SendAndReceive<T>(override val topic: String,
override val destination: Party,
override val payload: Any,
override val sendSessionID: Long,
override val uniqueMessageId: UUID,
override val receiveType: Class<T>,
override val receiveSessionID: Long) : SendRequest, ReceiveRequest<T> {
data class SendAndReceive<T : SessionMessage>(override val session: ProtocolSession,
override val message: SessionMessage,
override val receiveType: Class<T>) : SendRequest, ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class ReceiveOnly<T>(override val topic: String,
override val receiveType: Class<T>,
override val receiveSessionID: Long) : ReceiveRequest<T> {
data class ReceiveOnly<T : SessionMessage>(override val session: ProtocolSession,
override val receiveType: Class<T>) : ReceiveRequest<T> {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}
data class SendOnly(override val destination: Party,
override val topic: String,
override val payload: Any,
override val sendSessionID: Long,
override val uniqueMessageId: UUID) : SendRequest {
data class SendOnly(override val session: ProtocolSession, override val message: SessionMessage) : SendRequest {
@Transient
override val stackTraceInCaseOfProblems: StackSnapshot = StackSnapshot()
}

View File

@ -8,16 +8,22 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.random63BitValue
import com.r3corda.core.rootCause
import com.r3corda.core.utilities.UntrustworthyData
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.createDatabaseTransaction
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.PrintWriter
import java.io.StringWriter
import java.sql.SQLException
import java.util.*
import java.util.concurrent.ExecutionException
@ -36,12 +42,26 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
private val loggerName: String)
: Fiber<R>("protocol", scheduler), ProtocolStateMachine<R> {
companion object {
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run {
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER")
field.isAccessible = true
field.get(null)
}
/**
* Return the current [ProtocolStateMachineImpl] or null if executing outside of one.
*/
fun currentStateMachine(): ProtocolStateMachineImpl<*>? = Strand.currentStrand() as? ProtocolStateMachineImpl<*>
}
// These fields shouldn't be serialised, so they are marked @Transient.
@Transient lateinit override var serviceHub: ServiceHubInternal
@Transient internal lateinit var suspendAction: (ProtocolIORequest) -> Unit
@Transient internal lateinit var actionOnSuspend: (ProtocolIORequest) -> Unit
@Transient internal lateinit var actionOnEnd: () -> Unit
@Transient internal var receivedPayload: Any? = null
@Transient internal lateinit var database: Database
@Transient internal var fromCheckpoint: Boolean = false
@Transient private var _logger: Logger? = null
override val logger: Logger get() {
@ -62,18 +82,20 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
}
}
internal val openSessions = HashMap<Pair<ProtocolLogic<*>, Party>, ProtocolSession>()
init {
logic.psm = this
name = id.toString()
}
@Suspendable @Suppress("UNCHECKED_CAST")
@Suspendable
override fun run(): R {
createTransaction()
val result = try {
logic.call()
} catch (t: Throwable) {
actionOnEnd()
_resultFuture?.setException(t)
processException(t)
commitTransaction()
throw ExecutionException(t)
}
@ -106,56 +128,140 @@ class ProtocolStateMachineImpl<R>(override val id: StateMachineRunId,
}
@Suspendable
private fun <T : Any> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): UntrustworthyData<T> {
suspend(receiveRequest)
check(receivedPayload != null) { "Expected to receive something" }
val untrustworthy = UntrustworthyData(receiveRequest.receiveType.cast(receivedPayload))
receivedPayload = null
return untrustworthy
}
@Suspendable
override fun <T : Any> sendAndReceive(topic: String,
destination: Party,
sessionIDForSend: Long,
sessionIDForReceive: Long,
override fun <T : Any> sendAndReceive(otherParty: Party,
payload: Any,
receiveType: Class<T>): UntrustworthyData<T> {
return suspendAndExpectReceive(SendAndReceive(topic, destination, payload, sessionIDForSend, UUID.randomUUID(), receiveType, sessionIDForReceive))
receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
val receivedSessionData = sendAndReceiveInternal(session, sendSessionData, SessionData::class.java)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
}
@Suspendable
override fun <T : Any> receive(topic: String, sessionIDForReceive: Long, receiveType: Class<T>): UntrustworthyData<T> {
return suspendAndExpectReceive(ReceiveOnly(topic, receiveType, sessionIDForReceive))
override fun <T : Any> receive(otherParty: Party,
receiveType: Class<T>,
sessionProtocol: ProtocolLogic<*>): UntrustworthyData<T> {
val receivedSessionData = receiveInternal(getSession(otherParty, sessionProtocol), SessionData::class.java)
return UntrustworthyData(receiveType.cast(receivedSessionData.payload))
}
@Suspendable
override fun send(topic: String, destination: Party, sessionID: Long, payload: Any) {
suspend(SendOnly(destination, topic, payload, sessionID, UUID.randomUUID()))
override fun send(otherParty: Party, payload: Any, sessionProtocol: ProtocolLogic<*>) {
val session = getSession(otherParty, sessionProtocol)
val sendSessionData = createSessionData(session, payload)
sendInternal(session, sendSessionData)
}
private fun createSessionData(session: ProtocolSession, payload: Any): SessionData {
val otherPartySessionId = session.otherPartySessionId
?: throw IllegalStateException("We've somehow held onto an unconfirmed session: $session")
return SessionData(otherPartySessionId, payload)
}
@Suspendable
private fun suspend(protocolIORequest: ProtocolIORequest) {
private fun sendInternal(session: ProtocolSession, message: SessionMessage) {
suspend(SendOnly(session, message))
}
@Suspendable
private fun <T : SessionMessage> receiveInternal(session: ProtocolSession, receiveType: Class<T>): T {
return suspendAndExpectReceive(ReceiveOnly(session, receiveType))
}
@Suspendable
private fun <T : SessionMessage> sendAndReceiveInternal(session: ProtocolSession, message: SessionMessage, receiveType: Class<T>): T {
return suspendAndExpectReceive(SendAndReceive(session, message, receiveType))
}
@Suspendable
private fun getSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>): ProtocolSession {
return openSessions[Pair(sessionProtocol, otherParty)] ?: startNewSession(otherParty, sessionProtocol)
}
@Suspendable
private fun startNewSession(otherParty: Party, sessionProtocol: ProtocolLogic<*>) : ProtocolSession {
val session = ProtocolSession(sessionProtocol, otherParty, random63BitValue(), null)
openSessions[Pair(sessionProtocol, otherParty)] = session
val counterpartyProtocol = sessionProtocol.getCounterpartyMarker(otherParty).name
val sessionInit = SessionInit(session.ourSessionId, serviceHub.storageService.myLegalIdentity, counterpartyProtocol)
val sessionInitResponse = sendAndReceiveInternal(session, sessionInit, SessionInitResponse::class.java)
if (sessionInitResponse is SessionConfirm) {
session.otherPartySessionId = sessionInitResponse.initiatedSessionId
return session
} else {
sessionInitResponse as SessionReject
throw ProtocolSessionException("Party $otherParty rejected session attempt: ${sessionInitResponse.errorMessage}")
}
}
@Suspendable
private fun <T : SessionMessage> suspendAndExpectReceive(receiveRequest: ReceiveRequest<T>): T {
fun getReceivedMessage(): ExistingSessionMessage? = receiveRequest.session.receivedMessages.poll()
val receivedMessage = getReceivedMessage() ?: run {
// Suspend while we wait for the receive
receiveRequest.session.waitingForResponse = true
suspend(receiveRequest)
receiveRequest.session.waitingForResponse = false
getReceivedMessage()
?: throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got nothing: $id $receiveRequest")
}
if (receivedMessage is SessionEnd) {
openSessions.values.remove(receiveRequest.session)
throw ProtocolSessionException("Counterparty on ${receiveRequest.session.otherParty} has prematurly ended")
} else if (receiveRequest.receiveType.isInstance(receivedMessage)) {
return receiveRequest.receiveType.cast(receivedMessage)
} else {
throw IllegalStateException("Was expecting a ${receiveRequest.receiveType.simpleName} but got $receivedMessage: $id $receiveRequest")
}
}
@Suspendable
private fun suspend(ioRequest: ProtocolIORequest) {
commitTransaction()
parkAndSerialize { fiber, serializer ->
logger.trace { "Suspended $id on $ioRequest" }
try {
suspendAction(protocolIORequest)
actionOnSuspend(ioRequest)
} catch (t: Throwable) {
// Do not throw exception again - Quasar completely bins it.
logger.warn("Captured exception which was swallowed by Quasar", t)
actionOnEnd()
_resultFuture?.setException(t)
// TODO When error handling is introduced, look into whether we should be deleting the checkpoint and
// completing the Future
processException(t)
}
}
createTransaction()
}
companion object {
/**
* Retrieves our state machine id if we are running a [ProtocolStateMachineImpl].
*/
fun retrieveCurrentStateMachine(): ProtocolStateMachineImpl<*>? {
return Strand.currentStrand() as? ProtocolStateMachineImpl<*>
private fun processException(t: Throwable) {
actionOnEnd()
_resultFuture?.setException(t)
}
internal fun resume(scheduler: FiberScheduler) {
try {
if (fromCheckpoint) {
logger.info("$id resumed from checkpoint")
fromCheckpoint = false
Fiber.unparkDeserialized(this, scheduler)
} else if (state == State.NEW) {
logger.trace { "$id started" }
start()
} else {
logger.trace { "$id resumed" }
Fiber.unpark(this, QUASAR_UNBLOCKER)
}
} catch (t: Throwable) {
logger.error("$id threw '${t.rootCause}'")
logger.trace {
val s = StringWriter()
t.rootCause.printStackTrace(PrintWriter(s))
"Stack trace of protocol error: $s"
}
}
}
}

View File

@ -3,34 +3,38 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.FiberExecutorScheduler
import co.paralleluniverse.io.serialization.kryo.KryoSerializer
import co.paralleluniverse.strands.Strand
import com.codahale.metrics.Gauge
import com.esotericsoftware.kryo.Kryo
import com.google.common.base.Throwables
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.ThreadBox
import com.r3corda.core.abbreviate
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.TopicSession
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.*
import com.r3corda.core.then
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.debug
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import kotlinx.support.jdk8.collections.removeIf
import org.jetbrains.exposed.sql.Database
import rx.Observable
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.io.PrintWriter
import java.io.StringWriter
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.ExecutionException
import javax.annotation.concurrent.ThreadSafe
@ -48,7 +52,6 @@ import javax.annotation.concurrent.ThreadSafe
* The SMM will always invoke the protocol fibers on the given [AffinityExecutor], regardless of which thread actually
* starts them via [add].
*
* TODO: Session IDs should be set up and propagated automatically, on demand.
* TODO: Consider the issue of continuation identity more deeply: is it a safe assumption that a serialised
* continuation is always unique?
* TODO: Think about how to bring the system to a clean stop so it can be upgraded without any serialised stacks on disk
@ -58,12 +61,19 @@ import javax.annotation.concurrent.ThreadSafe
* TODO: Implement stub/skel classes that provide a basic RPC framework on top of this.
*/
@ThreadSafe
class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableServices: List<Any>,
class StateMachineManager(val serviceHub: ServiceHubInternal,
tokenizableServices: List<Any>,
val checkpointStorage: CheckpointStorage,
val executor: AffinityExecutor,
val database: Database) {
inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor)
companion object {
private val logger = loggerFor<StateMachineManager>()
internal val sessionTopic = TopicSession("platform.session")
}
val scheduler = FiberScheduler()
data class Change(
@ -95,6 +105,9 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private val totalStartedProtocols = metrics.counter("Protocols.Started")
private val totalFinishedProtocols = metrics.counter("Protocols.Finished")
private val openSessions = ConcurrentHashMap<Long, ProtocolSession>()
private val recentlyClosedSessions = ConcurrentHashMap<Long, Party>()
// Context for tokenized services in checkpoints
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo())
@ -119,6 +132,17 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val changes: Observable<Change>
get() = mutex.content.changesPublisher
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
}
}
fun start() {
restoreFibersFromCheckpoints()
serviceHub.networkMapCache.mapServiceRegistered.then(executor) { resumeRestoredFibers() }
}
/**
* Atomic get snapshot + subscribe. This is needed so we don't miss updates between subscriptions to [changes] and
* calls to [allStateMachines]
@ -131,69 +155,99 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
}
// Used to work around a small limitation in Quasar.
private val QUASAR_UNBLOCKER = run {
val field = Fiber::class.java.getDeclaredField("SERIALIZER_BLOCKER")
field.isAccessible = true
field.get(null)
}
init {
Fiber.setDefaultUncaughtExceptionHandler { fiber, throwable ->
(fiber as ProtocolStateMachineImpl<*>).logger.error("Caught exception from protocol", throwable)
}
}
fun start() {
checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) }
serviceHub.networkMapCache.mapServiceRegistered.then(executor) {
mutex.locked {
started = true
stateMachines.forEach { restartFiber(it.key, it.value) }
}
}
}
private fun createFiberForCheckpoint(checkpoint: Checkpoint) {
if (!checkpoint.fiberCreated) {
val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint })
}
}
private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
if (checkpoint.request is ReceiveRequest<*>) {
val topicSession = checkpoint.request.receiveTopicSession
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
iterateOnResponse(fiber, checkpoint.serialisedFiber, checkpoint.request) {
try {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, topicSession, fiber)
private fun restoreFibersFromCheckpoints() {
mutex.locked {
checkpointStorage.checkpoints.forEach {
// If a protocol is added before start() then don't attempt to restore it
if (!stateMachines.containsValue(it)) {
val fiber = deserializeFiber(it.serialisedFiber)
initFiber(fiber)
stateMachines[fiber] = it
}
}
if (checkpoint.request is SendRequest) {
sendMessage(fiber, checkpoint.request)
}
}
private fun resumeRestoredFibers() {
mutex.locked {
started = true
stateMachines.keys.forEach { resumeRestoredFiber(it) }
}
serviceHub.networkService.addMessageHandler(sessionTopic, executor) { message, reg ->
executor.checkOnThread()
val sessionMessage = message.data.deserialize<SessionMessage>()
when (sessionMessage) {
is ExistingSessionMessage -> onExistingSessionMessage(sessionMessage)
is SessionInit -> onSessionInit(sessionMessage)
}
}
}
private fun resumeRestoredFiber(fiber: ProtocolStateMachineImpl<*>) {
fiber.openSessions.values.forEach { openSessions[it.ourSessionId] = it }
if (fiber.openSessions.values.any { it.waitingForResponse }) {
fiber.logger.info("Restored fiber pending on receive ${fiber.id}}")
} else {
resumeFiber(fiber)
}
}
private fun onExistingSessionMessage(message: ExistingSessionMessage) {
val session = openSessions[message.recipientSessionId]
if (session != null) {
session.psm.logger.trace { "${session.psm.id} received $message on $session" }
if (message is SessionEnd) {
openSessions.remove(message.recipientSessionId)
}
session.receivedMessages += message
if (session.waitingForResponse) {
updateCheckpoint(session.psm)
resumeFiber(session.psm)
}
} else {
fiber.logger.info("Restored ${fiber.logic} - it was not waiting on any message; received payload: ${checkpoint.receivedPayload.toString().abbreviate(50)}")
executor.executeASAP {
if (checkpoint.request is SendRequest) {
sendMessage(fiber, checkpoint.request)
}
iterateStateMachine(fiber, checkpoint.receivedPayload) {
try {
Fiber.unparkDeserialized(fiber, scheduler)
} catch (e: Throwable) {
logError(e, it, null, fiber)
}
val otherParty = recentlyClosedSessions.remove(message.recipientSessionId)
if (otherParty != null) {
if (message is SessionConfirm) {
logger.debug { "Received session confirmation but associated fiber has already terminated, so sending session end" }
sendSessionMessage(otherParty, SessionEnd(message.initiatedSessionId), null)
} else {
logger.trace { "Ignoring session end message for already closed session: $message" }
}
} else {
logger.warn("Received a session message for unknown session: $message")
}
}
}
private fun onSessionInit(sessionInit: SessionInit) {
logger.trace { "Received $sessionInit" }
//TODO Verify the other party are who they say they are from the TLS subsystem
val otherParty = sessionInit.initiatorParty
val otherPartySessionId = sessionInit.initiatorSessionId
try {
val markerClass = Class.forName(sessionInit.protocolName)
val protocolFactory = serviceHub.getProtocolFactory(markerClass)
if (protocolFactory != null) {
val protocol = protocolFactory(otherParty)
val psm = createFiber(sessionInit.protocolName, protocol)
val session = ProtocolSession(protocol, otherParty, random63BitValue(), otherPartySessionId)
openSessions[session.ourSessionId] = session
psm.openSessions[Pair(protocol, otherParty)] = session
updateCheckpoint(psm)
sendSessionMessage(otherParty, SessionConfirm(otherPartySessionId, session.ourSessionId), psm)
psm.logger.debug { "Starting new ${psm.id} from $sessionInit on $session" }
startFiber(psm)
} else {
logger.warn("Unknown protocol marker class in $sessionInit")
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Don't know ${markerClass.name}"), null)
}
} catch (e: Exception) {
logger.warn("Received invalid $sessionInit", e)
sendSessionMessage(otherParty, SessionReject(otherPartySessionId, "Unable to establish session"), null)
}
}
private fun serializeFiber(fiber: ProtocolStateMachineImpl<*>): SerializedBytes<ProtocolStateMachineImpl<*>> {
// We don't use the passed-in serializer here, because we need to use our own augmented Kryo.
val kryo = quasarKryo()
// add the map of tokens -> tokenizedServices to the kyro context
SerializeAsTokenSerializer.setContext(kryo, serializationContext)
@ -204,7 +258,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
val kryo = quasarKryo()
// put the map of token -> tokenized into the kryo context
SerializeAsTokenSerializer.setContext(kryo, serializationContext)
return serialisedFiber.deserialize(kryo)
return serialisedFiber.deserialize(kryo).apply { fromCheckpoint = true }
}
private fun quasarKryo(): Kryo {
@ -212,70 +266,51 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
return createKryo(serializer.kryo)
}
private fun logError(e: Throwable, payload: Any?, topicSession: TopicSession?, psm: ProtocolStateMachineImpl<*>) {
psm.logger.error("Protocol state machine ${psm.javaClass.name} threw '${Throwables.getRootCause(e)}' " +
"when handling a message of type ${payload?.javaClass?.name} on queue $topicSession")
if (psm.logger.isTraceEnabled) {
val s = StringWriter()
Throwables.getRootCause(e).printStackTrace(PrintWriter(s))
psm.logger.trace("Stack trace of protocol error is: $s")
}
private fun <T> createFiber(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachineImpl<T> {
val id = StateMachineRunId.createRandom()
return ProtocolStateMachineImpl(id, logic, scheduler, loggerName).apply { initFiber(this) }
}
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint {
private fun initFiber(psm: ProtocolStateMachineImpl<*>) {
psm.database = database
psm.serviceHub = serviceHub
psm.suspendAction = { request ->
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
onNextSuspend(psm, request)
psm.actionOnSuspend = { ioRequest ->
updateCheckpoint(psm)
processIORequest(ioRequest)
}
psm.actionOnEnd = {
psm.logic.progressTracker?.currentStep = ProgressTracker.DONE
mutex.locked {
val finalCheckpoint = stateMachines.remove(psm)
if (finalCheckpoint != null) {
checkpointStorage.removeCheckpoint(finalCheckpoint)
}
stateMachines.remove(psm)?.let { checkpointStorage.removeCheckpoint(it) }
totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE)
}
endAllFiberSessions(psm)
}
val checkpoint = startingCheckpoint()
checkpoint.fiberCreated = true
totalStartedProtocols.inc()
mutex.locked {
stateMachines[psm] = checkpoint
totalStartedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.ADD)
}
return checkpoint
}
/**
* Kicks off a brand new state machine of the given class. It will log with the named logger.
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*/
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val id = StateMachineRunId.createRandom()
val fiber = ProtocolStateMachineImpl(id, logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion
val checkpoint = initFiber(fiber) {
val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
checkpoint
}
checkpointStorage.addCheckpoint(checkpoint)
mutex.locked { // If we are not started then our checkpoint will be picked up during start
if (!started) {
return fiber
}
}
try {
executor.executeASAP {
iterateStateMachine(fiber, null) {
fiber.start()
private fun endAllFiberSessions(psm: ProtocolStateMachineImpl<*>) {
openSessions.values.removeIf { session ->
if (session.psm == psm) {
val otherPartySessionId = session.otherPartySessionId
if (otherPartySessionId != null) {
sendSessionMessage(session.otherParty, SessionEnd(otherPartySessionId), psm)
}
recentlyClosedSessions[session.ourSessionId] = session.otherParty
true
} else {
false
}
}
}
private fun startFiber(fiber: ProtocolStateMachineImpl<*>) {
try {
resumeFiber(fiber)
} catch (e: ExecutionException) {
// There are two ways we can take exceptions in this method:
//
@ -290,17 +325,29 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
if (e.cause !is ExecutionException)
throw e
}
}
/**
* Kicks off a brand new state machine of the given class. It will log with the named logger.
* The state machine will be persisted when it suspends, with automated restart if the StateMachineManager is
* restarted with checkpointed state machines in the storage service.
*/
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val fiber = createFiber(loggerName, logic)
updateCheckpoint(fiber)
// If we are not started then our checkpoint will be picked up during start
mutex.locked {
if (started) {
startFiber(fiber)
}
}
return fiber
}
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
request: ProtocolIORequest?,
receivedPayload: Any?) {
val newCheckpoint = Checkpoint(serialisedFiber, request, receivedPayload)
val previousCheckpoint = mutex.locked {
stateMachines.put(psm, newCheckpoint)
}
private fun updateCheckpoint(psm: ProtocolStateMachineImpl<*>) {
check(psm.state != Strand.State.RUNNING) { "Fiber cannot be running when checkpointing" }
val newCheckpoint = Checkpoint(serializeFiber(psm))
val previousCheckpoint = mutex.locked { stateMachines.put(psm, newCheckpoint) }
if (previousCheckpoint != null) {
checkpointStorage.removeCheckpoint(previousCheckpoint)
}
@ -308,90 +355,70 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
checkpointingMeter.mark()
}
private fun iterateStateMachine(psm: ProtocolStateMachineImpl<*>,
receivedPayload: Any?,
resumeAction: (Any?) -> Unit) {
executor.checkOnThread()
psm.receivedPayload = receivedPayload
psm.logger.trace { "Waking up fiber ${psm.id} ${psm.logic}" }
resumeAction(receivedPayload)
}
private fun onNextSuspend(psm: ProtocolStateMachineImpl<*>, request: ProtocolIORequest) {
val serialisedFiber = serializeFiber(psm)
updateCheckpoint(psm, serialisedFiber, request, null)
// We have a request to do something: send, receive, or send-and-receive.
if (request is ReceiveRequest<*>) {
// Prepare a listener on the network that runs in the background thread when we receive a message.
prepareToReceiveForRequest(psm, serialisedFiber, request)
}
if (request is SendRequest) {
performSendRequest(psm, request)
private fun resumeFiber(psm: ProtocolStateMachineImpl<*>) {
executor.executeASAP {
psm.resume(scheduler)
}
}
private fun prepareToReceiveForRequest(psm: ProtocolStateMachineImpl<*>, serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>, request: ReceiveRequest<*>) {
executor.checkOnThread()
val queueID = request.receiveTopicSession
psm.logger.trace { "Preparing to receive message of type ${request.receiveType.name} on queue $queueID" }
iterateOnResponse(psm, serialisedFiber, request) {
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, it, queueID, psm)
private fun processIORequest(ioRequest: ProtocolIORequest) {
if (ioRequest is SendRequest) {
if (ioRequest.message is SessionInit) {
openSessions[ioRequest.session.ourSessionId] = ioRequest.session
}
sendSessionMessage(ioRequest.session.otherParty, ioRequest.message, ioRequest.session.psm)
if (ioRequest !is ReceiveRequest<*>) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
resumeFiber(ioRequest.session.psm)
}
}
}
private fun performSendRequest(psm: ProtocolStateMachineImpl<*>, request: SendRequest) {
val topicSession = sendMessage(psm, request)
private fun sendSessionMessage(party: Party, message: SessionMessage, psm: ProtocolStateMachineImpl<*>?) {
val node = serviceHub.networkMapCache.getNodeByLegalName(party.name)
?: throw IllegalArgumentException("Don't know about party $party")
val logger = psm?.logger ?: logger
logger.trace { "${psm?.id} sending $message to party $party" }
serviceHub.networkService.send(sessionTopic, message, node.address)
}
if (request is SendOnly) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
iterateStateMachine(psm, null) {
try {
Fiber.unpark(psm, QUASAR_UNBLOCKER)
} catch(e: Throwable) {
logError(e, request.payload, topicSession, psm)
}
}
interface SessionMessage
interface ExistingSessionMessage: SessionMessage {
val recipientSessionId: Long
}
data class SessionInit(val initiatorSessionId: Long, val initiatorParty: Party, val protocolName: String) : SessionMessage
interface SessionInitResponse : ExistingSessionMessage
data class SessionConfirm(val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionReject(val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse {
override val recipientSessionId: Long get() = initiatorSessionId
}
data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage {
override fun toString(): String {
return "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=${payload.toString().abbreviate(100)})"
}
}
private fun sendMessage(psm: ProtocolStateMachineImpl<*>, request: SendRequest): TopicSession {
val topicSession = TopicSession(request.topic, request.sendSessionID)
val payload = request.payload
psm.logger.trace { "Sending message of type ${payload.javaClass.name} using queue $topicSession to ${request.destination} (${payload.toString().abbreviate(50)})" }
val node = serviceHub.networkMapCache.getNodeByLegalName(request.destination.name) ?:
throw IllegalArgumentException("Don't know about ${request.destination} but trying to send a message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})", request.stackTraceInCaseOfProblems)
serviceHub.networkService.send(topicSession, payload, node.address, request.uniqueMessageId)
return topicSession
data class SessionEnd(override val recipientSessionId: Long) : ExistingSessionMessage
data class ProtocolSession(val protocol: ProtocolLogic<*>,
val otherParty: Party,
val ourSessionId: Long,
var otherPartySessionId: Long?,
@Volatile var waitingForResponse: Boolean = false) {
val receivedMessages = ConcurrentLinkedQueue<ExistingSessionMessage>()
val psm: ProtocolStateMachineImpl<*> get() = protocol.psm as ProtocolStateMachineImpl<*>
}
/**
* Add a trigger to the [MessagingService] to deserialize the fiber and pass message content to it, once a message is
* received.
*/
private fun iterateOnResponse(psm: ProtocolStateMachineImpl<*>,
serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
request: ReceiveRequest<*>,
resumeAction: (Any?) -> Unit) {
val topicSession = request.receiveTopicSession
serviceHub.networkService.runOnNextMessage(topicSession, executor) { netMsg ->
// Assertion to ensure we don't execute on the wrong thread.
executor.checkOnThread()
// TODO: This is insecure: we should not deserialise whatever we find and *then* check.
// We should instead verify as we read the data that it's what we are expecting and throw as early as
// possible. We only do it this way for convenience during the prototyping stage. Note that this means
// we could simply not require the programmer to specify the expected return type at all, and catch it
// at the last moment when we do the downcast. However this would make protocol code harder to read and
// make it more difficult to migrate to a more explicit serialisation scheme later.
val payload = netMsg.data.deserialize<Any>()
check(request.receiveType.isInstance(payload)) { "Expected message of type ${request.receiveType.name} but got ${payload.javaClass.name}" }
// Update the fiber's checkpoint so that it's no longer waiting on a response, but rather has the received payload
updateCheckpoint(psm, serialisedFiber, null, payload)
psm.logger.trace { "Received message of type ${payload.javaClass.name} on $topicSession (${payload.toString().abbreviate(50)})" }
iterateStateMachine(psm, payload, resumeAction)
}
}
}

View File

@ -1,12 +1,11 @@
package com.r3corda.node.services.transactions
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider
import com.r3corda.node.services.api.AbstractNodeService
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.protocols.NotaryProtocol.TOPIC
import kotlin.reflect.KClass
/**
* A Notary service acts as the final signer of a transaction ensuring two things:
@ -17,22 +16,18 @@ import com.r3corda.protocols.NotaryProtocol.TOPIC
*
* This is the base implementation that can be customised with specific Notary transaction commit protocol.
*/
abstract class NotaryService(services: ServiceHubInternal,
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : AbstractNodeService(services) {
abstract class NotaryService(markerClass: KClass<out NotaryProtocol.Client>, services: ServiceHubInternal) : SingletonSerializeAsToken() {
// Do not specify this as an advertised service. Use a concrete implementation.
// TODO: We do not want a service type that cannot be used. Fix the type system abuse here.
object Type : ServiceType("corda.notary")
abstract val logger: org.slf4j.Logger
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
abstract val protocolFactory: NotaryProtocol.Factory
init {
addProtocolHandler(TOPIC, TOPIC) { req: NotaryProtocol.Handshake ->
protocolFactory.create(req.replyToParty, timestampChecker, uniquenessProvider)
}
services.registerProtocolInitiator(markerClass) { createProtocol(it) }
}
/** Implement a factory that specifies the transaction commit protocol for the notary service to use */
abstract fun createProtocol(otherParty: Party): NotaryProtocol.Service
}

View File

@ -1,5 +1,6 @@
package com.r3corda.node.services.transactions
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.node.services.TimestampChecker
import com.r3corda.core.node.services.UniquenessProvider
@ -9,11 +10,13 @@ import com.r3corda.protocols.NotaryProtocol
/** A simple Notary service that does not perform transaction validation */
class SimpleNotaryService(services: ServiceHubInternal,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) {
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.Client::class, services) {
object Type : ServiceType("corda.notary.simple")
override val logger = loggerFor<SimpleNotaryService>()
override val protocolFactory = NotaryProtocol.DefaultFactory
override fun createProtocol(otherParty: Party): NotaryProtocol.Service {
return NotaryProtocol.Service(otherParty, timestampChecker, uniquenessProvider)
}
}

View File

@ -11,17 +11,13 @@ import com.r3corda.protocols.ValidatingNotaryProtocol
/** A Notary service that validates the transaction chain of he submitted transaction before committing it */
class ValidatingNotaryService(services: ServiceHubInternal,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider) : NotaryService(services, timestampChecker, uniquenessProvider) {
val timestampChecker: TimestampChecker,
val uniquenessProvider: UniquenessProvider) : NotaryService(NotaryProtocol.ValidatingClient::class, services) {
object Type : ServiceType("corda.notary.validating")
override val logger = loggerFor<ValidatingNotaryService>()
override val protocolFactory = object : NotaryProtocol.Factory {
override fun create(otherSide: Party,
timestampChecker: TimestampChecker,
uniquenessProvider: UniquenessProvider): NotaryProtocol.Service {
return ValidatingNotaryProtocol(otherSide, timestampChecker, uniquenessProvider)
}
override fun createProtocol(otherParty: Party): ValidatingNotaryProtocol {
return ValidatingNotaryProtocol(otherParty, timestampChecker, uniquenessProvider)
}
}

View File

@ -10,6 +10,7 @@ import com.r3corda.core.days
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
@ -23,7 +24,6 @@ import com.r3corda.node.services.persistence.PerFileTransactionStorage
import com.r3corda.node.services.persistence.StorageServiceImpl
import com.r3corda.protocols.TwoPartyTradeProtocol.Buyer
import com.r3corda.protocols.TwoPartyTradeProtocol.Seller
import com.r3corda.protocols.TwoPartyTradeProtocol.TOPIC
import com.r3corda.testing.*
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
@ -89,11 +89,11 @@ class TwoPartyTradeProtocolTests {
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey, notaryNode.storage.myLegalIdentityKey)
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
// TODO: Verify that the result was inserted into the transaction database.
// assertEquals(bobResult.get(), aliceNode.storage.validatedTransactions[aliceResult.get().id])
assertEquals(aliceResult.get(), bobResult.get())
assertEquals(aliceResult.get(), bobPsm.get().resultFuture.get())
aliceNode.stop()
bobNode.stop()
@ -120,21 +120,19 @@ class TwoPartyTradeProtocolTests {
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null).second
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerFuture
val aliceFuture = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerResult
// Everything is on this thread so we can now step through the protocol one step at a time.
// Seller Alice already sent a message to Buyer Bob. Pump once:
fun pumpAlice() = (aliceNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false)
fun pumpBob() = (bobNode.net as InMemoryMessagingNetwork.InMemoryMessaging).pumpReceive(false)
pumpBob()
bobNode.pumpReceive(false)
// Bob sends a couple of queries for the dependencies back to Alice. Alice reponds.
pumpAlice()
pumpBob()
pumpAlice()
pumpBob()
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
aliceNode.pumpReceive(false)
bobNode.pumpReceive(false)
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints).hasSize(1)
@ -147,7 +145,7 @@ class TwoPartyTradeProtocolTests {
// Alice doesn't know that and carries on: she wants to know about the cash transactions he's trying to use.
// She will wait around until Bob comes back.
assertThat(pumpAlice()).isNotNull()
assertThat(aliceNode.pumpReceive(false)).isNotNull()
// ... bring the node back up ... the act of constructing the SMM will re-register the message handlers
// that Bob was waiting on before the reboot occurred.
@ -309,16 +307,16 @@ class TwoPartyTradeProtocolTests {
val attachmentID = attachment(ByteArrayInputStream(stream.toByteArray()))
val bobsFakeCash = fillUpForBuyer(false, bobNode.keyManagement.freshKey().public).second
val bobsSignedTxns = insertFakeTransactions(bobsFakeCash, bobNode.services)
insertFakeTransactions(bobsFakeCash, bobNode.services)
val alicesFakePaper = fillUpForSeller(false, aliceNode.storage.myLegalIdentity.owningKey,
1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID).second
val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
insertFakeTransactions(alicesFakePaper, aliceNode.services, aliceNode.storage.myLegalIdentityKey)
net.runNetwork() // Clear network map registration messages
val aliceTxStream = aliceNode.storage.validatedTransactions.track().second
val aliceTxMappings = aliceNode.storage.stateMachineRecordedTransactionMapping.track().second
val (bobResult, aliceResult, bobSmId, aliceSmId) = runBuyerAndSeller("alice's paper".outputStateAndRef())
val aliceSmId = runBuyerAndSeller("alice's paper".outputStateAndRef()).sellerId
net.runNetwork()
@ -367,21 +365,20 @@ class TwoPartyTradeProtocolTests {
}
}
data class RunResult(
val buyerFuture: Future<SignedTransaction>,
val sellerFuture: Future<SignedTransaction>,
val buyerSmId: StateMachineRunId,
val sellerSmId: StateMachineRunId
private data class RunResult(
// The buyer is not created immediately, only when the seller starts running
val buyer: Future<ProtocolStateMachine<SignedTransaction>>,
val sellerResult: Future<SignedTransaction>,
val sellerId: StateMachineRunId
)
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>): RunResult {
val buyer = Buyer(aliceNode.info.identity, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
private fun runBuyerAndSeller(assetToSell: StateAndRef<OwnableState>) : RunResult {
val buyerFuture = bobNode.initiateSingleShotProtocol(Seller::class) { otherParty ->
Buyer(otherParty, notaryNode.info.identity, 1000.DOLLARS, CommercialPaper.State::class.java)
}
val seller = Seller(bobNode.info.identity, notaryNode.info, assetToSell, 1000.DOLLARS, ALICE_KEY)
connectProtocols(buyer, seller)
// We start the Buyer first, as the Seller sends the first message
val buyerPsm = bobNode.smm.add("$TOPIC.buyer", buyer)
val sellerPsm = aliceNode.smm.add("$TOPIC.seller", seller)
return RunResult(buyerPsm.resultFuture, sellerPsm.resultFuture, buyerPsm.id, sellerPsm.id)
val sellerResultFuture = aliceNode.smm.add("seller", seller).resultFuture
return RunResult(buyerFuture, sellerResultFuture, seller.psm.id)
}
private fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.runWithError(
@ -404,7 +401,7 @@ class TwoPartyTradeProtocolTests {
net.runNetwork() // Clear network map registration messages
val (bobResult, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
val (bobPsm, aliceResult) = runBuyerAndSeller("alice's paper".outputStateAndRef())
net.runNetwork()
@ -412,7 +409,7 @@ class TwoPartyTradeProtocolTests {
if (bobError)
aliceResult.get()
else
bobResult.get()
bobPsm.get().resultFuture.get()
}
assertTrue(e.cause is TransactionVerificationException)
assertNotNull(e.cause!!.cause)
@ -506,6 +503,7 @@ class TwoPartyTradeProtocolTests {
return Pair(vault, listOf(ap))
}
class RecordingTransactionStorage(val delegate: TransactionStorage) : TransactionStorage {
override fun track(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
return delegate.track()
@ -530,4 +528,5 @@ class TwoPartyTradeProtocolTests {
data class Add(val transaction: SignedTransaction) : TxRecord
data class Get(val id: SecureHash) : TxRecord
}
}

View File

@ -2,23 +2,24 @@ package com.r3corda.node.services
import com.codahale.metrics.MetricRegistry
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.crypto.Party
import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.api.MonitoringService
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.testing.node.MockNetworkMapCache
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.DataVending
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.testing.node.MockStorageService
import com.r3corda.testing.MOCK_IDENTITY_SERVICE
import com.r3corda.testing.node.MockNetworkMapCache
import com.r3corda.testing.node.MockStorageService
import java.time.Clock
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.KClass
@Suppress("LeakingThis")
open class MockServiceHubInternal(
@ -28,7 +29,6 @@ open class MockServiceHubInternal(
val identity: IdentityService? = MOCK_IDENTITY_SERVICE,
val storage: TxWritableStorageService? = MockStorageService(),
val mapCache: NetworkMapCache? = MockNetworkMapCache(),
val mapService: NetworkMapService? = null,
val scheduler: SchedulerService? = null,
val overrideClock: Clock? = NodeClock(),
val protocolFactory: ProtocolLogicRefFactory? = ProtocolLogicRefFactory()
@ -57,14 +57,10 @@ open class MockServiceHubInternal(
private val txStorageService: TxWritableStorageService
get() = storage ?: throw UnsupportedOperationException()
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
private val protocolFactories = ConcurrentHashMap<Class<*>, (Party) -> ProtocolLogic<*>>()
lateinit var smm: StateMachineManager
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
return smm.add(loggerName, logic).resultFuture
}
init {
if (net != null && storage != null) {
// Creating this class is sufficient, we don't have to store it anywhere, because it registers a listener
@ -72,4 +68,18 @@ open class MockServiceHubInternal(
DataVending.Service(this)
}
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) = recordTransactionsInternal(txStorageService, txs)
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
return smm.add(loggerName, logic).resultFuture
}
override fun registerProtocolInitiator(markerClass: KClass<*>, protocolFactory: (Party) -> ProtocolLogic<*>) {
protocolFactories[markerClass.java] = protocolFactory
}
override fun getProtocolFactory(markerClass: Class<*>): ((Party) -> ProtocolLogic<*>)? {
return protocolFactories[markerClass]
}
}

View File

@ -3,7 +3,6 @@ package com.r3corda.node.services
import com.google.common.jimfs.Configuration
import com.google.common.jimfs.Jimfs
import com.r3corda.core.contracts.*
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.days
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.recordTransactions
@ -12,18 +11,16 @@ import com.r3corda.core.protocols.ProtocolLogicRef
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.LogHelper
import com.r3corda.testing.node.TestClock
import com.r3corda.node.services.events.NodeSchedulerService
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.node.services.vault.NodeVaultService
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.configureDatabase
import com.r3corda.testing.ALICE_KEY
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockKeyManagementService
import com.r3corda.testing.node.TestClock
import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
@ -34,7 +31,9 @@ import java.nio.file.FileSystem
import java.security.PublicKey
import java.time.Clock
import java.time.Instant
import java.util.concurrent.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.test.assertTrue
class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
@ -128,8 +127,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
(serviceHub as TestReference).testReference.calls += increment
(serviceHub as TestReference).testReference.countDown.countDown()
}
override val topic: String get() = throw UnsupportedOperationException()
}
class Command : TypeOnlyCommandData()

View File

@ -9,7 +9,6 @@ import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.NotaryChangeProtocol
import com.r3corda.protocols.NotaryChangeProtocol.Instigator
import com.r3corda.protocols.StateReplacementException
import com.r3corda.protocols.StateReplacementRefused
@ -49,7 +48,7 @@ class NotaryChangeTests {
val state = issueState(clientNodeA)
val newNotary = newNotaryNode.info.identity
val protocol = Instigator(state, newNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork()
@ -62,7 +61,7 @@ class NotaryChangeTests {
val state = issueMultiPartyState(clientNodeA, clientNodeB)
val newNotary = newNotaryNode.info.identity
val protocol = Instigator(state, newNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork()
@ -78,7 +77,7 @@ class NotaryChangeTests {
val state = issueMultiPartyState(clientNodeA, clientNodeB)
val newEvilNotary = Party("Evil Notary", generateKeyPair().public)
val protocol = Instigator(state, newEvilNotary)
val future = clientNodeA.services.startProtocol(NotaryChangeProtocol.TOPIC, protocol)
val future = clientNodeA.services.startProtocol("notary-change", protocol)
net.runNetwork()

View File

@ -1,17 +1,19 @@
package com.r3corda.node.services
import com.r3corda.core.contracts.Timestamp
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.protocols.NotaryError
import com.r3corda.protocols.NotaryException
import com.r3corda.protocols.NotaryProtocol
import com.r3corda.testing.MINI_CORP_KEY
import com.r3corda.testing.node.MockNetwork
import org.junit.Before
import org.junit.Test
import java.time.Instant
@ -45,10 +47,7 @@ class NotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runNotaryClient(stx)
val signature = future.get()
signature.verifyWithECDSA(stx.txBits)
}
@ -61,10 +60,7 @@ class NotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runNotaryClient(stx)
val signature = future.get()
signature.verifyWithECDSA(stx.txBits)
}
@ -78,16 +74,13 @@ class NotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runNotaryClient(stx)
val ex = assertFailsWith(ExecutionException::class) { future.get() }
val error = (ex.cause as NotaryException).error
assertTrue(error is NotaryError.TimestampInvalid)
}
@Test fun `should report conflict for a duplicate transaction`() {
val stx = run {
val inputState = issueState(clientNode)
@ -98,8 +91,8 @@ class NotaryServiceTests {
val firstSpend = NotaryProtocol.Client(stx)
val secondSpend = NotaryProtocol.Client(stx)
clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.first", firstSpend)
val future = clientNode.services.startProtocol("${NotaryProtocol.TOPIC}.second", secondSpend)
clientNode.services.startProtocol("notary.first", firstSpend)
val future = clientNode.services.startProtocol("notary.second", secondSpend)
net.runNetwork()
@ -108,4 +101,12 @@ class NotaryServiceTests {
assertEquals(notaryError.tx, stx.tx)
notaryError.conflict.verified()
}
private fun runNotaryClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol("notary-test", protocol)
net.runNetwork()
return future
}
}

View File

@ -1,8 +1,11 @@
package com.r3corda.node.services
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.Command
import com.r3corda.core.contracts.DummyContract
import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.node.services.network.NetworkMapService
@ -44,9 +47,7 @@ class ValidatingNotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runValidatingClient(stx)
val ex = assertFailsWith(ExecutionException::class) { future.get() }
val notaryError = (ex.cause as NotaryException).error
@ -64,9 +65,7 @@ class ValidatingNotaryServiceTests {
tx.toSignedTransaction(false)
}
val protocol = NotaryProtocol.Client(stx)
val future = clientNode.services.startProtocol(NotaryProtocol.TOPIC, protocol)
net.runNetwork()
val future = runValidatingClient(stx)
val ex = assertFailsWith(ExecutionException::class) { future.get() }
val notaryError = (ex.cause as NotaryException).error
@ -75,4 +74,11 @@ class ValidatingNotaryServiceTests {
val missingKeys = (notaryError as NotaryError.SignaturesMissing).missingSigners
assertEquals(setOf(expectedMissingKey), missingKeys)
}
private fun runValidatingClient(stx: SignedTransaction): ListenableFuture<DigitalSignature.LegallyIdentifiable> {
val protocol = NotaryProtocol.ValidatingClient(stx)
val future = clientNode.services.startProtocol("notary", protocol)
net.runNetwork()
return future
}
}

View File

@ -1,13 +1,20 @@
package com.r3corda.node.services.persistence
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.contracts.asset.Cash
import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.Issued
import com.r3corda.core.contracts.TransactionType
import com.r3corda.core.contracts.USD
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.testing.node.MockNetwork
import com.r3corda.node.services.persistence.DataVending.Service.NotifyTransactionHandler
import com.r3corda.protocols.BroadcastTransactionProtocol.NotifyTxRequest
import com.r3corda.testing.MEGA_CORP
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
@ -38,9 +45,8 @@ class DataVendingServiceTests {
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
val tx = ptx.toSignedTransaction()
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
vaultServiceNode.info, tx)
network.runNetwork()
registerNode.sendNotifyTx(tx, vaultServiceNode)
// Check the transaction is in the receiving node
val actual = vaultServiceNode.services.vaultService.currentVault.states.singleOrNull()
@ -67,11 +73,23 @@ class DataVendingServiceTests {
ptx.signWith(registerNode.services.storageService.myLegalIdentityKey)
val tx = ptx.toSignedTransaction(false)
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
DataVending.Service.notify(registerNode.net, registerNode.services.storageService.myLegalIdentity,
vaultServiceNode.info, tx)
network.runNetwork()
registerNode.sendNotifyTx(tx, vaultServiceNode)
// Check the transaction is not in the receiving node
assertEquals(0, vaultServiceNode.services.vaultService.currentVault.states.toList().size)
}
}
private fun MockNode.sendNotifyTx(tx: SignedTransaction, walletServiceNode: MockNode) {
walletServiceNode.services.registerProtocolInitiator(NotifyTxProtocol::class, ::NotifyTransactionHandler)
services.startProtocol("notify-tx", NotifyTxProtocol(walletServiceNode.info.identity, tx))
network.runNetwork()
}
private class NotifyTxProtocol(val otherParty: Party, val stx: SignedTransaction) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(otherParty, NotifyTxRequest(stx, emptySet()))
}
}

View File

@ -10,12 +10,14 @@ import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.file.FileSystem
import java.nio.file.Files
import java.nio.file.Path
class PerFileCheckpointStorageTests {
val fileSystem = Jimfs.newFileSystem(unix())
val storeDir = fileSystem.getPath("store")
val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
val storeDir: Path = fileSystem.getPath("store")
lateinit var checkpointStorage: PerFileCheckpointStorage
@Before
@ -92,6 +94,6 @@ class PerFileCheckpointStorageTests {
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)), null, null)
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
}

View File

@ -2,14 +2,20 @@ package com.r3corda.node.services.statemachine
import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolSessionException
import com.r3corda.core.random63BitValue
import com.r3corda.testing.connectProtocols
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.statemachine.StateMachineManager.SessionData
import com.r3corda.node.services.statemachine.StateMachineManager.SessionMessage
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.MockNetwork
import com.r3corda.testing.node.MockNetwork.MockNode
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After
import org.junit.Before
import org.junit.Test
@ -50,18 +56,18 @@ class StateMachineManagerTests {
}
@Test
fun `protocol suspended just after receiving payload`() {
val topic = "send-and-receive"
fun `protocol restarted just after receiving payload`() {
node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node1.smm.add("test", sendProtocol)
node2.smm.add("test", receiveProtocol)
net.runNetwork()
node1.smm.add("test", SendProtocol(payload, node2.info.identity))
// We push through just enough messages to get only the SessionData sent
// TODO We should be able to give runNetwork a predicate for when to stop
net.runNetwork(2)
node2.stop()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
@ -83,7 +89,7 @@ class StateMachineManagerTests {
node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id)
val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first
val restoredProtocol = node3.getSingleProtocol<ProtocolNoBlocking>().first
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
@ -99,43 +105,44 @@ class StateMachineManagerTests {
@Test
fun `protocol loaded from checkpoint will respond to messages from before start`() {
val topic = "send-and-receive"
val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node1.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(payload, it) }
val receiveProtocol = ReceiveThenSuspendProtocol(node1.info.identity)
node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol
node2.stop() // kill receiver
node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
fun `protocol with send will resend on interrupted restart`() {
val topic = "send-and-receive"
val payload = random63BitValue()
val payload2 = random63BitValue()
var sentCount = 0
var receivedCount = 0
net.messagingNetwork.sentMessages.subscribe { if (it.message.topicSession.topic == topic) sentCount++ }
net.messagingNetwork.receivedMessages.subscribe { if (it.message.topicSession.topic == topic) receivedCount++ }
net.messagingNetwork.sentMessages.subscribe { if (isDataMessage(it)) sentCount++ }
net.messagingNetwork.receivedMessages.subscribe { if (isDataMessage(it)) receivedCount++ }
val node3 = net.createNode(node1.info.address)
net.runNetwork()
val firstProtocol = PingPongProtocol(topic, node3.info.identity, payload)
val secondProtocol = PingPongProtocol(topic, node2.info.identity, payload2)
connectProtocols(firstProtocol, secondProtocol)
var secondProtocol: PingPongProtocol? = null
node3.services.registerProtocolInitiator(PingPongProtocol::class) {
val protocol = PingPongProtocol(it, payload2)
secondProtocol = protocol
protocol
}
// Kick off first send and receive
node2.smm.add("test", firstProtocol)
node2.smm.add("test", PingPongProtocol(node3.info.identity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints.count())
// Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop()
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
val (firstAgain, fut1) = node2b.smm.findStateMachines(PingPongProtocol::class.java).single()
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints.count())
// Now add in the other half of the protocol. First message should get deduped. So message data stays in sync.
node3.smm.add("test", secondProtocol)
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork()
node2b.smm.executor.flush()
fut1.get()
@ -146,15 +153,66 @@ class StateMachineManagerTests {
assertEquals(0, node3.checkpointStorage.checkpoints.count(), "Checkpoints left after restored protocol should have ended")
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol.receivedPayload2, "Received payload does not match the expected second value on Node 2")
assertEquals(payload, secondProtocol!!.receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
assertEquals(payload + 1, secondProtocol!!.receivedPayload2, "Received payload does not match the expected second value on Node 2")
}
@Test
fun `sending to multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
node2.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
node3.services.registerProtocolInitiator(SendProtocol::class) { ReceiveThenSuspendProtocol(it) }
val payload = random63BitValue()
node1.smm.add("multiple-send", SendProtocol(payload, node2.info.identity, node3.info.identity))
net.runNetwork()
val node2Protocol = node2.getSingleProtocol<ReceiveThenSuspendProtocol>().first
val node3Protocol = node3.getSingleProtocol<ReceiveThenSuspendProtocol>().first
assertThat(node2Protocol.receivedPayloads[0]).isEqualTo(payload)
assertThat(node3Protocol.receivedPayloads[0]).isEqualTo(payload)
}
@Test
fun `receiving from multiple parties`() {
val node3 = net.createNode(node1.info.address)
net.runNetwork()
val node2Payload = random63BitValue()
val node3Payload = random63BitValue()
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node2Payload, it) }
node3.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { SendProtocol(node3Payload, it) }
val multiReceiveProtocol = ReceiveThenSuspendProtocol(node2.info.identity, node3.info.identity)
node1.smm.add("multiple-receive", multiReceiveProtocol)
net.runNetwork(1) // session handshaking
// have the messages arrive in reverse order of receive
node3.pumpReceive(false)
node2.pumpReceive(false)
net.runNetwork() // pump remaining messages
assertThat(multiReceiveProtocol.receivedPayloads[0]).isEqualTo(node2Payload)
assertThat(multiReceiveProtocol.receivedPayloads[1]).isEqualTo(node3Payload)
}
@Test
fun `exception thrown on other side`() {
node2.services.registerProtocolInitiator(ReceiveThenSuspendProtocol::class) { ExceptionProtocol }
val future = node1.smm.add("exception", ReceiveThenSuspendProtocol(node2.info.identity)).resultFuture
net.runNetwork()
assertThatThrownBy { future.get() }.hasCauseInstanceOf(ProtocolSessionException::class.java)
}
private fun isDataMessage(transfer: InMemoryMessagingNetwork.MessageTransfer): Boolean {
return transfer.message.topicSession == StateMachineManager.sessionTopic
&& transfer.message.data.deserialize<SessionMessage>() is SessionData
}
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
val servicesArray = advertisedServices.toTypedArray()
val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray)
stop()
val newNode = mockNet.createNode(networkMapAddress, id, advertisedServices = *advertisedServices.toTypedArray())
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return node.smm.findStateMachines(P::class.java).single().first
return newNode.getSingleProtocol<P>().first
}
private inline fun <reified P : ProtocolLogic<*>> MockNode.getSingleProtocol(): Pair<P, ListenableFuture<*>> {
return smm.findStateMachines(P::class.java).single()
}
@ -165,8 +223,6 @@ class StateMachineManagerTests {
override fun call() {
protocolStarted = true
}
override val topic: String get() = throw UnsupportedOperationException()
}
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
@ -177,8 +233,6 @@ class StateMachineManagerTests {
override fun doCall() {
protocolStarted = true
}
override val topic: String get() = throw UnsupportedOperationException()
}
@ -187,30 +241,37 @@ class StateMachineManagerTests {
val lazyTime by lazy { serviceHub.clock.instant() }
@Suspendable
override fun call() {
override fun call() = Unit
}
private class SendProtocol(val payload: Any, vararg val otherParties: Party) : ProtocolLogic<Unit>() {
init {
require(otherParties.isNotEmpty())
}
override val topic: String get() = throw UnsupportedOperationException()
}
private class SendProtocol(override val topic: String, val otherParty: Party, val payload: Any) : ProtocolLogic<Unit>() {
@Suspendable
override fun call() = send(otherParty, payload)
override fun call() = otherParties.forEach { send(it, payload) }
}
private class ReceiveProtocol(override val topic: String, val otherParty: Party) : NonTerminatingProtocol() {
private class ReceiveThenSuspendProtocol(vararg val otherParties: Party) : NonTerminatingProtocol() {
@Transient var receivedPayload: Any? = null
init {
require(otherParties.isNotEmpty())
}
@Transient var receivedPayloads: List<Any> = emptyList()
@Suspendable
override fun doCall() {
receivedPayload = receive<Any>(otherParty).unwrap { it }
receivedPayloads = otherParties.map { receive<Any>(it).unwrap { it } }
}
}
private class PingPongProtocol(override val topic: String, val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() {
private class PingPongProtocol(val otherParty: Party, val payload: Long) : ProtocolLogic<Unit>() {
@Transient var receivedPayload: Long? = null
@Transient var receivedPayload2: Long? = null
@ -219,7 +280,10 @@ class StateMachineManagerTests {
receivedPayload = sendAndReceive<Long>(otherParty, payload).unwrap { it }
receivedPayload2 = sendAndReceive<Long>(otherParty, (payload + 1)).unwrap { it }
}
}
private object ExceptionProtocol : ProtocolLogic<Nothing>() {
override fun call(): Nothing = throw Exception()
}
/**