From 8cdd57e4e02e227f01ebce2ec8a1108251a86704 Mon Sep 17 00:00:00 2001
From: Ross Nicoll <ross.nicoll@r3cev.com>
Date: Tue, 19 Apr 2016 10:55:35 +0100
Subject: [PATCH] Refactor common elements in node services

---
 .../core/node/services/AbstractNodeService.kt | 39 +++++++++++++++++++
 .../core/node/services/DataVendingService.kt  | 29 +++++++-------
 .../core/node/services/NodeInterestRates.kt   | 34 +++++++---------
 .../node/services/NodeTimestamperService.kt   | 27 +++++++------
 .../protocols/AbstractRequestMessage.kt       |  9 +++++
 src/main/kotlin/protocols/RatesFixProtocol.kt | 10 +++--
 .../kotlin/protocols/TimestampingProtocol.kt  |  6 +--
 .../core/node/TimestamperNodeServiceTest.kt   | 10 ++---
 8 files changed, 104 insertions(+), 60 deletions(-)
 create mode 100644 src/main/kotlin/core/node/services/AbstractNodeService.kt
 create mode 100644 src/main/kotlin/protocols/AbstractRequestMessage.kt

diff --git a/src/main/kotlin/core/node/services/AbstractNodeService.kt b/src/main/kotlin/core/node/services/AbstractNodeService.kt
new file mode 100644
index 0000000000..4b0ac71db4
--- /dev/null
+++ b/src/main/kotlin/core/node/services/AbstractNodeService.kt
@@ -0,0 +1,39 @@
+package core.node.services
+
+import core.messaging.Message
+import core.messaging.MessagingService
+import core.serialization.deserialize
+import core.serialization.serialize
+import protocols.AbstractRequestMessage
+import javax.annotation.concurrent.ThreadSafe
+
+/**
+ * Abstract superclass for services that a node can host, which provides helper functions.
+ */
+@ThreadSafe
+abstract class AbstractNodeService(val net: MessagingService) {
+    /**
+     * Postfix for base topics when sending a request to a service.
+     */
+    protected val topicDefaultPostfix = ".0"
+
+    /**
+     * Register a handler for a message topic. In comparison to using net.addMessageHandler() this manages a lot of
+     * common boilerplate code.
+     */
+    protected inline fun <reified Q : AbstractRequestMessage, reified R : Any>
+            addMessageHandler(topic: String,
+                              crossinline handler: (Q) -> R,
+                              crossinline exceptionHandler: (Message, Exception) -> Unit) {
+        net.addMessageHandler(topic + topicDefaultPostfix, null) { message, r ->
+            try {
+                val req = message.data.deserialize<Q>()
+                val data = handler(req)
+                val msg = net.createMessage(topic + "." + req.sessionID, data.serialize().bits)
+                net.send(msg, req.replyTo)
+            } catch(e: Exception) {
+                exceptionHandler(message, e)
+            }
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/main/kotlin/core/node/services/DataVendingService.kt b/src/main/kotlin/core/node/services/DataVendingService.kt
index 9abef4ef17..26bb4c44fa 100644
--- a/src/main/kotlin/core/node/services/DataVendingService.kt
+++ b/src/main/kotlin/core/node/services/DataVendingService.kt
@@ -1,5 +1,6 @@
 package core.node.services
 
+import core.SignedTransaction
 import core.crypto.SecureHash
 import core.messaging.Message
 import core.messaging.MessagingService
@@ -7,6 +8,7 @@ import core.messaging.SingleMessageRecipient
 import core.messaging.send
 import core.serialization.deserialize
 import core.utilities.loggerFor
+import protocols.AbstractRequestMessage
 import protocols.FetchAttachmentsProtocol
 import protocols.FetchTransactionsProtocol
 import java.io.InputStream
@@ -25,36 +27,38 @@ import javax.annotation.concurrent.ThreadSafe
  * Additionally, because nodes do not store invalid transactions, requesting such a transaction will always yield null.
  */
 @ThreadSafe
-class DataVendingService(private val net: MessagingService, private val storage: StorageService) {
+class DataVendingService(net: MessagingService, private val storage: StorageService) : AbstractNodeService(net) {
     companion object {
         val logger = loggerFor<DataVendingService>()
     }
 
     init {
-        net.addMessageHandler("${FetchTransactionsProtocol.TOPIC}.0") { msg, registration -> handleTXRequest(msg) }
-        net.addMessageHandler("${FetchAttachmentsProtocol.TOPIC}.0") { msg, registration -> handleAttachmentRequest(msg) }
+        addMessageHandler(FetchTransactionsProtocol.TOPIC,
+                { req: Request -> handleTXRequest(req) },
+                { message, e -> logger.error("Failure processing data vending request.", e) }
+        )
+        addMessageHandler(FetchAttachmentsProtocol.TOPIC,
+                { req: Request -> handleAttachmentRequest(req) },
+                { message, e -> logger.error("Failure processing data vending request.", e) }
+        )
     }
 
-    // TODO: Give all messages a respond-to address+session ID automatically.
-    data class Request(val hashes: List<SecureHash>, val responseTo: SingleMessageRecipient, val sessionID: Long)
+    class Request(val hashes: List<SecureHash>, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
 
-    private fun handleTXRequest(msg: Message) {
-        val req = msg.data.deserialize<Request>()
+    private fun handleTXRequest(req: Request): List<SignedTransaction?> {
         require(req.hashes.isNotEmpty())
-        val answers = req.hashes.map {
+        return req.hashes.map {
             val tx = storage.validatedTransactions[it]
             if (tx == null)
                 logger.info("Got request for unknown tx $it")
             tx
         }
-        net.send("${FetchTransactionsProtocol.TOPIC}.${req.sessionID}", req.responseTo, answers)
     }
 
-    private fun handleAttachmentRequest(msg: Message) {
+    private fun handleAttachmentRequest(req: Request): List<ByteArray?> {
         // TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer.
-        val req = msg.data.deserialize<Request>()
         require(req.hashes.isNotEmpty())
-        val answers: List<ByteArray?> = req.hashes.map {
+        return req.hashes.map {
             val jar: InputStream? = storage.attachments.openAttachment(it)?.open()
             if (jar == null) {
                 logger.info("Got request for unknown attachment $it")
@@ -63,6 +67,5 @@ class DataVendingService(private val net: MessagingService, private val storage:
                 jar.readBytes()
             }
         }
-        net.send("${FetchAttachmentsProtocol.TOPIC}.${req.sessionID}", req.responseTo, answers)
     }
 }
diff --git a/src/main/kotlin/core/node/services/NodeInterestRates.kt b/src/main/kotlin/core/node/services/NodeInterestRates.kt
index e1ed54ad25..7843f338b6 100644
--- a/src/main/kotlin/core/node/services/NodeInterestRates.kt
+++ b/src/main/kotlin/core/node/services/NodeInterestRates.kt
@@ -6,10 +6,13 @@ import core.crypto.signWithECDSA
 import core.math.CubicSplineInterpolator
 import core.math.Interpolator
 import core.math.InterpolatorFactory
+import core.messaging.Message
+import core.messaging.MessagingService
 import core.messaging.send
 import core.node.AbstractNode
 import core.node.AcceptsFileUpload
 import core.serialization.deserialize
+import org.slf4j.LoggerFactory
 import protocols.RatesFixProtocol
 import java.io.InputStream
 import java.math.BigDecimal
@@ -32,30 +35,21 @@ object NodeInterestRates {
     /**
      * The Service that wraps [Oracle] and handles messages/network interaction/request scrubbing.
      */
-    class Service(node: AbstractNode) : AcceptsFileUpload {
+    class Service(node: AbstractNode) : AcceptsFileUpload, AbstractNodeService(node.services.networkService) {
         val ss = node.services.storageService
         val oracle = Oracle(ss.myLegalIdentity, ss.myLegalIdentityKey)
-        val net = node.services.networkService
+
+        private val logger = LoggerFactory.getLogger(NodeInterestRates.Service::class.java)
 
         init {
-            handleQueries()
-            handleSignRequests()
-        }
-
-        private fun handleSignRequests() {
-            net.addMessageHandler(RatesFixProtocol.TOPIC + ".sign.0") { message, registration ->
-                val request = message.data.deserialize<RatesFixProtocol.SignRequest>()
-                val sig = oracle.sign(request.tx)
-                net.send("${RatesFixProtocol.TOPIC}.sign.${request.sessionID}", request.replyTo, sig)
-            }
-        }
-
-        private fun handleQueries() {
-            net.addMessageHandler(RatesFixProtocol.TOPIC + ".query.0") { message, registration ->
-                val request = message.data.deserialize<RatesFixProtocol.QueryRequest>()
-                val answers = oracle.query(request.queries)
-                net.send("${RatesFixProtocol.TOPIC}.query.${request.sessionID}", request.replyTo, answers)
-            }
+            addMessageHandler(RatesFixProtocol.TOPIC_SIGN,
+                    { req: RatesFixProtocol.SignRequest -> oracle.sign(req.tx) },
+                    { message, e -> logger.error("Exception during interest rate oracle request processing", e) }
+            )
+            addMessageHandler(RatesFixProtocol.TOPIC_QUERY,
+                    { req: RatesFixProtocol.QueryRequest -> oracle.query(req.queries) },
+                    { message, e -> logger.error("Exception during interest rate oracle request processing", e) }
+            )
         }
 
         // File upload support
diff --git a/src/main/kotlin/core/node/services/NodeTimestamperService.kt b/src/main/kotlin/core/node/services/NodeTimestamperService.kt
index e0aa3f8e30..ad00a00ab4 100644
--- a/src/main/kotlin/core/node/services/NodeTimestamperService.kt
+++ b/src/main/kotlin/core/node/services/NodeTimestamperService.kt
@@ -5,6 +5,7 @@ import core.Party
 import core.TimestampCommand
 import core.crypto.DigitalSignature
 import core.crypto.signWithECDSA
+import core.messaging.Message
 import core.messaging.MessagingService
 import core.seconds
 import core.serialization.deserialize
@@ -24,11 +25,11 @@ import javax.annotation.concurrent.ThreadSafe
  * See the doc site to learn more about timestamping authorities (nodes) and the role they play in the data model.
  */
 @ThreadSafe
-class NodeTimestamperService(private val net: MessagingService,
+class NodeTimestamperService(net: MessagingService,
                              val identity: Party,
                              val signingKey: KeyPair,
                              val clock: Clock = Clock.systemDefaultZone(),
-                             val tolerance: Duration = 30.seconds) {
+                             val tolerance: Duration = 30.seconds) : AbstractNodeService(net) {
     companion object {
         val TIMESTAMPING_PROTOCOL_TOPIC = "platform.timestamping.request"
 
@@ -37,18 +38,16 @@ class NodeTimestamperService(private val net: MessagingService,
 
     init {
         require(identity.owningKey == signingKey.public)
-        net.addMessageHandler(TIMESTAMPING_PROTOCOL_TOPIC + ".0", null) { message, r ->
-            try {
-                val req = message.data.deserialize<TimestampingProtocol.Request>()
-                val signature = processRequest(req)
-                val msg = net.createMessage(req.replyToTopic, signature.serialize().bits)
-                net.send(msg, req.replyTo)
-            } catch(e: TimestampingError) {
-                logger.warn("Failure during timestamping request due to bad request: ${e.javaClass.name}")
-            } catch(e: Exception) {
-                logger.error("Exception during timestamping", e)
-            }
-        }
+        addMessageHandler(TIMESTAMPING_PROTOCOL_TOPIC,
+                { req: TimestampingProtocol.Request -> processRequest(req) },
+                { message, e ->
+                    if (e is TimestampingError) {
+                        logger.warn("Failure during timestamping request due to bad request: ${e.javaClass.name}")
+                    } else {
+                        logger.error("Exception during timestamping", e)
+                    }
+                }
+        )
     }
 
     @VisibleForTesting
diff --git a/src/main/kotlin/protocols/AbstractRequestMessage.kt b/src/main/kotlin/protocols/AbstractRequestMessage.kt
new file mode 100644
index 0000000000..a1ceac7675
--- /dev/null
+++ b/src/main/kotlin/protocols/AbstractRequestMessage.kt
@@ -0,0 +1,9 @@
+package protocols
+
+import core.messaging.MessageRecipients
+
+/**
+ * Abstract superclass for request messages sent to services, which includes common
+ * fields such as replyTo and replyToTopic.
+ */
+abstract class AbstractRequestMessage(val replyTo: MessageRecipients, val sessionID: Long?)
\ No newline at end of file
diff --git a/src/main/kotlin/protocols/RatesFixProtocol.kt b/src/main/kotlin/protocols/RatesFixProtocol.kt
index 270be3e80c..254aa1b23a 100644
--- a/src/main/kotlin/protocols/RatesFixProtocol.kt
+++ b/src/main/kotlin/protocols/RatesFixProtocol.kt
@@ -28,6 +28,8 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
                             override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() {
     companion object {
         val TOPIC = "platform.rates.interest.fix"
+        val TOPIC_SIGN = TOPIC + ".sign"
+        val TOPIC_QUERY = TOPIC + ".query"
 
         class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate")
         object WORKING : ProgressTracker.Step("Working with data returned by oracle")
@@ -38,8 +40,8 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
 
     class FixOutOfRange(val byAmount: BigDecimal) : Exception()
 
-    data class QueryRequest(val queries: List<FixOf>, val replyTo: SingleMessageRecipient, val sessionID: Long)
-    data class SignRequest(val tx: WireTransaction, val replyTo: SingleMessageRecipient, val sessionID: Long)
+    class QueryRequest(val queries: List<FixOf>, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
+    class SignRequest(val tx: WireTransaction, replyTo: SingleMessageRecipient, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
 
     @Suspendable
     override fun call() {
@@ -74,7 +76,7 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
         val sessionID = random63BitValue()
         val wtx = tx.toWireTransaction()
         val req = SignRequest(wtx, serviceHub.networkService.myAddress, sessionID)
-        val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(TOPIC + ".sign", oracle.address, 0, sessionID, req)
+        val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(TOPIC_SIGN, oracle.address, 0, sessionID, req)
 
         return resp.validate { sig ->
             check(sig.signer == oracle.identity)
@@ -87,7 +89,7 @@ open class RatesFixProtocol(protected val tx: TransactionBuilder,
     fun query(): Fix {
         val sessionID = random63BitValue()
         val req = QueryRequest(listOf(fixOf), serviceHub.networkService.myAddress, sessionID)
-        val resp = sendAndReceive<ArrayList<Fix>>(TOPIC + ".query", oracle.address, 0, sessionID, req)
+        val resp = sendAndReceive<ArrayList<Fix>>(TOPIC_QUERY, oracle.address, 0, sessionID, req)
 
         return resp.validate {
             val fix = it.first()
diff --git a/src/main/kotlin/protocols/TimestampingProtocol.kt b/src/main/kotlin/protocols/TimestampingProtocol.kt
index e840a75818..b6ce07c5b5 100644
--- a/src/main/kotlin/protocols/TimestampingProtocol.kt
+++ b/src/main/kotlin/protocols/TimestampingProtocol.kt
@@ -47,8 +47,7 @@ class TimestampingProtocol(private val node: NodeInfo,
     override fun call(): DigitalSignature.LegallyIdentifiable {
         progressTracker.currentStep = REQUESTING
         val sessionID = random63BitValue()
-        val replyTopic = "${NodeTimestamperService.TIMESTAMPING_PROTOCOL_TOPIC}.$sessionID"
-        val req = Request(wtxBytes, serviceHub.networkService.myAddress, replyTopic)
+        val req = Request(wtxBytes, serviceHub.networkService.myAddress, sessionID)
 
         val maybeSignature = sendAndReceive<DigitalSignature.LegallyIdentifiable>(
                 NodeTimestamperService.TIMESTAMPING_PROTOCOL_TOPIC, node.address, 0, sessionID, req)
@@ -61,6 +60,5 @@ class TimestampingProtocol(private val node: NodeInfo,
         }
     }
 
-    // TODO: Improve the messaging api to have a notion of sender+replyTo topic (optional?)
-    data class Request(val tx: SerializedBytes<WireTransaction>, val replyTo: MessageRecipients, val replyToTopic: String)
+    class Request(val tx: SerializedBytes<WireTransaction>, replyTo: MessageRecipients, sessionID: Long) : AbstractRequestMessage(replyTo, sessionID)
 }
\ No newline at end of file
diff --git a/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt
index 1f70391b4d..5a25f77518 100644
--- a/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt
+++ b/src/test/kotlin/core/node/TimestamperNodeServiceTest.kt
@@ -90,14 +90,14 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
         // Zero commands is not OK.
         assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) {
             val wtx = ptx.toWireTransaction()
-            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
+            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
         }
         // More than one command is not OK.
         assertFailsWith(TimestampingError.RequiresExactlyOneCommand::class) {
             ptx.addCommand(TimestampCommand(clock.instant(), 30.seconds), ALICE)
             ptx.addCommand(TimestampCommand(clock.instant(), 40.seconds), ALICE)
             val wtx = ptx.toWireTransaction()
-            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
+            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
         }
     }
 
@@ -107,7 +107,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
             val now = clock.instant()
             ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE)
             val wtx = ptx.toWireTransaction()
-            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
+            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
         }
     }
 
@@ -117,7 +117,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
             val now = clock.instant()
             ptx.addCommand(TimestampCommand(now - 60.seconds, now - 40.seconds), ALICE)
             val wtx = ptx.toWireTransaction()
-            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
+            service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
         }
     }
 
@@ -126,7 +126,7 @@ class TimestamperNodeServiceTest : TestWithInMemoryNetwork() {
         val now = clock.instant()
         ptx.addCommand(TimestampCommand(now - 20.seconds, now + 20.seconds), ALICE)
         val wtx = ptx.toWireTransaction()
-        val sig = service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, "ignored"))
+        val sig = service.processRequest(TimestampingProtocol.Request(wtx.serialize(), myMessaging.first, Long.MIN_VALUE))
         ptx.checkAndAddSignature(sig)
         ptx.toSignedTransaction(false).verifySignatures()
     }