Merge branches 'colljos-vault-transaction-notes' and 'master' of https://bitbucket.org/R3-CEV/r3prototyping into colljos-vault-transaction-notes

This commit is contained in:
Jose Coll 2016-10-27 14:59:37 +01:00
commit f2e98ffba5
47 changed files with 671 additions and 2619 deletions

View File

@ -110,7 +110,6 @@ dependencies {
compile "org.jetbrains.kotlin:kotlin-reflect:$kotlin_version"
compile "org.jetbrains.kotlin:kotlin-test:$kotlin_version"
compile "org.jetbrains.kotlinx:kotlinx-support-jdk8:0.2"
compile 'co.paralleluniverse:capsule:1.0.3'
// Unit testing helpers.
testCompile 'junit:junit:4.12'
@ -193,7 +192,7 @@ applicationDistribution.into("bin") {
task buildCordaJAR(type: FatCapsule, dependsOn: ['quasarScan', 'buildCertSigningRequestUtilityJAR']) {
applicationClass 'com.r3corda.node.MainKt'
archiveName 'corda.jar'
applicationSource = files(project.tasks.findByName('jar'), 'build/classes/main/CordaCaplet.class')
applicationSource = files(project.tasks.findByName('jar'), 'node/build/classes/main/CordaCaplet.class')
capsuleManifest {
appClassPath = ["jolokia-agent-war-${project.ext.jolokia_version}.war"]

View File

@ -0,0 +1,62 @@
package com.r3corda.client
import com.r3corda.core.random63BitValue
import com.r3corda.node.driver.driver
import com.r3corda.node.services.config.configureTestSSL
import com.r3corda.node.services.messaging.ArtemisMessagingComponent.Companion.toHostAndPort
import org.apache.activemq.artemis.api.core.ActiveMQSecurityException
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.util.concurrent.CountDownLatch
import kotlin.concurrent.thread
class CordaRPCClientTest {
private val validUsername = "user1"
private val validPassword = "test"
private val stopDriver = CountDownLatch(1)
private var driverThread: Thread? = null
private lateinit var client: CordaRPCClient
@Before
fun start() {
val driverStarted = CountDownLatch(1)
driverThread = thread {
driver {
val driverInfo = startNode().get()
client = CordaRPCClient(toHostAndPort(driverInfo.nodeInfo.address), configureTestSSL())
driverStarted.countDown()
stopDriver.await()
}
}
driverStarted.await()
}
@After
fun stop() {
stopDriver.countDown()
driverThread?.join()
}
@Test
fun `log in with valid username and password`() {
client.start(validUsername, validPassword)
}
@Test
fun `log in with unknown user`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(random63BitValue().toString(), validPassword)
}
}
@Test
fun `log in with incorrect password`() {
assertThatExceptionOfType(ActiveMQSecurityException::class.java).isThrownBy {
client.start(validUsername, random63BitValue().toString())
}
}
}

View File

@ -1,6 +1,5 @@
package com.r3corda.client
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.client.model.NodeMonitorModel
import com.r3corda.client.model.ProgressTrackingEvent
import com.r3corda.core.bufferUntilSubscribed
@ -14,8 +13,7 @@ import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.serialization.OpaqueBytes
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.driver.driver
import com.r3corda.node.driver.startClient
import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.config.configureTestSSL
import com.r3corda.node.services.messaging.StateMachineUpdate
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.testing.expect
@ -26,16 +24,15 @@ import org.junit.Before
import org.junit.Test
import rx.Observable
import rx.Observer
import java.util.concurrent.CountDownLatch
import kotlin.concurrent.thread
class NodeMonitorModelTest {
lateinit var aliceNode: NodeInfo
lateinit var notaryNode: NodeInfo
lateinit var aliceClient: NodeMessagingClient
val driverStarted = SettableFuture.create<Unit>()
val stopDriver = SettableFuture.create<Unit>()
val driverStopped = SettableFuture.create<Unit>()
val stopDriver = CountDownLatch(1)
var driverThread: Thread? = null
lateinit var stateMachineTransactionMapping: Observable<StateMachineTransactionMapping>
lateinit var stateMachineUpdates: Observable<StateMachineUpdate>
@ -48,15 +45,15 @@ class NodeMonitorModelTest {
@Before
fun start() {
thread {
val driverStarted = CountDownLatch(1)
driverThread = thread {
driver {
val aliceNodeFuture = startNode("Alice")
val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
aliceNode = aliceNodeFuture.get()
notaryNode = notaryNodeFuture.get()
aliceClient = startClient(aliceNode).get()
newNode = { nodeName -> startNode(nodeName).get() }
aliceNode = aliceNodeFuture.get().nodeInfo
notaryNode = notaryNodeFuture.get().nodeInfo
newNode = { nodeName -> startNode(nodeName).get().nodeInfo }
val monitor = NodeMonitorModel()
stateMachineTransactionMapping = monitor.stateMachineTransactionMapping.bufferUntilSubscribed()
@ -67,20 +64,18 @@ class NodeMonitorModelTest {
networkMapUpdates = monitor.networkMap.bufferUntilSubscribed()
clientToService = monitor.clientToService
monitor.register(aliceNode, aliceClient.config.certificatesPath)
driverStarted.set(Unit)
stopDriver.get()
monitor.register(aliceNode, configureTestSSL(), "user1", "test")
driverStarted.countDown()
stopDriver.await()
}
driverStopped.set(Unit)
}
driverStarted.get()
driverStarted.await()
}
@After
fun stop() {
stopDriver.set(Unit)
driverStopped.get()
stopDriver.countDown()
driverThread?.join()
}
@Test

View File

@ -24,19 +24,12 @@ import kotlin.concurrent.thread
* useful tasks. See the documentation for [proxy] or review the docsite to learn more about how this API works.
*/
@ThreadSafe
class CordaRPCClient(val host: HostAndPort, certificatesPath: Path) : Closeable, ArtemisMessagingComponent(sslConfig(certificatesPath)) {
class CordaRPCClient(val host: HostAndPort, override val config: NodeSSLConfiguration) : Closeable, ArtemisMessagingComponent() {
companion object {
private val rpcLog = LoggerFactory.getLogger("com.r3corda.rpc")
private fun sslConfig(certificatesPath: Path): NodeSSLConfiguration = object : NodeSSLConfiguration {
override val certificatesPath: Path = certificatesPath
override val keyStorePassword = "cordacadevpass"
override val trustStorePassword = "trustpass"
}
}
// TODO: Certificate handling for clients needs more work.
private inner class State {
var running = false
lateinit var sessionFactory: ClientSessionFactory
@ -57,7 +50,7 @@ class CordaRPCClient(val host: HostAndPort, certificatesPath: Path) : Closeable,
/** Opens the connection to the server and registers a JVM shutdown hook to cleanly disconnect. */
@Throws(ActiveMQNotConnectedException::class)
fun start() {
fun start(username: String, password: String) {
state.locked {
check(!running)
checkStorePasswords() // Check the password.
@ -66,7 +59,7 @@ class CordaRPCClient(val host: HostAndPort, certificatesPath: Path) : Closeable,
sessionFactory = serverLocator.createSessionFactory()
// We use our initial connection ID as the queue namespace.
myID = sessionFactory.connection.id as Int and 0x000000FFFFFF
session = sessionFactory.createSession()
session = sessionFactory.createSession(username, password, false, true, true, serverLocator.isPreAcknowledge, serverLocator.ackBatchSize)
session.start()
clientImpl = CordaRPCClientImpl(session, state.lock, myAddressPrefix)
running = true

View File

@ -8,14 +8,14 @@ import com.r3corda.core.node.services.StateMachineTransactionMapping
import com.r3corda.core.node.services.Vault
import com.r3corda.core.protocols.StateMachineRunId
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.messaging.ArtemisMessagingComponent
import com.r3corda.node.services.config.NodeSSLConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingComponent.Companion.toHostAndPort
import com.r3corda.node.services.messaging.CordaRPCOps
import com.r3corda.node.services.messaging.StateMachineInfo
import com.r3corda.node.services.messaging.StateMachineUpdate
import javafx.beans.property.SimpleObjectProperty
import rx.Observable
import rx.subjects.PublishSubject
import java.nio.file.Path
data class ProgressTrackingEvent(val stateMachineId: StateMachineRunId, val message: String) {
companion object {
@ -54,14 +54,11 @@ class NodeMonitorModel {
/**
* Register for updates to/from a given vault.
* @param messagingService The messaging to use for communication.
* @param monitorNodeInfo the [Node] to connect to.
* TODO provide an unsubscribe mechanism
*/
fun register(vaultMonitorNodeInfo: NodeInfo, certificatesPath: Path) {
val client = CordaRPCClient(ArtemisMessagingComponent.toHostAndPort(vaultMonitorNodeInfo.address), certificatesPath)
client.start()
fun register(vaultMonitorNodeInfo: NodeInfo, sslConfig: NodeSSLConfiguration, username: String, password: String) {
val client = CordaRPCClient(toHostAndPort(vaultMonitorNodeInfo.address), sslConfig)
client.start(username, password)
val proxy = client.proxy()
val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates()

View File

@ -1,801 +0,0 @@
package com.r3corda.contracts
import com.r3corda.core.contracts.*
import com.r3corda.core.contracts.clauses.*
import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.services.ServiceType
import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.utilities.suggestInterestRateAnnouncementTimeWindow
import com.r3corda.protocols.TwoPartyDealProtocol
import org.apache.commons.jexl3.JexlBuilder
import org.apache.commons.jexl3.MapContext
import java.math.BigDecimal
import java.math.RoundingMode
import java.security.PublicKey
import java.time.LocalDate
import java.util.*
val IRS_PROGRAM_ID = InterestRateSwap()
// This is a placeholder for some types that we haven't identified exactly what they are just yet for things still in discussion
open class UnknownType() {
override fun equals(other: Any?): Boolean {
return (other is UnknownType)
}
override fun hashCode() = 1
}
/**
* Event superclass - everything happens on a date.
*/
open class Event(val date: LocalDate) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is Event) return false
if (date != other.date) return false
return true
}
override fun hashCode() = Objects.hash(date)
}
/**
* Top level PaymentEvent class - represents an obligation to pay an amount on a given date, which may be either in the past or the future.
*/
abstract class PaymentEvent(date: LocalDate) : Event(date) {
abstract fun calculate(): Amount<Currency>
}
/**
* A [RatePaymentEvent] represents a dated obligation of payment.
* It is a specialisation / modification of a basic cash flow event (to be written) that has some additional assistance
* functions for interest rate swap legs of the fixed and floating nature.
* For the fixed leg, the rate is already known at creation and therefore the flows can be pre-determined.
* For the floating leg, the rate refers to a reference rate which is to be "fixed" at a point in the future.
*/
abstract class RatePaymentEvent(date: LocalDate,
val accrualStartDate: LocalDate,
val accrualEndDate: LocalDate,
val dayCountBasisDay: DayCountBasisDay,
val dayCountBasisYear: DayCountBasisYear,
val notional: Amount<Currency>,
val rate: Rate) : PaymentEvent(date) {
companion object {
val CSVHeader = "AccrualStartDate,AccrualEndDate,DayCountFactor,Days,Date,Ccy,Notional,Rate,Flow"
}
override fun calculate(): Amount<Currency> = flow
abstract val flow: Amount<Currency>
val days: Int get() = calculateDaysBetween(accrualStartDate, accrualEndDate, dayCountBasisYear, dayCountBasisDay)
// TODO : Fix below (use daycount convention for division, not hardcoded 360 etc)
val dayCountFactor: BigDecimal get() = (BigDecimal(days).divide(BigDecimal(360.0), 8, RoundingMode.HALF_UP)).setScale(4, RoundingMode.HALF_UP)
open fun asCSV() = "$accrualStartDate,$accrualEndDate,$dayCountFactor,$days,$date,${notional.token},$notional,$rate,$flow"
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is RatePaymentEvent) return false
if (accrualStartDate != other.accrualStartDate) return false
if (accrualEndDate != other.accrualEndDate) return false
if (dayCountBasisDay != other.dayCountBasisDay) return false
if (dayCountBasisYear != other.dayCountBasisYear) return false
if (notional != other.notional) return false
if (rate != other.rate) return false
// if (flow != other.flow) return false // Flow is derived
return super.equals(other)
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(accrualStartDate, accrualEndDate, dayCountBasisDay,
dayCountBasisYear, notional, rate)
}
/**
* Basic class for the Fixed Rate Payments on the fixed leg - see [RatePaymentEvent].
* Assumes that the rate is valid.
*/
class FixedRatePaymentEvent(date: LocalDate,
accrualStartDate: LocalDate,
accrualEndDate: LocalDate,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
notional: Amount<Currency>,
rate: Rate) :
RatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay, dayCountBasisYear, notional, rate) {
companion object {
val CSVHeader = RatePaymentEvent.CSVHeader
}
override val flow: Amount<Currency> get() = Amount(dayCountFactor.times(BigDecimal(notional.quantity)).times(rate.ratioUnit!!.value).toLong(), notional.token)
override fun toString(): String =
"FixedRatePaymentEvent $accrualStartDate -> $accrualEndDate : $dayCountFactor : $days : $date : $notional : $rate : $flow"
}
/**
* Basic class for the Floating Rate Payments on the floating leg - see [RatePaymentEvent].
* If the rate is null returns a zero payment. // TODO: Is this the desired behaviour?
*/
class FloatingRatePaymentEvent(date: LocalDate,
accrualStartDate: LocalDate,
accrualEndDate: LocalDate,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
val fixingDate: LocalDate,
notional: Amount<Currency>,
rate: Rate) : RatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay, dayCountBasisYear, notional, rate) {
companion object {
val CSVHeader = RatePaymentEvent.CSVHeader + ",FixingDate"
}
override val flow: Amount<Currency> get() {
// TODO: Should an uncalculated amount return a zero ? null ? etc.
val v = rate.ratioUnit?.value ?: return Amount(0, notional.token)
return Amount(dayCountFactor.times(BigDecimal(notional.quantity)).times(v).toLong(), notional.token)
}
override fun toString(): String = "FloatingPaymentEvent $accrualStartDate -> $accrualEndDate : $dayCountFactor : $days : $date : $notional : $rate (fix on $fixingDate): $flow"
override fun asCSV(): String = "$accrualStartDate,$accrualEndDate,$dayCountFactor,$days,$date,${notional.token},$notional,$fixingDate,$rate,$flow"
/**
* Used for making immutables.
*/
fun withNewRate(newRate: Rate): FloatingRatePaymentEvent =
FloatingRatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay,
dayCountBasisYear, fixingDate, notional, newRate)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
other as FloatingRatePaymentEvent
if (fixingDate != other.fixingDate) return false
return super.equals(other)
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(fixingDate)
// Can't autogenerate as not a data class :-(
fun copy(date: LocalDate = this.date,
accrualStartDate: LocalDate = this.accrualStartDate,
accrualEndDate: LocalDate = this.accrualEndDate,
dayCountBasisDay: DayCountBasisDay = this.dayCountBasisDay,
dayCountBasisYear: DayCountBasisYear = this.dayCountBasisYear,
fixingDate: LocalDate = this.fixingDate,
notional: Amount<Currency> = this.notional,
rate: Rate = this.rate) = FloatingRatePaymentEvent(date, accrualStartDate, accrualEndDate, dayCountBasisDay, dayCountBasisYear, fixingDate, notional, rate)
}
/**
* The Interest Rate Swap class. For a quick overview of what an IRS is, see here - http://www.pimco.co.uk/EN/Education/Pages/InterestRateSwapsBasics1-08.aspx (no endorsement).
* This contract has 4 significant data classes within it, the "Common", "Calculation", "FixedLeg" and "FloatingLeg".
* It also has 4 commands, "Agree", "Fix", "Pay" and "Mature".
* Currently, we are not interested (excuse pun) in valuing the swap, calculating the PVs, DFs and all that good stuff (soon though).
* This is just a representation of a vanilla Fixed vs Floating (same currency) IRS in the R3 prototype model.
*/
class InterestRateSwap() : Contract {
override val legalContractReference = SecureHash.sha256("is_this_the_text_of_the_contract ? TBD")
companion object {
val oracleType = ServiceType.corda.getSubType("interest_rates")
}
/**
* This Common area contains all the information that is not leg specific.
*/
data class Common(
val baseCurrency: Currency,
val eligibleCurrency: Currency,
val eligibleCreditSupport: String,
val independentAmounts: Amount<Currency>,
val threshold: Amount<Currency>,
val minimumTransferAmount: Amount<Currency>,
val rounding: Amount<Currency>,
val valuationDateDescription: String, // This describes (in english) how regularly the swap is to be valued, e.g. "every local working day"
val notificationTime: String,
val resolutionTime: String,
val interestRate: ReferenceRate,
val addressForTransfers: String,
val exposure: UnknownType,
val localBusinessDay: BusinessCalendar,
val dailyInterestAmount: Expression,
val tradeID: String,
val hashLegalDocs: String
)
/**
* The Calculation data class is "mutable" through out the life of the swap, as in, it's the only thing that contains
* data that will changed from state to state (Recall that the design insists that everything is immutable, so we actually
* copy / update for each transition).
*/
data class Calculation(
val expression: Expression,
val floatingLegPaymentSchedule: Map<LocalDate, FloatingRatePaymentEvent>,
val fixedLegPaymentSchedule: Map<LocalDate, FixedRatePaymentEvent>
) {
/**
* Gets the date of the next fixing.
* @return LocalDate or null if no more fixings.
*/
fun nextFixingDate(): LocalDate? {
return floatingLegPaymentSchedule.
filter { it.value.rate is ReferenceRate }.// TODO - a better way to determine what fixings remain to be fixed
minBy { it.value.fixingDate.toEpochDay() }?.value?.fixingDate
}
/**
* Returns the fixing for that date.
*/
fun getFixing(date: LocalDate): FloatingRatePaymentEvent =
floatingLegPaymentSchedule.values.single { it.fixingDate == date }
/**
* Returns a copy after modifying (applying) the fixing for that date.
*/
fun applyFixing(date: LocalDate, newRate: FixedRate): Calculation {
val paymentEvent = getFixing(date)
val newFloatingLPS = floatingLegPaymentSchedule + (paymentEvent.date to paymentEvent.withNewRate(newRate))
return Calculation(expression = expression,
floatingLegPaymentSchedule = newFloatingLPS,
fixedLegPaymentSchedule = fixedLegPaymentSchedule)
}
}
abstract class CommonLeg(
val notional: Amount<Currency>,
val paymentFrequency: Frequency,
val effectiveDate: LocalDate,
val effectiveDateAdjustment: DateRollConvention?,
val terminationDate: LocalDate,
val terminationDateAdjustment: DateRollConvention?,
val dayCountBasisDay: DayCountBasisDay,
val dayCountBasisYear: DayCountBasisYear,
val dayInMonth: Int,
val paymentRule: PaymentRule,
val paymentDelay: Int,
val paymentCalendar: BusinessCalendar,
val interestPeriodAdjustment: AccrualAdjustment
) {
override fun toString(): String {
return "Notional=$notional,PaymentFrequency=$paymentFrequency,EffectiveDate=$effectiveDate,EffectiveDateAdjustment:$effectiveDateAdjustment,TerminatationDate=$terminationDate," +
"TerminationDateAdjustment=$terminationDateAdjustment,DayCountBasis=$dayCountBasisDay/$dayCountBasisYear,DayInMonth=$dayInMonth," +
"PaymentRule=$paymentRule,PaymentDelay=$paymentDelay,PaymentCalendar=$paymentCalendar,InterestPeriodAdjustment=$interestPeriodAdjustment"
}
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
other as CommonLeg
if (notional != other.notional) return false
if (paymentFrequency != other.paymentFrequency) return false
if (effectiveDate != other.effectiveDate) return false
if (effectiveDateAdjustment != other.effectiveDateAdjustment) return false
if (terminationDate != other.terminationDate) return false
if (terminationDateAdjustment != other.terminationDateAdjustment) return false
if (dayCountBasisDay != other.dayCountBasisDay) return false
if (dayCountBasisYear != other.dayCountBasisYear) return false
if (dayInMonth != other.dayInMonth) return false
if (paymentRule != other.paymentRule) return false
if (paymentDelay != other.paymentDelay) return false
if (paymentCalendar != other.paymentCalendar) return false
if (interestPeriodAdjustment != other.interestPeriodAdjustment) return false
return true
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(notional, paymentFrequency, effectiveDate,
effectiveDateAdjustment, terminationDate, effectiveDateAdjustment, terminationDate, terminationDateAdjustment,
dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment)
}
open class FixedLeg(
var fixedRatePayer: Party,
notional: Amount<Currency>,
paymentFrequency: Frequency,
effectiveDate: LocalDate,
effectiveDateAdjustment: DateRollConvention?,
terminationDate: LocalDate,
terminationDateAdjustment: DateRollConvention?,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
dayInMonth: Int,
paymentRule: PaymentRule,
paymentDelay: Int,
paymentCalendar: BusinessCalendar,
interestPeriodAdjustment: AccrualAdjustment,
var fixedRate: FixedRate,
var rollConvention: DateRollConvention // TODO - best way of implementing - still awaiting some clarity
) : CommonLeg
(notional, paymentFrequency, effectiveDate, effectiveDateAdjustment, terminationDate, terminationDateAdjustment,
dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment) {
override fun toString(): String = "FixedLeg(Payer=$fixedRatePayer," + super.toString() + ",fixedRate=$fixedRate," +
"rollConvention=$rollConvention"
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
if (!super.equals(other)) return false
other as FixedLeg
if (fixedRatePayer != other.fixedRatePayer) return false
if (fixedRate != other.fixedRate) return false
if (rollConvention != other.rollConvention) return false
return true
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(fixedRatePayer, fixedRate, rollConvention)
// Can't autogenerate as not a data class :-(
fun copy(fixedRatePayer: Party = this.fixedRatePayer,
notional: Amount<Currency> = this.notional,
paymentFrequency: Frequency = this.paymentFrequency,
effectiveDate: LocalDate = this.effectiveDate,
effectiveDateAdjustment: DateRollConvention? = this.effectiveDateAdjustment,
terminationDate: LocalDate = this.terminationDate,
terminationDateAdjustment: DateRollConvention? = this.terminationDateAdjustment,
dayCountBasisDay: DayCountBasisDay = this.dayCountBasisDay,
dayCountBasisYear: DayCountBasisYear = this.dayCountBasisYear,
dayInMonth: Int = this.dayInMonth,
paymentRule: PaymentRule = this.paymentRule,
paymentDelay: Int = this.paymentDelay,
paymentCalendar: BusinessCalendar = this.paymentCalendar,
interestPeriodAdjustment: AccrualAdjustment = this.interestPeriodAdjustment,
fixedRate: FixedRate = this.fixedRate) = FixedLeg(
fixedRatePayer, notional, paymentFrequency, effectiveDate, effectiveDateAdjustment, terminationDate,
terminationDateAdjustment, dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay,
paymentCalendar, interestPeriodAdjustment, fixedRate, rollConvention)
}
open class FloatingLeg(
var floatingRatePayer: Party,
notional: Amount<Currency>,
paymentFrequency: Frequency,
effectiveDate: LocalDate,
effectiveDateAdjustment: DateRollConvention?,
terminationDate: LocalDate,
terminationDateAdjustment: DateRollConvention?,
dayCountBasisDay: DayCountBasisDay,
dayCountBasisYear: DayCountBasisYear,
dayInMonth: Int,
paymentRule: PaymentRule,
paymentDelay: Int,
paymentCalendar: BusinessCalendar,
interestPeriodAdjustment: AccrualAdjustment,
var rollConvention: DateRollConvention,
var fixingRollConvention: DateRollConvention,
var resetDayInMonth: Int,
var fixingPeriodOffset: Int,
var resetRule: PaymentRule,
var fixingsPerPayment: Frequency,
var fixingCalendar: BusinessCalendar,
var index: String,
var indexSource: String,
var indexTenor: Tenor
) : CommonLeg(notional, paymentFrequency, effectiveDate, effectiveDateAdjustment, terminationDate, terminationDateAdjustment,
dayCountBasisDay, dayCountBasisYear, dayInMonth, paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment) {
override fun toString(): String = "FloatingLeg(Payer=$floatingRatePayer," + super.toString() +
"rollConvention=$rollConvention,FixingRollConvention=$fixingRollConvention,ResetDayInMonth=$resetDayInMonth" +
"FixingPeriondOffset=$fixingPeriodOffset,ResetRule=$resetRule,FixingsPerPayment=$fixingsPerPayment,FixingCalendar=$fixingCalendar," +
"Index=$index,IndexSource=$indexSource,IndexTenor=$indexTenor"
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
if (!super.equals(other)) return false
other as FloatingLeg
if (floatingRatePayer != other.floatingRatePayer) return false
if (rollConvention != other.rollConvention) return false
if (fixingRollConvention != other.fixingRollConvention) return false
if (resetDayInMonth != other.resetDayInMonth) return false
if (fixingPeriodOffset != other.fixingPeriodOffset) return false
if (resetRule != other.resetRule) return false
if (fixingsPerPayment != other.fixingsPerPayment) return false
if (fixingCalendar != other.fixingCalendar) return false
if (index != other.index) return false
if (indexSource != other.indexSource) return false
if (indexTenor != other.indexTenor) return false
return true
}
override fun hashCode() = super.hashCode() + 31 * Objects.hash(floatingRatePayer, rollConvention,
fixingRollConvention, resetDayInMonth, fixingPeriodOffset, resetRule, fixingsPerPayment, fixingCalendar,
index, indexSource, indexTenor)
fun copy(floatingRatePayer: Party = this.floatingRatePayer,
notional: Amount<Currency> = this.notional,
paymentFrequency: Frequency = this.paymentFrequency,
effectiveDate: LocalDate = this.effectiveDate,
effectiveDateAdjustment: DateRollConvention? = this.effectiveDateAdjustment,
terminationDate: LocalDate = this.terminationDate,
terminationDateAdjustment: DateRollConvention? = this.terminationDateAdjustment,
dayCountBasisDay: DayCountBasisDay = this.dayCountBasisDay,
dayCountBasisYear: DayCountBasisYear = this.dayCountBasisYear,
dayInMonth: Int = this.dayInMonth,
paymentRule: PaymentRule = this.paymentRule,
paymentDelay: Int = this.paymentDelay,
paymentCalendar: BusinessCalendar = this.paymentCalendar,
interestPeriodAdjustment: AccrualAdjustment = this.interestPeriodAdjustment,
rollConvention: DateRollConvention = this.rollConvention,
fixingRollConvention: DateRollConvention = this.fixingRollConvention,
resetDayInMonth: Int = this.resetDayInMonth,
fixingPeriod: Int = this.fixingPeriodOffset,
resetRule: PaymentRule = this.resetRule,
fixingsPerPayment: Frequency = this.fixingsPerPayment,
fixingCalendar: BusinessCalendar = this.fixingCalendar,
index: String = this.index,
indexSource: String = this.indexSource,
indexTenor: Tenor = this.indexTenor
) = FloatingLeg(floatingRatePayer, notional, paymentFrequency, effectiveDate, effectiveDateAdjustment,
terminationDate, terminationDateAdjustment, dayCountBasisDay, dayCountBasisYear, dayInMonth,
paymentRule, paymentDelay, paymentCalendar, interestPeriodAdjustment, rollConvention,
fixingRollConvention, resetDayInMonth, fixingPeriod, resetRule, fixingsPerPayment,
fixingCalendar, index, indexSource, indexTenor)
}
override fun verify(tx: TransactionForContract) = verifyClause(tx, AllComposition(Clauses.Timestamped(), Clauses.Group()), tx.commands.select<Commands>())
interface Clauses {
/**
* Common superclass for IRS contract clauses, which defines behaviour on match/no-match, and provides
* helper functions for the clauses.
*/
abstract class AbstractIRSClause : Clause<State, Commands, UniqueIdentifier>() {
// These functions may make more sense to use for basket types, but for now let's leave them here
fun checkLegDates(legs: List<CommonLeg>) {
requireThat {
"Effective date is before termination date" by legs.all { it.effectiveDate < it.terminationDate }
"Effective dates are in alignment" by legs.all { it.effectiveDate == legs[0].effectiveDate }
"Termination dates are in alignment" by legs.all { it.terminationDate == legs[0].terminationDate }
}
}
fun checkLegAmounts(legs: List<CommonLeg>) {
requireThat {
"The notional is non zero" by legs.any { it.notional.quantity > (0).toLong() }
"The notional for all legs must be the same" by legs.all { it.notional == legs[0].notional }
}
for (leg: CommonLeg in legs) {
if (leg is FixedLeg) {
requireThat {
// TODO: Confirm: would someone really enter a swap with a negative fixed rate?
"Fixed leg rate must be positive" by leg.fixedRate.isPositive()
}
}
}
}
// TODO: After business rules discussion, add further checks to the schedules and rates
fun checkSchedules(@Suppress("UNUSED_PARAMETER") legs: List<CommonLeg>): Boolean = true
fun checkRates(@Suppress("UNUSED_PARAMETER") legs: List<CommonLeg>): Boolean = true
/**
* Compares two schedules of Floating Leg Payments, returns the difference (i.e. omissions in either leg or changes to the values).
*/
fun getFloatingLegPaymentsDifferences(payments1: Map<LocalDate, Event>, payments2: Map<LocalDate, Event>): List<Pair<LocalDate, Pair<FloatingRatePaymentEvent, FloatingRatePaymentEvent>>> {
val diff1 = payments1.filter { payments1[it.key] != payments2[it.key] }
val diff2 = payments2.filter { payments1[it.key] != payments2[it.key] }
return (diff1.keys + diff2.keys).map {
it to Pair(diff1[it] as FloatingRatePaymentEvent, diff2[it] as FloatingRatePaymentEvent)
}
}
}
class Group : GroupClauseVerifier<State, Commands, UniqueIdentifier>(AnyComposition(Agree(), Fix(), Pay(), Mature())) {
override fun groupStates(tx: TransactionForContract): List<TransactionForContract.InOutGroup<State, UniqueIdentifier>>
// Group by Trade ID for in / out states
= tx.groupStates() { state -> state.linearId }
}
class Timestamped : Clause<ContractState, Commands, Unit>() {
override fun verify(tx: TransactionForContract,
inputs: List<ContractState>,
outputs: List<ContractState>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: Unit?): Set<Commands> {
require(tx.timestamp?.midpoint != null) { "must be timestamped" }
// We return an empty set because we don't process any commands
return emptySet()
}
}
class Agree : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Agree::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Agree>()
val irs = outputs.filterIsInstance<State>().single()
requireThat {
"There are no in states for an agreement" by inputs.isEmpty()
"There are events in the fix schedule" by (irs.calculation.fixedLegPaymentSchedule.size > 0)
"There are events in the float schedule" by (irs.calculation.floatingLegPaymentSchedule.size > 0)
"All notionals must be non zero" by (irs.fixedLeg.notional.quantity > 0 && irs.floatingLeg.notional.quantity > 0)
"The fixed leg rate must be positive" by (irs.fixedLeg.fixedRate.isPositive())
"The currency of the notionals must be the same" by (irs.fixedLeg.notional.token == irs.floatingLeg.notional.token)
"All leg notionals must be the same" by (irs.fixedLeg.notional == irs.floatingLeg.notional)
"The effective date is before the termination date for the fixed leg" by (irs.fixedLeg.effectiveDate < irs.fixedLeg.terminationDate)
"The effective date is before the termination date for the floating leg" by (irs.floatingLeg.effectiveDate < irs.floatingLeg.terminationDate)
"The effective dates are aligned" by (irs.floatingLeg.effectiveDate == irs.fixedLeg.effectiveDate)
"The termination dates are aligned" by (irs.floatingLeg.terminationDate == irs.fixedLeg.terminationDate)
"The rates are valid" by checkRates(listOf(irs.fixedLeg, irs.floatingLeg))
"The schedules are valid" by checkSchedules(listOf(irs.fixedLeg, irs.floatingLeg))
"The fixing period date offset cannot be negative" by (irs.floatingLeg.fixingPeriodOffset >= 0)
// TODO: further tests
}
checkLegAmounts(listOf(irs.fixedLeg, irs.floatingLeg))
checkLegDates(listOf(irs.fixedLeg, irs.floatingLeg))
return setOf(command.value)
}
}
class Fix : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Refix::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Refix>()
val irs = outputs.filterIsInstance<State>().single()
val prevIrs = inputs.filterIsInstance<State>().single()
val paymentDifferences = getFloatingLegPaymentsDifferences(prevIrs.calculation.floatingLegPaymentSchedule, irs.calculation.floatingLegPaymentSchedule)
// Having both of these tests are "redundant" as far as verify() goes, however, by performing both
// we can relay more information back to the user in the case of failure.
requireThat {
"There is at least one difference in the IRS floating leg payment schedules" by !paymentDifferences.isEmpty()
"There is only one change in the IRS floating leg payment schedule" by (paymentDifferences.size == 1)
}
val changedRates = paymentDifferences.single().second // Ignore the date of the changed rate (we checked that earlier).
val (oldFloatingRatePaymentEvent, newFixedRatePaymentEvent) = changedRates
val fixValue = command.value.fix
// Need to check that everything is the same apart from the new fixed rate entry.
requireThat {
"The fixed leg parties are constant" by (irs.fixedLeg.fixedRatePayer == prevIrs.fixedLeg.fixedRatePayer) // Although superseded by the below test, this is included for a regression issue
"The fixed leg is constant" by (irs.fixedLeg == prevIrs.fixedLeg)
"The floating leg is constant" by (irs.floatingLeg == prevIrs.floatingLeg)
"The common values are constant" by (irs.common == prevIrs.common)
"The fixed leg payment schedule is constant" by (irs.calculation.fixedLegPaymentSchedule == prevIrs.calculation.fixedLegPaymentSchedule)
"The expression is unchanged" by (irs.calculation.expression == prevIrs.calculation.expression)
"There is only one changed payment in the floating leg" by (paymentDifferences.size == 1)
"There changed payment is a floating payment" by (oldFloatingRatePaymentEvent.rate is ReferenceRate)
"The new payment is a fixed payment" by (newFixedRatePaymentEvent.rate is FixedRate)
"The changed payments dates are aligned" by (oldFloatingRatePaymentEvent.date == newFixedRatePaymentEvent.date)
"The new payment has the correct rate" by (newFixedRatePaymentEvent.rate.ratioUnit!!.value == fixValue.value)
"The fixing is for the next required date" by (prevIrs.calculation.nextFixingDate() == fixValue.of.forDay)
"The fix payment has the same currency as the notional" by (newFixedRatePaymentEvent.flow.token == irs.floatingLeg.notional.token)
// "The fixing is not in the future " by (fixCommand) // The oracle should not have signed this .
}
return setOf(command.value)
}
}
class Pay : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Pay::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Pay>()
requireThat {
"Payments not supported / verifiable yet" by false
}
return setOf(command.value)
}
}
class Mature : AbstractIRSClause() {
override val requiredCommands: Set<Class<out CommandData>> = setOf(Commands.Mature::class.java)
override fun verify(tx: TransactionForContract,
inputs: List<State>,
outputs: List<State>,
commands: List<AuthenticatedObject<Commands>>,
groupingKey: UniqueIdentifier?): Set<Commands> {
val command = tx.commands.requireSingleCommand<Commands.Mature>()
val irs = inputs.filterIsInstance<State>().single()
requireThat {
"No more fixings to be applied" by (irs.calculation.nextFixingDate() == null)
"The irs is fully consumed and there is no id matched output state" by outputs.isEmpty()
}
return setOf(command.value)
}
}
}
interface Commands : CommandData {
data class Refix(val fix: Fix) : Commands // Receive interest rate from oracle, Both sides agree
class Pay : TypeOnlyCommandData(), Commands // Not implemented just yet
class Agree : TypeOnlyCommandData(), Commands // Both sides agree to trade
class Mature : TypeOnlyCommandData(), Commands // Trade has matured; no more actions. Cleanup. // TODO: Do we need this?
}
/**
* The state class contains the 4 major data classes.
*/
data class State(
val fixedLeg: FixedLeg,
val floatingLeg: FloatingLeg,
val calculation: Calculation,
val common: Common,
override val linearId: UniqueIdentifier = UniqueIdentifier(common.tradeID)
) : FixableDealState, SchedulableState {
override val contract = IRS_PROGRAM_ID
override val oracleType: ServiceType
get() = InterestRateSwap.oracleType
override val ref = common.tradeID
override val participants: List<PublicKey>
get() = parties.map { it.owningKey }
override fun isRelevant(ourKeys: Set<PublicKey>): Boolean {
return (fixedLeg.fixedRatePayer.owningKey in ourKeys) || (floatingLeg.floatingRatePayer.owningKey in ourKeys)
}
override val parties: List<Party>
get() = listOf(fixedLeg.fixedRatePayer, floatingLeg.floatingRatePayer)
override fun nextScheduledActivity(thisStateRef: StateRef, protocolLogicRefFactory: ProtocolLogicRefFactory): ScheduledActivity? {
val nextFixingOf = nextFixingOf() ?: return null
// This is perhaps not how we should determine the time point in the business day, but instead expect the schedule to detail some of these aspects
val instant = suggestInterestRateAnnouncementTimeWindow(index = nextFixingOf.name, source = floatingLeg.indexSource, date = nextFixingOf.forDay).start
return ScheduledActivity(protocolLogicRefFactory.create(TwoPartyDealProtocol.FixingRoleDecider::class.java, thisStateRef), instant)
}
override fun generateAgreement(notary: Party): TransactionBuilder = InterestRateSwap().generateAgreement(floatingLeg, fixedLeg, calculation, common, notary)
override fun generateFix(ptx: TransactionBuilder, oldState: StateAndRef<*>, fix: Fix) {
InterestRateSwap().generateFix(ptx, StateAndRef(TransactionState(this, oldState.state.notary), oldState.ref), fix)
}
override fun nextFixingOf(): FixOf? {
val date = calculation.nextFixingDate()
return if (date == null) null else {
val fixingEvent = calculation.getFixing(date)
val oracleRate = fixingEvent.rate as ReferenceRate
FixOf(oracleRate.name, date, oracleRate.tenor)
}
}
/**
* For evaluating arbitrary java on the platform.
*/
fun evaluateCalculation(businessDate: LocalDate, expression: Expression = calculation.expression): Any {
// TODO: Jexl is purely for prototyping. It may be replaced
// TODO: Whatever we do use must be secure and sandboxed
val jexl = JexlBuilder().create()
val expr = jexl.createExpression(expression.expr)
val jc = MapContext()
jc.set("fixedLeg", fixedLeg)
jc.set("floatingLeg", floatingLeg)
jc.set("calculation", calculation)
jc.set("common", common)
jc.set("currentBusinessDate", businessDate)
return expr.evaluate(jc)
}
/**
* Just makes printing it out a bit better for those who don't have 80000 column wide monitors.
*/
fun prettyPrint() = toString().replace(",", "\n")
}
/**
* This generates the agreement state and also the schedules from the initial data.
* Note: The day count, interest rate calculation etc are not finished yet, but they are demonstrable.
*/
fun generateAgreement(floatingLeg: FloatingLeg, fixedLeg: FixedLeg, calculation: Calculation,
common: Common, notary: Party): TransactionBuilder {
val fixedLegPaymentSchedule = HashMap<LocalDate, FixedRatePaymentEvent>()
var dates = BusinessCalendar.createGenericSchedule(fixedLeg.effectiveDate, fixedLeg.paymentFrequency, fixedLeg.paymentCalendar, fixedLeg.rollConvention, endDate = fixedLeg.terminationDate)
var periodStartDate = fixedLeg.effectiveDate
// Create a schedule for the fixed payments
for (periodEndDate in dates) {
val paymentDate = BusinessCalendar.getOffsetDate(periodEndDate, Frequency.Daily, fixedLeg.paymentDelay)
val paymentEvent = FixedRatePaymentEvent(
paymentDate,
periodStartDate,
periodEndDate,
fixedLeg.dayCountBasisDay,
fixedLeg.dayCountBasisYear,
fixedLeg.notional,
fixedLeg.fixedRate
)
fixedLegPaymentSchedule[paymentDate] = paymentEvent
periodStartDate = periodEndDate
}
dates = BusinessCalendar.createGenericSchedule(floatingLeg.effectiveDate,
floatingLeg.fixingsPerPayment,
floatingLeg.fixingCalendar,
floatingLeg.rollConvention,
endDate = floatingLeg.terminationDate)
val floatingLegPaymentSchedule: MutableMap<LocalDate, FloatingRatePaymentEvent> = HashMap()
periodStartDate = floatingLeg.effectiveDate
// Now create a schedule for the floating and fixes.
for (periodEndDate in dates) {
val paymentDate = BusinessCalendar.getOffsetDate(periodEndDate, Frequency.Daily, floatingLeg.paymentDelay)
val paymentEvent = FloatingRatePaymentEvent(
paymentDate,
periodStartDate,
periodEndDate,
floatingLeg.dayCountBasisDay,
floatingLeg.dayCountBasisYear,
calcFixingDate(periodStartDate, floatingLeg.fixingPeriodOffset, floatingLeg.fixingCalendar),
floatingLeg.notional,
ReferenceRate(floatingLeg.indexSource, floatingLeg.indexTenor, floatingLeg.index)
)
floatingLegPaymentSchedule[paymentDate] = paymentEvent
periodStartDate = periodEndDate
}
val newCalculation = Calculation(calculation.expression, floatingLegPaymentSchedule, fixedLegPaymentSchedule)
// Put all the above into a new State object.
val state = State(fixedLeg, floatingLeg, newCalculation, common)
return TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Agree(), listOf(state.floatingLeg.floatingRatePayer.owningKey, state.fixedLeg.fixedRatePayer.owningKey)))
}
private fun calcFixingDate(date: LocalDate, fixingPeriodOffset: Int, calendar: BusinessCalendar): LocalDate {
return when (fixingPeriodOffset) {
0 -> date
else -> calendar.moveBusinessDays(date, DateRollDirection.BACKWARD, fixingPeriodOffset)
}
}
fun generateFix(tx: TransactionBuilder, irs: StateAndRef<State>, fixing: Fix) {
tx.addInputState(irs)
val fixedRate = FixedRate(RatioUnit(fixing.value))
tx.addOutputState(
irs.state.data.copy(calculation = irs.state.data.calculation.applyFixing(fixing.of.forDay, fixedRate)),
irs.state.notary
)
tx.addCommand(Commands.Refix(fixing), listOf(irs.state.data.floatingLeg.floatingRatePayer.owningKey, irs.state.data.fixedLeg.fixedRatePayer.owningKey))
}
}

View File

@ -1,7 +0,0 @@
package com.r3corda.contracts
fun InterestRateSwap.State.exportIRSToCSV(): String =
"Fixed Leg\n" + FixedRatePaymentEvent.CSVHeader + "\n" +
this.calculation.fixedLegPaymentSchedule.toSortedMap().values.map { it.asCSV() }.joinToString("\n") + "\n" +
"Floating Leg\n" + FloatingRatePaymentEvent.CSVHeader + "\n" +
this.calculation.floatingLegPaymentSchedule.toSortedMap().values.map { it.asCSV() }.joinToString("\n") + "\n"

View File

@ -1,88 +0,0 @@
package com.r3corda.contracts
import com.r3corda.core.contracts.Amount
import com.r3corda.core.contracts.Tenor
import java.math.BigDecimal
import java.util.*
// Things in here will move to the general utils class when we've hammered out various discussions regarding amounts, dates, oracle etc.
/**
* A utility class to prevent the various mixups between percentages, decimals, bips etc.
*/
open class RatioUnit(val value: BigDecimal) { // TODO: Discuss this type
override fun equals(other: Any?) = (other as? RatioUnit)?.value == value
override fun hashCode() = value.hashCode()
override fun toString() = value.toString()
}
/**
* A class to reprecent a percentage in an unambiguous way.
*/
open class PercentageRatioUnit(percentageAsString: String) : RatioUnit(BigDecimal(percentageAsString).divide(BigDecimal("100"))) {
override fun toString() = value.times(BigDecimal(100)).toString() + "%"
}
/**
* For the convenience of writing "5".percent
* Note that we do not currently allow 10.percent (ie no quotes) as this might get a little confusing if 0.1.percent was
* written. Additionally, there is a possibility of creating a precision error in the implicit conversion.
*/
val String.percent: PercentageRatioUnit get() = PercentageRatioUnit(this)
/**
* Parent of the Rate family. Used to denote fixed rates, floating rates, reference rates etc.
*/
open class Rate(val ratioUnit: RatioUnit? = null) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
other as Rate
if (ratioUnit != other.ratioUnit) return false
return true
}
/**
* @returns the hash code of the ratioUnit or zero if the ratioUnit is null, as is the case for floating rate fixings
* that have not yet happened. Yet-to-be fixed floating rates need to be equal such that schedules can be tested
* for equality.
*/
override fun hashCode() = ratioUnit?.hashCode() ?: 0
override fun toString() = ratioUnit.toString()
}
/**
* A very basic subclass to represent a fixed rate.
*/
class FixedRate(ratioUnit: RatioUnit) : Rate(ratioUnit) {
fun isPositive(): Boolean = ratioUnit!!.value > BigDecimal("0.0")
override fun equals(other: Any?) = other?.javaClass == javaClass && super.equals(other)
override fun hashCode() = super.hashCode()
}
/**
* The parent class of the Floating rate classes.
*/
open class FloatingRate : Rate(null)
/**
* So a reference rate is a rate that takes its value from a source at a given date
* e.g. LIBOR 6M as of 17 March 2016. Hence it requires a source (name) and a value date in the getAsOf(..) method.
*/
class ReferenceRate(val oracle: String, val tenor: Tenor, val name: String) : FloatingRate() {
override fun toString(): String = "$name - $tenor"
}
// TODO: For further discussion.
operator fun Amount<Currency>.times(other: RatioUnit): Amount<Currency> = Amount((BigDecimal(this.quantity).multiply(other.value)).longValueExact(), this.token)
//operator fun Amount<Currency>.times(other: FixedRate): Amount<Currency> = Amount<Currency>((BigDecimal(this.pennies).multiply(other.value)).longValueExact(), this.currency)
//fun Amount<Currency>.times(other: InterestRateSwap.RatioUnit): Amount<Currency> = Amount<Currency>((BigDecimal(this.pennies).multiply(other.value)).longValueExact(), this.currency)
operator fun kotlin.Int.times(other: FixedRate): Int = BigDecimal(this).multiply(other.ratioUnit!!.value).intValueExact()
operator fun Int.times(other: Rate): Int = BigDecimal(this).multiply(other.ratioUnit!!.value).intValueExact()
operator fun Int.times(other: RatioUnit): Int = BigDecimal(this).multiply(other.value).intValueExact()

