Merge branches 'colljos-vault-transaction-notes' and 'master' of https://bitbucket.org/R3-CEV/r3prototyping into colljos-vault-transaction-notes

This commit is contained in:
Jose Coll 2016-10-27 14:59:37 +01:00
commit f2e98ffba5
47 changed files with 671 additions and 2619 deletions

View File

@ -110,7 +110,6 @@ dependencies {
compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version" compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
compile "org.jetbrains.kotlin:kotlin-test:$kotlin_version" compile "org.jetbrains.kotlin:kotlin-test:$kotlin_version"
compile "org.jetbrains.kotlinx:kotlinx-support-jdk8:0.2" compile "org.jetbrains.kotlinx:kotlinx-support-jdk8:0.2"
compile 'co.paralleluniverse:capsule:1.0.3'
// Unit testing helpers. // Unit testing helpers.
testCompile 'junit:junit:4.12' testCompile 'junit:junit:4.12'
@ -193,7 +192,7 @@ applicationDistribution.into("bin") {
task buildCordaJAR(type: FatCapsule, dependsOn: ['quasarScan', 'buildCertSigningRequestUtilityJAR']) { task buildCordaJAR(type: FatCapsule, dependsOn: ['quasarScan', 'buildCertSigningRequestUtilityJAR']) {
applicationClass 'com.r3corda.node.MainKt' applicationClass 'com.r3corda.node.MainKt'
archiveName 'corda.jar' archiveName 'corda.jar'
applicationSource = files(project.tasks.findByName('jar'), 'build/classes/main/CordaCaplet.class') applicationSource = files(project.tasks.findByName('jar'), 'node/build/classes/main/CordaCaplet.class')
capsuleManifest { capsuleManifest {
appClassPath = ["jolokia-agent-war-${project.ext.jolokia_version}.war"] appClassPath = ["jolokia-agent-war-${project.ext.jolokia_version}.war"]

View File

@ -0,0 +1,62 @@
package com.r3corda.client
import com.r3corda.core.random63BitValue
import com.r3corda.node.driver.driver
import com.r3corda.node.services.config.configureTestSSL
import com.r3corda.node.services.messaging.ArtemisMessagingComponent.Companion.toHostAndPort
import org.apache.activemq.artemis.api.core.ActiveMQSecurityException
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.concurrent.CountDownLatch
import kotlin.concurrent.thread
class CordaRPCClientTest {
private val validUsername = "user1"
private val validPassword = "test"
private val stopDriver = CountDownLatch(1)
private var driverThread: Thread? = null
private lateinit var client: CordaRPCClient
@Before
fun start() {
val driverStarted = CountDownLatch(1)
driverThread = thread {
driver {
val driverInfo = startNode().get()
client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL())
driverStarted.countDown()
stopDriver.await()
}
}
driverStarted.await()
}
@After
fun stop() {
stopDriver.countDown()
driverThread?.join()
}
@Test
fun `log in with valid username and password`() {
client.start(validUsername, validPassword)
}
@Test
fun `log in with unknown user`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(random63BitValue().toString(), validPassword)
}
}
@Test
fun `log in with incorrect password`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(validUsername, random63BitValue().toString())
}
}
}

View File

@ -1,6 +1,5 @@
package com.r3corda.client package com.r3corda.client
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.client.model.NodeMonitorModel import com.r3corda.client.model.NodeMonitorModel
import com.r3corda.client.model.ProgressTrackingEvent import com.r3corda.client.model.ProgressTrackingEvent
import com.r3corda.core.bufferUntilSubscribed import com.r3corda.core.bufferUntilSubscribed
@ -14,8 +13,7 @@ import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.serialization.OpaqueBytes import com.r3corda.core.serialization.OpaqueBytes
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.driver.driver import com.r3corda.node.driver.driver
import com.r3corda.node.driver.startClient import com.r3corda.node.services.config.configureTestSSL
import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.messaging.StateMachineUpdate import com.r3corda.node.services.messaging.StateMachineUpdate
import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.testing.expect import com.r3corda.testing.expect
@ -26,16 +24,15 @@ import org.junit.Before
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
import rx.Observer import rx.Observer
import java.util.concurrent.CountDownLatch
import kotlin.concurrent.thread import kotlin.concurrent.thread
class NodeMonitorModelTest { class NodeMonitorModelTest {
lateinit var aliceNode: NodeInfo lateinit var aliceNode: NodeInfo
lateinit var notaryNode: NodeInfo lateinit var notaryNode: NodeInfo
lateinit var aliceClient: NodeMessagingClient val stopDriver = CountDownLatch(1)
val driverStarted = SettableFuture.create<Unit>() var driverThread: Thread? = null
val stopDriver = SettableFuture.create<Unit>()
val driverStopped = SettableFuture.create<Unit>()
lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping> lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping>
lateinit var stateMachineUpdates: Observable<StateMachineUpdate> lateinit var stateMachineUpdates: Observable<StateMachineUpdate>
@ -48,15 +45,15 @@ class NodeMonitorModelTest {
@Before @Before
fun start() { fun start() {
thread { val driverStarted = CountDownLatch(1)
driverThread = thread {
driver { driver {
val aliceNodeFuture = startNode("Alice") val aliceNodeFuture = startNode("Alice")
val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type))) val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
aliceNode = aliceNodeFuture.get() aliceNode = aliceNodeFuture.get().nodeInfo
notaryNode = notaryNodeFuture.get() notaryNode = notaryNodeFuture.get().nodeInfo
aliceClient = startClient(aliceNode).get() newNode = { nodeName -> startNode(nodeName).get().nodeInfo }
newNode = { nodeName -> startNode(nodeName).get() }
val monitor = NodeMonitorModel() val monitor = NodeMonitorModel()
stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed() stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed()
@ -67,20 +64,18 @@ class NodeMonitorModelTest {
networkMapUpdates = monitor.networkMap.bufferUntilSubscribed() networkMapUpdates = monitor.networkMap.bufferUntilSubscribed()
clientToService = monitor.clientToService clientToService = monitor.clientToService
monitor.register(aliceNode, aliceClient.config.certificatesPath) monitor.register(aliceNode, configureTestSSL(), "user1", "test")
driverStarted.set(Unit) driverStarted.countDown()
stopDriver.get() stopDriver.await()
} }
driverStopped.set(Unit)
} }
driverStarted.get() driverStarted.await()
} }
@After @After
fun stop() { fun stop() {
stopDriver.set(Unit) stopDriver.countDown()
driverStopped.get() driverThread?.join()
} }
@Test @Test

View File

@ -24,19 +24,12 @@ import kotlin.concurrent.thread
* useful tasks. See the documentation for [proxy] or review the docsite to learn more about how this API works. * useful tasks. See the documentation for [proxy] or review the docsite to learn more about how this API works.
*/ */
@ThreadSafe @ThreadSafe
class CordaRPCClient(val host: HostAndPort, certificatesPath: Path) : Closeable, ArtemisMessagingComponent(sslConfig(certificatesPath)) { class CordaRPCClient(val host: HostAndPort, override val config: NodeSSLConfiguration) : Closeable, ArtemisMessagingComponent() {
companion object { companion object {
private val rpcLog = LoggerFactory.getLogger("com.r3corda.rpc") private val rpcLog = LoggerFactory.getLogger("com.r3corda.rpc")
private fun sslConfig(certificatesPath: Path): NodeSSLConfiguration = object : NodeSSLConfiguration {
override val certificatesPath: Path = certificatesPath
override val keyStorePassword = "cordacadevpass"
override val trustStorePassword = "trustpass"
}
} }
// TODO: Certificate handling for clients needs more work. // TODO: Certificate handling for clients needs more work.
private inner class State { private inner class State {
var running = false var running = false
lateinit var sessionFactory: ClientSessionFactory lateinit var sessionFactory: ClientSessionFactory
@ -57,7 +50,7 @@ class CordaRPCClient(val host: HostAndPort, certificatesPath: Path) : Closeable,
/** Opens the connection to the server and registers a JVM shutdown hook to cleanly disconnect. */ /** Opens the connection to the server and registers a JVM shutdown hook to cleanly disconnect. */
@Throws(ActiveMQNotConnectedException::class) @Throws(ActiveMQNotConnectedException::class)
fun start() { fun start(username: String, password: String) {
state.locked { state.locked {
check(!running) check(!running)
checkStorePasswords() // Check the password. checkStorePasswords() // Check the password.
@ -66,7 +59,7 @@ class CordaRPCClient(val host: HostAndPort, certificatesPath: Path) : Closeable,
sessionFactory = serverLocator.createSessionFactory() sessionFactory = serverLocator.createSessionFactory()
// We use our initial connection ID as the queue namespace. // We use our initial connection ID as the queue namespace.
myID = sessionFactory.connection.id as Int and 0x000000FFFFFF myID = sessionFactory.connection.id as Int and 0x000000FFFFFF
session = sessionFactory.createSession() session = sessionFactory.createSession(username, password, false, true, true, serverLocator.isPreAcknowledge, serverLocator.ackBatchSize)
session.start() session.start()
clientImpl = CordaRPCClientImpl(session, state.lock, myAddressPrefix) clientImpl = CordaRPCClientImpl(session, state.lock, myAddressPrefix)
running = true running = true

View File

@ -8,14 +8,14 @@ import com.r3corda.core.node.services.StateMachineTransactionMapping
import com.r3corda.core.node.services.Vault import com.r3corda.core.node.services.Vault
import com.r3corda.core.protocols.StateMachineRunId import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.messaging.ArtemisMessagingComponent import com.r3corda.node.services.config.NodeSSLConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingComponent.Companion.toHostAndPort
import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.messaging.CordaRPCOps
import com.r3corda.node.services.messaging.StateMachineInfo import com.r3corda.node.services.messaging.StateMachineInfo
import com.r3corda.node.services.messaging.StateMachineUpdate import com.r3corda.node.services.messaging.StateMachineUpdate
import javafx.beans.property.SimpleObjectProperty import javafx.beans.property.SimpleObjectProperty
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.nio.file.Path
data class ProgressTrackingEvent(val stateMachineId: StateMachineRunId, val message: String) { data class ProgressTrackingEvent(val stateMachineId: StateMachineRunId, val message: String) {
companion object { companion object {
@ -54,14 +54,11 @@ class NodeMonitorModel {
/** /**
* Register for updates to/from a given vault. * Register for updates to/from a given vault.
* @param messagingService The messaging to use for communication.
* @param monitorNodeInfo the [Node] to connect to.
* TODO provide an unsubscribe mechanism * TODO provide an unsubscribe mechanism
*/ */
fun register(vaultMonitorNodeInfo: NodeInfo, certificatesPath: Path) { fun register(vaultMonitorNodeInfo: NodeInfo, sslConfig: NodeSSLConfiguration, username: String, password: String) {
val client = CordaRPCClient(toHostAndPort(vaultMonitorNodeInfo.address), sslConfig)
val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(vaultMonitorNodeInfo.address), certificatesPath) client.start(username, password)
client.start()
val proxy = client.proxy() val proxy = client.proxy()
val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates() val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates()

View File

@ -1,801 +0,0 @@
package com.r3corda.contracts
import com.r3corda.core.contracts.*
import com.r3corda.core.contracts.clauses.*
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.suggestInterestRateAnnouncementTimeWindow
import com.r3corda.protocols.TwoPartyDealProtocol
import org.apache.commons.jexl3.JexlBuilder
import org.apache.commons.jexl3.MapContext
import java.math.BigDecimal
import java.math.RoundingMode
import java.security.PublicKey
import java.time.LocalDate
import java.util.*
val IRS_PROGRAM_ID = InterestRateSwap()
// This is a placeholder for some types that we haven't identified exactly what they are just yet for things still in discussion
open class UnknownType() {
override fun equals(other: Any?): Boolean {
return (other is UnknownType)
}
override fun hashCode() = 1
}
/**
* Event superclass - everything happens on a date.
*/
open class Event(val date: LocalDate) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is Event) return false
if (date != other.date) return false
return true
}
override fun hashCode() = Objects.hash(date)
}
/**
* Top level PaymentEvent class - represents an obligation to pay an amount on a given date, which may be either in the past or the future.
*/
abstract class PaymentEvent(date: LocalDate) : Event(date) {
abstract fun calculate(): Amount<Currency>
}
/**
* A [RatePaymentEvent] represents a dated obligation of payment.
* It is a specialisation / modification of a basic cash flow event (to be written) that has some additional assistance
* functions for interest rate swap legs of the fixed and floating nature.
* For the fixed leg, the rate is already known at creation and therefore the flows can be pre-determined.
* For the floating leg, the rate refers to a reference rate which is to be "fixed" at a point in the future.
*/
abstract class RatePaymentEvent(date: LocalDate,
val accrualStartDate: LocalDate,
val accrualEndDate: LocalDate,
val dayCountBasisDay: DayCountBasisDay,
val dayCountBasisYear: DayCountBasisYear,
val notional: Amount<Currency>,
val rate: Rate) : PaymentEvent(date) {
companion object {
val CSVHeader = "AccrualStartDate,AccrualEndDate,DayCountFactor,Days,Date,Ccy,Notional,Rate,Flow"
}
override fun calculate(): Amount<Currency> = flow
abstract val flow: Amount<Currency>
val days: Int get() = calculateDaysBetween(accrualStartDate, accrualEndDate, dayCountBasisYear, dayCountBasisDay)
// TODO : Fix below (use daycount convention for division, not hardcoded 360 etc)
val dayCountFactor: BigDecimal get() = (BigDecimal(days).divide(BigDecimal(360.0), 8, RoundingMode.HALF_UP)).setScale(4, RoundingMode.HALF_UP)
open fun asCSV() = "$accrualStartDate,$accrualEndDate,$dayCountFactor,$days,$date,${notional.token},$notional,$rate,$flow"
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is RatePaymentEvent) return false
if (accrualStartDate != other.accrualStartDate) return false
if (accrualEndDate != other.accrualEndDate) return false
if (dayCountBasisDay != other.dayCountBasisDay) return false
if (dayCountBasisYear != other.dayCountBasisYear) return false
if (notional != other.notional) return false
if (rate != other.rate) return false
// if (flow != other.flow) return false // Flow is derived
return super.equals(other)
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(accrualStartDate, accrualEndDate, dayCountBasisDay,
dayCountBasisYear, notional, rate)
}
/**
* Basic class for the Fixed Rate Payments on the fixed leg - see [RatePaymentEvent].
* Assumes that the rate is valid.
*/
class FixedRatePaymentEvent(date: LocalDate,
accrualStartDate: LocalDate,
accrualEndDate: LocalDate,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
notional: Amount<Currency>,
rate: Rate) :
RatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay, dayCountBasisYear, notional, rate) {
companion object {
val CSVHeader = RatePaymentEvent.CSVHeader
}
override val flow: Amount<Currency> get() = Amount(dayCountFactor.times(BigDecimal(notional.quantity)).times(rate.ratioUnit!!.value).toLong(), notional.token)
override fun toString(): String =
"FixedRatePaymentEvent $accrualStartDate -> $accrualEndDate : $dayCountFactor : $days : $date : $notional : $rate : $flow"
}
/**
* Basic class for the Floating Rate Payments on the floating leg - see [RatePaymentEvent].
* If the rate is null returns a zero payment. // TODO: Is this the desired behaviour?
*/
class FloatingRatePaymentEvent(date: LocalDate,
accrualStartDate: LocalDate,
accrualEndDate: LocalDate,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
val fixingDate: LocalDate,
notional: Amount<Currency>,
rate: Rate) : RatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay, dayCountBasisYear, notional, rate) {
companion object {
val CSVHeader = RatePaymentEvent.CSVHeader + ",FixingDate"
}
override val flow: Amount<Currency> get() {
// TODO: Should an uncalculated amount return a zero ? null ? etc.
val v = rate.ratioUnit?.value ?: return Amount(0, notional.token)
return Amount(dayCountFactor.times(BigDecimal(notional.quantity)).times(v).toLong(), notional.token)
}
override fun toString(): String = "FloatingPaymentEvent $accrualStartDate -> $accrualEndDate : $dayCountFactor : $days : $date : $notional : $rate (fix on $fixingDate): $flow"
override fun asCSV(): String = "$accrualStartDate,$accrualEndDate,$dayCountFactor,$days,$date,${notional.token},$notional,$fixingDate,$rate,$flow"
/**
* Used for making immutables.
*/
fun withNewRate(newRate: Rate): FloatingRatePaymentEvent =
FloatingRatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay,
dayCountBasisYear, fixingDate, notional, newRate)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
other as FloatingRatePaymentEvent
if (fixingDate != other.fixingDate) return false
return super.equals(other)
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(fixingDate)
// Can't autogenerate as not a data class :-(
fun copy(date: LocalDate = this.date,
accrualStartDate: LocalDate = this.accrualStartDate,
accrualEndDate: LocalDate = this.accrualEndDate,
dayCountBasisDay: DayCountBasisDay = this.dayCountBasisDay,
dayCountBasisYear: DayCountBasisYear = this.dayCountBasisYear,
fixingDate: LocalDate = this.fixingDate,
notional: Amount<Currency> = this.notional,
rate: Rate = this.rate) = FloatingRatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay, dayCountBasisYear, fixingDate, notional, rate)
}
/**
* The Interest Rate Swap class. For a quick overview of what an IRS is, see here - http://www.pimco.co.uk/EN/Education/Pages/InterestRateSwapsBasics1-08.aspx (no endorsement).
* This contract has 4 significant data classes within it, the "Common", "Calculation", "FixedLeg" and "FloatingLeg".
* It also has 4 commands, "Agree", "Fix", "Pay" and "Mature".
* Currently, we are not interested (excuse pun) in valuing the swap, calculating the PVs, DFs and all that good stuff (soon though).
* This is just a representation of a vanilla Fixed vs Floating (same currency) IRS in the R3 prototype model.
*/
class InterestRateSwap() : Contract {
override val legalContractReference = SecureHash.sha256("is_this_the_text_of_the_contract ? TBD")
companion object {
val oracleType = ServiceType.corda.getSubType("interest_rates")
}
/**
* This Common area contains all the information that is not leg specific.
*/
data class Common(
val baseCurrency: Currency,
val eligibleCurrency: Currency,
val eligibleCreditSupport: String,
val independentAmounts: Amount<Currency>,
val threshold: Amount<Currency>,
val minimumTransferAmount: Amount<Currency>,
val rounding: Amount<Currency>,
val valuationDateDescription: String, // This describes (in english) how regularly the swap is to be valued, e.g. "every local working day"
val notificationTime: String,
val resolutionTime: String,
val interestRate: ReferenceRate,
val addressForTransfers: String,
val exposure: UnknownType,
val localBusinessDay: BusinessCalendar,
val dailyInterestAmount: Expression,
val tradeID: String,
val hashLegalDocs: String
)
/**
* The Calculation data class is "mutable" through out the life of the swap, as in, it's the only thing that contains
* data that will changed from state to state (Recall that the design insists that everything is immutable, so we actually
* copy / update for each transition).
*/
data class Calculation(
val expression: Expression,
val floatingLegPaymentSchedule: Map<LocalDate, FloatingRatePaymentEvent>,
val fixedLegPaymentSchedule: Map<LocalDate, FixedRatePaymentEvent>
) {
/**
* Gets the date of the next fixing.
* @return LocalDate or null if no more fixings.
*/
fun nextFixingDate(): LocalDate? {
return floatingLegPaymentSchedule.
filter { it.value.rate is ReferenceRate }.// TODO - a better way to determine what fixings remain to be fixed
minBy { it.value.fixingDate.toEpochDay() }?.value?.fixingDate
}
/**
* Returns the fixing for that date.
*/
fun getFixing(date: LocalDate): FloatingRatePaymentEvent =
floatingLegPaymentSchedule.values.single { it.fixingDate == date }
/**
* Returns a copy after modifying (applying) the fixing for that date.
*/
fun applyFixing(date: LocalDate, newRate: FixedRate): Calculation {
val paymentEvent = getFixing(date)
val newFloatingLPS = floatingLegPaymentSchedule + (paymentEvent.date to paymentEvent.withNewRate(newRate))
return Calculation(expression = expression,
floatingLegPaymentSchedule = newFloatingLPS,
fixedLegPaymentSchedule = fixedLegPaymentSchedule)
}
}
abstract class CommonLeg(
val notional: Amount<Currency>,
val paymentFrequency: Frequency,
val effectiveDate: LocalDate,
val effectiveDateAdjustment: DateRollConvention?,
val terminationDate: LocalDate,
val terminationDateAdjustment: DateRollConvention?,
val dayCountBasisDay: DayCountBasisDay,
val dayCountBasisYear: DayCountBasisYear,
val dayInMonth: Int,
val paymentRule: PaymentRule,
val paymentDelay: Int,
val paymentCalendar: BusinessCalendar,
val interestPeriodAdjustment: AccrualAdjustment
) {
override fun toString(): String {
return "Notional=$notional,PaymentFrequency=$paymentFrequency,EffectiveDate=$effectiveDate,EffectiveDateAdjustment:$effectiveDateAdjustment,TerminatationDate=$terminationDate," +
"TerminationDateAdjustment=$terminationDateAdjustment,DayCountBasis=$dayCountBasisDay/$dayCountBasisYear,DayInMonth=$dayInMonth," +
"PaymentRule=$paymentRule,PaymentDelay=$paymentDelay,PaymentCalendar=$paymentCalendar,InterestPeriodAdjustment=$interestPeriodAdjustment"
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
other as CommonLeg
if (notional != other.notional) return false
if (paymentFrequency != other.paymentFrequency) return false
if (effectiveDate != other.effectiveDate) return false
if (effectiveDateAdjustment != other.effectiveDateAdjustment) return false
if (terminationDate != other.terminationDate) return false
if (terminationDateAdjustment != other.terminationDateAdjustment) return false
if (dayCountBasisDay != other.dayCountBasisDay) return false
if (dayCountBasisYear != other.dayCountBasisYear) return false
if (dayInMonth != other.dayInMonth) return false
if (paymentRule != other.paymentRule) return false
if (paymentDelay != other.paymentDelay) return false
if (paymentCalendar != other.paymentCalendar) return false
if (interestPeriodAdjustment != other.interestPeriodAdjustment) return false
return true
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(notional, paymentFrequency, effectiveDate,
effectiveDateAdjustment, terminationDate, effectiveDateAdjustment, terminationDate, terminationDateAdjustment,
dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment)
}
open class FixedLeg(
var fixedRatePayer: Party,
notional: Amount<Currency>,
paymentFrequency: Frequency,
effectiveDate: LocalDate,
effectiveDateAdjustment: DateRollConvention?,
terminationDate: LocalDate,
terminationDateAdjustment: DateRollConvention?,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
dayInMonth: Int,
paymentRule: PaymentRule,
paymentDelay: Int,
paymentCalendar: BusinessCalendar,
interestPeriodAdjustment: AccrualAdjustment,
var fixedRate: FixedRate,
var rollConvention: DateRollConvention // TODO - best way of implementing - still awaiting some clarity
) : CommonLeg
(notional, paymentFrequency, effectiveDate, effectiveDateAdjustment, terminationDate, terminationDateAdjustment,
dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment) {
override fun toString(): String = "FixedLeg(Payer=$fixedRatePayer," + super.toString() + ",fixedRate=$fixedRate," +
"rollConvention=$rollConvention"
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
if (!super.equals(other)) return false
other as FixedLeg
if (fixedRatePayer != other.fixedRatePayer) return false
if (fixedRate != other.fixedRate) return false
if (rollConvention != other.rollConvention) return false
return true
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(fixedRatePayer, fixedRate, rollConvention)
// Can't autogenerate as not a data class :-(
fun copy(fixedRatePayer: Party = this.fixedRatePayer,
notional: Amount<Currency> = this.notional,
paymentFrequency: Frequency = this.paymentFrequency,
effectiveDate: LocalDate = this.effectiveDate,
effectiveDateAdjustment: DateRollConvention? = this.effectiveDateAdjustment,
terminationDate: LocalDate = this.terminationDate,
terminationDateAdjustment: DateRollConvention? = this.terminationDateAdjustment,
dayCountBasisDay: DayCountBasisDay = this.dayCountBasisDay,
dayCountBasisYear: DayCountBasisYear = this.dayCountBasisYear,
dayInMonth: Int = this.dayInMonth,
paymentRule: PaymentRule = this.paymentRule,
paymentDelay: Int = this.paymentDelay,
paymentCalendar: BusinessCalendar = this.paymentCalendar,
interestPeriodAdjustment: AccrualAdjustment = this.interestPeriodAdjustment,
fixedRate: FixedRate = this.fixedRate) = FixedLeg(
fixedRatePayer, notional, paymentFrequency, effectiveDate, effectiveDateAdjustment, terminationDate,
terminationDateAdjustment, dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay,
paymentCalendar, interestPeriodAdjustment, fixedRate, rollConvention)
}
open class FloatingLeg(
var floatingRatePayer: Party,
notional: Amount<Currency>,
paymentFrequency: Frequency,
effectiveDate: LocalDate,
effectiveDateAdjustment: DateRollConvention?,
terminationDate: LocalDate,
terminationDateAdjustment: DateRollConvention?,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
dayInMonth: Int,
paymentRule: PaymentRule,
paymentDelay: Int,
paymentCalendar: BusinessCalendar,
interestPeriodAdjustment: AccrualAdjustment,
var rollConvention: DateRollConvention,
var fixingRollConvention: DateRollConvention,
var resetDayInMonth: Int,
var fixingPeriodOffset: Int,
var resetRule: PaymentRule,
var fixingsPerPayment: Frequency,
var fixingCalendar: BusinessCalendar,
var index: String,
var indexSource: String,
var indexTenor: Tenor
) : CommonLeg(notional, paymentFrequency, effectiveDate, effectiveDateAdjustment, terminationDate, terminationDateAdjustment,
dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment) {
override fun toString(): String = "FloatingLeg(Payer=$floatingRatePayer," + super.toString() +
"rollConvention=$rollConvention,FixingRollConvention=$fixingRollConvention,ResetDayInMonth=$resetDayInMonth" +
"FixingPeriondOffset=$fixingPeriodOffset,ResetRule=$resetRule,FixingsPerPayment=$fixingsPerPayment,FixingCalendar=$fixingCalendar," +
"Index=$index,IndexSource=$indexSource,IndexTenor=$indexTenor"
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
if (!super.equals(other)) return false
other as FloatingLeg
if (floatingRatePayer != other.floatingRatePayer) return false
if (rollConvention != other.rollConvention) return false
if (fixingRollConvention != other.fixingRollConvention) return false
if (resetDayInMonth != other.resetDayInMonth) return false
if (fixingPeriodOffset != other.fixingPeriodOffset) return false
if (resetRule != other.resetRule) return false
if (fixingsPerPayment != other.fixingsPerPayment) return false
if (fixingCalendar != other.fixingCalendar) return false
if (index != other.index) return false
if (indexSource != other.indexSource) return false
if (indexTenor != other.indexTenor) return false
return true
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(floatingRatePayer, rollConvention,
fixingRollConvention, resetDayInMonth, fixingPeriodOffset, resetRule, fixingsPerPayment, fixingCalendar,
index, indexSource, indexTenor)
fun copy(floatingRatePayer: Party = this.floatingRatePayer,
notional: Amount<Currency> = this.notional,
paymentFrequency: Frequency = this.paymentFrequency,
effectiveDate: LocalDate = this.effectiveDate,
effectiveDateAdjustment: DateRollConvention? = this.effectiveDateAdjustment,
terminationDate: LocalDate = this.terminationDate,
terminationDateAdjustment: DateRollConvention? = this.terminationDateAdjustment,
dayCountBasisDay: DayCountBasisDay = this.dayCountBasisDay,
dayCountBasisYear: DayCountBasisYear = this.dayCountBasisYear,
dayInMonth: Int = this.dayInMonth,
paymentRule: PaymentRule = this.paymentRule,
paymentDelay: Int = this.paymentDelay,
paymentCalendar: BusinessCalendar = this.paymentCalendar,
interestPeriodAdjustment: AccrualAdjustment = this.interestPeriodAdjustment,
rollConvention: DateRollConvention = this.rollConvention,
fixingRollConvention: DateRollConvention = this.fixingRollConvention,
resetDayInMonth: Int = this.resetDayInMonth,
fixingPeriod: Int = this.fixingPeriodOffset,
resetRule: PaymentRule = this.resetRule,
fixingsPerPayment: Frequency = this.fixingsPerPayment,
fixingCalendar: BusinessCalendar = this.fixingCalendar,
index: String = this.index,
indexSource: String = this.indexSource,
indexTenor: Tenor = this.indexTenor
) = FloatingLeg(floatingRatePayer, notional, paymentFrequency, effectiveDate, effectiveDateAdjustment,
terminationDate, terminationDateAdjustment, dayCountBasisDay, dayCountBasisYear, dayInMonth,
paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment, rollConvention,
fixingRollConvention, resetDayInMonth, fixingPeriod, resetRule, fixingsPerPayment,
fixingCalendar, index, indexSource, indexTenor)
}
override fun verify(tx: TransactionForContract) = verifyClause(tx, AllComposition(Clauses.Timestamped(), Clauses.Group()), tx.commands.select<Commands>())
interface Clauses {
/**
* Common superclass for IRS contract clauses, which defines behaviour on match/no-match, and provides
* helper functions for the clauses.
*/
abstract class AbstractIRSClause : Clause<State, Commands, UniqueIdentifier>() {
// These functions may make more sense to use for basket types, but for now let's leave them here
fun checkLegDates(legs: List<CommonLeg>) {
requireThat {
"Effective date is before termination date" by legs.all { it.effectiveDate < it.terminationDate }
"Effective dates are in alignment" by legs.all { it.effectiveDate == legs[0].effectiveDate }
"Termination dates are in alignment" by legs.all { it.terminationDate == legs[0].terminationDate }
}
}
fun checkLegAmounts(legs: List<CommonLeg>) {
requireThat {
"The notional is non zero" by legs.any { it.notional.quantity > (0).toLong() }
"The notional for all legs must be the same" by legs.all { it.notional == legs[0].notional }
}
for (leg: CommonLeg in legs) {
if (leg is FixedLeg) {
requireThat {
// TODO: Confirm: would someone really enter a swap with a negative fixed rate?
"Fixed leg rate must be positive" by leg.fixedRate.isPositive()
}
}
}
}
// TODO: After business rules discussion, add further checks to the schedules and rates
fun checkSchedules(@Suppress("UNUSED_PARAMETER") legs: List<CommonLeg>): Boolean = true
fun checkRates(@Suppress("UNUSED_PARAMETER") legs: List<CommonLeg>): Boolean = true
/**
* Compares two schedules of Floating Leg Payments, returns the difference (i.e. omissions in either leg or changes to the values).
*/
fun getFloatingLegPaymentsDifferences(payments1: Map<LocalDate, Event>, payments2: Map<LocalDate, Event>): List<Pair<LocalDate, Pair<FloatingRatePaymentEvent, FloatingRatePaymentEvent>>> {
val diff1 = payments1.filter { payments1[it.key] != payments2[it.key] }
val diff2 = payments2.filter { payments1[it.key] != payments2[it.key] }
return (diff1.keys + diff2.keys).map {
it to Pair(diff1[it] as FloatingRatePaymentEvent, diff2[it] as FloatingRatePaymentEvent)
}
}
}
class Group : GroupClauseVerifier<State, Commands, UniqueIdentifier>(AnyComposition(Agree(), Fix(), Pay(), Mature())) {
override fun groupStates(tx: TransactionForContract): List<TransactionForContract.InOutGroup<State, UniqueIdentifier>>
// Group by Trade ID for in / out states
= tx.groupStates() { state -> state.linearId }
}
class Timestamped : Clause<ContractState, Commands, Unit>() {
override fun verify(tx: TransactionForContract,
inputs: List<ContractState>,
outputs: List<ContractState>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: Unit?): Set<Commands> {
require(tx.timestamp?.midpoint != null) { "must be timestamped" }
// We return an empty set because we don't process any commands
return emptySet()
}
}
class Agree : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Agree::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Agree>()
val irs = outputs.filterIsInstance<State>().single()
requireThat {
"There are no in states for an agreement" by inputs.isEmpty()
"There are events in the fix schedule" by (irs.calculation.fixedLegPaymentSchedule.size > 0)
"There are events in the float schedule" by (irs.calculation.floatingLegPaymentSchedule.size > 0)
"All notionals must be non zero" by (irs.fixedLeg.notional.quantity > 0 && irs.floatingLeg.notional.quantity > 0)
"The fixed leg rate must be positive" by (irs.fixedLeg.fixedRate.isPositive())
"The currency of the notionals must be the same" by (irs.fixedLeg.notional.token == irs.floatingLeg.notional.token)
"All leg notionals must be the same" by (irs.fixedLeg.notional == irs.floatingLeg.notional)
"The effective date is before the termination date for the fixed leg" by (irs.fixedLeg.effectiveDate < irs.fixedLeg.terminationDate)
"The effective date is before the termination date for the floating leg" by (irs.floatingLeg.effectiveDate < irs.floatingLeg.terminationDate)
"The effective dates are aligned" by (irs.floatingLeg.effectiveDate == irs.fixedLeg.effectiveDate)
"The termination dates are aligned" by (irs.floatingLeg.terminationDate == irs.fixedLeg.terminationDate)
"The rates are valid" by checkRates(listOf(irs.fixedLeg, irs.floatingLeg))
"The schedules are valid" by checkSchedules(listOf(irs.fixedLeg, irs.floatingLeg))
"The fixing period date offset cannot be negative" by (irs.floatingLeg.fixingPeriodOffset >= 0)
// TODO: further tests
}
checkLegAmounts(listOf(irs.fixedLeg, irs.floatingLeg))
checkLegDates(listOf(irs.fixedLeg, irs.floatingLeg))
return setOf(command.value)
}
}
class Fix : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Refix::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Refix>()
val irs = outputs.filterIsInstance<State>().single()
val prevIrs = inputs.filterIsInstance<State>().single()
val paymentDifferences = getFloatingLegPaymentsDifferences(prevIrs.calculation.floatingLegPaymentSchedule, irs.calculation.floatingLegPaymentSchedule)
// Having both of these tests are "redundant" as far as verify() goes, however, by performing both
// we can relay more information back to the user in the case of failure.
requireThat {
"There is at least one difference in the IRS floating leg payment schedules" by !paymentDifferences.isEmpty()
"There is only one change in the IRS floating leg payment schedule" by (paymentDifferences.size == 1)
}
val changedRates = paymentDifferences.single().second // Ignore the date of the changed rate (we checked that earlier).
val (oldFloatingRatePaymentEvent, newFixedRatePaymentEvent) = changedRates
val fixValue = command.value.fix
// Need to check that everything is the same apart from the new fixed rate entry.
requireThat {
"The fixed leg parties are constant" by (irs.fixedLeg.fixedRatePayer == prevIrs.fixedLeg.fixedRatePayer) // Although superseded by the below test, this is included for a regression issue
"The fixed leg is constant" by (irs.fixedLeg == prevIrs.fixedLeg)
"The floating leg is constant" by (irs.floatingLeg == prevIrs.floatingLeg)
"The common values are constant" by (irs.common == prevIrs.common)
"The fixed leg payment schedule is constant" by (irs.calculation.fixedLegPaymentSchedule == prevIrs.calculation.fixedLegPaymentSchedule)
"The expression is unchanged" by (irs.calculation.expression == prevIrs.calculation.expression)
"There is only one changed payment in the floating leg" by (paymentDifferences.size == 1)
"There changed payment is a floating payment" by (oldFloatingRatePaymentEvent.rate is ReferenceRate)
"The new payment is a fixed payment" by (newFixedRatePaymentEvent.rate is FixedRate)
"The changed payments dates are aligned" by (oldFloatingRatePaymentEvent.date == newFixedRatePaymentEvent.date)
"The new payment has the correct rate" by (newFixedRatePaymentEvent.rate.ratioUnit!!.value == fixValue.value)
"The fixing is for the next required date" by (prevIrs.calculation.nextFixingDate() == fixValue.of.forDay)
"The fix payment has the same currency as the notional" by (newFixedRatePaymentEvent.flow.token == irs.floatingLeg.notional.token)
// "The fixing is not in the future " by (fixCommand) // The oracle should not have signed this .
}
return setOf(command.value)
}
}
class Pay : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Pay::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Pay>()
requireThat {
"Payments not supported / verifiable yet" by false
}
return setOf(command.value)
}
}
class Mature : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Mature::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Mature>()
val irs = inputs.filterIsInstance<State>().single()
requireThat {
"No more fixings to be applied" by (irs.calculation.nextFixingDate() == null)
"The irs is fully consumed and there is no id matched output state" by outputs.isEmpty()
}
return setOf(command.value)
}
}
}
interface Commands : CommandData {
data class Refix(val fix: Fix) : Commands // Receive interest rate from oracle, Both sides agree
class Pay : TypeOnlyCommandData(), Commands // Not implemented just yet
class Agree : TypeOnlyCommandData(), Commands // Both sides agree to trade
class Mature : TypeOnlyCommandData(), Commands // Trade has matured; no more actions. Cleanup. // TODO: Do we need this?
}
/**
* The state class contains the 4 major data classes.
*/
data class State(
val fixedLeg: FixedLeg,
val floatingLeg: FloatingLeg,
val calculation: Calculation,
val common: Common,
override val linearId: UniqueIdentifier = UniqueIdentifier(common.tradeID)
) : FixableDealState, SchedulableState {
override val contract = IRS_PROGRAM_ID
override val oracleType: ServiceType
get() = InterestRateSwap.oracleType
override val ref = common.tradeID
override val participants: List<PublicKey>
get() = parties.map { it.owningKey }
override fun isRelevant(ourKeys: Set<PublicKey>): Boolean {
return (fixedLeg.fixedRatePayer.owningKey in ourKeys) || (floatingLeg.floatingRatePayer.owningKey in ourKeys)
}
override val parties: List<Party>
get() = listOf(fixedLeg.fixedRatePayer, floatingLeg.floatingRatePayer)
override fun nextScheduledActivity(thisStateRef: StateRef, protocolLogicRefFactory: ProtocolLogicRefFactory): ScheduledActivity? {
val nextFixingOf = nextFixingOf() ?: return null
// This is perhaps not how we should determine the time point in the business day, but instead expect the schedule to detail some of these aspects
val instant = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay).start
return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef), instant)
}
override fun generateAgreement(notary: Party): TransactionBuilder = InterestRateSwap().generateAgreement(floatingLeg, fixedLeg, calculation, common, notary)
override fun generateFix(ptx: TransactionBuilder, oldState: StateAndRef<*>, fix: Fix) {
InterestRateSwap().generateFix(ptx, StateAndRef(TransactionState(this, oldState.state.notary), oldState.ref), fix)
}
override fun nextFixingOf(): FixOf? {
val date = calculation.nextFixingDate()
return if (date == null) null else {
val fixingEvent = calculation.getFixing(date)
val oracleRate = fixingEvent.rate as ReferenceRate
FixOf(oracleRate.name, date, oracleRate.tenor)
}
}
/**
* For evaluating arbitrary java on the platform.
*/
fun evaluateCalculation(businessDate: LocalDate, expression: Expression = calculation.expression): Any {
// TODO: Jexl is purely for prototyping. It may be replaced
// TODO: Whatever we do use must be secure and sandboxed
val jexl = JexlBuilder().create()
val expr = jexl.createExpression(expression.expr)
val jc = MapContext()
jc.set("fixedLeg", fixedLeg)
jc.set("floatingLeg", floatingLeg)
jc.set("calculation", calculation)
jc.set("common", common)
jc.set("currentBusinessDate", businessDate)
return expr.evaluate(jc)
}
/**
* Just makes printing it out a bit better for those who don't have 80000 column wide monitors.
*/
fun prettyPrint() = toString().replace(",", "\n")
}
/**
* This generates the agreement state and also the schedules from the initial data.
* Note: The day count, interest rate calculation etc are not finished yet, but they are demonstrable.
*/
fun generateAgreement(floatingLeg: FloatingLeg, fixedLeg: FixedLeg, calculation: Calculation,
common: Common, notary: Party): TransactionBuilder {
val fixedLegPaymentSchedule = HashMap<LocalDate, FixedRatePaymentEvent>()
var dates = BusinessCalendar.createGenericSchedule(fixedLeg.effectiveDate, fixedLeg.paymentFrequency, fixedLeg.paymentCalendar, fixedLeg.rollConvention, endDate = fixedLeg.terminationDate)
var periodStartDate = fixedLeg.effectiveDate
// Create a schedule for the fixed payments
for (periodEndDate in dates) {
val paymentDate = BusinessCalendar.getOffsetDate(periodEndDate, Frequency.Daily, fixedLeg.paymentDelay)
val paymentEvent = FixedRatePaymentEvent(
paymentDate,
periodStartDate,
periodEndDate,
fixedLeg.dayCountBasisDay,
fixedLeg.dayCountBasisYear,
fixedLeg.notional,
fixedLeg.fixedRate
)
fixedLegPaymentSchedule[paymentDate] = paymentEvent
periodStartDate = periodEndDate
}
dates = BusinessCalendar.createGenericSchedule(floatingLeg.effectiveDate,
floatingLeg.fixingsPerPayment,
floatingLeg.fixingCalendar,
floatingLeg.rollConvention,
endDate = floatingLeg.terminationDate)
val floatingLegPaymentSchedule: MutableMap<LocalDate, FloatingRatePaymentEvent> = HashMap()
periodStartDate = floatingLeg.effectiveDate
// Now create a schedule for the floating and fixes.
for (periodEndDate in dates) {
val paymentDate = BusinessCalendar.getOffsetDate(periodEndDate, Frequency.Daily, floatingLeg.paymentDelay)
val paymentEvent = FloatingRatePaymentEvent(
paymentDate,
periodStartDate,
periodEndDate,
floatingLeg.dayCountBasisDay,
floatingLeg.dayCountBasisYear,
calcFixingDate(periodStartDate, floatingLeg.fixingPeriodOffset, floatingLeg.fixingCalendar),
floatingLeg.notional,
ReferenceRate(floatingLeg.indexSource, floatingLeg.indexTenor, floatingLeg.index)
)
floatingLegPaymentSchedule[paymentDate] = paymentEvent
periodStartDate = periodEndDate
}
val newCalculation = Calculation(calculation.expression, floatingLegPaymentSchedule, fixedLegPaymentSchedule)
// Put all the above into a new State object.
val state = State(fixedLeg, floatingLeg, newCalculation, common)
return TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Agree(), listOf(state.floatingLeg.floatingRatePayer.owningKey, state.fixedLeg.fixedRatePayer.owningKey)))
}
private fun calcFixingDate(date: LocalDate, fixingPeriodOffset: Int, calendar: BusinessCalendar): LocalDate {
return when (fixingPeriodOffset) {
0 -> date
else -> calendar.moveBusinessDays(date, DateRollDirection.BACKWARD, fixingPeriodOffset)
}
}
fun generateFix(tx: TransactionBuilder, irs: StateAndRef<State>, fixing: Fix) {
tx.addInputState(irs)
val fixedRate = FixedRate(RatioUnit(fixing.value))
tx.addOutputState(
irs.state.data.copy(calculation = irs.state.data.calculation.applyFixing(fixing.of.forDay, fixedRate)),
irs.state.notary
)
tx.addCommand(Commands.Refix(fixing), listOf(irs.state.data.floatingLeg.floatingRatePayer.owningKey, irs.state.data.fixedLeg.fixedRatePayer.owningKey))
}
}

