mirror of
https://github.com/corda/corda.git
synced 2024-12-20 13:33:12 +00:00
Merged in rnicoll-protocol-request-response (pull request #60)
Refactor common elements in node services
This commit is contained in:
commit
5d75a661b1
39
src/main/kotlin/core/node/services/AbstractNodeService.kt
Normal file
39
src/main/kotlin/core/node/services/AbstractNodeService.kt
Normal file
@ -0,0 +1,39 @@
|
||||
package core.node.services
|
||||
|
||||
import core.messaging.Message
|
||||
import core.messaging.MessagingService
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.serialize
|
||||
import protocols.AbstractRequestMessage
|
||||
import javax.annotation.concurrent.ThreadSafe
|
||||
|
||||
/**
|
||||
* Abstract superclass for services that a node can host, which provides helper functions.
|
||||
*/
|
||||
@ThreadSafe
|
||||
abstract class AbstractNodeService(val net: MessagingService) {
|
||||
/**
|
||||
* Postfix for base topics when sending a request to a service.
|
||||
*/
|
||||
protected val topicDefaultPostfix = ".0"
|
||||
|
||||
/**
|
||||
* Register a handler for a message topic. In comparison to using net.addMessageHandler() this manages a lot of
|
||||
* common boilerplate code.
|
||||
*/
|
||||
protected inline fun <reified Q : AbstractRequestMessage, reified R : Any>
|
||||
addMessageHandler(topic: String,
|
||||
crossinline handler: (Q) -> R,
|
||||
crossinline exceptionHandler: (Message, Exception) -> Unit) {
|
||||
net.addMessageHandler(topic + topicDefaultPostfix, null) { message, r ->
|
||||
try {
|
||||
val req = message.data.deserialize<Q>()
|
||||
val data = handler(req)
|
||||
val msg = net.createMessage(topic + "." + req.sessionID, data.serialize().bits)
|
||||
net.send(msg, req.replyTo)
|
||||
} catch(e: Exception) {
|
||||
exceptionHandler(message, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
package core.node.services
|
||||
|
||||
import core.SignedTransaction
|
||||
import core.crypto.SecureHash
|
||||
import core.messaging.Message
|
||||
import core.messaging.MessagingService
|
||||
@ -7,6 +8,7 @@ import core.messaging.SingleMessageRecipient
|
||||
import core.messaging.send
|
||||
import core.serialization.deserialize
|
||||
import core.utilities.loggerFor
|
||||
import protocols.AbstractRequestMessage
|
||||
import protocols.FetchAttachmentsProtocol
|
||||
import protocols.FetchTransactionsProtocol
|
||||
import java.io.InputStream
|
||||
@ -25,36 +27,38 @@ import javax.annotation.concurrent.ThreadSafe
|
||||
* Additionally, because nodes do not store invalid transactions, requesting such a transaction will always yield null.
|
||||
*/
|
||||
@ThreadSafe
|
||||
class DataVendingService(private val net: MessagingService, private val storage: StorageService) {
|
||||
class DataVendingService(net: MessagingService, private val storage: StorageService) : AbstractNodeService(net) {
|
||||
companion object {
|
||||
val logger = loggerFor<DataVendingService>()
|
||||
}
|
||||
|
||||
init {
|
||||
net.addMessageHandler("${FetchTransactionsProtocol.TOPIC}.0") { msg, registration -> handleTXRequest(msg) }
|
||||
net.addMessageHandler("${FetchAttachmentsProtocol.TOPIC}.0") { msg, registration -> handleAttachmentRequest(msg) }
|
||||
addMessageHandler(FetchTransactionsProtocol.TOPIC,
|
||||
{ req: Request -> handleTXRequest(req) },
|
||||
{ message, e -> logger.error("Failure processing data vending request.", e) }
|
||||
)
|
||||
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
|
||||
{ req: Request -> handleAttachmentRequest(req) },
|
||||
{ message, e -> logger.error("Failure processing data vending request.", e) }
|
||||
)
|
||||
}
|
||||
|
||||
// TODO: Give all messages a respond-to address+session ID automatically.
|
||||
data class Request(val hashes: List<SecureHash>, val responseTo: SingleMessageRecipient, val sessionID: Long)
|
||||
class Request(val hashes: List<SecureHash>, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
|
||||
|
||||
private fun handleTXRequest(msg: Message) {
|
||||
val req = msg.data.deserialize<Request>()
|
||||
private fun handleTXRequest(req: Request): List<SignedTransaction?> {
|
||||
require(req.hashes.isNotEmpty())
|
||||
val answers = req.hashes.map {
|
||||
return req.hashes.map {
|
||||
val tx = storage.validatedTransactions[it]
|
||||
if (tx == null)
|
||||
logger.info("Got request for unknown tx $it")
|
||||
tx
|
||||
}
|
||||
net.send("${FetchTransactionsProtocol.TOPIC}.${req.sessionID}", req.responseTo, answers)
|
||||
}
|
||||
|
||||
private fun handleAttachmentRequest(msg: Message) {
|
||||
private fun handleAttachmentRequest(req: Request): List<ByteArray?> {
|
||||
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
|
||||
val req = msg.data.deserialize<Request>()
|
||||
require(req.hashes.isNotEmpty())
|
||||
val answers: List<ByteArray?> = req.hashes.map {
|
||||
return req.hashes.map {
|
||||
val jar: InputStream? = storage.attachments.openAttachment(it)?.open()
|
||||
if (jar == null) {
|
||||
logger.info("Got request for unknown attachment $it")
|
||||
@ -63,6 +67,5 @@ class DataVendingService(private val net: MessagingService, private val storage:
|
||||
jar.readBytes()
|
||||
}
|
||||
}
|
||||
net.send("${FetchAttachmentsProtocol.TOPIC}.${req.sessionID}", req.responseTo, answers)
|
||||
}
|
||||
}
|
||||
|
@ -6,10 +6,13 @@ import core.crypto.signWithECDSA
|
||||
import core.math.CubicSplineInterpolator
|
||||
import core.math.Interpolator
|
||||
import core.math.InterpolatorFactory
|
||||
import core.messaging.Message
|
||||
import core.messaging.MessagingService
|
||||
import core.messaging.send
|
||||
import core.node.AbstractNode
|
||||
import core.node.AcceptsFileUpload
|
||||
import core.serialization.deserialize
|
||||
import org.slf4j.LoggerFactory
|
||||
import protocols.RatesFixProtocol
|
||||
import java.io.InputStream
|
||||
import java.math.BigDecimal
|
||||
@ -32,30 +35,21 @@ object NodeInterestRates {
|
||||
/**
|
||||
* The Service that wraps [Oracle] and handles messages/network interaction/request scrubbing.
|
||||
*/
|
||||
class Service(node: AbstractNode) : AcceptsFileUpload {
|
||||
class Service(node: AbstractNode) : AcceptsFileUpload, AbstractNodeService(node.services.networkService) {
|
||||
val ss = node.services.storageService
|
||||
val oracle = Oracle(ss.myLegalIdentity, ss.myLegalIdentityKey)
|
||||
val net = node.services.networkService
|
||||
|
||||
private val logger = LoggerFactory.getLogger(NodeInterestRates.Service::class.java)
|
||||
|
||||
init {
|
||||
handleQueries()
|
||||
handleSignRequests()
|
||||
}
|
||||
|
||||
private fun handleSignRequests() {
|
||||
net.addMessageHandler(RatesFixProtocol.TOPIC + ".sign.0") { message, registration ->
|
||||
val request = message.data.deserialize<RatesFixProtocol.SignRequest>()
|
||||
val sig = oracle.sign(request.tx)
|
||||
net.send("${RatesFixProtocol.TOPIC}.sign.${request.sessionID}", request.replyTo, sig)
|
||||
}
|
||||
}
|
||||
|
||||
private fun handleQueries() {
|
||||
net.addMessageHandler(RatesFixProtocol.TOPIC + ".query.0") { message, registration ->
|
||||
val request = message.data.deserialize<RatesFixProtocol.QueryRequest>()
|
||||
val answers = oracle.query(request.queries)
|
||||
net.send("${RatesFixProtocol.TOPIC}.query.${request.sessionID}", request.replyTo, answers)
|
||||
}
|
||||
addMessageHandler(RatesFixProtocol.TOPIC_SIGN,
|
||||
{ req: RatesFixProtocol.SignRequest -> oracle.sign(req.tx) },
|
||||
{ message, e -> logger.error("Exception during interest rate oracle request processing", e) }
|
||||
)
|
||||
addMessageHandler(RatesFixProtocol.TOPIC_QUERY,
|
||||
{ req: RatesFixProtocol.QueryRequest -> oracle.query(req.queries) },
|
||||
{ message, e -> logger.error("Exception during interest rate oracle request processing", e) }
|
||||
)
|
||||
}
|
||||
|
||||
// File upload support
|
||||
|
@ -5,6 +5,7 @@ import core.Party
|
||||
import core.TimestampCommand
|
||||
import core.crypto.DigitalSignature
|
||||
import core.crypto.signWithECDSA
|
||||
import core.messaging.Message
|
||||
import core.messaging.MessagingService
|
||||
import core.seconds
|
||||
import core.serialization.deserialize
|
||||
@ -24,11 +25,11 @@ import javax.annotation.concurrent.ThreadSafe
|
||||
* See the doc site to learn more about timestamping authorities (nodes) and the role they play in the data model.
|
||||
*/
|
||||
@ThreadSafe
|
||||
class NodeTimestamperService(private val net: MessagingService,
|
||||
class NodeTimestamperService(net: MessagingService,
|
||||
val identity: Party,
|
||||
val signingKey: KeyPair,
|
||||
val clock: Clock = Clock.systemDefaultZone(),
|
||||
val tolerance: Duration = 30.seconds) {
|
||||
val tolerance: Duration = 30.seconds) : AbstractNodeService(net) {
|
||||
companion object {
|
||||
val TIMESTAMPING_PROTOCOL_TOPIC = "platform.timestamping.request"
|
||||
|
||||
@ -37,18 +38,16 @@ class NodeTimestamperService(private val net: MessagingService,
|
||||
|
||||
init {
|
||||
require(identity.owningKey == signingKey.public)
|
||||
net.addMessageHandler(TIMESTAMPING_PROTOCOL_TOPIC + ".0", null) { message, r ->
|
||||
try {
|
||||
val req = message.data.deserialize<TimestampingProtocol.Request>()
|
||||
val signature = processRequest(req)
|
||||
val msg = net.createMessage(req.replyToTopic, signature.serialize().bits)
|
||||
net.send(msg, req.replyTo)
|
||||
} catch(e: TimestampingError) {
|
||||
logger.warn("Failure during timestamping request due to bad request: ${e.javaClass.name}")
|
||||
} catch(e: Exception) {
|
||||
logger.error("Exception during timestamping", e)
|
||||
}
|
||||
}
|
||||
addMessageHandler(TIMESTAMPING_PROTOCOL_TOPIC,
|
||||
{ req: TimestampingProtocol.Request -> processRequest(req) },
|
||||
{ message, e ->
|
||||
if (e is TimestampingError) {
|
||||
logger.warn("Failure during timestamping request due to bad request: ${e.javaClass.name}")
|
||||
} else {
|
||||
logger.error("Exception during timestamping", e)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
@VisibleForTesting
|
||||
|
9
src/main/kotlin/protocols/AbstractRequestMessage.kt
Normal file
9
src/main/kotlin/protocols/AbstractRequestMessage.kt
Normal file
@ -0,0 +1,9 @@
|
||||
package protocols
|
||||
|
||||
import core.messaging.MessageRecipients
|
||||
|
||||
/**
|
||||
* Abstract superclass for request messages sent to services, which includes common
|
||||
* fields such as replyTo and replyToTopic.
|
||||
*/
|
||||
abstract class AbstractRequestMessage(val replyTo: MessageRecipients, val sessionID: Long?)
|
@ -28,6 +28,8 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
|
||||
override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() {
|
||||
companion object {
|
||||
val TOPIC = "platform.rates.interest.fix"
|
||||
val TOPIC_SIGN = TOPIC + ".sign"
|
||||
val TOPIC_QUERY = TOPIC + ".query"
|
||||
|
||||
class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate")
|
||||
object WORKING : ProgressTracker.Step("Working with data returned by oracle")
|
||||
@ -38,8 +40,8 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
|
||||
|
||||
class FixOutOfRange(val byAmount: BigDecimal) : Exception()
|
||||
|
||||
data class QueryRequest(val queries: List<FixOf>, val replyTo: SingleMessageRecipient, val sessionID: Long)
|
||||
data class SignRequest(val tx: WireTransaction, val replyTo: SingleMessageRecipient, val sessionID: Long)
|
||||
class QueryRequest(val queries: List<FixOf>, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
|
||||
class SignRequest(val tx: WireTransaction, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
|
||||
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
@ -74,7 +76,7 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
|
||||
val sessionID = random63BitValue()
|
||||
val wtx = tx.toWireTransaction()
|
||||
val req = SignRequest(wtx, serviceHub.networkService.myAddress, sessionID)
|
||||
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(TOPIC + ".sign", oracle.address, 0, sessionID, req)
|
||||
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(TOPIC_SIGN, oracle.address, 0, sessionID, req)
|
||||
|
||||
return resp.validate { sig ->
|
||||
check(sig.signer == oracle.identity)
|
||||
@ -87,7 +89,7 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
|
||||
fun query(): Fix {
|
||||
val sessionID = random63BitValue()
|
||||
val req = QueryRequest(listOf(fixOf), serviceHub.networkService.myAddress, sessionID)
|
||||
val resp = sendAndReceive<ArrayList<Fix>>(TOPIC + ".query", oracle.address, 0, sessionID, req)
|
||||
val resp = sendAndReceive<ArrayList<Fix>>(TOPIC_QUERY, oracle.address, 0, sessionID, req)
|
||||
|
||||
return resp.validate {
|
||||
val fix = it.first()
|
||||
|
@ -47,8 +47,7 @@ class TimestampingProtocol(private val node: NodeInfo,
|
||||
override fun call(): DigitalSignature.LegallyIdentifiable {
|
||||
progressTracker.currentStep = REQUESTING
|
||||
val sessionID = random63BitValue()
|
||||
val replyTopic = "${NodeTimestamperService.TIMESTAMPING_PROTOCOL_TOPIC}.$sessionID"
|
||||
val req = Request(wtxBytes, serviceHub.networkService.myAddress, replyTopic)
|
||||
val req = Request(wtxBytes, serviceHub.networkService.myAddress, sessionID)
|
||||
|
||||
val maybeSignature = sendAndReceive<DigitalSignature.LegallyIdentifiable>(
|
||||
NodeTimestamperService.TIMESTAMPING_PROTOCOL_TOPIC, node.address, 0, sessionID, req)
|
||||
@ -61,6 +60,5 @@ class TimestampingProtocol(private val node: NodeInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Improve the messaging api to have a notion of sender+replyTo topic (optional?)
|
||||
data class Request(val tx: SerializedBytes<WireTransaction>, val replyTo: MessageRecipients, val replyToTopic: String)
|
||||
class Request(val tx: SerializedBytes<WireTransaction>, replyTo: MessageRecipients, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
|
||||
}
|
@ -90,14 +90,14 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
|
||||
// Zero commands is not OK.
|
||||
assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) {
|
||||
val wtx = ptx.toWireTransaction()
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
|
||||
}
|
||||
// More than one command is not OK.
|
||||
assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) {
|
||||
ptx.addCommand(TimestampCommand(clock.instant(), 30.seconds), ALICE)
|
||||
ptx.addCommand(TimestampCommand(clock.instant(), 40.seconds), ALICE)
|
||||
val wtx = ptx.toWireTransaction()
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
|
||||
}
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
|
||||
val now = clock.instant()
|
||||
ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE)
|
||||
val wtx = ptx.toWireTransaction()
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
|
||||
}
|
||||
}
|
||||
|
||||
@ -117,7 +117,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
|
||||
val now = clock.instant()
|
||||
ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE)
|
||||
val wtx = ptx.toWireTransaction()
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
|
||||
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
|
||||
}
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
|
||||
val now = clock.instant()
|
||||
ptx.addCommand(TimestampCommand(now - 20.seconds, now + 20.seconds), ALICE)
|
||||
val wtx = ptx.toWireTransaction()
|
||||
val sig = service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
|
||||
val sig = service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
|
||||
ptx.checkAndAddSignature(sig)
|
||||
ptx.toSignedTransaction(false).verifySignatures()
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user