View File

@ -1,718 +0,0 @@
package com.r3corda.contracts
import com.r3corda.core.contracts.*
import com.r3corda.core.node.recordTransactions
import com.r3corda.core.seconds
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.TEST_TX_TIME
import com.r3corda.testing.*
import com.r3corda.testing.node.MockServices
import org.junit.Test
import java.math.BigDecimal
import java.time.LocalDate
import java.util.*
fun createDummyIRS(irsSelect: Int): InterestRateSwap.State {
return when (irsSelect) {
1 -> {
val fixedLeg = InterestRateSwap.FixedLeg(
fixedRatePayer = MEGA_CORP,
notional = 15900000.DOLLARS,
paymentFrequency = Frequency.SemiAnnual,
effectiveDate = LocalDate.of(2016, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2026, 3, 10),
terminationDateAdjustment = null,
fixedRate = FixedRate(PercentageRatioUnit("1.677")),
dayCountBasisDay = DayCountBasisDay.D30,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 3,
paymentCalendar = BusinessCalendar.getInstance("London", "NewYork"),
interestPeriodAdjustment = AccrualAdjustment.Adjusted
)
val floatingLeg = InterestRateSwap.FloatingLeg(
floatingRatePayer = MINI_CORP,
notional = 15900000.DOLLARS,
paymentFrequency = Frequency.Quarterly,
effectiveDate = LocalDate.of(2016, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2026, 3, 10),
terminationDateAdjustment = null,
dayCountBasisDay = DayCountBasisDay.D30,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
fixingRollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
resetDayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 3,
paymentCalendar = BusinessCalendar.getInstance("London", "NewYork"),
interestPeriodAdjustment = AccrualAdjustment.Adjusted,
fixingPeriodOffset = 2,
resetRule = PaymentRule.InAdvance,
fixingsPerPayment = Frequency.Quarterly,
fixingCalendar = BusinessCalendar.getInstance("London"),
index = "LIBOR",
indexSource = "TEL3750",
indexTenor = Tenor("3M")
)
val calculation = InterestRateSwap.Calculation (
// TODO: this seems to fail quite dramatically
//expression = "fixedLeg.notional * fixedLeg.fixedRate",
// TODO: How I want it to look
//expression = "( fixedLeg.notional * (fixedLeg.fixedRate)) - (floatingLeg.notional * (rateSchedule.get(context.getDate('currentDate'))))",
// How it's ended up looking, which I think is now broken but it's a WIP.
expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" +
"(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"),
floatingLegPaymentSchedule = HashMap(),
fixedLegPaymentSchedule = HashMap()
)
val EUR = currency("EUR")
val common = InterestRateSwap.Common(
baseCurrency = EUR,
eligibleCurrency = EUR,
eligibleCreditSupport = "Cash in an Eligible Currency",
independentAmounts = Amount(0, EUR),
threshold = Amount(0, EUR),
minimumTransferAmount = Amount(250000 * 100, EUR),
rounding = Amount(10000 * 100, EUR),
valuationDateDescription = "Every Local Business Day",
notificationTime = "2:00pm London",
resolutionTime = "2:00pm London time on the first LocalBusiness Day following the date on which the notice is given ",
interestRate = ReferenceRate("T3270", Tenor("6M"), "EONIA"),
addressForTransfers = "",
exposure = UnknownType(),
localBusinessDay = BusinessCalendar.getInstance("London"),
tradeID = "trade1",
hashLegalDocs = "put hash here",
dailyInterestAmount = Expression("(CashAmount * InterestRate ) / (fixedLeg.notional.currency.currencyCode.equals('GBP')) ? 365 : 360")
)
InterestRateSwap.State(fixedLeg = fixedLeg, floatingLeg = floatingLeg, calculation = calculation, common = common)
}
2 -> {
// 10y swap, we pay 1.3% fixed 30/360 semi, rec 3m usd libor act/360 Q on 25m notional (mod foll/adj on both sides)
// I did a mock up start date 10/03/2015 10/03/2025 so you have 5 cashflows on float side that have been preset the rest are unknown
val fixedLeg = InterestRateSwap.FixedLeg(
fixedRatePayer = MEGA_CORP,
notional = 25000000.DOLLARS,
paymentFrequency = Frequency.SemiAnnual,
effectiveDate = LocalDate.of(2015, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2025, 3, 10),
terminationDateAdjustment = null,
fixedRate = FixedRate(PercentageRatioUnit("1.3")),
dayCountBasisDay = DayCountBasisDay.D30,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 0,
paymentCalendar = BusinessCalendar.getInstance(),
interestPeriodAdjustment = AccrualAdjustment.Adjusted
)
val floatingLeg = InterestRateSwap.FloatingLeg(
floatingRatePayer = MINI_CORP,
notional = 25000000.DOLLARS,
paymentFrequency = Frequency.Quarterly,
effectiveDate = LocalDate.of(2015, 3, 10),
effectiveDateAdjustment = null,
terminationDate = LocalDate.of(2025, 3, 10),
terminationDateAdjustment = null,
dayCountBasisDay = DayCountBasisDay.DActual,
dayCountBasisYear = DayCountBasisYear.Y360,
rollConvention = DateRollConvention.ModifiedFollowing,
fixingRollConvention = DateRollConvention.ModifiedFollowing,
dayInMonth = 10,
resetDayInMonth = 10,
paymentRule = PaymentRule.InArrears,
paymentDelay = 0,
paymentCalendar = BusinessCalendar.getInstance(),
interestPeriodAdjustment = AccrualAdjustment.Adjusted,
fixingPeriodOffset = 2,
resetRule = PaymentRule.InAdvance,
fixingsPerPayment = Frequency.Quarterly,
fixingCalendar = BusinessCalendar.getInstance(),
index = "USD LIBOR",
indexSource = "TEL3750",
indexTenor = Tenor("3M")
)
val calculation = InterestRateSwap.Calculation (
// TODO: this seems to fail quite dramatically
//expression = "fixedLeg.notional * fixedLeg.fixedRate",
// TODO: How I want it to look
//expression = "( fixedLeg.notional * (fixedLeg.fixedRate)) - (floatingLeg.notional * (rateSchedule.get(context.getDate('currentDate'))))",
// How it's ended up looking, which I think is now broken but it's a WIP.
expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" +
"(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"),
floatingLegPaymentSchedule = HashMap(),
fixedLegPaymentSchedule = HashMap()
)
val EUR = currency("EUR")
val common = InterestRateSwap.Common(
baseCurrency = EUR,
eligibleCurrency = EUR,
eligibleCreditSupport = "Cash in an Eligible Currency",
independentAmounts = Amount(0, EUR),
threshold = Amount(0, EUR),
minimumTransferAmount = Amount(250000 * 100, EUR),
rounding = Amount(10000 * 100, EUR),
valuationDateDescription = "Every Local Business Day",
notificationTime = "2:00pm London",
resolutionTime = "2:00pm London time on the first LocalBusiness Day following the date on which the notice is given ",
interestRate = ReferenceRate("T3270", Tenor("6M"), "EONIA"),
addressForTransfers = "",
exposure = UnknownType(),
localBusinessDay = BusinessCalendar.getInstance("London"),
tradeID = "trade2",
hashLegalDocs = "put hash here",
dailyInterestAmount = Expression("(CashAmount * InterestRate ) / (fixedLeg.notional.currency.currencyCode.equals('GBP')) ? 365 : 360")
)
return InterestRateSwap.State(fixedLeg = fixedLeg, floatingLeg = floatingLeg, calculation = calculation, common = common)
}
else -> TODO("IRS number $irsSelect not defined")
}
}
class IRSTests {
@Test
fun ok() {
trade().verifies()
}
@Test
fun `ok with groups`() {
tradegroups().verifies()
}
/**
* Generate an IRS txn - we'll need it for a few things.
*/
fun generateIRSTxn(irsSelect: Int): SignedTransaction {
val dummyIRS = createDummyIRS(irsSelect)
val genTX: SignedTransaction = run {
val gtx = InterestRateSwap().generateAgreement(
fixedLeg = dummyIRS.fixedLeg,
floatingLeg = dummyIRS.floatingLeg,
calculation = dummyIRS.calculation,
common = dummyIRS.common,
notary = DUMMY_NOTARY).apply {
setTime(TEST_TX_TIME, 30.seconds)
signWith(MEGA_CORP_KEY)
signWith(MINI_CORP_KEY)
signWith(DUMMY_NOTARY_KEY)
}
gtx.toSignedTransaction()
}
return genTX
}
/**
* Just make sure it's sane.
*/
@Test
fun pprintIRS() {
val irs = singleIRS()
println(irs.prettyPrint())
}
/**
* Utility so I don't have to keep typing this.
*/
fun singleIRS(irsSelector: Int = 1): InterestRateSwap.State {
return generateIRSTxn(irsSelector).tx.outputs.map { it.data }.filterIsInstance<InterestRateSwap.State>().single()
}
/**
* Test the generate. No explicit exception as if something goes wrong, we'll find out anyway.
*/
@Test
fun generateIRS() {
// Tests aren't allowed to return things
generateIRSTxn(1)
}
/**
* Testing a simple IRS, add a few fixings and then display as CSV.
*/
@Test
fun `IRS Export test`() {
// No transactions etc required - we're just checking simple maths and export functionallity
val irs = singleIRS(2)
var newCalculation = irs.calculation
val fixings = mapOf(LocalDate.of(2015, 3, 6) to "0.6",
LocalDate.of(2015, 6, 8) to "0.75",
LocalDate.of(2015, 9, 8) to "0.8",
LocalDate.of(2015, 12, 8) to "0.55",
LocalDate.of(2016, 3, 8) to "0.644")
for ((key, value) in fixings) {
newCalculation = newCalculation.applyFixing(key, FixedRate(PercentageRatioUnit(value)))
}
val newIRS = InterestRateSwap.State(irs.fixedLeg, irs.floatingLeg, newCalculation, irs.common)
println(newIRS.exportIRSToCSV())
}
/**
* Make sure it has a schedule and the schedule has some unfixed rates.
*/
@Test
fun `next fixing date`() {
val irs = singleIRS(1)
println(irs.calculation.nextFixingDate())
}
/**
* Iterate through all the fix dates and add something.
*/
@Test
fun generateIRSandFixSome() {
val services = MockServices()
var previousTXN = generateIRSTxn(1)
previousTXN.toLedgerTransaction(services).verify()
services.recordTransactions(previousTXN)
fun currentIRS() = previousTXN.tx.outputs.map { it.data }.filterIsInstance<InterestRateSwap.State>().single()
while (true) {
val nextFix: FixOf = currentIRS().nextFixingOf() ?: break
val fixTX: SignedTransaction = run {
val tx = TransactionType.General.Builder(DUMMY_NOTARY)
val fixing = Fix(nextFix, "0.052".percent.value)
InterestRateSwap().generateFix(tx, previousTXN.tx.outRef(0), fixing)
with(tx) {
setTime(TEST_TX_TIME, 30.seconds)
signWith(MEGA_CORP_KEY)
signWith(MINI_CORP_KEY)
signWith(DUMMY_NOTARY_KEY)
}
tx.toSignedTransaction()
}
fixTX.toLedgerTransaction(services).verify()
services.recordTransactions(fixTX)
previousTXN = fixTX
}
}
// Move these later as they aren't IRS specific.
@Test
fun `test some rate objects 100 * FixedRate(5%)`() {
val r1 = FixedRate(PercentageRatioUnit("5"))
assert(100 * r1 == 5)
}
@Test
fun `expression calculation testing`() {
val dummyIRS = singleIRS()
val stuffToPrint: ArrayList<String> = arrayListOf(
"fixedLeg.notional.quantity",
"fixedLeg.fixedRate.ratioUnit",
"fixedLeg.fixedRate.ratioUnit.value",
"floatingLeg.notional.quantity",
"fixedLeg.fixedRate",
"currentBusinessDate",
"calculation.floatingLegPaymentSchedule.get(currentBusinessDate)",
"fixedLeg.notional.token.currencyCode",
"fixedLeg.notional.quantity * 10",
"fixedLeg.notional.quantity * fixedLeg.fixedRate.ratioUnit.value",
"(fixedLeg.notional.token.currencyCode.equals('GBP')) ? 365 : 360 ",
"(fixedLeg.notional.quantity * (fixedLeg.fixedRate.ratioUnit.value))"
// "calculation.floatingLegPaymentSchedule.get(context.getDate('currentDate')).rate"
// "calculation.floatingLegPaymentSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value",
//"( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) - (floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))",
// "( fixedLeg.notional * fixedLeg.fixedRate )"
)
for (i in stuffToPrint) {
println(i)
val z = dummyIRS.evaluateCalculation(LocalDate.of(2016, 9, 15), Expression(i))
println(z.javaClass)
println(z)
println("-----------")
}
// This does not throw an exception in the test itself; it evaluates the above and they will throw if they do not pass.
}
/**
* Generates a typical transactional history for an IRS.
*/
fun trade(): LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter> {
val ld = LocalDate.of(2016, 3, 8)
val bd = BigDecimal("0.0063518")
return ledger {
transaction("Agreement") {
output("irs post agreement") { singleIRS() }
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
transaction("Fix") {
input("irs post agreement")
val postAgreement = "irs post agreement".output<InterestRateSwap.State>()
output("irs post first fixing") {
postAgreement.copy(
postAgreement.fixedLeg,
postAgreement.floatingLeg,
postAgreement.calculation.applyFixing(ld, FixedRate(RatioUnit(bd))),
postAgreement.common
)
}
command(ORACLE_PUBKEY) {
InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd))
}
timestamp(TEST_TX_TIME)
this.verifies()
}
}
}
@Test
fun `ensure failure occurs when there are inbound states for an agreement command`() {
val irs = singleIRS()
transaction {
input() { irs }
output("irs post agreement") { irs }
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "There are no in states for an agreement"
}
}
@Test
fun `ensure failure occurs when no events in fix schedule`() {
val irs = singleIRS()
val emptySchedule = HashMap<LocalDate, FixedRatePaymentEvent>()
transaction {
output() {
irs.copy(calculation = irs.calculation.copy(fixedLegPaymentSchedule = emptySchedule))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "There are events in the fix schedule"
}
}
@Test
fun `ensure failure occurs when no events in floating schedule`() {
val irs = singleIRS()
val emptySchedule = HashMap<LocalDate, FloatingRatePaymentEvent>()
transaction {
output() {
irs.copy(calculation = irs.calculation.copy(floatingLegPaymentSchedule = emptySchedule))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "There are events in the float schedule"
}
}
@Test
fun `ensure notionals are non zero`() {
val irs = singleIRS()
transaction {
output() {
irs.copy(irs.fixedLeg.copy(notional = irs.fixedLeg.notional.copy(quantity = 0)))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "All notionals must be non zero"
}
transaction {
output() {
irs.copy(irs.fixedLeg.copy(notional = irs.floatingLeg.notional.copy(quantity = 0)))
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "All notionals must be non zero"
}
}
@Test
fun `ensure positive rate on fixed leg`() {
val irs = singleIRS()
val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(fixedRate = FixedRate(PercentageRatioUnit("-0.1"))))
transaction {
output() {
modifiedIRS
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The fixed leg rate must be positive"
}
}
/**
* This will be modified once we adapt the IRS to be cross currency.
*/
@Test
fun `ensure same currency notionals`() {
val irs = singleIRS()
val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.fixedLeg.notional.quantity, Currency.getInstance("JPY"))))
transaction {
output() {
modifiedIRS
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The currency of the notionals must be the same"
}
}
@Test
fun `ensure notional amounts are equal`() {
val irs = singleIRS()
val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.floatingLeg.notional.quantity + 1, irs.floatingLeg.notional.token)))
transaction {
output() {
modifiedIRS
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "All leg notionals must be the same"
}
}
@Test
fun `ensure trade date and termination date checks are done pt1`() {
val irs = singleIRS()
val modifiedIRS1 = irs.copy(fixedLeg = irs.fixedLeg.copy(terminationDate = irs.fixedLeg.effectiveDate.minusDays(1)))
transaction {
output() {
modifiedIRS1
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The effective date is before the termination date for the fixed leg"
}
val modifiedIRS2 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.floatingLeg.effectiveDate.minusDays(1)))
transaction {
output() {
modifiedIRS2
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The effective date is before the termination date for the floating leg"
}
}
@Test
fun `ensure trade date and termination date checks are done pt2`() {
val irs = singleIRS()
val modifiedIRS3 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.fixedLeg.terminationDate.minusDays(1)))
transaction {
output() {
modifiedIRS3
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The termination dates are aligned"
}
val modifiedIRS4 = irs.copy(floatingLeg = irs.floatingLeg.copy(effectiveDate = irs.fixedLeg.effectiveDate.minusDays(1)))
transaction {
output() {
modifiedIRS4
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this `fails with` "The effective dates are aligned"
}
}
@Test
fun `various fixing tests`() {
val ld = LocalDate.of(2016, 3, 8)
val bd = BigDecimal("0.0063518")
transaction {
output("irs post agreement") { singleIRS() }
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
val oldIRS = singleIRS(1)
val newIRS = oldIRS.copy(oldIRS.fixedLeg,
oldIRS.floatingLeg,
oldIRS.calculation.applyFixing(ld, FixedRate(RatioUnit(bd))),
oldIRS.common)
transaction {
input() {
oldIRS
}
// Templated tweak for reference. A corrent fixing applied should be ok
tweak {
command(ORACLE_PUBKEY) {
InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd))
}
timestamp(TEST_TX_TIME)
output() { newIRS }
this.verifies()
}
// This test makes sure that verify confirms the fixing was applied and there is a difference in the old and new
tweak {
command(ORACLE_PUBKEY) { InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd)) }
timestamp(TEST_TX_TIME)
output() { oldIRS }
this `fails with` "There is at least one difference in the IRS floating leg payment schedules"
}
// This tests tries to sneak in a change to another fixing (which may or may not be the latest one)
tweak {
command(ORACLE_PUBKEY) { InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd)) }
timestamp(TEST_TX_TIME)
val firstResetKey = newIRS.calculation.floatingLegPaymentSchedule.keys.first()
val firstResetValue = newIRS.calculation.floatingLegPaymentSchedule[firstResetKey]
val modifiedFirstResetValue = firstResetValue!!.copy(notional = Amount(firstResetValue.notional.quantity, Currency.getInstance("JPY")))
output() {
newIRS.copy(
newIRS.fixedLeg,
newIRS.floatingLeg,
newIRS.calculation.copy(floatingLegPaymentSchedule = newIRS.calculation.floatingLegPaymentSchedule.plus(
Pair(firstResetKey, modifiedFirstResetValue))),
newIRS.common
)
}
this `fails with` "There is only one change in the IRS floating leg payment schedule"
}
// This tests modifies the payment currency for the fixing
tweak {
command(ORACLE_PUBKEY) { InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld, Tenor("3M")), bd)) }
timestamp(TEST_TX_TIME)
val latestReset = newIRS.calculation.floatingLegPaymentSchedule.filter { it.value.rate is FixedRate }.maxBy { it.key }
val modifiedLatestResetValue = latestReset!!.value.copy(notional = Amount(latestReset.value.notional.quantity, Currency.getInstance("JPY")))
output() {
newIRS.copy(
newIRS.fixedLeg,
newIRS.floatingLeg,
newIRS.calculation.copy(floatingLegPaymentSchedule = newIRS.calculation.floatingLegPaymentSchedule.plus(
Pair(latestReset.key, modifiedLatestResetValue))),
newIRS.common
)
}
this `fails with` "The fix payment has the same currency as the notional"
}
}
}
/**
* This returns an example of transactions that are grouped by TradeId and then a fixing applied.
* It's important to make the tradeID different for two reasons, the hashes will be the same and all sorts of confusion will
* result and the grouping won't work either.
* In reality, the only fields that should be in common will be the next fixing date and the reference rate.
*/
fun tradegroups(): LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter> {
val ld1 = LocalDate.of(2016, 3, 8)
val bd1 = BigDecimal("0.0063518")
val irs = singleIRS()
return ledger {
transaction("Agreement") {
output("irs post agreement1") {
irs.copy(
irs.fixedLeg,
irs.floatingLeg,
irs.calculation,
irs.common.copy(tradeID = "t1")
)
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
transaction("Agreement") {
output("irs post agreement2") {
irs.copy(
linearId = UniqueIdentifier("t2"),
fixedLeg = irs.fixedLeg,
floatingLeg = irs.floatingLeg,
calculation = irs.calculation,
common = irs.common.copy(tradeID = "t2")
)
}
command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() }
timestamp(TEST_TX_TIME)
this.verifies()
}
transaction("Fix") {
input("irs post agreement1")
input("irs post agreement2")
val postAgreement1 = "irs post agreement1".output<InterestRateSwap.State>()
output("irs post first fixing1") {
postAgreement1.copy(
postAgreement1.fixedLeg,
postAgreement1.floatingLeg,
postAgreement1.calculation.applyFixing(ld1, FixedRate(RatioUnit(bd1))),
postAgreement1.common.copy(tradeID = "t1")
)
}
val postAgreement2 = "irs post agreement2".output<InterestRateSwap.State>()
output("irs post first fixing2") {
postAgreement2.copy(
postAgreement2.fixedLeg,
postAgreement2.floatingLeg,
postAgreement2.calculation.applyFixing(ld1, FixedRate(RatioUnit(bd1))),
postAgreement2.common.copy(tradeID = "t2")
)
}
command(ORACLE_PUBKEY) {
InterestRateSwap.Commands.Refix(Fix(FixOf("ICE LIBOR", ld1, Tenor("3M")), bd1))
}
timestamp(TEST_TX_TIME)
this.verifies()
}
}
}
}