View File

@ -1,7 +0,0 @@
package com.r3corda.contracts
fun InterestRateSwap.State.exportIRSToCSV(): String =
"Fixed Leg\n" + FixedRatePaymentEvent.CSVHeader + "\n" +
this.calculation.fixedLegPaymentSchedule.toSortedMap().values.map { it.asCSV() }.joinToString("\n") + "\n" +
"Floating Leg\n" + FloatingRatePaymentEvent.CSVHeader + "\n" +
this.calculation.floatingLegPaymentSchedule.toSortedMap().values.map { it.asCSV() }.joinToString("\n") + "\n"

View File

@ -1,88 +0,0 @@
package com.r3corda.contracts
import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.Tenor
import java.math.BigDecimal
import java.util.*
// Things in here will move to the general utils class when we've hammered out various discussions regarding amounts, dates, oracle etc.
/**
* A utility class to prevent the various mixups between percentages, decimals, bips etc.
*/
open class RatioUnit(val value: BigDecimal) { // TODO: Discuss this type
override fun equals(other: Any?) = (other as? RatioUnit)?.value == value
override fun hashCode() = value.hashCode()
override fun toString() = value.toString()
}
/**
* A class to reprecent a percentage in an unambiguous way.
*/
open class PercentageRatioUnit(percentageAsString: String) : RatioUnit(BigDecimal(percentageAsString).divide(BigDecimal("100"))) {
override fun toString() = value.times(BigDecimal(100)).toString() + "%"
}
/**
* For the convenience of writing "5".percent
* Note that we do not currently allow 10.percent (ie no quotes) as this might get a little confusing if 0.1.percent was
* written. Additionally, there is a possibility of creating a precision error in the implicit conversion.
*/
val String.percent: PercentageRatioUnit get() = PercentageRatioUnit(this)
/**
* Parent of the Rate family. Used to denote fixed rates, floating rates, reference rates etc.
*/
open class Rate(val ratioUnit: RatioUnit? = null) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
other as Rate
if (ratioUnit != other.ratioUnit) return false
return true
}
/**
* @returns the hash code of the ratioUnit or zero if the ratioUnit is null, as is the case for floating rate fixings
* that have not yet happened. Yet-to-be fixed floating rates need to be equal such that schedules can be tested
* for equality.
*/
override fun hashCode() = ratioUnit?.hashCode() ?: 0
override fun toString() = ratioUnit.toString()
}
/**
* A very basic subclass to represent a fixed rate.
*/
class FixedRate(ratioUnit: RatioUnit) : Rate(ratioUnit) {
fun isPositive(): Boolean = ratioUnit!!.value > BigDecimal("0.0")
override fun equals(other: Any?) = other?.javaClass == javaClass && super.equals(other)
override fun hashCode() = super.hashCode()
}
/**
* The parent class of the Floating rate classes.
*/
open class FloatingRate : Rate(null)
/**
* So a reference rate is a rate that takes its value from a source at a given date
* e.g. LIBOR 6M as of 17 March 2016. Hence it requires a source (name) and a value date in the getAsOf(..) method.
*/
class ReferenceRate(val oracle: String, val tenor: Tenor, val name: String) : FloatingRate() {
override fun toString(): String = "$name - $tenor"
}
// TODO: For further discussion.
operator fun Amount<Currency>.times(other: RatioUnit): Amount<Currency> = Amount((BigDecimal(this.quantity).multiply(other.value)).longValueExact(), this.token)
//operator fun Amount<Currency>.times(other: FixedRate): Amount<Currency> = Amount<Currency>((BigDecimal(this.pennies).multiply(other.value)).longValueExact(), this.currency)
//fun Amount<Currency>.times(other: InterestRateSwap.RatioUnit): Amount<Currency> = Amount<Currency>((BigDecimal(this.pennies).multiply(other.value)).longValueExact(), this.currency)
operator fun kotlin.Int.times(other: FixedRate): Int = BigDecimal(this).multiply(other.ratioUnit!!.value).intValueExact()
operator fun Int.times(other: Rate): Int = BigDecimal(this).multiply(other.ratioUnit!!.value).intValueExact()
operator fun Int.times(other: RatioUnit): Int = BigDecimal(this).multiply(other.value).intValueExact()

