Merged in rnicoll-protocol-request-response (pull request #60)

Refactor common elements in node services
This commit is contained in:
Ross Nicoll 2016-04-19 17:58:55 +01:00
commit 5d75a661b1
8 changed files with 104 additions and 60 deletions

View File

@ -0,0 +1,39 @@
package core.node.services
import core.messaging.Message
import core.messaging.MessagingService
import core.serialization.deserialize
import core.serialization.serialize
import protocols.AbstractRequestMessage
import javax.annotation.concurrent.ThreadSafe
/**
* Abstract superclass for services that a node can host, which provides helper functions.
*/
@ThreadSafe
abstract class AbstractNodeService(val net: MessagingService) {
/**
* Postfix for base topics when sending a request to a service.
*/
protected val topicDefaultPostfix = ".0"
/**
* Register a handler for a message topic. In comparison to using net.addMessageHandler() this manages a lot of
* common boilerplate code.
*/
protected inline fun <reified Q : AbstractRequestMessage, reified R : Any>
addMessageHandler(topic: String,
crossinline handler: (Q) -> R,
crossinline exceptionHandler: (Message, Exception) -> Unit) {
net.addMessageHandler(topic + topicDefaultPostfix, null) { message, r ->
try {
val req = message.data.deserialize<Q>()
val data = handler(req)
val msg = net.createMessage(topic + "." + req.sessionID, data.serialize().bits)
net.send(msg, req.replyTo)
} catch(e: Exception) {
exceptionHandler(message, e)
}
}
}
}

View File

@ -1,5 +1,6 @@
package core.node.services
import core.SignedTransaction
import core.crypto.SecureHash
import core.messaging.Message
import core.messaging.MessagingService
@ -7,6 +8,7 @@ import core.messaging.SingleMessageRecipient
import core.messaging.send
import core.serialization.deserialize
import core.utilities.loggerFor
import protocols.AbstractRequestMessage
import protocols.FetchAttachmentsProtocol
import protocols.FetchTransactionsProtocol
import java.io.InputStream
@ -25,36 +27,38 @@ import javax.annotation.concurrent.ThreadSafe
* Additionally, because nodes do not store invalid transactions, requesting such a transaction will always yield null.
*/
@ThreadSafe
class DataVendingService(private val net: MessagingService, private val storage: StorageService) {
class DataVendingService(net: MessagingService, private val storage: StorageService) : AbstractNodeService(net) {
companion object {
val logger = loggerFor<DataVendingService>()
}
init {
net.addMessageHandler("${FetchTransactionsProtocol.TOPIC}.0") { msg, registration -> handleTXRequest(msg) }
net.addMessageHandler("${FetchAttachmentsProtocol.TOPIC}.0") { msg, registration -> handleAttachmentRequest(msg) }
addMessageHandler(FetchTransactionsProtocol.TOPIC,
{ req: Request -> handleTXRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
addMessageHandler(FetchAttachmentsProtocol.TOPIC,
{ req: Request -> handleAttachmentRequest(req) },
{ message, e -> logger.error("Failure processing data vending request.", e) }
)
}
// TODO: Give all messages a respond-to address+session ID automatically.
data class Request(val hashes: List<SecureHash>, val responseTo: SingleMessageRecipient, val sessionID: Long)
class Request(val hashes: List<SecureHash>, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
private fun handleTXRequest(msg: Message) {
val req = msg.data.deserialize<Request>()
private fun handleTXRequest(req: Request): List<SignedTransaction?> {
require(req.hashes.isNotEmpty())
val answers = req.hashes.map {
return req.hashes.map {
val tx = storage.validatedTransactions[it]
if (tx == null)
logger.info("Got request for unknown tx $it")
tx
}
net.send("${FetchTransactionsProtocol.TOPIC}.${req.sessionID}", req.responseTo, answers)
}
private fun handleAttachmentRequest(msg: Message) {
private fun handleAttachmentRequest(req: Request): List<ByteArray?> {
// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
val req = msg.data.deserialize<Request>()
require(req.hashes.isNotEmpty())
val answers: List<ByteArray?> = req.hashes.map {
return req.hashes.map {
val jar: InputStream? = storage.attachments.openAttachment(it)?.open()
if (jar == null) {
logger.info("Got request for unknown attachment $it")
@ -63,6 +67,5 @@ class DataVendingService(private val net: MessagingService, private val storage:
jar.readBytes()
}
}
net.send("${FetchAttachmentsProtocol.TOPIC}.${req.sessionID}", req.responseTo, answers)
}
}

View File

@ -6,10 +6,13 @@ import core.crypto.signWithECDSA
import core.math.CubicSplineInterpolator
import core.math.Interpolator
import core.math.InterpolatorFactory
import core.messaging.Message
import core.messaging.MessagingService
import core.messaging.send
import core.node.AbstractNode
import core.node.AcceptsFileUpload
import core.serialization.deserialize
import org.slf4j.LoggerFactory
import protocols.RatesFixProtocol
import java.io.InputStream
import java.math.BigDecimal
@ -32,30 +35,21 @@ object NodeInterestRates {
/**
* The Service that wraps [Oracle] and handles messages/network interaction/request scrubbing.
*/
class Service(node: AbstractNode) : AcceptsFileUpload {
class Service(node: AbstractNode) : AcceptsFileUpload, AbstractNodeService(node.services.networkService) {
val ss = node.services.storageService
val oracle = Oracle(ss.myLegalIdentity, ss.myLegalIdentityKey)
val net = node.services.networkService
private val logger = LoggerFactory.getLogger(NodeInterestRates.Service::class.java)
init {
handleQueries()
handleSignRequests()
}
private fun handleSignRequests() {
net.addMessageHandler(RatesFixProtocol.TOPIC + ".sign.0") { message, registration ->
val request = message.data.deserialize<RatesFixProtocol.SignRequest>()
val sig = oracle.sign(request.tx)
net.send("${RatesFixProtocol.TOPIC}.sign.${request.sessionID}", request.replyTo, sig)
}
}
private fun handleQueries() {
net.addMessageHandler(RatesFixProtocol.TOPIC + ".query.0") { message, registration ->
val request = message.data.deserialize<RatesFixProtocol.QueryRequest>()
val answers = oracle.query(request.queries)
net.send("${RatesFixProtocol.TOPIC}.query.${request.sessionID}", request.replyTo, answers)
}
addMessageHandler(RatesFixProtocol.TOPIC_SIGN,
{ req: RatesFixProtocol.SignRequest -> oracle.sign(req.tx) },
{ message, e -> logger.error("Exception during interest rate oracle request processing", e) }
)
addMessageHandler(RatesFixProtocol.TOPIC_QUERY,
{ req: RatesFixProtocol.QueryRequest -> oracle.query(req.queries) },
{ message, e -> logger.error("Exception during interest rate oracle request processing", e) }
)
}
// File upload support

View File

@ -5,6 +5,7 @@ import core.Party
import core.TimestampCommand
import core.crypto.DigitalSignature
import core.crypto.signWithECDSA
import core.messaging.Message
import core.messaging.MessagingService
import core.seconds
import core.serialization.deserialize
@ -24,11 +25,11 @@ import javax.annotation.concurrent.ThreadSafe
* See the doc site to learn more about timestamping authorities (nodes) and the role they play in the data model.
*/
@ThreadSafe
class NodeTimestamperService(private val net: MessagingService,
class NodeTimestamperService(net: MessagingService,
val identity: Party,
val signingKey: KeyPair,
val clock: Clock = Clock.systemDefaultZone(),
val tolerance: Duration = 30.seconds) {
val tolerance: Duration = 30.seconds) : AbstractNodeService(net) {
companion object {
val TIMESTAMPING_PROTOCOL_TOPIC = "platform.timestamping.request"
@ -37,18 +38,16 @@ class NodeTimestamperService(private val net: MessagingService,
init {
require(identity.owningKey == signingKey.public)
net.addMessageHandler(TIMESTAMPING_PROTOCOL_TOPIC + ".0", null) { message, r ->
try {
val req = message.data.deserialize<TimestampingProtocol.Request>()
val signature = processRequest(req)
val msg = net.createMessage(req.replyToTopic, signature.serialize().bits)
net.send(msg, req.replyTo)
} catch(e: TimestampingError) {
logger.warn("Failure during timestamping request due to bad request: ${e.javaClass.name}")
} catch(e: Exception) {
logger.error("Exception during timestamping", e)
}
}
addMessageHandler(TIMESTAMPING_PROTOCOL_TOPIC,
{ req: TimestampingProtocol.Request -> processRequest(req) },
{ message, e ->
if (e is TimestampingError) {
logger.warn("Failure during timestamping request due to bad request: ${e.javaClass.name}")
} else {
logger.error("Exception during timestamping", e)
}
}
)
}
@VisibleForTesting

View File

@ -0,0 +1,9 @@
package protocols
import core.messaging.MessageRecipients
/**
* Abstract superclass for request messages sent to services, which includes common
* fields such as replyTo and replyToTopic.
*/
abstract class AbstractRequestMessage(val replyTo: MessageRecipients, val sessionID: Long?)

View File

@ -28,6 +28,8 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() {
companion object {
val TOPIC = "platform.rates.interest.fix"
val TOPIC_SIGN = TOPIC + ".sign"
val TOPIC_QUERY = TOPIC + ".query"
class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate")
object WORKING : ProgressTracker.Step("Working with data returned by oracle")
@ -38,8 +40,8 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
class FixOutOfRange(val byAmount: BigDecimal) : Exception()
data class QueryRequest(val queries: List<FixOf>, val replyTo: SingleMessageRecipient, val sessionID: Long)
data class SignRequest(val tx: WireTransaction, val replyTo: SingleMessageRecipient, val sessionID: Long)
class QueryRequest(val queries: List<FixOf>, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
class SignRequest(val tx: WireTransaction, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
@Suspendable
override fun call() {
@ -74,7 +76,7 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
val sessionID = random63BitValue()
val wtx = tx.toWireTransaction()
val req = SignRequest(wtx, serviceHub.networkService.myAddress, sessionID)
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(TOPIC + ".sign", oracle.address, 0, sessionID, req)
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(TOPIC_SIGN, oracle.address, 0, sessionID, req)
return resp.validate { sig ->
check(sig.signer == oracle.identity)
@ -87,7 +89,7 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
fun query(): Fix {
val sessionID = random63BitValue()
val req = QueryRequest(listOf(fixOf), serviceHub.networkService.myAddress, sessionID)
val resp = sendAndReceive<ArrayList<Fix>>(TOPIC + ".query", oracle.address, 0, sessionID, req)
val resp = sendAndReceive<ArrayList<Fix>>(TOPIC_QUERY, oracle.address, 0, sessionID, req)
return resp.validate {
val fix = it.first()

View File

@ -47,8 +47,7 @@ class TimestampingProtocol(private val node: NodeInfo,
override fun call(): DigitalSignature.LegallyIdentifiable {
progressTracker.currentStep = REQUESTING
val sessionID = random63BitValue()
val replyTopic = "${NodeTimestamperService.TIMESTAMPING_PROTOCOL_TOPIC}.$sessionID"
val req = Request(wtxBytes, serviceHub.networkService.myAddress, replyTopic)
val req = Request(wtxBytes, serviceHub.networkService.myAddress, sessionID)
val maybeSignature = sendAndReceive<DigitalSignature.LegallyIdentifiable>(
NodeTimestamperService.TIMESTAMPING_PROTOCOL_TOPIC, node.address, 0, sessionID, req)
@ -61,6 +60,5 @@ class TimestampingProtocol(private val node: NodeInfo,
}
}
// TODO: Improve the messaging api to have a notion of sender+replyTo topic (optional?)
data class Request(val tx: SerializedBytes<WireTransaction>, val replyTo: MessageRecipients, val replyToTopic: String)
class Request(val tx: SerializedBytes<WireTransaction>, replyTo: MessageRecipients, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
}

View File

@ -90,14 +90,14 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
// Zero commands is not OK.
assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) {
val wtx = ptx.toWireTransaction()
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
}
// More than one command is not OK.
assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) {
ptx.addCommand(TimestampCommand(clock.instant(), 30.seconds), ALICE)
ptx.addCommand(TimestampCommand(clock.instant(), 40.seconds), ALICE)
val wtx = ptx.toWireTransaction()
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
}
}
@ -107,7 +107,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
val now = clock.instant()
ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE)
val wtx = ptx.toWireTransaction()
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
}
}
@ -117,7 +117,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
val now = clock.instant()
ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE)
val wtx = ptx.toWireTransaction()
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
}
}
@ -126,7 +126,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
val now = clock.instant()
ptx.addCommand(TimestampCommand(now - 20.seconds, now + 20.seconds), ALICE)
val wtx = ptx.toWireTransaction()
val sig = service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
val sig = service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
ptx.checkAndAddSignature(sig)
ptx.toSignedTransaction(false).verifySignatures()
}