View File

@ -15,6 +15,7 @@ import java.io.BufferedInputStream
import java.io.InputStream
import java.math.BigDecimal
import java.nio.file.Files
import java.nio.file.LinkOption
import java.nio.file.Path
import java.time.Duration
import java.time.temporal.Temporal
@ -89,6 +90,7 @@ inline fun <T> SettableFuture<T>.catch(block: () -> T) {
}
fun <R> Path.use(block: (InputStream) -> R): R = Files.newInputStream(this).use(block)
fun Path.exists(vararg options: LinkOption): Boolean = Files.exists(this, *options)
// Simple infix function to add back null safety that the JDK lacks: timeA until timeB
infix fun Temporal.until(endExclusive: Temporal) = Duration.between(this, endExclusive)
@ -290,7 +292,7 @@ fun <T, I: Comparable<I>> Iterable<T>.isOrderedAndUnique(extractId: T.() -> I):
if (lastLast == null) {
true
} else {
lastLast.compareTo(extractId(it)) < 0
lastLast < extractId(it)
}
}
}

View File

@ -27,7 +27,7 @@ import kotlin.reflect.primaryConstructor
*/
class ProtocolLogicRefFactory(private val protocolWhitelist: Map<String, Set<String>>) : SingletonSerializeAsToken() {
constructor() : this(mapOf(Pair(TwoPartyDealProtocol.FixingRoleDecider::class.java.name, setOf(StateRef::class.java.name, Duration::class.java.name))))
constructor() : this(mapOf())
// Pending real dependence on AppContext for class loading etc
@Suppress("UNUSED_PARAMETER")

View File

@ -1,19 +0,0 @@
package com.r3corda.core.utilities
import java.time.*
/**
* This whole file exists as short cuts to get demos working. In reality we'd have static data and/or rules engine
* defining things like this. It currently resides in the core module because it needs to be visible to the IRS
* contract.
*/
// We at some future point may implement more than just this constant announcement window and thus use the params.
@Suppress("UNUSED_PARAMETER")
fun suggestInterestRateAnnouncementTimeWindow(index: String, source: String, date: LocalDate): TimeWindow {
// TODO: we would ordinarily convert clock to same time zone as the index/source would announce in
// and suggest an announcement time for the interest rate
// Here we apply a blanket announcement time of 11:45 London irrespective of source or index
val time = LocalTime.of(11, 45)
val zoneId = ZoneId.of("Europe/London")
return TimeWindow(ZonedDateTime.of(date, time, zoneId).toInstant(), Duration.ofHours(24))
}

View File

@ -1,109 +0,0 @@
package com.r3corda.protocols
import co.paralleluniverse.fibers.Suspendable
import com.r3corda.core.contracts.Fix
import com.r3corda.core.contracts.FixOf
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.transactions.TransactionBuilder
import com.r3corda.core.transactions.WireTransaction
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.suggestInterestRateAnnouncementTimeWindow
import com.r3corda.protocols.RatesFixProtocol.FixOutOfRange
import java.math.BigDecimal
import java.time.Instant
import java.util.*
// This code is unit tested in NodeInterestRates.kt
/**
* This protocol queries the given oracle for an interest rate fix, and if it is within the given tolerance embeds the
* fix in the transaction and then proceeds to get the oracle to sign it. Although the [call] method combines the query
* and signing step, you can run the steps individually by constructing this object and then using the public methods
* for each step.
*
* @throws FixOutOfRange if the returned fix was further away from the expected rate by the given amount.
*/
open class RatesFixProtocol(protected val tx: TransactionBuilder,
private val oracle: Party,
private val fixOf: FixOf,
private val expectedRate: BigDecimal,
private val rateTolerance: BigDecimal,
override val progressTracker: ProgressTracker = RatesFixProtocol.tracker(fixOf.name)) : ProtocolLogic<Unit>() {
companion object {
class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate")
object WORKING : ProgressTracker.Step("Working with data returned by oracle")
object SIGNING : ProgressTracker.Step("Requesting confirmation signature from interest rate oracle")
fun tracker(fixName: String) = ProgressTracker(QUERYING(fixName), WORKING, SIGNING)
}
class FixOutOfRange(@Suppress("unused") val byAmount: BigDecimal) : Exception("Fix out of range by $byAmount")
data class QueryRequest(val queries: List<FixOf>, val deadline: Instant)
data class SignRequest(val tx: WireTransaction)
@Suspendable
override fun call() {
progressTracker.currentStep = progressTracker.steps[1]
val fix = subProtocol(FixQueryProtocol(fixOf, oracle))
progressTracker.currentStep = WORKING
checkFixIsNearExpected(fix)
tx.addCommand(fix, oracle.owningKey)
beforeSigning(fix)
progressTracker.currentStep = SIGNING
val signature = subProtocol(FixSignProtocol(tx, oracle))
tx.addSignatureUnchecked(signature)
}
/**
* You can override this to perform any additional work needed after the fix is added to the transaction but
* before it's sent back to the oracle for signing (for example, adding output states that depend on the fix).
*/
@Suspendable
protected open fun beforeSigning(fix: Fix) {
}
private fun checkFixIsNearExpected(fix: Fix) {
val delta = (fix.value - expectedRate).abs()
if (delta > rateTolerance) {
// TODO: Kick to a user confirmation / ui flow if it's out of bounds instead of raising an exception.
throw FixOutOfRange(delta)
}
}
class FixQueryProtocol(val fixOf: FixOf, val oracle: Party) : ProtocolLogic<Fix>() {
@Suspendable
override fun call(): Fix {
val deadline = suggestInterestRateAnnouncementTimeWindow(fixOf.name, oracle.name, fixOf.forDay).end
// TODO: add deadline to receive
val resp = sendAndReceive<ArrayList<Fix>>(oracle, QueryRequest(listOf(fixOf), deadline))
return resp.unwrap {
val fix = it.first()
// Check the returned fix is for what we asked for.
check(fix.of == fixOf)
fix
}
}
}
class FixSignProtocol(val tx: TransactionBuilder, val oracle: Party) : ProtocolLogic<DigitalSignature.LegallyIdentifiable>() {
@Suspendable
override fun call(): DigitalSignature.LegallyIdentifiable {
val wtx = tx.toWireTransaction()
val resp = sendAndReceive<DigitalSignature.LegallyIdentifiable>(oracle, SignRequest(wtx))
return resp.unwrap { sig ->
check(sig.signer == oracle)
tx.checkSignature(sig)
sig
}
}
}
}

View File

@ -306,126 +306,4 @@ object TwoPartyDealProtocol {
}
}
/**
* One side of the fixing protocol for an interest rate swap, but could easily be generalised further.
*
* Do not infer too much from the name of the class. This is just to indicate that it is the "side"
* of the protocol that is run by the party with the fixed leg of swap deal, which is the basis for deciding
* who does what in the protocol.
*/
class Fixer(override val otherParty: Party,
override val progressTracker: ProgressTracker = Secondary.tracker()) : Secondary<FixingSession>() {
private lateinit var txState: TransactionState<*>
private lateinit var deal: FixableDealState
override fun validateHandshake(handshake: Handshake<FixingSession>): Handshake<FixingSession> {
logger.trace { "Got fixing request for: ${handshake.payload}" }
txState = serviceHub.loadState(handshake.payload.ref)
deal = txState.data as FixableDealState
// validate the party that initiated is the one on the deal and that the recipient corresponds with it.
// TODO: this is in no way secure and will be replaced by general session initiation logic in the future
val myName = serviceHub.myInfo.legalIdentity.name
// Also check we are one of the parties
deal.parties.filter { it.name == myName }.single()
return handshake
}
@Suspendable
override fun assembleSharedTX(handshake: Handshake<FixingSession>): Pair<TransactionBuilder, List<PublicKey>> {
@Suppress("UNCHECKED_CAST")
val fixOf = deal.nextFixingOf()!!
// TODO Do we need/want to substitute in new public keys for the Parties?
val myName = serviceHub.myInfo.legalIdentity.name
val myOldParty = deal.parties.single { it.name == myName }
val newDeal = deal
val ptx = TransactionType.General.Builder(txState.notary)
val oracle = serviceHub.networkMapCache.get(handshake.payload.oracleType).first()
val addFixing = object : RatesFixProtocol(ptx, oracle.serviceIdentities(handshake.payload.oracleType).first(), fixOf, BigDecimal.ZERO, BigDecimal.ONE) {
@Suspendable
override fun beforeSigning(fix: Fix) {
newDeal.generateFix(ptx, StateAndRef(txState, handshake.payload.ref), fix)
// And add a request for timestamping: it may be that none of the contracts need this! But it can't hurt
// to have one.
ptx.setTime(serviceHub.clock.instant(), 30.seconds)
}
}
subProtocol(addFixing)
return Pair(ptx, arrayListOf(myOldParty.owningKey))
}
}
/**
* One side of the fixing protocol for an interest rate swap, but could easily be generalised furher.
*
* As per the [Fixer], do not infer too much from this class name in terms of business roles. This
* is just the "side" of the protocol run by the party with the floating leg as a way of deciding who
* does what in the protocol.
*/
class Floater(override val otherParty: Party,
override val payload: FixingSession,
override val progressTracker: ProgressTracker = Primary.tracker()) : Primary() {
@Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty {
val state = serviceHub.loadState(payload.ref) as TransactionState<FixableDealState>
StateAndRef(state, payload.ref)
}
override val myKeyPair: KeyPair get() {
val myName = serviceHub.myInfo.legalIdentity.name
val publicKey = dealToFix.state.data.parties.filter { it.name == myName }.single().owningKey
return serviceHub.keyManagementService.toKeyPair(publicKey)
}
override val notaryNode: NodeInfo get() =
serviceHub.networkMapCache.notaryNodes.filter { it.notaryIdentity == dealToFix.state.notary }.single()
}
/** Used to set up the session between [Floater] and [Fixer] */
data class FixingSession(val ref: StateRef, val oracleType: ServiceType)
/**
* This protocol looks at the deal and decides whether to be the Fixer or Floater role in agreeing a fixing.
*
* It is kicked off as an activity on both participant nodes by the scheduler when it's time for a fixing. If the
* Fixer role is chosen, then that will be initiated by the [FixingSession] message sent from the other party and
* handled by the [FixingSessionInitiationHandler].
*
* TODO: Replace [FixingSession] and [FixingSessionInitiationHandler] with generic session initiation logic once it exists.
*/
class FixingRoleDecider(val ref: StateRef,
override val progressTracker: ProgressTracker = tracker()) : ProtocolLogic<Unit>() {
companion object {
class LOADING() : ProgressTracker.Step("Loading state to decide fixing role")
fun tracker() = ProgressTracker(LOADING())
}
@Suspendable
override fun call(): Unit {
progressTracker.nextStep()
val dealToFix = serviceHub.loadState(ref)
// TODO: this is not the eventual mechanism for identifying the parties
val fixableDeal = (dealToFix.data as FixableDealState)
val sortedParties = fixableDeal.parties.sortedBy { it.name }
if (sortedParties[0].name == serviceHub.myInfo.legalIdentity.name) {
val fixing = FixingSession(ref, fixableDeal.oracleType)
// Start the Floater which will then kick-off the Fixer
subProtocol(Floater(sortedParties[1], fixing))
}
}
}
}

