Merged in mnesbit-cor-249-serviceloader-for-web-apis (pull request #204)

Make the IRS Demo web api an api plugin (scanned from the Node classpath) and use the same permission checking entry point for web api's as the scheduler.
This commit is contained in:
Matthew Nesbit 2016-07-08 10:47:02 +01:00
commit fb2efd8fc1
16 changed files with 196 additions and 55 deletions

View File

@ -0,0 +1,23 @@
package com.r3corda.core.node
/**
* Implement this interface on a class advertised in a META-INF/services/com.r3corda.core.node.CordaPluginRegistry file
* to extend a Corda node with additional application services.
*/
interface CordaPluginRegistry {
/**
* List of JAX-RS classes inside the contract jar. They are expected to have a single parameter constructor that takes a ServiceHub as input.
* These are listed as Class<*>, because they will be instantiated inside an AttachmentClassLoader so that subsequent protocols, contracts, etc
* will be running in the appropriate isolated context.
*/
val webApis: List<Class<*>>
/**
* A Map with an entry for each consumed protocol used by the webAPIs.
* The key of each map entry should contain the ProtocolLogic<T> class name.
* The associated map values are the union of all concrete class names passed to the protocol constructor.
* Standard java.lang.* and kotlin.* types do not need to be included explicitly
* This is used to extend the white listed protocols that can be initiated from the ServiceHub invokeProtocolAsync method
*/
val requiredProtocols: Map<String, Set<String>>
}

View File

@ -1,8 +1,10 @@
package com.r3corda.core.node
import com.google.common.util.concurrent.ListenableFuture
import com.r3corda.core.contracts.*
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.services.*
import com.r3corda.core.protocols.ProtocolLogic
import java.time.Clock
/**
@ -61,4 +63,11 @@ interface ServiceHub {
val definingTx = storageService.validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash)
return definingTx.tx.outputs[stateRef.index]
}
/**
* Will check [logicType] and [args] against a whitelist and if acceptable then construct and initiate the protocol.
*
* @throws IllegalProtocolLogicException or IllegalArgumentException if there are problems with the [logicType] or [args]
*/
fun <T : Any> invokeProtocolAsync(logicType: Class<out ProtocolLogic<T>>, vararg args: Any?): ListenableFuture<T>
}

View File