View File

@ -1,718 +0,0 @@
package com.r3corda.contracts
import com.r3corda.core.contracts.*
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.testing.*
import com.r3corda.testing.node.MockServices
import org.junit.Test
import java.math.BigDecimal
import java.time.LocalDate
import java.util.*
fun createDummyIRS(irsSelect: Int): InterestRateSwap.State {
return when (irsSelect) {
1 -> {
val fixedLeg = InterestRateSwap.FixedLeg(
fixedRatePayer = MEGA_CORP,
notional = 15900000.DOLLARS,
paymentFrequency = Frequency.SemiAnnual,
effectiveDate = LocalDate.of(2016, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2026, 3, 10),
terminationDateAdjustment = null,
fixedRate = FixedRate(PercentageRatioUnit("1.677")),
dayCountBasisDay = DayCountBasisDay.D30,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 3,
paymentCalendar = BusinessCalendar.getInstance("London", "NewYork"),
interestPeriodAdjustment = AccrualAdjustment.Adjusted
)
val floatingLeg = InterestRateSwap.FloatingLeg(
floatingRatePayer = MINI_CORP,
notional = 15900000.DOLLARS,
paymentFrequency = Frequency.Quarterly,
effectiveDate = LocalDate.of(2016, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2026, 3, 10),
terminationDateAdjustment = null,
dayCountBasisDay = DayCountBasisDay.D30,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
fixingRollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
resetDayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 3,
paymentCalendar = BusinessCalendar.getInstance("London", "NewYork"),
interestPeriodAdjustment = AccrualAdjustment.Adjusted,
fixingPeriodOffset = 2,
resetRule = PaymentRule.InAdvance,
fixingsPerPayment = Frequency.Quarterly,
fixingCalendar = BusinessCalendar.getInstance("London"),
index = "LIBOR",
indexSource = "TEL3750",
indexTenor = Tenor("3M")
)
val calculation = InterestRateSwap.Calculation (
// TODO: this seems to fail quite dramatically
//expression = "fixedLeg.notional * fixedLeg.fixedRate",
// TODO: How I want it to look
//expression = "( fixedLeg.notional * (fixedLeg.fixedRate)) - (floatingLeg.notional * (rateSchedule.get(context.getDate('currentDate'))))",
// How it's ended up looking, which I think is now broken but it's a WIP.
expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" +
"(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"),
floatingLegPaymentSchedule = HashMap(),
fixedLegPaymentSchedule = HashMap()
)
val EUR = currency("EUR")
val common = InterestRateSwap.Common(
baseCurrency = EUR,
eligibleCurrency = EUR,
eligibleCreditSupport = "Cash in an Eligible Currency",
independentAmounts = Amount(0, EUR),
threshold = Amount(0, EUR),
minimumTransferAmount = Amount(250000 * 100, EUR),
rounding = Amount(10000 * 100, EUR),
valuationDateDescription = "Every Local Business Day",
notificationTime = "2:00pm London",
resolutionTime = "2:00pm London time on the first LocalBusiness Day following the date on which the notice is given ",
interestRate = ReferenceRate("T3270", Tenor("6M"), "EONIA"),
addressForTransfers = "",
exposure = UnknownType(),
localBusinessDay = BusinessCalendar.getInstance("London"),
tradeID = "trade1",
hashLegalDocs = "put hash here",
dailyInterestAmount = Expression("(CashAmount * InterestRate ) / (fixedLeg.notional.currency.currencyCode.equals('GBP')) ? 365 : 360")
)
InterestRateSwap.State(fixedLeg = fixedLeg, floatingLeg = floatingLeg, calculation = calculation, common = common)
}
2 -> {
// 10y swap, we pay 1.3% fixed 30/360 semi, rec 3m usd libor act/360 Q on 25m notional (mod foll/adj on both sides)
// I did a mock up start date 10/03/2015 10/03/2025 so you have 5 cashflows on float side that have been preset the rest are unknown
val fixedLeg = InterestRateSwap.FixedLeg(
fixedRatePayer = MEGA_CORP,
notional = 25000000.DOLLARS,
paymentFrequency = Frequency.SemiAnnual,
effectiveDate = LocalDate.of(2015, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2025, 3, 10),
terminationDateAdjustment = null,
fixedRate = FixedRate(PercentageRatioUnit("1.3")),
dayCountBasisDay = DayCountBasisDay.D30,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 0,
paymentCalendar = BusinessCalendar.getInstance(),
interestPeriodAdjustment = AccrualAdjustment.Adjusted
)
val floatingLeg = InterestRateSwap.FloatingLeg(
floatingRatePayer = MINI_CORP,
notional = 25000000.DOLLARS,
paymentFrequency = Frequency.Quarterly,
effectiveDate = LocalDate.of(2015, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2025, 3, 10),
terminationDateAdjustment = null,
dayCountBasisDay = DayCountBasisDay.DActual,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
fixingRollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
resetDayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 0,
paymentCalendar = BusinessCalendar.getInstance(),
interestPeriodAdjustment = AccrualAdjustment.Adjusted,
fixingPeriodOffset = 2,
resetRule = PaymentRule.InAdvance,
fixingsPerPayment = Frequency.Quarterly,
fixingCalendar = BusinessCalendar.getInstance(),
index = "USD LIBOR",
indexSource = "TEL3750",
indexTenor = Tenor("3M")
)
val calculation = InterestRateSwap.Calculation (
// TODO: this seems to fail quite dramatically
//expression = "fixedLeg.notional * fixedLeg.fixedRate",
// TODO: How I want it to look
//expression = "( fixedLeg.notional * (fixedLeg.fixedRate)) - (floatingLeg.notional * (rateSchedule.get(context.getDate('currentDate'))))",
// How it's ended up looking, which I think is now broken but it's a WIP.
expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" +
"(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"),
floatingLegPaymentSchedule = HashMap(),
fixedLegPaymentSchedule = HashMap()
)
val EUR = currency("EUR")
val common = InterestRateSwap.Common(
baseCurrency = EUR,
eligibleCurrency = EUR,
eligibleCreditSupport = "Cash in an Eligible Currency",
independentAmounts = Amount(0, EUR),
threshold = Amount(0, EUR),
minimumTransferAmount = Amount(250000 * 100, EUR),
rounding = Amount(10000 * 100, EUR),
valuationDateDescription = "Every Local Business Day",
notificationTime = "2:00pm London",
resolutionTime = "2:00pm London time on the first LocalBusiness Day following the date on which the notice is given ",
interestRate = ReferenceRate("T3270", Tenor("6M"), "EONIA"),
addressForTransfers = "",
exposure = UnknownType(),
localBusinessDay = BusinessCalendar.getInstance("London"),
tradeID = "trade2",
hashLegalDocs = "put hash here",
dailyInterestAmount = Expression("(CashAmount * InterestRate ) / (fixedLeg.notional.currency.currencyCode.equals('GBP')) ? 365 : 360")
)
return InterestRateSwap.State(fixedLeg = fixedLeg, floatingLeg = floatingLeg, calculation = calculation, common = common)
}
else -> TODO("IRS number $irsSelect not defined")
}
}
class IRSTests {
@Test
fun ok() {
trade().verifies()
}
@Test
fun `ok with groups`() {
tradegroups().verifies()
}
/**
* Generate an IRS txn - we'll need it for a few things.
*/
fun generateIRSTxn(irsSelect: Int): SignedTransaction {
val dummyIRS = createDummyIRS(irsSelect)
val genTX: SignedTransaction = run {
val gtx = InterestRateSwap().generateAgreement(
fixedLeg = dummyIRS.fixedLeg,
floatingLeg = dummyIRS.floatingLeg,
calculation = dummyIRS.calculation,
common = dummyIRS.common,
notary = DUMMY_NOTARY).apply {
setTime(TEST_TX_TIME, 30.seconds)
signWith(MEGA_CORP_KEY)
signWith(MINI_CORP_KEY)
signWith(DUMMY_NOTARY_KEY)
}
gtx.toSignedTransaction()
}
return genTX
}
/**
* Just make sure it's sane.
*/
@Test
fun pprintIRS() {
val irs = singleIRS()
println(irs.prettyPrint())
}
/**
* Utility so I don't have to keep typing this.
*/
fun singleIRS(irsSelector: Int = 1): InterestRateSwap.State {
return generateIRSTxn(irsSelector).tx.outputs.map { it.data }.filterIsInstance<InterestRateSwap.State>().single()
}
/**
* Test the generate. No explicit exception as if something goes wrong, we'll find out anyway.
*/
@Test
fun generateIRS() {
// Tests aren't allowed to return things
generateIRSTxn(1)
}
/**
* Testing a simple IRS, add a few fixings and then display as CSV.
*/
@Test
fun `IRS Export test`() {
// No transactions etc required - we're just checking simple maths and export functionallity
val irs = singleIRS(2)
var newCalculation = irs.calculation
val fixings = mapOf(LocalDate.of(2015, 3, 6) to "0.6",
LocalDate.of(2015, 6, 8) to "0.75",
LocalDate.of(2015, 9, 8) to "0.8",
LocalDate.of(2015, 12, 8) to "0.55",
LocalDate.of(2016, 3, 8) to "0.644")
for ((key, value) in fixings) {
newCalculation = newCalculation.applyFixing(key, FixedRate(PercentageRatioUnit(value)))
}
val newIRS = InterestRateSwap.State(irs.fixedLeg, irs.floatingLeg, newCalculation, irs.common)
println(newIRS.exportIRSToCSV())
}
/**
* Make sure it has a schedule and the schedule has some unfixed rates.
*/
@Test
fun `next fixing date`() {
val irs = singleIRS(1)
println(irs.calculation.nextFixingDate())
}
/**
* Iterate through all the fix dates and add something.
*/
@Test
fun generateIRSandFixSome() {
val services = MockServices()
var previousTXN = generateIRSTxn(1)
previousTXN.toLedgerTransaction(services).verify()
services.recordTransactions(previousTXN)
fun currentIRS() = previousTXN.tx.outputs.map { it.data }.filterIsInstance<InterestRateSwap.State>().single()
while (true) {
val nextFix: FixOf = currentIRS().nextFixingOf() ?: break
val fixTX: SignedTransaction = run {
val tx = TransactionType.General.Builder(DUMMY_NOTARY)
val fixing = Fix(nextFix, "0.052".percent.value)
InterestRateSwap().generateFix(tx, previousTXN.tx.outRef(0), fixing)
with(tx) {
setTime(TEST_TX_TIME, 30.seconds)
signWith(MEGA_CORP_KEY)
signWith(MINI_CORP_KEY)
signWith(DUMMY_NOTARY_KEY)
}
tx.toSignedTransaction()
}
fixTX.toLedgerTransaction(services).verify()
services.recordTransactions(fixTX)
previousTXN = fixTX
}
}
// Move these later as they aren't IRS specific.
@Test
fun `test some rate objects 100 * FixedRate(5%)`() {
val r1 = FixedRate(PercentageRatioUnit("5"))
assert(100 * r1 == 5)
}
@Test
fun `expression calculation testing`() {
val dummyIRS = singleIRS()
val stuffToPrint: ArrayList<String> = arrayListOf(
"fixedLeg.notional.quantity",
"fixedLeg.fixedRate.ratioUnit",
"fixedLeg.fixedRate.ratioUnit.value",
"floatingLeg.notional.quantity",
"fixedLeg.fixedRate",
"currentBusinessDate",
"calculation.floatingLegPaymentSchedule.get(currentBusinessDate)",
"fixedLeg.notional.token.currencyCode",
"fixedLeg.notional.quantity * 10",
"fixedLeg.notional.quantity * fixedLeg.fixedRate.ratioUnit.value",
"(fixedLeg.notional.token.currencyCode.equals('GBP')) ? 365 : 360 ",
"(fixedLeg.notional.quantity * (fixedLeg.fixedRate.ratioUnit.value))"
// "calculation.floatingLegPaymentSchedule.get(context.getDate('currentDate')).rate"
// "calculation.floatingLegPaymentSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value",
//"( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) - (floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))",
// "( fixedLeg.notional * fixedLeg.fixedRate )"
)
for (i in stuffToPrint) {
println(i)
val z = dummyIRS.evaluateCalculation(LocalDate.of(2016, 9, 15), Expression(i))
println(z.javaClass)
println(z)
println("-----------")
}
// This does not throw an exception in the test itself; it evaluates the above and they will throw if they do not pass.
}
/**
* Generates a typical transactional history for an IRS.
*/
fun trade(): LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter> {
val ld = LocalDate.of(2016, 3, 8)
val bd = BigDecimal("0.0063518")
return ledger {
transaction("Agreement") {
output("irs post agreement") { singleIRS() }
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
transaction("Fix") {
input("irs post agreement")
val postAgreement = "irs post agreement".output<InterestRateSwap.State>()
output("irs post first fixing") {
postAgreement.copy(
postAgreement.fixedLeg,
postAgreement.floatingLeg,
postAgreement.calculation.applyFixing(ld, FixedRate(RatioUnit(bd))),
postAgreement.common
)
}
command(ORACLE_PUBKEY) {
InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd))
}
timestamp(TEST_TX_TIME)
this.verifies()
}
}
}
@Test
fun `ensure failure occurs when there are inbound states for an agreement command`() {
val irs = singleIRS()
transaction {
input() { irs }
output("irs post agreement") { irs }
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "There are no in states for an agreement"
}
}
@Test
fun `ensure failure occurs when no events in fix schedule`() {
val irs = singleIRS()
val emptySchedule = HashMap<LocalDate, FixedRatePaymentEvent>()
transaction {
output() {
irs.copy(calculation = irs.calculation.copy(fixedLegPaymentSchedule = emptySchedule))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "There are events in the fix schedule"
}
}
@Test
fun `ensure failure occurs when no events in floating schedule`() {
val irs = singleIRS()
val emptySchedule = HashMap<LocalDate, FloatingRatePaymentEvent>()
transaction {
output() {
irs.copy(calculation = irs.calculation.copy(floatingLegPaymentSchedule = emptySchedule))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "There are events in the float schedule"
}
}
@Test
fun `ensure notionals are non zero`() {
val irs = singleIRS()
transaction {
output() {
irs.copy(irs.fixedLeg.copy(notional = irs.fixedLeg.notional.copy(quantity = 0)))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "All notionals must be non zero"
}
transaction {
output() {
irs.copy(irs.fixedLeg.copy(notional = irs.floatingLeg.notional.copy(quantity = 0)))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "All notionals must be non zero"
}
}
@Test
fun `ensure positive rate on fixed leg`() {
val irs = singleIRS()
val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(fixedRate = FixedRate(PercentageRatioUnit("-0.1"))))
transaction {
output() {
modifiedIRS
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The fixed leg rate must be positive"
}
}
/**
* This will be modified once we adapt the IRS to be cross currency.
*/
@Test
fun `ensure same currency notionals`() {
val irs = singleIRS()
val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.fixedLeg.notional.quantity, Currency.getInstance("JPY"))))
transaction {
output() {
modifiedIRS
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The currency of the notionals must be the same"
}
}
@Test
fun `ensure notional amounts are equal`() {
val irs = singleIRS()
val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.floatingLeg.notional.quantity + 1, irs.floatingLeg.notional.token)))
transaction {
output() {
modifiedIRS
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "All leg notionals must be the same"
}
}
@Test
fun `ensure trade date and termination date checks are done pt1`() {
val irs = singleIRS()
val modifiedIRS1 = irs.copy(fixedLeg = irs.fixedLeg.copy(terminationDate = irs.fixedLeg.effectiveDate.minusDays(1)))
transaction {
output() {
modifiedIRS1
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The effective date is before the termination date for the fixed leg"
}
val modifiedIRS2 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.floatingLeg.effectiveDate.minusDays(1)))
transaction {
output() {
modifiedIRS2
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The effective date is before the termination date for the floating leg"
}
}
@Test
fun `ensure trade date and termination date checks are done pt2`() {
val irs = singleIRS()
val modifiedIRS3 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.fixedLeg.terminationDate.minusDays(1)))
transaction {
output() {
modifiedIRS3
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The termination dates are aligned"
}
val modifiedIRS4 = irs.copy(floatingLeg = irs.floatingLeg.copy(effectiveDate = irs.fixedLeg.effectiveDate.minusDays(1)))
transaction {
output() {
modifiedIRS4
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The effective dates are aligned"
}
}
@Test
fun `various fixing tests`() {
val ld = LocalDate.of(2016, 3, 8)
val bd = BigDecimal("0.0063518")
transaction {
output("irs post agreement") { singleIRS() }
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
val oldIRS = singleIRS(1)
val newIRS = oldIRS.copy(oldIRS.fixedLeg,
oldIRS.floatingLeg,
oldIRS.calculation.applyFixing(ld, FixedRate(RatioUnit(bd))),
oldIRS.common)
transaction {
input() {
oldIRS
}
// Templated tweak for reference. A corrent fixing applied should be ok
tweak {
command(ORACLE_PUBKEY) {
InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd))
}
timestamp(TEST_TX_TIME)
output() { newIRS }
this.verifies()
}
// This test makes sure that verify confirms the fixing was applied and there is a difference in the old and new
tweak {
command(ORACLE_PUBKEY) { InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd)) }
timestamp(TEST_TX_TIME)
output() { oldIRS }
this `fails with` "There is at least one difference in the IRS floating leg payment schedules"
}
// This tests tries to sneak in a change to another fixing (which may or may not be the latest one)
tweak {
command(ORACLE_PUBKEY) { InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd)) }
timestamp(TEST_TX_TIME)
val firstResetKey = newIRS.calculation.floatingLegPaymentSchedule.keys.first()
val firstResetValue = newIRS.calculation.floatingLegPaymentSchedule[firstResetKey]
val modifiedFirstResetValue = firstResetValue!!.copy(notional = Amount(firstResetValue.notional.quantity, Currency.getInstance("JPY")))
output() {
newIRS.copy(
newIRS.fixedLeg,
newIRS.floatingLeg,
newIRS.calculation.copy(floatingLegPaymentSchedule = newIRS.calculation.floatingLegPaymentSchedule.plus(
Pair(firstResetKey, modifiedFirstResetValue))),
newIRS.common
)
}
this `fails with` "There is only one change in the IRS floating leg payment schedule"
}
// This tests modifies the payment currency for the fixing
tweak {
command(ORACLE_PUBKEY) { InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd)) }
timestamp(TEST_TX_TIME)
val latestReset = newIRS.calculation.floatingLegPaymentSchedule.filter { it.value.rate is FixedRate }.maxBy { it.key }
val modifiedLatestResetValue = latestReset!!.value.copy(notional = Amount(latestReset.value.notional.quantity, Currency.getInstance("JPY")))
output() {
newIRS.copy(
newIRS.fixedLeg,
newIRS.floatingLeg,
newIRS.calculation.copy(floatingLegPaymentSchedule = newIRS.calculation.floatingLegPaymentSchedule.plus(
Pair(latestReset.key, modifiedLatestResetValue))),
newIRS.common
)
}
this `fails with` "The fix payment has the same currency as the notional"
}
}
}
/**
* This returns an example of transactions that are grouped by TradeId and then a fixing applied.
* It's important to make the tradeID different for two reasons, the hashes will be the same and all sorts of confusion will
* result and the grouping won't work either.
* In reality, the only fields that should be in common will be the next fixing date and the reference rate.
*/
fun tradegroups(): LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter> {
val ld1 = LocalDate.of(2016, 3, 8)
val bd1 = BigDecimal("0.0063518")
val irs = singleIRS()
return ledger {
transaction("Agreement") {
output("irs post agreement1") {
irs.copy(
irs.fixedLeg,
irs.floatingLeg,
irs.calculation,
irs.common.copy(tradeID = "t1")
)
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
transaction("Agreement") {
output("irs post agreement2") {
irs.copy(
linearId = UniqueIdentifier("t2"),
fixedLeg = irs.fixedLeg,
floatingLeg = irs.floatingLeg,
calculation = irs.calculation,
common = irs.common.copy(tradeID = "t2")
)
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
transaction("Fix") {
input("irs post agreement1")
input("irs post agreement2")
val postAgreement1 = "irs post agreement1".output<InterestRateSwap.State>()
output("irs post first fixing1") {
postAgreement1.copy(
postAgreement1.fixedLeg,
postAgreement1.floatingLeg,
postAgreement1.calculation.applyFixing(ld1, FixedRate(RatioUnit(bd1))),
postAgreement1.common.copy(tradeID = "t1")
)
}
val postAgreement2 = "irs post agreement2".output<InterestRateSwap.State>()
output("irs post first fixing2") {
postAgreement2.copy(
postAgreement2.fixedLeg,
postAgreement2.floatingLeg,
postAgreement2.calculation.applyFixing(ld1, FixedRate(RatioUnit(bd1))),
postAgreement2.common.copy(tradeID = "t2")
)
}
command(ORACLE_PUBKEY) {
InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld1, Tenor("3M")), bd1))
}
timestamp(TEST_TX_TIME)
this.verifies()
}
}
}
}