View File

@ -0,0 +1,24 @@
package com.r3corda.core
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class UtilsTest {
fun `ordered and unique basic`() {
val basic = listOf(1, 2, 3, 5, 8)
assertTrue(basic.isOrderedAndUnique { this })
val negative = listOf(-1, 2, 5)
assertTrue(negative.isOrderedAndUnique { this })
}
fun `ordered and unique duplicate`() {
val duplicated = listOf(1, 2, 2, 3, 5, 8)
assertFalse(duplicated.isOrderedAndUnique { this })
}
fun `ordered and unique out of sequence`() {
val mixed = listOf(3, 1, 2, 8, 5)
assertFalse(mixed.isOrderedAndUnique { this })
}
}

View File

@ -3,6 +3,7 @@ package com.r3corda.docs
import com.google.common.net.HostAndPort
import com.r3corda.client.CordaRPCClient
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.node.services.config.NodeSSLConfiguration
import org.graphstream.graph.Edge
import org.graphstream.graph.Node
import org.graphstream.graph.implementations.SingleGraph
@ -26,12 +27,18 @@ fun main(args: Array<String>) {
}
val nodeAddress = HostAndPort.fromString(args[0])
val printOrVisualise = PrintOrVisualise.valueOf(args[1])
val certificatesPath = Paths.get("build/trader-demo/buyer/certificates")
val sslConfig = object : NodeSSLConfiguration {
override val certificatesPath = Paths.get("build/trader-demo/buyer/certificates")
override val keyStorePassword = "cordacadevpass"
override val trustStorePassword = "trustpass"
}
// END 1
// START 2
val client = CordaRPCClient(nodeAddress, certificatesPath)
client.start()
val username = System.console().readLine("Enter username: ")
val password = String(System.console().readPassword("Enter password: "))
val client = CordaRPCClient(nodeAddress, sslConfig)
client.start(username, password)
val proxy = client.proxy()
// END 2
@ -65,7 +72,7 @@ fun main(args: Array<String>) {
futureTransactions.subscribe { transaction ->
graph.addNode<Node>("${transaction.id}")
transaction.tx.inputs.forEach { ref ->
graph.addEdge<Edge>("${ref}", "${ref.txhash}", "${transaction.id}")
graph.addEdge<Edge>("$ref", "${ref.txhash}", "${transaction.id}")
}
}
graph.display()

View File

@ -16,7 +16,7 @@ we also need to access the certificates of the node, we will access the node's `
:start-after: START 1
:end-before: END 1
Now we can connect to the node itself:
Now we can connect to the node itself using a valid RPC login. By default the user `user1` is available with password `test`.
.. literalinclude:: example-code/src/main/kotlin/com/r3corda/docs/ClientRpcTutorial.kt
:language: kotlin

View File

@ -6,7 +6,7 @@ import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.explorer.model.IdentityModel
import com.r3corda.node.driver.PortAllocation
import com.r3corda.node.driver.driver
import com.r3corda.node.driver.startClient
import com.r3corda.node.services.config.configureTestSSL
import com.r3corda.node.services.transactions.SimpleNotaryService
import javafx.stage.Stage
import tornadofx.App
@ -32,14 +32,12 @@ class Main : App() {
val aliceNodeFuture = startNode("Alice")
val notaryNodeFuture = startNode("Notary", advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type)))
val aliceNode = aliceNodeFuture.get()
val notaryNode = notaryNodeFuture.get()
val aliceClient = startClient(aliceNode).get()
val aliceNode = aliceNodeFuture.get().nodeInfo
val notaryNode = notaryNodeFuture.get().nodeInfo
Models.get<IdentityModel>(Main::class).notary.set(notaryNode.notaryIdentity)
Models.get<IdentityModel>(Main::class).myIdentity.set(aliceNode.legalIdentity)
Models.get<NodeMonitorModel>(Main::class).register(aliceNode, aliceClient.config.certificatesPath)
Models.get<NodeMonitorModel>(Main::class).register(aliceNode, configureTestSSL(), "user1", "test")
startNode("Bob").get()

