mirror of
https://github.com/corda/corda.git
synced 2025-02-20 09:26:41 +00:00
Messaging: another minor simplification to the state machines framework. This is not the end!
This commit is contained in:
parent
9720d4e404
commit
3f19e68b3f
@ -8,7 +8,6 @@
|
||||
|
||||
package contracts.protocols
|
||||
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import contracts.Cash
|
||||
import contracts.sumCashBy
|
||||
import core.*
|
||||
@ -49,18 +48,14 @@ abstract class TwoPartyTradeProtocol {
|
||||
val myKeyPair: KeyPair
|
||||
)
|
||||
|
||||
abstract fun runSeller(otherSide: SingleMessageRecipient,
|
||||
args: SellerInitialArgs): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
abstract fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller
|
||||
|
||||
class BuyerInitialArgs(
|
||||
val acceptablePrice: Amount,
|
||||
val typeToBuy: Class<out OwnableState>
|
||||
)
|
||||
|
||||
abstract fun runBuyer(
|
||||
otherSide: SingleMessageRecipient,
|
||||
args: BuyerInitialArgs
|
||||
): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>>
|
||||
abstract fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer
|
||||
|
||||
abstract class Buyer : ProtocolStateMachine<BuyerInitialArgs, Pair<TimestampedWireTransaction, LedgerTransaction>>()
|
||||
abstract class Seller : ProtocolStateMachine<SellerInitialArgs, Pair<TimestampedWireTransaction, LedgerTransaction>>()
|
||||
@ -203,11 +198,11 @@ private class TwoPartyTradeProtocolImpl(private val smm: StateMachineManager) :
|
||||
}
|
||||
}
|
||||
|
||||
override fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
override fun runSeller(otherSide: SingleMessageRecipient, args: SellerInitialArgs): Seller {
|
||||
return smm.add(otherSide, args, "$TRADE_TOPIC.seller", SellerImpl::class.java)
|
||||
}
|
||||
|
||||
override fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): ListenableFuture<out Pair<TimestampedWireTransaction, LedgerTransaction>> {
|
||||
override fun runBuyer(otherSide: SingleMessageRecipient, args: BuyerInitialArgs): Buyer {
|
||||
return smm.add(otherSide, args, "$TRADE_TOPIC.buyer", BuyerImpl::class.java)
|
||||
}
|
||||
}
|
@ -35,7 +35,7 @@ import java.util.concurrent.Executor
|
||||
* and, if run with a single-threaded executor, will ensure no two state machines run concurrently with each other
|
||||
* (bad for performance, good for programmer mental health!).
|
||||
*
|
||||
* TODO: The framework should do automatic error handling.
|
||||
* TODO: The framework should propagate exceptions and handle error handling automatically.
|
||||
*/
|
||||
class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor) {
|
||||
private val checkpointsMap = serviceHub.storageService.getMap<SecureHash, ByteArray>("state machines")
|
||||
@ -95,8 +95,8 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
}
|
||||
}
|
||||
|
||||
fun <T, R> add(otherSide: MessageRecipients, initialArgs: T, loggerName: String,
|
||||
continuationClass: Class<out ProtocolStateMachine<*, R>>): ListenableFuture<out R> {
|
||||
fun <T : ProtocolStateMachine<I, *>, I> add(otherSide: MessageRecipients, initialArgs: I, loggerName: String,
|
||||
continuationClass: Class<out T>): T {
|
||||
val logger = LoggerFactory.getLogger(loggerName)
|
||||
val (sm, continuation) = loadContinuationClass(continuationClass)
|
||||
sm.serviceHub = serviceHub
|
||||
@ -107,7 +107,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
iterateStateMachine(continuation, serviceHub.networkService, otherSide, initialArgs, logger, null)
|
||||
}
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return (sm as ProtocolStateMachine<T, R>).resultFuture
|
||||
return sm as T
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
|
@ -92,9 +92,9 @@ class TwoPartyTradeProtocolTests : TestWithInMemoryNetwork() {
|
||||
)
|
||||
)
|
||||
|
||||
assertEquals(aliceResult.get(), bobResult.get())
|
||||
assertEquals(aliceResult.resultFuture.get(), bobResult.resultFuture.get())
|
||||
|
||||
txns.add(aliceResult.get().second)
|
||||
txns.add(aliceResult.resultFuture.get().second)
|
||||
verify()
|
||||
}
|
||||
backgroundThread.shutdown()
|
||||
|
Loading…
x
Reference in New Issue
Block a user