View File

@ -15,6 +15,7 @@ import java.io.BufferedInputStream
import java.io.InputStream import java.io.InputStream
import java.math.BigDecimal import java.math.BigDecimal
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.LinkOption
import java.nio.file.Path import java.nio.file.Path
import java.time.Duration import java.time.Duration
import java.time.temporal.Temporal import java.time.temporal.Temporal
@ -89,6 +90,7 @@ inline fun <T> SettableFuture<T>.catch(block: () -> T) {
} }
fun <R> Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block) fun <R> Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block)
fun Path.exists(vararg options: LinkOption): Boolean = Files.exists(this, *options)
// Simple infix function to add back null safety that the JDK lacks: timeA until timeB // Simple infix function to add back null safety that the JDK lacks: timeA until timeB
infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive) infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive)
@ -290,7 +292,7 @@ fun <T, I: Comparable<I>> Iterable<T>.isOrderedAndUnique(extractId: T.() -> I):
if (lastLast == null) { if (lastLast == null) {
true true
} else { } else {
lastLast.compareTo(extractId(it)) < 0 lastLast < extractId(it)
} }
} }
} }

View File

@ -27,7 +27,7 @@ import kotlin.reflect.primaryConstructor
*/ */
class ProtocolLogicRefFactory(private val protocolWhitelist: Map<String, Set<String>>) : SingletonSerializeAsToken() { class ProtocolLogicRefFactory(private val protocolWhitelist: Map<String, Set<String>>) : SingletonSerializeAsToken() {
constructor() : this(mapOf(Pair(TwoPartyDealProtocol.FixingRoleDecider::class.java.name, setOf(StateRef::class.java.name, Duration::class.java.name)))) constructor() : this(mapOf())
// Pending real dependence on AppContext for class loading etc // Pending real dependence on AppContext for class loading etc
@Suppress("UNUSED_PARAMETER") @Suppress("UNUSED_PARAMETER")

View File

@ -1,19 +0,0 @@
package com.r3corda.core.utilities
import java.time.*
/**
* This whole file exists as short cuts to get demos working. In reality we'd have static data and/or rules engine
* defining things like this. It currently resides in the core module because it needs to be visible to the IRS
* contract.
*/
// We at some future point may implement more than just this constant announcement window and thus use the params.
@Suppress("UNUSED_PARAMETER")
fun suggestInterestRateAnnouncementTimeWindow(index: String, source: String, date: LocalDate): TimeWindow {
// TODO: we would ordinarily convert clock to same time zone as the index/source would announce in
// and suggest an announcement time for the interest rate
// Here we apply a blanket announcement time of 11:45 London irrespective of source or index
val time = LocalTime.of(11, 45)
val zoneId = ZoneId.of("Europe/London")
return TimeWindow(ZonedDateTime.of(date, time, zoneId).toInstant(), Duration.ofHours(24))
}

View File

@ -1,109 +0,0 @@
package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.contracts.Fix
import com.r3corda.core.contracts.FixOf
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.suggestInterestRateAnnouncementTimeWindow
import com.r3corda.protocols.RatesFixProtocol.FixOutOfRange
import java.math.BigDecimal
import java.time.Instant
import java.util.*
// This code is unit tested in NodeInterestRates.kt
/**
* This protocol queries the given oracle for an interest rate fix, and if it is within the given tolerance embeds the
* fix in the transaction and then proceeds to get the oracle to sign it. Although the [call] method combines the query
* and signing step, you can run the steps individually by constructing this object and then using the public methods
* for each step.
*
* @throws FixOutOfRange if the returned fix was further away from the expected rate by the given amount.
*/
open class RatesFixProtocol(protected val tx: TransactionBuilder,
private val oracle: Party,
private val fixOf: FixOf,
private val expectedRate: BigDecimal,
private val rateTolerance: BigDecimal,
override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() {
companion object {
class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate")
object WORKING : ProgressTracker.Step("Working with data returned by oracle")
object SIGNING : ProgressTracker.Step("Requesting confirmation signature from interest rate oracle")
fun tracker(fixName: String) = ProgressTracker(QUERYING(fixName), WORKING, SIGNING)
}
class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount")
data class QueryRequest(val queries: List<FixOf>, val deadline: Instant)
data class SignRequest(val tx: WireTransaction)
@Suspendable
override fun call() {
progressTracker.currentStep = progressTracker.steps[1]
val fix = subProtocol(FixQueryProtocol(fixOf, oracle))
progressTracker.currentStep = WORKING
checkFixIsNearExpected(fix)
tx.addCommand(fix, oracle.owningKey)
beforeSigning(fix)
progressTracker.currentStep = SIGNING
val signature = subProtocol(FixSignProtocol(tx, oracle))
tx.addSignatureUnchecked(signature)
}
/**
* You can override this to perform any additional work needed after the fix is added to the transaction but
* before it's sent back to the oracle for signing (for example, adding output states that depend on the fix).
*/
@Suspendable
protected open fun beforeSigning(fix: Fix) {
}
private fun checkFixIsNearExpected(fix: Fix) {
val delta = (fix.value - expectedRate).abs()
if (delta > rateTolerance) {
// TODO: Kick to a user confirmation / ui flow if it's out of bounds instead of raising an exception.
throw FixOutOfRange(delta)
}
}
class FixQueryProtocol(val fixOf: FixOf, val oracle: Party) : ProtocolLogic<Fix>() {
@Suspendable
override fun call(): Fix {
val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end
// TODO: add deadline to receive
val resp = sendAndReceive<ArrayList<Fix>>(oracle, QueryRequest(listOf(fixOf), deadline))
return resp.unwrap {
val fix = it.first()
// Check the returned fix is for what we asked for.
check(fix.of == fixOf)
fix
}
}
}
class FixSignProtocol(val tx: TransactionBuilder, val oracle: Party) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
@Suspendable
override fun call(): DigitalSignature.LegallyIdentifiable {
val wtx = tx.toWireTransaction()
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, SignRequest(wtx))
return resp.unwrap { sig ->
check(sig.signer == oracle)
tx.checkSignature(sig)
sig
}
}
}
}

View File

@ -306,126 +306,4 @@ object TwoPartyDealProtocol {
} }
} }
/**
* One side of the fixing protocol for an interest rate swap, but could easily be generalised further.
*
* Do not infer too much from the name of the class. This is just to indicate that it is the "side"
* of the protocol that is run by the party with the fixed leg of swap deal, which is the basis for deciding
* who does what in the protocol.
*/
class Fixer(override val otherParty: Party,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<FixingSession>() {
private lateinit var txState: TransactionState<*>
private lateinit var deal: FixableDealState
override fun validateHandshake(handshake: Handshake<FixingSession>): Handshake<FixingSession> {
logger.trace { "Got fixing request for: ${handshake.payload}" }
txState = serviceHub.loadState(handshake.payload.ref)
deal = txState.data as FixableDealState
// validate the party that initiated is the one on the deal and that the recipient corresponds with it.
// TODO: this is in no way secure and will be replaced by general session initiation logic in the future
val myName = serviceHub.myInfo.legalIdentity.name
// Also check we are one of the parties
deal.parties.filter { it.name == myName }.single()
return handshake
}
@Suspendable
override fun assembleSharedTX(handshake: Handshake<FixingSession>): Pair<TransactionBuilder, List<PublicKey>> {
@Suppress("UNCHECKED_CAST")
val fixOf = deal.nextFixingOf()!!
// TODO Do we need/want to substitute in new public keys for the Parties?
val myName = serviceHub.myInfo.legalIdentity.name
val myOldParty = deal.parties.single { it.name == myName }
val newDeal = deal
val ptx = TransactionType.General.Builder(txState.notary)
val oracle = serviceHub.networkMapCache.get(handshake.payload.oracleType).first()
val addFixing = object : RatesFixProtocol(ptx, oracle.serviceIdentities(handshake.payload.oracleType).first(), fixOf, BigDecimal.ZERO, BigDecimal.ONE) {
@Suspendable
override fun beforeSigning(fix: Fix) {
newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload.ref), fix)
// And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt
// to have one.
ptx.setTime(serviceHub.clock.instant(), 30.seconds)
}
}
subProtocol(addFixing)
return Pair(ptx, arrayListOf(myOldParty.owningKey))
}
}
/**
* One side of the fixing protocol for an interest rate swap, but could easily be generalised furher.
*
* As per the [Fixer], do not infer too much from this class name in terms of business roles. This
* is just the "side" of the protocol run by the party with the floating leg as a way of deciding who
* does what in the protocol.
*/
class Floater(override val otherParty: Party,
override val payload: FixingSession,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() {
@Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty {
val state = serviceHub.loadState(payload.ref) as TransactionState<FixableDealState>
StateAndRef(state, payload.ref)
}
override val myKeyPair: KeyPair get() {
val myName = serviceHub.myInfo.legalIdentity.name
val publicKey = dealToFix.state.data.parties.filter { it.name == myName }.single().owningKey
return serviceHub.keyManagementService.toKeyPair(publicKey)
}
override val notaryNode: NodeInfo get() =
serviceHub.networkMapCache.notaryNodes.filter { it.notaryIdentity == dealToFix.state.notary }.single()
}
/** Used to set up the session between [Floater] and [Fixer] */
data class FixingSession(val ref: StateRef, val oracleType: ServiceType)
/**
* This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing.
*
* It is kicked off as an activity on both participant nodes by the scheduler when it's time for a fixing. If the
* Fixer role is chosen, then that will be initiated by the [FixingSession] message sent from the other party and
* handled by the [FixingSessionInitiationHandler].
*
* TODO: Replace [FixingSession] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists.
*/
class FixingRoleDecider(val ref: StateRef,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
companion object {
class LOADING() : ProgressTracker.Step("Loading state to decide fixing role")
fun tracker() = ProgressTracker(LOADING())
}
@Suspendable
override fun call(): Unit {
progressTracker.nextStep()
val dealToFix = serviceHub.loadState(ref)
// TODO: this is not the eventual mechanism for identifying the parties
val fixableDeal = (dealToFix.data as FixableDealState)
val sortedParties = fixableDeal.parties.sortedBy { it.name }
if (sortedParties[0].name == serviceHub.myInfo.legalIdentity.name) {
val fixing = FixingSession(ref, fixableDeal.oracleType)
// Start the Floater which will then kick-off the Fixer
subProtocol(Floater(sortedParties[1], fixing))
}
}
}
} }

View File

@ -0,0 +1,24 @@
package com.r3corda.core
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class UtilsTest {
fun `ordered and unique basic`() {
val basic = listOf(1, 2, 3, 5, 8)
assertTrue(basic.isOrderedAndUnique { this })
val negative = listOf(-1, 2, 5)
assertTrue(negative.isOrderedAndUnique { this })
}
fun `ordered and unique duplicate`() {
val duplicated = listOf(1, 2, 2, 3, 5, 8)
assertFalse(duplicated.isOrderedAndUnique { this })
}
fun `ordered and unique out of sequence`() {
val mixed = listOf(3, 1, 2, 8, 5)
assertFalse(mixed.isOrderedAndUnique { this })
}
}

View File

@ -3,6 +3,7 @@ package com.r3corda.docs
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.client.CordaRPCClient import com.r3corda.client.CordaRPCClient
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.config.NodeSSLConfiguration
import org.graphstream.graph.Edge import org.graphstream.graph.Edge
import org.graphstream.graph.Node import org.graphstream.graph.Node
import org.graphstream.graph.implementations.SingleGraph import org.graphstream.graph.implementations.SingleGraph
@ -26,12 +27,18 @@ fun main(args: Array<String>) {
} }
val nodeAddress = HostAndPort.fromString(args[0]) val nodeAddress = HostAndPort.fromString(args[0])
val printOrVisualise = PrintOrVisualise.valueOf(args[1]) val printOrVisualise = PrintOrVisualise.valueOf(args[1])
val certificatesPath = Paths.get("build/trader-demo/buyer/certificates") val sslConfig = object : NodeSSLConfiguration {
override val certificatesPath = Paths.get("build/trader-demo/buyer/certificates")
override val keyStorePassword = "cordacadevpass"
override val trustStorePassword = "trustpass"
}
// END 1 // END 1
// START 2 // START 2
val client = CordaRPCClient(nodeAddress, certificatesPath) val username = System.console().readLine("Enter username: ")
client.start() val password = String(System.console().readPassword("Enter password: "))
val client = CordaRPCClient(nodeAddress, sslConfig)
client.start(username, password)
val proxy = client.proxy() val proxy = client.proxy()
// END 2 // END 2
@ -65,7 +72,7 @@ fun main(args: Array<String>) {
futureTransactions.subscribe { transaction -> futureTransactions.subscribe { transaction ->
graph.addNode<Node>("${transaction.id}") graph.addNode<Node>("${transaction.id}")
transaction.tx.inputs.forEach { ref -> transaction.tx.inputs.forEach { ref ->
graph.addEdge<Edge>("${ref}", "${ref.txhash}", "${transaction.id}") graph.addEdge<Edge>("$ref", "${ref.txhash}", "${transaction.id}")
} }
} }
graph.display() graph.display()

View File

@ -16,7 +16,7 @@ we also need to access the certificates of the node, we will access the node's `
:start-after: START 1 :start-after: START 1
:end-before: END 1 :end-before: END 1
Now we can connect to the node itself: Now we can connect to the node itself using a valid RPC login. By default the user `user1` is available with password `test`.
.. literalinclude:: example-code/src/main/kotlin/com/r3corda/docs/ClientRpcTutorial.kt .. literalinclude:: example-code/src/main/kotlin/com/r3corda/docs/ClientRpcTutorial.kt
:language: kotlin :language: kotlin

View File

@ -6,7 +6,7 @@ import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.explorer.model.IdentityModel import com.r3corda.explorer.model.IdentityModel
import com.r3corda.node.driver.PortAllocation import com.r3corda.node.driver.PortAllocation
import com.r3corda.node.driver.driver import com.r3corda.node.driver.driver
import com.r3corda.node.driver.startClient import com.r3corda.node.services.config.configureTestSSL
import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.SimpleNotaryService
import javafx.stage.Stage import javafx.stage.Stage
import tornadofx.App import tornadofx.App
@ -32,14 +32,12 @@ class Main : App() {
val aliceNodeFuture = startNode("Alice") val aliceNodeFuture = startNode("Alice")
val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type))) val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
val aliceNode = aliceNodeFuture.get() val aliceNode = aliceNodeFuture.get().nodeInfo
val notaryNode = notaryNodeFuture.get() val notaryNode = notaryNodeFuture.get().nodeInfo
val aliceClient = startClient(aliceNode).get()
Models.get<IdentityModel>(Main::class).notary.set(notaryNode.notaryIdentity) Models.get<IdentityModel>(Main::class).notary.set(notaryNode.notaryIdentity)
Models.get<IdentityModel>(Main::class).myIdentity.set(aliceNode.legalIdentity) Models.get<IdentityModel>(Main::class).myIdentity.set(aliceNode.legalIdentity)
Models.get<NodeMonitorModel>(Main::class).register(aliceNode, aliceClient.config.certificatesPath) Models.get<NodeMonitorModel>(Main::class).register(aliceNode, configureTestSSL(), "user1", "test")
startNode("Bob").get() startNode("Bob").get()

View File

@ -134,6 +134,9 @@ dependencies {
compile "org.hibernate:hibernate-core:5.2.2.Final" compile "org.hibernate:hibernate-core:5.2.2.Final"
compile "org.hibernate:hibernate-java8:5.2.2.Final" compile "org.hibernate:hibernate-java8:5.2.2.Final"
// Capsule is a library for building independently executable fat JARs.
compile 'co.paralleluniverse:capsule:1.0.3'
// Integration test helpers // Integration test helpers
integrationTestCompile 'junit:junit:4.12' integrationTestCompile 'junit:junit:4.12'

View File

@ -11,14 +11,8 @@ import org.junit.Test
class DriverTests { class DriverTests {
companion object { companion object {
fun nodeMustBeUp(networkMapCache: NetworkMapCache, nodeInfo: NodeInfo, nodeName: String) { fun nodeMustBeUp(nodeInfo: NodeInfo, nodeName: String) {
val hostAndPort = ArtemisMessagingComponent.toHostAndPort(nodeInfo.address) val hostAndPort = ArtemisMessagingComponent.toHostAndPort(nodeInfo.address)
// Check that the node is registered in the network map
poll("network map cache for $nodeName") {
networkMapCache.get().firstOrNull {
it.legalIdentity.name == nodeName
}
}
// Check that the port is bound // Check that the port is bound
addressMustBeBound(hostAndPort) addressMustBeBound(hostAndPort)
} }
@ -36,31 +30,31 @@ class DriverTests {
val notary = startNode("TestNotary", setOf(ServiceInfo(SimpleNotaryService.type))) val notary = startNode("TestNotary", setOf(ServiceInfo(SimpleNotaryService.type)))
val regulator = startNode("Regulator", setOf(ServiceInfo(RegulatorService.type))) val regulator = startNode("Regulator", setOf(ServiceInfo(RegulatorService.type)))
nodeMustBeUp(networkMapCache, notary.get(), "TestNotary") nodeMustBeUp(notary.get().nodeInfo, "TestNotary")
nodeMustBeUp(networkMapCache, regulator.get(), "Regulator") nodeMustBeUp(regulator.get().nodeInfo, "Regulator")
Pair(notary.get(), regulator.get()) Pair(notary.get(), regulator.get())
} }
nodeMustBeDown(notary) nodeMustBeDown(notary.nodeInfo)
nodeMustBeDown(regulator) nodeMustBeDown(regulator.nodeInfo)
} }
@Test @Test
fun startingNodeWithNoServicesWorks() { fun startingNodeWithNoServicesWorks() {
val noService = driver { val noService = driver {
val noService = startNode("NoService") val noService = startNode("NoService")
nodeMustBeUp(networkMapCache, noService.get(), "NoService") nodeMustBeUp(noService.get().nodeInfo, "NoService")
noService.get() noService.get()
} }
nodeMustBeDown(noService) nodeMustBeDown(noService.nodeInfo)
} }
@Test @Test
fun randomFreePortAllocationWorks() { fun randomFreePortAllocationWorks() {
val nodeInfo = driver(portAllocation = PortAllocation.RandomFree()) { val nodeInfo = driver(portAllocation = PortAllocation.RandomFree()) {
val nodeInfo = startNode("NoService") val nodeInfo = startNode("NoService")
nodeMustBeUp(networkMapCache, nodeInfo.get(), "NoService") nodeMustBeUp(nodeInfo.get().nodeInfo, "NoService")
nodeInfo.get() nodeInfo.get()
} }
nodeMustBeDown(nodeInfo) nodeMustBeDown(nodeInfo.nodeInfo)
} }
} }