@ -25,22 +25,28 @@ import kotlin.reflect.primaryConstructor
* TODO: Align with API related logic for passing in ProtocolLogic references (ProtocolRef)
* TODO: Actual support for AppContext / AttachmentsClassLoader
*/
class ProtocolLogicRefFactory(private val protocolLogicClassNameWhitelist: Set<String>, private val argsClassNameWhitelist: Set<String>) : SingletonSerializeAsToken() {
class ProtocolLogicRefFactory(private val protocolWhitelist: Map<String, Set<String>>) : SingletonSerializeAsToken() {
constructor() : this(setOf(TwoPartyDealProtocol.FixingRoleDecider::class.java.name), setOf(StateRef::class.java.name, Duration::class.java.name))
constructor() : this(mapOf(Pair(TwoPartyDealProtocol.FixingRoleDecider::class.java.name, setOf(StateRef::class.java.name, Duration::class.java.name))))
// Pending real dependence on AppContext for class loading etc
@Suppress("UNUSED_PARAMETER")
private fun validateProtocolClassName(className: String, appContext: AppContext) {
// TODO: make this specific to the attachments in the [AppContext] by including [SecureHash] in whitelist check
require(className in protocolLogicClassNameWhitelist) { "${ProtocolLogic::class.java.simpleName} of ${ProtocolLogicRef::class.java.simpleName} must have type on the whitelist: $className" }
require(protocolWhitelist.containsKey(className)) { "${ProtocolLogic::class.java.simpleName} of ${ProtocolLogicRef::class.java.simpleName} must have type on the whitelist: $className" }
}
// Pending real dependence on AppContext for class loading etc
@Suppress("UNUSED_PARAMETER")
private fun validateArgClassName(className: String, appContext: AppContext) {
private fun validateArgClassName(className: String, argClassName: String, appContext: AppContext) {
// TODO: consider more carefully what to whitelist and how to secure protocols
// For now automatically accept standard java.lang.* and kotlin.* types.
// All other types require manual specification at ProtocolLogicRefFactory construction time.
if (argClassName.startsWith("java.lang.") || argClassName.startsWith("kotlin.")) {
return
}
// TODO: make this specific to the attachments in the [AppContext] by including [SecureHash] in whitelist check
require(className in argsClassNameWhitelist) { "Args to ${ProtocolLogicRef::class.java.simpleName} must have types on the args whitelist: $className" }
require(protocolWhitelist[className]!!.contains(argClassName)) { "Args to ${className} must have types on the args whitelist: $argClassName" }
}
/**
@ -90,14 +96,14 @@ class ProtocolLogicRefFactory(private val protocolLogicClassNameWhitelist: Set<S
private fun createConstructor(appContext: AppContext, clazz: Class<out ProtocolLogic<*>>, args: Map<String, Any?>): () -> ProtocolLogic<*> {
for (constructor in clazz.kotlin.constructors) {
val params = buildParams(appContext, constructor, args) ?: continue
val params = buildParams(appContext, clazz, constructor, args) ?: continue
// If we get here then we matched every parameter
return { constructor.callBy(params) }
}
throw IllegalProtocolLogicException(clazz, "as could not find matching constructor for: $args")
}
private fun buildParams(appContext: AppContext, constructor: KFunction<ProtocolLogic<*>>, args: Map<String, Any?>): HashMap<KParameter, Any?>? {
private fun buildParams(appContext: AppContext, clazz: Class<out ProtocolLogic<*>>, constructor: KFunction<ProtocolLogic<*>>, args: Map<String, Any?>): HashMap<KParameter, Any?>? {
val params = hashMapOf<KParameter, Any?>()
val usedKeys = hashSetOf<String>()
for (parameter in constructor.parameters) {
@ -111,7 +117,7 @@ class ProtocolLogicRefFactory(private val protocolLogicClassNameWhitelist: Set<S
// Not all args were used
return null
}
params.values.forEach { if (it is Any) validateArgClassName(it.javaClass.name, appContext) }
params.values.forEach { if (it is Any) validateArgClassName(clazz.name, it.javaClass.name, appContext) }
return params
}

View File

@ -1,15 +1,35 @@
package com.r3corda.core.protocols;
import com.google.common.collect.Sets;
import org.jetbrains.annotations.NotNull;
import org.junit.Test;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
public class ProtocolLogicRefFromJavaTest {
public static class ParamType1 {
public final int value;
ParamType1(int v) {
value = v;
}
}
public static class ParamType2 {
public final String value;
ParamType2(String v) {
value = v;
}
}
public static class JavaProtocolLogic extends ProtocolLogic<Void> {
public JavaProtocolLogic(int A, String b) {
public JavaProtocolLogic(ParamType1 A, ParamType2 b) {
}
@Override
@ -43,13 +63,21 @@ public class ProtocolLogicRefFromJavaTest {
@Test
public void test() {
ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(Sets.newHashSet(JavaProtocolLogic.class.getName()), Sets.newHashSet(Integer.class.getName(), String.class.getName()));
factory.create(JavaProtocolLogic.class, 1, "Hello Jack");
Map<String, Set<String>> whiteList = new HashMap<>();
Set<String> argsList = new HashSet<>();
argsList.add(ParamType1.class.getName());
argsList.add(ParamType2.class.getName());
whiteList.put(JavaProtocolLogic.class.getName(), argsList);
ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(whiteList);
factory.create(JavaProtocolLogic.class, new ParamType1(1), new ParamType2("Hello Jack"));
}
@Test
public void testNoArg() {
ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(Sets.newHashSet(JavaNoArgProtocolLogic.class.getName()), Sets.newHashSet(Integer.class.getName(), String.class.getName()));
Map<String, Set<String>> whiteList = new HashMap<>();
Set<String> argsList = new HashSet<>();
whiteList.put(JavaNoArgProtocolLogic.class.getName(), argsList);
ProtocolLogicRefFactory factory = new ProtocolLogicRefFactory(whiteList);
factory.create(JavaNoArgProtocolLogic.class);
}
}

View File

@ -1,6 +1,5 @@
package com.r3corda.core.protocols
import com.google.common.collect.Sets
import com.r3corda.core.days
import org.junit.Before
import org.junit.Test
@ -8,13 +7,20 @@ import java.time.Duration
class ProtocolLogicRefTest {
data class ParamType1(val value: Int)
data class ParamType2(val value: String)
@Suppress("UNUSED_PARAMETER") // We will never use A or b
class KotlinProtocolLogic(A: Int, b: String) : ProtocolLogic<Unit>() {
constructor() : this(1, "2")
class KotlinProtocolLogic(A: ParamType1, b: ParamType2) : ProtocolLogic<Unit>() {
constructor() : this(ParamType1(1), ParamType2("2"))
constructor(C: String) : this(1, C)
constructor(C: ParamType2) : this(ParamType1(1), C)
constructor(illegal: Duration) : this(1, illegal.toString())
constructor(illegal: Duration) : this(ParamType1(1), ParamType2(illegal.toString()))
constructor(primitive: String) : this(ParamType1(1), ParamType2(primitive))
constructor(kotlinType: Int) : this(ParamType1(kotlinType), ParamType2("b"))
override fun call(): Unit {
}
@ -40,8 +46,8 @@ class ProtocolLogicRefTest {
@Before
fun setup() {
// We have to allow Java boxed primitives but Kotlin warns we shouldn't be using them
factory = ProtocolLogicRefFactory(Sets.newHashSet(KotlinProtocolLogic::class.java.name, KotlinNoArgProtocolLogic::class.java.name),
Sets.newHashSet(@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") Integer::class.java.name, String::class.java.name))
factory = ProtocolLogicRefFactory(mapOf(Pair(KotlinProtocolLogic::class.java.name, setOf(ParamType1::class.java.name, ParamType2::class.java.name)),
Pair(KotlinNoArgProtocolLogic::class.java.name, setOf())))
}
@Test
@ -51,18 +57,18 @@ class ProtocolLogicRefTest {
@Test
fun testCreateKotlin() {
val args = mapOf(Pair("A", 1), Pair("b", "Hello Jack"))
val args = mapOf(Pair("A", ParamType1(1)), Pair("b", ParamType2("Hello Jack")))
factory.createKotlin(KotlinProtocolLogic::class.java, args)
}
@Test
fun testCreatePrimary() {
factory.create(KotlinProtocolLogic::class.java, 1, "Hello Jack")
factory.create(KotlinProtocolLogic::class.java, ParamType1(1), ParamType2("Hello Jack"))
}
@Test(expected = IllegalArgumentException::class)
fun testCreateNotWhiteListed() {
factory.create(NotWhiteListedKotlinProtocolLogic::class.java, 1, "Hello Jack")
factory.create(NotWhiteListedKotlinProtocolLogic::class.java, ParamType1(1), ParamType2("Hello Jack"))
}
@Test
@ -72,7 +78,7 @@ class ProtocolLogicRefTest {
@Test
fun testCreateKotlinNonPrimary() {
val args = mapOf(Pair("C", "Hello Jack"))
val args = mapOf(Pair("C", ParamType2("Hello Jack")))
factory.createKotlin(KotlinProtocolLogic::class.java, args)
}
@ -81,4 +87,17 @@ class ProtocolLogicRefTest {
val args = mapOf(Pair("illegal", 1.days))
factory.createKotlin(KotlinProtocolLogic::class.java, args)
}
@Test
fun testCreateJavaPrimitiveNoRegistrationRequired() {
val args = mapOf(Pair("primitive", "A string"))
factory.createKotlin(KotlinProtocolLogic::class.java, args)
}
@Test
fun testCreateKotlinPrimitiveNoRegistrationRequired() {
val args = mapOf(Pair("kotlinType", 3))
factory.createKotlin(KotlinProtocolLogic::class.java, args)
}
}

View File

@ -5,10 +5,12 @@ import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.RunOnCallerThread
import com.r3corda.core.contracts.SignedTransaction
import com.r3corda.core.contracts.StateRef
import com.r3corda.core.crypto.Party
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.messaging.runOnNextMessage
import com.r3corda.core.node.CityDatabase
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.PhysicalLocation
import com.r3corda.core.node.services.*
@ -46,12 +48,14 @@ import com.r3corda.node.services.wallet.NodeWalletService
import com.r3corda.node.utilities.ANSIProgressObserver
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.protocols.TwoPartyDealProtocol
import org.slf4j.Logger
import java.nio.file.FileAlreadyExistsException
import java.nio.file.Files
import java.nio.file.Path
import java.security.KeyPair
import java.time.Clock
import java.time.Duration
import java.util.*
/**
@ -97,7 +101,7 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
// Internal only
override val monitoringService: MonitoringService = MonitoringService(MetricRegistry())
override val protocolLogicRefFactory = ProtocolLogicRefFactory()
override val protocolLogicRefFactory: ProtocolLogicRefFactory get() = protocolLogicFactory
override fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T> {
return smm.add(loggerName, logic)
@ -124,6 +128,7 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
lateinit var net: MessagingService
lateinit var api: APIServer
lateinit var scheduler: SchedulerService
lateinit var protocolLogicFactory: ProtocolLogicRefFactory
var isPreviousCheckpointsPresent = false
private set
@ -132,6 +137,11 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
val networkMapRegistrationFuture: ListenableFuture<Unit>
get() = _networkMapRegistrationFuture
/** Fetch CordaPluginRegistry classes registered in META-INF/services/com.r3corda.core.node.CordaPluginRegistry files that exist in the classpath */
protected val pluginRegistries: List<CordaPluginRegistry> by lazy {
ServiceLoader.load(CordaPluginRegistry::class.java).toList()
}
/** Set to true once [start] has been successfully called. */
@Volatile var started = false
private set
@ -158,6 +168,8 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
checkpointStorage,
serverThread)
protocolLogicFactory = initialiseProtocolLogicFactory()
// This object doesn't need to be referenced from this class because it registers handlers on the network
// service and so that keeps it from being collected.
DataVendingService(net, storage, services.networkMapCache)
@ -180,6 +192,18 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
return this
}
private fun initialiseProtocolLogicFactory(): ProtocolLogicRefFactory {
val protocolWhitelist = HashMap<String, Set<String>>()
for (plugin in pluginRegistries) {
for (protocol in plugin.requiredProtocols) {
protocolWhitelist.merge(protocol.key, protocol.value, { x, y -> x + y })
}
}
return ProtocolLogicRefFactory(protocolWhitelist)
}
/**
* Run any tasks that are needed to ensure the node is in a correct state before running start()
*/

View File

@ -3,10 +3,11 @@ package com.r3corda.node.internal
import com.codahale.metrics.JmxReporter
import com.google.common.net.HostAndPort
import com.r3corda.core.messaging.MessagingService
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.api.APIServer
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingService
@ -27,9 +28,9 @@ import java.io.RandomAccessFile
import java.lang.management.ManagementFactory
import java.net.InetSocketAddress
import java.nio.channels.FileLock
import java.nio.file.Files
import java.nio.file.Path
import java.time.Clock
import java.util.*
import javax.management.ObjectName
class ConfigurationException(message: String) : Exception(message)
@ -55,8 +56,7 @@ class ConfigurationException(message: String) : Exception(message)
*/
class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort, configuration: NodeConfiguration,
networkMapAddress: NodeInfo?, advertisedServices: Set<ServiceType>,
clock: Clock = NodeClock(),
val clientAPIs: List<Class<*>> = listOf()) : AbstractNode(dir, configuration, networkMapAddress, advertisedServices, clock) {
clock: Clock = NodeClock()) : AbstractNode(dir, configuration, networkMapAddress, advertisedServices, clock) {
companion object {
/** The port that is used by default if none is specified. As you know, 31337 is the most elite number. */
val DEFAULT_PORT = 31337
@ -109,12 +109,13 @@ class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort,
resourceConfig.register(ResponseFilter())
resourceConfig.register(api)
for(customAPIClass in clientAPIs) {
val customAPI = customAPIClass.getConstructor(APIServer::class.java).newInstance(api)
val webAPIsOnClasspath = pluginRegistries.flatMap { x -> x.webApis }
for (webapi in webAPIsOnClasspath) {
log.info("Add Plugin web API from attachment ${webapi.name}")
val customAPI = webapi.getConstructor(ServiceHub::class.java).newInstance(services)
resourceConfig.register(customAPI)
}
// Give the app a slightly better name in JMX rather than a randomly generated one and enable JMX
resourceConfig.addProperties(mapOf(ServerProperties.APPLICATION_NAME to "node.api",
ServerProperties.MONITORING_STATISTICS_MBEANS_ENABLED to "true"))
@ -187,5 +188,5 @@ class Node(dir: Path, val p2pAddr: HostAndPort, val webServerAddr: HostAndPort,
val ourProcessID: String = ManagementFactory.getRuntimeMXBean().name.split("@")[0]
f.setLength(0)
f.write(ourProcessID.toByteArray())
}
}
}

View File

@ -29,4 +29,11 @@ abstract class ServiceHubInternal : ServiceHub {
* itself, at which point this method would not be needed (by the scheduler)
*/
abstract fun <T> startProtocol(loggerName: String, logic: ProtocolLogic<T>): ListenableFuture<T>
override fun <T : Any> invokeProtocolAsync(logicType: Class<out ProtocolLogic<T>>, vararg args: Any?): ListenableFuture<T> {
val logicRef = protocolLogicRefFactory.create(logicType, *args)
@Suppress("UNCHECKED_CAST")
val logic = protocolLogicRefFactory.toProtocolLogic(logicRef) as ProtocolLogic<T>
return startProtocol(logicType.simpleName, logic)
}
}

View File

@ -9,6 +9,7 @@ import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.math.CubicSplineInterpolator
import com.r3corda.core.math.Interpolator
import com.r3corda.core.math.InterpolatorFactory
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.utilities.ProgressTracker
@ -18,11 +19,13 @@ import com.r3corda.node.services.api.AcceptsFileUpload
import com.r3corda.node.utilities.FiberBox
import com.r3corda.protocols.RatesFixProtocol
import com.r3corda.protocols.ServiceRequestMessage
import com.r3corda.protocols.TwoPartyDealProtocol
import org.slf4j.LoggerFactory
import java.io.InputStream
import java.math.BigDecimal
import java.security.KeyPair
import java.time.Clock
import java.time.Duration
import java.time.Instant
import java.time.LocalDate
import java.util.*
@ -93,6 +96,15 @@ object NodeInterestRates {
}
}
/**
* Register the protocol that is used with the Fixing integration tests
*/
class FixingServicePlugin : CordaPluginRegistry {
override val webApis: List<Class<*>> = emptyList()
override val requiredProtocols: Map<String, Set<String>> = mapOf(Pair(TwoPartyDealProtocol.FixingRoleDecider::class.java.name, setOf(Duration::class.java.name, StateRef::class.java.name)))
}
// File upload support
override val dataTypePrefix = "interest-rates"
override val acceptableFileExtensions = listOf(".rates", ".txt")

