diff --git a/core/src/main/kotlin/com/r3corda/core/node/CordaPluginRegistry.kt b/core/src/main/kotlin/com/r3corda/core/node/CordaPluginRegistry.kt new file mode 100644 index 0000000000..89492ee373 --- /dev/null +++ b/core/src/main/kotlin/com/r3corda/core/node/CordaPluginRegistry.kt @@ -0,0 +1,23 @@ +package com.r3corda.core.node + +/** + * Implement this interface on a class advertised in a META-INF/services/com.r3corda.core.node.CordaPluginRegistry file + * to extend a Corda node with additional application services. + */ +interface CordaPluginRegistry { + /** + * List of JAX-RS classes inside the contract jar. They are expected to have a single parameter constructor that takes a ServiceHub as input. + * These are listed as Class<*>, because they will be instantiated inside an AttachmentClassLoader so that subsequent protocols, contracts, etc + * will be running in the appropriate isolated context. + */ + val webApis: List> + + /** + * A Map with an entry for each consumed protocol used by the webAPIs. + * The key of each map entry should contain the ProtocolLogic class name. + * The associated map values are the union of all concrete class names passed to the protocol constructor. + * Standard java.lang.* and kotlin.* types do not need to be included explicitly + * This is used to extend the white listed protocols that can be initiated from the ServiceHub invokeProtocolAsync method + */ + val requiredProtocols: Map> +} \ No newline at end of file diff --git a/core/src/main/kotlin/com/r3corda/core/node/ServiceHub.kt b/core/src/main/kotlin/com/r3corda/core/node/ServiceHub.kt index 893f83a868..d1cf7fa0fd 100644 --- a/core/src/main/kotlin/com/r3corda/core/node/ServiceHub.kt +++ b/core/src/main/kotlin/com/r3corda/core/node/ServiceHub.kt @@ -1,8 +1,10 @@ package com.r3corda.core.node +import com.google.common.util.concurrent.ListenableFuture import com.r3corda.core.contracts.* import com.r3corda.core.messaging.MessagingService import com.r3corda.core.node.services.* +import com.r3corda.core.protocols.ProtocolLogic import java.time.Clock /** @@ -61,4 +63,11 @@ interface ServiceHub { val definingTx = storageService.validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) return definingTx.tx.outputs[stateRef.index] } + + /** + * Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the protocol. + * + * @throws IllegalProtocolLogicException or IllegalArgumentException if there are problems with the [logicType] or [args] + */ + fun invokeProtocolAsync(logicType: Class>, vararg args: Any?): ListenableFuture } \ No newline at end of file diff --git a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt index 2e8bffb531..5b0154cb81 100644 --- a/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt +++ b/core/src/main/kotlin/com/r3corda/core/protocols/ProtocolLogicRef.kt @@ -25,22 +25,28 @@ import kotlin.reflect.primaryConstructor * TODO: Align with API related logic for passing in ProtocolLogic references (ProtocolRef) * TODO: Actual support for AppContext / AttachmentsClassLoader */ -class ProtocolLogicRefFactory(private val protocolLogicClassNameWhitelist: Set, private val argsClassNameWhitelist: Set) : SingletonSerializeAsToken() { +class ProtocolLogicRefFactory(private val protocolWhitelist: Map>) : SingletonSerializeAsToken() { - constructor() : this(setOf(TwoPartyDealProtocol.FixingRoleDecider::class.java.name), setOf(StateRef::class.java.name, Duration::class.java.name)) + constructor() : this(mapOf(Pair(TwoPartyDealProtocol.FixingRoleDecider::class.java.name, setOf(StateRef::class.java.name, Duration::class.java.name)))) // Pending real dependence on AppContext for class loading etc @Suppress("UNUSED_PARAMETER") private fun validateProtocolClassName(className: String, appContext: AppContext) { // TODO: make this specific to the attachments in the [AppContext] by including [SecureHash] in whitelist check - require(className in protocolLogicClassNameWhitelist) { "${ProtocolLogic::class.java.simpleName} of ${ProtocolLogicRef::class.java.simpleName} must have type on the whitelist: $className" } + require(protocolWhitelist.containsKey(className)) { "${ProtocolLogic::class.java.simpleName} of ${ProtocolLogicRef::class.java.simpleName} must have type on the whitelist: $className" } } // Pending real dependence on AppContext for class loading etc @Suppress("UNUSED_PARAMETER") - private fun validateArgClassName(className: String, appContext: AppContext) { + private fun validateArgClassName(className: String, argClassName: String, appContext: AppContext) { + // TODO: consider more carefully what to whitelist and how to secure protocols + // For now automatically accept standard java.lang.* and kotlin.* types. + // All other types require manual specification at ProtocolLogicRefFactory construction time. + if (argClassName.startsWith("java.lang.") || argClassName.startsWith("kotlin.")) { + return + } // TODO: make this specific to the attachments in the [AppContext] by including [SecureHash] in whitelist check - require(className in argsClassNameWhitelist) { "Args to ${ProtocolLogicRef::class.java.simpleName} must have types on the args whitelist: $className" } + require(protocolWhitelist[className]!!.contains(argClassName)) { "Args to ${className} must have types on the args whitelist: $argClassName" } } /** @@ -90,14 +96,14 @@ class ProtocolLogicRefFactory(private val protocolLogicClassNameWhitelist: Set>, args: Map): () -> ProtocolLogic<*> { for (constructor in clazz.kotlin.constructors) { - val params = buildParams(appContext, constructor, args) ?: continue + val params = buildParams(appContext, clazz, constructor, args) ?: continue // If we get here then we matched every parameter return { constructor.callBy(params) } } throw IllegalProtocolLogicException(clazz, "as could not find matching constructor for: $args") } - private fun buildParams(appContext: AppContext, constructor: KFunction>, args: Map): HashMap? { + private fun buildParams(appContext: AppContext, clazz: Class>, constructor: KFunction>, args: Map): HashMap? { val params = hashMapOf() val usedKeys = hashSetOf() for (parameter in constructor.parameters) { @@ -111,7 +117,7 @@ class ProtocolLogicRefFactory(private val protocolLogicClassNameWhitelist: Set { - public JavaProtocolLogic(int A, String b) { + public JavaProtocolLogic(ParamType1 A, ParamType2 b) { } @Override @@ -43,13 +63,21 @@ public class ProtocolLogicRefFromJavaTest { @Test public void test() { - ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(Sets.newHashSet(JavaProtocolLogic.class.getName()), Sets.newHashSet(Integer.class.getName(), String.class.getName())); - factory.create(JavaProtocolLogic.class, 1, "Hello Jack"); + Map> whiteList = new HashMap<>(); + Set argsList = new HashSet<>(); + argsList.add(ParamType1.class.getName()); + argsList.add(ParamType2.class.getName()); + whiteList.put(JavaProtocolLogic.class.getName(), argsList); + ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(whiteList); + factory.create(JavaProtocolLogic.class, new ParamType1(1), new ParamType2("Hello Jack")); } @Test public void testNoArg() { - ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(Sets.newHashSet(JavaNoArgProtocolLogic.class.getName()), Sets.newHashSet(Integer.class.getName(), String.class.getName())); + Map> whiteList = new HashMap<>(); + Set argsList = new HashSet<>(); + whiteList.put(JavaNoArgProtocolLogic.class.getName(), argsList); + ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(whiteList); factory.create(JavaNoArgProtocolLogic.class); } } diff --git a/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt b/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt index bcede610b1..aa5651f12d 100644 --- a/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt +++ b/core/src/test/kotlin/com/r3corda/core/protocols/ProtocolLogicRefTest.kt @@ -1,6 +1,5 @@ package com.r3corda.core.protocols -import com.google.common.collect.Sets import com.r3corda.core.days import org.junit.Before import org.junit.Test @@ -8,13 +7,20 @@ import java.time.Duration class ProtocolLogicRefTest { + data class ParamType1(val value: Int) + data class ParamType2(val value: String) + @Suppress("UNUSED_PARAMETER") // We will never use A or b - class KotlinProtocolLogic(A: Int, b: String) : ProtocolLogic() { - constructor() : this(1, "2") + class KotlinProtocolLogic(A: ParamType1, b: ParamType2) : ProtocolLogic() { + constructor() : this(ParamType1(1), ParamType2("2")) - constructor(C: String) : this(1, C) + constructor(C: ParamType2) : this(ParamType1(1), C) - constructor(illegal: Duration) : this(1, illegal.toString()) + constructor(illegal: Duration) : this(ParamType1(1), ParamType2(illegal.toString())) + + constructor(primitive: String) : this(ParamType1(1), ParamType2(primitive)) + + constructor(kotlinType: Int) : this(ParamType1(kotlinType), ParamType2("b")) override fun call(): Unit { } @@ -40,8 +46,8 @@ class ProtocolLogicRefTest { @Before fun setup() { // We have to allow Java boxed primitives but Kotlin warns we shouldn't be using them - factory = ProtocolLogicRefFactory(Sets.newHashSet(KotlinProtocolLogic::class.java.name, KotlinNoArgProtocolLogic::class.java.name), - Sets.newHashSet(@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") Integer::class.java.name, String::class.java.name)) + factory = ProtocolLogicRefFactory(mapOf(Pair(KotlinProtocolLogic::class.java.name, setOf(ParamType1::class.java.name, ParamType2::class.java.name)), + Pair(KotlinNoArgProtocolLogic::class.java.name, setOf()))) } @Test @@ -51,18 +57,18 @@ class ProtocolLogicRefTest { @Test fun testCreateKotlin() { - val args = mapOf(Pair("A", 1), Pair("b", "Hello Jack")) + val args = mapOf(Pair("A", ParamType1(1)), Pair("b", ParamType2("Hello Jack"))) factory.createKotlin(KotlinProtocolLogic::class.java, args) } @Test fun testCreatePrimary() { - factory.create(KotlinProtocolLogic::class.java, 1, "Hello Jack") + factory.create(KotlinProtocolLogic::class.java, ParamType1(1), ParamType2("Hello Jack")) } @Test(expected = IllegalArgumentException::class) fun testCreateNotWhiteListed() { - factory.create(NotWhiteListedKotlinProtocolLogic::class.java, 1, "Hello Jack") + factory.create(NotWhiteListedKotlinProtocolLogic::class.java, ParamType1(1), ParamType2("Hello Jack")) } @Test @@ -72,7 +78,7 @@ class ProtocolLogicRefTest { @Test fun testCreateKotlinNonPrimary() { - val args = mapOf(Pair("C", "Hello Jack")) + val args = mapOf(Pair("C", ParamType2("Hello Jack"))) factory.createKotlin(KotlinProtocolLogic::class.java, args) } @@ -81,4 +87,17 @@ class ProtocolLogicRefTest { val args = mapOf(Pair("illegal", 1.days)) factory.createKotlin(KotlinProtocolLogic::class.java, args) } + + @Test + fun testCreateJavaPrimitiveNoRegistrationRequired() { + val args = mapOf(Pair("primitive", "A string")) + factory.createKotlin(KotlinProtocolLogic::class.java, args) + } + + @Test + fun testCreateKotlinPrimitiveNoRegistrationRequired() { + val args = mapOf(Pair("kotlinType", 3)) + factory.createKotlin(KotlinProtocolLogic::class.java, args) + } + } diff --git a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt index 7838807b5a..ae4276cbeb 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/AbstractNode.kt @@ -5,10 +5,12 @@ import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture import com.r3corda.core.RunOnCallerThread import com.r3corda.core.contracts.SignedTransaction +import com.r3corda.core.contracts.StateRef import com.r3corda.core.crypto.Party import com.r3corda.core.messaging.MessagingService import com.r3corda.core.messaging.runOnNextMessage import com.r3corda.core.node.CityDatabase +import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.PhysicalLocation import com.r3corda.core.node.services.* @@ -46,12 +48,14 @@ import com.r3corda.node.services.wallet.NodeWalletService import com.r3corda.node.utilities.ANSIProgressObserver import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AffinityExecutor +import com.r3corda.protocols.TwoPartyDealProtocol import org.slf4j.Logger import java.nio.file.FileAlreadyExistsException import java.nio.file.Files import java.nio.file.Path import java.security.KeyPair import java.time.Clock +import java.time.Duration import java.util.* /** @@ -97,7 +101,7 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, // Internal only override val monitoringService: MonitoringService = MonitoringService(MetricRegistry()) - override val protocolLogicRefFactory = ProtocolLogicRefFactory() + override val protocolLogicRefFactory: ProtocolLogicRefFactory get() = protocolLogicFactory override fun startProtocol(loggerName: String, logic: ProtocolLogic): ListenableFuture { return smm.add(loggerName, logic) @@ -124,6 +128,7 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, lateinit var net: MessagingService lateinit var api: APIServer lateinit var scheduler: SchedulerService + lateinit var protocolLogicFactory: ProtocolLogicRefFactory var isPreviousCheckpointsPresent = false private set @@ -132,6 +137,11 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, val networkMapRegistrationFuture: ListenableFuture get() = _networkMapRegistrationFuture + /** Fetch CordaPluginRegistry classes registered in META-INF/services/com.r3corda.core.node.CordaPluginRegistry files that exist in the classpath */ + protected val pluginRegistries: List by lazy { + ServiceLoader.load(CordaPluginRegistry::class.java).toList() + } + /** Set to true once [start] has been successfully called. */ @Volatile var started = false private set @@ -158,6 +168,8 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, checkpointStorage, serverThread) + protocolLogicFactory = initialiseProtocolLogicFactory() + // This object doesn't need to be referenced from this class because it registers handlers on the network // service and so that keeps it from being collected. DataVendingService(net, storage, services.networkMapCache) @@ -180,6 +192,18 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration, return this } + private fun initialiseProtocolLogicFactory(): ProtocolLogicRefFactory { + val protocolWhitelist = HashMap>() + for (plugin in pluginRegistries) { + for (protocol in plugin.requiredProtocols) { + protocolWhitelist.merge(protocol.key, protocol.value, { x, y -> x + y }) + } + } + + return ProtocolLogicRefFactory(protocolWhitelist) + } + + /** * Run any tasks that are needed to ensure the node is in a correct state before running start() */ diff --git a/node/src/main/kotlin/com/r3corda/node/internal/Node.kt b/node/src/main/kotlin/com/r3corda/node/internal/Node.kt index 85cfac29a3..1c5c167b76 100644 --- a/node/src/main/kotlin/com/r3corda/node/internal/Node.kt +++ b/node/src/main/kotlin/com/r3corda/node/internal/Node.kt @@ -3,10 +3,11 @@ package com.r3corda.node.internal import com.codahale.metrics.JmxReporter import com.google.common.net.HostAndPort import com.r3corda.core.messaging.MessagingService +import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.NodeInfo +import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.services.ServiceType import com.r3corda.core.utilities.loggerFor -import com.r3corda.node.api.APIServer import com.r3corda.node.serialization.NodeClock import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.messaging.ArtemisMessagingService @@ -27,9 +28,9 @@ import java.io.RandomAccessFile import java.lang.management.ManagementFactory import java.net.InetSocketAddress import java.nio.channels.FileLock -import java.nio.file.Files import java.nio.file.Path import java.time.Clock +import java.util.* import javax.management.ObjectName class ConfigurationException(message: String) : Exception(message) @@ -55,8 +56,7 @@ class ConfigurationException(message: String) : Exception(message) */ class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort, configuration: NodeConfiguration, networkMapAddress: NodeInfo?, advertisedServices: Set, - clock: Clock = NodeClock(), - val clientAPIs: List> = listOf()) : AbstractNode(dir, configuration, networkMapAddress, advertisedServices, clock) { + clock: Clock = NodeClock()) : AbstractNode(dir, configuration, networkMapAddress, advertisedServices, clock) { companion object { /** The port that is used by default if none is specified. As you know, 31337 is the most elite number. */ val DEFAULT_PORT = 31337 @@ -109,12 +109,13 @@ class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort, resourceConfig.register(ResponseFilter()) resourceConfig.register(api) - for(customAPIClass in clientAPIs) { - val customAPI = customAPIClass.getConstructor(APIServer::class.java).newInstance(api) + val webAPIsOnClasspath = pluginRegistries.flatMap { x -> x.webApis } + for (webapi in webAPIsOnClasspath) { + log.info("Add Plugin web API from attachment ${webapi.name}") + val customAPI = webapi.getConstructor(ServiceHub::class.java).newInstance(services) resourceConfig.register(customAPI) } - // Give the app a slightly better name in JMX rather than a randomly generated one and enable JMX resourceConfig.addProperties(mapOf(ServerProperties.APPLICATION_NAME to "node.api", ServerProperties.MONITORING_STATISTICS_MBEANS_ENABLED to "true")) @@ -187,5 +188,5 @@ class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort, val ourProcessID: String = ManagementFactory.getRuntimeMXBean().name.split("@")[0] f.setLength(0) f.write(ourProcessID.toByteArray()) - } + } } diff --git a/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt index e206a28644..d60c20ab67 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/api/ServiceHubInternal.kt @@ -29,4 +29,11 @@ abstract class ServiceHubInternal : ServiceHub { * itself, at which point this method would not be needed (by the scheduler) */ abstract fun startProtocol(loggerName: String, logic: ProtocolLogic): ListenableFuture + + override fun invokeProtocolAsync(logicType: Class>, vararg args: Any?): ListenableFuture { + val logicRef = protocolLogicRefFactory.create(logicType, *args) + @Suppress("UNCHECKED_CAST") + val logic = protocolLogicRefFactory.toProtocolLogic(logicRef) as ProtocolLogic + return startProtocol(logicType.simpleName, logic) + } } \ No newline at end of file diff --git a/node/src/main/kotlin/com/r3corda/node/services/clientapi/NodeInterestRates.kt b/node/src/main/kotlin/com/r3corda/node/services/clientapi/NodeInterestRates.kt index 94d4c9dbeb..d69dd39e2a 100644 --- a/node/src/main/kotlin/com/r3corda/node/services/clientapi/NodeInterestRates.kt +++ b/node/src/main/kotlin/com/r3corda/node/services/clientapi/NodeInterestRates.kt @@ -9,6 +9,7 @@ import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.math.CubicSplineInterpolator import com.r3corda.core.math.Interpolator import com.r3corda.core.math.InterpolatorFactory +import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.services.ServiceType import com.r3corda.core.protocols.ProtocolLogic import com.r3corda.core.utilities.ProgressTracker @@ -18,11 +19,13 @@ import com.r3corda.node.services.api.AcceptsFileUpload import com.r3corda.node.utilities.FiberBox import com.r3corda.protocols.RatesFixProtocol import com.r3corda.protocols.ServiceRequestMessage +import com.r3corda.protocols.TwoPartyDealProtocol import org.slf4j.LoggerFactory import java.io.InputStream import java.math.BigDecimal import java.security.KeyPair import java.time.Clock +import java.time.Duration import java.time.Instant import java.time.LocalDate import java.util.* @@ -93,6 +96,15 @@ object NodeInterestRates { } } + /** + * Register the protocol that is used with the Fixing integration tests + */ + class FixingServicePlugin : CordaPluginRegistry { + override val webApis: List> = emptyList() + override val requiredProtocols: Map> = mapOf(Pair(TwoPartyDealProtocol.FixingRoleDecider::class.java.name, setOf(Duration::class.java.name, StateRef::class.java.name))) + + } + // File upload support override val dataTypePrefix = "interest-rates" override val acceptableFileExtensions = listOf(".rates", ".txt") diff --git a/node/src/main/resources/META-INF/services/com.r3corda.core.node.CordaPluginRegistry b/node/src/main/resources/META-INF/services/com.r3corda.core.node.CordaPluginRegistry new file mode 100644 index 0000000000..6c27482af8 --- /dev/null +++ b/node/src/main/resources/META-INF/services/com.r3corda.core.node.CordaPluginRegistry @@ -0,0 +1,2 @@ +# Register a ServiceLoader service extending from com.r3corda.node.CordaPluginRegistry +com.r3corda.node.services.clientapi.NodeInterestRates$Service$FixingServicePlugin \ No newline at end of file diff --git a/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt index 566d2aff3f..c2f20ec8b3 100644 --- a/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/com/r3corda/node/services/NodeSchedulerServiceTest.kt @@ -43,7 +43,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { // We have to allow Java boxed primitives but Kotlin warns we shouldn't be using them @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") - val factory = ProtocolLogicRefFactory(setOf(TestProtocolLogic::class.java.name), setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name)) + val factory = ProtocolLogicRefFactory(mapOf(Pair(TestProtocolLogic::class.java.name, setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name)))) val scheduler: NodeSchedulerService val services: ServiceHub diff --git a/src/main/kotlin/com/r3corda/demos/IRSDemo.kt b/src/main/kotlin/com/r3corda/demos/IRSDemo.kt index c2a60d4609..d3824c39e0 100644 --- a/src/main/kotlin/com/r3corda/demos/IRSDemo.kt +++ b/src/main/kotlin/com/r3corda/demos/IRSDemo.kt @@ -1,9 +1,11 @@ package com.r3corda.demos import com.google.common.net.HostAndPort +import com.r3corda.contracts.InterestRateSwap import com.r3corda.core.crypto.Party import com.r3corda.core.logElapsedTime import com.r3corda.core.messaging.SingleMessageRecipient +import com.r3corda.core.node.CordaPluginRegistry import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.services.ServiceType import com.r3corda.core.serialization.deserialize @@ -241,6 +243,14 @@ object CliParamsSpec { val nonOptions = parser.nonOptions() } +class IRSDemoPluginRegistry : CordaPluginRegistry { + override val webApis: List> = listOf(InterestRateSwapAPI::class.java) + override val requiredProtocols: Map> = mapOf( + Pair(AutoOfferProtocol.Requester::class.java.name, setOf(InterestRateSwap.State::class.java.name)), + Pair(UpdateBusinessDayProtocol.Broadcast::class.java.name, setOf(java.time.LocalDate::class.java.name)), + Pair(ExitServerProtocol.Broadcast::class.java.name, setOf(kotlin.Int::class.java.name))) +} + private class NotSetupException: Throwable { constructor(message: String): super(message) {} } @@ -374,8 +384,7 @@ private fun startNode(params: CliParams.RunNode, networkMap: SingleMessageRecipi } val node = logElapsedTime("Node startup") { - Node(params.dir, params.networkAddress, params.apiAddress, config, networkMapId, advertisedServices, DemoClock(), - listOf(InterestRateSwapAPI::class.java)).start() + Node(params.dir, params.networkAddress, params.apiAddress, config, networkMapId, advertisedServices, DemoClock()).start() } // TODO: This should all be replaced by the identity service being updated diff --git a/src/main/kotlin/com/r3corda/demos/RateFixDemo.kt b/src/main/kotlin/com/r3corda/demos/RateFixDemo.kt index 96473b033c..1374c58dd4 100644 --- a/src/main/kotlin/com/r3corda/demos/RateFixDemo.kt +++ b/src/main/kotlin/com/r3corda/demos/RateFixDemo.kt @@ -76,8 +76,7 @@ fun main(args: Array) { val apiAddr = HostAndPort.fromParts(myNetAddr.hostText, myNetAddr.port + 1) val node = logElapsedTime("Node startup") { Node(dir, myNetAddr, apiAddr, config, networkMapAddress, - advertisedServices, DemoClock(), - listOf(InterestRateSwapAPI::class.java)).setup().start() } + advertisedServices, DemoClock()).setup().start() } val notary = node.services.networkMapCache.notaryNodes[0] diff --git a/src/main/kotlin/com/r3corda/demos/api/InterestRateSwapAPI.kt b/src/main/kotlin/com/r3corda/demos/api/InterestRateSwapAPI.kt index 9197216da7..c5e3574360 100644 --- a/src/main/kotlin/com/r3corda/demos/api/InterestRateSwapAPI.kt +++ b/src/main/kotlin/com/r3corda/demos/api/InterestRateSwapAPI.kt @@ -1,15 +1,16 @@ package com.r3corda.demos.api import com.r3corda.contracts.InterestRateSwap +import com.r3corda.core.contracts.SignedTransaction +import com.r3corda.core.node.ServiceHub +import com.r3corda.core.node.services.linearHeadsOfType import com.r3corda.core.utilities.loggerFor import com.r3corda.demos.protocols.AutoOfferProtocol import com.r3corda.demos.protocols.ExitServerProtocol import com.r3corda.demos.protocols.UpdateBusinessDayProtocol -import com.r3corda.node.api.APIServer -import com.r3corda.node.api.ProtocolClassRef -import com.r3corda.node.api.StatesQuery import java.net.URI import java.time.LocalDate +import java.time.LocalDateTime import javax.ws.rs.* import javax.ws.rs.core.MediaType import javax.ws.rs.core.Response @@ -35,23 +36,23 @@ import javax.ws.rs.core.Response * or if the demodate or population of deals should be reset (will only work while persistence is disabled). */ @Path("irs") -class InterestRateSwapAPI(val api: APIServer) { +class InterestRateSwapAPI(val services: ServiceHub) { private val logger = loggerFor() private fun generateDealLink(deal: InterestRateSwap.State) = "/api/irs/deals/" + deal.common.tradeID private fun getDealByRef(ref: String): InterestRateSwap.State? { - val states = api.queryStates(StatesQuery.selectDeal(ref)) + val states = services.walletService.linearHeadsOfType().filterValues { it.state.data.ref == ref } return if (states.isEmpty()) null else { - val deals = api.fetchStates(states).values.map { it?.data as InterestRateSwap.State }.filterNotNull() + val deals = states.values.map { it.state.data } return if (deals.isEmpty()) null else deals[0] } } private fun getAllDeals(): Array { - val states = api.queryStates(StatesQuery.selectAllDeals()) - val swaps = api.fetchStates(states).values.map { it?.data as InterestRateSwap.State }.filterNotNull().toTypedArray() + val states = services.walletService.linearHeadsOfType() + val swaps = states.values.map { it.state.data }.toTypedArray() return swaps } @@ -64,7 +65,7 @@ class InterestRateSwapAPI(val api: APIServer) { @Path("deals") @Consumes(MediaType.APPLICATION_JSON) fun storeDeal(newDeal: InterestRateSwap.State): Response { - api.invokeProtocolSync(ProtocolClassRef(AutoOfferProtocol.Requester::class.java.name!!), mapOf("dealToBeOffered" to newDeal)) + services.invokeProtocolAsync(AutoOfferProtocol.Requester::class.java, newDeal).get() return Response.created(URI.create(generateDealLink(newDeal))).build() } @@ -84,10 +85,10 @@ class InterestRateSwapAPI(val api: APIServer) { @Path("demodate") @Consumes(MediaType.APPLICATION_JSON) fun storeDemoDate(newDemoDate: LocalDate): Response { - val priorDemoDate = api.serverTime().toLocalDate() + val priorDemoDate = fetchDemoDate() // Can only move date forwards if (newDemoDate.isAfter(priorDemoDate)) { - api.invokeProtocolSync(ProtocolClassRef(UpdateBusinessDayProtocol.Broadcast::class.java.name!!), mapOf("date" to newDemoDate)) + services.invokeProtocolAsync(UpdateBusinessDayProtocol.Broadcast::class.java, newDemoDate).get() return Response.ok().build() } val msg = "demodate is already $priorDemoDate and can only be updated with a later date" @@ -99,14 +100,14 @@ class InterestRateSwapAPI(val api: APIServer) { @Path("demodate") @Produces(MediaType.APPLICATION_JSON) fun fetchDemoDate(): LocalDate { - return api.serverTime().toLocalDate() + return LocalDateTime.now(services.clock).toLocalDate() } @PUT @Path("restart") @Consumes(MediaType.APPLICATION_JSON) fun exitServer(): Response { - api.invokeProtocolSync(ProtocolClassRef(ExitServerProtocol.Broadcast::class.java.name!!), mapOf("exitCode" to 83)) + services.invokeProtocolAsync(ExitServerProtocol.Broadcast::class.java, 83).get() return Response.ok().build() } } diff --git a/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt b/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt index 007aebd7e2..5258e3dead 100644 --- a/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt +++ b/src/main/kotlin/com/r3corda/demos/protocols/ExitServerProtocol.kt @@ -37,22 +37,21 @@ object ExitServerProtocol { * This takes a Java Integer rather than Kotlin Int as that is what we end up with in the calling map and currently * we do not support coercing numeric types in the reflective search for matching constructors */ - class Broadcast(@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") val exitCode: Integer) : ProtocolLogic() { + class Broadcast(val exitCode: Int) : ProtocolLogic() { override val topic: String get() = TOPIC @Suspendable override fun call(): Boolean { if (enabled) { - val rc = exitCode.toInt() - val message = ExitMessage(rc) + val message = ExitMessage(exitCode) for (recipient in serviceHub.networkMapCache.partyNodes) { doNextRecipient(recipient, message) } // Sleep a little in case any async message delivery to other nodes needs to happen Strand.sleep(1, TimeUnit.SECONDS) - System.exit(rc) + System.exit(exitCode) } return enabled } diff --git a/src/main/resources/META-INF/services/com.r3corda.core.node.CordaPluginRegistry b/src/main/resources/META-INF/services/com.r3corda.core.node.CordaPluginRegistry new file mode 100644 index 0000000000..279300f9f9 --- /dev/null +++ b/src/main/resources/META-INF/services/com.r3corda.core.node.CordaPluginRegistry @@ -0,0 +1,2 @@ +# Register a ServiceLoader service extending from com.r3corda.node.CordaPluginRegistry +com.r3corda.demos.IRSDemoPluginRegistry \ No newline at end of file