View File

@ -4,6 +4,7 @@ import com.r3corda.core.contracts.*
import com.r3corda.node.api.StatesQuery import com.r3corda.node.api.StatesQuery
import com.r3corda.core.crypto.DigitalSignature import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.serialization.SerializedBytes import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.transactions.SignedTransaction import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction import com.r3corda.core.transactions.WireTransaction
@ -41,6 +42,16 @@ interface APIServer {
@Produces(MediaType.TEXT_PLAIN) @Produces(MediaType.TEXT_PLAIN)
fun status(): Response fun status(): Response
/**
* Report this node's configuration and identities.
* Currently tunnels the NodeInfo as an encoding of the Kryo serialised form.
* TODO this functionality should be available via the RPC
*/
@GET
@Path("info")
@Produces(MediaType.APPLICATION_JSON)
fun info(): NodeInfo
/** /**
* Query your "local" states (containing only outputs involving you) and return the hashes & indexes associated with them * Query your "local" states (containing only outputs involving you) and return the hashes & indexes associated with them
* to probably be later inflated by fetchLedgerTransactions() or fetchStates() although because immutable you can cache them * to probably be later inflated by fetchLedgerTransactions() or fetchStates() although because immutable you can cache them

View File

@ -1,5 +1,7 @@
package com.r3corda.node.driver package com.r3corda.node.driver
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.module.SimpleModule
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.core.ThreadBox import com.r3corda.core.ThreadBox
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
@ -7,6 +9,7 @@ import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.NodeInfo import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.NetworkMapCache import com.r3corda.core.node.services.NetworkMapCache
import com.r3corda.core.node.services.ServiceInfo import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.config.ConfigHelper import com.r3corda.node.services.config.ConfigHelper
import com.r3corda.node.services.config.FullNodeConfiguration import com.r3corda.node.services.config.FullNodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingComponent import com.r3corda.node.services.messaging.ArtemisMessagingComponent
@ -14,22 +17,24 @@ import com.r3corda.node.services.messaging.ArtemisMessagingServer
import com.r3corda.node.services.messaging.NodeMessagingClient import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.network.InMemoryNetworkMapCache import com.r3corda.node.services.network.InMemoryNetworkMapCache
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.JsonSupport
import com.typesafe.config.Config import com.typesafe.config.Config
import com.typesafe.config.ConfigRenderOptions import com.typesafe.config.ConfigRenderOptions
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.io.File import java.io.File
import java.io.InputStreamReader
import java.net.* import java.net.*
import java.nio.file.Path import java.nio.file.Path
import java.nio.file.Paths import java.nio.file.Paths
import java.text.SimpleDateFormat import java.text.SimpleDateFormat
import java.time.Clock
import java.util.* import java.util.*
import java.util.concurrent.* import java.util.concurrent.*
import kotlin.concurrent.thread import kotlin.concurrent.thread
/** /**
* This file defines a small "Driver" DSL for starting up nodes. * This file defines a small "Driver" DSL for starting up nodes that is only intended for development, demos and tests.
* *
* The process the driver is run in behaves as an Artemis client and starts up other processes. Namely it first * The process the driver is run in behaves as an Artemis client and starts up other processes. Namely it first
* bootstraps a network map service to allow the specified nodes to connect to, then starts up the actual nodes. * bootstraps a network map service to allow the specified nodes to connect to, then starts up the actual nodes.
@ -54,38 +59,18 @@ interface DriverDSLExposedInterface {
* @param advertisedServices The set of services to be advertised by the node. Defaults to empty set. * @param advertisedServices The set of services to be advertised by the node. Defaults to empty set.
* @return The [NodeInfo] of the started up node retrieved from the network map service. * @return The [NodeInfo] of the started up node retrieved from the network map service.
*/ */
fun startNode(providedName: String? = null, advertisedServices: Set<ServiceInfo> = setOf()): Future<NodeInfo> fun startNode(providedName: String? = null, advertisedServices: Set<ServiceInfo> = setOf()): Future<NodeInfoAndConfig>
/**
* Starts an [NodeMessagingClient].
*
* @param providedName name of the client, which will be used for creating its directory.
* @param serverAddress the artemis server to connect to, for example a [Node].
*/
fun startClient(providedName: String, serverAddress: HostAndPort): Future<NodeMessagingClient>
/**
* Starts a local [ArtemisMessagingServer] of which there may only be one.
*/
fun startLocalServer(): Future<ArtemisMessagingServer>
fun waitForAllNodesToFinish() fun waitForAllNodesToFinish()
val networkMapCache: NetworkMapCache
} }
fun DriverDSLExposedInterface.startClient(localServer: ArtemisMessagingServer) =
startClient("driver-local-server-client", localServer.myHostPort)
fun DriverDSLExposedInterface.startClient(remoteNodeInfo: NodeInfo, providedName: String? = null) =
startClient(
providedName = providedName ?: "${remoteNodeInfo.legalIdentity.name}-client",
serverAddress = ArtemisMessagingComponent.toHostAndPort(remoteNodeInfo.address)
)
interface DriverDSLInternalInterface : DriverDSLExposedInterface { interface DriverDSLInternalInterface : DriverDSLExposedInterface {
fun start() fun start()
fun shutdown() fun shutdown()
} }
data class NodeInfoAndConfig(val nodeInfo: NodeInfo, val config: Config)
sealed class PortAllocation { sealed class PortAllocation {
abstract fun nextPort(): Int abstract fun nextPort(): Int
fun nextHostAndPort(): HostAndPort = HostAndPort.fromParts("localhost", nextPort()) fun nextHostAndPort(): HostAndPort = HostAndPort.fromParts("localhost", nextPort())
@ -122,6 +107,7 @@ sealed class PortAllocation {
* and may be specified in [DriverDSL.startNode]. * and may be specified in [DriverDSL.startNode].
* @param portAllocation The port allocation strategy to use for the messaging and the web server addresses. Defaults to incremental. * @param portAllocation The port allocation strategy to use for the messaging and the web server addresses. Defaults to incremental.
* @param debugPortAllocation The port allocation strategy to use for jvm debugging. Defaults to incremental. * @param debugPortAllocation The port allocation strategy to use for jvm debugging. Defaults to incremental.
* @param useTestClock If true the test clock will be used in Node.
* @param isDebug Indicates whether the spawned nodes should start in jdwt debug mode. * @param isDebug Indicates whether the spawned nodes should start in jdwt debug mode.
* @param dsl The dsl itself. * @param dsl The dsl itself.
* @return The value returned in the [dsl] closure. * @return The value returned in the [dsl] closure.
@ -130,6 +116,7 @@ fun <A> driver(
baseDirectory: String = "build/${getTimestampAsDirectoryName()}", baseDirectory: String = "build/${getTimestampAsDirectoryName()}",
portAllocation: PortAllocation = PortAllocation.Incremental(10000), portAllocation: PortAllocation = PortAllocation.Incremental(10000),
debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005),
useTestClock: Boolean = false,
isDebug: Boolean = false, isDebug: Boolean = false,
dsl: DriverDSLExposedInterface.() -> A dsl: DriverDSLExposedInterface.() -> A
) = genericDriver( ) = genericDriver(
@ -137,6 +124,7 @@ fun <A> driver(
portAllocation = portAllocation, portAllocation = portAllocation,
debugPortAllocation = debugPortAllocation, debugPortAllocation = debugPortAllocation,
baseDirectory = baseDirectory, baseDirectory = baseDirectory,
useTestClock = useTestClock,
isDebug = isDebug isDebug = isDebug
), ),
coerce = { it }, coerce = { it },
@ -216,17 +204,15 @@ fun <A> poll(pollName: String, pollIntervalMs: Long = 500, warnCount: Int = 120,
return result return result
} }
class DriverDSL( open class DriverDSL(
val portAllocation: PortAllocation, val portAllocation: PortAllocation,
val debugPortAllocation: PortAllocation, val debugPortAllocation: PortAllocation,
val baseDirectory: String, val baseDirectory: String,
val useTestClock: Boolean,
val isDebug: Boolean val isDebug: Boolean
) : DriverDSLInternalInterface { ) : DriverDSLInternalInterface {
override val networkMapCache = InMemoryNetworkMapCache()
private val networkMapName = "NetworkMapService" private val networkMapName = "NetworkMapService"
private val networkMapAddress = portAllocation.nextHostAndPort() private val networkMapAddress = portAllocation.nextHostAndPort()
private var networkMapNodeInfo: NodeInfo? = null
private val identity = generateKeyPair()
class State { class State {
val registeredProcesses = LinkedList<Process>() val registeredProcesses = LinkedList<Process>()
@ -284,7 +270,26 @@ class DriverDSL(
addressMustNotBeBound(networkMapAddress) addressMustNotBeBound(networkMapAddress)
} }
override fun startNode(providedName: String?, advertisedServices: Set<ServiceInfo>): Future<NodeInfo> { private fun queryNodeInfo(webAddress: HostAndPort): NodeInfo? {
val url = URL("http://${webAddress.toString()}/api/info")
try {
val conn = url.openConnection() as HttpURLConnection
conn.requestMethod = "GET"
if (conn.responseCode != 200) {
return null
}
// For now the NodeInfo is tunneled in its Kryo format over the Node's Web interface.
val om = ObjectMapper()
val module = SimpleModule("NodeInfo")
module.addDeserializer(NodeInfo::class.java, JsonSupport.NodeInfoDeserializer)
om.registerModule(module)
return om.readValue(conn.inputStream, NodeInfo::class.java)
} catch(e: Exception) {
return null
}
}
override fun startNode(providedName: String?, advertisedServices: Set<ServiceInfo>): Future<NodeInfoAndConfig> {
val messagingAddress = portAllocation.nextHostAndPort() val messagingAddress = portAllocation.nextHostAndPort()
val apiAddress = portAllocation.nextHostAndPort() val apiAddress = portAllocation.nextHostAndPort()
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
@ -301,94 +306,19 @@ class DriverDSL(
"artemisAddress" to messagingAddress.toString(), "artemisAddress" to messagingAddress.toString(),
"webAddress" to apiAddress.toString(), "webAddress" to apiAddress.toString(),
"extraAdvertisedServiceIds" to advertisedServices.joinToString(","), "extraAdvertisedServiceIds" to advertisedServices.joinToString(","),
"networkMapAddress" to networkMapAddress.toString() "networkMapAddress" to networkMapAddress.toString(),
"useTestClock" to useTestClock
) )
) )
return Executors.newSingleThreadExecutor().submit(Callable<NodeInfo> { return Executors.newSingleThreadExecutor().submit(Callable<NodeInfoAndConfig> {
registerProcess(DriverDSL.startNode(config, quasarJarPath, debugPort)) registerProcess(DriverDSL.startNode(config, quasarJarPath, debugPort))
poll("network map cache for $name") { NodeInfoAndConfig(queryNodeInfo(apiAddress)!!, config)
networkMapCache.partyNodes.forEach {
if (it.legalIdentity.name == name) {
return@poll it
}
}
null
}
}) })
} }
override fun startClient(
providedName: String,
serverAddress: HostAndPort
): Future<NodeMessagingClient> {
val nodeConfiguration = FullNodeConfiguration(
ConfigHelper.loadConfig(
baseDirectoryPath = Paths.get(baseDirectory, providedName),
allowMissingConfig = true,
configOverrides = mapOf(
"myLegalName" to providedName
)
)
)
val client = NodeMessagingClient(nodeConfiguration,
serverHostPort = serverAddress,
myIdentity = identity.public,
executor = AffinityExecutor.ServiceAffinityExecutor(providedName, 1),
persistentInbox = false // Do not create a permanent queue for our transient UI identity
)
return Executors.newSingleThreadExecutor().submit(Callable<NodeMessagingClient> {
client.configureWithDevSSLCertificate()
client.start(null)
thread { client.run() }
state.locked {
clients.add(client)
}
client
})
}
override fun startLocalServer(): Future<ArtemisMessagingServer> {
val name = "driver-local-server"
val config = FullNodeConfiguration(
ConfigHelper.loadConfig(
baseDirectoryPath = Paths.get(baseDirectory, name),
allowMissingConfig = true,
configOverrides = mapOf(
"myLegalName" to name
)
)
)
val server = ArtemisMessagingServer(config,
portAllocation.nextHostAndPort(),
networkMapCache
)
return Executors.newSingleThreadExecutor().submit(Callable<ArtemisMessagingServer> {
server.configureWithDevSSLCertificate()
server.start()
state.locked {
localServer = server
}
server
})
}
override fun start() { override fun start() {
startNetworkMapService() startNetworkMapService()
val networkMapClient = startClient("driver-$networkMapName-client", networkMapAddress).get()
val networkMapAddr = NodeMessagingClient.makeNetworkMapAddress(networkMapAddress)
networkMapCache.addMapService(networkMapClient, networkMapAddr, true)
networkMapNodeInfo = poll("network map cache for $networkMapName") {
networkMapCache.partyNodes.forEach {
if (it.legalIdentity.name == networkMapName) {
return@poll it
}
}
null
}
} }
private fun startNetworkMapService() { private fun startNetworkMapService() {
@ -396,7 +326,6 @@ class DriverDSL(
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
val nodeDirectory = "$baseDirectory/$networkMapName" val nodeDirectory = "$baseDirectory/$networkMapName"
val config = ConfigHelper.loadConfig( val config = ConfigHelper.loadConfig(
baseDirectoryPath = Paths.get(nodeDirectory), baseDirectoryPath = Paths.get(nodeDirectory),
allowMissingConfig = true, allowMissingConfig = true,
@ -405,7 +334,8 @@ class DriverDSL(
"basedir" to Paths.get(nodeDirectory).normalize().toString(), "basedir" to Paths.get(nodeDirectory).normalize().toString(),
"artemisAddress" to networkMapAddress.toString(), "artemisAddress" to networkMapAddress.toString(),
"webAddress" to apiAddress.toString(), "webAddress" to apiAddress.toString(),
"extraAdvertisedServiceIds" to "" "extraAdvertisedServiceIds" to "",
"useTestClock" to useTestClock
) )
) )

View File

@ -24,6 +24,8 @@ class APIServerImpl(val node: AbstractNode) : APIServer {
} }
} }
override fun info() = node.services.myInfo
override fun queryStates(query: StatesQuery): List<StateRef> { override fun queryStates(query: StatesQuery): List<StateRef> {
// We're going to hard code two options here for now and assume that all LinearStates are deals // We're going to hard code two options here for now and assume that all LinearStates are deals
// Would like to maybe move to a model where we take something like a JEXL string, although don't want to develop // Would like to maybe move to a model where we take something like a JEXL string, although don't want to develop

View File

@ -209,9 +209,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
// the identity key. But the infrastructure to make that easy isn't here yet. // the identity key. But the infrastructure to make that easy isn't here yet.
keyManagement = makeKeyManagementService() keyManagement = makeKeyManagementService()
api = APIServerImpl(this@AbstractNode) api = APIServerImpl(this@AbstractNode)
scheduler = NodeSchedulerService(database, services)
protocolLogicFactory = initialiseProtocolLogicFactory() protocolLogicFactory = initialiseProtocolLogicFactory()
scheduler = NodeSchedulerService(database, services, protocolLogicFactory)
val tokenizableServices = mutableListOf(storage, net, vault, keyManagement, identity, platformClock, scheduler) val tokenizableServices = mutableListOf(storage, net, vault, keyManagement, identity, platformClock, scheduler)
@ -435,15 +434,11 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
protected abstract fun makeMessagingService(): MessagingServiceInternal protected abstract fun makeMessagingService(): MessagingServiceInternal
protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?) protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps)
protected open fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return DBCheckpointStorage()
}
protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> { protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> {
val attachments = makeAttachmentStorage(dir) val attachments = makeAttachmentStorage(dir)
val checkpointStorage = initialiseCheckpointService(dir) val checkpointStorage = DBCheckpointStorage()
val transactionStorage = DBTransactionStorage() val transactionStorage = DBTransactionStorage()
_servicesThatAcceptUploads += attachments _servicesThatAcceptUploads += attachments
// Populate the partyKeys set. // Populate the partyKeys set.

View File

@ -4,6 +4,7 @@ import com.codahale.metrics.JmxReporter
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.ServiceHub import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.ServiceInfo import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.core.then
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.serialization.NodeClock import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.api.MessagingServiceInternal import com.r3corda.node.services.api.MessagingServiceInternal
@ -119,11 +120,10 @@ class Node(override val configuration: FullNodeConfiguration, networkMapAddress:
} }
val legalIdentity = obtainLegalIdentity() val legalIdentity = obtainLegalIdentity()
val myIdentityOrNullIfNetworkMapService = if (networkMapService != null) legalIdentity.owningKey else null val myIdentityOrNullIfNetworkMapService = if (networkMapService != null) legalIdentity.owningKey else null
return NodeMessagingClient(configuration, serverAddr, myIdentityOrNullIfNetworkMapService, serverThread, return NodeMessagingClient(configuration, serverAddr, myIdentityOrNullIfNetworkMapService, serverThread, database)
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } })
} }
override fun startMessagingService(cordaRPCOps: CordaRPCOps?) { override fun startMessagingService(cordaRPCOps: CordaRPCOps) {
// Start up the embedded MQ server // Start up the embedded MQ server
messageBroker?.apply { messageBroker?.apply {
runOnStop += Runnable { messageBroker?.stop() } runOnStop += Runnable { messageBroker?.stop() }
@ -268,6 +268,8 @@ class Node(override val configuration: FullNodeConfiguration, networkMapAddress:
override fun start(): Node { override fun start(): Node {
alreadyRunningNodeCheck() alreadyRunningNodeCheck()
super.start() super.start()
// Only start the service API requests once the network map registration is complete
networkMapRegistrationFuture.then {
webServer = initWebServer() webServer = initWebServer()
// Begin exporting our own metrics via JMX. // Begin exporting our own metrics via JMX.
JmxReporter. JmxReporter.
@ -284,7 +286,7 @@ class Node(override val configuration: FullNodeConfiguration, networkMapAddress:
}. }.
build(). build().
start() start()
}
shutdownThread = thread(start = false) { shutdownThread = thread(start = false) {
stop() stop()
} }

View File

@ -1,25 +0,0 @@
package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol.Fixer
import com.r3corda.protocols.TwoPartyDealProtocol.Floater
/**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
* running scheduled fixings for the [InterestRateSwap] contract.
*
* TODO: This will be replaced with the symmetric session work
*/
object FixingSessionInitiation {
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init {
services.registerProtocolInitiator(Floater::class) { Fixer(it) }
}
}
}

View File

@ -2,6 +2,7 @@ package com.r3corda.node.services.config
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.core.crypto.X509Utilities import com.r3corda.core.crypto.X509Utilities
import com.r3corda.core.exists
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
import com.typesafe.config.Config import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigFactory
@ -89,14 +90,24 @@ fun Config.getProperties(path: String): Properties {
*/ */
fun NodeSSLConfiguration.configureWithDevSSLCertificate() { fun NodeSSLConfiguration.configureWithDevSSLCertificate() {
Files.createDirectories(certificatesPath) Files.createDirectories(certificatesPath)
if (!Files.exists(trustStorePath)) { if (!trustStorePath.exists()) {
Files.copy(javaClass.classLoader.getResourceAsStream("com/r3corda/node/internal/certificates/cordatruststore.jks"), Files.copy(javaClass.classLoader.getResourceAsStream("com/r3corda/node/internal/certificates/cordatruststore.jks"),
trustStorePath) trustStorePath)
} }
if (!Files.exists(keyStorePath)) { if (!keyStorePath.exists()) {
val caKeyStore = X509Utilities.loadKeyStore( val caKeyStore = X509Utilities.loadKeyStore(
javaClass.classLoader.getResourceAsStream("com/r3corda/node/internal/certificates/cordadevcakeys.jks"), javaClass.classLoader.getResourceAsStream("com/r3corda/node/internal/certificates/cordadevcakeys.jks"),
"cordacadevpass") "cordacadevpass")
X509Utilities.createKeystoreForSSL(keyStorePath, keyStorePassword, keyStorePassword, caKeyStore, "cordacadevkeypass") X509Utilities.createKeystoreForSSL(keyStorePath, keyStorePassword, keyStorePassword, caKeyStore, "cordacadevkeypass")
} }
} }
// TODO Move this to CoreTestUtils.kt once we can pry this from the explorer
fun configureTestSSL(): NodeSSLConfiguration = object : NodeSSLConfiguration {
override val certificatesPath = Files.createTempDirectory("certs")
override val keyStorePassword: String get() = "cordacadevpass"
override val trustStorePassword: String get() = "trustpass"
init {
configureWithDevSSLCertificate()
}
}

View File

@ -5,10 +5,13 @@ import com.r3corda.core.div
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.services.ServiceInfo import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.node.internal.Node import com.r3corda.node.internal.Node
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.messaging.NodeMessagingClient import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.utilities.TestClock
import com.typesafe.config.Config import com.typesafe.config.Config
import java.nio.file.Path import java.nio.file.Path
import java.time.Clock
import java.util.* import java.util.*
interface NodeSSLConfiguration { interface NodeSSLConfiguration {
@ -46,8 +49,12 @@ class FullNodeConfiguration(config: Config) : NodeConfiguration {
val webAddress: HostAndPort by config val webAddress: HostAndPort by config
val messagingServerAddress: HostAndPort? by config.getOrElse { null } val messagingServerAddress: HostAndPort? by config.getOrElse { null }
val extraAdvertisedServiceIds: String by config val extraAdvertisedServiceIds: String by config
val useTestClock: Boolean by config.getOrElse { false }
fun createNode(): Node { fun createNode(): Node {
// This is a sanity feature do not remove.
require(!useTestClock || devMode) { "Cannot use test clock outside of dev mode" }
val advertisedServices = mutableSetOf<ServiceInfo>() val advertisedServices = mutableSetOf<ServiceInfo>()
if (!extraAdvertisedServiceIds.isNullOrEmpty()) { if (!extraAdvertisedServiceIds.isNullOrEmpty()) {
for (serviceId in extraAdvertisedServiceIds.split(",")) { for (serviceId in extraAdvertisedServiceIds.split(",")) {
@ -56,7 +63,7 @@ class FullNodeConfiguration(config: Config) : NodeConfiguration {
} }
if (networkMapAddress == null) advertisedServices.add(ServiceInfo(NetworkMapService.type)) if (networkMapAddress == null) advertisedServices.add(ServiceInfo(NetworkMapService.type))
val networkMapMessageAddress: SingleMessageRecipient? = if (networkMapAddress == null) null else NodeMessagingClient.makeNetworkMapAddress(networkMapAddress!!) val networkMapMessageAddress: SingleMessageRecipient? = if (networkMapAddress == null) null else NodeMessagingClient.makeNetworkMapAddress(networkMapAddress!!)
return Node(this, networkMapMessageAddress, advertisedServices) return Node(this, networkMapMessageAddress, advertisedServices, if(useTestClock == true) TestClock() else NodeClock())
} }
} }

View File

@ -45,7 +45,7 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe @ThreadSafe
class NodeSchedulerService(private val database: Database, class NodeSchedulerService(private val database: Database,
private val services: ServiceHubInternal, private val services: ServiceHubInternal,
private val protocolLogicRefFactory: ProtocolLogicRefFactory = ProtocolLogicRefFactory(), private val protocolLogicRefFactory: ProtocolLogicRefFactory,
private val schedulerTimerExecutor: Executor = Executors.newSingleThreadExecutor()) private val schedulerTimerExecutor: Executor = Executors.newSingleThreadExecutor())
: SchedulerService, SingletonSerializeAsToken() { : SchedulerService, SingletonSerializeAsToken() {

View File

@ -15,16 +15,15 @@ import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptorFactory import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptorFactory
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactory import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactory
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.FileSystems
import java.nio.file.Path
import java.security.KeyStore import java.security.KeyStore
import java.security.PublicKey import java.security.PublicKey
/** /**
* The base class for Artemis services that defines shared data structures and transport configuration * The base class for Artemis services that defines shared data structures and transport configuration
*
* @param certificatePath A place where Artemis can stash its message journal and other files.
* @param config The config object is used to pass in the passwords for the certificate KeyStore and TrustStore
*/ */
abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : SingletonSerializeAsToken() { abstract class ArtemisMessagingComponent() : SingletonSerializeAsToken() {
companion object { companion object {
init { init {
@ -36,7 +35,7 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
const val RPC_REQUESTS_QUEUE = "rpc.requests" const val RPC_REQUESTS_QUEUE = "rpc.requests"
@JvmStatic @JvmStatic
protected val NETWORK_MAP_ADDRESS = SimpleString(PEERS_PREFIX +"networkmap") protected val NETWORK_MAP_ADDRESS = SimpleString("${PEERS_PREFIX}networkmap")
/** /**
* Assuming the passed in target address is actually an ArtemisAddress will extract the host and port of the node. This should * Assuming the passed in target address is actually an ArtemisAddress will extract the host and port of the node. This should
@ -70,7 +69,7 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
} }
protected data class NetworkMapAddress(override val hostAndPort: HostAndPort) : SingleMessageRecipient, ArtemisAddress { protected data class NetworkMapAddress(override val hostAndPort: HostAndPort) : SingleMessageRecipient, ArtemisAddress {
override val queueName: SimpleString = NETWORK_MAP_ADDRESS override val queueName: SimpleString get() = NETWORK_MAP_ADDRESS
} }
/** /**
@ -80,11 +79,11 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
*/ */
data class NodeAddress(val identity: PublicKey, override val hostAndPort: HostAndPort) : SingleMessageRecipient, ArtemisAddress { data class NodeAddress(val identity: PublicKey, override val hostAndPort: HostAndPort) : SingleMessageRecipient, ArtemisAddress {
override val queueName: SimpleString by lazy { SimpleString(PEERS_PREFIX+identity.toBase58String()) } override val queueName: SimpleString by lazy { SimpleString(PEERS_PREFIX+identity.toBase58String()) }
override fun toString(): String = "${javaClass.simpleName}(identity = $queueName, $hostAndPort)"
}
override fun toString(): String { /** The config object is used to pass in the passwords for the certificate KeyStore and TrustStore */
return "NodeAddress(identity = $queueName, $hostAndPort" abstract val config: NodeSSLConfiguration
}
}
protected fun parseKeyFromQueueName(name: String): PublicKey { protected fun parseKeyFromQueueName(name: String): PublicKey {
require(name.startsWith(PEERS_PREFIX)) require(name.startsWith(PEERS_PREFIX))
@ -119,8 +118,10 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
} }
} }
protected fun tcpTransport(direction: ConnectionDirection, host: String, port: Int) = protected fun tcpTransport(direction: ConnectionDirection, host: String, port: Int): TransportConfiguration {
TransportConfiguration( config.keyStorePath.expectedOnDefaultFileSystem()
config.trustStorePath.expectedOnDefaultFileSystem()
return TransportConfiguration(
when (direction) { when (direction) {
ConnectionDirection.INBOUND -> NettyAcceptorFactory::class.java.name ConnectionDirection.INBOUND -> NettyAcceptorFactory::class.java.name
ConnectionDirection.OUTBOUND -> NettyConnectorFactory::class.java.name ConnectionDirection.OUTBOUND -> NettyConnectorFactory::class.java.name
@ -150,8 +151,13 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true
) )
) )
}
fun configureWithDevSSLCertificate() { fun configureWithDevSSLCertificate() {
config.configureWithDevSSLCertificate() config.configureWithDevSSLCertificate()
} }
protected fun Path.expectedOnDefaultFileSystem() {
require(fileSystem == FileSystems.getDefault()) { "Artemis only uses the default file system" }
}
} }

View File

@ -3,11 +3,15 @@ package com.r3corda.node.services.messaging
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.core.ThreadBox import com.r3corda.core.ThreadBox
import com.r3corda.core.crypto.AddressFormatException import com.r3corda.core.crypto.AddressFormatException
import com.r3corda.core.crypto.newSecureRandom import com.r3corda.core.div
import com.r3corda.core.exists
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.services.NetworkMapCache import com.r3corda.core.node.services.NetworkMapCache
import com.r3corda.core.use
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingServer.NodeLoginModule.Companion.NODE_ROLE_NAME
import com.r3corda.node.services.messaging.ArtemisMessagingServer.NodeLoginModule.Companion.RPC_ROLE_NAME
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.core.config.BridgeConfiguration import org.apache.activemq.artemis.core.config.BridgeConfiguration
import org.apache.activemq.artemis.core.config.Configuration import org.apache.activemq.artemis.core.config.Configuration
@ -17,11 +21,25 @@ import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import org.apache.activemq.artemis.spi.core.security.jaas.InVMLoginModule import org.apache.activemq.artemis.spi.core.security.jaas.RolePrincipal
import org.apache.activemq.artemis.spi.core.security.jaas.UserPrincipal
import rx.Subscription import rx.Subscription
import java.math.BigInteger import java.io.IOException
import java.nio.file.Files
import java.nio.file.Path import java.nio.file.Path
import java.security.Principal
import java.util.*
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import javax.security.auth.Subject
import javax.security.auth.callback.CallbackHandler
import javax.security.auth.callback.NameCallback
import javax.security.auth.callback.PasswordCallback
import javax.security.auth.callback.UnsupportedCallbackException
import javax.security.auth.login.AppConfigurationEntry
import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED
import javax.security.auth.login.FailedLoginException
import javax.security.auth.login.LoginException
import javax.security.auth.spi.LoginModule
// TODO: Verify that nobody can connect to us and fiddle with our config over the socket due to the secman. // TODO: Verify that nobody can connect to us and fiddle with our config over the socket due to the secman.
// TODO: Implement a discovery engine that can trigger builds of new connections when another node registers? (later) // TODO: Implement a discovery engine that can trigger builds of new connections when another node registers? (later)
@ -37,9 +55,9 @@ import javax.annotation.concurrent.ThreadSafe
* a fully connected network, trusted network or on localhost. * a fully connected network, trusted network or on localhost.
*/ */
@ThreadSafe @ThreadSafe
class ArtemisMessagingServer(config: NodeConfiguration, class ArtemisMessagingServer(override val config: NodeConfiguration,
val myHostPort: HostAndPort, val myHostPort: HostAndPort,
val networkMapCache: NetworkMapCache) : ArtemisMessagingComponent(config) { val networkMapCache: NetworkMapCache) : ArtemisMessagingComponent() {
companion object { companion object {
val log = loggerFor<ArtemisMessagingServer>() val log = loggerFor<ArtemisMessagingServer>()
} }
@ -52,6 +70,10 @@ class ArtemisMessagingServer(config: NodeConfiguration,
private lateinit var activeMQServer: ActiveMQServer private lateinit var activeMQServer: ActiveMQServer
private var networkChangeHandle: Subscription? = null private var networkChangeHandle: Subscription? = null
init {
config.basedir.expectedOnDefaultFileSystem()
}
fun start() = mutex.locked { fun start() = mutex.locked {
if (!running) { if (!running) {
configureAndStartServer() configureAndStartServer()
@ -116,12 +138,7 @@ class ArtemisMessagingServer(config: NodeConfiguration,
} }
private fun configureAndStartServer() { private fun configureAndStartServer() {
val config = createArtemisConfig(config.certificatesPath, myHostPort).apply { val config = createArtemisConfig()
securityRoles = mapOf(
"#" to setOf(Role("internal", true, true, true, true, true, true, true))
)
}
val securityManager = createArtemisSecurityManager() val securityManager = createArtemisSecurityManager()
activeMQServer = ActiveMQServerImpl(config, securityManager).apply { activeMQServer = ActiveMQServerImpl(config, securityManager).apply {
@ -157,28 +174,61 @@ class ArtemisMessagingServer(config: NodeConfiguration,
activeMQServer.start() activeMQServer.start()
} }
private fun createArtemisConfig(directory: Path, hp: HostAndPort): Configuration { private fun createArtemisConfig(): Configuration = ConfigurationImpl().apply {
val config = ConfigurationImpl() val artemisDir = config.basedir / "artemis"
setConfigDirectories(config, directory) bindingsDirectory = (artemisDir / "bindings").toString()
config.acceptorConfigurations = setOf( journalDirectory = (artemisDir / "journal").toString()
tcpTransport(ConnectionDirection.INBOUND, "0.0.0.0", hp.port) largeMessagesDirectory = (artemisDir / "largemessages").toString()
acceptorConfigurations = setOf(
tcpTransport(ConnectionDirection.INBOUND, "0.0.0.0", myHostPort.port)
) )
// Enable built in message deduplication. Note we still have to do our own as the delayed commits // Enable built in message deduplication. Note we still have to do our own as the delayed commits
// and our own definition of commit mean that the built in deduplication cannot remove all duplicates. // and our own definition of commit mean that the built in deduplication cannot remove all duplicates.
config.idCacheSize = 2000 // Artemis Default duplicate cache size i.e. a guess idCacheSize = 2000 // Artemis Default duplicate cache size i.e. a guess
config.isPersistIDCache = true isPersistIDCache = true
return config isPopulateValidatedUser = true
setupUserRoles()
}
// This gives nodes full access and RPC clients only enough to do RPC
private fun ConfigurationImpl.setupUserRoles() {
// TODO COR-307
val nodeRole = Role(NODE_ROLE_NAME, true, true, true, true, true, true, true, true)
val clientRpcRole = restrictedRole(RPC_ROLE_NAME, consume = true, createNonDurableQueue = true, deleteNonDurableQueue = true)
securityRoles = mapOf(
"#" to setOf(nodeRole),
"clients.*.rpc.responses.*" to setOf(nodeRole, clientRpcRole),
"clients.*.rpc.observations.*" to setOf(nodeRole, clientRpcRole),
RPC_REQUESTS_QUEUE to setOf(nodeRole, restrictedRole(RPC_ROLE_NAME, send = true))
)
}
private fun restrictedRole(name: String, send: Boolean = false, consume: Boolean = false, createDurableQueue: Boolean = false,
deleteDurableQueue: Boolean = false, createNonDurableQueue: Boolean = false,
deleteNonDurableQueue: Boolean = false, manage: Boolean = false, browse: Boolean = false): Role {
return Role(name, send, consume, createDurableQueue, deleteDurableQueue, createNonDurableQueue,
deleteNonDurableQueue, manage, browse)
} }
private fun createArtemisSecurityManager(): ActiveMQJAASSecurityManager { private fun createArtemisSecurityManager(): ActiveMQJAASSecurityManager {
// TODO: set up proper security configuration https://r3-cev.atlassian.net/browse/COR-307 val rpcUsersFile = config.basedir / "rpc-users.properties"
val securityConfig = SecurityConfiguration().apply { if (!rpcUsersFile.exists()) {
addUser("internal", BigInteger(128, newSecureRandom()).toString(16)) val users = Properties()
addRole("internal", "internal") users["user1"] = "test"
defaultUser = "internal" Files.newOutputStream(rpcUsersFile).use {
users.store(it, null)
}
} }
return ActiveMQJAASSecurityManager(InVMLoginModule::class.java.name, securityConfig) val securityConfig = object : SecurityConfiguration() {
// Override to make it work with our login module
override fun getAppConfigurationEntry(name: String): Array<AppConfigurationEntry> {
val options = mapOf(NodeLoginModule.FILE_KEY to rpcUsersFile)
return arrayOf(AppConfigurationEntry(name, REQUIRED, options))
}
}
return ActiveMQJAASSecurityManager(NodeLoginModule::class.java.name, securityConfig)
} }
private fun connectorExists(hostAndPort: HostAndPort) = hostAndPort.toString() in activeMQServer.configuration.connectorConfigurations private fun connectorExists(hostAndPort: HostAndPort) = hostAndPort.toString() in activeMQServer.configuration.connectorConfigurations
@ -194,12 +244,11 @@ class ArtemisMessagingServer(config: NodeConfiguration,
private fun bridgeExists(name: SimpleString) = activeMQServer.clusterManager.bridges.containsKey(name.toString()) private fun bridgeExists(name: SimpleString) = activeMQServer.clusterManager.bridges.containsKey(name.toString())
private fun deployBridge(hostAndPort: HostAndPort, name: SimpleString) { private fun deployBridge(hostAndPort: HostAndPort, name: String) {
activeMQServer.deployBridge(BridgeConfiguration().apply { activeMQServer.deployBridge(BridgeConfiguration().apply {
val nameStr = name.toString() setName(name)
setName(nameStr) queueName = name
queueName = nameStr forwardingAddress = name
forwardingAddress = nameStr
staticConnectors = listOf(hostAndPort.toString()) staticConnectors = listOf(hostAndPort.toString())
confirmationWindowSize = 100000 // a guess confirmationWindowSize = 100000 // a guess
isUseDuplicateDetection = true // Enable the bridges automatic deduplication logic isUseDuplicateDetection = true // Enable the bridges automatic deduplication logic
@ -218,7 +267,7 @@ class ArtemisMessagingServer(config: NodeConfiguration,
if (!connectorExists(hostAndPort)) if (!connectorExists(hostAndPort))
addConnector(hostAndPort) addConnector(hostAndPort)
if (!bridgeExists(name)) if (!bridgeExists(name))
deployBridge(hostAndPort, name) deployBridge(hostAndPort, name.toString())
} }
private fun maybeDestroyBridge(name: SimpleString) { private fun maybeDestroyBridge(name: SimpleString) {
@ -227,11 +276,81 @@ class ArtemisMessagingServer(config: NodeConfiguration,
} }
} }
private fun setConfigDirectories(config: Configuration, dir: Path) {
config.apply { class NodeLoginModule : LoginModule {
bindingsDirectory = dir.resolve("bindings").toString()
journalDirectory = dir.resolve("journal").toString() companion object {
largeMessagesDirectory = dir.resolve("largemessages").toString() const val FILE_KEY = "rpc-users-file"
const val NODE_ROLE_NAME = "NodeRole"
const val RPC_ROLE_NAME = "RpcRole"
}
private val users = Properties()
private var loginSucceeded: Boolean = false
private lateinit var subject: Subject
private lateinit var callbackHandler: CallbackHandler
private lateinit var principals: List<Principal>
override fun initialize(subject: Subject, callbackHandler: CallbackHandler, sharedState: Map<String, *>, options: Map<String, *>) {
this.subject = subject
this.callbackHandler = callbackHandler
val rpcUsersFile = options[FILE_KEY] as Path
if (rpcUsersFile.exists()) {
rpcUsersFile.use {
users.load(it)
} }
} }
} }
override fun login(): Boolean {
val nameCallback = NameCallback("Username: ")
val passwordCallback = PasswordCallback("Password: ", false)
try {
callbackHandler.handle(arrayOf(nameCallback, passwordCallback))
} catch (e: IOException) {
throw LoginException(e.message)
} catch (e: UnsupportedCallbackException) {
throw LoginException("${e.message} not available to obtain information from user")
}
val username = nameCallback.name ?: throw FailedLoginException("User name is null")
val receivedPassword = passwordCallback.password ?: throw FailedLoginException("Password is null")
val password = if (username == "Node") "Node" else users[username] ?: throw FailedLoginException("User does not exist")
if (password != String(receivedPassword)) {
throw FailedLoginException("Password does not match")
}
principals = listOf(
UserPrincipal(username),
RolePrincipal(if (username == "Node") NODE_ROLE_NAME else RPC_ROLE_NAME))
loginSucceeded = true
return loginSucceeded
}
override fun commit(): Boolean {
val result = loginSucceeded
if (result) {
subject.principals.addAll(principals)
}
clear()
return result
}
override fun abort(): Boolean {
clear()
return true
}
override fun logout(): Boolean {
subject.principals.removeAll(principals)
return true
}
private fun clear() {
loginSucceeded = false
}
}
}

View File

@ -11,11 +11,13 @@ import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.utilities.* import com.r3corda.node.utilities.*
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.*
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.ResultRow import org.jetbrains.exposed.sql.ResultRow
import org.jetbrains.exposed.sql.statements.InsertStatement import org.jetbrains.exposed.sql.statements.InsertStatement
import java.nio.file.FileSystems
import java.security.PublicKey import java.security.PublicKey
import java.time.Instant import java.time.Instant
import java.util.* import java.util.*
@ -49,12 +51,11 @@ import javax.annotation.concurrent.ThreadSafe
* in this class. * in this class.
*/ */
@ThreadSafe @ThreadSafe
class NodeMessagingClient(config: NodeConfiguration, class NodeMessagingClient(override val config: NodeConfiguration,
val serverHostPort: HostAndPort, val serverHostPort: HostAndPort,
val myIdentity: PublicKey?, val myIdentity: PublicKey?,
val executor: AffinityExecutor, val executor: AffinityExecutor,
val persistentInbox: Boolean = true, val database: Database) : ArtemisMessagingComponent(), MessagingServiceInternal {
val persistenceTx: (() -> Unit) -> Unit = { it() }) : ArtemisMessagingComponent(config), MessagingServiceInternal {
companion object { companion object {
val log = loggerFor<NodeMessagingClient>() val log = loggerFor<NodeMessagingClient>()
@ -86,8 +87,7 @@ class NodeMessagingClient(config: NodeConfiguration,
var rpcConsumer: ClientConsumer? = null var rpcConsumer: ClientConsumer? = null
var rpcNotificationConsumer: ClientConsumer? = null var rpcNotificationConsumer: ClientConsumer? = null
// TODO: This is not robust and needs to be replaced by more intelligently using the message queue server. var pendingRedelivery = JDBCHashSet<Message>("pending_messages",loadOnInit = true)
var undeliveredMessages = listOf<Message>()
} }
/** A registration to handle messages of different types */ /** A registration to handle messages of different types */
@ -106,23 +106,16 @@ class NodeMessagingClient(config: NodeConfiguration,
val uuid = uuidString("message_id") val uuid = uuidString("message_id")
} }
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(if (persistentInbox) { private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(
object : AbstractJDBCHashSet<UUID, Table>(Table, loadOnInit = true) { object : AbstractJDBCHashSet<UUID, Table>(Table, loadOnInit = true) {
override fun elementFromRow(row: ResultRow): UUID = row[table.uuid] override fun elementFromRow(row: ResultRow): UUID = row[table.uuid]
override fun addElementToInsert(insert: InsertStatement, entry: UUID, finalizables: MutableList<() -> Unit>) { override fun addElementToInsert(insert: InsertStatement, entry: UUID, finalizables: MutableList<() -> Unit>) {
insert[table.uuid] = entry insert[table.uuid] = entry
} }
}
} else {
HashSet<UUID>()
}) })
init { fun start(rpcOps: CordaRPCOps) {
require(config.basedir.fileSystem == FileSystems.getDefault()) { "Artemis only uses the default file system" }
}
fun start(rpcOps: CordaRPCOps? = null) {
state.locked { state.locked {
check(!started) { "start can't be called twice" } check(!started) { "start can't be called twice" }
started = true started = true
@ -135,7 +128,7 @@ class NodeMessagingClient(config: NodeConfiguration,
// Create a session. Note that the acknowledgement of messages is not flushed to // Create a session. Note that the acknowledgement of messages is not flushed to
// the Artermis journal until the default buffer size of 1MB is acknowledged. // the Artermis journal until the default buffer size of 1MB is acknowledged.
val session = clientFactory!!.createSession(true, true, ActiveMQClient.DEFAULT_ACK_BATCH_SIZE) val session = clientFactory!!.createSession("Node", "Node", false, true, true, locator.isPreAcknowledge, ActiveMQClient.DEFAULT_ACK_BATCH_SIZE)
this.session = session this.session = session
session.start() session.start()
@ -146,7 +139,7 @@ class NodeMessagingClient(config: NodeConfiguration,
val queueName = toQueueName(myAddress) val queueName = toQueueName(myAddress)
val query = session.queueQuery(queueName) val query = session.queueQuery(queueName)
if (!query.isExists) { if (!query.isExists) {
session.createQueue(queueName, queueName, persistentInbox) session.createQueue(queueName, queueName, true)
} }
knownQueues.add(queueName) knownQueues.add(queueName)
p2pConsumer = session.createConsumer(queueName) p2pConsumer = session.createConsumer(queueName)
@ -154,7 +147,6 @@ class NodeMessagingClient(config: NodeConfiguration,
// Create an RPC queue and consumer: this will service locally connected clients only (not via a // Create an RPC queue and consumer: this will service locally connected clients only (not via a
// bridge) and those clients must have authenticated. We could use a single consumer for everything // bridge) and those clients must have authenticated. We could use a single consumer for everything
// and perhaps we should, but these queues are not worth persisting. // and perhaps we should, but these queues are not worth persisting.
if (rpcOps != null) {
session.createTemporaryQueue(RPC_REQUESTS_QUEUE, RPC_REQUESTS_QUEUE) session.createTemporaryQueue(RPC_REQUESTS_QUEUE, RPC_REQUESTS_QUEUE)
session.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 1") session.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 1")
rpcConsumer = session.createConsumer(RPC_REQUESTS_QUEUE) rpcConsumer = session.createConsumer(RPC_REQUESTS_QUEUE)
@ -162,7 +154,6 @@ class NodeMessagingClient(config: NodeConfiguration,
dispatcher = createRPCDispatcher(state, rpcOps) dispatcher = createRPCDispatcher(state, rpcOps)
} }
} }
}
private var shutdownLatch = CountDownLatch(1) private var shutdownLatch = CountDownLatch(1)
@ -227,8 +218,9 @@ class NodeMessagingClient(config: NodeConfiguration,
val topic = message.getStringProperty(TOPIC_PROPERTY) val topic = message.getStringProperty(TOPIC_PROPERTY)
val sessionID = message.getLongProperty(SESSION_ID_PROPERTY) val sessionID = message.getLongProperty(SESSION_ID_PROPERTY)
// Use the magic deduplication property built into Artemis as our message identity too // Use the magic deduplication property built into Artemis as our message identity too
val uuid = UUID.fromString(message.getStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID)) val uuid = UUID.fromString(message.getStringProperty(HDR_DUPLICATE_DETECTION_ID))
log.info("received message from: ${message.address} topic: $topic sessionID: $sessionID uuid: $uuid") val user = message.getStringProperty(HDR_VALIDATED_USER)
log.info("Received message from: ${message.address} user: $user topic: $topic sessionID: $sessionID uuid: $uuid")
val body = ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } val body = ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }
@ -259,10 +251,10 @@ class NodeMessagingClient(config: NodeConfiguration,
// without causing log spam. // without causing log spam.
log.warn("Received message for ${msg.topicSession} that doesn't have any registered handlers yet") log.warn("Received message for ${msg.topicSession} that doesn't have any registered handlers yet")
// This is a hack; transient messages held in memory isn't crash resistant.
// TODO: Use Artemis API more effectively so we don't pop messages off a queue that we aren't ready to use.
state.locked { state.locked {
undeliveredMessages += msg databaseTransaction(database) {
pendingRedelivery.add(msg)
}
} }
return false return false
} }
@ -277,7 +269,7 @@ class NodeMessagingClient(config: NodeConfiguration,
// Note that handlers may re-enter this class. We aren't holding any locks and methods like // Note that handlers may re-enter this class. We aren't holding any locks and methods like
// start/run/stop have re-entrancy assertions at the top, so it is OK. // start/run/stop have re-entrancy assertions at the top, so it is OK.
executor.fetchFrom { executor.fetchFrom {
persistenceTx { databaseTransaction(database) {
callHandlers(msg, deliverTo) callHandlers(msg, deliverTo)
} }
} }
@ -346,7 +338,7 @@ class NodeMessagingClient(config: NodeConfiguration,
putLongProperty(SESSION_ID_PROPERTY, sessionID) putLongProperty(SESSION_ID_PROPERTY, sessionID)
writeBodyBufferBytes(message.data) writeBodyBufferBytes(message.data)
// Use the magic deduplication property built into Artemis as our message identity too // Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
} }
if (knownQueues.add(queueName)) { if (knownQueues.add(queueName)) {
@ -376,9 +368,12 @@ class NodeMessagingClient(config: NodeConfiguration,
val handler = Handler(topicSession, callback) val handler = Handler(topicSession, callback)
handlers.add(handler) handlers.add(handler)
val messagesToRedeliver = state.locked { val messagesToRedeliver = state.locked {
val messagesToRedeliver = undeliveredMessages val pending = ArrayList<Message>()
undeliveredMessages = listOf() databaseTransaction(database) {
messagesToRedeliver pending.addAll(pendingRedelivery)
pendingRedelivery.clear()
}
pending
} }
messagesToRedeliver.forEach { deliver(it) } messagesToRedeliver.forEach { deliver(it) }
return handler return handler
@ -391,8 +386,8 @@ class NodeMessagingClient(config: NodeConfiguration,
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message { override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message {
// TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying. // TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying.
return object : Message { return object : Message {
override val topicSession: TopicSession get() = topicSession override val topicSession: TopicSession = topicSession
override val data: ByteArray get() = data override val data: ByteArray = data
override val debugTimestamp: Instant = Instant.now() override val debugTimestamp: Instant = Instant.now()
override fun serialise(): ByteArray = this.serialise() override fun serialise(): ByteArray = this.serialise()
override val uniqueMessageId: UUID = uuid override val uniqueMessageId: UUID = uuid
@ -408,7 +403,7 @@ class NodeMessagingClient(config: NodeConfiguration,
val msg = session!!.createMessage(false).apply { val msg = session!!.createMessage(false).apply {
writeBodyBufferBytes(bits.bits) writeBodyBufferBytes(bits.bits)
// Use the magic deduplication property built into Artemis as our message identity too // Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
} }
producer!!.send(toAddress, msg) producer!!.send(toAddress, msg)
} }

View File

@ -1,73 +0,0 @@
package com.r3corda.node.services.persistence
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.StandardCopyOption
import java.util.*
import java.util.Collections.synchronizedMap
import javax.annotation.concurrent.ThreadSafe
/**
* File-based checkpoint storage, storing checkpoints per file.
*/
@ThreadSafe
class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage {
companion object {
private val logger = loggerFor<PerFileCheckpointStorage>()
private val fileExtension = ".checkpoint"
}
private val checkpointFiles = synchronizedMap(IdentityHashMap<Checkpoint, Path>())
init {
logger.trace { "Initialising per file checkpoint storage on $storeDir" }
Files.createDirectories(storeDir)
Files.list(storeDir)
.filter { it.toString().toLowerCase().endsWith(fileExtension) }
.forEach {
val checkpoint = Files.readAllBytes(it).deserialize<Checkpoint>()
checkpointFiles[checkpoint] = it
}
}
override fun addCheckpoint(checkpoint: Checkpoint) {
val fileName = "${checkpoint.id.toString().toLowerCase()}$fileExtension"
val checkpointFile = storeDir.resolve(fileName)
atomicWrite(checkpointFile, checkpoint.serialize())
logger.trace { "Stored $checkpoint to $checkpointFile" }
checkpointFiles[checkpoint] = checkpointFile
}
private fun atomicWrite(checkpointFile: Path, serialisedCheckpoint: SerializedBytes<Checkpoint>) {
val tempCheckpointFile = checkpointFile.parent.resolve("${checkpointFile.fileName}.tmp")
serialisedCheckpoint.writeToFile(tempCheckpointFile)
Files.move(tempCheckpointFile, checkpointFile, StandardCopyOption.ATOMIC_MOVE)
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
val checkpointFile = checkpointFiles.remove(checkpoint)
require(checkpointFile != null) { "Trying to removing unknown checkpoint: $checkpoint" }
Files.delete(checkpointFile)
logger.trace { "Removed $checkpoint ($checkpointFile)" }
}
override fun forEach(block: (Checkpoint)->Boolean) {
synchronized(checkpointFiles) {
for(checkpoint in checkpointFiles.keys) {
if (!block(checkpoint)) {
break
}
}
}
}
}

View File

@ -1,70 +0,0 @@
package com.r3corda.node.services.persistence
import com.r3corda.core.ThreadBox
import com.r3corda.core.bufferUntilSubscribed
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.services.TransactionStorage
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import rx.Observable
import rx.subjects.PublishSubject
import java.nio.file.Files
import java.nio.file.Path
import java.util.*
import javax.annotation.concurrent.ThreadSafe
/**
* File-based transaction storage, storing transactions per file.
*/
@ThreadSafe
class PerFileTransactionStorage(val storeDir: Path) : TransactionStorage {
companion object {
private val logger = loggerFor<PerFileCheckpointStorage>()
private val fileExtension = ".transaction"
}
private val mutex = ThreadBox(object {
val transactionsMap = HashMap<SecureHash, SignedTransaction>()
val updatesPublisher = PublishSubject.create<SignedTransaction>()
fun notify(transaction: SignedTransaction) = updatesPublisher.onNext(transaction)
})
override val updates: Observable<SignedTransaction>
get() = mutex.content.updatesPublisher
init {
logger.trace { "Initialising per file transaction storage on $storeDir" }
Files.createDirectories(storeDir)
mutex.locked {
Files.list(storeDir)
.filter { it.toString().toLowerCase().endsWith(fileExtension) }
.map { Files.readAllBytes(it).deserialize<SignedTransaction>() }
.forEach { transactionsMap[it.id] = it }
}
}
override fun addTransaction(transaction: SignedTransaction) {
val transactionFile = storeDir.resolve("${transaction.id.toString().toLowerCase()}$fileExtension")
transaction.serialize().writeToFile(transactionFile)
mutex.locked {
transactionsMap[transaction.id] = transaction
notify(transaction)
}
logger.trace { "Stored $transaction to $transactionFile" }
}
override fun getTransaction(id: SecureHash): SignedTransaction? = mutex.locked { transactionsMap[id] }
val transactions: Iterable<SignedTransaction> get() = mutex.locked { transactionsMap.values.toList() }
override fun track(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
return mutex.locked {
Pair(transactionsMap.values.toList(), updates.bufferUntilSubscribed())
}
}
}

View File

@ -11,7 +11,10 @@ import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.module.kotlin.KotlinModule import com.fasterxml.jackson.module.kotlin.KotlinModule
import com.r3corda.core.contracts.BusinessCalendar import com.r3corda.core.contracts.BusinessCalendar
import com.r3corda.core.crypto.* import com.r3corda.core.crypto.*
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.IdentityService import com.r3corda.core.node.services.IdentityService
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.EdDSAPublicKey
import java.math.BigDecimal import java.math.BigDecimal
import java.time.LocalDate import java.time.LocalDate
@ -54,6 +57,11 @@ object JsonSupport {
cordaModule.addSerializer(PublicKeyTree::class.java, PublicKeyTreeSerializer) cordaModule.addSerializer(PublicKeyTree::class.java, PublicKeyTreeSerializer)
cordaModule.addDeserializer(PublicKeyTree::class.java, PublicKeyTreeDeserializer) cordaModule.addDeserializer(PublicKeyTree::class.java, PublicKeyTreeDeserializer)
// For NodeInfo
// TODO this tunnels the Kryo representation as a Base58 encoded string. Replace when RPC supports this.
cordaModule.addSerializer(NodeInfo::class.java, NodeInfoSerializer)
cordaModule.addDeserializer(NodeInfo::class.java, NodeInfoDeserializer)
mapper.registerModule(timeModule) mapper.registerModule(timeModule)
mapper.registerModule(cordaModule) mapper.registerModule(cordaModule)
mapper.registerModule(KotlinModule()) mapper.registerModule(KotlinModule())
@ -102,6 +110,25 @@ object JsonSupport {
} }
} }
object NodeInfoSerializer : JsonSerializer<NodeInfo>() {
override fun serialize(value: NodeInfo, gen: JsonGenerator, serializers: SerializerProvider) {
gen.writeString(Base58.encode(value.serialize().bits))
}
}
object NodeInfoDeserializer : JsonDeserializer<NodeInfo>() {
override fun deserialize(parser: JsonParser, context: DeserializationContext): NodeInfo {
if (parser.currentToken == JsonToken.FIELD_NAME) {
parser.nextToken()
}
try {
return Base58.decode(parser.text).deserialize<NodeInfo>()
} catch (e: Exception) {
throw JsonParseException(parser, "Invalid NodeInfo ${parser.text}: ${e.message}")
}
}
}
object SecureHashSerializer : JsonSerializer<SecureHash>() { object SecureHashSerializer : JsonSerializer<SecureHash>() {
override fun serialize(obj: SecureHash, generator: JsonGenerator, provider: SerializerProvider) { override fun serialize(obj: SecureHash, generator: JsonGenerator, provider: SerializerProvider) {
generator.writeString(obj.toString()) generator.writeString(obj.toString())

View File

@ -0,0 +1,43 @@
package com.r3corda.node.utilities
import com.r3corda.core.serialization.SerializeAsToken
import com.r3corda.core.serialization.SerializeAsTokenContext
import com.r3corda.core.serialization.SingletonSerializationToken
import java.time.*
import javax.annotation.concurrent.ThreadSafe
/**
* A [Clock] that can have the date advanced for use in demos.
*/
@ThreadSafe
class TestClock(private var delegateClock: Clock = Clock.systemUTC()) : MutableClock(), SerializeAsToken {
private val token = SingletonSerializationToken(this)
override fun toToken(context: SerializeAsTokenContext) = SingletonSerializationToken.registerWithContext(token, this, context)
@Synchronized fun updateDate(date: LocalDate): Boolean {
val currentDate = LocalDate.now(this)
if (currentDate.isBefore(date)) {
// It's ok to increment
delegateClock = Clock.offset(delegateClock, Duration.between(currentDate.atStartOfDay(), date.atStartOfDay()))
notifyMutationObservers()
return true
}
return false
}
@Synchronized override fun instant(): Instant {
return delegateClock.instant()
}
// Do not use this. Instead seek to use ZonedDateTime methods.
override fun withZone(zone: ZoneId): Clock {
throw UnsupportedOperationException("Tokenized clock does not support withZone()")
}
@Synchronized override fun getZone(): ZoneId {
return delegateClock.zone
}
}

View File

@ -1,4 +1,3 @@
# Register a ServiceLoader service extending from com.r3corda.node.CordaPluginRegistry # Register a ServiceLoader service extending from com.r3corda.node.CordaPluginRegistry
com.r3corda.node.services.clientapi.FixingSessionInitiation$Plugin
com.r3corda.node.services.NotaryChange$Plugin com.r3corda.node.services.NotaryChange$Plugin
com.r3corda.node.services.persistence.DataVending$Plugin com.r3corda.node.services.persistence.DataVending$Plugin

View File

@ -14,3 +14,4 @@ devMode = true
certificateSigningService = "https://cordaci-netperm.corda.r3cev.com" certificateSigningService = "https://cordaci-netperm.corda.r3cev.com"
useHTTPS = false useHTTPS = false
h2port = 0 h2port = 0
useTestClock = false

View File

@ -155,16 +155,14 @@ class TwoPartyTradeProtocolTests {
bobNode.pumpReceive() bobNode.pumpReceive()
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature. // OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
databaseTransaction(bobNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1) assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
}
val storage = bobNode.storage.validatedTransactions val storage = bobNode.storage.validatedTransactions
val bobTransactionsBeforeCrash = if (storage is PerFileTransactionStorage) { val bobTransactionsBeforeCrash = databaseTransaction(bobNode.database) {
storage.transactions (storage as DBTransactionStorage).transactions
} else if (storage is DBTransactionStorage) {
databaseTransaction(bobNode.database) {
storage.transactions
} }
} else throw IllegalArgumentException("Unknown storage implementation")
assertThat(bobTransactionsBeforeCrash).isNotEmpty() assertThat(bobTransactionsBeforeCrash).isNotEmpty()
// .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk. // .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk.

View File

@ -1,22 +1,37 @@
package com.r3corda.node.services package com.r3corda.node.services
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.core.contracts.ClientToServiceCommand
import com.r3corda.core.contracts.ContractState
import com.r3corda.core.contracts.StateAndRef
import com.r3corda.core.crypto.generateKeyPair import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.messaging.Message import com.r3corda.core.messaging.Message
import com.r3corda.core.messaging.createMessage import com.r3corda.core.messaging.createMessage
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.NetworkMapCache
import com.r3corda.core.node.services.StateMachineTransactionMapping
import com.r3corda.core.node.services.Vault
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.LogHelper
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingServer import com.r3corda.node.services.messaging.*
import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.network.InMemoryNetworkMapCache import com.r3corda.node.services.network.InMemoryNetworkMapCache
import com.r3corda.node.services.transactions.PersistentUniquenessProvider
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.configureDatabase
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.freeLocalHostAndPort import com.r3corda.testing.freeLocalHostAndPort
import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.jetbrains.exposed.sql.Database
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Rule import org.junit.Rule
import org.junit.Test import org.junit.Test
import org.junit.rules.TemporaryFolder import org.junit.rules.TemporaryFolder
import rx.Observable
import java.io.Closeable
import java.net.ServerSocket import java.net.ServerSocket
import java.nio.file.Path import java.nio.file.Path
import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.LinkedBlockingQueue
@ -33,12 +48,44 @@ class ArtemisMessagingTests {
val identity = generateKeyPair() val identity = generateKeyPair()
lateinit var config: NodeConfiguration lateinit var config: NodeConfiguration
lateinit var dataSource: Closeable
lateinit var database: Database
var messagingClient: NodeMessagingClient? = null var messagingClient: NodeMessagingClient? = null
var messagingServer: ArtemisMessagingServer? = null var messagingServer: ArtemisMessagingServer? = null
val networkMapCache = InMemoryNetworkMapCache() val networkMapCache = InMemoryNetworkMapCache()
val rpcOps = object : CordaRPCOps {
override val protocolVersion: Int
get() = throw UnsupportedOperationException()
override fun stateMachinesAndUpdates(): Pair<List<StateMachineInfo>, Observable<StateMachineUpdate>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun vaultAndUpdates(): Pair<List<StateAndRef<ContractState>>, Observable<Vault.Update>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun verifiedTransactions(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun stateMachineRecordedTransactionMapping(): Pair<List<StateMachineTransactionMapping>, Observable<StateMachineTransactionMapping>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun networkMapUpdates(): Pair<List<NodeInfo>, Observable<NetworkMapCache.MapChange>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun executeCommand(command: ClientToServiceCommand): TransactionBuildResult {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
}
@Before @Before
fun setUp() { fun setUp() {
// TODO: create a base class that provides a default implementation // TODO: create a base class that provides a default implementation
@ -52,12 +99,18 @@ class ArtemisMessagingTests {
override val keyStorePassword: String = "testpass" override val keyStorePassword: String = "testpass"
override val trustStorePassword: String = "trustpass" override val trustStorePassword: String = "trustpass"
} }
LogHelper.setLevel(PersistentUniquenessProvider::class)
val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties())
dataSource = dataSourceAndDatabase.first
database = dataSourceAndDatabase.second
} }
@After @After
fun cleanUp() { fun cleanUp() {
messagingClient?.stop() messagingClient?.stop()
messagingServer?.stop() messagingServer?.stop()
dataSource.close()
LogHelper.reset(PersistentUniquenessProvider::class)
} }
@Test @Test
@ -73,7 +126,7 @@ class ArtemisMessagingTests {
val remoteServerAddress = freeLocalHostAndPort() val remoteServerAddress = freeLocalHostAndPort()
createMessagingServer(remoteServerAddress).start() createMessagingServer(remoteServerAddress).start()
createMessagingClient(server = remoteServerAddress).start() createMessagingClient(server = remoteServerAddress).start(rpcOps)
} }
@Test @Test
@ -84,14 +137,14 @@ class ArtemisMessagingTests {
createMessagingServer(serverAddress).start() createMessagingServer(serverAddress).start()
messagingClient = createMessagingClient(server = invalidServerAddress) messagingClient = createMessagingClient(server = invalidServerAddress)
assertThatThrownBy { messagingClient!!.start() } assertThatThrownBy { messagingClient!!.start(rpcOps) }
messagingClient = null messagingClient = null
} }
@Test @Test
fun `client should connect to local server`() { fun `client should connect to local server`() {
createMessagingServer().start() createMessagingServer().start()
createMessagingClient().start() createMessagingClient().start(rpcOps)
} }
@Test @Test
@ -101,7 +154,7 @@ class ArtemisMessagingTests {
createMessagingServer().start() createMessagingServer().start()
val messagingClient = createMessagingClient() val messagingClient = createMessagingClient()
messagingClient.start() messagingClient.start(rpcOps)
thread { messagingClient.run() } thread { messagingClient.run() }
messagingClient.addMessageHandler(topic) { message, r -> messagingClient.addMessageHandler(topic) { message, r ->
@ -117,11 +170,13 @@ class ArtemisMessagingTests {
} }
private fun createMessagingClient(server: HostAndPort = hostAndPort): NodeMessagingClient { private fun createMessagingClient(server: HostAndPort = hostAndPort): NodeMessagingClient {
return NodeMessagingClient(config, server, identity.public, AffinityExecutor.ServiceAffinityExecutor("ArtemisMessagingTests", 1), false).apply { return databaseTransaction(database) {
NodeMessagingClient(config, server, identity.public, AffinityExecutor.ServiceAffinityExecutor("ArtemisMessagingTests", 1), database).apply {
configureWithDevSSLCertificate() configureWithDevSSLCertificate()
messagingClient = this messagingClient = this
} }
} }
}
private fun createMessagingServer(local: HostAndPort = hostAndPort): ArtemisMessagingServer { private fun createMessagingServer(local: HostAndPort = hostAndPort): ArtemisMessagingServer {
return ArtemisMessagingServer(config, local, networkMapCache).apply { return ArtemisMessagingServer(config, local, networkMapCache).apply {

View File

@ -12,7 +12,7 @@ import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.serialization.SingletonSerializeAsToken import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.DUMMY_NOTARY import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.node.services.events.NodeSchedulerService import com.r3corda.node.services.events.NodeSchedulerService
import com.r3corda.node.services.persistence.PerFileCheckpointStorage import com.r3corda.node.services.persistence.DBCheckpointStorage
import com.r3corda.node.services.statemachine.StateMachineManager import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.node.utilities.AddOrRemove import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
@ -52,8 +52,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
val factory = ProtocolLogicRefFactory(mapOf(Pair(TestProtocolLogic::class.java.name, setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name)))) val factory = ProtocolLogicRefFactory(mapOf(Pair(TestProtocolLogic::class.java.name, setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name))))
val services: MockServiceHubInternal lateinit var services: MockServiceHubInternal
lateinit var scheduler: NodeSchedulerService lateinit var scheduler: NodeSchedulerService
lateinit var smmExecutor: AffinityExecutor.ServiceAffinityExecutor lateinit var smmExecutor: AffinityExecutor.ServiceAffinityExecutor
lateinit var dataSource: Closeable lateinit var dataSource: Closeable
@ -72,13 +71,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
val testReference: NodeSchedulerServiceTest val testReference: NodeSchedulerServiceTest
} }
init {
val kms = MockKeyManagementService(ALICE_KEY)
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"), AffinityExecutor.ServiceAffinityExecutor("test", 1), persistenceTx = { it() })
services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference {
override val testReference = this@NodeSchedulerServiceTest
}
}
@Before @Before
fun setup() { fun setup() {
@ -89,9 +81,14 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
dataSource = dataSourceAndDatabase.first dataSource = dataSourceAndDatabase.first
database = dataSourceAndDatabase.second database = dataSourceAndDatabase.second
databaseTransaction(database) { databaseTransaction(database) {
val kms = MockKeyManagementService(ALICE_KEY)
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"), AffinityExecutor.ServiceAffinityExecutor("test", 1), database)
services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference {
override val testReference = this@NodeSchedulerServiceTest
}
scheduler = NodeSchedulerService(database, services, factory, schedulerGatedExecutor) scheduler = NodeSchedulerService(database, services, factory, schedulerGatedExecutor)
smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1) smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1)
val mockSMM = StateMachineManager(services, listOf(services, scheduler), PerFileCheckpointStorage(fs.getPath("checkpoints")), smmExecutor, database) val mockSMM = StateMachineManager(services, listOf(services, scheduler), DBCheckpointStorage(), smmExecutor, database)
mockSMM.changes.subscribe { change -> mockSMM.changes.subscribe { change ->
if (change.addOrRemove == AddOrRemove.REMOVE && mockSMM.allStateMachines.isEmpty()) { if (change.addOrRemove == AddOrRemove.REMOVE && mockSMM.allStateMachines.isEmpty()) {
smmHasRemovedAllProtocols.countDown() smmHasRemovedAllProtocols.countDown()

View File

@ -1,99 +0,0 @@
package com.r3corda.node.services.persistence
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.google.common.primitives.Ints
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.api.Checkpoint
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.file.FileSystem
import java.nio.file.Files
import java.nio.file.Path
class PerFileCheckpointStorageTests {
val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
val storeDir: Path = fileSystem.getPath("store")
lateinit var checkpointStorage: PerFileCheckpointStorage
@Before
fun setUp() {
newCheckpointStorage()
}
@After
fun cleanUp() {
fileSystem.close()
}
@Test
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
@Test
fun `remove checkpoint`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.removeCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints()).isEmpty()
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
@Test
fun `remove unknown checkpoint`() {
val checkpoint = newCheckpoint()
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
checkpointStorage.removeCheckpoint(checkpoint)
}
}
@Test
fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(firstCheckpoint)
val secondCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(secondCheckpoint)
checkpointStorage.removeCheckpoint(firstCheckpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
}
@Test
fun `add checkpoint and then remove after 'restart'`() {
val originalCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(originalCheckpoint)
newCheckpointStorage()
val reconstructedCheckpoint = checkpointStorage.checkpoints().single()
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
@Test
fun `non-checkpoint files are ignored`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray())
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
private fun newCheckpointStorage() {
checkpointStorage = PerFileCheckpointStorage(storeDir)
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
}

View File

@ -1,100 +0,0 @@
package com.r3corda.node.services.persistence
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.google.common.primitives.Ints
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.NullPublicKey
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.transactions.SignedTransaction
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.file.FileSystem
import java.nio.file.Files
import java.nio.file.Path
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
class PerFileTransactionStorageTests {
val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
val storeDir: Path = fileSystem.getPath("store")
lateinit var transactionStorage: PerFileTransactionStorage
@Before
fun setUp() {
newTransactionStorage()
}
@After
fun cleanUp() {
fileSystem.close()
}
@Test
fun `empty store`() {
assertThat(transactionStorage.getTransaction(newTransaction().id)).isNull()
assertThat(transactionStorage.transactions).isEmpty()
newTransactionStorage()
assertThat(transactionStorage.transactions).isEmpty()
}
@Test
fun `one transaction`() {
val transaction = newTransaction()
transactionStorage.addTransaction(transaction)
assertTransactionIsRetrievable(transaction)
assertThat(transactionStorage.transactions).containsExactly(transaction)
newTransactionStorage()
assertTransactionIsRetrievable(transaction)
assertThat(transactionStorage.transactions).containsExactly(transaction)
}
@Test
fun `two transactions across restart`() {
val firstTransaction = newTransaction()
val secondTransaction = newTransaction()
transactionStorage.addTransaction(firstTransaction)
newTransactionStorage()
transactionStorage.addTransaction(secondTransaction)
assertTransactionIsRetrievable(firstTransaction)
assertTransactionIsRetrievable(secondTransaction)
assertThat(transactionStorage.transactions).containsOnly(firstTransaction, secondTransaction)
}
@Test
fun `non-transaction files are ignored`() {
val transactions = newTransaction()
transactionStorage.addTransaction(transactions)
Files.write(storeDir.resolve("random-non-tx-file"), "this is not a transaction!!".toByteArray())
newTransactionStorage()
assertThat(transactionStorage.transactions).containsExactly(transactions)
}
@Test
fun `updates are fired`() {
val future = SettableFuture.create<SignedTransaction>()
transactionStorage.updates.subscribe { tx -> future.set(tx) }
val expected = newTransaction()
transactionStorage.addTransaction(expected)
val actual = future.get(1, TimeUnit.SECONDS)
assertEquals(expected, actual)
}
private fun newTransactionStorage() {
transactionStorage = PerFileTransactionStorage(storeDir)
}
private fun assertTransactionIsRetrievable(transaction: SignedTransaction) {
assertThat(transactionStorage.getTransaction(transaction.id)).isEqualTo(transaction)
}
private var txCount = 0
private fun newTransaction() = SignedTransaction(
SerializedBytes(Ints.toByteArray(++txCount)),
listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
}

View File

@ -10,6 +10,7 @@ import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.persistence.checkpoints import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.node.services.statemachine.StateMachineManager.* import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.initiateSingleShotProtocol import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.InMemoryMessagingNetwork.MessageTransfer import com.r3corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
@ -73,6 +74,7 @@ class StateMachineManagerTests {
// We push through just enough messages to get only the payload sent // We push through just enough messages to get only the payload sent
node2.pumpReceive() node2.pumpReceive()
node2.disableDBCloseOnStop()
node2.stop() node2.stop()
net.runNetwork() net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
@ -95,6 +97,7 @@ class StateMachineManagerTests {
val protocol = NoOpProtocol() val protocol = NoOpProtocol()
node3.smm.add(protocol) node3.smm.add(protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
node3.disableDBCloseOnStop()
node3.stop() node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id) node3 = net.createNode(node1.info.address, forcedID = node3.id)
@ -103,6 +106,7 @@ class StateMachineManagerTests {
net.runNetwork() // Allow network map messages to flow net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush() node3.smm.executor.flush()
assertEquals(true, restoredProtocol.protocolStarted) // Now we should have run the protocol and hopefully cleared the init checkpoint assertEquals(true, restoredProtocol.protocolStarted) // Now we should have run the protocol and hopefully cleared the init checkpoint
node3.disableDBCloseOnStop()
node3.stop() node3.stop()
// Now it is completed the protocol should leave no Checkpoint. // Now it is completed the protocol should leave no Checkpoint.
@ -119,6 +123,7 @@ class StateMachineManagerTests {
node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol
// Make sure the add() has finished initial processing. // Make sure the add() has finished initial processing.
node2.smm.executor.flush() node2.smm.executor.flush()
node2.disableDBCloseOnStop()
node2.stop() // kill receiver node2.stop() // kill receiver
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1) val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload) assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
@ -138,16 +143,22 @@ class StateMachineManagerTests {
// Kick off first send and receive // Kick off first send and receive
node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload)) node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload))
databaseTransaction(node2.database) {
assertEquals(1, node2.checkpointStorage.checkpoints().size) assertEquals(1, node2.checkpointStorage.checkpoints().size)
}
// Make sure the add() has finished initial processing. // Make sure the add() has finished initial processing.
node2.smm.executor.flush() node2.smm.executor.flush()
node2.disableDBCloseOnStop()
// Restart node and thus reload the checkpoint and resend the message with same UUID // Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop() node2.stop()
databaseTransaction(node2.database) {
assertEquals(1, node2.checkpointStorage.checkpoints().size) // confirm checkpoint
}
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray()) val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
node2.manuallyCloseDB()
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>() val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync. // Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork() net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints().size)
node2b.smm.executor.flush() node2b.smm.executor.flush()
fut1.get() fut1.get()
@ -156,8 +167,12 @@ class StateMachineManagerTests {
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
// can't give a precise value as every addMessageHandler re-runs the undelivered messages // can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages") assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
databaseTransaction(node2b.database) {
assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended") assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
}
databaseTransaction(node3.database) {
assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended") assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
}
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3") assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3") assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol.get().receivedPayload, "Received payload does not match the (restarted) first value on Node 2") assertEquals(payload, secondProtocol.get().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
@ -253,8 +268,10 @@ class StateMachineManagerTests {
private inline fun <reified P : ProtocolLogic<*>> MockNode.restartAndGetRestoredProtocol( private inline fun <reified P : ProtocolLogic<*>> MockNode.restartAndGetRestoredProtocol(
networkMapNode: MockNode? = null): P { networkMapNode: MockNode? = null): P {
disableDBCloseOnStop() //Handover DB to new node copy
stop() stop()
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray()) val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
manuallyCloseDB()
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return newNode.getSingleProtocol<P>().first return newNode.getSingleProtocol<P>().first
} }

View File

@ -9,7 +9,10 @@ import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.trace import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.MessagingServiceBuilder import com.r3corda.node.services.api.MessagingServiceBuilder
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.JDBCHashSet
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.node.InMemoryMessagingNetwork.InMemoryMessaging import com.r3corda.testing.node.InMemoryMessagingNetwork.InMemoryMessaging
import org.jetbrains.exposed.sql.Database
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -80,10 +83,10 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
@Synchronized @Synchronized
fun createNode(manuallyPumped: Boolean, fun createNode(manuallyPumped: Boolean,
executor: AffinityExecutor, executor: AffinityExecutor,
persistenceTx: (() -> Unit) -> Unit) database: Database)
: Pair<Handle, com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging>> { : Pair<Handle, com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging>> {
check(counter >= 0) { "In memory network stopped: please recreate." } check(counter >= 0) { "In memory network stopped: please recreate." }
val builder = createNodeWithID(manuallyPumped, counter, executor, persistenceTx = persistenceTx) as Builder val builder = createNodeWithID(manuallyPumped, counter, executor, database = database) as Builder
counter++ counter++
val id = builder.id val id = builder.id
return Pair(id, builder) return Pair(id, builder)
@ -98,9 +101,9 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
* @param persistenceTx a lambda to wrap message handling in a transaction if necessary. * @param persistenceTx a lambda to wrap message handling in a transaction if necessary.
*/ */
fun createNodeWithID(manuallyPumped: Boolean, id: Int, executor: AffinityExecutor, description: String? = null, fun createNodeWithID(manuallyPumped: Boolean, id: Int, executor: AffinityExecutor, description: String? = null,
persistenceTx: (() -> Unit) -> Unit) database: Database)
: MessagingServiceBuilder<InMemoryMessaging> { : MessagingServiceBuilder<InMemoryMessaging> {
return Builder(manuallyPumped, Handle(id, description ?: "In memory node $id"), executor, persistenceTx) return Builder(manuallyPumped, Handle(id, description ?: "In memory node $id"), executor, database = database)
} }
interface LatencyCalculator { interface LatencyCalculator {
@ -140,11 +143,11 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
messageReceiveQueues.clear() messageReceiveQueues.clear()
} }
inner class Builder(val manuallyPumped: Boolean, val id: Handle, val executor: AffinityExecutor, val persistenceTx: (() -> Unit) -> Unit) inner class Builder(val manuallyPumped: Boolean, val id: Handle, val executor: AffinityExecutor, val database: Database)
: com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging> { : com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging> {
override fun start(): ListenableFuture<InMemoryMessaging> { override fun start(): ListenableFuture<InMemoryMessaging> {
synchronized(this@InMemoryMessagingNetwork) { synchronized(this@InMemoryMessagingNetwork) {
val node = InMemoryMessaging(manuallyPumped, id, executor, persistenceTx) val node = InMemoryMessaging(manuallyPumped, id, executor, database)
handleEndpointMap[id] = node handleEndpointMap[id] = node
return Futures.immediateFuture(node) return Futures.immediateFuture(node)
} }
@ -208,7 +211,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
inner class InMemoryMessaging(private val manuallyPumped: Boolean, inner class InMemoryMessaging(private val manuallyPumped: Boolean,
private val handle: Handle, private val handle: Handle,
private val executor: AffinityExecutor, private val executor: AffinityExecutor,
private val persistenceTx: (() -> Unit) -> Unit) private val database: Database)
: SingletonSerializeAsToken(), com.r3corda.node.services.api.MessagingServiceInternal { : SingletonSerializeAsToken(), com.r3corda.node.services.api.MessagingServiceInternal {
inner class Handler(val topicSession: TopicSession, inner class Handler(val topicSession: TopicSession,
val callback: (Message, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration val callback: (Message, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
@ -218,7 +221,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
private inner class InnerState { private inner class InnerState {
val handlers: MutableList<Handler> = ArrayList() val handlers: MutableList<Handler> = ArrayList()
val pendingRedelivery = LinkedList<MessageTransfer>() val pendingRedelivery = JDBCHashSet<Message>("pending_messages",loadOnInit = true)
} }
private val state = ThreadBox(InnerState()) private val state = ThreadBox(InnerState())
@ -244,11 +247,14 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
check(running) check(running)
val (handler, items) = state.locked { val (handler, items) = state.locked {
val handler = Handler(topicSession, callback).apply { handlers.add(this) } val handler = Handler(topicSession, callback).apply { handlers.add(this) }
val items = ArrayList(pendingRedelivery) val pending = ArrayList<Message>()
databaseTransaction(database) {
pending.addAll(pendingRedelivery)
pendingRedelivery.clear() pendingRedelivery.clear()
Pair(handler, items)
} }
for ((sender, message) in items) { Pair(handler, pending)
}
for (message in items) {
send(message, handle) send(message, handle)
} }
return handler return handler
@ -328,7 +334,9 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
// up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at // up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at
// least sometimes. // least sometimes.
log.warn("Message to ${transfer.message.topicSession} could not be delivered") log.warn("Message to ${transfer.message.topicSession} could not be delivered")
pendingRedelivery.add(transfer) databaseTransaction(database) {
pendingRedelivery.add(transfer.message)
}
null null
} else { } else {
h h
@ -348,7 +356,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
if (transfer.message.uniqueMessageId !in processedMessages) { if (transfer.message.uniqueMessageId !in processedMessages) {
executor.execute { executor.execute {
persistenceTx { databaseTransaction(database) {
for (handler in deliverTo) { for (handler in deliverTo) {
try { try {
handler.callback(transfer.message, handler) handler.callback(transfer.message, handler)

View File

@ -4,7 +4,6 @@ import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.Futures
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.div
import com.r3corda.core.messaging.SingleMessageRecipient import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.PhysicalLocation import com.r3corda.core.node.PhysicalLocation
import com.r3corda.core.node.services.KeyManagementService import com.r3corda.core.node.services.KeyManagementService
@ -15,21 +14,17 @@ import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.internal.AbstractNode import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.api.MessagingServiceInternal import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.config.NodeConfiguration import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.keys.E2ETestKeyManagementService import com.r3corda.node.services.keys.E2ETestKeyManagementService
import com.r3corda.node.services.messaging.CordaRPCOps import com.r3corda.node.services.messaging.CordaRPCOps
import com.r3corda.node.services.network.InMemoryNetworkMapService import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.network.NetworkMapService import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.DBCheckpointStorage
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import com.r3corda.node.services.transactions.SimpleNotaryService import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.services.transactions.ValidatingNotaryService import com.r3corda.node.services.transactions.ValidatingNotaryService
import com.r3corda.node.utilities.AffinityExecutor import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor import com.r3corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor
import com.r3corda.node.utilities.databaseTransaction
import org.slf4j.Logger import org.slf4j.Logger
import java.nio.file.FileSystem import java.nio.file.FileSystem
import java.nio.file.Files import java.nio.file.Files
@ -125,16 +120,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// through the java.nio API which we are already mocking via Jimfs. // through the java.nio API which we are already mocking via Jimfs.
override fun makeMessagingService(): MessagingServiceInternal { override fun makeMessagingService(): MessagingServiceInternal {
require(id >= 0) { "Node ID must be zero or positive, was passed: " + id } require(id >= 0) { "Node ID must be zero or positive, was passed: " + id }
return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, serverThread, configuration.myLegalName, return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, serverThread, configuration.myLegalName, database).start().get()
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get()
}
override fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return if (mockNet.threadPerNode) {
DBCheckpointStorage()
} else {
PerFileCheckpointStorage(dir / "checkpoints")
}
} }
override fun makeIdentityService() = MockIdentityService(mockNet.identities) override fun makeIdentityService() = MockIdentityService(mockNet.identities)
@ -143,7 +129,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
override fun makeKeyManagementService(): KeyManagementService = E2ETestKeyManagementService(partyKeys) override fun makeKeyManagementService(): KeyManagementService = E2ETestKeyManagementService(partyKeys)
override fun startMessagingService(cordaRPCOps: CordaRPCOps?) { override fun startMessagingService(cordaRPCOps: CordaRPCOps) {
// Nothing to do // Nothing to do
} }