View File

@ -0,0 +1,2 @@
# Register a ServiceLoader service extending from com.r3corda.node.CordaPluginRegistry
com.r3corda.node.services.clientapi.NodeInterestRates$Service$FixingServicePlugin

View File

@ -43,7 +43,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
// We have to allow Java boxed primitives but Kotlin warns we shouldn't be using them
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
val factory = ProtocolLogicRefFactory(setOf(TestProtocolLogic::class.java.name), setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name))
val factory = ProtocolLogicRefFactory(mapOf(Pair(TestProtocolLogic::class.java.name, setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name))))
val scheduler: NodeSchedulerService
val services: ServiceHub

View File

@ -1,9 +1,11 @@
package com.r3corda.demos
import com.google.common.net.HostAndPort
import com.r3corda.contracts.InterestRateSwap
import com.r3corda.core.crypto.Party
import com.r3corda.core.logElapsedTime
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.serialization.deserialize
@ -241,6 +243,14 @@ object CliParamsSpec {
val nonOptions = parser.nonOptions()
}
class IRSDemoPluginRegistry : CordaPluginRegistry {
override val webApis: List<Class<*>> = listOf(InterestRateSwapAPI::class.java)
override val requiredProtocols: Map<String, Set<String>> = mapOf(
Pair(AutoOfferProtocol.Requester::class.java.name, setOf(InterestRateSwap.State::class.java.name)),
Pair(UpdateBusinessDayProtocol.Broadcast::class.java.name, setOf(java.time.LocalDate::class.java.name)),
Pair(ExitServerProtocol.Broadcast::class.java.name, setOf(kotlin.Int::class.java.name)))
}
private class NotSetupException: Throwable {
constructor(message: String): super(message) {}
}
@ -374,8 +384,7 @@ private fun startNode(params: CliParams.RunNode, networkMap: SingleMessageRecipi
}
val node = logElapsedTime("Node startup") {
Node(params.dir, params.networkAddress, params.apiAddress, config, networkMapId, advertisedServices, DemoClock(),
listOf(InterestRateSwapAPI::class.java)).start()
Node(params.dir, params.networkAddress, params.apiAddress, config, networkMapId, advertisedServices, DemoClock()).start()
}
// TODO: This should all be replaced by the identity service being updated