View File

@ -134,6 +134,9 @@ dependencies {
compile "org.hibernate:hibernate-core:5.2.2.Final"
compile "org.hibernate:hibernate-java8:5.2.2.Final"
// Capsule is a library for building independently executable fat JARs.
compile 'co.paralleluniverse:capsule:1.0.3'
// Integration test helpers
integrationTestCompile 'junit:junit:4.12'

View File

@ -11,14 +11,8 @@ import org.junit.Test
class DriverTests {
companion object {
fun nodeMustBeUp(networkMapCache: NetworkMapCache, nodeInfo: NodeInfo, nodeName: String) {
fun nodeMustBeUp(nodeInfo: NodeInfo, nodeName: String) {
val hostAndPort = ArtemisMessagingComponent.toHostAndPort(nodeInfo.address)
// Check that the node is registered in the network map
poll("network map cache for $nodeName") {
networkMapCache.get().firstOrNull {
it.legalIdentity.name == nodeName
}
}
// Check that the port is bound
addressMustBeBound(hostAndPort)
}
@ -36,31 +30,31 @@ class DriverTests {
val notary = startNode("TestNotary", setOf(ServiceInfo(SimpleNotaryService.type)))
val regulator = startNode("Regulator", setOf(ServiceInfo(RegulatorService.type)))
nodeMustBeUp(networkMapCache, notary.get(), "TestNotary")
nodeMustBeUp(networkMapCache, regulator.get(), "Regulator")
nodeMustBeUp(notary.get().nodeInfo, "TestNotary")
nodeMustBeUp(regulator.get().nodeInfo, "Regulator")
Pair(notary.get(), regulator.get())
}
nodeMustBeDown(notary)
nodeMustBeDown(regulator)
nodeMustBeDown(notary.nodeInfo)
nodeMustBeDown(regulator.nodeInfo)
}
@Test
fun startingNodeWithNoServicesWorks() {
val noService = driver {
val noService = startNode("NoService")
nodeMustBeUp(networkMapCache, noService.get(), "NoService")
nodeMustBeUp(noService.get().nodeInfo, "NoService")
noService.get()
}
nodeMustBeDown(noService)
nodeMustBeDown(noService.nodeInfo)
}
@Test
fun randomFreePortAllocationWorks() {
val nodeInfo = driver(portAllocation = PortAllocation.RandomFree()) {
val nodeInfo = startNode("NoService")
nodeMustBeUp(networkMapCache, nodeInfo.get(), "NoService")
nodeMustBeUp(nodeInfo.get().nodeInfo, "NoService")
nodeInfo.get()
}
nodeMustBeDown(nodeInfo)
nodeMustBeDown(nodeInfo.nodeInfo)
}
}

View File

@ -4,6 +4,7 @@ import com.r3corda.core.contracts.*
import com.r3corda.node.api.StatesQuery
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.transactions.WireTransaction
@ -41,6 +42,16 @@ interface APIServer {
@Produces(MediaType.TEXT_PLAIN)
fun status(): Response
/**
* Report this node's configuration and identities.
* Currently tunnels the NodeInfo as an encoding of the Kryo serialised form.
* TODO this functionality should be available via the RPC
*/
@GET
@Path("info")
@Produces(MediaType.APPLICATION_JSON)
fun info(): NodeInfo
/**
* Query your "local" states (containing only outputs involving you) and return the hashes & indexes associated with them
* to probably be later inflated by fetchLedgerTransactions() or fetchStates() although because immutable you can cache them

View File

@ -1,5 +1,7 @@
package com.r3corda.node.driver
import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.databind.module.SimpleModule
import com.google.common.net.HostAndPort
import com.r3corda.core.ThreadBox
import com.r3corda.core.crypto.Party
@ -7,6 +9,7 @@ import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.NetworkMapCache
import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.config.ConfigHelper
import com.r3corda.node.services.config.FullNodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingComponent
@ -14,22 +17,24 @@ import com.r3corda.node.services.messaging.ArtemisMessagingServer
import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.network.InMemoryNetworkMapCache
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.JsonSupport
import com.typesafe.config.Config
import com.typesafe.config.ConfigRenderOptions
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import java.io.File
import java.io.InputStreamReader
import java.net.*
import java.nio.file.Path
import java.nio.file.Paths
import java.text.SimpleDateFormat
import java.time.Clock
import java.util.*
import java.util.concurrent.*
import kotlin.concurrent.thread
/**
* This file defines a small "Driver" DSL for starting up nodes.
* This file defines a small "Driver" DSL for starting up nodes that is only intended for development, demos and tests.
*
* The process the driver is run in behaves as an Artemis client and starts up other processes. Namely it first
* bootstraps a network map service to allow the specified nodes to connect to, then starts up the actual nodes.
@ -54,38 +59,18 @@ interface DriverDSLExposedInterface {
* @param advertisedServices The set of services to be advertised by the node. Defaults to empty set.
* @return The [NodeInfo] of the started up node retrieved from the network map service.
*/
fun startNode(providedName: String? = null, advertisedServices: Set<ServiceInfo> = setOf()): Future<NodeInfo>
fun startNode(providedName: String? = null, advertisedServices: Set<ServiceInfo> = setOf()): Future<NodeInfoAndConfig>
/**
* Starts an [NodeMessagingClient].
*
* @param providedName name of the client, which will be used for creating its directory.
* @param serverAddress the artemis server to connect to, for example a [Node].
*/
fun startClient(providedName: String, serverAddress: HostAndPort): Future<NodeMessagingClient>
/**
* Starts a local [ArtemisMessagingServer] of which there may only be one.
*/
fun startLocalServer(): Future<ArtemisMessagingServer>
fun waitForAllNodesToFinish()
val networkMapCache: NetworkMapCache
}
fun DriverDSLExposedInterface.startClient(localServer: ArtemisMessagingServer) =
startClient("driver-local-server-client", localServer.myHostPort)
fun DriverDSLExposedInterface.startClient(remoteNodeInfo: NodeInfo, providedName: String? = null) =
startClient(
providedName = providedName ?: "${remoteNodeInfo.legalIdentity.name}-client",
serverAddress = ArtemisMessagingComponent.toHostAndPort(remoteNodeInfo.address)
)
interface DriverDSLInternalInterface : DriverDSLExposedInterface {
fun start()
fun shutdown()
}
data class NodeInfoAndConfig(val nodeInfo: NodeInfo, val config: Config)
sealed class PortAllocation {
abstract fun nextPort(): Int
fun nextHostAndPort(): HostAndPort = HostAndPort.fromParts("localhost", nextPort())
@ -122,6 +107,7 @@ sealed class PortAllocation {
* and may be specified in [DriverDSL.startNode].
* @param portAllocation The port allocation strategy to use for the messaging and the web server addresses. Defaults to incremental.
* @param debugPortAllocation The port allocation strategy to use for jvm debugging. Defaults to incremental.
* @param useTestClock If true the test clock will be used in Node.
* @param isDebug Indicates whether the spawned nodes should start in jdwt debug mode.
* @param dsl The dsl itself.
* @return The value returned in the [dsl] closure.
@ -130,6 +116,7 @@ fun <A> driver(
baseDirectory: String = "build/${getTimestampAsDirectoryName()}",
portAllocation: PortAllocation = PortAllocation.Incremental(10000),
debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005),
useTestClock: Boolean = false,
isDebug: Boolean = false,
dsl: DriverDSLExposedInterface.() -> A
) = genericDriver(
@ -137,6 +124,7 @@ fun <A> driver(
portAllocation = portAllocation,
debugPortAllocation = debugPortAllocation,
baseDirectory = baseDirectory,
useTestClock = useTestClock,
isDebug = isDebug
),
coerce = { it },
@ -216,17 +204,15 @@ fun <A> poll(pollName: String, pollIntervalMs: Long = 500, warnCount: Int = 120,
return result
}
class DriverDSL(
open class DriverDSL(
val portAllocation: PortAllocation,
val debugPortAllocation: PortAllocation,
val baseDirectory: String,
val useTestClock: Boolean,
val isDebug: Boolean
) : DriverDSLInternalInterface {
override val networkMapCache = InMemoryNetworkMapCache()
private val networkMapName = "NetworkMapService"
private val networkMapAddress = portAllocation.nextHostAndPort()
private var networkMapNodeInfo: NodeInfo? = null
private val identity = generateKeyPair()
class State {
val registeredProcesses = LinkedList<Process>()
@ -284,7 +270,26 @@ class DriverDSL(
addressMustNotBeBound(networkMapAddress)
}
override fun startNode(providedName: String?, advertisedServices: Set<ServiceInfo>): Future<NodeInfo> {
private fun queryNodeInfo(webAddress: HostAndPort): NodeInfo? {
val url = URL("http://${webAddress.toString()}/api/info")
try {
val conn = url.openConnection() as HttpURLConnection
conn.requestMethod = "GET"
if (conn.responseCode != 200) {
return null
}
// For now the NodeInfo is tunneled in its Kryo format over the Node's Web interface.
val om = ObjectMapper()
val module = SimpleModule("NodeInfo")
module.addDeserializer(NodeInfo::class.java, JsonSupport.NodeInfoDeserializer)
om.registerModule(module)
return om.readValue(conn.inputStream, NodeInfo::class.java)
} catch(e: Exception) {
return null
}
}
override fun startNode(providedName: String?, advertisedServices: Set<ServiceInfo>): Future<NodeInfoAndConfig> {
val messagingAddress = portAllocation.nextHostAndPort()
val apiAddress = portAllocation.nextHostAndPort()
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
@ -301,94 +306,19 @@ class DriverDSL(
"artemisAddress" to messagingAddress.toString(),
"webAddress" to apiAddress.toString(),
"extraAdvertisedServiceIds" to advertisedServices.joinToString(","),
"networkMapAddress" to networkMapAddress.toString()
"networkMapAddress" to networkMapAddress.toString(),
"useTestClock" to useTestClock
)
)
return Executors.newSingleThreadExecutor().submit(Callable<NodeInfo> {
return Executors.newSingleThreadExecutor().submit(Callable<NodeInfoAndConfig> {
registerProcess(DriverDSL.startNode(config, quasarJarPath, debugPort))
poll("network map cache for $name") {
networkMapCache.partyNodes.forEach {
if (it.legalIdentity.name == name) {
return@poll it
}
}
null
}
NodeInfoAndConfig(queryNodeInfo(apiAddress)!!, config)
})
}
override fun startClient(
providedName: String,
serverAddress: HostAndPort
): Future<NodeMessagingClient> {
val nodeConfiguration = FullNodeConfiguration(
ConfigHelper.loadConfig(
baseDirectoryPath = Paths.get(baseDirectory, providedName),
allowMissingConfig = true,
configOverrides = mapOf(
"myLegalName" to providedName
)
)
)
val client = NodeMessagingClient(nodeConfiguration,
serverHostPort = serverAddress,
myIdentity = identity.public,
executor = AffinityExecutor.ServiceAffinityExecutor(providedName, 1),
persistentInbox = false // Do not create a permanent queue for our transient UI identity
)
return Executors.newSingleThreadExecutor().submit(Callable<NodeMessagingClient> {
client.configureWithDevSSLCertificate()
client.start(null)
thread { client.run() }
state.locked {
clients.add(client)
}
client
})
}
override fun startLocalServer(): Future<ArtemisMessagingServer> {
val name = "driver-local-server"
val config = FullNodeConfiguration(
ConfigHelper.loadConfig(
baseDirectoryPath = Paths.get(baseDirectory, name),
allowMissingConfig = true,
configOverrides = mapOf(
"myLegalName" to name
)
)
)
val server = ArtemisMessagingServer(config,
portAllocation.nextHostAndPort(),
networkMapCache
)
return Executors.newSingleThreadExecutor().submit(Callable<ArtemisMessagingServer> {
server.configureWithDevSSLCertificate()
server.start()
state.locked {
localServer = server
}
server
})
}
override fun start() {
startNetworkMapService()
val networkMapClient = startClient("driver-$networkMapName-client", networkMapAddress).get()
val networkMapAddr = NodeMessagingClient.makeNetworkMapAddress(networkMapAddress)
networkMapCache.addMapService(networkMapClient, networkMapAddr, true)
networkMapNodeInfo = poll("network map cache for $networkMapName") {
networkMapCache.partyNodes.forEach {
if (it.legalIdentity.name == networkMapName) {
return@poll it
}
}
null
}
}
private fun startNetworkMapService() {
@ -396,7 +326,6 @@ class DriverDSL(
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
val nodeDirectory = "$baseDirectory/$networkMapName"
val config = ConfigHelper.loadConfig(
baseDirectoryPath = Paths.get(nodeDirectory),
allowMissingConfig = true,
@ -405,7 +334,8 @@ class DriverDSL(
"basedir" to Paths.get(nodeDirectory).normalize().toString(),
"artemisAddress" to networkMapAddress.toString(),
"webAddress" to apiAddress.toString(),
"extraAdvertisedServiceIds" to ""
"extraAdvertisedServiceIds" to "",
"useTestClock" to useTestClock
)
)

View File

@ -24,6 +24,8 @@ class APIServerImpl(val node: AbstractNode) : APIServer {
}
}
override fun info() = node.services.myInfo
override fun queryStates(query: StatesQuery): List<StateRef> {
// We're going to hard code two options here for now and assume that all LinearStates are deals
// Would like to maybe move to a model where we take something like a JEXL string, although don't want to develop

View File

@ -209,9 +209,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
// the identity key. But the infrastructure to make that easy isn't here yet.
keyManagement = makeKeyManagementService()
api = APIServerImpl(this@AbstractNode)
scheduler = NodeSchedulerService(database, services)
protocolLogicFactory = initialiseProtocolLogicFactory()
scheduler = NodeSchedulerService(database, services, protocolLogicFactory)
val tokenizableServices = mutableListOf(storage, net, vault, keyManagement, identity, platformClock, scheduler)
@ -435,15 +434,11 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val netwo
protected abstract fun makeMessagingService(): MessagingServiceInternal
protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps?)
protected open fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return DBCheckpointStorage()
}
protected abstract fun startMessagingService(cordaRPCOps: CordaRPCOps)
protected open fun initialiseStorageService(dir: Path): Pair<TxWritableStorageService, CheckpointStorage> {
val attachments = makeAttachmentStorage(dir)
val checkpointStorage = initialiseCheckpointService(dir)
val checkpointStorage = DBCheckpointStorage()
val transactionStorage = DBTransactionStorage()
_servicesThatAcceptUploads += attachments
// Populate the partyKeys set.

View File

@ -4,6 +4,7 @@ import com.codahale.metrics.JmxReporter
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.core.then
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.api.MessagingServiceInternal
@ -119,11 +120,10 @@ class Node(override val configuration: FullNodeConfiguration, networkMapAddress:
}
val legalIdentity = obtainLegalIdentity()
val myIdentityOrNullIfNetworkMapService = if (networkMapService != null) legalIdentity.owningKey else null
return NodeMessagingClient(configuration, serverAddr, myIdentityOrNullIfNetworkMapService, serverThread,
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } })
return NodeMessagingClient(configuration, serverAddr, myIdentityOrNullIfNetworkMapService, serverThread, database)
}
override fun startMessagingService(cordaRPCOps: CordaRPCOps?) {
override fun startMessagingService(cordaRPCOps: CordaRPCOps) {
// Start up the embedded MQ server
messageBroker?.apply {
runOnStop += Runnable { messageBroker?.stop() }
@ -268,23 +268,25 @@ class Node(override val configuration: FullNodeConfiguration, networkMapAddress:
override fun start(): Node {
alreadyRunningNodeCheck()
super.start()
webServer = initWebServer()
// Begin exporting our own metrics via JMX.
JmxReporter.
forRegistry(services.monitoringService.metrics).
inDomain("com.r3cev.corda").
createsObjectNamesWith { type, domain, name ->
// Make the JMX hierarchy a bit better organised.
val category = name.substringBefore('.')
val subName = name.substringAfter('.', "")
if (subName == "")
ObjectName("$domain:name=$category")
else
ObjectName("$domain:type=$category,name=$subName")
}.
build().
start()
// Only start the service API requests once the network map registration is complete
networkMapRegistrationFuture.then {
webServer = initWebServer()
// Begin exporting our own metrics via JMX.
JmxReporter.
forRegistry(services.monitoringService.metrics).
inDomain("com.r3cev.corda").
createsObjectNamesWith { type, domain, name ->
// Make the JMX hierarchy a bit better organised.
val category = name.substringBefore('.')
val subName = name.substringAfter('.', "")
if (subName == "")
ObjectName("$domain:name=$category")
else
ObjectName("$domain:type=$category,name=$subName")
}.
build().
start()
}
shutdownThread = thread(start = false) {
stop()
}

View File

@ -1,25 +0,0 @@
package com.r3corda.node.services.clientapi
import com.r3corda.core.node.CordaPluginRegistry
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.node.services.api.ServiceHubInternal
import com.r3corda.protocols.TwoPartyDealProtocol.Fixer
import com.r3corda.protocols.TwoPartyDealProtocol.Floater
/**
* This is a temporary handler required for establishing random sessionIDs for the [Fixer] and [Floater] as part of
* running scheduled fixings for the [InterestRateSwap] contract.
*
* TODO: This will be replaced with the symmetric session work
*/
object FixingSessionInitiation {
class Plugin: CordaPluginRegistry() {
override val servicePlugins: List<Class<*>> = listOf(Service::class.java)
}
class Service(services: ServiceHubInternal) : SingletonSerializeAsToken() {
init {
services.registerProtocolInitiator(Floater::class) { Fixer(it) }
}
}
}

View File

@ -2,6 +2,7 @@ package com.r3corda.node.services.config
import com.google.common.net.HostAndPort
import com.r3corda.core.crypto.X509Utilities
import com.r3corda.core.exists
import com.r3corda.core.utilities.loggerFor
import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory
@ -89,14 +90,24 @@ fun Config.getProperties(path: String): Properties {
*/
fun NodeSSLConfiguration.configureWithDevSSLCertificate() {
Files.createDirectories(certificatesPath)
if (!Files.exists(trustStorePath)) {
if (!trustStorePath.exists()) {
Files.copy(javaClass.classLoader.getResourceAsStream("com/r3corda/node/internal/certificates/cordatruststore.jks"),
trustStorePath)
}
if (!Files.exists(keyStorePath)) {
if (!keyStorePath.exists()) {
val caKeyStore = X509Utilities.loadKeyStore(
javaClass.classLoader.getResourceAsStream("com/r3corda/node/internal/certificates/cordadevcakeys.jks"),
"cordacadevpass")
X509Utilities.createKeystoreForSSL(keyStorePath, keyStorePassword, keyStorePassword, caKeyStore, "cordacadevkeypass")
}
}
// TODO Move this to CoreTestUtils.kt once we can pry this from the explorer
fun configureTestSSL(): NodeSSLConfiguration = object : NodeSSLConfiguration {
override val certificatesPath = Files.createTempDirectory("certs")
override val keyStorePassword: String get() = "cordacadevpass"
override val trustStorePassword: String get() = "trustpass"
init {
configureWithDevSSLCertificate()
}
}

View File

@ -5,10 +5,13 @@ import com.r3corda.core.div
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.services.ServiceInfo
import com.r3corda.node.internal.Node
import com.r3corda.node.serialization.NodeClock
import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.utilities.TestClock
import com.typesafe.config.Config
import java.nio.file.Path
import java.time.Clock
import java.util.*
interface NodeSSLConfiguration {
@ -46,8 +49,12 @@ class FullNodeConfiguration(config: Config) : NodeConfiguration {
val webAddress: HostAndPort by config
val messagingServerAddress: HostAndPort? by config.getOrElse { null }
val extraAdvertisedServiceIds: String by config
val useTestClock: Boolean by config.getOrElse { false }
fun createNode(): Node {
// This is a sanity feature do not remove.
require(!useTestClock || devMode) { "Cannot use test clock outside of dev mode" }
val advertisedServices = mutableSetOf<ServiceInfo>()
if (!extraAdvertisedServiceIds.isNullOrEmpty()) {
for (serviceId in extraAdvertisedServiceIds.split(",")) {
@ -56,7 +63,7 @@ class FullNodeConfiguration(config: Config) : NodeConfiguration {
}
if (networkMapAddress == null) advertisedServices.add(ServiceInfo(NetworkMapService.type))
val networkMapMessageAddress: SingleMessageRecipient? = if (networkMapAddress == null) null else NodeMessagingClient.makeNetworkMapAddress(networkMapAddress!!)
return Node(this, networkMapMessageAddress, advertisedServices)
return Node(this, networkMapMessageAddress, advertisedServices, if(useTestClock == true) TestClock() else NodeClock())
}
}

View File

@ -45,7 +45,7 @@ import javax.annotation.concurrent.ThreadSafe
@ThreadSafe
class NodeSchedulerService(private val database: Database,
private val services: ServiceHubInternal,
private val protocolLogicRefFactory: ProtocolLogicRefFactory = ProtocolLogicRefFactory(),
private val protocolLogicRefFactory: ProtocolLogicRefFactory,
private val schedulerTimerExecutor: Executor = Executors.newSingleThreadExecutor())
: SchedulerService, SingletonSerializeAsToken() {

View File

@ -15,16 +15,15 @@ import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyAcceptorFactory
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactory
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import java.nio.file.FileSystems
import java.nio.file.Path
import java.security.KeyStore
import java.security.PublicKey
/**
* The base class for Artemis services that defines shared data structures and transport configuration
*
* @param certificatePath A place where Artemis can stash its message journal and other files.
* @param config The config object is used to pass in the passwords for the certificate KeyStore and TrustStore
*/
abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : SingletonSerializeAsToken() {
abstract class ArtemisMessagingComponent() : SingletonSerializeAsToken() {
companion object {
init {
@ -36,7 +35,7 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
const val RPC_REQUESTS_QUEUE = "rpc.requests"
@JvmStatic
protected val NETWORK_MAP_ADDRESS = SimpleString(PEERS_PREFIX +"networkmap")
protected val NETWORK_MAP_ADDRESS = SimpleString("${PEERS_PREFIX}networkmap")
/**
* Assuming the passed in target address is actually an ArtemisAddress will extract the host and port of the node. This should
@ -70,7 +69,7 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
}
protected data class NetworkMapAddress(override val hostAndPort: HostAndPort) : SingleMessageRecipient, ArtemisAddress {
override val queueName: SimpleString = NETWORK_MAP_ADDRESS
override val queueName: SimpleString get() = NETWORK_MAP_ADDRESS
}
/**
@ -80,12 +79,12 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
*/
data class NodeAddress(val identity: PublicKey, override val hostAndPort: HostAndPort) : SingleMessageRecipient, ArtemisAddress {
override val queueName: SimpleString by lazy { SimpleString(PEERS_PREFIX+identity.toBase58String()) }
override fun toString(): String {
return "NodeAddress(identity = $queueName, $hostAndPort"
}
override fun toString(): String = "${javaClass.simpleName}(identity = $queueName, $hostAndPort)"
}
/** The config object is used to pass in the passwords for the certificate KeyStore and TrustStore */
abstract val config: NodeSSLConfiguration
protected fun parseKeyFromQueueName(name: String): PublicKey {
require(name.startsWith(PEERS_PREFIX))
return parsePublicKeyBase58(name.substring(PEERS_PREFIX.length))
@ -119,39 +118,46 @@ abstract class ArtemisMessagingComponent(val config: NodeSSLConfiguration) : Sin
}
}
protected fun tcpTransport(direction: ConnectionDirection, host: String, port: Int) =
TransportConfiguration(
when (direction) {
ConnectionDirection.INBOUND -> NettyAcceptorFactory::class.java.name
ConnectionDirection.OUTBOUND -> NettyConnectorFactory::class.java.name
},
mapOf(
// Basic TCP target details
TransportConstants.HOST_PROP_NAME to host,
TransportConstants.PORT_PROP_NAME to port.toInt(),
protected fun tcpTransport(direction: ConnectionDirection, host: String, port: Int): TransportConfiguration {
config.keyStorePath.expectedOnDefaultFileSystem()
config.trustStorePath.expectedOnDefaultFileSystem()
return TransportConfiguration(
when (direction) {
ConnectionDirection.INBOUND -> NettyAcceptorFactory::class.java.name
ConnectionDirection.OUTBOUND -> NettyConnectorFactory::class.java.name
},
mapOf(
// Basic TCP target details
TransportConstants.HOST_PROP_NAME to host,
TransportConstants.PORT_PROP_NAME to port.toInt(),
// Turn on AMQP support, which needs the protocol jar on the classpath.
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop
// It does not use AMQP messages for its own messages e.g. topology and heartbeats
// TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications
TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP",
// Turn on AMQP support, which needs the protocol jar on the classpath.
// Unfortunately we cannot disable core protocol as artemis only uses AMQP for interop
// It does not use AMQP messages for its own messages e.g. topology and heartbeats
// TODO further investigate how to ensure we use a well defined wire level protocol for Node to Node communications
TransportConstants.PROTOCOLS_PROP_NAME to "CORE,AMQP",
// Enable TLS transport layer with client certs and restrict to at least SHA256 in handshake
// and AES encryption
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to config.keyStorePath,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to config.keyStorePassword, // TODO proper management of keystores and password
TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to config.trustStorePath,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to config.trustStorePassword,
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to "TLSv1.2",
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true
)
)
// Enable TLS transport layer with client certs and restrict to at least SHA256 in handshake
// and AES encryption
TransportConstants.SSL_ENABLED_PROP_NAME to true,
TransportConstants.KEYSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.KEYSTORE_PATH_PROP_NAME to config.keyStorePath,
TransportConstants.KEYSTORE_PASSWORD_PROP_NAME to config.keyStorePassword, // TODO proper management of keystores and password
TransportConstants.TRUSTSTORE_PROVIDER_PROP_NAME to "JKS",
TransportConstants.TRUSTSTORE_PATH_PROP_NAME to config.trustStorePath,
TransportConstants.TRUSTSTORE_PASSWORD_PROP_NAME to config.trustStorePassword,
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to "TLSv1.2",
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true
)
)
}
fun configureWithDevSSLCertificate() {
config.configureWithDevSSLCertificate()
}
protected fun Path.expectedOnDefaultFileSystem() {
require(fileSystem == FileSystems.getDefault()) { "Artemis only uses the default file system" }
}
}

View File

@ -3,11 +3,15 @@ package com.r3corda.node.services.messaging
import com.google.common.net.HostAndPort
import com.r3corda.core.ThreadBox
import com.r3corda.core.crypto.AddressFormatException
import com.r3corda.core.crypto.newSecureRandom
import com.r3corda.core.div
import com.r3corda.core.exists
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.services.NetworkMapCache
import com.r3corda.core.use
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingServer.NodeLoginModule.Companion.NODE_ROLE_NAME
import com.r3corda.node.services.messaging.ArtemisMessagingServer.NodeLoginModule.Companion.RPC_ROLE_NAME
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.core.config.BridgeConfiguration
import org.apache.activemq.artemis.core.config.Configuration
@ -17,11 +21,25 @@ import org.apache.activemq.artemis.core.security.Role
import org.apache.activemq.artemis.core.server.ActiveMQServer
import org.apache.activemq.artemis.core.server.impl.ActiveMQServerImpl
import org.apache.activemq.artemis.spi.core.security.ActiveMQJAASSecurityManager
import org.apache.activemq.artemis.spi.core.security.jaas.InVMLoginModule
import org.apache.activemq.artemis.spi.core.security.jaas.RolePrincipal
import org.apache.activemq.artemis.spi.core.security.jaas.UserPrincipal
import rx.Subscription
import java.math.BigInteger
import java.io.IOException
import java.nio.file.Files
import java.nio.file.Path
import java.security.Principal
import java.util.*
import javax.annotation.concurrent.ThreadSafe
import javax.security.auth.Subject
import javax.security.auth.callback.CallbackHandler
import javax.security.auth.callback.NameCallback
import javax.security.auth.callback.PasswordCallback
import javax.security.auth.callback.UnsupportedCallbackException
import javax.security.auth.login.AppConfigurationEntry
import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag.REQUIRED
import javax.security.auth.login.FailedLoginException
import javax.security.auth.login.LoginException
import javax.security.auth.spi.LoginModule
// TODO: Verify that nobody can connect to us and fiddle with our config over the socket due to the secman.
// TODO: Implement a discovery engine that can trigger builds of new connections when another node registers? (later)
@ -37,9 +55,9 @@ import javax.annotation.concurrent.ThreadSafe
* a fully connected network, trusted network or on localhost.
*/
@ThreadSafe
class ArtemisMessagingServer(config: NodeConfiguration,
class ArtemisMessagingServer(override val config: NodeConfiguration,
val myHostPort: HostAndPort,
val networkMapCache: NetworkMapCache) : ArtemisMessagingComponent(config) {
val networkMapCache: NetworkMapCache) : ArtemisMessagingComponent() {
companion object {
val log = loggerFor<ArtemisMessagingServer>()
}
@ -52,6 +70,10 @@ class ArtemisMessagingServer(config: NodeConfiguration,
private lateinit var activeMQServer: ActiveMQServer
private var networkChangeHandle: Subscription? = null
init {
config.basedir.expectedOnDefaultFileSystem()
}
fun start() = mutex.locked {
if (!running) {
configureAndStartServer()
@ -116,12 +138,7 @@ class ArtemisMessagingServer(config: NodeConfiguration,
}
private fun configureAndStartServer() {
val config = createArtemisConfig(config.certificatesPath, myHostPort).apply {
securityRoles = mapOf(
"#" to setOf(Role("internal", true, true, true, true, true, true, true))
)
}
val config = createArtemisConfig()
val securityManager = createArtemisSecurityManager()
activeMQServer = ActiveMQServerImpl(config, securityManager).apply {
@ -157,28 +174,61 @@ class ArtemisMessagingServer(config: NodeConfiguration,
activeMQServer.start()
}
private fun createArtemisConfig(directory: Path, hp: HostAndPort): Configuration {
val config = ConfigurationImpl()
setConfigDirectories(config, directory)
config.acceptorConfigurations = setOf(
tcpTransport(ConnectionDirection.INBOUND, "0.0.0.0", hp.port)
private fun createArtemisConfig(): Configuration = ConfigurationImpl().apply {
val artemisDir = config.basedir / "artemis"
bindingsDirectory = (artemisDir / "bindings").toString()
journalDirectory = (artemisDir / "journal").toString()
largeMessagesDirectory = (artemisDir / "largemessages").toString()
acceptorConfigurations = setOf(
tcpTransport(ConnectionDirection.INBOUND, "0.0.0.0", myHostPort.port)
)
// Enable built in message deduplication. Note we still have to do our own as the delayed commits
// and our own definition of commit mean that the built in deduplication cannot remove all duplicates.
config.idCacheSize = 2000 // Artemis Default duplicate cache size i.e. a guess
config.isPersistIDCache = true
return config
idCacheSize = 2000 // Artemis Default duplicate cache size i.e. a guess
isPersistIDCache = true
isPopulateValidatedUser = true
setupUserRoles()
}
// This gives nodes full access and RPC clients only enough to do RPC
private fun ConfigurationImpl.setupUserRoles() {
// TODO COR-307
val nodeRole = Role(NODE_ROLE_NAME, true, true, true, true, true, true, true, true)
val clientRpcRole = restrictedRole(RPC_ROLE_NAME, consume = true, createNonDurableQueue = true, deleteNonDurableQueue = true)
securityRoles = mapOf(
"#" to setOf(nodeRole),
"clients.*.rpc.responses.*" to setOf(nodeRole, clientRpcRole),
"clients.*.rpc.observations.*" to setOf(nodeRole, clientRpcRole),
RPC_REQUESTS_QUEUE to setOf(nodeRole, restrictedRole(RPC_ROLE_NAME, send = true))
)
}
private fun restrictedRole(name: String, send: Boolean = false, consume: Boolean = false, createDurableQueue: Boolean = false,
deleteDurableQueue: Boolean = false, createNonDurableQueue: Boolean = false,
deleteNonDurableQueue: Boolean = false, manage: Boolean = false, browse: Boolean = false): Role {
return Role(name, send, consume, createDurableQueue, deleteDurableQueue, createNonDurableQueue,
deleteNonDurableQueue, manage, browse)
}
private fun createArtemisSecurityManager(): ActiveMQJAASSecurityManager {
// TODO: set up proper security configuration https://r3-cev.atlassian.net/browse/COR-307
val securityConfig = SecurityConfiguration().apply {
addUser("internal", BigInteger(128, newSecureRandom()).toString(16))
addRole("internal", "internal")
defaultUser = "internal"
val rpcUsersFile = config.basedir / "rpc-users.properties"
if (!rpcUsersFile.exists()) {
val users = Properties()
users["user1"] = "test"
Files.newOutputStream(rpcUsersFile).use {
users.store(it, null)
}
}
return ActiveMQJAASSecurityManager(InVMLoginModule::class.java.name, securityConfig)
val securityConfig = object : SecurityConfiguration() {
// Override to make it work with our login module
override fun getAppConfigurationEntry(name: String): Array<AppConfigurationEntry> {
val options = mapOf(NodeLoginModule.FILE_KEY to rpcUsersFile)
return arrayOf(AppConfigurationEntry(name, REQUIRED, options))
}
}
return ActiveMQJAASSecurityManager(NodeLoginModule::class.java.name, securityConfig)
}
private fun connectorExists(hostAndPort: HostAndPort) = hostAndPort.toString() in activeMQServer.configuration.connectorConfigurations
@ -194,12 +244,11 @@ class ArtemisMessagingServer(config: NodeConfiguration,
private fun bridgeExists(name: SimpleString) = activeMQServer.clusterManager.bridges.containsKey(name.toString())
private fun deployBridge(hostAndPort: HostAndPort, name: SimpleString) {
private fun deployBridge(hostAndPort: HostAndPort, name: String) {
activeMQServer.deployBridge(BridgeConfiguration().apply {
val nameStr = name.toString()
setName(nameStr)
queueName = nameStr
forwardingAddress = nameStr
setName(name)
queueName = name
forwardingAddress = name
staticConnectors = listOf(hostAndPort.toString())
confirmationWindowSize = 100000 // a guess
isUseDuplicateDetection = true // Enable the bridges automatic deduplication logic
@ -218,7 +267,7 @@ class ArtemisMessagingServer(config: NodeConfiguration,
if (!connectorExists(hostAndPort))
addConnector(hostAndPort)
if (!bridgeExists(name))
deployBridge(hostAndPort, name)
deployBridge(hostAndPort, name.toString())
}
private fun maybeDestroyBridge(name: SimpleString) {
@ -227,11 +276,81 @@ class ArtemisMessagingServer(config: NodeConfiguration,
}
}
private fun setConfigDirectories(config: Configuration, dir: Path) {
config.apply {
bindingsDirectory = dir.resolve("bindings").toString()
journalDirectory = dir.resolve("journal").toString()
largeMessagesDirectory = dir.resolve("largemessages").toString()
class NodeLoginModule : LoginModule {
companion object {
const val FILE_KEY = "rpc-users-file"
const val NODE_ROLE_NAME = "NodeRole"
const val RPC_ROLE_NAME = "RpcRole"
}
private val users = Properties()
private var loginSucceeded: Boolean = false
private lateinit var subject: Subject
private lateinit var callbackHandler: CallbackHandler
private lateinit var principals: List<Principal>
override fun initialize(subject: Subject, callbackHandler: CallbackHandler, sharedState: Map<String, *>, options: Map<String, *>) {
this.subject = subject
this.callbackHandler = callbackHandler
val rpcUsersFile = options[FILE_KEY] as Path
if (rpcUsersFile.exists()) {
rpcUsersFile.use {
users.load(it)
}
}
}
override fun login(): Boolean {
val nameCallback = NameCallback("Username: ")
val passwordCallback = PasswordCallback("Password: ", false)
try {
callbackHandler.handle(arrayOf(nameCallback, passwordCallback))
} catch (e: IOException) {
throw LoginException(e.message)
} catch (e: UnsupportedCallbackException) {
throw LoginException("${e.message} not available to obtain information from user")
}
val username = nameCallback.name ?: throw FailedLoginException("User name is null")
val receivedPassword = passwordCallback.password ?: throw FailedLoginException("Password is null")
val password = if (username == "Node") "Node" else users[username] ?: throw FailedLoginException("User does not exist")
if (password != String(receivedPassword)) {
throw FailedLoginException("Password does not match")
}
principals = listOf(
UserPrincipal(username),
RolePrincipal(if (username == "Node") NODE_ROLE_NAME else RPC_ROLE_NAME))
loginSucceeded = true
return loginSucceeded
}
override fun commit(): Boolean {
val result = loginSucceeded
if (result) {
subject.principals.addAll(principals)
}
clear()
return result
}
override fun abort(): Boolean {
clear()
return true
}
override fun logout(): Boolean {
subject.principals.removeAll(principals)
return true
}
private fun clear() {
loginSucceeded = false
}
}
}

View File

@ -11,11 +11,13 @@ import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.utilities.*
import org.apache.activemq.artemis.api.core.ActiveMQObjectClosedException
import org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID
import org.apache.activemq.artemis.api.core.Message.HDR_VALIDATED_USER
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.*
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.ResultRow
import org.jetbrains.exposed.sql.statements.InsertStatement
import java.nio.file.FileSystems
import java.security.PublicKey
import java.time.Instant
import java.util.*
@ -49,12 +51,11 @@ import javax.annotation.concurrent.ThreadSafe
* in this class.
*/
@ThreadSafe
class NodeMessagingClient(config: NodeConfiguration,
class NodeMessagingClient(override val config: NodeConfiguration,
val serverHostPort: HostAndPort,
val myIdentity: PublicKey?,
val executor: AffinityExecutor,
val persistentInbox: Boolean = true,
val persistenceTx: (() -> Unit) -> Unit = { it() }) : ArtemisMessagingComponent(config), MessagingServiceInternal {
val database: Database) : ArtemisMessagingComponent(), MessagingServiceInternal {
companion object {
val log = loggerFor<NodeMessagingClient>()
@ -86,8 +87,7 @@ class NodeMessagingClient(config: NodeConfiguration,
var rpcConsumer: ClientConsumer? = null
var rpcNotificationConsumer: ClientConsumer? = null
// TODO: This is not robust and needs to be replaced by more intelligently using the message queue server.
var undeliveredMessages = listOf<Message>()
var pendingRedelivery = JDBCHashSet<Message>("pending_messages",loadOnInit = true)
}
/** A registration to handle messages of different types */
@ -106,23 +106,16 @@ class NodeMessagingClient(config: NodeConfiguration,
val uuid = uuidString("message_id")
}
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(if (persistentInbox) {
private val processedMessages: MutableSet<UUID> = Collections.synchronizedSet(
object : AbstractJDBCHashSet<UUID, Table>(Table, loadOnInit = true) {
override fun elementFromRow(row: ResultRow): UUID = row[table.uuid]
override fun addElementToInsert(insert: InsertStatement, entry: UUID, finalizables: MutableList<() -> Unit>) {
insert[table.uuid] = entry
}
}
} else {
HashSet<UUID>()
})
})
init {
require(config.basedir.fileSystem == FileSystems.getDefault()) { "Artemis only uses the default file system" }
}
fun start(rpcOps: CordaRPCOps? = null) {
fun start(rpcOps: CordaRPCOps) {
state.locked {
check(!started) { "start can't be called twice" }
started = true
@ -135,7 +128,7 @@ class NodeMessagingClient(config: NodeConfiguration,
// Create a session. Note that the acknowledgement of messages is not flushed to
// the Artermis journal until the default buffer size of 1MB is acknowledged.
val session = clientFactory!!.createSession(true, true, ActiveMQClient.DEFAULT_ACK_BATCH_SIZE)
val session = clientFactory!!.createSession("Node", "Node", false, true, true, locator.isPreAcknowledge, ActiveMQClient.DEFAULT_ACK_BATCH_SIZE)
this.session = session
session.start()
@ -146,7 +139,7 @@ class NodeMessagingClient(config: NodeConfiguration,
val queueName = toQueueName(myAddress)
val query = session.queueQuery(queueName)
if (!query.isExists) {
session.createQueue(queueName, queueName, persistentInbox)
session.createQueue(queueName, queueName, true)
}
knownQueues.add(queueName)
p2pConsumer = session.createConsumer(queueName)
@ -154,13 +147,11 @@ class NodeMessagingClient(config: NodeConfiguration,
// Create an RPC queue and consumer: this will service locally connected clients only (not via a
// bridge) and those clients must have authenticated. We could use a single consumer for everything
// and perhaps we should, but these queues are not worth persisting.
if (rpcOps != null) {
session.createTemporaryQueue(RPC_REQUESTS_QUEUE, RPC_REQUESTS_QUEUE)
session.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 1")
rpcConsumer = session.createConsumer(RPC_REQUESTS_QUEUE)
rpcNotificationConsumer = session.createConsumer("rpc.qremovals")
dispatcher = createRPCDispatcher(state, rpcOps)
}
session.createTemporaryQueue(RPC_REQUESTS_QUEUE, RPC_REQUESTS_QUEUE)
session.createTemporaryQueue("activemq.notifications", "rpc.qremovals", "_AMQ_NotifType = 1")
rpcConsumer = session.createConsumer(RPC_REQUESTS_QUEUE)
rpcNotificationConsumer = session.createConsumer("rpc.qremovals")
dispatcher = createRPCDispatcher(state, rpcOps)
}
}
@ -227,8 +218,9 @@ class NodeMessagingClient(config: NodeConfiguration,
val topic = message.getStringProperty(TOPIC_PROPERTY)
val sessionID = message.getLongProperty(SESSION_ID_PROPERTY)
// Use the magic deduplication property built into Artemis as our message identity too
val uuid = UUID.fromString(message.getStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID))
log.info("received message from: ${message.address} topic: $topic sessionID: $sessionID uuid: $uuid")
val uuid = UUID.fromString(message.getStringProperty(HDR_DUPLICATE_DETECTION_ID))
val user = message.getStringProperty(HDR_VALIDATED_USER)
log.info("Received message from: ${message.address} user: $user topic: $topic sessionID: $sessionID uuid: $uuid")
val body = ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) }
@ -259,10 +251,10 @@ class NodeMessagingClient(config: NodeConfiguration,
// without causing log spam.
log.warn("Received message for ${msg.topicSession} that doesn't have any registered handlers yet")
// This is a hack; transient messages held in memory isn't crash resistant.
// TODO: Use Artemis API more effectively so we don't pop messages off a queue that we aren't ready to use.
state.locked {
undeliveredMessages += msg
databaseTransaction(database) {
pendingRedelivery.add(msg)
}
}
return false
}
@ -277,7 +269,7 @@ class NodeMessagingClient(config: NodeConfiguration,
// Note that handlers may re-enter this class. We aren't holding any locks and methods like
// start/run/stop have re-entrancy assertions at the top, so it is OK.
executor.fetchFrom {
persistenceTx {
databaseTransaction(database) {
callHandlers(msg, deliverTo)
}
}
@ -346,7 +338,7 @@ class NodeMessagingClient(config: NodeConfiguration,
putLongProperty(SESSION_ID_PROPERTY, sessionID)
writeBodyBufferBytes(message.data)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
if (knownQueues.add(queueName)) {
@ -376,9 +368,12 @@ class NodeMessagingClient(config: NodeConfiguration,
val handler = Handler(topicSession, callback)
handlers.add(handler)
val messagesToRedeliver = state.locked {
val messagesToRedeliver = undeliveredMessages
undeliveredMessages = listOf()
messagesToRedeliver
val pending = ArrayList<Message>()
databaseTransaction(database) {
pending.addAll(pendingRedelivery)
pendingRedelivery.clear()
}
pending
}
messagesToRedeliver.forEach { deliver(it) }
return handler
@ -391,8 +386,8 @@ class NodeMessagingClient(config: NodeConfiguration,
override fun createMessage(topicSession: TopicSession, data: ByteArray, uuid: UUID): Message {
// TODO: We could write an object that proxies directly to an underlying MQ message here and avoid copying.
return object : Message {
override val topicSession: TopicSession get() = topicSession
override val data: ByteArray get() = data
override val topicSession: TopicSession = topicSession
override val data: ByteArray = data
override val debugTimestamp: Instant = Instant.now()
override fun serialise(): ByteArray = this.serialise()
override val uniqueMessageId: UUID = uuid
@ -408,7 +403,7 @@ class NodeMessagingClient(config: NodeConfiguration,
val msg = session!!.createMessage(false).apply {
writeBodyBufferBytes(bits.bits)
// Use the magic deduplication property built into Artemis as our message identity too
putStringProperty(org.apache.activemq.artemis.api.core.Message.HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString()))
}
producer!!.send(toAddress, msg)
}

View File

@ -1,73 +0,0 @@
package com.r3corda.node.services.persistence
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint
import com.r3corda.node.services.api.CheckpointStorage
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.StandardCopyOption
import java.util.*
import java.util.Collections.synchronizedMap
import javax.annotation.concurrent.ThreadSafe
/**
* File-based checkpoint storage, storing checkpoints per file.
*/
@ThreadSafe
class PerFileCheckpointStorage(val storeDir: Path) : CheckpointStorage {
companion object {
private val logger = loggerFor<PerFileCheckpointStorage>()
private val fileExtension = ".checkpoint"
}
private val checkpointFiles = synchronizedMap(IdentityHashMap<Checkpoint, Path>())
init {
logger.trace { "Initialising per file checkpoint storage on $storeDir" }
Files.createDirectories(storeDir)
Files.list(storeDir)
.filter { it.toString().toLowerCase().endsWith(fileExtension) }
.forEach {
val checkpoint = Files.readAllBytes(it).deserialize<Checkpoint>()
checkpointFiles[checkpoint] = it
}
}
override fun addCheckpoint(checkpoint: Checkpoint) {
val fileName = "${checkpoint.id.toString().toLowerCase()}$fileExtension"
val checkpointFile = storeDir.resolve(fileName)
atomicWrite(checkpointFile, checkpoint.serialize())
logger.trace { "Stored $checkpoint to $checkpointFile" }
checkpointFiles[checkpoint] = checkpointFile
}
private fun atomicWrite(checkpointFile: Path, serialisedCheckpoint: SerializedBytes<Checkpoint>) {
val tempCheckpointFile = checkpointFile.parent.resolve("${checkpointFile.fileName}.tmp")
serialisedCheckpoint.writeToFile(tempCheckpointFile)
Files.move(tempCheckpointFile, checkpointFile, StandardCopyOption.ATOMIC_MOVE)
}
override fun removeCheckpoint(checkpoint: Checkpoint) {
val checkpointFile = checkpointFiles.remove(checkpoint)
require(checkpointFile != null) { "Trying to removing unknown checkpoint: $checkpoint" }
Files.delete(checkpointFile)
logger.trace { "Removed $checkpoint ($checkpointFile)" }
}
override fun forEach(block: (Checkpoint)->Boolean) {
synchronized(checkpointFiles) {
for(checkpoint in checkpointFiles.keys) {
if (!block(checkpoint)) {
break
}
}
}
}
}

View File

@ -1,70 +0,0 @@
package com.r3corda.node.services.persistence
import com.r3corda.core.ThreadBox
import com.r3corda.core.bufferUntilSubscribed
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.node.services.TransactionStorage
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.loggerFor
import com.r3corda.core.utilities.trace
import rx.Observable
import rx.subjects.PublishSubject
import java.nio.file.Files
import java.nio.file.Path
import java.util.*
import javax.annotation.concurrent.ThreadSafe
/**
* File-based transaction storage, storing transactions per file.
*/
@ThreadSafe
class PerFileTransactionStorage(val storeDir: Path) : TransactionStorage {
companion object {
private val logger = loggerFor<PerFileCheckpointStorage>()
private val fileExtension = ".transaction"
}
private val mutex = ThreadBox(object {
val transactionsMap = HashMap<SecureHash, SignedTransaction>()
val updatesPublisher = PublishSubject.create<SignedTransaction>()
fun notify(transaction: SignedTransaction) = updatesPublisher.onNext(transaction)
})
override val updates: Observable<SignedTransaction>
get() = mutex.content.updatesPublisher
init {
logger.trace { "Initialising per file transaction storage on $storeDir" }
Files.createDirectories(storeDir)
mutex.locked {
Files.list(storeDir)
.filter { it.toString().toLowerCase().endsWith(fileExtension) }
.map { Files.readAllBytes(it).deserialize<SignedTransaction>() }
.forEach { transactionsMap[it.id] = it }
}
}
override fun addTransaction(transaction: SignedTransaction) {
val transactionFile = storeDir.resolve("${transaction.id.toString().toLowerCase()}$fileExtension")
transaction.serialize().writeToFile(transactionFile)
mutex.locked {
transactionsMap[transaction.id] = transaction
notify(transaction)
}
logger.trace { "Stored $transaction to $transactionFile" }
}
override fun getTransaction(id: SecureHash): SignedTransaction? = mutex.locked { transactionsMap[id] }
val transactions: Iterable<SignedTransaction> get() = mutex.locked { transactionsMap.values.toList() }
override fun track(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
return mutex.locked {
Pair(transactionsMap.values.toList(), updates.bufferUntilSubscribed())
}
}
}

View File

@ -11,7 +11,10 @@ import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.module.kotlin.KotlinModule
import com.r3corda.core.contracts.BusinessCalendar
import com.r3corda.core.crypto.*
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.IdentityService
import com.r3corda.core.serialization.deserialize
import com.r3corda.core.serialization.serialize
import net.i2p.crypto.eddsa.EdDSAPublicKey
import java.math.BigDecimal
import java.time.LocalDate
@ -54,6 +57,11 @@ object JsonSupport {
cordaModule.addSerializer(PublicKeyTree::class.java, PublicKeyTreeSerializer)
cordaModule.addDeserializer(PublicKeyTree::class.java, PublicKeyTreeDeserializer)
// For NodeInfo
// TODO this tunnels the Kryo representation as a Base58 encoded string. Replace when RPC supports this.
cordaModule.addSerializer(NodeInfo::class.java, NodeInfoSerializer)
cordaModule.addDeserializer(NodeInfo::class.java, NodeInfoDeserializer)
mapper.registerModule(timeModule)
mapper.registerModule(cordaModule)
mapper.registerModule(KotlinModule())
@ -102,6 +110,25 @@ object JsonSupport {
}
}
object NodeInfoSerializer : JsonSerializer<NodeInfo>() {
override fun serialize(value: NodeInfo, gen: JsonGenerator, serializers: SerializerProvider) {
gen.writeString(Base58.encode(value.serialize().bits))
}
}
object NodeInfoDeserializer : JsonDeserializer<NodeInfo>() {
override fun deserialize(parser: JsonParser, context: DeserializationContext): NodeInfo {
if (parser.currentToken == JsonToken.FIELD_NAME) {
parser.nextToken()
}
try {
return Base58.decode(parser.text).deserialize<NodeInfo>()
} catch (e: Exception) {
throw JsonParseException(parser, "Invalid NodeInfo ${parser.text}: ${e.message}")
}
}
}
object SecureHashSerializer : JsonSerializer<SecureHash>() {
override fun serialize(obj: SecureHash, generator: JsonGenerator, provider: SerializerProvider) {
generator.writeString(obj.toString())

View File

@ -0,0 +1,43 @@
package com.r3corda.node.utilities
import com.r3corda.core.serialization.SerializeAsToken
import com.r3corda.core.serialization.SerializeAsTokenContext
import com.r3corda.core.serialization.SingletonSerializationToken
import java.time.*
import javax.annotation.concurrent.ThreadSafe
/**
* A [Clock] that can have the date advanced for use in demos.
*/
@ThreadSafe
class TestClock(private var delegateClock: Clock = Clock.systemUTC()) : MutableClock(), SerializeAsToken {
private val token = SingletonSerializationToken(this)
override fun toToken(context: SerializeAsTokenContext) = SingletonSerializationToken.registerWithContext(token, this, context)
@Synchronized fun updateDate(date: LocalDate): Boolean {
val currentDate = LocalDate.now(this)
if (currentDate.isBefore(date)) {
// It's ok to increment
delegateClock = Clock.offset(delegateClock, Duration.between(currentDate.atStartOfDay(), date.atStartOfDay()))
notifyMutationObservers()
return true
}
return false
}
@Synchronized override fun instant(): Instant {
return delegateClock.instant()
}
// Do not use this. Instead seek to use ZonedDateTime methods.
override fun withZone(zone: ZoneId): Clock {
throw UnsupportedOperationException("Tokenized clock does not support withZone()")
}
@Synchronized override fun getZone(): ZoneId {
return delegateClock.zone
}
}

View File

@ -1,4 +1,3 @@
# Register a ServiceLoader service extending from com.r3corda.node.CordaPluginRegistry
com.r3corda.node.services.clientapi.FixingSessionInitiation$Plugin
com.r3corda.node.services.NotaryChange$Plugin
com.r3corda.node.services.persistence.DataVending$Plugin

View File

@ -13,4 +13,5 @@ dataSourceProperties = {
devMode = true
certificateSigningService = "https://cordaci-netperm.corda.r3cev.com"
useHTTPS = false
h2port = 0
h2port = 0
useTestClock = false

View File

@ -155,16 +155,14 @@ class TwoPartyTradeProtocolTests {
bobNode.pumpReceive()
// OK, now Bob has sent the partial transaction back to Alice and is waiting for Alice's signature.
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
databaseTransaction(bobNode.database) {
assertThat(bobNode.checkpointStorage.checkpoints()).hasSize(1)
}
val storage = bobNode.storage.validatedTransactions
val bobTransactionsBeforeCrash = if (storage is PerFileTransactionStorage) {
storage.transactions
} else if (storage is DBTransactionStorage) {
databaseTransaction(bobNode.database) {
storage.transactions
}
} else throw IllegalArgumentException("Unknown storage implementation")
val bobTransactionsBeforeCrash = databaseTransaction(bobNode.database) {
(storage as DBTransactionStorage).transactions
}
assertThat(bobTransactionsBeforeCrash).isNotEmpty()
// .. and let's imagine that Bob's computer has a power cut. He now has nothing now beyond what was on disk.

View File

@ -1,22 +1,37 @@
package com.r3corda.node.services
import com.google.common.net.HostAndPort
import com.r3corda.core.contracts.ClientToServiceCommand
import com.r3corda.core.contracts.ContractState
import com.r3corda.core.contracts.StateAndRef
import com.r3corda.core.crypto.generateKeyPair
import com.r3corda.core.messaging.Message
import com.r3corda.core.messaging.createMessage
import com.r3corda.core.node.NodeInfo
import com.r3corda.core.node.services.DEFAULT_SESSION_ID
import com.r3corda.core.node.services.NetworkMapCache
import com.r3corda.core.node.services.StateMachineTransactionMapping
import com.r3corda.core.node.services.Vault
import com.r3corda.core.transactions.SignedTransaction
import com.r3corda.core.utilities.LogHelper
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingServer
import com.r3corda.node.services.messaging.NodeMessagingClient
import com.r3corda.node.services.messaging.*
import com.r3corda.node.services.network.InMemoryNetworkMapCache
import com.r3corda.node.services.transactions.PersistentUniquenessProvider
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.configureDatabase
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.freeLocalHostAndPort
import com.r3corda.testing.node.makeTestDataSourceProperties
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.jetbrains.exposed.sql.Database
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import rx.Observable
import java.io.Closeable
import java.net.ServerSocket
import java.nio.file.Path
import java.util.concurrent.LinkedBlockingQueue
@ -33,12 +48,44 @@ class ArtemisMessagingTests {
val identity = generateKeyPair()
lateinit var config: NodeConfiguration
lateinit var dataSource: Closeable
lateinit var database: Database
var messagingClient: NodeMessagingClient? = null
var messagingServer: ArtemisMessagingServer? = null
val networkMapCache = InMemoryNetworkMapCache()
val rpcOps = object : CordaRPCOps {
override val protocolVersion: Int
get() = throw UnsupportedOperationException()
override fun stateMachinesAndUpdates(): Pair<List<StateMachineInfo>, Observable<StateMachineUpdate>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun vaultAndUpdates(): Pair<List<StateAndRef<ContractState>>, Observable<Vault.Update>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun verifiedTransactions(): Pair<List<SignedTransaction>, Observable<SignedTransaction>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun stateMachineRecordedTransactionMapping(): Pair<List<StateMachineTransactionMapping>, Observable<StateMachineTransactionMapping>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun networkMapUpdates(): Pair<List<NodeInfo>, Observable<NetworkMapCache.MapChange>> {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun executeCommand(command: ClientToServiceCommand): TransactionBuildResult {
throw UnsupportedOperationException("not implemented") //To change body of created functions use File | Settings | File Templates.
}
}
@Before
fun setUp() {
// TODO: create a base class that provides a default implementation
@ -52,12 +99,18 @@ class ArtemisMessagingTests {
override val keyStorePassword: String = "testpass"
override val trustStorePassword: String = "trustpass"
}
LogHelper.setLevel(PersistentUniquenessProvider::class)
val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties())
dataSource = dataSourceAndDatabase.first
database = dataSourceAndDatabase.second
}
@After
fun cleanUp() {
messagingClient?.stop()
messagingServer?.stop()
dataSource.close()
LogHelper.reset(PersistentUniquenessProvider::class)
}
@Test
@ -73,7 +126,7 @@ class ArtemisMessagingTests {
val remoteServerAddress = freeLocalHostAndPort()
createMessagingServer(remoteServerAddress).start()
createMessagingClient(server = remoteServerAddress).start()
createMessagingClient(server = remoteServerAddress).start(rpcOps)
}
@Test
@ -84,14 +137,14 @@ class ArtemisMessagingTests {
createMessagingServer(serverAddress).start()
messagingClient = createMessagingClient(server = invalidServerAddress)
assertThatThrownBy { messagingClient!!.start() }
assertThatThrownBy { messagingClient!!.start(rpcOps) }
messagingClient = null
}
@Test
fun `client should connect to local server`() {
createMessagingServer().start()
createMessagingClient().start()
createMessagingClient().start(rpcOps)
}
@Test
@ -101,7 +154,7 @@ class ArtemisMessagingTests {
createMessagingServer().start()
val messagingClient = createMessagingClient()
messagingClient.start()
messagingClient.start(rpcOps)
thread { messagingClient.run() }
messagingClient.addMessageHandler(topic) { message, r ->
@ -117,9 +170,11 @@ class ArtemisMessagingTests {
}
private fun createMessagingClient(server: HostAndPort = hostAndPort): NodeMessagingClient {
return NodeMessagingClient(config, server, identity.public, AffinityExecutor.ServiceAffinityExecutor("ArtemisMessagingTests", 1), false).apply {
configureWithDevSSLCertificate()
messagingClient = this
return databaseTransaction(database) {
NodeMessagingClient(config, server, identity.public, AffinityExecutor.ServiceAffinityExecutor("ArtemisMessagingTests", 1), database).apply {
configureWithDevSSLCertificate()
messagingClient = this
}
}
}

View File

@ -12,7 +12,7 @@ import com.r3corda.core.protocols.ProtocolLogicRefFactory
import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.DUMMY_NOTARY
import com.r3corda.node.services.events.NodeSchedulerService
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.persistence.DBCheckpointStorage
import com.r3corda.node.services.statemachine.StateMachineManager
import com.r3corda.node.utilities.AddOrRemove
import com.r3corda.node.utilities.AffinityExecutor
@ -52,8 +52,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
val factory = ProtocolLogicRefFactory(mapOf(Pair(TestProtocolLogic::class.java.name, setOf(NodeSchedulerServiceTest::class.java.name, Integer::class.java.name))))
val services: MockServiceHubInternal
lateinit var services: MockServiceHubInternal
lateinit var scheduler: NodeSchedulerService
lateinit var smmExecutor: AffinityExecutor.ServiceAffinityExecutor
lateinit var dataSource: Closeable
@ -72,13 +71,6 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
val testReference: NodeSchedulerServiceTest
}
init {
val kms = MockKeyManagementService(ALICE_KEY)
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"), AffinityExecutor.ServiceAffinityExecutor("test", 1), persistenceTx = { it() })
services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference {
override val testReference = this@NodeSchedulerServiceTest
}
}
@Before
fun setup() {
@ -89,9 +81,14 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
dataSource = dataSourceAndDatabase.first
database = dataSourceAndDatabase.second
databaseTransaction(database) {
val kms = MockKeyManagementService(ALICE_KEY)
val mockMessagingService = InMemoryMessagingNetwork(false).InMemoryMessaging(false, InMemoryMessagingNetwork.Handle(0, "None"), AffinityExecutor.ServiceAffinityExecutor("test", 1), database)
services = object : MockServiceHubInternal(overrideClock = testClock, keyManagement = kms, net = mockMessagingService), TestReference {
override val testReference = this@NodeSchedulerServiceTest
}
scheduler = NodeSchedulerService(database, services, factory, schedulerGatedExecutor)
smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1)
val mockSMM = StateMachineManager(services, listOf(services, scheduler), PerFileCheckpointStorage(fs.getPath("checkpoints")), smmExecutor, database)
val mockSMM = StateMachineManager(services, listOf(services, scheduler), DBCheckpointStorage(), smmExecutor, database)
mockSMM.changes.subscribe { change ->
if (change.addOrRemove == AddOrRemove.REMOVE && mockSMM.allStateMachines.isEmpty()) {
smmHasRemovedAllProtocols.countDown()

View File

@ -1,99 +0,0 @@
package com.r3corda.node.services.persistence
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.google.common.primitives.Ints
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.node.services.api.Checkpoint
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.file.FileSystem
import java.nio.file.Files
import java.nio.file.Path
class PerFileCheckpointStorageTests {
val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
val storeDir: Path = fileSystem.getPath("store")
lateinit var checkpointStorage: PerFileCheckpointStorage
@Before
fun setUp() {
newCheckpointStorage()
}
@After
fun cleanUp() {
fileSystem.close()
}
@Test
fun `add new checkpoint`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
@Test
fun `remove checkpoint`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
checkpointStorage.removeCheckpoint(checkpoint)
assertThat(checkpointStorage.checkpoints()).isEmpty()
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
@Test
fun `remove unknown checkpoint`() {
val checkpoint = newCheckpoint()
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
checkpointStorage.removeCheckpoint(checkpoint)
}
}
@Test
fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(firstCheckpoint)
val secondCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(secondCheckpoint)
checkpointStorage.removeCheckpoint(firstCheckpoint)
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).containsExactly(secondCheckpoint)
}
@Test
fun `add checkpoint and then remove after 'restart'`() {
val originalCheckpoint = newCheckpoint()
checkpointStorage.addCheckpoint(originalCheckpoint)
newCheckpointStorage()
val reconstructedCheckpoint = checkpointStorage.checkpoints().single()
assertThat(reconstructedCheckpoint).isEqualTo(originalCheckpoint).isNotSameAs(originalCheckpoint)
checkpointStorage.removeCheckpoint(reconstructedCheckpoint)
assertThat(checkpointStorage.checkpoints()).isEmpty()
}
@Test
fun `non-checkpoint files are ignored`() {
val checkpoint = newCheckpoint()
checkpointStorage.addCheckpoint(checkpoint)
Files.write(storeDir.resolve("random-non-checkpoint-file"), "this is not a checkpoint!!".toByteArray())
newCheckpointStorage()
assertThat(checkpointStorage.checkpoints()).containsExactly(checkpoint)
}
private fun newCheckpointStorage() {
checkpointStorage = PerFileCheckpointStorage(storeDir)
}
private var checkpointCount = 1
private fun newCheckpoint() = Checkpoint(SerializedBytes(Ints.toByteArray(checkpointCount++)))
}

View File

@ -1,100 +0,0 @@
package com.r3corda.node.services.persistence
import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.google.common.primitives.Ints
import com.google.common.util.concurrent.SettableFuture
import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.NullPublicKey
import com.r3corda.core.serialization.SerializedBytes
import com.r3corda.core.transactions.SignedTransaction
import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.nio.file.FileSystem
import java.nio.file.Files
import java.nio.file.Path
import java.util.concurrent.TimeUnit
import kotlin.test.assertEquals
class PerFileTransactionStorageTests {
val fileSystem: FileSystem = Jimfs.newFileSystem(unix())
val storeDir: Path = fileSystem.getPath("store")
lateinit var transactionStorage: PerFileTransactionStorage
@Before
fun setUp() {
newTransactionStorage()
}
@After
fun cleanUp() {
fileSystem.close()
}
@Test
fun `empty store`() {
assertThat(transactionStorage.getTransaction(newTransaction().id)).isNull()
assertThat(transactionStorage.transactions).isEmpty()
newTransactionStorage()
assertThat(transactionStorage.transactions).isEmpty()
}
@Test
fun `one transaction`() {
val transaction = newTransaction()
transactionStorage.addTransaction(transaction)
assertTransactionIsRetrievable(transaction)
assertThat(transactionStorage.transactions).containsExactly(transaction)
newTransactionStorage()
assertTransactionIsRetrievable(transaction)
assertThat(transactionStorage.transactions).containsExactly(transaction)
}
@Test
fun `two transactions across restart`() {
val firstTransaction = newTransaction()
val secondTransaction = newTransaction()
transactionStorage.addTransaction(firstTransaction)
newTransactionStorage()
transactionStorage.addTransaction(secondTransaction)
assertTransactionIsRetrievable(firstTransaction)
assertTransactionIsRetrievable(secondTransaction)
assertThat(transactionStorage.transactions).containsOnly(firstTransaction, secondTransaction)
}
@Test
fun `non-transaction files are ignored`() {
val transactions = newTransaction()
transactionStorage.addTransaction(transactions)
Files.write(storeDir.resolve("random-non-tx-file"), "this is not a transaction!!".toByteArray())
newTransactionStorage()
assertThat(transactionStorage.transactions).containsExactly(transactions)
}
@Test
fun `updates are fired`() {
val future = SettableFuture.create<SignedTransaction>()
transactionStorage.updates.subscribe { tx -> future.set(tx) }
val expected = newTransaction()
transactionStorage.addTransaction(expected)
val actual = future.get(1, TimeUnit.SECONDS)
assertEquals(expected, actual)
}
private fun newTransactionStorage() {
transactionStorage = PerFileTransactionStorage(storeDir)
}
private fun assertTransactionIsRetrievable(transaction: SignedTransaction) {
assertThat(transactionStorage.getTransaction(transaction.id)).isEqualTo(transaction)
}
private var txCount = 0
private fun newTransaction() = SignedTransaction(
SerializedBytes(Ints.toByteArray(++txCount)),
listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1))))
}

View File

@ -10,6 +10,7 @@ import com.r3corda.core.random63BitValue
import com.r3corda.core.serialization.deserialize
import com.r3corda.node.services.persistence.checkpoints
import com.r3corda.node.services.statemachine.StateMachineManager.*
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.initiateSingleShotProtocol
import com.r3corda.testing.node.InMemoryMessagingNetwork
import com.r3corda.testing.node.InMemoryMessagingNetwork.MessageTransfer
@ -73,6 +74,7 @@ class StateMachineManagerTests {
// We push through just enough messages to get only the payload sent
node2.pumpReceive()
node2.disableDBCloseOnStop()
node2.stop()
net.runNetwork()
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
@ -95,6 +97,7 @@ class StateMachineManagerTests {
val protocol = NoOpProtocol()
node3.smm.add(protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
node3.disableDBCloseOnStop()
node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id)
@ -103,6 +106,7 @@ class StateMachineManagerTests {
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
assertEquals(true, restoredProtocol.protocolStarted) // Now we should have run the protocol and hopefully cleared the init checkpoint
node3.disableDBCloseOnStop()
node3.stop()
// Now it is completed the protocol should leave no Checkpoint.
@ -119,6 +123,7 @@ class StateMachineManagerTests {
node2.smm.add(ReceiveThenSuspendProtocol(node1.info.legalIdentity)) // Prepare checkpointed receive protocol
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
node2.disableDBCloseOnStop()
node2.stop() // kill receiver
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveThenSuspendProtocol>(node1)
assertThat(restoredProtocol.receivedPayloads[0]).isEqualTo(payload)
@ -138,16 +143,22 @@ class StateMachineManagerTests {
// Kick off first send and receive
node2.smm.add(PingPongProtocol(node3.info.legalIdentity, payload))
assertEquals(1, node2.checkpointStorage.checkpoints().size)
databaseTransaction(node2.database) {
assertEquals(1, node2.checkpointStorage.checkpoints().size)
}
// Make sure the add() has finished initial processing.
node2.smm.executor.flush()
node2.disableDBCloseOnStop()
// Restart node and thus reload the checkpoint and resend the message with same UUID
node2.stop()
databaseTransaction(node2.database) {
assertEquals(1, node2.checkpointStorage.checkpoints().size) // confirm checkpoint
}
val node2b = net.createNode(node1.info.address, node2.id, advertisedServices = *node2.advertisedServices.toTypedArray())
node2.manuallyCloseDB()
val (firstAgain, fut1) = node2b.getSingleProtocol<PingPongProtocol>()
// Run the network which will also fire up the second protocol. First message should get deduped. So message data stays in sync.
net.runNetwork()
assertEquals(1, node2.checkpointStorage.checkpoints().size)
node2b.smm.executor.flush()
fut1.get()
@ -156,8 +167,12 @@ class StateMachineManagerTests {
assertEquals(4, receivedCount, "Protocol should have exchanged 4 unique messages")// Two messages each way
// can't give a precise value as every addMessageHandler re-runs the undelivered messages
assertTrue(sentCount > receivedCount, "Node restart should have retransmitted messages")
assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
databaseTransaction(node2b.database) {
assertEquals(0, node2b.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
}
databaseTransaction(node3.database) {
assertEquals(0, node3.checkpointStorage.checkpoints().size, "Checkpoints left after restored protocol should have ended")
}
assertEquals(payload2, firstAgain.receivedPayload, "Received payload does not match the first value on Node 3")
assertEquals(payload2 + 1, firstAgain.receivedPayload2, "Received payload does not match the expected second value on Node 3")
assertEquals(payload, secondProtocol.get().receivedPayload, "Received payload does not match the (restarted) first value on Node 2")
@ -253,8 +268,10 @@ class StateMachineManagerTests {
private inline fun <reified P : ProtocolLogic<*>> MockNode.restartAndGetRestoredProtocol(
networkMapNode: MockNode? = null): P {
disableDBCloseOnStop() //Handover DB to new node copy
stop()
val newNode = mockNet.createNode(networkMapNode?.info?.address, id, advertisedServices = *advertisedServices.toTypedArray())
manuallyCloseDB()
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return newNode.getSingleProtocol<P>().first
}

View File

@ -9,7 +9,10 @@ import com.r3corda.core.serialization.SingletonSerializeAsToken
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.MessagingServiceBuilder
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.JDBCHashSet
import com.r3corda.node.utilities.databaseTransaction
import com.r3corda.testing.node.InMemoryMessagingNetwork.InMemoryMessaging
import org.jetbrains.exposed.sql.Database
import org.slf4j.LoggerFactory
import rx.Observable
import rx.subjects.PublishSubject
@ -80,10 +83,10 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
@Synchronized
fun createNode(manuallyPumped: Boolean,
executor: AffinityExecutor,
persistenceTx: (() -> Unit) -> Unit)
database: Database)
: Pair<Handle, com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging>> {
check(counter >= 0) { "In memory network stopped: please recreate." }
val builder = createNodeWithID(manuallyPumped, counter, executor, persistenceTx = persistenceTx) as Builder
val builder = createNodeWithID(manuallyPumped, counter, executor, database = database) as Builder
counter++
val id = builder.id
return Pair(id, builder)
@ -98,9 +101,9 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
* @param persistenceTx a lambda to wrap message handling in a transaction if necessary.
*/
fun createNodeWithID(manuallyPumped: Boolean, id: Int, executor: AffinityExecutor, description: String? = null,
persistenceTx: (() -> Unit) -> Unit)
database: Database)
: MessagingServiceBuilder<InMemoryMessaging> {
return Builder(manuallyPumped, Handle(id, description ?: "In memory node $id"), executor, persistenceTx)
return Builder(manuallyPumped, Handle(id, description ?: "In memory node $id"), executor, database = database)
}
interface LatencyCalculator {
@ -140,11 +143,11 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
messageReceiveQueues.clear()
}
inner class Builder(val manuallyPumped: Boolean, val id: Handle, val executor: AffinityExecutor, val persistenceTx: (() -> Unit) -> Unit)
inner class Builder(val manuallyPumped: Boolean, val id: Handle, val executor: AffinityExecutor, val database: Database)
: com.r3corda.node.services.api.MessagingServiceBuilder<InMemoryMessaging> {
override fun start(): ListenableFuture<InMemoryMessaging> {
synchronized(this@InMemoryMessagingNetwork) {
val node = InMemoryMessaging(manuallyPumped, id, executor, persistenceTx)
val node = InMemoryMessaging(manuallyPumped, id, executor, database)
handleEndpointMap[id] = node
return Futures.immediateFuture(node)
}
@ -208,7 +211,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
inner class InMemoryMessaging(private val manuallyPumped: Boolean,
private val handle: Handle,
private val executor: AffinityExecutor,
private val persistenceTx: (() -> Unit) -> Unit)
private val database: Database)
: SingletonSerializeAsToken(), com.r3corda.node.services.api.MessagingServiceInternal {
inner class Handler(val topicSession: TopicSession,
val callback: (Message, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration
@ -218,7 +221,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
private inner class InnerState {
val handlers: MutableList<Handler> = ArrayList()
val pendingRedelivery = LinkedList<MessageTransfer>()
val pendingRedelivery = JDBCHashSet<Message>("pending_messages",loadOnInit = true)
}
private val state = ThreadBox(InnerState())
@ -244,11 +247,14 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
check(running)
val (handler, items) = state.locked {
val handler = Handler(topicSession, callback).apply { handlers.add(this) }
val items = ArrayList(pendingRedelivery)
pendingRedelivery.clear()
Pair(handler, items)
val pending = ArrayList<Message>()
databaseTransaction(database) {
pending.addAll(pendingRedelivery)
pendingRedelivery.clear()
}
Pair(handler, pending)
}
for ((sender, message) in items) {
for (message in items) {
send(message, handle)
}
return handler
@ -328,7 +334,9 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
// up a handler for yet. Most unit tests don't run threaded, but we want to test true parallelism at
// least sometimes.
log.warn("Message to ${transfer.message.topicSession} could not be delivered")
pendingRedelivery.add(transfer)
databaseTransaction(database) {
pendingRedelivery.add(transfer.message)
}
null
} else {
h
@ -348,7 +356,7 @@ class InMemoryMessagingNetwork(val sendManuallyPumped: Boolean) : SingletonSeria
if (transfer.message.uniqueMessageId !in processedMessages) {
executor.execute {
persistenceTx {
databaseTransaction(database) {
for (handler in deliverTo) {
try {
handler.callback(transfer.message, handler)

View File

@ -4,7 +4,6 @@ import com.google.common.jimfs.Configuration.unix
import com.google.common.jimfs.Jimfs
import com.google.common.util.concurrent.Futures
import com.r3corda.core.crypto.Party
import com.r3corda.core.div
import com.r3corda.core.messaging.SingleMessageRecipient
import com.r3corda.core.node.PhysicalLocation
import com.r3corda.core.node.services.KeyManagementService
@ -15,21 +14,17 @@ import com.r3corda.core.testing.InMemoryVaultService
import com.r3corda.core.utilities.DUMMY_NOTARY_KEY
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.internal.AbstractNode
import com.r3corda.node.services.api.CheckpointStorage
import com.r3corda.node.services.api.MessagingServiceInternal
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.keys.E2ETestKeyManagementService
import com.r3corda.node.services.messaging.CordaRPCOps
import com.r3corda.node.services.network.InMemoryNetworkMapService
import com.r3corda.node.services.network.NetworkMapService
import com.r3corda.node.services.persistence.DBCheckpointStorage
import com.r3corda.node.services.persistence.PerFileCheckpointStorage
import com.r3corda.node.services.transactions.InMemoryUniquenessProvider
import com.r3corda.node.services.transactions.SimpleNotaryService
import com.r3corda.node.services.transactions.ValidatingNotaryService
import com.r3corda.node.utilities.AffinityExecutor
import com.r3corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor
import com.r3corda.node.utilities.databaseTransaction
import org.slf4j.Logger
import java.nio.file.FileSystem
import java.nio.file.Files
@ -125,16 +120,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
// through the java.nio API which we are already mocking via Jimfs.
override fun makeMessagingService(): MessagingServiceInternal {
require(id >= 0) { "Node ID must be zero or positive, was passed: " + id }
return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, serverThread, configuration.myLegalName,
persistenceTx = { body: () -> Unit -> databaseTransaction(database) { body() } }).start().get()
}
override fun initialiseCheckpointService(dir: Path): CheckpointStorage {
return if (mockNet.threadPerNode) {
DBCheckpointStorage()
} else {
PerFileCheckpointStorage(dir / "checkpoints")
}
return mockNet.messagingNetwork.createNodeWithID(!mockNet.threadPerNode, id, serverThread, configuration.myLegalName, database).start().get()
}
override fun makeIdentityService() = MockIdentityService(mockNet.identities)
@ -143,7 +129,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false,
override fun makeKeyManagementService(): KeyManagementService = E2ETestKeyManagementService(partyKeys)
override fun startMessagingService(cordaRPCOps: CordaRPCOps?) {
override fun startMessagingService(cordaRPCOps: CordaRPCOps) {
// Nothing to do
}