View File

@ -76,8 +76,7 @@ fun main(args: Array<String>) {
val apiAddr = HostAndPort.fromParts(myNetAddr.hostText, myNetAddr.port + 1)
val node = logElapsedTime("Node startup") { Node(dir, myNetAddr, apiAddr, config, networkMapAddress,
advertisedServices, DemoClock(),
listOf(InterestRateSwapAPI::class.java)).setup().start() }
advertisedServices, DemoClock()).setup().start() }
val notary = node.services.networkMapCache.notaryNodes[0]

View File

@ -1,15 +1,16 @@
package com.r3corda.demos.api
import com.r3corda.contracts.InterestRateSwap
import com.r3corda.core.contracts.SignedTransaction
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.linearHeadsOfType
import com.r3corda.core.utilities.loggerFor
import com.r3corda.demos.protocols.AutoOfferProtocol
import com.r3corda.demos.protocols.ExitServerProtocol
import com.r3corda.demos.protocols.UpdateBusinessDayProtocol
import com.r3corda.node.api.APIServer
import com.r3corda.node.api.ProtocolClassRef
import com.r3corda.node.api.StatesQuery
import java.net.URI
import java.time.LocalDate
import java.time.LocalDateTime
import javax.ws.rs.*
import javax.ws.rs.core.MediaType
import javax.ws.rs.core.Response
@ -35,23 +36,23 @@ import javax.ws.rs.core.Response
* or if the demodate or population of deals should be reset (will only work while persistence is disabled).
*/
@Path("irs")
class InterestRateSwapAPI(val api: APIServer) {
class InterestRateSwapAPI(val services: ServiceHub) {
private val logger = loggerFor<InterestRateSwapAPI>()
private fun generateDealLink(deal: InterestRateSwap.State) = "/api/irs/deals/" + deal.common.tradeID
private fun getDealByRef(ref: String): InterestRateSwap.State? {
val states = api.queryStates(StatesQuery.selectDeal(ref))
val states = services.walletService.linearHeadsOfType<InterestRateSwap.State>().filterValues { it.state.data.ref == ref }
return if (states.isEmpty()) null else {
val deals = api.fetchStates(states).values.map { it?.data as InterestRateSwap.State }.filterNotNull()
val deals = states.values.map { it.state.data }
return if (deals.isEmpty()) null else deals[0]
}
}
private fun getAllDeals(): Array<InterestRateSwap.State> {
val states = api.queryStates(StatesQuery.selectAllDeals())
val swaps = api.fetchStates(states).values.map { it?.data as InterestRateSwap.State }.filterNotNull().toTypedArray()
val states = services.walletService.linearHeadsOfType<InterestRateSwap.State>()
val swaps = states.values.map { it.state.data }.toTypedArray()
return swaps
}
@ -64,7 +65,7 @@ class InterestRateSwapAPI(val api: APIServer) {
@Path("deals")
@Consumes(MediaType.APPLICATION_JSON)
fun storeDeal(newDeal: InterestRateSwap.State): Response {
api.invokeProtocolSync(ProtocolClassRef(AutoOfferProtocol.Requester::class.java.name!!), mapOf("dealToBeOffered" to newDeal))
services.invokeProtocolAsync<SignedTransaction>(AutoOfferProtocol.Requester::class.java, newDeal).get()
return Response.created(URI.create(generateDealLink(newDeal))).build()
}
@ -84,10 +85,10 @@ class InterestRateSwapAPI(val api: APIServer) {
@Path("demodate")
@Consumes(MediaType.APPLICATION_JSON)
fun storeDemoDate(newDemoDate: LocalDate): Response {
val priorDemoDate = api.serverTime().toLocalDate()
val priorDemoDate = fetchDemoDate()
// Can only move date forwards
if (newDemoDate.isAfter(priorDemoDate)) {
api.invokeProtocolSync(ProtocolClassRef(UpdateBusinessDayProtocol.Broadcast::class.java.name!!), mapOf("date" to newDemoDate))
services.invokeProtocolAsync<Unit>(UpdateBusinessDayProtocol.Broadcast::class.java, newDemoDate).get()
return Response.ok().build()
}
val msg = "demodate is already $priorDemoDate and can only be updated with a later date"
@ -99,14 +100,14 @@ class InterestRateSwapAPI(val api: APIServer) {
@Path("demodate")
@Produces(MediaType.APPLICATION_JSON)
fun fetchDemoDate(): LocalDate {
return api.serverTime().toLocalDate()
return LocalDateTime.now(services.clock).toLocalDate()
}
@PUT
@Path("restart")
@Consumes(MediaType.APPLICATION_JSON)
fun exitServer(): Response {
api.invokeProtocolSync(ProtocolClassRef(ExitServerProtocol.Broadcast::class.java.name!!), mapOf("exitCode" to 83))
services.invokeProtocolAsync<Boolean>(ExitServerProtocol.Broadcast::class.java, 83).get()
return Response.ok().build()
}
}

View File

@ -37,22 +37,21 @@ object ExitServerProtocol {
* This takes a Java Integer rather than Kotlin Int as that is what we end up with in the calling map and currently
* we do not support coercing numeric types in the reflective search for matching constructors
*/
class Broadcast(@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") val exitCode: Integer) : ProtocolLogic<Boolean>() {
class Broadcast(val exitCode: Int) : ProtocolLogic<Boolean>() {
override val topic: String get() = TOPIC
@Suspendable
override fun call(): Boolean {
if (enabled) {
val rc = exitCode.toInt()
val message = ExitMessage(rc)
val message = ExitMessage(exitCode)
for (recipient in serviceHub.networkMapCache.partyNodes) {
doNextRecipient(recipient, message)
}
// Sleep a little in case any async message delivery to other nodes needs to happen
Strand.sleep(1, TimeUnit.SECONDS)
System.exit(rc)
System.exit(exitCode)
}
return enabled
}

View File

@ -0,0 +1,2 @@
# Register a ServiceLoader service extending from com.r3corda.node.CordaPluginRegistry
com.r3corda.demos.IRSDemoPluginRegistry