diff --git a/.gitignore b/.gitignore index 8fc24e7bad..5a11950686 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ lib/dokka.jar .idea/libraries .idea/shelf .idea/dataSources +.idea/markdown-navigator /gradle-plugins/.idea/ # Include the -parameters compiler option by default in IntelliJ required for serialization. @@ -66,7 +67,7 @@ lib/dokka.jar ## Plugin-specific files: # IntelliJ -/out/ +**/out/ /classes/ # mpeltonen/sbt-idea plugin diff --git a/.idea/compiler.xml b/.idea/compiler.xml index 175866f66d..75643b8eeb 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -40,6 +40,8 @@ + + diff --git a/build.gradle b/build.gradle index 51ab3d5f0e..e3a664f5de 100644 --- a/build.gradle +++ b/build.gradle @@ -4,7 +4,7 @@ buildscript { file("$projectDir/constants.properties").withInputStream { constants.load(it) } // Our version: bump this on release. - ext.corda_release_version = "0.14-SNAPSHOT" + ext.corda_release_version = "0.15-SNAPSHOT" // Increment this on any release that changes public APIs anywhere in the Corda platform // TODO This is going to be difficult until we have a clear separation throughout the code of what is public and what is internal ext.corda_platform_version = 1 @@ -31,7 +31,6 @@ buildscript { ext.log4j_version = '2.7' ext.bouncycastle_version = constants.getProperty("bouncycastleVersion") ext.guava_version = constants.getProperty("guavaVersion") - ext.quickcheck_version = '0.7' ext.okhttp_version = '3.5.0' ext.netty_version = '4.1.9.Final' ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion") @@ -168,22 +167,23 @@ repositories { } } +// TODO: Corda root project currently produces a dummy cordapp when it shouldn't. // Required for building out the fat JAR. dependencies { - compile project(':node') + cordaCompile project(':node') compile "com.google.guava:guava:$guava_version" - // Set to compile to ensure it exists now deploy nodes no longer relies on build - compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') - compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') + // Set to corda compile to ensure it exists now deploy nodes no longer relies on build + cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts') + cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') // For the buildCordappDependenciesJar task - runtime project(':client:jfx') - runtime project(':client:mock') - runtime project(':client:rpc') - runtime project(':core') - runtime project(':finance') - runtime project(':webserver') + cordaRuntime project(':client:jfx') + cordaRuntime project(':client:mock') + cordaRuntime project(':client:rpc') + cordaRuntime project(':core') + cordaRuntime project(':finance') + cordaRuntime project(':webserver') testCompile project(':test-utils') } @@ -285,7 +285,7 @@ artifactory { password = System.getenv('CORDA_ARTIFACTORY_PASSWORD') } defaults { - publications('corda-jfx', 'corda-mock', 'corda-rpc', 'corda-core', 'corda', 'cordform-common', 'corda-finance', 'corda-node', 'corda-node-api', 'corda-node-schemas', 'corda-test-utils', 'corda-jackson', 'corda-verifier', 'corda-webserver-impl', 'corda-webserver') + publications('corda-jfx', 'corda-mock', 'corda-rpc', 'corda-core', 'corda', 'cordform-common', 'corda-finance', 'corda-node', 'corda-node-api', 'corda-node-schemas', 'corda-test-common', 'corda-test-utils', 'corda-jackson', 'corda-verifier', 'corda-webserver-impl', 'corda-webserver') } } } diff --git a/client/jackson/build.gradle b/client/jackson/build.gradle index 234f7e1ae0..f591484aa0 100644 --- a/client/jackson/build.gradle +++ b/client/jackson/build.gradle @@ -6,6 +6,8 @@ apply plugin: 'com.jfrog.artifactory' dependencies { compile project(':core') compile project(':finance') + testCompile project(':test-utils') + compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version" @@ -18,12 +20,6 @@ dependencies { testCompile project(path: ':core', configuration: 'testArtifacts') testCompile "junit:junit:$junit_version" - - // TODO: Upgrade to junit-quickcheck 0.8, once it is released, - // because it depends on org.javassist:javassist instead - // of javassist:javassist. - testCompile "com.pholser:junit-quickcheck-core:$quickcheck_version" - testCompile "com.pholser:junit-quickcheck-generators:$quickcheck_version" } jar { @@ -31,5 +27,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } \ No newline at end of file diff --git a/client/jackson/src/main/kotlin/net/corda/jackson/JacksonSupport.kt b/client/jackson/src/main/kotlin/net/corda/jackson/JacksonSupport.kt index 1ebb881d19..ee8333f06e 100644 --- a/client/jackson/src/main/kotlin/net/corda/jackson/JacksonSupport.kt +++ b/client/jackson/src/main/kotlin/net/corda/jackson/JacksonSupport.kt @@ -1,5 +1,7 @@ package net.corda.jackson +import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.annotation.JsonProperty import com.fasterxml.jackson.core.* import com.fasterxml.jackson.databind.* import com.fasterxml.jackson.databind.deser.std.NumberDeserializers @@ -9,6 +11,8 @@ import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule import com.fasterxml.jackson.module.kotlin.KotlinModule import net.corda.contracts.BusinessCalendar import net.corda.core.contracts.Amount +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateRef import net.corda.core.crypto.* import net.corda.core.crypto.composite.CompositeKey import net.corda.core.identity.AbstractParty @@ -17,9 +21,14 @@ import net.corda.core.identity.Party import net.corda.core.messaging.CordaRPCOps import net.corda.core.node.NodeInfo import net.corda.core.node.services.IdentityService -import net.corda.core.utilities.OpaqueBytes +import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.core.transactions.CoreTransaction +import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.OpaqueBytes import net.i2p.crypto.eddsa.EdDSAPublicKey import org.bouncycastle.asn1.x500.X500Name import java.math.BigDecimal @@ -38,32 +47,24 @@ object JacksonSupport { // If you change this API please update the docs in the docsite (json.rst) interface PartyObjectMapper { - @Deprecated("Use partyFromX500Name instead") - fun partyFromName(partyName: String): Party? fun partyFromX500Name(name: X500Name): Party? fun partyFromKey(owningKey: PublicKey): Party? fun partiesFromName(query: String): Set } class RpcObjectMapper(val rpc: CordaRPCOps, factory: JsonFactory, val fuzzyIdentityMatch: Boolean) : PartyObjectMapper, ObjectMapper(factory) { - @Suppress("OverridingDeprecatedMember", "DEPRECATION") - override fun partyFromName(partyName: String): Party? = rpc.partyFromName(partyName) override fun partyFromX500Name(name: X500Name): Party? = rpc.partyFromX500Name(name) override fun partyFromKey(owningKey: PublicKey): Party? = rpc.partyFromKey(owningKey) override fun partiesFromName(query: String) = rpc.partiesFromName(query, fuzzyIdentityMatch) } class IdentityObjectMapper(val identityService: IdentityService, factory: JsonFactory, val fuzzyIdentityMatch: Boolean) : PartyObjectMapper, ObjectMapper(factory) { - @Suppress("OverridingDeprecatedMember", "DEPRECATION") - override fun partyFromName(partyName: String): Party? = identityService.partyFromName(partyName) override fun partyFromX500Name(name: X500Name): Party? = identityService.partyFromX500Name(name) override fun partyFromKey(owningKey: PublicKey): Party? = identityService.partyFromKey(owningKey) override fun partiesFromName(query: String) = identityService.partiesFromName(query, fuzzyIdentityMatch) } class NoPartyObjectMapper(factory: JsonFactory) : PartyObjectMapper, ObjectMapper(factory) { - @Suppress("OverridingDeprecatedMember", "DEPRECATION") - override fun partyFromName(partyName: String): Party? = throw UnsupportedOperationException() override fun partyFromX500Name(name: X500Name): Party? = throw UnsupportedOperationException() override fun partyFromKey(owningKey: PublicKey): Party? = throw UnsupportedOperationException() override fun partiesFromName(query: String) = throw UnsupportedOperationException() @@ -109,6 +110,10 @@ object JacksonSupport { // For X.500 distinguished names addDeserializer(X500Name::class.java, X500NameDeserializer) addSerializer(X500Name::class.java, X500NameSerializer) + + // Mixins for transaction types to prevent some properties from being serialized + setMixInAnnotation(SignedTransaction::class.java, SignedTransactionMixin::class.java) + setMixInAnnotation(WireTransaction::class.java, WireTransactionMixin::class.java) } } @@ -278,7 +283,7 @@ object JacksonSupport { object CalendarSerializer : JsonSerializer() { override fun serialize(obj: BusinessCalendar, generator: JsonGenerator, context: SerializerProvider) { val calendarName = BusinessCalendar.calendars.find { BusinessCalendar.getInstance(it) == obj } - if(calendarName != null) { + if (calendarName != null) { generator.writeString(calendarName) } else { generator.writeObject(BusinessCalendarWrapper(obj.holidayDates)) @@ -371,5 +376,24 @@ object JacksonSupport { gen.writeBinary(value.bytes) } } + + abstract class SignedTransactionMixin { + @JsonIgnore abstract fun getTxBits(): SerializedBytes + @JsonProperty("signatures") protected abstract fun getSigs(): List + @JsonProperty protected abstract fun getTransaction(): CoreTransaction + @JsonIgnore abstract fun getTx(): WireTransaction + @JsonIgnore abstract fun getNotaryChangeTx(): NotaryChangeWireTransaction + @JsonIgnore abstract fun getInputs(): List + @JsonIgnore abstract fun getNotary(): Party? + @JsonIgnore abstract fun getId(): SecureHash + @JsonIgnore abstract fun getRequiredSigningKeys(): Set + } + + abstract class WireTransactionMixin { + @JsonIgnore abstract fun getMerkleTree(): MerkleTree + @JsonIgnore abstract fun getAvailableComponents(): List + @JsonIgnore abstract fun getAvailableComponentHashes(): List + @JsonIgnore abstract fun getOutputStates(): List + } } diff --git a/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt b/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt index b1bbfa8f67..a92edb7aea 100644 --- a/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt +++ b/client/jackson/src/test/kotlin/net/corda/jackson/JacksonSupportTest.kt @@ -1,27 +1,31 @@ package net.corda.jackson import com.fasterxml.jackson.databind.SerializationFeature -import com.pholser.junit.quickcheck.From -import com.pholser.junit.quickcheck.Property -import com.pholser.junit.quickcheck.runner.JUnitQuickcheck import net.corda.core.contracts.Amount import net.corda.core.contracts.USD -import net.corda.core.testing.PublicKeyGenerator +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.crypto.TransactionSignature +import net.corda.core.crypto.generateKeyPair +import net.corda.core.transactions.SignedTransaction +import net.corda.testing.ALICE_PUBKEY +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.MINI_CORP +import net.corda.testing.TestDependencyInjectionBase +import net.corda.testing.contracts.DummyContract import net.i2p.crypto.eddsa.EdDSAPublicKey import org.junit.Test -import org.junit.runner.RunWith -import java.security.PublicKey import java.util.* import kotlin.test.assertEquals -@RunWith(JUnitQuickcheck::class) -class JacksonSupportTest { +class JacksonSupportTest : TestDependencyInjectionBase() { companion object { val mapper = JacksonSupport.createNonRpcMapper() } - @Property - fun publicKeySerializingWorks(@From(PublicKeyGenerator::class) publicKey: PublicKey) { + @Test + fun publicKeySerializingWorks() { + val publicKey = generateKeyPair().public val serialized = mapper.writeValueAsString(publicKey) val parsedKey = mapper.readValue(serialized, EdDSAPublicKey::class.java) assertEquals(publicKey, parsedKey) @@ -50,4 +54,24 @@ class JacksonSupportTest { val writer = mapper.writer().without(SerializationFeature.INDENT_OUTPUT) assertEquals("""{"notional":"25000000.00 USD"}""", writer.writeValueAsString(Dummy(Amount.parseCurrency("$25000000")))) } + + @Test + fun writeTransaction() { + fun makeDummyTx(): SignedTransaction { + val wtx = DummyContract.generateInitial(1, DUMMY_NOTARY, MINI_CORP.ref(1)).toWireTransaction() + val signatures = TransactionSignature( + ByteArray(1), + ALICE_PUBKEY, + SignatureMetadata( + 1, + Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID + ) + ) + return SignedTransaction(wtx, listOf(signatures)) + } + + val writer = mapper.writer() + // We don't particularly care about the serialized format, just need to make sure it completes successfully. + writer.writeValueAsString(makeDummyTx()) + } } diff --git a/client/jfx/build.gradle b/client/jfx/build.gradle index 5d12e01f56..9c70ee7a04 100644 --- a/client/jfx/build.gradle +++ b/client/jfx/build.gradle @@ -62,5 +62,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } \ No newline at end of file diff --git a/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt b/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt index 5c93d9e820..94662d7378 100644 --- a/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt +++ b/client/jfx/src/integration-test/kotlin/net/corda/client/jfx/NodeMonitorModelTest.kt @@ -2,15 +2,15 @@ package net.corda.client.jfx import net.corda.client.jfx.model.NodeMonitorModel import net.corda.client.jfx.model.ProgressTrackingEvent -import net.corda.core.bufferUntilSubscribed +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.contracts.Amount +import net.corda.core.contracts.ContractState import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.USD import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.keys import net.corda.core.flows.FlowInitiator import net.corda.core.flows.StateMachineRunId -import net.corda.core.getOrThrow import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.core.messaging.StateMachineUpdate @@ -19,12 +19,9 @@ import net.corda.core.node.NodeInfo import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.Vault -import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.SignedTransaction -import net.corda.testing.ALICE -import net.corda.testing.BOB -import net.corda.testing.CHARLIE -import net.corda.testing.DUMMY_NOTARY +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow import net.corda.flows.CashExitFlow import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow @@ -32,11 +29,9 @@ import net.corda.node.services.network.NetworkMapService import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService import net.corda.nodeapi.User +import net.corda.testing.* import net.corda.testing.driver.driver -import net.corda.testing.expect -import net.corda.testing.expectEvents import net.corda.testing.node.DriverBasedTest -import net.corda.testing.sequence import org.bouncycastle.asn1.x500.X500Name import org.junit.Test import rx.Observable @@ -53,7 +48,7 @@ class NodeMonitorModelTest : DriverBasedTest() { lateinit var stateMachineUpdatesBob: Observable lateinit var progressTracking: Observable lateinit var transactions: Observable - lateinit var vaultUpdates: Observable + lateinit var vaultUpdates: Observable> lateinit var networkMapUpdates: Observable lateinit var newNode: (X500Name) -> NodeInfo @@ -78,14 +73,14 @@ class NodeMonitorModelTest : DriverBasedTest() { vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed() networkMapUpdates = monitor.networkMap.bufferUntilSubscribed() - monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) + monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password, initialiseSerialization = false) rpc = monitor.proxyObservable.value!! val bobNodeHandle = startNode(BOB.name, rpcUsers = listOf(cashUser)).getOrThrow() bobNode = bobNodeHandle.nodeInfo val monitorBob = NodeMonitorModel() stateMachineUpdatesBob = monitorBob.stateMachineUpdates.bufferUntilSubscribed() - monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) + monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password, initialiseSerialization = false) rpcBob = monitorBob.proxyObservable.value!! runTest() } @@ -148,7 +143,7 @@ class NodeMonitorModelTest : DriverBasedTest() { var moveSmId: StateMachineRunId? = null var issueTx: SignedTransaction? = null var moveTx: SignedTransaction? = null - stateMachineUpdates.expectEvents { + stateMachineUpdates.expectEvents(isStrict = false) { sequence( // ISSUE expect { add: StateMachineUpdate.Added -> @@ -159,14 +154,13 @@ class NodeMonitorModelTest : DriverBasedTest() { expect { remove: StateMachineUpdate.Removed -> require(remove.id == issueSmId) }, - // MOVE - expect { add: StateMachineUpdate.Added -> + // MOVE - N.B. There are other framework flows that happen in parallel for the remote resolve transactions flow + expect(match = { it is StateMachineUpdate.Added && it.stateMachineInfo.flowLogicClassName == CashPaymentFlow::class.java.name }) { add: StateMachineUpdate.Added -> moveSmId = add.id val initiator = add.stateMachineInfo.initiator require(initiator is FlowInitiator.RPC && initiator.username == "user1") }, - expect { remove: StateMachineUpdate.Removed -> - require(remove.id == moveSmId) + expect(match = { it is StateMachineUpdate.Removed && it.id == moveSmId }) { } ) } diff --git a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/ContractStateModel.kt b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/ContractStateModel.kt index bbe3736c22..4944ca31b9 100644 --- a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/ContractStateModel.kt +++ b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/ContractStateModel.kt @@ -19,7 +19,7 @@ data class Diff( * This model exposes the list of owned contract states. */ class ContractStateModel { - private val vaultUpdates: Observable by observable(NodeMonitorModel::vaultUpdates) + private val vaultUpdates: Observable> by observable(NodeMonitorModel::vaultUpdates) private val contractStatesDiff: Observable> = vaultUpdates.map { Diff(it.produced, it.consumed) diff --git a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt index e2b134bc8c..33e160c742 100644 --- a/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt +++ b/client/jfx/src/main/kotlin/net/corda/client/jfx/model/NodeMonitorModel.kt @@ -3,14 +3,13 @@ package net.corda.client.jfx.model import javafx.beans.property.SimpleObjectProperty import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCClientConfiguration +import net.corda.core.contracts.ContractState import net.corda.core.flows.StateMachineRunId -import net.corda.core.messaging.CordaRPCOps -import net.corda.core.messaging.StateMachineInfo -import net.corda.core.messaging.StateMachineTransactionMapping -import net.corda.core.messaging.StateMachineUpdate +import net.corda.core.messaging.* import net.corda.core.node.services.NetworkMapCache.MapChange import net.corda.core.node.services.Vault -import net.corda.core.seconds +import net.corda.core.node.services.vault.* +import net.corda.core.utilities.seconds import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NetworkHostAndPort import rx.Observable @@ -32,14 +31,14 @@ data class ProgressTrackingEvent(val stateMachineId: StateMachineRunId, val mess class NodeMonitorModel { private val stateMachineUpdatesSubject = PublishSubject.create() - private val vaultUpdatesSubject = PublishSubject.create() + private val vaultUpdatesSubject = PublishSubject.create>() private val transactionsSubject = PublishSubject.create() private val stateMachineTransactionMappingSubject = PublishSubject.create() private val progressTrackingSubject = PublishSubject.create() private val networkMapSubject = PublishSubject.create() val stateMachineUpdates: Observable = stateMachineUpdatesSubject - val vaultUpdates: Observable = vaultUpdatesSubject + val vaultUpdates: Observable> = vaultUpdatesSubject val transactions: Observable = transactionsSubject val stateMachineTransactionMapping: Observable = stateMachineTransactionMappingSubject val progressTracking: Observable = progressTrackingSubject @@ -51,17 +50,18 @@ class NodeMonitorModel { * Register for updates to/from a given vault. * TODO provide an unsubscribe mechanism */ - fun register(nodeHostAndPort: NetworkHostAndPort, username: String, password: String) { + fun register(nodeHostAndPort: NetworkHostAndPort, username: String, password: String, initialiseSerialization: Boolean = true) { val client = CordaRPCClient( hostAndPort = nodeHostAndPort, configuration = CordaRPCClientConfiguration.default.copy( connectionMaxRetryInterval = 10.seconds - ) + ), + initialiseSerialization = initialiseSerialization ) val connection = client.start(username, password) val proxy = connection.proxy - val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates() + val (stateMachines, stateMachineUpdates) = proxy.stateMachinesFeed() // Extract the flow tracking stream // TODO is there a nicer way of doing this? Stream of streams in general results in code like this... val currentProgressTrackerUpdates = stateMachines.mapNotNull { stateMachine -> @@ -82,21 +82,22 @@ class NodeMonitorModel { val currentStateMachines = stateMachines.map { StateMachineUpdate.Added(it) } stateMachineUpdates.startWith(currentStateMachines).subscribe(stateMachineUpdatesSubject) - // Vault updates - val (vault, vaultUpdates) = proxy.vaultAndUpdates() - val initialVaultUpdate = Vault.Update(setOf(), vault.toSet()) + // Vault snapshot (force single page load with MAX_PAGE_SIZE) + updates + val (vaultSnapshot, vaultUpdates) = proxy.vaultTrackBy(QueryCriteria.VaultQueryCriteria(Vault.StateStatus.ALL), + PageSpecification(DEFAULT_PAGE_NUM, MAX_PAGE_SIZE)) + val initialVaultUpdate = Vault.Update(setOf(), vaultSnapshot.states.toSet()) vaultUpdates.startWith(initialVaultUpdate).subscribe(vaultUpdatesSubject) // Transactions - val (transactions, newTransactions) = proxy.verifiedTransactions() + val (transactions, newTransactions) = proxy.verifiedTransactionsFeed() newTransactions.startWith(transactions).subscribe(transactionsSubject) // SM -> TX mapping - val (smTxMappings, futureSmTxMappings) = proxy.stateMachineRecordedTransactionMapping() + val (smTxMappings, futureSmTxMappings) = proxy.stateMachineRecordedTransactionMappingFeed() futureSmTxMappings.startWith(smTxMappings).subscribe(stateMachineTransactionMappingSubject) // Parties on network - val (parties, futurePartyUpdate) = proxy.networkMapUpdates() + val (parties, futurePartyUpdate) = proxy.networkMapFeed() futurePartyUpdate.startWith(parties.map { MapChange.Added(it) }).subscribe(networkMapSubject) proxyObservable.set(proxy) diff --git a/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/AmountBindings.kt b/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/AmountBindings.kt index 31cb4bf3e6..26387c1663 100644 --- a/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/AmountBindings.kt +++ b/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/AmountBindings.kt @@ -23,13 +23,14 @@ object AmountBindings { ) { sum -> Amount(sum.toLong(), token) } fun exchange( - currency: ObservableValue, - exchangeRate: ObservableValue + observableCurrency: ObservableValue, + observableExchangeRate: ObservableValue ): ObservableValue) -> Long>> { - return EasyBind.combine(currency, exchangeRate) { currency, exchangeRate -> - Pair(currency) { amount: Amount -> - (exchangeRate.rate(amount.token, currency) * amount.quantity).toLong() - } + return EasyBind.combine(observableCurrency, observableExchangeRate) { currency, exchangeRate -> + Pair) -> Long>( + currency, + { (quantity, _, token) -> (exchangeRate.rate(token, currency) * quantity).toLong() } + ) } } diff --git a/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/ObservableUtilities.kt b/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/ObservableUtilities.kt index 5ccc04a7c4..198fa779c5 100644 --- a/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/ObservableUtilities.kt +++ b/client/jfx/src/main/kotlin/net/corda/client/jfx/utils/ObservableUtilities.kt @@ -1,5 +1,6 @@ package net.corda.client.jfx.utils +import javafx.application.Platform import javafx.beans.binding.Bindings import javafx.beans.binding.BooleanBinding import javafx.beans.property.ReadOnlyObjectWrapper @@ -10,7 +11,13 @@ import javafx.collections.MapChangeListener import javafx.collections.ObservableList import javafx.collections.ObservableMap import javafx.collections.transformation.FilteredList +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateAndRef +import net.corda.core.messaging.DataFeed +import net.corda.core.node.services.Vault import org.fxmisc.easybind.EasyBind +import rx.Observable +import rx.schedulers.Schedulers import java.util.function.Predicate /** @@ -313,3 +320,36 @@ fun ObservableList.firstOrDefault(default: ObservableValue, predicate fun ObservableList.firstOrNullObservable(predicate: (A) -> Boolean): ObservableValue { return Bindings.createObjectBinding({ this.firstOrNull(predicate) }, arrayOf(this)) } + +/** + * Modifies the given Rx observable such that emissions are run on the JavaFX GUI thread. Use this when you have an Rx + * observable that may emit in the background e.g. from the network and you wish to link it to the user interface. + * + * Note: you should use the returned observable, not the original one this method is called on. + */ +fun Observable.observeOnFXThread(): Observable = observeOn(Schedulers.from(Platform::runLater)) + +/** + * Given a [DataFeed] that contains the results of a vault query and a subsequent stream of changes, returns a JavaFX + * [ObservableList] that mirrors the streamed results on the UI thread. Note that the paging is *not* respected by this + * function: if a state is added that would not have appeared in the page in the initial query, it will still be added + * to the observable list. + * + * @see toFXListOfStates if you want just the state objects and not the ledger pointers too. + */ +fun DataFeed, Vault.Update>.toFXListOfStateRefs(): ObservableList> { + val list = FXCollections.observableArrayList(snapshot.states) + updates.observeOnFXThread().subscribe { (consumed, produced) -> + list.removeAll(consumed) + list.addAll(produced) + } + return list +} + +/** + * Returns the same list as [toFXListOfStateRefs] but which contains the states instead of [StateAndRef] wrappers. + * The same notes apply as with that function. + */ +fun DataFeed, Vault.Update>.toFXListOfStates(): ObservableList { + return toFXListOfStateRefs().map { it.state.data } +} \ No newline at end of file diff --git a/client/mock/build.gradle b/client/mock/build.gradle index d709d4c911..0c4a2efbb0 100644 --- a/client/mock/build.gradle +++ b/client/mock/build.gradle @@ -30,5 +30,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } \ No newline at end of file diff --git a/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt b/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt index 9748e2a2ca..8da2e4e49b 100644 --- a/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt +++ b/client/mock/src/main/kotlin/net/corda/client/mock/Generator.kt @@ -167,21 +167,23 @@ fun Generator.Companion.replicate(number: Int, generator: Generator): Gen } -fun Generator.Companion.replicatePoisson(meanSize: Double, generator: Generator) = Generator> { +fun Generator.Companion.replicatePoisson(meanSize: Double, generator: Generator, atLeastOne: Boolean = false) = Generator> { val chance = (meanSize - 1) / meanSize val result = mutableListOf() var finish = false while (!finish) { - val result = Generator.doubleRange(0.0, 1.0).generate(it).flatMap { value -> + val res = Generator.doubleRange(0.0, 1.0).generate(it).flatMap { value -> if (value < chance) { generator.generate(it).map { result.add(it) } } else { finish = true - Try.Success(Unit) + if (result.isEmpty() && atLeastOne) { + generator.generate(it).map { result.add(it) } + } else Try.Success(Unit) } } - if (result is Try.Failure) { - return@Generator result + if (res is Try.Failure) { + return@Generator res } } Try.Success(result) diff --git a/client/rpc/build.gradle b/client/rpc/build.gradle index b2ab10dff4..61159ac74d 100644 --- a/client/rpc/build.gradle +++ b/client/rpc/build.gradle @@ -24,6 +24,11 @@ sourceSets { runtimeClasspath += main.output + test.output srcDir file('src/integration-test/kotlin') } + java { + compileClasspath += main.output + test.output + runtimeClasspath += main.output + test.output + srcDir file('src/integration-test/java') + } } smokeTest { kotlin { @@ -33,6 +38,11 @@ sourceSets { runtimeClasspath += main.output srcDir file('src/smoke-test/kotlin') } + java { + compileClasspath += main.output + runtimeClasspath += main.output + srcDir file('src/smoke-test/java') + } } } @@ -82,5 +92,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } diff --git a/client/rpc/src/integration-test/java/net/corda/client/rpc/CordaRPCJavaClientTest.java b/client/rpc/src/integration-test/java/net/corda/client/rpc/CordaRPCJavaClientTest.java new file mode 100644 index 0000000000..9f38c17316 --- /dev/null +++ b/client/rpc/src/integration-test/java/net/corda/client/rpc/CordaRPCJavaClientTest.java @@ -0,0 +1,81 @@ +package net.corda.client.rpc; + +import net.corda.core.concurrent.CordaFuture; +import net.corda.client.rpc.internal.RPCClient; +import net.corda.core.contracts.Amount; +import net.corda.core.messaging.CordaRPCOps; +import net.corda.core.messaging.FlowHandle; +import net.corda.core.node.services.ServiceInfo; +import net.corda.core.utilities.OpaqueBytes; +import net.corda.flows.AbstractCashFlow; +import net.corda.flows.CashIssueFlow; +import net.corda.flows.CashPaymentFlow; +import net.corda.node.internal.Node; +import net.corda.node.services.transactions.ValidatingNotaryService; +import net.corda.nodeapi.User; +import net.corda.testing.node.NodeBasedTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.ExecutionException; + +import static kotlin.test.AssertionsKt.assertEquals; +import static net.corda.client.rpc.CordaRPCClientConfiguration.getDefault; +import static net.corda.contracts.GetBalances.getCashBalance; +import static net.corda.node.services.RPCUserServiceKt.startFlowPermission; +import static net.corda.testing.TestConstants.getALICE; + +public class CordaRPCJavaClientTest extends NodeBasedTest { + private List perms = Arrays.asList(startFlowPermission(CashPaymentFlow.class), startFlowPermission(CashIssueFlow.class)); + private Set permSet = new HashSet<>(perms); + private User rpcUser = new User("user1", "test", permSet); + + private Node node; + private CordaRPCClient client; + private RPCClient.RPCConnection connection = null; + private CordaRPCOps rpcProxy; + + private void login(String username, String password) { + connection = client.start(username, password); + rpcProxy = connection.getProxy(); + } + + @Before + public void setUp() throws ExecutionException, InterruptedException { + Set services = new HashSet<>(Collections.singletonList(new ServiceInfo(ValidatingNotaryService.Companion.getType(), null))); + CordaFuture nodeFuture = startNode(getALICE().getName(), 1, services, Arrays.asList(rpcUser), Collections.emptyMap()); + node = nodeFuture.get(); + client = new CordaRPCClient(node.getConfiguration().getRpcAddress(), null, getDefault(), false); + } + + @After + public void done() throws IOException { + connection.close(); + } + + @Test + public void testLogin() { + login(rpcUser.getUsername(), rpcUser.getPassword()); + } + + @Test + public void testCashBalances() throws NoSuchFieldException, ExecutionException, InterruptedException { + login(rpcUser.getUsername(), rpcUser.getPassword()); + + Amount dollars123 = new Amount<>(123, Currency.getInstance("USD")); + + FlowHandle flowHandle = rpcProxy.startFlowDynamic(CashIssueFlow.class, + dollars123, OpaqueBytes.of("1".getBytes()), + node.info.getLegalIdentity(), node.info.getLegalIdentity()); + System.out.println("Started issuing cash, waiting on result"); + flowHandle.getReturnValue().get(); + + Amount balance = getCashBalance(rpcProxy, Currency.getInstance("USD")); + System.out.print("Balance: " + balance + "\n"); + + assertEquals(dollars123, balance, "matching"); + } +} diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt index 40a179fd29..19e5fdc50e 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/CordaRPCClientTest.kt @@ -1,13 +1,15 @@ package net.corda.client.rpc +import net.corda.contracts.getCashBalance +import net.corda.contracts.getCashBalances import net.corda.core.contracts.DOLLARS +import net.corda.core.contracts.USD +import net.corda.core.crypto.random63BitValue import net.corda.core.flows.FlowInitiator -import net.corda.core.getOrThrow import net.corda.core.messaging.* import net.corda.core.node.services.ServiceInfo -import net.corda.core.crypto.random63BitValue import net.corda.core.utilities.OpaqueBytes -import net.corda.testing.ALICE +import net.corda.core.utilities.getOrThrow import net.corda.flows.CashException import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow @@ -15,13 +17,13 @@ import net.corda.node.internal.Node import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.ValidatingNotaryService import net.corda.nodeapi.User +import net.corda.testing.ALICE import net.corda.testing.node.NodeBasedTest import org.apache.activemq.artemis.api.core.ActiveMQSecurityException import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.After import org.junit.Before import org.junit.Test -import java.util.* import kotlin.test.assertEquals import kotlin.test.assertFalse import kotlin.test.assertTrue @@ -42,7 +44,7 @@ class CordaRPCClientTest : NodeBasedTest() { @Before fun setUp() { node = startNode(ALICE.name, rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow() - client = CordaRPCClient(node.configuration.rpcAddress!!) + client = CordaRPCClient(node.configuration.rpcAddress!!, initialiseSerialization = false) } @After @@ -117,20 +119,18 @@ class CordaRPCClientTest : NodeBasedTest() { println("Started issuing cash, waiting on result") flowHandle.returnValue.get() - val finishCash = proxy.getCashBalances() - println("Cash Balances: $finishCash") - assertEquals(1, finishCash.size) - assertEquals(123.DOLLARS, finishCash.get(Currency.getInstance("USD"))) + val cashDollars = proxy.getCashBalance(USD) + println("Balance: $cashDollars") + assertEquals(123.DOLLARS, cashDollars) } @Test fun `flow initiator via RPC`() { login(rpcUser.username, rpcUser.password) val proxy = connection!!.proxy - val smUpdates = proxy.stateMachinesAndUpdates() var countRpcFlows = 0 var countShellFlows = 0 - smUpdates.second.subscribe { + proxy.stateMachinesFeed().updates.subscribe { if (it is StateMachineUpdate.Added) { val initiator = it.stateMachineInfo.initiator if (initiator is FlowInitiator.RPC) diff --git a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt index ac524995c7..4433ba03be 100644 --- a/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt +++ b/client/rpc/src/integration-test/kotlin/net/corda/client/rpc/RPCStabilityTests.kt @@ -1,24 +1,15 @@ package net.corda.client.rpc -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool -import com.google.common.util.concurrent.Futures import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.crypto.random63BitValue -import net.corda.core.future -import net.corda.core.getOrThrow +import net.corda.core.internal.concurrent.fork +import net.corda.core.internal.concurrent.transpose import net.corda.core.messaging.RPCOps -import net.corda.core.millis -import net.corda.core.seconds -import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.Try +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.utilities.* import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.nodeapi.RPCApi -import net.corda.nodeapi.RPCKryo import net.corda.testing.* import net.corda.testing.driver.poll import org.apache.activemq.artemis.api.core.SimpleString @@ -29,10 +20,7 @@ import rx.Observable import rx.subjects.PublishSubject import rx.subjects.UnicastSubject import java.time.Duration -import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.Executors -import java.util.concurrent.ScheduledExecutorService -import java.util.concurrent.TimeUnit +import java.util.concurrent.* import java.util.concurrent.atomic.AtomicInteger class RPCStabilityTests { @@ -238,9 +226,7 @@ class RPCStabilityTests { assertEquals("pong", client.ping()) serverFollower.shutdown() startRpcServer(ops = ops, customPort = serverPort).getOrThrow() - val pingFuture = future { - client.ping() - } + val pingFuture = ForkJoinPool.commonPool().fork(client::ping) assertEquals("pong", pingFuture.getOrThrow(10.seconds)) clientFollower.shutdown() // Driver would do this after the new server, causing hang. } @@ -274,9 +260,9 @@ class RPCStabilityTests { ).get() val numberOfClients = 4 - val clients = Futures.allAsList((1 .. numberOfClients).map { + val clients = (1 .. numberOfClients).map { startRandomRpcClient(server.broker.hostAndPort!!) - }).get() + }.transpose().get() // Poll until all clients connect pollUntilClientNumber(server, numberOfClients) @@ -305,16 +291,8 @@ class RPCStabilityTests { return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk } } } - val dummyObservableSerialiser = object : Serializer>() { - override fun write(kryo: Kryo?, output: Output?, `object`: Observable?) { - } - override fun read(kryo: Kryo?, input: Input?, type: Class>?): Observable { - return Observable.empty() - } - } @Test fun `slow consumers are kicked`() { - val kryoPool = KryoPool.Builder { RPCKryo(dummyObservableSerialiser) }.build() rpcDriver { val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get() @@ -339,7 +317,7 @@ class RPCStabilityTests { methodName = SlowConsumerRPCOps::streamAtInterval.name, arguments = listOf(10.millis, 123456) ) - request.writeToClientMessage(kryoPool, message) + request.writeToClientMessage(SerializationDefaults.RPC_SERVER_CONTEXT, message) producer.send(message) session.commit() diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt index 78baa8d906..3ec8268c3e 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/CordaRPCClient.kt @@ -2,11 +2,17 @@ package net.corda.client.rpc import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration +import net.corda.client.rpc.serialization.KryoClientSerializationScheme import net.corda.core.messaging.CordaRPCOps +import net.corda.core.serialization.SerializationDefaults import net.corda.core.utilities.NetworkHostAndPort import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.config.SSLConfiguration +import net.corda.nodeapi.internal.serialization.AMQPClientSerializationScheme +import net.corda.nodeapi.internal.serialization.KRYO_P2P_CONTEXT +import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT +import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl import java.time.Duration /** @see RPCClient.RPCConnection */ @@ -35,11 +41,22 @@ data class CordaRPCClientConfiguration( class CordaRPCClient( hostAndPort: NetworkHostAndPort, sslConfiguration: SSLConfiguration? = null, - configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default + configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default, + initialiseSerialization: Boolean = true ) { + init { + // Init serialization. It's plausible there are multiple clients in a single JVM, so be tolerant of + // others having registered first. + // TODO: allow clients to have serialization factory etc injected and align with RPC protocol version? + if (initialiseSerialization) { + initialiseSerialization() + } + } + private val rpcClient = RPCClient( tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), - configuration.toRpcClientConfiguration() + configuration.toRpcClientConfiguration(), + KRYO_RPC_CLIENT_CONTEXT ) fun start(username: String, password: String): CordaRPCConnection { @@ -49,4 +66,21 @@ class CordaRPCClient( inline fun use(username: String, password: String, block: (CordaRPCConnection) -> A): A { return start(username, password).use(block) } + + companion object { + fun initialiseSerialization() { + try { + SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { + registerScheme(KryoClientSerializationScheme(this)) + registerScheme(AMQPClientSerializationScheme()) + } + SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT + SerializationDefaults.RPC_CLIENT_CONTEXT = KRYO_RPC_CLIENT_CONTEXT + } catch(e: IllegalStateException) { + // Check that it's registered as we expect + check(SerializationDefaults.SERIALIZATION_FACTORY is SerializationFactoryImpl) { "RPC client encountered conflicting configuration of serialization subsystem." } + check((SerializationDefaults.SERIALIZATION_FACTORY as SerializationFactoryImpl).alreadyRegisteredSchemes.any { it is KryoClientSerializationScheme }) { "RPC client encountered conflicting configuration of serialization subsystem." } + } + } + } } \ No newline at end of file diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt index a792c24faa..94fa65a018 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClient.kt @@ -1,10 +1,12 @@ package net.corda.client.rpc.internal -import net.corda.core.logElapsedTime -import net.corda.core.messaging.RPCOps -import net.corda.core.minutes import net.corda.core.crypto.random63BitValue -import net.corda.core.seconds +import net.corda.core.internal.logElapsedTime +import net.corda.core.messaging.RPCOps +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport @@ -85,13 +87,15 @@ data class RPCClientConfiguration( */ class RPCClient( val transport: TransportConfiguration, - val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default + val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default, + val serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT ) { constructor( hostAndPort: NetworkHostAndPort, sslConfiguration: SSLConfiguration? = null, - configuration: RPCClientConfiguration = RPCClientConfiguration.default - ) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration) + configuration: RPCClientConfiguration = RPCClientConfiguration.default, + serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT + ) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration, serializationContext) companion object { private val log = loggerFor>() @@ -146,7 +150,7 @@ class RPCClient( minLargeMessageSize = rpcConfiguration.maxFileSize } - val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass) + val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass, serializationContext) try { proxyHandler.start() diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt index e83363b7fc..98a9e221ff 100644 --- a/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/internal/RPCClientProxyHandler.kt @@ -4,18 +4,19 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.cache.Cache import com.google.common.cache.CacheBuilder import com.google.common.cache.RemovalCause import com.google.common.cache.RemovalListener import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.ThreadFactoryBuilder -import net.corda.core.ThreadBox +import net.corda.core.internal.ThreadBox import net.corda.core.crypto.random63BitValue -import net.corda.core.getOrThrow +import net.corda.core.internal.LazyPool +import net.corda.core.internal.LazyStickyPool +import net.corda.core.internal.LifeCycle import net.corda.core.messaging.RPCOps -import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.serialization.SerializationContext import net.corda.core.utilities.* import net.corda.nodeapi.* import org.apache.activemq.artemis.api.core.SimpleString @@ -61,7 +62,8 @@ class RPCClientProxyHandler( private val rpcPassword: String, private val serverLocator: ServerLocator, private val clientAddress: SimpleString, - private val rpcOpsClass: Class + private val rpcOpsClass: Class, + serializationContext: SerializationContext ) : InvocationHandler { private enum class State { @@ -74,9 +76,6 @@ class RPCClientProxyHandler( private companion object { val log = loggerFor() - // Note that this KryoPool is not yet capable of deserialising Observables, it requires Proxy-specific context - // to do that. However it may still be used for serialisation of RPC requests and related messages. - val kryoPool: KryoPool = KryoPool.Builder { RPCKryo(RpcClientObservableSerializer) }.build() // To check whether toString() is being invoked val toStringMethod: Method = Object::toString.javaMethod!! } @@ -85,7 +84,7 @@ class RPCClientProxyHandler( private var reaperExecutor: ScheduledExecutorService? = null // A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering. - private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build() + private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").setDaemon(true).build() private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) { Executors.newFixedThreadPool(1, observationExecutorThreadFactory) } @@ -109,11 +108,10 @@ class RPCClientProxyHandler( private val observablesToReap = ThreadBox(object { var observables = ArrayList() }) - // A Kryo pool that automatically adds the observable context when an instance is requested. - private val kryoPoolWithObservableContext = RpcClientObservableSerializer.createPoolWithContext(kryoPool, observableContext) + private val serializationContextWithObservableContext = RpcClientObservableSerializer.createContext(serializationContext, observableContext) private fun createRpcObservableMap(): RpcObservableMap { - val onObservableRemove = RemovalListener>> { + val onObservableRemove = RemovalListener>> { val rpcCallSite = callSiteMap?.remove(it.key.toLong) if (it.cause == RemovalCause.COLLECTED) { log.warn(listOf( @@ -194,7 +192,7 @@ class RPCClientProxyHandler( val replyFuture = SettableFuture.create() sessionAndProducerPool.run { val message = it.session.createMessage(false) - request.writeToClientMessage(kryoPool, message) + request.writeToClientMessage(serializationContextWithObservableContext, message) log.debug { val argumentsString = arguments?.joinToString() ?: "" @@ -221,7 +219,7 @@ class RPCClientProxyHandler( // The handler for Artemis messages. private fun artemisMessageHandler(message: ClientMessage) { - val serverToClient = RPCApi.ServerToClient.fromClientMessage(kryoPoolWithObservableContext, message) + val serverToClient = RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message) log.debug { "Got message from RPC server $serverToClient" } when (serverToClient) { is RPCApi.ServerToClient.RpcReply -> { @@ -338,7 +336,7 @@ class RPCClientProxyHandler( } } -private typealias RpcObservableMap = Cache>> +private typealias RpcObservableMap = Cache>> private typealias RpcReplyMap = ConcurrentHashMap> private typealias CallSiteMap = ConcurrentHashMap @@ -348,7 +346,7 @@ private typealias CallSiteMap = ConcurrentHashMap * @param observableMap holds the Observables that are ultimately exposed to the user. * @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to. */ -private data class ObservableContext( +data class ObservableContext( val callSiteMap: CallSiteMap?, val observableMap: RpcObservableMap, val hardReferenceStore: MutableSet> @@ -357,17 +355,17 @@ private data class ObservableContext( /** * A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext]. */ -private object RpcClientObservableSerializer : Serializer>() { +object RpcClientObservableSerializer : Serializer>() { private object RpcObservableContextKey - fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool { - return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext) + + fun createContext(serializationContext: SerializationContext, observableContext: ObservableContext): SerializationContext { + return serializationContext.withProperty(RpcObservableContextKey, observableContext) } - override fun read(kryo: Kryo, input: Input, type: Class>): Observable { - @Suppress("UNCHECKED_CAST") + override fun read(kryo: Kryo, input: Input, type: Class>): Observable { val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext val observableId = RPCApi.ObservableId(input.readLong(true)) - val observable = UnicastSubject.create>() + val observable = UnicastSubject.create>() require(observableContext.observableMap.getIfPresent(observableId) == null) { "Multiple Observables arrived with the same ID $observableId" } @@ -384,7 +382,7 @@ private object RpcClientObservableSerializer : Serializer>() { }.dematerialize() } - override fun write(kryo: Kryo, output: Output, observable: Observable) { + override fun write(kryo: Kryo, output: Output, observable: Observable<*>) { throw UnsupportedOperationException("Cannot serialise Observables on the client side") } diff --git a/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt b/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt new file mode 100644 index 0000000000..0bb26b93fb --- /dev/null +++ b/client/rpc/src/main/kotlin/net/corda/client/rpc/serialization/SerializationScheme.kt @@ -0,0 +1,28 @@ +package net.corda.client.rpc.serialization + +import com.esotericsoftware.kryo.pool.KryoPool +import net.corda.client.rpc.internal.RpcClientObservableSerializer +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory +import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.RPCKryo +import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme +import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer +import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 + +class KryoClientSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) { + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return byteSequence == KryoHeaderV0_1 && (target == SerializationContext.UseCase.RPCClient || target == SerializationContext.UseCase.P2P) + } + + override fun rpcClientKryoPool(context: SerializationContext): KryoPool { + return KryoPool.Builder { + DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, serializationFactory, context)).apply { classLoader = context.deserializationClassLoader } + }.build() + } + + // We're on the client and don't have access to server classes. + override fun rpcServerKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } +} \ No newline at end of file diff --git a/client/rpc/src/smoke-test/java/net/corda/java/rpc/StandaloneCordaRPCJavaClientTest.java b/client/rpc/src/smoke-test/java/net/corda/java/rpc/StandaloneCordaRPCJavaClientTest.java new file mode 100644 index 0000000000..4d5a250ef6 --- /dev/null +++ b/client/rpc/src/smoke-test/java/net/corda/java/rpc/StandaloneCordaRPCJavaClientTest.java @@ -0,0 +1,87 @@ +package net.corda.java.rpc; + +import net.corda.client.rpc.CordaRPCConnection; +import net.corda.core.contracts.Amount; +import net.corda.core.messaging.CordaRPCOps; +import net.corda.core.messaging.DataFeed; +import net.corda.core.messaging.FlowHandle; +import net.corda.core.node.NodeInfo; +import net.corda.core.node.services.NetworkMapCache; +import net.corda.core.utilities.OpaqueBytes; +import net.corda.flows.AbstractCashFlow; +import net.corda.flows.CashIssueFlow; +import net.corda.nodeapi.User; +import net.corda.smoketesting.NodeConfig; +import net.corda.smoketesting.NodeProcess; +import org.bouncycastle.asn1.x500.X500Name; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.*; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +import static kotlin.test.AssertionsKt.assertEquals; +import static net.corda.contracts.GetBalances.getCashBalance; + +public class StandaloneCordaRPCJavaClientTest { + private List perms = Collections.singletonList("ALL"); + private Set permSet = new HashSet<>(perms); + private User rpcUser = new User("user1", "test", permSet); + + private AtomicInteger port = new AtomicInteger(15000); + + private NodeProcess notary; + private CordaRPCOps rpcProxy; + private CordaRPCConnection connection; + private NodeInfo notaryNode; + + private NodeConfig notaryConfig = new NodeConfig( + new X500Name("CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH"), + port.getAndIncrement(), + port.getAndIncrement(), + port.getAndIncrement(), + Collections.singletonList("corda.notary.validating"), + Arrays.asList(rpcUser), + null + ); + + @Before + public void setUp() { + notary = new NodeProcess.Factory().create(notaryConfig); + connection = notary.connect(); + rpcProxy = connection.getProxy(); + notaryNode = fetchNotaryIdentity(); + } + + @After + public void done() { + try { + connection.close(); + } finally { + notary.close(); + } + } + + private NodeInfo fetchNotaryIdentity() { + DataFeed, NetworkMapCache.MapChange> nodeDataFeed = rpcProxy.networkMapFeed(); + return nodeDataFeed.getSnapshot().get(0); + } + + @Test + public void testCashBalances() throws NoSuchFieldException, ExecutionException, InterruptedException { + Amount dollars123 = new Amount<>(123, Currency.getInstance("USD")); + + FlowHandle flowHandle = rpcProxy.startFlowDynamic(CashIssueFlow.class, + dollars123, OpaqueBytes.of("1".getBytes()), + notaryNode.getLegalIdentity(), notaryNode.getLegalIdentity()); + System.out.println("Started issuing cash, waiting on result"); + flowHandle.getReturnValue().get(); + + Amount balance = getCashBalance(rpcProxy, Currency.getInstance("USD")); + System.out.print("Balance: " + balance + "\n"); + + assertEquals(dollars123, balance, "matching"); + } +} diff --git a/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/StandaloneCordaRPClientTest.kt b/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/StandaloneCordaRPClientTest.kt index 182ca10eea..f26ea31ff8 100644 --- a/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/StandaloneCordaRPClientTest.kt +++ b/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/StandaloneCordaRPClientTest.kt @@ -5,19 +5,19 @@ import com.google.common.hash.HashingInputStream import net.corda.client.rpc.CordaRPCConnection import net.corda.client.rpc.notUsed import net.corda.contracts.asset.Cash -import net.corda.core.contracts.DOLLARS -import net.corda.core.contracts.POUNDS -import net.corda.core.contracts.SWISS_FRANCS +import net.corda.contracts.getCashBalance +import net.corda.contracts.getCashBalances +import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash -import net.corda.core.getOrThrow +import net.corda.core.internal.InputStreamAndHash import net.corda.core.messaging.* import net.corda.core.node.NodeInfo import net.corda.core.node.services.Vault import net.corda.core.node.services.vault.* -import net.corda.core.seconds import net.corda.core.utilities.OpaqueBytes -import net.corda.core.sizedInputStreamAndHash +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.seconds import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow import net.corda.nodeapi.User @@ -35,6 +35,7 @@ import java.util.concurrent.atomic.AtomicInteger import kotlin.test.assertEquals import kotlin.test.assertFalse import kotlin.test.assertNotEquals +import kotlin.test.assertTrue class StandaloneCordaRPClientTest { private companion object { @@ -78,7 +79,7 @@ class StandaloneCordaRPClientTest { @Test fun `test attachments`() { - val attachment = sizedInputStreamAndHash(attachmentSize) + val attachment = InputStreamAndHash.createInMemoryTestZip(attachmentSize, 1) assertFalse(rpcProxy.attachmentExists(attachment.sha256)) val id = WrapperStream(attachment.inputStream).use { rpcProxy.uploadAttachment(it) } assertEquals(attachment.sha256, id, "Attachment has incorrect SHA256 hash") @@ -117,38 +118,38 @@ class StandaloneCordaRPClientTest { @Test fun `test state machines`() { - val (stateMachines, updates) = rpcProxy.stateMachinesAndUpdates() + val (stateMachines, updates) = rpcProxy.stateMachinesFeed() assertEquals(0, stateMachines.size) - var updateCount = 0 + val updateCount = AtomicInteger(0) updates.subscribe { update -> if (update is StateMachineUpdate.Added) { log.info("StateMachine>> Id=${update.id}") - ++updateCount + updateCount.incrementAndGet() } } // Now issue some cash rpcProxy.startFlow(::CashIssueFlow, 513.SWISS_FRANCS, OpaqueBytes.of(0), notaryNode.legalIdentity, notaryNode.notaryIdentity) .returnValue.getOrThrow(timeout) - assertEquals(1, updateCount) + assertEquals(1, updateCount.get()) } @Test fun `test vault track by`() { - val (vault, vaultUpdates) = rpcProxy.vaultTrackBy() - assertEquals(0, vault.states.size) + val (vault, vaultUpdates) = rpcProxy.vaultTrackBy(paging = PageSpecification(DEFAULT_PAGE_NUM)) + assertEquals(0, vault.totalStatesAvailable) - var updateCount = 0 + val updateCount = AtomicInteger(0) vaultUpdates.subscribe { update -> log.info("Vault>> FlowId=${update.flowId}") - ++updateCount + updateCount.incrementAndGet() } // Now issue some cash rpcProxy.startFlow(::CashIssueFlow, 629.POUNDS, OpaqueBytes.of(0), notaryNode.legalIdentity, notaryNode.notaryIdentity) .returnValue.getOrThrow(timeout) - assertNotEquals(0, updateCount) + assertNotEquals(0, updateCount.get()) // Check that this cash exists in the vault val cashBalance = rpcProxy.getCashBalances() @@ -177,10 +178,27 @@ class StandaloneCordaRPClientTest { assertEquals(3, moreResults.totalStatesAvailable) // 629 - 100 + 100 // Check that this cash exists in the vault - val cashBalance = rpcProxy.getCashBalances() - log.info("Cash Balances: $cashBalance") - assertEquals(1, cashBalance.size) - assertEquals(629.POUNDS, cashBalance[Currency.getInstance("GBP")]) + val cashBalances = rpcProxy.getCashBalances() + log.info("Cash Balances: $cashBalances") + assertEquals(1, cashBalances.size) + assertEquals(629.POUNDS, cashBalances[Currency.getInstance("GBP")]) + } + + @Test + fun `test cash balances`() { + val startCash = rpcProxy.getCashBalances() + assertTrue(startCash.isEmpty(), "Should not start with any cash") + + val flowHandle = rpcProxy.startFlow(::CashIssueFlow, + 629.DOLLARS, OpaqueBytes.of(0), + notaryNode.legalIdentity, notaryNode.legalIdentity + ) + println("Started issuing cash, waiting on result") + flowHandle.returnValue.get() + + val balance = rpcProxy.getCashBalance(USD) + println("Balance: " + balance) + assertEquals(629.DOLLARS, balance) } private fun fetchNotaryIdentity(): NodeInfo { diff --git a/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/ValidateClasspathTest.kt b/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/ValidateClasspathTest.kt index ecc534ca8a..f79e97641a 100644 --- a/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/ValidateClasspathTest.kt +++ b/client/rpc/src/smoke-test/kotlin/net/corda/kotlin/rpc/ValidateClasspathTest.kt @@ -1,6 +1,6 @@ package net.corda.kotlin.rpc -import net.corda.core.div +import net.corda.core.internal.div import org.junit.Test import java.io.File import java.nio.file.Path diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt index 20026ab7c1..c6b5329d8c 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/AbstractRPCTest.kt @@ -1,8 +1,8 @@ package net.corda.client.rpc import net.corda.client.rpc.internal.RPCClientConfiguration -import net.corda.core.flatMap -import net.corda.core.map +import net.corda.core.internal.concurrent.flatMap +import net.corda.core.internal.concurrent.map import net.corda.core.messaging.RPCOps import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.nodeapi.User @@ -44,13 +44,13 @@ open class AbstractRPCTest { startInVmRpcClient(rpcUser.username, rpcUser.password, clientConfiguration).map { TestProxy(it, { startInVmArtemisSession(rpcUser.username, rpcUser.password) }) } - }.get() + } RPCTestMode.Netty -> startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server -> startRpcClient(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password, clientConfiguration).map { TestProxy(it, { startArtemisSession(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password) }) } - }.get() - } + } + }.get() } } diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt index 0117504c2e..e798438fa7 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/ClientRPCInfrastructureTests.kt @@ -1,11 +1,11 @@ package net.corda.client.rpc -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.getOrThrow +import net.corda.core.concurrent.CordaFuture +import net.corda.core.internal.concurrent.doneFuture +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.concurrent.thenMatch import net.corda.core.messaging.RPCOps -import net.corda.core.thenMatch +import net.corda.core.utilities.getOrThrow import net.corda.node.services.messaging.getRpcContext import net.corda.nodeapi.RPCSinceVersion import net.corda.testing.RPCDriverExposedDSLInterface @@ -27,7 +27,9 @@ import kotlin.test.assertTrue class ClientRPCInfrastructureTests : AbstractRPCTest() { // TODO: Test that timeouts work - private fun RPCDriverExposedDSLInterface.testProxy() = testProxy(TestOpsImpl()).ops + private fun RPCDriverExposedDSLInterface.testProxy(): TestOps { + return testProxy(TestOpsImpl()).ops + } interface TestOps : RPCOps { @Throws(IllegalArgumentException::class) @@ -41,9 +43,9 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() { fun makeComplicatedObservable(): Observable>> - fun makeListenableFuture(): ListenableFuture + fun makeListenableFuture(): CordaFuture - fun makeComplicatedListenableFuture(): ListenableFuture>> + fun makeComplicatedListenableFuture(): CordaFuture>> @RPCSinceVersion(2) fun addedLater() @@ -52,7 +54,7 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() { } private lateinit var complicatedObservable: Observable>> - private lateinit var complicatedListenableFuturee: ListenableFuture>> + private lateinit var complicatedListenableFuturee: CordaFuture>> inner class TestOpsImpl : TestOps { override val protocolVersion = 1 @@ -60,9 +62,9 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() { override fun void() {} override fun someCalculation(str: String, num: Int) = "$str $num" override fun makeObservable(): Observable = Observable.just(1, 2, 3, 4) - override fun makeListenableFuture(): ListenableFuture = Futures.immediateFuture(1) + override fun makeListenableFuture() = doneFuture(1) override fun makeComplicatedObservable() = complicatedObservable - override fun makeComplicatedListenableFuture(): ListenableFuture>> = complicatedListenableFuturee + override fun makeComplicatedListenableFuture() = complicatedListenableFuturee override fun addedLater(): Unit = throw IllegalStateException() override fun captureUser(): String = getRpcContext().currentUser.username } @@ -150,10 +152,10 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() { fun `complex ListenableFuture`() { rpcDriver { val proxy = testProxy() - val serverQuote = SettableFuture.create>>() + val serverQuote = openFuture>>() complicatedListenableFuturee = serverQuote - val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.") + val twainQuote = "Mark Twain" to doneFuture("I have never let my schooling interfere with my education.") val clientQuotes = LinkedBlockingQueue() val clientFuture = proxy.makeComplicatedListenableFuture() diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt index fb283773d1..ea78ef374e 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCConcurrencyTests.kt @@ -1,10 +1,10 @@ package net.corda.client.rpc import net.corda.client.rpc.internal.RPCClientConfiguration -import net.corda.core.future import net.corda.core.messaging.RPCOps -import net.corda.core.millis +import net.corda.core.utilities.millis import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.concurrent.fork import net.corda.core.serialization.CordaSerializable import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.testing.RPCDriverExposedDSLInterface @@ -17,6 +17,7 @@ import rx.subjects.UnicastSubject import java.util.* import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.CountDownLatch +import java.util.concurrent.ForkJoinPool @RunWith(Parameterized::class) class RPCConcurrencyTests : AbstractRPCTest() { @@ -68,7 +69,7 @@ class RPCConcurrencyTests : AbstractRPCTest() { Observable.empty>() } else { val publish = UnicastSubject.create>() - future { + ForkJoinPool.commonPool().fork { (1..branchingFactor).toList().parallelStream().forEach { publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) } @@ -105,7 +106,7 @@ class RPCConcurrencyTests : AbstractRPCTest() { val done = CountDownLatch(numberOfBlockedCalls) // Start a couple of blocking RPC calls (1..numberOfBlockedCalls).forEach { - future { + ForkJoinPool.commonPool().fork { proxy.ops.waitLatch(id) done.countDown() } diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt index 362f001232..b9d64a3cab 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPerformanceTests.kt @@ -3,12 +3,11 @@ package net.corda.client.rpc import com.google.common.base.Stopwatch import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.core.messaging.RPCOps -import net.corda.core.minutes -import net.corda.core.seconds -import net.corda.core.utilities.div +import net.corda.core.utilities.minutes +import net.corda.core.utilities.seconds +import net.corda.testing.performance.div import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.testing.RPCDriverExposedDSLInterface -import net.corda.testing.driver.ShutdownManager import net.corda.testing.measure import net.corda.testing.performance.startPublishingFixedRateInjector import net.corda.testing.performance.startReporter diff --git a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt index ebc9cef461..f31469bcb4 100644 --- a/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt +++ b/client/rpc/src/test/kotlin/net/corda/client/rpc/RPCPermissionsTests.kt @@ -1,8 +1,8 @@ package net.corda.client.rpc import net.corda.core.messaging.RPCOps -import net.corda.node.services.messaging.requirePermission import net.corda.node.services.messaging.getRpcContext +import net.corda.node.services.messaging.requirePermission import net.corda.nodeapi.PermissionException import net.corda.nodeapi.User import net.corda.testing.RPCDriverExposedDSLInterface diff --git a/constants.properties b/constants.properties index 6993620217..4567dcffbb 100644 --- a/constants.properties +++ b/constants.properties @@ -1,4 +1,4 @@ -gradlePluginsVersion=0.13.2 +gradlePluginsVersion=0.13.6 kotlinVersion=1.1.1 guavaVersion=21.0 bouncycastleVersion=1.57 diff --git a/cordform-common/build.gradle b/cordform-common/build.gradle index 340a4b6ec6..82274e0d09 100644 --- a/cordform-common/build.gradle +++ b/cordform-common/build.gradle @@ -17,3 +17,7 @@ dependencies { // Bouncy Castle: for X.500 distinguished name manipulation compile "org.bouncycastle:bcprov-jdk15on:$bouncycastle_version" } + +publish { + name project.name +} \ No newline at end of file diff --git a/cordform-common/src/main/java/net/corda/cordform/CordformNode.java b/cordform-common/src/main/java/net/corda/cordform/CordformNode.java index 80a9a3795a..9175bead2f 100644 --- a/cordform-common/src/main/java/net/corda/cordform/CordformNode.java +++ b/cordform-common/src/main/java/net/corda/cordform/CordformNode.java @@ -4,6 +4,7 @@ import static java.util.Collections.emptyList; import com.typesafe.config.Config; import com.typesafe.config.ConfigFactory; import com.typesafe.config.ConfigValueFactory; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -86,6 +87,6 @@ public class CordformNode implements NodeDefinition { * @param id The (0-based) BFT replica ID. */ public void bftReplicaId(Integer id) { - config = config.withValue("bftReplicaId", ConfigValueFactory.fromAnyRef(id)); + config = config.withValue("bftSMaRt", ConfigValueFactory.fromMap(Collections.singletonMap("replicaId", id))); } } diff --git a/core/build.gradle b/core/build.gradle index e7dc4fa641..edf8fb1037 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -2,6 +2,7 @@ apply plugin: 'kotlin' apply plugin: 'kotlin-jpa' apply plugin: 'net.corda.plugins.quasar-utils' apply plugin: 'net.corda.plugins.publish-utils' +apply plugin: 'com.jfrog.artifactory' description 'Corda core' @@ -40,22 +41,12 @@ dependencies { // AssertJ: for fluent assertions for testing testCompile "org.assertj:assertj-core:${assertj_version}" - // TODO: Upgrade to junit-quickcheck 0.8, once it is released, - // because it depends on org.javassist:javassist instead - // of javassist:javassist. - testCompile "com.pholser:junit-quickcheck-core:$quickcheck_version" - testCompile "com.pholser:junit-quickcheck-generators:$quickcheck_version" - // Guava: Google utilities library. compile "com.google.guava:guava:$guava_version" // RxJava: observable streams of events. compile "io.reactivex:rxjava:$rxjava_version" - // Kryo: object graph serialization. - compile "com.esotericsoftware:kryo:4.0.0" - compile "de.javakaffee:kryo-serializers:0.41" - // Apache JEXL: An embeddable expression evaluation library. // This may be temporary until we experiment with other ways to do on-the-fly contract specialisation via an API. compile "org.apache.commons:commons-jexl3:3.0" @@ -98,5 +89,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } diff --git a/core/src/main/java/net/corda/core/internal/package-info.java b/core/src/main/java/net/corda/core/internal/package-info.java new file mode 100644 index 0000000000..aa06d1bace --- /dev/null +++ b/core/src/main/java/net/corda/core/internal/package-info.java @@ -0,0 +1,4 @@ +/** + * WARNING: This is an internal package and not part of the public API. Do not use anything found here or in any sub-package. + */ +package net.corda.core.internal; \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/CordaException.kt b/core/src/main/kotlin/net/corda/core/CordaException.kt index 49ed6b6975..4e5353fa77 100644 --- a/core/src/main/kotlin/net/corda/core/CordaException.kt +++ b/core/src/main/kotlin/net/corda/core/CordaException.kt @@ -58,9 +58,9 @@ open class CordaException internal constructor(override var originalExceptionCla } } -open class CordaRuntimeException internal constructor(override var originalExceptionClassName: String?, - private var _message: String? = null, - private var _cause: Throwable? = null) : RuntimeException(null, null, true, true), CordaThrowable { +open class CordaRuntimeException(override var originalExceptionClassName: String?, + private var _message: String? = null, + private var _cause: Throwable? = null) : RuntimeException(null, null, true, true), CordaThrowable { constructor(message: String?, cause: Throwable?) : this(null, message, cause) override val message: String? diff --git a/core/src/main/kotlin/net/corda/core/Streams.kt b/core/src/main/kotlin/net/corda/core/Streams.kt deleted file mode 100644 index 2f33522c35..0000000000 --- a/core/src/main/kotlin/net/corda/core/Streams.kt +++ /dev/null @@ -1,30 +0,0 @@ -package net.corda.core - -import java.util.* -import java.util.Spliterator.* -import java.util.stream.IntStream -import java.util.stream.Stream -import java.util.stream.StreamSupport -import kotlin.streams.asSequence - -private fun IntProgression.spliteratorOfInt(): Spliterator.OfInt { - val kotlinIterator = iterator() - val javaIterator = object : PrimitiveIterator.OfInt { - override fun nextInt() = kotlinIterator.nextInt() - override fun hasNext() = kotlinIterator.hasNext() - override fun remove() = throw UnsupportedOperationException("remove") - } - val spliterator = Spliterators.spliterator( - javaIterator, - (1 + (last - first) / step).toLong(), - SUBSIZED or IMMUTABLE or NONNULL or SIZED or ORDERED or SORTED or DISTINCT - ) - return if (step > 0) spliterator else object : Spliterator.OfInt by spliterator { - override fun getComparator() = Comparator.reverseOrder() - } -} - -fun IntProgression.stream(): IntStream = StreamSupport.intStream(spliteratorOfInt(), false) - -@Suppress("UNCHECKED_CAST") // When toArray has filled in the array, the component type is no longer T? but T (that may itself be nullable). -inline fun Stream.toTypedArray() = toArray { size -> arrayOfNulls(size) } as Array diff --git a/core/src/main/kotlin/net/corda/core/Utils.kt b/core/src/main/kotlin/net/corda/core/Utils.kt index 24d8247df0..d9702f53aa 100644 --- a/core/src/main/kotlin/net/corda/core/Utils.kt +++ b/core/src/main/kotlin/net/corda/core/Utils.kt @@ -1,102 +1,16 @@ -// TODO Move out the Kotlin specific stuff into a separate file @file:JvmName("Utils") package net.corda.core -import com.google.common.base.Throwables -import com.google.common.io.ByteStreams -import com.google.common.util.concurrent.* -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.sha256 -import net.corda.core.flows.FlowException -import net.corda.core.serialization.CordaSerializable -import org.slf4j.Logger +import net.corda.core.concurrent.CordaFuture +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.concurrent.thenMatch import rx.Observable import rx.Observer -import rx.subjects.PublishSubject -import rx.subjects.UnicastSubject -import java.io.* -import java.math.BigDecimal -import java.nio.charset.Charset -import java.nio.charset.StandardCharsets.UTF_8 -import java.nio.file.* -import java.nio.file.attribute.FileAttribute -import java.time.Duration -import java.time.temporal.Temporal -import java.util.concurrent.CompletableFuture -import java.util.concurrent.ExecutionException -import java.util.concurrent.Future -import java.util.concurrent.TimeUnit -import java.util.concurrent.locks.ReentrantLock -import java.util.stream.Stream -import java.util.zip.Deflater -import java.util.zip.ZipEntry -import java.util.zip.ZipInputStream -import java.util.zip.ZipOutputStream -import kotlin.concurrent.withLock -import kotlin.reflect.KClass -import kotlin.reflect.KProperty -val Int.days: Duration get() = Duration.ofDays(this.toLong()) -@Suppress("unused") // It's here for completeness -val Int.hours: Duration get() = Duration.ofHours(this.toLong()) -val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong()) -val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong()) -val Int.millis: Duration get() = Duration.ofMillis(this.toLong()) +// TODO Delete this file once the Future stuff is out of here - -// TODO: Review by EOY2016 if we ever found these utilities helpful. -val Int.bd: BigDecimal get() = BigDecimal(this) -val Double.bd: BigDecimal get() = BigDecimal(this) -val String.bd: BigDecimal get() = BigDecimal(this) -val Long.bd: BigDecimal get() = BigDecimal(this) - -fun String.abbreviate(maxWidth: Int): String = if (length <= maxWidth) this else take(maxWidth - 1) + "…" - -/** Like the + operator but throws an exception in case of integer overflow. */ -infix fun Int.checkedAdd(b: Int) = Math.addExact(this, b) - -/** Like the + operator but throws an exception in case of integer overflow. */ -@Suppress("unused") -infix fun Long.checkedAdd(b: Long) = Math.addExact(this, b) - -/** Same as [Future.get] but with a more descriptive name, and doesn't throw [ExecutionException], instead throwing its cause */ -fun Future.getOrThrow(timeout: Duration? = null): T { - return try { - if (timeout == null) get() else get(timeout.toNanos(), TimeUnit.NANOSECONDS) - } catch (e: ExecutionException) { - throw e.cause!! - } -} - -fun future(block: () -> V): Future = CompletableFuture.supplyAsync(block) - -fun , V> F.then(block: (F) -> V) = addListener(Runnable { block(this) }, MoreExecutors.directExecutor()) - -fun Future.match(success: (U) -> V, failure: (Throwable) -> V): V { - return success(try { - getOrThrow() - } catch (t: Throwable) { - return failure(t) - }) -} - -fun ListenableFuture.thenMatch(success: (U) -> V, failure: (Throwable) -> W) = then { it.match(success, failure) } -fun ListenableFuture<*>.andForget(log: Logger) = then { it.match({}, { log.error("Background task failed:", it) }) } -@Suppress("UNCHECKED_CAST") // We need the awkward cast because otherwise F cannot be nullable, even though it's safe. -infix fun ListenableFuture.map(mapper: (F) -> T): ListenableFuture = Futures.transform(this, { (mapper as (F?) -> T)(it) }) -infix fun ListenableFuture.flatMap(mapper: (F) -> ListenableFuture): ListenableFuture = Futures.transformAsync(this) { mapper(it!!) } - -/** Executes the given block and sets the future to either the result, or any exception that was thrown. */ -inline fun SettableFuture.catch(block: () -> T) { - try { - set(block()) - } catch (t: Throwable) { - setException(t) - } -} - -fun ListenableFuture.toObservable(): Observable { +fun CordaFuture.toObservable(): Observable { return Observable.create { subscriber -> thenMatch({ subscriber.onNext(it) @@ -107,303 +21,26 @@ fun ListenableFuture.toObservable(): Observable { } } -/** Allows you to write code like: Paths.get("someDir") / "subdir" / "filename" but using the Paths API to avoid platform separator problems. */ -operator fun Path.div(other: String): Path = resolve(other) -operator fun String.div(other: String): Path = Paths.get(this) / other - -fun Path.createDirectory(vararg attrs: FileAttribute<*>): Path = Files.createDirectory(this, *attrs) -fun Path.createDirectories(vararg attrs: FileAttribute<*>): Path = Files.createDirectories(this, *attrs) -fun Path.exists(vararg options: LinkOption): Boolean = Files.exists(this, *options) -fun Path.copyToDirectory(targetDir: Path, vararg options: CopyOption): Path { - require(targetDir.isDirectory()) { "$targetDir is not a directory" } - val targetFile = targetDir.resolve(fileName) - Files.copy(this, targetFile, *options) - return targetFile -} -fun Path.moveTo(target: Path, vararg options: CopyOption): Path = Files.move(this, target, *options) -fun Path.isRegularFile(vararg options: LinkOption): Boolean = Files.isRegularFile(this, *options) -fun Path.isDirectory(vararg options: LinkOption): Boolean = Files.isDirectory(this, *options) -val Path.size: Long get() = Files.size(this) -inline fun Path.list(block: (Stream) -> R): R = Files.list(this).use(block) -fun Path.deleteIfExists(): Boolean = Files.deleteIfExists(this) -fun Path.readAll(): ByteArray = Files.readAllBytes(this) -inline fun Path.read(vararg options: OpenOption, block: (InputStream) -> R): R = Files.newInputStream(this, *options).use(block) -inline fun Path.write(createDirs: Boolean = false, vararg options: OpenOption = emptyArray(), block: (OutputStream) -> Unit) { - if (createDirs) { - normalize().parent?.createDirectories() - } - Files.newOutputStream(this, *options).use(block) -} - -inline fun Path.readLines(charset: Charset = UTF_8, block: (Stream) -> R): R = Files.lines(this, charset).use(block) -fun Path.readAllLines(charset: Charset = UTF_8): List = Files.readAllLines(this, charset) -fun Path.writeLines(lines: Iterable, charset: Charset = UTF_8, vararg options: OpenOption): Path = Files.write(this, lines, charset, *options) - -fun InputStream.copyTo(target: Path, vararg options: CopyOption): Long = Files.copy(this, target, *options) - -// Simple infix function to add back null safety that the JDK lacks: timeA until timeB -infix fun Temporal.until(endExclusive: Temporal): Duration = Duration.between(this, endExclusive) - -/** Returns the index of the given item or throws [IllegalArgumentException] if not found. */ -fun List.indexOfOrThrow(item: T): Int { - val i = indexOf(item) - require(i != -1) - return i -} - /** - * Returns the single element matching the given [predicate], or `null` if element was not found, - * or throws if more than one element was found. - */ -fun Iterable.noneOrSingle(predicate: (T) -> Boolean): T? { - var single: T? = null - for (element in this) { - if (predicate(element)) { - if (single == null) { - single = element - } else throw IllegalArgumentException("Collection contains more than one matching element.") - } - } - return single -} - -/** Returns single element, or `null` if element was not found, or throws if more than one element was found. */ -fun Iterable.noneOrSingle(): T? { - var single: T? = null - for (element in this) { - if (single == null) { - single = element - } else throw IllegalArgumentException("Collection contains more than one matching element.") - } - return single -} - -/** Returns a random element in the list, or null if empty */ -fun List.randomOrNull(): T? { - if (size <= 1) return firstOrNull() - val randomIndex = (Math.random() * size).toInt() - return get(randomIndex) -} - -/** Returns a random element in the list matching the given predicate, or null if none found */ -fun List.randomOrNull(predicate: (T) -> Boolean) = filter(predicate).randomOrNull() - -inline fun elapsedTime(block: () -> Unit): Duration { - val start = System.nanoTime() - block() - val end = System.nanoTime() - return Duration.ofNanos(end - start) -} - -// TODO: Add inline back when a new Kotlin version is released and check if the java.lang.VerifyError -// returns in the IRSSimulationTest. If not, commit the inline back. -fun logElapsedTime(label: String, logger: Logger? = null, body: () -> T): T { - // Use nanoTime as it's monotonic. - val now = System.nanoTime() - try { - return body() - } finally { - val elapsed = Duration.ofNanos(System.nanoTime() - now).toMillis() - if (logger != null) - logger.info("$label took $elapsed msec") - else - println("$label took $elapsed msec") - } -} - -fun Logger.logElapsedTime(label: String, body: () -> T): T = logElapsedTime(label, this, body) - -/** - * A threadbox is a simple utility that makes it harder to forget to take a lock before accessing some shared state. - * Simply define a private class to hold the data that must be grouped under the same lock, and then pass the only - * instance to the ThreadBox constructor. You can now use the [locked] method with a lambda to take the lock in a - * way that ensures it'll be released if there's an exception. - * - * Note that this technique is not infallible: if you capture a reference to the fields in another lambda which then - * gets stored and invoked later, there may still be unsafe multi-threaded access going on, so watch out for that. - * This is just a simple guard rail that makes it harder to slip up. - * - * Example: - * - * private class MutableState { var i = 5 } - * private val state = ThreadBox(MutableState()) - * - * val ii = state.locked { i } - */ -class ThreadBox(val content: T, val lock: ReentrantLock = ReentrantLock()) { - inline fun locked(body: T.() -> R): R = lock.withLock { body(content) } - inline fun alreadyLocked(body: T.() -> R): R { - check(lock.isHeldByCurrentThread, { "Expected $lock to already be locked." }) - return body(content) - } - - fun checkNotLocked() = check(!lock.isHeldByCurrentThread) -} - -/** - * This represents a transient exception or condition that might no longer be thrown if the operation is re-run or called - * again. - * - * We avoid the use of the word transient here to hopefully reduce confusion with the term in relation to (Java) serialization. - */ -@CordaSerializable -abstract class RetryableException(message: String) : FlowException(message) - -/** - * A simple wrapper that enables the use of Kotlin's "val x by TransientProperty { ... }" syntax. Such a property - * will not be serialized to disk, and if it's missing (or the first time it's accessed), the initializer will be - * used to set it up. Note that the initializer will be called with the TransientProperty object locked. - */ -class TransientProperty(private val initializer: () -> T) { - @Transient private var v: T? = null - - @Synchronized - operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: initializer().also { v = it } -} - -/** - * Given a path to a zip file, extracts it to the given directory. - */ -fun extractZipFile(zipFile: Path, toDirectory: Path) = extractZipFile(Files.newInputStream(zipFile), toDirectory) - -/** - * Given a zip file input stream, extracts it to the given directory. - */ -fun extractZipFile(inputStream: InputStream, toDirectory: Path) { - val normalisedDirectory = toDirectory.normalize().createDirectories() - ZipInputStream(BufferedInputStream(inputStream)).use { - while (true) { - val e = it.nextEntry ?: break - val outPath = (normalisedDirectory / e.name).normalize() - - // Security checks: we should reject a zip that contains tricksy paths that try to escape toDirectory. - check(outPath.startsWith(normalisedDirectory)) { "ZIP contained a path that resolved incorrectly: ${e.name}" } - - if (e.isDirectory) { - outPath.createDirectories() - continue - } - outPath.write { out -> - ByteStreams.copy(it, out) - } - it.closeEntry() - } - } -} - -/** - * Get a valid InputStream from an in-memory zip as required for tests. - * Note that a slightly bigger than numOfExpectedBytes size is expected. - */ -@Throws(IllegalArgumentException::class) -fun sizedInputStreamAndHash(numOfExpectedBytes: Int): InputStreamAndHash { - if (numOfExpectedBytes <= 0) throw IllegalArgumentException("A positive number of numOfExpectedBytes is required.") - val baos = ByteArrayOutputStream() - ZipOutputStream(baos).use({ zos -> - val arraySize = 1024 - val bytes = ByteArray(arraySize) - val n = (numOfExpectedBytes - 1) / arraySize + 1 // same as Math.ceil(numOfExpectedBytes/arraySize). - zos.setLevel(Deflater.NO_COMPRESSION) - zos.putNextEntry(ZipEntry("z")) - for (i in 0 until n) { - zos.write(bytes, 0, arraySize) - } - zos.closeEntry() - }) - return getInputStreamAndHashFromOutputStream(baos) -} - -/** Convert a [ByteArrayOutputStream] to [InputStreamAndHash]. */ -fun getInputStreamAndHashFromOutputStream(baos: ByteArrayOutputStream): InputStreamAndHash { - // TODO: Consider converting OutputStream to InputStream without creating a ByteArray, probably using piped streams. - val bytes = baos.toByteArray() - // TODO: Consider calculating sha256 on the fly using a DigestInputStream. - return InputStreamAndHash(ByteArrayInputStream(bytes), bytes.sha256()) -} - -data class InputStreamAndHash(val inputStream: InputStream, val sha256: SecureHash.SHA256) - -// TODO: Generic csv printing utility for clases. - -val Throwable.rootCause: Throwable get() = Throwables.getRootCause(this) - -/** - * Returns an Observable that buffers events until subscribed. - * @see UnicastSubject - */ -fun Observable.bufferUntilSubscribed(): Observable { - val subject = UnicastSubject.create() - val subscription = subscribe(subject) - return subject.doOnUnsubscribe { subscription.unsubscribe() } -} - -/** - * Copy an [Observer] to multiple other [Observer]s. - */ -fun Observer.tee(vararg teeTo: Observer): Observer { - val subject = PublishSubject.create() - subject.subscribe(this) - teeTo.forEach { subject.subscribe(it) } - return subject -} - -/** - * Returns a [ListenableFuture] bound to the *first* item emitted by this Observable. The future will complete with a + * Returns a [CordaFuture] bound to the *first* item emitted by this Observable. The future will complete with a * NoSuchElementException if no items are emitted or any other error thrown by the Observable. If it's cancelled then * it will unsubscribe from the observable. */ -fun Observable.toFuture(): ListenableFuture = ObservableToFuture(this) +fun Observable.toFuture(): CordaFuture = openFuture().also { + val subscription = first().subscribe(object : Observer { + override fun onNext(value: T) { + it.set(value) + } -private class ObservableToFuture(observable: Observable) : AbstractFuture(), Observer { - private val subscription = observable.first().subscribe(this) - override fun onNext(value: T) { - set(value) - } + override fun onError(e: Throwable) { + it.setException(e) + } - override fun onError(e: Throwable) { - setException(e) - } - - override fun cancel(mayInterruptIfRunning: Boolean): Boolean { - subscription.unsubscribe() - return super.cancel(mayInterruptIfRunning) - } - - override fun onCompleted() {} -} - -/** Return the sum of an Iterable of [BigDecimal]s. */ -fun Iterable.sum(): BigDecimal = fold(BigDecimal.ZERO) { a, b -> a + b } - -fun codePointsString(vararg codePoints: Int): String { - val builder = StringBuilder() - codePoints.forEach { builder.append(Character.toChars(it)) } - return builder.toString() -} - -fun Class.checkNotUnorderedHashMap() { - if (HashMap::class.java.isAssignableFrom(this) && !LinkedHashMap::class.java.isAssignableFrom(this)) { - throw NotSerializableException("Map type $this is unstable under iteration. Suggested fix: use LinkedHashMap instead.") - } -} - -fun Class<*>.requireExternal(msg: String = "Internal class") - = require(!name.startsWith("net.corda.node.") && !name.contains(".internal.")) { "$msg: $name" } - -interface DeclaredField { - companion object { - inline fun Any?.declaredField(clazz: KClass<*>, name: String): DeclaredField = declaredField(clazz.java, name) - inline fun Any.declaredField(name: String): DeclaredField = declaredField(javaClass, name) - inline fun Any?.declaredField(clazz: Class<*>, name: String): DeclaredField { - val javaField = clazz.getDeclaredField(name).apply { isAccessible = true } - val receiver = this - return object : DeclaredField { - override var value - get() = javaField.get(receiver) as T - set(value) = javaField.set(receiver, value) - } + override fun onCompleted() {} + }) + it.then { + if (it.isCancelled) { + subscription.unsubscribe() } } - - var value: T } diff --git a/core/src/main/kotlin/net/corda/core/concurrent/ConcurrencyUtils.kt b/core/src/main/kotlin/net/corda/core/concurrent/ConcurrencyUtils.kt index 8ab4d6f4e1..fd0947e1ca 100644 --- a/core/src/main/kotlin/net/corda/core/concurrent/ConcurrencyUtils.kt +++ b/core/src/main/kotlin/net/corda/core/concurrent/ConcurrencyUtils.kt @@ -1,34 +1,44 @@ package net.corda.core.concurrent -import com.google.common.annotations.VisibleForTesting -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.catch -import net.corda.core.match -import net.corda.core.then +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.utilities.getOrThrow +import net.corda.core.internal.VisibleForTesting import org.slf4j.Logger import org.slf4j.LoggerFactory +import java.util.concurrent.Future import java.util.concurrent.atomic.AtomicBoolean +/** Invoke [getOrThrow] and pass the value/throwable to success/failure respectively. */ +fun Future.match(success: (V) -> W, failure: (Throwable) -> W): W { + val value = try { + getOrThrow() + } catch (t: Throwable) { + return failure(t) + } + return success(value) +} + /** * As soon as a given future becomes done, the handler is invoked with that future as its argument. * The result of the handler is copied into the result future, and the handler isn't invoked again. * If a given future errors after the result future is done, the error is automatically logged. */ -fun firstOf(vararg futures: ListenableFuture, handler: (ListenableFuture) -> T) = firstOf(futures, defaultLog, handler) +fun firstOf(vararg futures: CordaFuture, handler: (CordaFuture) -> W) = firstOf(futures, defaultLog, handler) private val defaultLog = LoggerFactory.getLogger("net.corda.core.concurrent") @VisibleForTesting internal val shortCircuitedTaskFailedMessage = "Short-circuited task failed:" -internal fun firstOf(futures: Array>, log: Logger, handler: (ListenableFuture) -> T): ListenableFuture { - val resultFuture = SettableFuture.create() +internal fun firstOf(futures: Array>, log: Logger, handler: (CordaFuture) -> W): CordaFuture { + val resultFuture = openFuture() val winnerChosen = AtomicBoolean() futures.forEach { it.then { if (winnerChosen.compareAndSet(false, true)) { - resultFuture.catch { handler(it) } - } else if (!it.isCancelled) { + resultFuture.capture { handler(it) } + } else if (it.isCancelled) { + // Do nothing. + } else { it.match({}, { log.error(shortCircuitedTaskFailedMessage, it) }) } } diff --git a/core/src/main/kotlin/net/corda/core/concurrent/CordaFuture.kt b/core/src/main/kotlin/net/corda/core/concurrent/CordaFuture.kt new file mode 100644 index 0000000000..3ecab2e395 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/concurrent/CordaFuture.kt @@ -0,0 +1,22 @@ +package net.corda.core.concurrent + +import java.util.concurrent.CompletableFuture +import java.util.concurrent.Future + +/** + * Same as [Future] with additional methods to provide some of the features of [java.util.concurrent.CompletableFuture] while minimising the API surface area. + * In Kotlin, to avoid compile errors, whenever CordaFuture is used in a parameter or extension method receiver type, its type parameter should be specified with out variance. + */ +interface CordaFuture : Future { + /** + * Run the given callback when this future is done, on the completion thread. + * If the completion thread is problematic for you e.g. deadlock, you can submit to an executor manually. + * If callback fails, its throwable is logged. + */ + fun then(callback: (CordaFuture) -> W): Unit + + /** + * @return a new [CompletableFuture] with the same outcome as this Future. + */ + fun toCompletableFuture(): CompletableFuture +} diff --git a/core/src/main/kotlin/net/corda/core/contracts/Amount.kt b/core/src/main/kotlin/net/corda/core/contracts/Amount.kt index 0569c6751c..e08ddce671 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/Amount.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/Amount.kt @@ -1,5 +1,8 @@ package net.corda.core.contracts +import net.corda.core.crypto.composite.CompositeKey +import net.corda.core.utilities.exactAdd +import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable import java.math.BigDecimal import java.math.RoundingMode @@ -168,7 +171,7 @@ data class Amount(val quantity: Long, val displayTokenSize: BigDecimal, */ operator fun plus(other: Amount): Amount { checkToken(other) - return Amount(Math.addExact(quantity, other.quantity), displayTokenSize, token) + return Amount(quantity exactAdd other.quantity, displayTokenSize, token) } /** @@ -268,9 +271,9 @@ data class SourceAndAmount(val source: P, val amount: Amou * but in various scenarios it may be more consistent to allow positive and negative values. * For example it is common for a bank to code asset flows as gains and losses from its perspective i.e. always the destination. * @param token represents the type of asset token as would be used to construct Amount objects. - * @param source is the [Party], [Account], [CompositeKey], or other identifier of the token source if quantityDelta is positive, + * @param source is the [Party], [CompositeKey], or other identifier of the token source if quantityDelta is positive, * or the token sink if quantityDelta is negative. The type P should support value equality. - * @param destination is the [Party], [Account], [CompositeKey], or other identifier of the token sink if quantityDelta is positive, + * @param destination is the [Party], [CompositeKey], or other identifier of the token sink if quantityDelta is positive, * or the token source if quantityDelta is negative. The type P should support value equality. */ @CordaSerializable @@ -329,7 +332,7 @@ class AmountTransfer(val quantityDelta: Long, "Only AmountTransfer between the same two parties can be aggregated/netted" } return if (other.source == source) { - AmountTransfer(Math.addExact(quantityDelta, other.quantityDelta), token, source, destination) + AmountTransfer(quantityDelta exactAdd other.quantityDelta, token, source, destination) } else { AmountTransfer(Math.subtractExact(quantityDelta, other.quantityDelta), token, source, destination) } @@ -388,10 +391,10 @@ class AmountTransfer(val quantityDelta: Long, * relative asset exchange happens, but with each party exchanging versus a central counterparty, or clearing house. * * @param centralParty The central party to face the exchange against. - * @return Returns two new AmountTransfers each between one of the original parties and the centralParty. + * @return Returns a list of two new AmountTransfers each between one of the original parties and the centralParty. * The net total exchange is the same as in the original input. */ - fun novate(centralParty: P): Pair, AmountTransfer> = Pair(copy(destination = centralParty), copy(source = centralParty)) + fun novate(centralParty: P): List> = listOf(copy(destination = centralParty), copy(source = centralParty)) /** * Applies this AmountTransfer to a list of [SourceAndAmount] objects representing balances. diff --git a/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt b/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt index 4446c3c04f..fd38e9aa99 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt @@ -2,6 +2,7 @@ package net.corda.core.contracts +import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import java.math.BigDecimal import java.security.PublicKey @@ -54,13 +55,6 @@ object Requirements { infix inline fun String.using(expr: Boolean) { if (!expr) throw IllegalArgumentException("Failed requirement: $this") } - // Avoid overloading Kotlin keywords - @Deprecated("This function is deprecated, use 'using' instead", - ReplaceWith("using (expr)", "net.corda.core.contracts.Requirements.using")) - @Suppress("NOTHING_TO_INLINE") // Inlining this takes it out of our committed ABI. - infix inline fun String.by(expr: Boolean) { - using(expr) - } } inline fun requireThat(body: Requirements.() -> R) = Requirements.body() @@ -71,7 +65,7 @@ inline fun requireThat(body: Requirements.() -> R) = Requirements.body() /** Filters the command list by type, party and public key all at once. */ inline fun Collection>.select(signer: PublicKey? = null, - party: Party? = null) = + party: AbstractParty? = null) = filter { it.value is T }. filter { if (signer == null) true else signer in it.signers }. filter { if (party == null) true else party in it.signingParties }. diff --git a/core/src/main/kotlin/net/corda/core/contracts/Structures.kt b/core/src/main/kotlin/net/corda/core/contracts/Structures.kt index 0740a4a7e4..f820cdd674 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/Structures.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/Structures.kt @@ -1,19 +1,24 @@ +@file:JvmName("Structures") + package net.corda.core.contracts -import net.corda.core.contracts.clauses.Clause import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.secureRandomBytes import net.corda.core.flows.FlowLogicRef import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party -import net.corda.core.serialization.* +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.MissingAttachmentsException +import net.corda.core.serialization.SerializeAsTokenContext +import net.corda.core.serialization.serialize +import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.OpaqueBytes import java.io.FileNotFoundException import java.io.IOException import java.io.InputStream import java.io.OutputStream import java.security.PublicKey -import java.time.Duration import java.time.Instant import java.util.jar.JarInputStream @@ -76,7 +81,7 @@ interface ContractState { * A _participant_ is any party that is able to consume this state in a valid transaction. * * The list of participants is required for certain types of transactions. For example, when changing the notary - * for this state ([TransactionType.NotaryChange]), every participant has to be involved and approve the transaction + * for this state, every participant has to be involved and approve the transaction * so that they receive the updated state, and don't end up in a situation where they can no longer use a state * they possess, since someone consumed that state during the notary change process. * @@ -141,6 +146,12 @@ data class Issued(val issuer: PartyAndReference, val product: P) { fun Amount>.withoutIssuer(): Amount = Amount(quantity, token.product) // DOCSTART 3 + +/** + * Return structure for [OwnableState.withNewOwner] + */ +data class CommandAndState(val command: CommandData, val ownableState: OwnableState) + /** * A contract state that can have a single owner. */ @@ -149,7 +160,7 @@ interface OwnableState : ContractState { val owner: AbstractParty /** Copies the underlying data structure, replacing the owner field with this new value and leaving the rest alone */ - fun withNewOwner(newOwner: AbstractParty): Pair + fun withNewOwner(newOwner: AbstractParty): CommandAndState } // DOCEND 3 @@ -199,26 +210,6 @@ interface LinearState : ContractState { * True if this should be tracked by our vault(s). */ fun isRelevant(ourKeys: Set): Boolean - - /** - * Standard clause to verify the LinearState safety properties. - */ - @CordaSerializable - class ClauseVerifier : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Unit?): Set { - val inputIds = inputs.map { it.linearId }.distinct() - val outputIds = outputs.map { it.linearId }.distinct() - requireThat { - "LinearStates are not merged" using (inputIds.count() == inputs.count()) - "LinearStates are not split" using (outputIds.count() == outputs.count()) - } - return emptySet() - } - } } // DOCEND 2 @@ -281,14 +272,13 @@ abstract class TypeOnlyCommandData : CommandData { /** Command data/content plus pubkey pair: the signature is stored at the end of the serialized bytes */ @CordaSerializable -// DOCSTART 9 -data class Command(val value: CommandData, val signers: List) { -// DOCEND 9 +data class Command(val value: T, val signers: List) { + // TODO Introduce NonEmptyList? init { require(signers.isNotEmpty()) } - constructor(data: CommandData, key: PublicKey) : this(data, listOf(key)) + constructor(data: T, key: PublicKey) : this(data, listOf(key)) private fun commandDataToString() = value.toString().let { if (it.contains("@")) it.replace('$', '.').split("@")[0] else it } override fun toString() = "${commandDataToString()} with pubkeys ${signers.joinToString()}" @@ -324,63 +314,6 @@ data class AuthenticatedObject( ) // DOCEND 6 -/** - * A time-window is required for validation/notarization purposes. - * If present in a transaction, contains a time that was verified by the uniqueness service. The true time must be - * between (fromTime, untilTime). - * Usually, a time-window is required to have both sides set (fromTime, untilTime). - * However, some apps may require that a time-window has a start [Instant] (fromTime), but no end [Instant] (untilTime) and vice versa. - * TODO: Consider refactoring using TimeWindow abstraction like TimeWindow.From, TimeWindow.Until, TimeWindow.Between. - */ -@CordaSerializable -class TimeWindow private constructor( - /** The time at which this transaction is said to have occurred is after this moment. */ - val fromTime: Instant?, - /** The time at which this transaction is said to have occurred is before this moment. */ - val untilTime: Instant? -) { - companion object { - /** Use when the left-side [fromTime] of a [TimeWindow] is only required and we don't need an end instant (untilTime). */ - @JvmStatic - fun fromOnly(fromTime: Instant) = TimeWindow(fromTime, null) - - /** Use when the right-side [untilTime] of a [TimeWindow] is only required and we don't need a start instant (fromTime). */ - @JvmStatic - fun untilOnly(untilTime: Instant) = TimeWindow(null, untilTime) - - /** Use when both sides of a [TimeWindow] must be set ([fromTime], [untilTime]). */ - @JvmStatic - fun between(fromTime: Instant, untilTime: Instant): TimeWindow { - require(fromTime < untilTime) { "fromTime should be earlier than untilTime" } - return TimeWindow(fromTime, untilTime) - } - - /** Use when we have a start time and a period of validity. */ - @JvmStatic - fun fromStartAndDuration(fromTime: Instant, duration: Duration): TimeWindow = between(fromTime, fromTime + duration) - - /** - * When we need to create a [TimeWindow] based on a specific time [Instant] and some tolerance in both sides of this instant. - * The result will be the following time-window: ([time] - [tolerance], [time] + [tolerance]). - */ - @JvmStatic - fun withTolerance(time: Instant, tolerance: Duration) = between(time - tolerance, time + tolerance) - } - - /** The midpoint is calculated as fromTime + (untilTime - fromTime)/2. Note that it can only be computed if both sides are set. */ - val midpoint: Instant get() = fromTime!! + Duration.between(fromTime, untilTime!!).dividedBy(2) - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is TimeWindow) return false - return (fromTime == other.fromTime && untilTime == other.untilTime) - } - - override fun hashCode() = 31 * (fromTime?.hashCode() ?: 0) + (untilTime?.hashCode() ?: 0) - - override fun toString() = "TimeWindow(fromTime=$fromTime, untilTime=$untilTime)" -} - // DOCSTART 5 /** * Implemented by a program that implements business logic on the shared ledger. All participants run this code for @@ -399,7 +332,7 @@ interface Contract { * existing contract code. */ @Throws(IllegalArgumentException::class) - fun verify(tx: TransactionForContract) + fun verify(tx: LedgerTransaction) /** * Unparsed reference to the natural language contract that this code is supposed to express (usually a hash of @@ -489,3 +422,23 @@ fun JarInputStream.extractFile(path: String, outputTo: OutputStream) { } throw FileNotFoundException(path) } + +/** + * A privacy salt is required to compute nonces per transaction component in order to ensure that an adversary cannot + * use brute force techniques and reveal the content of a Merkle-leaf hashed value. + * Because this salt serves the role of the seed to compute nonces, its size and entropy should be equal to the + * underlying hash function used for Merkle tree generation, currently [SHA256], which has an output of 32 bytes. + * There are two constructors, one that generates a new 32-bytes random salt, and another that takes a [ByteArray] input. + * The latter is required in cases where the salt value needs to be pre-generated (agreed between transacting parties), + * but it is highlighted that one should always ensure it has sufficient entropy. + */ +@CordaSerializable +class PrivacySalt(bytes: ByteArray) : OpaqueBytes(bytes) { + /** Constructs a salt with a randomly-generated 32 byte value. */ + constructor() : this(secureRandomBytes(32)) + + init { + require(bytes.size == 32) { "Privacy salt should be 32 bytes." } + require(!bytes.all { it == 0.toByte() }) { "Privacy salt should not be all zeros." } + } +} diff --git a/core/src/main/kotlin/net/corda/core/contracts/TimeWindow.kt b/core/src/main/kotlin/net/corda/core/contracts/TimeWindow.kt new file mode 100644 index 0000000000..e26ea50cd0 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/contracts/TimeWindow.kt @@ -0,0 +1,92 @@ +package net.corda.core.contracts + +import net.corda.core.internal.div +import net.corda.core.internal.until +import net.corda.core.serialization.CordaSerializable +import net.corda.core.transactions.WireTransaction +import java.time.Duration +import java.time.Instant + +/** + * An interval on the time-line; not a single instantaneous point. + * + * There is no such thing as _exact_ time in networked systems due to the underlying physics involved and other issues + * such as network latency. The best that can be approximated is "fuzzy time" or an instant of time which has margin of + * tolerance around it. This is what [TimeWindow] represents. Time windows can be open-ended (i.e. specify only one of + * [fromTime] and [untilTime]) or they can be fully bounded. + * + * [WireTransaction] has an optional time-window property, which if specified, restricts the validity of the transaction + * to that time-interval as the Consensus Service will not sign it if it's received outside of this window. + */ +@CordaSerializable +abstract class TimeWindow { + companion object { + /** Creates a [TimeWindow] with null [untilTime], i.e. the time interval `[fromTime, ∞)`. [midpoint] will return null. */ + @JvmStatic + fun fromOnly(fromTime: Instant): TimeWindow = From(fromTime) + + /** Creates a [TimeWindow] with null [fromTime], i.e. the time interval `(∞, untilTime)`. [midpoint] will return null. */ + @JvmStatic + fun untilOnly(untilTime: Instant): TimeWindow = Until(untilTime) + + /** + * Creates a [TimeWindow] with the time interval `[fromTime, untilTime)`. [midpoint] will return + * `fromTime + (untilTime - fromTime) / 2`. + * @throws IllegalArgumentException If [fromTime] ≥ [untilTime] + */ + @JvmStatic + fun between(fromTime: Instant, untilTime: Instant): TimeWindow = Between(fromTime, untilTime) + + /** + * Creates a [TimeWindow] with the time interval `[fromTime, fromTime + duration)`. [midpoint] will return + * `fromTime + duration / 2` + */ + @JvmStatic + fun fromStartAndDuration(fromTime: Instant, duration: Duration): TimeWindow = between(fromTime, fromTime + duration) + + /** + * Creates a [TimeWindow] which is centered around [instant] with the given [tolerance] on both sides, i.e the + * time interval `[instant - tolerance, instant + tolerance)`. [midpoint] will return [instant]. + */ + @JvmStatic + fun withTolerance(instant: Instant, tolerance: Duration) = between(instant - tolerance, instant + tolerance) + } + + /** Returns the inclusive lower-bound of this [TimeWindow]'s interval, with null implying infinity. */ + abstract val fromTime: Instant? + + /** Returns the exclusive upper-bound of this [TimeWindow]'s interval, with null implying infinity. */ + abstract val untilTime: Instant? + + /** + * Returns the midpoint of [fromTime] and [untilTime] if both are non-null, calculated as + * `fromTime + (untilTime - fromTime) / 2`, otherwise returns null. + */ + abstract val midpoint: Instant? + + /** Returns true iff the given [instant] is within the time interval of this [TimeWindow]. */ + abstract operator fun contains(instant: Instant): Boolean + + private data class From(override val fromTime: Instant) : TimeWindow() { + override val untilTime: Instant? get() = null + override val midpoint: Instant? get() = null + override fun contains(instant: Instant): Boolean = instant >= fromTime + override fun toString(): String = "[$fromTime, ∞)" + } + + private data class Until(override val untilTime: Instant) : TimeWindow() { + override val fromTime: Instant? get() = null + override val midpoint: Instant? get() = null + override fun contains(instant: Instant): Boolean = instant < untilTime + override fun toString(): String = "(∞, $untilTime)" + } + + private data class Between(override val fromTime: Instant, override val untilTime: Instant) : TimeWindow() { + init { + require(fromTime < untilTime) { "fromTime must be earlier than untilTime" } + } + override val midpoint: Instant get() = fromTime + (fromTime until untilTime) / 2 + override fun contains(instant: Instant): Boolean = instant >= fromTime && instant < untilTime + override fun toString(): String = "[$fromTime, $untilTime)" + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/contracts/TransactionTypes.kt b/core/src/main/kotlin/net/corda/core/contracts/TransactionTypes.kt deleted file mode 100644 index db388b5c8c..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/TransactionTypes.kt +++ /dev/null @@ -1,176 +0,0 @@ -package net.corda.core.contracts - -import net.corda.core.identity.Party -import net.corda.core.serialization.CordaSerializable -import net.corda.core.transactions.LedgerTransaction -import net.corda.core.transactions.TransactionBuilder -import java.security.PublicKey - -/** Defines transaction build & validation logic for a specific transaction type */ -@CordaSerializable -sealed class TransactionType { - /** - * Check that the transaction is valid based on: - * - General platform rules - * - Rules for the specific transaction type - * - * Note: Presence of _signatures_ is not checked, only the public keys to be signed for. - */ - @Throws(TransactionVerificationException::class) - fun verify(tx: LedgerTransaction) { - require(tx.notary != null || tx.timeWindow == null) { "Transactions with time-windows must be notarised" } - val duplicates = detectDuplicateInputs(tx) - if (duplicates.isNotEmpty()) throw TransactionVerificationException.DuplicateInputStates(tx.id, duplicates) - val missing = verifySigners(tx) - if (missing.isNotEmpty()) throw TransactionVerificationException.SignersMissing(tx.id, missing.toList()) - verifyTransaction(tx) - } - - /** Check that the list of signers includes all the necessary keys */ - fun verifySigners(tx: LedgerTransaction): Set { - val notaryKey = tx.inputs.map { it.state.notary.owningKey }.toSet() - if (notaryKey.size > 1) throw TransactionVerificationException.MoreThanOneNotary(tx.id) - - val requiredKeys = getRequiredSigners(tx) + notaryKey - val missing = requiredKeys - tx.mustSign - - return missing - } - - /** Check that the inputs are unique. */ - private fun detectDuplicateInputs(tx: LedgerTransaction): Set { - var seenInputs = emptySet() - var duplicates = emptySet() - tx.inputs.forEach { state -> - if (seenInputs.contains(state.ref)) { - duplicates += state.ref - } - seenInputs += state.ref - } - return duplicates - } - - /** - * Return the list of public keys that that require signatures for the transaction type. - * Note: the notary key is checked separately for all transactions and need not be included. - */ - abstract fun getRequiredSigners(tx: LedgerTransaction): Set - - /** Implement type specific transaction validation logic */ - abstract fun verifyTransaction(tx: LedgerTransaction) - - /** A general transaction type where transaction validity is determined by custom contract code */ - object General : TransactionType() { - /** Just uses the default [TransactionBuilder] with no special logic */ - class Builder(notary: Party?) : TransactionBuilder(General, notary) - - override fun verifyTransaction(tx: LedgerTransaction) { - verifyNoNotaryChange(tx) - verifyEncumbrances(tx) - verifyContracts(tx) - } - - /** - * Make sure the notary has stayed the same. As we can't tell how inputs and outputs connect, if there - * are any inputs, all outputs must have the same notary. - * - * TODO: Is that the correct set of restrictions? May need to come back to this, see if we can be more - * flexible on output notaries. - */ - private fun verifyNoNotaryChange(tx: LedgerTransaction) { - if (tx.notary != null && tx.inputs.isNotEmpty()) { - tx.outputs.forEach { - if (it.notary != tx.notary) { - throw TransactionVerificationException.NotaryChangeInWrongTransactionType(tx.id, tx.notary, it.notary) - } - } - } - } - - private fun verifyEncumbrances(tx: LedgerTransaction) { - // Validate that all encumbrances exist within the set of input states. - val encumberedInputs = tx.inputs.filter { it.state.encumbrance != null } - encumberedInputs.forEach { (state, ref) -> - val encumbranceStateExists = tx.inputs.any { - it.ref.txhash == ref.txhash && it.ref.index == state.encumbrance - } - if (!encumbranceStateExists) { - throw TransactionVerificationException.TransactionMissingEncumbranceException( - tx.id, - state.encumbrance!!, - TransactionVerificationException.Direction.INPUT - ) - } - } - - // Check that, in the outputs, an encumbered state does not refer to itself as the encumbrance, - // and that the number of outputs can contain the encumbrance. - for ((i, output) in tx.outputs.withIndex()) { - val encumbranceIndex = output.encumbrance ?: continue - if (encumbranceIndex == i || encumbranceIndex >= tx.outputs.size) { - throw TransactionVerificationException.TransactionMissingEncumbranceException( - tx.id, - encumbranceIndex, - TransactionVerificationException.Direction.OUTPUT) - } - } - } - - /** - * Check the transaction is contract-valid by running the verify() for each input and output state contract. - * If any contract fails to verify, the whole transaction is considered to be invalid. - */ - private fun verifyContracts(tx: LedgerTransaction) { - val ctx = tx.toTransactionForContract() - // TODO: This will all be replaced in future once the sandbox and contract constraints work is done. - val contracts = (ctx.inputs.map { it.contract } + ctx.outputs.map { it.contract }).toSet() - for (contract in contracts) { - try { - contract.verify(ctx) - } catch(e: Throwable) { - throw TransactionVerificationException.ContractRejection(tx.id, contract, e) - } - } - } - - override fun getRequiredSigners(tx: LedgerTransaction) = tx.commands.flatMap { it.signers }.toSet() - } - - /** - * A special transaction type for reassigning a notary for a state. Validation does not involve running - * any contract code, it just checks that the states are unmodified apart from the notary field. - */ - object NotaryChange : TransactionType() { - /** - * A transaction builder that automatically sets the transaction type to [NotaryChange] - * and adds the list of participants to the signers set for every input state. - */ - class Builder(notary: Party) : TransactionBuilder(NotaryChange, notary) { - override fun addInputState(stateAndRef: StateAndRef<*>): TransactionBuilder { - signers.addAll(stateAndRef.state.data.participants.map { it.owningKey }) - super.addInputState(stateAndRef) - return this - } - } - - /** - * Check that the difference between inputs and outputs is only the notary field, and that all required signing - * public keys are present. - * - * @throws TransactionVerificationException.InvalidNotaryChange if the validity check fails. - */ - override fun verifyTransaction(tx: LedgerTransaction) { - try { - for ((input, output) in tx.inputs.zip(tx.outputs)) { - check(input.state.data == output.data) - check(input.state.notary != output.notary) - } - check(tx.commands.isEmpty()) - } catch (e: IllegalStateException) { - throw TransactionVerificationException.InvalidNotaryChange(tx.id) - } - } - - override fun getRequiredSigners(tx: LedgerTransaction) = tx.inputs.flatMap { it.state.data.participants }.map { it.owningKey }.toSet() - } -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/TransactionVerification.kt b/core/src/main/kotlin/net/corda/core/contracts/TransactionVerification.kt deleted file mode 100644 index 5417169459..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/TransactionVerification.kt +++ /dev/null @@ -1,130 +0,0 @@ -package net.corda.core.contracts - -import net.corda.core.crypto.SecureHash -import net.corda.core.flows.FlowException -import net.corda.core.identity.Party -import net.corda.core.serialization.CordaSerializable -import java.security.PublicKey -import java.util.* - -// TODO: Consider moving this out of the core module and providing a different way for unit tests to test contracts. - -/** - * A transaction to be passed as input to a contract verification function. Defines helper methods to - * simplify verification logic in contracts. - */ -// DOCSTART 1 -data class TransactionForContract(val inputs: List, - val outputs: List, - val attachments: List, - val commands: List>, - val origHash: SecureHash, - val inputNotary: Party? = null, - val timeWindow: TimeWindow? = null) { -// DOCEND 1 - override fun hashCode() = origHash.hashCode() - override fun equals(other: Any?) = other is TransactionForContract && other.origHash == origHash - - /** - * Given a type and a function that returns a grouping key, associates inputs and outputs together so that they - * can be processed as one. The grouping key is any arbitrary object that can act as a map key (so must implement - * equals and hashCode). - * - * The purpose of this function is to simplify the writing of verification logic for transactions that may contain - * similar but unrelated state evolutions which need to be checked independently. Consider a transaction that - * simultaneously moves both dollars and euros (e.g. is an atomic FX trade). There may be multiple dollar inputs and - * multiple dollar outputs, depending on things like how fragmented the owner's vault is and whether various privacy - * techniques are in use. The quantity of dollars on the output side must sum to the same as on the input side, to - * ensure no money is being lost track of. This summation and checking must be repeated independently for each - * currency. To solve this, you would use groupStates with a type of Cash.State and a selector that returns the - * currency field: the resulting list can then be iterated over to perform the per-currency calculation. - */ - // DOCSTART 2 - fun groupStates(ofType: Class, selector: (T) -> K): List> { - val inputs = inputs.filterIsInstance(ofType) - val outputs = outputs.filterIsInstance(ofType) - - val inGroups: Map> = inputs.groupBy(selector) - val outGroups: Map> = outputs.groupBy(selector) - - @Suppress("DEPRECATION") - return groupStatesInternal(inGroups, outGroups) - } - // DOCEND 2 - - /** See the documentation for the reflection-based version of [groupStates] */ - inline fun groupStates(selector: (T) -> K): List> { - val inputs = inputs.filterIsInstance() - val outputs = outputs.filterIsInstance() - - val inGroups: Map> = inputs.groupBy(selector) - val outGroups: Map> = outputs.groupBy(selector) - - @Suppress("DEPRECATION") - return groupStatesInternal(inGroups, outGroups) - } - - @Deprecated("Do not use this directly: exposed as public only due to function inlining") - fun groupStatesInternal(inGroups: Map>, outGroups: Map>): List> { - val result = ArrayList>() - - for ((k, v) in inGroups.entries) - result.add(InOutGroup(v, outGroups[k] ?: emptyList(), k)) - for ((k, v) in outGroups.entries) { - if (inGroups[k] == null) - result.add(InOutGroup(emptyList(), v, k)) - } - - return result - } - - /** Utilities for contract writers to incorporate into their logic. */ - - /** - * A set of related inputs and outputs that are connected by some common attributes. An InOutGroup is calculated - * using [groupStates] and is useful for handling cases where a transaction may contain similar but unrelated - * state evolutions, for example, a transaction that moves cash in two different currencies. The numbers must add - * up on both sides of the transaction, but the values must be summed independently per currency. Grouping can - * be used to simplify this logic. - */ - // DOCSTART 3 - data class InOutGroup(val inputs: List, val outputs: List, val groupingKey: K) - // DOCEND 3 -} - -class TransactionResolutionException(val hash: SecureHash) : FlowException() { - override fun toString(): String = "Transaction resolution failure for $hash" -} - -class AttachmentResolutionException(val hash: SecureHash) : FlowException() { - override fun toString(): String = "Attachment resolution failure for $hash" -} - -sealed class TransactionVerificationException(val txId: SecureHash, cause: Throwable?) : FlowException(cause) { - class ContractRejection(txId: SecureHash, val contract: Contract, cause: Throwable?) : TransactionVerificationException(txId, cause) - class MoreThanOneNotary(txId: SecureHash) : TransactionVerificationException(txId, null) - class SignersMissing(txId: SecureHash, val missing: List) : TransactionVerificationException(txId, null) { - override fun toString(): String = "Signers missing: ${missing.joinToString()}" - } - - class DuplicateInputStates(txId: SecureHash, val duplicates: Set) : TransactionVerificationException(txId, null) { - override fun toString(): String = "Duplicate inputs: ${duplicates.joinToString()}" - } - - class InvalidNotaryChange(txId: SecureHash) : TransactionVerificationException(txId, null) - class NotaryChangeInWrongTransactionType(txId: SecureHash, val txNotary: Party, val outputNotary: Party) : TransactionVerificationException(txId, null) { - override fun toString(): String { - return "Found unexpected notary change in transaction. Tx notary: $txNotary, found: $outputNotary" - } - } - - class TransactionMissingEncumbranceException(txId: SecureHash, val missing: Int, val inOut: Direction) : TransactionVerificationException(txId, null) { - override val message: String get() = "Missing required encumbrance $missing in $inOut" - } - - @CordaSerializable - enum class Direction { - INPUT, - OUTPUT - } -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/TransactionVerificationException.kt b/core/src/main/kotlin/net/corda/core/contracts/TransactionVerificationException.kt new file mode 100644 index 0000000000..6f42461e5c --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/contracts/TransactionVerificationException.kt @@ -0,0 +1,42 @@ +package net.corda.core.contracts + +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowException +import net.corda.core.identity.Party +import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.NonEmptySet +import java.security.PublicKey + +class TransactionResolutionException(val hash: SecureHash) : FlowException("Transaction resolution failure for $hash") +class AttachmentResolutionException(val hash: SecureHash) : FlowException("Attachment resolution failure for $hash") + +sealed class TransactionVerificationException(val txId: SecureHash, message: String, cause: Throwable?) + : FlowException("$message, transaction: $txId", cause) { + + class ContractRejection(txId: SecureHash, contract: Contract, cause: Throwable) + : TransactionVerificationException(txId, "Contract verification failed: ${cause.message}, contract: $contract", cause) + + class MoreThanOneNotary(txId: SecureHash) + : TransactionVerificationException(txId, "More than one notary", null) + + class SignersMissing(txId: SecureHash, missing: List) + : TransactionVerificationException(txId, "Signers missing: ${missing.joinToString()}", null) + + class DuplicateInputStates(txId: SecureHash, val duplicates: NonEmptySet) + : TransactionVerificationException(txId, "Duplicate inputs: ${duplicates.joinToString()}", null) + + class InvalidNotaryChange(txId: SecureHash) + : TransactionVerificationException(txId, "Detected a notary change. Outputs must use the same notary as inputs", null) + + class NotaryChangeInWrongTransactionType(txId: SecureHash, txNotary: Party, outputNotary: Party) + : TransactionVerificationException(txId, "Found unexpected notary change in transaction. Tx notary: $txNotary, found: $outputNotary", null) + + class TransactionMissingEncumbranceException(txId: SecureHash, missing: Int, inOut: Direction) + : TransactionVerificationException(txId, "Missing required encumbrance $missing in $inOut", null) + + @CordaSerializable + enum class Direction { + INPUT, + OUTPUT + } +} diff --git a/core/src/main/kotlin/net/corda/core/contracts/UniqueIdentifier.kt b/core/src/main/kotlin/net/corda/core/contracts/UniqueIdentifier.kt index 2be023d047..5f62e064de 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/UniqueIdentifier.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/UniqueIdentifier.kt @@ -1,6 +1,6 @@ package net.corda.core.contracts -import com.google.common.annotations.VisibleForTesting +import net.corda.core.internal.VisibleForTesting import net.corda.core.serialization.CordaSerializable import java.util.* diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/AllComposition.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/AllComposition.kt deleted file mode 100644 index 5be41988e5..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/AllComposition.kt +++ /dev/null @@ -1,10 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState - -/** - * Compose a number of clauses, such that all of the clauses must run for verification to pass. - */ -@Deprecated("Use AllOf") -class AllComposition(firstClause: Clause, vararg remainingClauses: Clause) : AllOf(firstClause, *remainingClauses) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/AllOf.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/AllOf.kt deleted file mode 100644 index 6fcc9df40b..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/AllOf.kt +++ /dev/null @@ -1,38 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import java.util.* - -/** - * Compose a number of clauses, such that all of the clauses must run for verification to pass. - */ -open class AllOf(firstClause: Clause, vararg remainingClauses: Clause) : CompositeClause() { - override val clauses = ArrayList>() - - init { - clauses.add(firstClause) - clauses.addAll(remainingClauses) - } - - override fun matchedClauses(commands: List>): List> { - clauses.forEach { clause -> - check(clause.matches(commands)) { "Failed to match clause $clause" } - } - return clauses - } - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: K?): Set { - return matchedClauses(commands).flatMapTo(HashSet()) { clause -> - clause.verify(tx, inputs, outputs, commands, groupingKey) - } - } - - override fun toString() = "All: $clauses.toList()" -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/AnyComposition.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/AnyComposition.kt deleted file mode 100644 index fbad044ca3..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/AnyComposition.kt +++ /dev/null @@ -1,10 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState - -/** - * Compose a number of clauses, such that any number of the clauses can run. - */ -@Deprecated("Use AnyOf instead, although note that any of requires at least one matched clause") -class AnyComposition(vararg rawClauses: Clause) : AnyOf(*rawClauses) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/AnyOf.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/AnyOf.kt deleted file mode 100644 index ceb732bea2..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/AnyOf.kt +++ /dev/null @@ -1,28 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import java.util.* - -/** - * Compose a number of clauses, such that one or more of the clauses can run. - */ -open class AnyOf(vararg rawClauses: Clause) : CompositeClause() { - override val clauses: List> = rawClauses.toList() - - override fun matchedClauses(commands: List>): List> { - val matched = clauses.filter { it.matches(commands) } - require(matched.isNotEmpty()) { "At least one clause must match" } - return matched - } - - override fun verify(tx: TransactionForContract, inputs: List, outputs: List, commands: List>, groupingKey: K?): Set { - return matchedClauses(commands).flatMapTo(HashSet()) { clause -> - clause.verify(tx, inputs, outputs, commands, groupingKey) - } - } - - override fun toString(): String = "Any: ${clauses.toList()}" -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/Clause.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/Clause.kt deleted file mode 100644 index 73747de7d7..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/Clause.kt +++ /dev/null @@ -1,74 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import net.corda.core.utilities.loggerFor -import org.slf4j.Logger - -/** - * A clause of a contract, containing a chunk of verification logic. That logic may be delegated to other clauses, or - * provided directly by this clause. - * - * @param S the type of contract state this clause operates on. - * @param C a common supertype of commands this clause operates on. - * @param K the type of the grouping key for states this clause operates on. Use [Unit] if not applicable. - * - * @see CompositeClause - */ -abstract class Clause { - companion object { - val log: Logger by lazy { loggerFor>() } - } - - /** Determine whether this clause runs or not */ - open val requiredCommands: Set> = emptySet() - - /** - * Determine the subclauses which will be verified as a result of verifying this clause. - * - * @throws IllegalStateException if the given commands do not result in a valid execution (for example no match - * with [FirstOf]). - */ - @Throws(IllegalStateException::class) - open fun getExecutionPath(commands: List>): List> - = listOf(this) - - /** - * Verify the transaction matches the conditions from this clause. For example, a "no zero amount output" clause - * would check each of the output states that it applies to, looking for a zero amount, and throw IllegalStateException - * if any matched. - * - * @param tx the full transaction being verified. This is provided for cases where clauses need to access - * states or commands outside of their normal scope. - * @param inputs input states which are relevant to this clause. By default this is the set passed into [verifyClause], - * but may be further reduced by clauses such as [GroupClauseVerifier]. - * @param outputs output states which are relevant to this clause. By default this is the set passed into [verifyClause], - * but may be further reduced by clauses such as [GroupClauseVerifier]. - * @param commands commands which are relevant to this clause. By default this is the set passed into [verifyClause], - * but may be further reduced by clauses such as [GroupClauseVerifier]. - * @param groupingKey a grouping key applied to states and commands, where applicable. Taken from - * [TransactionForContract.InOutGroup]. - * @return the set of commands that are consumed IF this clause is matched, and cannot be used to match a - * later clause. This would normally be all commands matching "requiredCommands" for this clause, but some - * verify() functions may do further filtering on possible matches, and return a subset. This may also include - * commands that were not required (for example the Exit command for fungible assets is optional). - */ - @Throws(IllegalStateException::class) - abstract fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: K?): Set -} - -/** - * Determine if the given list of commands matches the required commands for a clause to trigger. - */ -fun Clause<*, C, *>.matches(commands: List>): Boolean { - return if (requiredCommands.isEmpty()) - true - else - commands.map { it.value.javaClass }.toSet().containsAll(requiredCommands) -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/ClauseVerifier.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/ClauseVerifier.kt deleted file mode 100644 index 30fadda3ed..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/ClauseVerifier.kt +++ /dev/null @@ -1,29 +0,0 @@ -@file:JvmName("ClauseVerifier") - -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract - -/** - * Verify a transaction against the given list of clauses. - * - * @param tx transaction to be verified. - * @param clauses the clauses to verify. - * @param commands commands extracted from the transaction, which are relevant to the - * clauses. - */ -fun verifyClause(tx: TransactionForContract, - clause: Clause, - commands: List>) { - if (Clause.log.isTraceEnabled) { - clause.getExecutionPath(commands).forEach { - Clause.log.trace("Tx ${tx.origHash} clause: $clause") - } - } - val matchedCommands = clause.verify(tx, tx.inputs, tx.outputs, commands, null) - - check(matchedCommands.containsAll(commands.map { it.value })) { "The following commands were not matched at the end of execution: " + (commands - matchedCommands) } -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/CompositeClause.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/CompositeClause.kt deleted file mode 100644 index be0e711731..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/CompositeClause.kt +++ /dev/null @@ -1,25 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState - -/** - * Abstract supertype for clauses which compose other clauses together in some logical manner. - */ -abstract class CompositeClause : Clause() { - /** List of clauses under this composite clause */ - abstract val clauses: List> - - override fun getExecutionPath(commands: List>): List> - = matchedClauses(commands).flatMap { it.getExecutionPath(commands) } - - /** - * Determine which clauses are matched by the supplied commands. - * - * @throws IllegalStateException if the given commands do not result in a valid execution (for example no match - * with [FirstOf]). - */ - @Throws(IllegalStateException::class) - abstract fun matchedClauses(commands: List>): List> -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/FilterOn.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/FilterOn.kt deleted file mode 100644 index e34f313443..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/FilterOn.kt +++ /dev/null @@ -1,25 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract - -/** - * Filter the states that are passed through to the wrapped clause, to restrict them to a specific type. - */ -class FilterOn(val clause: Clause, - val filterStates: (List) -> List) : Clause() { - override val requiredCommands: Set> - = clause.requiredCommands - - override fun getExecutionPath(commands: List>): List> - = clause.getExecutionPath(commands) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: K?): Set - = clause.verify(tx, filterStates(inputs), filterStates(outputs), commands, groupingKey) -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/FirstComposition.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/FirstComposition.kt deleted file mode 100644 index 2edea14625..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/FirstComposition.kt +++ /dev/null @@ -1,28 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import java.util.* - -/** - * Compose a number of clauses, such that the first match is run, and it errors if none is run. - */ -@Deprecated("Use FirstOf instead") -class FirstComposition(firstClause: Clause, vararg remainingClauses: Clause) : CompositeClause() { - override val clauses = ArrayList>() - override fun matchedClauses(commands: List>): List> = listOf(clauses.first { it.matches(commands) }) - - init { - clauses.add(firstClause) - clauses.addAll(remainingClauses) - } - - override fun verify(tx: TransactionForContract, inputs: List, outputs: List, commands: List>, groupingKey: K?): Set { - val clause = matchedClauses(commands).singleOrNull() ?: throw IllegalStateException("No delegate clause matched in first composition") - return clause.verify(tx, inputs, outputs, commands, groupingKey) - } - - override fun toString() = "First: ${clauses.toList()}" -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/FirstOf.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/FirstOf.kt deleted file mode 100644 index 43a495b026..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/FirstOf.kt +++ /dev/null @@ -1,41 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import net.corda.core.utilities.loggerFor -import java.util.* - -/** - * Compose a number of clauses, such that the first match is run, and it errors if none is run. - */ -class FirstOf(firstClause: Clause, vararg remainingClauses: Clause) : CompositeClause() { - companion object { - val logger = loggerFor>() - } - - override val clauses = ArrayList>() - - /** - * Get the single matched clause from the set this composes, based on the given commands. This is provided as - * helper method for internal use, rather than using the exposed [matchedClauses] function which unnecessarily - * wraps the clause in a list. - */ - private fun matchedClause(commands: List>): Clause { - return clauses.firstOrNull { it.matches(commands) } ?: throw IllegalStateException("No delegate clause matched in first composition") - } - - override fun matchedClauses(commands: List>) = listOf(matchedClause(commands)) - - init { - clauses.add(firstClause) - clauses.addAll(remainingClauses) - } - - override fun verify(tx: TransactionForContract, inputs: List, outputs: List, commands: List>, groupingKey: K?): Set { - return matchedClause(commands).verify(tx, inputs, outputs, commands, groupingKey) - } - - override fun toString() = "First: ${clauses.toList()}" -} diff --git a/core/src/main/kotlin/net/corda/core/contracts/clauses/GroupClauseVerifier.kt b/core/src/main/kotlin/net/corda/core/contracts/clauses/GroupClauseVerifier.kt deleted file mode 100644 index f8634a812a..0000000000 --- a/core/src/main/kotlin/net/corda/core/contracts/clauses/GroupClauseVerifier.kt +++ /dev/null @@ -1,29 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import java.util.* - -abstract class GroupClauseVerifier(val clause: Clause) : Clause() { - abstract fun groupStates(tx: TransactionForContract): List> - - override fun getExecutionPath(commands: List>): List> - = clause.getExecutionPath(commands) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Unit?): Set { - val groups = groupStates(tx) - val matchedCommands = HashSet() - - for ((groupInputs, groupOutputs, groupToken) in groups) { - matchedCommands.addAll(clause.verify(tx, groupInputs, groupOutputs, commands, groupToken)) - } - - return matchedCommands - } -} diff --git a/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt b/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt index abfcaa964e..6db86fb563 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt @@ -4,6 +4,7 @@ import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.composite.CompositeSignature import net.corda.core.crypto.provider.CordaObjectIdentifier import net.corda.core.crypto.provider.CordaSecurityProvider +import net.corda.core.serialization.serialize import net.i2p.crypto.eddsa.EdDSAEngine import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPublicKey @@ -13,19 +14,19 @@ import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec -import org.bouncycastle.asn1.* +import org.bouncycastle.asn1.ASN1Integer +import org.bouncycastle.asn1.ASN1ObjectIdentifier +import org.bouncycastle.asn1.DERNull +import org.bouncycastle.asn1.DLSequence import org.bouncycastle.asn1.bc.BCObjectIdentifiers import org.bouncycastle.asn1.nist.NISTObjectIdentifiers import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers import org.bouncycastle.asn1.pkcs.PrivateKeyInfo import org.bouncycastle.asn1.sec.SECObjectIdentifiers -import org.bouncycastle.asn1.x500.X500Name -import org.bouncycastle.asn1.x509.* +import org.bouncycastle.asn1.x509.AlgorithmIdentifier +import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo import org.bouncycastle.asn1.x9.X9ObjectIdentifiers import org.bouncycastle.cert.X509CertificateHolder -import org.bouncycastle.cert.X509v3CertificateBuilder -import org.bouncycastle.cert.bc.BcX509ExtensionUtils -import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey import org.bouncycastle.jcajce.provider.asymmetric.rsa.BCRSAPrivateKey @@ -39,10 +40,6 @@ import org.bouncycastle.jce.spec.ECPublicKeySpec import org.bouncycastle.math.ec.ECConstants import org.bouncycastle.math.ec.FixedPointCombMultiplier import org.bouncycastle.math.ec.WNafUtil -import org.bouncycastle.operator.ContentSigner -import org.bouncycastle.operator.jcajce.JcaContentVerifierProviderBuilder -import org.bouncycastle.pkcs.PKCS10CertificationRequest -import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder import org.bouncycastle.pqc.jcajce.provider.BouncyCastlePQCProvider import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PublicKey @@ -52,7 +49,6 @@ import java.security.* import java.security.spec.InvalidKeySpecException import java.security.spec.PKCS8EncodedKeySpec import java.security.spec.X509EncodedKeySpec -import java.util.* import javax.crypto.Mac import javax.crypto.spec.SecretKeySpec @@ -195,7 +191,7 @@ object Crypto { // that could cause unexpected and suspicious behaviour. // i.e. if someone removes a Provider and then he/she adds a new one with the same name. // The val is private to avoid any harmful state changes. - private val providerMap: Map = mapOf( + val providerMap: Map = mapOf( BouncyCastleProvider.PROVIDER_NAME to getBouncyCastleProvider(), CordaSecurityProvider.PROVIDER_NAME to CordaSecurityProvider(), "BCPQC" to BouncyCastlePQCProvider()) // unfortunately, provider's name is not final in BouncyCastlePQCProvider, so we explicitly set it. @@ -282,7 +278,7 @@ object Crypto { /** * Decode a PKCS8 encoded key to its [PrivateKey] object based on the input scheme code name. - * This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. + * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers. * @param schemeCodeName a [String] that should match a key in supportedSignatureSchemes map (e.g. ECDSA_SECP256K1_SHA256). * @param encodedKey a PKCS8 encoded private key. * @throws IllegalArgumentException on not supported scheme or if the given key specification @@ -293,7 +289,7 @@ object Crypto { /** * Decode a PKCS8 encoded key to its [PrivateKey] object based on the input scheme code name. - * This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. + * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers. * @param signatureScheme a signature scheme (e.g. ECDSA_SECP256K1_SHA256). * @param encodedKey a PKCS8 encoded private key. * @throws IllegalArgumentException on not supported scheme or if the given key specification @@ -325,7 +321,7 @@ object Crypto { /** * Decode an X509 encoded key to its [PrivateKey] object based on the input scheme code name. - * This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. + * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers. * @param schemeCodeName a [String] that should match a key in supportedSignatureSchemes map (e.g. ECDSA_SECP256K1_SHA256). * @param encodedKey an X509 encoded public key. * @throws IllegalArgumentException if the requested scheme is not supported. @@ -337,7 +333,7 @@ object Crypto { /** * Decode an X509 encoded key to its [PrivateKey] object based on the input scheme code name. - * This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. + * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers. * @param signatureScheme a signature scheme (e.g. ECDSA_SECP256K1_SHA256). * @param encodedKey an X509 encoded public key. * @throws IllegalArgumentException if the requested scheme is not supported. @@ -401,23 +397,23 @@ object Crypto { } /** - * Generic way to sign [MetaData] objects with a [PrivateKey]. - * [MetaData] is a wrapper over the transaction's Merkle root in order to attach extra information, such as a timestamp or partial and blind signature indicators. + * Generic way to sign [SignableData] objects with a [PrivateKey]. + * [SignableData] is a wrapper over the transaction's id (Merkle root) in order to attach extra information, such as a timestamp or partial and blind signature indicators. * @param privateKey the signer's [PrivateKey]. - * @param metaData a [MetaData] object that adds extra information to a transaction. - * @return a [TransactionSignature] object than contains the output of a successful signing and the metaData. - * @throws IllegalArgumentException if the signature scheme is not supported for this private key or - * if metaData.schemeCodeName is not aligned with key type. + * @param signableData a [SignableData] object that adds extra information to a transaction. + * @return a [TransactionSignature] object than contains the output of a successful signing, signer's public key and the signature metadata. + * @throws IllegalArgumentException if the signature scheme is not supported for this private key. * @throws InvalidKeyException if the private key is invalid. * @throws SignatureException if signing is not possible due to malformed data or private key. */ @Throws(IllegalArgumentException::class, InvalidKeyException::class, SignatureException::class) - fun doSign(privateKey: PrivateKey, metaData: MetaData): TransactionSignature { - val sigKey: SignatureScheme = findSignatureScheme(privateKey) - val sigMetaData: SignatureScheme = findSignatureScheme(metaData.schemeCodeName) - if (sigKey != sigMetaData) throw IllegalArgumentException("Metadata schemeCodeName: ${metaData.schemeCodeName} is not aligned with the key type.") - val signatureData = doSign(sigKey.schemeCodeName, privateKey, metaData.bytes()) - return TransactionSignature(signatureData, metaData) + fun doSign(keyPair: KeyPair, signableData: SignableData): TransactionSignature { + val sigKey: SignatureScheme = findSignatureScheme(keyPair.private) + val sigMetaData: SignatureScheme = findSignatureScheme(keyPair.public) + if (sigKey != sigMetaData) throw IllegalArgumentException("Metadata schemeCodeName: ${sigMetaData.schemeCodeName}" + + " is not aligned with the key type: ${sigKey.schemeCodeName}.") + val signatureBytes = doSign(sigKey.schemeCodeName, keyPair.private, signableData.serialize().bytes) + return TransactionSignature(signatureBytes, keyPair.public, signableData.signatureMetadata) } /** @@ -434,7 +430,7 @@ object Crypto { * if this signatureData scheme is unable to process the input data provided, if the verification is not possible. * @throws IllegalArgumentException if the signature scheme is not supported or if any of the clear or signature data is empty. */ - @Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) + @Throws(InvalidKeyException::class, SignatureException::class) fun doVerify(schemeCodeName: String, publicKey: PublicKey, signatureData: ByteArray, clearData: ByteArray) = doVerify(findSignatureScheme(schemeCodeName), publicKey, signatureData, clearData) /** @@ -485,9 +481,9 @@ object Crypto { /** * Utility to simplify the act of verifying a [TransactionSignature]. * It returns true if it succeeds, but it always throws an exception if verification fails. - * @param publicKey the signer's [PublicKey]. - * @param transactionSignature the signatureData on a message. - * @return true if verification passes or throws an exception if verification fails. + * @param txId transaction's id (Merkle root). + * @param transactionSignature the signature on the transaction. + * @return true if verification passes or throw exception if verification fails. * @throws InvalidKeyException if the key is invalid. * @throws SignatureException if this signatureData object is not initialized properly, * the passed-in signatureData is improperly encoded or of the wrong type, @@ -495,9 +491,26 @@ object Crypto { * @throws IllegalArgumentException if the signature scheme is not supported or if any of the clear or signature data is empty. */ @Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) - fun doVerify(publicKey: PublicKey, transactionSignature: TransactionSignature): Boolean { - if (publicKey != transactionSignature.metaData.publicKey) IllegalArgumentException("MetaData's publicKey: ${transactionSignature.metaData.publicKey.toStringShort()} does not match") - return Crypto.doVerify(publicKey, transactionSignature.signatureData, transactionSignature.metaData.bytes()) + fun doVerify(txId: SecureHash, transactionSignature: TransactionSignature): Boolean { + val signableData = SignableData(txId, transactionSignature.signatureMetadata) + return Crypto.doVerify(transactionSignature.by, transactionSignature.bytes, signableData.serialize().bytes) + } + + /** + * Utility to simplify the act of verifying a digital signature by identifying the signature scheme used from the input public key's type. + * It returns true if it succeeds and false if not. In comparison to [doVerify] if the key and signature + * do not match it returns false rather than throwing an exception. Normally you should use the function which throws, + * as it avoids the risk of failing to test the result. + * @param txId transaction's id (Merkle root). + * @param transactionSignature the signature on the transaction. + * @throws SignatureException if this signatureData object is not initialized properly, + * the passed-in signatureData is improperly encoded or of the wrong type, + * if this signatureData scheme is unable to process the input data provided, if the verification is not possible. + */ + @Throws(SignatureException::class) + fun isValid(txId: SecureHash, transactionSignature: TransactionSignature): Boolean { + val signableData = SignableData(txId, transactionSignature.signatureMetadata) + return isValid(findSignatureScheme(transactionSignature.by), transactionSignature.by, transactionSignature.bytes, signableData.serialize().bytes) } /** @@ -752,90 +765,6 @@ object Crypto { return mac.doFinal(seed) } - /** - * Build a partial X.509 certificate ready for signing. - * - * @param issuer name of the issuing entity. - * @param subject name of the certificate subject. - * @param subjectPublicKey public key of the certificate subject. - * @param validityWindow the time period the certificate is valid for. - * @param nameConstraints any name constraints to impose on certificates signed by the generated certificate. - */ - fun createCertificate(certificateType: CertificateType, issuer: X500Name, - subject: X500Name, subjectPublicKey: PublicKey, - validityWindow: Pair, - nameConstraints: NameConstraints? = null): X509v3CertificateBuilder { - - val serial = BigInteger.valueOf(random63BitValue()) - val keyPurposes = DERSequence(ASN1EncodableVector().apply { certificateType.purposes.forEach { add(it) } }) - val subjectPublicKeyInfo = SubjectPublicKeyInfo.getInstance(ASN1Sequence.getInstance(subjectPublicKey.encoded)) - - val builder = JcaX509v3CertificateBuilder(issuer, serial, validityWindow.first, validityWindow.second, subject, subjectPublicKey) - .addExtension(Extension.subjectKeyIdentifier, false, BcX509ExtensionUtils().createSubjectKeyIdentifier(subjectPublicKeyInfo)) - .addExtension(Extension.basicConstraints, certificateType.isCA, BasicConstraints(certificateType.isCA)) - .addExtension(Extension.keyUsage, false, certificateType.keyUsage) - .addExtension(Extension.extendedKeyUsage, false, keyPurposes) - - if (nameConstraints != null) { - builder.addExtension(Extension.nameConstraints, true, nameConstraints) - } - return builder - } - - /** - * Build and sign an X.509 certificate with the given signer. - * - * @param issuer name of the issuing entity. - * @param issuerSigner content signer to sign the certificate with. - * @param subject name of the certificate subject. - * @param subjectPublicKey public key of the certificate subject. - * @param validityWindow the time period the certificate is valid for. - * @param nameConstraints any name constraints to impose on certificates signed by the generated certificate. - */ - fun createCertificate(certificateType: CertificateType, issuer: X500Name, issuerSigner: ContentSigner, - subject: X500Name, subjectPublicKey: PublicKey, - validityWindow: Pair, - nameConstraints: NameConstraints? = null): X509CertificateHolder { - val builder = createCertificate(certificateType, issuer, subject, subjectPublicKey, validityWindow, nameConstraints) - return builder.build(issuerSigner).apply { - require(isValidOn(Date())) - } - } - - /** - * Build and sign an X.509 certificate with CA cert private key. - * - * @param issuer name of the issuing entity. - * @param issuerKeyPair the public & private key to sign the certificate with. - * @param subject name of the certificate subject. - * @param subjectPublicKey public key of the certificate subject. - * @param validityWindow the time period the certificate is valid for. - * @param nameConstraints any name constraints to impose on certificates signed by the generated certificate. - */ - fun createCertificate(certificateType: CertificateType, issuer: X500Name, issuerKeyPair: KeyPair, - subject: X500Name, subjectPublicKey: PublicKey, - validityWindow: Pair, - nameConstraints: NameConstraints? = null): X509CertificateHolder { - - val signatureScheme = findSignatureScheme(issuerKeyPair.private) - val provider = providerMap[signatureScheme.providerName] - val builder = createCertificate(certificateType, issuer, subject, subjectPublicKey, validityWindow, nameConstraints) - - val signer = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider) - return builder.build(signer).apply { - require(isValidOn(Date())) - require(isSignatureValid(JcaContentVerifierProviderBuilder().build(issuerKeyPair.public))) - } - } - - /** - * Create certificate signing request using provided information. - */ - fun createCertificateSigningRequest(subject: X500Name, keyPair: KeyPair, signatureScheme: SignatureScheme): PKCS10CertificationRequest { - val signer = ContentSignerBuilder.build(signatureScheme, keyPair.private, providerMap[signatureScheme.providerName]) - return JcaPKCS10CertificationRequestBuilder(subject, keyPair.public).build(signer) - } - private class KeyInfoConverter(val signatureScheme: SignatureScheme) : AsymmetricKeyInfoConverter { override fun generatePublic(keyInfo: SubjectPublicKeyInfo?): PublicKey? = keyInfo?.let { decodePublicKey(signatureScheme, it.encoded) } override fun generatePrivate(keyInfo: PrivateKeyInfo?): PrivateKey? = keyInfo?.let { decodePrivateKey(signatureScheme, it.encoded) } diff --git a/core/src/main/kotlin/net/corda/core/crypto/CryptoUtils.kt b/core/src/main/kotlin/net/corda/core/crypto/CryptoUtils.kt index c77131522a..f41fd55f6a 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/CryptoUtils.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/CryptoUtils.kt @@ -35,7 +35,17 @@ fun PrivateKey.sign(bytesToSign: ByteArray, publicKey: PublicKey): DigitalSignat */ @Throws(IllegalArgumentException::class, InvalidKeyException::class, SignatureException::class) fun KeyPair.sign(bytesToSign: ByteArray) = private.sign(bytesToSign, public) -fun KeyPair.sign(bytesToSign: OpaqueBytes) = private.sign(bytesToSign.bytes, public) +fun KeyPair.sign(bytesToSign: OpaqueBytes) = sign(bytesToSign.bytes) +/** + * Helper function for signing a [SignableData] object. + * @param signableData the object to be signed. + * @return a [TransactionSignature] object. + * @throws IllegalArgumentException if the signature scheme is not supported for this private key. + * @throws InvalidKeyException if the private key is invalid. + * @throws SignatureException if signing is not possible due to malformed data or private key. + */ +@Throws(InvalidKeyException::class, SignatureException::class) +fun KeyPair.sign(signableData: SignableData): TransactionSignature = Crypto.doSign(this, signableData) /** * Utility to simplify the act of verifying a signature. @@ -89,7 +99,7 @@ fun PublicKey.containsAny(otherKeys: Iterable): Boolean { } /** Returns the set of all [PublicKey]s of the signatures */ -fun Iterable.byKeys() = map { it.by }.toSet() +fun Iterable.byKeys() = map { it.by }.toSet() // Allow Kotlin destructuring: val (private, public) = keyPair operator fun KeyPair.component1(): PrivateKey = this.private @@ -106,17 +116,6 @@ fun generateKeyPair(): KeyPair = Crypto.generateKeyPair() */ fun entropyToKeyPair(entropy: BigInteger): KeyPair = Crypto.deriveKeyPairFromEntropy(entropy) -/** - * Helper function for signing. - * @param metaData tha attached MetaData object. - * @return a [TransactionSignature ] object. - * @throws IllegalArgumentException if the signature scheme is not supported for this private key. - * @throws InvalidKeyException if the private key is invalid. - * @throws SignatureException if signing is not possible due to malformed data or private key. - */ -@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) -fun PrivateKey.sign(metaData: MetaData): TransactionSignature = Crypto.doSign(this, metaData) - /** * Helper function to verify a signature. * @param signatureData the signature on a message. @@ -130,21 +129,6 @@ fun PrivateKey.sign(metaData: MetaData): TransactionSignature = Crypto.doSign(th @Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) fun PublicKey.verify(signatureData: ByteArray, clearData: ByteArray): Boolean = Crypto.doVerify(this, signatureData, clearData) -/** - * Helper function to verify a metadata attached signature. It is noted that the transactionSignature contains - * signatureData and a [MetaData] object that contains the signer's public key and the transaction's Merkle root. - * @param transactionSignature a [TransactionSignature] object that . - * @throws InvalidKeyException if the key is invalid. - * @throws SignatureException if this signatureData object is not initialized properly, - * the passed-in signatureData is improperly encoded or of the wrong type, - * if this signatureData algorithm is unable to process the input data provided, etc. - * @throws IllegalArgumentException if the signature scheme is not supported for this private key or if any of the clear or signature data is empty. - */ -@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) -fun PublicKey.verify(transactionSignature: TransactionSignature): Boolean { - return Crypto.doVerify(this, transactionSignature) -} - /** * Helper function for the signers to verify their own signature. * @param signatureData the signature on a message. diff --git a/core/src/main/kotlin/net/corda/core/crypto/EncodingUtils.kt b/core/src/main/kotlin/net/corda/core/crypto/EncodingUtils.kt index a79821b760..0219142656 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/EncodingUtils.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/EncodingUtils.kt @@ -60,7 +60,7 @@ fun String.hexToBase58(): String = hexToByteArray().toBase58() /** Encoding changer. Hex-[String] to Base64-[String], i.e. "48656C6C6F20576F726C64" -> "SGVsbG8gV29ybGQ=" */ fun String.hexToBase64(): String = hexToByteArray().toBase64() -// TODO We use for both CompositeKeys and EdDSAPublicKey custom Kryo serializers and deserializers. We need to specify encoding. +// TODO We use for both CompositeKeys and EdDSAPublicKey custom serializers and deserializers. We need to specify encoding. // TODO: follow the crypto-conditions ASN.1 spec, some changes are needed to be compatible with the condition // structure, e.g. mapping a PublicKey to a condition with the specific feature (ED25519). fun parsePublicKeyBase58(base58String: String): PublicKey = base58String.base58ToByteArray().deserialize() diff --git a/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt b/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt index babf164bc6..831e20758e 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/MerkleTree.kt @@ -23,8 +23,10 @@ sealed class MerkleTree { /** * Merkle tree building using hashes, with zero hash padding to full power of 2. */ - @Throws(IllegalArgumentException::class) + @Throws(MerkleTreeException::class) fun getMerkleTree(allLeavesHashes: List): MerkleTree { + if (allLeavesHashes.isEmpty()) + throw MerkleTreeException("Cannot calculate Merkle root on empty hash list.") val leaves = padWithZeros(allLeavesHashes).map { Leaf(it) } return buildMerkleTree(leaves) } @@ -46,8 +48,6 @@ sealed class MerkleTree { * @return Tree root. */ private tailrec fun buildMerkleTree(lastNodesList: List): MerkleTree { - if (lastNodesList.isEmpty()) - throw MerkleTreeException("Cannot calculate Merkle root on empty hash list.") if (lastNodesList.size == 1) { return lastNodesList[0] //Root reached. } else { diff --git a/core/src/main/kotlin/net/corda/core/crypto/MetaData.kt b/core/src/main/kotlin/net/corda/core/crypto/MetaData.kt deleted file mode 100644 index edcf018e82..0000000000 --- a/core/src/main/kotlin/net/corda/core/crypto/MetaData.kt +++ /dev/null @@ -1,71 +0,0 @@ -package net.corda.core.crypto - -import net.corda.core.serialization.CordaSerializable -import net.corda.core.utilities.opaque -import net.corda.core.serialization.serialize -import java.security.PublicKey -import java.time.Instant -import java.util.* - -/** - * A [MetaData] object adds extra information to a transaction. MetaData is used to support a universal - * digital signature model enabling full, partial, fully or partially blind and metaData attached signatures, - * (such as an attached timestamp). A MetaData object contains both the merkle root of the transaction and the signer's public key. - * When signatureType is set to FULL, then visibleInputs and signedInputs can be ignored. - * Note: We could omit signatureType as it can always be defined by combining visibleInputs and signedInputs, - * but it helps to speed up the process when FULL is used, and thus we can bypass the extra check on boolean arrays. - * - * @param schemeCodeName a signature scheme's code name (e.g. ECDSA_SECP256K1_SHA256). - * @param versionID DLT's version. - * @param signatureType type of the signature, see [SignatureType] (e.g. FULL, PARTIAL, BLIND, PARTIAL_AND_BLIND). - * @param timestamp the signature's timestamp as provided by the signer. - * @param visibleInputs for partially/fully blind signatures. We use Merkle tree boolean index flags (from left to right) - * indicating what parts of the transaction were visible when the signature was calculated. - * @param signedInputs for partial signatures. We use Merkle tree boolean index flags (from left to right) - * indicating what parts of the Merkle tree are actually signed. - * @param merkleRoot the Merkle root of the transaction. - * @param publicKey the signer's public key. - */ -@CordaSerializable -open class MetaData( - val schemeCodeName: String, - val versionID: String, - val signatureType: SignatureType = SignatureType.FULL, - val timestamp: Instant?, - val visibleInputs: BitSet?, - val signedInputs: BitSet?, - val merkleRoot: ByteArray, - val publicKey: PublicKey) { - - fun bytes() = this.serialize().bytes - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other?.javaClass != javaClass) return false - - other as MetaData - - if (schemeCodeName != other.schemeCodeName) return false - if (versionID != other.versionID) return false - if (signatureType != other.signatureType) return false - if (timestamp != other.timestamp) return false - if (visibleInputs != other.visibleInputs) return false - if (signedInputs != other.signedInputs) return false - if (merkleRoot.opaque() != other.merkleRoot.opaque()) return false - if (publicKey != other.publicKey) return false - return true - } - - override fun hashCode(): Int { - var result = schemeCodeName.hashCode() - result = 31 * result + versionID.hashCode() - result = 31 * result + signatureType.hashCode() - result = 31 * result + (timestamp?.hashCode() ?: 0) - result = 31 * result + (visibleInputs?.hashCode() ?: 0) - result = 31 * result + (signedInputs?.hashCode() ?: 0) - result = 31 * result + Arrays.hashCode(merkleRoot) - result = 31 * result + publicKey.hashCode() - return result - } -} - diff --git a/core/src/main/kotlin/net/corda/core/crypto/PartialMerkleTree.kt b/core/src/main/kotlin/net/corda/core/crypto/PartialMerkleTree.kt index 619921535b..a8ac0356be 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/PartialMerkleTree.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/PartialMerkleTree.kt @@ -5,9 +5,7 @@ import net.corda.core.serialization.CordaSerializable import java.util.* @CordaSerializable -class MerkleTreeException(val reason: String) : Exception() { - override fun toString() = "Partial Merkle Tree exception. Reason: $reason" -} +class MerkleTreeException(val reason: String) : Exception("Partial Merkle Tree exception. Reason: $reason") /** * Building and verification of Partial Merkle Tree. diff --git a/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt b/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt index fcefe20a5b..af2114ab6d 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/SecureHash.kt @@ -1,8 +1,9 @@ package net.corda.core.crypto -import com.google.common.io.BaseEncoding import net.corda.core.serialization.CordaSerializable import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.parseAsHex +import net.corda.core.utilities.toHexString import java.security.MessageDigest /** @@ -18,7 +19,7 @@ sealed class SecureHash(bytes: ByteArray) : OpaqueBytes(bytes) { } } - override fun toString(): String = BaseEncoding.base16().encode(bytes) + override fun toString(): String = bytes.toHexString() fun prefixChars(prefixLen: Int = 6) = toString().substring(0, prefixLen) fun hashConcat(other: SecureHash) = (this.bytes + other.bytes).sha256() @@ -26,7 +27,7 @@ sealed class SecureHash(bytes: ByteArray) : OpaqueBytes(bytes) { // Like static methods in Java, except the 'companion' is a singleton that can have state. companion object { @JvmStatic - fun parse(str: String) = BaseEncoding.base16().decode(str.toUpperCase()).let { + fun parse(str: String) = str.toUpperCase().parseAsHex().let { when (it.size) { 32 -> SHA256(it) else -> throw IllegalArgumentException("Provided string is ${it.size} bytes not 32 bytes in hex: $str") diff --git a/core/src/main/kotlin/net/corda/core/crypto/SignableData.kt b/core/src/main/kotlin/net/corda/core/crypto/SignableData.kt new file mode 100644 index 0000000000..a893381f7b --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/crypto/SignableData.kt @@ -0,0 +1,14 @@ +package net.corda.core.crypto + +import net.corda.core.serialization.CordaSerializable + +/** + * A [SignableData] object is the packet actually signed. + * It works as a wrapper over transaction id and signature metadata. + * + * @param txId transaction's id. + * @param signatureMetadata meta data required. + */ +@CordaSerializable +data class SignableData(val txId: SecureHash, val signatureMetadata: SignatureMetadata) + diff --git a/core/src/main/kotlin/net/corda/core/crypto/SignatureMetadata.kt b/core/src/main/kotlin/net/corda/core/crypto/SignatureMetadata.kt new file mode 100644 index 0000000000..99335bde4c --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/crypto/SignatureMetadata.kt @@ -0,0 +1,15 @@ +package net.corda.core.crypto + +import net.corda.core.serialization.CordaSerializable + +/** + * SignatureMeta is required to add extra meta-data to a transaction's signature. + * It currently supports platformVersion only, but it can be extended to support a universal digital + * signature model enabling partial signatures and attaching extra information, such as a user's timestamp or other + * application-specific fields. + * + * @param platformVersion current DLT version. + * @param schemeNumberID number id of the signature scheme used based on signer's key-pair, see [SignatureScheme.schemeNumberID]. + */ +@CordaSerializable +data class SignatureMetadata(val platformVersion: Int, val schemeNumberID: Int) diff --git a/core/src/main/kotlin/net/corda/core/crypto/SignatureScheme.kt b/core/src/main/kotlin/net/corda/core/crypto/SignatureScheme.kt index 49493f6d6f..9ec9d26863 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/SignatureScheme.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/SignatureScheme.kt @@ -6,7 +6,7 @@ import java.security.spec.AlgorithmParameterSpec /** * This class is used to define a digital signature scheme. - * @param schemeNumberID we assign a number ID for more efficient on-wire serialisation. Please ensure uniqueness between schemes. + * @param schemeNumberID we assign a number ID for better efficiency on-wire serialisation. Please ensure uniqueness between schemes. * @param schemeCodeName code name for this signature scheme (e.g. RSA_SHA256, ECDSA_SECP256K1_SHA256, ECDSA_SECP256R1_SHA256, EDDSA_ED25519_SHA512, SPHINCS-256_SHA512). * @param signatureOID ASN.1 algorithm identifier of the signature algorithm (e.g 1.3.101.112 for EdDSA) * @param alternativeOIDs ASN.1 algorithm identifiers for keys of the signature, where we want to map multiple keys to diff --git a/core/src/main/kotlin/net/corda/core/crypto/SignatureType.kt b/core/src/main/kotlin/net/corda/core/crypto/SignatureType.kt deleted file mode 100644 index 5d94f8d9e4..0000000000 --- a/core/src/main/kotlin/net/corda/core/crypto/SignatureType.kt +++ /dev/null @@ -1,17 +0,0 @@ -package net.corda.core.crypto - -import net.corda.core.serialization.CordaSerializable - -/** - * Supported Signature types: - *

    - *
  • FULL = signature covers whole transaction, by the convention that signing the Merkle root, it is equivalent to signing all parts of the transaction. - *
  • PARTIAL = signature covers only a part of the transaction, see [MetaData]. - *
  • BLIND = when an entity blindly signs without having full knowledge on the content, see [MetaData]. - *
  • PARTIAL_AND_BLIND = combined PARTIAL and BLIND in the same time. - *
- */ -@CordaSerializable -enum class SignatureType { - FULL, PARTIAL, BLIND, PARTIAL_AND_BLIND -} diff --git a/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt b/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt index f1262f84af..472a8a1024 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/SignedData.kt @@ -23,7 +23,8 @@ open class SignedData(val raw: SerializedBytes, val sig: DigitalSign @Throws(SignatureException::class) fun verified(): T { sig.by.verify(raw.bytes, sig) - val data = raw.deserialize() + @Suppress("UNCHECKED_CAST") + val data = raw.deserialize() as T verifyData(data) return data } diff --git a/core/src/main/kotlin/net/corda/core/crypto/TransactionSignature.kt b/core/src/main/kotlin/net/corda/core/crypto/TransactionSignature.kt index ccd917513b..28e843a82d 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/TransactionSignature.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/TransactionSignature.kt @@ -1,22 +1,57 @@ package net.corda.core.crypto +import net.corda.core.serialization.CordaSerializable import java.security.InvalidKeyException +import java.security.PublicKey import java.security.SignatureException +import java.util.* /** - * A wrapper around a digital signature accompanied with metadata, see [MetaData.Full] and [DigitalSignature]. - * The signature protocol works as follows: s = sign(MetaData.hashBytes). + * A wrapper over the signature output accompanied by signer's public key and signature metadata. + * This is similar to [DigitalSignature.WithKey], but targeted to DLT transaction signatures. */ -open class TransactionSignature(val signatureData: ByteArray, val metaData: MetaData) : DigitalSignature(signatureData) { +@CordaSerializable +class TransactionSignature(bytes: ByteArray, val by: PublicKey, val signatureMetadata: SignatureMetadata): DigitalSignature(bytes) { /** - * Function to auto-verify a [MetaData] object's signature. - * Note that [MetaData] contains both public key and merkle root of the transaction. + * Function to verify a [SignableData] object's signature. + * Note that [SignableData] contains the id of the transaction and extra metadata, such as DLT's platform version. + * + * @param txId transaction's id (Merkle root), which along with [signatureMetadata] will be used to construct the [SignableData] object to be signed. * @throws InvalidKeyException if the key is invalid. * @throws SignatureException if this signatureData object is not initialized properly, * the passed-in signatureData is improperly encoded or of the wrong type, * if this signatureData algorithm is unable to process the input data provided, etc. * @throws IllegalArgumentException if the signature scheme is not supported for this private key or if any of the clear or signature data is empty. */ - @Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) - fun verify(): Boolean = Crypto.doVerify(metaData.publicKey, signatureData, metaData.bytes()) + @Throws(InvalidKeyException::class, SignatureException::class) + fun verify(txId: SecureHash) = Crypto.doVerify(txId, this) + + /** + * Utility to simplify the act of verifying a signature. In comparison to [verify] doesn't throw an + * exception, making it more suitable where a boolean is required, but normally you should use the function + * which throws, as it avoids the risk of failing to test the result. + * + * @throws InvalidKeyException if the key to verify the signature with is not valid (i.e. wrong key type for the + * signature). + * @throws SignatureException if the signature is invalid (i.e. damaged). + * @return whether the signature is correct for this key. + */ + @Throws(InvalidKeyException::class, SignatureException::class) + fun isValid(txId: SecureHash) = Crypto.isValid(txId, this) + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is TransactionSignature) return false + + return (Arrays.equals(bytes, other.bytes) + && by == other.by + && signatureMetadata == other.signatureMetadata) + } + + override fun hashCode(): Int { + var result = super.hashCode() + result = 31 * result + by.hashCode() + result = 31 * result + signatureMetadata.hashCode() + return result + } } diff --git a/core/src/main/kotlin/net/corda/core/crypto/X500NameUtils.kt b/core/src/main/kotlin/net/corda/core/crypto/X500NameUtils.kt new file mode 100644 index 0000000000..0043988ba4 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/crypto/X500NameUtils.kt @@ -0,0 +1,77 @@ +@file:JvmName("X500NameUtils") +package net.corda.core.crypto + +import org.bouncycastle.asn1.ASN1Encodable +import org.bouncycastle.asn1.x500.X500Name +import org.bouncycastle.asn1.x500.X500NameBuilder +import org.bouncycastle.asn1.x500.style.BCStyle +import org.bouncycastle.cert.X509CertificateHolder +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter +import java.security.KeyPair +import java.security.cert.X509Certificate + +/** + * Rebuild the distinguished name, adding a postfix to the common name. If no common name is present. + * @throws IllegalArgumentException if the distinguished name does not contain a common name element. + */ +fun X500Name.appendToCommonName(commonName: String): X500Name = mutateCommonName { attr -> attr.toString() + commonName } + +/** + * Rebuild the distinguished name, replacing the common name with the given value. If no common name is present, this + * adds one. + * @throws IllegalArgumentException if the distinguished name does not contain a common name element. + */ +fun X500Name.replaceCommonName(commonName: String): X500Name = mutateCommonName { _ -> commonName } + +/** + * Rebuild the distinguished name, replacing the common name with a value generated from the provided function. + * + * @param mutator a function to generate the new value from the previous one. + * @throws IllegalArgumentException if the distinguished name does not contain a common name element. + */ +private fun X500Name.mutateCommonName(mutator: (ASN1Encodable) -> String): X500Name { + val builder = X500NameBuilder(BCStyle.INSTANCE) + var matched = false + this.rdNs.forEach { rdn -> + rdn.typesAndValues.forEach { typeAndValue -> + when (typeAndValue.type) { + BCStyle.CN -> { + matched = true + builder.addRDN(typeAndValue.type, mutator(typeAndValue.value)) + } + else -> { + builder.addRDN(typeAndValue) + } + } + } + } + require(matched) { "Input X.500 name must include a common name (CN) attribute: ${this}" } + return builder.build() +} + +val X500Name.commonName: String get() = getRDNs(BCStyle.CN).first().first.value.toString() +val X500Name.orgName: String? get() = getRDNs(BCStyle.O).firstOrNull()?.first?.value?.toString() +val X500Name.location: String get() = getRDNs(BCStyle.L).first().first.value.toString() +val X500Name.locationOrNull: String? get() = try { + location +} catch (e: Exception) { + null +} +val X509Certificate.subject: X500Name get() = X509CertificateHolder(encoded).subject +val X509CertificateHolder.cert: X509Certificate get() = JcaX509CertificateConverter().getCertificate(this) + +/** + * Generate a distinguished name from the provided values. + */ +@JvmOverloads +fun getX509Name(myLegalName: String, nearestCity: String, email: String, country: String? = null): X500Name { + return X500NameBuilder(BCStyle.INSTANCE).let { builder -> + builder.addRDN(BCStyle.CN, myLegalName) + builder.addRDN(BCStyle.L, nearestCity) + country?.let { builder.addRDN(BCStyle.C, it) } + builder.addRDN(BCStyle.E, email) + builder.build() + } +} + +data class CertificateAndKeyPair(val certificate: X509CertificateHolder, val keyPair: KeyPair) diff --git a/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt b/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt index 51a13a076a..f9ca0dad7c 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeKey.kt @@ -4,13 +4,13 @@ import net.corda.core.crypto.Crypto import net.corda.core.crypto.composite.CompositeKey.NodeAndWeight import net.corda.core.crypto.keys import net.corda.core.crypto.provider.CordaObjectIdentifier -import net.corda.core.crypto.toSHA256Bytes import net.corda.core.crypto.toStringShort +import net.corda.core.utilities.exactAdd import net.corda.core.serialization.CordaSerializable +import net.corda.core.utilities.sequence import org.bouncycastle.asn1.* import org.bouncycastle.asn1.x509.AlgorithmIdentifier import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo -import java.nio.ByteBuffer import java.security.PublicKey import java.util.* @@ -59,7 +59,7 @@ class CompositeKey private constructor(val threshold: Int, children: List = children.sorted() init { // TODO: replace with the more extensive, but slower, checkValidity() test. @@ -127,7 +127,7 @@ class CompositeKey private constructor(val threshold: Int, children: List 0) { "Non-positive weight: $weight detected." } - sum = Math.addExact(sum, weight) // Add and check for integer overflow. + sum = sum exactAdd weight // Add and check for integer overflow. } return sum } @@ -145,7 +145,7 @@ class CompositeKey private constructor(val threshold: Int, children: List() return if (verifyKey.isFulfilledBy(sig.sigs.map { it.by })) { - val clearData = buffer.toByteArray() + val clearData = SecureHash.SHA256(buffer.toByteArray()) sig.sigs.all { it.isValid(clearData) } } else { false diff --git a/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeSignaturesWithKeys.kt b/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeSignaturesWithKeys.kt index 5a69484ffa..f9c8da8ab3 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeSignaturesWithKeys.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/composite/CompositeSignaturesWithKeys.kt @@ -1,14 +1,14 @@ package net.corda.core.crypto.composite -import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.TransactionSignature import net.corda.core.serialization.CordaSerializable /** * Custom class for holding signature data. This exists for later extension work to provide a standardised cross-platform - * serialization format (i.e. not Kryo). + * serialization format. */ @CordaSerializable -data class CompositeSignaturesWithKeys(val sigs: List) { +data class CompositeSignaturesWithKeys(val sigs: List) { companion object { val EMPTY = CompositeSignaturesWithKeys(emptyList()) } diff --git a/core/src/main/kotlin/net/corda/core/crypto/testing/DummyKeys.kt b/core/src/main/kotlin/net/corda/core/crypto/testing/DummyKeys.kt deleted file mode 100644 index 8b699ef38d..0000000000 --- a/core/src/main/kotlin/net/corda/core/crypto/testing/DummyKeys.kt +++ /dev/null @@ -1,35 +0,0 @@ -package net.corda.core.crypto.testing - -import net.corda.core.crypto.DigitalSignature -import net.corda.core.identity.AnonymousParty -import net.corda.core.serialization.CordaSerializable -import java.math.BigInteger -import java.security.PublicKey - -@CordaSerializable -object NullPublicKey : PublicKey, Comparable { - override fun getAlgorithm() = "NULL" - override fun getEncoded() = byteArrayOf(0) - override fun getFormat() = "NULL" - override fun compareTo(other: PublicKey): Int = if (other == NullPublicKey) 0 else -1 - override fun toString() = "NULL_KEY" -} - -val NULL_PARTY = AnonymousParty(NullPublicKey) - -// TODO: Clean up this duplication between Null and Dummy public key -@CordaSerializable -@Deprecated("Has encoding format problems, consider entropyToKeyPair() instead") -class DummyPublicKey(val s: String) : PublicKey, Comparable { - override fun getAlgorithm() = "DUMMY" - override fun getEncoded() = s.toByteArray() - override fun getFormat() = "ASN.1" - override fun compareTo(other: PublicKey): Int = BigInteger(encoded).compareTo(BigInteger(other.encoded)) - override fun equals(other: Any?) = other is DummyPublicKey && other.s == s - override fun hashCode(): Int = s.hashCode() - override fun toString() = "PUBKEY[$s]" -} - -/** A signature with a key and value of zero. Useful when you want a signature object that you know won't ever be used. */ -@CordaSerializable -object NullSignature : DigitalSignature.WithKey(NullPublicKey, ByteArray(32)) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/crypto/testing/NullKeys.kt b/core/src/main/kotlin/net/corda/core/crypto/testing/NullKeys.kt new file mode 100644 index 0000000000..de2bc3ee42 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/crypto/testing/NullKeys.kt @@ -0,0 +1,21 @@ +package net.corda.core.crypto.testing + +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.crypto.TransactionSignature +import net.corda.core.identity.AnonymousParty +import net.corda.core.serialization.CordaSerializable +import java.security.PublicKey + +@CordaSerializable +object NullPublicKey : PublicKey, Comparable { + override fun getAlgorithm() = "NULL" + override fun getEncoded() = byteArrayOf(0) + override fun getFormat() = "NULL" + override fun compareTo(other: PublicKey): Int = if (other == NullPublicKey) 0 else -1 + override fun toString() = "NULL_KEY" +} + +val NULL_PARTY = AnonymousParty(NullPublicKey) + +/** A signature with a key and value of zero. Useful when you want a signature object that you know won't ever be used. */ +val NULL_SIGNATURE = TransactionSignature(ByteArray(32), NullPublicKey, SignatureMetadata(1, -1)) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/flows/AbstractStateReplacementFlow.kt b/core/src/main/kotlin/net/corda/core/flows/AbstractStateReplacementFlow.kt similarity index 74% rename from core/src/main/kotlin/net/corda/flows/AbstractStateReplacementFlow.kt rename to core/src/main/kotlin/net/corda/core/flows/AbstractStateReplacementFlow.kt index a035bd256f..c4a84da857 100644 --- a/core/src/main/kotlin/net/corda/flows/AbstractStateReplacementFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/AbstractStateReplacementFlow.kt @@ -1,19 +1,14 @@ -package net.corda.flows +package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef -import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.isFulfilledBy -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowLogic -import net.corda.core.identity.AbstractParty -import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction -import net.corda.core.transactions.WireTransaction import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.unwrap @@ -32,7 +27,7 @@ abstract class AbstractStateReplacementFlow { * @param M the type of a class representing proposed modification by the instigator. */ @CordaSerializable - data class Proposal(val stateRef: StateRef, val modification: M, val stx: SignedTransaction) + data class Proposal(val stateRef: StateRef, val modification: M) /** * The assembled transaction for upgrading a contract. @@ -56,7 +51,7 @@ abstract class AbstractStateReplacementFlow { abstract class Instigator( val originalState: StateAndRef, val modification: M, - override val progressTracker: ProgressTracker = tracker()) : FlowLogic>() { + override val progressTracker: ProgressTracker = Instigator.tracker()) : FlowLogic>() { companion object { object SIGNING : ProgressTracker.Step("Requesting signatures from other parties") object NOTARY : ProgressTracker.Step("Requesting notary signature") @@ -79,7 +74,16 @@ abstract class AbstractStateReplacementFlow { val finalTx = stx + signatures serviceHub.recordTransactions(finalTx) - return finalTx.tx.outRef(0) + + val newOutput = run { + if (stx.isNotaryChangeTransaction()) { + stx.resolveNotaryChangeTransaction(serviceHub).outRef(0) + } else { + stx.tx.outRef(0) + } + } + + return newOutput } /** @@ -91,7 +95,7 @@ abstract class AbstractStateReplacementFlow { abstract protected fun assembleTx(): UpgradeTx @Suspendable - private fun collectSignatures(participants: Iterable, stx: SignedTransaction): List { + private fun collectSignatures(participants: Iterable, stx: SignedTransaction): List { val parties = participants.map { val participantNode = serviceHub.networkMapCache.getNodeByLegalIdentityKey(it) ?: throw IllegalStateException("Participant $it to state $originalState not found on the network") @@ -109,10 +113,10 @@ abstract class AbstractStateReplacementFlow { } @Suspendable - private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey { - val proposal = Proposal(originalState.ref, modification, stx) - val response = sendAndReceive(party, proposal) - return response.unwrap { + private fun getParticipantSignature(party: Party, stx: SignedTransaction): TransactionSignature { + val proposal = Proposal(originalState.ref, modification) + subFlow(SendTransactionFlow(party, stx)) + return sendAndReceive(party, proposal).unwrap { check(party.owningKey.isFulfilledBy(it.by)) { "Not signed by the required participant" } it.verify(stx.id) it @@ -120,7 +124,7 @@ abstract class AbstractStateReplacementFlow { } @Suspendable - private fun getNotarySignatures(stx: SignedTransaction): List { + private fun getNotarySignatures(stx: SignedTransaction): List { progressTracker.currentStep = NOTARY try { return subFlow(NotaryFlow.Client(stx)) @@ -133,7 +137,7 @@ abstract class AbstractStateReplacementFlow { // Type parameter should ideally be Unit but that prevents Java code from subclassing it (https://youtrack.jetbrains.com/issue/KT-15964). // We use Void? instead of Unit? as that's what you'd use in Java. abstract class Acceptor(val otherSide: Party, - override val progressTracker: ProgressTracker = tracker()) : FlowLogic() { + override val progressTracker: ProgressTracker = Acceptor.tracker()) : FlowLogic() { companion object { object VERIFYING : ProgressTracker.Step("Verifying state replacement proposal") object APPROVING : ProgressTracker.Step("State replacement approved") @@ -145,63 +149,61 @@ abstract class AbstractStateReplacementFlow { @Throws(StateReplacementException::class) override fun call(): Void? { progressTracker.currentStep = VERIFYING + // We expect stx to have insufficient signatures here + val stx = subFlow(ReceiveTransactionFlow(otherSide, checkSufficientSignatures = false)) + checkMySignatureRequired(stx) val maybeProposal: UntrustworthyData> = receive(otherSide) - val stx: SignedTransaction = maybeProposal.unwrap { - verifyProposal(it) - verifyTx(it.stx) - it.stx + maybeProposal.unwrap { + verifyProposal(stx, it) } approve(stx) return null } - @Suspendable - private fun verifyTx(stx: SignedTransaction) { - checkMySignatureRequired(stx.tx) - checkDependenciesValid(stx) - // We expect stx to have insufficient signatures, so we convert the WireTransaction to the LedgerTransaction - // here, thus bypassing the sufficient-signatures check. - stx.tx.toLedgerTransaction(serviceHub).verify() - } - @Suspendable private fun approve(stx: SignedTransaction) { progressTracker.currentStep = APPROVING val mySignature = sign(stx) - val swapSignatures = sendAndReceive>(otherSide, mySignature) + val swapSignatures = sendAndReceive>(otherSide, mySignature) - // TODO: This step should not be necessary, as signatures are re-checked in verifySignatures. + // TODO: This step should not be necessary, as signatures are re-checked in verifyRequiredSignatures. val allSignatures = swapSignatures.unwrap { signatures -> signatures.forEach { it.verify(stx.id) } signatures } val finalTx = stx + allSignatures - finalTx.verifySignatures() + if (finalTx.isNotaryChangeTransaction()) { + finalTx.resolveNotaryChangeTransaction(serviceHub).verifyRequiredSignatures() + } else { + finalTx.verifyRequiredSignatures() + } serviceHub.recordTransactions(finalTx) } /** - * Check the state change proposal to confirm that it's acceptable to this node. Rules for verification depend - * on the change proposed, and may further depend on the node itself (for example configuration). The - * proposal is returned if acceptable, otherwise a [StateReplacementException] is thrown. + * Check the state change proposal and the signed transaction to confirm that it's acceptable to this node. + * Rules for verification depend on the change proposed, and may further depend on the node itself (for example configuration). + * The proposal is returned if acceptable, otherwise a [StateReplacementException] is thrown. */ @Throws(StateReplacementException::class) - abstract protected fun verifyProposal(proposal: Proposal) + abstract protected fun verifyProposal(stx: SignedTransaction, proposal: Proposal) - private fun checkMySignatureRequired(tx: WireTransaction) { + private fun checkMySignatureRequired(stx: SignedTransaction) { // TODO: use keys from the keyManagementService instead val myKey = serviceHub.myInfo.legalIdentity.owningKey - require(myKey in tx.mustSign) { "Party is not a participant for any of the input states of transaction ${tx.id}" } + + val requiredKeys = if (stx.isNotaryChangeTransaction()) { + stx.resolveNotaryChangeTransaction(serviceHub).requiredSigningKeys + } else { + stx.tx.requiredSigningKeys + } + + require(myKey in requiredKeys) { "Party is not a participant for any of the input states of transaction ${stx.id}" } } - @Suspendable - private fun checkDependenciesValid(stx: SignedTransaction) { - subFlow(ResolveTransactionsFlow(stx.tx, otherSide)) - } - - private fun sign(stx: SignedTransaction): DigitalSignature.WithKey { + private fun sign(stx: SignedTransaction): TransactionSignature { return serviceHub.createSignature(stx) } } diff --git a/core/src/main/kotlin/net/corda/flows/BroadcastTransactionFlow.kt b/core/src/main/kotlin/net/corda/core/flows/BroadcastTransactionFlow.kt similarity index 68% rename from core/src/main/kotlin/net/corda/flows/BroadcastTransactionFlow.kt rename to core/src/main/kotlin/net/corda/core/flows/BroadcastTransactionFlow.kt index 140da37489..221200341d 100644 --- a/core/src/main/kotlin/net/corda/flows/BroadcastTransactionFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/BroadcastTransactionFlow.kt @@ -1,11 +1,9 @@ -package net.corda.flows +package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.Party -import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.NonEmptySet /** * Notify the specified parties about a transaction. The remote peers will download this transaction and its @@ -18,17 +16,13 @@ import net.corda.core.transactions.SignedTransaction */ @InitiatingFlow class BroadcastTransactionFlow(val notarisedTransaction: SignedTransaction, - val participants: Set) : FlowLogic() { - @CordaSerializable - data class NotifyTxRequest(val tx: SignedTransaction) - + val participants: NonEmptySet) : FlowLogic() { @Suspendable override fun call() { // TODO: Messaging layer should handle this broadcast for us - val msg = NotifyTxRequest(notarisedTransaction) participants.filter { it != serviceHub.myInfo.legalIdentity }.forEach { participant -> - // This pops out the other side in NotifyTransactionHandler - send(participant, msg) + // SendTransactionFlow allows otherParty to access our data to resolve the transaction. + subFlow(SendTransactionFlow(participant, notarisedTransaction)) } } } diff --git a/core/src/main/kotlin/net/corda/flows/CollectSignaturesFlow.kt b/core/src/main/kotlin/net/corda/core/flows/CollectSignaturesFlow.kt similarity index 83% rename from core/src/main/kotlin/net/corda/flows/CollectSignaturesFlow.kt rename to core/src/main/kotlin/net/corda/core/flows/CollectSignaturesFlow.kt index 5bd80d9f98..6d9a2d1b8e 100644 --- a/core/src/main/kotlin/net/corda/flows/CollectSignaturesFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/CollectSignaturesFlow.kt @@ -1,11 +1,9 @@ -package net.corda.flows +package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable -import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.toBase58String -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party import net.corda.core.node.ServiceHub import net.corda.core.transactions.SignedTransaction @@ -19,7 +17,7 @@ import java.security.PublicKey * * You would typically use this flow after you have built a transaction with the TransactionBuilder and signed it with * your key pair. If there are additional signatures to collect then they can be collected using this flow. Signatures - * are collected based upon the [WireTransaction.mustSign] property which contains the union of all the PublicKeys + * are collected based upon the [WireTransaction.requiredSigningKeys] property which contains the union of all the PublicKeys * listed in the transaction's commands as well as a notary's public key, if required. This flow returns a * [SignedTransaction] which can then be passed to the [FinalityFlow] for notarisation. The other side of this flow is * the [SignTransactionFlow]. @@ -44,7 +42,7 @@ import java.security.PublicKey * * Example - issuing a multi-lateral agreement which requires N signatures: * - * val builder = TransactionType.General.Builder(notaryRef) + * val builder = TransactionBuilder(notaryRef) * val issueCommand = Command(Agreement.Commands.Issue(), state.participants) * * builder.withItems(state, issueCommand) @@ -62,7 +60,7 @@ import java.security.PublicKey // TODO: AbstractStateReplacementFlow needs updating to use this flow. // TODO: Update this flow to handle randomly generated keys when that works is complete. class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, - override val progressTracker: ProgressTracker = tracker()): FlowLogic() { + override val progressTracker: ProgressTracker = CollectSignaturesFlow.tracker()) : FlowLogic() { companion object { object COLLECTING : ProgressTracker.Step("Collecting signatures from counter-parties.") @@ -80,7 +78,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, // Usually just the Initiator and possibly an oracle would have signed at this point. val myKey = serviceHub.myInfo.legalIdentity.owningKey val signed = partiallySignedTx.sigs.map { it.by } - val notSigned = partiallySignedTx.tx.mustSign - signed + val notSigned = partiallySignedTx.tx.requiredSigningKeys - signed // One of the signatures collected so far MUST be from the initiator of this flow. require(partiallySignedTx.sigs.any { it.by == myKey }) { @@ -88,7 +86,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, } // The signatures must be valid and the transaction must be valid. - partiallySignedTx.verifySignatures(*notSigned.toTypedArray()) + partiallySignedTx.verifySignaturesExcept(*notSigned.toTypedArray()) partiallySignedTx.tx.toLedgerTransaction(serviceHub).verify() // Determine who still needs to sign. @@ -107,7 +105,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, // Verify all but the notary's signature if the transaction requires a notary, otherwise verify all signatures. progressTracker.currentStep = VERIFYING - if (notaryKey != null) stx.verifySignatures(notaryKey) else stx.verifySignatures() + if (notaryKey != null) stx.verifySignaturesExcept(notaryKey) else stx.verifyRequiredSignatures() return stx } @@ -115,7 +113,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, /** * Lookup the [Party] object for each [PublicKey] using the [ServiceHub.networkMapCache]. */ - @Suspendable private fun keysToParties(keys: List): List = keys.map { + @Suspendable private fun keysToParties(keys: Collection): List = keys.map { // TODO: Revisit when IdentityService supports resolution of a (possibly random) public key to a legal identity key. val partyNode = serviceHub.networkMapCache.getNodeByLegalIdentityKey(it) ?: throw IllegalStateException("Party ${it.toBase58String()} not found on the network.") @@ -126,8 +124,10 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, /** * Get and check the required signature. */ - @Suspendable private fun collectSignature(counterparty: Party): DigitalSignature.WithKey { - return sendAndReceive(counterparty, partiallySignedTx).unwrap { + @Suspendable private fun collectSignature(counterparty: Party): TransactionSignature { + // SendTransactionFlow allows otherParty to access our data to resolve the transaction. + subFlow(SendTransactionFlow(counterparty, partiallySignedTx)) + return receive(counterparty).unwrap { require(counterparty.owningKey.isFulfilledBy(it.by)) { "Not signed by the required Party." } it } @@ -175,7 +175,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, * @param otherParty The counter-party which is providing you a transaction to sign. */ abstract class SignTransactionFlow(val otherParty: Party, - override val progressTracker: ProgressTracker = tracker()) : FlowLogic() { + override val progressTracker: ProgressTracker = SignTransactionFlow.tracker()) : FlowLogic() { companion object { object RECEIVING : ProgressTracker.Step("Receiving transaction proposal for signing.") @@ -187,35 +187,30 @@ abstract class SignTransactionFlow(val otherParty: Party, @Suspendable override fun call(): SignedTransaction { progressTracker.currentStep = RECEIVING - val checkedProposal = receive(otherParty).unwrap { proposal -> - progressTracker.currentStep = VERIFYING - // Check that the Responder actually needs to sign. - checkMySignatureRequired(proposal) - // Check the signatures which have already been provided. Usually the Initiators and possibly an Oracle's. - checkSignatures(proposal) - // Resolve dependencies and verify, pass in the WireTransaction as we don't have all signatures. - subFlow(ResolveTransactionsFlow(proposal.tx, otherParty)) - proposal.tx.toLedgerTransaction(serviceHub).verify() - // Perform some custom verification over the transaction. - try { - checkTransaction(proposal) - } catch(e: Exception) { - if (e is IllegalStateException || e is IllegalArgumentException || e is AssertionError) - throw FlowException(e) - else - throw e - } - // All good. Unwrap the proposal. - proposal + // Receive transaction and resolve dependencies, check sufficient signatures is disabled as we don't have all signatures. + val stx = subFlow(ReceiveTransactionFlow(otherParty, checkSufficientSignatures = false)) + progressTracker.currentStep = VERIFYING + // Check that the Responder actually needs to sign. + checkMySignatureRequired(stx) + // Check the signatures which have already been provided. Usually the Initiators and possibly an Oracle's. + checkSignatures(stx) + stx.tx.toLedgerTransaction(serviceHub).verify() + // Perform some custom verification over the transaction. + try { + checkTransaction(stx) + } catch(e: Exception) { + if (e is IllegalStateException || e is IllegalArgumentException || e is AssertionError) + throw FlowException(e) + else + throw e } - // Sign and send back our signature to the Initiator. progressTracker.currentStep = SIGNING - val mySignature = serviceHub.createSignature(checkedProposal) + val mySignature = serviceHub.createSignature(stx) send(otherParty, mySignature) // Return the fully signed transaction once it has been committed. - return waitForLedgerCommit(checkedProposal.id) + return waitForLedgerCommit(stx.id) } @Suspendable private fun checkSignatures(stx: SignedTransaction) { @@ -223,9 +218,9 @@ abstract class SignTransactionFlow(val otherParty: Party, "The Initiator of CollectSignaturesFlow must have signed the transaction." } val signed = stx.sigs.map { it.by } - val allSigners = stx.tx.mustSign + val allSigners = stx.tx.requiredSigningKeys val notSigned = allSigners - signed - stx.verifySignatures(*notSigned.toTypedArray()) + stx.verifySignaturesExcept(*notSigned.toTypedArray()) } /** @@ -253,7 +248,7 @@ abstract class SignTransactionFlow(val otherParty: Party, @Suspendable private fun checkMySignatureRequired(stx: SignedTransaction) { // TODO: Revisit when key management is properly fleshed out. val myKey = serviceHub.myInfo.legalIdentity.owningKey - require(myKey in stx.tx.mustSign) { + require(myKey in stx.tx.requiredSigningKeys) { "Party is not a participant for any of the input states of transaction ${stx.id}" } } diff --git a/core/src/main/kotlin/net/corda/flows/ContractUpgradeFlow.kt b/core/src/main/kotlin/net/corda/core/flows/ContractUpgradeFlow.kt similarity index 82% rename from core/src/main/kotlin/net/corda/flows/ContractUpgradeFlow.kt rename to core/src/main/kotlin/net/corda/core/flows/ContractUpgradeFlow.kt index ce785ed6f2..d6e7024811 100644 --- a/core/src/main/kotlin/net/corda/flows/ContractUpgradeFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/ContractUpgradeFlow.kt @@ -1,10 +1,7 @@ -package net.corda.flows +package net.corda.core.flows import net.corda.core.contracts.* -import net.corda.core.flows.InitiatingFlow -import net.corda.core.flows.StartableByRPC -import net.corda.core.identity.AbstractParty -import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder import java.security.PublicKey @@ -25,14 +22,17 @@ class ContractUpgradeFlow().single()) } @JvmStatic - fun verify(input: ContractState, output: ContractState, commandData: Command) { - val command = commandData.value as UpgradeCommand + fun verify(input: ContractState, output: ContractState, commandData: Command) { + val command = commandData.value val participantKeys: Set = input.participants.map { it.owningKey }.toSet() val keysThatSigned: Set = commandData.signers.toSet() @Suppress("UNCHECKED_CAST") @@ -47,19 +47,22 @@ class ContractUpgradeFlow assembleBareTx( stateRef: StateAndRef, - upgradedContractClass: Class> + upgradedContractClass: Class>, + privacySalt: PrivacySalt ): TransactionBuilder { val contractUpgrade = upgradedContractClass.newInstance() - return TransactionType.General.Builder(stateRef.state.notary) + return TransactionBuilder(stateRef.state.notary) .withItems( stateRef, contractUpgrade.upgrade(stateRef.state.data), - Command(UpgradeCommand(upgradedContractClass), stateRef.state.data.participants.map { it.owningKey })) + Command(UpgradeCommand(upgradedContractClass), stateRef.state.data.participants.map { it.owningKey }), + privacySalt + ) } } override fun assembleTx(): AbstractStateReplacementFlow.UpgradeTx { - val baseTx = assembleBareTx(originalState, modification) + val baseTx = assembleBareTx(originalState, modification, PrivacySalt()) val participantKeys = originalState.state.data.participants.map { it.owningKey }.toSet() // TODO: We need a much faster way of finding our key in the transaction val myKey = serviceHub.keyManagementService.filterMyKeys(participantKeys).single() diff --git a/core/src/main/kotlin/net/corda/flows/FinalityFlow.kt b/core/src/main/kotlin/net/corda/core/flows/FinalityFlow.kt similarity index 60% rename from core/src/main/kotlin/net/corda/flows/FinalityFlow.kt rename to core/src/main/kotlin/net/corda/core/flows/FinalityFlow.kt index dfc3c3c20d..883deac3ac 100644 --- a/core/src/main/kotlin/net/corda/flows/FinalityFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FinalityFlow.kt @@ -1,16 +1,18 @@ -package net.corda.flows +package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateRef import net.corda.core.contracts.TransactionState import net.corda.core.crypto.isFulfilledBy -import net.corda.core.flows.FlowLogic +import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.internal.ResolveTransactionsFlow import net.corda.core.node.ServiceHub import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.toNonEmptySet /** * Verifies the given transactions, then sends them to the named notary. If the notary agrees that the transactions @@ -32,9 +34,10 @@ import net.corda.core.utilities.ProgressTracker * @param transactions What to commit. * @param extraRecipients A list of additional participants to inform of the transaction. */ -class FinalityFlow(val transactions: Iterable, +open class FinalityFlow(val transactions: Iterable, val extraRecipients: Set, override val progressTracker: ProgressTracker) : FlowLogic>() { + val extraParticipants: Set = extraRecipients.map { it -> Participant(it, it) }.toSet() constructor(transaction: SignedTransaction, extraParticipants: Set) : this(listOf(transaction), extraParticipants, tracker()) constructor(transaction: SignedTransaction) : this(listOf(transaction), emptySet(), tracker()) constructor(transaction: SignedTransaction, progressTracker: ProgressTracker) : this(listOf(transaction), emptySet(), progressTracker) @@ -50,6 +53,9 @@ class FinalityFlow(val transactions: Iterable, fun tracker() = ProgressTracker(NOTARISING, BROADCASTING) } + open protected val me + get() = serviceHub.myInfo.legalIdentity + @Suspendable @Throws(NotaryException::class) override fun call(): List { @@ -59,31 +65,43 @@ class FinalityFlow(val transactions: Iterable, // Lookup the resolved transactions and use them to map each signed transaction to the list of participants. // Then send to the notary if needed, record locally and distribute. progressTracker.currentStep = NOTARISING - val notarisedTxns = notariseAndRecord(lookupParties(resolveDependenciesOf(transactions))) + val notarisedTxns: List>> = resolveDependenciesOf(transactions) + .map { (stx, ltx) -> Pair(notariseAndRecord(stx), lookupParties(ltx)) } // Each transaction has its own set of recipients, but extra recipients get them all. progressTracker.currentStep = BROADCASTING - val me = serviceHub.myInfo.legalIdentity for ((stx, parties) in notarisedTxns) { - subFlow(BroadcastTransactionFlow(stx, parties + extraRecipients - me)) + broadcastTransaction(stx, (parties + extraParticipants).filter { it.wellKnown != me }) } return notarisedTxns.map { it.first } } - // TODO: API: Make some of these protected? + /** + * Broadcast a transaction to the participants. By default calls [BroadcastTransactionFlow], however can be + * overridden for more complex transaction delivery protocols (for example where not all parties know each other). + * This implementation will filter out any participants for whom there is no well known identity. + * + * @param participants the participants to send the transaction to. This is expected to include extra participants + * and exclude the local node. + */ + @Suspendable + open protected fun broadcastTransaction(stx: SignedTransaction, participants: Iterable) { + val wellKnownParticipants = participants.map { it.wellKnown }.filterNotNull() + if (wellKnownParticipants.isNotEmpty()) { + subFlow(BroadcastTransactionFlow(stx, wellKnownParticipants.toNonEmptySet())) + } + } @Suspendable - private fun notariseAndRecord(stxnsAndParties: List>>): List>> { - return stxnsAndParties.map { (stx, parties) -> - val notarised = if (needsNotarySignature(stx)) { - val notarySignatures = subFlow(NotaryFlow.Client(stx)) - stx + notarySignatures - } else { - stx - } - serviceHub.recordTransactions(notarised) - Pair(notarised, parties) + private fun notariseAndRecord(stx: SignedTransaction): SignedTransaction { + val notarised = if (needsNotarySignature(stx)) { + val notarySignatures = subFlow(NotaryFlow.Client(stx)) + stx + notarySignatures + } else { + stx } + serviceHub.recordTransactions(notarised) + return notarised } private fun needsNotarySignature(stx: SignedTransaction): Boolean { @@ -99,14 +117,31 @@ class FinalityFlow(val transactions: Iterable, return !(notaryKey?.isFulfilledBy(signers) ?: false) } - private fun lookupParties(ltxns: List>): List>> { - return ltxns.map { (stx, ltx) -> - // Calculate who is meant to see the results based on the participants involved. - val keys = ltx.outputs.flatMap { it.data.participants } + ltx.inputs.flatMap { it.state.data.participants } - // TODO: Is it safe to drop participants we don't know how to contact? Does not knowing how to contact them count as a reason to fail? - val parties = keys.mapNotNull { serviceHub.identityService.partyFromAnonymous(it) }.toSet() - Pair(stx, parties) - } + /** + * Resolve the parties involved in a transaction. + * + * @return the set of participants and their resolved well known identities (where known). + */ + open protected fun lookupParties(ltx: LedgerTransaction): Set { + // Calculate who is meant to see the results based on the participants involved. + return extractParticipants(ltx) + .map(this::partyFromAnonymous) + .toSet() + } + + /** + * Helper function to extract all participants from a ledger transaction. Intended to help implement [lookupParties] + * overriding functions. + */ + protected fun extractParticipants(ltx: LedgerTransaction): List { + return ltx.outputStates.flatMap { it.participants } + ltx.inputStates.flatMap { it.participants } + } + + /** + * Helper function which wraps [IdentityService.partyFromAnonymous] so it can be called as a lambda function. + */ + protected fun partyFromAnonymous(anon: AbstractParty): Participant { + return Participant(anon, serviceHub.identityService.partyFromAnonymous(anon)) } private fun resolveDependenciesOf(signedTransactions: Iterable): List> { @@ -125,10 +160,12 @@ class FinalityFlow(val transactions: Iterable, return sorted.map { stx -> val notary = stx.tx.notary // The notary signature(s) are allowed to be missing but no others. - val wtx = if (notary != null) stx.verifySignatures(notary.owningKey) else stx.verifySignatures() - val ltx = wtx.toLedgerTransaction(augmentedLookup) + if (notary != null) stx.verifySignaturesExcept(notary.owningKey) else stx.verifyRequiredSignatures() + val ltx = stx.toLedgerTransaction(augmentedLookup, false) ltx.verify() stx to ltx } } + + data class Participant(val participant: AbstractParty, val wellKnown: Party?) } diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt index e527f22c55..4e20320883 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowException.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowException.kt @@ -25,6 +25,6 @@ open class FlowException(message: String?, cause: Throwable?) : CordaException(m * that we were not expecting), or the other side had an internal error, or the other side terminated when we * were waiting for a response. */ -class FlowSessionException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) { +class UnexpectedFlowEndException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) { constructor(msg: String) : this(msg, null) } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt index 4583d14a62..6a8d7f68f9 100644 --- a/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt +++ b/core/src/main/kotlin/net/corda/core/flows/FlowLogic.kt @@ -4,8 +4,10 @@ import co.paralleluniverse.fibers.Suspendable import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.abbreviate import net.corda.core.messaging.DataFeed import net.corda.core.node.ServiceHub +import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.UntrustworthyData @@ -51,10 +53,15 @@ abstract class FlowLogic { */ val serviceHub: ServiceHub get() = stateMachine.serviceHub - @Deprecated("This is no longer used and will be removed in a future release. If you are using this to communicate " + - "with the same party but for two different message streams, then the correct way of doing that is to use sub-flows", - level = DeprecationLevel.ERROR) - open fun getCounterpartyMarker(party: Party): Class<*> = javaClass + /** + * Returns a [FlowContext] object describing the flow [otherParty] is using. With [FlowContext.flowVersion] it + * provides the necessary information needed for the evolution of flows and enabling backwards compatibility. + * + * This method can be called before any send or receive has been done with [otherParty]. In such a case this will force + * them to start their flow. + */ + @Suspendable + fun getFlowContext(otherParty: Party): FlowContext = stateMachine.getFlowContext(otherParty, flowUsedForSessions) /** * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response @@ -89,11 +96,6 @@ abstract class FlowLogic { return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions) } - /** @see sendAndReceiveWithRetry */ - internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { - return sendAndReceiveWithRetry(R::class.java, otherParty, payload) - } - /** * Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received. * @@ -103,9 +105,8 @@ abstract class FlowLogic { * oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered * to a different one, so there is no need to wait until the initial node comes back up to obtain a response. */ - @Suspendable - internal open fun sendAndReceiveWithRetry(receiveType: Class, otherParty: Party, payload: Any): UntrustworthyData { - return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, true) + internal inline fun sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData { + return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, true) } /** @@ -139,7 +140,7 @@ abstract class FlowLogic { * network's event horizon time. */ @Suspendable - open fun send(otherParty: Party, payload: Any) = stateMachine.send(otherParty, payload, flowUsedForSessions) + open fun send(otherParty: Party, payload: Any): Unit = stateMachine.send(otherParty, payload, flowUsedForSessions) /** * Invokes the given subflow. This function returns once the subflow completes successfully with the result @@ -163,7 +164,7 @@ abstract class FlowLogic { } logger.debug { "Calling subflow: $subLogic" } val result = subLogic.call() - logger.debug { "Subflow finished with result $result" } + logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" } // It's easy to forget this when writing flows so we just step it to the DONE state when it completes. subLogic.progressTracker?.currentStep = ProgressTracker.DONE return result @@ -180,7 +181,9 @@ abstract class FlowLogic { * @param extraAuditData in the audit log for this permission check these extra key value pairs will be recorded. */ @Throws(FlowException::class) - fun checkFlowPermission(permissionName: String, extraAuditData: Map) = stateMachine.checkFlowPermission(permissionName, extraAuditData) + fun checkFlowPermission(permissionName: String, extraAuditData: Map) { + stateMachine.checkFlowPermission(permissionName, extraAuditData) + } /** @@ -189,7 +192,9 @@ abstract class FlowLogic { * @param comment a general human readable summary of the event. * @param extraAuditData in the audit log for this permission check these extra key value pairs will be recorded. */ - fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) = stateMachine.recordAuditEvent(eventType, comment, extraAuditData) + fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) { + stateMachine.recordAuditEvent(eventType, comment, extraAuditData) + } /** * Override this to provide a [ProgressTracker]. If one is provided and stepped, the framework will do something @@ -230,6 +235,29 @@ abstract class FlowLogic { @Suspendable fun waitForLedgerCommit(hash: SecureHash): SignedTransaction = stateMachine.waitForLedgerCommit(hash, this) + /** + * Returns a shallow copy of the Quasar stack frames at the time of call to [flowStackSnapshot]. Use this to inspect + * what objects would be serialised at the time of call to a suspending action (e.g. send/receive). + * Note: This logic is only available during tests and is not meant to be used during the production deployment. + * Therefore the default implementationdoes nothing. + */ + @Suspendable + fun flowStackSnapshot(): FlowStackSnapshot? = stateMachine.flowStackSnapshot(this::class.java) + + /** + * Persists a shallow copy of the Quasar stack frames at the time of call to [persistFlowStackSnapshot]. + * Use this to track the monitor evolution of the quasar stack values during the flow execution. + * The flow stack snapshot is stored in a file located in {baseDir}/flowStackSnapshots/YYYY-MM-DD/{flowId}/ + * where baseDir is the node running directory and flowId is the flow unique identifier generated by the platform. + * + * Note: With respect to the [flowStackSnapshot], the snapshot being persisted by this method is partial, + * meaning that only flow relevant traces and local variables are persisted. + * Also, this logic is only available during tests and is not meant to be used during the production deployment. + * Therefore the default implementation does nothing. + */ + @Suspendable + fun persistFlowStackSnapshot(): Unit = stateMachine.persistFlowStackSnapshot(this::class.java) + //////////////////////////////////////////////////////////////////////////////////////////////////////////// private var _stateMachine: FlowStateMachine<*>? = null @@ -261,3 +289,20 @@ abstract class FlowLogic { } } } + +/** + * Version and name of the CorDapp hosting the other side of the flow. + */ +@CordaSerializable +data class FlowContext( + /** + * The integer flow version the other side is using. + * @see InitiatingFlow + */ + val flowVersion: Int, + /** + * Name of the CorDapp jar hosting the flow, without the .jar extension. It will include a unique identifier + * to deduplicate it from other releases of the same CorDapp, typically a version string. See the + * [CorDapp JAR format](https://docs.corda.net/cordapp-build-systems.html#cordapp-jar-format) for more details. + */ + val appName: String) diff --git a/core/src/main/kotlin/net/corda/core/flows/FlowStackSnapshot.kt b/core/src/main/kotlin/net/corda/core/flows/FlowStackSnapshot.kt new file mode 100644 index 0000000000..c1866adcc8 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/flows/FlowStackSnapshot.kt @@ -0,0 +1,73 @@ +package net.corda.core.flows + +import net.corda.core.utilities.loggerFor +import java.nio.file.Path +import java.util.* + +interface FlowStackSnapshotFactory { + private object Holder { + val INSTANCE: FlowStackSnapshotFactory + + init { + val serviceFactory = ServiceLoader.load(FlowStackSnapshotFactory::class.java).singleOrNull() + INSTANCE = serviceFactory ?: FlowStackSnapshotDefaultFactory() + } + } + + companion object { + val instance: FlowStackSnapshotFactory by lazy { Holder.INSTANCE } + } + + /** + * Returns flow stack data snapshot extracted from Quasar stack. + * It is designed to be used in the debug mode of the flow execution. + * Note. This logic is only available during tests and is not meant to be used during the production deployment. + * Therefore the default implementation does nothing. + */ + fun getFlowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot? + + /** Stores flow stack snapshot as a json file. The stored shapshot is only partial and consists + * only data (i.e. stack traces and local variables values) relevant to the flow. It does not + * persist corda internal data (e.g. FlowStateMachine). Instead it uses [StackFrameDataToken] to indicate + * the class of the element on the stack. + * The flow stack snapshot is stored in a file located in + * {baseDir}/flowStackSnapshots/YYYY-MM-DD/{flowId}/ + * where baseDir is the node running directory and flowId is the flow unique identifier generated by the platform. + * Note. This logic is only available during tests and is not meant to be used during the production deployment. + * Therefore the default implementation does nothing. + */ + fun persistAsJsonFile(flowClass: Class<*>, baseDir: Path, flowId: String): Unit +} + +private class FlowStackSnapshotDefaultFactory : FlowStackSnapshotFactory { + val log = loggerFor() + + override fun getFlowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot? { + log.warn("Flow stack snapshot are not supposed to be used in a production deployment") + return null + } + + override fun persistAsJsonFile(flowClass: Class<*>, baseDir: Path, flowId: String) { + log.warn("Flow stack snapshot are not supposed to be used in a production deployment") + } +} + +/** + * Main data object representing snapshot of the flow stack, extracted from the Quasar stack. + */ +data class FlowStackSnapshot constructor( + val timestamp: Long = System.currentTimeMillis(), + val flowClass: Class<*>? = null, + val stackFrames: List = listOf() +) { + data class Frame( + val stackTraceElement: StackTraceElement? = null, // This should be the call that *pushed* the frame of [objects] + val stackObjects: List = listOf() + ) +} + +/** + * Token class, used to indicate stack presence of the corda internal data. Since this data is of no use for + * a CordApp developer, it is skipped from serialisation and its presence is only marked by this token. + */ +data class StackFrameDataToken(val className: String) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/flows/ManualFinalityFlow.kt b/core/src/main/kotlin/net/corda/core/flows/ManualFinalityFlow.kt new file mode 100644 index 0000000000..f91b38b85f --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/flows/ManualFinalityFlow.kt @@ -0,0 +1,20 @@ +package net.corda.core.flows + +import net.corda.core.identity.Party +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.ProgressTracker + +/** + * Alternative finality flow which only does not attempt to take participants from the transaction, but instead all + * participating parties must be provided manually. + * + * @param transactions What to commit. + * @param extraRecipients A list of additional participants to inform of the transaction. + */ +class ManualFinalityFlow(transactions: Iterable, + recipients: Set, + progressTracker: ProgressTracker) : FinalityFlow(transactions, recipients, progressTracker) { + constructor(transaction: SignedTransaction, extraParticipants: Set) : this(listOf(transaction), extraParticipants, tracker()) + override fun lookupParties(ltx: LedgerTransaction): Set = emptySet() +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/flows/NotaryChangeFlow.kt b/core/src/main/kotlin/net/corda/core/flows/NotaryChangeFlow.kt new file mode 100644 index 0000000000..a861ad8f84 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/flows/NotaryChangeFlow.kt @@ -0,0 +1,59 @@ +package net.corda.core.flows + +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SignableData +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.identity.Party +import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.ProgressTracker + +/** + * A flow to be used for changing a state's Notary. This is required since all input states to a transaction + * must point to the same notary. + * + * This assembles the transaction for notary replacement and sends out change proposals to all participants + * of that state. If participants agree to the proposed change, they each sign the transaction. + * Finally, the transaction containing all signatures is sent back to each participant so they can record it and + * use the new updated state for future transactions. + */ +@InitiatingFlow +class NotaryChangeFlow( + originalState: StateAndRef, + newNotary: Party, + progressTracker: ProgressTracker = AbstractStateReplacementFlow.Instigator.tracker()) + : AbstractStateReplacementFlow.Instigator(originalState, newNotary, progressTracker) { + + override fun assembleTx(): AbstractStateReplacementFlow.UpgradeTx { + val inputs = resolveEncumbrances(originalState) + + val tx = NotaryChangeWireTransaction( + inputs.map { it.ref }, + originalState.state.notary, + modification + ) + + val participantKeys = inputs.flatMap { it.state.data.participants }.map { it.owningKey }.toSet() + // TODO: We need a much faster way of finding our key in the transaction + val myKey = serviceHub.keyManagementService.filterMyKeys(participantKeys).single() + val signableData = SignableData(tx.id, SignatureMetadata(serviceHub.myInfo.platformVersion, Crypto.findSignatureScheme(myKey).schemeNumberID)) + val mySignature = serviceHub.keyManagementService.sign(signableData, myKey) + val stx = SignedTransaction(tx, listOf(mySignature)) + + return AbstractStateReplacementFlow.UpgradeTx(stx, participantKeys, myKey) + } + + /** Resolves the encumbrance state chain for the given [state] */ + private fun resolveEncumbrances(state: StateAndRef): List> { + val states = mutableListOf(state) + while (states.last().state.encumbrance != null) { + val encumbranceStateRef = StateRef(states.last().ref.txhash, states.last().state.encumbrance!!) + val encumbranceState = serviceHub.toStateAndRef(encumbranceStateRef) + states.add(encumbranceState) + } + return states + } +} diff --git a/core/src/main/kotlin/net/corda/flows/NotaryFlow.kt b/core/src/main/kotlin/net/corda/core/flows/NotaryFlow.kt similarity index 65% rename from core/src/main/kotlin/net/corda/flows/NotaryFlow.kt rename to core/src/main/kotlin/net/corda/core/flows/NotaryFlow.kt index 579f9c8125..584692bd3d 100644 --- a/core/src/main/kotlin/net/corda/flows/NotaryFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/NotaryFlow.kt @@ -1,21 +1,22 @@ -package net.corda.flows +package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.StateRef import net.corda.core.contracts.TimeWindow -import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SignedData +import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.keys -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.Party -import net.corda.core.node.services.* +import net.corda.core.internal.FetchDataFlow +import net.corda.core.node.services.NotaryService +import net.corda.core.node.services.TrustedAuthorityNotaryService +import net.corda.core.node.services.UniquenessProvider import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.unwrap +import java.security.SignatureException import java.util.function.Predicate object NotaryFlow { @@ -31,8 +32,8 @@ object NotaryFlow { */ @InitiatingFlow open class Client(private val stx: SignedTransaction, - override val progressTracker: ProgressTracker) : FlowLogic>() { - constructor(stx: SignedTransaction) : this(stx, Client.tracker()) + override val progressTracker: ProgressTracker) : FlowLogic>() { + constructor(stx: SignedTransaction) : this(stx, tracker()) companion object { object REQUESTING : ProgressTracker.Step("Requesting signature by Notary service") @@ -45,27 +46,36 @@ object NotaryFlow { @Suspendable @Throws(NotaryException::class) - override fun call(): List { + override fun call(): List { progressTracker.currentStep = REQUESTING - val wtx = stx.tx - notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary") - check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { + + notaryParty = stx.notary ?: throw IllegalStateException("Transaction does not specify a Notary") + check(stx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { "Input states must have the same Notary" } - try { - stx.verifySignatures(notaryParty.owningKey) - } catch (ex: SignedTransaction.SignaturesMissingException) { - throw NotaryException(NotaryError.SignaturesMissing(ex)) - } - val payload: Any = if (serviceHub.networkMapCache.isValidatingNotary(notaryParty)) { - stx - } else { - wtx.buildFilteredTransaction(Predicate { it is StateRef || it is TimeWindow }) + try { + if (stx.isNotaryChangeTransaction()) { + stx.resolveNotaryChangeTransaction(serviceHub).verifySignaturesExcept(notaryParty.owningKey) + } else { + stx.verifySignaturesExcept(notaryParty.owningKey) + } + } catch (ex: SignatureException) { + throw NotaryException(NotaryError.TransactionInvalid(ex)) } val response = try { - sendAndReceiveWithRetry>(notaryParty, payload) + if (serviceHub.networkMapCache.isValidatingNotary(notaryParty)) { + subFlow(SendTransactionWithRetry(notaryParty, stx)) + receive>(notaryParty) + } else { + val tx: Any = if (stx.isNotaryChangeTransaction()) { + stx.notaryChangeTx + } else { + stx.tx.buildFilteredTransaction(Predicate { it is StateRef || it is TimeWindow }) + } + sendAndReceiveWithRetry(notaryParty, tx) + } } catch (e: NotaryException) { if (e.error is NotaryError.Conflict) { e.error.conflict.verified() @@ -74,14 +84,14 @@ object NotaryFlow { } return response.unwrap { signatures -> - signatures.forEach { validateSignature(it, stx.id.bytes) } + signatures.forEach { validateSignature(it, stx.id) } signatures } } - private fun validateSignature(sig: DigitalSignature.WithKey, data: ByteArray) { + private fun validateSignature(sig: TransactionSignature, txId: SecureHash) { check(sig.by in notaryParty.owningKey.keys) { "Invalid signer for the notary result" } - sig.verify(data) + sig.verify(txId) } } @@ -114,7 +124,7 @@ object NotaryFlow { @Suspendable private fun signAndSendResponse(txId: SecureHash) { - val signature = service.sign(txId.bytes) + val signature = service.sign(txId) send(otherSide, listOf(signature)) } } @@ -137,10 +147,16 @@ sealed class NotaryError { /** Thrown if the time specified in the [TimeWindow] command is outside the allowed tolerance. */ object TimeWindowInvalid : NotaryError() - data class TransactionInvalid(val msg: String) : NotaryError() - data class SignaturesInvalid(val msg: String) : NotaryError() - - data class SignaturesMissing(val cause: SignedTransaction.SignaturesMissingException) : NotaryError() { + data class TransactionInvalid(val cause: Throwable) : NotaryError() { override fun toString() = cause.toString() } } + +/** + * The [SendTransactionWithRetry] flow is equivalent to [SendTransactionFlow] but using [sendAndReceiveWithRetry] + * instead of [sendAndReceive], [SendTransactionWithRetry] is intended to be use by the notary client only. + */ +private class SendTransactionWithRetry(otherSide: Party, stx: SignedTransaction) : SendTransactionFlow(otherSide, stx) { + @Suspendable + override fun sendPayloadAndReceiveDataRequest(otherSide: Party, payload: Any) = sendAndReceiveWithRetry(otherSide, payload) +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt b/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt new file mode 100644 index 0000000000..48504034d7 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/flows/ReceiveTransactionFlow.kt @@ -0,0 +1,48 @@ +package net.corda.core.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.contracts.* +import net.corda.core.identity.Party +import net.corda.core.internal.ResolveTransactionsFlow +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.unwrap +import java.security.SignatureException + +/** + * The [ReceiveTransactionFlow] should be called in response to the [SendTransactionFlow]. + * + * This flow is a combination of [receive], resolve and [SignedTransaction.verify]. This flow will receive the [SignedTransaction] + * and perform the resolution back-and-forth required to check the dependencies and download any missing attachments. + * The flow will return the [SignedTransaction] after it is resolved and then verified using [SignedTransaction.verify]. + */ +class ReceiveTransactionFlow +@JvmOverloads +constructor(private val otherParty: Party, private val checkSufficientSignatures: Boolean = true) : FlowLogic() { + @Suspendable + @Throws(SignatureException::class, AttachmentResolutionException::class, TransactionResolutionException::class, TransactionVerificationException::class) + override fun call(): SignedTransaction { + return receive(otherParty).unwrap { + subFlow(ResolveTransactionsFlow(it, otherParty)) + it.verify(serviceHub, checkSufficientSignatures) + it + } + } +} + +/** + * The [ReceiveStateAndRefFlow] should be called in response to the [SendStateAndRefFlow]. + * + * This flow is a combination of [receive] and resolve. This flow will receive a list of [StateAndRef] + * and perform the resolution back-and-forth required to check the dependencies. + * The flow will return the list of [StateAndRef] after it is resolved. + */ +// @JvmSuppressWildcards is used to suppress wildcards in return type when calling `subFlow(new ReceiveStateAndRef(otherParty))` in java. +class ReceiveStateAndRefFlow(private val otherParty: Party) : FlowLogic<@JvmSuppressWildcards List>>() { + @Suspendable + override fun call(): List> { + return receive>>(otherParty).unwrap { + subFlow(ResolveTransactionsFlow(it.map { it.ref.txhash }.toSet(), otherParty)) + it + } + } +} diff --git a/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt b/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt new file mode 100644 index 0000000000..12cc8592d0 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/flows/SendTransactionFlow.kt @@ -0,0 +1,69 @@ +package net.corda.core.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.contracts.StateAndRef +import net.corda.core.identity.Party +import net.corda.core.internal.FetchDataFlow +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.unwrap + +/** + * The [SendTransactionFlow] should be used to send a transaction to another peer that wishes to verify that transaction's + * integrity by resolving and checking the dependencies as well. The other side should invoke [ReceiveTransactionFlow] at + * the right point in the conversation to receive the sent transaction and perform the resolution back-and-forth required + * to check the dependencies and download any missing attachments. + * + * @param otherSide the target party. + * @param stx the [SignedTransaction] being sent to the [otherSide]. + */ +open class SendTransactionFlow(otherSide: Party, stx: SignedTransaction) : DataVendingFlow(otherSide, stx) + +/** + * The [SendStateAndRefFlow] should be used to send a list of input [StateAndRef] to another peer that wishes to verify + * the input's integrity by resolving and checking the dependencies as well. The other side should invoke [ReceiveStateAndRefFlow] + * at the right point in the conversation to receive the input state and ref and perform the resolution back-and-forth + * required to check the dependencies. + * + * @param otherSide the target party. + * @param stateAndRefs the list of [StateAndRef] being sent to the [otherSide]. + */ +open class SendStateAndRefFlow(otherSide: Party, stateAndRefs: List>) : DataVendingFlow(otherSide, stateAndRefs) + +sealed class DataVendingFlow(val otherSide: Party, val payload: Any) : FlowLogic() { + @Suspendable + protected open fun sendPayloadAndReceiveDataRequest(otherSide: Party, payload: Any) = sendAndReceive(otherSide, payload) + + @Suspendable + protected open fun verifyDataRequest(dataRequest: FetchDataFlow.Request.Data) { + // User can override this method to perform custom request verification. + } + + @Suspendable + override fun call(): Void? { + // The first payload will be the transaction data, subsequent payload will be the transaction/attachment data. + var payload = payload + // This loop will receive [FetchDataFlow.Request] continuously until the `otherSide` has all the data they need + // to resolve the transaction, a [FetchDataFlow.EndRequest] will be sent from the `otherSide` to indicate end of + // data request. + while (true) { + val dataRequest = sendPayloadAndReceiveDataRequest(otherSide, payload).unwrap { request -> + when (request) { + is FetchDataFlow.Request.Data -> { + // Security TODO: Check for abnormally large or malformed data requests + verifyDataRequest(request) + request + } + FetchDataFlow.Request.End -> return null + } + } + payload = when (dataRequest.dataType) { + FetchDataFlow.DataType.TRANSACTION -> dataRequest.hashes.map { + serviceHub.validatedTransactions.getTransaction(it) ?: throw FetchDataFlow.HashNotFound(it) + } + FetchDataFlow.DataType.ATTACHMENT -> dataRequest.hashes.map { + serviceHub.attachments.openAttachment(it)?.open()?.readBytes() ?: throw FetchDataFlow.HashNotFound(it) + } + } + } + } +} diff --git a/core/src/main/kotlin/net/corda/flows/TransactionKeyFlow.kt b/core/src/main/kotlin/net/corda/core/flows/TransactionKeyFlow.kt similarity index 50% rename from core/src/main/kotlin/net/corda/flows/TransactionKeyFlow.kt rename to core/src/main/kotlin/net/corda/core/flows/TransactionKeyFlow.kt index 1989ef5d82..120d6c6a56 100644 --- a/core/src/main/kotlin/net/corda/flows/TransactionKeyFlow.kt +++ b/core/src/main/kotlin/net/corda/core/flows/TransactionKeyFlow.kt @@ -1,10 +1,10 @@ -package net.corda.flows +package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.InitiatingFlow -import net.corda.core.flows.StartableByRPC +import net.corda.core.identity.AnonymousParty +import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.Party +import net.corda.core.node.services.IdentityService import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.unwrap @@ -16,33 +16,37 @@ import net.corda.core.utilities.unwrap @InitiatingFlow class TransactionKeyFlow(val otherSide: Party, val revocationEnabled: Boolean, - override val progressTracker: ProgressTracker) : FlowLogic>() { + override val progressTracker: ProgressTracker) : FlowLogic>() { constructor(otherSide: Party) : this(otherSide, false, tracker()) companion object { object AWAITING_KEY : ProgressTracker.Step("Awaiting key") fun tracker() = ProgressTracker(AWAITING_KEY) - fun validateIdentity(otherSide: Party, anonymousOtherSide: AnonymisedIdentity): AnonymisedIdentity { - require(anonymousOtherSide.certificate.subject == otherSide.name) + fun validateAndRegisterIdentity(identityService: IdentityService, otherSide: Party, anonymousOtherSide: PartyAndCertificate): PartyAndCertificate { + require(anonymousOtherSide.name == otherSide.name) + // Validate then store their identity so that we can prove the key in the transaction is owned by the + // counterparty. + identityService.verifyAndRegisterIdentity(anonymousOtherSide) return anonymousOtherSide } } @Suspendable - override fun call(): LinkedHashMap { + override fun call(): LinkedHashMap { progressTracker.currentStep = AWAITING_KEY val legalIdentityAnonymous = serviceHub.keyManagementService.freshKeyAndCert(serviceHub.myInfo.legalIdentityAndCert, revocationEnabled) - serviceHub.identityService.registerAnonymousIdentity(legalIdentityAnonymous.identity, serviceHub.myInfo.legalIdentity, legalIdentityAnonymous.certPath) // Special case that if we're both parties, a single identity is generated - val identities = LinkedHashMap() + val identities = LinkedHashMap() if (otherSide == serviceHub.myInfo.legalIdentity) { - identities.put(otherSide, legalIdentityAnonymous) + identities.put(otherSide, legalIdentityAnonymous.party.anonymise()) } else { - val otherSideAnonymous = sendAndReceive(otherSide, legalIdentityAnonymous).unwrap { validateIdentity(otherSide, it) } - identities.put(serviceHub.myInfo.legalIdentity, legalIdentityAnonymous) - identities.put(otherSide, otherSideAnonymous) + val anonymousOtherSide = sendAndReceive(otherSide, legalIdentityAnonymous).unwrap { confidentialIdentity -> + validateAndRegisterIdentity(serviceHub.identityService, otherSide, confidentialIdentity) + } + identities.put(serviceHub.myInfo.legalIdentity, legalIdentityAnonymous.party.anonymise()) + identities.put(otherSide, anonymousOtherSide.party.anonymise()) } return identities } diff --git a/core/src/main/kotlin/net/corda/core/identity/AnonymisedIdentity.kt b/core/src/main/kotlin/net/corda/core/identity/AnonymisedIdentity.kt deleted file mode 100644 index 0048917443..0000000000 --- a/core/src/main/kotlin/net/corda/core/identity/AnonymisedIdentity.kt +++ /dev/null @@ -1,16 +0,0 @@ -package net.corda.flows - -import net.corda.core.identity.AnonymousParty -import net.corda.core.serialization.CordaSerializable -import org.bouncycastle.cert.X509CertificateHolder -import java.security.PublicKey -import java.security.cert.CertPath - -@CordaSerializable -data class AnonymisedIdentity( - val certPath: CertPath, - val certificate: X509CertificateHolder, - val identity: AnonymousParty) { - constructor(certPath: CertPath, certificate: X509CertificateHolder, identity: PublicKey) - : this(certPath, certificate, AnonymousParty(identity)) -} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/identity/AnonymousParty.kt b/core/src/main/kotlin/net/corda/core/identity/AnonymousParty.kt index 33ffffb19b..b2264a1c6f 100644 --- a/core/src/main/kotlin/net/corda/core/identity/AnonymousParty.kt +++ b/core/src/main/kotlin/net/corda/core/identity/AnonymousParty.kt @@ -2,6 +2,7 @@ package net.corda.core.identity import net.corda.core.contracts.PartyAndReference import net.corda.core.crypto.toBase58String +import net.corda.core.crypto.toStringShort import net.corda.core.utilities.OpaqueBytes import org.bouncycastle.asn1.x500.X500Name import java.security.PublicKey @@ -13,7 +14,7 @@ import java.security.PublicKey class AnonymousParty(owningKey: PublicKey) : AbstractParty(owningKey) { // Use the key as the bulk of the toString(), but include a human readable identifier as well, so that [Party] // can put in the key and actual name - override fun toString() = "${owningKey.toBase58String()} " + override fun toString() = "${owningKey.toStringShort()} " override fun nameOrNull(): X500Name? = null diff --git a/core/src/main/kotlin/net/corda/core/identity/Party.kt b/core/src/main/kotlin/net/corda/core/identity/Party.kt index e41c550c84..a2f301be75 100644 --- a/core/src/main/kotlin/net/corda/core/identity/Party.kt +++ b/core/src/main/kotlin/net/corda/core/identity/Party.kt @@ -30,5 +30,6 @@ class Party(val name: X500Name, owningKey: PublicKey) : AbstractParty(owningKey) override fun toString() = name.toString() override fun nameOrNull(): X500Name? = name + fun anonymise(): AnonymousParty = AnonymousParty(owningKey) override fun ref(bytes: OpaqueBytes): PartyAndReference = PartyAndReference(this, bytes) } diff --git a/core/src/main/kotlin/net/corda/core/identity/PartyAndCertificate.kt b/core/src/main/kotlin/net/corda/core/identity/PartyAndCertificate.kt index 12557d562e..557c26d7ee 100644 --- a/core/src/main/kotlin/net/corda/core/identity/PartyAndCertificate.kt +++ b/core/src/main/kotlin/net/corda/core/identity/PartyAndCertificate.kt @@ -4,12 +4,14 @@ import net.corda.core.serialization.CordaSerializable import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.cert.X509CertificateHolder import java.security.PublicKey -import java.security.cert.CertPath +import java.security.cert.* +import java.util.* /** * A full party plus the X.509 certificate and path linking the party back to a trust root. Equality of * [PartyAndCertificate] instances is based on the party only, as certificate and path are data associated with the party, - * not part of the identifier themselves. + * not part of the identifier themselves. While party and certificate can both be derived from the certificate path, + * this class exists in order to ensure the implementation classes of certificates and party public keys are kept stable. */ @CordaSerializable data class PartyAndCertificate(val party: Party, @@ -30,4 +32,18 @@ data class PartyAndCertificate(val party: Party, override fun hashCode(): Int = party.hashCode() override fun toString(): String = party.toString() + + /** + * Verify that the given certificate path is valid and leads to the owning key of the party. + */ + fun verify(trustAnchor: TrustAnchor): PKIXCertPathValidatorResult { + require(certPath.certificates.first() is X509Certificate) { "Subject certificate must be an X.509 certificate" } + require(Arrays.equals(party.owningKey.encoded, certificate.subjectPublicKeyInfo.encoded)) { "Certificate public key must match party owning key" } + require(Arrays.equals(certPath.certificates.first().encoded, certificate.encoded)) { "Certificate path must link to certificate" } + + val validatorParameters = PKIXParameters(setOf(trustAnchor)) + val validator = CertPathValidator.getInstance("PKIX") + validatorParameters.isRevocationEnabled = false + return validator.validate(certPath, validatorParameters) as PKIXCertPathValidatorResult + } } diff --git a/core/src/main/kotlin/net/corda/core/utilities/Emoji.kt b/core/src/main/kotlin/net/corda/core/internal/Emoji.kt similarity index 83% rename from core/src/main/kotlin/net/corda/core/utilities/Emoji.kt rename to core/src/main/kotlin/net/corda/core/internal/Emoji.kt index f2a6a6b566..7c7e5202f5 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/Emoji.kt +++ b/core/src/main/kotlin/net/corda/core/internal/Emoji.kt @@ -1,6 +1,4 @@ -package net.corda.core.utilities - -import net.corda.core.codePointsString +package net.corda.core.internal /** * A simple wrapper class that contains icons and support for printing them only when we're connected to a terminal. @@ -11,8 +9,8 @@ object Emoji { // Check for that here. DemoBench sets TERM_PROGRAM appropriately. val hasEmojiTerminal by lazy { System.getenv("CORDA_FORCE_EMOJI") != null || - System.getenv("TERM_PROGRAM") in listOf("Apple_Terminal", "iTerm.app") || - (System.getenv("TERM_PROGRAM") == "JediTerm" && System.getProperty("java.vendor") == "JetBrains s.r.o") + System.getenv("TERM_PROGRAM") in listOf("Apple_Terminal", "iTerm.app") || + (System.getenv("TERM_PROGRAM") == "JediTerm" && System.getProperty("java.vendor") == "JetBrains s.r.o") } @JvmStatic val CODE_SANTA_CLAUS: String = codePointsString(0x1F385) @@ -29,6 +27,9 @@ object Emoji { @JvmStatic val CODE_BOOKS: String = codePointsString(0x1F4DA) @JvmStatic val CODE_SLEEPING_FACE: String = codePointsString(0x1F634) @JvmStatic val CODE_LIGHTBULB: String = codePointsString(0x1F4A1) + @JvmStatic val CODE_FREE: String = codePointsString(0x1F193) + @JvmStatic val CODE_SOON: String = codePointsString(0x1F51C) + /** * When non-null, toString() methods are allowed to use emoji in the output as we're going to render them to a @@ -46,6 +47,8 @@ object Emoji { val books: String get() = if (emojiMode.get() != null) "$CODE_BOOKS " else "" val sleepingFace: String get() = if (emojiMode.get() != null) "$CODE_SLEEPING_FACE " else "" val lightBulb: String get() = if (emojiMode.get() != null) "$CODE_LIGHTBULB " else "" + val free: String get() = if (emojiMode.get() != null) "$CODE_FREE " else "" + val soon: String get() = if (emojiMode.get() != null) "$CODE_SOON " else "" // These have old/non-emoji symbols with better platform support. val greenTick: String get() = if (emojiMode.get() != null) "$CODE_GREEN_TICK " else "✓" @@ -79,4 +82,9 @@ object Emoji { } } + private fun codePointsString(vararg codePoints: Int): String { + val builder = StringBuilder() + codePoints.forEach { builder.append(Character.toChars(it)) } + return builder.toString() + } } diff --git a/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt b/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt new file mode 100644 index 0000000000..5006843121 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/FetchDataFlow.kt @@ -0,0 +1,179 @@ +package net.corda.core.internal + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.contracts.AbstractAttachment +import net.corda.core.contracts.Attachment +import net.corda.core.contracts.NamedByHash +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.sha256 +import net.corda.core.flows.FlowException +import net.corda.core.flows.FlowLogic +import net.corda.core.identity.Party +import net.corda.core.internal.FetchDataFlow.DownloadedVsRequestedDataMismatch +import net.corda.core.internal.FetchDataFlow.HashNotFound +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationToken +import net.corda.core.serialization.SerializeAsToken +import net.corda.core.serialization.SerializeAsTokenContext +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.UntrustworthyData +import net.corda.core.utilities.unwrap +import java.util.* + +/** + * An abstract flow for fetching typed data from a remote peer. + * + * Given a set of hashes (IDs), either loads them from local disk or asks the remote peer to provide them. + * + * A malicious response in which the data provided by the remote peer does not hash to the requested hash results in + * [DownloadedVsRequestedDataMismatch] being thrown. If the remote peer doesn't have an entry, it results in a + * [HashNotFound] exception being thrown. + * + * By default this class does not insert data into any local database, if you want to do that after missing items were + * fetched then override [maybeWriteToDisk]. You *must* override [load]. If the wire type is not the same as the + * ultimate type, you must also override [convert]. + * + * @param T The ultimate type of the data being fetched. + * @param W The wire type of the data being fetched, for when it isn't the same as the ultimate type. + */ +sealed class FetchDataFlow( + protected val requests: Set, + protected val otherSide: Party, + protected val dataType: DataType) : FlowLogic>() { + + @CordaSerializable + class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : IllegalArgumentException() + + @CordaSerializable + class DownloadedVsRequestedSizeMismatch(val requested: Int, val got: Int) : IllegalArgumentException() + + class HashNotFound(val requested: SecureHash) : FlowException() + + @CordaSerializable + data class Result(val fromDisk: List, val downloaded: List) + + @CordaSerializable + sealed class Request { + data class Data(val hashes: NonEmptySet, val dataType: DataType) : Request() + object End : Request() + } + + @CordaSerializable + enum class DataType { + TRANSACTION, ATTACHMENT + } + + @Suspendable + @Throws(HashNotFound::class) + override fun call(): Result { + // Load the items we have from disk and figure out which we're missing. + val (fromDisk, toFetch) = loadWhatWeHave() + + return if (toFetch.isEmpty()) { + Result(fromDisk, emptyList()) + } else { + logger.info("Requesting ${toFetch.size} dependency(s) for verification from ${otherSide.name}") + + // TODO: Support "large message" response streaming so response sizes are not limited by RAM. + // We can then switch to requesting items in large batches to minimise the latency penalty. + // This is blocked by bugs ARTEMIS-1278 and ARTEMIS-1279. For now we limit attachments and txns to 10mb each + // and don't request items in batch, which is a performance loss, but works around the issue. We have + // configured Artemis to not fragment messages up to 10mb so we can send 10mb messages without problems. + // Above that, we start losing authentication data on the message fragments and take exceptions in the + // network layer. + val maybeItems = ArrayList(toFetch.size) + for (hash in toFetch) { + // We skip the validation here (with unwrap { it }) because we will do it below in validateFetchResponse. + // The only thing checked is the object type. It is a protocol violation to send results out of order. + maybeItems += sendAndReceive>(otherSide, Request.Data(NonEmptySet.of(hash), dataType)).unwrap { it } + } + // Check for a buggy/malicious peer answering with something that we didn't ask for. + val downloaded = validateFetchResponse(UntrustworthyData(maybeItems), toFetch) + logger.info("Fetched ${downloaded.size} elements from ${otherSide.name}") + maybeWriteToDisk(downloaded) + Result(fromDisk, downloaded) + } + } + + protected open fun maybeWriteToDisk(downloaded: List) { + // Do nothing by default. + } + + private fun loadWhatWeHave(): Pair, List> { + val fromDisk = ArrayList() + val toFetch = ArrayList() + for (txid in requests) { + val stx = load(txid) + if (stx == null) + toFetch += txid + else + fromDisk += stx + } + return Pair(fromDisk, toFetch) + } + + protected abstract fun load(txid: SecureHash): T? + + @Suppress("UNCHECKED_CAST") + protected open fun convert(wire: W): T = wire as T + + private fun validateFetchResponse(maybeItems: UntrustworthyData>, + requests: List): List { + return maybeItems.unwrap { response -> + if (response.size != requests.size) + throw DownloadedVsRequestedSizeMismatch(requests.size, response.size) + val answers = response.map { convert(it) } + // Check transactions actually hash to what we requested, if this fails the remote node + // is a malicious flow violator or buggy. + for ((index, item) in answers.withIndex()) { + if (item.id != requests[index]) + throw DownloadedVsRequestedDataMismatch(requests[index], item.id) + } + answers + } + } +} + + +/** + * Given a set of hashes either loads from from local storage or requests them from the other peer. Downloaded + * attachments are saved to local storage automatically. + */ +class FetchAttachmentsFlow(requests: Set, + otherSide: Party) : FetchDataFlow(requests, otherSide, DataType.ATTACHMENT) { + + override fun load(txid: SecureHash): Attachment? = serviceHub.attachments.openAttachment(txid) + + override fun convert(wire: ByteArray): Attachment = FetchedAttachment({ wire }) + + override fun maybeWriteToDisk(downloaded: List) { + for (attachment in downloaded) { + serviceHub.attachments.importAttachment(attachment.open()) + } + } + + private class FetchedAttachment(dataLoader: () -> ByteArray) : AbstractAttachment(dataLoader), SerializeAsToken { + override val id: SecureHash by lazy { attachmentData.sha256() } + + private class Token(private val id: SecureHash) : SerializationToken { + override fun fromToken(context: SerializeAsTokenContext) = FetchedAttachment(context.attachmentDataLoader(id)) + } + + override fun toToken(context: SerializeAsTokenContext) = Token(id) + } +} + +/** + * Given a set of tx hashes (IDs), either loads them from local disk or asks the remote peer to provide them. + * + * A malicious response in which the data provided by the remote peer does not hash to the requested hash results in + * [FetchDataFlow.DownloadedVsRequestedDataMismatch] being thrown. If the remote peer doesn't have an entry, it + * results in a [FetchDataFlow.HashNotFound] exception. Note that returned transactions are not inserted into + * the database, because it's up to the caller to actually verify the transactions are valid. + */ +class FetchTransactionsFlow(requests: Set, otherSide: Party) : + FetchDataFlow(requests, otherSide, DataType.TRANSACTION) { + + override fun load(txid: SecureHash): SignedTransaction? = serviceHub.validatedTransactions.getTransaction(txid) +} diff --git a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt index 8c87cfaa51..7ee183b348 100644 --- a/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt +++ b/core/src/main/kotlin/net/corda/core/internal/FlowStateMachine.kt @@ -1,10 +1,12 @@ package net.corda.core.internal import co.paralleluniverse.fibers.Suspendable -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowContext import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowStackSnapshot import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party import net.corda.core.node.ServiceHub @@ -14,6 +16,9 @@ import org.slf4j.Logger /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ interface FlowStateMachine { + @Suspendable + fun getFlowContext(otherParty: Party, sessionFlow: FlowLogic<*>): FlowContext + @Suspendable fun sendAndReceive(receiveType: Class, otherParty: Party, @@ -25,18 +30,37 @@ interface FlowStateMachine { fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData @Suspendable - fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) + fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>): Unit @Suspendable fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction - fun checkFlowPermission(permissionName: String, extraAuditData: Map) + fun checkFlowPermission(permissionName: String, extraAuditData: Map): Unit - fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) + fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map): Unit + + /** + * Returns a shallow copy of the Quasar stack frames at the time of call to [flowStackSnapshot]. Use this to inspect + * what objects would be serialised at the time of call to a suspending action (e.g. send/receive). + */ + @Suspendable + fun flowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot? + + /** + * Persists a shallow copy of the Quasar stack frames at the time of call to [persistFlowStackSnapshot]. + * Use this to track the monitor evolution of the quasar stack values during the flow execution. + * The flow stack snapshot is stored in a file located in {baseDir}/flowStackSnapshots/YYYY-MM-DD/{flowId}/ + * where baseDir is the node running directory and flowId is the flow unique identifier generated by the platform. + * + * Note: With respect to the [flowStackSnapshot], the snapshot being persisted by this method is partial, + * meaning that only flow relevant traces and local variables are persisted. + */ + @Suspendable + fun persistFlowStackSnapshot(flowClass: Class<*>): Unit val serviceHub: ServiceHub val logger: Logger val id: StateMachineRunId - val resultFuture: ListenableFuture + val resultFuture: CordaFuture val flowInitiator: FlowInitiator } diff --git a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt new file mode 100644 index 0000000000..e0d9f37580 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt @@ -0,0 +1,258 @@ +package net.corda.core.internal + +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.sha256 +import org.slf4j.Logger +import rx.Observable +import rx.Observer +import rx.subjects.PublishSubject +import rx.subjects.UnicastSubject +import java.io.* +import java.lang.reflect.Field +import java.math.BigDecimal +import java.nio.charset.Charset +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.* +import java.nio.file.attribute.FileAttribute +import java.time.Duration +import java.time.temporal.Temporal +import java.util.* +import java.util.Spliterator.* +import java.util.stream.IntStream +import java.util.stream.Stream +import java.util.stream.StreamSupport +import java.util.zip.Deflater +import java.util.zip.ZipEntry +import java.util.zip.ZipOutputStream +import kotlin.reflect.KClass + +val Throwable.rootCause: Throwable get() = cause?.rootCause ?: this +fun Throwable.getStackTraceAsString() = StringWriter().also { printStackTrace(PrintWriter(it)) }.toString() + +infix fun Temporal.until(endExclusive: Temporal): Duration = Duration.between(this, endExclusive) + +operator fun Duration.div(divider: Long): Duration = dividedBy(divider) +operator fun Duration.times(multiplicand: Long): Duration = multipliedBy(multiplicand) + +/** + * Allows you to write code like: Paths.get("someDir") / "subdir" / "filename" but using the Paths API to avoid platform + * separator problems. + */ +operator fun Path.div(other: String): Path = resolve(other) +operator fun String.div(other: String): Path = Paths.get(this) / other + +/** + * Returns the single element matching the given [predicate], or `null` if the collection is empty, or throws exception + * if more than one element was found. + */ +inline fun Iterable.noneOrSingle(predicate: (T) -> Boolean): T? { + val iterator = iterator() + for (single in iterator) { + if (predicate(single)) { + while (iterator.hasNext()) { + if (predicate(iterator.next())) throw IllegalArgumentException("Collection contains more than one matching element.") + } + return single + } + } + return null +} + +/** + * Returns the single element, or `null` if the list is empty, or throws an exception if it has more than one element. + */ +fun List.noneOrSingle(): T? { + return when (size) { + 0 -> null + 1 -> this[0] + else -> throw IllegalArgumentException("List has more than one element.") + } +} + +/** Returns a random element in the list, or `null` if empty */ +fun List.randomOrNull(): T? { + return when (size) { + 0 -> null + 1 -> this[0] + else -> this[(Math.random() * size).toInt()] + } +} + +/** Returns the index of the given item or throws [IllegalArgumentException] if not found. */ +fun List.indexOfOrThrow(item: T): Int { + val i = indexOf(item) + require(i != -1) + return i +} + +fun Path.createDirectory(vararg attrs: FileAttribute<*>): Path = Files.createDirectory(this, *attrs) +fun Path.createDirectories(vararg attrs: FileAttribute<*>): Path = Files.createDirectories(this, *attrs) +fun Path.exists(vararg options: LinkOption): Boolean = Files.exists(this, *options) +fun Path.copyToDirectory(targetDir: Path, vararg options: CopyOption): Path { + require(targetDir.isDirectory()) { "$targetDir is not a directory" } + val targetFile = targetDir.resolve(fileName) + Files.copy(this, targetFile, *options) + return targetFile +} +fun Path.moveTo(target: Path, vararg options: CopyOption): Path = Files.move(this, target, *options) +fun Path.isRegularFile(vararg options: LinkOption): Boolean = Files.isRegularFile(this, *options) +fun Path.isDirectory(vararg options: LinkOption): Boolean = Files.isDirectory(this, *options) +inline val Path.size: Long get() = Files.size(this) +inline fun Path.list(block: (Stream) -> R): R = Files.list(this).use(block) +fun Path.deleteIfExists(): Boolean = Files.deleteIfExists(this) +fun Path.readAll(): ByteArray = Files.readAllBytes(this) +inline fun Path.read(vararg options: OpenOption, block: (InputStream) -> R): R = Files.newInputStream(this, *options).use(block) +inline fun Path.write(createDirs: Boolean = false, vararg options: OpenOption = emptyArray(), block: (OutputStream) -> Unit) { + if (createDirs) { + normalize().parent?.createDirectories() + } + Files.newOutputStream(this, *options).use(block) +} + +inline fun Path.readLines(charset: Charset = UTF_8, block: (Stream) -> R): R = Files.lines(this, charset).use(block) +fun Path.readAllLines(charset: Charset = UTF_8): List = Files.readAllLines(this, charset) +fun Path.writeLines(lines: Iterable, charset: Charset = UTF_8, vararg options: OpenOption): Path = Files.write(this, lines, charset, *options) + +fun InputStream.copyTo(target: Path, vararg options: CopyOption): Long = Files.copy(this, target, *options) + +fun String.abbreviate(maxWidth: Int): String = if (length <= maxWidth) this else take(maxWidth - 1) + "…" + +/** Return the sum of an Iterable of [BigDecimal]s. */ +fun Iterable.sum(): BigDecimal = fold(BigDecimal.ZERO) { a, b -> a + b } + +/** + * Returns an Observable that buffers events until subscribed. + * @see UnicastSubject + */ +fun Observable.bufferUntilSubscribed(): Observable { + val subject = UnicastSubject.create() + val subscription = subscribe(subject) + return subject.doOnUnsubscribe { subscription.unsubscribe() } +} + +/** Copy an [Observer] to multiple other [Observer]s. */ +fun Observer.tee(vararg teeTo: Observer): Observer { + val subject = PublishSubject.create() + subject.subscribe(this) + teeTo.forEach { subject.subscribe(it) } + return subject +} + +/** Executes the given code block and returns a [Duration] of how long it took to execute in nanosecond precision. */ +inline fun elapsedTime(block: () -> Unit): Duration { + val start = System.nanoTime() + block() + val end = System.nanoTime() + return Duration.ofNanos(end - start) +} + + +fun Logger.logElapsedTime(label: String, body: () -> T): T = logElapsedTime(label, this, body) + +// TODO: Add inline back when a new Kotlin version is released and check if the java.lang.VerifyError +// returns in the IRSSimulationTest. If not, commit the inline back. +fun logElapsedTime(label: String, logger: Logger? = null, body: () -> T): T { + // Use nanoTime as it's monotonic. + val now = System.nanoTime() + try { + return body() + } finally { + val elapsed = Duration.ofNanos(System.nanoTime() - now).toMillis() + if (logger != null) + logger.info("$label took $elapsed msec") + else + println("$label took $elapsed msec") + } +} + +/** Convert a [ByteArrayOutputStream] to [InputStreamAndHash]. */ +fun ByteArrayOutputStream.toInputStreamAndHash(): InputStreamAndHash { + val bytes = toByteArray() + return InputStreamAndHash(ByteArrayInputStream(bytes), bytes.sha256()) +} + +data class InputStreamAndHash(val inputStream: InputStream, val sha256: SecureHash.SHA256) { + companion object { + /** + * Get a valid InputStream from an in-memory zip as required for some tests. The zip consists of a single file + * called "z" that contains the given content byte repeated the given number of times. + * Note that a slightly bigger than numOfExpectedBytes size is expected. + */ + fun createInMemoryTestZip(numOfExpectedBytes: Int, content: Byte): InputStreamAndHash { + require(numOfExpectedBytes > 0) + val baos = ByteArrayOutputStream() + ZipOutputStream(baos).use { zos -> + val arraySize = 1024 + val bytes = ByteArray(arraySize) { content } + val n = (numOfExpectedBytes - 1) / arraySize + 1 // same as Math.ceil(numOfExpectedBytes/arraySize). + zos.setLevel(Deflater.NO_COMPRESSION) + zos.putNextEntry(ZipEntry("z")) + for (i in 0 until n) { + zos.write(bytes, 0, arraySize) + } + zos.closeEntry() + } + return baos.toInputStreamAndHash() + } + } +} + +fun IntIterator.toJavaIterator(): PrimitiveIterator.OfInt { + return object : PrimitiveIterator.OfInt { + override fun nextInt() = this@toJavaIterator.nextInt() + override fun hasNext() = this@toJavaIterator.hasNext() + override fun remove() = throw UnsupportedOperationException("remove") + } +} + +private fun IntProgression.toSpliterator(): Spliterator.OfInt { + val spliterator = Spliterators.spliterator( + iterator().toJavaIterator(), + (1 + (last - first) / step).toLong(), + SUBSIZED or IMMUTABLE or NONNULL or SIZED or ORDERED or SORTED or DISTINCT + ) + return if (step > 0) spliterator else object : Spliterator.OfInt by spliterator { + override fun getComparator() = Comparator.reverseOrder() + } +} + +fun IntProgression.stream(parallel: Boolean = false): IntStream = StreamSupport.intStream(toSpliterator(), parallel) + +@Suppress("UNCHECKED_CAST") // When toArray has filled in the array, the component type is no longer T? but T (that may itself be nullable). +inline fun Stream.toTypedArray() = toArray { size -> arrayOfNulls(size) } as Array + +fun Class.castIfPossible(obj: Any): T? = if (isInstance(obj)) cast(obj) else null + +/** Returns a [DeclaredField] wrapper around the declared (possibly non-public) static field of the receiver [Class]. */ +fun Class<*>.staticField(name: String): DeclaredField = DeclaredField(this, name, null) +/** Returns a [DeclaredField] wrapper around the declared (possibly non-public) static field of the receiver [KClass]. */ +fun KClass<*>.staticField(name: String): DeclaredField = DeclaredField(java, name, null) +/** Returns a [DeclaredField] wrapper around the declared (possibly non-public) instance field of the receiver object. */ +fun Any.declaredField(name: String): DeclaredField = DeclaredField(javaClass, name, this) +/** + * Returns a [DeclaredField] wrapper around the (possibly non-public) instance field of the receiver object, but declared + * in its superclass [clazz]. + */ +fun Any.declaredField(clazz: KClass<*>, name: String): DeclaredField = DeclaredField(clazz.java, name, this) + +/** + * A simple wrapper around a [Field] object providing type safe read and write access using [value], ignoring the field's + * visibility. + */ +class DeclaredField(clazz: Class<*>, name: String, private val receiver: Any?) { + private val javaField = clazz.getDeclaredField(name).apply { isAccessible = true } + var value: T + @Suppress("UNCHECKED_CAST") + get() = javaField.get(receiver) as T + set(value) = javaField.set(receiver, value) +} + +/** The annotated object would have a more restricted visibility were it not needed in tests. */ +@Target(AnnotationTarget.CLASS, + AnnotationTarget.PROPERTY, + AnnotationTarget.CONSTRUCTOR, + AnnotationTarget.FUNCTION, + AnnotationTarget.TYPEALIAS) +@Retention(AnnotationRetention.SOURCE) +@MustBeDocumented +annotation class VisibleForTesting diff --git a/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt b/core/src/main/kotlin/net/corda/core/internal/LazyPool.kt similarity index 96% rename from core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt rename to core/src/main/kotlin/net/corda/core/internal/LazyPool.kt index 2649924aa1..3e4e3a526d 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/LazyPool.kt +++ b/core/src/main/kotlin/net/corda/core/internal/LazyPool.kt @@ -1,7 +1,6 @@ -package net.corda.core.utilities +package net.corda.core.internal import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.Semaphore /** diff --git a/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt b/core/src/main/kotlin/net/corda/core/internal/LazyStickyPool.kt similarity index 98% rename from core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt rename to core/src/main/kotlin/net/corda/core/internal/LazyStickyPool.kt index f44723b6b8..6746989291 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/LazyStickyPool.kt +++ b/core/src/main/kotlin/net/corda/core/internal/LazyStickyPool.kt @@ -1,4 +1,4 @@ -package net.corda.core.utilities +package net.corda.core.internal import java.util.* import java.util.concurrent.LinkedBlockingQueue diff --git a/core/src/main/kotlin/net/corda/core/utilities/LifeCycle.kt b/core/src/main/kotlin/net/corda/core/internal/LifeCycle.kt similarity index 97% rename from core/src/main/kotlin/net/corda/core/utilities/LifeCycle.kt rename to core/src/main/kotlin/net/corda/core/internal/LifeCycle.kt index bc73e9f51a..96786ea3e9 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/LifeCycle.kt +++ b/core/src/main/kotlin/net/corda/core/internal/LifeCycle.kt @@ -1,4 +1,4 @@ -package net.corda.core.utilities +package net.corda.core.internal import java.util.concurrent.locks.ReentrantReadWriteLock import kotlin.concurrent.withLock diff --git a/core/src/main/kotlin/net/corda/flows/ResolveTransactionsFlow.kt b/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt similarity index 55% rename from core/src/main/kotlin/net/corda/flows/ResolveTransactionsFlow.kt rename to core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt index 92f3b9ebd0..77fa327e91 100644 --- a/core/src/main/kotlin/net/corda/flows/ResolveTransactionsFlow.kt +++ b/core/src/main/kotlin/net/corda/core/internal/ResolveTransactionsFlow.kt @@ -1,39 +1,34 @@ -package net.corda.flows +package net.corda.core.internal import co.paralleluniverse.fibers.Suspendable -import net.corda.core.checkedAdd import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable -import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction -import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.exactAdd import java.util.* // TODO: This code is currently unit tested by TwoPartyTradeFlowTests, it should have its own tests. - -// TODO: It may be a clearer API if we make the primary c'tor private here, and only allow a single tx to be "resolved". - /** - * This flow is used to verify the validity of a transaction by recursively checking the validity of all the - * dependencies. Once a transaction is checked it's inserted into local storage so it can be relayed and won't be - * checked again. + * Resolves transactions for the specified [txHashes] along with their full history (dependency graph) from [otherSide]. + * Each retrieved transaction is validated and inserted into the local transaction storage. * - * A couple of constructors are provided that accept a single transaction. When these are used, the dependencies of that - * transaction are resolved and then the transaction itself is verified. Again, if successful, the results are inserted - * into the database as long as a [SignedTransaction] was provided. If only the [WireTransaction] form was provided - * then this isn't enough to put into the local database, so only the dependencies are checked and inserted. This way - * to use the flow is helpful when resolving and verifying a finished but partially signed transaction. - * - * The flow returns a list of verified [LedgerTransaction] objects, in a depth-first order. + * @return a list of verified [SignedTransaction] objects, in a depth-first order. */ class ResolveTransactionsFlow(private val txHashes: Set, - private val otherSide: Party) : FlowLogic>() { - + private val otherSide: Party) : FlowLogic>() { + /** + * Resolves and validates the dependencies of the specified [signedTransaction]. Fetches the attachments, but does + * *not* validate or store the [signedTransaction] itself. + * + * @return a list of verified [SignedTransaction] objects, in a depth-first order. + */ + constructor(signedTransaction: SignedTransaction, otherSide: Party) : this(dependencyIDs(signedTransaction), otherSide) { + this.signedTransaction = signedTransaction + } companion object { - private fun dependencyIDs(wtx: WireTransaction) = wtx.inputs.map { it.txhash }.toSet() + private fun dependencyIDs(stx: SignedTransaction) = stx.inputs.map { it.txhash }.toSet() /** * Topologically sorts the given transactions such that dependencies are listed before dependers. */ @@ -42,7 +37,7 @@ class ResolveTransactionsFlow(private val txHashes: Set, // Construct txhash -> dependent-txs map val forwardGraph = HashMap>() transactions.forEach { stx -> - stx.tx.inputs.forEach { (txhash) -> + stx.inputs.forEach { (txhash) -> // Note that we use a LinkedHashSet here to make the traversal deterministic (as long as the input list is) forwardGraph.getOrPut(txhash) { LinkedHashSet() }.add(stx) } @@ -65,74 +60,48 @@ class ResolveTransactionsFlow(private val txHashes: Set, require(result.size == transactions.size) return result } - } @CordaSerializable class ExcessivelyLargeTransactionGraph : Exception() - // Transactions to verify after the dependencies. - private var stx: SignedTransaction? = null - private var wtx: WireTransaction? = null + /** Transaction for fetch attachments for */ + private var signedTransaction: SignedTransaction? = null // TODO: Figure out a more appropriate DOS limit here, 5000 is simply a very bad guess. /** The maximum number of transactions this flow will try to download before bailing out. */ var transactionCountLimit = 5000 - - /** - * Resolve the full history of a transaction and verify it with its dependencies. - */ - constructor(stx: SignedTransaction, otherSide: Party) : this(stx.tx, otherSide) { - this.stx = stx - } - - /** - * Resolve the full history of a transaction and verify it with its dependencies. - */ - constructor(wtx: WireTransaction, otherSide: Party) : this(dependencyIDs(wtx), otherSide) { - this.wtx = wtx - } + set(value) { + require(value > 0) { "$value is not a valid count limit" } + field = value + } @Suspendable @Throws(FetchDataFlow.HashNotFound::class) - override fun call(): List { - val newTxns: Iterable = topologicalSort(downloadDependencies(txHashes)) + override fun call(): List { + // Start fetching data. + val newTxns = downloadDependencies(txHashes) + fetchMissingAttachments(signedTransaction?.let { newTxns + it } ?: newTxns) + send(otherSide, FetchDataFlow.Request.End) + // Finish fetching data. - // For each transaction, verify it and insert it into the database. As we are iterating over them in a - // depth-first order, we should not encounter any verification failures due to missing data. If we fail - // half way through, it's no big deal, although it might result in us attempting to re-download data - // redundantly next time we attempt verification. - val result = ArrayList() - - for (stx in newTxns) { - // Resolve to a LedgerTransaction and then run all contracts. - val ltx = stx.toLedgerTransaction(serviceHub) - // Block on each verification request. - // TODO We could recover some parallelism from the dependency graph. - serviceHub.transactionVerifierService.verify(ltx).getOrThrow() - serviceHub.recordTransactions(stx) - result += ltx + val result = topologicalSort(newTxns) + result.forEach { + // For each transaction, verify it and insert it into the database. As we are iterating over them in a + // depth-first order, we should not encounter any verification failures due to missing data. If we fail + // half way through, it's no big deal, although it might result in us attempting to re-download data + // redundantly next time we attempt verification. + it.verify(serviceHub) + serviceHub.recordTransactions(it) } - // If this flow is resolving a specific transaction, make sure we have its attachments and then verify - // it as well, but don't insert to the database. Note that when we were given a SignedTransaction (stx != null) - // we *could* insert, because successful verification implies we have everything we need here, and it might - // be a clearer API if we do that. But for consistency with the other c'tor we currently do not. - // - // If 'stx' is set, then 'wtx' is the contents (from the c'tor). - val wtx = stx?.verifySignatures() ?: wtx - wtx?.let { - fetchMissingAttachments(listOf(it)) - val ltx = it.toLedgerTransaction(serviceHub) - ltx.verify() - result += ltx - } - - return result + return signedTransaction?.let { + result + it + } ?: result } @Suspendable - private fun downloadDependencies(depsToCheck: Set): Collection { + private fun downloadDependencies(depsToCheck: Set): List { // Maintain a work queue of all hashes to load/download, initialised with our starting set. Then do a breadth // first traversal across the dependency graph. // @@ -154,7 +123,6 @@ class ResolveTransactionsFlow(private val txHashes: Set, val resultQ = LinkedHashMap() val limit = transactionCountLimit - check(limit > 0) { "$limit is not a valid count limit" } var limitCounter = 0 while (nextRequests.isNotEmpty()) { // Don't re-download the same tx when we haven't verified it yet but it's referenced multiple times in the @@ -168,21 +136,18 @@ class ResolveTransactionsFlow(private val txHashes: Set, // Request the standalone transaction data (which may refer to things we don't yet have). val downloads: List = subFlow(FetchTransactionsFlow(notAlreadyFetched, otherSide)).downloaded - fetchMissingAttachments(downloads.map { it.tx }) - for (stx in downloads) check(resultQ.putIfAbsent(stx.id, stx) == null) // Assert checks the filter at the start. // Add all input states to the work queue. - val inputHashes = downloads.flatMap { it.tx.inputs }.map { it.txhash } + val inputHashes = downloads.flatMap { it.inputs }.map { it.txhash } nextRequests.addAll(inputHashes) - limitCounter = limitCounter checkedAdd nextRequests.size + limitCounter = limitCounter exactAdd nextRequests.size if (limitCounter > limit) throw ExcessivelyLargeTransactionGraph() } - - return resultQ.values + return resultQ.values.toList() } /** @@ -190,9 +155,10 @@ class ResolveTransactionsFlow(private val txHashes: Set, * first in the returned list and thus doesn't have any unverified dependencies. */ @Suspendable - private fun fetchMissingAttachments(downloads: List) { + private fun fetchMissingAttachments(downloads: List) { // TODO: This could be done in parallel with other fetches for extra speed. - val missingAttachments = downloads.flatMap { wtx -> + val wireTransactions = downloads.filterNot { it.isNotaryChangeTransaction() }.map { it.tx } + val missingAttachments = wireTransactions.flatMap { wtx -> wtx.attachments.filter { serviceHub.attachments.openAttachment(it) == null } } if (missingAttachments.isNotEmpty()) diff --git a/core/src/main/kotlin/net/corda/core/internal/ThreadBox.kt b/core/src/main/kotlin/net/corda/core/internal/ThreadBox.kt new file mode 100644 index 0000000000..eedd576694 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/ThreadBox.kt @@ -0,0 +1,32 @@ +package net.corda.core.internal + +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock + +/** + * A threadbox is a simple utility that makes it harder to forget to take a lock before accessing some shared state. + * Simply define a private class to hold the data that must be grouped under the same lock, and then pass the only + * instance to the ThreadBox constructor. You can now use the [locked] method with a lambda to take the lock in a + * way that ensures it'll be released if there's an exception. + * + * Note that this technique is not infallible: if you capture a reference to the fields in another lambda which then + * gets stored and invoked later, there may still be unsafe multi-threaded access going on, so watch out for that. + * This is just a simple guard rail that makes it harder to slip up. + * + * Example: + *``` + * private class MutableState { var i = 5 } + * private val state = ThreadBox(MutableState()) + * + * val ii = state.locked { i } + * ``` + */ +class ThreadBox(val content: T, val lock: ReentrantLock = ReentrantLock()) { + inline fun locked(body: T.() -> R): R = lock.withLock { body(content) } + inline fun alreadyLocked(body: T.() -> R): R { + check(lock.isHeldByCurrentThread, { "Expected $lock to already be locked." }) + return body(content) + } + + fun checkNotLocked(): Unit = check(!lock.isHeldByCurrentThread) +} diff --git a/core/src/main/kotlin/net/corda/core/internal/WriteOnceProperty.kt b/core/src/main/kotlin/net/corda/core/internal/WriteOnceProperty.kt new file mode 100644 index 0000000000..ad0ae9bc39 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/WriteOnceProperty.kt @@ -0,0 +1,18 @@ +package net.corda.core.internal + +import kotlin.reflect.KProperty + +/** + * A write-once property to be used as delegate for Kotlin var properties. The expectation is that this is initialised + * prior to the spawning of any threads that may access it and so there's no need for it to be volatile. + */ +class WriteOnceProperty(private val defaultValue:T? = null) { + private var v: T? = defaultValue + + operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: throw IllegalStateException("Write-once property $property not set.") + + operator fun setValue(thisRef: Any?, property: KProperty<*>, value: T) { + check(v == defaultValue || v === value) { "Cannot set write-once property $property more than once." } + v = value + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/internal/concurrent/CordaFutureImpl.kt b/core/src/main/kotlin/net/corda/core/internal/concurrent/CordaFutureImpl.kt new file mode 100644 index 0000000000..cc60db5386 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/internal/concurrent/CordaFutureImpl.kt @@ -0,0 +1,146 @@ +package net.corda.core.internal.concurrent + +import com.google.common.annotations.VisibleForTesting +import net.corda.core.concurrent.CordaFuture +import net.corda.core.concurrent.match +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.loggerFor +import org.slf4j.Logger +import java.time.Duration +import java.util.concurrent.CompletableFuture +import java.util.concurrent.Executor +import java.util.concurrent.Future +import java.util.concurrent.TimeUnit + +/** @return a fresh [OpenFuture]. */ +fun openFuture(): OpenFuture = CordaFutureImpl() + +/** @return a done future with the given value as its outcome. */ +fun doneFuture(value: V): CordaFuture = CordaFutureImpl().apply { set(value) } + +/** @return a future that will have the same outcome as the given block, when this executor has finished running it. */ +fun Executor.fork(block: () -> V): CordaFuture = CordaFutureImpl().also { execute { it.capture(block) } } + +/** When this future is done, do [match]. */ +fun CordaFuture.thenMatch(success: (V) -> W, failure: (Throwable) -> X) = then { match(success, failure) } + +/** When this future is done and the outcome is failure, log the throwable. */ +fun CordaFuture<*>.andForget(log: Logger) = thenMatch({}, { log.error("Background task failed:", it) }) + +/** + * Returns a future that will have an outcome of applying the given transform to this future's value. + * But if this future fails, the transform is not invoked and the returned future becomes done with the same throwable. + */ +fun CordaFuture.map(transform: (V) -> W): CordaFuture = CordaFutureImpl().also { result -> + thenMatch({ + result.capture { transform(it) } + }, { + result.setException(it) + }) +} + +/** + * Returns a future that will have the same outcome as the future returned by the given transform. + * But if this future or the transform fails, the returned future's outcome is the same throwable. + * In the case where this future fails, the transform is not invoked. + */ +fun CordaFuture.flatMap(transform: (V) -> CordaFuture): CordaFuture = CordaFutureImpl().also { result -> + thenMatch(success@ { + result.captureLater(try { + transform(it) + } catch (t: Throwable) { + result.setException(t) + return@success + }) + }, { + result.setException(it) + }) +} + +/** + * If all of the given futures succeed, the returned future's outcome is a list of all their values. + * The values are in the same order as the futures in the collection, not the order of completion. + * If at least one given future fails, the returned future's outcome is the first throwable that was thrown. + * Any subsequent throwables are added to the first one as suppressed throwables, in the order they are thrown. + * If no futures were given, the returned future has an immediate outcome of empty list. + * Otherwise the returned future does not have an outcome until all given futures have an outcome. + * Unlike Guava's Futures.allAsList, this method never hides failures/hangs subsequent to the first failure. + */ +fun Collection>.transpose(): CordaFuture> { + if (isEmpty()) return doneFuture(emptyList()) + val transpose = CordaFutureImpl>() + val stateLock = Any() + var failure: Throwable? = null + var remaining = size + forEach { + it.then { doneFuture -> + synchronized(stateLock) { + doneFuture.match({}, { throwable -> + if (failure == null) failure = throwable else failure!!.addSuppressed(throwable) + }) + if (--remaining == 0) { + if (failure == null) transpose.set(map { it.getOrThrow() }) else transpose.setException(failure!!) + } + } + } + } + return transpose +} + +/** The contravariant members of [OpenFuture]. */ +interface ValueOrException { + /** @return whether this future actually changed. */ + fun set(value: V): Boolean + + /** @return whether this future actually changed. */ + fun setException(t: Throwable): Boolean + + /** When the given future has an outcome, make this future have the same outcome. */ + fun captureLater(f: CordaFuture) = f.then { capture { f.getOrThrow() } } + + /** Run the given block (in the foreground) and set this future to its outcome. */ + fun capture(block: () -> V): Boolean { + return set(try { + block() + } catch (t: Throwable) { + return setException(t) + }) + } +} + +/** A [CordaFuture] with additional methods to complete it with a value, exception or the outcome of another future. */ +interface OpenFuture : ValueOrException, CordaFuture + +/** Unless you really want this particular implementation, use [openFuture] to make one. */ +@VisibleForTesting +internal class CordaFutureImpl(private val impl: CompletableFuture = CompletableFuture()) : Future by impl, OpenFuture { + companion object { + private val defaultLog = loggerFor>() + internal val listenerFailedMessage = "Future listener failed:" + } + + override fun set(value: V) = impl.complete(value) + override fun setException(t: Throwable) = impl.completeExceptionally(t) + override fun then(callback: (CordaFuture) -> W) = thenImpl(defaultLog, callback) + /** For testing only. */ + internal fun thenImpl(log: Logger, callback: (CordaFuture) -> W) { + impl.whenComplete { _, _ -> + try { + callback(this) + } catch (t: Throwable) { + log.error(listenerFailedMessage, t) + } + } + } + + // We don't simply return impl so that the caller can't interfere with it. + override fun toCompletableFuture() = CompletableFuture().also { completable -> + thenMatch({ + completable.complete(it) + }, { + completable.completeExceptionally(it) + }) + } +} + +internal fun Future.get(timeout: Duration? = null): V = if (timeout == null) get() else get(timeout.toNanos(), TimeUnit.NANOSECONDS) diff --git a/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt b/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt index bf99ebe570..84e30e9140 100644 --- a/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt +++ b/core/src/main/kotlin/net/corda/core/messaging/CordaRPCOps.kt @@ -1,7 +1,6 @@ package net.corda.core.messaging -import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.contracts.Amount +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.UpgradedContract @@ -53,10 +52,7 @@ sealed class StateMachineUpdate { @CordaSerializable data class StateMachineTransactionMapping(val stateMachineRunId: StateMachineRunId, val transactionId: SecureHash) -/** - * RPC operations that the node exposes to clients using the Java client library. These can be called from - * client apps and are implemented by the node in the [net.corda.node.internal.CordaRPCOpsImpl] class. - */ +/** RPC operations that the node exposes to clients. */ interface CordaRPCOps : RPCOps { /** * Returns the RPC protocol version, which is the same the node's Platform Version. Exists since version 1 so guaranteed @@ -70,9 +66,6 @@ interface CordaRPCOps : RPCOps { @RPCReturnsObservables fun stateMachinesFeed(): DataFeed, StateMachineUpdate> - @Deprecated("This function will be removed in a future milestone", ReplaceWith("stateMachinesFeed()")) - fun stateMachinesAndUpdates() = stateMachinesFeed() - /** * Returns a snapshot of vault states for a given query criteria (and optional order and paging specification) * @@ -137,45 +130,33 @@ interface CordaRPCOps : RPCOps { fun vaultTrackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, - contractType: Class): DataFeed, Vault.Update> + contractType: Class): DataFeed, Vault.Update> // DOCEND VaultTrackByAPI // Note: cannot apply @JvmOverloads to interfaces nor interface implementations // Java Helpers // DOCSTART VaultTrackAPIHelpers - fun vaultTrack(contractType: Class): DataFeed, Vault.Update> { + fun vaultTrack(contractType: Class): DataFeed, Vault.Update> { return vaultTrackBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), contractType) } - fun vaultTrackByCriteria(contractType: Class, criteria: QueryCriteria): DataFeed, Vault.Update> { + fun vaultTrackByCriteria(contractType: Class, criteria: QueryCriteria): DataFeed, Vault.Update> { return vaultTrackBy(criteria, PageSpecification(), Sort(emptySet()), contractType) } - fun vaultTrackByWithPagingSpec(contractType: Class, criteria: QueryCriteria, paging: PageSpecification): DataFeed, Vault.Update> { + fun vaultTrackByWithPagingSpec(contractType: Class, criteria: QueryCriteria, paging: PageSpecification): DataFeed, Vault.Update> { return vaultTrackBy(criteria, paging, Sort(emptySet()), contractType) } - fun vaultTrackByWithSorting(contractType: Class, criteria: QueryCriteria, sorting: Sort): DataFeed, Vault.Update> { + fun vaultTrackByWithSorting(contractType: Class, criteria: QueryCriteria, sorting: Sort): DataFeed, Vault.Update> { return vaultTrackBy(criteria, PageSpecification(), sorting, contractType) } // DOCEND VaultTrackAPIHelpers - /** - * Returns a data feed of head states in the vault and an observable of future updates to the vault. - */ - @RPCReturnsObservables - // TODO: Remove this from the interface - @Deprecated("This function will be removed in a future milestone", ReplaceWith("vaultTrackBy(QueryCriteria())")) - fun vaultAndUpdates(): DataFeed>, Vault.Update> - /** * Returns a data feed of all recorded transactions and an observable of future recorded ones. */ @RPCReturnsObservables fun verifiedTransactionsFeed(): DataFeed, SignedTransaction> - @Deprecated("This function will be removed in a future milestone", ReplaceWith("verifiedTransactionFeed()")) - fun verifiedTransactions() = verifiedTransactionsFeed() - - /** * Returns a snapshot list of existing state machine id - recorded transaction hash mappings, and a stream of future * such mappings as well. @@ -183,18 +164,12 @@ interface CordaRPCOps : RPCOps { @RPCReturnsObservables fun stateMachineRecordedTransactionMappingFeed(): DataFeed, StateMachineTransactionMapping> - @Deprecated("This function will be removed in a future milestone", ReplaceWith("stateMachineRecordedTransactionMappingFeed()")) - fun stateMachineRecordedTransactionMapping() = stateMachineRecordedTransactionMappingFeed() - /** * Returns all parties currently visible on the network with their advertised services and an observable of future updates to the network. */ @RPCReturnsObservables fun networkMapFeed(): DataFeed, NetworkMapCache.MapChange> - @Deprecated("This function will be removed in a future milestone", ReplaceWith("networkMapFeed()")) - fun networkMapUpdates() = networkMapFeed() - /** * Start the given flow with the given arguments. [logicType] must be annotated with [net.corda.core.flows.StartableByRPC]. */ @@ -223,12 +198,6 @@ interface CordaRPCOps : RPCOps { */ fun getVaultTransactionNotes(txnId: SecureHash): Iterable - /* - * Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for - * which we have no cash evaluate to null (not present in map), not 0. - */ - fun getCashBalances(): Map> - /** * Checks whether an attachment with the given hash is stored on the node. */ @@ -244,10 +213,6 @@ interface CordaRPCOps : RPCOps { */ fun uploadAttachment(jar: InputStream): SecureHash - // TODO: Remove this from the interface - @Deprecated("This service will be removed in a future milestone") - fun uploadFile(dataType: String, name: String?, file: InputStream): String - /** * Authorise a contract state upgrade. * This will store the upgrade authorisation in the vault, and will be queried by [ContractUpgradeFlow.Acceptor] during contract upgrade process. @@ -268,25 +233,28 @@ interface CordaRPCOps : RPCOps { fun currentNodeTime(): Instant /** - * Returns a [ListenableFuture] which completes when the node has registered wih the network map service. It can also + * Returns a [CordaFuture] which completes when the node has registered wih the network map service. It can also * complete with an exception if it is unable to. */ @RPCReturnsObservables - fun waitUntilRegisteredWithNetworkMap(): ListenableFuture + fun waitUntilRegisteredWithNetworkMap(): CordaFuture // TODO These need rethinking. Instead of these direct calls we should have a way of replicating a subset of // the node's state locally and query that directly. + /** + * Returns the well known identity from an abstract party. This is intended to resolve the well known identity + * from a confidential identity, however it transparently handles returning the well known identity back if + * a well known identity is passed in. + * + * @param party identity to determine well known identity for. + * @return well known identity, if found. + */ + fun partyFromAnonymous(party: AbstractParty): Party? /** * Returns the [Party] corresponding to the given key, if found. */ fun partyFromKey(key: PublicKey): Party? - /** - * Returns the [Party] with the given name as it's [Party.name] - */ - @Deprecated("Use partyFromX500Name instead") - fun partyFromName(name: String): Party? - /** * Returns the [Party] with the X.500 principal as it's [Party.name] */ @@ -321,7 +289,7 @@ inline fun CordaRPCOps.vaultQueryBy(criteria: QueryC inline fun CordaRPCOps.vaultTrackBy(criteria: QueryCriteria = QueryCriteria.VaultQueryCriteria(), paging: PageSpecification = PageSpecification(), - sorting: Sort = Sort(emptySet())): DataFeed, Vault.Update> { + sorting: Sort = Sort(emptySet())): DataFeed, Vault.Update> { return vaultTrackBy(criteria, paging, sorting, T::class.java) } @@ -437,13 +405,4 @@ inline fun > CordaRPCOps.startTrac * The Data feed contains a snapshot of the requested data and an [Observable] of future updates. */ @CordaSerializable -data class DataFeed(val snapshot: A, val updates: Observable) { - @Deprecated("This function will be removed in a future milestone", ReplaceWith("snapshot")) - val first: A get() = snapshot - @Deprecated("This function will be removed in a future milestone", ReplaceWith("updates")) - val second: Observable get() = updates - @Deprecated("This function will be removed in a future milestone", ReplaceWith("snapshot")) - val current: A get() = snapshot - @Deprecated("This function will be removed in a future milestone", ReplaceWith("updates")) - val future: Observable get() = updates -} +data class DataFeed(val snapshot: A, val updates: Observable) diff --git a/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt b/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt index cf10588939..75cdf0b74e 100644 --- a/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt +++ b/core/src/main/kotlin/net/corda/core/messaging/FlowHandle.kt @@ -1,6 +1,6 @@ package net.corda.core.messaging -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.flows.StateMachineRunId import net.corda.core.serialization.CordaSerializable import rx.Observable @@ -9,11 +9,11 @@ import rx.Observable * [FlowHandle] is a serialisable handle for the started flow, parameterised by the type of the flow's return value. * * @property id The started state machine's ID. - * @property returnValue A [ListenableFuture] of the flow's return value. + * @property returnValue A [CordaFuture] of the flow's return value. */ interface FlowHandle
: AutoCloseable { val id: StateMachineRunId - val returnValue: ListenableFuture + val returnValue: CordaFuture /** * Use this function for flows whose returnValue is not going to be used, so as to free up server resources. @@ -41,7 +41,7 @@ interface FlowProgressHandle : FlowHandle { @CordaSerializable data class FlowHandleImpl( override val id: StateMachineRunId, - override val returnValue: ListenableFuture) : FlowHandle { + override val returnValue: CordaFuture) : FlowHandle { // Remember to add @Throws to FlowHandle.close() if this throws an exception. override fun close() { @@ -52,7 +52,7 @@ data class FlowHandleImpl( @CordaSerializable data class FlowProgressHandleImpl( override val id: StateMachineRunId, - override val returnValue: ListenableFuture, + override val returnValue: CordaFuture, override val progress: Observable) : FlowProgressHandle { // Remember to add @Throws to FlowProgressHandle.close() if this throws an exception. diff --git a/core/src/main/kotlin/net/corda/core/node/CordaPluginRegistry.kt b/core/src/main/kotlin/net/corda/core/node/CordaPluginRegistry.kt index 89daa9939d..19b482622f 100644 --- a/core/src/main/kotlin/net/corda/core/node/CordaPluginRegistry.kt +++ b/core/src/main/kotlin/net/corda/core/node/CordaPluginRegistry.kt @@ -1,46 +1,18 @@ package net.corda.core.node +import net.corda.core.contracts.ContractState import net.corda.core.messaging.CordaRPCOps +import net.corda.core.node.services.VaultQueryService import net.corda.core.schemas.MappedSchema +import net.corda.core.schemas.QueryableState import net.corda.core.serialization.SerializationCustomization import java.util.function.Function -import net.corda.core.schemas.QueryableState -import net.corda.core.contracts.ContractState -import net.corda.core.node.services.VaultQueryService /** * Implement this interface on a class advertised in a META-INF/services/net.corda.core.node.CordaPluginRegistry file * to extend a Corda node with additional application services. */ abstract class CordaPluginRegistry { - - @Suppress("unused") - @Deprecated("This is no longer in use, moved to WebServerPluginRegistry class in webserver module", - level = DeprecationLevel.ERROR, replaceWith = ReplaceWith("net.corda.webserver.services.WebServerPluginRegistry")) - open val webApis: List> get() = emptyList() - - - @Suppress("unused") - @Deprecated("This is no longer in use, moved to WebServerPluginRegistry class in webserver module", - level = DeprecationLevel.ERROR, replaceWith = ReplaceWith("net.corda.webserver.services.WebServerPluginRegistry")) - open val staticServeDirs: Map get() = emptyMap() - - @Suppress("unused") - @Deprecated("This is no longer needed. Instead annotate any flows that need to be invoked via RPC with " + - "@StartableByRPC and any scheduled flows with @SchedulableFlow", level = DeprecationLevel.ERROR) - open val requiredFlows: Map> get() = emptyMap() - - /** - * List of lambdas constructing additional long lived services to be hosted within the node. - * They expect a single [PluginServiceHub] parameter as input. - * The [PluginServiceHub] will be fully constructed before the plugin service is created and will - * allow access to the Flow factory and Flow initiation entry points there. - */ - @Suppress("unused") - @Deprecated("This is no longer used. If you need to create your own service, such as an oracle, then use the " + - "@CordaService annotation. For flow registrations use @InitiatedBy.", level = DeprecationLevel.ERROR) - open val servicePlugins: List> get() = emptyList() - /** * Optionally whitelist types for use in object serialization, as we lock down the types that can be serialized. * diff --git a/core/src/main/kotlin/net/corda/core/node/NodeInfo.kt b/core/src/main/kotlin/net/corda/core/node/NodeInfo.kt index 342e9997e1..7b71cdcde2 100644 --- a/core/src/main/kotlin/net/corda/core/node/NodeInfo.kt +++ b/core/src/main/kotlin/net/corda/core/node/NodeInfo.kt @@ -6,6 +6,7 @@ import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType import net.corda.core.serialization.CordaSerializable import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.NonEmptySet /** * Information for an advertised service including the service specific identity information. @@ -21,7 +22,7 @@ data class ServiceEntry(val info: ServiceInfo, val identity: PartyAndCertificate @CordaSerializable data class NodeInfo(val addresses: List, val legalIdentityAndCert: PartyAndCertificate, //TODO This field will be removed in future PR which gets rid of services. - val legalIdentitiesAndCerts: Set, + val legalIdentitiesAndCerts: NonEmptySet, val platformVersion: Int, var advertisedServices: List = emptyList(), val worldMapLocation: WorldMapLocation? = null) { diff --git a/core/src/main/kotlin/net/corda/core/node/PhysicalLocationStructures.kt b/core/src/main/kotlin/net/corda/core/node/PhysicalLocationStructures.kt index 049845432f..754fbbc44e 100644 --- a/core/src/main/kotlin/net/corda/core/node/PhysicalLocationStructures.kt +++ b/core/src/main/kotlin/net/corda/core/node/PhysicalLocationStructures.kt @@ -3,6 +3,8 @@ package net.corda.core.node import net.corda.core.serialization.CordaSerializable import java.util.* +data class ScreenCoordinate(val screenX: Double, val screenY: Double) + /** A latitude/longitude pair. */ @CordaSerializable data class WorldCoordinate(val latitude: Double, val longitude: Double) { @@ -21,7 +23,7 @@ data class WorldCoordinate(val latitude: Double, val longitude: Double) { */ @Suppress("unused") // Used from the visualiser GUI. fun project(screenWidth: Double, screenHeight: Double, topLatitude: Double, bottomLatitude: Double, - leftLongitude: Double, rightLongitude: Double): Pair { + leftLongitude: Double, rightLongitude: Double): ScreenCoordinate { require(latitude in bottomLatitude..topLatitude) require(longitude in leftLongitude..rightLongitude) @@ -33,7 +35,7 @@ data class WorldCoordinate(val latitude: Double, val longitude: Double) { val topLatRel = screenYRelative(topLatitude) val bottomLatRel = screenYRelative(bottomLatitude) fun latitudeToScreenY(lat: Double) = screenHeight * (screenYRelative(lat) - topLatRel) / (bottomLatRel - topLatRel) - return Pair(longitudeToScreenX(longitude), latitudeToScreenY(latitude)) + return ScreenCoordinate(longitudeToScreenX(longitude), latitudeToScreenY(latitude)) } } diff --git a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt index 081aea4d7f..0a885130d9 100644 --- a/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt +++ b/core/src/main/kotlin/net/corda/core/node/ServiceHub.kt @@ -1,13 +1,16 @@ package net.corda.core.node -import com.google.common.collect.Lists import net.corda.core.contracts.* -import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SignableData +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.crypto.TransactionSignature import net.corda.core.node.services.* import net.corda.core.serialization.SerializeAsToken import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import java.security.PublicKey +import java.sql.Connection import java.time.Clock /** @@ -76,7 +79,7 @@ interface ServiceHub : ServicesForResolution { * further processing. This is expected to be run within a database transaction. */ fun recordTransactions(first: SignedTransaction, vararg remaining: SignedTransaction) { - recordTransactions(Lists.asList(first, remaining)) + recordTransactions(listOf(first, *remaining)) } /** @@ -87,7 +90,10 @@ interface ServiceHub : ServicesForResolution { @Throws(TransactionResolutionException::class) override fun loadState(stateRef: StateRef): TransactionState<*> { val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) - return stx.tx.outputs[stateRef.index] + return if (stx.isNotaryChangeTransaction()) { + stx.resolveNotaryChangeTransaction(this).outputs[stateRef.index] + } + else stx.tx.outputs[stateRef.index] } /** @@ -98,7 +104,12 @@ interface ServiceHub : ServicesForResolution { @Throws(TransactionResolutionException::class) fun toStateAndRef(stateRef: StateRef): StateAndRef { val stx = validatedTransactions.getTransaction(stateRef.txhash) ?: throw TransactionResolutionException(stateRef.txhash) - return stx.tx.outRef(stateRef.index) + return if (stx.isNotaryChangeTransaction()) { + stx.resolveNotaryChangeTransaction(this).outRef(stateRef.index) + } + else { + stx.tx.outRef(stateRef.index) + } } /** @@ -115,7 +126,7 @@ interface ServiceHub : ServicesForResolution { /** * Helper property to shorten code for fetching the the [PublicKey] portion of the * Node's Notary signing identity. It is required that the Node hosts a notary service, - * otherwise an IllegalArgumentException will be thrown. + * otherwise an [IllegalArgumentException] will be thrown. * Typical use is during signing in flows and for unit test signing. * When this [PublicKey] is passed into the signing methods below, or on the KeyManagementService * the matching [java.security.PrivateKey] will be looked up internally and used to sign. @@ -124,9 +135,14 @@ interface ServiceHub : ServicesForResolution { */ val notaryIdentityKey: PublicKey get() = this.myInfo.notaryIdentity.owningKey + // Helper method to construct an initial partially signed transaction from a [TransactionBuilder]. + private fun signInitialTransaction(builder: TransactionBuilder, publicKey: PublicKey, signatureMetadata: SignatureMetadata): SignedTransaction { + return builder.toSignedTransaction(keyManagementService, publicKey, signatureMetadata) + } + /** * Helper method to construct an initial partially signed transaction from a [TransactionBuilder] - * using keys stored inside the node. + * using keys stored inside the node. Signature metadata is added automatically. * @param builder The [TransactionBuilder] to seal with the node's signature. * Any existing signatures on the builder will be preserved. * @param publicKey The [PublicKey] matched to the internal [java.security.PrivateKey] to use in signing this transaction. @@ -134,15 +150,12 @@ interface ServiceHub : ServicesForResolution { * to sign with. * @return Returns a SignedTransaction with the new node signature attached. */ - fun signInitialTransaction(builder: TransactionBuilder, publicKey: PublicKey): SignedTransaction { - val sig = keyManagementService.sign(builder.toWireTransaction().id.bytes, publicKey) - builder.addSignatureUnchecked(sig) - return builder.toSignedTransaction(false) - } + fun signInitialTransaction(builder: TransactionBuilder, publicKey: PublicKey) = + signInitialTransaction(builder, publicKey, SignatureMetadata(myInfo.platformVersion, Crypto.findSignatureScheme(publicKey).schemeNumberID)) /** * Helper method to construct an initial partially signed transaction from a TransactionBuilder - * using the default identity key contained in the node. + * using the default identity key contained in the node. The legal Indentity key is used to sign. * @param builder The TransactionBuilder to seal with the node's signature. * Any existing signatures on the builder will be preserved. * @return Returns a SignedTransaction with the new node signature attached. @@ -167,25 +180,30 @@ interface ServiceHub : ServicesForResolution { return stx } + // Helper method to create an additional signature for an existing (partially) [SignedTransaction]. + private fun createSignature(signedTransaction: SignedTransaction, publicKey: PublicKey, signatureMetadata: SignatureMetadata): TransactionSignature { + val signableData = SignableData(signedTransaction.id, signatureMetadata) + return keyManagementService.sign(signableData, publicKey) + } + /** * Helper method to create an additional signature for an existing (partially) [SignedTransaction]. * @param signedTransaction The [SignedTransaction] to which the signature will apply. * @param publicKey The [PublicKey] matching to a signing [java.security.PrivateKey] hosted in the node. * If the [PublicKey] is actually a [net.corda.core.crypto.CompositeKey] the first leaf key found locally will be used * for signing. - * @return The [DigitalSignature.WithKey] generated by signing with the internally held [java.security.PrivateKey]. + * @return The [TransactionSignature] generated by signing with the internally held [java.security.PrivateKey]. */ - fun createSignature(signedTransaction: SignedTransaction, publicKey: PublicKey): DigitalSignature.WithKey { - return keyManagementService.sign(signedTransaction.id.bytes, publicKey) - } + fun createSignature(signedTransaction: SignedTransaction, publicKey: PublicKey) = + createSignature(signedTransaction, publicKey, SignatureMetadata(myInfo.platformVersion, Crypto.findSignatureScheme(publicKey).schemeNumberID)) /** - * Helper method to create an additional signature for an existing (partially) SignedTransaction - * using the default identity signing key of the node. + * Helper method to create an additional signature for an existing (partially) [SignedTransaction] + * using the default identity signing key of the node. The legal identity key is used to sign. * @param signedTransaction The SignedTransaction to which the signature will apply. * @return The DigitalSignature.WithKey generated by signing with the internally held identity PrivateKey. */ - fun createSignature(signedTransaction: SignedTransaction): DigitalSignature.WithKey { + fun createSignature(signedTransaction: SignedTransaction): TransactionSignature { return createSignature(signedTransaction, legalIdentityKey) } @@ -202,10 +220,21 @@ interface ServiceHub : ServicesForResolution { } /** - * Helper method to ap-pend an additional signature for an existing (partially) [SignedTransaction] + * Helper method to append an additional signature for an existing (partially) [SignedTransaction] * using the default identity signing key of the node. * @param signedTransaction The [SignedTransaction] to which the signature will be added. * @return A new [SignedTransaction] with the addition of the new signature. */ fun addSignature(signedTransaction: SignedTransaction): SignedTransaction = addSignature(signedTransaction, legalIdentityKey) + + /** + * Exposes a JDBC connection (session) object using the currently configured database. + * Applications can use this to execute arbitrary SQL queries (native, direct, prepared, callable) + * against its Node database tables (including custom contract tables defined by extending [Queryable]). + * When used within a flow, this session automatically forms part of the enclosing flow transaction boundary, + * and thus queryable data will include everything committed as of the last checkpoint. + * @throws IllegalStateException if called outside of a transaction. + * @return A new [Connection] + */ + fun jdbcSession(): Connection } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/node/services/IdentityService.kt b/core/src/main/kotlin/net/corda/core/node/services/IdentityService.kt index 3e35feb6f2..18084c66d9 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/IdentityService.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/IdentityService.kt @@ -19,25 +19,23 @@ interface IdentityService { val caCertStore: CertStore /** - * Verify and then store a well known identity. + * Verify and then store an identity. * - * @param party a party representing a legal entity. - * @throws IllegalArgumentException if the certificate path is invalid, or if there is already an existing - * certificate chain for the anonymous party. + * @param party a party representing a legal entity and the certificate path linking them to the network trust root. + * @throws IllegalArgumentException if the certificate path is invalid. */ @Throws(CertificateExpiredException::class, CertificateNotYetValidException::class, InvalidAlgorithmParameterException::class) + @Deprecated("Use verifyAndRegisterIdentity() instead, which is the same function with a better name") fun registerIdentity(party: PartyAndCertificate) /** * Verify and then store an identity. * - * @param anonymousParty a party representing a legal entity in a transaction. - * @param path certificate path from the trusted root to the party. - * @throws IllegalArgumentException if the certificate path is invalid, or if there is already an existing - * certificate chain for the anonymous party. + * @param identity a party representing a legal entity and the certificate path linking them to the network trust root. + * @throws IllegalArgumentException if the certificate path is invalid. */ @Throws(CertificateExpiredException::class, CertificateNotYetValidException::class, InvalidAlgorithmParameterException::class) - fun registerAnonymousIdentity(anonymousParty: AnonymousParty, party: Party, path: CertPath) + fun verifyAndRegisterIdentity(identity: PartyAndCertificate) /** * Asserts that an anonymous party maps to the given full party, by looking up the certificate chain associated with @@ -54,6 +52,13 @@ interface IdentityService { */ fun getAllIdentities(): Iterable + /** + * Get the certificate and path for a well known identity's owning key. + * + * @return the party and certificate, or null if unknown. + */ + fun certificateFromKey(owningKey: PublicKey): PartyAndCertificate? + /** * Get the certificate and path for a well known identity. * @@ -66,15 +71,16 @@ interface IdentityService { // but for now this is not supported. fun partyFromKey(key: PublicKey): Party? - @Deprecated("Use partyFromX500Name or partiesFromName") - fun partyFromName(name: String): Party? + fun partyFromX500Name(principal: X500Name): Party? /** - * Resolve the well known identity of a party. If the party passed in is already a well known identity - * (i.e. a [Party]) this returns it as-is. + * Returns the well known identity from an abstract party. This is intended to resolve the well known identity + * from a confidential identity, however it transparently handles returning the well known identity back if + * a well known identity is passed in. * - * @return the well known identity, or null if unknown. + * @param party identity to determine well known identity for. + * @return well known identity, if found. */ fun partyFromAnonymous(party: AbstractParty): Party? @@ -95,11 +101,6 @@ interface IdentityService { */ fun requirePartyFromAnonymous(party: AbstractParty): Party - /** - * Get the certificate chain showing an anonymous party is owned by the given party. - */ - fun pathForAnonymous(anonymousParty: AnonymousParty): CertPath? - /** * Returns a list of candidate matches for a given string, with optional fuzzy(ish) matching. Fuzzy matching may * get smarter with time e.g. to correct spelling errors, so you should not hard-code indexes into the results diff --git a/core/src/main/kotlin/net/corda/core/node/services/KeyManagementService.kt b/core/src/main/kotlin/net/corda/core/node/services/KeyManagementService.kt new file mode 100644 index 0000000000..ed2d106bae --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/node/services/KeyManagementService.kt @@ -0,0 +1,66 @@ +package net.corda.core.node.services + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.SignableData +import net.corda.core.crypto.TransactionSignature +import net.corda.core.identity.PartyAndCertificate +import java.security.PublicKey + +/** + * The KMS is responsible for storing and using private keys to sign things. An implementation of this may, for example, + * call out to a hardware security module that enforces various auditing and frequency-of-use requirements. + */ +interface KeyManagementService { + /** + * Returns a snapshot of the current signing [PublicKey]s. + * For each of these keys a [PrivateKey] is available, that can be used later for signing. + */ + val keys: Set + + /** + * Generates a new random [KeyPair] and adds it to the internal key storage. Returns the public part of the pair. + */ + @Suspendable + fun freshKey(): PublicKey + + /** + * Generates a new random [KeyPair], adds it to the internal key storage, then generates a corresponding + * [X509Certificate] and adds it to the identity service. + * + * @param identity identity to generate a key and certificate for. Must be an identity this node has CA privileges for. + * @param revocationEnabled whether to check revocation status of certificates in the certificate path. + * @return X.509 certificate and path to the trust root. + */ + @Suspendable + fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): PartyAndCertificate + + /** + * Filter some keys down to the set that this node owns (has private keys for). + * + * @param candidateKeys keys which this node may own. + */ + fun filterMyKeys(candidateKeys: Iterable): Iterable + + /** + * Using the provided signing [PublicKey] internally looks up the matching [PrivateKey] and signs the data. + * @param bytes The data to sign over using the chosen key. + * @param publicKey The [PublicKey] partner to an internally held [PrivateKey], either derived from the node's primary identity, + * or previously generated via the [freshKey] method. + * If the [PublicKey] is actually a [CompositeKey] the first leaf signing key hosted by the node is used. + * @throws IllegalArgumentException if the input key is not a member of [keys]. + */ + @Suspendable + fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey + + /** + * Using the provided signing [PublicKey] internally looks up the matching [PrivateKey] and signs the [SignableData]. + * @param signableData a wrapper over transaction id (Merkle root) and signature metadata. + * @param publicKey The [PublicKey] partner to an internally held [PrivateKey], either derived from the node's primary identity, + * or previously generated via the [freshKey] method. + * If the [PublicKey] is actually a [CompositeKey] the first leaf signing key hosted by the node is used. + * @throws IllegalArgumentException if the input key is not a member of [keys]. + */ + @Suspendable + fun sign(signableData: SignableData, publicKey: PublicKey): TransactionSignature +} diff --git a/core/src/main/kotlin/net/corda/core/node/services/NetworkMapCache.kt b/core/src/main/kotlin/net/corda/core/node/services/NetworkMapCache.kt index cda2fc7c5f..33fbd1f10e 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/NetworkMapCache.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/NetworkMapCache.kt @@ -1,13 +1,12 @@ package net.corda.core.node.services -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.Contract import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.internal.randomOrNull import net.corda.core.messaging.DataFeed import net.corda.core.node.NodeInfo -import net.corda.core.node.ServiceHub -import net.corda.core.randomOrNull import net.corda.core.serialization.CordaSerializable import org.bouncycastle.asn1.x500.X500Name import rx.Observable @@ -45,7 +44,7 @@ interface NetworkMapCache { /** Tracks changes to the network map cache */ val changed: Observable /** Future to track completion of the NetworkMapService registration. */ - val mapServiceRegistered: ListenableFuture + val mapServiceRegistered: CordaFuture /** * Atomically get the current party nodes and a stream of updates. Note that the Observable buffers updates until the @@ -98,9 +97,9 @@ interface NetworkMapCache { /** Gets a notary identity by the given name. */ fun getNotary(principal: X500Name): Party? { - val notaryNode = notaryNodes.randomOrNull { + val notaryNode = notaryNodes.filter { it.advertisedServices.any { it.info.type.isSubTypeOf(ServiceType.notary) && it.info.name == principal } - } + }.randomOrNull() return notaryNode?.notaryIdentity } @@ -120,6 +119,20 @@ interface NetworkMapCache { return nodes.randomOrNull()?.notaryIdentity } + /** + * Returns a service identity advertised by one of the nodes on the network + * @param type Specifies the type of the service + */ + fun getAnyServiceOfType(type: ServiceType): Party? { + for (node in partyNodes) { + val serviceIdentities = node.serviceIdentities(type) + if (serviceIdentities.isNotEmpty()) { + return serviceIdentities.randomOrNull() + } + } + return null; + } + /** Checks whether a given party is an advertised notary identity */ fun isNotary(party: Party): Boolean = notaryNodes.any { it.notaryIdentity == party } diff --git a/core/src/main/kotlin/net/corda/core/node/services/NotaryService.kt b/core/src/main/kotlin/net/corda/core/node/services/NotaryService.kt index aa3def742c..03b6803124 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/NotaryService.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/NotaryService.kt @@ -2,17 +2,16 @@ package net.corda.core.node.services import net.corda.core.contracts.StateRef import net.corda.core.contracts.TimeWindow -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.SignedData +import net.corda.core.crypto.* import net.corda.core.flows.FlowLogic +import net.corda.core.flows.NotaryError +import net.corda.core.flows.NotaryException +import net.corda.core.flows.NotaryFlow import net.corda.core.identity.Party import net.corda.core.node.ServiceHub import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.serialize import net.corda.core.utilities.loggerFor -import net.corda.flows.NotaryError -import net.corda.flows.NotaryException import org.slf4j.Logger abstract class NotaryService : SingletonSerializeAsToken() { @@ -23,11 +22,9 @@ abstract class NotaryService : SingletonSerializeAsToken() { /** * Produces a notary service flow which has the corresponding sends and receives as [NotaryFlow.Client]. - * The first parameter is the client [Party] making the request and the second is the platform version - * of the client's node. Use this version parameter to provide backwards compatibility if the notary flow protocol - * changes. + * @param otherParty client [Party] making the request */ - abstract fun createServiceFlow(otherParty: Party, platformVersion: Int): FlowLogic + abstract fun createServiceFlow(otherParty: Party): FlowLogic } /** @@ -75,4 +72,9 @@ abstract class TrustedAuthorityNotaryService : NotaryService() { fun sign(bits: ByteArray): DigitalSignature.WithKey { return services.keyManagementService.sign(bits, services.notaryIdentityKey) } + + fun sign(txId: SecureHash): TransactionSignature { + val signableData = SignableData(txId, SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(services.notaryIdentityKey).schemeNumberID)) + return services.keyManagementService.sign(signableData, services.notaryIdentityKey) + } } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/node/services/Services.kt b/core/src/main/kotlin/net/corda/core/node/services/Services.kt deleted file mode 100644 index 6d1a486105..0000000000 --- a/core/src/main/kotlin/net/corda/core/node/services/Services.kt +++ /dev/null @@ -1,545 +0,0 @@ -package net.corda.core.node.services - -import co.paralleluniverse.fibers.Suspendable -import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.contracts.* -import net.corda.core.crypto.composite.CompositeKey -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.keys -import net.corda.core.flows.FlowException -import net.corda.core.identity.AbstractParty -import net.corda.core.identity.Party -import net.corda.core.identity.PartyAndCertificate -import net.corda.core.messaging.DataFeed -import net.corda.core.node.services.vault.PageSpecification -import net.corda.core.node.services.vault.QueryCriteria -import net.corda.core.node.services.vault.Sort -import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE -import net.corda.core.serialization.CordaSerializable -import net.corda.core.utilities.OpaqueBytes -import net.corda.core.toFuture -import net.corda.core.transactions.LedgerTransaction -import net.corda.core.transactions.TransactionBuilder -import net.corda.core.transactions.WireTransaction -import net.corda.flows.AnonymisedIdentity -import rx.Observable -import rx.subjects.PublishSubject -import java.io.InputStream -import java.security.PublicKey -import java.security.cert.X509Certificate -import java.time.Instant -import java.util.* - -/** - * Session ID to use for services listening for the first message in a session (before a - * specific session ID has been established). - */ -val DEFAULT_SESSION_ID = 0L - -/** - * This file defines various 'services' which are not currently fleshed out. A service is a module that provides - * immutable snapshots of data that may be changing in response to user or network events. - */ - -/** - * A vault (name may be temporary) wraps a set of states that are useful for us to keep track of, for instance, - * because we own them. This class represents an immutable, stable state of a vault: it is guaranteed not to - * change out from underneath you, even though the canonical currently-best-known vault may change as we learn - * about new transactions from our peers and generate new transactions that consume states ourselves. - * - * This abstract class has no references to Cash contracts. - * - * [states] Holds a [VaultService] queried subset of states that are *active* and *relevant*. - * Active means they haven't been consumed yet (or we don't know about it). - * Relevant means they contain at least one of our pubkeys. - */ -@CordaSerializable -class Vault(val states: Iterable>) { - - /** - * Represents an update observed by the vault that will be notified to observers. Include the [StateRef]s of - * transaction outputs that were consumed (inputs) and the [ContractState]s produced (outputs) to/by the transaction - * or transactions observed and the vault. - * - * If the vault observes multiple transactions simultaneously, where some transactions consume the outputs of some of the - * other transactions observed, then the changes are observed "net" of those. - */ - @CordaSerializable - data class Update(val consumed: Set>, val produced: Set>, val flowId: UUID? = null) { - /** Checks whether the update contains a state of the specified type. */ - inline fun containsType() = consumed.any { it.state.data is T } || produced.any { it.state.data is T } - - /** Checks whether the update contains a state of the specified type and state status */ - fun containsType(clazz: Class, status: StateStatus) = - when (status) { - StateStatus.UNCONSUMED -> produced.any { clazz.isAssignableFrom(it.state.data.javaClass) } - StateStatus.CONSUMED -> consumed.any { clazz.isAssignableFrom(it.state.data.javaClass) } - else -> consumed.any { clazz.isAssignableFrom(it.state.data.javaClass) } - || produced.any { clazz.isAssignableFrom(it.state.data.javaClass) } - } - - /** - * Combine two updates into a single update with the combined inputs and outputs of the two updates but net - * any outputs of the left-hand-side (this) that are consumed by the inputs of the right-hand-side (rhs). - * - * i.e. the net effect in terms of state live-ness of receiving the combined update is the same as receiving this followed by rhs. - */ - operator fun plus(rhs: Update): Update { - val combined = Vault.Update( - consumed + (rhs.consumed - produced), - // The ordering below matters to preserve ordering of consumed/produced Sets when they are insertion order dependent implementations. - produced.filter { it !in rhs.consumed }.toSet() + rhs.produced) - return combined - } - - override fun toString(): String { - val sb = StringBuilder() - sb.appendln("${consumed.size} consumed, ${produced.size} produced") - sb.appendln("") - sb.appendln("Produced:") - produced.forEach { - sb.appendln("${it.ref}: ${it.state}") - } - return sb.toString() - } - } - - companion object { - val NoUpdate = Update(emptySet(), emptySet()) - } - - @CordaSerializable - enum class StateStatus { - UNCONSUMED, CONSUMED, ALL - } - - /** - * Returned in queries [VaultService.queryBy] and [VaultService.trackBy]. - * A Page contains: - * 1) a [List] of actual [StateAndRef] requested by the specified [QueryCriteria] to a maximum of [MAX_PAGE_SIZE] - * 2) a [List] of associated [Vault.StateMetadata], one per [StateAndRef] result - * 3) a total number of states that met the given [QueryCriteria] if a [PageSpecification] was provided - * (otherwise defaults to -1) - * 4) Status types used in this query: UNCONSUMED, CONSUMED, ALL - * 5) Other results as a [List] of any type (eg. aggregate function results with/without group by) - * - * Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata - * results will be empty) - */ - @CordaSerializable - data class Page(val states: List>, - val statesMetadata: List, - val totalStatesAvailable: Long, - val stateTypes: StateStatus, - val otherResults: List) - - @CordaSerializable - data class StateMetadata(val ref: StateRef, - val contractStateClassName: String, - val recordedTime: Instant, - val consumedTime: Instant?, - val status: Vault.StateStatus, - val notaryName: String, - val notaryKey: String, - val lockId: String?, - val lockUpdateTime: Instant?) -} - -/** - * A [VaultService] is responsible for securely and safely persisting the current state of a vault to storage. The - * vault service vends immutable snapshots of the current vault for working with: if you build a transaction based - * on a vault that isn't current, be aware that it may end up being invalid if the states that were used have been - * consumed by someone else first! - * - * Note that transactions we've seen are held by the storage service, not the vault. - */ -interface VaultService { - - /** - * Prefer the use of [updates] unless you know why you want to use this instead. - * - * Get a synchronous Observable of updates. When observations are pushed to the Observer, the Vault will already incorporate - * the update, and the database transaction associated with the update will still be open and current. If for some - * reason the processing crosses outside of the database transaction (for example, the update is pushed outside the current - * JVM or across to another [Thread] which is executing in a different database transaction) then the Vault may - * not incorporate the update due to racing with committing the current database transaction. - */ - val rawUpdates: Observable - - /** - * Get a synchronous Observable of updates. When observations are pushed to the Observer, the Vault will already incorporate - * the update, and the database transaction associated with the update will have been committed and closed. - */ - val updates: Observable - - /** - * Enable creation of observables of updates. - */ - val updatesPublisher: PublishSubject - - /** - * Returns a map of how much cash we have in each currency, ignoring details like issuer. Note: currencies for - * which we have no cash evaluate to null (not present in map), not 0. - */ - val cashBalances: Map> - - /** - * Atomically get the current vault and a stream of updates. Note that the Observable buffers updates until the - * first subscriber is registered so as to avoid racing with early updates. - */ - // TODO: Remove this from the interface - @Deprecated("This function will be removed in a future milestone", ReplaceWith("trackBy(QueryCriteria())")) - fun track(): DataFeed, Vault.Update> - - /** - * Return unconsumed [ContractState]s for a given set of [StateRef]s - */ - // TODO: Remove this from the interface - @Deprecated("This function will be removed in a future milestone", ReplaceWith("queryBy(VaultQueryCriteria(stateRefs = listOf()))")) - fun statesForRefs(refs: List): Map?> - - /** - * Possibly update the vault by marking as spent states that these transactions consume, and adding any relevant - * new states that they create. You should only insert transactions that have been successfully verified here! - * - * TODO: Consider if there's a good way to enforce the must-be-verified requirement in the type system. - */ - fun notifyAll(txns: Iterable) - - /** Same as notifyAll but with a single transaction. */ - fun notify(tx: WireTransaction) = notifyAll(listOf(tx)) - - /** - * Provide a [Future] for when a [StateRef] is consumed, which can be very useful in building tests. - */ - fun whenConsumed(ref: StateRef): ListenableFuture { - return updates.filter { it.consumed.any { it.ref == ref } }.toFuture() - } - - /** Get contracts we would be willing to upgrade the suggested contract to. */ - // TODO: We need a better place to put business logic functions - fun getAuthorisedContractUpgrade(ref: StateRef): Class>? - - /** - * Authorise a contract state upgrade. - * This will store the upgrade authorisation in the vault, and will be queried by [ContractUpgradeFlow.Acceptor] during contract upgrade process. - * Invoking this method indicate the node is willing to upgrade the [state] using the [upgradedContractClass]. - * This method will NOT initiate the upgrade process. To start the upgrade process, see [ContractUpgradeFlow.Instigator]. - */ - fun authoriseContractUpgrade(stateAndRef: StateAndRef<*>, upgradedContractClass: Class>) - - /** - * Authorise a contract state upgrade. - * This will remove the upgrade authorisation from the vault. - */ - fun deauthoriseContractUpgrade(stateAndRef: StateAndRef<*>) - - /** - * Add a note to an existing [LedgerTransaction] given by its unique [SecureHash] id - * Multiple notes may be attached to the same [LedgerTransaction]. - * These are additively and immutably persisted within the node local vault database in a single textual field - * using a semi-colon separator - */ - fun addNoteToTransaction(txnId: SecureHash, noteText: String) - - fun getTransactionNotes(txnId: SecureHash): Iterable - - /** - * Generate a transaction that moves an amount of currency to the given pubkey. - * - * Note: an [Amount] of [Currency] is only fungible for a given Issuer Party within a [FungibleAsset] - * - * @param tx A builder, which may contain inputs, outputs and commands already. The relevant components needed - * to move the cash will be added on top. - * @param amount How much currency to send. - * @param to a key of the recipient. - * @param onlyFromParties if non-null, the asset states will be filtered to only include those issued by the set - * of given parties. This can be useful if the party you're trying to pay has expectations - * about which type of asset claims they are willing to accept. - * @return A [Pair] of the same transaction builder passed in as [tx], and the list of keys that need to sign - * the resulting transaction for it to be valid. - * @throws InsufficientBalanceException when a cash spending transaction fails because - * there is insufficient quantity for a given currency (and optionally set of Issuer Parties). - */ - @Throws(InsufficientBalanceException::class) - @Suspendable - fun generateSpend(tx: TransactionBuilder, - amount: Amount, - to: AbstractParty, - onlyFromParties: Set? = null): Pair> - - // DOCSTART VaultStatesQuery - /** - * Return [ContractState]s of a given [Contract] type and [Iterable] of [Vault.StateStatus]. - * Optionally may specify whether to include [StateRef] that have been marked as soft locked (default is true) - */ - // TODO: Remove this from the interface - @Deprecated("This function will be removed in a future milestone", ReplaceWith("queryBy(QueryCriteria())")) - fun states(clazzes: Set>, statuses: EnumSet, includeSoftLockedStates: Boolean = true): Iterable> - // DOCEND VaultStatesQuery - - /** - * Soft locking is used to prevent multiple transactions trying to use the same output simultaneously. - * Violation of a soft lock would result in a double spend being created and rejected by the notary. - */ - - // DOCSTART SoftLockAPI - - /** - * Reserve a set of [StateRef] for a given [UUID] unique identifier. - * Typically, the unique identifier will refer to a Flow lockId associated with a [Transaction] in an in-flight flow. - * In the case of coin selection, soft locks are automatically taken upon gathering relevant unconsumed input refs. - * - * @throws [StatesNotAvailableException] when not possible to softLock all of requested [StateRef] - */ - @Throws(StatesNotAvailableException::class) - fun softLockReserve(lockId: UUID, stateRefs: Set) - - /** - * Release all or an explicitly specified set of [StateRef] for a given [UUID] unique identifier. - * A vault soft lock manager is automatically notified of a Flows that are terminated, such that any soft locked states - * may be released. - * In the case of coin selection, softLock are automatically released once previously gathered unconsumed input refs - * are consumed as part of cash spending. - */ - fun softLockRelease(lockId: UUID, stateRefs: Set? = null) - - /** - * Retrieve softLockStates for a given [UUID] or return all softLockStates in vault for a given - * [ContractState] type - */ - fun softLockedStates(lockId: UUID? = null): List> - - // DOCEND SoftLockAPI - - /** - * TODO: this function should be private to the vault, but currently Cash Exit functionality - * is implemented in a separate module (finance) and requires access to it. - */ - @Suspendable - fun unconsumedStatesForSpending(amount: Amount, onlyFromIssuerParties: Set? = null, notary: Party? = null, lockId: UUID, withIssuerRefs: Set? = null): List> -} - -// TODO: Remove this from the interface -@Deprecated("This function will be removed in a future milestone", ReplaceWith("queryBy(VaultQueryCriteria())")) -inline fun VaultService.unconsumedStates(includeSoftLockedStates: Boolean = true): Iterable> = - states(setOf(T::class.java), EnumSet.of(Vault.StateStatus.UNCONSUMED), includeSoftLockedStates) - -// TODO: Remove this from the interface -@Deprecated("This function will be removed in a future milestone", ReplaceWith("queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED))")) -inline fun VaultService.consumedStates(): Iterable> = - states(setOf(T::class.java), EnumSet.of(Vault.StateStatus.CONSUMED)) - -/** Returns the [linearState] heads only when the type of the state would be considered an 'instanceof' the given type. */ -// TODO: Remove this from the interface -@Deprecated("This function will be removed in a future milestone", ReplaceWith("queryBy(LinearStateQueryCriteria(linearId = listOf()))")) -inline fun VaultService.linearHeadsOfType() = - states(setOf(T::class.java), EnumSet.of(Vault.StateStatus.UNCONSUMED)) - .associateBy { it.state.data.linearId }.mapValues { it.value } - -class StatesNotAvailableException(override val message: String?, override val cause: Throwable? = null) : FlowException(message, cause) { - override fun toString() = "Soft locking error: $message" -} - -interface VaultQueryService { - - // DOCSTART VaultQueryAPI - /** - * Generic vault query function which takes a [QueryCriteria] object to define filters, - * optional [PageSpecification] and optional [Sort] modification criteria (default unsorted), - * and returns a [Vault.Page] object containing the following: - * 1. states as a List of (page number and size defined by [PageSpecification]) - * 2. states metadata as a List of [Vault.StateMetadata] held in the Vault States table. - * 3. total number of results available if [PageSpecification] supplied (otherwise returns -1) - * 4. status types used in this query: UNCONSUMED, CONSUMED, ALL - * 5. other results (aggregate functions with/without using value groups) - * - * @throws VaultQueryException if the query cannot be executed for any reason - * (missing criteria or parsing error, paging errors, unsupported query, underlying database error) - * - * Notes - * If no [PageSpecification] is provided, a maximum of [DEFAULT_PAGE_SIZE] results will be returned. - * API users must specify a [PageSpecification] if they are expecting more than [DEFAULT_PAGE_SIZE] results, - * otherwise a [VaultQueryException] will be thrown alerting to this condition. - * It is the responsibility of the API user to request further pages and/or specify a more suitable [PageSpecification]. - */ - @Throws(VaultQueryException::class) - fun _queryBy(criteria: QueryCriteria, - paging: PageSpecification, - sorting: Sort, - contractType: Class): Vault.Page - /** - * Generic vault query function which takes a [QueryCriteria] object to define filters, - * optional [PageSpecification] and optional [Sort] modification criteria (default unsorted), - * and returns a [Vault.PageAndUpdates] object containing - * 1) a snapshot as a [Vault.Page] (described previously in [queryBy]) - * 2) an [Observable] of [Vault.Update] - * - * @throws VaultQueryException if the query cannot be executed for any reason - * - * Notes: the snapshot part of the query adheres to the same behaviour as the [queryBy] function. - * the [QueryCriteria] applies to both snapshot and deltas (streaming updates). - */ - @Throws(VaultQueryException::class) - fun _trackBy(criteria: QueryCriteria, - paging: PageSpecification, - sorting: Sort, - contractType: Class): DataFeed, Vault.Update> - // DOCEND VaultQueryAPI - - // Note: cannot apply @JvmOverloads to interfaces nor interface implementations - // Java Helpers - fun queryBy(contractType: Class): Vault.Page { - return _queryBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), contractType) - } - fun queryBy(contractType: Class, criteria: QueryCriteria): Vault.Page { - return _queryBy(criteria, PageSpecification(), Sort(emptySet()), contractType) - } - fun queryBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification): Vault.Page { - return _queryBy(criteria, paging, Sort(emptySet()), contractType) - } - fun queryBy(contractType: Class, criteria: QueryCriteria, sorting: Sort): Vault.Page { - return _queryBy(criteria, PageSpecification(), sorting, contractType) - } - fun queryBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): Vault.Page { - return _queryBy(criteria, paging, sorting, contractType) - } - - fun trackBy(contractType: Class): DataFeed, Vault.Update> { - return _trackBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), contractType) - } - fun trackBy(contractType: Class, criteria: QueryCriteria): DataFeed, Vault.Update> { - return _trackBy(criteria, PageSpecification(), Sort(emptySet()), contractType) - } - fun trackBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification): DataFeed, Vault.Update> { - return _trackBy(criteria, paging, Sort(emptySet()), contractType) - } - fun trackBy(contractType: Class, criteria: QueryCriteria, sorting: Sort): DataFeed, Vault.Update> { - return _trackBy(criteria, PageSpecification(), sorting, contractType) - } - fun trackBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): DataFeed, Vault.Update> { - return _trackBy(criteria, paging, sorting, contractType) - } -} - -inline fun VaultQueryService.queryBy(): Vault.Page { - return _queryBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), T::class.java) -} - -inline fun VaultQueryService.queryBy(criteria: QueryCriteria): Vault.Page { - return _queryBy(criteria, PageSpecification(), Sort(emptySet()), T::class.java) -} - -inline fun VaultQueryService.queryBy(criteria: QueryCriteria, paging: PageSpecification): Vault.Page { - return _queryBy(criteria, paging, Sort(emptySet()), T::class.java) -} - -inline fun VaultQueryService.queryBy(criteria: QueryCriteria, sorting: Sort): Vault.Page { - return _queryBy(criteria, PageSpecification(), sorting, T::class.java) -} - -inline fun VaultQueryService.queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): Vault.Page { - return _queryBy(criteria, paging, sorting, T::class.java) -} - -inline fun VaultQueryService.trackBy(): DataFeed, Vault.Update> { - return _trackBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), T::class.java) -} - -inline fun VaultQueryService.trackBy(criteria: QueryCriteria): DataFeed, Vault.Update> { - return _trackBy(criteria, PageSpecification(), Sort(emptySet()), T::class.java) -} - -inline fun VaultQueryService.trackBy(criteria: QueryCriteria, paging: PageSpecification): DataFeed, Vault.Update> { - return _trackBy(criteria, paging, Sort(emptySet()), T::class.java) -} - -inline fun VaultQueryService.trackBy(criteria: QueryCriteria, sorting: Sort): DataFeed, Vault.Update> { - return _trackBy(criteria, PageSpecification(), sorting, T::class.java) -} - -inline fun VaultQueryService.trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): DataFeed, Vault.Update> { - return _trackBy(criteria, paging, sorting, T::class.java) -} - -class VaultQueryException(description: String) : FlowException("$description") - -/** - * The KMS is responsible for storing and using private keys to sign things. An implementation of this may, for example, - * call out to a hardware security module that enforces various auditing and frequency-of-use requirements. - */ - -interface KeyManagementService { - /** - * Returns a snapshot of the current signing [PublicKey]s. - * For each of these keys a [PrivateKey] is available, that can be used later for signing. - */ - val keys: Set - - /** - * Generates a new random [KeyPair] and adds it to the internal key storage. Returns the public part of the pair. - */ - @Suspendable - fun freshKey(): PublicKey - - /** - * Generates a new random [KeyPair], adds it to the internal key storage, then generates a corresponding - * [X509Certificate] and adds it to the identity service. - * - * @param identity identity to generate a key and certificate for. Must be an identity this node has CA privileges for. - * @param revocationEnabled whether to check revocation status of certificates in the certificate path. - * @return X.509 certificate and path to the trust root. - */ - @Suspendable - fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymisedIdentity - - /** - * Filter some keys down to the set that this node owns (has private keys for). - * - * @param candidateKeys keys which this node may own. - */ - fun filterMyKeys(candidateKeys: Iterable): Iterable - - /** - * Using the provided signing [PublicKey] internally looks up the matching [PrivateKey] and signs the data. - * @param bytes The data to sign over using the chosen key. - * @param publicKey The [PublicKey] partner to an internally held [PrivateKey], either derived from the node's primary identity, - * or previously generated via the [freshKey] method. - * If the [PublicKey] is actually a [CompositeKey] the first leaf signing key hosted by the node is used. - * @throws IllegalArgumentException if the input key is not a member of [keys]. - * TODO A full [KeyManagementService] implementation needs to record activity to the [AuditService] and to limit signing to - * appropriately authorised contexts and initiating users. - */ - @Suspendable - fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey -} - -/** - * An interface that denotes a service that can accept file uploads. - */ -// TODO This is no longer used and can be removed -interface FileUploader { - /** - * Accepts the data in the given input stream, and returns some sort of useful return message that will be sent - * back to the user in the response. - */ - fun upload(file: InputStream): String - - /** - * Check if this service accepts this type of upload. For example if you are uploading interest rates this could - * be "my-service-interest-rates". Type here does not refer to file extentions or MIME types. - */ - fun accepts(type: String): Boolean -} - -/** - * Provides verification service. The implementation may be a simple in-memory verify() call or perhaps an IPC/RPC. - */ -interface TransactionVerifierService { - /** - * @param transaction The transaction to be verified. - * @return A future that completes successfully if the transaction verified, or sets an exception the verifier threw. - */ - fun verify(transaction: LedgerTransaction): ListenableFuture<*> -} diff --git a/core/src/main/kotlin/net/corda/core/node/services/TimeWindowChecker.kt b/core/src/main/kotlin/net/corda/core/node/services/TimeWindowChecker.kt index b06ee522af..fe1e95d96b 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/TimeWindowChecker.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/TimeWindowChecker.kt @@ -1,26 +1,11 @@ package net.corda.core.node.services import net.corda.core.contracts.TimeWindow -import net.corda.core.seconds -import net.corda.core.until import java.time.Clock -import java.time.Duration /** - * Checks if the given time-window falls within the allowed tolerance interval. + * Checks if the current instant provided by the input clock falls within the provided time-window. */ -class TimeWindowChecker(val clock: Clock = Clock.systemUTC(), - val tolerance: Duration = 30.seconds) { - fun isValid(timeWindow: TimeWindow): Boolean { - val untilTime = timeWindow.untilTime - val fromTime = timeWindow.fromTime - - val now = clock.instant() - - // We don't need to test for (fromTime == null && untilTime == null) or backwards bounds because the TimeWindow - // constructor already checks that. - if (untilTime != null && untilTime until now > tolerance) return false - if (fromTime != null && now until fromTime > tolerance) return false - return true - } +class TimeWindowChecker(val clock: Clock = Clock.systemUTC()) { + fun isValid(timeWindow: TimeWindow): Boolean = clock.instant() in timeWindow } diff --git a/core/src/main/kotlin/net/corda/core/node/services/TransactionVerifierService.kt b/core/src/main/kotlin/net/corda/core/node/services/TransactionVerifierService.kt new file mode 100644 index 0000000000..99f40a0492 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/node/services/TransactionVerifierService.kt @@ -0,0 +1,15 @@ +package net.corda.core.node.services + +import net.corda.core.concurrent.CordaFuture +import net.corda.core.transactions.LedgerTransaction + +/** + * Provides verification service. The implementation may be a simple in-memory verify() call or perhaps an IPC/RPC. + */ +interface TransactionVerifierService { + /** + * @param transaction The transaction to be verified. + * @return A future that completes successfully if the transaction verified, or sets an exception the verifier threw. + */ + fun verify(transaction: LedgerTransaction): CordaFuture<*> +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/node/services/VaultQueryService.kt b/core/src/main/kotlin/net/corda/core/node/services/VaultQueryService.kt new file mode 100644 index 0000000000..884fc7bb3f --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/node/services/VaultQueryService.kt @@ -0,0 +1,140 @@ +package net.corda.core.node.services + +import net.corda.core.contracts.ContractState +import net.corda.core.flows.FlowException +import net.corda.core.messaging.DataFeed +import net.corda.core.node.services.vault.PageSpecification +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.Sort + +interface VaultQueryService { + + // DOCSTART VaultQueryAPI + /** + * Generic vault query function which takes a [QueryCriteria] object to define filters, + * optional [PageSpecification] and optional [Sort] modification criteria (default unsorted), + * and returns a [Vault.Page] object containing the following: + * 1. states as a List of (page number and size defined by [PageSpecification]) + * 2. states metadata as a List of [Vault.StateMetadata] held in the Vault States table. + * 3. total number of results available if [PageSpecification] supplied (otherwise returns -1) + * 4. status types used in this query: UNCONSUMED, CONSUMED, ALL + * 5. other results (aggregate functions with/without using value groups) + * + * @throws VaultQueryException if the query cannot be executed for any reason + * (missing criteria or parsing error, paging errors, unsupported query, underlying database error) + * + * Notes + * If no [PageSpecification] is provided, a maximum of [DEFAULT_PAGE_SIZE] results will be returned. + * API users must specify a [PageSpecification] if they are expecting more than [DEFAULT_PAGE_SIZE] results, + * otherwise a [VaultQueryException] will be thrown alerting to this condition. + * It is the responsibility of the API user to request further pages and/or specify a more suitable [PageSpecification]. + */ + @Throws(VaultQueryException::class) + fun _queryBy(criteria: QueryCriteria, + paging: PageSpecification, + sorting: Sort, + contractType: Class): Vault.Page + + /** + * Generic vault query function which takes a [QueryCriteria] object to define filters, + * optional [PageSpecification] and optional [Sort] modification criteria (default unsorted), + * and returns a [Vault.PageAndUpdates] object containing + * 1) a snapshot as a [Vault.Page] (described previously in [queryBy]) + * 2) an [Observable] of [Vault.Update] + * + * @throws VaultQueryException if the query cannot be executed for any reason + * + * Notes: the snapshot part of the query adheres to the same behaviour as the [queryBy] function. + * the [QueryCriteria] applies to both snapshot and deltas (streaming updates). + */ + @Throws(VaultQueryException::class) + fun _trackBy(criteria: QueryCriteria, + paging: PageSpecification, + sorting: Sort, + contractType: Class): DataFeed, Vault.Update> + // DOCEND VaultQueryAPI + + // Note: cannot apply @JvmOverloads to interfaces nor interface implementations + // Java Helpers + fun queryBy(contractType: Class): Vault.Page { + return _queryBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), contractType) + } + + fun queryBy(contractType: Class, criteria: QueryCriteria): Vault.Page { + return _queryBy(criteria, PageSpecification(), Sort(emptySet()), contractType) + } + + fun queryBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification): Vault.Page { + return _queryBy(criteria, paging, Sort(emptySet()), contractType) + } + + fun queryBy(contractType: Class, criteria: QueryCriteria, sorting: Sort): Vault.Page { + return _queryBy(criteria, PageSpecification(), sorting, contractType) + } + + fun queryBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): Vault.Page { + return _queryBy(criteria, paging, sorting, contractType) + } + + fun trackBy(contractType: Class): DataFeed, Vault.Update> { + return _trackBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), contractType) + } + + fun trackBy(contractType: Class, criteria: QueryCriteria): DataFeed, Vault.Update> { + return _trackBy(criteria, PageSpecification(), Sort(emptySet()), contractType) + } + + fun trackBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification): DataFeed, Vault.Update> { + return _trackBy(criteria, paging, Sort(emptySet()), contractType) + } + + fun trackBy(contractType: Class, criteria: QueryCriteria, sorting: Sort): DataFeed, Vault.Update> { + return _trackBy(criteria, PageSpecification(), sorting, contractType) + } + + fun trackBy(contractType: Class, criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): DataFeed, Vault.Update> { + return _trackBy(criteria, paging, sorting, contractType) + } +} + +inline fun VaultQueryService.queryBy(): Vault.Page { + return _queryBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), T::class.java) +} + +inline fun VaultQueryService.queryBy(criteria: QueryCriteria): Vault.Page { + return _queryBy(criteria, PageSpecification(), Sort(emptySet()), T::class.java) +} + +inline fun VaultQueryService.queryBy(criteria: QueryCriteria, paging: PageSpecification): Vault.Page { + return _queryBy(criteria, paging, Sort(emptySet()), T::class.java) +} + +inline fun VaultQueryService.queryBy(criteria: QueryCriteria, sorting: Sort): Vault.Page { + return _queryBy(criteria, PageSpecification(), sorting, T::class.java) +} + +inline fun VaultQueryService.queryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): Vault.Page { + return _queryBy(criteria, paging, sorting, T::class.java) +} + +inline fun VaultQueryService.trackBy(): DataFeed, Vault.Update> { + return _trackBy(QueryCriteria.VaultQueryCriteria(), PageSpecification(), Sort(emptySet()), T::class.java) +} + +inline fun VaultQueryService.trackBy(criteria: QueryCriteria): DataFeed, Vault.Update> { + return _trackBy(criteria, PageSpecification(), Sort(emptySet()), T::class.java) +} + +inline fun VaultQueryService.trackBy(criteria: QueryCriteria, paging: PageSpecification): DataFeed, Vault.Update> { + return _trackBy(criteria, paging, Sort(emptySet()), T::class.java) +} + +inline fun VaultQueryService.trackBy(criteria: QueryCriteria, sorting: Sort): DataFeed, Vault.Update> { + return _trackBy(criteria, PageSpecification(), sorting, T::class.java) +} + +inline fun VaultQueryService.trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort): DataFeed, Vault.Update> { + return _trackBy(criteria, paging, sorting, T::class.java) +} + +class VaultQueryException(description: String) : FlowException(description) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt b/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt new file mode 100644 index 0000000000..183557f355 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/node/services/VaultService.kt @@ -0,0 +1,280 @@ +package net.corda.core.node.services + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.concurrent.CordaFuture +import net.corda.core.contracts.* +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowException +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.serialization.CordaSerializable +import net.corda.core.toFuture +import net.corda.core.transactions.CoreTransaction +import net.corda.core.utilities.NonEmptySet +import rx.Observable +import rx.subjects.PublishSubject +import java.time.Instant +import java.util.* + +/** + * A vault (name may be temporary) wraps a set of states that are useful for us to keep track of, for instance, + * because we own them. This class represents an immutable, stable state of a vault: it is guaranteed not to + * change out from underneath you, even though the canonical currently-best-known vault may change as we learn + * about new transactions from our peers and generate new transactions that consume states ourselves. + * + * This abstract class has no references to Cash contracts. + * + * [states] Holds a [VaultService] queried subset of states that are *active* and *relevant*. + * Active means they haven't been consumed yet (or we don't know about it). + * Relevant means they contain at least one of our pubkeys. + */ +@CordaSerializable +class Vault(val states: Iterable>) { + /** + * Represents an update observed by the vault that will be notified to observers. Include the [StateRef]s of + * transaction outputs that were consumed (inputs) and the [ContractState]s produced (outputs) to/by the transaction + * or transactions observed and the vault. + * + * If the vault observes multiple transactions simultaneously, where some transactions consume the outputs of some of the + * other transactions observed, then the changes are observed "net" of those. + */ + @CordaSerializable + data class Update( + val consumed: Set>, + val produced: Set>, + val flowId: UUID? = null, + /** + * Specifies the type of update, currently supported types are general and notary change. Notary + * change transactions only modify the notary field on states, and potentially need to be handled + * differently. + */ + val type: UpdateType = UpdateType.GENERAL + ) { + /** Checks whether the update contains a state of the specified type. */ + inline fun containsType() = consumed.any { it.state.data is T } || produced.any { it.state.data is T } + + /** Checks whether the update contains a state of the specified type and state status */ + fun containsType(clazz: Class, status: StateStatus) = + when (status) { + StateStatus.UNCONSUMED -> produced.any { clazz.isAssignableFrom(it.state.data.javaClass) } + StateStatus.CONSUMED -> consumed.any { clazz.isAssignableFrom(it.state.data.javaClass) } + else -> consumed.any { clazz.isAssignableFrom(it.state.data.javaClass) } + || produced.any { clazz.isAssignableFrom(it.state.data.javaClass) } + } + + fun isEmpty() = consumed.isEmpty() && produced.isEmpty() + + /** + * Combine two updates into a single update with the combined inputs and outputs of the two updates but net + * any outputs of the left-hand-side (this) that are consumed by the inputs of the right-hand-side (rhs). + * + * i.e. the net effect in terms of state live-ness of receiving the combined update is the same as receiving this followed by rhs. + */ + operator fun plus(rhs: Update): Update { + require(rhs.type == type) { "Cannot combine updates of different types" } + val combinedConsumed = consumed + (rhs.consumed - produced) + // The ordering below matters to preserve ordering of consumed/produced Sets when they are insertion order dependent implementations. + val combinedProduced = produced.filter { it !in rhs.consumed }.toSet() + rhs.produced + return copy(consumed = combinedConsumed, produced = combinedProduced) + } + + override fun toString(): String { + val sb = StringBuilder() + sb.appendln("${consumed.size} consumed, ${produced.size} produced") + sb.appendln("") + sb.appendln("Consumed:") + consumed.forEach { + sb.appendln("${it.ref}: ${it.state}") + } + sb.appendln("") + sb.appendln("Produced:") + produced.forEach { + sb.appendln("${it.ref}: ${it.state}") + } + return sb.toString() + } + } + + companion object { + val NoUpdate = Update(emptySet(), emptySet(), type = Vault.UpdateType.GENERAL) + val NoNotaryUpdate = Vault.Update(emptySet(), emptySet(), type = Vault.UpdateType.NOTARY_CHANGE) + } + + @CordaSerializable + enum class StateStatus { + UNCONSUMED, CONSUMED, ALL + } + + @CordaSerializable + enum class UpdateType { + GENERAL, NOTARY_CHANGE + } + + /** + * Returned in queries [VaultService.queryBy] and [VaultService.trackBy]. + * A Page contains: + * 1) a [List] of actual [StateAndRef] requested by the specified [QueryCriteria] to a maximum of [MAX_PAGE_SIZE] + * 2) a [List] of associated [Vault.StateMetadata], one per [StateAndRef] result + * 3) a total number of states that met the given [QueryCriteria] if a [PageSpecification] was provided + * (otherwise defaults to -1) + * 4) Status types used in this query: UNCONSUMED, CONSUMED, ALL + * 5) Other results as a [List] of any type (eg. aggregate function results with/without group by) + * + * Note: currently otherResults are used only for Aggregate Functions (in which case, the states and statesMetadata + * results will be empty) + */ + @CordaSerializable + data class Page(val states: List>, + val statesMetadata: List, + val totalStatesAvailable: Long, + val stateTypes: StateStatus, + val otherResults: List) + + @CordaSerializable + data class StateMetadata(val ref: StateRef, + val contractStateClassName: String, + val recordedTime: Instant, + val consumedTime: Instant?, + val status: Vault.StateStatus, + val notaryName: String, + val notaryKey: String, + val lockId: String?, + val lockUpdateTime: Instant?) +} + +/** + * A [VaultService] is responsible for securely and safely persisting the current state of a vault to storage. The + * vault service vends immutable snapshots of the current vault for working with: if you build a transaction based + * on a vault that isn't current, be aware that it may end up being invalid if the states that were used have been + * consumed by someone else first! + * + * Note that transactions we've seen are held by the storage service, not the vault. + */ +interface VaultService { + + /** + * Prefer the use of [updates] unless you know why you want to use this instead. + * + * Get a synchronous Observable of updates. When observations are pushed to the Observer, the Vault will already incorporate + * the update, and the database transaction associated with the update will still be open and current. If for some + * reason the processing crosses outside of the database transaction (for example, the update is pushed outside the current + * JVM or across to another [Thread] which is executing in a different database transaction) then the Vault may + * not incorporate the update due to racing with committing the current database transaction. + */ + val rawUpdates: Observable> + + /** + * Get a synchronous Observable of updates. When observations are pushed to the Observer, the Vault will already incorporate + * the update, and the database transaction associated with the update will have been committed and closed. + */ + val updates: Observable> + + /** + * Enable creation of observables of updates. + */ + val updatesPublisher: PublishSubject> + + /** + * Possibly update the vault by marking as spent states that these transactions consume, and adding any relevant + * new states that they create. You should only insert transactions that have been successfully verified here! + * + * TODO: Consider if there's a good way to enforce the must-be-verified requirement in the type system. + */ + fun notifyAll(txns: Iterable) + + /** Same as notifyAll but with a single transaction. */ + fun notify(tx: CoreTransaction) = notifyAll(listOf(tx)) + + /** + * Provide a [CordaFuture] for when a [StateRef] is consumed, which can be very useful in building tests. + */ + fun whenConsumed(ref: StateRef): CordaFuture> { + return updates.filter { it.consumed.any { it.ref == ref } }.toFuture() + } + + /** Get contracts we would be willing to upgrade the suggested contract to. */ + // TODO: We need a better place to put business logic functions + fun getAuthorisedContractUpgrade(ref: StateRef): Class>? + + /** + * Authorise a contract state upgrade. + * This will store the upgrade authorisation in the vault, and will be queried by [ContractUpgradeFlow.Acceptor] during contract upgrade process. + * Invoking this method indicate the node is willing to upgrade the [state] using the [upgradedContractClass]. + * This method will NOT initiate the upgrade process. To start the upgrade process, see [ContractUpgradeFlow.Instigator]. + */ + fun authoriseContractUpgrade(stateAndRef: StateAndRef<*>, upgradedContractClass: Class>) + + /** + * Authorise a contract state upgrade. + * This will remove the upgrade authorisation from the vault. + */ + fun deauthoriseContractUpgrade(stateAndRef: StateAndRef<*>) + + /** + * Add a note to an existing [LedgerTransaction] given by its unique [SecureHash] id + * Multiple notes may be attached to the same [LedgerTransaction]. + * These are additively and immutably persisted within the node local vault database in a single textual field + * using a semi-colon separator + */ + fun addNoteToTransaction(txnId: SecureHash, noteText: String) + + fun getTransactionNotes(txnId: SecureHash): Iterable + + // DOCEND VaultStatesQuery + + /** + * Soft locking is used to prevent multiple transactions trying to use the same output simultaneously. + * Violation of a soft lock would result in a double spend being created and rejected by the notary. + */ + + // DOCSTART SoftLockAPI + + /** + * Reserve a set of [StateRef] for a given [UUID] unique identifier. + * Typically, the unique identifier will refer to a [FlowLogic.runId.uuid] associated with an in-flight flow. + * In this case if the flow terminates the locks will automatically be freed, even if there is an error. + * However, the user can specify their own [UUID] and manage this manually, possibly across the lifetime of multiple flows, + * or from other thread contexts e.g. [CordaService] instances. + * In the case of coin selection, soft locks are automatically taken upon gathering relevant unconsumed input refs. + * + * @throws [StatesNotAvailableException] when not possible to softLock all of requested [StateRef] + */ + @Throws(StatesNotAvailableException::class) + fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet) + + /** + * Release all or an explicitly specified set of [StateRef] for a given [UUID] unique identifier. + * A vault soft lock manager is automatically notified of a Flows that are terminated, such that any soft locked states + * may be released. + * In the case of coin selection, softLock are automatically released once previously gathered unconsumed input refs + * are consumed as part of cash spending. + */ + fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet? = null) + // DOCEND SoftLockAPI + + /** + * Helper function to combine using [VaultQueryService] calls to determine spendable states and soft locking them. + * Currently performance will be worse than for the hand optimised version in `Cash.unconsumedCashStatesForSpending` + * However, this is fully generic and can operate with custom [FungibleAsset] states. + * @param lockId The [FlowLogic.runId.uuid] of the current flow used to soft lock the states. + * @param eligibleStatesQuery A custom query object that selects down to the appropriate subset of all states of the + * [contractType]. e.g. by selecting on account, issuer, etc. The query is internally augmented with the UNCONSUMED, + * soft lock and contract type requirements. + * @param amount The required amount of the asset, but with the issuer stripped off. + * It is assumed that compatible issuer states will be filtered out by the [eligibleStatesQuery]. + * @param contractType class type of the result set. + * @return Returns a locked subset of the [eligibleStatesQuery] sufficient to satisfy the requested amount, + * or else an empty list and no change in the stored lock states when their are insufficient resources available. + */ + @Suspendable + @Throws(StatesNotAvailableException::class) + fun , U : Any> tryLockFungibleStatesForSpending(lockId: UUID, + eligibleStatesQuery: QueryCriteria, + amount: Amount, + contractType: Class): List> + +} + + +class StatesNotAvailableException(override val message: String?, override val cause: Throwable? = null) : FlowException(message, cause) { + override fun toString() = "Soft locking error: $message" +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteria.kt b/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteria.kt index 0bbce100e5..0be25265f6 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteria.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteria.kt @@ -4,7 +4,6 @@ package net.corda.core.node.services.vault import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateRef -import net.corda.core.contracts.UniqueIdentifier import net.corda.core.identity.AbstractParty import net.corda.core.node.services.Vault import net.corda.core.schemas.PersistentState @@ -26,6 +25,19 @@ sealed class QueryCriteria { @CordaSerializable data class TimeCondition(val type: TimeInstantType, val predicate: ColumnPredicate) + // DOCSTART VaultQuerySoftLockingCriteria + @CordaSerializable + data class SoftLockingCondition(val type: SoftLockingType, val lockIds: List = emptyList()) + + @CordaSerializable + enum class SoftLockingType { + UNLOCKED_ONLY, // only unlocked states + LOCKED_ONLY, // only soft locked states + SPECIFIED, // only those soft locked states specified by lock id(s) + UNLOCKED_AND_SPECIFIED // all unlocked states plus those soft locked states specified by lock id(s) + } + // DOCEND VaultQuerySoftLockingCriteria + abstract class CommonQueryCriteria : QueryCriteria() { abstract val status: Vault.StateStatus override fun visit(parser: IQueryCriteriaParser): Collection { @@ -40,7 +52,7 @@ sealed class QueryCriteria { val contractStateTypes: Set>? = null, val stateRefs: List? = null, val notaryName: List? = null, - val includeSoftlockedStates: Boolean = true, + val softLockingCondition: SoftLockingCondition? = null, val timeCondition: TimeCondition? = null) : CommonQueryCriteria() { override fun visit(parser: IQueryCriteriaParser): Collection { return parser.parseCriteria(this as CommonQueryCriteria).plus(parser.parseCriteria(this)) @@ -51,8 +63,8 @@ sealed class QueryCriteria { * LinearStateQueryCriteria: provides query by attributes defined in [VaultSchema.VaultLinearState] */ data class LinearStateQueryCriteria @JvmOverloads constructor(val participants: List? = null, - val linearId: List? = null, - val dealRef: List? = null, + val uuid: List? = null, + val externalId: List? = null, override val status: Vault.StateStatus = Vault.StateStatus.UNCONSUMED) : CommonQueryCriteria() { override fun visit(parser: IQueryCriteriaParser): Collection { return parser.parseCriteria(this as CommonQueryCriteria).plus(parser.parseCriteria(this)) diff --git a/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteriaUtils.kt b/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteriaUtils.kt index 92817ac424..88364f3bb2 100644 --- a/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteriaUtils.kt +++ b/core/src/main/kotlin/net/corda/core/node/services/vault/QueryCriteriaUtils.kt @@ -6,7 +6,7 @@ import net.corda.core.schemas.PersistentState import net.corda.core.serialization.CordaSerializable import java.lang.reflect.Field import kotlin.reflect.KProperty1 -import kotlin.reflect.jvm.javaField +import kotlin.reflect.jvm.javaGetter @CordaSerializable enum class BinaryLogicalOperator { @@ -66,9 +66,9 @@ sealed class CriteriaExpression { } @CordaSerializable -sealed class Column { - data class Java(val field: Field) : Column() - data class Kotlin(val property: KProperty1) : Column() +class Column(val name: String, val declaringClass: Class<*>) { + constructor(field: Field) : this(field.name, field.declaringClass) + constructor(property: KProperty1) : this(property.name, property.javaGetter!!.declaringClass) } @CordaSerializable @@ -92,19 +92,8 @@ fun resolveEnclosingObjectFromExpression(expression: CriteriaExpression resolveEnclosingObjectFromColumn(column: Column): Class { - return when (column) { - is Column.Java -> column.field.declaringClass as Class - is Column.Kotlin -> column.property.javaField!!.declaringClass as Class - } -} - -fun getColumnName(column: Column): String { - return when (column) { - is Column.Java -> column.field.name - is Column.Kotlin -> column.property.name - } -} +fun resolveEnclosingObjectFromColumn(column: Column): Class = column.declaringClass as Class +fun getColumnName(column: Column): String = column.name /** * Pagination and Ordering @@ -173,8 +162,7 @@ data class Sort(val columns: Collection) { enum class LinearStateAttribute(val attributeName: String) : Attribute { /** Vault Linear States */ UUID("uuid"), - EXTERNAL_ID("externalId"), - DEAL_REFERENCE("dealReference") + EXTERNAL_ID("externalId") } enum class FungibleStateAttribute(val attributeName: String) : Attribute { @@ -210,14 +198,14 @@ sealed class SortAttribute { object Builder { fun > compare(operator: BinaryComparisonOperator, value: R) = ColumnPredicate.BinaryComparison(operator, value) - fun KProperty1.predicate(predicate: ColumnPredicate) = CriteriaExpression.ColumnPredicateExpression(Column.Kotlin(this), predicate) + fun KProperty1.predicate(predicate: ColumnPredicate) = CriteriaExpression.ColumnPredicateExpression(Column(this), predicate) - fun Field.predicate(predicate: ColumnPredicate) = CriteriaExpression.ColumnPredicateExpression(Column.Java(this), predicate) + fun Field.predicate(predicate: ColumnPredicate) = CriteriaExpression.ColumnPredicateExpression(Column(this), predicate) - fun KProperty1.functionPredicate(predicate: ColumnPredicate, groupByColumns: List>? = null, orderBy: Sort.Direction? = null) - = CriteriaExpression.AggregateFunctionExpression(Column.Kotlin(this), predicate, groupByColumns, orderBy) - fun Field.functionPredicate(predicate: ColumnPredicate, groupByColumns: List>? = null, orderBy: Sort.Direction? = null) - = CriteriaExpression.AggregateFunctionExpression(Column.Java(this), predicate, groupByColumns, orderBy) + fun KProperty1.functionPredicate(predicate: ColumnPredicate, groupByColumns: List>? = null, orderBy: Sort.Direction? = null) + = CriteriaExpression.AggregateFunctionExpression(Column(this), predicate, groupByColumns, orderBy) + fun Field.functionPredicate(predicate: ColumnPredicate, groupByColumns: List>? = null, orderBy: Sort.Direction? = null) + = CriteriaExpression.AggregateFunctionExpression(Column(this), predicate, groupByColumns, orderBy) fun > KProperty1.comparePredicate(operator: BinaryComparisonOperator, value: R) = predicate(compare(operator, value)) fun > Field.comparePredicate(operator: BinaryComparisonOperator, value: R) = predicate(compare(operator, value)) @@ -264,34 +252,34 @@ object Builder { /** aggregate functions */ fun KProperty1.sum(groupByColumns: List>? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.SUM), groupByColumns?.map { Column.Kotlin(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.SUM), groupByColumns?.map { Column(it) }, orderBy) @JvmStatic @JvmOverloads fun Field.sum(groupByColumns: List? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.SUM), groupByColumns?.map { Column.Java(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.SUM), groupByColumns?.map { Column(it) }, orderBy) fun KProperty1.count() = functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.COUNT)) @JvmStatic fun Field.count() = functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.COUNT)) fun KProperty1.avg(groupByColumns: List>? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.AVG), groupByColumns?.map { Column.Kotlin(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.AVG), groupByColumns?.map { Column(it) }, orderBy) @JvmStatic @JvmOverloads fun Field.avg(groupByColumns: List? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.AVG), groupByColumns?.map { Column.Java(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.AVG), groupByColumns?.map { Column(it) }, orderBy) fun KProperty1.min(groupByColumns: List>? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MIN), groupByColumns?.map { Column.Kotlin(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MIN), groupByColumns?.map { Column(it) }, orderBy) @JvmStatic @JvmOverloads fun Field.min(groupByColumns: List? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MIN), groupByColumns?.map { Column.Java(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MIN), groupByColumns?.map { Column(it) }, orderBy) fun KProperty1.max(groupByColumns: List>? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MAX), groupByColumns?.map { Column.Kotlin(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MAX), groupByColumns?.map { Column(it) }, orderBy) @JvmStatic @JvmOverloads fun Field.max(groupByColumns: List? = null, orderBy: Sort.Direction? = null) = - functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MAX), groupByColumns?.map { Column.Java(it) }, orderBy) + functionPredicate(ColumnPredicate.AggregateFunction(AggregateFunctionType.MAX), groupByColumns?.map { Column(it) }, orderBy) } inline fun builder(block: Builder.() -> A) = block(Builder) diff --git a/core/src/main/kotlin/net/corda/core/schemas/PersistentTypes.kt b/core/src/main/kotlin/net/corda/core/schemas/PersistentTypes.kt index 36a847eec3..a2c2d00e8f 100644 --- a/core/src/main/kotlin/net/corda/core/schemas/PersistentTypes.kt +++ b/core/src/main/kotlin/net/corda/core/schemas/PersistentTypes.kt @@ -3,6 +3,7 @@ package net.corda.core.schemas import io.requery.Persistable import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateRef +import net.corda.core.serialization.CordaSerializable import net.corda.core.utilities.toHexString import java.io.Serializable import javax.persistence.Column @@ -49,7 +50,7 @@ open class MappedSchema(schemaFamily: Class<*>, * A super class for all mapped states exported to a schema that ensures the [StateRef] appears on the database row. The * [StateRef] will be set to the correct value by the framework (there's no need to set during mapping generation by the state itself). */ -@MappedSuperclass open class PersistentState(@EmbeddedId var stateRef: PersistentStateRef? = null) : StatePersistable +@MappedSuperclass @CordaSerializable open class PersistentState(@EmbeddedId var stateRef: PersistentStateRef? = null) : StatePersistable /** * Embedded [StateRef] representation used in state mapping. diff --git a/core/src/main/kotlin/net/corda/core/schemas/converters/AbstractPartyToX500NameAsStringConverter.kt b/core/src/main/kotlin/net/corda/core/schemas/converters/AbstractPartyToX500NameAsStringConverter.kt new file mode 100644 index 0000000000..9ee9503eef --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/schemas/converters/AbstractPartyToX500NameAsStringConverter.kt @@ -0,0 +1,34 @@ +package net.corda.core.schemas.converters + +import net.corda.core.identity.AbstractParty +import net.corda.core.node.services.IdentityService +import org.bouncycastle.asn1.x500.X500Name +import javax.persistence.AttributeConverter +import javax.persistence.Converter + +/** + * Converter to persist a party as its's well known identity (where resolvable) + * Completely anonymous parties are stored as null (to preserve privacy) + */ +@Converter(autoApply = true) +class AbstractPartyToX500NameAsStringConverter(identitySvc: () -> IdentityService) : AttributeConverter { + + private val identityService: IdentityService by lazy { + identitySvc() + } + + override fun convertToDatabaseColumn(party: AbstractParty?): String? { + party?.let { + return identityService.partyFromAnonymous(party)?.toString() + } + return null // non resolvable anonymous parties + } + + override fun convertToEntityAttribute(dbData: String?): AbstractParty? { + dbData?.let { + val party = identityService.partyFromX500Name(X500Name(dbData)) + return party as AbstractParty + } + return null // non resolvable anonymous parties are stored as nulls + } +} diff --git a/core/src/main/kotlin/net/corda/core/node/AttachmentsClassLoader.kt b/core/src/main/kotlin/net/corda/core/serialization/AttachmentsClassLoader.kt similarity index 98% rename from core/src/main/kotlin/net/corda/core/node/AttachmentsClassLoader.kt rename to core/src/main/kotlin/net/corda/core/serialization/AttachmentsClassLoader.kt index bd76fd8bbc..d179ab6e25 100644 --- a/core/src/main/kotlin/net/corda/core/node/AttachmentsClassLoader.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/AttachmentsClassLoader.kt @@ -1,8 +1,7 @@ -package net.corda.core.node +package net.corda.core.serialization import net.corda.core.contracts.Attachment import net.corda.core.crypto.SecureHash -import net.corda.core.serialization.CordaSerializable import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.io.FileNotFoundException diff --git a/core/src/main/kotlin/net/corda/core/serialization/KryoAMQPSerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/KryoAMQPSerializer.kt deleted file mode 100644 index 42cb65aec7..0000000000 --- a/core/src/main/kotlin/net/corda/core/serialization/KryoAMQPSerializer.kt +++ /dev/null @@ -1,51 +0,0 @@ -package net.corda.core.serialization - -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import net.corda.core.serialization.amqp.DeserializationInput -import net.corda.core.serialization.amqp.SerializationOutput -import net.corda.core.serialization.amqp.SerializerFactory - -/** - * This [Kryo] custom [Serializer] switches the object graph of anything annotated with `@CordaSerializable` - * to using the AMQP serialization wire format, and simply writes that out as bytes to the wire. - * - * There is no need to write out the length, since this can be peeked out of the first few bytes of the stream. - */ -object KryoAMQPSerializer : Serializer() { - internal fun registerCustomSerializers(factory: SerializerFactory) { - factory.apply { - register(net.corda.core.serialization.amqp.custom.PublicKeySerializer) - register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(this)) - register(net.corda.core.serialization.amqp.custom.X500NameSerializer) - register(net.corda.core.serialization.amqp.custom.BigDecimalSerializer) - register(net.corda.core.serialization.amqp.custom.CurrencySerializer) - register(net.corda.core.serialization.amqp.custom.InstantSerializer(this)) - } - } - - // TODO: need to sort out the whitelist... we currently do not apply the whitelist attached to the [Kryo] - // instance to the factory. We need to do this before turning on AMQP serialization. - private val serializerFactory = SerializerFactory().apply { - registerCustomSerializers(this) - } - - override fun write(kryo: Kryo, output: Output, obj: Any) { - val amqpOutput = SerializationOutput(serializerFactory) - val bytes = amqpOutput.serialize(obj).bytes - // No need to write out the size since it's encoded within the AMQP. - output.write(bytes) - } - - override fun read(kryo: Kryo, input: Input, type: Class): Any { - val amqpInput = DeserializationInput(serializerFactory) - // Use our helper functions to peek the size of the serialized object out of the AMQP byte stream. - val peekedBytes = input.readBytes(DeserializationInput.BYTES_NEEDED_TO_PEEK) - val size = DeserializationInput.peekSize(peekedBytes) - val allBytes = peekedBytes.copyOf(size) - input.readBytes(allBytes, peekedBytes.size, size - peekedBytes.size) - return amqpInput.deserialize(SerializedBytes(allBytes), type) - } -} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/MissingAttachmentsException.kt b/core/src/main/kotlin/net/corda/core/serialization/MissingAttachmentsException.kt new file mode 100644 index 0000000000..2094f1aeff --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/MissingAttachmentsException.kt @@ -0,0 +1,7 @@ +package net.corda.core.serialization + +import net.corda.core.crypto.SecureHash + +/** Thrown during deserialisation to indicate that an attachment needed to construct the [WireTransaction] is not found */ +@CordaSerializable +class MissingAttachmentsException(val ids: List) : Exception() \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt new file mode 100644 index 0000000000..7dab1e0243 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -0,0 +1,145 @@ +package net.corda.core.serialization + +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.sha256 +import net.corda.core.internal.WriteOnceProperty +import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT +import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.sequence + +/** + * An abstraction for serializing and deserializing objects, with support for versioning of the wire format via + * a header / prefix in the bytes. + */ +interface SerializationFactory { + /** + * Deserialize the bytes in to an object, using the prefixed bytes to determine the format. + * + * @param byteSequence The bytes to deserialize, including a format header prefix. + * @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown. + * @param context A context that configures various parameters to deserialization. + */ + fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T + + /** + * Serialize an object to bytes using the preferred serialization format version from the context. + * + * @param obj The object to be serialized. + * @param context A context that configures various parameters to serialization, including the serialization format version. + */ + fun serialize(obj: T, context: SerializationContext): SerializedBytes +} + +/** + * Parameters to serialization and deserialization. + */ +interface SerializationContext { + /** + * When serializing, use the format this header sequence represents. + */ + val preferedSerializationVersion: ByteSequence + /** + * The class loader to use for deserialization. + */ + val deserializationClassLoader: ClassLoader + /** + * A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized. + */ + val whitelist: ClassWhitelist + /** + * A map of any addition properties specific to the particular use case. + */ + val properties: Map + /** + * Duplicate references to the same object preserved in the wire format and when deserialized when this is true, + * otherwise they appear as new copies of the object. + */ + val objectReferencesEnabled: Boolean + /** + * The use case we are serializing or deserializing for. See [UseCase]. + */ + val useCase: UseCase + /** + * Helper method to return a new context based on this context with the property added. + */ + fun withProperty(property: Any, value: Any): SerializationContext + + /** + * Helper method to return a new context based on this context with object references disabled. + */ + fun withoutReferences(): SerializationContext + + /** + * Helper method to return a new context based on this context with the deserialization class loader changed. + */ + fun withClassLoader(classLoader: ClassLoader): SerializationContext + + /** + * Helper method to return a new context based on this context with the given class specifically whitelisted. + */ + fun withWhitelisted(clazz: Class<*>): SerializationContext + + /** + * Helper method to return a new context based on this context but with serialization using the format this header sequence represents. + */ + fun withPreferredSerializationVersion(versionHeader: ByteSequence): SerializationContext + + /** + * The use case that we are serializing for, since it influences the implementations chosen. + */ + enum class UseCase { P2P, RPCServer, RPCClient, Storage, Checkpoint } +} + +/** + * Global singletons to be used as defaults that are injected elsewhere (generally, in the node or in RPC client). + */ +object SerializationDefaults { + var SERIALIZATION_FACTORY: SerializationFactory by WriteOnceProperty() + var P2P_CONTEXT: SerializationContext by WriteOnceProperty() + var RPC_SERVER_CONTEXT: SerializationContext by WriteOnceProperty() + var RPC_CLIENT_CONTEXT: SerializationContext by WriteOnceProperty() + var STORAGE_CONTEXT: SerializationContext by WriteOnceProperty() + var CHECKPOINT_CONTEXT: SerializationContext by WriteOnceProperty() +} + +/** + * Convenience extension method for deserializing a ByteSequence, utilising the defaults. + */ +inline fun ByteSequence.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T { + return serializationFactory.deserialize(this, T::class.java, context) +} + +/** + * Convenience extension method for deserializing SerializedBytes with type matching, utilising the defaults. + */ +inline fun SerializedBytes.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T { + return serializationFactory.deserialize(this, T::class.java, context) +} + +/** + * Convenience extension method for deserializing a ByteArray, utilising the defaults. + */ +inline fun ByteArray.deserialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): T = this.sequence().deserialize(serializationFactory, context) + +/** + * Convenience extension method for serializing an object of type T, utilising the defaults. + */ +fun T.serialize(serializationFactory: SerializationFactory = SERIALIZATION_FACTORY, context: SerializationContext = P2P_CONTEXT): SerializedBytes { + return serializationFactory.serialize(this, context) +} + +/** + * A type safe wrapper around a byte array that contains a serialised object. You can call [SerializedBytes.deserialize] + * to get the original object back. + */ +@Suppress("unused") // Type parameter is just for documentation purposes. +class SerializedBytes(bytes: ByteArray) : OpaqueBytes(bytes) { + // It's OK to use lazy here because SerializedBytes is configured to use the ImmutableClassSerializer. + val hash: SecureHash by lazy { bytes.sha256() } +} + +interface ClassWhitelist { + fun hasListed(type: Class<*>): Boolean +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomization.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomization.kt index 08d497589e..e051f732c3 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomization.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationCustomization.kt @@ -1,13 +1,6 @@ package net.corda.core.serialization -import com.esotericsoftware.kryo.Kryo - interface SerializationCustomization { fun addToWhitelist(type: Class<*>) } -class KryoSerializationCustomization(val kryo: Kryo) : SerializationCustomization { - override fun addToWhitelist(type: Class<*>) { - kryo.addToWhitelist(type) - } -} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt index c141435a4e..8ca6820711 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationToken.kt @@ -1,11 +1,5 @@ package net.corda.core.serialization -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.KryoException -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import net.corda.core.node.ServiceHub import net.corda.core.serialization.SingletonSerializationToken.Companion.singletonSerializationToken @@ -36,86 +30,13 @@ interface SerializationToken { fun fromToken(context: SerializeAsTokenContext): Any } -/** - * A Kryo serializer for [SerializeAsToken] implementations. - */ -class SerializeAsTokenSerializer : Serializer() { - override fun write(kryo: Kryo, output: Output, obj: T) { - kryo.writeClassAndObject(output, obj.toToken(kryo.serializationContext() ?: throw KryoException("Attempt to write a ${SerializeAsToken::class.simpleName} instance of ${obj.javaClass.name} without initialising a context"))) - } - - override fun read(kryo: Kryo, input: Input, type: Class): T { - val token = (kryo.readClassAndObject(input) as? SerializationToken) ?: throw KryoException("Non-token read for tokenized type: ${type.name}") - val fromToken = token.fromToken(kryo.serializationContext() ?: throw KryoException("Attempt to read a token for a ${SerializeAsToken::class.simpleName} instance of ${type.name} without initialising a context")) - if (type.isAssignableFrom(fromToken.javaClass)) { - return type.cast(fromToken) - } else { - throw KryoException("Token read ($token) did not return expected tokenized type: ${type.name}") - } - } -} - -private val serializationContextKey = SerializeAsTokenContext::class.java - -fun Kryo.serializationContext() = context.get(serializationContextKey) as? SerializeAsTokenContext - -fun Kryo.withSerializationContext(serializationContext: SerializeAsTokenContext, block: () -> T) = run { - context.containsKey(serializationContextKey) && throw IllegalStateException("There is already a serialization context.") - context.put(serializationContextKey, serializationContext) - try { - block() - } finally { - context.remove(serializationContextKey) - } -} - /** * A context for mapping SerializationTokens to/from SerializeAsTokens. - * - * A context is initialised with an object containing all the instances of [SerializeAsToken] to eagerly register all the tokens. - * In our case this can be the [ServiceHub]. - * - * Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary - * on the Kryo instance when serializing to enable/disable tokenization. */ -class SerializeAsTokenContext internal constructor(val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) { - constructor(toBeTokenized: Any, kryoPool: KryoPool, serviceHub: ServiceHub) : this(serviceHub, { - kryoPool.run { kryo -> - kryo.withSerializationContext(this) { - toBeTokenized.serialize(kryo) - } - } - }) - - private val classNameToSingleton = mutableMapOf() - private var readOnly = false - - init { - /** - * Go ahead and eagerly serialize the object to register all of the tokens in the context. - * - * This results in the toToken() method getting called for any [SingletonSerializeAsToken] instances which - * are encountered in the object graph as they are serialized by Kryo and will therefore register the token to - * object mapping for those instances. We then immediately set the readOnly flag to stop further adhoc or - * accidental registrations from occuring as these could not be deserialized in a deserialization-first - * scenario if they are not part of this iniital context construction serialization. - */ - init(this) - readOnly = true - } - - internal fun putSingleton(toBeTokenized: SerializeAsToken) { - val className = toBeTokenized.javaClass.name - if (className !in classNameToSingleton) { - // Only allowable if we are in SerializeAsTokenContext init (readOnly == false) - if (readOnly) { - throw UnsupportedOperationException("Attempt to write token for lazy registered ${className}. All tokens should be registered during context construction.") - } - classNameToSingleton[className] = toBeTokenized - } - } - - internal fun getSingleton(className: String) = classNameToSingleton[className] ?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this") +interface SerializeAsTokenContext { + val serviceHub: ServiceHub + fun putSingleton(toBeTokenized: SerializeAsToken) + fun getSingleton(className: String): SerializeAsToken } /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/ArraySerializer.kt b/core/src/main/kotlin/net/corda/core/serialization/amqp/ArraySerializer.kt deleted file mode 100644 index ca1612bc50..0000000000 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/ArraySerializer.kt +++ /dev/null @@ -1,50 +0,0 @@ -package net.corda.core.serialization.amqp - -import org.apache.qpid.proton.codec.Data -import java.io.NotSerializableException -import java.lang.reflect.Type - -/** - * Serialization / deserialization of arrays. - */ -class ArraySerializer(override val type: Type, factory: SerializerFactory) : AMQPSerializer { - override val typeDescriptor = "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" - - internal val elementType: Type = type.componentType() - - private val typeNotation: TypeNotation = RestrictedType(type.typeName, null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList()) - - override fun writeClassInfo(output: SerializationOutput) { - if (output.writeTypeNotations(typeNotation)) { - output.requireSerializer(elementType) - } - } - - override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { - // Write described - data.withDescribed(typeNotation.descriptor) { - withList { - for (entry in obj as Array<*>) { - output.writeObjectOrNull(entry, this, elementType) - } - } - } - } - - override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any { - if (obj is List<*>) { - return obj.map { input.readObjectOrNull(it, schema, elementType) }.toArrayOfType(elementType) - } else throw NotSerializableException("Expected a List but found $obj") - } - - private fun List.toArrayOfType(type: Type): Any { - val elementType = type.asClass() ?: throw NotSerializableException("Unexpected array element type $type") - val list = this - return java.lang.reflect.Array.newInstance(elementType, this.size).apply { - val array = this - for (i in 0..lastIndex) { - java.lang.reflect.Array.set(array, i, list[i]) - } - } - } -} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt index 876efdda82..276415b041 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/BaseTransaction.kt @@ -2,58 +2,154 @@ package net.corda.core.transactions import net.corda.core.contracts.* import net.corda.core.identity.Party -import java.security.PublicKey -import java.util.* +import net.corda.core.internal.indexOfOrThrow +import net.corda.core.internal.castIfPossible +import java.util.function.Predicate /** * An abstract class defining fields shared by all transaction types in the system. */ -abstract class BaseTransaction( - /** The inputs of this transaction. Note that in BaseTransaction subclasses the type of this list may change! */ - open val inputs: List<*>, - /** Ordered list of states defined by this transaction, along with the associated notaries. */ - val outputs: List>, - /** - * If present, the notary for this transaction. If absent then the transaction is not notarised at all. - * This is intended for issuance/genesis transactions that don't consume any other states and thus can't - * double spend anything. - */ - val notary: Party?, - /** - * Public keys that need to be fulfilled by signatures in order for the transaction to be valid. - * In a [SignedTransaction] this list is used to check whether there are any missing signatures. Note that - * there is nothing that forces the list to be the _correct_ list of signers for this transaction until - * the transaction is verified by using [LedgerTransaction.verify]. - * - * It includes the notary key, if the notary field is set. - */ - val mustSign: List, - /** - * Pointer to a class that defines the behaviour of this transaction: either normal, or "notary changing". - */ - val type: TransactionType, - /** - * If specified, a time window in which this transaction may have been notarised. Contracts can check this - * time window to find out when a transaction is deemed to have occurred, from the ledger's perspective. - */ - val timeWindow: TimeWindow? -) : NamedByHash { +abstract class BaseTransaction : NamedByHash { + /** The inputs of this transaction. Note that in BaseTransaction subclasses the type of this list may change! */ + abstract val inputs: List<*> + /** Ordered list of states defined by this transaction, along with the associated notaries. */ + abstract val outputs: List> + /** + * If present, the notary for this transaction. If absent then the transaction is not notarised at all. + * This is intended for issuance/genesis transactions that don't consume any other states and thus can't + * double spend anything. + */ + abstract val notary: Party? - protected fun checkInvariants() { - if (notary == null) check(inputs.isEmpty()) { "The notary must be specified explicitly for any transaction that has inputs" } - if (timeWindow != null) check(notary != null) { "If a time-window is provided, there must be a notary" } + protected open fun checkBaseInvariants() { + checkNotarySetIfInputsPresent() + checkNoDuplicateInputs() } - override fun equals(other: Any?): Boolean { - if (other === this) return true - return other is BaseTransaction && - notary == other.notary && - mustSign == other.mustSign && - type == other.type && - timeWindow == other.timeWindow + private fun checkNotarySetIfInputsPresent() { + if (notary == null) { + check(inputs.isEmpty()) { "The notary must be specified explicitly for any transaction that has inputs" } + } } - override fun hashCode() = Objects.hash(notary, mustSign, type, timeWindow) + private fun checkNoDuplicateInputs() { + val duplicates = inputs.groupBy { it }.filter { it.value.size > 1 }.keys + check(duplicates.isEmpty()) { "Duplicate input states detected" } + } + + /** + * Returns a [StateAndRef] for the given output index. + */ + @Suppress("UNCHECKED_CAST") + fun outRef(index: Int): StateAndRef = StateAndRef(outputs[index] as TransactionState, StateRef(id, index)) + + /** + * Returns a [StateAndRef] for the requested output state, or throws [IllegalArgumentException] if not found. + */ + fun outRef(state: ContractState): StateAndRef = outRef(outputStates.indexOfOrThrow(state)) + + /** + * Helper property to return a list of [ContractState] objects, rather than the often less convenient [TransactionState] + */ + val outputStates: List get() = outputs.map { it.data } + + /** + * Helper to simplify getting an indexed output. + * @param index the position of the item in the output. + * @return The ContractState at the requested index + */ + fun getOutput(index: Int): ContractState = outputs[index].data + + /** + * Helper to simplify getting all output states of a particular class, interface, or base class. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * Clazz must be an extension of [ContractState]. + * @return the possibly empty list of output states matching the clazz restriction. + */ + fun outputsOfType(clazz: Class): List = outputs.mapNotNull { clazz.castIfPossible(it.data) } + + inline fun outputsOfType(): List = outputsOfType(T::class.java) + + /** + * Helper to simplify filtering outputs according to a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. + * @return the possibly empty list of output states matching the predicate and clazz restrictions. + */ + fun filterOutputs(clazz: Class, predicate: Predicate): List { + return outputsOfType(clazz).filter { predicate.test(it) } + } + + inline fun filterOutputs(crossinline predicate: (T) -> Boolean): List { + return filterOutputs(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify finding a single output matching a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. + * @return the single item matching the predicate. + * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. + */ + fun findOutput(clazz: Class, predicate: Predicate): T { + return outputsOfType(clazz).single { predicate.test(it) } + } + + inline fun findOutput(crossinline predicate: (T) -> Boolean): T { + return findOutput(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify getting all output [StateAndRef] items of a particular state class, interface, or base class. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * Clazz must be an extension of [ContractState]. + * @return the possibly empty list of output [StateAndRef] states matching the clazz restriction. + */ + fun outRefsOfType(clazz: Class): List> { + return outputs.mapIndexedNotNull { index, state -> + @Suppress("UNCHECKED_CAST") + clazz.castIfPossible(state.data)?.let { StateAndRef(state as TransactionState, StateRef(id, index)) } + } + } + + inline fun outRefsOfType(): List> = outRefsOfType(T::class.java) + + /** + * Helper to simplify filtering output [StateAndRef] items according to a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. + * @return the possibly empty list of output [StateAndRef] states matching the predicate and clazz restrictions. + */ + fun filterOutRefs(clazz: Class, predicate: Predicate): List> { + return outRefsOfType(clazz).filter { predicate.test(it.state.data) } + } + + inline fun filterOutRefs(crossinline predicate: (T) -> Boolean): List> { + return filterOutRefs(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify finding a single output [StateAndRef] matching a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * Clazz must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. + * @return the single [StateAndRef] item matching the predicate. + * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. + */ + fun findOutRef(clazz: Class, predicate: Predicate): StateAndRef { + return outRefsOfType(clazz).single { predicate.test(it.state.data) } + } + + inline fun findOutRef(crossinline predicate: (T) -> Boolean): StateAndRef { + return findOutRef(T::class.java, Predicate { predicate(it) }) + } override fun toString(): String = "${javaClass.simpleName}(id=$id)" -} +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/BaseTransactions.kt b/core/src/main/kotlin/net/corda/core/transactions/BaseTransactions.kt new file mode 100644 index 0000000000..2d82c4850b --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/transactions/BaseTransactions.kt @@ -0,0 +1,32 @@ +package net.corda.core.transactions + +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef + +/** + * A transaction with the minimal amount of information required to compute the unique transaction [id], and + * resolve a [FullTransaction]. This type of transaction, wrapped in [SignedTransaction], gets transferred across the + * wire and recorded to storage. + */ +abstract class CoreTransaction : BaseTransaction() { + /** The inputs of this transaction, containing state references only **/ + abstract override val inputs: List +} + +/** A transaction with fully resolved components, such as input states. */ +abstract class FullTransaction : BaseTransaction() { + abstract override val inputs: List> + + override fun checkBaseInvariants() { + super.checkBaseInvariants() + checkInputsHaveSameNotary() + } + + private fun checkInputsHaveSameNotary() { + if (inputs.isEmpty()) return + val inputNotaries = inputs.map { it.state.notary }.toHashSet() + check(inputNotaries.size == 1) { "All inputs must point to the same notary" } + check(inputNotaries.single() == notary) { "The specified notary must be the one specified by all inputs" } + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt index 3fd91f1f0e..59d3e8515d 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/LedgerTransaction.kt @@ -3,8 +3,10 @@ package net.corda.core.transactions import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party +import net.corda.core.internal.castIfPossible import net.corda.core.serialization.CordaSerializable -import java.security.PublicKey +import java.util.* +import java.util.function.Predicate /** * A LedgerTransaction is derived from a [WireTransaction]. It is the result of doing the following operations: @@ -19,73 +21,324 @@ import java.security.PublicKey */ // TODO LedgerTransaction is not supposed to be serialisable as it references attachments, etc. The verification logic // currently sends this across to out-of-process verifiers. We'll need to change that first. +// DOCSTART 1 @CordaSerializable -class LedgerTransaction( +data class LedgerTransaction( /** The resolved input states which will be consumed/invalidated by the execution of this transaction. */ - override val inputs: List>, - outputs: List>, + override val inputs: List>, + override val outputs: List>, /** Arbitrary data passed to the program of each input state. */ val commands: List>, /** A list of [Attachment] objects identified by the transaction that are needed for this transaction to verify. */ val attachments: List, /** The hash of the original serialised WireTransaction. */ override val id: SecureHash, - notary: Party?, - signers: List, - timeWindow: TimeWindow?, - type: TransactionType -) : BaseTransaction(inputs, outputs, notary, signers, type, timeWindow) { + override val notary: Party?, + val timeWindow: TimeWindow?, + val privacySalt: PrivacySalt +) : FullTransaction() { + //DOCEND 1 init { - checkInvariants() + checkBaseInvariants() + if (timeWindow != null) check(notary != null) { "Transactions with time-windows must be notarised" } + checkNoNotaryChange() + checkEncumbrancesValid() } - @Suppress("UNCHECKED_CAST") - fun outRef(index: Int) = StateAndRef(outputs[index] as TransactionState, StateRef(id, index)) - - // TODO: Remove this concept. - // There isn't really a good justification for hiding this data from the contract, it's just a backwards compat hack. - /** Strips the transaction down to a form that is usable by the contract verify functions */ - fun toTransactionForContract(): TransactionForContract { - return TransactionForContract(inputs.map { it.state.data }, outputs.map { it.data }, attachments, commands, id, - inputs.map { it.state.notary }.singleOrNull(), timeWindow) - } + val inputStates: List get() = inputs.map { it.state.data } /** - * Verifies this transaction and throws an exception if not valid, depending on the type. For general transactions: - * - * - The contracts are run with the transaction as the input. - * - The list of keys mentioned in commands is compared against the signers list. + * Returns the typed input StateAndRef at the specified index + * @param index The index into the inputs. + * @return The [StateAndRef] + */ + @Suppress("UNCHECKED_CAST") + fun inRef(index: Int): StateAndRef = inputs[index] as StateAndRef + + /** + * Verifies this transaction and runs contract code. At this stage it is assumed that signatures have already been verified. * * @throws TransactionVerificationException if anything goes wrong. */ @Throws(TransactionVerificationException::class) - fun verify() = type.verify(this) + fun verify() = verifyContracts() - // TODO: When we upgrade to Kotlin 1.1 we can make this a data class again and have the compiler generate these. - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other?.javaClass != javaClass) return false - if (!super.equals(other)) return false - - other as LedgerTransaction - - if (inputs != other.inputs) return false - if (outputs != other.outputs) return false - if (commands != other.commands) return false - if (attachments != other.attachments) return false - if (id != other.id) return false - - return true + /** + * Check the transaction is contract-valid by running the verify() for each input and output state contract. + * If any contract fails to verify, the whole transaction is considered to be invalid. + */ + private fun verifyContracts() { + val contracts = (inputs.map { it.state.data.contract } + outputs.map { it.data.contract }).toSet() + for (contract in contracts) { + try { + contract.verify(this) + } catch(e: Throwable) { + throw TransactionVerificationException.ContractRejection(id, contract, e) + } + } } - override fun hashCode(): Int { - var result = super.hashCode() - result = 31 * result + inputs.hashCode() - result = 31 * result + outputs.hashCode() - result = 31 * result + commands.hashCode() - result = 31 * result + attachments.hashCode() - result = 31 * result + id.hashCode() + /** + * Make sure the notary has stayed the same. As we can't tell how inputs and outputs connect, if there + * are any inputs, all outputs must have the same notary. + * + * TODO: Is that the correct set of restrictions? May need to come back to this, see if we can be more + * flexible on output notaries. + */ + private fun checkNoNotaryChange() { + if (notary != null && inputs.isNotEmpty()) { + outputs.forEach { + if (it.notary != notary) { + throw TransactionVerificationException.NotaryChangeInWrongTransactionType(id, notary, it.notary) + } + } + } + } + + private fun checkEncumbrancesValid() { + // Validate that all encumbrances exist within the set of input states. + val encumberedInputs = inputs.filter { it.state.encumbrance != null } + encumberedInputs.forEach { (state, ref) -> + val encumbranceStateExists = inputs.any { + it.ref.txhash == ref.txhash && it.ref.index == state.encumbrance + } + if (!encumbranceStateExists) { + throw TransactionVerificationException.TransactionMissingEncumbranceException( + id, + state.encumbrance!!, + TransactionVerificationException.Direction.INPUT + ) + } + } + + // Check that, in the outputs, an encumbered state does not refer to itself as the encumbrance, + // and that the number of outputs can contain the encumbrance. + for ((i, output) in outputs.withIndex()) { + val encumbranceIndex = output.encumbrance ?: continue + if (encumbranceIndex == i || encumbranceIndex >= outputs.size) { + throw TransactionVerificationException.TransactionMissingEncumbranceException( + id, + encumbranceIndex, + TransactionVerificationException.Direction.OUTPUT) + } + } + } + + /** + * Given a type and a function that returns a grouping key, associates inputs and outputs together so that they + * can be processed as one. The grouping key is any arbitrary object that can act as a map key (so must implement + * equals and hashCode). + * + * The purpose of this function is to simplify the writing of verification logic for transactions that may contain + * similar but unrelated state evolutions which need to be checked independently. Consider a transaction that + * simultaneously moves both dollars and euros (e.g. is an atomic FX trade). There may be multiple dollar inputs and + * multiple dollar outputs, depending on things like how fragmented the owner's vault is and whether various privacy + * techniques are in use. The quantity of dollars on the output side must sum to the same as on the input side, to + * ensure no money is being lost track of. This summation and checking must be repeated independently for each + * currency. To solve this, you would use groupStates with a type of Cash.State and a selector that returns the + * currency field: the resulting list can then be iterated over to perform the per-currency calculation. + */ + // DOCSTART 2 + fun groupStates(ofType: Class, selector: (T) -> K): List> { + val inputs = inputsOfType(ofType) + val outputs = outputsOfType(ofType) + + val inGroups: Map> = inputs.groupBy(selector) + val outGroups: Map> = outputs.groupBy(selector) + + val result = ArrayList>() + + for ((k, v) in inGroups.entries) + result.add(InOutGroup(v, outGroups[k] ?: emptyList(), k)) + for ((k, v) in outGroups.entries) { + if (inGroups[k] == null) + result.add(InOutGroup(emptyList(), v, k)) + } + return result } + // DOCEND 2 + + /** See the documentation for the reflection-based version of [groupStates] */ + inline fun groupStates(noinline selector: (T) -> K): List> { + return groupStates(T::class.java, selector) + } + + /** Utilities for contract writers to incorporate into their logic. */ + + /** + * A set of related inputs and outputs that are connected by some common attributes. An InOutGroup is calculated + * using [groupStates] and is useful for handling cases where a transaction may contain similar but unrelated + * state evolutions, for example, a transaction that moves cash in two different currencies. The numbers must add + * up on both sides of the transaction, but the values must be summed independently per currency. Grouping can + * be used to simplify this logic. + */ + // DOCSTART 3 + data class InOutGroup(val inputs: List, val outputs: List, val groupingKey: K) + // DOCEND 3 + + /** + * Helper to simplify getting an indexed input [ContractState]. + * @param index the position of the item in the inputs. + * @return The [StateAndRef] at the requested index + */ + fun getInput(index: Int): ContractState = inputs[index].state.data + + /** + * Helper to simplify getting all inputs states of a particular class, interface, or base class. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of [ContractState]. + * @return the possibly empty list of inputs matching the clazz restriction. + */ + fun inputsOfType(clazz: Class): List = inputs.mapNotNull { clazz.castIfPossible(it.state.data) } + + inline fun inputsOfType(): List = inputsOfType(T::class.java) + + /** + * Helper to simplify getting all inputs states of a particular class, interface, or base class. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of [ContractState]. + * @return the possibly empty list of inputs [StateAndRef] matching the clazz restriction. + */ + fun inRefsOfType(clazz: Class): List> { + @Suppress("UNCHECKED_CAST") + return inputs.mapNotNull { if (clazz.isInstance(it.state.data)) it as StateAndRef else null } + } + + inline fun inRefsOfType(): List> = inRefsOfType(T::class.java) + + /** + * Helper to simplify filtering inputs according to a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of [ContractState]. + * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. + * @return the possibly empty list of input states matching the predicate and clazz restrictions. + */ + fun filterInputs(clazz: Class, predicate: Predicate): List { + return inputsOfType(clazz).filter { predicate.test(it) } + } + + inline fun filterInputs(crossinline predicate: (T) -> Boolean): List { + return filterInputs(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify filtering inputs according to a [Predicate]. + * @param predicate A filtering function taking a state of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of [ContractState]. + * @return the possibly empty list of inputs [StateAndRef] matching the predicate and clazz restrictions. + */ + fun filterInRefs(clazz: Class, predicate: Predicate): List> { + return inRefsOfType(clazz).filter { predicate.test(it.state.data) } + } + + inline fun filterInRefs(crossinline predicate: (T) -> Boolean): List> { + return filterInRefs(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify finding a single input [ContractState] matching a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of ContractState. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. + * @return the single item matching the predicate. + * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. + */ + fun findInput(clazz: Class, predicate: Predicate): T { + return inputsOfType(clazz).single { predicate.test(it) } + } + + inline fun findInput(crossinline predicate: (T) -> Boolean): T { + return findInput(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify finding a single input matching a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of ContractState. + * @param predicate A filtering function taking a state of type T and returning true if this is the desired item. + * The class filtering is applied before the predicate. + * @return the single item matching the predicate. + * @throws IllegalArgumentException if no item, or multiple items are found matching the requirements. + */ + fun findInRef(clazz: Class, predicate: Predicate): StateAndRef { + return inRefsOfType(clazz).single { predicate.test(it.state.data) } + } + + inline fun findInRef(crossinline predicate: (T) -> Boolean): StateAndRef { + return findInRef(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify getting an indexed command. + * @param index the position of the item in the commands. + * @return The Command at the requested index + */ + @Suppress("UNCHECKED_CAST") + fun getCommand(index: Int): Command = Command(commands[index].value as T, commands[index].signers) + + /** + * Helper to simplify getting all [Command] items with a [CommandData] of a particular class, interface, or base class. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of [CommandData]. + * @return the possibly empty list of commands with [CommandData] values matching the clazz restriction. + */ + fun commandsOfType(clazz: Class): List> { + return commands.mapNotNull { (signers, _, value) -> clazz.castIfPossible(value)?.let { Command(it, signers) } } + } + + inline fun commandsOfType(): List> = commandsOfType(T::class.java) + + /** + * Helper to simplify filtering [Command] items according to a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of [CommandData]. + * @param predicate A filtering function taking a [CommandData] item of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. + * @return the possibly empty list of [Command] items with [CommandData] values matching the predicate and clazz restrictions. + */ + fun filterCommands(clazz: Class, predicate: Predicate): List> { + return commandsOfType(clazz).filter { predicate.test(it.value) } + } + + inline fun filterCommands(crossinline predicate: (T) -> Boolean): List> { + return filterCommands(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify finding a single [Command] items according to a [Predicate]. + * @param clazz The class type used for filtering via an [Class.isInstance] check. + * [clazz] must be an extension of [CommandData]. + * @param predicate A filtering function taking a [CommandData] item of type T and returning true if it should be included in the list. + * The class filtering is applied before the predicate. + * @return the [Command] item with [CommandData] values matching the predicate and clazz restrictions. + * @throws IllegalArgumentException if no items, or multiple items matched the requirements. + */ + fun findCommand(clazz: Class, predicate: Predicate): Command { + return commandsOfType(clazz).single { predicate.test(it.value) } + } + + inline fun findCommand(crossinline predicate: (T) -> Boolean): Command { + return findCommand(T::class.java, Predicate { predicate(it) }) + } + + /** + * Helper to simplify getting an indexed attachment. + * @param index the position of the item in the attachments. + * @return The Attachment at the requested index. + */ + fun getAttachment(index: Int): Attachment = attachments[index] + + /** + * Helper to simplify getting an indexed attachment. + * @param id the SecureHash of the desired attachment. + * @return The Attachment with the matching id. + * @throws IllegalArgumentException if no item matches the id. + */ + fun getAttachment(id: SecureHash): Attachment = attachments.first { it.id == id } } + diff --git a/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt index 0e611fb242..38b418e877 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/MerkleTransaction.kt @@ -1,22 +1,37 @@ package net.corda.core.transactions import net.corda.core.contracts.* -import net.corda.core.crypto.MerkleTree -import net.corda.core.crypto.MerkleTreeException -import net.corda.core.crypto.PartialMerkleTree -import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.* import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.p2PKryo +import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT import net.corda.core.serialization.serialize -import net.corda.core.serialization.withoutReferences -import java.security.PublicKey +import java.nio.ByteBuffer import java.util.function.Predicate -fun serializedHash(x: T): SecureHash { - return p2PKryo().run { kryo -> kryo.withoutReferences { x.serialize(kryo).hash } } +/** + * If a privacy salt is provided, the resulted output (Merkle-leaf) is computed as + * Hash(serializedObject || Hash(privacy_salt || obj_index_in_merkle_tree)). + */ +fun serializedHash(x: T, privacySalt: PrivacySalt?, index: Int): SecureHash { + return if (privacySalt != null) + serializedHash(x, computeNonce(privacySalt, index)) + else + serializedHash(x) } +fun serializedHash(x: T, nonce: SecureHash): SecureHash { + return if (x !is PrivacySalt) // PrivacySalt is not required to have an accompanied nonce. + (x.serialize(context = P2P_CONTEXT.withoutReferences()).bytes + nonce.bytes).sha256() + else + serializedHash(x) +} + +fun serializedHash(x: T): SecureHash = x.serialize(context = P2P_CONTEXT.withoutReferences()).bytes.sha256() + +/** The nonce is computed as Hash(privacySalt || index). */ +fun computeNonce(privacySalt: PrivacySalt, index: Int) = (privacySalt.bytes + ByteBuffer.allocate(4).putInt(index).array()).sha256() + /** * Implemented by [WireTransaction] and [FilteredLeaves]. A TraversableTransaction allows you to iterate * over the flattened components of the underlying transaction structure, taking into account that some @@ -30,11 +45,20 @@ interface TraversableTransaction { val inputs: List val attachments: List val outputs: List> - val commands: List + val commands: List> val notary: Party? - val mustSign: List - val type: TransactionType? val timeWindow: TimeWindow? + /** + * For privacy purposes, each part of a transaction should be accompanied by a nonce. + * To avoid storing a random number (nonce) per component, an initial "salt" is the sole value utilised, + * so that all component nonces are deterministically computed in the following way: + * nonce1 = H(salt || 1) + * nonce2 = H(salt || 2) + * + * Thus, all of the nonces are "independent" in the sense that knowing one or some of them, you can learn + * nothing about the rest. + */ + val privacySalt: PrivacySalt? /** * Returns a flattened list of all the components that are present in the transaction, in the following order: @@ -44,20 +68,20 @@ interface TraversableTransaction { * - Each output that is present * - Each command that is present * - The notary [Party], if present - * - Each required signer ([mustSign]) that is present - * - The type of the transaction, if present * - The time-window of the transaction, if present + * - The privacy salt required for nonces, always presented in [WireTransaction] and always null in [FilteredLeaves] */ val availableComponents: List + // NOTE: if the order below is altered or components are added/removed in the future, one should also reflect + // this change to the indexOffsets() method in WireTransaction. get() { // We may want to specify our own behaviour on certain tx fields. // Like if we include them at all, what to do with null values, if we treat list as one or not etc. for building // torn-off transaction and id calculation. val result = mutableListOf(inputs, attachments, outputs, commands).flatten().toMutableList() notary?.let { result += it } - result.addAll(mustSign) - type?.let { result += it } timeWindow?.let { result += it } + privacySalt?.let { result += it } return result } @@ -66,24 +90,36 @@ interface TraversableTransaction { * The root of the tree is the transaction identifier. The tree structure is helpful for privacy, please * see the user-guide section "Transaction tear-offs" to learn more about this topic. */ - val availableComponentHashes: List get() = availableComponents.map { serializedHash(it) } + val availableComponentHashes: List get() = availableComponents.mapIndexed { index, it -> serializedHash(it, privacySalt, index) } } /** * Class that holds filtered leaves for a partial Merkle transaction. We assume mixed leaf types, notice that every - * field from [WireTransaction] can be used in [PartialMerkleTree] calculation. + * field from [WireTransaction] can be used in [PartialMerkleTree] calculation, except for the privacySalt. + * A list of nonces is also required to (re)construct component hashes. */ @CordaSerializable class FilteredLeaves( override val inputs: List, override val attachments: List, override val outputs: List>, - override val commands: List, + override val commands: List>, override val notary: Party?, - override val mustSign: List, - override val type: TransactionType?, - override val timeWindow: TimeWindow? + override val timeWindow: TimeWindow?, + val nonces: List ) : TraversableTransaction { + + /** + * PrivacySalt should be always null for FilteredLeaves, because making it accidentally visible would expose all + * nonces (including filtered out components) causing privacy issues, see [serializedHash] and + * [TraversableTransaction.privacySalt]. + */ + override val privacySalt: PrivacySalt? get() = null + + init { + require(availableComponents.size == nonces.size) { "Each visible component should be accompanied by a nonce." } + } + /** * Function that checks the whole filtered structure. * Force type checking on a structure that we obtained, so we don't sign more than expected. @@ -97,6 +133,8 @@ class FilteredLeaves( val checkList = availableComponents.map { checkingFun(it) } return (!checkList.isEmpty()) && checkList.all { it } } + + override val availableComponentHashes: List get() = availableComponents.mapIndexed { index, it -> serializedHash(it, nonces[index]) } } /** diff --git a/core/src/main/kotlin/net/corda/core/transactions/NotaryChangeTransactions.kt b/core/src/main/kotlin/net/corda/core/transactions/NotaryChangeTransactions.kt new file mode 100644 index 0000000000..705e5553ca --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/transactions/NotaryChangeTransactions.kt @@ -0,0 +1,96 @@ +package net.corda.core.transactions + +import net.corda.core.contracts.* +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.TransactionSignature +import net.corda.core.crypto.toBase58String +import net.corda.core.identity.Party +import net.corda.core.node.ServiceHub +import java.security.PublicKey + +/** + * A special transaction for changing the notary of a state. It only needs specifying the state(s) as input(s), + * old and new notaries. Output states can be computed by applying the notary modification to corresponding inputs + * on the fly. + */ +data class NotaryChangeWireTransaction( + override val inputs: List, + override val notary: Party, + val newNotary: Party +) : CoreTransaction() { + /** + * This transaction does not contain any output states, outputs can be obtained by resolving a + * [NotaryChangeLedgerTransaction] and applying the notary modification to inputs. + */ + override val outputs: List> + get() = emptyList() + + init { + check(inputs.isNotEmpty()) { "A notary change transaction must have inputs" } + check(notary != newNotary) { "The old and new notaries must be different – $newNotary" } + } + + /** + * A privacy salt is not really required in this case, because we already used nonces in normal transactions and + * thus input state refs will always be unique. Also, filtering doesn't apply on this type of transactions. + */ + override val id: SecureHash by lazy { serializedHash(inputs + notary + newNotary) } + + fun resolve(services: ServiceHub, sigs: List): NotaryChangeLedgerTransaction { + val resolvedInputs = inputs.map { ref -> + services.loadState(ref).let { StateAndRef(it, ref) } + } + return NotaryChangeLedgerTransaction(resolvedInputs, notary, newNotary, id, sigs) + } +} + +/** + * A notary change transaction with fully resolved inputs and signatures. In contrast with a regular transaction, + * signatures are checked against the signers specified by input states' *participants* fields, so full resolution is + * needed for signature verification. + */ +data class NotaryChangeLedgerTransaction( + override val inputs: List>, + override val notary: Party, + val newNotary: Party, + override val id: SecureHash, + override val sigs: List +) : FullTransaction(), TransactionWithSignatures { + init { + checkEncumbrances() + } + + /** We compute the outputs on demand by applying the notary field modification to the inputs */ + override val outputs: List> + get() = inputs.mapIndexed { pos, (state) -> + if (state.encumbrance != null) { + state.copy(notary = newNotary, encumbrance = pos + 1) + } else state.copy(notary = newNotary) + } + + override val requiredSigningKeys: Set + get() = inputs.flatMap { it.state.data.participants }.map { it.owningKey }.toSet() + notary.owningKey + + override fun getKeyDescriptions(keys: Set): List { + return keys.map { it.toBase58String() } + } + + /** + * Check that encumbrances have been included in the inputs. The [NotaryChangeFlow] guarantees that an encumbrance + * will follow its encumbered state in the inputs. + */ + private fun checkEncumbrances() { + inputs.forEachIndexed { i, (state, ref) -> + state.encumbrance?.let { + val nextIndex = i + 1 + fun nextStateIsEncumbrance() = (inputs[nextIndex].ref.txhash == ref.txhash) && (inputs[nextIndex].ref.index == it) + if (nextIndex >= inputs.size || !nextStateIsEncumbrance()) { + throw TransactionVerificationException.TransactionMissingEncumbranceException( + id, + it, + TransactionVerificationException.Direction.INPUT) + } + } + } + } +} diff --git a/core/src/main/kotlin/net/corda/core/transactions/SignedTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/SignedTransaction.kt index 8a4453b2d6..c2afc3f2ec 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/SignedTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/SignedTransaction.kt @@ -1,16 +1,16 @@ package net.corda.core.transactions -import net.corda.core.contracts.AttachmentResolutionException -import net.corda.core.contracts.NamedByHash -import net.corda.core.contracts.TransactionResolutionException -import net.corda.core.contracts.TransactionVerificationException -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.isFulfilledBy -import net.corda.core.crypto.keys +import net.corda.core.contracts.* +import net.corda.core.crypto.* +import net.corda.core.identity.Party +import net.corda.core.internal.VisibleForTesting import net.corda.core.node.ServiceHub import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.core.utilities.getOrThrow +import java.security.KeyPair import java.security.PublicKey import java.security.SignatureException import java.util.* @@ -29,117 +29,84 @@ import java.util.* * sign. */ // DOCSTART 1 -data class SignedTransaction(val txBits: SerializedBytes, - val sigs: List -) : NamedByHash { -// DOCEND 1 - init { - require(sigs.isNotEmpty()) +data class SignedTransaction(val txBits: SerializedBytes, + override val sigs: List +) : TransactionWithSignatures { + // DOCEND 1 + constructor(ctx: CoreTransaction, sigs: List) : this(ctx.serialize(), sigs) { + cachedTransaction = ctx } - // TODO: This needs to be reworked to ensure that the inner WireTransaction is only ever deserialised sandboxed. + init { + require(sigs.isNotEmpty()) { "Tried to instantiate a ${SignedTransaction::class.java.simpleName} without any signatures " } + } + + /** Cache the deserialized form of the transaction. This is useful when building a transaction or collecting signatures. */ + @Volatile @Transient private var cachedTransaction: CoreTransaction? = null /** Lazily calculated access to the deserialised/hashed transaction data. */ - val tx: WireTransaction by lazy { WireTransaction.deserialize(txBits) } + private val transaction: CoreTransaction get() = cachedTransaction ?: txBits.deserialize().apply { cachedTransaction = this } - /** - * The Merkle root of the inner [WireTransaction]. Note that this is _not_ the same as the simple hash of - * [txBits], which would not use the Merkle tree structure. If the difference isn't clear, please consult - * the user guide section "Transaction tear-offs" to learn more about Merkle trees. - */ - override val id: SecureHash get() = tx.id + /** The id of the contained [WireTransaction]. */ + override val id: SecureHash get() = transaction.id - @CordaSerializable - class SignaturesMissingException(val missing: Set, val descriptions: List, override val id: SecureHash) : NamedByHash, SignatureException() { - override fun toString(): String { - return "Missing signatures for $descriptions on transaction ${id.prefixChars()} for ${missing.joinToString()}" - } - } + /** Returns the contained [WireTransaction], or throws if this is a notary change transaction */ + val tx: WireTransaction get() = transaction as WireTransaction - /** - * Verifies the signatures on this transaction and throws if any are missing which aren't passed as parameters. - * In this context, "verifying" means checking they are valid signatures and that their public keys are in - * the contained transactions [BaseTransaction.mustSign] property. - * - * Normally you would not provide any keys to this function, but if you're in the process of building a partial - * transaction and you want to access the contents before you've signed it, you can specify your own keys here - * to bypass that check. - * - * @throws SignatureException if any signatures are invalid or unrecognised. - * @throws SignaturesMissingException if any signatures should have been present but were not. - */ - // DOCSTART 2 - @Throws(SignatureException::class) - fun verifySignatures(vararg allowedToBeMissing: PublicKey): WireTransaction { - // DOCEND 2 - // Embedded WireTransaction is not deserialised until after we check the signatures. - checkSignaturesAreValid() + /** Returns the contained [NotaryChangeWireTransaction], or throws if this is a normal transaction */ + val notaryChangeTx: NotaryChangeWireTransaction get() = transaction as NotaryChangeWireTransaction - val missing = getMissingSignatures() - if (missing.isNotEmpty()) { - val allowed = allowedToBeMissing.toSet() - val needed = missing - allowed - if (needed.isNotEmpty()) - throw SignaturesMissingException(needed, getMissingKeyDescriptions(needed), id) - } - check(tx.id == id) - return tx - } + /** Helper to access the inputs of the contained transaction */ + val inputs: List get() = transaction.inputs + /** Helper to access the notary of the contained transaction */ + val notary: Party? get() = transaction.notary - /** - * Mathematically validates the signatures that are present on this transaction. This does not imply that - * the signatures are by the right keys, or that there are sufficient signatures, just that they aren't - * corrupt. If you use this function directly you'll need to do the other checks yourself. Probably you - * want [verifySignatures] instead. - * - * @throws SignatureException if a signature fails to verify. - */ - @Throws(SignatureException::class) - fun checkSignaturesAreValid() { - for (sig in sigs) { - sig.verify(id.bytes) - } - } + override val requiredSigningKeys: Set get() = tx.requiredSigningKeys - private fun getMissingSignatures(): Set { - val sigKeys = sigs.map { it.by }.toSet() - // TODO Problem is that we can get single PublicKey wrapped as CompositeKey in allowedToBeMissing/mustSign - // equals on CompositeKey won't catch this case (do we want to single PublicKey be equal to the same key wrapped in CompositeKey with threshold 1?) - val missing = tx.mustSign.filter { !it.isFulfilledBy(sigKeys) }.toSet() - return missing - } - - /** - * Get a human readable description of where signatures are required from, and are missing, to assist in debugging - * the underlying cause. - */ - private fun getMissingKeyDescriptions(missing: Set): ArrayList { + override fun getKeyDescriptions(keys: Set): ArrayList { // TODO: We need a much better way of structuring this data - val missingElements = ArrayList() + val descriptions = ArrayList() this.tx.commands.forEach { command -> - if (command.signers.any { it in missing }) - missingElements.add(command.toString()) + if (command.signers.any { it in keys }) + descriptions.add(command.toString()) } - if (this.tx.notary?.owningKey in missing) - missingElements.add("notary") - return missingElements + if (this.tx.notary?.owningKey in keys) + descriptions.add("notary") + return descriptions + } + + @VisibleForTesting + fun withAdditionalSignature(keyPair: KeyPair, signatureMetadata: SignatureMetadata): SignedTransaction { + val signableData = SignableData(tx.id, signatureMetadata) + return withAdditionalSignature(keyPair.sign(signableData)) } /** Returns the same transaction but with an additional (unchecked) signature. */ - fun withAdditionalSignature(sig: DigitalSignature.WithKey) = copy(sigs = sigs + sig) + fun withAdditionalSignature(sig: TransactionSignature) = copyWithCache(listOf(sig)) /** Returns the same transaction but with an additional (unchecked) signatures. */ - fun withAdditionalSignatures(sigList: Iterable) = copy(sigs = sigs + sigList) - - /** Alias for [withAdditionalSignature] to let you use Kotlin operator overloading. */ - operator fun plus(sig: DigitalSignature.WithKey) = withAdditionalSignature(sig) - - /** Alias for [withAdditionalSignatures] to let you use Kotlin operator overloading. */ - operator fun plus(sigList: Collection) = withAdditionalSignatures(sigList) + fun withAdditionalSignatures(sigList: Iterable) = copyWithCache(sigList) /** - * Checks the transaction's signatures are valid, optionally calls [verifySignatures] to check - * all required signatures are present, and then calls [WireTransaction.toLedgerTransaction] + * Creates a copy of the SignedTransaction that includes the provided [sigList]. Also propagates the [cachedTransaction] + * so the contained transaction does not need to be deserialized again. + */ + private fun copyWithCache(sigList: Iterable): SignedTransaction { + val cached = cachedTransaction + return copy(sigs = sigs + sigList).apply { + cachedTransaction = cached + } + } + + /** Alias for [withAdditionalSignature] to let you use Kotlin operator overloading. */ + operator fun plus(sig: TransactionSignature) = withAdditionalSignature(sig) + + /** Alias for [withAdditionalSignatures] to let you use Kotlin operator overloading. */ + operator fun plus(sigList: Collection) = withAdditionalSignatures(sigList) + + /** + * Checks the transaction's signatures are valid, optionally calls [verifyRequiredSignatures] to + * check all required signatures are present, and then calls [WireTransaction.toLedgerTransaction] * with the passed in [ServiceHub] to resolve the dependencies, returning an unverified * LedgerTransaction. * @@ -156,15 +123,14 @@ data class SignedTransaction(val txBits: SerializedBytes, @Throws(SignatureException::class, AttachmentResolutionException::class, TransactionResolutionException::class) fun toLedgerTransaction(services: ServiceHub, checkSufficientSignatures: Boolean = true): LedgerTransaction { checkSignaturesAreValid() - if (checkSufficientSignatures) verifySignatures() + if (checkSufficientSignatures) verifyRequiredSignatures() return tx.toLedgerTransaction(services) } /** - * Checks the transaction's signatures are valid, optionally calls [verifySignatures] to check - * all required signatures are present, calls [WireTransaction.toLedgerTransaction] with the - * passed in [ServiceHub] to resolve the dependencies and return an unverified - * LedgerTransaction, then verifies the LedgerTransaction. + * Checks the transaction's signatures are valid, optionally calls [verifyRequiredSignatures] to check + * all required signatures are present. Resolves inputs and attachments from the local storage and performs full + * transaction verification, including running the contracts. * * @throws AttachmentResolutionException if a required attachment was not found in storage. * @throws TransactionResolutionException if an input points to a transaction not found in storage. @@ -174,10 +140,41 @@ data class SignedTransaction(val txBits: SerializedBytes, @JvmOverloads @Throws(SignatureException::class, AttachmentResolutionException::class, TransactionResolutionException::class, TransactionVerificationException::class) fun verify(services: ServiceHub, checkSufficientSignatures: Boolean = true) { + if (isNotaryChangeTransaction()) { + verifyNotaryChangeTransaction(checkSufficientSignatures, services) + } else { + verifyRegularTransaction(checkSufficientSignatures, services) + } + } + + private fun verifyRegularTransaction(checkSufficientSignatures: Boolean, services: ServiceHub) { checkSignaturesAreValid() - if (checkSufficientSignatures) verifySignatures() - tx.toLedgerTransaction(services).verify() + if (checkSufficientSignatures) verifyRequiredSignatures() + val ltx = tx.toLedgerTransaction(services) + // TODO: allow non-blocking verification + services.transactionVerifierService.verify(ltx).getOrThrow() + } + + private fun verifyNotaryChangeTransaction(checkSufficientSignatures: Boolean, services: ServiceHub) { + val ntx = resolveNotaryChangeTransaction(services) + if (checkSufficientSignatures) ntx.verifyRequiredSignatures() + } + + fun isNotaryChangeTransaction() = transaction is NotaryChangeWireTransaction + + /** + * If [transaction] is a [NotaryChangeWireTransaction], loads the input states and resolves it to a + * [NotaryChangeLedgerTransaction] so the signatures can be verified. + */ + fun resolveNotaryChangeTransaction(services: ServiceHub): NotaryChangeLedgerTransaction { + val ntx = transaction as? NotaryChangeWireTransaction + ?: throw IllegalStateException("Expected a ${NotaryChangeWireTransaction::class.simpleName} but found ${transaction::class.simpleName}") + return ntx.resolve(services, sigs) } override fun toString(): String = "${javaClass.simpleName}(id=$id)" + + @CordaSerializable + class SignaturesMissingException(val missing: Set, val descriptions: List, override val id: SecureHash) + : NamedByHash, SignatureException("Missing signatures for $descriptions on transaction ${id.prefixChars()} for ${missing.joinToString()}") } diff --git a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt index 61a42d2dbb..5e10965370 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt @@ -6,10 +6,9 @@ import net.corda.core.crypto.* import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine import net.corda.core.node.ServiceHub -import net.corda.core.serialization.serialize +import net.corda.core.node.services.KeyManagementService import java.security.KeyPair import java.security.PublicKey -import java.security.SignatureException import java.time.Duration import java.time.Instant import java.util.* @@ -25,35 +24,30 @@ import java.util.* * @param notary Notary used for the transaction. If null, this indicates the transaction DOES NOT have a notary. * When this is set to a non-null value, an output state can be added by just passing in a [ContractState] – a * [TransactionState] with this notary specified will be generated automatically. - * - * @param signers The set of public keys the transaction needs signatures for. The logic for building the signers set - * can be customised for every [TransactionType]. E.g. in the general case it contains the command and notary public keys, - * but for the [TransactionType.NotaryChange] transactions it is the set of all input [ContractState.participants]. */ open class TransactionBuilder( - protected val type: TransactionType = TransactionType.General, var notary: Party? = null, var lockId: UUID = (Strand.currentStrand() as? FlowStateMachine<*>)?.id?.uuid ?: UUID.randomUUID(), protected val inputs: MutableList = arrayListOf(), protected val attachments: MutableList = arrayListOf(), protected val outputs: MutableList> = arrayListOf(), - protected val commands: MutableList = arrayListOf(), - protected val signers: MutableSet = mutableSetOf(), - protected var window: TimeWindow? = null) { - constructor(type: TransactionType, notary: Party) : this(type, notary, (Strand.currentStrand() as? FlowStateMachine<*>)?.id?.uuid ?: UUID.randomUUID()) + protected val commands: MutableList> = arrayListOf(), + protected var window: TimeWindow? = null, + protected var privacySalt: PrivacySalt = PrivacySalt() + ) { + constructor(notary: Party) : this (notary, (Strand.currentStrand() as? FlowStateMachine<*>)?.id?.uuid ?: UUID.randomUUID()) /** * Creates a copy of the builder. */ fun copy() = TransactionBuilder( - type = type, notary = notary, inputs = ArrayList(inputs), attachments = ArrayList(attachments), outputs = ArrayList(outputs), commands = ArrayList(commands), - signers = LinkedHashSet(signers), - window = window + window = window, + privacySalt = privacySalt ) // DOCSTART 1 @@ -65,9 +59,10 @@ open class TransactionBuilder( is SecureHash -> addAttachment(t) is TransactionState<*> -> addOutputState(t) is ContractState -> addOutputState(t) - is Command -> addCommand(t) + is Command<*> -> addCommand(t) is CommandData -> throw IllegalArgumentException("You passed an instance of CommandData, but that lacks the pubkey. You need to wrap it in a Command object first.") is TimeWindow -> setTimeWindow(t) + is PrivacySalt -> setPrivacySalt(t) else -> throw IllegalArgumentException("Wrong argument type: ${t.javaClass}") } } @@ -76,7 +71,7 @@ open class TransactionBuilder( // DOCEND 1 fun toWireTransaction() = WireTransaction(ArrayList(inputs), ArrayList(attachments), - ArrayList(outputs), ArrayList(commands), notary, signers.toList(), type, window) + ArrayList(outputs), ArrayList(commands), notary, window, privacySalt) @Throws(AttachmentResolutionException::class, TransactionResolutionException::class) fun toLedgerTransaction(services: ServiceHub) = toWireTransaction().toLedgerTransaction(services) @@ -89,7 +84,6 @@ open class TransactionBuilder( open fun addInputState(stateAndRef: StateAndRef<*>): TransactionBuilder { val notary = stateAndRef.state.notary require(notary == this.notary) { "Input state requires notary \"$notary\" which does not match the transaction notary \"${this.notary}\"." } - signers.add(notary.owningKey) inputs.add(stateAndRef.ref) return this } @@ -105,7 +99,9 @@ open class TransactionBuilder( } @JvmOverloads - fun addOutputState(state: ContractState, notary: Party, encumbrance: Int? = null) = addOutputState(TransactionState(state, notary, encumbrance)) + fun addOutputState(state: ContractState, notary: Party, encumbrance: Int? = null): TransactionBuilder { + return addOutputState(TransactionState(state, notary, encumbrance)) + } /** A default notary must be specified during builder construction to use this method */ fun addOutputState(state: ContractState): TransactionBuilder { @@ -114,9 +110,7 @@ open class TransactionBuilder( return this } - fun addCommand(arg: Command): TransactionBuilder { - // TODO: replace pubkeys in commands with 'pointers' to keys in signers - signers.addAll(arg.signers) + fun addCommand(arg: Command<*>): TransactionBuilder { commands.add(arg) return this } @@ -131,7 +125,6 @@ open class TransactionBuilder( */ fun setTimeWindow(timeWindow: TimeWindow): TransactionBuilder { check(notary != null) { "Only notarised transactions can have a time-window" } - signers.add(notary!!.owningKey) window = timeWindow return this } @@ -145,64 +138,25 @@ open class TransactionBuilder( */ fun setTimeWindow(time: Instant, timeTolerance: Duration) = setTimeWindow(TimeWindow.withTolerance(time, timeTolerance)) + fun setPrivacySalt(privacySalt: PrivacySalt): TransactionBuilder { + this.privacySalt = privacySalt + return this + } + // Accessors that yield immutable snapshots. fun inputStates(): List = ArrayList(inputs) fun attachments(): List = ArrayList(attachments) fun outputStates(): List> = ArrayList(outputs) - fun commands(): List = ArrayList(commands) + fun commands(): List> = ArrayList(commands) - /** The signatures that have been collected so far - might be incomplete! */ - @Deprecated("Signatures should be gathered on a SignedTransaction instead.") - protected val currentSigs = arrayListOf() - - @Deprecated("Use ServiceHub.signInitialTransaction() instead.") - fun signWith(key: KeyPair): TransactionBuilder { - val data = toWireTransaction().id - addSignatureUnchecked(key.sign(data.bytes)) - return this - } - - /** Adds the signature directly to the transaction, without checking it for validity. */ - @Deprecated("Use ServiceHub.signInitialTransaction() instead.") - fun addSignatureUnchecked(sig: DigitalSignature.WithKey): TransactionBuilder { - currentSigs.add(sig) - return this - } - - @Deprecated("Use ServiceHub.signInitialTransaction() instead.") - fun toSignedTransaction(checkSufficientSignatures: Boolean = true): SignedTransaction { - if (checkSufficientSignatures) { - val gotKeys = currentSigs.map { it.by }.toSet() - val missing: Set = signers.filter { !it.isFulfilledBy(gotKeys) }.toSet() - if (missing.isNotEmpty()) - throw IllegalStateException("Missing signatures on the transaction for the public keys: ${missing.joinToString()}") - } + /** + * Sign the built transaction and return it. This is an internal function for use by the service hub, please use + * [ServiceHub.signInitialTransaction] instead. + */ + fun toSignedTransaction(keyManagementService: KeyManagementService, publicKey: PublicKey, signatureMetadata: SignatureMetadata): SignedTransaction { val wtx = toWireTransaction() - return SignedTransaction(wtx.serialize(), ArrayList(currentSigs)) - } - - /** - * Checks that the given signature matches one of the commands and that it is a correct signature over the tx, then - * adds it. - * - * @throws SignatureException if the signature didn't match the transaction contents. - * @throws IllegalArgumentException if the signature key doesn't appear in any command. - */ - @Deprecated("Use WireTransaction.checkSignature() instead.") - fun checkAndAddSignature(sig: DigitalSignature.WithKey) { - checkSignature(sig) - addSignatureUnchecked(sig) - } - - /** - * Checks that the given signature matches one of the commands and that it is a correct signature over the tx. - * - * @throws SignatureException if the signature didn't match the transaction contents. - * @throws IllegalArgumentException if the signature key doesn't appear in any command. - */ - @Deprecated("Use WireTransaction.checkSignature() instead.") - fun checkSignature(sig: DigitalSignature.WithKey) { - require(commands.any { it.signers.any { sig.by in it.keys } }) { "Signature key doesn't match any command" } - sig.verify(toWireTransaction().id) + val signableData = SignableData(wtx.id, signatureMetadata) + val sig = keyManagementService.sign(signableData, publicKey) + return SignedTransaction(wtx, listOf(sig)) } } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/TransactionWithSignatures.kt b/core/src/main/kotlin/net/corda/core/transactions/TransactionWithSignatures.kt new file mode 100644 index 0000000000..5bc5f3ce0c --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/transactions/TransactionWithSignatures.kt @@ -0,0 +1,81 @@ +package net.corda.core.transactions + +import net.corda.core.contracts.NamedByHash +import net.corda.core.crypto.TransactionSignature +import net.corda.core.crypto.isFulfilledBy +import net.corda.core.transactions.SignedTransaction.SignaturesMissingException +import net.corda.core.utilities.toNonEmptySet +import java.security.InvalidKeyException +import java.security.PublicKey +import java.security.SignatureException + +/** An interface for transactions containing signatures, with logic for signature verification */ +interface TransactionWithSignatures : NamedByHash { + val sigs: List + + /** Specifies all the public keys that require signatures for the transaction to be valid */ + val requiredSigningKeys: Set + + /** + * Verifies the signatures on this transaction and throws if any are missing. In this context, "verifying" means + * checking they are valid signatures and that their public keys are in the [requiredSigningKeys] set. + * + * @throws SignatureException if any signatures are invalid or unrecognised. + * @throws SignaturesMissingException if any signatures should have been present but were not. + */ + @Throws(SignatureException::class) + fun verifyRequiredSignatures() = verifySignaturesExcept() + + /** + * Verifies the signatures on this transaction and throws if any are missing which aren't passed as parameters. + * In this context, "verifying" means checking they are valid signatures and that their public keys are in + * the [requiredSigningKeys] set. + * + * Normally you would not provide any keys to this function, but if you're in the process of building a partial + * transaction and you want to access the contents before you've signed it, you can specify your own keys here + * to bypass that check. + * + * @throws SignatureException if any signatures are invalid or unrecognised. + * @throws SignaturesMissingException if any signatures should have been present but were not. + */ + @Throws(SignatureException::class) + fun verifySignaturesExcept(vararg allowedToBeMissing: PublicKey) { + checkSignaturesAreValid() + + val needed = getMissingSignatures() - allowedToBeMissing + if (needed.isNotEmpty()) + throw SignaturesMissingException(needed.toNonEmptySet(), getKeyDescriptions(needed), id) + } + + /** + * Mathematically validates the signatures that are present on this transaction. This does not imply that + * the signatures are by the right keys, or that there are sufficient signatures, just that they aren't + * corrupt. If you use this function directly you'll need to do the other checks yourself. Probably you + * want [verifySignatures] instead. + * + * @throws InvalidKeyException if the key on a signature is invalid. + * @throws SignatureException if a signature fails to verify. + */ + @Throws(InvalidKeyException::class, SignatureException::class) + fun checkSignaturesAreValid() { + for (sig in sigs) { + sig.verify(id) + } + } + + /** + * Get a human readable description of where signatures are required from, and are missing, to assist in debugging + * the underlying cause. + * + * Note that the results should not be serialised, parsed or expected to remain stable between Corda versions. + */ + fun getKeyDescriptions(keys: Set): List + + private fun getMissingSignatures(): Set { + val sigKeys = sigs.map { it.by }.toSet() + // TODO Problem is that we can get single PublicKey wrapped as CompositeKey in allowedToBeMissing/mustSign + // equals on CompositeKey won't catch this case (do we want to single PublicKey be equal to the same key wrapped in CompositeKey with threshold 1?) + val missing = requiredSigningKeys.filter { !it.isFulfilledBy(sigKeys) }.toSet() + return missing + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt index 684ad86f9b..970ca4d728 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/WireTransaction.kt @@ -1,70 +1,54 @@ package net.corda.core.transactions -import com.esotericsoftware.kryo.pool.KryoPool import net.corda.core.contracts.* -import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.MerkleTree import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.keys import net.corda.core.identity.Party -import net.corda.core.indexOfOrThrow +import net.corda.core.internal.Emoji import net.corda.core.node.ServicesForResolution -import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.p2PKryo -import net.corda.core.serialization.serialize -import net.corda.core.utilities.Emoji import java.security.PublicKey import java.security.SignatureException import java.util.function.Predicate /** * A transaction ready for serialisation, without any signatures attached. A WireTransaction is usually wrapped - * by a [SignedTransaction] that carries the signatures over this payload. The hash of the wire transaction is - * the identity of the transaction, that is, it's possible for two [SignedTransaction]s with different sets of - * signatures to have the same identity hash. + * by a [SignedTransaction] that carries the signatures over this payload. + * The identity of the transaction is the Merkle tree root of its components (see [MerkleTree]). */ -class WireTransaction( +data class WireTransaction( /** Pointers to the input states on the ledger, identified by (tx identity hash, output index). */ override val inputs: List, /** Hashes of the ZIP/JAR files that are needed to interpret the contents of this wire transaction. */ override val attachments: List, - outputs: List>, + override val outputs: List>, /** Ordered list of ([CommandData], [PublicKey]) pairs that instruct the contracts what to do. */ - override val commands: List, - notary: Party?, - signers: List, - type: TransactionType, - timeWindow: TimeWindow? -) : BaseTransaction(inputs, outputs, notary, signers, type, timeWindow), TraversableTransaction { + override val commands: List>, + override val notary: Party?, + override val timeWindow: TimeWindow?, + override val privacySalt: PrivacySalt = PrivacySalt() +) : CoreTransaction(), TraversableTransaction { init { - checkInvariants() + checkBaseInvariants() + if (timeWindow != null) check(notary != null) { "Transactions with time-windows must be notarised" } + check(availableComponents.isNotEmpty()) { "A WireTransaction cannot be empty" } } - // Cache the serialised form of the transaction and its hash to give us fast access to it. - @Volatile @Transient private var cachedBytes: SerializedBytes? = null - val serialized: SerializedBytes get() = cachedBytes ?: serialize().apply { cachedBytes = this } + /** The transaction id is represented by the root hash of Merkle tree over the transaction components. */ + override val id: SecureHash get() = merkleTree.hash - override val id: SecureHash by lazy { merkleTree.hash } - - companion object { - fun deserialize(data: SerializedBytes, kryo: KryoPool = p2PKryo()): WireTransaction { - val wtx = data.bytes.deserialize(kryo) - wtx.cachedBytes = data - return wtx + /** Public keys that need to be fulfilled by signatures in order for the transaction to be valid. */ + val requiredSigningKeys: Set get() { + val commandKeys = commands.flatMap { it.signers }.toSet() + // TODO: prevent notary field from being set if there are no inputs and no timestamp + return if (notary != null && (inputs.isNotEmpty() || timeWindow != null)) { + commandKeys + notary.owningKey + } else { + commandKeys } } - /** Returns a [StateAndRef] for the given output index. */ - @Suppress("UNCHECKED_CAST") - fun outRef(index: Int): StateAndRef { - require(index >= 0 && index < outputs.size) - return StateAndRef(outputs[index] as TransactionState, StateRef(id, index)) - } - - /** Returns a [StateAndRef] for the requested output state, or throws [IllegalArgumentException] if not found. */ - fun outRef(state: ContractState): StateAndRef = outRef(outputs.map { it.data }.indexOfOrThrow(state)) - /** * Looks up identities and attachments from storage to generate a [LedgerTransaction]. A transaction is expected to * have been fully resolved using the resolution flow by this point. @@ -104,7 +88,7 @@ class WireTransaction( val resolvedInputs = inputs.map { ref -> resolveStateRef(ref)?.let { StateAndRef(it, ref) } ?: throw TransactionResolutionException(ref.txhash) } - return LedgerTransaction(resolvedInputs, outputs, authenticatedArgs, attachments, id, notary, mustSign, timeWindow, type) + return LedgerTransaction(resolvedInputs, outputs, authenticatedArgs, attachments, id, notary, timeWindow, privacySalt) } /** @@ -121,30 +105,74 @@ class WireTransaction( /** * Construction of partial transaction from WireTransaction based on filtering. + * Note that list of nonces to be sent is updated on the fly, based on the index of the filtered tx component. * @param filtering filtering over the whole WireTransaction * @returns FilteredLeaves used in PartialMerkleTree calculation and verification. */ fun filterWithFun(filtering: Predicate): FilteredLeaves { - fun notNullFalse(elem: Any?): Any? = if (elem == null || !filtering.test(elem)) null else elem + val nonces: MutableList = mutableListOf() + val offsets = indexOffsets() + fun notNullFalseAndNoncesUpdate(elem: Any?, index: Int): Any? { + return if (elem == null || !filtering.test(elem)) { + null + } else { + nonces.add(computeNonce(privacySalt, index)) + elem + } + } + + fun filterAndNoncesUpdate(t: T, index: Int): Boolean { + return if (filtering.test(t)) { + nonces.add(computeNonce(privacySalt, index)) + true + } else { + false + } + } + + // TODO: We should have a warning (require) if all leaves (excluding salt) are visible after filtering. + // Consider the above after refactoring FilteredTransaction to implement TraversableTransaction, + // so that a WireTransaction can be used when required to send a full tx (e.g. RatesFixFlow in Oracles). return FilteredLeaves( - inputs.filter { filtering.test(it) }, - attachments.filter { filtering.test(it) }, - outputs.filter { filtering.test(it) }, - commands.filter { filtering.test(it) }, - notNullFalse(notary) as Party?, - mustSign.filter { filtering.test(it) }, - notNullFalse(type) as TransactionType?, - notNullFalse(timeWindow) as TimeWindow? + inputs.filterIndexed { index, it -> filterAndNoncesUpdate(it, index) }, + attachments.filterIndexed { index, it -> filterAndNoncesUpdate(it, index + offsets[0]) }, + outputs.filterIndexed { index, it -> filterAndNoncesUpdate(it, index + offsets[1]) }, + commands.filterIndexed { index, it -> filterAndNoncesUpdate(it, index + offsets[2]) }, + notNullFalseAndNoncesUpdate(notary, offsets[3]) as Party?, + notNullFalseAndNoncesUpdate(timeWindow, offsets[4]) as TimeWindow?, + nonces ) } + // We use index offsets, to get the actual leaf-index per transaction component required for nonce computation. + private fun indexOffsets(): List { + // There is no need to add an index offset for inputs, because they are the first components in the + // transaction format and it is always zero. Thus, offsets[0] corresponds to attachments, + // offsets[1] to outputs, offsets[2] to commands and so on. + val offsets = mutableListOf(inputs.size, inputs.size + attachments.size) + offsets.add(offsets.last() + outputs.size) + offsets.add(offsets.last() + commands.size) + if (notary != null) { + offsets.add(offsets.last() + 1) + } else { + offsets.add(offsets.last()) + } + if (timeWindow != null) { + offsets.add(offsets.last() + 1) + } else { + offsets.add(offsets.last()) + } + // No need to add offset for privacySalt as it doesn't require a nonce. + return offsets + } + /** * Checks that the given signature matches one of the commands and that it is a correct signature over the tx. * * @throws SignatureException if the signature didn't match the transaction contents. * @throws IllegalArgumentException if the signature key doesn't appear in any command. */ - fun checkSignature(sig: DigitalSignature.WithKey) { + fun checkSignature(sig: TransactionSignature) { require(commands.any { it.signers.any { sig.by in it.keys } }) { "Signature key doesn't match any command" } sig.verify(id) } @@ -153,35 +181,9 @@ class WireTransaction( val buf = StringBuilder() buf.appendln("Transaction:") for (input in inputs) buf.appendln("${Emoji.rightArrow}INPUT: $input") - for (output in outputs) buf.appendln("${Emoji.leftArrow}OUTPUT: ${output.data}") + for ((data) in outputs) buf.appendln("${Emoji.leftArrow}OUTPUT: $data") for (command in commands) buf.appendln("${Emoji.diamond}COMMAND: $command") for (attachment in attachments) buf.appendln("${Emoji.paperclip}ATTACHMENT: $attachment") return buf.toString() } - - // TODO: When Kotlin 1.1 comes out we can make this class a data class again, and have these be autogenerated. - - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other?.javaClass != javaClass) return false - if (!super.equals(other)) return false - - other as WireTransaction - - if (inputs != other.inputs) return false - if (attachments != other.attachments) return false - if (outputs != other.outputs) return false - if (commands != other.commands) return false - - return true - } - - override fun hashCode(): Int { - var result = super.hashCode() - result = 31 * result + inputs.hashCode() - result = 31 * result + attachments.hashCode() - result = 31 * result + outputs.hashCode() - result = 31 * result + commands.hashCode() - return result - } } diff --git a/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt b/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt index 3102086b43..a43f86d4ab 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/ByteArrays.kt @@ -2,42 +2,146 @@ package net.corda.core.utilities -import com.google.common.io.BaseEncoding import net.corda.core.serialization.CordaSerializable import java.io.ByteArrayInputStream -import java.util.* +import javax.xml.bind.DatatypeConverter + +/** + * An abstraction of a byte array, with offset and size that does no copying of bytes unless asked to. + * + * The data of interest typically starts at position [offset] within the [bytes] and is [size] bytes long. + */ +@CordaSerializable +sealed class ByteSequence : Comparable { + /** + * The underlying bytes. + */ + abstract val bytes: ByteArray + /** + * The number of bytes this sequence represents. + */ + abstract val size: Int + /** + * The start position of the sequence within the byte array. + */ + abstract val offset: Int + /** Returns a [ByteArrayInputStream] of the bytes */ + fun open() = ByteArrayInputStream(bytes, offset, size) + + /** + * Create a sub-sequence backed by the same array. + * + * @param offset The offset within this sequence to start the new sequence. Note: not the offset within the backing array. + * @param size The size of the intended sub sequence. + */ + fun subSequence(offset: Int, size: Int): ByteSequence { + require(offset >= 0) + require(offset + size <= this.size) + return if (offset == 0 && size == this.size) this else of(bytes, this.offset + offset, size) + } + + companion object { + /** + * Construct a [ByteSequence] given a [ByteArray] and optional offset and size, that represents that potentially + * sub-sequence of bytes. The returned implementation is optimised when the whole [ByteArray] is the sequence. + */ + @JvmStatic + @JvmOverloads + fun of(bytes: ByteArray, offset: Int = 0, size: Int = bytes.size): ByteSequence { + return if (offset == 0 && size == bytes.size && size != 0) OpaqueBytes(bytes) else OpaqueBytesSubSequence(bytes, offset, size) + } + } + + /** + * Take the first n bytes of this sequence as a sub-sequence. See [subSequence] for further semantics. + */ + fun take(n: Int): ByteSequence { + require(size >= n) + return subSequence(0, n) + } + + /** + * Copy this sequence, complete with new backing array. This can be helpful to break references to potentially + * large backing arrays from small sub-sequences. + */ + fun copy(): ByteSequence = of(bytes.copyOfRange(offset, offset + size)) + + /** + * Compare byte arrays byte by byte. Arrays that are shorter are deemed less than longer arrays if all the bytes + * of the shorter array equal those in the same position of the longer array. + */ + override fun compareTo(other: ByteSequence): Int { + val min = minOf(this.size, other.size) + // Compare min bytes + for (index in 0 until min) { + val unsignedThis = java.lang.Byte.toUnsignedInt(this.bytes[this.offset + index]) + val unsignedOther = java.lang.Byte.toUnsignedInt(other.bytes[other.offset + index]) + if (unsignedThis != unsignedOther) { + return Integer.signum(unsignedThis - unsignedOther) + } + } + // First min bytes is the same, so now resort to size + return Integer.signum(this.size - other.size) + } + + override fun equals(other: Any?): Boolean { + if (this === other) return true + if (other !is ByteSequence) return false + if (this.size != other.size) return false + return subArraysEqual(this.bytes, this.offset, this.size, other.bytes, other.offset) + } + + private fun subArraysEqual(a: ByteArray, aOffset: Int, length: Int, b: ByteArray, bOffset: Int): Boolean { + var bytesRemaining = length + var aPos = aOffset + var bPos = bOffset + while (bytesRemaining-- > 0) { + if (a[aPos++] != b[bPos++]) return false + } + return true + } + + override fun hashCode(): Int { + var result = 1 + for (index in offset until (offset + size)) { + result = 31 * result + bytes[index] + } + return result + } + + override fun toString(): String = "[${bytes.copyOfRange(offset, offset + size).toHexString()}]" +} /** * A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect. * In an ideal JVM this would be a value type and be completely overhead free. Project Valhalla is adding such * functionality to Java, but it won't arrive for a few years yet! */ -@CordaSerializable -open class OpaqueBytes(val bytes: ByteArray) { +open class OpaqueBytes(override val bytes: ByteArray) : ByteSequence() { companion object { @JvmStatic fun of(vararg b: Byte) = OpaqueBytes(byteArrayOf(*b)) } init { - check(bytes.isNotEmpty()) + require(bytes.isNotEmpty()) } - override fun equals(other: Any?): Boolean { - if (this === other) return true - if (other !is OpaqueBytes) return false - return Arrays.equals(bytes, other.bytes) - } - - override fun hashCode() = Arrays.hashCode(bytes) - override fun toString() = "[" + bytes.toHexString() + "]" - - val size: Int get() = bytes.size - - /** Returns a [ByteArrayInputStream] of the bytes */ - fun open() = ByteArrayInputStream(bytes) + override val size: Int get() = bytes.size + override val offset: Int get() = 0 } +@Deprecated("Use sequence instead") fun ByteArray.opaque(): OpaqueBytes = OpaqueBytes(this) -fun ByteArray.toHexString(): String = BaseEncoding.base16().encode(this) -fun String.parseAsHex(): ByteArray = BaseEncoding.base16().decode(this) + +fun ByteArray.sequence(offset: Int = 0, size: Int = this.size) = ByteSequence.of(this, offset, size) + +fun ByteArray.toHexString(): String = DatatypeConverter.printHexBinary(this) +fun String.parseAsHex(): ByteArray = DatatypeConverter.parseHexBinary(this) + +private class OpaqueBytesSubSequence(override val bytes: ByteArray, override val offset: Int, override val size: Int) : ByteSequence() { + init { + require(offset >= 0 && offset < bytes.size) + require(size >= 0 && size <= bytes.size) + } +} diff --git a/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt b/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt index e6e656a199..16f29be8b7 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt @@ -1,7 +1,26 @@ package net.corda.core.utilities +import net.corda.core.internal.concurrent.get +import net.corda.core.serialization.CordaSerializable import org.slf4j.Logger import org.slf4j.LoggerFactory +import java.time.Duration +import java.util.concurrent.ExecutionException +import java.util.concurrent.Future +import kotlin.reflect.KProperty + +// +// READ ME FIRST: +// This is a collection of public utilities useful only for Kotlin code. Think carefully before adding anything here and +// make sure it's tested and documented. If you're looking to add a public utility that is also relevant to Java then +// don't put it here but in a seperate file called Utils.kt +// + +/** Like the + operator but throws [ArithmeticException] in case of integer overflow. */ +infix fun Int.exactAdd(b: Int): Int = Math.addExact(this, b) + +/** Like the + operator but throws [ArithmeticException] in case of integer overflow. */ +infix fun Long.exactAdd(b: Long): Long = Math.addExact(this, b) /** * Get the [Logger] for a class using the syntax @@ -18,4 +37,69 @@ inline fun Logger.trace(msg: () -> String) { /** Log a DEBUG level message produced by evaluating the given lamdba, but only if DEBUG logging is enabled. */ inline fun Logger.debug(msg: () -> String) { if (isDebugEnabled) debug(msg()) -} \ No newline at end of file +} + +/** + * Extension method for easier construction of [Duration]s in terms of integer days: `val twoDays = 2.days`. + * @see Duration.ofDays + */ +inline val Int.days: Duration get() = Duration.ofDays(toLong()) + +/** + * Extension method for easier construction of [Duration]s in terms of integer hours: `val twoHours = 2.hours`. + * @see Duration.ofHours + */ +inline val Int.hours: Duration get() = Duration.ofHours(toLong()) + +/** + * Extension method for easier construction of [Duration]s in terms of integer minutes: `val twoMinutes = 2.minutes`. + * @see Duration.ofMinutes + */ +inline val Int.minutes: Duration get() = Duration.ofMinutes(toLong()) + +/** + * Extension method for easier construction of [Duration]s in terms of integer seconds: `val twoSeconds = 2.seconds`. + * @see Duration.ofSeconds + */ +inline val Int.seconds: Duration get() = Duration.ofSeconds(toLong()) + +/** + * Extension method for easier construction of [Duration]s in terms of integer milliseconds: `val twoMillis = 2.millis`. + * @see Duration.ofMillis + */ +inline val Int.millis: Duration get() = Duration.ofMillis(toLong()) + +/** + * A simple wrapper that enables the use of Kotlin's `val x by transient { ... }` syntax. Such a property + * will not be serialized, and if it's missing (or the first time it's accessed), the initializer will be + * used to set it up. + */ +@Suppress("DEPRECATION") +fun transient(initializer: () -> T) = TransientProperty(initializer) + +@Deprecated("Use transient") +@CordaSerializable +class TransientProperty(private val initialiser: () -> T) { + @Transient private var initialised = false + @Transient private var value: T? = null + + @Suppress("UNCHECKED_CAST") + @Synchronized + operator fun getValue(thisRef: Any?, property: KProperty<*>): T { + if (!initialised) { + value = initialiser() + initialised = true + } + return value as T + } +} + +/** @see NonEmptySet.copyOf */ +fun Collection.toNonEmptySet(): NonEmptySet = NonEmptySet.copyOf(this) + +/** Same as [Future.get] except that the [ExecutionException] is unwrapped. */ +fun Future.getOrThrow(timeout: Duration? = null): V = try { + get(timeout) +} catch (e: ExecutionException) { + throw e.cause!! +} diff --git a/core/src/main/kotlin/net/corda/core/utilities/LegalNameValidator.kt b/core/src/main/kotlin/net/corda/core/utilities/LegalNameValidator.kt index 9ade89dfaf..bea9a6241d 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/LegalNameValidator.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/LegalNameValidator.kt @@ -113,7 +113,7 @@ private class X500NameRule : Rule { private class MustHaveAtLeastTwoLettersRule : Rule { override fun validate(legalName: String) { // Try to exclude names like "/", "£", "X" etc. - require(legalName.count { it.isLetter() } >= 3) { "Illegal input legal name '$legalName'. Legal name must have at least two letters" } + require(legalName.count { it.isLetter() } >= 2) { "Illegal input legal name '$legalName'. Legal name must have at least two letters" } } } diff --git a/core/src/main/kotlin/net/corda/core/utilities/NonEmptySet.kt b/core/src/main/kotlin/net/corda/core/utilities/NonEmptySet.kt index 13920ba788..9e364769e2 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/NonEmptySet.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/NonEmptySet.kt @@ -1,117 +1,62 @@ package net.corda.core.utilities -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output import java.util.* +import java.util.function.Consumer +import java.util.stream.Stream /** - * A set which is constrained to ensure it can never be empty. An initial value must be provided at - * construction, and attempting to remove the last element will cause an IllegalStateException. - * The underlying set is exposed for Kryo to access, but should not be accessed directly. + * An immutable ordered non-empty set. */ -class NonEmptySet(initial: T) : MutableSet { - private val set: MutableSet = HashSet() +class NonEmptySet private constructor(private val elements: Set) : Set by elements { + companion object { + /** + * Returns a singleton set containing [element]. This behaves the same as [Collections.singleton] but returns a + * [NonEmptySet] for the extra type-safety. + */ + @JvmStatic + fun of(element: T): NonEmptySet = NonEmptySet(Collections.singleton(element)) - init { - set.add(initial) - } + /** Returns a non-empty set containing the given elements, minus duplicates, in the order each was specified. */ + @JvmStatic + fun of(first: T, second: T, vararg rest: T): NonEmptySet { + val elements = LinkedHashSet(rest.size + 2) + elements += first + elements += second + elements.addAll(rest) + return NonEmptySet(elements) + } - override val size: Int - get() = set.size - - override fun add(element: T): Boolean = set.add(element) - override fun addAll(elements: Collection): Boolean = set.addAll(elements) - override fun clear() = throw UnsupportedOperationException() - override fun contains(element: T): Boolean = set.contains(element) - override fun containsAll(elements: Collection): Boolean = set.containsAll(elements) - override fun isEmpty(): Boolean = false - - override fun iterator(): MutableIterator = Iterator(set.iterator()) - - override fun remove(element: T): Boolean = - // Test either there's more than one element, or the removal is a no-op - if (size > 1) - set.remove(element) - else if (!contains(element)) - false - else - throw IllegalStateException() - - override fun removeAll(elements: Collection): Boolean = - if (size > elements.size) - set.removeAll(elements) - else if (!containsAll(elements)) - // Remove the common elements - set.removeAll(elements) - else - throw IllegalStateException() - - override fun retainAll(elements: Collection): Boolean { - val iterator = iterator() - val ret = false - - // The iterator will throw an IllegalStateException if we try removing the last element - while (iterator.hasNext()) { - if (!elements.contains(iterator.next())) { - iterator.remove() + /** + * Returns a non-empty set containing each of [elements], minus duplicates, in the order each appears first in + * the source collection. + * @throws IllegalArgumentException If [elements] is empty. + */ + @JvmStatic + fun copyOf(elements: Collection): NonEmptySet { + if (elements is NonEmptySet) return elements + return when (elements.size) { + 0 -> throw IllegalArgumentException("elements is empty") + 1 -> of(elements.first()) + else -> { + val copy = LinkedHashSet(elements.size) + elements.forEach { copy += it } // Can't use Collection.addAll as it doesn't specify insertion order + NonEmptySet(copy) + } } } - - return ret } - override fun equals(other: Any?): Boolean = - if (other is Set<*>) - // Delegate down to the wrapped set's equals() function - set == other - else - false + /** Returns the first element of the set. */ + fun head(): T = elements.iterator().next() + override fun isEmpty(): Boolean = false + override fun iterator() = object : Iterator by elements.iterator() {} - override fun hashCode(): Int = set.hashCode() - override fun toString(): String = set.toString() - - inner class Iterator(val iterator: MutableIterator) : MutableIterator { - override fun hasNext(): Boolean = iterator.hasNext() - override fun next(): T = iterator.next() - override fun remove() = - if (set.size > 1) - iterator.remove() - else - throw IllegalStateException() - } -} - -fun nonEmptySetOf(initial: T, vararg elements: T): NonEmptySet { - val set = NonEmptySet(initial) - // We add the first element twice, but it's a set, so who cares - set.addAll(elements) - return set -} - -/** - * Custom serializer which understands it has to read in an item before - * trying to construct the set. - */ -object NonEmptySetSerializer : Serializer>() { - override fun write(kryo: Kryo, output: Output, obj: NonEmptySet) { - // Write out the contents as normal - output.writeInt(obj.size) - obj.forEach { kryo.writeClassAndObject(output, it) } - } - - override fun read(kryo: Kryo, input: Input, type: Class>): NonEmptySet { - val size = input.readInt() - require(size >= 1) { "Size is positive" } - // TODO: Is there an upper limit we can apply to how big one of these could be? - val first = kryo.readClassAndObject(input) - // Read the first item and use it to construct the NonEmptySet - val set = NonEmptySet(first) - // Read in the rest of the set - for (i in 2..size) { - set.add(kryo.readClassAndObject(input)) - } - return set - } + // Following methods are not delegated by Kotlin's Class delegation + override fun forEach(action: Consumer) = elements.forEach(action) + override fun stream(): Stream = elements.stream() + override fun parallelStream(): Stream = elements.parallelStream() + override fun spliterator(): Spliterator = elements.spliterator() + override fun equals(other: Any?): Boolean = other === this || other == elements + override fun hashCode(): Int = elements.hashCode() + override fun toString(): String = elements.toString() } diff --git a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt index fc788eb10c..6cd9526795 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt @@ -1,10 +1,8 @@ package net.corda.core.utilities -import net.corda.core.TransientProperty import net.corda.core.serialization.CordaSerializable import rx.Observable import rx.Subscription -import rx.subjects.BehaviorSubject import rx.subjects.PublishSubject import java.util.* @@ -76,7 +74,7 @@ class ProgressTracker(vararg steps: Step) { val steps = arrayOf(UNSTARTED, *steps, DONE) // This field won't be serialized. - private val _changes by TransientProperty { PublishSubject.create() } + private val _changes by transient { PublishSubject.create() } @CordaSerializable private data class Child(val tracker: ProgressTracker, @Transient val subscription: Subscription?) diff --git a/core/src/main/kotlin/net/corda/flows/FetchAttachmentsFlow.kt b/core/src/main/kotlin/net/corda/flows/FetchAttachmentsFlow.kt deleted file mode 100644 index 805e25da14..0000000000 --- a/core/src/main/kotlin/net/corda/flows/FetchAttachmentsFlow.kt +++ /dev/null @@ -1,40 +0,0 @@ -package net.corda.flows - -import net.corda.core.contracts.AbstractAttachment -import net.corda.core.contracts.Attachment -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.sha256 -import net.corda.core.flows.InitiatingFlow -import net.corda.core.identity.Party -import net.corda.core.serialization.SerializationToken -import net.corda.core.serialization.SerializeAsToken -import net.corda.core.serialization.SerializeAsTokenContext - -/** - * Given a set of hashes either loads from from local storage or requests them from the other peer. Downloaded - * attachments are saved to local storage automatically. - */ -@InitiatingFlow -class FetchAttachmentsFlow(requests: Set, - otherSide: Party) : FetchDataFlow(requests, otherSide) { - - override fun load(txid: SecureHash): Attachment? = serviceHub.attachments.openAttachment(txid) - - override fun convert(wire: ByteArray): Attachment = FetchedAttachment({ wire }) - - override fun maybeWriteToDisk(downloaded: List) { - for (attachment in downloaded) { - serviceHub.attachments.importAttachment(attachment.open()) - } - } - - private class FetchedAttachment(dataLoader: () -> ByteArray) : AbstractAttachment(dataLoader), SerializeAsToken { - override val id: SecureHash by lazy { attachmentData.sha256() } - - private class Token(private val id: SecureHash) : SerializationToken { - override fun fromToken(context: SerializeAsTokenContext) = FetchedAttachment(context.attachmentDataLoader(id)) - } - - override fun toToken(context: SerializeAsTokenContext) = Token(id) - } -} diff --git a/core/src/main/kotlin/net/corda/flows/FetchDataFlow.kt b/core/src/main/kotlin/net/corda/flows/FetchDataFlow.kt deleted file mode 100644 index efbe720210..0000000000 --- a/core/src/main/kotlin/net/corda/flows/FetchDataFlow.kt +++ /dev/null @@ -1,107 +0,0 @@ -package net.corda.flows - -import co.paralleluniverse.fibers.Suspendable -import net.corda.core.contracts.NamedByHash -import net.corda.core.crypto.SecureHash -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowLogic -import net.corda.core.identity.Party -import net.corda.core.serialization.CordaSerializable -import net.corda.core.utilities.UntrustworthyData -import net.corda.core.utilities.unwrap -import net.corda.flows.FetchDataFlow.DownloadedVsRequestedDataMismatch -import net.corda.flows.FetchDataFlow.HashNotFound -import java.util.* - -/** - * An abstract flow for fetching typed data from a remote peer. - * - * Given a set of hashes (IDs), either loads them from local disk or asks the remote peer to provide them. - * - * A malicious response in which the data provided by the remote peer does not hash to the requested hash results in - * [DownloadedVsRequestedDataMismatch] being thrown. If the remote peer doesn't have an entry, it results in a - * [HashNotFound] exception being thrown. - * - * By default this class does not insert data into any local database, if you want to do that after missing items were - * fetched then override [maybeWriteToDisk]. You *must* override [load]. If the wire type is not the same as the - * ultimate type, you must also override [convert]. - * - * @param T The ultimate type of the data being fetched. - * @param W The wire type of the data being fetched, for when it isn't the same as the ultimate type. - */ -abstract class FetchDataFlow( - protected val requests: Set, - protected val otherSide: Party) : FlowLogic>() { - - @CordaSerializable - class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : IllegalArgumentException() - - @CordaSerializable - class DownloadedVsRequestedSizeMismatch(val requested: Int, val got: Int) : IllegalArgumentException() - - class HashNotFound(val requested: SecureHash) : FlowException() - - @CordaSerializable - data class Request(val hashes: List) - - @CordaSerializable - data class Result(val fromDisk: List, val downloaded: List) - - @Suspendable - @Throws(HashNotFound::class) - override fun call(): Result { - // Load the items we have from disk and figure out which we're missing. - val (fromDisk, toFetch) = loadWhatWeHave() - - return if (toFetch.isEmpty()) { - Result(fromDisk, emptyList()) - } else { - logger.trace("Requesting ${toFetch.size} dependency(s) for verification") - - // TODO: Support "large message" response streaming so response sizes are not limited by RAM. - val maybeItems = sendAndReceive>(otherSide, Request(toFetch)) - // Check for a buggy/malicious peer answering with something that we didn't ask for. - val downloaded = validateFetchResponse(maybeItems, toFetch) - maybeWriteToDisk(downloaded) - Result(fromDisk, downloaded) - } - } - - protected open fun maybeWriteToDisk(downloaded: List) { - // Do nothing by default. - } - - private fun loadWhatWeHave(): Pair, List> { - val fromDisk = ArrayList() - val toFetch = ArrayList() - for (txid in requests) { - val stx = load(txid) - if (stx == null) - toFetch += txid - else - fromDisk += stx - } - return Pair(fromDisk, toFetch) - } - - protected abstract fun load(txid: SecureHash): T? - - @Suppress("UNCHECKED_CAST") - protected open fun convert(wire: W): T = wire as T - - private fun validateFetchResponse(maybeItems: UntrustworthyData>, - requests: List): List { - return maybeItems.unwrap { response -> - if (response.size != requests.size) - throw DownloadedVsRequestedSizeMismatch(requests.size, response.size) - val answers = response.map { convert(it) } - // Check transactions actually hash to what we requested, if this fails the remote node - // is a malicious flow violator or buggy. - for ((index, item) in answers.withIndex()) { - if (item.id != requests[index]) - throw DownloadedVsRequestedDataMismatch(requests[index], item.id) - } - answers - } - } -} diff --git a/core/src/main/kotlin/net/corda/flows/FetchTransactionsFlow.kt b/core/src/main/kotlin/net/corda/flows/FetchTransactionsFlow.kt deleted file mode 100644 index 0f99aad169..0000000000 --- a/core/src/main/kotlin/net/corda/flows/FetchTransactionsFlow.kt +++ /dev/null @@ -1,21 +0,0 @@ -package net.corda.flows - -import net.corda.core.crypto.SecureHash -import net.corda.core.flows.InitiatingFlow -import net.corda.core.identity.Party -import net.corda.core.transactions.SignedTransaction - -/** - * Given a set of tx hashes (IDs), either loads them from local disk or asks the remote peer to provide them. - * - * A malicious response in which the data provided by the remote peer does not hash to the requested hash results in - * [FetchDataFlow.DownloadedVsRequestedDataMismatch] being thrown. If the remote peer doesn't have an entry, it - * results in a [FetchDataFlow.HashNotFound] exception. Note that returned transactions are not inserted into - * the database, because it's up to the caller to actually verify the transactions are valid. - */ -@InitiatingFlow -class FetchTransactionsFlow(requests: Set, otherSide: Party) : - FetchDataFlow(requests, otherSide) { - - override fun load(txid: SecureHash): SignedTransaction? = serviceHub.validatedTransactions.getTransaction(txid) -} diff --git a/core/src/main/kotlin/net/corda/flows/NotaryChangeFlow.kt b/core/src/main/kotlin/net/corda/flows/NotaryChangeFlow.kt deleted file mode 100644 index ee5453d167..0000000000 --- a/core/src/main/kotlin/net/corda/flows/NotaryChangeFlow.kt +++ /dev/null @@ -1,84 +0,0 @@ -package net.corda.flows - -import net.corda.core.contracts.* -import net.corda.core.flows.InitiatingFlow -import net.corda.core.identity.AbstractParty -import net.corda.core.identity.Party -import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.ProgressTracker - -/** - * A flow to be used for changing a state's Notary. This is required since all input states to a transaction - * must point to the same notary. - * - * This assembles the transaction for notary replacement and sends out change proposals to all participants - * of that state. If participants agree to the proposed change, they each sign the transaction. - * Finally, the transaction containing all signatures is sent back to each participant so they can record it and - * use the new updated state for future transactions. - */ -@InitiatingFlow -class NotaryChangeFlow( - originalState: StateAndRef, - newNotary: Party, - progressTracker: ProgressTracker = tracker()) - : AbstractStateReplacementFlow.Instigator(originalState, newNotary, progressTracker) { - - override fun assembleTx(): AbstractStateReplacementFlow.UpgradeTx { - val state = originalState.state - val tx = TransactionType.NotaryChange.Builder(originalState.state.notary) - - val participants: Iterable = if (state.encumbrance == null) { - val modifiedState = TransactionState(state.data, modification) - tx.addInputState(originalState) - tx.addOutputState(modifiedState) - state.data.participants - } else { - resolveEncumbrances(tx) - } - - val stx = serviceHub.signInitialTransaction(tx) - val participantKeys = participants.map { it.owningKey } - // TODO: We need a much faster way of finding our key in the transaction - val myKey = serviceHub.keyManagementService.filterMyKeys(participantKeys).single() - - return AbstractStateReplacementFlow.UpgradeTx(stx, participantKeys, myKey) - } - - /** - * Adds the notary change state transitions to the [tx] builder for the [originalState] and its encumbrance - * state chain (encumbrance states might be themselves encumbered by other states). - * - * @return union of all added states' participants - */ - private fun resolveEncumbrances(tx: TransactionBuilder): Iterable { - val stateRef = originalState.ref - val txId = stateRef.txhash - val issuingTx = serviceHub.validatedTransactions.getTransaction(txId) - ?: throw StateReplacementException("Transaction $txId not found") - val outputs = issuingTx.tx.outputs - - val participants = mutableSetOf() - - var nextStateIndex = stateRef.index - var newOutputPosition = tx.outputStates().size - while (true) { - val nextState = outputs[nextStateIndex] - tx.addInputState(StateAndRef(nextState, StateRef(txId, nextStateIndex))) - participants.addAll(nextState.data.participants) - - if (nextState.encumbrance == null) { - val modifiedState = TransactionState(nextState.data, modification) - tx.addOutputState(modifiedState) - break - } else { - val modifiedState = TransactionState(nextState.data, modification, newOutputPosition + 1) - tx.addOutputState(modifiedState) - nextStateIndex = nextState.encumbrance - } - - newOutputPosition++ - } - - return participants - } -} diff --git a/core/src/main/resources/net/corda/core/node/isolated.jar b/core/src/main/resources/net/corda/core/node/isolated.jar deleted file mode 100644 index e2db13bf3c..0000000000 Binary files a/core/src/main/resources/net/corda/core/node/isolated.jar and /dev/null differ diff --git a/core/src/test/java/net/corda/core/concurrent/CordaFutureInJavaTest.java b/core/src/test/java/net/corda/core/concurrent/CordaFutureInJavaTest.java new file mode 100644 index 0000000000..2b8497cbc4 --- /dev/null +++ b/core/src/test/java/net/corda/core/concurrent/CordaFutureInJavaTest.java @@ -0,0 +1,100 @@ +package net.corda.core.concurrent; + +import net.corda.core.internal.concurrent.OpenFuture; +import org.junit.Test; + +import java.io.EOFException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static net.corda.core.internal.concurrent.CordaFutureImplKt.doneFuture; +import static net.corda.core.internal.concurrent.CordaFutureImplKt.openFuture; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class CordaFutureInJavaTest { + @Test + public void methodsAreNotTooAwkwardToUse() throws InterruptedException, ExecutionException { + { + CordaFuture f = openFuture(); + f.cancel(false); + assertTrue(f.isCancelled()); + } + { + CordaFuture f = openFuture(); + assertThatThrownBy(() -> f.get(1, TimeUnit.MILLISECONDS)).isInstanceOf(TimeoutException.class); + } + { + CordaFuture f = doneFuture(100); + assertEquals(100, f.get()); + } + { + Future f = doneFuture(100); + assertEquals(Integer.valueOf(100), f.get()); + } + { + OpenFuture f = openFuture(); + OpenFuture g = openFuture(); + f.then(done -> { + try { + return g.set(done.get()); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + }); + f.set(100); + assertEquals(100, g.get()); + } + } + + @Test + public void toCompletableFutureWorks() throws InterruptedException, ExecutionException { + { + OpenFuture f = openFuture(); + CompletableFuture g = f.toCompletableFuture(); + f.set(100); + assertEquals(100, g.get()); + } + { + OpenFuture f = openFuture(); + CompletableFuture g = f.toCompletableFuture(); + EOFException e = new EOFException(); + f.setException(e); + assertThatThrownBy(g::get).hasCause(e); + } + { + OpenFuture f = openFuture(); + CompletableFuture g = f.toCompletableFuture(); + f.cancel(false); + assertTrue(g.isCancelled()); + } + } + + @Test + public void toCompletableFutureDoesNotHaveThePowerToAffectTheUnderlyingFuture() { + { + OpenFuture f = openFuture(); + CompletableFuture g = f.toCompletableFuture(); + g.complete(100); + assertFalse(f.isDone()); + } + { + OpenFuture f = openFuture(); + CompletableFuture g = f.toCompletableFuture(); + g.completeExceptionally(new EOFException()); + assertFalse(f.isDone()); + } + { + OpenFuture f = openFuture(); + CompletableFuture g = f.toCompletableFuture(); + g.cancel(false); + // For now let's do the most conservative thing i.e. nothing: + assertFalse(f.isDone()); + } + } +} diff --git a/core/src/test/kotlin/net/corda/core/CollectionExtensionTests.kt b/core/src/test/kotlin/net/corda/core/CollectionExtensionTests.kt deleted file mode 100644 index 1fab9fceaa..0000000000 --- a/core/src/test/kotlin/net/corda/core/CollectionExtensionTests.kt +++ /dev/null @@ -1,40 +0,0 @@ -package net.corda.core - -import org.junit.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith - -class CollectionExtensionTests { - @Test - fun `noneOrSingle returns a single item`() { - val collection = listOf(1) - assertEquals(collection.noneOrSingle(), 1) - assertEquals(collection.noneOrSingle { it == 1 }, 1) - } - - @Test - fun `noneOrSingle returns null if item not found`() { - val collection = emptyList() - assertEquals(collection.noneOrSingle(), null) - } - - @Test - fun `noneOrSingle throws if more than one item found`() { - val collection = listOf(1, 2) - assertFailsWith { collection.noneOrSingle() } - assertFailsWith { collection.noneOrSingle { it > 0 } } - } - - @Test - fun `indexOfOrThrow returns index of the given item`() { - val collection = listOf(1, 2) - assertEquals(collection.indexOfOrThrow(1), 0) - assertEquals(collection.indexOfOrThrow(2), 1) - } - - @Test - fun `indexOfOrThrow throws if the given item is not found`() { - val collection = listOf(1) - assertFailsWith { collection.indexOfOrThrow(2) } - } -} diff --git a/core/src/test/kotlin/net/corda/core/StreamsTest.kt b/core/src/test/kotlin/net/corda/core/StreamsTest.kt deleted file mode 100644 index f9b5ebc1ec..0000000000 --- a/core/src/test/kotlin/net/corda/core/StreamsTest.kt +++ /dev/null @@ -1,42 +0,0 @@ -package net.corda.core - -import org.junit.Assert.assertArrayEquals -import org.junit.Test -import java.util.stream.IntStream -import java.util.stream.Stream -import kotlin.test.assertEquals - -class StreamsTest { - @Test - fun `IntProgression stream works`() { - assertArrayEquals(intArrayOf(1, 2, 3, 4), (1..4).stream().toArray()) - assertArrayEquals(intArrayOf(1, 2, 3, 4), (1 until 5).stream().toArray()) - assertArrayEquals(intArrayOf(1, 3), (1..4 step 2).stream().toArray()) - assertArrayEquals(intArrayOf(1, 3), (1..3 step 2).stream().toArray()) - assertArrayEquals(intArrayOf(), (1..0).stream().toArray()) - assertArrayEquals(intArrayOf(1, 0), (1 downTo 0).stream().toArray()) - assertArrayEquals(intArrayOf(3, 1), (3 downTo 0 step 2).stream().toArray()) - assertArrayEquals(intArrayOf(3, 1), (3 downTo 1 step 2).stream().toArray()) - } - - @Test - fun `IntProgression spliterator characteristics and comparator`() { - val rangeCharacteristics = IntStream.range(0, 2).spliterator().characteristics() - val forward = (0..9 step 3).stream().spliterator() - assertEquals(rangeCharacteristics, forward.characteristics()) - assertEquals(null, forward.comparator) - val reverse = (9 downTo 0 step 3).stream().spliterator() - assertEquals(rangeCharacteristics, reverse.characteristics()) - assertEquals(Comparator.reverseOrder(), reverse.comparator) - } - - @Test - fun `Stream toTypedArray works`() { - val a: Array = Stream.of("one", "two").toTypedArray() - assertEquals(Array::class.java, a.javaClass) - assertArrayEquals(arrayOf("one", "two"), a) - val b: Array = Stream.of("one", "two", null).toTypedArray() - assertEquals(Array::class.java, b.javaClass) - assertArrayEquals(arrayOf("one", "two", null), b) - } -} diff --git a/core/src/test/kotlin/net/corda/core/UtilsTest.kt b/core/src/test/kotlin/net/corda/core/UtilsTest.kt index d785102d07..717356f4d4 100644 --- a/core/src/test/kotlin/net/corda/core/UtilsTest.kt +++ b/core/src/test/kotlin/net/corda/core/UtilsTest.kt @@ -1,18 +1,11 @@ package net.corda.core -import com.google.common.util.concurrent.MoreExecutors -import com.nhaarman.mockito_kotlin.mock -import com.nhaarman.mockito_kotlin.same -import com.nhaarman.mockito_kotlin.verify +import net.corda.core.utilities.getOrThrow import org.assertj.core.api.Assertions.* import org.junit.Test -import org.mockito.ArgumentMatchers.anyString -import org.slf4j.Logger import rx.subjects.PublishSubject import java.util.* import java.util.concurrent.CancellationException -import java.util.concurrent.Executors -import java.util.concurrent.TimeUnit class UtilsTest { @Test @@ -65,17 +58,4 @@ class UtilsTest { future.get() } } - - @Test - fun `andForget works`() { - val log = mock() - val throwable = Exception("Boom") - val executor = MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()) - executor.submit { throw throwable }.andForget(log) - executor.shutdown() - while (!executor.awaitTermination(1, TimeUnit.SECONDS)) { - // Do nothing. - } - verify(log).error(anyString(), same(throwable)) - } } diff --git a/core/src/test/kotlin/net/corda/core/concurrent/ConcurrencyUtilsTest.kt b/core/src/test/kotlin/net/corda/core/concurrent/ConcurrencyUtilsTest.kt index 722d67184e..a0b4f62453 100644 --- a/core/src/test/kotlin/net/corda/core/concurrent/ConcurrencyUtilsTest.kt +++ b/core/src/test/kotlin/net/corda/core/concurrent/ConcurrencyUtilsTest.kt @@ -1,21 +1,22 @@ package net.corda.core.concurrent -import com.google.common.util.concurrent.SettableFuture import com.nhaarman.mockito_kotlin.* -import net.corda.core.getOrThrow +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.utilities.getOrThrow import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Test import org.slf4j.Logger import java.io.EOFException import java.util.concurrent.CancellationException +import java.util.concurrent.CompletableFuture import kotlin.test.assertEquals import kotlin.test.assertTrue class ConcurrencyUtilsTest { - private val f1 = SettableFuture.create() - private val f2 = SettableFuture.create() + private val f1 = openFuture() + private val f2 = openFuture() private var invocations = 0 - private val log: Logger = mock() + private val log = mock() @Test fun `firstOf short circuit`() { // Order not significant in this case: @@ -31,6 +32,7 @@ class ConcurrencyUtilsTest { f2.setException(throwable) assertEquals(1, invocations) // Least astonishing to skip handler side-effects. verify(log).error(eq(shortCircuitedTaskFailedMessage), same(throwable)) + verifyNoMoreInteractions(log) } @Test @@ -48,20 +50,24 @@ class ConcurrencyUtilsTest { assertTrue(f2.isCancelled) } + /** + * Note that if you set CancellationException on CompletableFuture it will report isCancelled. + */ @Test fun `firstOf re-entrant handler attempt not due to cancel`() { val futures = arrayOf(f1, f2) - val fakeCancel = CancellationException() + val nonCancel = IllegalStateException() val g = firstOf(futures, log) { ++invocations - futures.forEach { it.setException(fakeCancel) } // One handler attempt here. + futures.forEach { it.setException(nonCancel) } // One handler attempt here. it.getOrThrow() } f1.set(100) assertEquals(100, g.getOrThrow()) assertEquals(1, invocations) // Handler didn't run as g was already done. - verify(log).error(eq(shortCircuitedTaskFailedMessage), same(fakeCancel)) - assertThatThrownBy { f2.getOrThrow() }.isSameAs(fakeCancel) + verify(log).error(eq(shortCircuitedTaskFailedMessage), same(nonCancel)) + verifyNoMoreInteractions(log) + assertThatThrownBy { f2.getOrThrow() }.isSameAs(nonCancel) } @Test @@ -75,4 +81,37 @@ class ConcurrencyUtilsTest { assertEquals(1, invocations) verifyNoMoreInteractions(log) } + + @Test + fun `match does not pass failure of success block into the failure block`() { + val f = CompletableFuture.completedFuture(100) + val successes = mutableListOf() + val failures = mutableListOf() + val x = Throwable() + assertThatThrownBy { + f.match({ + successes.add(it) + throw x + }, failures::add) + }.isSameAs(x) + assertEquals(listOf(100), successes) + assertEquals(emptyList(), failures) + } + + @Test + fun `match does not pass ExecutionException to failure block`() { + val e = Throwable() + val f = CompletableFuture().apply { completeExceptionally(e) } + val successes = mutableListOf() + val failures = mutableListOf() + val x = Throwable() + assertThatThrownBy { + f.match(successes::add, { + failures.add(it) + throw x + }) + }.isSameAs(x) + assertEquals(emptyList(), successes) + assertEquals(listOf(e), failures) + } } diff --git a/core/src/test/kotlin/net/corda/core/contracts/LedgerTransactionQueryTests.kt b/core/src/test/kotlin/net/corda/core/contracts/LedgerTransactionQueryTests.kt new file mode 100644 index 0000000000..0c1540d089 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/contracts/LedgerTransactionQueryTests.kt @@ -0,0 +1,289 @@ +package net.corda.core.contracts + +import net.corda.core.identity.AbstractParty +import net.corda.core.node.ServiceHub +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.TestDependencyInjectionBase +import net.corda.testing.contracts.DummyContract +import net.corda.testing.node.MockServices +import org.junit.Before +import org.junit.Test +import java.util.function.Predicate +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class LedgerTransactionQueryTests : TestDependencyInjectionBase() { + + private lateinit var services: ServiceHub + + @Before + fun setup() { + services = MockServices() + } + + interface Commands { + data class Cmd1(val id: Int) : CommandData, Commands + data class Cmd2(val id: Int) : CommandData, Commands + } + + + private class StringTypeDummyState(val data: String) : ContractState { + override val contract: Contract = DummyContract() + override val participants: List = emptyList() + } + + private class IntTypeDummyState(val data: Int) : ContractState { + override val contract: Contract = DummyContract() + override val participants: List = emptyList() + } + + private fun makeDummyState(data: Any): ContractState { + return when (data) { + is String -> StringTypeDummyState(data) + is Int -> IntTypeDummyState(data) + else -> throw IllegalArgumentException() + } + } + + private fun makeDummyStateAndRef(data: Any): StateAndRef<*> { + val dummyState = makeDummyState(data) + val fakeIssueTx = services.signInitialTransaction(TransactionBuilder(notary = DUMMY_NOTARY).addOutputState(dummyState)) + services.recordTransactions(fakeIssueTx) + val dummyStateRef = StateRef(fakeIssueTx.id, 0) + return StateAndRef(TransactionState(dummyState, DUMMY_NOTARY, null), dummyStateRef) + } + + private fun makeDummyTransaction(): LedgerTransaction { + val tx = TransactionBuilder(notary = DUMMY_NOTARY) + for (i in 0..4) { + tx.addInputState(makeDummyStateAndRef(i)) + tx.addInputState(makeDummyStateAndRef(i.toString())) + tx.addOutputState(makeDummyState(i)) + tx.addOutputState(makeDummyState(i.toString())) + tx.addCommand(Commands.Cmd1(i), listOf(services.myInfo.legalIdentity.owningKey)) + tx.addCommand(Commands.Cmd2(i), listOf(services.myInfo.legalIdentity.owningKey)) + } + return tx.toLedgerTransaction(services) + } + + @Test + fun `Simple InRef Indexer tests`() { + val ltx = makeDummyTransaction() + assertEquals(0, ltx.inRef(0).state.data.data) + assertEquals("0", ltx.inRef(1).state.data.data) + assertEquals(3, ltx.inRef(6).state.data.data) + assertEquals("3", ltx.inRef(7).state.data.data) + assertFailsWith { ltx.inRef(10) } + } + + @Test + fun `Simple OutRef Indexer tests`() { + val ltx = makeDummyTransaction() + assertEquals(0, ltx.outRef(0).state.data.data) + assertEquals("0", ltx.outRef(1).state.data.data) + assertEquals(3, ltx.outRef(6).state.data.data) + assertEquals("3", ltx.outRef(7).state.data.data) + assertFailsWith { ltx.outRef(10) } + } + + @Test + fun `Simple Input Indexer tests`() { + val ltx = makeDummyTransaction() + assertEquals(0, (ltx.getInput(0) as IntTypeDummyState).data) + assertEquals("0", (ltx.getInput(1) as StringTypeDummyState).data) + assertEquals(3, (ltx.getInput(6) as IntTypeDummyState).data) + assertEquals("3", (ltx.getInput(7) as StringTypeDummyState).data) + assertFailsWith { ltx.getInput(10) } + } + + @Test + fun `Simple Output Indexer tests`() { + val ltx = makeDummyTransaction() + assertEquals(0, (ltx.getOutput(0) as IntTypeDummyState).data) + assertEquals("0", (ltx.getOutput(1) as StringTypeDummyState).data) + assertEquals(3, (ltx.getOutput(6) as IntTypeDummyState).data) + assertEquals("3", (ltx.getOutput(7) as StringTypeDummyState).data) + assertFailsWith { ltx.getOutput(10) } + } + + @Test + fun `Simple Command Indexer tests`() { + val ltx = makeDummyTransaction() + assertEquals(0, ltx.getCommand(0).value.id) + assertEquals(0, ltx.getCommand(1).value.id) + assertEquals(3, ltx.getCommand(6).value.id) + assertEquals(3, ltx.getCommand(7).value.id) + assertFailsWith { ltx.getOutput(10) } + } + + @Test + fun `Simple Inputs of type tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.inputsOfType(IntTypeDummyState::class.java) + assertEquals(5, intStates.size) + assertEquals(listOf(0, 1, 2, 3, 4), intStates.map { it.data }) + val stringStates = ltx.inputsOfType() + assertEquals(5, stringStates.size) + assertEquals(listOf("0", "1", "2", "3", "4"), stringStates.map { it.data }) + val notPresentQuery = ltx.inputsOfType(FungibleAsset::class.java) + assertEquals(emptyList(), notPresentQuery) + } + + @Test + fun `Simple InputsRefs of type tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.inRefsOfType(IntTypeDummyState::class.java) + assertEquals(5, intStates.size) + assertEquals(listOf(0, 1, 2, 3, 4), intStates.map { it.state.data.data }) + assertEquals(listOf(ltx.inputs[0], ltx.inputs[2], ltx.inputs[4], ltx.inputs[6], ltx.inputs[8]), intStates) + val stringStates = ltx.inRefsOfType() + assertEquals(5, stringStates.size) + assertEquals(listOf("0", "1", "2", "3", "4"), stringStates.map { it.state.data.data }) + assertEquals(listOf(ltx.inputs[1], ltx.inputs[3], ltx.inputs[5], ltx.inputs[7], ltx.inputs[9]), stringStates) + } + + @Test + fun `Simple Outputs of type tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.outputsOfType(IntTypeDummyState::class.java) + assertEquals(5, intStates.size) + assertEquals(listOf(0, 1, 2, 3, 4), intStates.map { it.data }) + val stringStates = ltx.outputsOfType() + assertEquals(5, stringStates.size) + assertEquals(listOf("0", "1", "2", "3", "4"), stringStates.map { it.data }) + val notPresentQuery = ltx.outputsOfType(FungibleAsset::class.java) + assertEquals(emptyList(), notPresentQuery) + } + + @Test + fun `Simple OutputsRefs of type tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.outRefsOfType(IntTypeDummyState::class.java) + assertEquals(5, intStates.size) + assertEquals(listOf(0, 1, 2, 3, 4), intStates.map { it.state.data.data }) + assertEquals(listOf(0, 2, 4, 6, 8), intStates.map { it.ref.index }) + assertTrue(intStates.all { it.ref.txhash == ltx.id }) + val stringStates = ltx.outRefsOfType() + assertEquals(5, stringStates.size) + assertEquals(listOf("0", "1", "2", "3", "4"), stringStates.map { it.state.data.data }) + assertEquals(listOf(1, 3, 5, 7, 9), stringStates.map { it.ref.index }) + assertTrue(stringStates.all { it.ref.txhash == ltx.id }) + } + + @Test + fun `Simple Commands of type tests`() { + val ltx = makeDummyTransaction() + val intCmd1 = ltx.commandsOfType(Commands.Cmd1::class.java) + assertEquals(5, intCmd1.size) + assertEquals(listOf(0, 1, 2, 3, 4), intCmd1.map { it.value.id }) + val intCmd2 = ltx.commandsOfType() + assertEquals(5, intCmd2.size) + assertEquals(listOf(0, 1, 2, 3, 4), intCmd2.map { it.value.id }) + val notPresentQuery = ltx.commandsOfType(FungibleAsset.Commands.Exit::class.java) + assertEquals(emptyList(), notPresentQuery) + } + + @Test + fun `Filtered Input Tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.filterInputs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) + assertEquals(3, intStates.size) + assertEquals(listOf(0, 2, 4), intStates.map { it.data }) + val stringStates: List = ltx.filterInputs { it.data == "3" } + assertEquals("3", stringStates.single().data) + } + + @Test + fun `Filtered InRef Tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.filterInRefs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) + assertEquals(3, intStates.size) + assertEquals(listOf(0, 2, 4), intStates.map { it.state.data.data }) + assertEquals(listOf(ltx.inputs[0], ltx.inputs[4], ltx.inputs[8]), intStates) + val stringStates: List> = ltx.filterInRefs { it.data == "3" } + assertEquals("3", stringStates.single().state.data.data) + assertEquals(ltx.inputs[7], stringStates.single()) + } + + @Test + fun `Filtered Output Tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.filterOutputs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) + assertEquals(3, intStates.size) + assertEquals(listOf(0, 2, 4), intStates.map { it.data }) + val stringStates: List = ltx.filterOutputs { it.data == "3" } + assertEquals("3", stringStates.single().data) + } + + @Test + fun `Filtered OutRef Tests`() { + val ltx = makeDummyTransaction() + val intStates = ltx.filterOutRefs(IntTypeDummyState::class.java, Predicate { it.data.rem(2) == 0 }) + assertEquals(3, intStates.size) + assertEquals(listOf(0, 2, 4), intStates.map { it.state.data.data }) + assertEquals(listOf(0, 4, 8), intStates.map { it.ref.index }) + assertTrue(intStates.all { it.ref.txhash == ltx.id }) + val stringStates: List> = ltx.filterOutRefs { it.data == "3" } + assertEquals("3", stringStates.single().state.data.data) + assertEquals(7, stringStates.single().ref.index) + assertEquals(ltx.id, stringStates.single().ref.txhash) + } + + @Test + fun `Filtered Commands Tests`() { + val ltx = makeDummyTransaction() + val intCmds1 = ltx.filterCommands(Commands.Cmd1::class.java, Predicate { it.id.rem(2) == 0 }) + assertEquals(3, intCmds1.size) + assertEquals(listOf(0, 2, 4), intCmds1.map { it.value.id }) + val intCmds2 = ltx.filterCommands { it.id == 3 } + assertEquals(3, intCmds2.single().value.id) + } + + @Test + fun `Find Input Tests`() { + val ltx = makeDummyTransaction() + val intState = ltx.findInput(IntTypeDummyState::class.java, Predicate { it.data == 4 }) + assertEquals(ltx.getInput(8), intState) + val stringState: StringTypeDummyState = ltx.findInput { it.data == "3" } + assertEquals(ltx.getInput(7), stringState) + } + + @Test + fun `Find InRef Tests`() { + val ltx = makeDummyTransaction() + val intState = ltx.findInRef(IntTypeDummyState::class.java, Predicate { it.data == 4 }) + assertEquals(ltx.inRef(8), intState) + val stringState: StateAndRef = ltx.findInRef { it.data == "3" } + assertEquals(ltx.inRef(7), stringState) + } + + @Test + fun `Find Output Tests`() { + val ltx = makeDummyTransaction() + val intState = ltx.findOutput(IntTypeDummyState::class.java, Predicate { it.data == 4 }) + assertEquals(ltx.getOutput(8), intState) + val stringState: StringTypeDummyState = ltx.findOutput { it.data == "3" } + assertEquals(ltx.getOutput(7), stringState) + } + + @Test + fun `Find OutRef Tests`() { + val ltx = makeDummyTransaction() + val intState = ltx.findOutRef(IntTypeDummyState::class.java, Predicate { it.data == 4 }) + assertEquals(ltx.outRef(8), intState) + val stringState: StateAndRef = ltx.findOutRef { it.data == "3" } + assertEquals(ltx.outRef(7), stringState) + } + + @Test + fun `Find Commands Tests`() { + val ltx = makeDummyTransaction() + val intCmd1 = ltx.findCommand(Commands.Cmd1::class.java, Predicate { it.id == 2 }) + assertEquals(ltx.getCommand(4), intCmd1) + val intCmd2 = ltx.findCommand { it.id == 3 } + assertEquals(ltx.getCommand(7), intCmd2) + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/contracts/TimeWindowTest.kt b/core/src/test/kotlin/net/corda/core/contracts/TimeWindowTest.kt new file mode 100644 index 0000000000..453b01eb65 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/contracts/TimeWindowTest.kt @@ -0,0 +1,67 @@ +package net.corda.core.contracts + +import net.corda.core.utilities.millis +import net.corda.core.utilities.minutes +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import java.time.Instant +import java.time.LocalDate +import java.time.ZoneOffset.UTC + +class TimeWindowTest { + private val now = Instant.now() + + @Test + fun fromOnly() { + val timeWindow = TimeWindow.fromOnly(now) + assertThat(timeWindow.fromTime).isEqualTo(now) + assertThat(timeWindow.untilTime).isNull() + assertThat(timeWindow.midpoint).isNull() + assertThat(timeWindow.contains(now - 1.millis)).isFalse() + assertThat(timeWindow.contains(now)).isTrue() + assertThat(timeWindow.contains(now + 1.millis)).isTrue() + } + + @Test + fun untilOnly() { + val timeWindow = TimeWindow.untilOnly(now) + assertThat(timeWindow.fromTime).isNull() + assertThat(timeWindow.untilTime).isEqualTo(now) + assertThat(timeWindow.midpoint).isNull() + assertThat(timeWindow.contains(now - 1.millis)).isTrue() + assertThat(timeWindow.contains(now)).isFalse() + assertThat(timeWindow.contains(now + 1.millis)).isFalse() + } + + @Test + fun between() { + val today = LocalDate.now() + val fromTime = today.atTime(12, 0).toInstant(UTC) + val untilTime = today.atTime(12, 30).toInstant(UTC) + val timeWindow = TimeWindow.between(fromTime, untilTime) + assertThat(timeWindow.fromTime).isEqualTo(fromTime) + assertThat(timeWindow.untilTime).isEqualTo(untilTime) + assertThat(timeWindow.midpoint).isEqualTo(today.atTime(12, 15).toInstant(UTC)) + assertThat(timeWindow.contains(fromTime - 1.millis)).isFalse() + assertThat(timeWindow.contains(fromTime)).isTrue() + assertThat(timeWindow.contains(fromTime + 1.millis)).isTrue() + assertThat(timeWindow.contains(untilTime)).isFalse() + assertThat(timeWindow.contains(untilTime + 1.millis)).isFalse() + } + + @Test + fun fromStartAndDuration() { + val timeWindow = TimeWindow.fromStartAndDuration(now, 10.minutes) + assertThat(timeWindow.fromTime).isEqualTo(now) + assertThat(timeWindow.untilTime).isEqualTo(now + 10.minutes) + assertThat(timeWindow.midpoint).isEqualTo(now + 5.minutes) + } + + @Test + fun withTolerance() { + val timeWindow = TimeWindow.withTolerance(now, 10.minutes) + assertThat(timeWindow.fromTime).isEqualTo(now - 10.minutes) + assertThat(timeWindow.untilTime).isEqualTo(now + 10.minutes) + assertThat(timeWindow.midpoint).isEqualTo(now) + } +} diff --git a/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt b/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt index c13ebbae09..2d05c6ea3d 100644 --- a/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt +++ b/core/src/test/kotlin/net/corda/core/contracts/TransactionEncumbranceTests.kt @@ -3,10 +3,10 @@ package net.corda.core.contracts import net.corda.contracts.asset.Cash import net.corda.core.crypto.SecureHash import net.corda.core.identity.AbstractParty +import net.corda.core.transactions.LedgerTransaction import net.corda.testing.MEGA_CORP import net.corda.testing.MINI_CORP import net.corda.testing.ledger -import net.corda.testing.transaction import org.junit.Test import java.time.Instant import java.time.temporal.ChronoUnit @@ -28,8 +28,8 @@ class TransactionEncumbranceTests { class DummyTimeLock : Contract { override val legalContractReference = SecureHash.sha256("DummyTimeLock") - override fun verify(tx: TransactionForContract) { - val timeLockInput = tx.inputs.filterIsInstance().singleOrNull() ?: return + override fun verify(tx: LedgerTransaction) { + val timeLockInput = tx.inputsOfType().singleOrNull() ?: return val time = tx.timeWindow?.untilTime ?: throw IllegalArgumentException("Transactions containing time-locks must have a time-window") requireThat { "the time specified in the time-lock has passed" using (time >= timeLockInput.validFrom) @@ -114,22 +114,26 @@ class TransactionEncumbranceTests { @Test fun `state cannot be encumbered by itself`() { - transaction { - input { state } - output(encumbrance = 0) { stateWithNewOwner } - command(MEGA_CORP.owningKey) { Cash.Commands.Move() } - this `fails with` "Missing required encumbrance 0 in OUTPUT" + ledger { + transaction { + input { state } + output(encumbrance = 0) { stateWithNewOwner } + command(MEGA_CORP.owningKey) { Cash.Commands.Move() } + this `fails with` "Missing required encumbrance 0 in OUTPUT" + } } } @Test fun `encumbrance state index must be valid`() { - transaction { - input { state } - output(encumbrance = 2) { stateWithNewOwner } - output { timeLock } - command(MEGA_CORP.owningKey) { Cash.Commands.Move() } - this `fails with` "Missing required encumbrance 2 in OUTPUT" + ledger { + transaction { + input { state } + output(encumbrance = 2) { stateWithNewOwner } + output { timeLock } + command(MEGA_CORP.owningKey) { Cash.Commands.Move() } + this `fails with` "Missing required encumbrance 2 in OUTPUT" + } } } diff --git a/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt b/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt index 8f55aa7317..2e1603b1b8 100644 --- a/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt +++ b/core/src/test/kotlin/net/corda/core/contracts/TransactionGraphSearchTests.kt @@ -1,20 +1,18 @@ package net.corda.core.contracts -import net.corda.testing.contracts.DummyContract -import net.corda.testing.contracts.DummyState import net.corda.core.crypto.newSecureRandom import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY -import net.corda.testing.MEGA_CORP_KEY -import net.corda.testing.MEGA_CORP_PUBKEY +import net.corda.testing.* +import net.corda.testing.contracts.DummyContract +import net.corda.testing.contracts.DummyState import net.corda.testing.node.MockServices import net.corda.testing.node.MockTransactionStorage import org.junit.Test import kotlin.test.assertEquals -class TransactionGraphSearchTests { +class TransactionGraphSearchTests : TestDependencyInjectionBase() { class GraphTransactionStorage(val originTx: SignedTransaction, val inputTx: SignedTransaction) : MockTransactionStorage() { init { addTransaction(originTx) @@ -35,14 +33,14 @@ class TransactionGraphSearchTests { val megaCorpServices = MockServices(MEGA_CORP_KEY) val notaryServices = MockServices(DUMMY_NOTARY_KEY) - val originBuilder = TransactionType.General.Builder(DUMMY_NOTARY) + val originBuilder = TransactionBuilder(DUMMY_NOTARY) originBuilder.addOutputState(DummyState(random31BitValue())) originBuilder.addCommand(command, MEGA_CORP_PUBKEY) val originPtx = megaCorpServices.signInitialTransaction(originBuilder) val originTx = notaryServices.addSignature(originPtx) - val inputBuilder = TransactionType.General.Builder(DUMMY_NOTARY) + val inputBuilder = TransactionBuilder(DUMMY_NOTARY) inputBuilder.addInputState(originTx.tx.outRef(0)) val inputPtx = megaCorpServices.signInitialTransaction(inputBuilder) diff --git a/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt b/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt index d4308446e2..2e471ce5db 100644 --- a/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt +++ b/core/src/test/kotlin/net/corda/core/contracts/TransactionTests.kt @@ -1,27 +1,29 @@ package net.corda.core.contracts import net.corda.contracts.asset.DUMMY_CASH_ISSUER_KEY -import net.corda.testing.contracts.DummyContract +import net.corda.core.crypto.* import net.corda.core.crypto.composite.CompositeKey -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.generateKeyPair -import net.corda.core.crypto.sign import net.corda.core.identity.Party -import net.corda.core.serialization.SerializedBytes import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.testing.* +import net.corda.testing.contracts.DummyContract import org.junit.Test import java.security.KeyPair import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertNotEquals -class TransactionTests { - - private fun makeSigned(wtx: WireTransaction, vararg keys: KeyPair): SignedTransaction { - val bytes: SerializedBytes = wtx.serialized - return SignedTransaction(bytes, keys.map { it.sign(wtx.id.bytes) }) +class TransactionTests : TestDependencyInjectionBase() { + private fun makeSigned(wtx: WireTransaction, vararg keys: KeyPair, notarySig: Boolean = true): SignedTransaction { + val keySigs = keys.map { it.sign(SignableData(wtx.id, SignatureMetadata(1, Crypto.findSignatureScheme(it.public).schemeNumberID))) } + val sigs = if (notarySig) { + keySigs + DUMMY_NOTARY_KEY.sign(SignableData(wtx.id, SignatureMetadata(1, Crypto.findSignatureScheme(DUMMY_NOTARY_KEY.public).schemeNumberID))) + } else { + keySigs + } + return SignedTransaction(wtx, sigs) } @Test @@ -38,26 +40,24 @@ class TransactionTests { inputs = listOf(StateRef(SecureHash.randomSHA256(), 0)), attachments = emptyList(), outputs = emptyList(), - commands = emptyList(), + commands = listOf(dummyCommand(compKey, DUMMY_KEY_1.public, DUMMY_KEY_2.public)), notary = DUMMY_NOTARY, - signers = listOf(compKey, DUMMY_KEY_1.public, DUMMY_KEY_2.public), - type = TransactionType.General, timeWindow = null ) assertEquals( setOf(compKey, DUMMY_KEY_2.public), - assertFailsWith { makeSigned(wtx, DUMMY_KEY_1).verifySignatures() }.missing + assertFailsWith { makeSigned(wtx, DUMMY_KEY_1).verifyRequiredSignatures() }.missing ) assertEquals( setOf(compKey, DUMMY_KEY_2.public), - assertFailsWith { makeSigned(wtx, DUMMY_KEY_1, ak).verifySignatures() }.missing + assertFailsWith { makeSigned(wtx, DUMMY_KEY_1, ak).verifyRequiredSignatures() }.missing ) - makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ak, bk).verifySignatures() - makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ck).verifySignatures() - makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ak, bk, ck).verifySignatures() - makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ak).verifySignatures(compKey) - makeSigned(wtx, DUMMY_KEY_1, ak).verifySignatures(compKey, DUMMY_KEY_2.public) // Mixed allowed to be missing. + makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ak, bk).verifyRequiredSignatures() + makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ck).verifyRequiredSignatures() + makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ak, bk, ck).verifyRequiredSignatures() + makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2, ak).verifySignaturesExcept(compKey) + makeSigned(wtx, DUMMY_KEY_1, ak).verifySignaturesExcept(compKey, DUMMY_KEY_2.public) // Mixed allowed to be missing. } @Test @@ -66,31 +66,29 @@ class TransactionTests { inputs = listOf(StateRef(SecureHash.randomSHA256(), 0)), attachments = emptyList(), outputs = emptyList(), - commands = emptyList(), + commands = listOf(dummyCommand(DUMMY_KEY_1.public, DUMMY_KEY_2.public)), notary = DUMMY_NOTARY, - signers = listOf(DUMMY_KEY_1.public, DUMMY_KEY_2.public), - type = TransactionType.General, timeWindow = null ) - assertFailsWith { makeSigned(wtx).verifySignatures() } + assertFailsWith { makeSigned(wtx, notarySig = false).verifyRequiredSignatures() } assertEquals( setOf(DUMMY_KEY_1.public), - assertFailsWith { makeSigned(wtx, DUMMY_KEY_2).verifySignatures() }.missing + assertFailsWith { makeSigned(wtx, DUMMY_KEY_2).verifyRequiredSignatures() }.missing ) assertEquals( setOf(DUMMY_KEY_2.public), - assertFailsWith { makeSigned(wtx, DUMMY_KEY_1).verifySignatures() }.missing + assertFailsWith { makeSigned(wtx, DUMMY_KEY_1).verifyRequiredSignatures() }.missing ) assertEquals( setOf(DUMMY_KEY_2.public), - assertFailsWith { makeSigned(wtx, DUMMY_CASH_ISSUER_KEY).verifySignatures(DUMMY_KEY_1.public) }.missing + assertFailsWith { makeSigned(wtx, DUMMY_CASH_ISSUER_KEY).verifySignaturesExcept(DUMMY_KEY_1.public) }.missing ) - makeSigned(wtx, DUMMY_KEY_1).verifySignatures(DUMMY_KEY_2.public) - makeSigned(wtx, DUMMY_KEY_2).verifySignatures(DUMMY_KEY_1.public) + makeSigned(wtx, DUMMY_KEY_1).verifySignaturesExcept(DUMMY_KEY_2.public) + makeSigned(wtx, DUMMY_KEY_2).verifySignaturesExcept(DUMMY_KEY_1.public) - makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2).verifySignatures() + makeSigned(wtx, DUMMY_KEY_1, DUMMY_KEY_2).verifyRequiredSignatures() } @Test @@ -101,8 +99,8 @@ class TransactionTests { val commands = emptyList>() val attachments = emptyList() val id = SecureHash.randomSHA256() - val signers = listOf(DUMMY_NOTARY_KEY.public) val timeWindow: TimeWindow? = null + val privacySalt: PrivacySalt = PrivacySalt() val transaction: LedgerTransaction = LedgerTransaction( inputs, outputs, @@ -110,39 +108,26 @@ class TransactionTests { attachments, id, null, - signers, timeWindow, - TransactionType.General + privacySalt ) - transaction.type.verify(transaction) + transaction.verify() } @Test - fun `transaction verification fails for duplicate inputs`() { - val baseOutState = TransactionState(DummyContract.SingleOwnerState(0, ALICE), DUMMY_NOTARY) + fun `transaction cannot have duplicate inputs`() { val stateRef = StateRef(SecureHash.randomSHA256(), 0) - val stateAndRef = StateAndRef(baseOutState, stateRef) - val inputs = listOf(stateAndRef, stateAndRef) - val outputs = listOf(baseOutState) - val commands = emptyList>() - val attachments = emptyList() - val id = SecureHash.randomSHA256() - val signers = listOf(DUMMY_NOTARY_KEY.public) - val timeWindow: TimeWindow? = null - val transaction: LedgerTransaction = LedgerTransaction( - inputs, - outputs, - commands, - attachments, - id, - DUMMY_NOTARY, - signers, - timeWindow, - TransactionType.General + fun buildTransaction() = WireTransaction( + inputs = listOf(stateRef, stateRef), + attachments = emptyList(), + outputs = emptyList(), + commands = listOf(dummyCommand(DUMMY_KEY_1.public, DUMMY_KEY_2.public)), + notary = DUMMY_NOTARY, + timeWindow = null ) - assertFailsWith { transaction.type.verify(transaction) } + assertFailsWith { buildTransaction() } } @Test @@ -155,20 +140,38 @@ class TransactionTests { val commands = emptyList>() val attachments = emptyList() val id = SecureHash.randomSHA256() - val signers = listOf(DUMMY_NOTARY_KEY.public) val timeWindow: TimeWindow? = null - val transaction: LedgerTransaction = LedgerTransaction( + val privacySalt: PrivacySalt = PrivacySalt() + fun buildTransaction() = LedgerTransaction( inputs, outputs, commands, attachments, id, notary, - signers, timeWindow, - TransactionType.General + privacySalt ) - assertFailsWith { transaction.type.verify(transaction) } + assertFailsWith { buildTransaction() } + } + + @Test + fun `transactions with identical contents must have different ids`() { + val outputState = TransactionState(DummyContract.SingleOwnerState(0, ALICE), DUMMY_NOTARY) + fun buildTransaction() = WireTransaction( + inputs = emptyList(), + attachments = emptyList(), + outputs = listOf(outputState), + commands = listOf(dummyCommand(DUMMY_KEY_1.public, DUMMY_KEY_2.public)), + notary = null, + timeWindow = null, + privacySalt = PrivacySalt() // Randomly-generated – used for calculating the id + ) + + val issueTx1 = buildTransaction() + val issueTx2 = buildTransaction() + + assertNotEquals(issueTx1.id, issueTx2.id) } } diff --git a/core/src/test/kotlin/net/corda/core/contracts/clauses/AllOfTests.kt b/core/src/test/kotlin/net/corda/core/contracts/clauses/AllOfTests.kt deleted file mode 100644 index 7e2636bcfc..0000000000 --- a/core/src/test/kotlin/net/corda/core/contracts/clauses/AllOfTests.kt +++ /dev/null @@ -1,31 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.TransactionForContract -import net.corda.core.crypto.SecureHash -import org.junit.Test -import java.util.concurrent.atomic.AtomicInteger -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith - -class AllOfTests { - - @Test - fun minimal() { - val counter = AtomicInteger(0) - val clause = AllOf(matchedClause(counter), matchedClause(counter)) - val tx = TransactionForContract(emptyList(), emptyList(), emptyList(), emptyList(), SecureHash.randomSHA256()) - verifyClause(tx, clause, emptyList>()) - - // Check that we've run the verify() function of two clauses - assertEquals(2, counter.get()) - } - - @Test - fun `not all match`() { - val clause = AllOf(matchedClause(), unmatchedClause()) - val tx = TransactionForContract(emptyList(), emptyList(), emptyList(), emptyList(), SecureHash.randomSHA256()) - assertFailsWith { verifyClause(tx, clause, emptyList>()) } - } -} diff --git a/core/src/test/kotlin/net/corda/core/contracts/clauses/AnyOfTests.kt b/core/src/test/kotlin/net/corda/core/contracts/clauses/AnyOfTests.kt deleted file mode 100644 index fa7d6be9a8..0000000000 --- a/core/src/test/kotlin/net/corda/core/contracts/clauses/AnyOfTests.kt +++ /dev/null @@ -1,44 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.TransactionForContract -import net.corda.core.crypto.SecureHash -import org.junit.Test -import java.util.concurrent.atomic.AtomicInteger -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith - -class AnyOfTests { - @Test - fun minimal() { - val counter = AtomicInteger(0) - val clause = AnyOf(matchedClause(counter), matchedClause(counter)) - val tx = TransactionForContract(emptyList(), emptyList(), emptyList(), emptyList(), SecureHash.randomSHA256()) - verifyClause(tx, clause, emptyList>()) - - // Check that we've run the verify() function of two clauses - assertEquals(2, counter.get()) - } - - @Test - fun `not all match`() { - val counter = AtomicInteger(0) - val clause = AnyOf(matchedClause(counter), unmatchedClause(counter)) - val tx = TransactionForContract(emptyList(), emptyList(), emptyList(), emptyList(), SecureHash.randomSHA256()) - verifyClause(tx, clause, emptyList>()) - - // Check that we've run the verify() function of one clause - assertEquals(1, counter.get()) - } - - @Test - fun `none match`() { - val counter = AtomicInteger(0) - val clause = AnyOf(unmatchedClause(counter), unmatchedClause(counter)) - val tx = TransactionForContract(emptyList(), emptyList(), emptyList(), emptyList(), SecureHash.randomSHA256()) - assertFailsWith(IllegalArgumentException::class) { - verifyClause(tx, clause, emptyList>()) - } - } -} diff --git a/core/src/test/kotlin/net/corda/core/contracts/clauses/ClauseTestUtils.kt b/core/src/test/kotlin/net/corda/core/contracts/clauses/ClauseTestUtils.kt deleted file mode 100644 index a21e6d2b08..0000000000 --- a/core/src/test/kotlin/net/corda/core/contracts/clauses/ClauseTestUtils.kt +++ /dev/null @@ -1,30 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import java.util.concurrent.atomic.AtomicInteger - -internal fun matchedClause(counter: AtomicInteger? = null) = object : Clause() { - override val requiredCommands: Set> = emptySet() - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, groupingKey: Unit?): Set { - counter?.incrementAndGet() - return emptySet() - } -} - -/** A clause that can never be matched */ -internal fun unmatchedClause(counter: AtomicInteger? = null) = object : Clause() { - override val requiredCommands: Set> = setOf(object : CommandData {}.javaClass) - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, groupingKey: Unit?): Set { - counter?.incrementAndGet() - return emptySet() - } -} diff --git a/core/src/test/kotlin/net/corda/core/contracts/clauses/VerifyClausesTests.kt b/core/src/test/kotlin/net/corda/core/contracts/clauses/VerifyClausesTests.kt deleted file mode 100644 index 4627f1baa2..0000000000 --- a/core/src/test/kotlin/net/corda/core/contracts/clauses/VerifyClausesTests.kt +++ /dev/null @@ -1,42 +0,0 @@ -package net.corda.core.contracts.clauses - -import net.corda.core.contracts.AuthenticatedObject -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import net.corda.testing.contracts.DummyContract -import net.corda.core.crypto.SecureHash -import org.junit.Test -import kotlin.test.assertFailsWith - -/** - * Tests for the clause verifier. - */ -class VerifyClausesTests { - /** Very simple check that the function doesn't error when given any clause */ - @Test - fun minimal() { - val clause = object : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, groupingKey: Unit?): Set = emptySet() - } - val tx = TransactionForContract(emptyList(), emptyList(), emptyList(), emptyList(), SecureHash.randomSHA256()) - verifyClause(tx, clause, emptyList>()) - } - - @Test - fun errorSuperfluousCommands() { - val clause = object : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, groupingKey: Unit?): Set = emptySet() - } - val command = AuthenticatedObject(emptyList(), emptyList(), DummyContract.Commands.Create()) - val tx = TransactionForContract(emptyList(), emptyList(), emptyList(), listOf(command), SecureHash.randomSHA256()) - // The clause is matched, but doesn't mark the command as consumed, so this should error - assertFailsWith { verifyClause(tx, clause, listOf(command)) } - } -} diff --git a/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt b/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt index 5294898acd..b34d2cc976 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt @@ -1,21 +1,26 @@ package net.corda.core.crypto import net.corda.core.crypto.composite.CompositeKey +import net.corda.core.crypto.composite.CompositeKey.NodeAndWeight import net.corda.core.crypto.composite.CompositeSignature import net.corda.core.crypto.composite.CompositeSignaturesWithKeys -import net.corda.core.div +import net.corda.core.internal.declaredField +import net.corda.core.internal.div import net.corda.core.serialization.serialize import net.corda.core.utilities.OpaqueBytes +import net.corda.node.utilities.* +import net.corda.testing.TestDependencyInjectionBase import org.bouncycastle.asn1.x500.X500Name import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder +import java.security.PublicKey import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertTrue -class CompositeKeyTests { +class CompositeKeyTests : TestDependencyInjectionBase() { @Rule @JvmField val tempFolder: TemporaryFolder = TemporaryFolder() @@ -24,18 +29,21 @@ class CompositeKeyTests { val bobKey = generateKeyPair() val charlieKey = generateKeyPair() - val alicePublicKey = aliceKey.public - val bobPublicKey = bobKey.public - val charliePublicKey = charlieKey.public + val alicePublicKey: PublicKey = aliceKey.public + val bobPublicKey: PublicKey = bobKey.public + val charliePublicKey: PublicKey = charlieKey.public val message = OpaqueBytes("Transaction".toByteArray()) + val secureHash = message.sha256() - val aliceSignature = aliceKey.sign(message) - val bobSignature = bobKey.sign(message) - val charlieSignature = charlieKey.sign(message) + // By lazy is required so that the serialisers are configured before vals initialisation takes place (they internally invoke serialise). + val aliceSignature by lazy { aliceKey.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(alicePublicKey).schemeNumberID))) } + val bobSignature by lazy { bobKey.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(bobPublicKey).schemeNumberID))) } + val charlieSignature by lazy { charlieKey.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(charliePublicKey).schemeNumberID))) } @Test fun `(Alice) fulfilled by Alice signature`() { + println(aliceKey.serialize().hash) assertTrue { alicePublicKey.isFulfilledBy(aliceSignature.by) } assertFalse { alicePublicKey.isFulfilledBy(charlieSignature.by) } } @@ -76,7 +84,7 @@ class CompositeKeyTests { } @Test - fun `kryo encoded tree decodes correctly`() { + fun `encoded tree decodes correctly`() { val aliceAndBob = CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey).build() val aliceAndBobOrCharlie = CompositeKey.Builder().addKeys(aliceAndBob, charliePublicKey).build(threshold = 1) @@ -146,11 +154,12 @@ class CompositeKeyTests { * Check that verifying a composite signature using the [CompositeSignature] engine works. */ @Test - fun `composite signature verification`() { + fun `composite TransactionSignature verification `() { val twoOfThree = CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey, charliePublicKey).build(threshold = 2) + val engine = CompositeSignature() engine.initVerify(twoOfThree) - engine.update(message.bytes) + engine.update(secureHash.bytes) assertFalse { engine.verify(CompositeSignaturesWithKeys(listOf(aliceSignature)).serialize().bytes) } assertFalse { engine.verify(CompositeSignaturesWithKeys(listOf(bobSignature)).serialize().bytes) } @@ -161,7 +170,7 @@ class CompositeKeyTests { assertTrue { engine.verify(CompositeSignaturesWithKeys(listOf(aliceSignature, bobSignature, charlieSignature)).serialize().bytes) } // Check the underlying signature is validated - val brokenBobSignature = DigitalSignature.WithKey(bobSignature.by, aliceSignature.bytes) + val brokenBobSignature = TransactionSignature(aliceSignature.bytes, bobSignature.by, SignatureMetadata(1, Crypto.findSignatureScheme(bobSignature.by).schemeNumberID)) assertFalse { engine.verify(CompositeSignaturesWithKeys(listOf(aliceSignature, brokenBobSignature)).serialize().bytes) } } @@ -229,10 +238,7 @@ class CompositeKeyTests { // We will create a graph cycle between key5 and key3. Key5 has already a reference to key3 (via key4). // To create a cycle, we add a reference (child) from key3 to key5. // Children list is immutable, so reflection is used to inject key5 as an extra NodeAndWeight child of key3. - val field = key3.javaClass.getDeclaredField("children") - field.isAccessible = true - val fixedChildren = key3.children.plus(CompositeKey.NodeAndWeight(key5, 1)) - field.set(key3, fixedChildren) + key3.declaredField>("children").value = key3.children + NodeAndWeight(key5, 1) /* A view of the example graph cycle. * @@ -278,19 +284,19 @@ class CompositeKeyTests { @Test fun `CompositeKey from multiple signature schemes and signature verification`() { - val (privRSA, pubRSA) = Crypto.generateKeyPair(Crypto.RSA_SHA256) - val (privK1, pubK1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) - val (privR1, pubR1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) - val (privEd, pubEd) = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) - val (privSP, pubSP) = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + val keyPairRSA = Crypto.generateKeyPair(Crypto.RSA_SHA256) + val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPairR1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val keyPairEd = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val keyPairSP = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) - val RSASignature = privRSA.sign(message.bytes, pubRSA) - val K1Signature = privK1.sign(message.bytes, pubK1) - val R1Signature = privR1.sign(message.bytes, pubR1) - val EdSignature = privEd.sign(message.bytes, pubEd) - val SPSignature = privSP.sign(message.bytes, pubSP) + val RSASignature = keyPairRSA.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairRSA.public).schemeNumberID))) + val K1Signature = keyPairK1.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairK1.public).schemeNumberID))) + val R1Signature = keyPairR1.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairR1.public).schemeNumberID))) + val EdSignature = keyPairEd.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairEd.public).schemeNumberID))) + val SPSignature = keyPairSP.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairSP.public).schemeNumberID))) - val compositeKey = CompositeKey.Builder().addKeys(pubRSA, pubK1, pubR1, pubEd, pubSP).build() as CompositeKey + val compositeKey = CompositeKey.Builder().addKeys(keyPairRSA.public, keyPairK1.public, keyPairR1.public, keyPairEd.public, keyPairSP.public).build() as CompositeKey val signatures = listOf(RSASignature, K1Signature, R1Signature, EdSignature, SPSignature) assertTrue { compositeKey.isFulfilledBy(signatures.byKeys()) } @@ -303,19 +309,19 @@ class CompositeKeyTests { @Test fun `Test save to keystore`() { // From test case [CompositeKey from multiple signature schemes and signature verification] - val (privRSA, pubRSA) = Crypto.generateKeyPair(Crypto.RSA_SHA256) - val (privK1, pubK1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) - val (privR1, pubR1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) - val (privEd, pubEd) = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) - val (privSP, pubSP) = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + val keyPairRSA = Crypto.generateKeyPair(Crypto.RSA_SHA256) + val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val keyPairR1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val keyPairEd = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val keyPairSP = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) - val RSASignature = privRSA.sign(message.bytes, pubRSA) - val K1Signature = privK1.sign(message.bytes, pubK1) - val R1Signature = privR1.sign(message.bytes, pubR1) - val EdSignature = privEd.sign(message.bytes, pubEd) - val SPSignature = privSP.sign(message.bytes, pubSP) + val RSASignature = keyPairRSA.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairRSA.public).schemeNumberID))) + val K1Signature = keyPairK1.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairK1.public).schemeNumberID))) + val R1Signature = keyPairR1.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairR1.public).schemeNumberID))) + val EdSignature = keyPairEd.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairEd.public).schemeNumberID))) + val SPSignature = keyPairSP.sign(SignableData(secureHash, SignatureMetadata(1, Crypto.findSignatureScheme(keyPairSP.public).schemeNumberID))) - val compositeKey = CompositeKey.Builder().addKeys(pubRSA, pubK1, pubR1, pubEd, pubSP).build() as CompositeKey + val compositeKey = CompositeKey.Builder().addKeys(keyPairRSA.public, keyPairK1.public, keyPairR1.public, keyPairEd.public, keyPairSP.public).build() as CompositeKey val signatures = listOf(RSASignature, K1Signature, R1Signature, EdSignature, SPSignature) assertTrue { compositeKey.isFulfilledBy(signatures.byKeys()) } @@ -332,12 +338,12 @@ class CompositeKeyTests { // Store certificate to keystore. val keystorePath = tempFolder.root.toPath() / "keystore.jks" - val keystore = KeyStoreUtilities.loadOrCreateKeyStore(keystorePath, "password") + val keystore = loadOrCreateKeyStore(keystorePath, "password") keystore.setCertificateEntry("CompositeKey", compositeKeyCert.cert) keystore.save(keystorePath, "password") // Load keystore from disk. - val keystore2 = KeyStoreUtilities.loadKeyStore(keystorePath, "password") + val keystore2 = loadKeyStore(keystorePath, "password") assertTrue { keystore2.containsAlias("CompositeKey") } val key = keystore2.getCertificate("CompositeKey").publicKey diff --git a/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt b/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt index 6b69c32e13..84fcc0f9e8 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/PartialMerkleTreeTest.kt @@ -1,26 +1,29 @@ package net.corda.core.crypto -import com.esotericsoftware.kryo.KryoException import net.corda.contracts.asset.Cash import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash.Companion.zeroHash import net.corda.core.identity.Party -import net.corda.core.serialization.p2PKryo +import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_PUBKEY_1 -import net.corda.testing.TEST_TX_TIME import net.corda.testing.* import org.junit.Test import java.security.PublicKey import java.util.function.Predicate import kotlin.test.* -class PartialMerkleTreeTest { +class PartialMerkleTreeTest : TestDependencyInjectionBase() { val nodes = "abcdef" - val hashed = nodes.map { it.serialize().sha256() } + val hashed = nodes.map { + initialiseTestSerialization() + try { + it.serialize().sha256() + } finally { + resetTestSerialization() + } + } val expectedRoot = MerkleTree.getMerkleTree(hashed.toMutableList() + listOf(zeroHash, zeroHash)).hash val merkleTree = MerkleTree.getMerkleTree(hashed) @@ -93,37 +96,42 @@ class PartialMerkleTreeTest { } @Test - fun `building Merkle tree for a transaction`() { + fun `building Merkle tree for a tx and nonce test`() { fun filtering(elem: Any): Boolean { return when (elem) { is StateRef -> true is TransactionState<*> -> elem.data.participants[0].owningKey.keys == MINI_CORP_PUBKEY.keys - is Command -> MEGA_CORP_PUBKEY in elem.signers + is Command<*> -> MEGA_CORP_PUBKEY in elem.signers is TimeWindow -> true is PublicKey -> elem == MEGA_CORP_PUBKEY else -> false } } + val d = testTx.serialize().deserialize() + assertEquals(testTx.id, d.id) + val mt = testTx.buildFilteredTransaction(Predicate(::filtering)) val leaves = mt.filteredLeaves - val d = WireTransaction.deserialize(testTx.serialized) - assertEquals(testTx.id, d.id) - assertEquals(1, leaves.commands.size) - assertEquals(1, leaves.outputs.size) + assertEquals(1, leaves.inputs.size) - assertEquals(1, leaves.mustSign.size) assertEquals(0, leaves.attachments.size) - assertTrue(mt.filteredLeaves.timeWindow != null) - assertEquals(null, mt.filteredLeaves.type) - assertEquals(null, mt.filteredLeaves.notary) + assertEquals(1, leaves.outputs.size) + assertEquals(1, leaves.commands.size) + assertNull(mt.filteredLeaves.notary) + assertNotNull(mt.filteredLeaves.timeWindow) + assertNull(mt.filteredLeaves.privacySalt) + assertEquals(4, leaves.nonces.size) assertTrue(mt.verify()) } @Test fun `same transactions with different notaries have different ids`() { - val wtx1 = makeSimpleCashWtx(DUMMY_NOTARY) - val wtx2 = makeSimpleCashWtx(MEGA_CORP) + // We even use the same privacySalt, and thus the only difference between the two transactions is the notary party. + val privacySalt = PrivacySalt() + val wtx1 = makeSimpleCashWtx(DUMMY_NOTARY, privacySalt) + val wtx2 = makeSimpleCashWtx(MEGA_CORP, privacySalt) + assertEquals(wtx1.privacySalt, wtx2.privacySalt) assertNotEquals(wtx1.id, wtx2.id) } @@ -135,10 +143,22 @@ class PartialMerkleTreeTest { assertTrue(mt.filteredLeaves.inputs.isEmpty()) assertTrue(mt.filteredLeaves.outputs.isEmpty()) assertTrue(mt.filteredLeaves.timeWindow == null) + assertTrue(mt.filteredLeaves.availableComponents.isEmpty()) + assertTrue(mt.filteredLeaves.availableComponentHashes.isEmpty()) + assertTrue(mt.filteredLeaves.nonces.isEmpty()) assertFailsWith { mt.verify() } + + // Including only privacySalt still results to an empty FilteredTransaction. + fun filterPrivacySalt(elem: Any): Boolean = elem is PrivacySalt + val mt2 = testTx.buildFilteredTransaction(Predicate(::filterPrivacySalt)) + assertTrue(mt2.filteredLeaves.privacySalt == null) + assertTrue(mt2.filteredLeaves.availableComponents.isEmpty()) + assertTrue(mt2.filteredLeaves.availableComponentHashes.isEmpty()) + assertTrue(mt2.filteredLeaves.nonces.isEmpty()) + assertFailsWith { mt2.verify() } } - // Partial Merkle Tree building tests + // Partial Merkle Tree building tests. @Test fun `build Partial Merkle Tree, only left nodes branch`() { val inclHashes = listOf(hashed[3], hashed[5]) @@ -212,24 +232,26 @@ class PartialMerkleTreeTest { assertFalse(pmt.verify(wrongRoot, inclHashes)) } - @Test(expected = KryoException::class) + @Test(expected = Exception::class) fun `hash map serialization not allowed`() { val hm1 = hashMapOf("a" to 1, "b" to 2, "c" to 3, "e" to 4) - p2PKryo().run { kryo -> - hm1.serialize(kryo) - } + hm1.serialize() } - private fun makeSimpleCashWtx(notary: Party, timeWindow: TimeWindow? = null, attachments: List = emptyList()): WireTransaction { + private fun makeSimpleCashWtx( + notary: Party, + privacySalt: PrivacySalt = PrivacySalt(), + timeWindow: TimeWindow? = null, + attachments: List = emptyList() + ): WireTransaction { return WireTransaction( inputs = testTx.inputs, attachments = attachments, outputs = testTx.outputs, commands = testTx.commands, notary = notary, - signers = listOf(MEGA_CORP_PUBKEY, DUMMY_PUBKEY_1), - type = TransactionType.General, - timeWindow = timeWindow + timeWindow = timeWindow, + privacySalt = privacySalt ) } } diff --git a/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt b/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt index cb83d847da..c8f35a77a5 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/SignedDataTest.kt @@ -1,13 +1,21 @@ package net.corda.core.crypto +import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.serialize +import net.corda.testing.TestDependencyInjectionBase +import org.junit.Before import org.junit.Test import java.security.SignatureException import kotlin.test.assertEquals -class SignedDataTest { +class SignedDataTest : TestDependencyInjectionBase() { + @Before + fun initialise() { + serialized = data.serialize() + } + val data = "Just a simple test string" - val serialized = data.serialize() + lateinit var serialized: SerializedBytes @Test fun `make sure correctly signed data is released`() { diff --git a/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt b/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt index 4aad5b6580..287e7de0c9 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/TransactionSignatureTest.kt @@ -1,74 +1,41 @@ package net.corda.core.crypto +import net.corda.testing.TestDependencyInjectionBase import org.junit.Test import java.security.SignatureException -import java.time.Instant import kotlin.test.assertTrue /** - * Digital signature MetaData tests + * Digital signature MetaData tests. */ -class TransactionSignatureTest { +class TransactionSignatureTest : TestDependencyInjectionBase() { val testBytes = "12345678901234567890123456789012".toByteArray() - /** valid sign and verify. */ + /** Valid sign and verify. */ @Test - fun `MetaData Full sign and verify`() { + fun `Signature metadata full sign and verify`() { val keyPair = Crypto.generateKeyPair("ECDSA_SECP256K1_SHA256") - // create a MetaData.Full object - val meta = MetaData("ECDSA_SECP256K1_SHA256", "M9", SignatureType.FULL, Instant.now(), null, null, testBytes, keyPair.public) + // Create a SignableData object. + val signableData = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(keyPair.public).schemeNumberID)) - // sign the message - val transactionSignature: TransactionSignature = keyPair.private.sign(meta) + // Sign the meta object. + val transactionSignature: TransactionSignature = keyPair.sign(signableData) - // check auto-verification - assertTrue(transactionSignature.verify()) + // Check auto-verification. + assertTrue(transactionSignature.verify(testBytes.sha256())) - // check manual verification - assertTrue(keyPair.public.verify(transactionSignature)) + // Check manual verification. + assertTrue(Crypto.doVerify(testBytes.sha256(), transactionSignature)) } - /** Signing should fail, as I sign with a secpK1 key, but set schemeCodeName is set to secpR1. */ - @Test(expected = IllegalArgumentException::class) - fun `MetaData Full failure wrong scheme`() { + /** Verification should fail; corrupted metadata - clearData (Merkle root) has changed. */ + @Test(expected = SignatureException::class) + fun `Signature metadata full failure clearData has changed`() { val keyPair = Crypto.generateKeyPair("ECDSA_SECP256K1_SHA256") - val meta = MetaData("ECDSA_SECP256R1_SHA256", "M9", SignatureType.FULL, Instant.now(), null, null, testBytes, keyPair.public) - keyPair.private.sign(meta) - } - - /** Verification should fail; corrupted metadata - public key has changed. */ - @Test(expected = SignatureException::class) - fun `MetaData Full failure public key has changed`() { - val keyPair1 = Crypto.generateKeyPair("ECDSA_SECP256K1_SHA256") - val keyPair2 = Crypto.generateKeyPair("ECDSA_SECP256K1_SHA256") - val meta = MetaData("ECDSA_SECP256K1_SHA256", "M9", SignatureType.FULL, Instant.now(), null, null, testBytes, keyPair2.public) - val transactionSignature = keyPair1.private.sign(meta) - transactionSignature.verify() - } - - /** Verification should fail; corrupted metadata - clearData has changed. */ - @Test(expected = SignatureException::class) - fun `MetaData Full failure clearData has changed`() { - val keyPair1 = Crypto.generateKeyPair("ECDSA_SECP256K1_SHA256") - val meta = MetaData("ECDSA_SECP256K1_SHA256", "M9", SignatureType.FULL, Instant.now(), null, null, testBytes, keyPair1.public) - val transactionSignature = keyPair1.private.sign(meta) - - val meta2 = MetaData("ECDSA_SECP256K1_SHA256", "M9", SignatureType.FULL, Instant.now(), null, null, testBytes.plus(testBytes), keyPair1.public) - val transactionSignature2 = TransactionSignature(transactionSignature.signatureData, meta2) - keyPair1.public.verify(transactionSignature2) - } - - /** Verification should fail; corrupted metadata - schemeCodeName has changed from K1 to R1. */ - @Test(expected = SignatureException::class) - fun `MetaData Wrong schemeCodeName has changed`() { - val keyPair1 = Crypto.generateKeyPair("ECDSA_SECP256K1_SHA256") - val meta = MetaData("ECDSA_SECP256K1_SHA256", "M9", SignatureType.FULL, Instant.now(), null, null, testBytes, keyPair1.public) - val transactionSignature = keyPair1.private.sign(meta) - - val meta2 = MetaData("ECDSA_SECP256R1_SHA256", "M9", SignatureType.FULL, Instant.now(), null, null, testBytes.plus(testBytes), keyPair1.public) - val transactionSignature2 = TransactionSignature(transactionSignature.signatureData, meta2) - keyPair1.public.verify(transactionSignature2) + val signableData = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(keyPair.public).schemeNumberID)) + val transactionSignature = keyPair.sign(signableData) + Crypto.doVerify((testBytes + testBytes).sha256(), transactionSignature) } } diff --git a/core/src/test/kotlin/net/corda/core/crypto/X509NameConstraintsTest.kt b/core/src/test/kotlin/net/corda/core/crypto/X509NameConstraintsTest.kt index 30292b00e7..8d2589119c 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/X509NameConstraintsTest.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/X509NameConstraintsTest.kt @@ -1,6 +1,7 @@ package net.corda.core.crypto -import net.corda.core.toTypedArray +import net.corda.core.internal.toTypedArray +import net.corda.node.utilities.* import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralSubtree @@ -17,23 +18,23 @@ class X509NameConstraintsTest { private fun makeKeyStores(subjectName: X500Name, nameConstraints: NameConstraints): Pair { val rootKeys = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val rootCACert = X509Utilities.createSelfSignedCACertificate(X509Utilities.getX509Name("Corda Root CA","London","demo@r3.com",null), rootKeys) + val rootCACert = X509Utilities.createSelfSignedCACertificate(getX509Name("Corda Root CA", "London", "demo@r3.com", null), rootKeys) val intermediateCAKeyPair = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val intermediateCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, rootCACert, rootKeys, X509Utilities.getX509Name("Corda Intermediate CA","London","demo@r3.com",null), intermediateCAKeyPair.public) + val intermediateCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, rootCACert, rootKeys, getX509Name("Corda Intermediate CA", "London", "demo@r3.com", null), intermediateCAKeyPair.public) val clientCAKeyPair = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val clientCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, intermediateCACert, intermediateCAKeyPair, X509Utilities.getX509Name("Corda Client CA","London","demo@r3.com",null), clientCAKeyPair.public, nameConstraints = nameConstraints) + val clientCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, intermediateCACert, intermediateCAKeyPair, getX509Name("Corda Client CA", "London", "demo@r3.com", null), clientCAKeyPair.public, nameConstraints = nameConstraints) val keyPass = "password" - val trustStore = KeyStore.getInstance(KeyStoreUtilities.KEYSTORE_TYPE) + val trustStore = KeyStore.getInstance(KEYSTORE_TYPE) trustStore.load(null, keyPass.toCharArray()) trustStore.addOrReplaceCertificate(X509Utilities.CORDA_ROOT_CA, rootCACert.cert) val tlsKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) val tlsCert = X509Utilities.createCertificate(CertificateType.TLS, clientCACert, clientCAKeyPair, subjectName, tlsKey.public) - val keyStore = KeyStore.getInstance(KeyStoreUtilities.KEYSTORE_TYPE) + val keyStore = KeyStore.getInstance(KEYSTORE_TYPE) keyStore.load(null, keyPass.toCharArray()) keyStore.addOrReplaceKey(X509Utilities.CORDA_CLIENT_TLS, tlsKey.private, keyPass.toCharArray(), Stream.of(tlsCert, clientCACert, intermediateCACert, rootCACert).map { it.cert }.toTypedArray()) diff --git a/node/src/test/kotlin/net/corda/node/messaging/AttachmentTests.kt b/core/src/test/kotlin/net/corda/core/flows/AttachmentTests.kt similarity index 65% rename from node/src/test/kotlin/net/corda/node/messaging/AttachmentTests.kt rename to core/src/test/kotlin/net/corda/core/flows/AttachmentTests.kt index a9fb24ce5b..767c4e045d 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/AttachmentTests.kt +++ b/core/src/test/kotlin/net/corda/core/flows/AttachmentTests.kt @@ -1,28 +1,28 @@ -package net.corda.node.messaging +package net.corda.core.flows +import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.Attachment import net.corda.core.crypto.SecureHash import net.corda.core.crypto.sha256 -import net.corda.core.getOrThrow import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.ServiceInfo -import net.corda.flows.FetchAttachmentsFlow -import net.corda.flows.FetchDataFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.identity.Party +import net.corda.core.internal.FetchAttachmentsFlow +import net.corda.core.internal.FetchDataFlow import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.database.RequeryConfiguration import net.corda.node.services.network.NetworkMapService import net.corda.node.services.persistence.schemas.requery.AttachmentEntity import net.corda.node.services.transactions.SimpleNotaryService -import net.corda.node.utilities.transaction import net.corda.testing.node.MockNetwork import net.corda.testing.node.makeTestDataSourceProperties -import org.jetbrains.exposed.sql.Database +import net.corda.testing.node.makeTestDatabaseProperties import org.junit.After import org.junit.Before import org.junit.Test import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream -import java.io.Closeable import java.math.BigInteger import java.security.KeyPair import java.util.jar.JarOutputStream @@ -32,15 +32,13 @@ import kotlin.test.assertFailsWith class AttachmentTests { lateinit var mockNet: MockNetwork - lateinit var dataSource: Closeable - lateinit var database: Database lateinit var configuration: RequeryConfiguration @Before fun setUp() { mockNet = MockNetwork() val dataSourceProperties = makeTestDataSourceProperties() - configuration = RequeryConfiguration(dataSourceProperties) + configuration = RequeryConfiguration(dataSourceProperties, databaseProperties = makeTestDatabaseProperties()) } @After @@ -60,7 +58,11 @@ class AttachmentTests { @Test fun `download and store`() { - val (n0, n1) = mockNet.createTwoNodes() + val nodes = mockNet.createSomeNodes(2) + val n0 = nodes.partyNodes[0] + val n1 = nodes.partyNodes[1] + n0.registerInitiatedFlow(FetchAttachmentsResponse::class.java) + n1.registerInitiatedFlow(FetchAttachmentsResponse::class.java) // Insert an attachment into node zero's store directly. val id = n0.database.transaction { @@ -69,7 +71,7 @@ class AttachmentTests { // Get node one to run a flow to fetch it and insert it. mockNet.runNetwork() - val f1 = n1.services.startFlow(FetchAttachmentsFlow(setOf(id), n0.info.legalIdentity)) + val f1 = n1.startAttachmentFlow(setOf(id), n0.info.legalIdentity) mockNet.runNetwork() assertEquals(0, f1.resultFuture.getOrThrow().fromDisk.size) @@ -83,18 +85,22 @@ class AttachmentTests { // Shut down node zero and ensure node one can still resolve the attachment. n0.stop() - val response: FetchDataFlow.Result = n1.services.startFlow(FetchAttachmentsFlow(setOf(id), n0.info.legalIdentity)).resultFuture.getOrThrow() + val response: FetchDataFlow.Result = n1.startAttachmentFlow(setOf(id), n0.info.legalIdentity).resultFuture.getOrThrow() assertEquals(attachment, response.fromDisk[0]) } @Test fun `missing`() { - val (n0, n1) = mockNet.createTwoNodes() + val nodes = mockNet.createSomeNodes(2) + val n0 = nodes.partyNodes[0] + val n1 = nodes.partyNodes[1] + n0.registerInitiatedFlow(FetchAttachmentsResponse::class.java) + n1.registerInitiatedFlow(FetchAttachmentsResponse::class.java) // Get node one to fetch a non-existent attachment. val hash = SecureHash.randomSHA256() mockNet.runNetwork() - val f1 = n1.services.startFlow(FetchAttachmentsFlow(setOf(hash), n0.info.legalIdentity)) + val f1 = n1.startAttachmentFlow(setOf(hash), n0.info.legalIdentity) mockNet.runNetwork() val e = assertFailsWith { f1.resultFuture.getOrThrow() } assertEquals(hash, e.requested) @@ -103,22 +109,24 @@ class AttachmentTests { @Test fun `malicious response`() { // Make a node that doesn't do sanity checking at load time. - val n0 = mockNet.createNode(null, -1, object : MockNetwork.Factory { + val n0 = mockNet.createNode(nodeFactory = object : MockNetwork.Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, entropyRoot: BigInteger): MockNetwork.MockNode { return object : MockNetwork.MockNode(config, network, networkMapAddr, advertisedServices, id, overrideServices, entropyRoot) { - override fun start(): MockNetwork.MockNode { + override fun start() { super.start() attachments.checkAttachmentsOnLoad = false - return this } } } - }, true, null, null, ServiceInfo(NetworkMapService.type), ServiceInfo(SimpleNotaryService.type)) + }, advertisedServices = *arrayOf(ServiceInfo(NetworkMapService.type), ServiceInfo(SimpleNotaryService.type))) val n1 = mockNet.createNode(n0.network.myAddress) + n0.registerInitiatedFlow(FetchAttachmentsResponse::class.java) + n1.registerInitiatedFlow(FetchAttachmentsResponse::class.java) + val attachment = fakeAttachment() // Insert an attachment into node zero's store directly. val id = n0.database.transaction { @@ -136,11 +144,24 @@ class AttachmentTests { n0.attachments.session.update(corruptAttachment) } - // Get n1 to fetch the attachment. Should receive corrupted bytes. mockNet.runNetwork() - val f1 = n1.services.startFlow(FetchAttachmentsFlow(setOf(id), n0.info.legalIdentity)) + val f1 = n1.startAttachmentFlow(setOf(id), n0.info.legalIdentity) mockNet.runNetwork() assertFailsWith { f1.resultFuture.getOrThrow() } } + + private fun MockNetwork.MockNode.startAttachmentFlow(hashes: Set, otherSide: Party) = services.startFlow(InitiatingFetchAttachmentsFlow(otherSide, hashes)) + + @InitiatingFlow + private class InitiatingFetchAttachmentsFlow(val otherSide: Party, val hashes: Set) : FlowLogic>() { + @Suspendable + override fun call(): FetchDataFlow.Result = subFlow(FetchAttachmentsFlow(hashes, otherSide)) + } + + @InitiatedBy(InitiatingFetchAttachmentsFlow::class) + private class FetchAttachmentsResponse(val otherSide: Party) : FlowLogic() { + @Suspendable + override fun call() = subFlow(TestDataVendingFlow(otherSide)) + } } diff --git a/core/src/test/kotlin/net/corda/core/flows/CollectSignaturesFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/CollectSignaturesFlowTests.kt index 8bccd7666c..769de29fb4 100644 --- a/core/src/test/kotlin/net/corda/core/flows/CollectSignaturesFlowTests.kt +++ b/core/src/test/kotlin/net/corda/core/flows/CollectSignaturesFlowTests.kt @@ -2,16 +2,13 @@ package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.Command -import net.corda.core.contracts.TransactionType import net.corda.core.contracts.requireThat import net.corda.testing.contracts.DummyContract -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.getOrThrow +import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.unwrap -import net.corda.flows.CollectSignaturesFlow -import net.corda.flows.FinalityFlow -import net.corda.flows.SignTransactionFlow import net.corda.testing.MINI_CORP_KEY import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockServices @@ -64,9 +61,10 @@ class CollectSignaturesFlowTests { val flow = object : SignTransactionFlow(otherParty) { @Suspendable override fun checkTransaction(stx: SignedTransaction) = requireThat { val tx = stx.tx + val ltx = tx.toLedgerTransaction(serviceHub) "There should only be one output state" using (tx.outputs.size == 1) "There should only be one output state" using (tx.inputs.isEmpty()) - val magicNumberState = tx.outputs.single().data as DummyContract.MultiOwnerState + val magicNumberState = ltx.outputsOfType().single() "Must be 1337 or greater" using (magicNumberState.magicNumber >= 1337) } } @@ -86,7 +84,7 @@ class CollectSignaturesFlowTests { val notary = serviceHub.networkMapCache.notaryNodes.single().notaryIdentity val command = Command(DummyContract.Commands.Create(), state.participants.map { it.owningKey }) - val builder = TransactionType.General.Builder(notary = notary).withItems(state, command) + val builder = TransactionBuilder(notary).withItems(state, command) val ptx = serviceHub.signInitialTransaction(builder) val stx = subFlow(CollectSignaturesFlow(ptx)) val ftx = subFlow(FinalityFlow(stx)).single() @@ -106,7 +104,7 @@ class CollectSignaturesFlowTests { override fun call(): SignedTransaction { val notary = serviceHub.networkMapCache.notaryNodes.single().notaryIdentity val command = Command(DummyContract.Commands.Create(), state.participants.map { it.owningKey }) - val builder = TransactionType.General.Builder(notary = notary).withItems(state, command) + val builder = TransactionBuilder(notary).withItems(state, command) val ptx = serviceHub.signInitialTransaction(builder) val stx = subFlow(CollectSignaturesFlow(ptx)) val ftx = subFlow(FinalityFlow(stx)).single() @@ -121,9 +119,10 @@ class CollectSignaturesFlowTests { val flow = object : SignTransactionFlow(otherParty) { @Suspendable override fun checkTransaction(stx: SignedTransaction) = requireThat { val tx = stx.tx + val ltx = tx.toLedgerTransaction(serviceHub) "There should only be one output state" using (tx.outputs.size == 1) "There should only be one output state" using (tx.inputs.isEmpty()) - val magicNumberState = tx.outputs.single().data as DummyContract.MultiOwnerState + val magicNumberState = ltx.outputsOfType().single() "Must be 1337 or greater" using (magicNumberState.magicNumber >= 1337) } } @@ -144,7 +143,7 @@ class CollectSignaturesFlowTests { val flow = a.services.startFlow(TestFlowTwo.Initiator(state)) mockNet.runNetwork() val result = flow.resultFuture.getOrThrow() - result.verifySignatures() + result.verifyRequiredSignatures() println(result.tx) println(result.sigs) } @@ -156,7 +155,7 @@ class CollectSignaturesFlowTests { val flow = a.services.startFlow(CollectSignaturesFlow(ptx)) mockNet.runNetwork() val result = flow.resultFuture.getOrThrow() - result.verifySignatures() + result.verifyRequiredSignatures() println(result.tx) println(result.sigs) } diff --git a/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt b/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt index 1ee8601e34..bc05d5be77 100644 --- a/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt +++ b/core/src/test/kotlin/net/corda/core/flows/ContractUpgradeFlowTest.kt @@ -3,26 +3,24 @@ package net.corda.core.flows import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.asset.Cash import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyContract -import net.corda.testing.contracts.DummyContractV2 import net.corda.core.crypto.SecureHash -import net.corda.core.getOrThrow import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow -import net.corda.core.node.services.unconsumedStates +import net.corda.core.node.services.queryBy import net.corda.core.utilities.OpaqueBytes +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.Emoji +import net.corda.core.internal.Emoji +import net.corda.core.utilities.getOrThrow import net.corda.flows.CashIssueFlow -import net.corda.flows.ContractUpgradeFlow -import net.corda.flows.FinalityFlow import net.corda.node.internal.CordaRPCOpsImpl import net.corda.node.services.startFlowPermission -import net.corda.node.utilities.transaction import net.corda.nodeapi.User import net.corda.testing.RPCDriverExposedDSLInterface +import net.corda.testing.contracts.DummyContract +import net.corda.testing.contracts.DummyContractV2 import net.corda.testing.node.MockNetwork import net.corda.testing.rpcDriver import net.corda.testing.rpcTestUser @@ -74,7 +72,7 @@ class ContractUpgradeFlowTest { // The request is expected to be rejected because party B hasn't authorised the upgrade yet. val rejectedFuture = a.services.startFlow(ContractUpgradeFlow(atx!!.tx.outRef(0), DummyContractV2::class.java)).resultFuture mockNet.runNetwork() - assertFailsWith(FlowSessionException::class) { rejectedFuture.getOrThrow() } + assertFailsWith(UnexpectedFlowEndException::class) { rejectedFuture.getOrThrow() } // Party B authorise the contract state upgrade. b.services.vaultService.authoriseContractUpgrade(btx!!.tx.outRef(0), DummyContractV2::class.java) @@ -118,7 +116,7 @@ class ContractUpgradeFlowTest { @Test fun `2 parties contract upgrade using RPC`() { - rpcDriver { + rpcDriver(initialiseSerialization = false) { // Create dummy contract. val twoPartyDummyContract = DummyContract.generateInitial(0, notary, a.info.legalIdentity.ref(1), b.info.legalIdentity.ref(1)) val signedByA = a.services.signInitialTransaction(twoPartyDummyContract) @@ -144,7 +142,7 @@ class ContractUpgradeFlowTest { DummyContractV2::class.java).returnValue mockNet.runNetwork() - assertFailsWith(FlowSessionException::class) { rejectedFuture.getOrThrow() } + assertFailsWith(UnexpectedFlowEndException::class) { rejectedFuture.getOrThrow() } // Party B authorise the contract state upgrade. rpcB.authoriseContractUpgrade(btx!!.tx.outRef(0), DummyContractV2::class.java) @@ -180,14 +178,14 @@ class ContractUpgradeFlowTest { mockNet.runNetwork() val stx = result.getOrThrow().stx val stateAndRef = stx.tx.outRef(0) - val baseState = a.database.transaction { a.services.vaultService.unconsumedStates().single() } + val baseState = a.database.transaction { a.services.vaultQueryService.queryBy().states.single() } assertTrue(baseState.state.data is Cash.State, "Contract state is old version.") // Starts contract upgrade flow. val upgradeResult = a.services.startFlow(ContractUpgradeFlow(stateAndRef, CashV2::class.java)).resultFuture mockNet.runNetwork() upgradeResult.getOrThrow() // Get contract state from the vault. - val firstState = a.database.transaction { a.services.vaultService.unconsumedStates().single() } + val firstState = a.database.transaction { a.services.vaultQueryService.queryBy().states.single() } assertTrue(firstState.state.data is CashV2.State, "Contract state is upgraded to the new version.") assertEquals(Amount(1000000, USD).`issued by`(a.info.legalIdentity.ref(1)), (firstState.state.data as CashV2.State).amount, "Upgraded cash contain the correct amount.") assertEquals>(listOf(a.info.legalIdentity), (firstState.state.data as CashV2.State).owners, "Upgraded cash belongs to the right owner.") @@ -204,12 +202,12 @@ class ContractUpgradeFlowTest { override fun move(newAmount: Amount>, newOwner: AbstractParty) = copy(amount = amount.copy(newAmount.quantity), owners = listOf(newOwner)) override fun toString() = "${Emoji.bagOfCash}New Cash($amount at ${amount.token.issuer} owned by $owner)" - override fun withNewOwner(newOwner: AbstractParty) = Pair(Cash.Commands.Move(), copy(owners = listOf(newOwner))) + override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Cash.Commands.Move(), copy(owners = listOf(newOwner))) } override fun upgrade(state: Cash.State) = CashV2.State(state.amount.times(1000), listOf(state.owner)) - override fun verify(tx: TransactionForContract) {} + override fun verify(tx: LedgerTransaction) {} // Dummy Cash contract for testing. override val legalContractReference = SecureHash.sha256("") diff --git a/core/src/test/kotlin/net/corda/core/flows/FinalityFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/FinalityFlowTests.kt new file mode 100644 index 0000000000..b280db5dff --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/flows/FinalityFlowTests.kt @@ -0,0 +1,55 @@ +package net.corda.core.flows + +import net.corda.contracts.asset.Cash +import net.corda.core.contracts.Amount +import net.corda.core.contracts.GBP +import net.corda.core.contracts.Issued +import net.corda.core.identity.Party +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.getOrThrow +import net.corda.testing.node.MockNetwork +import net.corda.testing.node.MockServices +import org.junit.After +import org.junit.Before +import org.junit.Test +import kotlin.test.assertEquals + +class FinalityFlowTests { + lateinit var mockNet: MockNetwork + lateinit var nodeA: MockNetwork.MockNode + lateinit var nodeB: MockNetwork.MockNode + lateinit var notary: Party + val services = MockServices() + + @Before + fun setup() { + mockNet = MockNetwork() + val nodes = mockNet.createSomeNodes(2) + nodeA = nodes.partyNodes[0] + nodeB = nodes.partyNodes[1] + notary = nodes.notaryNode.info.notaryIdentity + mockNet.runNetwork() + } + + @After + fun tearDown() { + mockNet.stopNodes() + } + + @Test + fun `finalise a simple transaction`() { + val amount = Amount(1000, Issued(nodeA.info.legalIdentity.ref(0), GBP)) + val builder = TransactionBuilder(notary) + Cash().generateIssue(builder, amount, nodeB.info.legalIdentity, notary) + val stx = nodeA.services.signInitialTransaction(builder) + val flow = nodeA.services.startFlow(FinalityFlow(stx)) + mockNet.runNetwork() + val result = flow.resultFuture.getOrThrow() + val notarisedTx = result.single() + notarisedTx.verifyRequiredSignatures() + val transactionSeenByB = nodeB.services.database.transaction { + nodeB.services.validatedTransactions.getTransaction(notarisedTx.id) + } + assertEquals(notarisedTx, transactionSeenByB) + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/flows/ManualFinalityFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/ManualFinalityFlowTests.kt new file mode 100644 index 0000000000..01e81137f2 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/flows/ManualFinalityFlowTests.kt @@ -0,0 +1,63 @@ +package net.corda.core.flows + +import net.corda.contracts.asset.Cash +import net.corda.core.contracts.Amount +import net.corda.core.contracts.GBP +import net.corda.core.contracts.Issued +import net.corda.core.identity.Party +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.getOrThrow +import net.corda.testing.node.MockNetwork +import net.corda.testing.node.MockServices +import org.junit.After +import org.junit.Before +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull + +class ManualFinalityFlowTests { + lateinit var mockNet: MockNetwork + lateinit var nodeA: MockNetwork.MockNode + lateinit var nodeB: MockNetwork.MockNode + lateinit var nodeC: MockNetwork.MockNode + lateinit var notary: Party + val services = MockServices() + + @Before + fun setup() { + mockNet = MockNetwork() + val nodes = mockNet.createSomeNodes(3) + nodeA = nodes.partyNodes[0] + nodeB = nodes.partyNodes[1] + nodeC = nodes.partyNodes[2] + notary = nodes.notaryNode.info.notaryIdentity + mockNet.runNetwork() + } + + @After + fun tearDown() { + mockNet.stopNodes() + } + + @Test + fun `finalise a simple transaction`() { + val amount = Amount(1000, Issued(nodeA.info.legalIdentity.ref(0), GBP)) + val builder = TransactionBuilder(notary) + Cash().generateIssue(builder, amount, nodeB.info.legalIdentity, notary) + val stx = nodeA.services.signInitialTransaction(builder) + val flow = nodeA.services.startFlow(ManualFinalityFlow(stx, setOf(nodeC.info.legalIdentity))) + mockNet.runNetwork() + val result = flow.resultFuture.getOrThrow() + val notarisedTx = result.single() + notarisedTx.verifyRequiredSignatures() + // We override the participants, so node C will get a copy despite not being involved, and B won't + val transactionSeenByB = nodeB.services.database.transaction { + nodeB.services.validatedTransactions.getTransaction(notarisedTx.id) + } + assertNull(transactionSeenByB) + val transactionSeenByC = nodeC.services.database.transaction { + nodeC.services.validatedTransactions.getTransaction(notarisedTx.id) + } + assertEquals(notarisedTx, transactionSeenByC) + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/flows/TestDataVendingFlow.kt b/core/src/test/kotlin/net/corda/core/flows/TestDataVendingFlow.kt new file mode 100644 index 0000000000..5c53890038 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/flows/TestDataVendingFlow.kt @@ -0,0 +1,19 @@ +package net.corda.core.flows + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.identity.Party +import net.corda.core.internal.FetchDataFlow +import net.corda.core.utilities.UntrustworthyData + +// Flow to start data vending without sending transaction. For testing only. +class TestDataVendingFlow(otherSide: Party) : SendStateAndRefFlow(otherSide, emptyList()) { + @Suspendable + override fun sendPayloadAndReceiveDataRequest(otherSide: Party, payload: Any): UntrustworthyData { + return if (payload is List<*> && payload.isEmpty()) { + // Hack to not send the first message. + receive(otherSide) + } else { + super.sendPayloadAndReceiveDataRequest(otherSide, payload) + } + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/flows/TransactionKeyFlowTests.kt b/core/src/test/kotlin/net/corda/core/flows/TransactionKeyFlowTests.kt similarity index 54% rename from core/src/test/kotlin/net/corda/flows/TransactionKeyFlowTests.kt rename to core/src/test/kotlin/net/corda/core/flows/TransactionKeyFlowTests.kt index 486e24b3e8..83defd2fa2 100644 --- a/core/src/test/kotlin/net/corda/flows/TransactionKeyFlowTests.kt +++ b/core/src/test/kotlin/net/corda/core/flows/TransactionKeyFlowTests.kt @@ -1,14 +1,13 @@ -package net.corda.flows +package net.corda.core.flows -import net.corda.core.getOrThrow import net.corda.core.identity.AbstractParty +import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party +import net.corda.core.utilities.getOrThrow import net.corda.testing.ALICE import net.corda.testing.BOB import net.corda.testing.DUMMY_NOTARY import net.corda.testing.node.MockNetwork -import org.junit.After -import org.junit.Before import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertFalse @@ -16,22 +15,10 @@ import kotlin.test.assertNotEquals import kotlin.test.assertTrue class TransactionKeyFlowTests { - lateinit var mockNet: MockNetwork - - @Before - fun before() { - mockNet = MockNetwork(false) - } - - @After - fun cleanUp() { - mockNet.stopNodes() - } - @Test fun `issue key`() { // We run this in parallel threads to help catch any race conditions that may exist. - mockNet = MockNetwork(false, true) + val mockNet = MockNetwork(false, true) // Set up values we'll need val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) @@ -39,31 +26,33 @@ class TransactionKeyFlowTests { val bobNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOB.name) val alice: Party = aliceNode.services.myInfo.legalIdentity val bob: Party = bobNode.services.myInfo.legalIdentity - aliceNode.services.identityService.registerIdentity(bobNode.info.legalIdentityAndCert) - aliceNode.services.identityService.registerIdentity(notaryNode.info.legalIdentityAndCert) - bobNode.services.identityService.registerIdentity(aliceNode.info.legalIdentityAndCert) - bobNode.services.identityService.registerIdentity(notaryNode.info.legalIdentityAndCert) + aliceNode.services.identityService.verifyAndRegisterIdentity(bobNode.info.legalIdentityAndCert) + aliceNode.services.identityService.verifyAndRegisterIdentity(notaryNode.info.legalIdentityAndCert) + bobNode.services.identityService.verifyAndRegisterIdentity(aliceNode.info.legalIdentityAndCert) + bobNode.services.identityService.verifyAndRegisterIdentity(notaryNode.info.legalIdentityAndCert) // Run the flows val requesterFlow = aliceNode.services.startFlow(TransactionKeyFlow(bob)) // Get the results - val actual: Map = requesterFlow.resultFuture.getOrThrow().toMap() + val actual: Map = requesterFlow.resultFuture.getOrThrow().toMap() assertEquals(2, actual.size) // Verify that the generated anonymous identities do not match the well known identities val aliceAnonymousIdentity = actual[alice] ?: throw IllegalStateException() val bobAnonymousIdentity = actual[bob] ?: throw IllegalStateException() - assertNotEquals(alice, aliceAnonymousIdentity.identity) - assertNotEquals(bob, bobAnonymousIdentity.identity) + assertNotEquals(alice, aliceAnonymousIdentity) + assertNotEquals(bob, bobAnonymousIdentity) // Verify that the anonymous identities look sane - assertEquals(alice.name, aliceAnonymousIdentity.certificate.subject) - assertEquals(bob.name, bobAnonymousIdentity.certificate.subject) + assertEquals(alice.name, aliceNode.services.identityService.partyFromAnonymous(aliceAnonymousIdentity)!!.name) + assertEquals(bob.name, bobNode.services.identityService.partyFromAnonymous(bobAnonymousIdentity)!!.name) // Verify that the nodes have the right anonymous identities - assertTrue { aliceAnonymousIdentity.identity.owningKey in aliceNode.services.keyManagementService.keys } - assertTrue { bobAnonymousIdentity.identity.owningKey in bobNode.services.keyManagementService.keys } - assertFalse { aliceAnonymousIdentity.identity.owningKey in bobNode.services.keyManagementService.keys } - assertFalse { bobAnonymousIdentity.identity.owningKey in aliceNode.services.keyManagementService.keys } + assertTrue { aliceAnonymousIdentity.owningKey in aliceNode.services.keyManagementService.keys } + assertTrue { bobAnonymousIdentity.owningKey in bobNode.services.keyManagementService.keys } + assertFalse { aliceAnonymousIdentity.owningKey in bobNode.services.keyManagementService.keys } + assertFalse { bobAnonymousIdentity.owningKey in aliceNode.services.keyManagementService.keys } + + mockNet.stopNodes() } } diff --git a/core/src/test/kotlin/net/corda/core/internal/InternalUtilsTest.kt b/core/src/test/kotlin/net/corda/core/internal/InternalUtilsTest.kt new file mode 100644 index 0000000000..b2f31384db --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/InternalUtilsTest.kt @@ -0,0 +1,90 @@ +package net.corda.core.internal + +import org.assertj.core.api.Assertions +import org.junit.Assert.assertArrayEquals +import org.junit.Test +import java.util.stream.IntStream +import java.util.stream.Stream +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class InternalUtilsTest { + @Test + fun `noneOrSingle on an empty collection`() { + val collection = emptyList() + Assertions.assertThat(collection.noneOrSingle()).isNull() + Assertions.assertThat(collection.noneOrSingle { it == 1 }).isNull() + } + + @Test + fun `noneOrSingle on a singleton collection`() { + val collection = listOf(1) + Assertions.assertThat(collection.noneOrSingle()).isEqualTo(1) + Assertions.assertThat(collection.noneOrSingle { it == 1 }).isEqualTo(1) + Assertions.assertThat(collection.noneOrSingle { it == 2 }).isNull() + } + + @Test + fun `noneOrSingle on a collection with two items`() { + val collection = listOf(1, 2) + assertFailsWith { collection.noneOrSingle() } + Assertions.assertThat(collection.noneOrSingle { it == 1 }).isEqualTo(1) + Assertions.assertThat(collection.noneOrSingle { it == 2 }).isEqualTo(2) + Assertions.assertThat(collection.noneOrSingle { it == 3 }).isNull() + assertFailsWith { collection.noneOrSingle { it > 0 } } + } + + @Test + fun `noneOrSingle on a collection with items 1, 2, 1`() { + val collection = listOf(1, 2, 1) + assertFailsWith { collection.noneOrSingle() } + assertFailsWith { collection.noneOrSingle { it == 1 } } + Assertions.assertThat(collection.noneOrSingle { it == 2 }).isEqualTo(2) + } + + @Test + fun `indexOfOrThrow returns index of the given item`() { + val collection = listOf(1, 2) + assertEquals(collection.indexOfOrThrow(1), 0) + assertEquals(collection.indexOfOrThrow(2), 1) + } + + @Test + fun `indexOfOrThrow throws if the given item is not found`() { + val collection = listOf(1) + assertFailsWith { collection.indexOfOrThrow(2) } + } + + @Test + fun `IntProgression stream works`() { + assertArrayEquals(intArrayOf(1, 2, 3, 4), (1..4).stream().toArray()) + assertArrayEquals(intArrayOf(1, 2, 3, 4), (1 until 5).stream().toArray()) + assertArrayEquals(intArrayOf(1, 3), (1..4 step 2).stream().toArray()) + assertArrayEquals(intArrayOf(1, 3), (1..3 step 2).stream().toArray()) + assertArrayEquals(intArrayOf(), (1..0).stream().toArray()) + assertArrayEquals(intArrayOf(1, 0), (1 downTo 0).stream().toArray()) + assertArrayEquals(intArrayOf(3, 1), (3 downTo 0 step 2).stream().toArray()) + assertArrayEquals(intArrayOf(3, 1), (3 downTo 1 step 2).stream().toArray()) + } + + @Test + fun `IntProgression spliterator characteristics and comparator`() { + val rangeCharacteristics = IntStream.range(0, 2).spliterator().characteristics() + val forward = (0..9 step 3).stream().spliterator() + assertEquals(rangeCharacteristics, forward.characteristics()) + assertEquals(null, forward.comparator) + val reverse = (9 downTo 0 step 3).stream().spliterator() + assertEquals(rangeCharacteristics, reverse.characteristics()) + assertEquals(Comparator.reverseOrder(), reverse.comparator) + } + + @Test + fun `Stream toTypedArray works`() { + val a: Array = Stream.of("one", "two").toTypedArray() + assertEquals(Array::class.java, a.javaClass) + assertArrayEquals(arrayOf("one", "two"), a) + val b: Array = Stream.of("one", "two", null).toTypedArray() + assertEquals(Array::class.java, b.javaClass) + assertArrayEquals(arrayOf("one", "two", null), b) + } +} diff --git a/core/src/test/kotlin/net/corda/core/flows/ResolveTransactionsFlowTest.kt b/core/src/test/kotlin/net/corda/core/internal/ResolveTransactionsFlowTest.kt similarity index 76% rename from core/src/test/kotlin/net/corda/core/flows/ResolveTransactionsFlowTest.kt rename to core/src/test/kotlin/net/corda/core/internal/ResolveTransactionsFlowTest.kt index 341f9b0d5d..6ead35bc90 100644 --- a/core/src/test/kotlin/net/corda/core/flows/ResolveTransactionsFlowTest.kt +++ b/core/src/test/kotlin/net/corda/core/internal/ResolveTransactionsFlowTest.kt @@ -1,17 +1,20 @@ -package net.corda.core.flows +package net.corda.core.internal -import net.corda.testing.contracts.DummyContract +import co.paralleluniverse.fibers.Suspendable import net.corda.core.crypto.SecureHash -import net.corda.core.getOrThrow +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.TestDataVendingFlow import net.corda.core.identity.Party -import net.corda.core.utilities.opaque import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.sequence import net.corda.testing.DUMMY_NOTARY_KEY -import net.corda.flows.ResolveTransactionsFlow -import net.corda.node.utilities.transaction import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP_KEY import net.corda.testing.MINI_CORP +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockServices import org.junit.After @@ -19,7 +22,6 @@ import org.junit.Before import org.junit.Test import java.io.ByteArrayOutputStream import java.io.InputStream -import java.security.SignatureException import java.util.jar.JarEntry import java.util.jar.JarOutputStream import kotlin.test.assertEquals @@ -41,6 +43,8 @@ class ResolveTransactionsFlowTest { val nodes = mockNet.createSomeNodes() a = nodes.partyNodes[0] b = nodes.partyNodes[1] + a.registerInitiatedFlow(TestResponseFlow::class.java) + b.registerInitiatedFlow(TestResponseFlow::class.java) notary = nodes.notaryNode.info.notaryIdentity mockNet.runNetwork() } @@ -54,7 +58,7 @@ class ResolveTransactionsFlowTest { @Test fun `resolve from two hashes`() { val (stx1, stx2) = makeTransactions() - val p = ResolveTransactionsFlow(setOf(stx2.id), a.info.legalIdentity) + val p = TestFlow(setOf(stx2.id), a.info.legalIdentity) val future = b.services.startFlow(p).resultFuture mockNet.runNetwork() val results = future.getOrThrow() @@ -69,16 +73,16 @@ class ResolveTransactionsFlowTest { @Test fun `dependency with an error`() { val stx = makeTransactions(signFirstTX = false).second - val p = ResolveTransactionsFlow(setOf(stx.id), a.info.legalIdentity) + val p = TestFlow(setOf(stx.id), a.info.legalIdentity) val future = b.services.startFlow(p).resultFuture mockNet.runNetwork() - assertFailsWith(SignatureException::class) { future.getOrThrow() } + assertFailsWith(SignedTransaction.SignaturesMissingException::class) { future.getOrThrow() } } @Test fun `resolve from a signed transaction`() { val (stx1, stx2) = makeTransactions() - val p = ResolveTransactionsFlow(stx2, a.info.legalIdentity) + val p = TestFlow(stx2, a.info.legalIdentity) val future = b.services.startFlow(p).resultFuture mockNet.runNetwork() future.getOrThrow() @@ -103,8 +107,7 @@ class ResolveTransactionsFlowTest { } cursor = stx } - val p = ResolveTransactionsFlow(setOf(cursor.id), a.info.legalIdentity) - p.transactionCountLimit = 40 + val p = TestFlow(setOf(cursor.id), a.info.legalIdentity, 40) val future = b.services.startFlow(p).resultFuture mockNet.runNetwork() assertFailsWith { future.getOrThrow() } @@ -128,7 +131,7 @@ class ResolveTransactionsFlowTest { a.services.recordTransactions(stx2, stx3) } - val p = ResolveTransactionsFlow(setOf(stx3.id), a.info.legalIdentity) + val p = TestFlow(setOf(stx3.id), a.info.legalIdentity) val future = b.services.startFlow(p).resultFuture mockNet.runNetwork() future.getOrThrow() @@ -143,14 +146,14 @@ class ResolveTransactionsFlowTest { jar.write("Some test file".toByteArray()) jar.closeEntry() jar.close() - return bs.toByteArray().opaque().open() + return bs.toByteArray().sequence().open() } // TODO: this operation should not require an explicit transaction val id = a.database.transaction { a.services.attachments.importAttachment(makeJar()) } val stx2 = makeTransactions(withAttachment = id).second - val p = ResolveTransactionsFlow(stx2, a.info.legalIdentity) + val p = TestFlow(stx2, a.info.legalIdentity) val future = b.services.startFlow(p).resultFuture mockNet.runNetwork() future.getOrThrow() @@ -172,7 +175,7 @@ class ResolveTransactionsFlowTest { val ptx = megaCorpServices.signInitialTransaction(it) notaryServices.addSignature(ptx) } - false -> { + false -> { notaryServices.signInitialTransaction(it) } } @@ -187,4 +190,22 @@ class ResolveTransactionsFlowTest { return Pair(dummy1, dummy2) } // DOCEND 2 + + @InitiatingFlow + private class TestFlow(private val resolveTransactionsFlow: ResolveTransactionsFlow, private val txCountLimit: Int? = null) : FlowLogic>() { + constructor(txHashes: Set, otherSide: Party, txCountLimit: Int? = null) : this(ResolveTransactionsFlow(txHashes, otherSide), txCountLimit = txCountLimit) + constructor(stx: SignedTransaction, otherSide: Party) : this(ResolveTransactionsFlow(stx, otherSide)) + + @Suspendable + override fun call(): List { + txCountLimit?.let { resolveTransactionsFlow.transactionCountLimit = it } + return subFlow(resolveTransactionsFlow) + } + } + + @InitiatedBy(TestFlow::class) + private class TestResponseFlow(val otherSide: Party) : FlowLogic() { + @Suspendable + override fun call() = subFlow(TestDataVendingFlow(otherSide)) + } } diff --git a/core/src/test/kotlin/net/corda/core/internal/concurrent/CordaFutureImplTest.kt b/core/src/test/kotlin/net/corda/core/internal/concurrent/CordaFutureImplTest.kt new file mode 100644 index 0000000000..c176372a39 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/internal/concurrent/CordaFutureImplTest.kt @@ -0,0 +1,163 @@ +package net.corda.core.internal.concurrent + +import com.nhaarman.mockito_kotlin.* +import net.corda.core.concurrent.CordaFuture +import net.corda.core.utilities.getOrThrow +import org.assertj.core.api.Assertions +import org.junit.Test +import org.slf4j.Logger +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class CordaFutureTest { + @Test + fun `fork works`() { + val e = Executors.newSingleThreadExecutor() + try { + assertEquals(100, e.fork { 100 }.getOrThrow()) + val x = Exception() + val f = e.fork { throw x } + Assertions.assertThatThrownBy { f.getOrThrow() }.isSameAs(x) + } finally { + e.shutdown() + } + } + + @Test + fun `if a listener fails its throwable is logged`() { + val f = CordaFutureImpl() + val x = Exception() + val log = mock() + val flag = AtomicBoolean() + f.thenImpl(log) { throw x } + f.thenImpl(log) { flag.set(true) } // Must not be affected by failure of previous listener. + f.set(100) + verify(log).error(eq(CordaFutureImpl.listenerFailedMessage), same(x)) + verifyNoMoreInteractions(log) + assertTrue(flag.get()) + } + + @Test + fun `map works`() { + run { + val f = CordaFutureImpl() + val g = f.map { it * 2 } + f.set(100) + assertEquals(200, g.getOrThrow()) + } + run { + val f = CordaFutureImpl() + val x = Exception() + val g = f.map { throw x } + f.set(100) + Assertions.assertThatThrownBy { g.getOrThrow() }.isSameAs(x) + } + run { + val block = mock<(Any?) -> Any?>() + val f = CordaFutureImpl() + val g = f.map(block) + val x = Exception() + f.setException(x) + Assertions.assertThatThrownBy { g.getOrThrow() }.isSameAs(x) + verifyNoMoreInteractions(block) + } + } + + @Test + fun `flatMap works`() { + run { + val f = CordaFutureImpl() + val g = f.flatMap { CordaFutureImpl().apply { set(it * 2) } } + f.set(100) + assertEquals(200, g.getOrThrow()) + } + run { + val f = CordaFutureImpl() + val x = Exception() + val g = f.flatMap { CordaFutureImpl().apply { setException(x) } } + f.set(100) + Assertions.assertThatThrownBy { g.getOrThrow() }.isSameAs(x) + } + run { + val f = CordaFutureImpl() + val x = Exception() + val g: CordaFuture = f.flatMap { throw x } + f.set(100) + Assertions.assertThatThrownBy { g.getOrThrow() }.isSameAs(x) + } + run { + val block = mock<(Any?) -> CordaFuture<*>>() + val f = CordaFutureImpl() + val g = f.flatMap(block) + val x = Exception() + f.setException(x) + Assertions.assertThatThrownBy { g.getOrThrow() }.isSameAs(x) + verifyNoMoreInteractions(block) + } + } + + @Test + fun `andForget works`() { + val log = mock() + val throwable = Exception("Boom") + val executor = Executors.newSingleThreadExecutor() + executor.fork { throw throwable }.andForget(log) + executor.shutdown() + while (!executor.awaitTermination(1, TimeUnit.SECONDS)) { + // Do nothing. + } + verify(log).error(any(), same(throwable)) + } +} + +class TransposeTest { + private val a = openFuture() + private val b = openFuture() + private val c = openFuture() + private val f = listOf(a, b, c).transpose() + @Test + fun `transpose empty collection`() { + assertEquals(emptyList(), emptyList>().transpose().getOrThrow()) + } + + @Test + fun `transpose values are in the same order as the collection of futures`() { + b.set(2) + c.set(3) + assertFalse(f.isDone) + a.set(1) + assertEquals(listOf(1, 2, 3), f.getOrThrow()) + } + + @Test + fun `transpose throwables are reported in the order they were thrown`() { + val ax = Exception() + val bx = Exception() + val cx = Exception() + b.setException(bx) + c.setException(cx) + assertFalse(f.isDone) + a.setException(ax) + Assertions.assertThatThrownBy { f.getOrThrow() }.isSameAs(bx) + assertEquals(listOf(cx, ax), bx.suppressed.asList()) + assertEquals(emptyList(), ax.suppressed.asList()) + assertEquals(emptyList(), cx.suppressed.asList()) + } + + @Test + fun `transpose mixture of outcomes`() { + val bx = Exception() + val cx = Exception() + b.setException(bx) + c.setException(cx) + assertFalse(f.isDone) + a.set(100) // Discarded. + Assertions.assertThatThrownBy { f.getOrThrow() }.isSameAs(bx) + assertEquals(listOf(cx), bx.suppressed.asList()) + assertEquals(emptyList(), cx.suppressed.asList()) + } +} diff --git a/core/src/test/kotlin/net/corda/core/node/CityDatabaseTest.kt b/core/src/test/kotlin/net/corda/core/node/CityDatabaseTest.kt index 3733ea9e31..a9d3bcda6a 100644 --- a/core/src/test/kotlin/net/corda/core/node/CityDatabaseTest.kt +++ b/core/src/test/kotlin/net/corda/core/node/CityDatabaseTest.kt @@ -1,6 +1,6 @@ package net.corda.core.node -import org.junit.Assert.* +import org.junit.Assert.assertEquals import org.junit.Test class CityDatabaseTest { diff --git a/core/src/test/kotlin/net/corda/core/node/ServiceInfoTests.kt b/core/src/test/kotlin/net/corda/core/node/ServiceInfoTests.kt index b1387cb006..7cf226d988 100644 --- a/core/src/test/kotlin/net/corda/core/node/ServiceInfoTests.kt +++ b/core/src/test/kotlin/net/corda/core/node/ServiceInfoTests.kt @@ -1,10 +1,8 @@ package net.corda.core.node -import net.corda.core.crypto.X509Utilities import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType import net.corda.testing.getTestX509Name -import org.bouncycastle.asn1.x500.X500Name import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith diff --git a/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt b/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt index a2f2e81716..faf5bc13d9 100644 --- a/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt +++ b/core/src/test/kotlin/net/corda/core/node/VaultUpdateTests.kt @@ -4,16 +4,18 @@ import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.identity.AbstractParty import net.corda.core.node.services.Vault +import net.corda.core.transactions.LedgerTransaction import net.corda.testing.DUMMY_NOTARY import org.junit.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith class VaultUpdateTests { object DummyContract : Contract { - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { } override val legalContractReference: SecureHash = SecureHash.sha256("") @@ -46,7 +48,7 @@ class VaultUpdateTests { @Test fun `something plus nothing is something`() { - val before = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) + val before = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) val after = before + Vault.NoUpdate assertEquals(before, after) } @@ -54,32 +56,39 @@ class VaultUpdateTests { @Test fun `nothing plus something is something`() { val before = Vault.NoUpdate - val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) - val expected = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) + val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) + val expected = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef2, stateAndRef3)) assertEquals(expected, after) } @Test fun `something plus consume state 0 is something without state 0 output`() { - val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) - val after = before + Vault.Update(setOf(stateAndRef0), setOf()) - val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef1)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) + val after = before + Vault.Update(setOf(stateAndRef0), setOf()) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef1)) assertEquals(expected, after) } @Test fun `something plus produce state 4 is something with additional state 4 output`() { - val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) - val after = before + Vault.Update(setOf(), setOf(stateAndRef4)) - val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) + val after = before + Vault.Update(setOf(), setOf(stateAndRef4)) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1, stateAndRef4)) assertEquals(expected, after) } @Test fun `something plus consume states 0 and 1, and produce state 4, is something without state 0 and 1 outputs and only state 4 output`() { - val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) - val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef4)) - val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef4)) + val before = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1)) + val after = before + Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef4)) + val expected = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef4)) assertEquals(expected, after) } + + @Test + fun `can't combine updates of different types`() { + val regularUpdate = Vault.Update(setOf(stateAndRef0, stateAndRef1), setOf(stateAndRef4)) + val notaryChangeUpdate = Vault.Update(setOf(stateAndRef2, stateAndRef3), setOf(stateAndRef0, stateAndRef1), type = Vault.UpdateType.NOTARY_CHANGE) + assertFailsWith { regularUpdate + notaryChangeUpdate } + } } diff --git a/core/src/test/kotlin/net/corda/core/node/services/TimeWindowCheckerTests.kt b/core/src/test/kotlin/net/corda/core/node/services/TimeWindowCheckerTests.kt index 9b174d6b41..61d37b5211 100644 --- a/core/src/test/kotlin/net/corda/core/node/services/TimeWindowCheckerTests.kt +++ b/core/src/test/kotlin/net/corda/core/node/services/TimeWindowCheckerTests.kt @@ -1,33 +1,39 @@ package net.corda.core.node.services import net.corda.core.contracts.TimeWindow -import net.corda.core.seconds +import net.corda.core.utilities.seconds import org.junit.Test import java.time.Clock import java.time.Instant -import java.time.ZoneId +import java.time.ZoneOffset import kotlin.test.assertFalse import kotlin.test.assertTrue class TimeWindowCheckerTests { - val clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()) - val timeWindowChecker = TimeWindowChecker(clock, tolerance = 30.seconds) + val clock: Clock = Clock.fixed(Instant.now(), ZoneOffset.UTC) + val timeWindowChecker = TimeWindowChecker(clock) @Test fun `should return true for valid time-window`() { val now = clock.instant() - val timeWindowPast = TimeWindow.between(now - 60.seconds, now - 29.seconds) - val timeWindowFuture = TimeWindow.between(now + 29.seconds, now + 60.seconds) - assertTrue { timeWindowChecker.isValid(timeWindowPast) } - assertTrue { timeWindowChecker.isValid(timeWindowFuture) } + val timeWindowBetween = TimeWindow.between(now - 10.seconds, now + 10.seconds) + val timeWindowFromOnly = TimeWindow.fromOnly(now - 10.seconds) + val timeWindowUntilOnly = TimeWindow.untilOnly(now + 10.seconds) + assertTrue { timeWindowChecker.isValid(timeWindowBetween) } + assertTrue { timeWindowChecker.isValid(timeWindowFromOnly) } + assertTrue { timeWindowChecker.isValid(timeWindowUntilOnly) } } @Test fun `should return false for invalid time-window`() { val now = clock.instant() - val timeWindowPast = TimeWindow.between(now - 60.seconds, now - 31.seconds) - val timeWindowFuture = TimeWindow.between(now + 31.seconds, now + 60.seconds) - assertFalse { timeWindowChecker.isValid(timeWindowPast) } - assertFalse { timeWindowChecker.isValid(timeWindowFuture) } + val timeWindowBetweenPast = TimeWindow.between(now - 10.seconds, now - 2.seconds) + val timeWindowBetweenFuture = TimeWindow.between(now + 2.seconds, now + 10.seconds) + val timeWindowFromOnlyFuture = TimeWindow.fromOnly(now + 10.seconds) + val timeWindowUntilOnlyPast = TimeWindow.untilOnly(now - 10.seconds) + assertFalse { timeWindowChecker.isValid(timeWindowBetweenPast) } + assertFalse { timeWindowChecker.isValid(timeWindowBetweenFuture) } + assertFalse { timeWindowChecker.isValid(timeWindowFromOnlyFuture) } + assertFalse { timeWindowChecker.isValid(timeWindowUntilOnlyPast) } } } diff --git a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt index f1259a4f73..822da1b93e 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/AttachmentSerializationTest.kt @@ -5,20 +5,20 @@ import net.corda.core.contracts.Attachment import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatingFlow -import net.corda.core.getOrThrow +import net.corda.core.flows.TestDataVendingFlow import net.corda.core.identity.Party +import net.corda.core.internal.FetchAttachmentsFlow +import net.corda.core.internal.FetchDataFlow import net.corda.core.messaging.RPCOps import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.ServiceInfo +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap -import net.corda.flows.FetchAttachmentsFlow import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.network.NetworkMapService import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.persistence.schemas.requery.AttachmentEntity -import net.corda.node.services.statemachine.SessionInit -import net.corda.node.utilities.transaction import net.corda.testing.node.MockNetwork import org.junit.After import org.junit.Before @@ -82,9 +82,12 @@ class AttachmentSerializationTest { mockNet.stopNodes() } - private class ServerLogic(private val client: Party) : FlowLogic() { + private class ServerLogic(private val client: Party, private val sendData: Boolean) : FlowLogic() { @Suspendable override fun call() { + if (sendData) { + subFlow(TestDataVendingFlow(client)) + } receive(client).unwrap { assertEquals("ping one", it) } sendAndReceive(client, "pong").unwrap { assertEquals("ping two", it) } } @@ -135,25 +138,28 @@ class AttachmentSerializationTest { @Suspendable override fun getAttachmentContent(): String { val (downloadedAttachment) = subFlow(FetchAttachmentsFlow(setOf(attachmentId), server)).downloaded + send(server, FetchDataFlow.Request.End) communicate() return downloadedAttachment.extractContent() } } - private fun launchFlow(clientLogic: ClientLogic, rounds: Int) { - server.internalRegisterFlowFactory(ClientLogic::class.java, object : InitiatedFlowFactory { - override fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): ServerLogic { - return ServerLogic(otherParty) - } - }, ServerLogic::class.java, track = false) + private fun launchFlow(clientLogic: ClientLogic, rounds: Int, sendData: Boolean = false) { + server.internalRegisterFlowFactory( + ClientLogic::class.java, + InitiatedFlowFactory.Core { ServerLogic(it, sendData) }, + ServerLogic::class.java, + track = false) client.services.startFlow(clientLogic) mockNet.runNetwork(rounds) } private fun rebootClientAndGetAttachmentContent(checkAttachmentsOnLoad: Boolean = true): String { client.stop() - client = mockNet.createNode(server.network.myAddress, client.id, object : MockNetwork.Factory { - override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, entropyRoot: BigInteger): MockNetwork.MockNode { + client = mockNet.createNode(server.network.myAddress, client.id, object : MockNetwork.Factory { + override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, + advertisedServices: Set, id: Int, overrideServices: Map?, + entropyRoot: BigInteger): MockNetwork.MockNode { return object : MockNetwork.MockNode(config, network, networkMapAddr, advertisedServices, id, overrideServices, entropyRoot) { override fun startMessagingService(rpcOps: RPCOps) { attachments.checkAttachmentsOnLoad = checkAttachmentsOnLoad @@ -192,7 +198,7 @@ class AttachmentSerializationTest { @Test fun `only the hash of a FetchAttachmentsFlow attachment should be saved in checkpoint`() { val attachmentId = server.saveAttachment("genuine") - launchFlow(FetchAttachmentLogic(server, attachmentId), 2) + launchFlow(FetchAttachmentLogic(server, attachmentId), 2, sendData = true) client.hackAttachment(attachmentId, "hacked") assertEquals("hacked", rebootClientAndGetAttachmentContent(false)) } diff --git a/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt b/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt deleted file mode 100644 index e27932c757..0000000000 --- a/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt +++ /dev/null @@ -1,179 +0,0 @@ -package net.corda.core.serialization - -import com.esotericsoftware.kryo.Kryo -import com.google.common.primitives.Ints -import net.corda.core.crypto.* -import net.corda.core.utilities.opaque -import net.corda.node.services.persistence.NodeAttachmentService -import net.corda.testing.ALICE -import net.corda.testing.BOB -import net.corda.testing.BOB_PUBKEY -import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatThrownBy -import org.bouncycastle.cert.X509CertificateHolder -import org.junit.Before -import org.junit.Test -import org.slf4j.LoggerFactory -import java.io.ByteArrayInputStream -import java.io.InputStream -import java.security.cert.CertPath -import java.security.cert.CertificateFactory -import java.time.Instant -import java.util.* -import kotlin.test.assertEquals -import kotlin.test.assertTrue - -class KryoTests { - - private lateinit var kryo: Kryo - - @Before - fun setup() { - // We deliberately do not return this, since we do some unorthodox registering below and do not want to pollute the pool. - kryo = p2PKryo().borrow() - } - - @Test - fun ok() { - val birthday = Instant.parse("1984-04-17T00:30:00.00Z") - val mike = Person("mike", birthday) - val bits = mike.serialize(kryo) - assertThat(bits.deserialize(kryo)).isEqualTo(Person("mike", birthday)) - } - - @Test - fun nullables() { - val bob = Person("bob", null) - val bits = bob.serialize(kryo) - assertThat(bits.deserialize(kryo)).isEqualTo(Person("bob", null)) - } - - @Test - fun `serialised form is stable when the same object instance is added to the deserialised object graph`() { - kryo.noReferencesWithin>() - val obj = Ints.toByteArray(0x01234567).opaque() - val originalList = arrayListOf(obj) - val deserialisedList = originalList.serialize(kryo).deserialize(kryo) - originalList += obj - deserialisedList += obj - assertThat(deserialisedList.serialize(kryo)).isEqualTo(originalList.serialize(kryo)) - } - - @Test - fun `serialised form is stable when the same object instance occurs more than once, and using java serialisation`() { - kryo.noReferencesWithin>() - val instant = Instant.ofEpochMilli(123) - val instantCopy = Instant.ofEpochMilli(123) - assertThat(instant).isNotSameAs(instantCopy) - val listWithCopies = arrayListOf(instant, instantCopy) - val listWithSameInstances = arrayListOf(instant, instant) - assertThat(listWithSameInstances.serialize(kryo)).isEqualTo(listWithCopies.serialize(kryo)) - } - - @Test - fun `cyclic object graph`() { - val cyclic = Cyclic(3) - val bits = cyclic.serialize(kryo) - assertThat(bits.deserialize(kryo)).isEqualTo(cyclic) - } - - @Test - fun `deserialised key pair functions the same as serialised one`() { - val keyPair = generateKeyPair() - val bitsToSign: ByteArray = Ints.toByteArray(0x01234567) - val wrongBits: ByteArray = Ints.toByteArray(0x76543210) - val signature = keyPair.sign(bitsToSign) - signature.verify(bitsToSign) - assertThatThrownBy { signature.verify(wrongBits) } - - val deserialisedKeyPair = keyPair.serialize(kryo).deserialize(kryo) - val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign) - deserialisedSignature.verify(bitsToSign) - assertThatThrownBy { deserialisedSignature.verify(wrongBits) } - } - - @Test - fun `write and read Kotlin object singleton`() { - val serialised = TestSingleton.serialize(kryo) - val deserialised = serialised.deserialize(kryo) - assertThat(deserialised).isSameAs(TestSingleton) - } - - @Test - fun `InputStream serialisation`() { - val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() }) - val readRubbishStream: InputStream = rubbish.inputStream().serialize(kryo).deserialize(kryo) - for (i in 0..12344) { - assertEquals(rubbish[i], readRubbishStream.read().toByte()) - } - assertEquals(-1, readRubbishStream.read()) - } - - @Test - fun `serialize - deserialize MetaData`() { - val testString = "Hello World" - val testBytes = testString.toByteArray() - val keyPair1 = Crypto.generateKeyPair("ECDSA_SECP256K1_SHA256") - val bitSet = java.util.BitSet(10) - bitSet.set(3) - - val meta = MetaData("ECDSA_SECP256K1_SHA256", "M9", SignatureType.FULL, Instant.now(), bitSet, bitSet, testBytes, keyPair1.public) - val serializedMetaData = meta.bytes() - val meta2 = serializedMetaData.deserialize() - assertEquals(meta2, meta) - } - - @Test - fun `serialize - deserialize Logger`() { - val logger = LoggerFactory.getLogger("aName") - val logger2 = logger.serialize(storageKryo()).deserialize(storageKryo()) - assertEquals(logger.name, logger2.name) - assertTrue(logger === logger2) - } - - @Test - fun `HashCheckingStream (de)serialize`() { - val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() }) - val readRubbishStream: InputStream = NodeAttachmentService.HashCheckingStream(SecureHash.sha256(rubbish), rubbish.size, ByteArrayInputStream(rubbish)).serialize(kryo).deserialize(kryo) - for (i in 0..12344) { - assertEquals(rubbish[i], readRubbishStream.read().toByte()) - } - assertEquals(-1, readRubbishStream.read()) - } - - @Test - fun `serialize - deserialize X509CertififcateHolder`() { - val expected: X509CertificateHolder = X509Utilities.createSelfSignedCACertificate(ALICE.name, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)) - val serialized = expected.serialize(kryo).bytes - val actual: X509CertificateHolder = serialized.deserialize(kryo) - assertEquals(expected, actual) - } - - @Test - fun `serialize - deserialize X509CertPath`() { - val certFactory = CertificateFactory.getInstance("X509") - val rootCAKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE.name, rootCAKey) - val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB.name, BOB_PUBKEY) - val expected = certFactory.generateCertPath(listOf(certificate.cert, rootCACert.cert)) - val serialized = expected.serialize(kryo).bytes - val actual: CertPath = serialized.deserialize(kryo) - assertEquals(expected, actual) - } - - @CordaSerializable - private data class Person(val name: String, val birthday: Instant?) - - @Suppress("unused") - @CordaSerializable - private class Cyclic(val value: Int) { - val thisInstance = this - override fun equals(other: Any?): Boolean = (this === other) || (other is Cyclic && this.value == other.value) - override fun hashCode(): Int = value.hashCode() - override fun toString(): String = "Cyclic($value)" - } - - @CordaSerializable - private object TestSingleton - -} diff --git a/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt b/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt index 4e34ebfb07..492b50826f 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/TransactionSerializationTests.kt @@ -3,8 +3,9 @@ package net.corda.core.serialization import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.identity.AbstractParty -import net.corda.core.seconds +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.seconds import net.corda.testing.* import net.corda.testing.node.MockServices import org.junit.Before @@ -16,11 +17,11 @@ import kotlin.test.assertFailsWith val TEST_PROGRAM_ID = TransactionSerializationTests.TestCash() -class TransactionSerializationTests { +class TransactionSerializationTests : TestDependencyInjectionBase() { class TestCash : Contract { override val legalContractReference = SecureHash.sha256("TestCash") - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { } data class State( @@ -31,7 +32,7 @@ class TransactionSerializationTests { override val participants: List get() = listOf(owner) - override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Move(), copy(owner = newOwner)) + override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Commands.Move(), copy(owner = newOwner)) } interface Commands : CommandData { @@ -53,7 +54,7 @@ class TransactionSerializationTests { @Before fun setup() { - tx = TransactionType.General.Builder(DUMMY_NOTARY).withItems( + tx = TransactionBuilder(DUMMY_NOTARY).withItems( inputState, outputState, changeState, Command(TestCash.Commands.Move(), arrayListOf(MEGA_CORP.owningKey)) ) } @@ -64,12 +65,12 @@ class TransactionSerializationTests { val stx = notaryServices.addSignature(ptx) // Now check that the signature we just made verifies. - stx.verifySignatures() + stx.verifyRequiredSignatures() // Corrupt the data and ensure the signature catches the problem. stx.id.bytes[5] = stx.id.bytes[5].inc() assertFailsWith(SignatureException::class) { - stx.verifySignatures() + stx.verifyRequiredSignatures() } } @@ -85,14 +86,14 @@ class TransactionSerializationTests { // If the signature was replaced in transit, we don't like it. assertFailsWith(SignatureException::class) { - val tx2 = TransactionType.General.Builder(DUMMY_NOTARY).withItems(inputState, outputState, changeState, + val tx2 = TransactionBuilder(DUMMY_NOTARY).withItems(inputState, outputState, changeState, Command(TestCash.Commands.Move(), DUMMY_KEY_2.public)) val ptx2 = notaryServices.signInitialTransaction(tx2) val dummyServices = MockServices(DUMMY_KEY_2) val stx2 = dummyServices.addSignature(ptx2) - stx.copy(sigs = stx2.sigs).verifySignatures() + stx.copy(sigs = stx2.sigs).verifyRequiredSignatures() } } @@ -103,4 +104,18 @@ class TransactionSerializationTests { val stx = notaryServices.addSignature(ptx) assertEquals(TEST_TX_TIME, stx.tx.timeWindow?.midpoint) } + + @Test + fun storeAndLoadWhenSigning() { + val ptx = megaCorpServices.signInitialTransaction(tx) + ptx.verifySignaturesExcept(notaryServices.key.public) + + val stored = ptx.serialize() + val loaded = stored.deserialize() + + assertEquals(loaded, ptx) + + val final = notaryServices.addSignature(loaded) + final.verifyRequiredSignatures() + } } diff --git a/core/src/test/kotlin/net/corda/core/testing/Generators.kt b/core/src/test/kotlin/net/corda/core/testing/Generators.kt deleted file mode 100644 index 591dbb13c1..0000000000 --- a/core/src/test/kotlin/net/corda/core/testing/Generators.kt +++ /dev/null @@ -1,156 +0,0 @@ -package net.corda.core.testing - -import com.pholser.junit.quickcheck.generator.GenerationStatus -import com.pholser.junit.quickcheck.generator.Generator -import com.pholser.junit.quickcheck.generator.java.util.ArrayListGenerator -import com.pholser.junit.quickcheck.random.SourceOfRandomness -import net.corda.core.contracts.* -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.entropyToKeyPair -import net.corda.core.identity.AnonymousParty -import net.corda.core.identity.Party -import net.corda.core.utilities.OpaqueBytes -import net.corda.testing.getTestX509Name -import org.bouncycastle.asn1.x500.X500Name -import java.nio.ByteBuffer -import java.nio.charset.Charset -import java.security.PrivateKey -import java.security.PublicKey -import java.time.Duration -import java.time.Instant -import java.util.* - -/** - * Generators for quickcheck - * - * TODO Split this into several files - */ - -fun Generator.generateList(random: SourceOfRandomness, status: GenerationStatus): List { - val arrayGenerator = ArrayListGenerator() - arrayGenerator.addComponentGenerators(listOf(this)) - @Suppress("UNCHECKED_CAST") - return arrayGenerator.generate(random, status) as List -} - -class PrivateKeyGenerator : Generator(PrivateKey::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): PrivateKey { - return entropyToKeyPair(random.nextBigInteger(32)).private - } -} - -// TODO add CompositeKeyGenerator that actually does something useful. -class PublicKeyGenerator : Generator(PublicKey::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): PublicKey { - return entropyToKeyPair(random.nextBigInteger(32)).public - } -} - -class AnonymousPartyGenerator : Generator(AnonymousParty::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): AnonymousParty { - return AnonymousParty(PublicKeyGenerator().generate(random, status)) - } -} - -class PartyGenerator : Generator(Party::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Party { - return Party(X500NameGenerator().generate(random, status), PublicKeyGenerator().generate(random, status)) - } -} - -class PartyAndReferenceGenerator : Generator(PartyAndReference::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): PartyAndReference { - return PartyAndReference(AnonymousPartyGenerator().generate(random, status), OpaqueBytes(random.nextBytes(16))) - } -} - -class SecureHashGenerator : Generator(SecureHash::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): SecureHash { - return SecureHash.sha256(random.nextBytes(16)) - } -} - -class StateRefGenerator : Generator(StateRef::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): StateRef { - return StateRef(SecureHash.sha256(random.nextBytes(16)), random.nextInt(0, 10)) - } -} - -@Suppress("CAST_NEVER_SUCCEEDS", "UNCHECKED_CAST") -class TransactionStateGenerator(val stateGenerator: Generator) : Generator>(TransactionState::class.java as Class>) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): TransactionState { - return TransactionState(stateGenerator.generate(random, status), PartyGenerator().generate(random, status)) - } -} - -@Suppress("CAST_NEVER_SUCCEEDS", "UNCHECKED_CAST") -class IssuedGenerator(val productGenerator: Generator) : Generator>(Issued::class.java as Class>) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Issued { - return Issued(PartyAndReferenceGenerator().generate(random, status), productGenerator.generate(random, status)) - } -} - -@Suppress("CAST_NEVER_SUCCEEDS", "UNCHECKED_CAST") -class AmountGenerator(val tokenGenerator: Generator) : Generator>(Amount::class.java as Class>) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Amount { - return Amount(random.nextLong(0, 1000000), tokenGenerator.generate(random, status)) - } -} - -class CurrencyGenerator : Generator(Currency::class.java) { - companion object { - val currencies = Currency.getAvailableCurrencies().toList() - } - - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Currency { - return currencies[random.nextInt(0, currencies.size - 1)] - } -} - -class InstantGenerator : Generator(Instant::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Instant { - return Instant.ofEpochMilli(random.nextLong(0, 1000000)) - } -} - -class DurationGenerator : Generator(Duration::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Duration { - return Duration.ofMillis(random.nextLong(0, 1000000)) - } -} - -class TimeWindowGenerator : Generator(TimeWindow::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): TimeWindow { - return TimeWindow.withTolerance(InstantGenerator().generate(random, status), DurationGenerator().generate(random, status)) - } -} - -class X500NameGenerator : Generator(X500Name::class.java) { - companion object { - private val charset = Charset.forName("US-ASCII") - private val asciiA = charset.encode("A")[0] - private val asciia = charset.encode("a")[0] - } - - /** - * Append something that looks a bit like a proper noun to the string builder. - */ - private fun appendProperNoun(builder: StringBuilder, random: SourceOfRandomness) : StringBuilder { - val length = random.nextByte(1, 8) - val encoded = ByteBuffer.allocate(length.toInt()) - encoded.put((random.nextByte(0, 25) + asciiA).toByte()) - for (charIdx in 1..length - 1) { - encoded.put((random.nextByte(0, 25) + asciia).toByte()) - } - return builder.append(charset.decode(encoded)) - } - - override fun generate(random: SourceOfRandomness, status: GenerationStatus): X500Name { - val wordCount = random.nextByte(1, 3) - val cn = StringBuilder() - for (word in 0..wordCount) { - appendProperNoun(cn, random).append(" ") - } - return getTestX509Name(cn.trim().toString()) - } -} diff --git a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt new file mode 100644 index 0000000000..ae4a67c093 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt @@ -0,0 +1,59 @@ +package net.corda.core.utilities + +import net.corda.core.crypto.random63BitValue +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.testing.TestDependencyInjectionBase +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class KotlinUtilsTest : TestDependencyInjectionBase() { + @Test + fun `transient property which is null`() { + val test = NullTransientProperty() + test.transientValue + test.transientValue + assertThat(test.evalCount).isEqualTo(1) + } + + @Test + fun `transient property with non-capturing lamba`() { + val original = NonCapturingTransientProperty() + val originalVal = original.transientVal + val copy = original.serialize().deserialize() + val copyVal = copy.transientVal + assertThat(copyVal).isNotEqualTo(originalVal) + assertThat(copy.transientVal).isEqualTo(copyVal) + } + + @Test + fun `transient property with capturing lamba`() { + val original = CapturingTransientProperty("Hello") + val originalVal = original.transientVal + val copy = original.serialize().deserialize() + val copyVal = copy.transientVal + assertThat(copyVal).isNotEqualTo(originalVal) + assertThat(copy.transientVal).isEqualTo(copyVal) + assertThat(copy.transientVal).startsWith("Hello") + } + + private class NullTransientProperty { + var evalCount = 0 + val transientValue by transient { + evalCount++ + null + } + } + + @CordaSerializable + private class NonCapturingTransientProperty { + val transientVal by transient { random63BitValue() } + } + + @CordaSerializable + private class CapturingTransientProperty(prefix: String) { + private val seed = random63BitValue() + val transientVal by transient { prefix + seed + random63BitValue() } + } +} \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt b/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt index 6a086b39ba..299dec166e 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/NonEmptySetTest.kt @@ -4,119 +4,64 @@ import com.google.common.collect.testing.SetTestSuiteBuilder import com.google.common.collect.testing.TestIntegerSetGenerator import com.google.common.collect.testing.features.CollectionFeature import com.google.common.collect.testing.features.CollectionSize -import com.google.common.collect.testing.testers.* import junit.framework.TestSuite import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.testing.initialiseTestSerialization +import net.corda.testing.resetTestSerialization +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Suite -import kotlin.test.assertEquals @RunWith(Suite::class) @Suite.SuiteClasses( NonEmptySetTest.Guava::class, - NonEmptySetTest.Remove::class, - NonEmptySetTest.Serializer::class + NonEmptySetTest.General::class ) class NonEmptySetTest { - /** - * Guava test suite generator for NonEmptySet. - */ - class Guava { - companion object { - @JvmStatic - fun suite(): TestSuite - = SetTestSuiteBuilder - .using(NonEmptySetGenerator()) - .named("test NonEmptySet with several values") + object Guava { + @JvmStatic + fun suite(): TestSuite { + return SetTestSuiteBuilder + .using(NonEmptySetGenerator) + .named("Guava test suite") .withFeatures( CollectionSize.SEVERAL, CollectionFeature.ALLOWS_NULL_VALUES, - CollectionFeature.FAILS_FAST_ON_CONCURRENT_MODIFICATION, - CollectionFeature.GENERAL_PURPOSE + CollectionFeature.KNOWN_ORDER ) - // Kotlin throws the wrong exception in this cases - .suppressing(CollectionAddAllTester::class.java.getMethod("testAddAll_nullCollectionReference")) - // Disable tests that try to remove everything: - .suppressing(CollectionRemoveAllTester::class.java.getMethod("testRemoveAll_nullCollectionReferenceNonEmptySubject")) - .suppressing(CollectionClearTester::class.java.methods.toList()) - .suppressing(CollectionRetainAllTester::class.java.methods.toList()) - .suppressing(CollectionRemoveIfTester::class.java.getMethod("testRemoveIf_allPresent")) .createTestSuite() } - - /** - * For some reason IntelliJ really wants to scan this class for tests and fail when - * it doesn't find any. This stops that error from occurring. - */ - @Test fun dummy() { - } } - /** - * Test removal, which Guava's standard tests can't cover for us. - */ - class Remove { + class General { @Test - fun `construction`() { - val expected = 17 - val basicSet = nonEmptySetOf(expected) - val actual = basicSet.first() - assertEquals(expected, actual) - } - - @Test(expected = IllegalStateException::class) - fun `remove sole element`() { - val basicSet = nonEmptySetOf(-17) - basicSet.remove(-17) + fun `copyOf - empty source`() { + assertThatThrownBy { NonEmptySet.copyOf(HashSet()) }.isInstanceOf(IllegalArgumentException::class.java) } @Test - fun `remove one of two elements`() { - val basicSet = nonEmptySetOf(-17, 17) - basicSet.remove(-17) + fun head() { + assertThat(NonEmptySet.of(1, 2).head()).isEqualTo(1) } @Test - fun `remove element which does not exist`() { - val basicSet = nonEmptySetOf(-17) - basicSet.remove(-5) - assertEquals(1, basicSet.size) - } + fun `serialize deserialize`() { + initialiseTestSerialization() + try { + val original = NonEmptySet.of(-17, 22, 17) + val copy = original.serialize().deserialize() - @Test(expected = IllegalStateException::class) - fun `remove via iterator`() { - val basicSet = nonEmptySetOf(-17, 17) - val iterator = basicSet.iterator() - while (iterator.hasNext()) { - iterator.remove() + assertThat(copy).isEqualTo(original).isNotSameAs(original) + } finally { + resetTestSerialization() } } } - /** - * Test serialization/deserialization. - */ - class Serializer { - @Test - fun `serialize deserialize`() { - val expected: NonEmptySet = nonEmptySetOf(-17, 22, 17) - val serialized = expected.serialize().bytes - val actual = serialized.deserialize>() - - assertEquals(expected, actual) - } - } -} - -/** - * Generator of non empty set instances needed for testing. - */ -class NonEmptySetGenerator : TestIntegerSetGenerator() { - override fun create(elements: Array?): NonEmptySet? { - val set = nonEmptySetOf(elements!!.first()) - set.addAll(elements.toList()) - return set + private object NonEmptySetGenerator : TestIntegerSetGenerator() { + override fun create(elements: Array): NonEmptySet = NonEmptySet.copyOf(elements.asList()) } } diff --git a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt index 8b56b4ad13..f307a2375e 100644 --- a/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt +++ b/core/src/test/kotlin/net/corda/core/utilities/ProgressTrackerTest.kt @@ -1,11 +1,5 @@ package net.corda.core.utilities -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.KryoSerializable -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import net.corda.core.serialization.createTestKryo -import net.corda.core.serialization.serialize import org.junit.Before import org.junit.Test import java.util.* @@ -94,31 +88,4 @@ class ProgressTrackerTest { pt.currentStep = SimpleSteps.ONE assertEquals(SimpleSteps.TWO, pt.nextStep()) } - - @Test - fun rxSubscriptionsAreNotSerialized() { - class Unserializable : KryoSerializable { - override fun write(kryo: Kryo?, output: Output?) = throw AssertionError("not called") - override fun read(kryo: Kryo?, input: Input?) = throw AssertionError("not called") - - fun foo() { - println("bar") - } - } - - val kryo = createTestKryo().apply { - // This is required to make sure Kryo walks through the auto-generated members for the lambda below. - fieldSerializerConfig.isIgnoreSyntheticFields = false - } - pt.setChildProgressTracker(SimpleSteps.TWO, pt2) - class Tmp { - val unserializable = Unserializable() - - init { - pt2.changes.subscribe { unserializable.foo() } - } - } - Tmp() - pt.serialize(kryo) - } } diff --git a/docs/source/_static/corda-cheat-sheet.pdf b/docs/source/_static/corda-cheat-sheet.pdf index c6378f1c27..cc00820e04 100644 Binary files a/docs/source/_static/corda-cheat-sheet.pdf and b/docs/source/_static/corda-cheat-sheet.pdf differ diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css index 7dfbdc74cf..c7f2ab9999 100644 --- a/docs/source/_static/css/custom.css +++ b/docs/source/_static/css/custom.css @@ -179,3 +179,11 @@ a:visited { background-position: center top; background-origin: content box; } + + +/* Version dropdown */ + +.version-dropdown { + border-radius: 4px; + border-color: #263673; +} \ No newline at end of file diff --git a/docs/source/_static/versions b/docs/source/_static/versions new file mode 100644 index 0000000000..4110e2dfd6 --- /dev/null +++ b/docs/source/_static/versions @@ -0,0 +1,12 @@ +{ + "https://docs.corda.net/releases/release-M6.0": "M6.0", + "https://docs.corda.net/releases/release-M7.0": "M7.0", + "https://docs.corda.net/releases/release-M8.2": "M8.2", + "https://docs.corda.net/releases/release-M9.2": "M9.2", + "https://docs.corda.net/releases/release-M10.1": "M10.1", + "https://docs.corda.net/releases/release-M11.2": "M11.2", + "https://docs.corda.net/releases/release-M12.1": "M12.1", + "https://docs.corda.net/releases/release-M13.0": "M13.0", + "https://docs.corda.net/releases/release-M14.0": "M14.0", + "https://docs.corda.net/head/": "Master" +} diff --git a/docs/source/_templates/layout_for_doc_website.html b/docs/source/_templates/layout_for_doc_website.html index 841ca935b3..2272ea7fb4 100644 --- a/docs/source/_templates/layout_for_doc_website.html +++ b/docs/source/_templates/layout_for_doc_website.html @@ -10,35 +10,29 @@ API reference: Kotlin/ Slack
- +
+ {% endblock %} {% block footer %} + + + + +Updating the contract +===================== + +Remember that each state references a contract. The contract imposes constraints on transactions involving that state. +If the transaction does not obey the constraints of all the contracts of all its states, it cannot become a valid +ledger update. + +We need to modify our contract so that the lender's signature is required in any IOU creation transaction. This will +only require changing a single line of code. In ``IOUContract.java``/``IOUContract.kt``, update the final line of the +``requireThat`` block as follows: + +.. container:: codeset + + .. code-block:: kotlin + + "The borrower and lender must be signers." using (command.signers.containsAll(listOf( + out.borrower.owningKey, out.lender.owningKey))) + + .. code-block:: java + + check.using("The borrower and lender must be signers.", command.getSigners().containsAll( + ImmutableList.of(borrower.getOwningKey(), lender.getOwningKey()))); + +Progress so far +--------------- +Our contract now imposes an additional constraint - the lender must also sign an IOU creation transaction. Next, we +need to update ``IOUFlow`` so that it actually gathers the counterparty's signature as part of the flow. \ No newline at end of file diff --git a/docs/source/tut-two-party-flow.rst b/docs/source/tut-two-party-flow.rst new file mode 100644 index 0000000000..bd5b5a87fc --- /dev/null +++ b/docs/source/tut-two-party-flow.rst @@ -0,0 +1,203 @@ +.. highlight:: kotlin +.. raw:: html + + + + +Updating the flow +================= + +To update the flow, we'll need to do two things: + +* Update the borrower's side of the flow to request the lender's signature +* Create a flow for the lender to run in response to a signature request from the borrower + +Updating the borrower's flow +---------------------------- +In the original CorDapp, we automated the process of notarising a transaction and recording it in every party's vault +by invoking a built-in flow called ``FinalityFlow`` as a subflow. We're going to use another pre-defined flow, called +``CollectSignaturesFlow``, to gather the lender's signature. + +We also need to add the lender's public key to the transaction's command, making the lender one of the required signers +on the transaction. + +In ``IOUFlow.java``/``IOUFlow.kt``, update ``IOUFlow.call`` as follows: + +.. container:: codeset + + .. code-block:: kotlin + + // We add the items to the builder. + val state = IOUState(iouValue, me, otherParty) + val cmd = Command(IOUContract.Create(), listOf(me.owningKey, otherParty.owningKey)) + txBuilder.withItems(state, cmd) + + // Verifying the transaction. + txBuilder.verify(serviceHub) + + // Signing the transaction. + val signedTx = serviceHub.toSignedTransaction(txBuilder) + + // Obtaining the counterparty's signature + val fullySignedTx = subFlow(CollectSignaturesFlow(signedTx)) + + // Finalising the transaction. + subFlow(FinalityFlow(fullySignedTx)) + + .. code-block:: java + + // We add the items to the builder. + IOUState state = new IOUState(iouValue, me, otherParty); + List requiredSigners = ImmutableList.of(me.getOwningKey(), otherParty.getOwningKey()); + Command cmd = new Command(new IOUContract.Create(), requiredSigners); + txBuilder.withItems(state, cmd); + + // Verifying the transaction. + txBuilder.verify(getServiceHub()); + + // Signing the transaction. + final SignedTransaction signedTx = getServiceHub().toSignedTransaction(txBuilder); + + // Obtaining the counterparty's signature + final SignedTransaction fullySignedTx = subFlow(new CollectSignaturesFlow(signedTx, null)); + + // Finalising the transaction. + subFlow(new FinalityFlow(fullySignedTx)); + +To make the lender a required signer, we simply add the lender's public key to the list of signers on the command. + +``CollectSignaturesFlow``, meanwhile, takes a transaction signed by the flow initiator, and returns a transaction +signed by all the transaction's other required signers. We then pass this fully-signed transaction into +``FinalityFlow``. + +The lender's flow +----------------- +Reorganising our class +^^^^^^^^^^^^^^^^^^^^^^ +Before we define the lender's flow, let's reorganise ``IOUFlow.java``/``IOUFlow.kt`` a little bit: + +* Rename ``IOUFlow`` to ``Initiator`` +* In Java, make the ``Initiator`` class static, rename its constructor to match the new name, and move the definition + inside an enclosing ``IOUFlow`` class +* In Kotlin, move the definition of ``Initiator`` class inside an enclosing ``IOUFlow`` singleton object + +We will end up with the following structure: + +.. container:: codeset + + .. code-block:: kotlin + + object IOUFlow { + @InitiatingFlow + @StartableByRPC + class Initiator(val iouValue: Int, + val otherParty: Party) : FlowLogic() { + + .. code-block:: java + + public class IOUFlow { + @InitiatingFlow + @StartableByRPC + public static class Initiator extends FlowLogic { + +Writing the lender's flow +^^^^^^^^^^^^^^^^^^^^^^^^^ +We're now ready to write the lender's flow, which will respond to the borrower's attempt to gather our signature. + +Inside the ``IOUFlow`` class/singleton object, add the following class: + +.. container:: codeset + + .. code-block:: kotlin + + @InitiatedBy(Initiator::class) + class Acceptor(val otherParty: Party) : FlowLogic() { + @Suspendable + override fun call() { + val signTransactionFlow = object : SignTransactionFlow(otherParty) { + override fun checkTransaction(stx: SignedTransaction) = requireThat { + val output = stx.tx.outputs.single().data + "This must be an IOU transaction." using (output is IOUState) + val iou = output as IOUState + "The IOU's value can't be too high." using (iou.value < 100) + } + } + + subFlow(signTransactionFlow) + } + } + + .. code-block:: java + + @InitiatedBy(Initiator.class) + public static class Acceptor extends FlowLogic { + + private final Party otherParty; + + public Acceptor(Party otherParty) { + this.otherParty = otherParty; + } + + @Suspendable + @Override + public Void call() throws FlowException { + class signTxFlow extends SignTransactionFlow { + private signTxFlow(Party otherParty) { + super(otherParty, null); + } + + @Override + protected void checkTransaction(SignedTransaction stx) { + requireThat(require -> { + ContractState output = stx.getTx().getOutputs().get(0).getData(); + require.using("This must be an IOU transaction.", output instanceof IOUState); + IOUState iou = (IOUState) output; + require.using("The IOU's value can't be too high.", iou.getValue() < 100); + return null; + }); + } + } + + subFlow(new signTxFlow(otherParty)); + + return null; + } + } + +As with the ``Initiator``, our ``Acceptor`` flow is a ``FlowLogic`` subclass where we've overridden ``FlowLogic.call``. + +The flow is annotated with ``InitiatedBy(Initiator.class)``, which means that your node will invoke ``Acceptor.call`` +when it receives a message from a instance of ``Initiator`` running on another node. What will this message from the +``Initiator`` be? If we look at the definition of ``CollectSignaturesFlow``, we can see that we'll be sent a +``SignedTransaction``, and are expected to send back our signature over that transaction. + +We could handle this manually. However, there is also a pre-defined flow called ``SignTransactionFlow`` that can handle +this process for us automatically. ``SignTransactionFlow`` is an abstract class, and we must subclass it and override +``SignTransactionFlow.checkTransaction``. + +Once we've defined the subclass, we invoke it using ``FlowLogic.subFlow``, and the communication with the borrower's +and the lender's flow is conducted automatically. + +CheckTransactions +~~~~~~~~~~~~~~~~~ +``SignTransactionFlow`` will automatically verify the transaction and its signatures before signing it. However, just +because a transaction is valid doesn't mean we necessarily want to sign. What if we don't want to deal with the +counterparty in question, or the value is too high, or we're not happy with the transaction's structure? + +Overriding ``SignTransactionFlow.checkTransaction`` allows us to define these additional checks. In our case, we are +checking that: + +* The transaction involves an ``IOUState`` - this ensures that ``IOUContract`` will be run to verify the transaction +* The IOU's value is less than some amount (100 in this case) + +If either of these conditions are not met, we will not sign the transaction - even if the transaction and its +signatures are valid. + +Conclusion +---------- +We have now updated our flow to gather the lender's signature as well, in line with the constraints in ``IOUContract``. +We can now run our updated CorDapp, using the instructions :doc:`here `. + +Our CorDapp now requires agreement from both the lender and the borrower before an IOU can be created on the ledger. +This prevents either the lender or the borrower from unilaterally updating the ledger in a way that only benefits +themselves. diff --git a/docs/source/tut-two-party-index.rst b/docs/source/tut-two-party-index.rst new file mode 100644 index 0000000000..27c69f3e33 --- /dev/null +++ b/docs/source/tut-two-party-index.rst @@ -0,0 +1,10 @@ +Two-party flows +=============== + +.. toctree:: + :maxdepth: 1 + + tut-two-party-introduction + tut-two-party-contract + tut-two-party-flow + tut-two-party-running \ No newline at end of file diff --git a/docs/source/tut-two-party-introduction.rst b/docs/source/tut-two-party-introduction.rst new file mode 100644 index 0000000000..46f924583f --- /dev/null +++ b/docs/source/tut-two-party-introduction.rst @@ -0,0 +1,25 @@ +Introduction +============ + +.. note:: This tutorial extends the CorDapp built during the :doc:`Hello, World tutorial `. You can + find the final version of the CorDapp produced in that tutorial + `here `_. + +In the Hello, World tutorial, we built a CorDapp allowing us to model IOUs on ledger. Our CorDapp was made up of three +elements: + +* An ``IOUState``, representing IOUs on the ledger +* An ``IOUContract``, controlling the evolution of IOUs over time +* An ``IOUFlow``, orchestrating the process of agreeing the creation of an IOU on-ledger + +However, in our original CorDapp, only the IOU's borrower was required to sign transactions issuing IOUs. The lender +had no say in whether the issuance of the IOU was a valid ledger update or not. + +In this tutorial, we'll update our code so that the borrower requires the lender's agreement before they can issue an +IOU onto the ledger. We'll need to make two changes: + +* The ``IOUContract`` will need to be updated so that transactions involving an ``IOUState`` will require the lender's + signature (as well as the borrower's) to become valid ledger updates +* The ``IOUFlow`` will need to be updated to allow for the gathering of the lender's signature + +We'll start by updating the contract. \ No newline at end of file diff --git a/docs/source/tut-two-party-running.rst b/docs/source/tut-two-party-running.rst new file mode 100644 index 0000000000..a616769fda --- /dev/null +++ b/docs/source/tut-two-party-running.rst @@ -0,0 +1,28 @@ +Running our CorDapp +=================== + + + +Conclusion +---------- +We have written a simple CorDapp that allows IOUs to be issued onto the ledger. Like all CorDapps, our +CorDapp is made up of three key parts: + +* The ``IOUState``, representing IOUs on the ledger +* The ``IOUContract``, controlling the evolution of IOUs over time +* The ``IOUFlow``, orchestrating the process of agreeing the creation of an IOU on-ledger. + +Together, these three parts completely determine how IOUs are created and evolved on the ledger. + +Next steps +---------- +You should now be ready to develop your own CorDapps. There's +`a more fleshed-out version of the IOU CorDapp `_ +with an API and web front-end, and a set of example CorDapps in +`the main Corda repo `_, under ``samples``. An explanation of how to run these +samples :doc:`here `. + +As you write CorDapps, you can learn more about the API available :doc:`here `. + +If you get stuck at any point, please reach out on `Slack `_, +`Discourse `_, or `Stack Overflow `_. \ No newline at end of file diff --git a/docs/source/tutorial-attachments.rst b/docs/source/tutorial-attachments.rst index 8570e309ca..90dd58750c 100644 --- a/docs/source/tutorial-attachments.rst +++ b/docs/source/tutorial-attachments.rst @@ -39,10 +39,9 @@ a JVM client. Protocol -------- -Normally attachments on transactions are fetched automatically via the ``ResolveTransactionsFlow``. Attachments +Normally attachments on transactions are fetched automatically via the ``ReceiveTransactionFlow``. Attachments are needed in order to validate a transaction (they include, for example, the contract code), so must be fetched -before the validation process can run. ``ResolveTransactionsFlow`` calls ``FetchTransactionsFlow`` to perform the -actual retrieval. +before the validation process can run. .. note:: Future versions of Corda may support non-critical attachments that are not used for transaction verification and which are shared explicitly. These are useful for attaching and signing auditing data with a transaction @@ -104,7 +103,7 @@ transaction and send it to the recipient node: // Create a trivial transaction that just passes across the attachment - in normal cases there would be // inputs, outputs and commands that refer to this attachment. - val ptx = TransactionType.General.Builder(notary = null) + val ptx = TransactionBuilder(notary = null) require(rpc.attachmentExists(PROSPECTUS_HASH)) ptx.addAttachment(PROSPECTUS_HASH) // TODO: Add a dummy state and specify a notary, so that the tx hash is randomised each time and the demo can be repeated. diff --git a/docs/source/tutorial-building-transactions.rst b/docs/source/tutorial-building-transactions.rst index 0b6da351b2..18734ca89f 100644 --- a/docs/source/tutorial-building-transactions.rst +++ b/docs/source/tutorial-building-transactions.rst @@ -4,102 +4,102 @@ Building transactions Introduction ------------ -Understanding and implementing transactions in Corda is key to building -and implementing real world smart contracts. It is only through -construction of valid Corda transactions containing appropriate data -that nodes on the ledger can map real world business objects into a -shared digital view of the data in the Corda ledger. More importantly as -the developer of new smart contracts it is the code which determines -what data is well formed and what data should be rejected as mistakes, -or to prevent malicious activity. This document details some of the -considerations and APIs used to when constructing transactions as part -of a flow. +Understanding and implementing transactions in Corda is key to building +and implementing real world smart contracts. It is only through +construction of valid Corda transactions containing appropriate data +that nodes on the ledger can map real world business objects into a +shared digital view of the data in the Corda ledger. More importantly as +the developer of new smart contracts it is the code which determines +what data is well formed and what data should be rejected as mistakes, +or to prevent malicious activity. This document details some of the +considerations and APIs used to when constructing transactions as part +of a flow. The Basic Lifecycle Of Transactions ----------------------------------- -Transactions in Corda are constructed in stages and contain a number of -elements. In particular a transaction’s core data structure is the -``net.corda.core.transactions.WireTransaction``, which is usually -manipulated via a -``net.corda.core.contracts.General.TransactionBuilder`` and contains: +Transactions in Corda are constructed in stages and contain a number of +elements. In particular a transaction’s core data structure is the +``net.corda.core.transactions.WireTransaction``, which is usually +manipulated via a +``net.corda.core.transactions.TransactionBuilder`` and contains: -1. A set of Input state references that will be consumed by the final -accepted transaction. +1. A set of Input state references that will be consumed by the final +accepted transaction. -2. A set of Output states to create/replace the consumed states and thus -become the new latest versions of data on the ledger. +2. A set of Output states to create/replace the consumed states and thus +become the new latest versions of data on the ledger. -3. A set of ``Attachment`` items which can contain legal documents, contract -code, or private encrypted sections as an extension beyond the native -contract states. +3. A set of ``Attachment`` items which can contain legal documents, contract +code, or private encrypted sections as an extension beyond the native +contract states. -4. A set of ``Command`` items which give a context to the type of ledger -transition that is encoded in the transaction. Also each command has an -associated set of signer keys, which will be required to sign the -transaction. +4. A set of ``Command`` items which give a context to the type of ledger +transition that is encoded in the transaction. Also each command has an +associated set of signer keys, which will be required to sign the +transaction. -5. A signers list, which is populated by the ``TransactionBuilder`` to -be the union of the signers on the individual Command objects. +5. A signers list, which is populated by the ``TransactionBuilder`` to +be the union of the signers on the individual Command objects. -6. A notary identity to specify the Notary node which is tracking the -state consumption. (If the input states are registered with different -notary nodes the flow will have to insert additional ``NotaryChange`` -transactions to migrate the states across to a consistent notary node, -before being allowed to mutate any states.) +6. A notary identity to specify the Notary node which is tracking the +state consumption. (If the input states are registered with different +notary nodes the flow will have to insert additional ``NotaryChange`` +transactions to migrate the states across to a consistent notary node, +before being allowed to mutate any states.) -7. Optionally a timestamp that can used in the Notary to time bound the -period in which the proposed transaction stays valid. +7. Optionally a timestamp that can used in the Notary to time bound the +period in which the proposed transaction stays valid. -Typically, the ``WireTransaction`` should be regarded as a proposal and -may need to be exchanged back and forth between parties before it can be -fully populated. This is an immediate consequence of the Corda privacy -model, which means that the input states are likely to be unknown to the -other node. +Typically, the ``WireTransaction`` should be regarded as a proposal and +may need to be exchanged back and forth between parties before it can be +fully populated. This is an immediate consequence of the Corda privacy +model, which means that the input states are likely to be unknown to the +other node. -Once the proposed data is fully populated the flow code should freeze -the ``WireTransaction`` and form a ``SignedTransaction``. This is key to -the ledger agreement process, as once a flow has attached a node’s -signature it has stated that all details of the transaction are -acceptable to it. A flow should take care not to attach signatures to -intermediate data, which might be maliciously used to construct a -different ``SignedTransaction``. For instance in a foreign exchange -scenario we shouldn't send a ``SignedTransaction`` with only our sell -side populated as that could be used to take the money without the -expected return of the other currency. Also, it is best practice for -flows to receive back the ``DigitalSignature.WithKey`` of other parties -rather than a full ``SignedTransaction`` objects, because otherwise we -have to separately check that this is still the same -``SignedTransaction`` and not a malicious substitute. +Once the proposed data is fully populated the flow code should freeze +the ``WireTransaction`` and form a ``SignedTransaction``. This is key to +the ledger agreement process, as once a flow has attached a node’s +signature it has stated that all details of the transaction are +acceptable to it. A flow should take care not to attach signatures to +intermediate data, which might be maliciously used to construct a +different ``SignedTransaction``. For instance in a foreign exchange +scenario we shouldn't send a ``SignedTransaction`` with only our sell +side populated as that could be used to take the money without the +expected return of the other currency. Also, it is best practice for +flows to receive back the ``DigitalSignature.WithKey`` of other parties +rather than a full ``SignedTransaction`` objects, because otherwise we +have to separately check that this is still the same +``SignedTransaction`` and not a malicious substitute. -The final stage of committing the transaction to the ledger is to -notarise the ``SignedTransaction``, distribute this to all appropriate -parties and record the data into the ledger. These actions are best -delegated to the ``FinalityFlow``, rather than calling the individual -steps manually. However, do note that the final broadcast to the other -nodes is asynchronous, so care must be used in unit testing to -correctly await the Vault updates. +The final stage of committing the transaction to the ledger is to +notarise the ``SignedTransaction``, distribute this to all appropriate +parties and record the data into the ledger. These actions are best +delegated to the ``FinalityFlow``, rather than calling the individual +steps manually. However, do note that the final broadcast to the other +nodes is asynchronous, so care must be used in unit testing to +correctly await the Vault updates. Gathering Inputs ---------------- -One of the first steps to forming a transaction is gathering the set of -input references. This process will clearly vary according to the nature -of the business process being captured by the smart contract and the -parameterised details of the request. However, it will generally involve -searching the Vault via the ``VaultService`` interface on the +One of the first steps to forming a transaction is gathering the set of +input references. This process will clearly vary according to the nature +of the business process being captured by the smart contract and the +parameterised details of the request. However, it will generally involve +searching the Vault via the ``VaultQueryService`` interface on the ``ServiceHub`` to locate the input states. -To give a few more specific details consider two simplified real world -scenarios. First, a basic foreign exchange Cash transaction. This -transaction needs to locate a set of funds to exchange. A flow -modelling this is implemented in ``FxTransactionBuildTutorial.kt``. -Second, a simple business model in which parties manually accept, or -reject each other's trade proposals which is implemented in -``WorkflowTransactionBuildTutorial.kt``. To run and explore these -examples using the IntelliJ IDE one can run/step the respective unit -tests in ``FxTransactionBuildTutorialTest.kt`` and -``WorkflowTransactionBuildTutorialTest.kt``, which drive the flows as +To give a few more specific details consider two simplified real world +scenarios. First, a basic foreign exchange Cash transaction. This +transaction needs to locate a set of funds to exchange. A flow +modelling this is implemented in ``FxTransactionBuildTutorial.kt``. +Second, a simple business model in which parties manually accept, or +reject each other's trade proposals which is implemented in +``WorkflowTransactionBuildTutorial.kt``. To run and explore these +examples using the IntelliJ IDE one can run/step the respective unit +tests in ``FxTransactionBuildTutorialTest.kt`` and +``WorkflowTransactionBuildTutorialTest.kt``, which drive the flows as part of a simulated in-memory network of nodes. .. |nbsp| unicode:: 0xA0 @@ -116,112 +116,106 @@ standard ``CashState`` in the ``:financial`` Gradle module. The Cash contract uses ``FungibleAsset`` states to model holdings of interchangeable assets and allow the split/merge and summing of states to meet a contractual obligation. We would normally use the -``generateSpend`` method on the ``VaultService`` to gather the required +``Cash.generateSpend`` method to gather the required amount of cash into a ``TransactionBuilder``, set the outputs and move command. However, to elucidate more clearly example flow code is shown -here that will manually carry out the inputs queries using the lower -level ``VaultService``. +here that will manually carry out the inputs queries by specifying relevant +query criteria filters to the ``tryLockFungibleStatesForSpending`` method +of the ``VaultService``. .. literalinclude:: example-code/src/main/kotlin/net/corda/docs/FxTransactionBuildTutorial.kt :language: kotlin :start-after: DOCSTART 1 :end-before: DOCEND 1 -As a foreign exchange transaction we expect an exchange of two -currencies, so we will also require a set of input states from the other -counterparty. However, the Corda privacy model means we do not know the -other node’s states. Our flow must therefore negotiate with the other -node for them to carry out a similar query and populate the inputs (See -the ``ForeignExchangeFlow`` for more details of the exchange). Having -identified a set of Input ``StateRef`` items we can then create the -output as discussed below. +As a foreign exchange transaction we expect an exchange of two +currencies, so we will also require a set of input states from the other +counterparty. However, the Corda privacy model means we do not know the +other node’s states. Our flow must therefore negotiate with the other +node for them to carry out a similar query and populate the inputs (See +the ``ForeignExchangeFlow`` for more details of the exchange). Having +identified a set of Input ``StateRef`` items we can then create the +output as discussed below. -For the trade approval flow we need to implement a simple workflow -pattern. We start by recording the unconfirmed trade details in a state -object implementing the ``LinearState`` interface. One field of this -record is used to map the business workflow to an enumerated state. -Initially the initiator creates a new state object which receives a new -``UniqueIdentifier`` in its ``linearId`` property and a starting -workflow state of ``NEW``. The ``Contract.verify`` method is written to -allow the initiator to sign this initial transaction and send it to the -other party. This pattern ensures that a permanent copy is recorded on -both ledgers for audit purposes, but the state is prevented from being -maliciously put in an approved state. The subsequent workflow steps then -follow with transactions that consume the state as inputs on one side -and output a new version with whatever state updates, or amendments -match to the business process, the ``linearId`` being preserved across -the changes. Attached ``Command`` objects help the verify method -restrict changes to appropriate fields and signers at each step in the -workflow. In this it is typical to have both parties sign the change -transactions, but it can be valid to allow unilateral signing, if for instance -one side could block a rejection. Commonly the manual initiator of these -workflows will query the Vault for states of the right contract type and -in the right workflow state over the RPC interface. The RPC will then -initiate the relevant flow using ``StateRef``, or ``linearId`` values as +For the trade approval flow we need to implement a simple workflow +pattern. We start by recording the unconfirmed trade details in a state +object implementing the ``LinearState`` interface. One field of this +record is used to map the business workflow to an enumerated state. +Initially the initiator creates a new state object which receives a new +``UniqueIdentifier`` in its ``linearId`` property and a starting +workflow state of ``NEW``. The ``Contract.verify`` method is written to +allow the initiator to sign this initial transaction and send it to the +other party. This pattern ensures that a permanent copy is recorded on +both ledgers for audit purposes, but the state is prevented from being +maliciously put in an approved state. The subsequent workflow steps then +follow with transactions that consume the state as inputs on one side +and output a new version with whatever state updates, or amendments +match to the business process, the ``linearId`` being preserved across +the changes. Attached ``Command`` objects help the verify method +restrict changes to appropriate fields and signers at each step in the +workflow. In this it is typical to have both parties sign the change +transactions, but it can be valid to allow unilateral signing, if for instance +one side could block a rejection. Commonly the manual initiator of these +workflows will query the Vault for states of the right contract type and +in the right workflow state over the RPC interface. The RPC will then +initiate the relevant flow using ``StateRef``, or ``linearId`` values as parameters to the flow to identify the states being operated upon. Thus -code to gather the latest input state would be: +code to gather the latest input state for a given ``StateRef`` would use +the ``VaultQueryService`` as follows: .. literalinclude:: example-code/src/main/kotlin/net/corda/docs/WorkflowTransactionBuildTutorial.kt :language: kotlin :start-after: DOCSTART 1 :end-before: DOCEND 1 -.. container:: codeset - - .. sourcecode:: kotlin - - // Pull in the latest Vault version of the StateRef as a full StateAndRef - val latestRecord = serviceHub.latest(ref) - - Generating Commands ------------------- -For the commands that will be added to the transaction, these will need -to correctly reflect the task at hand. These must match because inside -the ``Contract.verify`` method the command will be used to select the -validation code path. The ``Contract.verify`` method will then restrict -the allowed contents of the transaction to reflect this context. Typical -restrictions might include that the input cash amount must equal the -output cash amount, or that a workflow step is only allowed to change -the status field. Sometimes, the command may capture some data too e.g. -the foreign exchange rate, or the identity of one party, or the StateRef -of the specific input that originates the command in a bulk operation. -This data will be used to further aid the ``Contract.verify``, because -to ensure consistent, secure and reproducible behaviour in a distributed -environment the ``Contract.verify``, transaction is the only allowed to -use the content of the transaction to decide validity. +For the commands that will be added to the transaction, these will need +to correctly reflect the task at hand. These must match because inside +the ``Contract.verify`` method the command will be used to select the +validation code path. The ``Contract.verify`` method will then restrict +the allowed contents of the transaction to reflect this context. Typical +restrictions might include that the input cash amount must equal the +output cash amount, or that a workflow step is only allowed to change +the status field. Sometimes, the command may capture some data too e.g. +the foreign exchange rate, or the identity of one party, or the StateRef +of the specific input that originates the command in a bulk operation. +This data will be used to further aid the ``Contract.verify``, because +to ensure consistent, secure and reproducible behaviour in a distributed +environment the ``Contract.verify``, transaction is the only allowed to +use the content of the transaction to decide validity. -Another essential requirement for commands is that the correct set of -``CompositeKeys`` are added to the Command on the builder, which will be -used to form the set of required signers on the final validated -transaction. These must correctly align with the expectations of the -``Contract.verify`` method, which should be written to defensively check -this. In particular, it is expected that at minimum the owner of an -asset would have to be signing to permission transfer of that asset. In -addition, other signatories will often be required e.g. an Oracle -identity for an Oracle command, or both parties when there is an -exchange of assets. +Another essential requirement for commands is that the correct set of +``CompositeKeys`` are added to the Command on the builder, which will be +used to form the set of required signers on the final validated +transaction. These must correctly align with the expectations of the +``Contract.verify`` method, which should be written to defensively check +this. In particular, it is expected that at minimum the owner of an +asset would have to be signing to permission transfer of that asset. In +addition, other signatories will often be required e.g. an Oracle +identity for an Oracle command, or both parties when there is an +exchange of assets. Generating Outputs ------------------ -Having located a set of ``StateAndRefs`` as the transaction inputs, the -flow has to generate the output states. Typically, this is a simple call -to the Kotlin ``copy`` method to modify the few fields that will -transitioned in the transaction. The contract code may provide a -``generateXXX`` method to help with this process if the task is more -complicated. With a workflow state a slightly modified copy state is -usually sufficient, especially as it is expected that we wish to preserve -the ``linearId`` between state revisions, so that Vault queries can find +Having located a set of ``StateAndRefs`` as the transaction inputs, the +flow has to generate the output states. Typically, this is a simple call +to the Kotlin ``copy`` method to modify the few fields that will +transitioned in the transaction. The contract code may provide a +``generateXXX`` method to help with this process if the task is more +complicated. With a workflow state a slightly modified copy state is +usually sufficient, especially as it is expected that we wish to preserve +the ``linearId`` between state revisions, so that Vault queries can find the latest revision. -For fungible contract states such as ``Cash`` it is common to distribute -and split the total amount e.g. to produce a remaining balance output -state for the original owner when breaking up a large amount input -state. Remember that the result of a successful transaction is always to -fully consume/spend the input states, so this is required to conserve -the total cash. For example from the demo code: +For fungible contract states such as ``Cash`` it is common to distribute +and split the total amount e.g. to produce a remaining balance output +state for the original owner when breaking up a large amount input +state. Remember that the result of a successful transaction is always to +fully consume/spend the input states, so this is required to conserve +the total cash. For example from the demo code: .. literalinclude:: example-code/src/main/kotlin/net/corda/docs/FxTransactionBuildTutorial.kt :language: kotlin @@ -231,15 +225,13 @@ the total cash. For example from the demo code: Building the WireTransaction ---------------------------- -Having gathered all the ingredients for the transaction we now need to -use a ``TransactionBuilder`` to construct the full ``WireTransaction``. -The initial ``TransactionBuilder`` should be created by calling the -``TransactionType.General.Builder`` method. (The other -``TransactionBuilder`` implementation is only used for the ``NotaryChange`` flow where -``ContractStates`` need moving to a different Notary.) At this point the -Notary to associate with the states should be recorded. Then we keep -adding inputs, outputs, commands and attachments to fill the -transaction. Examples of this process are: +Having gathered all the ingredients for the transaction we now need to +use a ``TransactionBuilder`` to construct the full ``WireTransaction``. +The initial ``TransactionBuilder`` should be created by calling the +``TransactionBuilder`` method. At this point the +Notary to associate with the states should be recorded. Then we keep +adding inputs, outputs, commands and attachments to fill the +transaction. Examples of this process are: .. literalinclude:: example-code/src/main/kotlin/net/corda/docs/WorkflowTransactionBuildTutorial.kt :language: kotlin @@ -254,51 +246,51 @@ transaction. Examples of this process are: Completing the SignedTransaction -------------------------------- -Having created an initial ``WireTransaction`` and converted this to an -initial ``SignedTransaction`` the process of verifying and forming a -full ``SignedTransaction`` begins and then completes with the -notarisation. In practice this is a relatively stereotypical process, -because assuming the ``WireTransaction`` is correctly constructed the -verification should be immediate. However, it is also important to -recheck the business details of any data received back from an external -node, because a malicious party could always modify the contents before -returning the transaction. Each remote flow should therefore check as -much as possible of the initial ``SignedTransaction`` inside the ``unwrap`` of -the receive before agreeing to sign. Any issues should immediately throw -an exception to abort the flow. Similarly the originator, should always -apply any new signatures to its original proposal to ensure the contents +Having created an initial ``WireTransaction`` and converted this to an +initial ``SignedTransaction`` the process of verifying and forming a +full ``SignedTransaction`` begins and then completes with the +notarisation. In practice this is a relatively stereotypical process, +because assuming the ``WireTransaction`` is correctly constructed the +verification should be immediate. However, it is also important to +recheck the business details of any data received back from an external +node, because a malicious party could always modify the contents before +returning the transaction. Each remote flow should therefore check as +much as possible of the initial ``SignedTransaction`` inside the ``unwrap`` of +the receive before agreeing to sign. Any issues should immediately throw +an exception to abort the flow. Similarly the originator, should always +apply any new signatures to its original proposal to ensure the contents of the transaction has not been altered by the remote parties. -The typical code therefore checks the received ``SignedTransaction`` -using the ``verifySignatures`` method, but excluding itself, the notary -and any other parties yet to apply their signature. The contents of the -``WireTransaction`` inside the ``SignedTransaction`` should be fully -verified further by expanding with ``toLedgerTransaction`` and calling -``verify``. Further context specific and business checks should then be -made, because the ``Contract.verify`` is not allowed to access external -context. For example the flow may need to check that the parties are the -right ones, or that the ``Command`` present on the transaction is as -expected for this specific flow. An example of this from the demo code is: +The typical code therefore checks the received ``SignedTransaction`` +using the ``verifySignaturesExcept`` method, excluding itself, the +notary and any other parties yet to apply their signature. The contents of the +``WireTransaction`` inside the ``SignedTransaction`` should be fully +verified further by expanding with ``toLedgerTransaction`` and calling +``verify``. Further context specific and business checks should then be +made, because the ``Contract.verify`` is not allowed to access external +context. For example the flow may need to check that the parties are the +right ones, or that the ``Command`` present on the transaction is as +expected for this specific flow. An example of this from the demo code is: .. literalinclude:: example-code/src/main/kotlin/net/corda/docs/WorkflowTransactionBuildTutorial.kt :language: kotlin :start-after: DOCSTART 3 :end-before: DOCEND 3 -After verification the remote flow will return its signature to the -originator. The originator should apply that signature to the starting +After verification the remote flow will return its signature to the +originator. The originator should apply that signature to the starting ``SignedTransaction`` and recheck the signatures match. Committing the Transaction -------------------------- -Once all the party signatures are applied to the SignedTransaction the -final step is notarisation. This involves calling ``NotaryFlow.Client`` -to confirm the transaction, consume the inputs and return its confirming -signature. Then the flow should ensure that all nodes end with all -signatures and that they call ``ServiceHub.recordTransactions``. The -code for this is standardised in the ``FinalityFlow``, or more explicitly -an example is: +Once all the party signatures are applied to the SignedTransaction the +final step is notarisation. This involves calling ``NotaryFlow.Client`` +to confirm the transaction, consume the inputs and return its confirming +signature. Then the flow should ensure that all nodes end with all +signatures and that they call ``ServiceHub.recordTransactions``. The +code for this is standardised in the ``FinalityFlow``, or more explicitly +an example is: .. literalinclude:: example-code/src/main/kotlin/net/corda/docs/WorkflowTransactionBuildTutorial.kt :language: kotlin @@ -308,19 +300,19 @@ an example is: Partially Visible Transactions ------------------------------ -The discussion so far has assumed that the parties need full visibility -of the transaction to sign. However, there may be situations where each -party needs to store private data for audit purposes, or for evidence to -a regulator, but does not wish to share that with the other trading -partner. The tear-off/Merkle tree support in Corda allows flows to send -portions of the full transaction to restrict visibility to remote -parties. To do this one can use the -``WireTransaction.buildFilteredTransaction`` extension method to produce -a ``FilteredTransaction``. The elements of the ``SignedTransaction`` -which we wish to be hide will be replaced with their secure hash. The -overall transaction txid is still provable from the -``FilteredTransaction`` preventing change of the private data, but we do -not expose that data to the other node directly. A full example of this -can be found in the ``NodeInterestRates`` Oracle code from the -``irs-demo`` project which interacts with the ``RatesFixFlow`` flow. -Also, refer to the :doc:`merkle-trees` documentation. +The discussion so far has assumed that the parties need full visibility +of the transaction to sign. However, there may be situations where each +party needs to store private data for audit purposes, or for evidence to +a regulator, but does not wish to share that with the other trading +partner. The tear-off/Merkle tree support in Corda allows flows to send +portions of the full transaction to restrict visibility to remote +parties. To do this one can use the +``WireTransaction.buildFilteredTransaction`` extension method to produce +a ``FilteredTransaction``. The elements of the ``SignedTransaction`` +which we wish to be hide will be replaced with their secure hash. The +overall transaction txid is still provable from the +``FilteredTransaction`` preventing change of the private data, but we do +not expose that data to the other node directly. A full example of this +can be found in the ``NodeInterestRates`` Oracle code from the +``irs-demo`` project which interacts with the ``RatesFixFlow`` flow. +Also, refer to the :doc:`merkle-trees` documentation. diff --git a/docs/source/tutorial-contract-clauses.rst b/docs/source/tutorial-contract-clauses.rst index cb6ea53f46..c8b3ca98cd 100644 --- a/docs/source/tutorial-contract-clauses.rst +++ b/docs/source/tutorial-contract-clauses.rst @@ -68,7 +68,7 @@ We start by defining the ``CommercialPaper`` class. As in the previous tutorial, class CommercialPaper : Contract { override val legalContractReference: SecureHash = SecureHash.sha256("https://en.wikipedia.org/wiki/Commercial_paper") - override fun verify(tx: TransactionForContract) = verifyClause(tx, Clauses.Group(), tx.commands.select()) + override fun verify(tx: LedgerTransaction) = verifyClause(tx, Clauses.Group(), tx.commands.select()) interface Commands : CommandData { data class Move(override val contractHash: SecureHash? = null) : FungibleAsset.Commands.Move, Commands @@ -85,7 +85,7 @@ We start by defining the ``CommercialPaper`` class. As in the previous tutorial, } @Override - public void verify(@NotNull TransactionForContract tx) throws IllegalArgumentException { + public void verify(@NotNull LedgerTransaction tx) throws IllegalArgumentException { ClauseVerifier.verifyClause(tx, new Clauses.Group(), extractCommands(tx)); } @@ -128,7 +128,7 @@ and is included in the ``CommercialPaper.kt`` code. override val requiredCommands: Set> get() = setOf(Commands.Move::class.java) - override fun verify(tx: TransactionForContract, + override fun verify(tx: LedgerTransaction, inputs: List, outputs: List, commands: List>, @@ -158,7 +158,7 @@ and is included in the ``CommercialPaper.kt`` code. @NotNull @Override - public Set verify(@NotNull TransactionForContract tx, + public Set verify(@NotNull LedgerTransaction tx, @NotNull List inputs, @NotNull List outputs, @NotNull List> commands, @@ -229,7 +229,7 @@ its subclauses (wrapped move, issue, redeem). "Any" in this case means that it w Redeem(), Move(), Issue())) { - override fun groupStates(tx: TransactionForContract): List>> + override fun groupStates(tx: LedgerTransaction): List>> = tx.groupStates> { it.token } } @@ -246,7 +246,7 @@ its subclauses (wrapped move, issue, redeem). "Any" in this case means that it w @NotNull @Override - public List> groupStates(@NotNull TransactionForContract tx) { + public List> groupStates(@NotNull LedgerTransaction tx) { return tx.groupStates(State.class, State::withoutOwner); } } diff --git a/docs/source/tutorial-contract.rst b/docs/source/tutorial-contract.rst index b41f65aa25..cdee7aa768 100644 --- a/docs/source/tutorial-contract.rst +++ b/docs/source/tutorial-contract.rst @@ -61,7 +61,7 @@ Kotlin syntax works. class CommercialPaper : Contract { override val legalContractReference: SecureHash = SecureHash.sha256("https://en.wikipedia.org/wiki/Commercial_paper"); - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { TODO() } } @@ -75,7 +75,7 @@ Kotlin syntax works. } @Override - public void verify(TransactionForContract tx) { + public void verify(LedgerTransaction tx) { throw new UnsupportedOperationException(); } } @@ -114,7 +114,7 @@ A state is a class that stores data that is checked by the contract. A commercia override val participants = listOf(owner) fun withoutOwner() = copy(owner = AnonymousParty(NullPublicKey)) - override fun withNewOwner(newOwner: PublicKey) = Pair(Commands.Move(), copy(owner = newOwner)) + override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Move(), copy(owner = newOwner)) } .. sourcecode:: java @@ -298,7 +298,7 @@ run two contracts one time each: Cash and CommercialPaper. .. sourcecode:: kotlin - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { // Group by everything except owner: any modification to the CP at all is considered changing it fundamentally. val groups = tx.groupStates(State::withoutOwner) @@ -309,7 +309,7 @@ run two contracts one time each: Cash and CommercialPaper. .. sourcecode:: java @Override - public void verify(TransactionForContract tx) { + public void verify(LedgerTransaction tx) { List> groups = tx.groupStates(State.class, State::withoutOwner); AuthenticatedObject cmd = requireSingleCommand(tx.getCommands(), Commands.class); @@ -356,7 +356,7 @@ inputs e.g. because she received the dollars in two payments. The input and outp the cash smart contract must consider the pounds and the dollars separately because they are not fungible: they cannot be merged together. So we have two groups: A and B. -The ``TransactionForContract.groupStates`` method handles this logic for us: firstly, it selects only states of the +The ``LedgerTransaction.groupStates`` method handles this logic for us: firstly, it selects only states of the given type (as the transaction may include other types of state, such as states representing bond ownership, or a multi-sig state) and then it takes a function that maps a state to a grouping key. All states that share the same key are grouped together. In the case of the cash example above, the grouping key would be the currency. @@ -448,7 +448,7 @@ logic. is Commands.Redeem -> { // Redemption of the paper requires movement of on-ledger cash. val input = inputs.single() - val received = tx.outputs.sumCashBy(input.owner) + val received = tx.outputs.map{ it.data }.sumCashBy(input.owner) val time = timeWindow?.fromTime ?: throw IllegalArgumentException("Redemptions must be timestamped") requireThat { "the paper must have matured" using (time >= input.maturityDate) @@ -680,9 +680,9 @@ Finally, we can do redemption. .. sourcecode:: kotlin @Throws(InsufficientBalanceException::class) - fun generateRedeem(tx: TransactionBuilder, paper: StateAndRef, vault: VaultService) { + fun generateRedeem(tx: TransactionBuilder, paper: StateAndRef, services: ServiceHub) { // Add the cash movement using the states in our vault. - vault.generateSpend(tx, paper.state.data.faceValue.withoutIssuer(), paper.state.data.owner) + Cash.generateSpend(services, tx, paper.state.data.faceValue.withoutIssuer(), paper.state.data.owner) tx.addInputState(paper) tx.addCommand(Command(Commands.Redeem(), paper.state.data.owner.owningKey)) } @@ -698,7 +698,7 @@ from the issuer of the commercial paper to the current owner. If we don't have e an exception is thrown. Then we add the paper itself as an input, but, not an output (as we wish to remove it from the ledger). Finally, we add a Redeem command that should be signed by the owner of the commercial paper. -.. warning:: The amount we pass to the ``generateSpend`` function has to be treated first with ``withoutIssuer``. +.. warning:: The amount we pass to the ``Cash.generateSpend`` function has to be treated first with ``withoutIssuer``. This reflects the fact that the way we handle issuer constraints is still evolving; the commercial paper contract requires payment in the form of a currency issued by a specific party (e.g. the central bank, or the issuers own bank perhaps). But the vault wants to assemble spend transactions using cash states from @@ -707,7 +707,7 @@ from the ledger). Finally, we add a Redeem command that should be signed by the A ``TransactionBuilder`` is not by itself ready to be used anywhere, so first, we must convert it to something that is recognised by the network. The most important next step is for the participating entities to sign it. Typically, -an initiating flow will create an initial partially signed ``SignedTransaction`` by calling the ``serviceHub.signInitialTransaction`` method. +an initiating flow will create an initial partially signed ``SignedTransaction`` by calling the ``serviceHub.toSignedTransaction`` method. Then the frozen ``SignedTransaction`` can be passed to other nodes by the flow, these can sign using ``serviceHub.createSignature`` and distribute. The ``CollectSignaturesFlow`` provides a generic implementation of this process that can be used as a ``subFlow`` . @@ -791,7 +791,7 @@ The time-lock contract mentioned above can be implemented very simply: class TestTimeLock : Contract { ... - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { val time = tx.timestamp.before ?: throw IllegalStateException(...) ... requireThat { diff --git a/docs/source/tutorial-cordapp.rst b/docs/source/tutorial-cordapp.rst index c17c81b874..7ab46f30c1 100644 --- a/docs/source/tutorial-cordapp.rst +++ b/docs/source/tutorial-cordapp.rst @@ -7,832 +7,549 @@ The example CorDapp =================== -This guide covers how to get started with the example CorDapp. Please note there are several Corda repositories: +.. contents:: -* `corda `_ which contains the core platform code and sample CorDapps. -* `cordapp-tutorial `_ which contains an example CorDapp you can use to bootstrap your own CorDapps. It is the subject of this tutorial and should help you understand the basics. -* `cordapp-template `_ which contains a bare-bones template designed for starting new projects by copying. +The example CorDapp allows nodes to agree IOUs with each other. Nodes will always agree to the creation of a new IOU +unless: -We recommend you read the non-technical white paper and technical white paper before you get started with Corda: +* Its value is less than 1, or greater than 99 +* A node tries to issue an IOU to itself -1. `The Introductory white paper `_ describes the - motivating vision and background of the project. It is the kind of document your boss should read. It describes why the - project exists and briefly compares it to alternative systems on the market. -2. `The Technical white paper `_ describes the entire - intended design from beginning to end. It is the kind of document that you should read, or at least, read parts of. Note - that because the technical white paper describes the intended end state, it does not always align with the implementation. - -Background ----------- - -The example CorDapp implements a basic scenario where one party wishes to send an IOU to another party. The scenario -defines four nodes: +By default, the CorDapp is deployed on 4 test nodes: * **Controller**, which hosts the network map service and validating notary service * **NodeA** * **NodeB** * **NodeC** -The nodes can generate IOUs and send them to other nodes. The flows used to facilitate the agreement process always results in -an agreement with the recipient as long as the IOU meets the contract constraints which are defined in ``IOUContract.kt``. +Because data is only propagated on a need-to-know basis, any IOUs agreed between NodeA and NodeB become "shared facts" +between NodeA and NodeB only. NodeC won't be aware of these IOUs. -All agreed IOUs between NodeA and NodeB become "shared facts" between NodeA and NodeB. But note that NodeC -won't see any of these transactions or have copies of any of the resulting ``IOUState`` objects. This is -because data is only propagated on a need-to-know basis. - -Getting started ---------------- - -There are two ways to get started with the example CorDapp. You can either work from a milestone release of Corda or a -SNAPSHOT release of Corda. - -**Using a monthly Corda milestone release.** If you wish to develop your CorDapp using the most recent milestone release -then you can get started simply by cloning the ``cordapp-tutorial`` repository. Gradle will grab all the required dependencies -for you from our `public Maven repository `_. - -**Using a Corda SNAPSHOT build.** Alternatively, if you wish to work from the master branch of the Corda repo which contains -the most up-to-date Corda feature set then you will need to clone the ``corda`` repository and publish the latest master -build (or previously tagged releases) to your local Maven repository. You will then need to ensure that Gradle -grabs the correct dependencies for you from Maven local by changing the ``corda_release_version`` in ``build.gradle``. -This will be covered below in `Using a SNAPSHOT release`_. - -Firstly, follow the :doc:`getting set up ` page to download the JDK, IntelliJ and git if you didn't -already have it. - -Working from milestone releases -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you wish to build a CorDapp against a milestone release then please use these instructions. - -The process for developing your CorDapp from a milestone release is the most simple way to get started and is the preferred -approach. - -We publish all our milestone releases to a public Maven repository on a monthly basis. As such, Gradle will automatically -grab the appropriately versioned (specified in the ``cordapp-tutorial``'s ``build.gradle`` file) dependencies for you from Maven. -All you have to do is check out the release tag of the tutorial version you wish to use. - -By default, the ``master`` branch of the ``cordapp-tutorial`` points to a SNAPSHOT release of Corda, this is because it is -being constantly updated to reflect the changes in the master branch of the `corda` repository. - -.. note:: If you wish to use a SNAPSHOT release then follow the instructions below: `Using a SNAPSHOT release`_. - -To clone the ``cordapp-tutorial`` repository, use the following command: +Downloading the example CorDapp +------------------------------- +If you haven't already, set up your machine by following the :doc:`quickstart guide `. Then clone the +example CorDapp from the `cordapp-tutorial repository `_ using the following +command: ``git clone https://github.com/corda/cordapp-tutorial`` -Now change directories to the freshly cloned repo: +And change directories to the freshly cloned repo: ``cd cordapp-tutorial`` -To enumerate all the tagged releases. Use: +We want to work off the latest Milestone release. To enumerate all the Milestone releases, run: ``git tag`` -To checkout a specific tag, use: +And check out the latest (highest-numbered) Milestone release using: -``git checkout -b [local_branch_name] tags/[tag_name]`` +``git checkout [tag_name]`` -where ``local_branch_name`` is a name of your choice and ``tag_name`` is the name of the tag you wish to checkout. +Where ``tag_name`` is the name of the tag you wish to checkout. Gradle will grab all the required dependencies for you +from our `public Maven repository `_. -Gradle will handle all the dependencies for you. Now you are now ready to get started `building the example CorDapp`_. +.. note:: If you wish to build off the latest, unstable version of the codebase, follow the instructions in + `Using a SNAPSHOT release`_. -Using a SNAPSHOT release -~~~~~~~~~~~~~~~~~~~~~~~~ - -If you wish to build a CorDapp against the most current version of Corda, follow these instructions. - -Firstly navigate to the folder on your machine you wish to clone the Corda repository to. Then use the following command -to clone the Corda repository: - -``git clone https://github.com/corda/corda.git`` - -Now change directories: - -``cd corda`` - -Once you've cloned the ``corda`` repository and are in the repo directory you have the option to remain on the master -branch or checkout a specific branch. Use: - -``git branch --all`` - -to enumerate all the branches. To checkout a specific branch, use: - -``git checkout -b [local_branch_name] origin/[remote_branch_name]`` - -where ``local_branch_name`` is a name of your choice and ``remote_branch_name`` is the name of the remote branch you wish -to checkout. - -.. note:: When working with ``master`` you will have access to the most up-to-date feature set. However you will be - potentially sacrificing stability. We will endeavour to keep the ``master`` branch of the ``cordapp-tutorial`` repo in sync - with the ``master`` branch of ``corda`` repo. A milestone tagged release would be more stable for CorDapp development. - -The next step is to publish the Corda JARs to your local Maven repository. By default the Maven local repository can be -found: - -* ``~/.m2/repository`` on Unix/Mac OS X -* ``%HOMEPATH%\.m2`` on windows. - -Publishing can be done with running the following Gradle task from the root project directory: - -Unix/Mac OSX: ``./gradlew install`` - -Windows: ``gradlew.bat install`` - -This will install all required modules, along with sources and JavaDocs to your local Maven repository. The ``version`` -and ``groupid`` of Corda installed to Maven local is specified in the ``build.gradle`` file in the root of the ``corda`` -repository. You shouldn't have to change these values unless you want to publish multiple versions of a SNAPSHOT, e.g. -if you are trying out new features, in this case you can change ``version`` for each SNAPSHOT you publish. - -.. note:: **A quick point on corda version numbers used by Gradle.** - - In the ``build.gradle`` file for your CorDapp, you can specify the ``corda_release_version`` to use. It is important - that when developing your CorDapp that you use the correct version number. For example, when wanting to work from a SNAPSHOT - release, the release numbers are suffixed with 'SNAPSHOT', e.g. if the latest milestone release is M6 then the - SNAPSHOT release will be 0.7-SNAPSHOT, and so on. As such, you will set your ``corda_release_version`` to ``'0.7-SNAPSHOT'`` - in the ``build.gradle`` file in your CorDapp. Gradle will automatically grab the SNAPSHOT dependencies from your local - Maven repository. Alternatively, if working from a milestone release, you will use the version number only, for example - ``0.6`` or ``0.7``. - - Lastly, as the Corda repository evolves on a daily basis up until the next milestone release, it is worth nothing that - the substance of two SNAPSHOT releases of the same number may be different. If you are using a SNAPSHOT and need help - debugging an error then please tell us the **commit** you are working from. This will help us ascertain the issue. - -As additional feature branches are merged into Corda you can ``git pull`` the new changes from the ``corda`` repository. -If you are feeling inquisitive, you may also wish to review some of the current feature branches. All new features are -developed on separate branches. To enumerate all the current branches use: - -``git branch --all`` - -and to check out an open feature branch, use: - -``git checkout -b [local_branch_name] origin/[branch_name]`` - -.. note:: Publishing Corda JARs from unmerged feature branches might cause some unexpected behaviour / broken CorDapps. - It would also replace any previously published SNAPSHOTS of the same version. - -.. warning:: If you do modify Corda after you have previously published it to Maven local then you must republish your - SNAPSHOT build such that Maven reflects the changes you have made. - -Once you have published the Corda JARs to your local Maven repository, you are ready to get started building your -CorDapp using the latest Corda features. - -Opening the example CorDapp with IntelliJ -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -For those familiar with IntelliJ, you can skip this section. - -As noted in the getting started guide, we recommend using the IntelliJ IDE. Assuming you have already downloaded and -installed IntelliJ, lets now open the example CorDapp with IntelliJ. +Opening the example CorDapp in IntelliJ +--------------------------------------- +Let's open the example CorDapp in the IntelliJ IDE. **For those completely new to IntelliJ** -Firstly, load up IntelliJ. A dialogue will appear: +Upon opening IntelliJ, a dialogue will appear: .. image:: resources/intellij-welcome.png :width: 400 Click open, then navigate to the folder where you cloned the ``cordapp-tutorial`` and click OK. -Next, IntelliJ will show a bunch of pop-up windows. One of which requires our attention: +Next, IntelliJ will show several pop-up windows, one of which requires our attention: .. image:: resources/unlinked-gradle-project.png :width: 400 -Click the 'import gradle project' link. A dialogue will pop-up. Press OK. Gradle will now obtain all the -project dependencies and perform some indexing. It usually takes a minute or so. If you miss the 'import gradle project' -dialogue, simply close and re-open IntelliJ again to see it again. +Click the 'import gradle project' link. A dialogue will pop-up. Press OK. Gradle will now download all the +project dependencies and perform some indexing. This usually takes a minute or so. -**Alternative approach** - -Alternatively, one can instruct IntelliJ to create a new project through cloning a repository. From the IntelliJ welcome -dialogue (shown above), opt to 'check out from version control', then select git and enter the git URL for the example CorDapp -(https://github.com/corda/cordapp-tutorial). You'll then need to import the Gradle project when prompted, as explained above. +If the 'import gradle project' pop-up does not appear, click the small green speech bubble at the bottom-right of +the IDE, or simply close and re-open IntelliJ again to make it reappear. **If you already have IntelliJ open** -From the ``File`` menu, navigate to ``Open ...`` and then navigate to the directory where you cloned the ``cordapp-tutorial``. -Alternatively, if you wish to clone from github directly then navigate to ``File > New > Project from existing sources ...`` -and enter the URL to the example CorDapp (specified above). When instructed, be sure to import the Gradle project when prompted. +From the ``File`` menu, navigate to ``Open ...`` and then navigate to the directory where you cloned the +``cordapp-tutorial`` and click OK. -**The Gradle plugin** - -IntelliJ can be used to run Gradle tasks through the Gradle plugin which can be found via ``View > Tool windows > Gradle``. -All the Gradle projects are listed in the window on the right hand side of the IDE. Click on a project, then 'tasks' to -see all available Gradle tasks. - -* For the example CorDapp repo there will only be one Gradle project listed. -* For the Corda repo there will be many project listed, the root project ``corda`` and associated sub-projects: ``core``, - ``finance``, ``node``, etc. - -.. note:: It's worth noting that when you change branch in the example CorDapp, the ``corda_release_version`` will change to - reflect the version of the branch you are working from. - -To execute a task, double click it. The output will be shown in a console window. - -Building the example CorDapp ----------------------------- - -**From the command line** - -Firstly, return to your terminal window used above and make sure you are in the ``cordapp-tutorial`` directory. - -To build the example CorDapp use the following command: - -Unix/Mac OSX: ``./gradlew deployNodes`` - -Windows: ``gradlew.bat deployNodes`` - -This build process will build the example CorDapp defined in the example CorDapp source. CorDapps can be written in -any language targeting the JVM. In our case, we've provided the example source in both Kotlin (``/kotlin-source/src``) and -Java (``/java-source/src``) Since both sets of source files are functionally identical, we will refer to the Kotlin build -throughout the documentation. - -For more information on the example CorDapp see "The Example CorDapp" section below. Gradle will then grab all the -dependencies for you and build the example CorDapp. - -The ``deployNodes`` Gradle task allows you easily create a formation of Corda nodes. In the case of the example CorDapp -we are creating ``four`` nodes. - -After the building process has finished to see the newly built nodes, you can navigate to the ``kotlin-source/build/nodes`` folder -located in the ``cordapp-tutorial`` root directory. You can ignore the other folders in ``/build`` for now. The ``nodes`` -folder has the following structure: +Project structure +----------------- +The example CorDapp has the following directory structure: .. sourcecode:: none - . nodes - ├── controller - │   ├── corda.jar - │   ├── node.conf - │   └── plugins - ├── nodea - │   ├── corda.jar - │   ├── node.conf - │   └── plugins - ├── nodeb - │   ├── corda.jar - │   ├── node.conf - │   └── plugins - ├── nodec - │   ├── corda.jar - │   ├── node.conf - │   └── plugins - ├── runnodes - └── runnodes.bat + . + ├── LICENCE + ├── README.md + ├── TRADEMARK + ├── build.gradle + ├── config + │   ├── dev + │   │   └── log4j2.xml + │   └── test + │   └── log4j2.xml + ├── doc + │   └── example_flow.plantuml + ├── gradle + │   └── wrapper + │   ├── gradle-wrapper.jar + │   └── gradle-wrapper.properties + ├── gradle.properties + ├── gradlew + ├── gradlew.bat + ├── java-source + │   └── ... + ├── kotlin-source + │   ├── build.gradle + │   └── src + │   ├── main + │   │   ├── kotlin + │   │   │   └── com + │   │   │   └── example + │   │   │   ├── api + │   │   │   │   └── ExampleApi.kt + │   │   │   ├── client + │   │   │   │   └── ExampleClientRPC.kt + │   │   │   ├── contract + │   │   │   │   └── IOUContract.kt + │   │   │   ├── flow + │   │   │   │   └── ExampleFlow.kt + │   │   │   ├── model + │   │   │   │   └── IOU.kt + │   │   │   ├── plugin + │   │   │   │   └── ExamplePlugin.kt + │   │   │   ├── schema + │   │   │   │   └── IOUSchema.kt + │   │   │   └── state + │   │   │   └── IOUState.kt + │   │   └── resources + │   │   ├── META-INF + │   │   │   └── services + │   │   │   └── net.corda.webserver.services.WebServerPluginRegistry + │   │   ├── certificates + │   │   │   ├── readme.txt + │   │   │   ├── sslkeystore.jks + │   │   │   └── truststore.jks + │   │   └── exampleWeb + │   │   ├── index.html + │   │   └── js + │   │   └── angular-module.js + │   └── test + │   └── kotlin + │   └── com + │   └── example + │   ├── Main.kt + │   ├── contract + │   │   └── IOUContractTests.kt + │   └── flow + │   └── IOUFlowTests.kt + ├── lib + │   ├── README.txt + │   └── quasar.jar + └── settings.gradle -There will be one folder generated for each node you build (more on later when we get into the detail of the -``deployNodes`` Gradle task) and a ``runnodes`` shell script (batch file on Windows). +The most important files and directories to note are: -Each node folder contains the Corda JAR and a folder for plugins (or CorDapps). There is also -a node.conf file. See :doc:`Corda configuration files `. - -**Building from IntelliJ** - -Open the Gradle window by selecting ``View > Tool windows > Gradle`` from the main menu. You will see the Gradle window -open on the right hand side of the IDE. Expand `tasks` and then expand `other`. Double click on `deployNodes`. Gradle will -start the build process and output progress to a console window in the IDE. +* The **root directory** contains some gradle files, a README and a LICENSE +* **config** contains log4j configs +* **gradle** contains the gradle wrapper, which allows the use of Gradle without installing it yourself and worrying + about which version is required +* **lib** contains the Quasar jar which is required for runtime instrumentation of classes by Quasar +* **kotlin-source** contains the source code for the example CorDapp written in Kotlin + * **kotlin-source/src/main/kotlin** contains the source code for the example CorDapp + * **kotlin-source/src/main/python** contains a python script which accesses nodes via RPC + * **kotlin-source/src/main/resources** contains the certificate store, some static web content to be served by the + nodes and the WebServerPluginRegistry file + * **kotlin-source/src/test/kotlin** contains unit tests for the contracts and flows, and the driver to run the nodes + via IntelliJ +* **java-source** contains the same source code, but written in java. This is an aid for users who do not want to + develop in Kotlin, and serves as an example of how CorDapps can be developed in any language targeting the JVM Running the example CorDapp --------------------------- +There are two ways to run the example CorDapp: -To run the sample CorDapp navigate to the ``kotlin-source/build/nodes`` folder and execute the ``runnodes`` shell script with: +* Via the terminal +* Via IntelliJ -Unix: ``./runnodes`` or ``sh runnodes`` +We explain both below. -Windows: ``runnodes.bat`` +Terminal: Building the example CorDapp +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Return to your terminal window and make sure you are in the ``cordapp-tutorial`` directory. To build the example +CorDapp use the following command: -The ``runnodes`` scripts should create a terminal tab for each node. In each terminal tab, you'll see the Corda welcome -message and some pertinent config information, see below: +* Unix/Mac OSX: ``./gradlew deployNodes`` +* Windows: ``gradlew.bat deployNodes`` + +This will package up our CorDapp source files into a plugin and automatically build four pre-configured nodes that have +our CorDapp plugin installed. These nodes are meant for local testing only. + +After the build process has finished, you will see the newly-build nodes in the ``kotlin-source/build/nodes``. There +will be one folder generated for each node you built, plus a ``runnodes`` shell script (or batch file on Windows). + +.. note:: CorDapps can be written in any language targeting the JVM. In our case, we've provided the example source in + both Kotlin (``/kotlin-source/src``) and Java (``/java-source/src``) Since both sets of source files are + functionally identical, we will refer to the Kotlin build throughout the documentation. + +Each node in the ``nodes`` folder has the following structure: .. sourcecode:: none - ______ __ - / ____/ _________/ /___ _ - / / __ / ___/ __ / __ `/ Computer science and finance together. - / /___ /_/ / / / /_/ / /_/ / You should see our crazy Christmas parties! - \____/ /_/ \__,_/\__,_/ + . nodeName + ├── corda.jar + ├── node.conf + └── plugins - --- DEVELOPER SNAPSHOT ------------------------------------------------------------ +``corda.jar` is the Corda runtime, ``plugins`` contains our node's CorDapps, and our node's configuration is provided +in ``node.conf``. - Logs can be found in : /Users/rogerwillis/Documents/Corda/cordapp-tutorial/kotlin-source/build/nodes/nodea/logs - Database connection URL is : jdbc:h2:tcp://10.18.0.196:50661/node - Node listening on address : localhost:10004 - Loaded plugins : com.example.plugin.ExamplePlugin - Embedded web server is listening on : http://10.18.0.196:10005/ - Node started up and registered in 39.0 sec +Terminal: Running the example CorDapp +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +To run our nodes, run the following command from the root of the ``cordapp-tutorial`` folder: -You'll need to refer to the above later on for the JDBC connection string and port numbers. +* Unix/Mac OSX: ``kotlin-source/build/nodes/runnodes`` +* Windows: ``call kotlin-source\build\nodes\runnodes.bat`` -Depending on the speed of your machine, it usually takes around 30 seconds for the nodes to finish starting up. If you -want to double check all the nodes are running you can query the 'status' end-point located at -``http://host:post/api/status``. +On Unix/Mac OSX, do not click/change focus until all eight additional terminal windows have opened, or some nodes may +fail to start. -When booted up, the node will generate a bunch of files and directories in addition to the ones covered above: +The ``runnodes`` script creates a terminal tab/window for each node: .. sourcecode:: none - . - ├── artemis - ├── attachments - ├── cache - ├── certificates - ├── corda.jar - ├── identity-private-key - ├── identity-public - ├── logs - ├── node.conf - ├── persistence.mv.db - └── plugins + ______ __ + / ____/ _________/ /___ _ + / / __ / ___/ __ / __ `/ It's kind of like a block chain but + / /___ /_/ / / / /_/ / /_/ / cords sounded healthier than chains. + \____/ /_/ \__,_/\__,_/ -Notably: + --- Corda Open Source 0.12.1 (da47f1c) ----------------------------------------------- -* **artemis** contains the internal files for Artemis MQ, our message broker. -* **attachments** contains any persisted attachments. -* **certificates** contains the certificate store. -* **identity-private-key** is the node's private key. -* **identity-public** is the node's public key. -* **logs** contains the node's log files. -* **persistence.mv.db** is the h2 database where transactions and other data is persisted. + 📚 New! Training now available worldwide, see https://corda.net/corda-training/ -Additional files and folders are added as the node is running. + Logs can be found in : /Users/joeldudley/Desktop/cordapp-tutorial/kotlin-source/build/nodes/NodeA/logs + Database connection url is : jdbc:h2:tcp://10.163.199.132:54763/node + Listening on address : 127.0.0.1:10005 + RPC service listening on address : localhost:10006 + Loaded plugins : com.example.plugin.ExamplePlugin + Node for "NodeA" started up and registered in 35.0 sec -Running the example CorDapp via IntelliJ -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -To run the example CorDapp via IntelliJ you can use the ``Run Example CorDapp`` run configuration. Select it from the drop -down menu at the top right-hand side of the IDE and press the green arrow to start the nodes. See image below: + Welcome to the Corda interactive shell. + Useful commands include 'help' to see what is available, and 'bye' to shut down the node. + + Fri Jul 07 10:33:47 BST 2017>>> + +The script will also create a webserver terminal tab for each node: + +.. sourcecode:: none + + Logs can be found in /Users/joeldudley/Desktop/cordapp-tutorial/kotlin-source/build/nodes/NodeA/logs/web + Starting as webserver: localhost:10007 + Webserver started up in 42.02 sec + +Depending on your machine, it usually takes around 60 seconds for the nodes to finish starting up. If you want to +ensure that all the nodes are running OK, you can query the 'status' end-point located at +``http://localhost:[port]/api/status`` (e.g. ``http://localhost:10007/api/status`` for ``NodeA``). + +IntelliJ: Building and running the example CorDapp +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +To run the example CorDapp via IntelliJ you can use the ``Run Example CorDapp - Kotlin`` run configuration. Select it +from the drop-down menu at the top right-hand side of the IDE and press the green arrow to start the nodes: .. image:: resources/run-config-drop-down.png :width: 400 -The node driver defined in ``/src/main/kotlin-source/com/example/Main.kt`` allows you to specify how many nodes you would like -to run and the various configuration settings for each node. With the example CorDapp, the Node driver starts four nodes -and sets up an RPC user for all but the "Controller" node (which hosts the notary Service and network map service): +The node driver defined in ``/src/test/kotlin/com/example/Main.kt`` allows you to specify how many nodes you would like +to run and the configuration settings for each node. With the example CorDapp, the driver starts up four nodes +and adds an RPC user for all but the "Controller" node (which serves as the notary and network map service): .. sourcecode:: kotlin - fun main(args: Array) { - // No permissions required as we are not invoking flows. - val user = User("user1", "test", permissions = setOf()) - driver(dsl = { - startNode("Controller", setOf(ServiceInfo(ValidatingNotaryService.type))) - startNode("NodeA", rpcUsers = listOf(user)) - startNode("NodeB", rpcUsers = listOf(user)) - startNode("NodeC", rpcUsers = listOf(user)) - waitForAllNodesToFinish() - }, isDebug = true) - } + fun main(args: Array) { + // No permissions required as we are not invoking flows. + val user = User("user1", "test", permissions = setOf()) + driver(isDebug = true) { + startNode(X500Name("CN=Controller,O=R3,OU=corda,L=London,C=UK"), setOf(ServiceInfo(ValidatingNotaryService.type))) + val (nodeA, nodeB, nodeC) = Futures.allAsList( + startNode(X500Name("CN=NodeA,O=NodeA,L=London,C=UK"), rpcUsers = listOf(user)), + startNode(X500Name("CN=NodeB,O=NodeB,L=New York,C=US"), rpcUsers = listOf(user)), + startNode(X500Name("CN=NodeC,O=NodeC,L=Paris,C=FR"), rpcUsers = listOf(user))).getOrThrow() -To stop the nodes, press the red "stop" button at the top right-hand side of the IDE. + startWebserver(nodeA) + startWebserver(nodeB) + startWebserver(nodeC) -The node driver can also be used to as a basis for `debugging your CorDapp`_ + waitForAllNodesToFinish() + } + } + +To stop the nodes, press the red square button at the top right-hand side of the IDE, next to the run configurations. + +We'll look later at how the node driver can be useful for `debugging your CorDapp`_. Interacting with the example CorDapp ------------------------------------ Via HTTP ~~~~~~~~ +The CorDapp defines several HTTP API end-points and a web front-end. The end-points allow you to list your existing +IOUs, agree new IOUs, and see who is on the network. -The CorDapp defines a few HTTP API end-points and also serves some static web content. The end-points allow you to -list IOUs and add IOUs. +The nodes are running locally on the following ports: -The nodes can be found using the following port numbers, defined in build.gradle and the respective node.conf file for -each node found in `kotlin-source/build/nodes/NodeX`` etc: +* Controller: ``localhost:10004`` +* NodeA: ``localhost:10007`` +* NodeB: ``localhost:10010`` +* NodeC: ``localhost:10013`` -* Controller: ``localhost:10003`` -* NodeA: ``localhost:10005`` -* NodeB: ``localhost:10007`` -* NodeC: ``localhost:10009`` +These ports are defined in build.gradle and in each node's node.conf file under ``kotlin-source/build/nodes/NodeX``. -Note that the ``deployNodes`` Gradle task is used to generate the ``node.conf`` files for each node. - -As the nodes start-up they should tell you which host and port the embedded web server is running on. The API endpoints -served are as follows: +As the nodes start up, they should tell you which port their embedded web server is running on. The available API +endpoints are: * ``/api/example/me`` * ``/api/example/peers`` * ``/api/example/ious`` * ``/api/example/{COUNTERPARTY}/create-iou`` -The static web content is served from ``/web/example``. +The web front-end is served from ``/web/example``. -An IOU can be created via accessing the ``api/example/create-iou`` end-point directly or through the +An IOU can be created by sending a PUT request to the ``api/example/create-iou`` end-point directly, or by using the the web form hosted at ``/web/example``. - .. warning:: **The content in ``web/example`` is only available for demonstration purposes and does not implement any - anti-XSS, anti-XSRF or any other security techniques. Do not copy such code directly into products meant for production use.** +.. warning:: The content in ``web/example`` is only available for demonstration purposes and does not implement + anti-XSS, anti-XSRF or any other security techniques. Do not use this code in production. -**Submitting an IOU via HTTP API:** +**Creating an IOU via the HTTP API:** -To create an IOU from NodeA to NodeB, use: +To create an IOU between NodeA and NodeB, we would run the following from the command line: .. sourcecode:: bash - echo '{"value": "1"}' | cURL -T - -H 'Content-Type: application/json' http://localhost:10005/api/example/NodeB/create-iou + echo '{"value": "1"}' | cURL -T - -H 'Content-Type: application/json' http://localhost:10007/api/example/NodeB/create-iou -Note the port number ``10005`` (NodeA) and NodeB referenced in the API end-point path. This command instructs NodeA to -create and send an IOU to NodeB. Upon verification and completion of the process, both nodes (but not NodeC) will -have a signed, notarised copy of the IOU. +Note that both NodeA's port number (``10007``) and NodeB are referenced in the PUT request path. This command instructs +NodeA to agree an IOU with NodeB. Once the process is complete, both nodes will have a signed, notarised copy of the +IOU. NodeC will not. -**Submitting an IOU via web/example:** +**Submitting an IOU via the web front-end:** -Navigate to the "create IOU" button at the top left of the page, and enter the IOU details - e.g. +Navigate to ``/web/example``, click the "create IOU" button at the top-left of the page, and enter the IOU details into +the web-form. The IOU must have a value of between 1 and 99. .. sourcecode:: none Counter-party: Select from list - Order Number: 1 - Delivery Date: 2018-09-15 - City: London - Country Code: GB - Item name: Wow such item - Item amount: 5 + Value (Int): 5 -and click submit (note you can add additional item types and amounts). Upon pressing submit, the modal dialogue should close. -To check what validation is performed over the IOU data, have a look at the ``IOUContract.Create`` class in -``IOUContract.kt`` which defines the following contract constraints (among others not included here): - -.. sourcecode:: kotlin - - // Generic constraints around the IOU transaction. - "No inputs should be consumed when issuing an IOU." using (tx.inputs.isEmpty()) - "Only one output state should be created." using (tx.outputs.size == 1) - val out = tx.outputs.single() as IOUState - "The sender and the recipient cannot be the same entity." using (out.sender != out.recipient) - "All of the participants must be signers." using (command.signers.containsAll(out.participants)) - - // IOU-specific constraints. - "The IOU's value must be non-negative." using (out.iou.value > 0) +And click submit. Upon clicking submit, the modal dialogue will close, and the nodes will agree the IOU. **Once an IOU has been submitted:** -Inspect the terminal windows for the nodes. Assume all of the above contract constraints are met, you should see some -activity in the terminal windows for NodeA and NodeB (note: the green ticks are only visible on unix/mac): - -*NodeA:* +Assuming all went well, you should see some activity in NodeA's web-server terminal window: .. sourcecode:: none - ✅ Generating transaction based on new IOU. - ✅ Verifying contract constraints. - ✅ Signing transaction with our private key. - ✅ Sending proposed transaction to recipient for review. - ✅ Done + >> Generating transaction based on new IOU. + >> Verifying contract constraints. + >> Signing transaction with our private key. + >> Gathering the counterparty's signature. + >> Structural step change in child of Gathering the counterparty's signature. + >> Collecting signatures from counter-parties. + >> Verifying collected signatures. + >> Done + >> Obtaining notary signature and recording transaction. + >> Structural step change in child of Obtaining notary signature and recording transaction. + >> Requesting signature by notary service + >> Broadcasting transaction to participants + >> Done + >> Done -*NodeB:* - -.. sourcecode:: none - - ✅ Receiving proposed transaction from sender. - ✅ Verifying signatures and contract constraints. - ✅ Signing proposed transaction with our private key. - ✅ Obtaining notary signature and recording transaction. - ✅ Requesting signature by notary service - ✅ Requesting signature by Notary service - ✅ Validating response from Notary service - ✅ Broadcasting transaction to participants - ✅ Done - -*NodeC:* - -.. sourcecode:: none - - You shouldn't see any activity. - -Next you can view the newly created IOU by accessing the vault of NodeA or NodeB: +You can view the newly-created IOU by accessing the vault of NodeA or NodeB: *Via the HTTP API:* -For NodeA. navigate to http://localhost:10005/api/example/ious. For NodeB, -navigate to http://localhost:10007/api/example/ious. +* NodeA's vault: Navigate to http://localhost:10007/api/example/ious +* NodeB's vault: Navigate to http://localhost:10010/api/example/ious *Via web/example:* -Navigate to http://localhost:10005/web/example the refresh button in the top left-hand side of the page. You should -see the newly created agreement on the page. +* NodeA: Navigate to http://localhost:10007/web/example and hit the "refresh" button +* NodeA: Navigate to http://localhost:10010/web/example and hit the "refresh" button -Using the h2 web console -~~~~~~~~~~~~~~~~~~~~~~~~ +If you access the vault or web front-end of NodeC (on ``localhost:10013``), there will be no IOUs. This is because +NodeC was not involved in this transaction. -You can connect to the h2 database to see the current state of the ledger, among other data such as the current state of -the network map cache. Firstly, navigate to the folder where you downloaded the h2 web console as part of the pre-requisites -section, above. Change directories to the bin folder: +Via the interactive shell (terminal only) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Once a node has been started via the terminal, it will display an interactive shell: -``cd h2/bin`` +.. sourcecode:: none -Where there are a bunch of shell scripts and batch files. Run the web console: + Welcome to the Corda interactive shell. + Useful commands include 'help' to see what is available, and 'bye' to shut down the node. -Unix: + Fri Jul 07 16:36:29 BST 2017>>> -``sh h2.sh`` +You can see a list of the flows that your node can run using `flow list`. In our case, this will return the following +list: -Windows: +.. sourcecode:: none -``h2.bat`` + com.example.flow.ExampleFlow$Initiator + net.corda.flows.CashExitFlow + net.corda.flows.CashIssueFlow + net.corda.flows.CashPaymentFlow + net.corda.flows.ContractUpgradeFlow -The h2 web console should start up in a web browser tab. To connect we first need to obtain a JDBC connection string. Each -node outputs its connection string in the terminal window as it starts up. In a terminal window where a node is running, -look for the following string: +We can create a new IOU using the ``ExampleFlow$Initiator`` flow. For example, from the interactive shell of NodeA, you +can agree an IOU of 50 with NodeB by running ``flow start Initiator iouValue: 50, otherParty: NodeB``. -``Database connection URL is : jdbc:h2:tcp://10.18.0.150:56736/node`` +This will print out the following progress steps: -you can use the string on the right to connect to the h2 database: just paste it in to the JDBC URL field and click Connect. -You will be presented with a web application that enumerates all the available tables and provides an interface for you to -query them using SQL. +.. sourcecode:: none -Using the Example RPC client + ✅ Generating transaction based on new IOU. + ✅ Verifying contract constraints. + ✅ Signing transaction with our private key. + ✅ Gathering the counterparty's signature. + ✅ Collecting signatures from counter-parties. + ✅ Verifying collected signatures. + ✅ Obtaining notary signature and recording transaction. + ✅ Requesting signature by notary service + Requesting signature by Notary service + Validating response from Notary service + ✅ Broadcasting transaction to participants + ✅ Done + +We can also issue RPC operations to the node via the interactive shell. Type ``run`` to see the full list of available +operations. + +We can see a list of the states in our node's vault using ``run vaultAndUpdates``: + +.. sourcecode:: none + + --- + first: + - state: + data: + iou: + value: 50 + sender: "CN=NodeB,O=NodeB,L=New York,C=US" + recipient: "CN=NodeA,O=NodeA,L=London,C=UK" + linearId: + externalId: null + id: "84628565-2688-45ef-bb06-aae70fcf3be7" + contract: + legalContractReference: "4DDE2A47C361106CBAEC06CC40FE418A994822A3C8054851FEECD51207BFAF82" + participants: + - "CN=NodeB,O=NodeB,L=New York,C=US" + - "CN=NodeA,O=NodeA,L=London,C=UK" + notary: "CN=Controller,O=R3,OU=corda,L=London,C=UK,OU=corda.notary.validating" + encumbrance: null + ref: + txhash: "52A1B18E6ABD535EF36B2075469B01D2EF888034F721C4BECD26F40355C8C9DC" + index: 0 + second: "(observable)" + +We can also see the transactions stored in our node's local storage using ``run verifiedTransactions`` (we've +abbreviated the output below): + +.. sourcecode:: none + + first: + - txBits: "Y29yZGEAAAEOAQEAamF2YS51dGlsLkFycmF5TGlz9AABAAABAAEBAW5ldC5jb3JkYS5jb3JlLmNvbnRyYWN0cy5UcmFuc2FjdGlvblN0YXTlA1RyYW5zYWN0aW9uU3RhdGUuZGF04VRyYW5zYWN0aW9uU3RhdGUuZW5jdW1icmFuY+VUcmFuc2FjdGlvblN0YXRlLm5vdGFy+WkBAmNvbS5leGFtcGxlLnN0YXRlLklPVVN0YXTlBElPVVN0YXRlLmlv9UlPVVN0YXRlLmxpbmVhcknkSU9VU3RhdGUucmVjaXBpZW70SU9VU3RhdGUuc2VuZGXyDQEBSU9VLnZhbHXlAWQCAQA0ADIBAlVuaXF1ZUlkZW50aWZpZXIuZXh0ZXJuYWxJ5FVuaXF1ZUlkZW50aWZpZXIuaeQBgDAvAC0BAlVVSUQubGVhc3RTaWdCaXTzVVVJRC5tb3N0U2lnQml08wmxkIaDnsaq+YkNDAsACaHovZfbpr2d9wMCAQACAQBIAEYBAkFic3RyYWN0UGFydHkub3duaW5nS2X5UGFydHkubmFt5SIuIOnhdbFQY3EL/LQD90w6y+kCfj4x8UWXaqKtW68GBPlnREMAQTkwPjEOMAwGA1UEAwwFTm9kZUExDjAMBgNVBAoMBU5vZGVBMQ8wDQYDVQQHDAZMb25kb24xCzAJBgNVBAYTAlVLAgEAJgAkASIuIHI7goTSxPMdaRgJgGJVLQbFEzE++qJeYbEbQjrYxzuVRkUAQzkwQDEOMAwGA1UEAwwFTm9kZUIxDjAMBgNVBAoMBU5vZGVCMREwDwYDVQQHDAhOZXcgWW9yazELMAkGA1UEBhMCVVMCAQABAAABAAAkASIuIMqulslvpZ0PaM6fdyFZm+JsDGkuJ7xWnL3zB6PqpzANdwB1OTByMRMwEQYDVQQDDApDb250cm9sbGVyMQswCQYDVQQKDAJSMzEOMAwGA1UECwwFY29yZGExDzANBgNVBAcMBkxvbmRvbjELMAkGA1UEBhMCVUsxIDAeBgNVBAsMF2NvcmRhLm5vdGFyeS52YWxpZGF0aW5nAQAAAQABAQNuZXQuY29yZGEuY29yZS5jb250cmFjdHMuQ29tbWFu5AJDb21tYW5kLnNpZ25lcvNDb21tYW5kLnZhbHXlRwEAAi4gcjuChNLE8x1pGAmAYlUtBsUTMT76ol5hsRtCOtjHO5UuIOnhdbFQY3EL/LQD90w6y+kCfj4x8UWXaqKtW68GBPlnADMBBGNvbS5leGFtcGxlLmNvbnRyYWN0LklPVUNvbnRyYWN0JENvbW1hbmRzJENyZWF05QAAAQVuZXQuY29yZGEuY29yZS5pZGVudGl0eS5QYXJ0+SIuIMqulslvpZ0PaM6fdyFZm+JsDGkuJ7xWnL3zB6PqpzANAHU5MHIxEzARBgNVBAMMCkNvbnRyb2xsZXIxCzAJBgNVBAoMAlIzMQ4wDAYDVQQLDAVjb3JkYTEPMA0GA1UEBwwGTG9uZG9uMQswCQYDVQQGEwJVSzEgMB4GA1UECwwXY29yZGEubm90YXJ5LnZhbGlkYXRpbmcAAQACLiByO4KE0sTzHWkYCYBiVS0GxRMxPvqiXmGxG0I62Mc7lS4g6eF1sVBjcQv8tAP3TDrL6QJ+PjHxRZdqoq1brwYE+WcBBm5ldC5jb3JkYS5jb3JlLmNvbnRyYWN0cy5UcmFuc2FjdGlvblR5cGUkR2VuZXJh7AA=" + sigs: + - "cRgJlF8cUMMooyaV2OIKmR4/+3XmMsEPsbdlhU5YqngRhqgy9+tLzylh7kvWOhYZ4hjjOfrazLoZ6uOx6BAMCQ==" + - "iGLRDIbhlwguMz6yayX5p6vfQcAsp8haZc1cLGm7DPDIgq6hFyx2fzoI03DjXAV/mBT1upcUjM9UZ4gbRMedAw==" + id: "52A1B18E6ABD535EF36B2075469B01D2EF888034F721C4BECD26F40355C8C9DC" + tx: + inputs: [] + attachments: [] + outputs: + - data: + iou: + value: 50 + sender: "CN=NodeB,O=NodeB,L=New York,C=US" + recipient: "CN=NodeA,O=NodeA,L=London,C=UK" + linearId: + externalId: null + id: "84628565-2688-45ef-bb06-aae70fcf3be7" + contract: + legalContractReference: "4DDE2A47C361106CBAEC06CC40FE418A994822A3C8054851FEECD51207BFAF82" + participants: + - "CN=NodeB,O=NodeB,L=New York,C=US" + - "CN=NodeA,O=NodeA,L=London,C=UK" + notary: "CN=Controller,O=R3,OU=corda,L=London,C=UK,OU=corda.notary.validating" + encumbrance: null + commands: + - value: {} + signers: + - "8Kqd4oWdx4KQAVc3u5qvHZTGJxMtrShFudAzLUTdZUzbF9aPQcCZD5KXViC" + - "8Kqd4oWdx4KQAVcBx98LBHwXwC3a7hNptQomrg9mq2ScY7t1Qqsyk5dCNAr" + notary: "CN=Controller,O=R3,OU=corda,L=London,C=UK,OU=corda.notary.validating" + type: {} + timeWindow: null + mustSign: + - "8Kqd4oWdx4KQAVc3u5qvHZTGJxMtrShFudAzLUTdZUzbF9aPQcCZD5KXViC" + - "8Kqd4oWdx4KQAVcBx98LBHwXwC3a7hNptQomrg9mq2ScY7t1Qqsyk5dCNAr" + id: "52A1B18E6ABD535EF36B2075469B01D2EF888034F721C4BECD26F40355C8C9DC" + merkleTree: ... + availableComponents: ... + availableComponentHashes: ... + serialized: "Y29yZGEAAAEOAQEAamF2YS51dGlsLkFycmF5TGlz9AABAAABAAEBAW5ldC5jb3JkYS5jb3JlLmNvbnRyYWN0cy5UcmFuc2FjdGlvblN0YXTlA1RyYW5zYWN0aW9uU3RhdGUuZGF04VRyYW5zYWN0aW9uU3RhdGUuZW5jdW1icmFuY+VUcmFuc2FjdGlvblN0YXRlLm5vdGFy+WkBAmNvbS5leGFtcGxlLnN0YXRlLklPVVN0YXTlBElPVVN0YXRlLmlv9UlPVVN0YXRlLmxpbmVhcknkSU9VU3RhdGUucmVjaXBpZW70SU9VU3RhdGUuc2VuZGXyDQEBSU9VLnZhbHXlAWQCAQA0ADIBAlVuaXF1ZUlkZW50aWZpZXIuZXh0ZXJuYWxJ5FVuaXF1ZUlkZW50aWZpZXIuaeQBgDAvAC0BAlVVSUQubGVhc3RTaWdCaXTzVVVJRC5tb3N0U2lnQml08wmxkIaDnsaq+YkNDAsACaHovZfbpr2d9wMCAQACAQBIAEYBAkFic3RyYWN0UGFydHkub3duaW5nS2X5UGFydHkubmFt5SIuIOnhdbFQY3EL/LQD90w6y+kCfj4x8UWXaqKtW68GBPlnREMAQTkwPjEOMAwGA1UEAwwFTm9kZUExDjAMBgNVBAoMBU5vZGVBMQ8wDQYDVQQHDAZMb25kb24xCzAJBgNVBAYTAlVLAgEAJgAkASIuIHI7goTSxPMdaRgJgGJVLQbFEzE++qJeYbEbQjrYxzuVRkUAQzkwQDEOMAwGA1UEAwwFTm9kZUIxDjAMBgNVBAoMBU5vZGVCMREwDwYDVQQHDAhOZXcgWW9yazELMAkGA1UEBhMCVVMCAQABAAABAAAkASIuIMqulslvpZ0PaM6fdyFZm+JsDGkuJ7xWnL3zB6PqpzANdwB1OTByMRMwEQYDVQQDDApDb250cm9sbGVyMQswCQYDVQQKDAJSMzEOMAwGA1UECwwFY29yZGExDzANBgNVBAcMBkxvbmRvbjELMAkGA1UEBhMCVUsxIDAeBgNVBAsMF2NvcmRhLm5vdGFyeS52YWxpZGF0aW5nAQAAAQABAQNuZXQuY29yZGEuY29yZS5jb250cmFjdHMuQ29tbWFu5AJDb21tYW5kLnNpZ25lcvNDb21tYW5kLnZhbHXlRwEAAi4gcjuChNLE8x1pGAmAYlUtBsUTMT76ol5hsRtCOtjHO5UuIOnhdbFQY3EL/LQD90w6y+kCfj4x8UWXaqKtW68GBPlnADMBBGNvbS5leGFtcGxlLmNvbnRyYWN0LklPVUNvbnRyYWN0JENvbW1hbmRzJENyZWF05QAAAQVuZXQuY29yZGEuY29yZS5pZGVudGl0eS5QYXJ0+SIuIMqulslvpZ0PaM6fdyFZm+JsDGkuJ7xWnL3zB6PqpzANAHU5MHIxEzARBgNVBAMMCkNvbnRyb2xsZXIxCzAJBgNVBAoMAlIzMQ4wDAYDVQQLDAVjb3JkYTEPMA0GA1UEBwwGTG9uZG9uMQswCQYDVQQGEwJVSzEgMB4GA1UECwwXY29yZGEubm90YXJ5LnZhbGlkYXRpbmcAAQACLiByO4KE0sTzHWkYCYBiVS0GxRMxPvqiXmGxG0I62Mc7lS4g6eF1sVBjcQv8tAP3TDrL6QJ+PjHxRZdqoq1brwYE+WcBBm5ldC5jb3JkYS5jb3JlLmNvbnRyYWN0cy5UcmFuc2FjdGlvblR5cGUkR2VuZXJh7AA=" + second: "(observable)" + +The same states and transactions will be present on NodeB, who was NodeA's counterparty in the creation of the IOU. +However, the vault and local storage of NodeC will remain empty, since NodeC was not involved in the transaction. + +Via the h2 web console +~~~~~~~~~~~~~~~~~~~~~~ +You can connect directly to your node's database to see its stored states, transactions and attachments. To do so, +please follow the instructions in :doc:`node-database`. + +Using the example RPC client ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The ``/src/main/kotlin-source/com/example/client/ExampleClientRPC.kt`` file is a simple utility that uses the client +RPC library to connect to a node. It will log any existing IOUs and listen for any future IOUs. If you haven't created +any IOUs when you first connect to one of the nodes, the client will simply log any future IOUs that are agreed. -The ``/src/main/kotlin-source/com/example/client/ExampleClientRPC.kt`` file is a simple utility which uses the client RPC library -to connect to a node and log the created IOUs. It will log any existing IOUs and listen for any future -IOUs. If you haven't created any IOUs when you connect to one of the Nodes via RPC then the client will log -and future IOUs which are agreed. +*Running the client via IntelliJ:* -To build the client use the following gradle task: - -``./gradlew runExampleClientRPC`` - -*To run the client, via IntelliJ:* - -Select the 'Run Example RPC Client' run configuration which, by default, connects to NodeA (Artemis port 10004). Click the +Select the 'Run Example RPC Client' run configuration which, by default, connects to NodeA (Artemis port 10007). Click the Green Arrow to run the client. You can edit the run configuration to connect on a different port. -*Via command line:* +*Running the client via the command line:* Run the following gradle task: -``./gradlew runExampleClientRPC localhost:10004`` +``./gradlew runExampleClientRPC localhost:10007`` -To close the application use ``ctrl+C``. For more information on the client RPC interface and how to build an RPC client -application see: +You can close the application using ``ctrl+C``. + +For more information on the client RPC interface and how to build an RPC client application, see: * :doc:`Client RPC documentation ` * :doc:`Client RPC tutorial ` -Extending the example CorDapp ------------------------------ - -cordapp-tutorial project structure -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The example CorDapp has the following directory structure: - -.. sourcecode:: none - - . cordapp-tutorial - ├── README.md - ├── LICENSE - ├── build.gradle - ├── config - │   ├── ... - ├── gradle - │   └── ... - ├── gradle.properties - ├── gradlew - ├── gradlew.bat - ├── lib - │   ├── ... - ├── settings.gradle - ├── kotlin-source - │ └── src - │ ├── main - │    │   ├── kotlin - │    │   │   └── com - │    │   │   └── example - │    │   │   ├── Main.kt - │    │   │   ├── api - │    │   │   │   └── ExampleApi.kt - │    │   │   ├── client - │    │   │   │   └── ExampleClientRPC.kt - │    │   │   ├── contract - │    │   │   │   ├── IOUContract.kt - │    │   │   │   └── IOUState.kt - │    │   │   ├── model - │    │   │   │   └── IOU.kt - │    │   │   ├── plugin - │    │   │   │   └── ExamplePlugin.kt - │    │   │   └── flow - │    │   │   └── ExampleFlow.kt - │    │   │   └── service - │    │   │   └── ExampleService.kt - │ │   ├── python - │ │   └── resources - │ │   ├── META-INF - │ │   │   └── services - │   │   │   ├── net.corda.core.node.CordaPluginRegistry - │   │ │ └── net.corda.webserver.services.WebServerPluginRegistry - │ │   ├── certificates - │ │   │   ├── readme.txt - │ │   │   ├── sslkeystore.jks - │ │   │   └── truststore.jks - │ │   └── exampleWeb - │ │   ├── index.html - │ │   └── js - │ │   └── example.js - │ └── test - │ ├── java - │ ├── kotlin - │ │   └── com - │ │   └── example - │ │   └── ExampleTest.kt - │    └── resources - └── java-source - └── src - ├── main -    │   ├── java -    │   │   └── com -    │   │   └── example -    │   │   ├── Main.java -    │   │   ├── api -    │   │   │   └── ExampleApi.java -    │   │   ├── client -    │   │   │   └── ExampleClientRPC.java -    │   │   ├── contract -    │   │   │   ├── IOUContract.java -    │   │   │   └── IOUState.java -    │   │   ├── model -    │   │   │   └── IOU.java -    │   │   ├── plugin -    │   │   │   └── ExamplePlugin.java -    │   │   └── flow -    │   │   └── ExampleFlow.java -    │   │   └── service -    │   │   └── ExampleService.java - │   ├── python - │   └── resources - │   ├── META-INF - │   │   └── services -    │   │   ├── net.corda.core.node.CordaPluginRegistry -    │ │ └── net.corda.webserver.services.WebServerPluginRegistry - │   ├── certificates - │   │   ├── readme.txt - │   │   ├── sslkeystore.jks - │   │   └── truststore.jks - │   └── exampleWeb - │   ├── index.html - │   └── js - │   └── example.js - └── test - ├── java - ├── kotlin - │   └── com - │   └── example - │   └── ExampleTest.kt -    └── resources - -In the file structure above, the most important files and directories to note are: - -* The **root directory** contains some gradle files, a README and a LICENSE. -* **config** contains log4j configs. -* **gradle** contains the gradle wrapper, which allows the use of Gradle without installing it yourself and worrying - about which version is required. -* **lib** contains the Quasar.jar which is required for runtime instrumentation of classes by Quasar. -* **kotlin-source** contains the source code for the example CorDapp written in Kotlin. - * **kotlin-source/src/main/kotlin** contains the source code for the example CorDapp. - * **kotlin-source/src/main/python** contains a python script which accesses nodes via RPC. - * **kotlin-source/src/main/resources** contains the certificate store, some static web content to be served by the nodes and the - PluginServiceRegistry file. - * **kotlin-source/src/test/kotlin** contains unit tests for protocols, contracts, etc. -* **java-source** contains the same source code, but written in java. This is an aid for users who do not want to develop in - Kotlin, and serves as an example of how CorDapps can be developed in any language targeting the JVM. - -Some elements are covered in more detail below. - -The build.gradle file -~~~~~~~~~~~~~~~~~~~~~ - -It is usually necessary to make a couple of changes to the ``build.gradle`` file. Here will cover the most pertinent bits. - -**The buildscript** - -The buildscript is always located at the top of the file. It determines which plugins, task classes, and other classes -are available for use in the rest of the build script. It also specifies version numbers for dependencies, among other -things. - -If you are working from a Corda SNAPSHOT release which you have publish to Maven local then ensure that -``corda_release_version`` is the same as the version of the Corda core modules you published to Maven local. If not then -change the ``kotlin_version`` property. Also, if you are working from a previous cordapp-tutorial milestone release, then -be sure to ``git checkout`` the correct version of the example CorDapp from the ``cordapp-tutorial`` repo. - -.. sourcecode:: groovy - - buildscript { - ext.kotlin_version = '1.0.4' - ext.corda_release_version = '0.5-SNAPSHOT' // Ensure this version is the same as the corda core modules you are using. - ext.quasar_version = '0.7.6' - ext.jersey_version = '2.23.1' - - repositories { - ... - } - - dependencies { - ... - } - } - -**Project dependencies** - -If you have any additional external dependencies for your CorDapp then add them below the comment at the end of this -code snippet.package. Use the standard format: - -``compile "{groupId}:{artifactId}:{versionNumber}"`` - -.. sourcecode:: groovy - - dependencies { - compile "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" - testCompile group: 'junit', name: 'junit', version: '4.11' - - // Corda integration dependencies - compile "net.corda:client:$corda_release_version" - compile "net.corda:core:$corda_release_version" - compile "net.corda:contracts:$corda_release_version" - compile "net.corda:node:$corda_release_version" - compile "net.corda:corda:$corda_release_version" - compile "net.corda:test-utils:$corda_release_version" - - ... - - // Cordapp dependencies - // Specify your cordapp's dependencies below, including dependent cordapps - } - -For further information about managing dependencies with `look at the Gradle docs `_. - -**CordFormation** - -This is the local node deployment system for CorDapps, the nodes generated are intended to be used for experimenting, -debugging, and testing node configurations but not intended for production or testnet deployment. - -In the CorDapp build.gradle file you'll find a ``deployNodes`` task, this is where you configure the nodes you would -like to deploy for testing. See further details below: - -.. sourcecode:: groovy - - task deployNodes(type: com.r3corda.plugins.Cordform, dependsOn: ['jar']) { - directory "./kotlin-source/build/nodes" // The output directory. - networkMap "CN=Controller,O=R3,OU=corda,L=London,C=GB" // The distinguished name of the node to be used as the network map. - node { - name "CN=Controller,O=R3,OU=corda,L=London,C=GB" // Distinguished name of node to be deployed. - advertisedServices = ["corda.notary.validating"] // A list of services you wish the node to offer. - p2pPort 10002 - rpcPort 10003 // Usually 1 higher than the messaging port. - webPort 10004 // Usually 1 higher than the RPC port. - cordapps = [] // Add package names of CordaApps. - } - node { // Create an additional node. - name "CN=NodeA,O=R3,OU=corda,L=London,C=GB" - advertisedServices = [] - p2pPort 10005 - rpcPort 10006 - webPort 10007 - cordapps = [] - } - ... - } - -You can add any number of nodes, with any number of services / CorDapps by editing the templates in ``deployNodes``. The -only requirement is that you must specify a node to run as the network map service and one as the notary service. - -.. note:: CorDapps in the current cordapp-tutorial project are automatically registered with all nodes defined in - ``deployNodes``, although we expect this to change in the near future. - -.. warning:: Make sure that there are no port clashes! - -When you are finished editing your *CordFormation* the changes will be reflected the next time you run ``./gradlew deployNodes``. - -Service Provider Configuration File -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you are building a CorDapp from scratch or adding a new CorDapp to the cordapp-tutorial project then you must provide -a reference to your sub-class of ``CordaPluginRegistry`` or ``WebServerPluginRegistry`` (for Wep API) in the provider-configuration file -located in the ``resources/META-INF/services`` directory. - -Re-Deploying Your Nodes Locally -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -If you need to create any additional nodes you can do it via the ``build.gradle`` file as discussed above in -``the build.gradle file`` and in more detail in the "cordFormation" section. - -You may also wish to edit the ``/kotlin-source/build/nodes//node.conf`` files for your nodes. For more information on -doing this, see the :doc:`Corda configuration file ` page. - -Once you have made some changes to your CorDapp you can redeploy it with the following command: - -Unix/Mac OSX: ``./gradlew deployNodes`` - -Windows: ``gradlew.bat deployNodes`` - Running Nodes Across Machines -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The nodes can also be set up to communicate between separate machines on the -same subnet. +----------------------------- +The nodes can also be configured to communicate across the network when residing on different machines. After deploying the nodes, navigate to the build folder (``kotlin-source/build/nodes`` or ``java-source/build/nodes``) -and move some of the individual node folders to separate machines on the same subnet (e.g. using a USB key). -It is important that no nodes - including the controller node - end up on more than one machine. Each computer -should also have a copy of ``runnodes`` and ``runnodes.bat``. +and move some of the individual node folders to separate machines (e.g. using a USB key). It is important that none of +the nodes - including the controller node - end up on more than one machine. Each computer should also have a copy of +``runnodes`` and ``runnodes.bat``. For example, you may end up with the following layout: @@ -843,45 +560,49 @@ You must now edit the configuration file for each node, including the controller and make the following changes: * Change the Artemis messaging address to the machine's IP address (e.g. ``p2pAddress="10.18.0.166:10006"``) -* Change the network map service details to the IP address of the machine where the - controller node is running and to its legal name (e.g. ``networkMapService.address="10.18.0.166:10002"`` and - ``networkMapService.legalName=controller``) (please note that the controller will not have the ``networkMapService`` config) +* Change the network map service's address to the IP address of the machine where the controller node is running + (e.g. ``networkMapService { address="10.18.0.166:10002" ...``). The controller will not have the + ``networkMapService`` config -Now run each node. +After starting each node, they should be able to see one another and agree IOUs among themselves. Debugging your CorDapp -~~~~~~~~~~~~~~~~~~~~~~ +---------------------- +Debugging is done via IntelliJ as follows: -Debugging is done via IntelliJ and can be done using the following steps. - -1. Set your breakpoints. -2. Edit the node driver code in ``Main.kt`` to reflect how many nodes you wish to start along with any other - configuration options. For example, the below starts 4 nodes, with one being the network map service / notary and - sets up RPC credentials for 3 of the nodes. +1. Edit the node driver code in ``Main.kt`` to reflect the number of nodes you wish to start, along with any other + configuration options. For example, the code below starts 4 nodes, with one being the network map service and notary. + It also sets up RPC credentials for the three non-notary nodes .. sourcecode:: kotlin fun main(args: Array) { // No permissions required as we are not invoking flows. val user = User("user1", "test", permissions = setOf()) - driver(dsl = { - startNode("Controller", setOf(ServiceInfo(ValidatingNotaryService.type))) - startNode("NodeA", rpcUsers = listOf(user)) - startNode("NodeB", rpcUsers = listOf(user)) - startNode("NodeC", rpcUsers = listOf(user)) + driver(isDebug = true) { + startNode(X500Name("CN=Controller,O=R3,OU=corda,L=London,C=UK"), setOf(ServiceInfo(ValidatingNotaryService.type))) + val (nodeA, nodeB, nodeC) = Futures.allAsList( + startNode(X500Name("CN=NodeA,O=NodeA,L=London,C=UK"), rpcUsers = listOf(user)), + startNode(X500Name("CN=NodeB,O=NodeB,L=New York,C=US"), rpcUsers = listOf(user)), + startNode(X500Name("CN=NodeC,O=NodeC,L=Paris,C=FR"), rpcUsers = listOf(user))).getOrThrow() + + startWebserver(nodeA) + startWebserver(nodeB) + startWebserver(nodeC) + waitForAllNodesToFinish() - }, isDebug = true) + } } -3. Select and run the “Run Example CorDapp” run configuration in IntelliJ. -4. IntelliJ will build and run the CorDapp. Observe the console output for the remote debug ports. The “Controller” - node will generally be on port 5005, with NodeA on port 5006 an so-on. +2. Select and run the “Run Example CorDapp” run configuration in IntelliJ +3. IntelliJ will build and run the CorDapp. The remote debug ports for each node will be automatically generated and + printed to the terminal. For example: .. sourcecode:: none - Listening for transport dt_socket at address: 5008 - Listening for transport dt_socket at address: 5007 - Listening for transport dt_socket at address: 5006 + [INFO ] 15:27:59.533 [main] Node.logStartupInfo - Working Directory: /Users/joeldudley/cordapp-tutorial/build/20170707142746/NodeA + [INFO ] 15:27:59.533 [main] Node.logStartupInfo - Debug port: dt_socket:5007 -5. Edit the “Debug CorDapp” run configuration with the port of the node you wish to connect to. -6. Run the “Debug CorDapp” run configuration. +4. Edit the “Debug CorDapp” run configuration with the port of the node you wish to connect to +5. Run the “Debug CorDapp” run configuration +6. Set your breakpoints and start using your node. When your node hits a breakpoint, execution will pause diff --git a/docs/source/tutorials-index.rst b/docs/source/tutorials-index.rst index bd92064680..faaa498335 100644 --- a/docs/source/tutorials-index.rst +++ b/docs/source/tutorials-index.rst @@ -5,6 +5,7 @@ Tutorials :maxdepth: 1 hello-world-index + tut-two-party-index tutorial-contract tutorial-contract-clauses tutorial-test-dsl diff --git a/docs/source/using-a-notary.rst b/docs/source/using-a-notary.rst index 650755247b..08517c6eaa 100644 --- a/docs/source/using-a-notary.rst +++ b/docs/source/using-a-notary.rst @@ -39,7 +39,7 @@ Then we initialise the transaction builder: .. sourcecode:: kotlin - val builder: TransactionBuilder = TransactionType.General.Builder(notary = ourNotary) + val builder: TransactionBuilder = TransactionBuilder(notary = ourNotary) For any output state we add to this transaction builder, ``ourNotary`` will be assigned as its notary. Next we create a state object and assign ourselves as the owner. For this example we'll use a @@ -64,7 +64,7 @@ We then sign the transaction, build and record it to our transaction storage: .. sourcecode:: kotlin val mySigningKey: PublicKey = serviceHub.legalIdentityKey - val issueTransaction = serviceHub.signInitialTransaction(issueTransaction, mySigningKey) + val issueTransaction = serviceHub.toSignedTransaction(issueTransaction, mySigningKey) serviceHub.recordTransactions(issueTransaction) The transaction is recorded and we now have a state (asset) in possession that we can transfer to someone else. Note @@ -98,7 +98,7 @@ Again we sign the transaction, and build it: // We build it and add our default identity signature without checking if all signatures are present, // Note we know that the notary signature is missing, so thie SignedTransaction is still partial. - val moveTransaction = serviceHub.signInitialTransaction(moveTransactionBuilder) + val moveTransaction = serviceHub.toSignedTransaction(moveTransactionBuilder) Next we need to obtain a signature from the notary for the transaction to be valid. Prior to signing, the notary will commit our old (input) state so it cannot be used again. diff --git a/docs/source/versioning.rst b/docs/source/versioning.rst index 61ce417263..9f5967e1ef 100644 --- a/docs/source/versioning.rst +++ b/docs/source/versioning.rst @@ -31,14 +31,42 @@ for the network. Flow versioning --------------- -A platform which can be extended with CorDapps also requires the ability to version these apps as they evolve from -release to release. This allows users of these apps, whether they're other nodes or RPC users, to select which version -they wish to use and enables nodes to control which app versions they support. Flows have their own version numbers, -independent of other versioning, for example of the platform. In particular it is the initiating flow that can be versioned -using the ``version`` property of the ``InitiatingFlow`` annotation. This assigns an integer version number, similar in -concept to the platform version, which is used in the session handshake process when a flow communicates with another party -for the first time. The other party will only accept the session request if it, firstly, has that flow loaded, and secondly, -for the same version (see also :doc:`flow-state-machine`). +In addition to the evolution of the platform, flows that run on top of the platform can also evolve. It may be that the +flow protocol between an initiating flow and it's intiated flow changes from one CorDapp release to the next in such as +way to be backwards incompatible with existing flows. For example, if a sequence of sends and receives needs to change +or if the semantics of a particular receive changes. + +The ``InitiatingFlow`` annotation (see :doc:`flow-state-machine` for more information on the flow annotations) has a ``version`` +property, which if not specified defaults to 1. This flow version is included in the flow session handshake and exposed +to both parties in the communication via ``FlowLogic.getFlowContext``. This takes in a ``Party`` and will return a +``FlowContext`` object which describes the flow running on the other side. In particular it has the ``flowVersion`` property +which can be used to programmatically evolve flows across versions. + +.. container:: codeset + + .. sourcecode:: kotlin + + @Suspendable + override fun call() { + val flowVersionOfOtherParty = getFlowContext(otherParty).flowVersion + val receivedString = if (flowVersionOfOtherParty == 1) { + receive(otherParty).unwrap { it.toString() } + } else { + receive(otherParty).unwrap { it } + } + } + +The above shows an example evolution of a flow which in the first version was expecting to receive an Int, but then +in subsequent versions was relaxed to receive a String. This flow is still able to communicate with parties which are +running the older flow (or rather older CorDapps containing the older flow). + +.. warning:: It's important that ``InitiatingFlow.version`` be incremented each time the flow protocol changes in an + incompatible way. + +``FlowContext`` also has ``appName`` which is the name of the CorDapp hosting the flow. This can be used to determine +implementation details of the CorDapp. See :doc:`cordapp-build-systems` for more information on the CorDapp filename. + +.. note:: Currently changing any of the properties of a ``CordaSerializable`` type is also backwards incompatible and + requires incrementing of ``InitiatingFlow.version``. This will be relaxed somewhat once the AMQP wire serialisation + format is implemented as it will automatically handle a lot of the data type migration cases. -.. note:: Currently we don't support multiple versions of the same flow loaded in the same node. This will be possible - once we start loading CorDapps in separate class loaders. diff --git a/docs/source/writing-cordapps.rst b/docs/source/writing-cordapps.rst index b3a360efb5..c63be2b40f 100644 --- a/docs/source/writing-cordapps.rst +++ b/docs/source/writing-cordapps.rst @@ -1,7 +1,8 @@ Writing a CorDapp ================= -The source-code for a CorDapp is a set of files written in a JVM language that defines a set of Corda components: +When writing a CorDapp, you are writing a set of files in a JVM language that defines one or more of the following +Corda components: * States (i.e. classes implementing ``ContractState``) * Contracts (i.e. classes implementing ``Contract``) @@ -9,14 +10,14 @@ The source-code for a CorDapp is a set of files written in a JVM language that d * Web APIs * Services -These files should be placed under ``src/main/[java|kotlin]``. The CorDapp's resources folder (``src/main/resources``) -should also include the following subfolders: +CorDapp structure +----------------- +Your CorDapp project's structure should be based on the structure of the +`Java Template CorDapp `_ or the +`Kotlin Template CorDapp `_, depending on which language you intend +to use. -* ``src/main/resources/certificates``, containing the node's certificates -* ``src/main/resources/META-INF/services``, containing a file named ``net.corda.core.node.CordaPluginRegistry`` - -For example, the source-code of the `Template CorDapp `_ has the following -structure: +The ``src`` directory of the Template CorDapp, where we define our CorDapp's source-code, has the following structure: .. parsed-literal:: @@ -59,87 +60,39 @@ structure: └── contract └── TemplateTests.java -Defining a plugin ------------------ -You can specify the transport options (between nodes and between Web Client and a node) for your CorDapp by subclassing -``net.corda.core.node.CordaPluginRegistry``: +Defining plugins +---------------- +Your CorDapp may need to define two types of plugins: -* The ``customizeSerialization`` function allows classes to be whitelisted for object serialisation, over and - above those tagged with the ``@CordaSerializable`` annotation. For instance, new state types will need to be - explicitly registered. In general, the annotation should be preferred. See :doc:`serialization`. +* ``CordaPluginRegistry`` subclasses, which define additional serializable classes and vault schemas +* ``WebServerPluginRegistry`` subclasses, which define the APIs and static web content served by your CorDapp -The fully-qualified class path of each ``CordaPluginRegistry`` subclass must be added to the -``net.corda.core.node.CordaPluginRegistry`` file in the CorDapp's ``resources/META-INF/services`` folder. A CorDapp -can register multiple plugins in a single ``net.corda.core.node.CordaPluginRegistry`` file. +The fully-qualified class path of each ``CordaPluginRegistry`` subclass must then be added to the +``net.corda.core.node.CordaPluginRegistry`` file in the CorDapp's ``resources/META-INF/services`` folder. Meanwhile, +the fully-qualified class path of each ``WebServerPluginRegistry`` subclass must be added to the +``net.corda.webserver.services.WebServerPluginRegistry`` file, again in the CorDapp's ``resources/META-INF/services`` +folder. -You can specify the web APIs and static web content for your CorDapp by implementing -``net.corda.webserver.services.WebServerPluginRegistry`` interface: +The ``CordaPluginRegistry`` class defines the following: -* The ``webApis`` property is a list of JAX-RS annotated REST access classes. These classes will be constructed by - the bundled web server and must have a single argument constructor taking a ``CordaRPCOps`` object. This will - allow the API to communicate with the node process via the RPC interface. These web APIs will not be available if the - bundled web server is not started. +* ``customizeSerialization``, which can be overridden to provide a list of the classes to be whitelisted for object + serialisation, over and above those tagged with the ``@CordaSerializable`` annotation. See :doc:`serialization` + +* ``requiredSchemas``, which can be overridden to return a set of the MappedSchemas to use for persistence and vault + queries + +The ``WebServerPluginRegistry`` class defines the following: + +* ``webApis``, which can be overridden to return a list of JAX-RS annotated REST access classes. These classes will be + constructed by the bundled web server and must have a single argument constructor taking a ``CordaRPCOps`` object. + This will allow the API to communicate with the node process via the RPC interface. These web APIs will not be + available if the bundled web server is not started + +* ``staticServeDirs``, which can be overridden to map static web content to virtual paths and allow simple web demos to + be distributed within the CorDapp jars. This static content will not be available if the bundled web server is not + started -* The ``staticServeDirs`` property maps static web content to virtual paths and allows simple web demos to be - distributed within the CorDapp jars. These static serving directories will not be available if the bundled web server - is not started. * The static web content itself should be placed inside the ``src/main/resources`` directory -The fully-qualified class path of each ``WebServerPluginRegistry`` class must be added to the -``net.corda.webserver.services.WebServerPluginRegistry`` file in the CorDapp's ``resources/META-INF/services`` folder. A CorDapp -can register multiple plugins in a single ``net.corda.webserver.services.WebServerPluginRegistry`` file. - -Installing CorDapps -------------------- -To run a CorDapp, its source is compiled into a JAR by running the gradle ``jar`` task. The CorDapp JAR is then added -to a node by adding it to the node's ``/plugins/`` folder (where ``node_dir`` is the folder in which the -node's JAR and configuration files are stored). - -.. note:: Any external dependencies of your CorDapp will automatically be placed into the - ``/dependencies/`` folder. This will be changed in a future release. - -.. note:: Building nodes using the gradle ``deployNodes`` task will place the CorDapp JAR into each node's ``plugins`` - folder automatically. - -At runtime, nodes will load any plugins present in their ``plugins`` folder. - -RPC permissions ---------------- -If a node's owner needs to interact with their node via RPC (e.g. to read the contents of the node's storage), they -must define one or more RPC users. These users are added to the node's ``node.conf`` file. - -The syntax for adding an RPC user is: - -.. container:: codeset - - .. sourcecode:: groovy - - rpcUsers=[ - { - username=exampleUser - password=examplePass - permissions=[] - } - ... - ] - -Currently, users need special permissions to start flows via RPC. These permissions are added as follows: - -.. container:: codeset - - .. sourcecode:: groovy - - rpcUsers=[ - { - username=exampleUser - password=examplePass - permissions=[ - "StartFlow.net.corda.flows.ExampleFlow1", - "StartFlow.net.corda.flows.ExampleFlow2" - ] - } - ... - ] - -.. note:: Currently, the node's web server has super-user access, meaning that it can run any RPC operation without - logging in. This will be changed in a future release. +To learn about how to use gradle to build your cordapp against Corda and generate an artifact please read +:doc:`cordapp-build-systems`. \ No newline at end of file diff --git a/experimental/src/main/kotlin/net/corda/contracts/universal/UniversalContract.kt b/experimental/src/main/kotlin/net/corda/contracts/universal/UniversalContract.kt index 74e3b0b00c..0e695837ef 100644 --- a/experimental/src/main/kotlin/net/corda/contracts/universal/UniversalContract.kt +++ b/experimental/src/main/kotlin/net/corda/contracts/universal/UniversalContract.kt @@ -6,6 +6,7 @@ import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder import java.math.BigDecimal import java.time.Instant @@ -37,14 +38,14 @@ class UniversalContract : Contract { class Split(val ratio: BigDecimal) : Commands } - fun eval(@Suppress("UNUSED_PARAMETER") tx: TransactionForContract, expr: Perceivable): Instant? = when (expr) { + fun eval(@Suppress("UNUSED_PARAMETER") tx: LedgerTransaction, expr: Perceivable): Instant? = when (expr) { is Const -> expr.value is StartDate -> null is EndDate -> null else -> throw Error("Unable to evaluate") } - fun eval(tx: TransactionForContract, expr: Perceivable): Boolean = when (expr) { + fun eval(tx: LedgerTransaction, expr: Perceivable): Boolean = when (expr) { is PerceivableAnd -> eval(tx, expr.left) && eval(tx, expr.right) is PerceivableOr -> eval(tx, expr.left) || eval(tx, expr.right) is Const -> expr.value @@ -57,7 +58,7 @@ class UniversalContract : Contract { else -> throw NotImplementedError("eval - Boolean - " + expr.javaClass.name) } - fun eval(tx: TransactionForContract, expr: Perceivable): BigDecimal = + fun eval(tx: LedgerTransaction, expr: Perceivable): BigDecimal = when (expr) { is Const -> expr.value is UnaryPlus -> { @@ -94,7 +95,7 @@ class UniversalContract : Contract { else -> throw NotImplementedError("eval - BigDecimal - " + expr.javaClass.name) } - fun validateImmediateTransfers(tx: TransactionForContract, arrangement: Arrangement): Arrangement = when (arrangement) { + fun validateImmediateTransfers(tx: LedgerTransaction, arrangement: Arrangement): Arrangement = when (arrangement) { is Obligation -> { val amount = eval(tx, arrangement.amount) requireThat { "transferred quantity is non-negative" using (amount >= BigDecimal.ZERO) } @@ -179,7 +180,7 @@ class UniversalContract : Contract { else -> throw NotImplementedError("replaceNext " + arrangement.javaClass.name) } - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { requireThat { "transaction has a single command".using(tx.commands.size == 1) @@ -191,7 +192,7 @@ class UniversalContract : Contract { when (value) { is Commands.Action -> { - val inState = tx.inputs.single() as State + val inState = tx.inputsOfType().single() val arr = when (inState.details) { is Actions -> inState.details is RollOut -> reduceRollOut(inState.details) @@ -221,7 +222,7 @@ class UniversalContract : Contract { when (tx.outputs.size) { 1 -> { - val outState = tx.outputs.single() as State + val outState = tx.outputsOfType().single() requireThat { "output state must match action result state" using (arrangement.equals(outState.details)) "output state must match action result state" using (rest == zero) @@ -229,7 +230,7 @@ class UniversalContract : Contract { } 0 -> throw IllegalArgumentException("must have at least one out state") else -> { - val allContracts = And(tx.outputs.map { (it as State).details }.toSet()) + val allContracts = And(tx.outputsOfType().map { it.details }.toSet()) requireThat { "output states must match action result state" using (arrangement.equals(allContracts)) @@ -239,15 +240,15 @@ class UniversalContract : Contract { } } is Commands.Issue -> { - val outState = tx.outputs.single() as State + val outState = tx.outputsOfType().single() requireThat { "the transaction is signed by all liable parties" using (liableParties(outState.details).all { it in cmd.signers }) "the transaction has no input states" using tx.inputs.isEmpty() } } is Commands.Move -> { - val inState = tx.inputs.single() as State - val outState = tx.outputs.single() as State + val inState = tx.inputsOfType().single() + val outState = tx.outputsOfType().single() requireThat { "the transaction is signed by all liable parties" using (liableParties(outState.details).all { it in cmd.signers }) @@ -256,13 +257,13 @@ class UniversalContract : Contract { } } is Commands.Fix -> { - val inState = tx.inputs.single() as State + val inState = tx.inputsOfType().single() val arr = when (inState.details) { is Actions -> inState.details is RollOut -> reduceRollOut(inState.details) else -> throw IllegalArgumentException("Unexpected arrangement, " + tx.inputs.single()) } - val outState = tx.outputs.single() as State + val outState = tx.outputsOfType().single() val unusedFixes = value.fixes.map { it.of }.toMutableSet() val expectedArr = replaceFixing(tx, arr, @@ -279,7 +280,7 @@ class UniversalContract : Contract { } @Suppress("UNCHECKED_CAST") - fun replaceFixing(tx: TransactionForContract, perceivable: Perceivable, + fun replaceFixing(tx: LedgerTransaction, perceivable: Perceivable, fixings: Map, unusedFixings: MutableSet): Perceivable = when (perceivable) { is Const -> perceivable @@ -299,11 +300,11 @@ class UniversalContract : Contract { else -> throw NotImplementedError("replaceFixing - " + perceivable.javaClass.name) } - fun replaceFixing(tx: TransactionForContract, arr: Action, + fun replaceFixing(tx: LedgerTransaction, arr: Action, fixings: Map, unusedFixings: MutableSet) = Action(arr.name, replaceFixing(tx, arr.condition, fixings, unusedFixings), replaceFixing(tx, arr.arrangement, fixings, unusedFixings)) - fun replaceFixing(tx: TransactionForContract, arr: Arrangement, + fun replaceFixing(tx: LedgerTransaction, arr: Arrangement, fixings: Map, unusedFixings: MutableSet): Arrangement = when (arr) { is Zero -> arr diff --git a/finance/build.gradle b/finance/build.gradle index 3e4cae5813..9604e46672 100644 --- a/finance/build.gradle +++ b/finance/build.gradle @@ -16,19 +16,6 @@ dependencies { testCompile project(':test-utils') testCompile project(path: ':core', configuration: 'testArtifacts') testCompile "junit:junit:$junit_version" - - // TODO: Upgrade to junit-quickcheck 0.8, once it is released, - // because it depends on org.javassist:javassist instead - // of javassist:javassist. - testCompile "com.pholser:junit-quickcheck-core:$quickcheck_version" - testCompile "com.pholser:junit-quickcheck-generators:$quickcheck_version" -} - -configurations.testCompile { - // Excluding javassist:javassist because it clashes with Hibernate's - // transitive org.javassist:javassist dependency. - // TODO: Remove this exclusion once junit-quickcheck 0.8 is released. - exclude group: 'javassist', module: 'javassist' } configurations { @@ -49,5 +36,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } \ No newline at end of file diff --git a/finance/isolated/src/main/kotlin/net/corda/contracts/AnotherDummyContract.kt b/finance/isolated/src/main/kotlin/net/corda/contracts/isolated/AnotherDummyContract.kt similarity index 72% rename from finance/isolated/src/main/kotlin/net/corda/contracts/AnotherDummyContract.kt rename to finance/isolated/src/main/kotlin/net/corda/contracts/isolated/AnotherDummyContract.kt index ae94bab519..9d615cc02b 100644 --- a/finance/isolated/src/main/kotlin/net/corda/contracts/AnotherDummyContract.kt +++ b/finance/isolated/src/main/kotlin/net/corda/contracts/isolated/AnotherDummyContract.kt @@ -1,17 +1,16 @@ package net.corda.contracts.isolated import net.corda.core.contracts.* -import net.corda.core.identity.Party import net.corda.core.crypto.SecureHash import net.corda.core.identity.AbstractParty +import net.corda.core.identity.Party +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder -import java.security.PublicKey - -// The dummy contract doesn't do anything useful. It exists for testing purposes. +import net.corda.nodeapi.DummyContractBackdoor val ANOTHER_DUMMY_PROGRAM_ID = AnotherDummyContract() -class AnotherDummyContract : Contract, net.corda.core.node.DummyContractBackdoor { +class AnotherDummyContract : Contract, DummyContractBackdoor { data class State(val magicNumber: Int = 0) : ContractState { override val contract = ANOTHER_DUMMY_PROGRAM_ID override val participants: List @@ -22,7 +21,7 @@ class AnotherDummyContract : Contract, net.corda.core.node.DummyContractBackdoor class Create : TypeOnlyCommandData(), Commands } - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { // Always accepts. } @@ -31,7 +30,7 @@ class AnotherDummyContract : Contract, net.corda.core.node.DummyContractBackdoor override fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder { val state = State(magicNumber) - return TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Create(), owner.party.owningKey)) + return TransactionBuilder(notary).withItems(state, Command(Commands.Create(), owner.party.owningKey)) } override fun inspectState(state: ContractState): Int = (state as State).magicNumber diff --git a/finance/isolated/src/main/kotlin/net/corda/core/node/DummyContractBackdoor.kt b/finance/isolated/src/main/kotlin/net/corda/nodeapi/DummyContractBackdoor.kt similarity index 93% rename from finance/isolated/src/main/kotlin/net/corda/core/node/DummyContractBackdoor.kt rename to finance/isolated/src/main/kotlin/net/corda/nodeapi/DummyContractBackdoor.kt index d965530bb2..2b6e9c2eac 100644 --- a/finance/isolated/src/main/kotlin/net/corda/core/node/DummyContractBackdoor.kt +++ b/finance/isolated/src/main/kotlin/net/corda/nodeapi/DummyContractBackdoor.kt @@ -1,4 +1,4 @@ -package net.corda.core.node +package net.corda.nodeapi import net.corda.core.contracts.ContractState import net.corda.core.contracts.PartyAndReference diff --git a/finance/src/main/java/net/corda/contracts/JavaCommercialPaper.java b/finance/src/main/java/net/corda/contracts/JavaCommercialPaper.java index df6b29e5be..d4910e67a9 100644 --- a/finance/src/main/java/net/corda/contracts/JavaCommercialPaper.java +++ b/finance/src/main/java/net/corda/contracts/JavaCommercialPaper.java @@ -5,34 +5,27 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import kotlin.Pair; import kotlin.Unit; +import net.corda.contracts.asset.Cash; import net.corda.contracts.asset.CashKt; import net.corda.core.contracts.*; -import net.corda.core.contracts.TransactionForContract.InOutGroup; -import net.corda.core.contracts.clauses.AnyOf; -import net.corda.core.contracts.clauses.Clause; -import net.corda.core.contracts.clauses.ClauseVerifier; -import net.corda.core.contracts.clauses.GroupClauseVerifier; import net.corda.core.crypto.SecureHash; import net.corda.core.crypto.testing.NullPublicKey; import net.corda.core.identity.AbstractParty; import net.corda.core.identity.AnonymousParty; import net.corda.core.identity.Party; -import net.corda.core.node.services.VaultService; +import net.corda.core.node.ServiceHub; +import net.corda.core.transactions.LedgerTransaction; import net.corda.core.transactions.TransactionBuilder; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import java.time.Instant; -import java.util.Collections; -import java.util.Currency; -import java.util.List; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import static net.corda.core.contracts.ContractsDSL.requireSingleCommand; import static net.corda.core.contracts.ContractsDSL.requireThat; - /** * This is a Java version of the CommercialPaper contract (chosen because it's simple). This demonstrates how the * use of Kotlin for implementation of the framework does not impose the same language choice on contract developers. @@ -69,8 +62,8 @@ public class JavaCommercialPaper implements Contract { @NotNull @Override - public Pair withNewOwner(@NotNull AbstractParty newOwner) { - return new Pair<>(new Commands.Move(), new State(this.issuance, newOwner, this.faceValue, this.maturityDate)); + public CommandAndState withNewOwner(@NotNull AbstractParty newOwner) { + return new CommandAndState(new Commands.Move(), new State(this.issuance, newOwner, this.faceValue, this.maturityDate)); } public ICommercialPaperState withFaceValue(Amount> newFaceValue) { @@ -105,16 +98,18 @@ public class JavaCommercialPaper implements Contract { } @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; + public boolean equals(Object that) { + if (this == that) return true; + if (that == null || getClass() != that.getClass()) return false; - State state = (State) o; + State state = (State) that; if (issuance != null ? !issuance.equals(state.issuance) : state.issuance != null) return false; if (owner != null ? !owner.equals(state.owner) : state.owner != null) return false; if (faceValue != null ? !faceValue.equals(state.faceValue) : state.faceValue != null) return false; - return !(maturityDate != null ? !maturityDate.equals(state.maturityDate) : state.maturityDate != null); + if (maturityDate != null ? !maturityDate.equals(state.maturityDate) : state.maturityDate != null) + return false; + return true; } @Override @@ -137,138 +132,6 @@ public class JavaCommercialPaper implements Contract { } } - public interface Clauses { - @SuppressWarnings("unused") - class Group extends GroupClauseVerifier { - // This complains because we're passing generic types into a varargs, but it is valid so we suppress the - // warning. - @SuppressWarnings("unchecked") - Group() { - super(new AnyOf<>( - new Clauses.Redeem(), - new Clauses.Move(), - new Clauses.Issue() - )); - } - - @NotNull - @Override - public List> groupStates(@NotNull TransactionForContract tx) { - return tx.groupStates(State.class, State::withoutOwner); - } - } - - @SuppressWarnings("unused") - class Move extends Clause { - @NotNull - @Override - public Set> getRequiredCommands() { - return Collections.singleton(Commands.Move.class); - } - - @NotNull - @Override - public Set verify(@NotNull TransactionForContract tx, - @NotNull List inputs, - @NotNull List outputs, - @NotNull List> commands, - State groupingKey) { - AuthenticatedObject cmd = requireSingleCommand(tx.getCommands(), Commands.Move.class); - // There should be only a single input due to aggregation above - State input = Iterables.getOnlyElement(inputs); - - if (!cmd.getSigners().contains(input.getOwner().getOwningKey())) - throw new IllegalStateException("Failed requirement: the transaction is signed by the owner of the CP"); - - // Check the output CP state is the same as the input state, ignoring the owner field. - if (outputs.size() != 1) { - throw new IllegalStateException("the state is propagated"); - } - // Don't need to check anything else, as if outputs.size == 1 then the output is equal to - // the input ignoring the owner field due to the grouping. - return Collections.singleton(cmd.getValue()); - } - } - - @SuppressWarnings("unused") - class Redeem extends Clause { - @NotNull - @Override - public Set> getRequiredCommands() { - return Collections.singleton(Commands.Redeem.class); - } - - @NotNull - @Override - public Set verify(@NotNull TransactionForContract tx, - @NotNull List inputs, - @NotNull List outputs, - @NotNull List> commands, - State groupingKey) { - AuthenticatedObject cmd = requireSingleCommand(tx.getCommands(), Commands.Redeem.class); - - // There should be only a single input due to aggregation above - State input = Iterables.getOnlyElement(inputs); - - if (!cmd.getSigners().contains(input.getOwner().getOwningKey())) - throw new IllegalStateException("Failed requirement: the transaction is signed by the owner of the CP"); - - TimeWindow timeWindow = tx.getTimeWindow(); - Instant time = null == timeWindow - ? null - : timeWindow.getUntilTime(); - Amount> received = CashKt.sumCashBy(tx.getOutputs(), input.getOwner()); - - requireThat(require -> { - require.using("must be timestamped", timeWindow != null); - require.using("received amount equals the face value: " - + received + " vs " + input.getFaceValue(), received.equals(input.getFaceValue())); - require.using("the paper must have matured", time != null && !time.isBefore(input.getMaturityDate())); - require.using("the received amount equals the face value", input.getFaceValue().equals(received)); - require.using("the paper must be destroyed", outputs.isEmpty()); - return Unit.INSTANCE; - }); - - return Collections.singleton(cmd.getValue()); - } - } - - @SuppressWarnings("unused") - class Issue extends Clause { - @NotNull - @Override - public Set> getRequiredCommands() { - return Collections.singleton(Commands.Issue.class); - } - - @NotNull - @Override - public Set verify(@NotNull TransactionForContract tx, - @NotNull List inputs, - @NotNull List outputs, - @NotNull List> commands, - State groupingKey) { - AuthenticatedObject cmd = requireSingleCommand(tx.getCommands(), Commands.Issue.class); - State output = Iterables.getOnlyElement(outputs); - TimeWindow timeWindowCommand = tx.getTimeWindow(); - Instant time = null == timeWindowCommand - ? null - : timeWindowCommand.getUntilTime(); - - requireThat(require -> { - require.using("output values sum to more than the inputs", inputs.isEmpty()); - require.using("output values sum to more than the inputs", output.faceValue.getQuantity() > 0); - require.using("must be timestamped", timeWindowCommand != null); - require.using("the maturity date is not in the past", time != null && time.isBefore(output.getMaturityDate())); - require.using("output states are issued by a command signer", cmd.getSigners().contains(output.issuance.getParty().getOwningKey())); - return Unit.INSTANCE; - }); - - return Collections.singleton(cmd.getValue()); - } - } - } - public interface Commands extends CommandData { class Move implements Commands { @Override @@ -293,7 +156,7 @@ public class JavaCommercialPaper implements Contract { } @NotNull - private List> extractCommands(@NotNull TransactionForContract tx) { + private List> extractCommands(@NotNull LedgerTransaction tx) { return tx.getCommands() .stream() .filter((AuthenticatedObject command) -> command.getValue() instanceof Commands) @@ -302,8 +165,76 @@ public class JavaCommercialPaper implements Contract { } @Override - public void verify(@NotNull TransactionForContract tx) throws IllegalArgumentException { - ClauseVerifier.verifyClause(tx, new Clauses.Group(), extractCommands(tx)); + public void verify(@NotNull LedgerTransaction tx) throws IllegalArgumentException { + + // Group by everything except owner: any modification to the CP at all is considered changing it fundamentally. + final List> groups = tx.groupStates(State.class, State::withoutOwner); + + // There are two possible things that can be done with this CP. The first is trading it. The second is redeeming + // it for cash on or after the maturity date. + final List> commands = tx.getCommands().stream().filter( + it -> { + return it.getValue() instanceof Commands; + } + ).collect(Collectors.toList()); + final AuthenticatedObject command = Iterables.getOnlyElement(commands); + final TimeWindow timeWindow = tx.getTimeWindow(); + + for (final LedgerTransaction.InOutGroup group : groups) { + final List inputs = group.getInputs(); + final List outputs = group.getOutputs(); + if (command.getValue() instanceof Commands.Move) { + final AuthenticatedObject cmd = requireSingleCommand(tx.getCommands(), Commands.Move.class); + // There should be only a single input due to aggregation above + final State input = Iterables.getOnlyElement(inputs); + + if (!cmd.getSigners().contains(input.getOwner().getOwningKey())) + throw new IllegalStateException("Failed requirement: the transaction is signed by the owner of the CP"); + + // Check the output CP state is the same as the input state, ignoring the owner field. + if (outputs.size() != 1) { + throw new IllegalStateException("the state is propagated"); + } + } else if (command.getValue() instanceof Commands.Redeem) { + final AuthenticatedObject cmd = requireSingleCommand(tx.getCommands(), Commands.Redeem.class); + + // There should be only a single input due to aggregation above + final State input = Iterables.getOnlyElement(inputs); + + if (!cmd.getSigners().contains(input.getOwner().getOwningKey())) + throw new IllegalStateException("Failed requirement: the transaction is signed by the owner of the CP"); + + final Instant time = null == timeWindow + ? null + : timeWindow.getUntilTime(); + final Amount> received = CashKt.sumCashBy(tx.getOutputs().stream().map(TransactionState::getData).collect(Collectors.toList()), input.getOwner()); + + requireThat(require -> { + require.using("must be timestamped", timeWindow != null); + require.using("received amount equals the face value: " + + received + " vs " + input.getFaceValue(), received.equals(input.getFaceValue())); + require.using("the paper must have matured", time != null && !time.isBefore(input.getMaturityDate())); + require.using("the received amount equals the face value", input.getFaceValue().equals(received)); + require.using("the paper must be destroyed", outputs.isEmpty()); + return Unit.INSTANCE; + }); + } else if (command.getValue() instanceof Commands.Issue) { + final AuthenticatedObject cmd = requireSingleCommand(tx.getCommands(), Commands.Issue.class); + final State output = Iterables.getOnlyElement(outputs); + final Instant time = null == timeWindow + ? null + : timeWindow.getUntilTime(); + + requireThat(require -> { + require.using("output values sum to more than the inputs", inputs.isEmpty()); + require.using("output values sum to more than the inputs", output.faceValue.getQuantity() > 0); + require.using("must be timestamped", timeWindow != null); + require.using("the maturity date is not in the past", time != null && time.isBefore(output.getMaturityDate())); + require.using("output states are issued by a command signer", cmd.getSigners().contains(output.issuance.getParty().getOwningKey())); + return Unit.INSTANCE; + }); + } + } } @NotNull @@ -316,7 +247,7 @@ public class JavaCommercialPaper implements Contract { public TransactionBuilder generateIssue(@NotNull PartyAndReference issuance, @NotNull Amount> faceValue, @Nullable Instant maturityDate, @NotNull Party notary, Integer encumbrance) { State state = new State(issuance, issuance.getParty(), faceValue, maturityDate); TransactionState output = new TransactionState<>(state, notary, encumbrance); - return new TransactionType.General.Builder(notary).withItems(output, new Command(new Commands.Issue(), issuance.getParty().getOwningKey())); + return new TransactionBuilder(notary).withItems(output, new Command<>(new Commands.Issue(), issuance.getParty().getOwningKey())); } public TransactionBuilder generateIssue(@NotNull PartyAndReference issuance, @NotNull Amount> faceValue, @Nullable Instant maturityDate, @NotNull Party notary) { @@ -324,15 +255,15 @@ public class JavaCommercialPaper implements Contract { } @Suspendable - public void generateRedeem(TransactionBuilder tx, StateAndRef paper, VaultService vault) throws InsufficientBalanceException { - vault.generateSpend(tx, StructuresKt.withoutIssuer(paper.getState().getData().getFaceValue()), paper.getState().getData().getOwner(), null); + public void generateRedeem(TransactionBuilder tx, StateAndRef paper, ServiceHub services) throws InsufficientBalanceException { + Cash.generateSpend(services, tx, Structures.withoutIssuer(paper.getState().getData().getFaceValue()), paper.getState().getData().getOwner(), Collections.EMPTY_SET); tx.addInputState(paper); - tx.addCommand(new Command(new Commands.Redeem(), paper.getState().getData().getOwner().getOwningKey())); + tx.addCommand(new Command<>(new Commands.Redeem(), paper.getState().getData().getOwner().getOwningKey())); } public void generateMove(TransactionBuilder tx, StateAndRef paper, AbstractParty newOwner) { tx.addInputState(paper); tx.addOutputState(new TransactionState<>(new State(paper.getState().getData().getIssuance(), newOwner, paper.getState().getData().getFaceValue(), paper.getState().getData().getMaturityDate()), paper.getState().getNotary(), paper.getState().getEncumbrance())); - tx.addCommand(new Command(new Commands.Move(), paper.getState().getData().getOwner().getOwningKey())); + tx.addCommand(new Command<>(new Commands.Move(), paper.getState().getData().getOwner().getOwningKey())); } } diff --git a/finance/src/main/kotlin/net/corda/contracts/CommercialPaper.kt b/finance/src/main/kotlin/net/corda/contracts/CommercialPaper.kt index f3d2fcdb2e..f94ff09dd5 100644 --- a/finance/src/main/kotlin/net/corda/contracts/CommercialPaper.kt +++ b/finance/src/main/kotlin/net/corda/contracts/CommercialPaper.kt @@ -1,24 +1,21 @@ package net.corda.contracts import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.asset.Cash import net.corda.contracts.asset.sumCashBy -import net.corda.contracts.clause.AbstractIssue import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.AnyOf -import net.corda.core.contracts.clauses.Clause -import net.corda.core.contracts.clauses.GroupClauseVerifier -import net.corda.core.contracts.clauses.verifyClause import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.testing.NULL_PARTY import net.corda.core.crypto.toBase58String import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party -import net.corda.core.node.services.VaultService -import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.Emoji +import net.corda.core.node.ServiceHub import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState import net.corda.core.schemas.QueryableState +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.Emoji import net.corda.schemas.CommercialPaperSchemaV1 import java.time.Instant import java.util.* @@ -44,7 +41,6 @@ import java.util.* * which may need to be tracked. That, in turn, requires validation logic (there is a bean validator that knows how * to do this in the Apache BVal project). */ - val CP_PROGRAM_ID = CommercialPaper() // TODO: Generalise the notion of an owned instrument into a superclass/supercontract. Consider composition vs inheritance. @@ -52,13 +48,6 @@ class CommercialPaper : Contract { // TODO: should reference the content of the legal agreement, not its URI override val legalContractReference: SecureHash = SecureHash.sha256("https://en.wikipedia.org/wiki/Commercial_paper") - data class Terms( - val asset: Issued, - val maturityDate: Instant - ) - - override fun verify(tx: TransactionForContract) = verifyClause(tx, Clauses.Group(), tx.commands.select()) - data class State( val issuance: PartyAndReference, override val owner: AbstractParty, @@ -66,13 +55,10 @@ class CommercialPaper : Contract { val maturityDate: Instant ) : OwnableState, QueryableState, ICommercialPaperState { override val contract = CP_PROGRAM_ID - override val participants: List - get() = listOf(owner) + override val participants = listOf(owner) - val token: Issued - get() = Issued(issuance, Terms(faceValue.token, maturityDate)) - - override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Move(), copy(owner = newOwner)) + override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Commands.Move(), copy(owner = newOwner)) + fun withoutOwner() = copy(owner = NULL_PARTY) override fun toString() = "${Emoji.newspaper}CommercialPaper(of $faceValue redeemable on $maturityDate by '$issuance', owned by $owner)" // Although kotlin is smart enough not to need these, as we are using the ICommercialPaperState, we need to declare them explicitly for use later, @@ -81,7 +67,6 @@ class CommercialPaper : Contract { override fun withFaceValue(newFaceValue: Amount>): ICommercialPaperState = copy(faceValue = newFaceValue) override fun withMaturityDate(newMaturityDate: Instant): ICommercialPaperState = copy(maturityDate = newMaturityDate) - // DOCSTART VaultIndexedQueryCriteria /** Object Relational Mapping support. */ override fun supportedSchemas(): Iterable = listOf(CommercialPaperSchemaV1) /** Additional used schemas would be added here (eg. CommercialPaperV2, ...) */ @@ -99,97 +84,76 @@ class CommercialPaper : Contract { faceValueIssuerParty = this.faceValue.token.issuer.party.owningKey.toBase58String(), faceValueIssuerRef = this.faceValue.token.issuer.reference.bytes ) - /** Additional schema mappings would be added here (eg. CommercialPaperV2, ...) */ + /** Additional schema mappings would be added here (eg. CommercialPaperV2, ...) */ else -> throw IllegalArgumentException("Unrecognised schema $schema") } } - // DOCEND VaultIndexedQueryCriteria - } - - interface Clauses { - class Group : GroupClauseVerifier>( - AnyOf( - Redeem(), - Move(), - Issue())) { - override fun groupStates(tx: TransactionForContract): List>> - = tx.groupStates> { it.token } - } - - class Issue : AbstractIssue( - { map { Amount(it.faceValue.quantity, it.token) }.sumOrThrow() }, - { token -> map { Amount(it.faceValue.quantity, it.token) }.sumOrZero(token) }) { - override val requiredCommands: Set> = setOf(Commands.Issue::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Issued?): Set { - val consumedCommands = super.verify(tx, inputs, outputs, commands, groupingKey) - commands.requireSingleCommand() - val timeWindow = tx.timeWindow - val time = timeWindow?.untilTime ?: throw IllegalArgumentException("Issuances must have a time-window") - - require(outputs.all { time < it.maturityDate }) { "maturity date is not in the past" } - - return consumedCommands - } - } - - class Move : Clause>() { - override val requiredCommands: Set> = setOf(Commands.Move::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Issued?): Set { - val command = commands.requireSingleCommand() - val input = inputs.single() - requireThat { - "the transaction is signed by the owner of the CP" using (input.owner.owningKey in command.signers) - "the state is propagated" using (outputs.size == 1) - // Don't need to check anything else, as if outputs.size == 1 then the output is equal to - // the input ignoring the owner field due to the grouping. - } - return setOf(command.value) - } - } - - class Redeem : Clause>() { - override val requiredCommands: Set> = setOf(Commands.Redeem::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Issued?): Set { - // TODO: This should filter commands down to those with compatible subjects (underlying product and maturity date) - // before requiring a single command - val command = commands.requireSingleCommand() - val timeWindow = tx.timeWindow - - val input = inputs.single() - val received = tx.outputs.sumCashBy(input.owner) - val time = timeWindow?.fromTime ?: throw IllegalArgumentException("Redemptions must have a time-window") - requireThat { - "the paper must have matured" using (time >= input.maturityDate) - "the received amount equals the face value" using (received == input.faceValue) - "the paper must be destroyed" using outputs.isEmpty() - "the transaction is signed by the owner of the CP" using (input.owner.owningKey in command.signers) - } - - return setOf(command.value) - } - - } } interface Commands : CommandData { - data class Move(override val contractHash: SecureHash? = null) : FungibleAsset.Commands.Move, Commands + class Move : TypeOnlyCommandData(), Commands + class Redeem : TypeOnlyCommandData(), Commands - data class Issue(override val nonce: Long = random63BitValue()) : IssueCommand, Commands + // We don't need a nonce in the issue command, because the issuance.reference field should already be unique per CP. + // However, nothing in the platform enforces that uniqueness: it's up to the issuer. + class Issue : TypeOnlyCommandData(), Commands + } + + override fun verify(tx: LedgerTransaction) { + // Group by everything except owner: any modification to the CP at all is considered changing it fundamentally. + val groups = tx.groupStates(State::withoutOwner) + + // There are two possible things that can be done with this CP. The first is trading it. The second is redeeming + // it for cash on or after the maturity date. + val command = tx.commands.requireSingleCommand() + val timeWindow: TimeWindow? = tx.timeWindow + + // Suppress compiler warning as 'key' is an unused variable when destructuring 'groups'. + @Suppress("UNUSED_VARIABLE") + for ((inputs, outputs, key) in groups) { + when (command.value) { + is Commands.Move -> { + val input = inputs.single() + requireThat { + "the transaction is signed by the owner of the CP" using (input.owner.owningKey in command.signers) + "the state is propagated" using (outputs.size == 1) + // Don't need to check anything else, as if outputs.size == 1 then the output is equal to + // the input ignoring the owner field due to the grouping. + } + } + + is Commands.Redeem -> { + // Redemption of the paper requires movement of on-ledger cash. + val input = inputs.single() + val received = tx.outputStates.sumCashBy(input.owner) + val time = timeWindow?.fromTime ?: throw IllegalArgumentException("Redemptions must have a time-window") + requireThat { + "the paper must have matured" using (time >= input.maturityDate) + "the received amount equals the face value" using (received == input.faceValue) + "the paper must be destroyed" using outputs.isEmpty() + "the transaction is signed by the owner of the CP" using (input.owner.owningKey in command.signers) + } + } + + is Commands.Issue -> { + val output = outputs.single() + val time = timeWindow?.untilTime ?: throw IllegalArgumentException("Issuances have a time-window") + requireThat { + // Don't allow people to issue commercial paper under other entities identities. + "output states are issued by a command signer" using + (output.issuance.party.owningKey in command.signers) + "output values sum to more than the inputs" using (output.faceValue.quantity > 0) + "the maturity date is not in the past" using (time < output.maturityDate) + // Don't allow an existing CP state to be replaced by this issuance. + // TODO: Consider how to handle the case of mistaken issuances, or other need to patch. + "output values sum to more than the inputs" using inputs.isEmpty() + } + } + + // TODO: Think about how to evolve contracts over time with new commands. + else -> throw IllegalArgumentException("Unrecognised command") + } + } } /** @@ -197,9 +161,10 @@ class CommercialPaper : Contract { * an existing transaction because you aren't able to issue multiple pieces of CP in a single transaction * at the moment: this restriction is not fundamental and may be lifted later. */ - fun generateIssue(issuance: PartyAndReference, faceValue: Amount>, maturityDate: Instant, notary: Party): TransactionBuilder { - val state = TransactionState(State(issuance, issuance.party, faceValue, maturityDate), notary) - return TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Issue(), issuance.party.owningKey)) + fun generateIssue(issuance: PartyAndReference, faceValue: Amount>, maturityDate: Instant, + notary: Party): TransactionBuilder { + val state = State(issuance, issuance.party, faceValue, maturityDate) + return TransactionBuilder(notary = notary).withItems(state, Command(Commands.Issue(), issuance.party.owningKey)) } /** @@ -207,7 +172,7 @@ class CommercialPaper : Contract { */ fun generateMove(tx: TransactionBuilder, paper: StateAndRef, newOwner: AbstractParty) { tx.addInputState(paper) - tx.addOutputState(TransactionState(paper.state.data.copy(owner = newOwner), paper.state.notary)) + tx.addOutputState(paper.state.data.withOwner(newOwner)) tx.addCommand(Commands.Move(), paper.state.data.owner.owningKey) } @@ -220,17 +185,14 @@ class CommercialPaper : Contract { */ @Throws(InsufficientBalanceException::class) @Suspendable - fun generateRedeem(tx: TransactionBuilder, paper: StateAndRef, vault: VaultService) { + fun generateRedeem(tx: TransactionBuilder, paper: StateAndRef, services: ServiceHub) { // Add the cash movement using the states in our vault. - val amount = paper.state.data.faceValue.let { amount -> Amount(amount.quantity, amount.token.product) } - vault.generateSpend(tx, amount, paper.state.data.owner) + Cash.generateSpend(services, tx, paper.state.data.faceValue.withoutIssuer(), paper.state.data.owner) tx.addInputState(paper) - tx.addCommand(CommercialPaper.Commands.Redeem(), paper.state.data.owner.owningKey) + tx.addCommand(Commands.Redeem(), paper.state.data.owner.owningKey) } } infix fun CommercialPaper.State.`owned by`(owner: AbstractParty) = copy(owner = owner) infix fun CommercialPaper.State.`with notary`(notary: Party) = TransactionState(this, notary) -infix fun ICommercialPaperState.`owned by`(newOwner: AbstractParty) = withOwner(newOwner) - - +infix fun ICommercialPaperState.`owned by`(newOwner: AbstractParty) = withOwner(newOwner) \ No newline at end of file diff --git a/finance/src/main/kotlin/net/corda/contracts/CommercialPaperLegacy.kt b/finance/src/main/kotlin/net/corda/contracts/CommercialPaperLegacy.kt deleted file mode 100644 index 0c3811556b..0000000000 --- a/finance/src/main/kotlin/net/corda/contracts/CommercialPaperLegacy.kt +++ /dev/null @@ -1,135 +0,0 @@ -package net.corda.contracts - -import co.paralleluniverse.fibers.Suspendable -import net.corda.contracts.asset.sumCashBy -import net.corda.core.contracts.* -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.testing.NULL_PARTY -import net.corda.core.identity.AbstractParty -import net.corda.core.identity.Party -import net.corda.core.node.services.VaultService -import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.Emoji -import java.time.Instant -import java.util.* - -/** - * Legacy version of [CommercialPaper] that includes the full verification logic itself, rather than breaking it - * into clauses. This is here just as an example for the contract tutorial. - */ - -val CP_LEGACY_PROGRAM_ID = CommercialPaperLegacy() - -// TODO: Generalise the notion of an owned instrument into a superclass/supercontract. Consider composition vs inheritance. -class CommercialPaperLegacy : Contract { - // TODO: should reference the content of the legal agreement, not its URI - override val legalContractReference: SecureHash = SecureHash.sha256("https://en.wikipedia.org/wiki/Commercial_paper") - - data class State( - val issuance: PartyAndReference, - override val owner: AbstractParty, - val faceValue: Amount>, - val maturityDate: Instant - ) : OwnableState, ICommercialPaperState { - override val contract = CP_LEGACY_PROGRAM_ID - override val participants = listOf(owner) - - fun withoutOwner() = copy(owner = NULL_PARTY) - override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Move(), copy(owner = newOwner)) - override fun toString() = "${Emoji.newspaper}CommercialPaper(of $faceValue redeemable on $maturityDate by '$issuance', owned by $owner)" - - // Although kotlin is smart enough not to need these, as we are using the ICommercialPaperState, we need to declare them explicitly for use later, - override fun withOwner(newOwner: AbstractParty): ICommercialPaperState = copy(owner = newOwner) - - override fun withFaceValue(newFaceValue: Amount>): ICommercialPaperState = copy(faceValue = newFaceValue) - override fun withMaturityDate(newMaturityDate: Instant): ICommercialPaperState = copy(maturityDate = newMaturityDate) - } - - interface Commands : CommandData { - class Move : TypeOnlyCommandData(), Commands - - class Redeem : TypeOnlyCommandData(), Commands - // We don't need a nonce in the issue command, because the issuance.reference field should already be unique per CP. - // However, nothing in the platform enforces that uniqueness: it's up to the issuer. - class Issue : TypeOnlyCommandData(), Commands - } - - override fun verify(tx: TransactionForContract) { - // Group by everything except owner: any modification to the CP at all is considered changing it fundamentally. - val groups = tx.groupStates(State::withoutOwner) - - // There are two possible things that can be done with this CP. The first is trading it. The second is redeeming - // it for cash on or after the maturity date. - val command = tx.commands.requireSingleCommand() - val timeWindow: TimeWindow? = tx.timeWindow - - // Suppress compiler warning as 'key' is an unused variable when destructuring 'groups'. - @Suppress("UNUSED_VARIABLE") - for ((inputs, outputs, key) in groups) { - when (command.value) { - is Commands.Move -> { - val input = inputs.single() - requireThat { - "the transaction is signed by the owner of the CP" using (input.owner.owningKey in command.signers) - "the state is propagated" using (outputs.size == 1) - // Don't need to check anything else, as if outputs.size == 1 then the output is equal to - // the input ignoring the owner field due to the grouping. - } - } - - is Commands.Redeem -> { - // Redemption of the paper requires movement of on-ledger cash. - val input = inputs.single() - val received = tx.outputs.sumCashBy(input.owner) - val time = timeWindow?.fromTime ?: throw IllegalArgumentException("Redemptions must have a time-window") - requireThat { - "the paper must have matured" using (time >= input.maturityDate) - "the received amount equals the face value" using (received == input.faceValue) - "the paper must be destroyed" using outputs.isEmpty() - "the transaction is signed by the owner of the CP" using (input.owner.owningKey in command.signers) - } - } - - is Commands.Issue -> { - val output = outputs.single() - val time = timeWindow?.untilTime ?: throw IllegalArgumentException("Issuances have a time-window") - requireThat { - // Don't allow people to issue commercial paper under other entities identities. - "output states are issued by a command signer" using - (output.issuance.party.owningKey in command.signers) - "output values sum to more than the inputs" using (output.faceValue.quantity > 0) - "the maturity date is not in the past" using (time < output.maturityDate) - // Don't allow an existing CP state to be replaced by this issuance. - // TODO: this has a weird/incorrect assertion string because it doesn't quite match the logic in the clause version. - // TODO: Consider how to handle the case of mistaken issuances, or other need to patch. - "output values sum to more than the inputs" using inputs.isEmpty() - } - } - - // TODO: Think about how to evolve contracts over time with new commands. - else -> throw IllegalArgumentException("Unrecognised command") - } - } - } - - fun generateIssue(issuance: PartyAndReference, faceValue: Amount>, maturityDate: Instant, - notary: Party): TransactionBuilder { - val state = State(issuance, issuance.party, faceValue, maturityDate) - return TransactionBuilder(notary = notary).withItems(state, Command(Commands.Issue(), issuance.party.owningKey)) - } - - fun generateMove(tx: TransactionBuilder, paper: StateAndRef, newOwner: AbstractParty) { - tx.addInputState(paper) - tx.addOutputState(paper.state.data.withOwner(newOwner)) - tx.addCommand(Command(Commands.Move(), paper.state.data.owner.owningKey)) - } - - @Throws(InsufficientBalanceException::class) - @Suspendable - fun generateRedeem(tx: TransactionBuilder, paper: StateAndRef, vault: VaultService) { - // Add the cash movement using the states in our vault. - vault.generateSpend(tx, paper.state.data.faceValue.withoutIssuer(), paper.state.data.owner) - tx.addInputState(paper) - tx.addCommand(Command(Commands.Redeem(), paper.state.data.owner.owningKey)) - } -} diff --git a/finance/src/main/kotlin/net/corda/contracts/FinanceTypes.kt b/finance/src/main/kotlin/net/corda/contracts/FinanceTypes.kt index 82b1ff1a41..c4f4fe6546 100644 --- a/finance/src/main/kotlin/net/corda/contracts/FinanceTypes.kt +++ b/finance/src/main/kotlin/net/corda/contracts/FinanceTypes.kt @@ -8,15 +8,13 @@ import com.fasterxml.jackson.databind.JsonSerializer import com.fasterxml.jackson.databind.SerializerProvider import com.fasterxml.jackson.databind.annotation.JsonDeserialize import com.fasterxml.jackson.databind.annotation.JsonSerialize +import net.corda.contracts.asset.CommodityContract import net.corda.core.contracts.CommandData import net.corda.core.contracts.LinearState import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.TokenizableAssetInfo -import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.node.services.ServiceType -import net.corda.core.node.services.VaultService -import net.corda.core.node.services.linearHeadsOfType import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.TransactionBuilder import java.math.BigDecimal @@ -25,7 +23,6 @@ import java.time.LocalDate import java.time.format.DateTimeFormatter import java.util.* - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // // Interest rate fixes @@ -83,7 +80,6 @@ data class Tenor(val name: String) { TimeUnit.Week -> startDate.plusWeeks(amount.toLong()) TimeUnit.Month -> startDate.plusMonths(amount.toLong()) TimeUnit.Year -> startDate.plusYears(amount.toLong()) - else -> throw IllegalStateException("Invalid tenor time unit: $unit") } // Move date to the closest business day when it falls on a weekend/holiday val adjustedMaturityDate = calendar.applyRollConvention(maturityDate, DateRollConvention.ModifiedFollowing) @@ -394,9 +390,6 @@ data class Commodity(val commodityCode: String, * implementation of general flows that manipulate many agreement types. */ interface DealState : LinearState { - /** Human readable well known reference (e.g. trade reference) */ - val ref: String - /** * Generate a partial transaction representing an agreement (command) to this deal, allowing a general * deal/agreement flow to generate the necessary transaction for potential implementations. @@ -409,12 +402,6 @@ interface DealState : LinearState { fun generateAgreement(notary: Party): TransactionBuilder } -// TODO: Remove this from the interface -@Deprecated("This function will be removed in a future milestone", ReplaceWith("queryBy(LinearStateQueryCriteria(dealPartyName = listOf()))")) -inline fun VaultService.dealsWith(party: AbstractParty) = linearHeadsOfType().values.filter { - it.state.data.participants.any { it == party } -} - /** * Interface adding fixing specific methods. */ diff --git a/finance/src/main/kotlin/net/corda/contracts/GetBalances.kt b/finance/src/main/kotlin/net/corda/contracts/GetBalances.kt new file mode 100644 index 0000000000..d1245403a5 --- /dev/null +++ b/finance/src/main/kotlin/net/corda/contracts/GetBalances.kt @@ -0,0 +1,76 @@ +@file:JvmName("GetBalances") + +package net.corda.contracts + +import net.corda.core.contracts.Amount +import net.corda.core.contracts.FungibleAsset +import net.corda.core.messaging.CordaRPCOps +import net.corda.core.messaging.vaultQueryBy +import net.corda.core.node.ServiceHub +import net.corda.core.node.services.Vault +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.Sort +import net.corda.core.node.services.vault.builder +import net.corda.schemas.CashSchemaV1 +import java.util.* +import kotlin.collections.LinkedHashMap + +private fun generateCashSumCriteria(currency: Currency): QueryCriteria { + val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } + val sumCriteria = QueryCriteria.VaultCustomQueryCriteria(sum) + + val ccyIndex = builder { CashSchemaV1.PersistentCashState::currency.equal(currency.currencyCode) } + val ccyCriteria = QueryCriteria.VaultCustomQueryCriteria(ccyIndex) + return sumCriteria.and(ccyCriteria) +} + +private fun generateCashSumsCriteria(): QueryCriteria { + val sum = builder { + CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency), + orderBy = Sort.Direction.DESC) + } + return QueryCriteria.VaultCustomQueryCriteria(sum) +} + +private fun rowsToAmount(currency: Currency, rows: Vault.Page>): Amount { + return if (rows.otherResults.isEmpty()) { + Amount(0L, currency) + } else { + require(rows.otherResults.size == 2) + require(rows.otherResults[1] == currency.currencyCode) + @Suppress("UNCHECKED_CAST") + val quantity = rows.otherResults[0] as Long + Amount(quantity, currency) + } +} + +private fun rowsToBalances(rows: List): Map> { + val balances = LinkedHashMap>() + for (index in 0..rows.size - 1 step 2) { + val ccy = Currency.getInstance(rows[index + 1] as String) + balances[ccy] = Amount(rows[index] as Long, ccy) + } + return balances +} + +fun CordaRPCOps.getCashBalance(currency: Currency): Amount { + val results = this.vaultQueryByCriteria(generateCashSumCriteria(currency), FungibleAsset::class.java) + return rowsToAmount(currency, results) +} + +fun ServiceHub.getCashBalance(currency: Currency): Amount { + val results = this.vaultQueryService.queryBy>(generateCashSumCriteria(currency)) + return rowsToAmount(currency, results) +} + +fun CordaRPCOps.getCashBalances(): Map> { + val sums = this.vaultQueryBy>(generateCashSumsCriteria()).otherResults + return rowsToBalances(sums) +} + +fun ServiceHub.getCashBalances(): Map> { + val sums = this.vaultQueryService.queryBy>(generateCashSumsCriteria()).otherResults + return rowsToBalances(sums) +} + diff --git a/finance/src/main/kotlin/net/corda/contracts/asset/Cash.kt b/finance/src/main/kotlin/net/corda/contracts/asset/Cash.kt index f47fb51ec1..02832f6893 100644 --- a/finance/src/main/kotlin/net/corda/contracts/asset/Cash.kt +++ b/finance/src/main/kotlin/net/corda/contracts/asset/Cash.kt @@ -1,13 +1,8 @@ package net.corda.contracts.asset -import net.corda.contracts.clause.AbstractConserveAmount -import net.corda.contracts.clause.AbstractIssue -import net.corda.contracts.clause.NoZeroSizedOutputs +import co.paralleluniverse.fibers.Suspendable +import co.paralleluniverse.strands.Strand import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.AllOf -import net.corda.core.contracts.clauses.FirstOf -import net.corda.core.contracts.clauses.GroupClauseVerifier -import net.corda.core.contracts.clauses.verifyClause import net.corda.core.crypto.SecureHash import net.corda.core.crypto.entropyToKeyPair import net.corda.core.crypto.newSecureRandom @@ -15,16 +10,28 @@ import net.corda.core.crypto.testing.NULL_PARTY import net.corda.core.crypto.toBase58String import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.internal.Emoji +import net.corda.core.node.ServiceHub +import net.corda.core.node.services.StatesNotAvailableException import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState import net.corda.core.schemas.QueryableState -import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.deserialize +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.Emoji +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.toHexString +import net.corda.core.utilities.toNonEmptySet +import net.corda.core.utilities.trace import net.corda.schemas.CashSchemaV1 import org.bouncycastle.asn1.x500.X500Name import java.math.BigInteger +import java.security.PublicKey +import java.sql.SQLException import java.util.* +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -60,33 +67,11 @@ class Cash : OnLedgerAsset() { */ // DOCSTART 2 override val legalContractReference: SecureHash = SecureHash.sha256("https://www.big-book-of-banking-law.gov/cash-claims.html") + // DOCEND 2 override fun extractCommands(commands: Collection>): List> = commands.select() - interface Clauses { - class Group : GroupClauseVerifier>(AllOf>( - NoZeroSizedOutputs(), - FirstOf>( - Issue(), - ConserveAmount()) - ) - ) { - override fun groupStates(tx: TransactionForContract): List>> - = tx.groupStates> { it.amount.token } - } - - class Issue : AbstractIssue( - sum = { sumCash() }, - sumOrZero = { sumCashOrZero(it) } - ) { - override val requiredCommands: Set> = setOf(Commands.Issue::class.java) - } - - @CordaSerializable - class ConserveAmount : AbstractConserveAmount() - } - // DOCSTART 1 /** A state representing a cash claim against some party. */ data class State( @@ -107,7 +92,7 @@ class Cash : OnLedgerAsset() { override fun toString() = "${Emoji.bagOfCash}Cash($amount at ${amount.token.issuer} owned by $owner)" - override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Move(), copy(owner = newOwner)) + override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Commands.Move(), copy(owner = newOwner)) /** Object Relational Mapping support. */ override fun generateMappedObject(schema: MappedSchema): PersistentState { @@ -119,7 +104,7 @@ class Cash : OnLedgerAsset() { issuerParty = this.amount.token.issuer.party.owningKey.toBase58String(), issuerRef = this.amount.token.issuer.reference.bytes ) - /** Additional schema mappings would be added here (eg. CashSchemaV2, CashSchemaV3, ...) */ + /** Additional schema mappings would be added here (eg. CashSchemaV2, CashSchemaV3, ...) */ else -> throw IllegalArgumentException("Unrecognised schema $schema") } } @@ -164,7 +149,7 @@ class Cash : OnLedgerAsset() { * Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. */ fun generateIssue(tx: TransactionBuilder, amount: Amount>, owner: AbstractParty, notary: Party) - = generateIssue(tx, TransactionState(State(amount, owner), notary), generateIssueCommand()) + = generateIssue(tx, TransactionState(State(amount, owner), notary), generateIssueCommand()) override fun deriveState(txState: TransactionState, amount: Amount>, owner: AbstractParty) = txState.copy(data = txState.data.copy(amount = amount, owner = owner)) @@ -173,8 +158,230 @@ class Cash : OnLedgerAsset() { override fun generateIssueCommand() = Commands.Issue() override fun generateMoveCommand() = Commands.Move() - override fun verify(tx: TransactionForContract) - = verifyClause(tx, Clauses.Group(), extractCommands(tx.commands)) + override fun verify(tx: LedgerTransaction) { + // Each group is a set of input/output states with distinct (reference, currency) attributes. These types + // of cash are not fungible and must be kept separated for bookkeeping purposes. + val groups = tx.groupStates { it: Cash.State -> it.amount.token } + + for ((inputs, outputs, key) in groups) { + // Either inputs or outputs could be empty. + val issuer = key.issuer + val currency = key.product + + requireThat { + "there are no zero sized outputs" using (outputs.none { it.amount.quantity == 0L }) + } + + val issueCommand = tx.commands.select().firstOrNull() + if (issueCommand != null) { + verifyIssueCommand(inputs, outputs, tx, issueCommand, currency, issuer) + } else { + val inputAmount = inputs.sumCashOrNull() ?: throw IllegalArgumentException("there is at least one cash input for this group") + val outputAmount = outputs.sumCashOrZero(Issued(issuer, currency)) + + // If we want to remove cash from the ledger, that must be signed for by the issuer. + // A mis-signed or duplicated exit command will just be ignored here and result in the exit amount being zero. + val exitKeys: Set = inputs.flatMap { it.exitKeys }.toSet() + val exitCommand = tx.commands.select(parties = null, signers = exitKeys).filter { it.value.amount.token == key }.singleOrNull() + val amountExitingLedger = exitCommand?.value?.amount ?: Amount(0, Issued(issuer, currency)) + + requireThat { + "there are no zero sized inputs" using inputs.none { it.amount.quantity == 0L } + "for reference ${issuer.reference} at issuer ${issuer.party} the amounts balance: ${inputAmount.quantity} - ${amountExitingLedger.quantity} != ${outputAmount.quantity}" using + (inputAmount == outputAmount + amountExitingLedger) + } + + verifyMoveCommand(inputs, tx.commands) + } + } + } + + private fun verifyIssueCommand(inputs: List, + outputs: List, + tx: LedgerTransaction, + issueCommand: AuthenticatedObject, + currency: Currency, + issuer: PartyAndReference) { + // If we have an issue command, perform special processing: the group is allowed to have no inputs, + // and the output states must have a deposit reference owned by the signer. + // + // Whilst the transaction *may* have no inputs, it can have them, and in this case the outputs must + // sum to more than the inputs. An issuance of zero size is not allowed. + // + // Note that this means literally anyone with access to the network can issue cash claims of arbitrary + // amounts! It is up to the recipient to decide if the backing party is trustworthy or not, via some + // as-yet-unwritten identity service. See ADP-22 for discussion. + + // The grouping ensures that all outputs have the same deposit reference and currency. + val inputAmount = inputs.sumCashOrZero(Issued(issuer, currency)) + val outputAmount = outputs.sumCash() + val cashCommands = tx.commands.select() + requireThat { + "the issue command has a nonce" using (issueCommand.value.nonce != 0L) + // TODO: This doesn't work with the trader demo, so use the underlying key instead + // "output states are issued by a command signer" by (issuer.party in issueCommand.signingParties) + "output states are issued by a command signer" using (issuer.party.owningKey in issueCommand.signers) + "output values sum to more than the inputs" using (outputAmount > inputAmount) + "there is only a single issue command" using (cashCommands.count() == 1) + } + } + + companion object { + // coin selection retry loop counter, sleep (msecs) and lock for selecting states + private val MAX_RETRIES = 5 + private val RETRY_SLEEP = 100 + private val spendLock: ReentrantLock = ReentrantLock() + /** + * Generate a transaction that moves an amount of currency to the given pubkey. + * + * Note: an [Amount] of [Currency] is only fungible for a given Issuer Party within a [FungibleAsset] + * + * @param services The [ServiceHub] to provide access to the database session. + * @param tx A builder, which may contain inputs, outputs and commands already. The relevant components needed + * to move the cash will be added on top. + * @param amount How much currency to send. + * @param to a key of the recipient. + * @param onlyFromParties if non-null, the asset states will be filtered to only include those issued by the set + * of given parties. This can be useful if the party you're trying to pay has expectations + * about which type of asset claims they are willing to accept. + * @return A [Pair] of the same transaction builder passed in as [tx], and the list of keys that need to sign + * the resulting transaction for it to be valid. + * @throws InsufficientBalanceException when a cash spending transaction fails because + * there is insufficient quantity for a given currency (and optionally set of Issuer Parties). + */ + @JvmStatic + @Throws(InsufficientBalanceException::class) + @Suspendable + fun generateSpend(services: ServiceHub, + tx: TransactionBuilder, + amount: Amount, + to: AbstractParty, + onlyFromParties: Set = emptySet()): Pair> { + + fun deriveState(txState: TransactionState, amt: Amount>, owner: AbstractParty) + = txState.copy(data = txState.data.copy(amount = amt, owner = owner)) + + // Retrieve unspent and unlocked cash states that meet our spending criteria. + val acceptableCoins = Cash.unconsumedCashStatesForSpending(services, amount, onlyFromParties, tx.notary, tx.lockId) + return OnLedgerAsset.generateSpend(tx, amount, to, acceptableCoins, + { state, quantity, owner -> deriveState(state, quantity, owner) }, + { Cash().generateMoveCommand() }) + + } + + /** + * An optimised query to gather Cash states that are available and retry if they are temporarily unavailable. + * @param services The service hub to allow access to the database session + * @param amount The amount of currency desired (ignoring issues, but specifying the currency) + * @param onlyFromIssuerParties If empty the operation ignores the specifics of the issuer, + * otherwise the set of eligible states wil be filtered to only include those from these issuers. + * @param notary If null the notary source is ignored, if specified then only states marked + * with this notary are included. + * @param lockId The [FlowLogic.runId.uuid] of the flow, which is used to soft reserve the states. + * Also, previous outputs of the flow will be eligible as they are implicitly locked with this id until the flow completes. + * @param withIssuerRefs If not empty the specific set of issuer references to match against. + * @return The matching states that were found. If sufficient funds were found these will be locked, + * otherwise what is available is returned unlocked for informational purposes. + */ + @JvmStatic + @Suspendable + fun unconsumedCashStatesForSpending(services: ServiceHub, + amount: Amount, + onlyFromIssuerParties: Set = emptySet(), + notary: Party? = null, + lockId: UUID, + withIssuerRefs: Set = emptySet()): List> { + + val issuerKeysStr = onlyFromIssuerParties.fold("") { left, right -> left + "('${right.owningKey.toBase58String()}')," }.dropLast(1) + val issuerRefsStr = withIssuerRefs.fold("") { left, right -> left + "('${right.bytes.toHexString()}')," }.dropLast(1) + + val stateAndRefs = mutableListOf>() + + // TODO: Need to provide a database provider independent means of performing this function. + // We are using an H2 specific means of selecting a minimum set of rows that match a request amount of coins: + // 1) There is no standard SQL mechanism of calculating a cumulative total on a field and restricting row selection on the + // running total of such an accumulator + // 2) H2 uses session variables to perform this accumulator function: + // http://www.h2database.com/html/functions.html#set + // 3) H2 does not support JOIN's in FOR UPDATE (hence we are forced to execute 2 queries) + + for (retryCount in 1..MAX_RETRIES) { + + spendLock.withLock { + val statement = services.jdbcSession().createStatement() + try { + statement.execute("CALL SET(@t, 0);") + + // we select spendable states irrespective of lock but prioritised by unlocked ones (Eg. null) + // the softLockReserve update will detect whether we try to lock states locked by others + val selectJoin = """ + SELECT vs.transaction_id, vs.output_index, vs.contract_state, ccs.pennies, SET(@t, ifnull(@t,0)+ccs.pennies) total_pennies, vs.lock_id + FROM vault_states AS vs, contract_cash_states AS ccs + WHERE vs.transaction_id = ccs.transaction_id AND vs.output_index = ccs.output_index + AND vs.state_status = 0 + AND ccs.ccy_code = '${amount.token}' and @t < ${amount.quantity} + AND (vs.lock_id = '$lockId' OR vs.lock_id is null) + """ + + (if (notary != null) + " AND vs.notary_key = '${notary.owningKey.toBase58String()}'" else "") + + (if (onlyFromIssuerParties.isNotEmpty()) + " AND ccs.issuer_key IN ($issuerKeysStr)" else "") + + (if (withIssuerRefs.isNotEmpty()) + " AND ccs.issuer_ref IN ($issuerRefsStr)" else "") + + // Retrieve spendable state refs + val rs = statement.executeQuery(selectJoin) + stateAndRefs.clear() + log.debug(selectJoin) + var totalPennies = 0L + while (rs.next()) { + val txHash = SecureHash.parse(rs.getString(1)) + val index = rs.getInt(2) + val stateRef = StateRef(txHash, index) + val state = rs.getBytes(3).deserialize>(context = SerializationDefaults.STORAGE_CONTEXT) + val pennies = rs.getLong(4) + totalPennies = rs.getLong(5) + val rowLockId = rs.getString(6) + stateAndRefs.add(StateAndRef(state, stateRef)) + log.trace { "ROW: $rowLockId ($lockId): $stateRef : $pennies ($totalPennies)" } + } + + if (stateAndRefs.isNotEmpty() && totalPennies >= amount.quantity) { + // we should have a minimum number of states to satisfy our selection `amount` criteria + log.trace("Coin selection for $amount retrieved ${stateAndRefs.count()} states totalling $totalPennies pennies: $stateAndRefs") + + // With the current single threaded state machine available states are guaranteed to lock. + // TODO However, we will have to revisit these methods in the future multi-threaded. + services.vaultService.softLockReserve(lockId, (stateAndRefs.map { it.ref }).toNonEmptySet()) + return stateAndRefs + } + log.trace("Coin selection requested $amount but retrieved $totalPennies pennies with state refs: ${stateAndRefs.map { it.ref }}") + // retry as more states may become available + } catch (e: SQLException) { + log.error("""Failed retrieving unconsumed states for: amount [$amount], onlyFromIssuerParties [$onlyFromIssuerParties], notary [$notary], lockId [$lockId] + $e. + """) + } catch (e: StatesNotAvailableException) { // Should never happen with single threaded state machine + stateAndRefs.clear() + log.warn(e.message) + // retry only if there are locked states that may become available again (or consumed with change) + } finally { + statement.close() + } + } + + log.warn("Coin selection failed on attempt $retryCount") + // TODO: revisit the back off strategy for contended spending. + if (retryCount != MAX_RETRIES) { + Strand.sleep(RETRY_SLEEP * retryCount.toLong()) + } + } + + log.warn("Insufficient spendable states identified for $amount") + return stateAndRefs + } + } + } // Small DSL extensions. diff --git a/finance/src/main/kotlin/net/corda/contracts/asset/CommodityContract.kt b/finance/src/main/kotlin/net/corda/contracts/asset/CommodityContract.kt index efa22e2264..f9f0a13523 100644 --- a/finance/src/main/kotlin/net/corda/contracts/asset/CommodityContract.kt +++ b/finance/src/main/kotlin/net/corda/contracts/asset/CommodityContract.kt @@ -1,18 +1,13 @@ package net.corda.contracts.asset import net.corda.contracts.Commodity -import net.corda.contracts.clause.AbstractConserveAmount -import net.corda.contracts.clause.AbstractIssue -import net.corda.contracts.clause.NoZeroSizedOutputs import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.AnyOf -import net.corda.core.contracts.clauses.GroupClauseVerifier -import net.corda.core.contracts.clauses.verifyClause import net.corda.core.crypto.SecureHash import net.corda.core.crypto.newSecureRandom import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder import java.util.* @@ -48,48 +43,6 @@ class CommodityContract : OnLedgerAsset>(AnyOf( - NoZeroSizedOutputs(), - Issue(), - ConserveAmount())) { - /** - * Group commodity states by issuance definition (issuer and underlying commodity). - */ - override fun groupStates(tx: TransactionForContract) - = tx.groupStates> { it.amount.token } - } - - /** - * Standard issue clause, specialised to match the commodity issue command. - */ - class Issue : AbstractIssue( - sum = { sumCommodities() }, - sumOrZero = { sumCommoditiesOrZero(it) } - ) { - override val requiredCommands: Set> = setOf(Commands.Issue::class.java) - } - - /** - * Standard clause for conserving the amount from input to output. - */ - @CordaSerializable - class ConserveAmount : AbstractConserveAmount() - } - /** A state representing a commodity claim against some party */ data class State( override val amount: Amount>, @@ -109,7 +62,7 @@ class CommodityContract : OnLedgerAsset>) : Commands, FungibleAsset.Commands.Exit } - override fun verify(tx: TransactionForContract) - = verifyClause(tx, Clauses.Group(), extractCommands(tx.commands)) + override fun verify(tx: LedgerTransaction) { + // Each group is a set of input/output states with distinct (reference, commodity) attributes. These types + // of commodity are not fungible and must be kept separated for bookkeeping purposes. + val groups = tx.groupStates { it: CommodityContract.State -> it.amount.token } + + for ((inputs, outputs, key) in groups) { + // Either inputs or outputs could be empty. + val issuer = key.issuer + val commodity = key.product + val party = issuer.party + + requireThat { + "there are no zero sized outputs" using ( outputs.none { it.amount.quantity == 0L } ) + } + + val issueCommand = tx.commands.select().firstOrNull() + if (issueCommand != null) { + verifyIssueCommand(inputs, outputs, tx, issueCommand, commodity, issuer) + } else { + val inputAmount = inputs.sumCommoditiesOrNull() ?: throw IllegalArgumentException("there is at least one commodity input for this group") + val outputAmount = outputs.sumCommoditiesOrZero(Issued(issuer, commodity)) + + // If we want to remove commodity from the ledger, that must be signed for by the issuer. + // A mis-signed or duplicated exit command will just be ignored here and result in the exit amount being zero. + val exitCommand = tx.commands.select(party = party).singleOrNull() + val amountExitingLedger = exitCommand?.value?.amount ?: Amount(0, Issued(issuer, commodity)) + + requireThat { + "there are no zero sized inputs" using ( inputs.none { it.amount.quantity == 0L } ) + "for reference ${issuer.reference} at issuer ${party.nameOrNull()} the amounts balance" using + (inputAmount == outputAmount + amountExitingLedger) + } + + verifyMoveCommand(inputs, tx.commands) + } + } + } + + private fun verifyIssueCommand(inputs: List, + outputs: List, + tx: LedgerTransaction, + issueCommand: AuthenticatedObject, + commodity: Commodity, + issuer: PartyAndReference) { + // If we have an issue command, perform special processing: the group is allowed to have no inputs, + // and the output states must have a deposit reference owned by the signer. + // + // Whilst the transaction *may* have no inputs, it can have them, and in this case the outputs must + // sum to more than the inputs. An issuance of zero size is not allowed. + // + // Note that this means literally anyone with access to the network can issue cash claims of arbitrary + // amounts! It is up to the recipient to decide if the backing party is trustworthy or not, via some + // as-yet-unwritten identity service. See ADP-22 for discussion. + + // The grouping ensures that all outputs have the same deposit reference and currency. + val inputAmount = inputs.sumCommoditiesOrZero(Issued(issuer, commodity)) + val outputAmount = outputs.sumCommodities() + val commodityCommands = tx.commands.select() + requireThat { + "the issue command has a nonce" using (issueCommand.value.nonce != 0L) + "output deposits are owned by a command signer" using (issuer.party in issueCommand.signingParties) + "output values sum to more than the inputs" using (outputAmount > inputAmount) + "there is only a single issue command" using (commodityCommands.count() == 1) + } + } override fun extractCommands(commands: Collection>): List> = commands.select() diff --git a/finance/src/main/kotlin/net/corda/contracts/asset/Obligation.kt b/finance/src/main/kotlin/net/corda/contracts/asset/Obligation.kt index dbc7cf54dc..2a37bda579 100644 --- a/finance/src/main/kotlin/net/corda/contracts/asset/Obligation.kt +++ b/finance/src/main/kotlin/net/corda/contracts/asset/Obligation.kt @@ -5,20 +5,19 @@ import net.corda.contracts.NetCommand import net.corda.contracts.NetType import net.corda.contracts.NettableState import net.corda.contracts.asset.Obligation.Lifecycle.NORMAL -import net.corda.contracts.clause.* import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.* import net.corda.core.crypto.SecureHash import net.corda.core.crypto.entropyToKeyPair -import net.corda.core.crypto.testing.NULL_PARTY +import net.corda.core.crypto.random63BitValue import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party -import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.Emoji import net.corda.core.serialization.CordaSerializable +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.Emoji import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.seconds import org.bouncycastle.asn1.x500.X500Name import java.math.BigInteger import java.security.PublicKey @@ -29,6 +28,37 @@ import kotlin.collections.component1 import kotlin.collections.component2 import kotlin.collections.set +/** + * Common interface for the state subsets used when determining nettability of two or more states. Exposes the + * underlying issued thing. + */ +interface NetState

{ + val template: Obligation.Terms

+} + +/** + * Subset of state, containing the elements which must match for two obligation transactions to be nettable. + * If two obligation state objects produce equal bilateral net states, they are considered safe to net directly. + * Bilateral states are used in close-out netting. + */ +data class BilateralNetState

( + val partyKeys: Set, + override val template: Obligation.Terms

+) : NetState

+ +/** + * Subset of state, containing the elements which must match for two or more obligation transactions to be candidates + * for netting (this does not include the checks to enforce that everyone's amounts received are the same at the end, + * which is handled under the verify() function). + * In comparison to [BilateralNetState], this doesn't include the parties' keys, as ensuring balances match on + * input and output is handled elsewhere. + * Used in cases where all parties (or their proxies) are signing, such as central clearing. + */ +data class MultilateralNetState

( + override val template: Obligation.Terms

+) : NetState

+ + // Just a fake program identifier for now. In a real system it could be, for instance, the hash of the program bytecode. val OBLIGATION_PROGRAM_ID = Obligation() @@ -54,186 +84,6 @@ class Obligation

: Contract { */ override val legalContractReference: SecureHash = SecureHash.sha256("https://www.big-book-of-banking-law.example.gov/cash-settlement.html") - interface Clauses { - /** - * Parent clause for clauses that operate on grouped states (those which are fungible). - */ - class Group

: GroupClauseVerifier, Commands, Issued>>( - AllOf( - NoZeroSizedOutputs, Commands, Terms

>(), - FirstOf( - SetLifecycle

(), - AllOf( - VerifyLifecycle, Commands, Issued>, P>(), - FirstOf( - Settle

(), - Issue(), - ConserveAmount() - ) - ) - ) - ) - ) { - override fun groupStates(tx: TransactionForContract): List, Issued>>> - = tx.groupStates, Issued>> { it.amount.token } - } - - /** - * Generic issuance clause - */ - class Issue

: AbstractIssue, Commands, Terms

>({ -> sumObligations() }, { token: Issued> -> sumObligationsOrZero(token) }) { - override val requiredCommands: Set> = setOf(Commands.Issue::class.java) - } - - /** - * Generic move/exit clause for fungible assets - */ - class ConserveAmount

: AbstractConserveAmount, Commands, Terms

>() - - /** - * Clause for supporting netting of obligations. - */ - class Net : NetClause() { - val lifecycleClause = Clauses.VerifyLifecycle() - override fun toString(): String = "Net obligations" - - override fun verify(tx: TransactionForContract, inputs: List, outputs: List, commands: List>, groupingKey: Unit?): Set { - lifecycleClause.verify(tx, inputs, outputs, commands, groupingKey) - return super.verify(tx, inputs, outputs, commands, groupingKey) - } - } - - /** - * Obligation-specific clause for changing the lifecycle of one or more states. - */ - class SetLifecycle

: Clause, Commands, Issued>>() { - override val requiredCommands: Set> = setOf(Commands.SetLifecycle::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List>, - outputs: List>, - commands: List>, - groupingKey: Issued>?): Set { - val command = commands.requireSingleCommand() - Obligation

().verifySetLifecycleCommand(inputs, outputs, tx, command) - return setOf(command.value) - } - - override fun toString(): String = "Set obligation lifecycle" - } - - /** - * Obligation-specific clause for settling an outstanding obligation by witnessing - * change of ownership of other states to fulfil - */ - class Settle

: Clause, Commands, Issued>>() { - override val requiredCommands: Set> = setOf(Commands.Settle::class.java) - override fun verify(tx: TransactionForContract, - inputs: List>, - outputs: List>, - commands: List>, - groupingKey: Issued>?): Set { - require(groupingKey != null) - val command = commands.requireSingleCommand>() - val obligor = groupingKey!!.issuer.party - val template = groupingKey.product - val inputAmount: Amount>> = inputs.sumObligationsOrNull

() ?: throw IllegalArgumentException("there is at least one obligation input for this group") - val outputAmount: Amount>> = outputs.sumObligationsOrZero(groupingKey) - - // Sum up all asset state objects that are moving and fulfil our requirements - - // The fungible asset contract verification handles ensuring there's inputs enough to cover the output states, - // we only care about counting how much is output in this transaction. We then calculate the difference in - // settlement amounts between the transaction inputs and outputs, and the two must match. No elimination is - // done of amounts paid in by each beneficiary, as it's presumed the beneficiaries have enough sense to do that - // themselves. Therefore if someone actually signed the following transaction (using cash just for an example): - // - // Inputs: - // £1m cash owned by B - // £1m owed from A to B - // Outputs: - // £1m cash owned by B - // Commands: - // Settle (signed by A) - // Move (signed by B) - // - // That would pass this check. Ensuring they do not is best addressed in the transaction generation stage. - val assetStates = tx.outputs.filterIsInstance>() - val acceptableAssetStates = assetStates - // TODO: This filter is nonsense, because it just checks there is an asset contract loaded, we need to - // verify the asset contract is the asset contract we expect. - // Something like: - // attachments.mustHaveOneOf(key.acceptableAssetContract) - .filter { it.contract.legalContractReference in template.acceptableContracts } - // Restrict the states to those of the correct issuance definition (this normally - // covers issued product and obligor, but is opaque to us) - .filter { it.amount.token in template.acceptableIssuedProducts } - // Catch that there's nothing useful here, so we can dump out a useful error - requireThat { - "there are fungible asset state outputs" using (assetStates.isNotEmpty()) - "there are defined acceptable fungible asset states" using (acceptableAssetStates.isNotEmpty()) - } - - val amountReceivedByOwner = acceptableAssetStates.groupBy { it.owner } - // Note we really do want to search all commands, because we want move commands of other contracts, not just - // this one. - val moveCommands = tx.commands.select() - var totalPenniesSettled = 0L - val requiredSigners = inputs.map { it.amount.token.issuer.party.owningKey }.toSet() - - for ((beneficiary, obligations) in inputs.groupBy { it.owner }) { - val settled = amountReceivedByOwner[beneficiary]?.sumFungibleOrNull

() - if (settled != null) { - val debt = obligations.sumObligationsOrZero(groupingKey) - require(settled.quantity <= debt.quantity) { "Payment of $settled must not exceed debt $debt" } - totalPenniesSettled += settled.quantity - } - } - - val totalAmountSettled = Amount(totalPenniesSettled, command.value.amount.token) - requireThat { - // Insist that we can be the only contract consuming inputs, to ensure no other contract can think it's being - // settled as well - "all move commands relate to this contract" using (moveCommands.map { it.value.contractHash } - .all { it == null || it == Obligation

().legalContractReference }) - // Settle commands exclude all other commands, so we don't need to check for contracts moving at the same - // time. - "amounts paid must match recipients to settle" using inputs.map { it.owner }.containsAll(amountReceivedByOwner.keys) - "amount in settle command ${command.value.amount} matches settled total $totalAmountSettled" using (command.value.amount == totalAmountSettled) - "signatures are present from all obligors" using command.signers.containsAll(requiredSigners) - "there are no zero sized inputs" using inputs.none { it.amount.quantity == 0L } - "at obligor ${obligor} the obligations after settlement balance" using - (inputAmount == outputAmount + Amount(totalPenniesSettled, groupingKey)) - } - return setOf(command.value) - } - } - - /** - * Obligation-specific clause for verifying that all states are in - * normal lifecycle. In a group clause set, this must be run after - * any lifecycle change clause, which is the only clause that involve - * non-standard lifecycle states on input/output. - */ - class VerifyLifecycle : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: T?): Set - = verify(inputs.filterIsInstance>(), outputs.filterIsInstance>()) - - private fun verify(inputs: List>, - outputs: List>): Set { - requireThat { - "all inputs are in the normal state " using inputs.all { it.lifecycle == Lifecycle.NORMAL } - "all outputs are in the normal state " using outputs.all { it.lifecycle == Lifecycle.NORMAL } - } - return emptySet() - } - } - } - /** * Represents where in its lifecycle a contract state is, which in turn controls the commands that can be applied * to the state. Most states will not leave the [NORMAL] lifecycle. Note that settled (as an end lifecycle) is @@ -254,6 +104,12 @@ class Obligation

: Contract { * Subset of state, containing the elements specified when issuing a new settlement contract. * * @param P the product the obligation is for payment of. + * @param acceptableContracts is the contract types that can be accepted, such as cash. + * @param acceptableIssuedProducts is the assets which are acceptable forms of payment (i.e. GBP issued by the Bank + * of England). + * @param dueBefore when payment is due by. + * @param timeTolerance tolerance value on [dueBefore], to handle clock skew between distributed systems. Generally + * this would be about 30 seconds. */ @CordaSerializable data class Terms

( @@ -264,7 +120,7 @@ class Obligation

: Contract { /** When the contract must be settled by. */ val dueBefore: Instant, - val timeTolerance: Duration = Duration.ofSeconds(30) + val timeTolerance: Duration = 30.seconds ) { val product: P get() = acceptableIssuedProducts.map { it.product }.toSet().single() @@ -326,7 +182,7 @@ class Obligation

: Contract { } } - override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Move(), copy(beneficiary = newOwner)) + override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Commands.Move(), copy(beneficiary = newOwner)) } // Just for grouping @@ -379,10 +235,209 @@ class Obligation

: Contract { data class Exit

(override val amount: Amount>>) : Commands, FungibleAsset.Commands.Exit> } - override fun verify(tx: TransactionForContract) = verifyClause(tx, FirstOf( - Clauses.Net(), - Clauses.Group

() - ), tx.commands.select()) + override fun verify(tx: LedgerTransaction) { + val netCommand = tx.commands.select().firstOrNull() + if (netCommand != null) { + verifyLifecycleCommand(tx.inputStates, tx.outputStates) + verifyNetCommand(tx, netCommand) + } else { + val groups = tx.groupStates { it: Obligation.State

-> it.amount.token } + for ((inputs, outputs, key) in groups) { + requireThat { + "there are no zero sized outputs" using (outputs.none { it.amount.quantity == 0L }) + } + val setLifecycleCommand = tx.commands.select().firstOrNull() + if (setLifecycleCommand != null) { + verifySetLifecycleCommand(inputs, outputs, tx, setLifecycleCommand) + } else { + verifyLifecycleCommand(inputs, outputs) + val settleCommand = tx.commands.select>().firstOrNull() + if (settleCommand != null) { + verifySettleCommand(tx, inputs, outputs, settleCommand, key) + } else { + val issueCommand = tx.commands.select().firstOrNull() + if (issueCommand != null) { + verifyIssueCommand(tx, inputs, outputs, issueCommand, key) + } else { + conserveAmount(tx, inputs, outputs, key) + } + } + } + } + } + } + + private fun conserveAmount(tx: LedgerTransaction, + inputs: List>>, + outputs: List>>, + key: Issued>) { + val issuer = key.issuer + val terms = key.product + val inputAmount = inputs.sumObligationsOrNull

() ?: throw IllegalArgumentException("there is at least one obligation input for this group") + val outputAmount = outputs.sumObligationsOrZero(Issued(issuer, terms)) + + // If we want to remove obligations from the ledger, that must be signed for by the issuer. + // A mis-signed or duplicated exit command will just be ignored here and result in the exit amount being zero. + val exitKeys: Set = inputs.flatMap { it.exitKeys }.toSet() + val exitCommand = tx.commands.select>(parties = null, signers = exitKeys).filter { it.value.amount.token == key }.singleOrNull() + val amountExitingLedger = exitCommand?.value?.amount ?: Amount(0, Issued(issuer, terms)) + + requireThat { + "there are no zero sized inputs" using (inputs.none { it.amount.quantity == 0L }) + "for reference ${issuer.reference} at issuer ${issuer.party.nameOrNull()} the amounts balance" using + (inputAmount == outputAmount + amountExitingLedger) + } + + verifyMoveCommand(inputs, tx.commands) + } + + private fun verifyIssueCommand(tx: LedgerTransaction, + inputs: List>>, + outputs: List>>, + issueCommand: AuthenticatedObject, + key: Issued>) { + // If we have an issue command, perform special processing: the group is allowed to have no inputs, + // and the output states must have a deposit reference owned by the signer. + // + // Whilst the transaction *may* have no inputs, it can have them, and in this case the outputs must + // sum to more than the inputs. An issuance of zero size is not allowed. + // + // Note that this means literally anyone with access to the network can issue cash claims of arbitrary + // amounts! It is up to the recipient to decide if the backing party is trustworthy or not, via some + // as-yet-unwritten identity service. See ADP-22 for discussion. + + // The grouping ensures that all outputs have the same deposit reference and currency. + val issuer = key.issuer + val terms = key.product + val inputAmount = inputs.sumObligationsOrZero(Issued(issuer, terms)) + val outputAmount = outputs.sumObligations

() + val issueCommands = tx.commands.select() + requireThat { + "the issue command has a nonce" using (issueCommand.value.nonce != 0L) + "output states are issued by a command signer" using (issuer.party in issueCommand.signingParties) + "output values sum to more than the inputs" using (outputAmount > inputAmount) + "there is only a single issue command" using (issueCommands.count() == 1) + } + } + + private fun verifySettleCommand(tx: LedgerTransaction, + inputs: List>>, + outputs: List>>, + command: AuthenticatedObject>, + groupingKey: Issued>) { + val obligor = groupingKey.issuer.party + val template = groupingKey.product + val inputAmount: Amount>> = inputs.sumObligationsOrNull

() ?: throw IllegalArgumentException("there is at least one obligation input for this group") + val outputAmount: Amount>> = outputs.sumObligationsOrZero(groupingKey) + + // Sum up all asset state objects that are moving and fulfil our requirements + + // The fungible asset contract verification handles ensuring there's inputs enough to cover the output states, + // we only care about counting how much is output in this transaction. We then calculate the difference in + // settlement amounts between the transaction inputs and outputs, and the two must match. No elimination is + // done of amounts paid in by each beneficiary, as it's presumed the beneficiaries have enough sense to do that + // themselves. Therefore if someone actually signed the following transaction (using cash just for an example): + // + // Inputs: + // £1m cash owned by B + // £1m owed from A to B + // Outputs: + // £1m cash owned by B + // Commands: + // Settle (signed by A) + // Move (signed by B) + // + // That would pass this check. Ensuring they do not is best addressed in the transaction generation stage. + val assetStates = tx.outputsOfType>() + val acceptableAssetStates = assetStates + // TODO: This filter is nonsense, because it just checks there is an asset contract loaded, we need to + // verify the asset contract is the asset contract we expect. + // Something like: + // attachments.mustHaveOneOf(key.acceptableAssetContract) + .filter { it.contract.legalContractReference in template.acceptableContracts } + // Restrict the states to those of the correct issuance definition (this normally + // covers issued product and obligor, but is opaque to us) + .filter { it.amount.token in template.acceptableIssuedProducts } + // Catch that there's nothing useful here, so we can dump out a useful error + requireThat { + "there are fungible asset state outputs" using (assetStates.isNotEmpty()) + "there are defined acceptable fungible asset states" using (acceptableAssetStates.isNotEmpty()) + } + + val amountReceivedByOwner = acceptableAssetStates.groupBy { it.owner } + // Note we really do want to search all commands, because we want move commands of other contracts, not just + // this one. + val moveCommands = tx.commands.select() + var totalPenniesSettled = 0L + val requiredSigners = inputs.map { it.amount.token.issuer.party.owningKey }.toSet() + + for ((beneficiary, obligations) in inputs.groupBy { it.owner }) { + val settled = amountReceivedByOwner[beneficiary]?.sumFungibleOrNull

() + if (settled != null) { + val debt = obligations.sumObligationsOrZero(groupingKey) + require(settled.quantity <= debt.quantity) { "Payment of $settled must not exceed debt $debt" } + totalPenniesSettled += settled.quantity + } + } + + val totalAmountSettled = Amount(totalPenniesSettled, command.value.amount.token) + requireThat { + // Insist that we can be the only contract consuming inputs, to ensure no other contract can think it's being + // settled as well + "all move commands relate to this contract" using (moveCommands.map { it.value.contractHash } + .all { it == null || it == Obligation

().legalContractReference }) + // Settle commands exclude all other commands, so we don't need to check for contracts moving at the same + // time. + "amounts paid must match recipients to settle" using inputs.map { it.owner }.containsAll(amountReceivedByOwner.keys) + "amount in settle command ${command.value.amount} matches settled total $totalAmountSettled" using (command.value.amount == totalAmountSettled) + "signatures are present from all obligors" using command.signers.containsAll(requiredSigners) + "there are no zero sized inputs" using inputs.none { it.amount.quantity == 0L } + "at obligor $obligor the obligations after settlement balance" using + (inputAmount == outputAmount + Amount(totalPenniesSettled, groupingKey)) + } + } + + private fun verifyLifecycleCommand(inputs: List, outputs: List) { + val filteredInputs = inputs.filterIsInstance>() + val filteredOutputs = outputs.filterIsInstance>() + requireThat { + "all inputs are in the normal state " using filteredInputs.all { it.lifecycle == Lifecycle.NORMAL } + "all outputs are in the normal state " using filteredOutputs.all { it.lifecycle == Lifecycle.NORMAL } + } + } + + private fun verifyNetCommand(tx: LedgerTransaction, command: AuthenticatedObject) { + val groups = when (command.value.type) { + NetType.CLOSE_OUT -> tx.groupStates { it: Obligation.State

-> it.bilateralNetState } + NetType.PAYMENT -> tx.groupStates { it: Obligation.State

-> it.multilateralNetState } + } + for ((groupInputs, groupOutputs, key) in groups) { + + val template = key.template + // Create two maps of balances from obligors to beneficiaries, one for input states, the other for output states. + val inputBalances = extractAmountsDue(template, groupInputs) + val outputBalances = extractAmountsDue(template, groupOutputs) + + // Sum the columns of the matrices. This will yield the net amount payable to/from each party to/from all other participants. + // The two summaries must match, reflecting that the amounts owed match on both input and output. + requireThat { + "all input states use the same template" using (groupInputs.all { it.template == template }) + "all output states use the same template" using (groupOutputs.all { it.template == template }) + "amounts owed on input and output must match" using (sumAmountsDue(inputBalances) == sumAmountsDue + (outputBalances)) + } + + // TODO: Handle proxies nominated by parties, i.e. a central clearing service + val involvedParties: Set = groupInputs.map { it.beneficiary.owningKey }.union(groupInputs.map { it.obligor.owningKey }).toSet() + when (command.value.type) { + // For close-out netting, allow any involved party to sign + NetType.CLOSE_OUT -> require(command.signers.intersect(involvedParties).isNotEmpty()) { "any involved party has signed" } + // Require signatures from all parties (this constraint can be changed for other contracts, and is used as a + // placeholder while exact requirements are established), or fail the transaction. + NetType.PAYMENT -> require(command.signers.containsAll(involvedParties)) { "all involved parties have signed" } + } + } + } /** * A default command mutates inputs and produces identical outputs, except that the lifecycle changes. @@ -390,7 +445,7 @@ class Obligation

: Contract { @VisibleForTesting private fun verifySetLifecycleCommand(inputs: List>>, outputs: List>>, - tx: TransactionForContract, + tx: LedgerTransaction, setLifecycleCommand: AuthenticatedObject) { // Default must not change anything except lifecycle, so number of inputs and outputs must match // exactly. @@ -467,8 +522,39 @@ class Obligation

: Contract { generateExitCommand = { amount -> Commands.Exit(amount) } ) + /** + * Puts together an issuance transaction for the specified currency obligation amount that starts out being owned by + * the given pubkey. + * + * @param tx transaction builder to add states and commands to. + * @param obligor the party who is expected to pay some currency amount to fulfil the obligation (also the owner of + * the obligation). + * @param amount currency amount the obligor is expected to pay. + * @param dueBefore the date on which the obligation is due. The default time tolerance is used (currently this is + * 30 seconds). + * @param beneficiary the party the obligor is expected to pay. + * @param notary the notary for this transaction's outputs. + */ + fun generateCashIssue(tx: TransactionBuilder, + obligor: AbstractParty, + amount: Amount>, + dueBefore: Instant, + beneficiary: AbstractParty, + notary: Party) { + val issuanceDef = Terms(NonEmptySet.of(Cash().legalContractReference), NonEmptySet.of(amount.token), dueBefore) + OnLedgerAsset.generateIssue(tx, TransactionState(State(Lifecycle.NORMAL, obligor, issuanceDef, amount.quantity, beneficiary), notary), Commands.Issue()) + } + /** * Puts together an issuance transaction for the specified amount that starts out being owned by the given pubkey. + * + * @param tx transaction builder to add states and commands to. + * @param obligor the party who is expected to pay some amount to fulfil the obligation. + * @param issuanceDef the terms of the obligation, including which contracts and underlying assets are acceptable + * forms of payment. + * @param pennies the quantity of the asset (in the smallest normal unit of measurement) owed. + * @param beneficiary the party the obligor is expected to pay. + * @param notary the notary for this transaction's outputs. */ fun generateIssue(tx: TransactionBuilder, obligor: AbstractParty, @@ -476,7 +562,7 @@ class Obligation

: Contract { pennies: Long, beneficiary: AbstractParty, notary: Party) - = OnLedgerAsset.generateIssue(tx, TransactionState(State(Lifecycle.NORMAL, obligor, issuanceDef, pennies, beneficiary), notary), Commands.Issue()) + = OnLedgerAsset.generateIssue(tx, TransactionState(State(Lifecycle.NORMAL, obligor, issuanceDef, pennies, beneficiary), notary), Commands.Issue()) fun generatePaymentNetting(tx: TransactionBuilder, issued: Issued>, @@ -644,7 +730,7 @@ fun

extractAmountsDue(product: Obligation.Terms

, states: Iterable netAmountsDue(balances: Map, Amount>): Map, Amount> { +fun

netAmountsDue(balances: Map, Amount>): Map, Amount> { val nettedBalances = HashMap, Amount>() balances.forEach { balance -> @@ -671,7 +757,7 @@ fun netAmountsDue(balances: Map, Amount sumAmountsDue(balances: Map, Amount>): Map { +fun

sumAmountsDue(balances: Map, Amount>): Map { val sum = HashMap() // Fill the map with zeroes initially diff --git a/finance/src/main/kotlin/net/corda/contracts/asset/OnLedgerAsset.kt b/finance/src/main/kotlin/net/corda/contracts/asset/OnLedgerAsset.kt index 383ddab54d..c40649b789 100644 --- a/finance/src/main/kotlin/net/corda/contracts/asset/OnLedgerAsset.kt +++ b/finance/src/main/kotlin/net/corda/contracts/asset/OnLedgerAsset.kt @@ -226,8 +226,6 @@ abstract class OnLedgerAsset> : C * * @param tx transaction builder to add states and commands to. * @param amountIssued the amount to be exited, represented as a quantity of issued currency. - * @param changeKey the key to send any change to. This needs to be explicitly stated as the input states are not - * necessarily owned by us. * @param assetStates the asset states to take funds from. No checks are done about ownership of these states, it is * the responsibility of the caller to check that they do not exit funds held by others. * @return the public keys which must sign the transaction for it to be valid. diff --git a/finance/src/main/kotlin/net/corda/contracts/clause/AbstractConserveAmount.kt b/finance/src/main/kotlin/net/corda/contracts/clause/AbstractConserveAmount.kt deleted file mode 100644 index f2fa484632..0000000000 --- a/finance/src/main/kotlin/net/corda/contracts/clause/AbstractConserveAmount.kt +++ /dev/null @@ -1,70 +0,0 @@ -package net.corda.contracts.clause - -import net.corda.contracts.asset.OnLedgerAsset -import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.Clause -import net.corda.core.identity.AbstractParty -import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.loggerFor -import java.security.PublicKey - -/** - * Standardised clause for checking input/output balances of fungible assets. Requires that a - * Move command is provided, and errors if absent. Must be the last clause under a grouping clause; - * errors on no-match, ends on match. - */ -abstract class AbstractConserveAmount, C : CommandData, T : Any> : Clause>() { - - private companion object { - val log = loggerFor>() - } - - /** - * Generate an transaction exiting fungible assets from the ledger. - * - * @param tx transaction builder to add states and commands to. - * @param amountIssued the amount to be exited, represented as a quantity of issued currency. - * @param assetStates the asset states to take funds from. No checks are done about ownership of these states, it is - * the responsibility of the caller to check that they do not attempt to exit funds held by others. - * @return the public keys which must sign the transaction for it to be valid. - */ - @Deprecated("This function will be removed in a future milestone", ReplaceWith("OnLedgerAsset.generateExit()")) - @Throws(InsufficientBalanceException::class) - fun generateExit(tx: TransactionBuilder, amountIssued: Amount>, - assetStates: List>, - deriveState: (TransactionState, Amount>, AbstractParty) -> TransactionState, - generateMoveCommand: () -> CommandData, - generateExitCommand: (Amount>) -> CommandData): Set - = OnLedgerAsset.generateExit(tx, amountIssued, assetStates, deriveState, generateMoveCommand, generateExitCommand) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Issued?): Set { - require(groupingKey != null) { "Conserve amount clause can only be used on grouped states" } - val matchedCommands = commands.filter { command -> command.value is FungibleAsset.Commands.Move || command.value is FungibleAsset.Commands.Exit<*> } - val inputAmount: Amount> = inputs.sumFungibleOrNull() ?: throw IllegalArgumentException("there is at least one asset input for group $groupingKey") - val deposit = groupingKey!!.issuer - val outputAmount: Amount> = outputs.sumFungibleOrZero(groupingKey) - - // If we want to remove assets from the ledger, that must be signed for by the issuer and owner. - val exitKeys: Set = inputs.flatMap { it.exitKeys }.toSet() - val exitCommand = matchedCommands.select>(parties = null, signers = exitKeys).filter { it.value.amount.token == groupingKey }.singleOrNull() - val amountExitingLedger: Amount> = exitCommand?.value?.amount ?: Amount(0, groupingKey) - - requireThat { - "there are no zero sized inputs" using inputs.none { it.amount.quantity == 0L } - "for reference ${deposit.reference} at issuer ${deposit.party} the amounts balance: ${inputAmount.quantity} - ${amountExitingLedger.quantity} != ${outputAmount.quantity}" using - (inputAmount == outputAmount + amountExitingLedger) - } - - verifyMoveCommand(inputs, commands) - - // This is safe because we've taken the commands from a collection of C objects at the start - @Suppress("UNCHECKED_CAST") - return matchedCommands.map { it.value }.toSet() - } - - override fun toString(): String = "Conserve amount between inputs and outputs" -} diff --git a/finance/src/main/kotlin/net/corda/contracts/clause/AbstractIssue.kt b/finance/src/main/kotlin/net/corda/contracts/clause/AbstractIssue.kt deleted file mode 100644 index b437daa907..0000000000 --- a/finance/src/main/kotlin/net/corda/contracts/clause/AbstractIssue.kt +++ /dev/null @@ -1,55 +0,0 @@ -package net.corda.contracts.clause - -import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.Clause - -/** - * Standard issue clause for contracts that issue fungible assets. - * - * @param S the type of contract state which is being issued. - * @param T the token underlying the issued state. - * @param sum function to convert a list of states into an amount of the token. Must error if there are no states in - * the list. - * @param sumOrZero function to convert a list of states into an amount of the token, and returns zero if there are - * no states in the list. Takes in an instance of the token definition for constructing the zero amount if needed. - */ -abstract class AbstractIssue( - val sum: List.() -> Amount>, - val sumOrZero: List.(token: Issued) -> Amount> -) : Clause>() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Issued?): Set { - require(groupingKey != null) - // TODO: Take in matched commands as a parameter - val issueCommand = commands.requireSingleCommand() - - // If we have an issue command, perform special processing: the group is allowed to have no inputs, - // and the output states must have a deposit reference owned by the signer. - // - // Whilst the transaction *may* have no inputs, it can have them, and in this case the outputs must - // sum to more than the inputs. An issuance of zero size is not allowed. - // - // Note that this means literally anyone with access to the network can issue asset claims of arbitrary - // amounts! It is up to the recipient to decide if the backing party is trustworthy or not, via some - // external mechanism (such as locally defined rules on which parties are trustworthy). - - // The grouping already ensures that all outputs have the same deposit reference and token. - val issuer = groupingKey!!.issuer.party - val inputAmount = inputs.sumOrZero(groupingKey) - val outputAmount = outputs.sum() - requireThat { - "the issue command has a nonce" using (issueCommand.value.nonce != 0L) - // TODO: This doesn't work with the trader demo, so use the underlying key instead - // "output states are issued by a command signer" by (issuer in issueCommand.signingParties) - "output states are issued by a command signer" using (issuer.owningKey in issueCommand.signers) - "output values sum to more than the inputs" using (outputAmount > inputAmount) - } - - // This is safe because we've taken the command from a collection of C objects at the start - @Suppress("UNCHECKED_CAST") - return setOf(issueCommand.value as C) - } -} diff --git a/finance/src/main/kotlin/net/corda/contracts/clause/Net.kt b/finance/src/main/kotlin/net/corda/contracts/clause/Net.kt deleted file mode 100644 index 5c791b1a84..0000000000 --- a/finance/src/main/kotlin/net/corda/contracts/clause/Net.kt +++ /dev/null @@ -1,101 +0,0 @@ -package net.corda.contracts.clause - -import com.google.common.annotations.VisibleForTesting -import net.corda.contracts.NetCommand -import net.corda.contracts.NetType -import net.corda.contracts.asset.Obligation -import net.corda.contracts.asset.extractAmountsDue -import net.corda.contracts.asset.sumAmountsDue -import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.Clause -import net.corda.core.identity.AbstractParty -import java.security.PublicKey - -/** - * Common interface for the state subsets used when determining nettability of two or more states. Exposes the - * underlying issued thing. - */ -interface NetState

{ - val template: Obligation.Terms

-} - -/** - * Subset of state, containing the elements which must match for two obligation transactions to be nettable. - * If two obligation state objects produce equal bilateral net states, they are considered safe to net directly. - * Bilateral states are used in close-out netting. - */ -data class BilateralNetState

( - val partyKeys: Set, - override val template: Obligation.Terms

-) : NetState

- -/** - * Subset of state, containing the elements which must match for two or more obligation transactions to be candidates - * for netting (this does not include the checks to enforce that everyone's amounts received are the same at the end, - * which is handled under the verify() function). - * In comparison to [BilateralNetState], this doesn't include the parties' keys, as ensuring balances match on - * input and output is handled elsewhere. - * Used in cases where all parties (or their proxies) are signing, such as central clearing. - */ -data class MultilateralNetState

( - override val template: Obligation.Terms

-) : NetState

- -/** - * Clause for netting contract states. Currently only supports obligation contract. - */ -// TODO: Make this usable for any nettable contract states -open class NetClause : Clause() { - override val requiredCommands: Set> = setOf(Obligation.Commands.Net::class.java) - - @Suppress("ConvertLambdaToReference") - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Unit?): Set { - val matchedCommands: List> = commands.filter { it.value is NetCommand } - val command = matchedCommands.requireSingleCommand() - val groups = when (command.value.type) { - NetType.CLOSE_OUT -> tx.groupStates { it: Obligation.State

-> it.bilateralNetState } - NetType.PAYMENT -> tx.groupStates { it: Obligation.State

-> it.multilateralNetState } - } - for ((groupInputs, groupOutputs, key) in groups) { - verifyNetCommand(groupInputs, groupOutputs, command, key) - } - return matchedCommands.map { it.value }.toSet() - } - - /** - * Verify a netting command. This handles both close-out and payment netting. - */ - @VisibleForTesting - fun verifyNetCommand(inputs: List>, - outputs: List>, - command: AuthenticatedObject, - netState: NetState

) { - val template = netState.template - // Create two maps of balances from obligors to beneficiaries, one for input states, the other for output states. - val inputBalances = extractAmountsDue(template, inputs) - val outputBalances = extractAmountsDue(template, outputs) - - // Sum the columns of the matrices. This will yield the net amount payable to/from each party to/from all other participants. - // The two summaries must match, reflecting that the amounts owed match on both input and output. - requireThat { - "all input states use the same template" using (inputs.all { it.template == template }) - "all output states use the same template" using (outputs.all { it.template == template }) - "amounts owed on input and output must match" using (sumAmountsDue(inputBalances) == sumAmountsDue - (outputBalances)) - } - - // TODO: Handle proxies nominated by parties, i.e. a central clearing service - val involvedParties: Set = inputs.map { it.beneficiary.owningKey }.union(inputs.map { it.obligor.owningKey }).toSet() - when (command.value.type) { - // For close-out netting, allow any involved party to sign - NetType.CLOSE_OUT -> require(command.signers.intersect(involvedParties).isNotEmpty()) { "any involved party has signed" } - // Require signatures from all parties (this constraint can be changed for other contracts, and is used as a - // placeholder while exact requirements are established), or fail the transaction. - NetType.PAYMENT -> require(command.signers.containsAll(involvedParties)) { "all involved parties have signed" } - } - } -} diff --git a/finance/src/main/kotlin/net/corda/contracts/clause/NoZeroSizedOutputs.kt b/finance/src/main/kotlin/net/corda/contracts/clause/NoZeroSizedOutputs.kt deleted file mode 100644 index 3987ce4ec5..0000000000 --- a/finance/src/main/kotlin/net/corda/contracts/clause/NoZeroSizedOutputs.kt +++ /dev/null @@ -1,23 +0,0 @@ -package net.corda.contracts.clause - -import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.Clause - -/** - * Clause for fungible asset contracts, which enforces that no output state should have - * a balance of zero. - */ -open class NoZeroSizedOutputs, C : CommandData, T : Any> : Clause>() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Issued?): Set { - requireThat { - "there are no zero sized outputs" using outputs.none { it.amount.quantity == 0L } - } - return emptySet() - } - - override fun toString(): String = "No zero sized outputs" -} diff --git a/finance/src/main/kotlin/net/corda/flows/AbstractCashFlow.kt b/finance/src/main/kotlin/net/corda/flows/AbstractCashFlow.kt index 52bef42f21..fd809ab7e5 100644 --- a/finance/src/main/kotlin/net/corda/flows/AbstractCashFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/AbstractCashFlow.kt @@ -1,8 +1,10 @@ package net.corda.flows import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FinalityFlow import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic +import net.corda.core.flows.NotaryException import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable @@ -12,7 +14,7 @@ import net.corda.core.utilities.ProgressTracker /** * Initiates a flow that produces an Issue/Move or Exit Cash transaction. */ -abstract class AbstractCashFlow(override val progressTracker: ProgressTracker) : FlowLogic() { +abstract class AbstractCashFlow(override val progressTracker: ProgressTracker) : FlowLogic() { companion object { object GENERATING_ID : ProgressTracker.Step("Generating anonymous identities") object GENERATING_TX : ProgressTracker.Step("Generating transaction") diff --git a/finance/src/main/kotlin/net/corda/flows/CashExitFlow.kt b/finance/src/main/kotlin/net/corda/flows/CashExitFlow.kt index 9a363fffc2..abf0b511c8 100644 --- a/finance/src/main/kotlin/net/corda/flows/CashExitFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/CashExitFlow.kt @@ -4,12 +4,15 @@ import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.asset.Cash import net.corda.core.contracts.Amount import net.corda.core.contracts.InsufficientBalanceException -import net.corda.core.contracts.TransactionType import net.corda.core.contracts.issuedBy import net.corda.core.flows.StartableByRPC import net.corda.core.identity.Party -import net.corda.core.utilities.OpaqueBytes +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM +import net.corda.core.node.services.vault.PageSpecification +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.ProgressTracker import java.util.* @@ -36,9 +39,9 @@ class CashExitFlow(val amount: Amount, val issueRef: OpaqueBytes, prog @Throws(CashException::class) override fun call(): AbstractCashFlow.Result { progressTracker.currentStep = GENERATING_TX - val builder: TransactionBuilder = TransactionType.General.Builder(notary = null as Party?) + val builder: TransactionBuilder = TransactionBuilder(notary = null as Party?) val issuer = serviceHub.myInfo.legalIdentity.ref(issueRef) - val exitStates = serviceHub.vaultService.unconsumedStatesForSpending(amount, setOf(issuer.party), builder.notary, builder.lockId, setOf(issuer.reference)) + val exitStates = Cash.unconsumedCashStatesForSpending(serviceHub, amount, setOf(issuer.party), builder.notary, builder.lockId, setOf(issuer.reference)) val signers = try { Cash().generateExit( builder, @@ -48,13 +51,9 @@ class CashExitFlow(val amount: Amount, val issueRef: OpaqueBytes, prog throw CashException("Exiting more cash than exists", e) } - // Work out who the owners of the burnt states were - val inputStatesNullable = serviceHub.vaultService.statesForRefs(builder.inputStates()) - val inputStates = inputStatesNullable.values.filterNotNull().map { it.data } - if (inputStatesNullable.size != inputStates.size) { - val unresolvedStateRefs = inputStatesNullable.filter { it.value == null }.map { it.key } - throw IllegalStateException("Failed to resolve input StateRefs: $unresolvedStateRefs") - } + // Work out who the owners of the burnt states were (specify page size so we don't silently drop any if > DEFAULT_PAGE_SIZE) + val inputStates = serviceHub.vaultQueryService.queryBy(VaultQueryCriteria(stateRefs = builder.inputStates()), + PageSpecification(pageNumber = DEFAULT_PAGE_NUM, pageSize = builder.inputStates().size)).states // TODO: Is it safe to drop participants we don't know how to contact? Does not knowing how to contact them // count as a reason to fail? diff --git a/finance/src/main/kotlin/net/corda/flows/CashIssueFlow.kt b/finance/src/main/kotlin/net/corda/flows/CashIssueFlow.kt index ec9659f1d5..99304f122c 100644 --- a/finance/src/main/kotlin/net/corda/flows/CashIssueFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/CashIssueFlow.kt @@ -3,12 +3,14 @@ package net.corda.flows import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.asset.Cash import net.corda.core.contracts.Amount -import net.corda.core.contracts.TransactionType import net.corda.core.contracts.issuedBy +import net.corda.core.flows.FinalityFlow import net.corda.core.flows.StartableByRPC +import net.corda.core.flows.TransactionKeyFlow +import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party -import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.ProgressTracker import java.util.* @@ -43,11 +45,11 @@ class CashIssueFlow(val amount: Amount, val txIdentities = if (anonymous) { subFlow(TransactionKeyFlow(recipient)) } else { - emptyMap() + emptyMap() } - val anonymousRecipient = txIdentities.get(recipient)?.identity ?: recipient + val anonymousRecipient = txIdentities[recipient] ?: recipient progressTracker.currentStep = GENERATING_TX - val builder: TransactionBuilder = TransactionType.General.Builder(notary = notary) + val builder: TransactionBuilder = TransactionBuilder(notary) val issuer = serviceHub.myInfo.legalIdentity.ref(issueRef) val signers = Cash().generateIssue(builder, amount.issuedBy(issuer), anonymousRecipient, notary) progressTracker.currentStep = SIGNING_TX diff --git a/finance/src/main/kotlin/net/corda/flows/CashPaymentFlow.kt b/finance/src/main/kotlin/net/corda/flows/CashPaymentFlow.kt index 0567e2c77e..1b3032c565 100644 --- a/finance/src/main/kotlin/net/corda/flows/CashPaymentFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/CashPaymentFlow.kt @@ -1,10 +1,12 @@ package net.corda.flows import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.asset.Cash import net.corda.core.contracts.Amount import net.corda.core.contracts.InsufficientBalanceException -import net.corda.core.contracts.TransactionType import net.corda.core.flows.StartableByRPC +import net.corda.core.flows.TransactionKeyFlow +import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.ProgressTracker @@ -25,7 +27,7 @@ open class CashPaymentFlow( val recipient: Party, val anonymous: Boolean, progressTracker: ProgressTracker, - val issuerConstraint: Set? = null) : AbstractCashFlow(progressTracker) { + val issuerConstraint: Set = emptySet()) : AbstractCashFlow(progressTracker) { /** A straightforward constructor that constructs spends using cash states of any issuer. */ constructor(amount: Amount, recipient: Party) : this(amount, recipient, true, tracker()) /** A straightforward constructor that constructs spends using cash states of any issuer. */ @@ -37,14 +39,14 @@ open class CashPaymentFlow( val txIdentities = if (anonymous) { subFlow(TransactionKeyFlow(recipient)) } else { - emptyMap() + emptyMap() } - val anonymousRecipient = txIdentities.get(recipient)?.identity ?: recipient + val anonymousRecipient = txIdentities.get(recipient) ?: recipient progressTracker.currentStep = GENERATING_TX - val builder: TransactionBuilder = TransactionType.General.Builder(null as Party?) + val builder: TransactionBuilder = TransactionBuilder(null as Party?) // TODO: Have some way of restricting this to states the caller controls val (spendTX, keysForSigning) = try { - serviceHub.vaultService.generateSpend( + Cash.generateSpend(serviceHub, builder, amount, anonymousRecipient, diff --git a/finance/src/main/kotlin/net/corda/flows/IssuerFlow.kt b/finance/src/main/kotlin/net/corda/flows/IssuerFlow.kt index 72d11fe0c7..ccd6001fdf 100644 --- a/finance/src/main/kotlin/net/corda/flows/IssuerFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/IssuerFlow.kt @@ -6,8 +6,8 @@ import net.corda.core.contracts.* import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.serialization.CordaSerializable -import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.unwrap import java.util.* @@ -49,10 +49,7 @@ object IssuerFlow { return sendAndReceive(issuerBankParty, issueRequest).unwrap { res -> val tx = res.stx.tx val expectedAmount = Amount(amount.quantity, Issued(issuerBankParty.ref(issueToPartyRef), amount.token)) - val cashOutputs = tx.outputs - .map { it.data} - .filterIsInstance() - .filter { state -> state.owner == res.recipient } + val cashOutputs = tx.filterOutputs { state -> state.owner == res.recipient } require(cashOutputs.size == 1) { "Require a single cash output paying ${res.recipient}, found ${tx.outputs}" } require(cashOutputs.single().amount == expectedAmount) { "Require payment of $expectedAmount"} res diff --git a/finance/src/main/kotlin/net/corda/flows/TwoPartyDealFlow.kt b/finance/src/main/kotlin/net/corda/flows/TwoPartyDealFlow.kt index 8509e55c9b..358de91328 100644 --- a/finance/src/main/kotlin/net/corda/flows/TwoPartyDealFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/TwoPartyDealFlow.kt @@ -4,30 +4,30 @@ import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.DealState import net.corda.core.contracts.requireThat import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.TransactionSignature +import net.corda.core.flows.CollectSignaturesFlow +import net.corda.core.flows.FinalityFlow import net.corda.core.flows.FlowLogic +import net.corda.core.flows.SignTransactionFlow import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.node.NodeInfo import net.corda.core.node.services.ServiceType -import net.corda.core.seconds import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.seconds import net.corda.core.utilities.trace import net.corda.core.utilities.unwrap import java.security.PublicKey /** * Classes for manipulating a two party deal or agreement. - * - * TODO: The subclasses should probably be broken out into individual flows rather than making this an ever expanding collection of subclasses. - * - * TODO: Also, the term Deal is used here where we might prefer Agreement. - * - * TODO: Make this flow more generic. - * */ +// TODO: The subclasses should probably be broken out into individual flows rather than making this an ever expanding collection of subclasses. +// TODO: Also, the term Deal is used here where we might prefer Agreement. +// TODO: Make this flow more generic. object TwoPartyDealFlow { // This object is serialised to the network and is the first flow message the seller sends to the buyer. @CordaSerializable @@ -93,8 +93,12 @@ object TwoPartyDealFlow { val handshake = receiveAndValidateHandshake() progressTracker.currentStep = SIGNING - val (utx, additionalSigningPubKeys) = assembleSharedTX(handshake) - val ptx = signWithOurKeys(additionalSigningPubKeys, utx) + val (utx, additionalSigningPubKeys, additionalSignatures) = assembleSharedTX(handshake) + val ptx = if (additionalSignatures.any()) { + serviceHub.signInitialTransaction(utx, additionalSigningPubKeys).withAdditionalSignatures(additionalSignatures) + } else { + serviceHub.signInitialTransaction(utx, additionalSigningPubKeys) + } logger.trace { "Signed proposed transaction." } @@ -137,13 +141,8 @@ object TwoPartyDealFlow { return handshake.unwrap { validateHandshake(it) } } - private fun signWithOurKeys(signingPubKeys: List, ptx: TransactionBuilder): SignedTransaction { - // Now sign the transaction with whatever keys we need to move the cash. - return serviceHub.signInitialTransaction(ptx, signingPubKeys) - } - @Suspendable protected abstract fun validateHandshake(handshake: Handshake): Handshake - @Suspendable protected abstract fun assembleSharedTX(handshake: Handshake): Pair> + @Suspendable protected abstract fun assembleSharedTX(handshake: Handshake): Triple, List> } @CordaSerializable @@ -175,18 +174,18 @@ object TwoPartyDealFlow { // What is the seller trying to sell us? val autoOffer = handshake.payload val deal = autoOffer.dealBeingOffered - logger.trace { "Got deal request for: ${deal.ref}" } + logger.trace { "Got deal request for: ${deal.linearId.externalId!!}" } return handshake.copy(payload = autoOffer.copy(dealBeingOffered = deal)) } - override fun assembleSharedTX(handshake: Handshake): Pair> { + override fun assembleSharedTX(handshake: Handshake): Triple, List> { val deal = handshake.payload.dealBeingOffered val ptx = deal.generateAgreement(handshake.payload.notary) // We set the transaction's time-window: it may be that none of the contracts need this! // But it can't hurt to have one. ptx.setTimeWindow(serviceHub.clock.instant(), 30.seconds) - return Pair(ptx, arrayListOf(deal.participants.single { it == serviceHub.myInfo.legalIdentity as AbstractParty }.owningKey)) + return Triple(ptx, arrayListOf(deal.participants.single { it == serviceHub.myInfo.legalIdentity as AbstractParty }.owningKey), emptyList()) } } } diff --git a/finance/src/main/kotlin/net/corda/flows/TwoPartyTradeFlow.kt b/finance/src/main/kotlin/net/corda/flows/TwoPartyTradeFlow.kt index d384e1255e..e63be156e3 100644 --- a/finance/src/main/kotlin/net/corda/flows/TwoPartyTradeFlow.kt +++ b/finance/src/main/kotlin/net/corda/flows/TwoPartyTradeFlow.kt @@ -1,19 +1,22 @@ package net.corda.flows import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.asset.Cash import net.corda.contracts.asset.sumCashBy -import net.corda.core.contracts.* -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowLogic +import net.corda.core.contracts.Amount +import net.corda.core.contracts.OwnableState +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.withoutIssuer +import net.corda.core.flows.* import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.node.NodeInfo -import net.corda.core.seconds import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.seconds import net.corda.core.utilities.unwrap import java.security.PublicKey import java.util.* @@ -48,7 +51,6 @@ object TwoPartyTradeFlow { // This object is serialised to the network and is the first flow message the seller sends to the buyer. @CordaSerializable data class SellerTradeInfo( - val assetForSale: StateAndRef, val price: Amount, val sellerOwner: AbstractParty ) @@ -76,17 +78,18 @@ object TwoPartyTradeFlow { override fun call(): SignedTransaction { progressTracker.currentStep = AWAITING_PROPOSAL // Make the first message we'll send to kick off the flow. - val hello = SellerTradeInfo(assetToSell, price, me) + val hello = SellerTradeInfo(price, me) // What we get back from the other side is a transaction that *might* be valid and acceptable to us, // but we must check it out thoroughly before we sign! + // SendTransactionFlow allows otherParty to access our data to resolve the transaction. + subFlow(SendStateAndRefFlow(otherParty, listOf(assetToSell))) send(otherParty, hello) - // Verify and sign the transaction. progressTracker.currentStep = VERIFYING_AND_SIGNING // DOCSTART 5 val signTransactionFlow = object : SignTransactionFlow(otherParty, VERIFYING_AND_SIGNING.childProgressTracker()) { override fun checkTransaction(stx: SignedTransaction) { - if (stx.tx.outputs.map { it.data }.sumCashBy(me).withoutIssuer() != price) + if (stx.tx.outputStates.sumCashBy(me).withoutIssuer() != price) throw FlowException("Transaction is not sending us the right amount of cash") } } @@ -114,11 +117,13 @@ object TwoPartyTradeFlow { val typeToBuy: Class) : FlowLogic() { // DOCSTART 2 object RECEIVING : ProgressTracker.Step("Waiting for seller trading info") + object VERIFYING : ProgressTracker.Step("Verifying seller assets") object SIGNING : ProgressTracker.Step("Generating and signing transaction proposal") object COLLECTING_SIGNATURES : ProgressTracker.Step("Collecting signatures from other parties") { override fun childProgressTracker() = CollectSignaturesFlow.tracker() } + object RECORDING : ProgressTracker.Step("Recording completed transaction") { // TODO: Currently triggers a race condition on Team City. See https://github.com/corda/corda/issues/733. // override fun childProgressTracker() = FinalityFlow.tracker() @@ -132,69 +137,57 @@ object TwoPartyTradeFlow { override fun call(): SignedTransaction { // Wait for a trade request to come in from the other party. progressTracker.currentStep = RECEIVING - val tradeRequest = receiveAndValidateTradeRequest() + val (assetForSale, tradeRequest) = receiveAndValidateTradeRequest() // Put together a proposed transaction that performs the trade, and sign it. progressTracker.currentStep = SIGNING - val (ptx, cashSigningPubKeys) = assembleSharedTX(tradeRequest) - val partSignedTx = signWithOurKeys(cashSigningPubKeys, ptx) + val (ptx, cashSigningPubKeys) = assembleSharedTX(assetForSale, tradeRequest) + // Now sign the transaction with whatever keys we need to move the cash. + val partSignedTx = serviceHub.signInitialTransaction(ptx, cashSigningPubKeys) // Send the signed transaction to the seller, who must then sign it themselves and commit // it to the ledger by sending it to the notary. progressTracker.currentStep = COLLECTING_SIGNATURES val twiceSignedTx = subFlow(CollectSignaturesFlow(partSignedTx, COLLECTING_SIGNATURES.childProgressTracker())) - // Notarise and record the transaction. progressTracker.currentStep = RECORDING - return subFlow(FinalityFlow(twiceSignedTx, setOf(otherParty, serviceHub.myInfo.legalIdentity))).single() + return subFlow(FinalityFlow(twiceSignedTx)).single() } @Suspendable - private fun receiveAndValidateTradeRequest(): SellerTradeInfo { - val maybeTradeRequest = receive(otherParty) - - progressTracker.currentStep = VERIFYING - maybeTradeRequest.unwrap { - // What is the seller trying to sell us? - val asset = it.assetForSale.state.data + private fun receiveAndValidateTradeRequest(): Pair, SellerTradeInfo> { + val assetForSale = subFlow(ReceiveStateAndRefFlow(otherParty)).single() + return assetForSale to receive(otherParty).unwrap { + progressTracker.currentStep = VERIFYING + val asset = assetForSale.state.data val assetTypeName = asset.javaClass.name - if (it.price > acceptablePrice) throw UnacceptablePriceException(it.price) if (!typeToBuy.isInstance(asset)) throw AssetMismatchException(typeToBuy.name, assetTypeName) - - // Check that the state being sold to us is in a valid chain of transactions, i.e. that the - // seller has a valid chain of custody proving that they own the thing they're selling. - subFlow(ResolveTransactionsFlow(setOf(it.assetForSale.ref.txhash), otherParty)) - - return it + it } } - private fun signWithOurKeys(cashSigningPubKeys: List, ptx: TransactionBuilder): SignedTransaction { - // Now sign the transaction with whatever keys we need to move the cash. - return serviceHub.signInitialTransaction(ptx, cashSigningPubKeys) - } @Suspendable - private fun assembleSharedTX(tradeRequest: SellerTradeInfo): Pair> { - val ptx = TransactionType.General.Builder(notary) + private fun assembleSharedTX(assetForSale: StateAndRef, tradeRequest: SellerTradeInfo): Pair> { + val ptx = TransactionBuilder(notary) // Add input and output states for the movement of cash, by using the Cash contract to generate the states - val (tx, cashSigningPubKeys) = serviceHub.vaultService.generateSpend(ptx, tradeRequest.price, tradeRequest.sellerOwner) + val (tx, cashSigningPubKeys) = Cash.generateSpend(serviceHub, ptx, tradeRequest.price, tradeRequest.sellerOwner) // Add inputs/outputs/a command for the movement of the asset. - tx.addInputState(tradeRequest.assetForSale) + tx.addInputState(assetForSale) // Just pick some new public key for now. This won't be linked with our identity in any way, which is what // we want for privacy reasons: the key is here ONLY to manage and control ownership, it is not intended to // reveal who the owner actually is. The key management service is expected to derive a unique key from some // initial seed in order to provide privacy protection. val freshKey = serviceHub.keyManagementService.freshKey() - val (command, state) = tradeRequest.assetForSale.state.data.withNewOwner(AnonymousParty(freshKey)) - tx.addOutputState(state, tradeRequest.assetForSale.state.notary) - tx.addCommand(command, tradeRequest.assetForSale.state.data.owner.owningKey) + val (command, state) = assetForSale.state.data.withNewOwner(AnonymousParty(freshKey)) + tx.addOutputState(state, assetForSale.state.notary) + tx.addCommand(command, assetForSale.state.data.owner.owningKey) // We set the transaction's time-window: it may be that none of the contracts need this! // But it can't hurt to have one. diff --git a/finance/src/main/kotlin/net/corda/schemas/CashSchemaV1.kt b/finance/src/main/kotlin/net/corda/schemas/CashSchemaV1.kt index e2ede7a16e..8e11e1f8c8 100644 --- a/finance/src/main/kotlin/net/corda/schemas/CashSchemaV1.kt +++ b/finance/src/main/kotlin/net/corda/schemas/CashSchemaV1.kt @@ -2,6 +2,7 @@ package net.corda.schemas import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState +import net.corda.core.serialization.CordaSerializable import javax.persistence.* /** @@ -13,6 +14,7 @@ object CashSchema * First version of a cash contract ORM schema that maps all fields of the [Cash] contract state as it stood * at the time of writing. */ +@CordaSerializable object CashSchemaV1 : MappedSchema(schemaFamily = CashSchema.javaClass, version = 1, mappedTypes = listOf(PersistentCashState::class.java)) { @Entity @Table(name = "contract_cash_states", diff --git a/finance/src/main/kotlin/net/corda/schemas/CommercialPaperSchemaV1.kt b/finance/src/main/kotlin/net/corda/schemas/CommercialPaperSchemaV1.kt index 2b24f6be8b..98d600ed09 100644 --- a/finance/src/main/kotlin/net/corda/schemas/CommercialPaperSchemaV1.kt +++ b/finance/src/main/kotlin/net/corda/schemas/CommercialPaperSchemaV1.kt @@ -2,6 +2,7 @@ package net.corda.schemas import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState +import net.corda.core.serialization.CordaSerializable import java.time.Instant import javax.persistence.Column import javax.persistence.Entity @@ -17,6 +18,7 @@ object CommercialPaperSchema * First version of a commercial paper contract ORM schema that maps all fields of the [CommercialPaper] contract state * as it stood at the time of writing. */ +@CordaSerializable object CommercialPaperSchemaV1 : MappedSchema(schemaFamily = CommercialPaperSchema.javaClass, version = 1, mappedTypes = listOf(PersistentCommercialPaperState::class.java)) { @Entity @Table(name = "cp_states", diff --git a/finance/src/test/java/net/corda/contracts/asset/CashTestsJava.java b/finance/src/test/java/net/corda/contracts/asset/CashTestsJava.java index ed091c6104..ba479db499 100644 --- a/finance/src/test/java/net/corda/contracts/asset/CashTestsJava.java +++ b/finance/src/test/java/net/corda/contracts/asset/CashTestsJava.java @@ -8,8 +8,6 @@ import org.junit.Test; import static net.corda.core.contracts.ContractsDSL.DOLLARS; import static net.corda.core.contracts.ContractsDSL.issuedBy; -import static net.corda.testing.TestConstants.getDUMMY_PUBKEY_1; -import static net.corda.testing.TestConstants.getDUMMY_PUBKEY_2; import static net.corda.testing.CoreTestUtils.*; /** @@ -18,8 +16,8 @@ import static net.corda.testing.CoreTestUtils.*; public class CashTestsJava { private final OpaqueBytes defaultRef = new OpaqueBytes(new byte[]{1}); private final PartyAndReference defaultIssuer = getMEGA_CORP().ref(defaultRef); - private final Cash.State inState = new Cash.State(issuedBy(DOLLARS(1000), defaultIssuer), new AnonymousParty(getDUMMY_PUBKEY_1())); - private final Cash.State outState = new Cash.State(inState.getAmount(), new AnonymousParty(getDUMMY_PUBKEY_2())); + private final Cash.State inState = new Cash.State(issuedBy(DOLLARS(1000), defaultIssuer), new AnonymousParty(getMEGA_CORP_PUBKEY())); + private final Cash.State outState = new Cash.State(inState.getAmount(), new AnonymousParty(getMINI_CORP_PUBKEY())); @Test public void trivial() { @@ -29,18 +27,18 @@ public class CashTestsJava { tx.failsWith("the amounts balance"); tx.tweak(tw -> { - tw.output(new Cash.State(issuedBy(DOLLARS(2000), defaultIssuer), new AnonymousParty(getDUMMY_PUBKEY_2()))); + tw.output(new Cash.State(issuedBy(DOLLARS(2000), defaultIssuer), new AnonymousParty(getMINI_CORP_PUBKEY()))); return tw.failsWith("the amounts balance"); }); tx.tweak(tw -> { tw.output(outState); // No command arguments - return tw.failsWith("required net.corda.core.contracts.FungibleAsset.Commands.Move command"); + return tw.failsWith("required net.corda.contracts.asset.Cash.Commands.Move command"); }); tx.tweak(tw -> { tw.output(outState); - tw.command(getDUMMY_PUBKEY_2(), new Cash.Commands.Move()); + tw.command(getMINI_CORP_PUBKEY(), new Cash.Commands.Move()); return tw.failsWith("the owning keys are a subset of the signing keys"); }); tx.tweak(tw -> { @@ -48,14 +46,14 @@ public class CashTestsJava { // issuedBy() can't be directly imported because it conflicts with other identically named functions // with different overloads (for some reason). tw.output(CashKt.issuedBy(outState, getMINI_CORP())); - tw.command(getDUMMY_PUBKEY_1(), new Cash.Commands.Move()); - return tw.failsWith("at least one asset input"); + tw.command(getMEGA_CORP_PUBKEY(), new Cash.Commands.Move()); + return tw.failsWith("at least one cash input"); }); // Simple reallocation works. return tx.tweak(tw -> { tw.output(outState); - tw.command(getDUMMY_PUBKEY_1(), new Cash.Commands.Move()); + tw.command(getMEGA_CORP_PUBKEY(), new Cash.Commands.Move()); return tw.verifies(); }); }); diff --git a/finance/src/test/java/net/corda/flows/AbstractStateReplacementFlowTest.java b/finance/src/test/java/net/corda/flows/AbstractStateReplacementFlowTest.java index 9c8d3086e3..c278197812 100644 --- a/finance/src/test/java/net/corda/flows/AbstractStateReplacementFlowTest.java +++ b/finance/src/test/java/net/corda/flows/AbstractStateReplacementFlowTest.java @@ -1,9 +1,10 @@ package net.corda.flows; +import net.corda.core.flows.AbstractStateReplacementFlow; import net.corda.core.identity.Party; -import net.corda.core.identity.Party; -import net.corda.core.utilities.*; -import org.jetbrains.annotations.*; +import net.corda.core.transactions.SignedTransaction; +import net.corda.core.utilities.ProgressTracker; +import org.jetbrains.annotations.NotNull; @SuppressWarnings("unused") public class AbstractStateReplacementFlowTest { @@ -15,7 +16,7 @@ public class AbstractStateReplacementFlowTest { } @Override - protected void verifyProposal(@NotNull AbstractStateReplacementFlow.Proposal proposal) { + protected void verifyProposal(@NotNull SignedTransaction stx, @NotNull AbstractStateReplacementFlow.Proposal proposal) { } } } diff --git a/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt b/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt index 35bf37e01e..4f45920180 100644 --- a/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt +++ b/finance/src/test/kotlin/net/corda/contracts/CommercialPaperTests.kt @@ -1,20 +1,19 @@ package net.corda.contracts import net.corda.contracts.asset.* -import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.core.contracts.* -import net.corda.core.days import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.node.services.Vault import net.corda.core.node.services.VaultService -import net.corda.core.seconds import net.corda.core.transactions.SignedTransaction -import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.days +import net.corda.core.utilities.seconds import net.corda.testing.* +import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.testing.node.MockServices -import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseAndMockServices import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized @@ -59,16 +58,16 @@ class KotlinCommercialPaperTest : ICommercialPaperTestTemplate { } class KotlinCommercialPaperLegacyTest : ICommercialPaperTestTemplate { - override fun getPaper(): ICommercialPaperState = CommercialPaperLegacy.State( + override fun getPaper(): ICommercialPaperState = CommercialPaper.State( issuance = MEGA_CORP.ref(123), owner = MEGA_CORP, faceValue = 1000.DOLLARS `issued by` MEGA_CORP.ref(123), maturityDate = TEST_TX_TIME + 7.days ) - override fun getIssueCommand(notary: Party): CommandData = CommercialPaperLegacy.Commands.Issue() - override fun getRedeemCommand(notary: Party): CommandData = CommercialPaperLegacy.Commands.Redeem() - override fun getMoveCommand(): CommandData = CommercialPaperLegacy.Commands.Move() + override fun getIssueCommand(notary: Party): CommandData = CommercialPaper.Commands.Issue() + override fun getRedeemCommand(notary: Party): CommandData = CommercialPaper.Commands.Redeem() + override fun getMoveCommand(): CommandData = CommercialPaper.Commands.Move() } @RunWith(Parameterized::class) @@ -154,7 +153,7 @@ class CommercialPaperTestsGeneric { fun `key mismatch at issue`() { transaction { output { thisTest.getPaper() } - command(DUMMY_PUBKEY_1) { thisTest.getIssueCommand(DUMMY_NOTARY) } + command(MINI_CORP_PUBKEY) { thisTest.getIssueCommand(DUMMY_NOTARY) } timeWindow(TEST_TX_TIME) this `fails with` "output states are issued by a command signer" } @@ -205,49 +204,30 @@ class CommercialPaperTestsGeneric { private lateinit var alicesVault: Vault private val notaryServices = MockServices(DUMMY_NOTARY_KEY) + private val issuerServices = MockServices(DUMMY_CASH_ISSUER_KEY) private lateinit var moveTX: SignedTransaction @Test fun `issue move and then redeem`() { + initialiseTestSerialization() + val aliceDatabaseAndServices = makeTestDatabaseAndMockServices(keys = listOf(ALICE_KEY)) + val databaseAlice = aliceDatabaseAndServices.first + aliceServices = aliceDatabaseAndServices.second + aliceVaultService = aliceServices.vaultService - val dataSourcePropsAlice = makeTestDataSourceProperties() - val dataSourceAndDatabaseAlice = configureDatabase(dataSourcePropsAlice) - val databaseAlice = dataSourceAndDatabaseAlice.second databaseAlice.transaction { - - aliceServices = object : MockServices(ALICE_KEY) { - override val vaultService: VaultService = makeVaultService(dataSourcePropsAlice) - - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - validatedTransactions.addTransaction(stx) - } - // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. - vaultService.notifyAll(txs.map { it.tx }) - } - } - alicesVault = aliceServices.fillWithSomeTestCash(9000.DOLLARS, atLeastThisManyStates = 1, atMostThisManyStates = 1) + alicesVault = aliceServices.fillWithSomeTestCash(9000.DOLLARS, issuerServices, atLeastThisManyStates = 1, atMostThisManyStates = 1, issuedBy = DUMMY_CASH_ISSUER) aliceVaultService = aliceServices.vaultService } - val dataSourcePropsBigCorp = makeTestDataSourceProperties() - val dataSourceAndDatabaseBigCorp = configureDatabase(dataSourcePropsBigCorp) - val databaseBigCorp = dataSourceAndDatabaseBigCorp.second + val bigCorpDatabaseAndServices = makeTestDatabaseAndMockServices(keys = listOf(BIG_CORP_KEY)) + val databaseBigCorp = bigCorpDatabaseAndServices.first + bigCorpServices = bigCorpDatabaseAndServices.second + bigCorpVaultService = bigCorpServices.vaultService + databaseBigCorp.transaction { - - bigCorpServices = object : MockServices(BIG_CORP_KEY) { - override val vaultService: VaultService = makeVaultService(dataSourcePropsBigCorp) - - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - validatedTransactions.addTransaction(stx) - } - // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. - vaultService.notifyAll(txs.map { it.tx }) - } - } - bigCorpVault = bigCorpServices.fillWithSomeTestCash(13000.DOLLARS, atLeastThisManyStates = 1, atMostThisManyStates = 1) + bigCorpVault = bigCorpServices.fillWithSomeTestCash(13000.DOLLARS, issuerServices, atLeastThisManyStates = 1, atMostThisManyStates = 1, issuedBy = DUMMY_CASH_ISSUER) bigCorpVaultService = bigCorpServices.vaultService } @@ -266,8 +246,8 @@ class CommercialPaperTestsGeneric { databaseAlice.transaction { // Alice pays $9000 to BigCorp to own some of their debt. moveTX = run { - val builder = TransactionType.General.Builder(DUMMY_NOTARY) - aliceVaultService.generateSpend(builder, 9000.DOLLARS, AnonymousParty(bigCorpServices.key.public)) + val builder = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(aliceServices, builder, 9000.DOLLARS, AnonymousParty(bigCorpServices.key.public)) CommercialPaper().generateMove(builder, issueTx.tx.outRef(0), AnonymousParty(aliceServices.key.public)) val ptx = aliceServices.signInitialTransaction(builder) val ptx2 = bigCorpServices.addSignature(ptx) @@ -287,9 +267,9 @@ class CommercialPaperTestsGeneric { databaseBigCorp.transaction { fun makeRedeemTX(time: Instant): Pair { - val builder = TransactionType.General.Builder(DUMMY_NOTARY) + val builder = TransactionBuilder(DUMMY_NOTARY) builder.setTimeWindow(time, 30.seconds) - CommercialPaper().generateRedeem(builder, moveTX.tx.outRef(1), bigCorpVaultService) + CommercialPaper().generateRedeem(builder, moveTX.tx.outRef(1), bigCorpServices) val ptx = aliceServices.signInitialTransaction(builder) val ptx2 = bigCorpServices.addSignature(ptx) val stx = notaryServices.addSignature(ptx2) @@ -310,5 +290,6 @@ class CommercialPaperTestsGeneric { validRedemption.toLedgerTransaction(aliceServices).verify() // soft lock not released after success either!!! (as transaction not recorded) } + resetTestSerialization() } } diff --git a/finance/src/test/kotlin/net/corda/contracts/DummyFungibleContract.kt b/finance/src/test/kotlin/net/corda/contracts/DummyFungibleContract.kt index e73cce41ce..401b8ae324 100644 --- a/finance/src/test/kotlin/net/corda/contracts/DummyFungibleContract.kt +++ b/finance/src/test/kotlin/net/corda/contracts/DummyFungibleContract.kt @@ -1,25 +1,21 @@ package net.corda.contracts.asset -import net.corda.contracts.clause.AbstractConserveAmount -import net.corda.contracts.clause.AbstractIssue -import net.corda.contracts.clause.NoZeroSizedOutputs import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.AllOf -import net.corda.core.contracts.clauses.FirstOf -import net.corda.core.contracts.clauses.GroupClauseVerifier -import net.corda.core.contracts.clauses.verifyClause -import net.corda.core.crypto.* +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.newSecureRandom +import net.corda.core.crypto.toBase58String import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.internal.Emoji import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState import net.corda.core.schemas.QueryableState -import net.corda.core.serialization.CordaSerializable +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder -import net.corda.core.utilities.Emoji import net.corda.schemas.SampleCashSchemaV1 import net.corda.schemas.SampleCashSchemaV2 import net.corda.schemas.SampleCashSchemaV3 +import java.security.PublicKey import java.util.* class DummyFungibleContract : OnLedgerAsset() { @@ -28,29 +24,6 @@ class DummyFungibleContract : OnLedgerAsset>): List> = commands.select() - interface Clauses { - class Group : GroupClauseVerifier>(AllOf>( - NoZeroSizedOutputs(), - FirstOf>( - Issue(), - ConserveAmount()) - ) - ) { - override fun groupStates(tx: TransactionForContract): List>> - = tx.groupStates> { it.amount.token } - } - - class Issue : AbstractIssue( - sum = { sumCash() }, - sumOrZero = { sumCashOrZero(it) } - ) { - override val requiredCommands: Set> = setOf(Commands.Issue::class.java) - } - - @CordaSerializable - class ConserveAmount : AbstractConserveAmount() - } - data class State( override val amount: Amount>, @@ -68,7 +41,7 @@ class DummyFungibleContract : OnLedgerAsset it.amount.token } + + for ((inputs, outputs, key) in groups) { + // Either inputs or outputs could be empty. + val issuer = key.issuer + val currency = key.product + + requireThat { + "there are no zero sized outputs" using (outputs.none { it.amount.quantity == 0L }) + } + + val issueCommand = tx.commands.select().firstOrNull() + if (issueCommand != null) { + verifyIssueCommand(inputs, outputs, tx, issueCommand, currency, issuer) + } else { + val inputAmount = inputs.sumCashOrNull() ?: throw IllegalArgumentException("there is at least one input for this group") + val outputAmount = outputs.sumCashOrZero(Issued(issuer, currency)) + + val exitKeys: Set = inputs.flatMap { it.exitKeys }.toSet() + val exitCommand = tx.commands.select(parties = null, signers = exitKeys).filter { it.value.amount.token == key }.singleOrNull() + val amountExitingLedger = exitCommand?.value?.amount ?: Amount(0, Issued(issuer, currency)) + + requireThat { + "there are no zero sized inputs" using inputs.none { it.amount.quantity == 0L } + "for reference ${issuer.reference} at issuer ${issuer.party} the amounts balance: ${inputAmount.quantity} - ${amountExitingLedger.quantity} != ${outputAmount.quantity}" using + (inputAmount == outputAmount + amountExitingLedger) + } + + verifyMoveCommand(inputs, tx.commands) + } + } + } + + private fun verifyIssueCommand(inputs: List, + outputs: List, + tx: LedgerTransaction, + issueCommand: AuthenticatedObject, + currency: Currency, + issuer: PartyAndReference) { + // If we have an issue command, perform special processing: the group is allowed to have no inputs, + // and the output states must have a deposit reference owned by the signer. + // + // Whilst the transaction *may* have no inputs, it can have them, and in this case the outputs must + // sum to more than the inputs. An issuance of zero size is not allowed. + // + // Note that this means literally anyone with access to the network can issue cash claims of arbitrary + // amounts! It is up to the recipient to decide if the backing party is trustworthy or not, via some + // as-yet-unwritten identity service. See ADP-22 for discussion. + + // The grouping ensures that all outputs have the same deposit reference and currency. + val inputAmount = inputs.sumCashOrZero(Issued(issuer, currency)) + val outputAmount = outputs.sumCash() + val cashCommands = tx.commands.select() + requireThat { + "the issue command has a nonce" using (issueCommand.value.nonce != 0L) + // TODO: This doesn't work with the trader demo, so use the underlying key instead + // "output states are issued by a command signer" by (issuer.party in issueCommand.signingParties) + "output states are issued by a command signer" using (issuer.party.owningKey in issueCommand.signers) + "output values sum to more than the inputs" using (outputAmount > inputAmount) + "there is only a single issue command" using (cashCommands.count() == 1) + } + } } diff --git a/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt b/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt index 5a0771ed3d..c2ab2d3d75 100644 --- a/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt +++ b/finance/src/test/kotlin/net/corda/contracts/asset/CashTests.kt @@ -1,86 +1,77 @@ package net.corda.contracts.asset -import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyState import net.corda.core.crypto.SecureHash import net.corda.core.crypto.generateKeyPair import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.node.services.VaultService -import net.corda.core.node.services.unconsumedStates -import net.corda.core.utilities.OpaqueBytes -import net.corda.core.transactions.SignedTransaction +import net.corda.core.node.services.queryBy +import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.OpaqueBytes import net.corda.node.services.vault.NodeVaultService -import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction +import net.corda.node.utilities.CordaPersistence import net.corda.testing.* -import net.corda.testing.node.MockKeyManagementService +import net.corda.testing.contracts.DummyState +import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.testing.node.MockServices -import net.corda.testing.node.makeTestDataSourceProperties -import org.jetbrains.exposed.sql.Database +import net.corda.testing.node.makeTestDatabaseAndMockServices +import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable import java.security.KeyPair import java.util.* import kotlin.test.* -class CashTests { +class CashTests : TestDependencyInjectionBase() { val defaultRef = OpaqueBytes(ByteArray(1, { 1 })) val defaultIssuer = MEGA_CORP.ref(defaultRef) val inState = Cash.State( amount = 1000.DOLLARS `issued by` defaultIssuer, - owner = AnonymousParty(DUMMY_PUBKEY_1) + owner = AnonymousParty(ALICE_PUBKEY) ) // Input state held by the issuer val issuerInState = inState.copy(owner = defaultIssuer.party) - val outState = issuerInState.copy(owner = AnonymousParty(DUMMY_PUBKEY_2)) + val outState = issuerInState.copy(owner = AnonymousParty(BOB_PUBKEY)) fun Cash.State.editDepositRef(ref: Byte) = copy( amount = Amount(amount.quantity, token = amount.token.copy(amount.token.issuer.copy(reference = OpaqueBytes.of(ref)))) ) lateinit var miniCorpServices: MockServices + lateinit var megaCorpServices: MockServices val vault: VaultService get() = miniCorpServices.vaultService - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence lateinit var vaultStatesUnconsumed: List> @Before fun setUp() { LogHelper.setLevel(NodeVaultService::class) - val dataSourceProps = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProps) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + megaCorpServices = MockServices(MEGA_CORP_KEY) + val databaseAndServices = makeTestDatabaseAndMockServices(keys = listOf(MINI_CORP_KEY, MEGA_CORP_KEY, OUR_KEY)) + database = databaseAndServices.first + miniCorpServices = databaseAndServices.second + database.transaction { - miniCorpServices = object : MockServices(MINI_CORP_KEY) { - override val keyManagementService: MockKeyManagementService = MockKeyManagementService(identityService, MINI_CORP_KEY, MEGA_CORP_KEY, OUR_KEY) - override val vaultService: VaultService = makeVaultService(dataSourceProps) - - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - validatedTransactions.addTransaction(stx) - } - // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. - vaultService.notifyAll(txs.map { it.tx }) - } - } - miniCorpServices.fillWithSomeTestCash(howMuch = 100.DOLLARS, atLeastThisManyStates = 1, atMostThisManyStates = 1, - issuedBy = MEGA_CORP.ref(1), issuerKey = MEGA_CORP_KEY, ownedBy = OUR_IDENTITY_1) + ownedBy = OUR_IDENTITY_1, issuedBy = MEGA_CORP.ref(1), issuerServices = megaCorpServices) miniCorpServices.fillWithSomeTestCash(howMuch = 400.DOLLARS, atLeastThisManyStates = 1, atMostThisManyStates = 1, - issuedBy = MEGA_CORP.ref(1), issuerKey = MEGA_CORP_KEY, ownedBy = OUR_IDENTITY_1) + ownedBy = OUR_IDENTITY_1, issuedBy = MEGA_CORP.ref(1), issuerServices = megaCorpServices) miniCorpServices.fillWithSomeTestCash(howMuch = 80.DOLLARS, atLeastThisManyStates = 1, atMostThisManyStates = 1, - issuedBy = MINI_CORP.ref(1), issuerKey = MINI_CORP_KEY, ownedBy = OUR_IDENTITY_1) + ownedBy = OUR_IDENTITY_1, issuedBy = MINI_CORP.ref(1), issuerServices = miniCorpServices) miniCorpServices.fillWithSomeTestCash(howMuch = 80.SWISS_FRANCS, atLeastThisManyStates = 1, atMostThisManyStates = 1, - issuedBy = MINI_CORP.ref(1), issuerKey = MINI_CORP_KEY, ownedBy = OUR_IDENTITY_1) + ownedBy = OUR_IDENTITY_1, issuedBy = MINI_CORP.ref(1), issuerServices = miniCorpServices) - vaultStatesUnconsumed = miniCorpServices.vaultService.unconsumedStates().toList() + vaultStatesUnconsumed = miniCorpServices.vaultQueryService.queryBy().states } + resetTestSerialization() + } + + @After + fun tearDown() { + database.close() } @Test @@ -96,23 +87,23 @@ class CashTests { tweak { output { outState } // No command arguments - this `fails with` "required net.corda.core.contracts.FungibleAsset.Commands.Move command" + this `fails with` "required net.corda.contracts.asset.Cash.Commands.Move command" } tweak { output { outState } - command(DUMMY_PUBKEY_2) { Cash.Commands.Move() } + command(BOB_PUBKEY) { Cash.Commands.Move() } this `fails with` "the owning keys are a subset of the signing keys" } tweak { output { outState } output { outState `issued by` MINI_CORP } - command(DUMMY_PUBKEY_1) { Cash.Commands.Move() } - this `fails with` "at least one asset input" + command(ALICE_PUBKEY) { Cash.Commands.Move() } + this `fails with` "at least one cash input" } // Simple reallocation works. tweak { output { outState } - command(DUMMY_PUBKEY_1) { Cash.Commands.Move() } + command(ALICE_PUBKEY) { Cash.Commands.Move() } this.verifies() } } @@ -126,7 +117,7 @@ class CashTests { output { outState } command(MINI_CORP_PUBKEY) { Cash.Commands.Move() } - this `fails with` "there is at least one asset input" + this `fails with` "there is at least one cash input for this group" } } @@ -136,14 +127,14 @@ class CashTests { // institution is allowed to issue as much cash as they want. transaction { output { outState } - command(DUMMY_PUBKEY_1) { Cash.Commands.Issue() } + command(ALICE_PUBKEY) { Cash.Commands.Issue() } this `fails with` "output states are issued by a command signer" } transaction { output { Cash.State( amount = 1000.DOLLARS `issued by` MINI_CORP.ref(12, 34), - owner = AnonymousParty(DUMMY_PUBKEY_1) + owner = AnonymousParty(ALICE_PUBKEY) ) } tweak { @@ -157,25 +148,27 @@ class CashTests { @Test fun generateIssueRaw() { + initialiseTestSerialization() // Test generation works. - val tx: WireTransaction = TransactionType.General.Builder(notary = null).apply { - Cash().generateIssue(this, 100.DOLLARS `issued by` MINI_CORP.ref(12, 34), owner = AnonymousParty(DUMMY_PUBKEY_1), notary = DUMMY_NOTARY) + val tx: WireTransaction = TransactionBuilder(notary = null).apply { + Cash().generateIssue(this, 100.DOLLARS `issued by` MINI_CORP.ref(12, 34), owner = AnonymousParty(ALICE_PUBKEY), notary = DUMMY_NOTARY) }.toWireTransaction() assertTrue(tx.inputs.isEmpty()) - val s = tx.outputs[0].data as Cash.State + val s = tx.outputsOfType().single() assertEquals(100.DOLLARS `issued by` MINI_CORP.ref(12, 34), s.amount) assertEquals(MINI_CORP as AbstractParty, s.amount.token.issuer.party) - assertEquals(AnonymousParty(DUMMY_PUBKEY_1), s.owner) + assertEquals(AnonymousParty(ALICE_PUBKEY), s.owner) assertTrue(tx.commands[0].value is Cash.Commands.Issue) assertEquals(MINI_CORP_PUBKEY, tx.commands[0].signers[0]) } @Test fun generateIssueFromAmount() { + initialiseTestSerialization() // Test issuance from an issued amount val amount = 100.DOLLARS `issued by` MINI_CORP.ref(12, 34) - val tx: WireTransaction = TransactionType.General.Builder(notary = null).apply { - Cash().generateIssue(this, amount, owner = AnonymousParty(DUMMY_PUBKEY_1), notary = DUMMY_NOTARY) + val tx: WireTransaction = TransactionBuilder(notary = null).apply { + Cash().generateIssue(this, amount, owner = AnonymousParty(ALICE_PUBKEY), notary = DUMMY_NOTARY) }.toWireTransaction() assertTrue(tx.inputs.isEmpty()) assertEquals(tx.outputs[0], tx.outputs[0]) @@ -190,7 +183,7 @@ class CashTests { // Move fails: not allowed to summon money. tweak { - command(DUMMY_PUBKEY_1) { Cash.Commands.Move() } + command(ALICE_PUBKEY) { Cash.Commands.Move() } this `fails with` "the amounts balance" } @@ -224,15 +217,7 @@ class CashTests { command(MEGA_CORP_PUBKEY) { Cash.Commands.Issue() } tweak { command(MEGA_CORP_PUBKEY) { Cash.Commands.Issue() } - this `fails with` "List has more than one element." - } - tweak { - command(MEGA_CORP_PUBKEY) { Cash.Commands.Move() } - this `fails with` "The following commands were not matched at the end of execution" - } - tweak { - command(MEGA_CORP_PUBKEY) { Cash.Commands.Exit(inState.amount.splitEvenly(2).first()) } - this `fails with` "The following commands were not matched at the end of execution" + this `fails with` "there is only a single issue command" } this.verifies() } @@ -244,14 +229,15 @@ class CashTests { */ @Test(expected = IllegalStateException::class) fun `reject issuance with inputs`() { + initialiseTestSerialization() // Issue some cash - var ptx = TransactionType.General.Builder(DUMMY_NOTARY) + var ptx = TransactionBuilder(DUMMY_NOTARY) Cash().generateIssue(ptx, 100.DOLLARS `issued by` MINI_CORP.ref(12, 34), owner = MINI_CORP, notary = DUMMY_NOTARY) val tx = miniCorpServices.signInitialTransaction(ptx) // Include the previously issued cash in a new issuance command - ptx = TransactionType.General.Builder(DUMMY_NOTARY) + ptx = TransactionBuilder(DUMMY_NOTARY) ptx.addInputState(tx.tx.outRef(0)) Cash().generateIssue(ptx, 100.DOLLARS `issued by` MINI_CORP.ref(12, 34), owner = MINI_CORP, notary = DUMMY_NOTARY) } @@ -260,7 +246,7 @@ class CashTests { fun testMergeSplit() { // Splitting value works. transaction { - command(DUMMY_PUBKEY_1) { Cash.Commands.Move() } + command(ALICE_PUBKEY) { Cash.Commands.Move() } tweak { input { inState } val splits4 = inState.amount.splitEvenly(4) @@ -327,7 +313,7 @@ class CashTests { input { inState.copy( amount = 150.POUNDS `issued by` defaultIssuer, - owner = AnonymousParty(DUMMY_PUBKEY_2) + owner = AnonymousParty(BOB_PUBKEY) ) } output { outState.copy(amount = 1150.DOLLARS `issued by` defaultIssuer) } @@ -338,7 +324,7 @@ class CashTests { input { inState } input { inState `issued by` MINI_CORP } output { outState } - command(DUMMY_PUBKEY_1) { Cash.Commands.Move() } + command(ALICE_PUBKEY) { Cash.Commands.Move() } this `fails with` "the amounts balance" } // Can't combine two different deposits at the same issuer. @@ -365,7 +351,7 @@ class CashTests { tweak { command(MEGA_CORP_PUBKEY) { Cash.Commands.Exit(200.DOLLARS `issued by` defaultIssuer) } - this `fails with` "required net.corda.core.contracts.FungibleAsset.Commands.Move command" + this `fails with` "required net.corda.contracts.asset.Cash.Commands.Move command" tweak { command(MEGA_CORP_PUBKEY) { Cash.Commands.Move() } @@ -404,7 +390,7 @@ class CashTests { input { inState } output { outState.copy(amount = inState.amount - (200.DOLLARS `issued by` defaultIssuer)) } command(MEGA_CORP_PUBKEY) { Cash.Commands.Exit(200.DOLLARS `issued by` defaultIssuer) } - command(DUMMY_PUBKEY_1) { Cash.Commands.Move() } + command(ALICE_PUBKEY) { Cash.Commands.Move() } this `fails with` "the amounts balance" } } @@ -418,20 +404,20 @@ class CashTests { // Can't merge them together. tweak { - output { inState.copy(owner = AnonymousParty(DUMMY_PUBKEY_2), amount = 2000.DOLLARS `issued by` defaultIssuer) } + output { inState.copy(owner = AnonymousParty(BOB_PUBKEY), amount = 2000.DOLLARS `issued by` defaultIssuer) } this `fails with` "the amounts balance" } // Missing MiniCorp deposit tweak { - output { inState.copy(owner = AnonymousParty(DUMMY_PUBKEY_2)) } - output { inState.copy(owner = AnonymousParty(DUMMY_PUBKEY_2)) } + output { inState.copy(owner = AnonymousParty(BOB_PUBKEY)) } + output { inState.copy(owner = AnonymousParty(BOB_PUBKEY)) } this `fails with` "the amounts balance" } // This works. - output { inState.copy(owner = AnonymousParty(DUMMY_PUBKEY_2)) } - output { inState.copy(owner = AnonymousParty(DUMMY_PUBKEY_2)) `issued by` MINI_CORP } - command(DUMMY_PUBKEY_1) { Cash.Commands.Move() } + output { inState.copy(owner = AnonymousParty(BOB_PUBKEY)) } + output { inState.copy(owner = AnonymousParty(BOB_PUBKEY)) `issued by` MINI_CORP } + command(ALICE_PUBKEY) { Cash.Commands.Move() } this.verifies() } } @@ -440,12 +426,12 @@ class CashTests { fun multiCurrency() { // Check we can do an atomic currency trade tx. transaction { - val pounds = Cash.State(658.POUNDS `issued by` MINI_CORP.ref(3, 4, 5), AnonymousParty(DUMMY_PUBKEY_2)) - input { inState `owned by` AnonymousParty(DUMMY_PUBKEY_1) } + val pounds = Cash.State(658.POUNDS `issued by` MINI_CORP.ref(3, 4, 5), AnonymousParty(BOB_PUBKEY)) + input { inState `owned by` AnonymousParty(ALICE_PUBKEY) } input { pounds } - output { inState `owned by` AnonymousParty(DUMMY_PUBKEY_2) } - output { pounds `owned by` AnonymousParty(DUMMY_PUBKEY_1) } - command(DUMMY_PUBKEY_1, DUMMY_PUBKEY_2) { Cash.Commands.Move() } + output { inState `owned by` AnonymousParty(BOB_PUBKEY) } + output { pounds `owned by` AnonymousParty(ALICE_PUBKEY) } + command(ALICE_PUBKEY, BOB_PUBKEY) { Cash.Commands.Move() } this.verifies() } @@ -458,7 +444,7 @@ class CashTests { val OUR_KEY: KeyPair by lazy { generateKeyPair() } val OUR_IDENTITY_1: AbstractParty get() = AnonymousParty(OUR_KEY.public) - val THEIR_IDENTITY_1 = AnonymousParty(DUMMY_PUBKEY_2) + val THEIR_IDENTITY_1 = AnonymousParty(MINI_CORP_PUBKEY) fun makeCash(amount: Amount, corp: Party, depositRef: Byte = 1) = StateAndRef( @@ -477,15 +463,15 @@ class CashTests { * Generate an exit transaction, removing some amount of cash from the ledger. */ fun makeExit(amount: Amount, corp: Party, depositRef: Byte = 1): WireTransaction { - val tx = TransactionType.General.Builder(DUMMY_NOTARY) + val tx = TransactionBuilder(DUMMY_NOTARY) Cash().generateExit(tx, Amount(amount.quantity, Issued(corp.ref(depositRef), amount.token)), WALLET) return tx.toWireTransaction() } fun makeSpend(amount: Amount, dest: AbstractParty): WireTransaction { - val tx = TransactionType.General.Builder(DUMMY_NOTARY) + val tx = TransactionBuilder(DUMMY_NOTARY) database.transaction { - vault.generateSpend(tx, amount, dest) + Cash.generateSpend(miniCorpServices, tx, amount, dest) } return tx.toWireTransaction() } @@ -495,6 +481,7 @@ class CashTests { */ @Test fun generateSimpleExit() { + initialiseTestSerialization() val wtx = makeExit(100.DOLLARS, MEGA_CORP, 1) assertEquals(WALLET[0].ref, wtx.inputs[0]) assertEquals(0, wtx.outputs.size) @@ -510,10 +497,11 @@ class CashTests { */ @Test fun generatePartialExit() { + initialiseTestSerialization() val wtx = makeExit(50.DOLLARS, MEGA_CORP, 1) assertEquals(WALLET[0].ref, wtx.inputs[0]) assertEquals(1, wtx.outputs.size) - assertEquals(WALLET[0].state.data.copy(amount = WALLET[0].state.data.amount.splitEvenly(2).first()), wtx.outputs[0].data) + assertEquals(WALLET[0].state.data.copy(amount = WALLET[0].state.data.amount.splitEvenly(2).first()), wtx.getOutput(0)) } /** @@ -521,6 +509,7 @@ class CashTests { */ @Test fun generateAbsentExit() { + initialiseTestSerialization() assertFailsWith { makeExit(100.POUNDS, MEGA_CORP, 1) } } @@ -529,6 +518,7 @@ class CashTests { */ @Test fun generateInvalidReferenceExit() { + initialiseTestSerialization() assertFailsWith { makeExit(100.POUNDS, MEGA_CORP, 2) } } @@ -537,6 +527,7 @@ class CashTests { */ @Test fun generateInsufficientExit() { + initialiseTestSerialization() assertFailsWith { makeExit(1000.DOLLARS, MEGA_CORP, 1) } } @@ -545,6 +536,7 @@ class CashTests { */ @Test fun generateOwnerWithNoStatesExit() { + initialiseTestSerialization() assertFailsWith { makeExit(100.POUNDS, CHARLIE, 1) } } @@ -553,34 +545,34 @@ class CashTests { */ @Test fun generateExitWithEmptyVault() { + initialiseTestSerialization() assertFailsWith { - val tx = TransactionType.General.Builder(DUMMY_NOTARY) + val tx = TransactionBuilder(DUMMY_NOTARY) Cash().generateExit(tx, Amount(100, Issued(CHARLIE.ref(1), GBP)), emptyList()) } } @Test fun generateSimpleDirectSpend() { - + initialiseTestSerialization() database.transaction { - val wtx = makeSpend(100.DOLLARS, THEIR_IDENTITY_1) @Suppress("UNCHECKED_CAST") val vaultState = vaultStatesUnconsumed.elementAt(0) assertEquals(vaultState.ref, wtx.inputs[0]) - assertEquals(vaultState.state.data.copy(owner = THEIR_IDENTITY_1), wtx.outputs[0].data) + assertEquals(vaultState.state.data.copy(owner = THEIR_IDENTITY_1), wtx.getOutput(0)) assertEquals(OUR_IDENTITY_1.owningKey, wtx.commands.single { it.value is Cash.Commands.Move }.signers[0]) } } @Test fun generateSimpleSpendWithParties() { - + initialiseTestSerialization() database.transaction { - val tx = TransactionType.General.Builder(DUMMY_NOTARY) - vault.generateSpend(tx, 80.DOLLARS, ALICE, setOf(MINI_CORP)) + val tx = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(miniCorpServices, tx, 80.DOLLARS, ALICE, setOf(MINI_CORP)) assertEquals(vaultStatesUnconsumed.elementAt(2).ref, tx.inputStates()[0]) } @@ -588,7 +580,7 @@ class CashTests { @Test fun generateSimpleSpendWithChange() { - + initialiseTestSerialization() database.transaction { val wtx = makeSpend(10.DOLLARS, THEIR_IDENTITY_1) @@ -604,7 +596,7 @@ class CashTests { @Test fun generateSpendWithTwoInputs() { - + initialiseTestSerialization() database.transaction { val wtx = makeSpend(500.DOLLARS, THEIR_IDENTITY_1) @@ -613,14 +605,14 @@ class CashTests { val vaultState1 = vaultStatesUnconsumed.elementAt(1) assertEquals(vaultState0.ref, wtx.inputs[0]) assertEquals(vaultState1.ref, wtx.inputs[1]) - assertEquals(vaultState0.state.data.copy(owner = THEIR_IDENTITY_1, amount = 500.DOLLARS `issued by` defaultIssuer), wtx.outputs[0].data) + assertEquals(vaultState0.state.data.copy(owner = THEIR_IDENTITY_1, amount = 500.DOLLARS `issued by` defaultIssuer), wtx.getOutput(0)) assertEquals(OUR_IDENTITY_1.owningKey, wtx.commands.single { it.value is Cash.Commands.Move }.signers[0]) } } @Test fun generateSpendMixedDeposits() { - + initialiseTestSerialization() database.transaction { val wtx = makeSpend(580.DOLLARS, THEIR_IDENTITY_1) assertEquals(3, wtx.inputs.size) @@ -634,14 +626,14 @@ class CashTests { assertEquals(vaultState1.ref, wtx.inputs[1]) assertEquals(vaultState2.ref, wtx.inputs[2]) assertEquals(vaultState0.state.data.copy(owner = THEIR_IDENTITY_1, amount = 500.DOLLARS `issued by` defaultIssuer), wtx.outputs[1].data) - assertEquals(vaultState2.state.data.copy(owner = THEIR_IDENTITY_1), wtx.outputs[0].data) + assertEquals(vaultState2.state.data.copy(owner = THEIR_IDENTITY_1), wtx.getOutput(0)) assertEquals(OUR_IDENTITY_1.owningKey, wtx.commands.single { it.value is Cash.Commands.Move }.signers[0]) } } @Test fun generateSpendInsufficientBalance() { - + initialiseTestSerialization() database.transaction { val e: InsufficientBalanceException = assertFailsWith("balance") { @@ -753,7 +745,7 @@ class CashTests { transaction { input("MEGA_CORP cash") - output("MEGA_CORP cash".output().copy(owner = AnonymousParty(DUMMY_PUBKEY_1))) + output("MEGA_CORP cash".output().copy(owner = AnonymousParty(ALICE_PUBKEY))) command(MEGA_CORP_PUBKEY) { Cash.Commands.Move() } this.verifies() } diff --git a/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt b/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt index 5d393f71db..d661b0f7ab 100644 --- a/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt +++ b/finance/src/test/kotlin/net/corda/contracts/asset/ObligationTests.kt @@ -4,18 +4,20 @@ import net.corda.contracts.Commodity import net.corda.contracts.NetType import net.corda.contracts.asset.Obligation.Lifecycle import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyState import net.corda.core.crypto.SecureHash -import net.corda.core.hours import net.corda.core.crypto.testing.NULL_PARTY import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.OpaqueBytes -import net.corda.core.utilities.* +import net.corda.core.utilities.days +import net.corda.core.utilities.hours import net.corda.testing.* +import net.corda.testing.contracts.DummyState import net.corda.testing.node.MockServices +import org.junit.After import org.junit.Test -import java.time.Duration import java.time.Instant import java.time.temporal.ChronoUnit import java.util.* @@ -28,9 +30,9 @@ class ObligationTests { val defaultRef = OpaqueBytes.of(1) val defaultIssuer = MEGA_CORP.ref(defaultRef) val oneMillionDollars = 1000000.DOLLARS `issued by` defaultIssuer - val trustedCashContract = nonEmptySetOf(SecureHash.randomSHA256() as SecureHash) - val megaIssuedDollars = nonEmptySetOf(Issued(defaultIssuer, USD)) - val megaIssuedPounds = nonEmptySetOf(Issued(defaultIssuer, GBP)) + val trustedCashContract = NonEmptySet.of(SecureHash.randomSHA256() as SecureHash) + val megaIssuedDollars = NonEmptySet.of(Issued(defaultIssuer, USD)) + val megaIssuedPounds = NonEmptySet.of(Issued(defaultIssuer, GBP)) val fivePm: Instant = TEST_TX_TIME.truncatedTo(ChronoUnit.DAYS) + 17.hours val sixPm: Instant = fivePm + 1.hours val megaCorpDollarSettlement = Obligation.Terms(trustedCashContract, megaIssuedDollars, fivePm) @@ -42,7 +44,7 @@ class ObligationTests { quantity = 1000.DOLLARS.quantity, beneficiary = CHARLIE ) - val outState = inState.copy(beneficiary = AnonymousParty(DUMMY_PUBKEY_2)) + val outState = inState.copy(beneficiary = AnonymousParty(BOB_PUBKEY)) val miniCorpServices = MockServices(MINI_CORP_KEY) val notaryServices = MockServices(DUMMY_NOTARY_KEY) @@ -57,6 +59,11 @@ class ObligationTests { } } + @After + fun reset() { + resetTestSerialization() + } + @Test fun trivial() { transaction { @@ -70,18 +77,18 @@ class ObligationTests { tweak { output { outState } // No command arguments - this `fails with` "required net.corda.core.contracts.FungibleAsset.Commands.Move command" + this `fails with` "required net.corda.contracts.asset.Obligation.Commands.Move command" } tweak { output { outState } - command(DUMMY_PUBKEY_2) { Obligation.Commands.Move() } + command(BOB_PUBKEY) { Obligation.Commands.Move() } this `fails with` "the owning keys are a subset of the signing keys" } tweak { output { outState } output { outState `issued by` MINI_CORP } command(CHARLIE.owningKey) { Obligation.Commands.Move() } - this `fails with` "at least one asset input" + this `fails with` "at least one obligation input" } // Simple reallocation works. tweak { @@ -100,7 +107,7 @@ class ObligationTests { output { outState } command(MINI_CORP_PUBKEY) { Obligation.Commands.Move() } - this `fails with` "there is at least one asset input" + this `fails with` "at least one obligation input" } // Check we can issue money only as long as the issuer institution is a command signer, i.e. any recognised @@ -127,8 +134,9 @@ class ObligationTests { this.verifies() } + initialiseTestSerialization() // Test generation works. - val tx = TransactionType.General.Builder(notary = null).apply { + val tx = TransactionBuilder(notary = null).apply { Obligation().generateIssue(this, MINI_CORP, megaCorpDollarSettlement, 100.DOLLARS.quantity, beneficiary = CHARLIE, notary = DUMMY_NOTARY) }.toWireTransaction() @@ -139,9 +147,10 @@ class ObligationTests { beneficiary = CHARLIE, template = megaCorpDollarSettlement ) - assertEquals(tx.outputs[0].data, expected) + assertEquals(tx.getOutput(0), expected) assertTrue(tx.commands[0].value is Obligation.Commands.Issue) assertEquals(MINI_CORP_PUBKEY, tx.commands[0].signers[0]) + resetTestSerialization() // We can consume $1000 in a transaction and output $2000 as long as it's signed by an issuer. transaction { @@ -184,15 +193,7 @@ class ObligationTests { command(MEGA_CORP_PUBKEY) { Obligation.Commands.Issue() } tweak { command(MEGA_CORP_PUBKEY) { Obligation.Commands.Issue() } - this `fails with` "List has more than one element." - } - tweak { - command(MEGA_CORP_PUBKEY) { Obligation.Commands.Move() } - this `fails with` "The following commands were not matched at the end of execution" - } - tweak { - command(MEGA_CORP_PUBKEY) { Obligation.Commands.Exit(inState.amount.splitEvenly(2).first()) } - this `fails with` "The following commands were not matched at the end of execution" + this `fails with` "there is only a single issue command" } this.verifies() } @@ -204,15 +205,16 @@ class ObligationTests { */ @Test(expected = IllegalStateException::class) fun `reject issuance with inputs`() { + initialiseTestSerialization() // Issue some obligation - val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { + val tx = TransactionBuilder(DUMMY_NOTARY).apply { Obligation().generateIssue(this, MINI_CORP, megaCorpDollarSettlement, 100.DOLLARS.quantity, beneficiary = MINI_CORP, notary = DUMMY_NOTARY) }.toWireTransaction() // Include the previously issued obligation in a new issuance command - val ptx = TransactionType.General.Builder(DUMMY_NOTARY) + val ptx = TransactionBuilder(DUMMY_NOTARY) ptx.addInputState(tx.outRef>(0)) Obligation().generateIssue(ptx, MINI_CORP, megaCorpDollarSettlement, 100.DOLLARS.quantity, beneficiary = MINI_CORP, notary = DUMMY_NOTARY) @@ -221,9 +223,10 @@ class ObligationTests { /** Test generating a transaction to net two obligations of the same size, and therefore there are no outputs. */ @Test fun `generate close-out net transaction`() { + initialiseTestSerialization() val obligationAliceToBob = oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = oneMillionDollars.OBLIGATION between Pair(BOB, ALICE) - val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { + val tx = TransactionBuilder(DUMMY_NOTARY).apply { Obligation().generateCloseOutNetting(this, ALICE, obligationAliceToBob, obligationBobToAlice) }.toWireTransaction() assertEquals(0, tx.outputs.size) @@ -232,23 +235,25 @@ class ObligationTests { /** Test generating a transaction to net two obligations of the different sizes, and confirm the balance is correct. */ @Test fun `generate close-out net transaction with remainder`() { + initialiseTestSerialization() val obligationAliceToBob = (2000000.DOLLARS `issued by` defaultIssuer).OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = oneMillionDollars.OBLIGATION between Pair(BOB, ALICE) - val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { + val tx = TransactionBuilder(DUMMY_NOTARY).apply { Obligation().generateCloseOutNetting(this, ALICE, obligationAliceToBob, obligationBobToAlice) }.toWireTransaction() assertEquals(1, tx.outputs.size) - val actual = tx.outputs[0].data + val actual = tx.getOutput(0) assertEquals((1000000.DOLLARS `issued by` defaultIssuer).OBLIGATION between Pair(ALICE, BOB), actual) } /** Test generating a transaction to net two obligations of the same size, and therefore there are no outputs. */ @Test fun `generate payment net transaction`() { + initialiseTestSerialization() val obligationAliceToBob = oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = oneMillionDollars.OBLIGATION between Pair(BOB, ALICE) - val tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { + val tx = TransactionBuilder(DUMMY_NOTARY).apply { Obligation().generatePaymentNetting(this, obligationAliceToBob.amount.token, DUMMY_NOTARY, obligationAliceToBob, obligationBobToAlice) }.toWireTransaction() assertEquals(0, tx.outputs.size) @@ -257,69 +262,73 @@ class ObligationTests { /** Test generating a transaction to two obligations, where one is bigger than the other and therefore there is a remainder. */ @Test fun `generate payment net transaction with remainder`() { + initialiseTestSerialization() val obligationAliceToBob = oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) val obligationBobToAlice = (2000000.DOLLARS `issued by` defaultIssuer).OBLIGATION between Pair(BOB, ALICE) - val tx = TransactionType.General.Builder(null).apply { + val tx = TransactionBuilder(null).apply { Obligation().generatePaymentNetting(this, obligationAliceToBob.amount.token, DUMMY_NOTARY, obligationAliceToBob, obligationBobToAlice) }.toWireTransaction() assertEquals(1, tx.outputs.size) val expected = obligationBobToAlice.copy(quantity = obligationBobToAlice.quantity - obligationAliceToBob.quantity) - val actual = tx.outputs[0].data + val actual = tx.getOutput(0) assertEquals(expected, actual) } /** Test generating a transaction to mark outputs as having defaulted. */ @Test fun `generate set lifecycle`() { + initialiseTestSerialization() // We don't actually verify the states, this is just here to make things look sensible - val dueBefore = TEST_TX_TIME - Duration.ofDays(7) + val dueBefore = TEST_TX_TIME - 7.days - // Generate a transaction issuing the obligation - var tx = TransactionType.General.Builder(null).apply { - Obligation().generateIssue(this, MINI_CORP, megaCorpDollarSettlement.copy(dueBefore = dueBefore), 100.DOLLARS.quantity, + // Generate a transaction issuing the obligation. + var tx = TransactionBuilder(null).apply { + val amount = Amount(100, Issued(defaultIssuer, USD)) + Obligation().generateCashIssue(this, ALICE, amount, dueBefore, beneficiary = MINI_CORP, notary = DUMMY_NOTARY) } var stx = miniCorpServices.signInitialTransaction(tx) var stateAndRef = stx.tx.outRef>(0) // Now generate a transaction marking the obligation as having defaulted - tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { + tx = TransactionBuilder(DUMMY_NOTARY).apply { Obligation().generateSetLifecycle(this, listOf(stateAndRef), Lifecycle.DEFAULTED, DUMMY_NOTARY) } var ptx = miniCorpServices.signInitialTransaction(tx, MINI_CORP_PUBKEY) stx = notaryServices.addSignature(ptx) assertEquals(1, stx.tx.outputs.size) - assertEquals(stateAndRef.state.data.copy(lifecycle = Lifecycle.DEFAULTED), stx.tx.outputs[0].data) - stx.verifySignatures() + assertEquals(stateAndRef.state.data.copy(lifecycle = Lifecycle.DEFAULTED), stx.tx.getOutput(0)) + stx.verifyRequiredSignatures() // And set it back stateAndRef = stx.tx.outRef>(0) - tx = TransactionType.General.Builder(DUMMY_NOTARY).apply { + tx = TransactionBuilder(DUMMY_NOTARY).apply { Obligation().generateSetLifecycle(this, listOf(stateAndRef), Lifecycle.NORMAL, DUMMY_NOTARY) } ptx = miniCorpServices.signInitialTransaction(tx) stx = notaryServices.addSignature(ptx) assertEquals(1, stx.tx.outputs.size) - assertEquals(stateAndRef.state.data.copy(lifecycle = Lifecycle.NORMAL), stx.tx.outputs[0].data) - stx.verifySignatures() + assertEquals(stateAndRef.state.data.copy(lifecycle = Lifecycle.NORMAL), stx.tx.getOutput(0)) + stx.verifyRequiredSignatures() } /** Test generating a transaction to settle an obligation. */ @Test fun `generate settlement transaction`() { - val cashTx = TransactionType.General.Builder(null).apply { + initialiseTestSerialization() + val cashTx = TransactionBuilder(null).apply { Cash().generateIssue(this, 100.DOLLARS `issued by` defaultIssuer, MINI_CORP, DUMMY_NOTARY) }.toWireTransaction() // Generate a transaction issuing the obligation - val obligationTx = TransactionType.General.Builder(null).apply { + val obligationTx = TransactionBuilder(null).apply { Obligation().generateIssue(this, MINI_CORP, megaCorpDollarSettlement, 100.DOLLARS.quantity, beneficiary = MINI_CORP, notary = DUMMY_NOTARY) }.toWireTransaction() // Now generate a transaction settling the obligation - val settleTx = TransactionType.General.Builder(DUMMY_NOTARY).apply { + val settleTx = TransactionBuilder(DUMMY_NOTARY).apply { Obligation().generateSettle(this, listOf(obligationTx.outRef(0)), listOf(cashTx.outRef(0)), Cash.Commands.Move(), DUMMY_NOTARY) }.toWireTransaction() assertEquals(2, settleTx.inputs.size) @@ -500,7 +509,7 @@ class ObligationTests { fun `commodity settlement`() { val defaultFcoj = Issued(defaultIssuer, Commodity.getInstance("FCOJ")!!) val oneUnitFcoj = Amount(1, defaultFcoj) - val obligationDef = Obligation.Terms(nonEmptySetOf(CommodityContract().legalContractReference), nonEmptySetOf(defaultFcoj), TEST_TX_TIME) + val obligationDef = Obligation.Terms(NonEmptySet.of(CommodityContract().legalContractReference), NonEmptySet.of(defaultFcoj), TEST_TX_TIME) val oneUnitFcojObligation = Obligation.State(Obligation.Lifecycle.NORMAL, ALICE, obligationDef, oneUnitFcoj.quantity, NULL_PARTY) // Try settling a simple commodity obligation @@ -534,8 +543,8 @@ class ObligationTests { } // Try defaulting an obligation due in the future - val pastTestTime = TEST_TX_TIME - Duration.ofDays(7) - val futureTestTime = TEST_TX_TIME + Duration.ofDays(7) + val pastTestTime = TEST_TX_TIME - 7.days + val futureTestTime = TEST_TX_TIME + 7.days transaction("Settlement") { input(oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) `at` futureTestTime) output("Alice's defaulted $1,000,000 obligation to Bob") { (oneMillionDollars.OBLIGATION between Pair(ALICE, BOB) `at` futureTestTime).copy(lifecycle = Lifecycle.DEFAULTED) } @@ -620,7 +629,7 @@ class ObligationTests { inState.copy( quantity = 15000, template = megaCorpPoundSettlement, - beneficiary = AnonymousParty(DUMMY_PUBKEY_2) + beneficiary = AnonymousParty(BOB_PUBKEY) ) } output { outState.copy(quantity = 115000) } @@ -651,7 +660,7 @@ class ObligationTests { tweak { command(CHARLIE.owningKey) { Obligation.Commands.Exit(Amount(200.DOLLARS.quantity, inState.amount.token)) } - this `fails with` "required net.corda.core.contracts.FungibleAsset.Commands.Move command" + this `fails with` "required net.corda.contracts.asset.Obligation.Commands.Move command" tweak { command(CHARLIE.owningKey) { Obligation.Commands.Move() } @@ -693,19 +702,19 @@ class ObligationTests { // Can't merge them together. tweak { - output { inState.copy(beneficiary = AnonymousParty(DUMMY_PUBKEY_2), quantity = 200000L) } + output { inState.copy(beneficiary = AnonymousParty(BOB_PUBKEY), quantity = 200000L) } this `fails with` "the amounts balance" } // Missing MiniCorp deposit tweak { - output { inState.copy(beneficiary = AnonymousParty(DUMMY_PUBKEY_2)) } - output { inState.copy(beneficiary = AnonymousParty(DUMMY_PUBKEY_2)) } + output { inState.copy(beneficiary = AnonymousParty(BOB_PUBKEY)) } + output { inState.copy(beneficiary = AnonymousParty(BOB_PUBKEY)) } this `fails with` "the amounts balance" } // This works. - output { inState.copy(beneficiary = AnonymousParty(DUMMY_PUBKEY_2)) } - output { inState.copy(beneficiary = AnonymousParty(DUMMY_PUBKEY_2)) `issued by` MINI_CORP } + output { inState.copy(beneficiary = AnonymousParty(BOB_PUBKEY)) } + output { inState.copy(beneficiary = AnonymousParty(BOB_PUBKEY)) `issued by` MINI_CORP } command(CHARLIE.owningKey) { Obligation.Commands.Move() } this.verifies() } @@ -715,12 +724,12 @@ class ObligationTests { fun multiCurrency() { // Check we can do an atomic currency trade tx. transaction { - val pounds = Obligation.State(Lifecycle.NORMAL, MINI_CORP, megaCorpPoundSettlement, 658.POUNDS.quantity, AnonymousParty(DUMMY_PUBKEY_2)) + val pounds = Obligation.State(Lifecycle.NORMAL, MINI_CORP, megaCorpPoundSettlement, 658.POUNDS.quantity, AnonymousParty(BOB_PUBKEY)) input { inState `owned by` CHARLIE } input { pounds } - output { inState `owned by` AnonymousParty(DUMMY_PUBKEY_2) } + output { inState `owned by` AnonymousParty(BOB_PUBKEY) } output { pounds `owned by` CHARLIE } - command(CHARLIE.owningKey, DUMMY_PUBKEY_2) { Obligation.Commands.Move() } + command(CHARLIE.owningKey, BOB_PUBKEY) { Obligation.Commands.Move() } this.verifies() } @@ -755,10 +764,10 @@ class ObligationTests { // States must not be nettable if the cash contract differs assertNotEquals(fiveKDollarsFromMegaToMega.bilateralNetState, - fiveKDollarsFromMegaToMega.copy(template = megaCorpDollarSettlement.copy(acceptableContracts = nonEmptySetOf(SecureHash.randomSHA256()))).bilateralNetState) + fiveKDollarsFromMegaToMega.copy(template = megaCorpDollarSettlement.copy(acceptableContracts = NonEmptySet.of(SecureHash.randomSHA256()))).bilateralNetState) // States must not be nettable if the trusted issuers differ - val miniCorpIssuer = nonEmptySetOf(Issued(MINI_CORP.ref(1), USD)) + val miniCorpIssuer = NonEmptySet.of(Issued(MINI_CORP.ref(1), USD)) assertNotEquals(fiveKDollarsFromMegaToMega.bilateralNetState, fiveKDollarsFromMegaToMega.copy(template = megaCorpDollarSettlement.copy(acceptableIssuedProducts = miniCorpIssuer)).bilateralNetState) } @@ -856,6 +865,7 @@ class ObligationTests { @Test fun `summing balances due between parties`() { + initialiseTestSerialization() val simple: Map, Amount> = mapOf(Pair(Pair(ALICE, BOB), Amount(100000000, GBP))) val expected: Map = mapOf(Pair(ALICE, -100000000L), Pair(BOB, 100000000L)) val actual = sumAmountsDue(simple) @@ -875,7 +885,7 @@ class ObligationTests { } val Issued.OBLIGATION_DEF: Obligation.Terms - get() = Obligation.Terms(nonEmptySetOf(Cash().legalContractReference), nonEmptySetOf(this), TEST_TX_TIME) + get() = Obligation.Terms(NonEmptySet.of(Cash().legalContractReference), NonEmptySet.of(this), TEST_TX_TIME) val Amount>.OBLIGATION: Obligation.State get() = Obligation.State(Obligation.Lifecycle.NORMAL, DUMMY_OBLIGATION_ISSUER, token.OBLIGATION_DEF, quantity, NULL_PARTY) } diff --git a/finance/src/test/kotlin/net/corda/contracts/testing/Generators.kt b/finance/src/test/kotlin/net/corda/contracts/testing/Generators.kt deleted file mode 100644 index 1ca4c5d998..0000000000 --- a/finance/src/test/kotlin/net/corda/contracts/testing/Generators.kt +++ /dev/null @@ -1,91 +0,0 @@ -package net.corda.contracts.testing - -import com.pholser.junit.quickcheck.generator.GenerationStatus -import com.pholser.junit.quickcheck.generator.Generator -import com.pholser.junit.quickcheck.generator.java.util.ArrayListGenerator -import com.pholser.junit.quickcheck.random.SourceOfRandomness -import net.corda.contracts.asset.Cash -import net.corda.core.contracts.Command -import net.corda.core.contracts.CommandData -import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionType -import net.corda.core.crypto.testing.NullSignature -import net.corda.core.identity.AnonymousParty -import net.corda.core.testing.* -import net.corda.core.transactions.SignedTransaction -import net.corda.core.transactions.WireTransaction - -/** - * This file contains generators for quickcheck style testing. The idea is that we can write random instance generators - * for each type we have in the code and test against those instead of predefined mock data. This style of testing can - * catch corner case bugs and test algebraic properties of the code, for example deserialize(serialize(generatedThing)) == generatedThing - * - * TODO add combinators for easier Generator writing - */ -class ContractStateGenerator : Generator(ContractState::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): ContractState { - return Cash.State( - amount = AmountGenerator(IssuedGenerator(CurrencyGenerator())).generate(random, status), - owner = AnonymousParty(PublicKeyGenerator().generate(random, status)) - ) - } -} - -class MoveGenerator : Generator(Cash.Commands.Move::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Cash.Commands.Move { - return Cash.Commands.Move(SecureHashGenerator().generate(random, status)) - } -} - -class IssueGenerator : Generator(Cash.Commands.Issue::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Cash.Commands.Issue { - return Cash.Commands.Issue(random.nextLong()) - } -} - -class ExitGenerator : Generator(Cash.Commands.Exit::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Cash.Commands.Exit { - return Cash.Commands.Exit(AmountGenerator(IssuedGenerator(CurrencyGenerator())).generate(random, status)) - } -} - -class CommandDataGenerator : Generator(CommandData::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): CommandData { - val generators = listOf(MoveGenerator(), IssueGenerator(), ExitGenerator()) - return generators[random.nextInt(0, generators.size - 1)].generate(random, status) - } -} - -class CommandGenerator : Generator(Command::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): Command { - val signersGenerator = ArrayListGenerator() - signersGenerator.addComponentGenerators(listOf(PublicKeyGenerator())) - return Command(CommandDataGenerator().generate(random, status), PublicKeyGenerator().generate(random, status)) - } -} - -class WiredTransactionGenerator : Generator(WireTransaction::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): WireTransaction { - val commands = CommandGenerator().generateList(random, status) + listOf(CommandGenerator().generate(random, status)) - return WireTransaction( - inputs = StateRefGenerator().generateList(random, status), - attachments = SecureHashGenerator().generateList(random, status), - outputs = TransactionStateGenerator(ContractStateGenerator()).generateList(random, status), - commands = commands, - notary = PartyGenerator().generate(random, status), - signers = commands.flatMap { it.signers }, - type = TransactionType.General, - timeWindow = TimeWindowGenerator().generate(random, status) - ) - } -} - -class SignedTransactionGenerator : Generator(SignedTransaction::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): SignedTransaction { - val wireTransaction = WiredTransactionGenerator().generate(random, status) - return SignedTransaction( - txBits = wireTransaction.serialized, - sigs = listOf(NullSignature) - ) - } -} diff --git a/finance/src/test/kotlin/net/corda/flows/BroadcastTransactionFlowTest.kt b/finance/src/test/kotlin/net/corda/flows/BroadcastTransactionFlowTest.kt deleted file mode 100644 index 49dc19751a..0000000000 --- a/finance/src/test/kotlin/net/corda/flows/BroadcastTransactionFlowTest.kt +++ /dev/null @@ -1,31 +0,0 @@ -package net.corda.flows - -import com.pholser.junit.quickcheck.From -import com.pholser.junit.quickcheck.Property -import com.pholser.junit.quickcheck.generator.GenerationStatus -import com.pholser.junit.quickcheck.generator.Generator -import com.pholser.junit.quickcheck.random.SourceOfRandomness -import com.pholser.junit.quickcheck.runner.JUnitQuickcheck -import net.corda.contracts.testing.SignedTransactionGenerator -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.serialize -import net.corda.flows.BroadcastTransactionFlow.NotifyTxRequest -import org.junit.runner.RunWith -import kotlin.test.assertEquals - -@RunWith(JUnitQuickcheck::class) -class BroadcastTransactionFlowTest { - - class NotifyTxRequestMessageGenerator : Generator(NotifyTxRequest::class.java) { - override fun generate(random: SourceOfRandomness, status: GenerationStatus): NotifyTxRequest { - return NotifyTxRequest(tx = SignedTransactionGenerator().generate(random, status)) - } - } - - @Property - fun serialiseDeserialiseOfNotifyMessageWorks(@From(NotifyTxRequestMessageGenerator::class) message: NotifyTxRequest) { - val serialized = message.serialize().bytes - val deserialized = serialized.deserialize() - assertEquals(deserialized, message) - } -} diff --git a/finance/src/test/kotlin/net/corda/flows/CashExitFlowTests.kt b/finance/src/test/kotlin/net/corda/flows/CashExitFlowTests.kt index b5ff01bff8..ed967899a2 100644 --- a/finance/src/test/kotlin/net/corda/flows/CashExitFlowTests.kt +++ b/finance/src/test/kotlin/net/corda/flows/CashExitFlowTests.kt @@ -3,9 +3,9 @@ package net.corda.flows import net.corda.contracts.asset.Cash import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.`issued by` -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode @@ -26,9 +26,9 @@ class CashExitFlowTests { @Before fun start() { - val nodes = mockNet.createTwoNodes() - notaryNode = nodes.first - bankOfCordaNode = nodes.second + val nodes = mockNet.createSomeNodes(1) + notaryNode = nodes.notaryNode + bankOfCordaNode = nodes.partyNodes[0] notary = notaryNode.info.notaryIdentity bankOfCorda = bankOfCordaNode.info.legalIdentity @@ -55,7 +55,7 @@ class CashExitFlowTests { val expected = (initialBalance - exitAmount).`issued by`(bankOfCorda.ref(ref)) assertEquals(1, exitTx.inputs.size) assertEquals(1, exitTx.outputs.size) - val output = exitTx.outputs.map { it.data }.filterIsInstance().single() + val output = exitTx.outputsOfType().single() assertEquals(expected, output.amount) } diff --git a/finance/src/test/kotlin/net/corda/flows/CashIssueFlowTests.kt b/finance/src/test/kotlin/net/corda/flows/CashIssueFlowTests.kt index db183cabff..962632c4fc 100644 --- a/finance/src/test/kotlin/net/corda/flows/CashIssueFlowTests.kt +++ b/finance/src/test/kotlin/net/corda/flows/CashIssueFlowTests.kt @@ -3,9 +3,9 @@ package net.corda.flows import net.corda.contracts.asset.Cash import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.`issued by` -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode @@ -24,9 +24,9 @@ class CashIssueFlowTests { @Before fun start() { - val nodes = mockNet.createTwoNodes() - notaryNode = nodes.first - bankOfCordaNode = nodes.second + val nodes = mockNet.createSomeNodes(1) + notaryNode = nodes.notaryNode + bankOfCordaNode = nodes.partyNodes[0] notary = notaryNode.info.notaryIdentity bankOfCorda = bankOfCordaNode.info.legalIdentity @@ -47,7 +47,7 @@ class CashIssueFlowTests { notary)).resultFuture mockNet.runNetwork() val issueTx = future.getOrThrow().stx - val output = issueTx.tx.outputs.single().data as Cash.State + val output = issueTx.tx.outputsOfType().single() assertEquals(expected.`issued by`(bankOfCorda.ref(ref)), output.amount) } diff --git a/finance/src/test/kotlin/net/corda/flows/CashPaymentFlowTests.kt b/finance/src/test/kotlin/net/corda/flows/CashPaymentFlowTests.kt index 3bdf21b4da..e056758ffb 100644 --- a/finance/src/test/kotlin/net/corda/flows/CashPaymentFlowTests.kt +++ b/finance/src/test/kotlin/net/corda/flows/CashPaymentFlowTests.kt @@ -3,13 +3,12 @@ package net.corda.flows import net.corda.contracts.asset.Cash import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.`issued by` -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.node.services.Vault import net.corda.core.node.services.trackBy import net.corda.core.node.services.vault.QueryCriteria import net.corda.core.utilities.OpaqueBytes -import net.corda.node.utilities.transaction +import net.corda.core.utilities.getOrThrow import net.corda.testing.expect import net.corda.testing.expectEvents import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin @@ -32,14 +31,12 @@ class CashPaymentFlowTests { @Before fun start() { - val nodes = mockNet.createTwoNodes() - notaryNode = nodes.first - bankOfCordaNode = nodes.second + val nodes = mockNet.createSomeNodes(1) + notaryNode = nodes.notaryNode + bankOfCordaNode = nodes.partyNodes[0] notary = notaryNode.info.notaryIdentity bankOfCorda = bankOfCordaNode.info.legalIdentity - notaryNode.services.identityService.registerIdentity(bankOfCordaNode.info.legalIdentityAndCert) - bankOfCordaNode.services.identityService.registerIdentity(notaryNode.info.legalIdentityAndCert) val future = bankOfCordaNode.services.startFlow(CashIssueFlow(initialBalance, ref, bankOfCorda, notary)).resultFuture @@ -75,17 +72,17 @@ class CashPaymentFlowTests { expect { update -> require(update.consumed.size == 1) { "Expected 1 consumed states, actual: $update" } require(update.produced.size == 1) { "Expected 1 produced states, actual: $update" } - val changeState = update.produced.single().state.data as Cash.State + val changeState = update.produced.single().state.data assertEquals(expectedChange.`issued by`(bankOfCorda.ref(ref)), changeState.amount) } } // Check notary node vault updates vaultUpdatesBankClient.expectEvents { - expect { update -> - require(update.consumed.isEmpty()) { update.consumed.size } - require(update.produced.size == 1) { update.produced.size } - val paymentState = update.produced.single().state.data as Cash.State + expect { (consumed, produced) -> + require(consumed.isEmpty()) { consumed.size } + require(produced.size == 1) { produced.size } + val paymentState = produced.single().state.data assertEquals(expectedPayment.`issued by`(bankOfCorda.ref(ref)), paymentState.amount) } } diff --git a/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt b/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt index bc98ce9943..2edfca7884 100644 --- a/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt +++ b/finance/src/test/kotlin/net/corda/flows/IssuerFlowTest.kt @@ -1,24 +1,25 @@ package net.corda.flows -import com.google.common.util.concurrent.ListenableFuture import net.corda.contracts.asset.Cash +import net.corda.core.concurrent.CordaFuture +import net.corda.testing.contracts.calculateRandomlySizedAmounts import net.corda.core.contracts.Amount import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.currency import net.corda.core.flows.FlowException -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.node.services.Vault import net.corda.core.node.services.trackBy import net.corda.core.node.services.vault.QueryCriteria -import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.OpaqueBytes +import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.getOrThrow import net.corda.flows.IssuerFlow.IssuanceRequester -import net.corda.node.utilities.transaction -import net.corda.testing.* -import net.corda.testing.contracts.calculateRandomlySizedAmounts +import net.corda.testing.expect +import net.corda.testing.expectEvents import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode +import net.corda.testing.sequence import org.junit.After import org.junit.Before import org.junit.Test @@ -45,9 +46,10 @@ class IssuerFlowTest(val anonymous: Boolean) { @Before fun start() { mockNet = MockNetwork(threadPerNode = true) - notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) - bankOfCordaNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOC.name) - bankClientNode = mockNet.createPartyNode(notaryNode.network.myAddress, MEGA_CORP.name) + val basketOfNodes = mockNet.createSomeNodes(2) + bankOfCordaNode = basketOfNodes.partyNodes[0] + bankClientNode = basketOfNodes.partyNodes[1] + notaryNode = basketOfNodes.notaryNode } @After @@ -79,7 +81,7 @@ class IssuerFlowTest(val anonymous: Boolean) { expect { update -> require(update.consumed.isEmpty()) { "Expected 0 consumed states, actual: $update" } require(update.produced.size == 1) { "Expected 1 produced states, actual: $update" } - val issued = update.produced.single().state.data as Cash.State + val issued = update.produced.single().state.data require(issued.owner.owningKey in bankOfCordaNode.services.keyManagementService.keys) }, // MOVE @@ -93,10 +95,10 @@ class IssuerFlowTest(val anonymous: Boolean) { // Check Bank Client Vault Updates vaultUpdatesBankClient.expectEvents { // MOVE - expect { update -> - require(update.consumed.isEmpty()) { update.consumed.size } - require(update.produced.size == 1) { update.produced.size } - val paidState = update.produced.single().state.data as Cash.State + expect { (consumed, produced) -> + require(consumed.isEmpty()) { consumed.size } + require(produced.size == 1) { produced.size } + val paidState = produced.single().state.data require(paidState.owner.owningKey in bankClientNode.services.keyManagementService.keys) } } @@ -157,7 +159,7 @@ class IssuerFlowTest(val anonymous: Boolean) { amount: Amount, issueToParty: Party, ref: OpaqueBytes, - notaryParty: Party): ListenableFuture { + notaryParty: Party): CordaFuture { val issueToPartyAndRef = issueToParty.ref(ref) val issueRequest = IssuanceRequester(amount, issueToParty, issueToPartyAndRef.reference, issuerNode.info.legalIdentity, notaryParty, anonymous) diff --git a/gradle-plugins/cordformation/build.gradle b/gradle-plugins/cordformation/build.gradle index 72c778d692..390ea7c8ba 100644 --- a/gradle-plugins/cordformation/build.gradle +++ b/gradle-plugins/cordformation/build.gradle @@ -61,3 +61,7 @@ jar { rename { 'net/corda/plugins/runnodes.jar' } } } + +publish { + name project.name +} diff --git a/gradle-plugins/cordformation/src/main/groovy/net/corda/plugins/Cordformation.groovy b/gradle-plugins/cordformation/src/main/groovy/net/corda/plugins/Cordformation.groovy index 6eafcdf3cd..edd74182f2 100644 --- a/gradle-plugins/cordformation/src/main/groovy/net/corda/plugins/Cordformation.groovy +++ b/gradle-plugins/cordformation/src/main/groovy/net/corda/plugins/Cordformation.groovy @@ -10,13 +10,22 @@ import org.gradle.api.artifacts.Configuration */ class Cordformation implements Plugin { void apply(Project project) { - Configuration cordappConf = project.configurations.create("cordapp") - cordappConf.transitive = false - project.configurations.compile.extendsFrom cordappConf + createCompileConfiguration("cordapp", project) + createCompileConfiguration("cordaCompile", project) + + Configuration configuration = project.configurations.create("cordaRuntime") + configuration.transitive = false + project.configurations.runtime.extendsFrom configuration configureCordappJar(project) } + private void createCompileConfiguration(String name, Project project) { + Configuration configuration = project.configurations.create(name) + configuration.transitive = false + project.configurations.compile.extendsFrom configuration + } + /** * Configures this project's JAR as a Cordapp JAR */ @@ -47,26 +56,30 @@ class Cordformation implements Plugin { }, filePathInJar).asFile() } - private static def getDirectNonCordaDependencies(Project project) { - def coreCordaNames = ['jfx', 'mock', 'rpc', 'core', 'corda', 'cordform-common', 'corda-webserver', 'finance', 'node', 'node-api', 'node-schemas', 'test-utils', 'jackson', 'verifier', 'webserver', 'capsule', 'webcapsule'] - def excludes = coreCordaNames.collect { [group: 'net.corda', name: it] } + [ + private static Set getDirectNonCordaDependencies(Project project) { + def excludes = [ [group: 'org.jetbrains.kotlin', name: 'kotlin-stdlib'], [group: 'org.jetbrains.kotlin', name: 'kotlin-stdlib-jre8'], [group: 'co.paralleluniverse', name: 'quasar-core'] ] - // The direct dependencies of this project - def cordappDeps = project.configurations.cordapp.allDependencies - def directDeps = project.configurations.runtime.allDependencies - cordappDeps - // We want to filter out anything Corda related or provided by Corda, like kotlin-stdlib and quasar - def filteredDeps = directDeps.findAll { excludes.collect { exclude -> (exclude.group == it.group) && (exclude.name == it.name) }.findAll { it }.isEmpty() } - filteredDeps.each { - // net.corda may be a core dependency which shouldn't be included in this cordapp so give a warning - if(it.group.contains('net.corda')) { - project.logger.warn("Including a dependency with a net.corda group: $it") - } else { - project.logger.trace("Including dependency: $it") + + project.with { + // The direct dependencies of this project + def excludeDeps = configurations.cordapp.allDependencies + configurations.cordaCompile.allDependencies + configurations.cordaRuntime.allDependencies + def directDeps = configurations.runtime.allDependencies - excludeDeps + // We want to filter out anything Corda related or provided by Corda, like kotlin-stdlib and quasar + def filteredDeps = directDeps.findAll { excludes.collect { exclude -> (exclude.group == it.group) && (exclude.name == it.name) }.findAll { it }.isEmpty() } + filteredDeps.each { + // net.corda may be a core dependency which shouldn't be included in this cordapp so give a warning + if(it.group.contains('net.corda.')) { + logger.warn("You appear to have included a Corda platform component ($it) using a 'compile' or 'runtime' dependency." + + "This can cause node stability problems. Please use 'corda' instead." + + "See http://docs.corda.net/cordapp-build-systems.html") + } else { + logger.trace("Including dependency: $it") + } } + return filteredDeps.collect { configurations.runtime.files it }.flatten().toSet() } - return filteredDeps.collect { project.configurations.runtime.files it }.flatten().toSet() } } diff --git a/gradle-plugins/publish-utils/README.rst b/gradle-plugins/publish-utils/README.rst index 3ca1c2febf..d1657ee5fe 100644 --- a/gradle-plugins/publish-utils/README.rst +++ b/gradle-plugins/publish-utils/README.rst @@ -76,8 +76,8 @@ The project configuration block has the following structure: .. code-block:: text publish { - name = 'non-default-project-name' disableDefaultJar = false // set to true to disable the default JAR being created (e.g. when creating a fat JAR) + name 'non-default-project-name' // Always put this last because it causes configuration to happen } **Artifacts** diff --git a/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/ProjectPublishExtension.groovy b/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/ProjectPublishExtension.groovy index de5c4dfb90..ee978bdbb8 100644 --- a/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/ProjectPublishExtension.groovy +++ b/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/ProjectPublishExtension.groovy @@ -1,14 +1,25 @@ package net.corda.plugins class ProjectPublishExtension { + private PublishTasks task + + void setPublishTask(PublishTasks task) { + this.task = task + } + /** - * Use a different name from the current project name for publishing + * Use a different name from the current project name for publishing. + * Set this after all other settings that need to be configured */ - String name + void name(String name) { + task.setPublishName(name) + } + /** * True when we do not want to publish default Java components */ Boolean disableDefaultJar = false + /** * True if publishing a WAR instead of a JAR. Forces disableDefaultJAR to "true" when true */ diff --git a/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/PublishTasks.groovy b/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/PublishTasks.groovy index f5221ff037..f25e07da6a 100644 --- a/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/PublishTasks.groovy +++ b/gradle-plugins/publish-utils/src/main/groovy/net/corda/plugins/PublishTasks.groovy @@ -23,24 +23,17 @@ class PublishTasks implements Plugin { void apply(Project project) { this.project = project + this.publishName = project.name createTasks() createExtensions() createConfigurations() - - project.afterEvaluate { - configurePublishingName() - checkAndConfigurePublishing() - } } - void configurePublishingName() { - if(publishConfig.name != null) { - project.logger.info("Changing publishing name for ${project.name} to ${publishConfig.name}") - publishName = publishConfig.name - } else { - publishName = project.name - } + void setPublishName(String publishName) { + project.logger.info("Changing publishing name from ${project.name} to ${publishName}") + this.publishName = publishName + checkAndConfigurePublishing() } void checkAndConfigurePublishing() { @@ -157,6 +150,7 @@ class PublishTasks implements Plugin { project.extensions.create("bintrayConfig", BintrayConfigExtension) } publishConfig = project.extensions.create("publish", ProjectPublishExtension) + publishConfig.setPublishTask(this) } void createConfigurations() { diff --git a/gradle-plugins/quasar-utils/build.gradle b/gradle-plugins/quasar-utils/build.gradle index e4bba19bcf..7829d47d75 100644 --- a/gradle-plugins/quasar-utils/build.gradle +++ b/gradle-plugins/quasar-utils/build.gradle @@ -12,3 +12,7 @@ dependencies { compile gradleApi() compile localGroovy() } + +publish { + name project.name +} diff --git a/node-api/build.gradle b/node-api/build.gradle index 6bb83dcbf6..43d289940b 100644 --- a/node-api/build.gradle +++ b/node-api/build.gradle @@ -27,6 +27,10 @@ dependencies { // TypeSafe Config: for simple and human friendly config files. compile "com.typesafe:config:$typesafe_config_version" + // Kryo: object graph serialization. + compile "com.esotericsoftware:kryo:4.0.0" + compile "de.javakaffee:kryo-serializers:0.41" + // Unit testing helpers. testCompile "junit:junit:$junit_version" testCompile "org.assertj:assertj-core:${assertj_version}" @@ -38,5 +42,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt index a37a841f83..f1ad6582a1 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/ArtemisMessagingComponent.kt @@ -6,7 +6,7 @@ import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.NodeInfo import net.corda.core.node.services.ServiceType -import net.corda.core.read +import net.corda.core.internal.read import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.NetworkHostAndPort diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt index 8e5aaf8cfb..e3376e482c 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCApi.kt @@ -1,7 +1,6 @@ package net.corda.nodeapi -import com.esotericsoftware.kryo.pool.KryoPool -import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.Try @@ -96,12 +95,12 @@ object RPCApi { val methodName: String, val arguments: List ) : ClientToServer() { - fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { MessageUtil.setJMSReplyTo(message, clientAddress) message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REQUEST.ordinal) message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong) message.putStringProperty(METHOD_NAME_FIELD_NAME, methodName) - message.bodyBuffer.writeBytes(arguments.serialize(kryoPool).bytes) + message.bodyBuffer.writeBytes(arguments.serialize(context = context).bytes) } } @@ -119,14 +118,14 @@ object RPCApi { } companion object { - fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ClientToServer { + fun fromClientMessage(context: SerializationContext, message: ClientMessage): ClientToServer { val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] return when (tag) { RPCApi.ClientToServer.Tag.RPC_REQUEST -> RpcRequest( clientAddress = MessageUtil.getJMSReplyTo(message), id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)), methodName = message.getStringProperty(METHOD_NAME_FIELD_NAME), - arguments = message.getBodyAsByteArray().deserialize(kryoPool) + arguments = message.getBodyAsByteArray().deserialize(context = context) ) RPCApi.ClientToServer.Tag.OBSERVABLES_CLOSED -> { val ids = ArrayList() @@ -148,48 +147,48 @@ object RPCApi { OBSERVATION } - abstract fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) + abstract fun writeToClientMessage(context: SerializationContext, message: ClientMessage) data class RpcReply( val id: RpcRequestId, val result: Try ) : ServerToClient() { - override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal) message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong) - message.bodyBuffer.writeBytes(result.serialize(kryoPool).bytes) + message.bodyBuffer.writeBytes(result.serialize(context = context).bytes) } } data class Observation( val id: ObservableId, - val content: Notification + val content: Notification<*> ) : ServerToClient() { - override fun writeToClientMessage(kryoPool: KryoPool, message: ClientMessage) { + override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) { message.putIntProperty(TAG_FIELD_NAME, Tag.OBSERVATION.ordinal) message.putLongProperty(OBSERVABLE_ID_FIELD_NAME, id.toLong) - message.bodyBuffer.writeBytes(content.serialize(kryoPool).bytes) + message.bodyBuffer.writeBytes(content.serialize(context = context).bytes) } } companion object { - fun fromClientMessage(kryoPool: KryoPool, message: ClientMessage): ServerToClient { + fun fromClientMessage(context: SerializationContext, message: ClientMessage): ServerToClient { val tag = Tag.values()[message.getIntProperty(TAG_FIELD_NAME)] return when (tag) { RPCApi.ServerToClient.Tag.RPC_REPLY -> { val id = RpcRequestId(message.getLongProperty(RPC_ID_FIELD_NAME)) - val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong) + val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id.toLong) RpcReply( id = id, - result = message.getBodyAsByteArray().deserialize(poolWithIdContext) + result = message.getBodyAsByteArray().deserialize(context = poolWithIdContext) ) } RPCApi.ServerToClient.Tag.OBSERVATION -> { val id = ObservableId(message.getLongProperty(OBSERVABLE_ID_FIELD_NAME)) - val poolWithIdContext = KryoPoolWithContext(kryoPool, RpcRequestOrObservableIdKey, id.toLong) + val poolWithIdContext = context.withProperty(RpcRequestOrObservableIdKey, id.toLong) Observation( id = id, - content = message.getBodyAsByteArray().deserialize(poolWithIdContext) + content = message.getBodyAsByteArray().deserialize(context = poolWithIdContext) ) } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt index 0b5236f68f..a4be40c829 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/RPCStructures.kt @@ -4,13 +4,16 @@ package net.corda.nodeapi import com.esotericsoftware.kryo.Registration import com.esotericsoftware.kryo.Serializer -import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.requireExternal -import net.corda.core.serialization.* +import net.corda.core.concurrent.CordaFuture +import net.corda.core.CordaRuntimeException +import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory import net.corda.core.toFuture import net.corda.core.toObservable -import net.corda.core.CordaRuntimeException import net.corda.nodeapi.config.OldConfig +import net.corda.nodeapi.internal.serialization.* import rx.Observable import java.io.InputStream @@ -46,16 +49,15 @@ class PermissionException(msg: String) : RuntimeException(msg) // The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly. // This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes, // because we can see everything we're using in one place. -class RPCKryo(observableSerializer: Serializer>) : CordaKryo(makeStandardClassResolver()) { +class RPCKryo(observableSerializer: Serializer>, val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : CordaKryo(CordaClassResolver(serializationFactory, serializationContext)) { init { DefaultKryoCustomizer.customize(this) // RPC specific classes register(InputStream::class.java, InputStreamSerializer) register(Observable::class.java, observableSerializer) - @Suppress("UNCHECKED_CAST") - register(ListenableFuture::class, - read = { kryo, input -> observableSerializer.read(kryo, input, Observable::class.java as Class>).toFuture() }, + register(CordaFuture::class, + read = { kryo, input -> observableSerializer.read(kryo, input, Observable::class.java).toFuture() }, write = { kryo, output, obj -> observableSerializer.write(kryo, output, obj.toObservable()) } ) } @@ -67,10 +69,14 @@ class RPCKryo(observableSerializer: Serializer>) : CordaKryo(mak if (InputStream::class.java != type && InputStream::class.java.isAssignableFrom(type)) { return super.getRegistration(InputStream::class.java) } - if (ListenableFuture::class.java != type && ListenableFuture::class.java.isAssignableFrom(type)) { - return super.getRegistration(ListenableFuture::class.java) + if (CordaFuture::class.java != type && CordaFuture::class.java.isAssignableFrom(type)) { + return super.getRegistration(CordaFuture::class.java) } type.requireExternal("RPC not allowed to deserialise internal classes") return super.getRegistration(type) } + + private fun Class<*>.requireExternal(msg: String) { + require(!name.startsWith("net.corda.node.") && !name.contains(".internal.")) { "$msg: $name" } + } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/config/ConfigUtilities.kt b/node-api/src/main/kotlin/net/corda/nodeapi/config/ConfigUtilities.kt index 81455d6e6a..34a7bbe87e 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/config/ConfigUtilities.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/config/ConfigUtilities.kt @@ -2,7 +2,7 @@ package net.corda.nodeapi.config import com.typesafe.config.Config import com.typesafe.config.ConfigUtil -import net.corda.core.noneOrSingle +import net.corda.core.internal.noneOrSingle import net.corda.core.utilities.validateX500Name import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.parseNetworkHostAndPort @@ -73,7 +73,7 @@ private fun Config.getSingleValue(path: String, type: KType): Any? { Path::class -> Paths.get(getString(path)) URL::class -> URL(getString(path)) Properties::class -> getConfig(path).toProperties() - X500Name::class -> X500Name(getString(path)).apply(::validateX500Name) + X500Name::class -> X500Name(getString(path)) else -> if (typeClass.java.isEnum) { parseEnum(typeClass.java, getString(path)) } else { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/config/SSLConfiguration.kt b/node-api/src/main/kotlin/net/corda/nodeapi/config/SSLConfiguration.kt index 13a70eb517..6fb82e508f 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/config/SSLConfiguration.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/config/SSLConfiguration.kt @@ -1,6 +1,6 @@ package net.corda.nodeapi.config -import net.corda.core.div +import net.corda.core.internal.div import java.nio.file.Path interface SSLConfiguration { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt new file mode 100644 index 0000000000..3d1ca75f34 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AMQPSerializationScheme.kt @@ -0,0 +1,107 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0 +import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput +import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory +import java.util.concurrent.ConcurrentHashMap + +internal val AMQP_ENABLED get() = SerializationDefaults.P2P_CONTEXT.preferedSerializationVersion == AmqpHeaderV1_0 + +abstract class AbstractAMQPSerializationScheme : SerializationScheme { + internal companion object { + fun registerCustomSerializers(factory: SerializerFactory) { + factory.apply { + register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this)) + register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer) + register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this)) + } + } + } + + private val serializerFactoriesForContexts = ConcurrentHashMap, SerializerFactory>() + + protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory + protected abstract fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory + + private fun getSerializerFactory(context: SerializationContext): SerializerFactory { + return serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { + when (context.useCase) { + SerializationContext.UseCase.Checkpoint -> + throw IllegalStateException("AMQP should not be used for checkpoint serialization.") + SerializationContext.UseCase.RPCClient -> + rpcClientSerializerFactory(context) + SerializationContext.UseCase.RPCServer -> + rpcServerSerializerFactory(context) + else -> SerializerFactory(context.whitelist) // TODO pass class loader also + } + }.also { registerCustomSerializers(it) } + } + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + val serializerFactory = getSerializerFactory(context) + return DeserializationInput(serializerFactory).deserialize(byteSequence, clazz) + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + val serializerFactory = getSerializerFactory(context) + return SerializationOutput(serializerFactory).serialize(obj) + } + + protected fun canDeserializeVersion(byteSequence: ByteSequence): Boolean = AMQP_ENABLED && byteSequence == AmqpHeaderV1_0 +} + +// TODO: This will eventually cover server RPC as well and move to node module, but for now this is not implemented +class AMQPServerSerializationScheme : AbstractAMQPSerializationScheme() { + override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { + throw UnsupportedOperationException() + } + + override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return (canDeserializeVersion(byteSequence) && + (target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage)) + } + +} + +// TODO: This will eventually cover client RPC as well and move to client module, but for now this is not implemented +class AMQPClientSerializationScheme : AbstractAMQPSerializationScheme() { + override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory { + throw UnsupportedOperationException() + } + + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return (canDeserializeVersion(byteSequence) && + (target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage)) + } + +} + +val AMQP_P2P_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.P2P) +val AMQP_STORAGE_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0, + SerializationDefaults.javaClass.classLoader, + AllButBlacklisted, + emptyMap(), + true, + SerializationContext.UseCase.Storage) \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/AllButBlacklisted.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AllButBlacklisted.kt similarity index 91% rename from core/src/main/kotlin/net/corda/core/serialization/AllButBlacklisted.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AllButBlacklisted.kt index 6f3bb8d3dd..81d72f0989 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/AllButBlacklisted.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/AllButBlacklisted.kt @@ -1,12 +1,22 @@ -package net.corda.core.serialization +package net.corda.nodeapi.internal.serialization +import net.corda.core.serialization.ClassWhitelist import sun.misc.Unsafe import sun.security.util.Password import java.io.* import java.lang.invoke.* -import java.lang.reflect.* -import java.net.* -import java.security.* +import java.lang.reflect.AccessibleObject +import java.lang.reflect.Modifier +import java.lang.reflect.Parameter +import java.lang.reflect.ReflectPermission +import java.net.DatagramSocket +import java.net.ServerSocket +import java.net.Socket +import java.net.URLConnection +import java.security.AccessController +import java.security.KeyStore +import java.security.Permission +import java.security.Provider import java.sql.Connection import java.util.* import java.util.logging.Handler diff --git a/core/src/main/kotlin/net/corda/core/serialization/CordaClassResolver.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt similarity index 72% rename from core/src/main/kotlin/net/corda/core/serialization/CordaClassResolver.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt index 11f450cac9..9c180148fd 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/CordaClassResolver.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolver.kt @@ -1,14 +1,16 @@ -package net.corda.core.serialization +package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.serializers.FieldSerializer import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.Util -import net.corda.core.node.AttachmentsClassLoader +import net.corda.core.serialization.* import net.corda.core.utilities.loggerFor +import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0 import java.io.PrintWriter -import java.lang.reflect.Modifier +import java.lang.reflect.Modifier.isAbstract import java.nio.charset.StandardCharsets import java.nio.file.Files import java.nio.file.Paths @@ -19,23 +21,13 @@ fun Kryo.addToWhitelist(type: Class<*>) { ((classResolver as? CordaClassResolver)?.whitelist as? MutableClassWhitelist)?.add(type) } -fun makeStandardClassResolver(): ClassResolver { - return CordaClassResolver(GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist())) -} - -fun makeNoWhitelistClassResolver(): ClassResolver { - return CordaClassResolver(AllWhitelist) -} - -fun makeAllButBlacklistedClassResolver(): ClassResolver { - return CordaClassResolver(AllButBlacklisted) -} - /** * @param amqpEnabled Setting this to true turns on experimental AMQP serialization for any class annotated with * [CordaSerializable]. */ -class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean = false) : DefaultClassResolver() { +class CordaClassResolver(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : DefaultClassResolver() { + val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist) + /** Returns the registration for the specified class, or null if the class is not registered. */ override fun getRegistration(type: Class<*>): Registration? { return super.getRegistration(type) ?: checkClass(type) @@ -52,18 +44,16 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean } private fun checkClass(type: Class<*>): Registration? { - /** If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking. */ + // If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking. if (!whitelistEnabled) return null // Allow primitives, abstracts and interfaces - if (type.isPrimitive || type == Any::class.java || Modifier.isAbstract(type.modifiers) || type == String::class.java) return null + if (type.isPrimitive || type == Any::class.java || isAbstract(type.modifiers) || type == String::class.java) return null // If array, recurse on element type - if (type.isArray) { - return checkClass(type.componentType) - } - if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) { - // Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry. - return checkClass(type.superclass) - } + if (type.isArray) return checkClass(type.componentType) + // Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry. + if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass) + // Kotlin lambdas require some special treatment + if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) return null // It's safe to have the Class already, since Kryo loads it with initialisation off. // If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted. // Thus, blacklisting precedes annotation checking. @@ -74,34 +64,40 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean } override fun registerImplicit(type: Class<*>): Registration { - val hasAnnotation = checkForAnnotation(type) // If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the // case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures // are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph. - if (!hasAnnotation || !amqpEnabled) { - val objectInstance = try { - type.kotlin.objectInstance - } catch (t: Throwable) { - // objectInstance will throw if the type is something like a lambda - null - } - // We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent. - val references = kryo.references - try { - kryo.references = true - val serializer = if (objectInstance != null) KotlinObjectSerializer(objectInstance) else kryo.getDefaultSerializer(type) - return register(Registration(type, serializer, NAME.toInt())) - } finally { - kryo.references = references - } - } else { + if (checkForAnnotation(type) && AMQP_ENABLED) { // Build AMQP serializer - return register(Registration(type, KryoAMQPSerializer, NAME.toInt())) + return register(Registration(type, KryoAMQPSerializer(serializationFactory, serializationContext), NAME.toInt())) + } + + val objectInstance = try { + type.kotlin.objectInstance + } catch (t: Throwable) { + null // objectInstance will throw if the type is something like a lambda + } + + // We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent. + val references = kryo.references + try { + kryo.references = true + val serializer = if (objectInstance != null) { + KotlinObjectSerializer(objectInstance) + } else if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) { + // Kotlin lambdas extend this class and any captured variables are stored in synthentic fields + FieldSerializer(kryo, type).apply { setIgnoreSyntheticFields(false) } + } else { + kryo.getDefaultSerializer(type) + } + return register(Registration(type, serializer, NAME.toInt())) + } finally { + kryo.references = references } } - // Trivial Serializer which simply returns the given instance which we already know is a Kotlin object - private class KotlinObjectSerializer(val objectInstance: Any) : Serializer() { + // Trivial Serializer which simply returns the given instance, which we already know is a Kotlin object + private class KotlinObjectSerializer(private val objectInstance: Any) : Serializer() { override fun read(kryo: Kryo, input: Input, type: Class): Any = objectInstance override fun write(kryo: Kryo, output: Output, obj: Any) = Unit } @@ -141,10 +137,6 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean } } -interface ClassWhitelist { - fun hasListed(type: Class<*>): Boolean -} - interface MutableClassWhitelist : ClassWhitelist { fun add(entry: Class<*>) } @@ -176,6 +168,21 @@ class GlobalTransientClassWhiteList(val delegate: ClassWhitelist) : MutableClass } } +/** + * A whitelist that can be customised via the [CordaPluginRegistry], since implements [MutableClassWhitelist]. + */ +class TransientClassWhiteList(val delegate: ClassWhitelist) : MutableClassWhitelist, ClassWhitelist by delegate { + val whitelist: MutableSet = Collections.synchronizedSet(mutableSetOf()) + + override fun hasListed(type: Class<*>): Boolean { + return (type.name in whitelist) || delegate.hasListed(type) + } + + override fun add(entry: Class<*>) { + whitelist += entry.name + } +} + /** * This class is not currently used, but can be installed to log a large number of missing entries from the whitelist diff --git a/core/src/main/kotlin/net/corda/core/serialization/DefaultKryoCustomizer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultKryoCustomizer.kt similarity index 76% rename from core/src/main/kotlin/net/corda/core/serialization/DefaultKryoCustomizer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultKryoCustomizer.kt index 9b1a18be0c..5112ea6a33 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/DefaultKryoCustomizer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultKryoCustomizer.kt @@ -1,6 +1,9 @@ -package net.corda.core.serialization +package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer import com.esotericsoftware.kryo.serializers.FieldSerializer import com.esotericsoftware.kryo.util.MapReferenceResolver @@ -9,12 +12,14 @@ import de.javakaffee.kryoserializers.BitSetSerializer import de.javakaffee.kryoserializers.UnmodifiableCollectionsSerializer import de.javakaffee.kryoserializers.guava.* import net.corda.core.crypto.composite.CompositeKey -import net.corda.core.crypto.MetaData import net.corda.core.node.CordaPluginRegistry +import net.corda.core.serialization.SerializeAsToken +import net.corda.core.serialization.SerializedBytes +import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction import net.corda.core.utilities.NonEmptySet -import net.corda.core.utilities.NonEmptySetSerializer +import net.corda.core.utilities.toNonEmptySet import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPublicKey import org.bouncycastle.asn1.x500.X500Name @@ -29,6 +34,7 @@ import org.objenesis.instantiator.ObjectInstantiator import org.objenesis.strategy.InstantiatorStrategy import org.objenesis.strategy.StdInstantiatorStrategy import org.slf4j.Logger +import sun.security.ec.ECPublicKeyImpl import sun.security.provider.certpath.X509CertPath import java.io.BufferedInputStream import java.io.FileInputStream @@ -36,13 +42,11 @@ import java.io.InputStream import java.lang.reflect.Modifier.isPublic import java.security.cert.CertPath import java.util.* +import kotlin.collections.ArrayList object DefaultKryoCustomizer { private val pluginRegistries: List by lazy { - // No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. - val unusedKryo = Kryo(makeStandardClassResolver(), MapReferenceResolver()) - val customization = KryoSerializationCustomization(unusedKryo) - ServiceLoader.load(CordaPluginRegistry::class.java).toList().filter { it.customizeSerialization(customization) } + ServiceLoader.load(CordaPluginRegistry::class.java, this.javaClass.classLoader).toList() } fun customize(kryo: Kryo): Kryo { @@ -55,8 +59,13 @@ object DefaultKryoCustomizer { instantiatorStrategy = CustomInstantiatorStrategy() + // WARNING: reordering the registrations here will cause a change in the serialized form, since classes + // with custom serializers get written as registration ids. This will break backwards-compatibility. + // Please add any new registrations to the end. + // TODO: re-organise registrations into logical groups before v1.0 + register(Arrays.asList("").javaClass, ArraysAsListSerializer()) - register(SignedTransaction::class.java, ImmutableClassSerializer(SignedTransaction::class)) + register(SignedTransaction::class.java, SignedTransactionSerializer) register(WireTransaction::class.java, WireTransactionSerializer) register(SerializedBytes::class.java, SerializedBytesSerializer) @@ -73,7 +82,7 @@ object DefaultKryoCustomizer { noReferencesWithin() - register(sun.security.ec.ECPublicKeyImpl::class.java, ECPublicKeyImplSerializer) + register(ECPublicKeyImpl::class.java, ECPublicKeyImplSerializer) register(EdDSAPublicKey::class.java, Ed25519PublicKeySerializer) register(EdDSAPrivateKey::class.java, Ed25519PrivateKeySerializer) @@ -88,7 +97,6 @@ object DefaultKryoCustomizer { addDefaultSerializer(SerializeAsToken::class.java, SerializeAsTokenSerializer()) - register(MetaData::class.java, MetaDataSerializer) register(BitSet::class.java, BitSetSerializer()) register(Class::class.java, ClassSerializer) @@ -112,6 +120,8 @@ object DefaultKryoCustomizer { register(BCSphincs256PublicKey::class.java, PublicKeySerializer) register(sun.security.ec.ECPublicKeyImpl::class.java, PublicKeySerializer) + register(NotaryChangeWireTransaction::class.java, NotaryChangeWireTransactionSerializer) + val customization = KryoSerializationCustomization(this) pluginRegistries.forEach { it.customizeSerialization(customization) } } @@ -128,4 +138,22 @@ object DefaultKryoCustomizer { return strat.newInstantiatorOf(type) } } -} + + private object NonEmptySetSerializer : Serializer>() { + override fun write(kryo: Kryo, output: Output, obj: NonEmptySet) { + // Write out the contents as normal + output.writeInt(obj.size, true) + obj.forEach { kryo.writeClassAndObject(output, it) } + } + + override fun read(kryo: Kryo, input: Input, type: Class>): NonEmptySet { + val size = input.readInt(true) + require(size >= 1) { "Invalid size read off the wire: $size" } + val list = ArrayList(size) + repeat(size) { + list += kryo.readClassAndObject(input) + } + return list.toNonEmptySet() + } + } +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/serialization/DefaultWhitelist.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultWhitelist.kt similarity index 98% rename from node-api/src/main/kotlin/net/corda/nodeapi/serialization/DefaultWhitelist.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultWhitelist.kt index 9e5b507cdb..ae10b24d4b 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/serialization/DefaultWhitelist.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/DefaultWhitelist.kt @@ -1,4 +1,4 @@ -package net.corda.nodeapi.serialization +package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.KryoException import net.corda.core.node.CordaPluginRegistry diff --git a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/Kryo.kt similarity index 72% rename from core/src/main/kotlin/net/corda/core/serialization/Kryo.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/Kryo.kt index b055c600da..9066dbed7e 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/Kryo.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/Kryo.kt @@ -1,21 +1,25 @@ -package net.corda.core.serialization +package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoCallback -import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.util.MapReferenceResolver -import com.google.common.annotations.VisibleForTesting import net.corda.core.contracts.* -import net.corda.core.crypto.* +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.composite.CompositeKey import net.corda.core.identity.Party -import net.corda.core.node.AttachmentsClassLoader +import net.corda.core.internal.VisibleForTesting +import net.corda.core.serialization.AttachmentsClassLoader +import net.corda.core.serialization.MissingAttachmentsException +import net.corda.core.serialization.SerializeAsTokenContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.transactions.CoreTransaction +import net.corda.core.transactions.NotaryChangeWireTransaction +import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction -import net.corda.core.utilities.LazyPool import net.corda.core.utilities.SgxSupport -import net.corda.core.utilities.OpaqueBytes import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec @@ -26,18 +30,15 @@ import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.cert.X509CertificateHolder import org.slf4j.Logger import org.slf4j.LoggerFactory +import sun.security.ec.ECPublicKeyImpl +import sun.security.util.DerValue import java.io.ByteArrayInputStream -import java.io.ByteArrayOutputStream import java.io.InputStream import java.lang.reflect.InvocationTargetException -import java.nio.file.Files -import java.nio.file.Path import java.security.PrivateKey import java.security.PublicKey import java.security.cert.CertPath import java.security.cert.CertificateFactory -import java.security.spec.InvalidKeySpecException -import java.time.Instant import java.util.* import javax.annotation.concurrent.ThreadSafe import kotlin.reflect.KClass @@ -45,6 +46,7 @@ import kotlin.reflect.KMutableProperty import kotlin.reflect.KParameter import kotlin.reflect.full.memberProperties import kotlin.reflect.full.primaryConstructor +import kotlin.reflect.jvm.isAccessible import kotlin.reflect.jvm.javaType /** @@ -79,58 +81,6 @@ import kotlin.reflect.jvm.javaType * TODO: eliminate internal, storage related whitelist issues, such as private keys in blob storage. */ -// A convenient instance of Kryo pre-configured with some useful things. Used as a default by various functions. -fun p2PKryo(): KryoPool = kryoPool - -// Same again, but this has whitelisting turned off for internal storage use only. -fun storageKryo(): KryoPool = internalKryoPool - - -/** - * A type safe wrapper around a byte array that contains a serialised object. You can call [SerializedBytes.deserialize] - * to get the original object back. - */ -@Suppress("unused") // Type parameter is just for documentation purposes. -class SerializedBytes(bytes: ByteArray, val internalOnly: Boolean = false) : OpaqueBytes(bytes) { - // It's OK to use lazy here because SerializedBytes is configured to use the ImmutableClassSerializer. - val hash: SecureHash by lazy { bytes.sha256() } - - fun writeToFile(path: Path): Path = Files.write(path, bytes) -} - -// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB -private val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray()) - -// Some extension functions that make deserialisation convenient and provide auto-casting of the result. -fun ByteArray.deserialize(kryo: KryoPool = p2PKryo()): T { - Input(this).use { - val header = OpaqueBytes(it.readBytes(8)) - if (header != KryoHeaderV0_1) { - throw KryoException("Serialized bytes header does not match any known format.") - } - @Suppress("UNCHECKED_CAST") - return kryo.run { k -> k.readClassAndObject(it) as T } - } -} - -// TODO: The preferred usage is with a pool. Try and eliminate use of this from RPC. -fun ByteArray.deserialize(kryo: Kryo): T = deserialize(kryo.asPool()) - -fun OpaqueBytes.deserialize(kryo: KryoPool = p2PKryo()): T { - return this.bytes.deserialize(kryo) -} - -// The more specific deserialize version results in the bytes being cached, which is faster. -@JvmName("SerializedBytesWireTransaction") -fun SerializedBytes.deserialize(kryo: KryoPool = p2PKryo()): WireTransaction = WireTransaction.deserialize(this, kryo) - -fun SerializedBytes.deserialize(kryo: KryoPool = if (internalOnly) storageKryo() else p2PKryo()): T = bytes.deserialize(kryo) - -fun SerializedBytes.deserialize(kryo: Kryo): T = bytes.deserialize(kryo.asPool()) - -// Internal adapter for use when we haven't yet converted to a pool, or for tests. -private fun Kryo.asPool(): KryoPool = (KryoPool.Builder { this }.build()) - /** * A serialiser that avoids writing the wrapper class to the byte stream, thus ensuring [SerializedBytes] is a pure * type safety hack. @@ -146,36 +96,6 @@ object SerializedBytesSerializer : Serializer>() { } } -/** - * Can be called on any object to convert it to a byte array (wrapped by [SerializedBytes]), regardless of whether - * the type is marked as serializable or was designed for it (so be careful!). - */ -fun T.serialize(kryo: KryoPool = p2PKryo(), internalOnly: Boolean = false): SerializedBytes { - return kryo.run { k -> serialize(k, internalOnly) } -} - - -private val serializeBufferPool = LazyPool( - newInstance = { ByteArray(64 * 1024) } -) -private val serializeOutputStreamPool = LazyPool( - clear = ByteArrayOutputStream::reset, - shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large - newInstance = { ByteArrayOutputStream(64 * 1024) } -) -fun T.serialize(kryo: Kryo, internalOnly: Boolean = false): SerializedBytes { - return serializeOutputStreamPool.run { stream -> - serializeBufferPool.run { buffer -> - Output(buffer).use { - it.outputStream = stream - it.writeBytes(KryoHeaderV0_1.bytes) - kryo.writeClassAndObject(it, this) - } - SerializedBytes(stream.toByteArray(), internalOnly) - } - } -} - /** * Serializes properties and deserializes by using the constructor. This assumes that all backed properties are * set via the constructor and the class is immutable. @@ -205,6 +125,7 @@ class ImmutableClassSerializer(val klass: KClass) : Serializer() output.writeInt(hashParameters(constructor.parameters)) for (param in constructor.parameters) { val kProperty = propsByName[param.name!!]!! + kProperty.isAccessible = true when (param.type.javaType.typeName) { "int" -> output.writeVarInt(kProperty.get(obj) as Int, true) "long" -> output.writeVarLong(kProperty.get(obj) as Long, true) @@ -213,6 +134,7 @@ class ImmutableClassSerializer(val klass: KClass) : Serializer() "byte" -> output.writeByte(kProperty.get(obj) as Byte) "double" -> output.writeDouble(kProperty.get(obj) as Double) "float" -> output.writeFloat(kProperty.get(obj) as Float) + "boolean" -> output.writeBoolean(kProperty.get(obj) as Boolean) else -> try { kryo.writeClassAndObject(output, kProperty.get(obj)) } catch (e: Exception) { @@ -246,6 +168,7 @@ class ImmutableClassSerializer(val klass: KClass) : Serializer() "byte" -> input.readByte() "double" -> input.readDouble() "float" -> input.readFloat() + "boolean" -> input.readBoolean() else -> kryo.readClassAndObject(input) } } @@ -258,7 +181,7 @@ class ImmutableClassSerializer(val klass: KClass) : Serializer() } } -// TODO This is a temporary inefficient serialiser for sending InputStreams through RPC. This may be done much more +// TODO This is a temporary inefficient serializer for sending InputStreams through RPC. This may be done much more // efficiently using Artemis's large message feature. object InputStreamSerializer : Serializer() { override fun write(kryo: Kryo, output: Output, stream: InputStream) { @@ -316,10 +239,6 @@ fun Input.readBytesWithLength(): ByteArray { return this.readBytes(size) } -/** Thrown during deserialisation to indicate that an attachment needed to construct the [WireTransaction] is not found */ -@CordaSerializable -class MissingAttachmentsException(val ids: List) : Exception() - /** A serialisation engine that knows how to deserialise code inside a sandbox */ @ThreadSafe object WireTransactionSerializer : Serializer() { @@ -332,9 +251,8 @@ object WireTransactionSerializer : Serializer() { kryo.writeClassAndObject(output, obj.outputs) kryo.writeClassAndObject(output, obj.commands) kryo.writeClassAndObject(output, obj.notary) - kryo.writeClassAndObject(output, obj.mustSign) - kryo.writeClassAndObject(output, obj.type) kryo.writeClassAndObject(output, obj.timeWindow) + kryo.writeClassAndObject(output, obj.privacySalt) } private fun attachmentsClassLoader(kryo: Kryo, attachmentHashes: List): ClassLoader? { @@ -358,21 +276,54 @@ object WireTransactionSerializer : Serializer() { // Otherwise we just assume the code we need is on the classpath already. kryo.useClassLoader(attachmentsClassLoader(kryo, attachmentHashes) ?: javaClass.classLoader) { val outputs = kryo.readClassAndObject(input) as List> - val commands = kryo.readClassAndObject(input) as List + val commands = kryo.readClassAndObject(input) as List> val notary = kryo.readClassAndObject(input) as Party? - val signers = kryo.readClassAndObject(input) as List - val transactionType = kryo.readClassAndObject(input) as TransactionType val timeWindow = kryo.readClassAndObject(input) as TimeWindow? - return WireTransaction(inputs, attachmentHashes, outputs, commands, notary, signers, transactionType, timeWindow) + val privacySalt = kryo.readClassAndObject(input) as PrivacySalt + return WireTransaction(inputs, attachmentHashes, outputs, commands, notary, timeWindow, privacySalt) } } } +@ThreadSafe +object NotaryChangeWireTransactionSerializer : Serializer() { + override fun write(kryo: Kryo, output: Output, obj: NotaryChangeWireTransaction) { + kryo.writeClassAndObject(output, obj.inputs) + kryo.writeClassAndObject(output, obj.notary) + kryo.writeClassAndObject(output, obj.newNotary) + } + + @Suppress("UNCHECKED_CAST") + override fun read(kryo: Kryo, input: Input, type: Class): NotaryChangeWireTransaction { + val inputs = kryo.readClassAndObject(input) as List + val notary = kryo.readClassAndObject(input) as Party + val newNotary = kryo.readClassAndObject(input) as Party + + return NotaryChangeWireTransaction(inputs, notary, newNotary) + } +} + +@ThreadSafe +object SignedTransactionSerializer : Serializer() { + override fun write(kryo: Kryo, output: Output, obj: SignedTransaction) { + kryo.writeClassAndObject(output, obj.txBits) + kryo.writeClassAndObject(output, obj.sigs) + } + + @Suppress("UNCHECKED_CAST") + override fun read(kryo: Kryo, input: Input, type: Class): SignedTransaction { + return SignedTransaction( + kryo.readClassAndObject(input) as SerializedBytes, + kryo.readClassAndObject(input) as List + ) + } +} + /** For serialising an ed25519 private key */ @ThreadSafe object Ed25519PrivateKeySerializer : Serializer() { override fun write(kryo: Kryo, output: Output, obj: EdDSAPrivateKey) { - check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec ) + check(obj.params == Crypto.EDDSA_ED25519_SHA512.algSpec) output.writeBytesWithLength(obj.seed) } @@ -398,15 +349,15 @@ object Ed25519PublicKeySerializer : Serializer() { /** For serialising an ed25519 public key */ @ThreadSafe -object ECPublicKeyImplSerializer : Serializer() { - override fun write(kryo: Kryo, output: Output, obj: sun.security.ec.ECPublicKeyImpl) { +object ECPublicKeyImplSerializer : Serializer() { + override fun write(kryo: Kryo, output: Output, obj: ECPublicKeyImpl) { output.writeBytesWithLength(obj.encoded) } - override fun read(kryo: Kryo, input: Input, type: Class): sun.security.ec.ECPublicKeyImpl { + override fun read(kryo: Kryo, input: Input, type: Class): ECPublicKeyImpl { val A = input.readBytesWithLength() - val der = sun.security.util.DerValue(A) - return sun.security.ec.ECPublicKeyImpl.parse(der) as sun.security.ec.ECPublicKeyImpl + val der = DerValue(A) + return ECPublicKeyImpl.parse(der) as ECPublicKeyImpl } } @@ -468,14 +419,6 @@ inline fun readListOfLength(kryo: Kryo, input: Input, minLen: Int = return list } -// No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. -private val internalKryoPool = KryoPool.Builder { DefaultKryoCustomizer.customize(CordaKryo(makeAllButBlacklistedClassResolver())) }.build() -private val kryoPool = KryoPool.Builder { DefaultKryoCustomizer.customize(CordaKryo(makeStandardClassResolver())) }.build() - -// No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. -@VisibleForTesting -fun createTestKryo(): Kryo = DefaultKryoCustomizer.customize(CordaKryo(makeNoWhitelistClassResolver())) - /** * We need to disable whitelist checking during calls from our Kryo code to register a serializer, since it checks * for existing registrations and then will enter our [CordaClassResolver.getRegistration] method. @@ -560,35 +503,6 @@ fun Kryo.withoutReferences(block: () -> T): T { } } -/** For serialising a MetaData object. */ -@ThreadSafe -object MetaDataSerializer : Serializer() { - override fun write(kryo: Kryo, output: Output, obj: MetaData) { - output.writeString(obj.schemeCodeName) - output.writeString(obj.versionID) - kryo.writeClassAndObject(output, obj.signatureType) - kryo.writeClassAndObject(output, obj.timestamp) - kryo.writeClassAndObject(output, obj.visibleInputs) - kryo.writeClassAndObject(output, obj.signedInputs) - output.writeBytesWithLength(obj.merkleRoot) - output.writeBytesWithLength(obj.publicKey.encoded) - } - - @Suppress("UNCHECKED_CAST") - @Throws(IllegalArgumentException::class, InvalidKeySpecException::class) - override fun read(kryo: Kryo, input: Input, type: Class): MetaData { - val schemeCodeName = input.readString() - val versionID = input.readString() - val signatureType = kryo.readClassAndObject(input) as SignatureType - val timestamp = kryo.readClassAndObject(input) as Instant? - val visibleInputs = kryo.readClassAndObject(input) as BitSet? - val signedInputs = kryo.readClassAndObject(input) as BitSet? - val merkleRoot = input.readBytesWithLength() - val publicKey = Crypto.decodePublicKey(schemeCodeName, input.readBytesWithLength()) - return MetaData(schemeCodeName, versionID, signatureType, timestamp, visibleInputs, signedInputs, merkleRoot, publicKey) - } -} - /** For serialising a Logger. */ @ThreadSafe object LoggerSerializer : Serializer() { @@ -655,24 +569,4 @@ object X509CertificateSerializer : Serializer() { } } -class KryoPoolWithContext(val baseKryoPool: KryoPool, val contextKey: Any, val context: Any) : KryoPool { - override fun run(callback: KryoCallback): T { - val kryo = borrow() - try { - return callback.execute(kryo) - } finally { - release(kryo) - } - } - - override fun borrow(): Kryo { - val kryo = baseKryoPool.borrow() - require(kryo.context.put(contextKey, context) == null) { "KryoPool already has context" } - return kryo - } - - override fun release(kryo: Kryo) { - requireNotNull(kryo.context.remove(contextKey)) { "Kryo instance lost context while borrowed" } - baseKryoPool.release(kryo) - } -} +fun Kryo.serializationContext(): SerializeAsTokenContext? = context.get(serializationContextKey) as? SerializeAsTokenContext \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoAMQPSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoAMQPSerializer.kt new file mode 100644 index 0000000000..16dec8a83e --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoAMQPSerializer.kt @@ -0,0 +1,37 @@ +package net.corda.nodeapi.internal.serialization + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory +import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.sequence +import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0 +import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput +import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory + +/** + * This [Kryo] custom [Serializer] switches the object graph of anything annotated with `@CordaSerializable` + * to using the AMQP serialization wire format, and simply writes that out as bytes to the wire. + * + * There is no need to write out the length, since this can be peeked out of the first few bytes of the stream. + */ +class KryoAMQPSerializer(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : Serializer() { + override fun write(kryo: Kryo, output: Output, obj: Any) { + val bytes = serializationFactory.serialize(obj, serializationContext.withPreferredSerializationVersion(AmqpHeaderV1_0)).bytes + // No need to write out the size since it's encoded within the AMQP. + output.write(bytes) + } + + override fun read(kryo: Kryo, input: Input, type: Class): Any { + // Use our helper functions to peek the size of the serialized object out of the AMQP byte stream. + val peekedBytes = input.readBytes(DeserializationInput.BYTES_NEEDED_TO_PEEK) + val size = DeserializationInput.peekSize(peekedBytes) + val allBytes = peekedBytes.copyOf(size) + input.readBytes(allBytes, peekedBytes.size, size - peekedBytes.size) + return serializationFactory.deserialize(allBytes.sequence(), type, serializationContext) + } +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoSerializationCustomization.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoSerializationCustomization.kt new file mode 100644 index 0000000000..6b1ab209b8 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/KryoSerializationCustomization.kt @@ -0,0 +1,10 @@ +package net.corda.nodeapi.internal.serialization + +import com.esotericsoftware.kryo.Kryo +import net.corda.core.serialization.SerializationCustomization + +class KryoSerializationCustomization(val kryo: Kryo) : SerializationCustomization { + override fun addToWhitelist(type: Class<*>) { + kryo.addToWhitelist(type) + } +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt new file mode 100644 index 0000000000..4ab8c3c27e --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializationScheme.kt @@ -0,0 +1,267 @@ +package net.corda.nodeapi.internal.serialization + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.io.serialization.kryo.KryoSerializer +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.KryoException +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.pool.KryoPool +import io.requery.util.CloseableIterator +import net.corda.core.internal.LazyPool +import net.corda.core.serialization.* +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.OpaqueBytes +import java.io.ByteArrayOutputStream +import java.io.NotSerializableException +import java.util.* +import java.util.concurrent.ConcurrentHashMap + +object NotSupportedSeralizationScheme : SerializationScheme { + private fun doThrow(): Nothing = throw UnsupportedOperationException("Serialization scheme not supported.") + + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean = doThrow() + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T = doThrow() + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes = doThrow() +} + +data class SerializationContextImpl(override val preferedSerializationVersion: ByteSequence, + override val deserializationClassLoader: ClassLoader, + override val whitelist: ClassWhitelist, + override val properties: Map, + override val objectReferencesEnabled: Boolean, + override val useCase: SerializationContext.UseCase) : SerializationContext { + override fun withProperty(property: Any, value: Any): SerializationContext { + return copy(properties = properties + (property to value)) + } + + override fun withoutReferences(): SerializationContext { + return copy(objectReferencesEnabled = false) + } + + override fun withClassLoader(classLoader: ClassLoader): SerializationContext { + return copy(deserializationClassLoader = classLoader) + } + + override fun withWhitelisted(clazz: Class<*>): SerializationContext { + return copy(whitelist = object : ClassWhitelist { + override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name + }) + } + + override fun withPreferredSerializationVersion(versionHeader: ByteSequence) = copy(preferedSerializationVersion = versionHeader) +} + +private const val HEADER_SIZE: Int = 8 + +open class SerializationFactoryImpl : SerializationFactory { + private val creator: List = Exception().stackTrace.asList() + + private val registeredSchemes: MutableCollection = Collections.synchronizedCollection(mutableListOf()) + + // TODO: This is read-mostly. Probably a faster implementation to be found. + private val schemes: ConcurrentHashMap, SerializationScheme> = ConcurrentHashMap() + + private fun schemeFor(byteSequence: ByteSequence, target: SerializationContext.UseCase): SerializationScheme { + // truncate sequence to 8 bytes, and make sure it's a copy to avoid holding onto large ByteArrays + return schemes.computeIfAbsent(byteSequence.take(HEADER_SIZE).copy() to target) { + for (scheme in registeredSchemes) { + if (scheme.canDeserializeVersion(it.first, it.second)) { + return@computeIfAbsent scheme + } + } + NotSupportedSeralizationScheme + } + } + + @Throws(NotSerializableException::class) + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T = schemeFor(byteSequence, context.useCase).deserialize(byteSequence, clazz, context) + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + return schemeFor(context.preferedSerializationVersion, context.useCase).serialize(obj, context) + } + + fun registerScheme(scheme: SerializationScheme) { + check(schemes.isEmpty()) { "All serialization schemes must be registered before any scheme is used." } + registeredSchemes += scheme + } + + val alreadyRegisteredSchemes: Collection get() = Collections.unmodifiableCollection(registeredSchemes) + + override fun toString(): String { + return "${this.javaClass.name} registeredSchemes=$registeredSchemes ${creator.joinToString("\n")}" + } + + override fun equals(other: Any?): Boolean { + return other is SerializationFactoryImpl && + other.registeredSchemes == this.registeredSchemes + } + + override fun hashCode(): Int = registeredSchemes.hashCode() +} + +private object AutoCloseableSerialisationDetector : Serializer() { + override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) { + val message = if (closeable is CloseableIterator<*>) { + "A live Iterator pointing to the database has been detected during flow checkpointing. This may be due " + + "to a Vault query - move it into a private method." + } else { + "${closeable.javaClass.name}, which is a closeable resource, has been detected during flow checkpointing. " + + "Restoring such resources across node restarts is not supported. Make sure code accessing it is " + + "confined to a private method or the reference is nulled out." + } + throw UnsupportedOperationException(message) + } + + override fun read(kryo: Kryo, input: Input, type: Class) = throw IllegalStateException("Should not reach here!") +} + +abstract class AbstractKryoSerializationScheme(val serializationFactory: SerializationFactory) : SerializationScheme { + private val kryoPoolsForContexts = ConcurrentHashMap, KryoPool>() + + protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool + protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool + + private fun getPool(context: SerializationContext): KryoPool { + return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) { + when (context.useCase) { + SerializationContext.UseCase.Checkpoint -> + KryoPool.Builder { + val serializer = Fiber.getFiberSerializer(false) as KryoSerializer + val classResolver = CordaClassResolver(serializationFactory, context).apply { setKryo(serializer.kryo) } + // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that + val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } + serializer.kryo.apply { + field.set(this, classResolver) + DefaultKryoCustomizer.customize(this) + addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) + classLoader = it.second + } + }.build() + SerializationContext.UseCase.RPCClient -> + rpcClientKryoPool(context) + SerializationContext.UseCase.RPCServer -> + rpcServerKryoPool(context) + else -> + KryoPool.Builder { + DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(serializationFactory, context))).apply { classLoader = it.second } + }.build() + } + } + } + + private fun withContext(kryo: Kryo, context: SerializationContext, block: (Kryo) -> T): T { + kryo.context.ensureCapacity(context.properties.size) + context.properties.forEach { kryo.context.put(it.key, it.value) } + try { + return block(kryo) + } finally { + kryo.context.clear() + } + } + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + val pool = getPool(context) + val headerSize = KryoHeaderV0_1.size + val header = byteSequence.take(headerSize) + if (header != KryoHeaderV0_1) { + throw KryoException("Serialized bytes header does not match expected format.") + } + Input(byteSequence.bytes, byteSequence.offset + headerSize, byteSequence.size - headerSize).use { input -> + return pool.run { kryo -> + withContext(kryo, context) { + @Suppress("UNCHECKED_CAST") + if (context.objectReferencesEnabled) { + kryo.readClassAndObject(input) as T + } else { + kryo.withoutReferences { kryo.readClassAndObject(input) as T } + } + } + } + } + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + val pool = getPool(context) + return pool.run { kryo -> + withContext(kryo, context) { + serializeOutputStreamPool.run { stream -> + serializeBufferPool.run { buffer -> + Output(buffer).use { + it.outputStream = stream + it.writeBytes(KryoHeaderV0_1.bytes) + if (context.objectReferencesEnabled) { + kryo.writeClassAndObject(it, obj) + } else { + kryo.withoutReferences { kryo.writeClassAndObject(it, obj) } + } + } + SerializedBytes(stream.toByteArray()) + } + } + } + } + } +} + +private val serializeBufferPool = LazyPool( + newInstance = { ByteArray(64 * 1024) } +) +private val serializeOutputStreamPool = LazyPool( + clear = ByteArrayOutputStream::reset, + shouldReturnToPool = { it.size() < 256 * 1024 }, // Discard if it grew too large + newInstance = { ByteArrayOutputStream(64 * 1024) } +) + +// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB +val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray(Charsets.UTF_8)) + + +val KRYO_P2P_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.P2P) +val KRYO_RPC_SERVER_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.RPCServer) +val KRYO_RPC_CLIENT_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()), + emptyMap(), + true, + SerializationContext.UseCase.RPCClient) +val KRYO_STORAGE_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + AllButBlacklisted, + emptyMap(), + true, + SerializationContext.UseCase.Storage) +val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(KryoHeaderV0_1, + SerializationDefaults.javaClass.classLoader, + QuasarWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.Checkpoint) + +object QuasarWhitelist : ClassWhitelist { + override fun hasListed(type: Class<*>): Boolean = true +} + +interface SerializationScheme { + // byteSequence expected to just be the 8 bytes necessary for versioning + fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean + + @Throws(NotSerializableException::class) + fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T + + @Throws(NotSerializableException::class) + fun serialize(obj: T, context: SerializationContext): SerializedBytes +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializeAsTokenContextImpl.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializeAsTokenContextImpl.kt new file mode 100644 index 0000000000..6877239bd5 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializeAsTokenContextImpl.kt @@ -0,0 +1,56 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.node.ServiceHub +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory +import net.corda.core.serialization.SerializeAsToken +import net.corda.core.serialization.SerializeAsTokenContext + +val serializationContextKey = SerializeAsTokenContext::class.java + +fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext) + +/** + * A context for mapping SerializationTokens to/from SerializeAsTokens. + * + * A context is initialised with an object containing all the instances of [SerializeAsToken] to eagerly register all the tokens. + * In our case this can be the [ServiceHub]. + * + * Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary + * when serializing to enable/disable tokenization. + */ +class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext { + constructor(toBeTokenized: Any, serializationFactory: SerializationFactory, context: SerializationContext, serviceHub: ServiceHub) : this(serviceHub, { + serializationFactory.serialize(toBeTokenized, context.withTokenContext(this)) + }) + + private val classNameToSingleton = mutableMapOf() + private var readOnly = false + + init { + /** + * Go ahead and eagerly serialize the object to register all of the tokens in the context. + * + * This results in the toToken() method getting called for any [SingletonSerializeAsToken] instances which + * are encountered in the object graph as they are serialized and will therefore register the token to + * object mapping for those instances. We then immediately set the readOnly flag to stop further adhoc or + * accidental registrations from occuring as these could not be deserialized in a deserialization-first + * scenario if they are not part of this iniital context construction serialization. + */ + init(this) + readOnly = true + } + + override fun putSingleton(toBeTokenized: SerializeAsToken) { + val className = toBeTokenized.javaClass.name + if (className !in classNameToSingleton) { + // Only allowable if we are in SerializeAsTokenContext init (readOnly == false) + if (readOnly) { + throw UnsupportedOperationException("Attempt to write token for lazy registered ${className}. All tokens should be registered during context construction.") + } + classNameToSingleton[className] = toBeTokenized + } + } + + override fun getSingleton(className: String) = classNameToSingleton[className] ?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this") +} \ No newline at end of file diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializeAsTokenSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializeAsTokenSerializer.kt new file mode 100644 index 0000000000..118eb299fb --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/SerializeAsTokenSerializer.kt @@ -0,0 +1,25 @@ +package net.corda.nodeapi.internal.serialization + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.KryoException +import com.esotericsoftware.kryo.Serializer +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import net.corda.core.internal.castIfPossible +import net.corda.core.serialization.SerializationToken +import net.corda.core.serialization.SerializeAsToken + +/** + * A Kryo serializer for [SerializeAsToken] implementations. + */ +class SerializeAsTokenSerializer : Serializer() { + override fun write(kryo: Kryo, output: Output, obj: T) { + kryo.writeClassAndObject(output, obj.toToken(kryo.serializationContext() ?: throw KryoException("Attempt to write a ${SerializeAsToken::class.simpleName} instance of ${obj.javaClass.name} without initialising a context"))) + } + + override fun read(kryo: Kryo, input: Input, type: Class): T { + val token = (kryo.readClassAndObject(input) as? SerializationToken) ?: throw KryoException("Non-token read for tokenized type: ${type.name}") + val fromToken = token.fromToken(kryo.serializationContext() ?: throw KryoException("Attempt to read a token for a ${SerializeAsToken::class.simpleName} instance of ${type.name} without initialising a context")) + return type.castIfPossible(fromToken) ?: throw KryoException("Token read ($token) did not return expected tokenized type: ${type.name}") + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPPrimitiveSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPPrimitiveSerializer.kt similarity index 94% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPPrimitiveSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPPrimitiveSerializer.kt index b68d37c935..11f91005ed 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPPrimitiveSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPPrimitiveSerializer.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.codec.Data diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializer.kt similarity index 95% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializer.kt index b2917c39cd..85cdc24e11 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/AMQPSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPSerializer.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/ArraySerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/ArraySerializer.kt new file mode 100644 index 0000000000..f1595d70d6 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/ArraySerializer.kt @@ -0,0 +1,166 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import org.apache.qpid.proton.codec.Data +import java.io.NotSerializableException +import java.lang.reflect.Type + +/** + * Serialization / deserialization of arrays. + */ +open class ArraySerializer(override val type: Type, factory: SerializerFactory) : AMQPSerializer { + companion object { + fun make(type: Type, factory: SerializerFactory) = when (type) { + Array::class.java -> CharArraySerializer(factory) + else -> ArraySerializer(type, factory) + } + } + + // because this might be an array of array of primitives (to any recursive depth) and + // because we care that the lowest type is unboxed we can't rely on the inbuilt type + // id to generate it properly (it will always return [[[Ljava.lang.type -> type[][][] + // for example). + // + // We *need* to retain knowledge for AMQP deserialisation weather that lowest primitive + // was boxed or unboxed so just infer it recursively + private fun calcTypeName(type: Type) : String = + if (type.componentType().isArray()) { + val typeName = calcTypeName(type.componentType()); "$typeName[]" + } + else { + val arrayType = if (type.asClass()!!.componentType.isPrimitive) "[p]" else "[]" + "${type.componentType().typeName}$arrayType" + } + + override val typeDescriptor by lazy { "$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}" } + internal val elementType: Type by lazy { type.componentType() } + internal open val typeName by lazy { calcTypeName(type) } + + internal val typeNotation: TypeNotation by lazy { + RestrictedType(typeName, null, emptyList(), "list", Descriptor(typeDescriptor, null), emptyList()) + } + + override fun writeClassInfo(output: SerializationOutput) { + if (output.writeTypeNotations(typeNotation)) { + output.requireSerializer(elementType) + } + } + + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + // Write described + data.withDescribed(typeNotation.descriptor) { + withList { + for (entry in obj as Array<*>) { + output.writeObjectOrNull(entry, this, elementType) + } + } + } + } + + override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any { + if (obj is List<*>) { + return obj.map { input.readObjectOrNull(it, schema, elementType) }.toArrayOfType(elementType) + } else throw NotSerializableException("Expected a List but found $obj") + } + + open fun List.toArrayOfType(type: Type): Any { + val elementType = type.asClass() ?: throw NotSerializableException("Unexpected array element type $type") + val list = this + return java.lang.reflect.Array.newInstance(elementType, this.size).apply { + (0..lastIndex).forEach { java.lang.reflect.Array.set(this, it, list[it]) } + } + } +} + +// Boxed Character arrays required a specialisation to handle the type conversion properly when populating +// the array since Kotlin won't allow an implicit cast from Int (as they're stored as 16bit ints) to Char +class CharArraySerializer(factory: SerializerFactory) : ArraySerializer(Array::class.java, factory) { + override fun List.toArrayOfType(type: Type): Any { + val elementType = type.asClass() ?: throw NotSerializableException("Unexpected array element type $type") + val list = this + return java.lang.reflect.Array.newInstance(elementType, this.size).apply { + (0..lastIndex).forEach { java.lang.reflect.Array.set(this, it, (list[it] as Int).toChar()) } + } + } +} + +// Specialisation of [ArraySerializer] that handles arrays of unboxed java primitive types +abstract class PrimArraySerializer(type: Type, factory: SerializerFactory) : ArraySerializer(type, factory) { + companion object { + // We don't need to handle the unboxed byte type as that is coercible to a byte array, but + // the other 7 primitive types we do + val primTypes: Map PrimArraySerializer> = mapOf( + IntArray::class.java to { f -> PrimIntArraySerializer(f) }, + CharArray::class.java to { f -> PrimCharArraySerializer(f) }, + BooleanArray::class.java to { f -> PrimBooleanArraySerializer(f) }, + FloatArray::class.java to { f -> PrimFloatArraySerializer(f) }, + ShortArray::class.java to { f -> PrimShortArraySerializer(f) }, + DoubleArray::class.java to { f -> PrimDoubleArraySerializer(f) }, + LongArray::class.java to { f -> PrimLongArraySerializer(f) } + // ByteArray::class.java <-> NOT NEEDED HERE (see comment above) + ) + + fun make(type: Type, factory: SerializerFactory) = primTypes[type]!!(factory) + } + + fun localWriteObject(data: Data, func: () -> Unit) { + data.withDescribed(typeNotation.descriptor) { withList { func() } } + } +} + +class PrimIntArraySerializer(factory: SerializerFactory) : + PrimArraySerializer(IntArray::class.java, factory) { + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + localWriteObject(data) { (obj as IntArray).forEach { output.writeObjectOrNull(it, data, elementType) } } + } +} + +class PrimCharArraySerializer(factory: SerializerFactory) : + PrimArraySerializer(CharArray::class.java, factory) { + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + localWriteObject(data) { (obj as CharArray).forEach { output.writeObjectOrNull(it, data, elementType) } } + } + + override fun List.toArrayOfType(type: Type): Any { + val elementType = type.asClass() ?: throw NotSerializableException("Unexpected array element type $type") + val list = this + return java.lang.reflect.Array.newInstance(elementType, this.size).apply { + val array = this + (0..lastIndex).forEach { java.lang.reflect.Array.set(array, it, (list[it] as Int).toChar()) } + } + } +} + +class PrimBooleanArraySerializer(factory: SerializerFactory) : + PrimArraySerializer(BooleanArray::class.java, factory) { + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + localWriteObject(data) { (obj as BooleanArray).forEach { output.writeObjectOrNull(it, data, elementType) } } + } +} + +class PrimDoubleArraySerializer(factory: SerializerFactory) : + PrimArraySerializer(DoubleArray::class.java, factory) { + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + localWriteObject(data) { (obj as DoubleArray).forEach { output.writeObjectOrNull(it, data, elementType) } } + } +} + +class PrimFloatArraySerializer(factory: SerializerFactory) : + PrimArraySerializer(FloatArray::class.java, factory) { + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + localWriteObject(data) { (obj as FloatArray).forEach { output.writeObjectOrNull(it, data, elementType) } } + } +} + +class PrimShortArraySerializer(factory: SerializerFactory) : + PrimArraySerializer(ShortArray::class.java, factory) { + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + localWriteObject(data) { (obj as ShortArray).forEach { output.writeObjectOrNull(it, data, elementType) } } + } +} + +class PrimLongArraySerializer(factory: SerializerFactory) : + PrimArraySerializer(LongArray::class.java, factory) { + override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) { + localWriteObject(data) { (obj as LongArray).forEach { output.writeObjectOrNull(it, data, elementType) } } + } +} diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/CollectionSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CollectionSerializer.kt similarity index 98% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/CollectionSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CollectionSerializer.kt index 76ec0be975..2b8660ed45 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/CollectionSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CollectionSerializer.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/CustomSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CustomSerializer.kt similarity index 98% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/CustomSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CustomSerializer.kt index d08d3b8e88..58752b0ecc 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/CustomSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/CustomSerializer.kt @@ -1,6 +1,6 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp -import net.corda.core.serialization.amqp.SerializerFactory.Companion.nameForType +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.nameForType import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializationInput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt similarity index 72% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializationInput.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt index 2859dbb989..5654dfa20d 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializationInput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializationInput.kt @@ -1,7 +1,8 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp -import com.google.common.base.Throwables +import net.corda.core.internal.getStackTraceAsString import net.corda.core.serialization.SerializedBytes +import net.corda.core.utilities.ByteSequence import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.amqp.DescribedType import org.apache.qpid.proton.amqp.UnsignedByte @@ -11,7 +12,7 @@ import java.lang.reflect.Type import java.nio.ByteBuffer import java.util.* -data class objectAndEnvelope(val obj: T, val envelope: Envelope) +data class ObjectAndEnvelope(val obj: T, val envelope: Envelope) /** * Main entry point for deserializing an AMQP encoded object. @@ -26,17 +27,6 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S internal companion object { val BYTES_NEEDED_TO_PEEK: Int = 23 - private fun subArraysEqual(a: ByteArray, aOffset: Int, length: Int, b: ByteArray, bOffset: Int): Boolean { - if (aOffset + length > a.size || bOffset + length > b.size) throw IndexOutOfBoundsException() - var bytesRemaining = length - var aPos = aOffset - var bPos = bOffset - while (bytesRemaining-- > 0) { - if (a[aPos++] != b[bPos++]) return false - } - return true - } - fun peekSize(bytes: ByteArray): Int { // There's an 8 byte header, and then a 0 byte plus descriptor followed by constructor val eighth = bytes[8].toInt() @@ -64,35 +54,35 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S @Throws(NotSerializableException::class) - inline internal fun deserializeAndReturnEnvelope(bytes: SerializedBytes): objectAndEnvelope = + inline internal fun deserializeAndReturnEnvelope(bytes: SerializedBytes): ObjectAndEnvelope = deserializeAndReturnEnvelope(bytes, T::class.java) @Throws(NotSerializableException::class) - private fun getEnvelope(bytes: SerializedBytes): Envelope { + private fun getEnvelope(bytes: ByteSequence): Envelope { // Check that the lead bytes match expected header - if (!subArraysEqual(bytes.bytes, 0, 8, AmqpHeaderV1_0.bytes, 0)) { + val headerSize = AmqpHeaderV1_0.size + if (bytes.take(headerSize) != AmqpHeaderV1_0) { throw NotSerializableException("Serialization header does not match.") } val data = Data.Factory.create() - val size = data.decode(ByteBuffer.wrap(bytes.bytes, 8, bytes.size - 8)) - if (size.toInt() != bytes.size - 8) { + val size = data.decode(ByteBuffer.wrap(bytes.bytes, bytes.offset + headerSize, bytes.size - headerSize)) + if (size.toInt() != bytes.size - headerSize) { throw NotSerializableException("Unexpected size of data") } return Envelope.get(data) } - @Throws(NotSerializableException::class) - private fun des(bytes: SerializedBytes, clazz: Class, generator: (SerializedBytes, Class) -> R): R { + private fun des(generator: () -> R): R { try { - return generator(bytes, clazz) + return generator() } catch(nse: NotSerializableException) { throw nse } catch(t: Throwable) { - throw NotSerializableException("Unexpected throwable: ${t.message} ${Throwables.getStackTraceAsString(t)}") + throw NotSerializableException("Unexpected throwable: ${t.message} ${t.getStackTraceAsString()}") } finally { objectHistory.clear() } @@ -104,28 +94,24 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S * be deserialized and a schema describing the types of the objects. */ @Throws(NotSerializableException::class) - fun deserialize(bytes: SerializedBytes, clazz: Class): T { - return des(bytes, clazz) { bytes, clazz -> - var envelope = getEnvelope(bytes) + fun deserialize(bytes: ByteSequence, clazz: Class): T { + return des { + val envelope = getEnvelope(bytes) clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)) } } @Throws(NotSerializableException::class) - internal fun deserializeAndReturnEnvelope(bytes: SerializedBytes, clazz: Class): objectAndEnvelope { - return des>(bytes, clazz) { bytes, clazz -> + internal fun deserializeAndReturnEnvelope(bytes: SerializedBytes, clazz: Class): ObjectAndEnvelope { + return des { val envelope = getEnvelope(bytes) // Now pick out the obj and schema from the envelope. - objectAndEnvelope(clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)), envelope) + ObjectAndEnvelope(clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz)), envelope) } } internal fun readObjectOrNull(obj: Any?, schema: Schema, type: Type): Any? { - if (obj == null) { - return null - } else { - return readObject(obj, schema, type) - } + return if (obj == null) null else readObject(obj, schema, type) } internal fun readObject(obj: Any, schema: Schema, type: Type): Any { @@ -133,7 +119,8 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S // Look up serializer in factory by descriptor val serializer = serializerFactory.get(obj.descriptor, schema) if (serializer.type != type && !serializer.type.isSubClassOf(type)) - throw NotSerializableException("Described type with descriptor ${obj.descriptor} was expected to be of type $type") + throw NotSerializableException("Described type with descriptor ${obj.descriptor} was " + + "expected to be of type $type but was ${serializer.type}") return serializer.readObject(obj.described, schema, this) } else if (obj is Binary) { return obj.array diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedGenericArrayType.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedGenericArrayType.kt similarity index 79% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedGenericArrayType.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedGenericArrayType.kt index 5183c00954..4cfaa17e8e 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedGenericArrayType.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedGenericArrayType.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import java.lang.reflect.GenericArrayType import java.lang.reflect.Type @@ -13,6 +13,6 @@ class DeserializedGenericArrayType(private val componentType: Type) : GenericArr override fun toString(): String = typeName override fun hashCode(): Int = Objects.hashCode(componentType) override fun equals(other: Any?): Boolean { - return other is GenericArrayType && componentType.equals(other.genericComponentType) + return other is GenericArrayType && (componentType == other.genericComponentType) } -} \ No newline at end of file +} diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedType.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedType.kt similarity index 99% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedType.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedType.kt index 8869d9c758..5da78fabeb 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedType.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedType.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import com.google.common.primitives.Primitives import java.io.NotSerializableException diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/MapSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt similarity index 83% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/MapSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt index 95803f3070..c399243445 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/MapSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/MapSerializer.kt @@ -1,6 +1,5 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp -import net.corda.core.checkNotUnorderedHashMap import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException import java.lang.reflect.ParameterizedType @@ -47,9 +46,9 @@ class MapSerializer(val declaredType: ParameterizedType, factory: SerializerFact // Write map data.putMap() data.enter() - for (entry in obj as Map<*, *>) { - output.writeObjectOrNull(entry.key, data, declaredType.actualTypeArguments[0]) - output.writeObjectOrNull(entry.value, data, declaredType.actualTypeArguments[1]) + for ((key, value) in obj as Map<*, *>) { + output.writeObjectOrNull(key, data, declaredType.actualTypeArguments[0]) + output.writeObjectOrNull(value, data, declaredType.actualTypeArguments[1]) } data.exit() // exit map } @@ -64,4 +63,10 @@ class MapSerializer(val declaredType: ParameterizedType, factory: SerializerFact private fun readEntry(schema: Schema, input: DeserializationInput, entry: Map.Entry) = input.readObjectOrNull(entry.key, schema, declaredType.actualTypeArguments[0]) to input.readObjectOrNull(entry.value, schema, declaredType.actualTypeArguments[1]) +} + +internal fun Class<*>.checkNotUnorderedHashMap() { + if (HashMap::class.java.isAssignableFrom(this) && !LinkedHashMap::class.java.isAssignableFrom(this)) { + throw IllegalArgumentException("Map type $this is unstable under iteration. Suggested fix: use java.util.LinkedHashMap instead.") + } } \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/ObjectSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/ObjectSerializer.kt similarity index 95% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/ObjectSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/ObjectSerializer.kt index d22c968ef6..c942b2e6de 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/ObjectSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/ObjectSerializer.kt @@ -1,6 +1,6 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp -import net.corda.core.serialization.amqp.SerializerFactory.Companion.nameForType +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.nameForType import org.apache.qpid.proton.amqp.UnsignedInteger import org.apache.qpid.proton.codec.Data import java.io.NotSerializableException diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/PropertySerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/PropertySerializer.kt similarity index 75% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/PropertySerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/PropertySerializer.kt index 4020ca5cc5..143cd16c5b 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/PropertySerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/PropertySerializer.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import org.apache.qpid.proton.amqp.Binary import org.apache.qpid.proton.codec.Data @@ -55,8 +55,10 @@ sealed class PropertySerializer(val name: String, val readMethod: Method, val re companion object { fun make(name: String, readMethod: Method, resolvedType: Type, factory: SerializerFactory): PropertySerializer { if (SerializerFactory.isPrimitive(resolvedType)) { - // This is a little inefficient for performance since it does a runtime check of type. We could do build time check with lots of subclasses here. - return AMQPPrimitivePropertySerializer(name, readMethod, resolvedType) + return when(resolvedType) { + Char::class.java, Character::class.java -> AMQPCharPropertySerializer(name, readMethod) + else -> AMQPPrimitivePropertySerializer(name, readMethod, resolvedType) + } } else { return DescribedTypePropertySerializer(name, readMethod, resolvedType) { factory.get(null, resolvedType) } } @@ -86,10 +88,9 @@ sealed class PropertySerializer(val name: String, val readMethod: Method, val re } /** - * A property serializer for an AMQP primitive type (Int, String, etc). + * A property serializer for most AMQP primitive type (Int, String, etc). */ class AMQPPrimitivePropertySerializer(name: String, readMethod: Method, resolvedType: Type) : PropertySerializer(name, readMethod, resolvedType) { - override fun writeClassInfo(output: SerializationOutput) {} override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? { @@ -105,5 +106,24 @@ sealed class PropertySerializer(val name: String, val readMethod: Method, val re } } } + + /** + * A property serializer for the AMQP char type, needed as a specialisation as the underlying + * value of the character is stored in numeric UTF-16 form and on deserialisation requires explicit + * casting back to a char otherwise it's treated as an Integer and a TypeMismatch occurs + */ + class AMQPCharPropertySerializer(name: String, readMethod: Method) : + PropertySerializer(name, readMethod, Character::class.java) { + override fun writeClassInfo(output: SerializationOutput) {} + + override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? { + return if(obj == null) null else (obj as Short).toChar() + } + + override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) { + val input = readMethod.invoke(obj) + if (input != null) data.putShort((input as Char).toShort()) else data.putNull() + } + } } diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt similarity index 98% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/Schema.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt index 844f7ce51b..a67cb8e400 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/Schema.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/Schema.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import com.google.common.hash.Hasher import com.google.common.hash.Hashing @@ -15,6 +15,9 @@ import java.lang.reflect.Type import java.lang.reflect.TypeVariable import java.util.* +import net.corda.nodeapi.internal.serialization.carpenter.Field as CarpenterField +import net.corda.nodeapi.internal.serialization.carpenter.Schema as CarpenterSchema + // TODO: get an assigned number as per AMQP spec val DESCRIPTOR_TOP_32BITS: Long = 0xc0da0000 @@ -402,4 +405,4 @@ private fun fingerprintForObject(type: Type, contextType: Type?, alreadySeen: Mu } interfacesForSerialization(type).map { fingerprintForType(it, type, alreadySeen, hasher, factory) } return hasher -} \ No newline at end of file +} diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationHelper.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt similarity index 98% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationHelper.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt index c77faa5119..a8c15461ca 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationHelper.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationHelper.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import com.google.common.reflect.TypeToken import org.apache.qpid.proton.codec.Data @@ -112,7 +112,7 @@ internal fun interfacesForSerialization(type: Type): List { private fun exploreType(type: Type?, interfaces: MutableSet) { val clazz = type?.asClass() if (clazz != null) { - if (clazz.isInterface) interfaces += type!! + if (clazz.isInterface) interfaces += type for (newInterface in clazz.genericInterfaces) { if (newInterface !in interfaces) { exploreType(resolveTypeVariables(newInterface, type), interfaces) diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationOutput.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt similarity index 93% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationOutput.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt index 3cbfad41ba..90011f149c 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializationOutput.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutput.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import net.corda.core.serialization.SerializedBytes import org.apache.qpid.proton.codec.Data @@ -34,7 +34,7 @@ open class SerializationOutput(internal val serializerFactory: SerializerFactory // Our object writeObject(obj, this) // The schema - putObject(Schema(schemaHistory.toList())) + writeSchema(Schema(schemaHistory.toList()), this) } } val bytes = ByteArray(data.encodedSize().toInt() + 8) @@ -53,6 +53,10 @@ open class SerializationOutput(internal val serializerFactory: SerializerFactory writeObject(obj, data, obj.javaClass) } + open fun writeSchema(schema: Schema, data: Data) { + data.putObject(schema) + } + internal fun writeObjectOrNull(obj: Any?, data: Data, type: Type) { if (obj == null) { data.putNull() diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializerFactory.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt similarity index 71% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/SerializerFactory.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt index a4f887be8b..9b20b08a5c 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/SerializerFactory.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializerFactory.kt @@ -1,11 +1,14 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import com.google.common.primitives.Primitives import com.google.common.reflect.TypeResolver -import net.corda.core.checkNotUnorderedHashMap -import net.corda.core.serialization.AllWhitelist import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.CordaSerializable +import net.corda.nodeapi.internal.serialization.AllWhitelist +import net.corda.nodeapi.internal.serialization.carpenter.CarpenterSchemas +import net.corda.nodeapi.internal.serialization.carpenter.ClassCarpenter +import net.corda.nodeapi.internal.serialization.carpenter.MetaCarpenter +import net.corda.nodeapi.internal.serialization.carpenter.carpenterSchema import org.apache.qpid.proton.amqp.* import java.io.NotSerializableException import java.lang.reflect.GenericArrayType @@ -45,6 +48,7 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { private val serializersByType = ConcurrentHashMap>() private val serializersByDescriptor = ConcurrentHashMap>() private val customSerializers = CopyOnWriteArrayList>() + private val classCarpenter = ClassCarpenter() /** * Look up, and manufacture if necessary, a serializer for the given type. @@ -54,25 +58,29 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { */ @Throws(NotSerializableException::class) fun get(actualClass: Class<*>?, declaredType: Type): AMQPSerializer { - val declaredClass = declaredType.asClass() - if (declaredClass != null) { - val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType - if (Collection::class.java.isAssignableFrom(declaredClass)) { - return serializersByType.computeIfAbsent(declaredType) { - CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(declaredClass, arrayOf(AnyType), null), this) - } - } else if (Map::class.java.isAssignableFrom(declaredClass)) { - return serializersByType.computeIfAbsent(declaredClass) { - makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType(declaredClass, arrayOf(AnyType, AnyType), null)) - } - } else { - return makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) + val declaredClass = declaredType.asClass() ?: throw NotSerializableException( + "Declared types of $declaredType are not supported.") + + val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType + + val serializer = if (Collection::class.java.isAssignableFrom(declaredClass)) { + serializersByType.computeIfAbsent(declaredType) { + CollectionSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( + declaredClass, arrayOf(AnyType), null), this) + } + } else if (Map::class.java.isAssignableFrom(declaredClass)) { + serializersByType.computeIfAbsent(declaredClass) { + makeMapSerializer(declaredType as? ParameterizedType ?: DeserializedParameterizedType( + declaredClass, arrayOf(AnyType, AnyType), null)) } } else { - throw NotSerializableException("Declared types of $declaredType are not supported.") + makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType) } - } + serializersByDescriptor.putIfAbsent(serializer.typeDescriptor, serializer) + + return serializer + } /** * Try and infer concrete types for any generics type variables for the actual class encountered, based on the declared @@ -168,51 +176,61 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } } - private fun processSchema(schema: Schema) { + private fun processSchema(schema: Schema, sentinal: Boolean = false) { + val carpenterSchemas = CarpenterSchemas.newInstance() for (typeNotation in schema.types) { - processSchemaEntry(typeNotation) + try { + processSchemaEntry(typeNotation, classCarpenter.classloader) + } catch (e: ClassNotFoundException) { + if (sentinal || (typeNotation !is CompositeType)) throw e + typeNotation.carpenterSchema( + classLoaders = listOf(classCarpenter.classloader), carpenterSchemas = carpenterSchemas) + } + } + + if (carpenterSchemas.isNotEmpty()) { + val mc = MetaCarpenter(carpenterSchemas, classCarpenter) + mc.build() + processSchema(schema, true) } } - private fun processSchemaEntry(typeNotation: TypeNotation) { + private fun processSchemaEntry(typeNotation: TypeNotation, + cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) { when (typeNotation) { - is CompositeType -> processCompositeType(typeNotation) // java.lang.Class (whether a class or interface) + is CompositeType -> processCompositeType(typeNotation, cl) // java.lang.Class (whether a class or interface) is RestrictedType -> processRestrictedType(typeNotation) // Collection / Map, possibly with generics } } private fun processRestrictedType(typeNotation: RestrictedType) { - serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) { - // TODO: class loader logic, and compare the schema. - val type = typeForName(typeNotation.name) - get(null, type) - } + // TODO: class loader logic, and compare the schema. + val type = typeForName(typeNotation.name) + get(null, type) } - private fun processCompositeType(typeNotation: CompositeType) { - serializersByDescriptor.computeIfAbsent(typeNotation.descriptor.name!!) { - // TODO: class loader logic, and compare the schema. - val type = typeForName(typeNotation.name) - get(type.asClass() ?: throw NotSerializableException("Unable to build composite type for $type"), type) - } + private fun processCompositeType(typeNotation: CompositeType, + cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader) { + // TODO: class loader logic, and compare the schema. + val type = typeForName(typeNotation.name, cl) + get(type.asClass() ?: throw NotSerializableException("Unable to build composite type for $type"), type) } - private fun makeClassSerializer(clazz: Class<*>, type: Type, declaredType: Type): AMQPSerializer { - return serializersByType.computeIfAbsent(type) { - if (isPrimitive(clazz)) { - AMQPPrimitiveSerializer(clazz) - } else { - findCustomSerializer(clazz, declaredType) ?: run { - if (type.isArray()) { - whitelisted(type.componentType()) - ArraySerializer(type, this) - } else if (clazz.kotlin.objectInstance != null) { - whitelisted(clazz) - SingletonSerializer(clazz, clazz.kotlin.objectInstance!!, this) - } else { - whitelisted(type) - ObjectSerializer(type, this) - } + private fun makeClassSerializer(clazz: Class<*>, type: Type, declaredType: Type): AMQPSerializer = serializersByType.computeIfAbsent(type) { + if (isPrimitive(clazz)) { + AMQPPrimitiveSerializer(clazz) + } else { + findCustomSerializer(clazz, declaredType) ?: run { + if (type.isArray()) { + whitelisted(type.componentType()) + if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this) + else ArraySerializer.make(type, this) + } else if (clazz.kotlin.objectInstance != null) { + whitelisted(clazz) + SingletonSerializer(clazz, clazz.kotlin.objectInstance!!, this) + } else { + whitelisted(type) + ObjectSerializer(type, this) } } } @@ -269,6 +287,8 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } private val primitiveTypeNames: Map, String> = mapOf( + Character::class.java to "char", + Char::class.java to "char", Boolean::class.java to "boolean", Byte::class.java to "byte", UnsignedByte::class.java to "ubyte", @@ -283,7 +303,6 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { Decimal32::class.java to "decimal32", Decimal64::class.java to "decimal62", Decimal128::class.java to "decimal128", - Char::class.java to "char", Date::class.java to "timestamp", UUID::class.java to "uuid", ByteArray::class.java to "binary", @@ -292,17 +311,20 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { private val namesOfPrimitiveTypes: Map> = primitiveTypeNames.map { it.value to it.key }.toMap() - fun nameForType(type: Type): String { - if (type is Class<*>) { - return primitiveTypeName(type) ?: if (type.isArray) "${nameForType(type.componentType)}[]" else type.name - } else if (type is ParameterizedType) { - return "${nameForType(type.rawType)}<${type.actualTypeArguments.joinToString { nameForType(it) }}>" - } else if (type is GenericArrayType) { - return "${nameForType(type.genericComponentType)}[]" - } else throw NotSerializableException("Unable to render type $type to a string.") + fun nameForType(type: Type): String = when (type) { + is Class<*> -> { + primitiveTypeName(type) ?: if (type.isArray) { + "${nameForType(type.componentType)}${if (type.componentType.isPrimitive) "[p]" else "[]"}" + } else type.name + } + is ParameterizedType -> "${nameForType(type.rawType)}<${type.actualTypeArguments.joinToString { nameForType(it) }}>" + is GenericArrayType -> "${nameForType(type.genericComponentType)}[]" + else -> throw NotSerializableException("Unable to render type $type to a string.") } - private fun typeForName(name: String): Type { + private fun typeForName( + name: String, + cl: ClassLoader = DeserializedParameterizedType::class.java.classLoader): Type { return if (name.endsWith("[]")) { val elementType = typeForName(name.substring(0, name.lastIndex - 1)) if (elementType is ParameterizedType || elementType is GenericArrayType) { @@ -312,8 +334,22 @@ class SerializerFactory(val whitelist: ClassWhitelist = AllWhitelist) { } else { throw NotSerializableException("Not able to deserialize array type: $name") } + } else if (name.endsWith("[p]")) { + // There is no need to handle the ByteArray case as that type is coercible automatically + // to the binary type and is thus handled by the main serializer and doesn't need a + // special case for a primitive array of bytes + when (name) { + "int[p]" -> IntArray::class.java + "char[p]" -> CharArray::class.java + "boolean[p]" -> BooleanArray::class.java + "float[p]" -> FloatArray::class.java + "double[p]" -> DoubleArray::class.java + "short[p]" -> ShortArray::class.java + "long[p]" -> LongArray::class.java + else -> throw NotSerializableException("Not able to deserialize array type: $name") + } } else { - DeserializedParameterizedType.make(name) + DeserializedParameterizedType.make(name, cl) } } } diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/SingletonSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SingletonSerializer.kt similarity index 96% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/SingletonSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SingletonSerializer.kt index ac7fca8d78..22b4a2da6e 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/SingletonSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/SingletonSerializer.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/BigDecimalSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/BigDecimalSerializer.kt similarity index 70% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/custom/BigDecimalSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/BigDecimalSerializer.kt index 68d02d2350..f1bd3874b5 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/BigDecimalSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/BigDecimalSerializer.kt @@ -1,6 +1,6 @@ -package net.corda.core.serialization.amqp.custom +package net.corda.nodeapi.internal.serialization.amqp.custom -import net.corda.core.serialization.amqp.CustomSerializer +import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer import java.math.BigDecimal /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/CurrencySerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/CurrencySerializer.kt similarity index 71% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/custom/CurrencySerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/CurrencySerializer.kt index cdad5b2242..c0a970b1cb 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/CurrencySerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/CurrencySerializer.kt @@ -1,6 +1,6 @@ -package net.corda.core.serialization.amqp.custom +package net.corda.nodeapi.internal.serialization.amqp.custom -import net.corda.core.serialization.amqp.CustomSerializer +import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer import java.util.* /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/InstantSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/InstantSerializer.kt similarity index 77% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/custom/InstantSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/InstantSerializer.kt index aa0e32a927..2690d2b6fa 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/InstantSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/InstantSerializer.kt @@ -1,7 +1,7 @@ -package net.corda.core.serialization.amqp.custom +package net.corda.nodeapi.internal.serialization.amqp.custom -import net.corda.core.serialization.amqp.CustomSerializer -import net.corda.core.serialization.amqp.SerializerFactory +import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory import java.time.Instant /** diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/PublicKeySerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/PublicKeySerializer.kt similarity index 90% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/custom/PublicKeySerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/PublicKeySerializer.kt index 747940eb4a..768a35766a 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/PublicKeySerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/PublicKeySerializer.kt @@ -1,7 +1,7 @@ -package net.corda.core.serialization.amqp.custom +package net.corda.nodeapi.internal.serialization.amqp.custom import net.corda.core.crypto.Crypto -import net.corda.core.serialization.amqp.* +import net.corda.nodeapi.internal.serialization.amqp.* import org.apache.qpid.proton.codec.Data import java.lang.reflect.Type import java.security.PublicKey diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/ThrowableSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/ThrowableSerializer.kt similarity index 91% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/custom/ThrowableSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/ThrowableSerializer.kt index 7196667a41..ff7bf77740 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/ThrowableSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/ThrowableSerializer.kt @@ -1,11 +1,11 @@ -package net.corda.core.serialization.amqp.custom +package net.corda.nodeapi.internal.serialization.amqp.custom -import net.corda.core.serialization.amqp.CustomSerializer -import net.corda.core.serialization.amqp.SerializerFactory -import net.corda.core.serialization.amqp.constructorForDeserialization -import net.corda.core.serialization.amqp.propertiesForSerialization import net.corda.core.CordaRuntimeException import net.corda.core.CordaThrowable +import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory +import net.corda.nodeapi.internal.serialization.amqp.constructorForDeserialization +import net.corda.nodeapi.internal.serialization.amqp.propertiesForSerialization import java.io.NotSerializableException class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy(Throwable::class.java, ThrowableProxy::class.java, factory) { diff --git a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/X500NameSerializer.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/X500NameSerializer.kt similarity index 90% rename from core/src/main/kotlin/net/corda/core/serialization/amqp/custom/X500NameSerializer.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/X500NameSerializer.kt index e45c45b5e9..399d204592 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/amqp/custom/X500NameSerializer.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/amqp/custom/X500NameSerializer.kt @@ -1,6 +1,6 @@ -package net.corda.core.serialization.amqp.custom +package net.corda.nodeapi.internal.serialization.amqp.custom -import net.corda.core.serialization.amqp.* +import net.corda.nodeapi.internal.serialization.amqp.* import org.apache.qpid.proton.codec.Data import org.bouncycastle.asn1.ASN1InputStream import org.bouncycastle.asn1.x500.X500Name diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/AMQPSchemaExtensions.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/AMQPSchemaExtensions.kt new file mode 100644 index 0000000000..dbe8bb254d --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/AMQPSchemaExtensions.kt @@ -0,0 +1,142 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.nodeapi.internal.serialization.amqp.CompositeType +import net.corda.nodeapi.internal.serialization.amqp.Field as AMQPField +import net.corda.nodeapi.internal.serialization.amqp.Schema as AMQPSchema + +fun AMQPSchema.carpenterSchema( + loaders: List = listOf(ClassLoader.getSystemClassLoader())) + : CarpenterSchemas { + val rtn = CarpenterSchemas.newInstance() + + types.filterIsInstance().forEach { + it.carpenterSchema(classLoaders = loaders, carpenterSchemas = rtn) + } + + return rtn +} + +/** + * if we can load the class then we MUST know about all of it's composite elements + */ +private fun CompositeType.validatePropertyTypes( + classLoaders: List = listOf(ClassLoader.getSystemClassLoader())) { + fields.forEach { + if (!it.validateType(classLoaders)) throw UncarpentableException(name, it.name, it.type) + } +} + +fun AMQPField.typeAsString() = if (type == "*") requires[0] else type + +/** + * based upon this AMQP schema either + * a) add the corresponding carpenter schema to the [carpenterSchemas] param + * b) add the class to the dependency tree in [carpenterSchemas] if it cannot be instantiated + * at this time + * + * @param classLoaders list of classLoaders, defaulting toe the system class loader, that might + * be used to load objects + * @param carpenterSchemas structure that holds the dependency tree and list of classes that + * need constructing + * @param force by default a schema is not added to [carpenterSchemas] if it already exists + * on the class path. For testing purposes schema generation can be forced + */ +fun CompositeType.carpenterSchema( + classLoaders: List = listOf(ClassLoader.getSystemClassLoader()), + carpenterSchemas: CarpenterSchemas, + force: Boolean = false) { + if (classLoaders.exists(name)) { + validatePropertyTypes(classLoaders) + if (!force) return + } + + val providesList = mutableListOf>() + + var isInterface = false + var isCreatable = true + + provides.forEach { + if (name == it) { + isInterface = true + return@forEach + } + + try { + providesList.add(classLoaders.loadIfExists(it)) + } catch (e: ClassNotFoundException) { + carpenterSchemas.addDepPair(this, name, it) + isCreatable = false + } + } + + val m: MutableMap = mutableMapOf() + + fields.forEach { + try { + m[it.name] = FieldFactory.newInstance(it.mandatory, it.name, it.getTypeAsClass(classLoaders)) + } catch (e: ClassNotFoundException) { + carpenterSchemas.addDepPair(this, name, it.typeAsString()) + isCreatable = false + } + } + + if (isCreatable) { + carpenterSchemas.carpenterSchemas.add(CarpenterSchemaFactory.newInstance( + name = name, + fields = m, + interfaces = providesList, + isInterface = isInterface)) + } +} + +// map a pair of (typename, mandatory) to the corresponding class type +// where the mandatory AMQP flag maps to the types nullability +val typeStrToType: Map, Class> = mapOf( + Pair("int", true) to Int::class.javaPrimitiveType!!, + Pair("int", false) to Integer::class.javaObjectType, + Pair("short", true) to Short::class.javaPrimitiveType!!, + Pair("short", false) to Short::class.javaObjectType, + Pair("long", true) to Long::class.javaPrimitiveType!!, + Pair("long", false) to Long::class.javaObjectType, + Pair("char", true) to Char::class.javaPrimitiveType!!, + Pair("char", false) to java.lang.Character::class.java, + Pair("boolean", true) to Boolean::class.javaPrimitiveType!!, + Pair("boolean", false) to Boolean::class.javaObjectType, + Pair("double", true) to Double::class.javaPrimitiveType!!, + Pair("double", false) to Double::class.javaObjectType, + Pair("float", true) to Float::class.javaPrimitiveType!!, + Pair("float", false) to Float::class.javaObjectType, + Pair("byte", true) to Byte::class.javaPrimitiveType!!, + Pair("byte", false) to Byte::class.javaObjectType +) + +fun AMQPField.getTypeAsClass( + classLoaders: List = listOf(ClassLoader.getSystemClassLoader()) +) = typeStrToType[Pair(type, mandatory)] ?: when (type) { + "string" -> String::class.java + "*" -> classLoaders.loadIfExists(requires[0]) + else -> classLoaders.loadIfExists(type) +} + +fun AMQPField.validateType( + classLoaders: List = listOf(ClassLoader.getSystemClassLoader()) +) = when (type) { + "byte", "int", "string", "short", "long", "char", "boolean", "double", "float" -> true + "*" -> classLoaders.exists(requires[0]) + else -> classLoaders.exists(type) +} + +private fun List.exists(clazz: String) = this.find { + try { it.loadClass(clazz); true } catch (e: ClassNotFoundException) { false } +} != null + +private fun List.loadIfExists(clazz: String): Class<*> { + this.forEach { + try { + return it.loadClass(clazz) + } catch (e: ClassNotFoundException) { + return@forEach + } + } + throw ClassNotFoundException(clazz) +} diff --git a/core/src/main/kotlin/net/corda/core/serialization/carpenter/ClassCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt similarity index 69% rename from core/src/main/kotlin/net/corda/core/serialization/carpenter/ClassCarpenter.kt rename to node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt index 5d774df240..21641a0ae1 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/carpenter/ClassCarpenter.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenter.kt @@ -1,9 +1,8 @@ -package net.corda.core.serialization.carpenter +package net.corda.nodeapi.internal.serialization.carpenter import org.objectweb.asm.ClassWriter import org.objectweb.asm.MethodVisitor import org.objectweb.asm.Opcodes.* -import org.objectweb.asm.Type import java.lang.Character.isJavaIdentifierPart import java.lang.Character.isJavaIdentifierStart @@ -19,6 +18,9 @@ interface SimpleFieldAccess { operator fun get(name: String): Any? } +class CarpenterClassLoader : ClassLoader(Thread.currentThread().contextClassLoader) { + fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size) +} /** * A class carpenter generates JVM bytecodes for a class given a schema and then loads it into a sub-classloader. @@ -71,143 +73,10 @@ class ClassCarpenter { // TODO: Support annotations. // TODO: isFoo getter patterns for booleans (this is what Kotlin generates) - class DuplicateNameException : RuntimeException("An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.") - class InterfaceMismatchException(msg: String) : RuntimeException(msg) - class NullablePrimitiveException(msg: String) : RuntimeException(msg) - - abstract class Field(val field: Class) { - companion object { - const val unsetName = "Unset" - } - - var name: String = unsetName - abstract val nullabilityAnnotation: String - - val descriptor: String - get() = Type.getDescriptor(this.field) - - val type: String - get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" - - fun generateField(cw: ClassWriter) { - val fieldVisitor = cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null) - fieldVisitor.visitAnnotation(nullabilityAnnotation, true).visitEnd() - fieldVisitor.visitEnd() - } - - fun addNullabilityAnnotation(mv: MethodVisitor) { - mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() - } - - fun visitParameter(mv: MethodVisitor, idx: Int) { - with(mv) { - visitParameter(name, 0) - if (!field.isPrimitive) { - visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd() - } - } - } - - abstract fun copy(name: String, field: Class): Field - abstract fun nullTest(mv: MethodVisitor, slot: Int) - } - - class NonNullableField(field: Class) : Field(field) { - override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;" - - constructor(name: String, field: Class) : this(field) { - this.name = name - } - - override fun copy(name: String, field: Class) = NonNullableField(name, field) - - override fun nullTest(mv: MethodVisitor, slot: Int) { - assert(name != unsetName) - - if (!field.isPrimitive) { - with(mv) { - visitVarInsn(ALOAD, 0) // load this - visitVarInsn(ALOAD, slot) // load parameter - visitLdcInsn("param \"$name\" cannot be null") - visitMethodInsn(INVOKESTATIC, - "java/util/Objects", - "requireNonNull", - "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) - visitInsn(POP) - } - } - } - } - - - class NullableField(field: Class) : Field(field) { - override val nullabilityAnnotation = "Ljavax/annotation/Nullable;" - - constructor(name: String, field: Class) : this(field) { - if (field.isPrimitive) { - throw NullablePrimitiveException ( - "Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable") - } - - this.name = name - } - - override fun copy(name: String, field: Class) = NullableField(name, field) - - override fun nullTest(mv: MethodVisitor, slot: Int) { - assert(name != unsetName) - } - } - - /** - * A Schema represents a desired class. - */ - abstract class Schema( - val name: String, - fields: Map, - val superclass: Schema? = null, - val interfaces: List> = emptyList()) - { - private fun Map.descriptors() = - LinkedHashMap(this.mapValues { it.value.descriptor }) - - /* Fix the order up front if the user didn't, inject the name into the field as it's - neater when iterating */ - val fields = LinkedHashMap(fields.mapValues { it.value.copy(it.key, it.value.field) }) - - fun fieldsIncludingSuperclasses(): Map = - (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) - - fun descriptorsIncludingSuperclasses(): Map = - (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + fields.descriptors() - - val jvmName: String - get() = name.replace(".", "/") - } - - private val String.jvm: String get() = replace(".", "/") - - class ClassSchema( - name: String, - fields: Map, - superclass: Schema? = null, - interfaces: List> = emptyList() - ) : Schema(name, fields, superclass, interfaces) - - class InterfaceSchema( - name: String, - fields: Map, - superclass: Schema? = null, - interfaces: List> = emptyList() - ) : Schema(name, fields, superclass, interfaces) - - private class CarpenterClassLoader : ClassLoader(Thread.currentThread().contextClassLoader) { - fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size) - } - - private val classloader = CarpenterClassLoader() + val classloader = CarpenterClassLoader() private val _loaded = HashMap>() + private val String.jvm: String get() = replace(".", "/") /** Returns a snapshot of the currently loaded classes as a map of full class name (package names+dots) -> class object */ val loaded: Map> = HashMap(_loaded) @@ -216,7 +85,8 @@ class ClassCarpenter { * Generate bytecode for the given schema and load into the JVM. The returned class object can be used to * construct instances of the generated class. * - * @throws DuplicateName if the schema's name is already taken in this namespace (you can create a new ClassCarpenter if you're OK with ambiguous names) + * @throws DuplicateNameException if the schema's name is already taken in this namespace (you can create a + * new ClassCarpenter if you're OK with ambiguous names) */ fun build(schema: Schema): Class<*> { validateSchema(schema) @@ -237,6 +107,8 @@ class ClassCarpenter { } } + assert (schema.name in _loaded) + return _loaded[schema.name]!! } @@ -257,10 +129,12 @@ class ClassCarpenter { private fun generateClass(classSchema: Schema): Class<*> { return generate(classSchema) { cw, schema -> val superName = schema.superclass?.jvmName ?: "java/lang/Object" - val interfaces = arrayOf(SimpleFieldAccess::class.java.name.jvm) + schema.interfaces.map { it.name.jvm } + val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList() + + if (SimpleFieldAccess::class.java !in schema.interfaces) interfaces.add(SimpleFieldAccess::class.java.name.jvm) with(cw) { - visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces) + visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray()) generateFields(schema) generateConstructor(schema) @@ -291,7 +165,7 @@ class ClassCarpenter { private fun ClassWriter.generateToString(schema: Schema) { val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper" - with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", "", null)) { + with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", null, null)) { visitCode() // com.google.common.base.MoreObjects.toStringHelper("TypeName") visitLdcInsn(schema.name.split('.').last()) @@ -374,14 +248,14 @@ class ClassCarpenter { // Calculate the super call. val superclassFields = schema.superclass?.fieldsIncludingSuperclasses() ?: emptyMap() visitVarInsn(ALOAD, 0) - if (schema.superclass == null) { + val sc = schema.superclass + if (sc == null) { visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "", "()V", false) } else { var slot = 1 - for (fieldType in superclassFields.values) - slot += load(slot, fieldType) - val superDesc = schema.superclass.descriptorsIncludingSuperclasses().values.joinToString("") - visitMethodInsn(INVOKESPECIAL, schema.superclass.name.jvm, "", "($superDesc)V", false) + superclassFields.values.forEach { slot += load(slot, it) } + val superDesc = sc.descriptorsIncludingSuperclasses().values.joinToString("") + visitMethodInsn(INVOKESPECIAL, sc.name.jvm, "", "($superDesc)V", false) } // Assign the fields from parameters. @@ -429,13 +303,18 @@ class ClassCarpenter { it.name.startsWith("get") -> it.name.substring(3).decapitalize() else -> throw InterfaceMismatchException( "Requested interfaces must consist only of methods that start " - + "with 'get': ${itf.name}.${it.name}") + + "with 'get': ${itf.name}.${it.name}") } + // If we're trying to carpent a class that prior to serialisation / deserialisation + // was made by a carpenter then we can ignore this (it will implement a plain get + // method from SimpleFieldAccess). + if (fieldNameFromItf.isEmpty() && SimpleFieldAccess::class.java in schema.interfaces) return@forEach + if ((schema is ClassSchema) and (fieldNameFromItf !in allFields)) throw InterfaceMismatchException( "Interface ${itf.name} requires a field named $fieldNameFromItf but that " - + "isn't found in the schema or any superclass schemas") + + "isn't found in the schema or any superclass schemas") } } } diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Exceptions.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Exceptions.kt new file mode 100644 index 0000000000..cfa2f2a4e8 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Exceptions.kt @@ -0,0 +1,11 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +class DuplicateNameException : RuntimeException ( + "An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.") + +class InterfaceMismatchException(msg: String) : RuntimeException(msg) + +class NullablePrimitiveException(msg: String) : RuntimeException(msg) + +class UncarpentableException (name: String, field: String, type: String) : + Exception ("Class $name is loadable yet contains field $field of unknown type $type") diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MetaCarpenter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MetaCarpenter.kt new file mode 100644 index 0000000000..77e866fb32 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MetaCarpenter.kt @@ -0,0 +1,107 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.nodeapi.internal.serialization.amqp.CompositeType +import net.corda.nodeapi.internal.serialization.amqp.TypeNotation + +/** + * Generated from an AMQP schema this class represents the classes unknown to the deserialiser and that thusly + * require carpenting up in bytecode form. This is a multi step process as carpenting one object may be depedent + * upon the creation of others, this information is tracked in the dependency tree represented by + * [dependencies] and [dependsOn]. Creatable classes are stored in [carpenterSchemas]. + * + * The state of this class after initial generation is expected to mutate as classes are built by the carpenter + * enablaing the resolution of dependencies and thus new carpenter schemas added whilst those already + * carpented schemas are removed. + * + * @property carpenterSchemas The list of carpentable classes + * @property dependencies Maps a class to a list of classes that depend on it being built first + * @property dependsOn Maps a class to a list of classes it depends on being built before it + * + * Once a class is constructed we can quickly check for resolution by first looking at all of its dependents in the + * [dependencies] map. This will give us a list of classes that depended on that class being carpented. We can then + * in turn look up all of those classes in the [dependsOn] list, remove their dependency on the newly created class, + * and if that list is reduced to zero know we can now generate a [Schema] for them and carpent them up + */ +data class CarpenterSchemas ( + val carpenterSchemas: MutableList, + val dependencies: MutableMap>>, + val dependsOn: MutableMap>) { + companion object CarpenterSchemaConstructor { + fun newInstance(): CarpenterSchemas { + return CarpenterSchemas( + mutableListOf(), + mutableMapOf>>(), + mutableMapOf>()) + } + } + + fun addDepPair(type: TypeNotation, dependant: String, dependee: String) { + dependsOn.computeIfAbsent(dependee, { mutableListOf() }).add(dependant) + dependencies.computeIfAbsent(dependant, { Pair(type, mutableListOf()) }).second.add(dependee) + } + + val size + get() = carpenterSchemas.size + + fun isEmpty() = carpenterSchemas.isEmpty() + fun isNotEmpty() = carpenterSchemas.isNotEmpty() +} + +/** + * Take a dependency tree of [CarpenterSchemas] and reduce it to zero by carpenting those classes that + * require it. As classes are carpented check for depdency resolution, if now free generate a [Schema] for + * that class and add it to the list of classes ([CarpenterSchemas.carpenterSchemas]) that require + * carpenting + * + * @property cc a reference to the actual class carpenter we're using to constuct classes + * @property objects a list of carpented classes loaded into the carpenters class loader + */ +abstract class MetaCarpenterBase (val schemas : CarpenterSchemas, val cc : ClassCarpenter = ClassCarpenter()) { + val objects = mutableMapOf>() + + fun step(newObject: Schema) { + objects[newObject.name] = cc.build (newObject) + + // go over the list of everything that had a dependency on the newly + // carpented class existing and remove it from their dependency list, If that + // list is now empty we have no impediment to carpenting that class up + schemas.dependsOn.remove(newObject.name)?.forEach { dependent -> + assert (newObject.name in schemas.dependencies[dependent]!!.second) + + schemas.dependencies[dependent]?.second?.remove(newObject.name) + + // we're out of blockers so we can now create the type + if (schemas.dependencies[dependent]?.second?.isEmpty() ?: false) { + (schemas.dependencies.remove (dependent)?.first as CompositeType).carpenterSchema ( + classLoaders = listOf ( + ClassLoader.getSystemClassLoader(), + cc.classloader), + carpenterSchemas = schemas) + } + } + } + + abstract fun build() + + val classloader : ClassLoader + get() = cc.classloader +} + +class MetaCarpenter(schemas : CarpenterSchemas, + cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) { + override fun build() { + while (schemas.carpenterSchemas.isNotEmpty()) { + val newObject = schemas.carpenterSchemas.removeAt(0) + step (newObject) + } + } +} + +class TestMetaCarpenter(schemas : CarpenterSchemas, + cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) { + override fun build() { + if (schemas.carpenterSchemas.isEmpty()) return + step (schemas.carpenterSchemas.removeAt(0)) + } +} + diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt new file mode 100644 index 0000000000..a0753bbfe4 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/carpenter/Schema.kt @@ -0,0 +1,148 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import jdk.internal.org.objectweb.asm.Opcodes.* +import org.objectweb.asm.ClassWriter +import org.objectweb.asm.MethodVisitor +import org.objectweb.asm.Type +import java.util.* + +/** + * A Schema represents a desired class. + */ +abstract class Schema( + val name: String, + fields: Map, + val superclass: Schema? = null, + val interfaces: List> = emptyList()) +{ + private fun Map.descriptors() = + LinkedHashMap(this.mapValues { it.value.descriptor }) + + /* Fix the order up front if the user didn't, inject the name into the field as it's + neater when iterating */ + val fields = LinkedHashMap(fields.mapValues { it.value.copy(it.key, it.value.field) }) + + fun fieldsIncludingSuperclasses(): Map = + (superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields) + + fun descriptorsIncludingSuperclasses(): Map = + (superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + fields.descriptors() + + val jvmName: String + get() = name.replace(".", "/") +} + +class ClassSchema( + name: String, + fields: Map, + superclass: Schema? = null, + interfaces: List> = emptyList() +) : Schema(name, fields, superclass, interfaces) + +class InterfaceSchema( + name: String, + fields: Map, + superclass: Schema? = null, + interfaces: List> = emptyList() +) : Schema(name, fields, superclass, interfaces) + +object CarpenterSchemaFactory { + fun newInstance ( + name: String, + fields: Map, + superclass: Schema? = null, + interfaces: List> = emptyList(), + isInterface: Boolean = false + ) : Schema = + if (isInterface) InterfaceSchema (name, fields, superclass, interfaces) + else ClassSchema (name, fields, superclass, interfaces) +} + +abstract class Field(val field: Class) { + companion object { + const val unsetName = "Unset" + } + + var name: String = unsetName + abstract val nullabilityAnnotation: String + + val descriptor: String + get() = Type.getDescriptor(this.field) + + val type: String + get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;" + + fun generateField(cw: ClassWriter) { + val fieldVisitor = cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null) + fieldVisitor.visitAnnotation(nullabilityAnnotation, true).visitEnd() + fieldVisitor.visitEnd() + } + + fun addNullabilityAnnotation(mv: MethodVisitor) { + mv.visitAnnotation(nullabilityAnnotation, true).visitEnd() + } + + fun visitParameter(mv: MethodVisitor, idx: Int) { + with(mv) { + visitParameter(name, 0) + if (!field.isPrimitive) { + visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd() + } + } + } + + abstract fun copy(name: String, field: Class): Field + abstract fun nullTest(mv: MethodVisitor, slot: Int) +} + +class NonNullableField(field: Class) : Field(field) { + override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;" + + constructor(name: String, field: Class) : this(field) { + this.name = name + } + + override fun copy(name: String, field: Class) = NonNullableField(name, field) + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + + if (!field.isPrimitive) { + with(mv) { + visitVarInsn(ALOAD, 0) // load this + visitVarInsn(ALOAD, slot) // load parameter + visitLdcInsn("param \"$name\" cannot be null") + visitMethodInsn(INVOKESTATIC, + "java/util/Objects", + "requireNonNull", + "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false) + visitInsn(POP) + } + } + } +} + +class NullableField(field: Class) : Field(field) { + override val nullabilityAnnotation = "Ljavax/annotation/Nullable;" + + constructor(name: String, field: Class) : this(field) { + if (field.isPrimitive) { + throw NullablePrimitiveException ( + "Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable") + } + + this.name = name + } + + override fun copy(name: String, field: Class) = NullableField(name, field) + + override fun nullTest(mv: MethodVisitor, slot: Int) { + assert(name != unsetName) + } +} + +object FieldFactory { + fun newInstance (mandatory: Boolean, name: String, field: Class) = + if (mandatory) NonNullableField (name, field) else NullableField (name, field) + +} diff --git a/node-api/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry b/node-api/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry index afc9681edf..6a4ca5dcd5 100644 --- a/node-api/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry +++ b/node-api/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry @@ -1,2 +1,2 @@ # Register a ServiceLoader service extending from net.corda.core.node.CordaPluginRegistry -net.corda.nodeapi.serialization.DefaultWhitelist \ No newline at end of file +net.corda.nodeapi.internal.serialization.DefaultWhitelist \ No newline at end of file diff --git a/core/src/test/kotlin/net/corda/core/node/AttachmentClassLoaderTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/AttachmentClassLoaderTests.kt similarity index 80% rename from core/src/test/kotlin/net/corda/core/node/AttachmentClassLoaderTests.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/AttachmentClassLoaderTests.kt index e2af5fe66b..284127a29a 100644 --- a/core/src/test/kotlin/net/corda/core/node/AttachmentClassLoaderTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/AttachmentClassLoaderTests.kt @@ -1,21 +1,26 @@ -package net.corda.core.node +package net.corda.nodeapi -import com.esotericsoftware.kryo.Kryo import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.whenever import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.node.ServiceHub import net.corda.core.node.services.AttachmentStorage import net.corda.core.serialization.* +import net.corda.core.serialization.SerializationDefaults.P2P_CONTEXT +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder +import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl +import net.corda.nodeapi.internal.serialization.WireTransactionSerializer +import net.corda.nodeapi.internal.serialization.withTokenContext import net.corda.testing.DUMMY_NOTARY import net.corda.testing.MEGA_CORP +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.node.MockAttachmentStorage import org.apache.commons.io.IOUtils import org.junit.Assert -import org.junit.Before import org.junit.Test import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream @@ -35,15 +40,14 @@ interface DummyContractBackdoor { val ATTACHMENT_TEST_PROGRAM_ID = AttachmentClassLoaderTests.AttachmentDummyContract() -class AttachmentClassLoaderTests { +class AttachmentClassLoaderTests : TestDependencyInjectionBase() { companion object { val ISOLATED_CONTRACTS_JAR_PATH: URL = AttachmentClassLoaderTests::class.java.getResource("isolated.jar") - private fun Kryo.withAttachmentStorage(attachmentStorage: AttachmentStorage, block: () -> T) = run { - context.put(WireTransactionSerializer.attachmentsClassLoaderEnabled, true) + private fun SerializationContext.withAttachmentStorage(attachmentStorage: AttachmentStorage): SerializationContext { val serviceHub = mock() whenever(serviceHub.attachments).thenReturn(attachmentStorage) - withSerializationContext(SerializeAsTokenContext(serviceHub) {}, block) + return this.withTokenContext(SerializeAsTokenContextImpl(serviceHub) {}).withProperty(WireTransactionSerializer.attachmentsClassLoaderEnabled, true) } } @@ -58,7 +62,7 @@ class AttachmentClassLoaderTests { class Create : TypeOnlyCommandData(), Commands } - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { // Always accepts. } @@ -67,7 +71,7 @@ class AttachmentClassLoaderTests { fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder { val state = State(magicNumber) - return TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Create(), owner.party.owningKey)) + return TransactionBuilder(notary).withItems(state, Command(Commands.Create(), owner.party.owningKey)) } } @@ -88,16 +92,6 @@ class AttachmentClassLoaderTests { class ClassLoaderForTests : URLClassLoader(arrayOf(ISOLATED_CONTRACTS_JAR_PATH), FilteringClassLoader) - lateinit var kryo: Kryo - lateinit var kryo2: Kryo - - @Before - fun setup() { - // Do not release these back to the pool, since we do some unorthodox modifications to them below. - kryo = p2PKryo().borrow() - kryo2 = p2PKryo().borrow() - } - @Test fun `dynamically load AnotherDummyContract from isolated contracts jar`() { val child = ClassLoaderForTests() @@ -202,7 +196,7 @@ class AttachmentClassLoaderTests { @Test fun `verify that contract DummyContract is in classPath`() { - val contractClass = Class.forName("net.corda.core.node.AttachmentClassLoaderTests\$AttachmentDummyContract") + val contractClass = Class.forName("net.corda.nodeapi.AttachmentClassLoaderTests\$AttachmentDummyContract") val contract = contractClass.newInstance() as Contract assertNotNull(contract) @@ -228,10 +222,8 @@ class AttachmentClassLoaderTests { val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) - kryo.classLoader = cl - kryo.addToWhitelist(contract.javaClass) - - val state2 = bytes.deserialize(kryo) + val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(contract.javaClass) + val state2 = bytes.deserialize(context = context) assertTrue(state2.javaClass.classLoader is AttachmentsClassLoader) assertNotNull(state2) } @@ -246,8 +238,9 @@ class AttachmentClassLoaderTests { assertNotNull(data.contract) - kryo2.addToWhitelist(data.contract.javaClass) - val bytes = data.serialize(kryo2) + val context2 = P2P_CONTEXT.withWhitelisted(data.contract.javaClass) + + val bytes = data.serialize(context = context2) val storage = MockAttachmentStorage() @@ -257,20 +250,18 @@ class AttachmentClassLoaderTests { val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) - kryo.classLoader = cl - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)) + val context = P2P_CONTEXT.withClassLoader(cl).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)) - val state2 = bytes.deserialize(kryo) + val state2 = bytes.deserialize(context = context) assertEquals(cl, state2.contract.javaClass.classLoader) assertNotNull(state2) // We should be able to load same class from a different class loader and have them be distinct. val cl2 = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) - kryo.classLoader = cl2 - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2)) + val context3 = P2P_CONTEXT.withClassLoader(cl2).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2)) - val state3 = bytes.deserialize(kryo) + val state3 = bytes.deserialize(context = context3) assertEquals(cl2, state3.contract.javaClass.classLoader) assertNotNull(state3) } @@ -294,30 +285,22 @@ class AttachmentClassLoaderTests { val contract = contractClass.newInstance() as DummyContractBackdoor val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY) val storage = MockAttachmentStorage() - kryo.addToWhitelist(contract.javaClass) - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child)) - kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child)) + val context = P2P_CONTEXT.withWhitelisted(contract.javaClass) + .withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child)) + .withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child)) + .withAttachmentStorage(storage) // todo - think about better way to push attachmentStorage down to serializer - val bytes = kryo.withAttachmentStorage(storage) { - + val bytes = run { val attachmentRef = importJar(storage) - tx.addAttachment(storage.openAttachment(attachmentRef)!!.id) - val wireTransaction = tx.toWireTransaction() - - wireTransaction.serialize(kryo) - } - // use empty attachmentStorage - kryo2.withAttachmentStorage(storage) { - - val copiedWireTransaction = bytes.deserialize(kryo2) - - assertEquals(1, copiedWireTransaction.outputs.size) - val contract2 = copiedWireTransaction.outputs[0].data.contract as DummyContractBackdoor - assertEquals(42, contract2.inspectState(copiedWireTransaction.outputs[0].data)) + wireTransaction.serialize(context = context) } + val copiedWireTransaction = bytes.deserialize(context = context) + assertEquals(1, copiedWireTransaction.outputs.size) + val contract2 = copiedWireTransaction.getOutput(0).contract as DummyContractBackdoor + assertEquals(42, contract2.inspectState(copiedWireTransaction.outputs[0].data)) } @Test @@ -330,21 +313,19 @@ class AttachmentClassLoaderTests { // todo - think about better way to push attachmentStorage down to serializer val attachmentRef = importJar(storage) - val bytes = kryo.withAttachmentStorage(storage) { + val bytes = run { tx.addAttachment(storage.openAttachment(attachmentRef)!!.id) val wireTransaction = tx.toWireTransaction() - wireTransaction.serialize(kryo) + wireTransaction.serialize(context = P2P_CONTEXT.withAttachmentStorage(storage)) } // use empty attachmentStorage - kryo2.withAttachmentStorage(MockAttachmentStorage()) { - val e = assertFailsWith(MissingAttachmentsException::class) { - bytes.deserialize(kryo2) - } - assertEquals(attachmentRef, e.ids.single()) + val e = assertFailsWith(MissingAttachmentsException::class) { + bytes.deserialize(context = P2P_CONTEXT.withAttachmentStorage(MockAttachmentStorage())) } + assertEquals(attachmentRef, e.ids.single()) } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/config/ConfigParsingTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/config/ConfigParsingTest.kt index a3abc82492..1d7b061fad 100644 --- a/node-api/src/test/kotlin/net/corda/nodeapi/config/ConfigParsingTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/config/ConfigParsingTest.kt @@ -4,7 +4,7 @@ import com.typesafe.config.Config import com.typesafe.config.ConfigFactory.empty import com.typesafe.config.ConfigRenderOptions.defaults import com.typesafe.config.ConfigValueFactory -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.utilities.NetworkHostAndPort import net.corda.testing.getTestX509Name import org.assertj.core.api.Assertions.assertThat diff --git a/core/src/test/kotlin/net/corda/core/serialization/CordaClassResolverTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt similarity index 60% rename from core/src/test/kotlin/net/corda/core/serialization/CordaClassResolverTests.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt index 46d4aa7499..88ac9f4049 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/CordaClassResolverTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CordaClassResolverTests.kt @@ -1,12 +1,13 @@ -package net.corda.core.serialization +package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.util.MapReferenceResolver -import net.corda.core.node.AttachmentClassLoaderTests -import net.corda.core.node.AttachmentsClassLoader import net.corda.core.node.services.AttachmentStorage +import net.corda.core.serialization.* +import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.AttachmentClassLoaderTests import net.corda.testing.node.MockAttachmentStorage import org.junit.Rule import org.junit.Test @@ -75,71 +76,84 @@ class DefaultSerializableSerializer : Serializer() { } class CordaClassResolverTests { + val factory: SerializationFactory = object : SerializationFactory { + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + TODO("not implemented") //To change body of created functions use File | Settings | File Templates. + } + + } + + val emptyWhitelistContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P) + val allButBlacklistedContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P) + @Test fun `Annotation on enum works for specialised entries`() { // TODO: Remove this suppress when we upgrade to kotlin 1.1 or when JetBrain fixes the bug. @Suppress("UNSUPPORTED_FEATURE") - CordaClassResolver(EmptyWhitelist).getRegistration(Foo.Bar::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Foo.Bar::class.java) } @Test fun `Annotation on array element works`() { val values = arrayOf(Element()) - CordaClassResolver(EmptyWhitelist).getRegistration(values.javaClass) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(values.javaClass) } @Test fun `Annotation not needed on abstract class`() { - CordaClassResolver(EmptyWhitelist).getRegistration(AbstractClass::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(AbstractClass::class.java) } @Test fun `Annotation not needed on interface`() { - CordaClassResolver(EmptyWhitelist).getRegistration(Interface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Interface::class.java) } @Test fun `Calling register method on modified Kryo does not consult the whitelist`() { - val kryo = CordaKryo(CordaClassResolver(EmptyWhitelist)) + val kryo = CordaKryo(CordaClassResolver(factory, emptyWhitelistContext)) kryo.register(NotSerializable::class.java) } @Test(expected = KryoException::class) fun `Calling register method on unmodified Kryo does consult the whitelist`() { - val kryo = Kryo(CordaClassResolver(EmptyWhitelist), MapReferenceResolver()) + val kryo = Kryo(CordaClassResolver(factory, emptyWhitelistContext), MapReferenceResolver()) kryo.register(NotSerializable::class.java) } @Test(expected = KryoException::class) fun `Annotation is needed without whitelisting`() { - CordaClassResolver(EmptyWhitelist).getRegistration(NotSerializable::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(NotSerializable::class.java) } @Test fun `Annotation is not needed with whitelisting`() { - val resolver = CordaClassResolver(GlobalTransientClassWhiteList(EmptyWhitelist)) - (resolver.whitelist as MutableClassWhitelist).add(NotSerializable::class.java) + val resolver = CordaClassResolver(factory, emptyWhitelistContext.withWhitelisted(NotSerializable::class.java)) resolver.getRegistration(NotSerializable::class.java) } @Test fun `Annotation not needed on Object`() { - CordaClassResolver(EmptyWhitelist).getRegistration(Object::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Object::class.java) } @Test fun `Annotation not needed on primitive`() { - CordaClassResolver(EmptyWhitelist).getRegistration(Integer.TYPE) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Integer.TYPE) } @Test(expected = KryoException::class) fun `Annotation does not work for custom serializable`() { - CordaClassResolver(EmptyWhitelist).getRegistration(CustomSerializable::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(CustomSerializable::class.java) } @Test(expected = KryoException::class) fun `Annotation does not work in conjunction with Kryo annotation`() { - CordaClassResolver(EmptyWhitelist).getRegistration(DefaultSerializable::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(DefaultSerializable::class.java) } private fun importJar(storage: AttachmentStorage) = AttachmentClassLoaderTests.ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) } @@ -150,20 +164,20 @@ class CordaClassResolverTests { val attachmentHash = importJar(storage) val classLoader = AttachmentsClassLoader(arrayOf(attachmentHash).map { storage.openAttachment(it)!! }) val attachedClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, classLoader) - CordaClassResolver(EmptyWhitelist).getRegistration(attachedClass) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(attachedClass) } @Test fun `Annotation is inherited from interfaces`() { - CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaInterface::class.java) - CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaSubInterface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaInterface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSubInterface::class.java) } @Test fun `Annotation is inherited from superclass`() { - CordaClassResolver(EmptyWhitelist).getRegistration(SubElement::class.java) - CordaClassResolver(EmptyWhitelist).getRegistration(SubSubElement::class.java) - CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaSuperSubInterface::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubElement::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubSubElement::class.java) + CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSuperSubInterface::class.java) } // Blacklist tests. @@ -174,7 +188,7 @@ class CordaClassResolverTests { fun `Check blacklisted class`() { expectedEx.expect(IllegalStateException::class.java) expectedEx.expectMessage("Class java.util.HashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // HashSet is blacklisted. resolver.getRegistration(HashSet::class.java) } @@ -183,8 +197,8 @@ class CordaClassResolverTests { @Test fun `Check blacklisted subclass`() { expectedEx.expect(IllegalStateException::class.java) - expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.core.serialization.CordaClassResolverTests\$SubHashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubHashSet is blacklisted, so it cannot be used in serialization.") + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // SubHashSet extends the blacklisted HashSet. resolver.getRegistration(SubHashSet::class.java) } @@ -193,8 +207,8 @@ class CordaClassResolverTests { @Test fun `Check blacklisted subsubclass`() { expectedEx.expect(IllegalStateException::class.java) - expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.core.serialization.CordaClassResolverTests\$SubSubHashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubSubHashSet is blacklisted, so it cannot be used in serialization.") + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // SubSubHashSet extends SubHashSet, which extends the blacklisted HashSet. resolver.getRegistration(SubSubHashSet::class.java) } @@ -203,8 +217,8 @@ class CordaClassResolverTests { @Test fun `Check blacklisted interface impl`() { expectedEx.expect(IllegalStateException::class.java) - expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.core.serialization.CordaClassResolverTests\$ConnectionImpl is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$ConnectionImpl is blacklisted, so it cannot be used in serialization.") + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // ConnectionImpl implements blacklisted Connection. resolver.getRegistration(ConnectionImpl::class.java) } @@ -214,15 +228,15 @@ class CordaClassResolverTests { @Test fun `Check blacklisted super-interface impl`() { expectedEx.expect(IllegalStateException::class.java) - expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.core.serialization.CordaClassResolverTests\$SubConnectionImpl is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubConnectionImpl is blacklisted, so it cannot be used in serialization.") + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // SubConnectionImpl implements SubConnection, which extends the blacklisted Connection. resolver.getRegistration(SubConnectionImpl::class.java) } @Test fun `Check forcibly allowed`() { - val resolver = CordaClassResolver(AllButBlacklisted) + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // LinkedHashSet is allowed for serialization. resolver.getRegistration(LinkedHashSet::class.java) } @@ -232,8 +246,8 @@ class CordaClassResolverTests { @Test fun `Check blacklist precedes CordaSerializable`() { expectedEx.expect(IllegalStateException::class.java) - expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.core.serialization.CordaClassResolverTests\$CordaSerializableHashSet is blacklisted, so it cannot be used in serialization.") - val resolver = CordaClassResolver(AllButBlacklisted) + expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$CordaSerializableHashSet is blacklisted, so it cannot be used in serialization.") + val resolver = CordaClassResolver(factory, allButBlacklistedContext) // CordaSerializableHashSet is @CordaSerializable, but extends the blacklisted HashSet. resolver.getRegistration(CordaSerializableHashSet::class.java) } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt new file mode 100644 index 0000000000..66918ac812 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/KryoTests.kt @@ -0,0 +1,212 @@ +package net.corda.nodeapi.internal.serialization + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.KryoSerializable +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import com.google.common.primitives.Ints +import net.corda.core.crypto.* +import net.corda.core.serialization.* +import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.sequence +import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.node.services.persistence.NodeAttachmentService +import net.corda.testing.ALICE_PUBKEY +import net.corda.testing.TestDependencyInjectionBase +import org.assertj.core.api.Assertions.assertThat +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Before +import org.junit.Test +import org.slf4j.LoggerFactory +import java.io.ByteArrayInputStream +import java.io.InputStream +import java.time.Instant +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class KryoTests : TestDependencyInjectionBase() { + private lateinit var factory: SerializationFactory + private lateinit var context: SerializationContext + + @Before + fun setup() { + factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } + context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) + } + + @Test + fun ok() { + val birthday = Instant.parse("1984-04-17T00:30:00.00Z") + val mike = Person("mike", birthday) + val bits = mike.serialize(factory, context) + assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday)) + } + + @Test + fun nullables() { + val bob = Person("bob", null) + val bits = bob.serialize(factory, context) + assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null)) + } + + @Test + fun `serialised form is stable when the same object instance is added to the deserialised object graph`() { + val noReferencesContext = context.withoutReferences() + val obj = Ints.toByteArray(0x01234567).sequence() + val originalList = arrayListOf(obj) + val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext) + originalList += obj + deserialisedList += obj + assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext)) + } + + @Test + fun `serialised form is stable when the same object instance occurs more than once, and using java serialisation`() { + val noReferencesContext = context.withoutReferences() + val instant = Instant.ofEpochMilli(123) + val instantCopy = Instant.ofEpochMilli(123) + assertThat(instant).isNotSameAs(instantCopy) + val listWithCopies = arrayListOf(instant, instantCopy) + val listWithSameInstances = arrayListOf(instant, instant) + assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext)) + } + + @Test + fun `cyclic object graph`() { + val cyclic = Cyclic(3) + val bits = cyclic.serialize(factory, context) + assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic) + } + + @Test + fun `deserialised key pair functions the same as serialised one`() { + val keyPair = generateKeyPair() + val bitsToSign: ByteArray = Ints.toByteArray(0x01234567) + val wrongBits: ByteArray = Ints.toByteArray(0x76543210) + val signature = keyPair.sign(bitsToSign) + signature.verify(bitsToSign) + assertThatThrownBy { signature.verify(wrongBits) } + + val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context) + val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign) + deserialisedSignature.verify(bitsToSign) + assertThatThrownBy { deserialisedSignature.verify(wrongBits) } + } + + @Test + fun `write and read Kotlin object singleton`() { + val serialised = TestSingleton.serialize(factory, context) + val deserialised = serialised.deserialize(factory, context) + assertThat(deserialised).isSameAs(TestSingleton) + } + + @Test + fun `InputStream serialisation`() { + val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() }) + val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context) + for (i in 0..12344) { + assertEquals(rubbish[i], readRubbishStream.read().toByte()) + } + assertEquals(-1, readRubbishStream.read()) + } + + @Test + fun `serialize - deserialize SignableData`() { + val testString = "Hello World" + val testBytes = testString.toByteArray() + + val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID)) + val serializedMetaData = meta.serialize(factory, context).bytes + val meta2 = serializedMetaData.deserialize(factory, context) + assertEquals(meta2, meta) + } + + @Test + fun `serialize - deserialize Logger`() { + val storageContext: SerializationContext = context // TODO: make it storage context + val logger = LoggerFactory.getLogger("aName") + val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext) + assertEquals(logger.name, logger2.name) + assertTrue(logger === logger2) + } + + @Test + fun `HashCheckingStream (de)serialize`() { + val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() }) + val readRubbishStream: InputStream = NodeAttachmentService.HashCheckingStream(SecureHash.sha256(rubbish), rubbish.size, ByteArrayInputStream(rubbish)).serialize(factory, context).deserialize(factory, context) + for (i in 0..12344) { + assertEquals(rubbish[i], readRubbishStream.read().toByte()) + } + assertEquals(-1, readRubbishStream.read()) + } + + @CordaSerializable + private data class Person(val name: String, val birthday: Instant?) + + @Suppress("unused") + @CordaSerializable + private class Cyclic(val value: Int) { + val thisInstance = this + override fun equals(other: Any?): Boolean = (this === other) || (other is Cyclic && this.value == other.value) + override fun hashCode(): Int = value.hashCode() + override fun toString(): String = "Cyclic($value)" + } + + @CordaSerializable + private object TestSingleton + + object SimpleSteps { + object ONE : ProgressTracker.Step("one") + object TWO : ProgressTracker.Step("two") + object THREE : ProgressTracker.Step("three") + object FOUR : ProgressTracker.Step("four") + + fun tracker() = ProgressTracker(ONE, TWO, THREE, FOUR) + } + + object ChildSteps { + object AYY : ProgressTracker.Step("ayy") + object BEE : ProgressTracker.Step("bee") + object SEA : ProgressTracker.Step("sea") + + fun tracker() = ProgressTracker(AYY, BEE, SEA) + } + + @Test + fun rxSubscriptionsAreNotSerialized() { + val pt: ProgressTracker = SimpleSteps.tracker() + val pt2: ProgressTracker = ChildSteps.tracker() + + class Unserializable : KryoSerializable { + override fun write(kryo: Kryo?, output: Output?) = throw AssertionError("not called") + override fun read(kryo: Kryo?, input: Input?) = throw AssertionError("not called") + + fun foo() { + println("bar") + } + } + + pt.setChildProgressTracker(SimpleSteps.TWO, pt2) + class Tmp { + val unserializable = Unserializable() + + init { + pt2.changes.subscribe { unserializable.foo() } + } + } + Tmp() + val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } + val context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) + pt.serialize(factory, context) + } +} diff --git a/core/src/test/kotlin/net/corda/core/serialization/SerializationTokenTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt similarity index 50% rename from core/src/test/kotlin/net/corda/core/serialization/SerializationTokenTest.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt index 9b2517c8d5..03ab48214d 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/SerializationTokenTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/SerializationTokenTest.kt @@ -1,29 +1,33 @@ -package net.corda.core.serialization +package net.corda.nodeapi.internal.serialization import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoException import com.esotericsoftware.kryo.io.Output import com.nhaarman.mockito_kotlin.mock import net.corda.core.node.ServiceHub +import net.corda.core.serialization.* import net.corda.core.utilities.OpaqueBytes +import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.testing.TestDependencyInjectionBase import org.assertj.core.api.Assertions.assertThat -import org.junit.After import org.junit.Before import org.junit.Test import java.io.ByteArrayOutputStream -class SerializationTokenTest { +class SerializationTokenTest : TestDependencyInjectionBase() { - lateinit var kryo: Kryo + lateinit var factory: SerializationFactory + lateinit var context: SerializationContext @Before fun setup() { - kryo = storageKryo().borrow() - } - - @After - fun cleanup() { - storageKryo().release(kryo) + factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } + context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) } // Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized @@ -38,20 +42,18 @@ class SerializationTokenTest { override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size } - companion object { - private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContext(toBeTokenized, storageKryo(), mock()) - } + private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContextImpl(toBeTokenized, factory, context, mock()) @Test fun `write token and read tokenizable`() { val tokenizableBefore = LargeTokenizable() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.serialize(kryo) - assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) - val tokenizableAfter = serializedBytes.deserialize(kryo) - assertThat(tokenizableAfter).isSameAs(tokenizableBefore) - } + val testContext = this.context.withTokenContext(context) + + val serializedBytes = tokenizableBefore.serialize(factory, testContext) + assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) + val tokenizableAfter = serializedBytes.deserialize(factory, testContext) + assertThat(tokenizableAfter).isSameAs(tokenizableBefore) } private class UnitSerializeAsToken : SingletonSerializeAsToken() @@ -60,68 +62,65 @@ class SerializationTokenTest { fun `write and read singleton`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.serialize(kryo) - val tokenizableAfter = serializedBytes.deserialize(kryo) + val testContext = this.context.withTokenContext(context) + val serializedBytes = tokenizableBefore.serialize(factory, testContext) + val tokenizableAfter = serializedBytes.deserialize(factory, testContext) assertThat(tokenizableAfter).isSameAs(tokenizableBefore) - } } @Test(expected = UnsupportedOperationException::class) fun `new token encountered after context init`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) - kryo.withSerializationContext(context) { - tokenizableBefore.serialize(kryo) - } + val testContext = this.context.withTokenContext(context) + tokenizableBefore.serialize(factory, testContext) } @Test(expected = UnsupportedOperationException::class) fun `deserialize unregistered token`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(emptyList()) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).serialize(kryo) - serializedBytes.deserialize(kryo) - } + val testContext = this.context.withTokenContext(context) + val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList())).serialize(factory, testContext) + serializedBytes.deserialize(factory, testContext) } @Test(expected = KryoException::class) fun `no context set`() { val tokenizableBefore = UnitSerializeAsToken() - tokenizableBefore.serialize(kryo) + tokenizableBefore.serialize(factory, context) } @Test(expected = KryoException::class) fun `deserialize non-token`() { val tokenizableBefore = UnitSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val stream = ByteArrayOutputStream() + val testContext = this.context.withTokenContext(context) + + val kryo: Kryo = DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(factory, this.context))) + val stream = ByteArrayOutputStream() Output(stream).use { + it.write(KryoHeaderV0_1.bytes) kryo.writeClass(it, SingletonSerializeAsToken::class.java) kryo.writeObject(it, emptyList()) } - val serializedBytes = SerializedBytes(stream.toByteArray()) - serializedBytes.deserialize(kryo) - } + val serializedBytes = SerializedBytes(stream.toByteArray()) + serializedBytes.deserialize(factory, testContext) } private class WrongTypeSerializeAsToken : SerializeAsToken { - override fun toToken(context: SerializeAsTokenContext): SerializationToken { - return object : SerializationToken { - override fun fromToken(context: SerializeAsTokenContext): Any = UnitSerializeAsToken() - } + object UnitSerializationToken : SerializationToken { + override fun fromToken(context: SerializeAsTokenContext): Any = UnitSerializeAsToken() } + override fun toToken(context: SerializeAsTokenContext): SerializationToken = UnitSerializationToken } @Test(expected = KryoException::class) fun `token returns unexpected type`() { val tokenizableBefore = WrongTypeSerializeAsToken() val context = serializeAsTokenContext(tokenizableBefore) - kryo.withSerializationContext(context) { - val serializedBytes = tokenizableBefore.serialize(kryo) - serializedBytes.deserialize(kryo) - } + val testContext = this.context.withTokenContext(context) + val serializedBytes = tokenizableBefore.serialize(factory, testContext) + serializedBytes.deserialize(factory, testContext) } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPTestUtils.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPTestUtils.kt new file mode 100644 index 0000000000..19b93b6542 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/AMQPTestUtils.kt @@ -0,0 +1,13 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import org.apache.qpid.proton.codec.Data + +class TestSerializationOutput( + private val verbose: Boolean, + serializerFactory: SerializerFactory = SerializerFactory()) : SerializationOutput(serializerFactory) { + + override fun writeSchema(schema: Schema, data: Data) { + if (verbose) println(schema) + super.writeSchema(schema, data) + } +} diff --git a/core/src/test/kotlin/net/corda/core/serialization/amqp/DeserializeAndReturnEnvelopeTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeAndReturnEnvelopeTests.kt similarity index 78% rename from core/src/test/kotlin/net/corda/core/serialization/amqp/DeserializeAndReturnEnvelopeTests.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeAndReturnEnvelopeTests.kt index ca172680cf..b91fa3fe9e 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/amqp/DeserializeAndReturnEnvelopeTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeAndReturnEnvelopeTests.kt @@ -1,11 +1,15 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import org.junit.Test -import kotlin.test.* +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue class DeserializeAndReturnEnvelopeTests { - fun testName() = Thread.currentThread().stackTrace[2].methodName + fun testName(): String = Thread.currentThread().stackTrace[2].methodName + + @Suppress("NOTHING_TO_INLINE") inline fun classTestName(clazz: String) = "${this.javaClass.name}\$${testName()}\$$clazz" @Test @@ -14,7 +18,7 @@ class DeserializeAndReturnEnvelopeTests { val a = A(10, "20") - var factory = SerializerFactory() + val factory = SerializerFactory() fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz) val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) @@ -30,7 +34,7 @@ class DeserializeAndReturnEnvelopeTests { val b = B(A(10, "20"), 30.0F) - var factory = SerializerFactory() + val factory = SerializerFactory() fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz) val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt new file mode 100644 index 0000000000..76933fb370 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentrySimpleTypesTest.kt @@ -0,0 +1,446 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import org.junit.Test +import kotlin.test.* +import net.corda.nodeapi.internal.serialization.carpenter.* + +// These tests work by having the class carpenter build the classes we serialise and then deserialise. Because +// those classes don't exist within the system's Class Loader the deserialiser will be forced to carpent +// versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This +// replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath +class DeserializeNeedingCarpentrySimpleTypesTest { + companion object { + /** + * If you want to see the schema encoded into the envelope after serialisation change this to true + */ + private const val VERBOSE = false + } + + val sf = SerializerFactory() + val sf2 = SerializerFactory() + + @Test + fun singleInt() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "int" to NonNullableField(Integer::class.javaPrimitiveType!!) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(1)) + val db = DeserializationInput(sf).deserialize(sb) + val db2 = DeserializationInput(sf2).deserialize(sb) + + // despite being carpented, and thus not on the class path, we should've cached clazz + // inside the serialiser object and thus we should have created the same type + assertEquals (db::class.java, clazz) + assertNotEquals (db2::class.java, clazz) + assertNotEquals (db::class.java, db2::class.java) + + assertEquals(1, db::class.java.getMethod("getInt").invoke(db)) + assertEquals(1, db2::class.java.getMethod("getInt").invoke(db2)) + } + + @Test + fun singleIntNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "int" to NullableField(Integer::class.java) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(1)) + val db1 = DeserializationInput(sf).deserialize(sb) + val db2 = DeserializationInput(sf2).deserialize(sb) + + assertEquals(clazz, db1::class.java) + assertNotEquals(clazz, db2::class.java) + assertEquals(1, db1::class.java.getMethod("getInt").invoke(db1)) + assertEquals(1, db2::class.java.getMethod("getInt").invoke(db2)) + } + + @Test + fun singleIntNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "int" to NullableField(Integer::class.java) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db1 = DeserializationInput(sf).deserialize(sb) + val db2 = DeserializationInput(sf2).deserialize(sb) + + assertEquals(clazz, db1::class.java) + assertNotEquals(clazz, db2::class.java) + assertEquals(null, db1::class.java.getMethod("getInt").invoke(db1)) + assertEquals(null, db2::class.java.getMethod("getInt").invoke(db2)) + } + + @Test + fun singleChar() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "char" to NonNullableField(Character::class.javaPrimitiveType!!) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance('a')) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals('a', db::class.java.getMethod("getChar").invoke(db)) + } + + @Test + fun singleCharNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "char" to NullableField(Character::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance('a')) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals('a', db::class.java.getMethod("getChar").invoke(db)) + } + + @Test + fun singleCharNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "char" to NullableField(java.lang.Character::class.java) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(null, db::class.java.getMethod("getChar").invoke(db)) + } + + @Test + fun singleLong() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "long" to NonNullableField(Long::class.javaPrimitiveType!!) + ))) + + val l : Long = 1 + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(l, (db::class.java.getMethod("getLong").invoke(db))) + } + + @Test + fun singleLongNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "long" to NullableField(Long::class.javaObjectType) + ))) + + val l : Long = 1 + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(l, (db::class.java.getMethod("getLong").invoke(db))) + } + + @Test + fun singleLongNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "long" to NullableField(Long::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(null, (db::class.java.getMethod("getLong").invoke(db))) + } + + @Test + fun singleBoolean() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "boolean" to NonNullableField(Boolean::class.javaPrimitiveType!!) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(true)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(true, db::class.java.getMethod("getBoolean").invoke(db)) + } + + @Test + fun singleBooleanNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "boolean" to NullableField(Boolean::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(true)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(true, db::class.java.getMethod("getBoolean").invoke(db)) + } + + @Test + fun singleBooleanNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "boolean" to NullableField(Boolean::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(null, db::class.java.getMethod("getBoolean").invoke(db)) + } + + @Test + fun singleDouble() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "double" to NonNullableField(Double::class.javaPrimitiveType!!) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(10.0, db::class.java.getMethod("getDouble").invoke(db)) + } + + @Test + fun singleDoubleNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "double" to NullableField(Double::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(10.0, db::class.java.getMethod("getDouble").invoke(db)) + } + + @Test + fun singleDoubleNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "double" to NullableField(Double::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(null, db::class.java.getMethod("getDouble").invoke(db)) + } + + @Test + fun singleShort() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "short" to NonNullableField(Short::class.javaPrimitiveType!!) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(3.toShort())) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(3.toShort(), db::class.java.getMethod("getShort").invoke(db)) + } + + @Test + fun singleShortNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "short" to NullableField(Short::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(3.toShort())) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(3.toShort(), db::class.java.getMethod("getShort").invoke(db)) + } + + @Test + fun singleShortNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "short" to NullableField(Short::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(null, db::class.java.getMethod("getShort").invoke(db)) + } + + @Test + fun singleFloat() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "float" to NonNullableField(Float::class.javaPrimitiveType!!) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0F)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(10.0F, db::class.java.getMethod("getFloat").invoke(db)) + } + + @Test + fun singleFloatNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "float" to NullableField(Float::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(10.0F)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(10.0F, db::class.java.getMethod("getFloat").invoke(db)) + } + + @Test + fun singleFloatNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "float" to NullableField(Float::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(null, db::class.java.getMethod("getFloat").invoke(db)) + } + + @Test + fun singleByte() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "byte" to NonNullableField(Byte::class.javaPrimitiveType!!) + ))) + + val b : Byte = 0b0101 + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(b, db::class.java.getMethod("getByte").invoke(db)) + assertEquals(0b0101, (db::class.java.getMethod("getByte").invoke(db) as Byte)) + } + + @Test + fun singleByteNullable() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "byte" to NullableField(Byte::class.javaObjectType) + ))) + + val b : Byte = 0b0101 + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(b, db::class.java.getMethod("getByte").invoke(db)) + assertEquals(0b0101, (db::class.java.getMethod("getByte").invoke(db) as Byte)) + } + + @Test + fun singleByteNullableNull() { + val clazz = ClassCarpenter().build(ClassSchema("single", mapOf( + "byte" to NullableField(Byte::class.javaObjectType) + ))) + + val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(null)) + val db = DeserializationInput(sf2).deserialize(sb) + + assertNotEquals(clazz, db::class.java) + assertEquals(null, db::class.java.getMethod("getByte").invoke(db)) + } + + @Test + fun simpleTypeKnownInterface() { + val clazz = ClassCarpenter().build (ClassSchema( + "oneType", mapOf("name" to NonNullableField(String::class.java)), + interfaces = listOf (I::class.java))) + val testVal = "Some Person" + val classInstance = clazz.constructors[0].newInstance(testVal) + + val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + + assertNotEquals(clazz, deserializedObj::class.java) + assertTrue(deserializedObj is I) + assertEquals(testVal, (deserializedObj as I).getName()) + } + + @Test + fun manyTypes() { + val manyClass = ClassCarpenter().build (ClassSchema("many", mapOf( + "intA" to NonNullableField (Int::class.java), + "intB" to NullableField (Integer::class.java), + "intC" to NullableField (Integer::class.java), + "strA" to NonNullableField (String::class.java), + "strB" to NullableField (String::class.java), + "strC" to NullableField (String::class.java), + "charA" to NonNullableField (Char::class.java), + "charB" to NullableField (Character::class.javaObjectType), + "charC" to NullableField (Character::class.javaObjectType), + "shortA" to NonNullableField (Short::class.javaPrimitiveType!!), + "shortB" to NullableField (Short::class.javaObjectType), + "shortC" to NullableField (Short::class.javaObjectType), + "longA" to NonNullableField (Long::class.javaPrimitiveType!!), + "longB" to NullableField(Long::class.javaObjectType), + "longC" to NullableField(Long::class.javaObjectType), + "booleanA" to NonNullableField (Boolean::class.javaPrimitiveType!!), + "booleanB" to NullableField (Boolean::class.javaObjectType), + "booleanC" to NullableField (Boolean::class.javaObjectType), + "doubleA" to NonNullableField (Double::class.javaPrimitiveType!!), + "doubleB" to NullableField (Double::class.javaObjectType), + "doubleC" to NullableField (Double::class.javaObjectType), + "floatA" to NonNullableField (Float::class.javaPrimitiveType!!), + "floatB" to NullableField (Float::class.javaObjectType), + "floatC" to NullableField (Float::class.javaObjectType), + "byteA" to NonNullableField (Byte::class.javaPrimitiveType!!), + "byteB" to NullableField (Byte::class.javaObjectType), + "byteC" to NullableField (Byte::class.javaObjectType)))) + + val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize( + manyClass.constructors.first().newInstance( + 1, 2, null, + "a", "b", null, + 'c', 'd', null, + 3.toShort(), 4.toShort(), null, + 100.toLong(), 200.toLong(), null, + true, false, null, + 10.0, 20.0, null, + 10.0F, 20.0F, null, + 0b0101.toByte(), 0b1010.toByte(), null)) + + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + + assertNotEquals(manyClass, deserializedObj::class.java) + assertEquals(1, deserializedObj::class.java.getMethod("getIntA").invoke(deserializedObj)) + assertEquals(2, deserializedObj::class.java.getMethod("getIntB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getIntC").invoke(deserializedObj)) + assertEquals("a", deserializedObj::class.java.getMethod("getStrA").invoke(deserializedObj)) + assertEquals("b", deserializedObj::class.java.getMethod("getStrB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getStrC").invoke(deserializedObj)) + assertEquals('c', deserializedObj::class.java.getMethod("getCharA").invoke(deserializedObj)) + assertEquals('d', deserializedObj::class.java.getMethod("getCharB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getCharC").invoke(deserializedObj)) + assertEquals(3.toShort(), deserializedObj::class.java.getMethod("getShortA").invoke(deserializedObj)) + assertEquals(4.toShort(), deserializedObj::class.java.getMethod("getShortB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getShortC").invoke(deserializedObj)) + assertEquals(100.toLong(), deserializedObj::class.java.getMethod("getLongA").invoke(deserializedObj)) + assertEquals(200.toLong(), deserializedObj::class.java.getMethod("getLongB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getLongC").invoke(deserializedObj)) + assertEquals(true, deserializedObj::class.java.getMethod("getBooleanA").invoke(deserializedObj)) + assertEquals(false, deserializedObj::class.java.getMethod("getBooleanB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getBooleanC").invoke(deserializedObj)) + assertEquals(10.0, deserializedObj::class.java.getMethod("getDoubleA").invoke(deserializedObj)) + assertEquals(20.0, deserializedObj::class.java.getMethod("getDoubleB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getDoubleC").invoke(deserializedObj)) + assertEquals(10.0F, deserializedObj::class.java.getMethod("getFloatA").invoke(deserializedObj)) + assertEquals(20.0F, deserializedObj::class.java.getMethod("getFloatB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getFloatC").invoke(deserializedObj)) + assertEquals(0b0101.toByte(), deserializedObj::class.java.getMethod("getByteA").invoke(deserializedObj)) + assertEquals(0b1010.toByte(), deserializedObj::class.java.getMethod("getByteB").invoke(deserializedObj)) + assertEquals(null, deserializedObj::class.java.getMethod("getByteC").invoke(deserializedObj)) + } +} + + + diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt new file mode 100644 index 0000000000..d0cae9e3b6 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeNeedingCarpentryTests.kt @@ -0,0 +1,240 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import org.junit.Test +import kotlin.test.* +import net.corda.nodeapi.internal.serialization.carpenter.* + +interface I { + fun getName() : String +} + +/** + * These tests work by having the class carpenter build the classes we serialise and then deserialise them + * within the context of a second serialiser factory. The second factory is required as the first, having + * been used to serialise the class, will have cached a copy of the class and will thus bypass the need + * to pull it out of the class loader. + * + * However, those classes don't exist within the system's Class Loader and thus the deserialiser will be forced + * to carpent versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This + * replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath + */ +class DeserializeNeedingCarpentryTests { + companion object { + /** + * If you want to see the schema encoded into the envelope after serialisation change this to true + */ + private const val VERBOSE = false + } + + val sf1 = SerializerFactory() + val sf2 = SerializerFactory() + + @Test + fun verySimpleType() { + val testVal = 10 + val clazz = ClassCarpenter().build(ClassSchema("oneType", mapOf("a" to NonNullableField(Int::class.java)))) + val classInstance = clazz.constructors[0].newInstance(testVal) + + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) + + val deserializedObj1 = DeserializationInput(sf1).deserialize(serialisedBytes) + assertEquals(clazz, deserializedObj1::class.java) + assertEquals (testVal, deserializedObj1::class.java.getMethod("getA").invoke(deserializedObj1)) + + val deserializedObj2 = DeserializationInput(sf1).deserialize(serialisedBytes) + assertEquals(clazz, deserializedObj2::class.java) + assertEquals(deserializedObj1::class.java, deserializedObj2::class.java) + assertEquals (testVal, deserializedObj2::class.java.getMethod("getA").invoke(deserializedObj2)) + + val deserializedObj3 = DeserializationInput(sf2).deserialize(serialisedBytes) + assertNotEquals(clazz, deserializedObj3::class.java) + assertNotEquals(deserializedObj1::class.java, deserializedObj3::class.java) + assertNotEquals(deserializedObj2::class.java, deserializedObj3::class.java) + assertEquals (testVal, deserializedObj3::class.java.getMethod("getA").invoke(deserializedObj3)) + + val deserializedObj4 = DeserializationInput(sf2).deserialize(serialisedBytes) + assertNotEquals(clazz, deserializedObj4::class.java) + assertNotEquals(deserializedObj1::class.java, deserializedObj4::class.java) + assertNotEquals(deserializedObj2::class.java, deserializedObj4::class.java) + assertEquals(deserializedObj3::class.java, deserializedObj4::class.java) + assertEquals (testVal, deserializedObj4::class.java.getMethod("getA").invoke(deserializedObj4)) + + } + + @Test + fun repeatedTypesAreRecognised() { + val testValA = 10 + val testValB = 20 + val testValC = 20 + val clazz = ClassCarpenter().build(ClassSchema("oneType", mapOf("a" to NonNullableField(Int::class.java)))) + val concreteA = clazz.constructors[0].newInstance(testValA) + val concreteB = clazz.constructors[0].newInstance(testValB) + val concreteC = clazz.constructors[0].newInstance(testValC) + + val deserialisedA = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(concreteA)) + + assertEquals (testValA, deserialisedA::class.java.getMethod("getA").invoke(deserialisedA)) + + val deserialisedB = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(concreteB)) + + assertEquals (testValB, deserialisedA::class.java.getMethod("getA").invoke(deserialisedB)) + assertEquals (deserialisedA::class.java, deserialisedB::class.java) + + // C is deseriliased with a different factory, meaning a different class carpenter, so the type + // won't already exist and it will be carpented a second time showing that when A and B are the + // same underlying class that we didn't create a second instance of the class with the + // second deserialisation + val lsf = SerializerFactory() + val deserialisedC = DeserializationInput(lsf).deserialize(TestSerializationOutput(VERBOSE, lsf).serialize(concreteC)) + assertEquals (testValC, deserialisedC::class.java.getMethod("getA").invoke(deserialisedC)) + assertNotEquals (deserialisedA::class.java, deserialisedC::class.java) + assertNotEquals (deserialisedB::class.java, deserialisedC::class.java) + } + + @Test + fun simpleTypeKnownInterface() { + val clazz = ClassCarpenter().build (ClassSchema( + "oneType", mapOf("name" to NonNullableField(String::class.java)), + interfaces = listOf (I::class.java))) + val testVal = "Some Person" + val classInstance = clazz.constructors[0].newInstance(testVal) + + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + + assertTrue(deserializedObj is I) + assertEquals(testVal, (deserializedObj as I).getName()) + } + + @Test + fun arrayOfTypes() { + val clazz = ClassCarpenter().build(ClassSchema("oneType", mapOf("a" to NonNullableField(Int::class.java)))) + + data class Outer (val a : Array) + + val outer = Outer (arrayOf ( + clazz.constructors[0].newInstance(1), + clazz.constructors[0].newInstance(2), + clazz.constructors[0].newInstance(3))) + + val deserializedObj = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(outer)) + + assertNotEquals((deserializedObj.a[0])::class.java, (outer.a[0])::class.java) + assertNotEquals((deserializedObj.a[1])::class.java, (outer.a[1])::class.java) + assertNotEquals((deserializedObj.a[2])::class.java, (outer.a[2])::class.java) + + assertEquals((deserializedObj.a[0])::class.java, (deserializedObj.a[1])::class.java) + assertEquals((deserializedObj.a[0])::class.java, (deserializedObj.a[2])::class.java) + assertEquals((deserializedObj.a[1])::class.java, (deserializedObj.a[2])::class.java) + + assertEquals( + outer.a[0]::class.java.getMethod("getA").invoke(outer.a[0]), + deserializedObj.a[0]::class.java.getMethod("getA").invoke(deserializedObj.a[0])) + assertEquals( + outer.a[1]::class.java.getMethod("getA").invoke(outer.a[1]), + deserializedObj.a[1]::class.java.getMethod("getA").invoke(deserializedObj.a[1])) + assertEquals( + outer.a[2]::class.java.getMethod("getA").invoke(outer.a[2]), + deserializedObj.a[2]::class.java.getMethod("getA").invoke(deserializedObj.a[2])) + } + + @Test + fun reusedClasses() { + val cc = ClassCarpenter() + + val innerType = cc.build(ClassSchema("inner", mapOf("a" to NonNullableField(Int::class.java)))) + val outerType = cc.build(ClassSchema("outer", mapOf("a" to NonNullableField(innerType)))) + val inner = innerType.constructors[0].newInstance(1) + val outer = outerType.constructors[0].newInstance(innerType.constructors[0].newInstance(2)) + + val serializedI = TestSerializationOutput(VERBOSE, sf1).serialize(inner) + val deserialisedI = DeserializationInput(sf2).deserialize(serializedI) + val serialisedO = TestSerializationOutput(VERBOSE, sf1).serialize(outer) + val deserialisedO = DeserializationInput(sf2).deserialize(serialisedO) + + // ensure out carpented version of inner is reused + assertEquals (deserialisedI::class.java, + (deserialisedO::class.java.getMethod("getA").invoke(deserialisedO))::class.java) + } + + @Test + fun nestedTypes() { + val cc = ClassCarpenter() + val nestedClass = cc.build (ClassSchema("nestedType", + mapOf("name" to NonNullableField(String::class.java)))) + + val outerClass = cc.build (ClassSchema("outerType", + mapOf("inner" to NonNullableField(nestedClass)))) + + val classInstance = outerClass.constructors.first().newInstance(nestedClass.constructors.first().newInstance("name")) + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + + val inner = deserializedObj::class.java.getMethod("getInner").invoke(deserializedObj) + assertEquals("name", inner::class.java.getMethod("getName").invoke(inner)) + } + + @Test + fun repeatedNestedTypes() { + val cc = ClassCarpenter() + val nestedClass = cc.build (ClassSchema("nestedType", + mapOf("name" to NonNullableField(String::class.java)))) + + data class outer(val a: Any, val b: Any) + + val classInstance = outer ( + nestedClass.constructors.first().newInstance("foo"), + nestedClass.constructors.first().newInstance("bar")) + + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + + assertEquals ("foo", deserializedObj.a::class.java.getMethod("getName").invoke(deserializedObj.a)) + assertEquals ("bar", deserializedObj.b::class.java.getMethod("getName").invoke(deserializedObj.b)) + } + + @Test + fun listOfType() { + val unknownClass = ClassCarpenter().build (ClassSchema("unknownClass", mapOf( + "v1" to NonNullableField(Int::class.java), + "v2" to NonNullableField(Int::class.java)))) + + data class outer (val l : List) + val toSerialise = outer (listOf ( + unknownClass.constructors.first().newInstance(1, 2), + unknownClass.constructors.first().newInstance(3, 4), + unknownClass.constructors.first().newInstance(5, 6), + unknownClass.constructors.first().newInstance(7, 8))) + + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(toSerialise) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + var sentinel = 1 + deserializedObj.l.forEach { + assertEquals(sentinel++, it::class.java.getMethod("getV1").invoke(it)) + assertEquals(sentinel++, it::class.java.getMethod("getV2").invoke(it)) + } + } + + @Test + fun unknownInterface() { + val cc = ClassCarpenter() + + val interfaceClass = cc.build (InterfaceSchema( + "gen.Interface", + mapOf("age" to NonNullableField (Int::class.java)))) + + val concreteClass = cc.build (ClassSchema ("gen.Class", mapOf( + "age" to NonNullableField (Int::class.java), + "name" to NonNullableField(String::class.java)), + interfaces = listOf (I::class.java, interfaceClass))) + + val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize( + concreteClass.constructors.first().newInstance(12, "timmy")) + val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + + assertTrue(deserializedObj is I) + assertEquals("timmy", (deserializedObj as I).getName()) + assertEquals("timmy", deserializedObj::class.java.getMethod("getName").invoke(deserializedObj)) + assertEquals(12, deserializedObj::class.java.getMethod("getAge").invoke(deserializedObj)) + } +} diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeSimpleTypesTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeSimpleTypesTests.kt new file mode 100644 index 0000000000..1e7171b31a --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializeSimpleTypesTests.kt @@ -0,0 +1,484 @@ +package net.corda.nodeapi.internal.serialization.amqp + +import org.junit.Test +import kotlin.test.assertEquals + +// Prior to certain fixes being made within the [PropertySerializaer] classes these simple +// deserialization operations would've blown up with type mismatch errors where the deserlized +// char property of the class would've been treated as an Integer and given to the constructor +// as such +class DeserializeSimpleTypesTests { + + companion object { + /** + * If you want to see the schema encoded into the envelope after serialisation change this to true + */ + private const val VERBOSE = false + } + + val sf1 = SerializerFactory() + val sf2 = SerializerFactory() + + @Test + fun testChar() { + data class C(val c: Char) + + var deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('c'))) + assertEquals('c', deserializedC.c) + + // CYRILLIC CAPITAL LETTER YU (U+042E) + deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('Ю'))) + assertEquals('Ю', deserializedC.c) + + // ARABIC LETTER FEH WITH DOT BELOW (U+06A3) + deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('ڣ'))) + assertEquals('ڣ', deserializedC.c) + + // ARABIC LETTER DAD WITH DOT BELOW (U+06FB) + deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('ۻ'))) + assertEquals('ۻ', deserializedC.c) + + // BENGALI LETTER AA (U+0986) + deserializedC = DeserializationInput().deserialize(SerializationOutput().serialize(C('আ'))) + assertEquals('আ', deserializedC.c) + } + + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") + @Test + fun testCharacter() { + data class C(val c: Character) + + val c = C(Character('c')) + val serialisedC = SerializationOutput().serialize(c) + val deserializedC = DeserializationInput().deserialize(serialisedC) + + assertEquals(c.c, deserializedC.c) + } + + @Test + fun testNullCharacter() { + data class C(val c: Char?) + + val c = C(null) + val serialisedC = SerializationOutput().serialize(c) + val deserializedC = DeserializationInput().deserialize(serialisedC) + + assertEquals(c.c, deserializedC.c) + } + + @Test + fun testArrayOfInt() { + class IA(val ia: Array) + + val ia = IA(arrayOf(1, 2, 3)) + + assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString()) + assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]") + + val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) + val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) + + assertEquals(ia.ia.size, deserializedIA.ia.size) + assertEquals(ia.ia[0], deserializedIA.ia[0]) + assertEquals(ia.ia[1], deserializedIA.ia[1]) + assertEquals(ia.ia[2], deserializedIA.ia[2]) + } + + @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") + @Test + fun testArrayOfInteger() { + class IA(val ia: Array) + + val ia = IA(arrayOf(Integer(1), Integer(2), Integer(3))) + + assertEquals("class [Ljava.lang.Integer;", ia.ia::class.java.toString()) + assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[]") + + val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) + val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) + + assertEquals(ia.ia.size, deserializedIA.ia.size) + assertEquals(ia.ia[0], deserializedIA.ia[0]) + assertEquals(ia.ia[1], deserializedIA.ia[1]) + assertEquals(ia.ia[2], deserializedIA.ia[2]) + } + + /** + * Test unboxed primitives + */ + @Test + fun testIntArray() { + class IA(val ia: IntArray) + + val v = IntArray(3) + v[0] = 1; v[1] = 2; v[2] = 3 + val ia = IA(v) + + assertEquals("class [I", ia.ia::class.java.toString()) + assertEquals(SerializerFactory.nameForType(ia.ia::class.java), "int[p]") + + val serialisedIA = TestSerializationOutput(VERBOSE, sf1).serialize(ia) + val deserializedIA = DeserializationInput(sf1).deserialize(serialisedIA) + + assertEquals(ia.ia.size, deserializedIA.ia.size) + assertEquals(ia.ia[0], deserializedIA.ia[0]) + assertEquals(ia.ia[1], deserializedIA.ia[1]) + assertEquals(ia.ia[2], deserializedIA.ia[2]) + } + + @Test + fun testArrayOfChars() { + class C(val c: Array) + + val c = C(arrayOf('a', 'b', 'c')) + + assertEquals("class [Ljava.lang.Character;", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testCharArray() { + class C(val c: CharArray) + + val v = CharArray(3) + v[0] = 'a'; v[1] = 'b'; v[2] = 'c' + val c = C(v) + + assertEquals("class [C", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "char[p]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + var deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + + // second test with more interesting characters + v[0] = 'ই'; v[1] = ' '; v[2] = 'ਔ' + val c2 = C(v) + + deserializedC = DeserializationInput(sf1).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(c2)) + + assertEquals(c2.c.size, deserializedC.c.size) + assertEquals(c2.c[0], deserializedC.c[0]) + assertEquals(c2.c[1], deserializedC.c[1]) + assertEquals(c2.c[2], deserializedC.c[2]) + } + + @Test + fun testArrayOfBoolean() { + class C(val c: Array) + + val c = C(arrayOf(true, false, false, true)) + + assertEquals("class [Ljava.lang.Boolean;", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + assertEquals(c.c[3], deserializedC.c[3]) + } + + @Test + fun testBooleanArray() { + class C(val c: BooleanArray) + + val c = C(BooleanArray(4)) + c.c[0] = true; c.c[1] = false; c.c[2] = false; c.c[3] = true + + assertEquals("class [Z", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "boolean[p]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + assertEquals(c.c[3], deserializedC.c[3]) + } + + @Test + fun testArrayOfByte() { + class C(val c: Array) + + val c = C(arrayOf(0b0001, 0b0101, 0b1111)) + + assertEquals("class [Ljava.lang.Byte;", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "byte[]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testByteArray() { + class C(val c: ByteArray) + + val c = C(ByteArray(3)) + c.c[0] = 0b0001; c.c[1] = 0b0101; c.c[2] = 0b1111 + + assertEquals("class [B", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "binary") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testArrayOfShort() { + class C(val c: Array) + + val c = C(arrayOf(1, 2, 3)) + + assertEquals("class [Ljava.lang.Short;", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testShortArray() { + class C(val c: ShortArray) + + val c = C(ShortArray(3)) + c.c[0] = 1; c.c[1] = 2; c.c[2] = 5 + + assertEquals("class [S", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "short[p]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testArrayOfLong() { + class C(val c: Array) + + val c = C(arrayOf(2147483650, -2147483800, 10)) + + assertEquals("class [Ljava.lang.Long;", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testLongArray() { + class C(val c: LongArray) + + val c = C(LongArray(3)) + c.c[0] = 2147483650; c.c[1] = -2147483800; c.c[2] = 10 + + assertEquals("class [J", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "long[p]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testArrayOfFloat() { + class C(val c: Array) + + val c = C(arrayOf(10F, 100.023232F, -1455.433400F)) + + assertEquals("class [Ljava.lang.Float;", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testFloatArray() { + class C(val c: FloatArray) + + val c = C(FloatArray(3)) + c.c[0] = 10F; c.c[1] = 100.023232F; c.c[2] = -1455.433400F + + assertEquals("class [F", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "float[p]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testArrayOfDouble() { + class C(val c: Array) + + val c = C(arrayOf(10.0, 100.2, -1455.2)) + + assertEquals("class [Ljava.lang.Double;", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun testDoubleArray() { + class C(val c: DoubleArray) + + val c = C(DoubleArray(3)) + c.c[0] = 10.0; c.c[1] = 100.2; c.c[2] = -1455.2 + + assertEquals("class [D", c.c::class.java.toString()) + assertEquals(SerializerFactory.nameForType(c.c::class.java), "double[p]") + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0], deserializedC.c[0]) + assertEquals(c.c[1], deserializedC.c[1]) + assertEquals(c.c[2], deserializedC.c[2]) + } + + @Test + fun arrayOfArrayOfInt() { + class C(val c: Array>) + val c = C (arrayOf (arrayOf(1,2,3), arrayOf(4,5,6))) + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0].size, deserializedC.c[0].size) + assertEquals(c.c[0][0], deserializedC.c[0][0]) + assertEquals(c.c[0][1], deserializedC.c[0][1]) + assertEquals(c.c[0][2], deserializedC.c[0][2]) + assertEquals(c.c[1].size, deserializedC.c[1].size) + assertEquals(c.c[1][0], deserializedC.c[1][0]) + assertEquals(c.c[1][1], deserializedC.c[1][1]) + assertEquals(c.c[1][2], deserializedC.c[1][2]) + } + + @Test + fun arrayOfIntArray() { + class C(val c: Array) + val c = C (arrayOf (IntArray(3), IntArray(3))) + c.c[0][0] = 1; c.c[0][1] = 2; c.c[0][2] = 3 + c.c[1][0] = 4; c.c[1][1] = 5; c.c[1][2] = 6 + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + assertEquals(c.c.size, deserializedC.c.size) + assertEquals(c.c[0].size, deserializedC.c[0].size) + assertEquals(c.c[0][0], deserializedC.c[0][0]) + assertEquals(c.c[0][1], deserializedC.c[0][1]) + assertEquals(c.c[0][2], deserializedC.c[0][2]) + assertEquals(c.c[1].size, deserializedC.c[1].size) + assertEquals(c.c[1][0], deserializedC.c[1][0]) + assertEquals(c.c[1][1], deserializedC.c[1][1]) + assertEquals(c.c[1][2], deserializedC.c[1][2]) + } + + @Test + fun arrayOfArrayOfIntArray() { + class C(val c: Array>) + + val c = C(arrayOf(arrayOf(IntArray(3), IntArray(3), IntArray(3)), + arrayOf(IntArray(3), IntArray(3), IntArray(3)), + arrayOf(IntArray(3), IntArray(3), IntArray(3)))) + + for (i in 0..2) { for (j in 0..2) { for (k in 0..2) { c.c[i][j][k] = i + j + k } } } + + val serialisedC = TestSerializationOutput(VERBOSE, sf1).serialize(c) + val deserializedC = DeserializationInput(sf1).deserialize(serialisedC) + + for (i in 0..2) { for (j in 0..2) { for (k in 0..2) { + assertEquals(c.c[i][j][k], deserializedC.c[i][j][k]) + }}} + } + + @Test + fun nestedRepeatedTypes() { + class A(val a : A?, val b: Int) + + var a = A(A(A(A(A(null, 1), 2), 3), 4), 5) + + val sa = TestSerializationOutput(VERBOSE, sf1).serialize(a) + val da1 = DeserializationInput(sf1).deserialize(sa) + val da2 = DeserializationInput(sf2).deserialize(sa) + + assertEquals(5, da1.b) + assertEquals(4, da1.a?.b) + assertEquals(3, da1.a?.a?.b) + assertEquals(2, da1.a?.a?.a?.b) + assertEquals(1, da1.a?.a?.a?.a?.b) + + assertEquals(5, da2.b) + assertEquals(4, da2.a?.b) + assertEquals(3, da2.a?.a?.b) + assertEquals(2, da2.a?.a?.a?.b) + assertEquals(1, da2.a?.a?.a?.a?.b) + + } +} + diff --git a/core/src/test/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedTypeTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedTypeTests.kt similarity index 98% rename from core/src/test/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedTypeTests.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedTypeTests.kt index 35b9f14236..469127061d 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/amqp/DeserializedParameterizedTypeTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/DeserializedParameterizedTypeTests.kt @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp import org.junit.Test import java.io.NotSerializableException diff --git a/core/src/test/java/net/corda/core/serialization/amqp/JavaSerializationOutputTests.java b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java similarity index 99% rename from core/src/test/java/net/corda/core/serialization/amqp/JavaSerializationOutputTests.java rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java index fe9e9f07a1..afb558ebdb 100644 --- a/core/src/test/java/net/corda/core/serialization/amqp/JavaSerializationOutputTests.java +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/JavaSerializationOutputTests.java @@ -1,4 +1,4 @@ -package net.corda.core.serialization.amqp; +package net.corda.nodeapi.internal.serialization.amqp; import net.corda.core.serialization.SerializedBytes; import org.apache.qpid.proton.codec.DecoderImpl; diff --git a/core/src/test/kotlin/net/corda/core/serialization/amqp/SerializationOutputTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt similarity index 80% rename from core/src/test/kotlin/net/corda/core/serialization/amqp/SerializationOutputTests.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt index 54771ac805..53bdaaafe1 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/amqp/SerializationOutputTests.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/amqp/SerializationOutputTests.kt @@ -1,16 +1,23 @@ -package net.corda.core.serialization.amqp +package net.corda.nodeapi.internal.serialization.amqp -import net.corda.core.contracts.* +import net.corda.core.CordaRuntimeException +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TransactionState import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowException import net.corda.core.identity.AbstractParty import net.corda.core.serialization.CordaSerializable -import net.corda.core.serialization.EmptyWhitelist -import net.corda.core.serialization.KryoAMQPSerializer -import net.corda.core.CordaRuntimeException +import net.corda.core.transactions.LedgerTransaction import net.corda.nodeapi.RPCException +import net.corda.nodeapi.internal.serialization.AbstractAMQPSerializationScheme +import net.corda.nodeapi.internal.serialization.EmptyWhitelist +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive +import net.corda.nodeapi.internal.serialization.amqp.custom.* import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP_PUBKEY +import org.apache.qpid.proton.amqp.* import org.apache.qpid.proton.codec.DecoderImpl import org.apache.qpid.proton.codec.EncoderImpl import org.junit.Test @@ -27,6 +34,14 @@ import kotlin.test.assertTrue class SerializationOutputTests { data class Foo(val bar: String, val pub: Int) + data class testFloat(val f: Float) + + data class testDouble(val d: Double) + + data class testShort(val s: Short) + + data class testBoolean(val b : Boolean) + interface FooInterface { val pub: Int } @@ -142,7 +157,6 @@ class SerializationOutputTests { // Check that a vanilla AMQP decoder can deserialize without schema. val result = decoder.readObject() as Envelope assertNotNull(result) - println(result.schema) val des = DeserializationInput(freshDeserializationFactory) val desObj = des.deserialize(bytes) @@ -159,12 +173,61 @@ class SerializationOutputTests { return desObj2 } + @Test + fun isPrimitive() { + assertTrue(isPrimitive(Character::class.java)) + assertTrue(isPrimitive(Boolean::class.java)) + assertTrue(isPrimitive(Byte::class.java)) + assertTrue(isPrimitive(UnsignedByte::class.java)) + assertTrue(isPrimitive(Short::class.java)) + assertTrue(isPrimitive(UnsignedShort::class.java)) + assertTrue(isPrimitive(Int::class.java)) + assertTrue(isPrimitive(UnsignedInteger::class.java)) + assertTrue(isPrimitive(Long::class.java)) + assertTrue(isPrimitive(UnsignedLong::class.java)) + assertTrue(isPrimitive(Float::class.java)) + assertTrue(isPrimitive(Double::class.java)) + assertTrue(isPrimitive(Decimal32::class.java)) + assertTrue(isPrimitive(Decimal64::class.java)) + assertTrue(isPrimitive(Decimal128::class.java)) + assertTrue(isPrimitive(Char::class.java)) + assertTrue(isPrimitive(Date::class.java)) + assertTrue(isPrimitive(UUID::class.java)) + assertTrue(isPrimitive(ByteArray::class.java)) + assertTrue(isPrimitive(String::class.java)) + assertTrue(isPrimitive(Symbol::class.java)) + } + @Test fun `test foo`() { val obj = Foo("Hello World!", 123) serdes(obj) } + @Test + fun `test float`() { + val obj = testFloat(10.0F) + serdes(obj) + } + + @Test + fun `test double`() { + val obj = testDouble(10.0) + serdes(obj) + } + + @Test + fun `test short`() { + val obj = testShort(1) + serdes(obj) + } + + @Test + fun `test bool`() { + val obj = testBoolean(true) + serdes(obj) + } + @Test fun `test foo implements`() { val obj = FooImplements("Hello World!", 123) @@ -177,7 +240,7 @@ class SerializationOutputTests { serdes(obj) } - @Test(expected = NotSerializableException::class) + @Test(expected = IllegalArgumentException::class) fun `test dislike of HashMap`() { val obj = WrapHashMap(HashMap()) serdes(obj) @@ -325,9 +388,9 @@ class SerializationOutputTests { @Test fun `test custom serializers on public key`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.PublicKeySerializer) + factory.register(PublicKeySerializer) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.PublicKeySerializer) + factory2.register(PublicKeySerializer) val obj = MEGA_CORP_PUBKEY serdes(obj, factory, factory2) } @@ -341,10 +404,10 @@ class SerializationOutputTests { @Test fun `test throwables serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + factory.register(ThrowableSerializer(factory)) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + factory2.register(ThrowableSerializer(factory2)) val t = IllegalAccessException("message").fillInStackTrace() val desThrowable = serdes(t, factory, factory2, false) as Throwable @@ -354,10 +417,10 @@ class SerializationOutputTests { @Test fun `test complex throwables serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + factory.register(ThrowableSerializer(factory)) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + factory2.register(ThrowableSerializer(factory2)) try { try { @@ -385,10 +448,10 @@ class SerializationOutputTests { @Test fun `test suppressed throwables serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + factory.register(ThrowableSerializer(factory)) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + factory2.register(ThrowableSerializer(factory2)) try { try { @@ -407,10 +470,10 @@ class SerializationOutputTests { @Test fun `test flow corda exception subclasses serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + factory.register(ThrowableSerializer(factory)) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + factory2.register(ThrowableSerializer(factory2)) val obj = FlowException("message").fillInStackTrace() serdes(obj, factory, factory2) @@ -419,10 +482,10 @@ class SerializationOutputTests { @Test fun `test RPC corda exception subclasses serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory)) + factory.register(ThrowableSerializer(factory)) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.ThrowableSerializer(factory2)) + factory2.register(ThrowableSerializer(factory2)) val obj = RPCException("message").fillInStackTrace() serdes(obj, factory, factory2) @@ -446,7 +509,7 @@ class SerializationOutputTests { } object FooContract : Contract { - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { } @@ -465,10 +528,10 @@ class SerializationOutputTests { val state = TransactionState(FooState(), MEGA_CORP) val factory = SerializerFactory() - KryoAMQPSerializer.registerCustomSerializers(factory) + AbstractAMQPSerializationScheme.registerCustomSerializers(factory) val factory2 = SerializerFactory() - KryoAMQPSerializer.registerCustomSerializers(factory2) + AbstractAMQPSerializationScheme.registerCustomSerializers(factory2) val desState = serdes(state, factory, factory2, expectedEqual = false, expectDeserializedEqual = false) assertTrue(desState is TransactionState<*>) @@ -480,10 +543,10 @@ class SerializationOutputTests { @Test fun `test currencies serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.CurrencySerializer) + factory.register(CurrencySerializer) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.CurrencySerializer) + factory2.register(CurrencySerializer) val obj = Currency.getInstance("USD") serdes(obj, factory, factory2) @@ -492,10 +555,10 @@ class SerializationOutputTests { @Test fun `test big decimals serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.BigDecimalSerializer) + factory.register(BigDecimalSerializer) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.BigDecimalSerializer) + factory2.register(BigDecimalSerializer) val obj = BigDecimal("100000000000000000000000000000.00") serdes(obj, factory, factory2) @@ -504,10 +567,10 @@ class SerializationOutputTests { @Test fun `test instants serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.InstantSerializer(factory)) + factory.register(InstantSerializer(factory)) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.InstantSerializer(factory2)) + factory2.register(InstantSerializer(factory2)) val obj = Instant.now() serdes(obj, factory, factory2) @@ -516,12 +579,12 @@ class SerializationOutputTests { @Test fun `test StateRef serialize`() { val factory = SerializerFactory() - factory.register(net.corda.core.serialization.amqp.custom.InstantSerializer(factory)) + factory.register(InstantSerializer(factory)) val factory2 = SerializerFactory() - factory2.register(net.corda.core.serialization.amqp.custom.InstantSerializer(factory2)) + factory2.register(InstantSerializer(factory2)) val obj = StateRef(SecureHash.randomSHA256(), 0) serdes(obj, factory, factory2) } -} \ No newline at end of file +} diff --git a/core/src/test/kotlin/net/corda/core/serialization/carpenter/ClassCarpenterTest.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt similarity index 68% rename from core/src/test/kotlin/net/corda/core/serialization/carpenter/ClassCarpenterTest.kt rename to node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt index 6dc4f5b12a..9ffef1ebd1 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/carpenter/ClassCarpenterTest.kt +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTest.kt @@ -1,12 +1,13 @@ -package net.corda.core.serialization.carpenter - +package net.corda.nodeapi.internal.serialization.carpenter import org.junit.Test +import java.beans.Introspector import java.lang.reflect.Field import java.lang.reflect.Method +import javax.annotation.Nonnull +import javax.annotation.Nullable import kotlin.test.assertEquals -import kotlin.test.assertTrue - +import kotlin.test.assertNotEquals class ClassCarpenterTest { interface DummyInterface { @@ -23,7 +24,7 @@ class ClassCarpenterTest { @Test fun empty() { - val clazz = cc.build(ClassCarpenter.ClassSchema("gen.EmptyClass", emptyMap(), null)) + val clazz = cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null)) assertEquals(0, clazz.nonSyntheticFields.size) assertEquals(2, clazz.nonSyntheticMethods.size) // get, toString assertEquals(0, clazz.declaredConstructors[0].parameterCount) @@ -32,7 +33,7 @@ class ClassCarpenterTest { @Test fun prims() { - val clazz = cc.build(ClassCarpenter.ClassSchema( + val clazz = cc.build(ClassSchema( "gen.Prims", mapOf( "anIntField" to Int::class.javaPrimitiveType!!, @@ -43,7 +44,7 @@ class ClassCarpenterTest { "floatMyBoat" to Float::class.javaPrimitiveType!!, "byteMe" to Byte::class.javaPrimitiveType!!, "booleanField" to Boolean::class.javaPrimitiveType!!).mapValues { - ClassCarpenter.NonNullableField (it.value) + NonNullableField(it.value) })) assertEquals(8, clazz.nonSyntheticFields.size) assertEquals(10, clazz.nonSyntheticMethods.size) @@ -70,10 +71,10 @@ class ClassCarpenterTest { } private fun genPerson(): Pair, Any> { - val clazz = cc.build(ClassCarpenter.ClassSchema("gen.Person", mapOf( + val clazz = cc.build(ClassSchema("gen.Person", mapOf( "age" to Int::class.javaPrimitiveType!!, "name" to String::class.java - ).mapValues { ClassCarpenter.NonNullableField (it.value) } )) + ).mapValues { NonNullableField(it.value) })) val i = clazz.constructors[0].newInstance(32, "Mike") return Pair(clazz, i) } @@ -91,17 +92,17 @@ class ClassCarpenterTest { assertEquals("Person{age=32, name=Mike}", i.toString()) } - @Test(expected = ClassCarpenter.DuplicateNameException::class) + @Test(expected = DuplicateNameException::class) fun duplicates() { - cc.build(ClassCarpenter.ClassSchema("gen.EmptyClass", emptyMap(), null)) - cc.build(ClassCarpenter.ClassSchema("gen.EmptyClass", emptyMap(), null)) + cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null)) + cc.build(ClassSchema("gen.EmptyClass", emptyMap(), null)) } @Test fun `can refer to each other`() { val (clazz1, i) = genPerson() - val clazz2 = cc.build(ClassCarpenter.ClassSchema("gen.Referee", mapOf( - "ref" to ClassCarpenter.NonNullableField (clazz1) + val clazz2 = cc.build(ClassSchema("gen.Referee", mapOf( + "ref" to NonNullableField(clazz1) ))) val i2 = clazz2.constructors[0].newInstance(i) assertEquals(i, (i2 as SimpleFieldAccess)["ref"]) @@ -109,13 +110,13 @@ class ClassCarpenterTest { @Test fun superclasses() { - val schema1 = ClassCarpenter.ClassSchema( + val schema1 = ClassSchema( "gen.A", - mapOf("a" to ClassCarpenter.NonNullableField (String::class.java))) + mapOf("a" to NonNullableField(String::class.java))) - val schema2 = ClassCarpenter.ClassSchema( + val schema2 = ClassSchema( "gen.B", - mapOf("b" to ClassCarpenter.NonNullableField (String::class.java)), + mapOf("b" to NonNullableField(String::class.java)), schema1) val clazz = cc.build(schema2) @@ -127,29 +128,30 @@ class ClassCarpenterTest { @Test fun interfaces() { - val schema1 = ClassCarpenter.ClassSchema( + val schema1 = ClassSchema( "gen.A", - mapOf("a" to ClassCarpenter.NonNullableField(String::class.java))) + mapOf("a" to NonNullableField(String::class.java))) - val schema2 = ClassCarpenter.ClassSchema("gen.B", - mapOf("b" to ClassCarpenter.NonNullableField(Int::class.java)), + val schema2 = ClassSchema("gen.B", + mapOf("b" to NonNullableField(Int::class.java)), schema1, interfaces = listOf(DummyInterface::class.java)) + val clazz = cc.build(schema2) val i = clazz.constructors[0].newInstance("xa", 1) as DummyInterface assertEquals("xa", i.a) assertEquals(1, i.b) } - @Test(expected = ClassCarpenter.InterfaceMismatchException::class) + @Test(expected = InterfaceMismatchException::class) fun `mismatched interface`() { - val schema1 = ClassCarpenter.ClassSchema( + val schema1 = ClassSchema( "gen.A", - mapOf("a" to ClassCarpenter.NonNullableField(String::class.java))) + mapOf("a" to NonNullableField(String::class.java))) - val schema2 = ClassCarpenter.ClassSchema( + val schema2 = ClassSchema( "gen.B", - mapOf("c" to ClassCarpenter.NonNullableField(Int::class.java)), + mapOf("c" to NonNullableField(Int::class.java)), schema1, interfaces = listOf(DummyInterface::class.java)) @@ -160,9 +162,9 @@ class ClassCarpenterTest { @Test fun `generate interface`() { - val schema1 = ClassCarpenter.InterfaceSchema( + val schema1 = InterfaceSchema( "gen.Interface", - mapOf("a" to ClassCarpenter.NonNullableField (Int::class.java))) + mapOf("a" to NonNullableField(Int::class.java))) val iface = cc.build(schema1) @@ -171,9 +173,9 @@ class ClassCarpenterTest { assertEquals(iface.declaredMethods.size, 1) assertEquals(iface.declaredMethods[0].name, "getA") - val schema2 = ClassCarpenter.ClassSchema( + val schema2 = ClassSchema( "gen.Derived", - mapOf("a" to ClassCarpenter.NonNullableField (Int::class.java)), + mapOf("a" to NonNullableField(Int::class.java)), interfaces = listOf(iface)) val clazz = cc.build(schema2) @@ -185,25 +187,25 @@ class ClassCarpenterTest { @Test fun `generate multiple interfaces`() { - val iFace1 = ClassCarpenter.InterfaceSchema( + val iFace1 = InterfaceSchema( "gen.Interface1", mapOf( - "a" to ClassCarpenter.NonNullableField(Int::class.java), - "b" to ClassCarpenter.NonNullableField(String::class.java))) + "a" to NonNullableField(Int::class.java), + "b" to NonNullableField(String::class.java))) - val iFace2 = ClassCarpenter.InterfaceSchema( + val iFace2 = InterfaceSchema( "gen.Interface2", mapOf( - "c" to ClassCarpenter.NonNullableField(Int::class.java), - "d" to ClassCarpenter.NonNullableField(String::class.java))) + "c" to NonNullableField(Int::class.java), + "d" to NonNullableField(String::class.java))) - val class1 = ClassCarpenter.ClassSchema( + val class1 = ClassSchema( "gen.Derived", mapOf( - "a" to ClassCarpenter.NonNullableField(Int::class.java), - "b" to ClassCarpenter.NonNullableField(String::class.java), - "c" to ClassCarpenter.NonNullableField(Int::class.java), - "d" to ClassCarpenter.NonNullableField(String::class.java)), + "a" to NonNullableField(Int::class.java), + "b" to NonNullableField(String::class.java), + "c" to NonNullableField(Int::class.java), + "d" to NonNullableField(String::class.java)), interfaces = listOf(cc.build(iFace1), cc.build(iFace2))) val clazz = cc.build(class1) @@ -221,26 +223,26 @@ class ClassCarpenterTest { @Test fun `interface implementing interface`() { - val iFace1 = ClassCarpenter.InterfaceSchema( + val iFace1 = InterfaceSchema( "gen.Interface1", mapOf( - "a" to ClassCarpenter.NonNullableField (Int::class.java), - "b" to ClassCarpenter.NonNullableField(String::class.java))) + "a" to NonNullableField(Int::class.java), + "b" to NonNullableField(String::class.java))) - val iFace2 = ClassCarpenter.InterfaceSchema( + val iFace2 = InterfaceSchema( "gen.Interface2", mapOf( - "c" to ClassCarpenter.NonNullableField(Int::class.java), - "d" to ClassCarpenter.NonNullableField(String::class.java)), + "c" to NonNullableField(Int::class.java), + "d" to NonNullableField(String::class.java)), interfaces = listOf(cc.build(iFace1))) - val class1 = ClassCarpenter.ClassSchema( + val class1 = ClassSchema( "gen.Derived", mapOf( - "a" to ClassCarpenter.NonNullableField(Int::class.java), - "b" to ClassCarpenter.NonNullableField(String::class.java), - "c" to ClassCarpenter.NonNullableField(Int::class.java), - "d" to ClassCarpenter.NonNullableField(String::class.java)), + "a" to NonNullableField(Int::class.java), + "b" to NonNullableField(String::class.java), + "c" to NonNullableField(Int::class.java), + "d" to NonNullableField(String::class.java)), interfaces = listOf(cc.build(iFace2))) val clazz = cc.build(class1) @@ -259,22 +261,21 @@ class ClassCarpenterTest { @Test(expected = java.lang.IllegalArgumentException::class) fun `null parameter small int`() { val className = "iEnjoySwede" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NonNullableField (Int::class.java))) + mapOf("a" to NonNullableField(Int::class.java))) val clazz = cc.build(schema) - val a : Int? = null clazz.constructors[0].newInstance(a) } - @Test(expected = ClassCarpenter.NullablePrimitiveException::class) + @Test(expected = NullablePrimitiveException::class) fun `nullable parameter small int`() { val className = "iEnjoySwede" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NullableField (Int::class.java))) + mapOf("a" to NullableField(Int::class.java))) cc.build(schema) } @@ -282,9 +283,9 @@ class ClassCarpenterTest { @Test fun `nullable parameter integer`() { val className = "iEnjoyWibble" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NullableField (Integer::class.java))) + mapOf("a" to NullableField(Integer::class.java))) val clazz = cc.build(schema) val a1 : Int? = null @@ -297,9 +298,9 @@ class ClassCarpenterTest { @Test fun `non nullable parameter integer with non null`() { val className = "iEnjoyWibble" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NonNullableField (Integer::class.java))) + mapOf("a" to NonNullableField(Integer::class.java))) val clazz = cc.build(schema) @@ -310,9 +311,9 @@ class ClassCarpenterTest { @Test(expected = java.lang.reflect.InvocationTargetException::class) fun `non nullable parameter integer with null`() { val className = "iEnjoyWibble" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NonNullableField (Integer::class.java))) + mapOf("a" to NonNullableField(Integer::class.java))) val clazz = cc.build(schema) @@ -324,9 +325,9 @@ class ClassCarpenterTest { @Suppress("UNCHECKED_CAST") fun `int array`() { val className = "iEnjoyPotato" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NonNullableField(IntArray::class.java))) + mapOf("a" to NonNullableField(IntArray::class.java))) val clazz = cc.build(schema) @@ -343,9 +344,9 @@ class ClassCarpenterTest { @Test(expected = java.lang.reflect.InvocationTargetException::class) fun `nullable int array throws`() { val className = "iEnjoySwede" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NonNullableField(IntArray::class.java))) + mapOf("a" to NonNullableField(IntArray::class.java))) val clazz = cc.build(schema) @@ -357,14 +358,13 @@ class ClassCarpenterTest { @Suppress("UNCHECKED_CAST") fun `integer array`() { val className = "iEnjoyFlan" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NonNullableField(Array::class.java))) + mapOf("a" to NonNullableField(Array::class.java))) val clazz = cc.build(schema) val i = clazz.constructors[0].newInstance(arrayOf(1, 2, 3)) as SimpleFieldAccess - val arr = clazz.getMethod("getA").invoke(i) assertEquals(1, (arr as Array)[0]) @@ -377,21 +377,19 @@ class ClassCarpenterTest { @Suppress("UNCHECKED_CAST") fun `int array with ints`() { val className = "iEnjoyCrumble" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", mapOf( "a" to Int::class.java, "b" to IntArray::class.java, - "c" to Int::class.java).mapValues { ClassCarpenter.NonNullableField(it.value) }) + "c" to Int::class.java).mapValues { NonNullableField(it.value) }) val clazz = cc.build(schema) - val i = clazz.constructors[0].newInstance(2, intArrayOf(4, 8), 16) as SimpleFieldAccess assertEquals(2, clazz.getMethod("getA").invoke(i)) assertEquals(4, (clazz.getMethod("getB").invoke(i) as IntArray)[0]) assertEquals(8, (clazz.getMethod("getB").invoke(i) as IntArray)[1]) assertEquals(16, clazz.getMethod("getC").invoke(i)) - assertEquals("$className{a=2, b=[4, 8], c=16}", i.toString()) } @@ -399,11 +397,11 @@ class ClassCarpenterTest { @Suppress("UNCHECKED_CAST") fun `multiple int arrays`() { val className = "iEnjoyJam" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", mapOf( "a" to IntArray::class.java, "b" to Int::class.java, - "c" to IntArray::class.java).mapValues { ClassCarpenter.NonNullableField(it.value) }) + "c" to IntArray::class.java).mapValues { NonNullableField(it.value) }) val clazz = cc.build(schema) val i = clazz.constructors[0].newInstance(intArrayOf(1, 2), 3, intArrayOf(4, 5, 6)) @@ -414,7 +412,6 @@ class ClassCarpenterTest { assertEquals(4, (clazz.getMethod("getC").invoke(i) as IntArray)[0]) assertEquals(5, (clazz.getMethod("getC").invoke(i) as IntArray)[1]) assertEquals(6, (clazz.getMethod("getC").invoke(i) as IntArray)[2]) - assertEquals("$className{a=[1, 2], b=3, c=[4, 5, 6]}", i.toString()) } @@ -422,9 +419,9 @@ class ClassCarpenterTest { @Suppress("UNCHECKED_CAST") fun `string array`() { val className = "iEnjoyToast" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NullableField(Array::class.java))) + mapOf("a" to NullableField(Array::class.java))) val clazz = cc.build(schema) @@ -440,12 +437,12 @@ class ClassCarpenterTest { @Suppress("UNCHECKED_CAST") fun `string arrays`() { val className = "iEnjoyToast" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", mapOf( "a" to Array::class.java, "b" to String::class.java, - "c" to Array::class.java).mapValues { ClassCarpenter.NullableField (it.value) }) + "c" to Array::class.java).mapValues { NullableField(it.value) }) val clazz = cc.build(schema) @@ -454,7 +451,6 @@ class ClassCarpenterTest { "and on the side", arrayOf("some pickles", "some fries")) - val arr1 = clazz.getMethod("getA").invoke(i) as Array val arr2 = clazz.getMethod("getC").invoke(i) as Array @@ -469,26 +465,34 @@ class ClassCarpenterTest { @Test fun `nullable sets annotations`() { val className = "iEnjoyJam" - val schema = ClassCarpenter.ClassSchema( + val schema = ClassSchema( "gen.$className", - mapOf("a" to ClassCarpenter.NullableField(String::class.java), - "b" to ClassCarpenter.NonNullableField(String::class.java))) + mapOf("a" to NullableField(String::class.java), + "b" to NonNullableField(String::class.java))) val clazz = cc.build(schema) assertEquals (2, clazz.declaredFields.size) - assertEquals (1, clazz.getDeclaredField("a").annotations.size) - assertEquals (javax.annotation.Nullable::class.java, clazz.getDeclaredField("a").annotations[0].annotationClass.java) - + assertEquals(Nullable::class.java, clazz.getDeclaredField("a").annotations[0].annotationClass.java) assertEquals (1, clazz.getDeclaredField("b").annotations.size) - assertEquals (javax.annotation.Nonnull::class.java, clazz.getDeclaredField("b").annotations[0].annotationClass.java) - + assertEquals(Nonnull::class.java, clazz.getDeclaredField("b").annotations[0].annotationClass.java) assertEquals (1, clazz.getMethod("getA").annotations.size) - assertEquals (javax.annotation.Nullable::class.java, clazz.getMethod("getA").annotations[0].annotationClass.java) - + assertEquals(Nullable::class.java, clazz.getMethod("getA").annotations[0].annotationClass.java) assertEquals (1, clazz.getMethod("getB").annotations.size) - assertEquals (javax.annotation.Nonnull::class.java, clazz.getMethod("getB").annotations[0].annotationClass.java) + assertEquals(Nonnull::class.java, clazz.getMethod("getB").annotations[0].annotationClass.java) } + @Test + fun beanTest() { + val schema = ClassSchema( + "pantsPantsPants", + mapOf("a" to NonNullableField(Integer::class.java))) + val clazz = cc.build(schema) + val descriptors = Introspector.getBeanInfo(clazz).propertyDescriptors + + assertEquals(2, descriptors.size) + assertNotEquals(null, descriptors.find { it.name == "a" }) + assertNotEquals(null, descriptors.find { it.name == "class" }) + } } diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt new file mode 100644 index 0000000000..3aae840918 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/ClassCarpenterTestUtils.kt @@ -0,0 +1,46 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.nodeapi.internal.serialization.amqp.Field +import net.corda.nodeapi.internal.serialization.amqp.Schema +import net.corda.nodeapi.internal.serialization.amqp.TypeNotation +import net.corda.nodeapi.internal.serialization.amqp.CompositeType +import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory +import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput + +fun mangleName(name: String) = "${name}__carpenter" + +/** + * given a list of class names work through the amqp envelope schema and alter any that + * match in the fashion defined above + */ +fun Schema.mangleNames(names: List): Schema { + val newTypes: MutableList = mutableListOf() + + for (type in types) { + val newName = if (type.name in names) mangleName(type.name) else type.name + val newProvides = type.provides.map { if (it in names) mangleName(it) else it } + val newFields = mutableListOf() + + (type as CompositeType).fields.forEach { + val fieldType = if (it.type in names) mangleName(it.type) else it.type + val requires = + if (it.requires.isNotEmpty() && (it.requires[0] in names)) listOf(mangleName(it.requires[0])) + else it.requires + + newFields.add(it.copy(type = fieldType, requires = requires)) + } + + newTypes.add(type.copy(name = newName, provides = newProvides, fields = newFields)) + } + + return Schema(types = newTypes) +} + +open class AmqpCarpenterBase { + var factory = SerializerFactory() + + fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz) + fun testName(): String = Thread.currentThread().stackTrace[2].methodName + @Suppress("NOTHING_TO_INLINE") + inline fun classTestName(clazz: String) = "${this.javaClass.name}\$${testName()}\$$clazz" +} diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt new file mode 100644 index 0000000000..5ec40e0f4e --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt @@ -0,0 +1,285 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.core.serialization.CordaSerializable +import net.corda.nodeapi.internal.serialization.amqp.CompositeType +import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +@CordaSerializable +interface I_ { + val a: Int +} + +class CompositeMembers : AmqpCarpenterBase() { + @Test + fun bothKnown() { + val testA = 10 + val testB = 20 + + @CordaSerializable + data class A(val a: Int) + + @CordaSerializable + data class B(val a: A, var b: Int) + + val b = B(A(testA), testB) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + + assert(obj.obj is B) + + val amqpObj = obj.obj as B + + assertEquals(testB, amqpObj.b) + assertEquals(testA, amqpObj.a.a) + assertEquals(2, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + assert(obj.envelope.schema.types[1] is CompositeType) + + var amqpSchemaA: CompositeType? = null + var amqpSchemaB: CompositeType? = null + + for (type in obj.envelope.schema.types) { + when (type.name.split ("$").last()) { + "A" -> amqpSchemaA = type as CompositeType + "B" -> amqpSchemaB = type as CompositeType + } + } + + assert(amqpSchemaA != null) + assert(amqpSchemaB != null) + + // Just ensure the amqp schema matches what we want before we go messing + // around with the internals + assertEquals(1, amqpSchemaA?.fields?.size) + assertEquals("a", amqpSchemaA!!.fields[0].name) + assertEquals("int", amqpSchemaA.fields[0].type) + + assertEquals(2, amqpSchemaB?.fields?.size) + assertEquals("a", amqpSchemaB!!.fields[0].name) + assertEquals(classTestName("A"), amqpSchemaB.fields[0].type) + assertEquals("b", amqpSchemaB.fields[1].name) + assertEquals("int", amqpSchemaB.fields[1].type) + + val metaSchema = obj.envelope.schema.carpenterSchema() + + // if we know all the classes there is nothing to really achieve here + assert(metaSchema.carpenterSchemas.isEmpty()) + assert(metaSchema.dependsOn.isEmpty()) + assert(metaSchema.dependencies.isEmpty()) + } + + // you cannot have an element of a composite class we know about + // that is unknown as that should be impossible. If we have the containing + // class in the class path then we must have all of it's constituent elements + @Test(expected = UncarpentableException::class) + fun nestedIsUnknown() { + val testA = 10 + val testB = 20 + + @CordaSerializable + data class A(override val a: Int) : I_ + + @CordaSerializable + data class B(val a: A, var b: Int) + + val b = B(A(testA), testB) + + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + val amqpSchema = obj.envelope.schema.mangleNames(listOf (classTestName ("A"))) + + assert(obj.obj is B) + + amqpSchema.carpenterSchema() + } + + @Test + fun ParentIsUnknown() { + val testA = 10 + val testB = 20 + + @CordaSerializable + data class A(override val a: Int) : I_ + + @CordaSerializable + data class B(val a: A, var b: Int) + + val b = B(A(testA), testB) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + + assert(obj.obj is B) + + val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("B"))) + val carpenterSchema = amqpSchema.carpenterSchema() + + assertEquals(1, carpenterSchema.size) + + val metaCarpenter = MetaCarpenter(carpenterSchema) + + metaCarpenter.build() + + assert(mangleName(classTestName("B")) in metaCarpenter.objects) + } + + @Test + fun BothUnknown() { + val testA = 10 + val testB = 20 + + @CordaSerializable + data class A(override val a: Int) : I_ + + @CordaSerializable + data class B(val a: A, var b: Int) + + val b = B(A(testA), testB) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + + assert(obj.obj is B) + + val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) + val carpenterSchema = amqpSchema.carpenterSchema() + + // just verify we're in the expected initial state, A is carpentable, B is not because + // it depends on A and the dependency chains are in place + assertEquals(1, carpenterSchema.size) + assertEquals(mangleName(classTestName("A")), carpenterSchema.carpenterSchemas.first().name) + assertEquals(1, carpenterSchema.dependencies.size) + assert(mangleName(classTestName("B")) in carpenterSchema.dependencies) + assertEquals(1, carpenterSchema.dependsOn.size) + assert(mangleName(classTestName("A")) in carpenterSchema.dependsOn) + + val metaCarpenter = TestMetaCarpenter(carpenterSchema) + + assertEquals(0, metaCarpenter.objects.size) + + // first iteration, carpent A, resolve deps and mark B as carpentable + metaCarpenter.build() + + // one build iteration should have carpetned up A and worked out that B is now buildable + // given it's depedencies have been satisfied + assertTrue(mangleName(classTestName("A")) in metaCarpenter.objects) + assertFalse(mangleName(classTestName("B")) in metaCarpenter.objects) + + assertEquals(1, carpenterSchema.carpenterSchemas.size) + assertEquals(mangleName(classTestName("B")), carpenterSchema.carpenterSchemas.first().name) + assertTrue(carpenterSchema.dependencies.isEmpty()) + assertTrue(carpenterSchema.dependsOn.isEmpty()) + + // second manual iteration, will carpent B + metaCarpenter.build() + assert(mangleName(classTestName("A")) in metaCarpenter.objects) + assert(mangleName(classTestName("B")) in metaCarpenter.objects) + + // and we must be finished + assertTrue(carpenterSchema.carpenterSchemas.isEmpty()) + } + + @Test(expected = UncarpentableException::class) + @Suppress("UNUSED") + fun nestedIsUnknownInherited() { + val testA = 10 + val testB = 20 + val testC = 30 + + @CordaSerializable + open class A(val a: Int) + + @CordaSerializable + class B(a: Int, var b: Int) : A(a) + + @CordaSerializable + data class C(val b: B, var c: Int) + + val c = C(B(testA, testB), testC) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c)) + + assert(obj.obj is C) + + val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) + + amqpSchema.carpenterSchema() + } + + @Test(expected = UncarpentableException::class) + @Suppress("UNUSED") + fun nestedIsUnknownInheritedUnknown() { + val testA = 10 + val testB = 20 + val testC = 30 + + @CordaSerializable + open class A(val a: Int) + + @CordaSerializable + class B(a: Int, var b: Int) : A(a) + + @CordaSerializable + data class C(val b: B, var c: Int) + + val c = C(B(testA, testB), testC) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c)) + + assert(obj.obj is C) + + val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) + + amqpSchema.carpenterSchema() + } + + @Suppress("UNUSED") + @Test(expected = UncarpentableException::class) + fun parentsIsUnknownWithUnknownInheritedMember() { + val testA = 10 + val testB = 20 + val testC = 30 + + @CordaSerializable + open class A(val a: Int) + + @CordaSerializable + class B(a: Int, var b: Int) : A(a) + + @CordaSerializable + data class C(val b: B, var c: Int) + + val c = C(B(testA, testB), testC) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(c)) + + assert(obj.obj is C) + + val carpenterSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) + TestMetaCarpenter(carpenterSchema.carpenterSchema()) + } + + /* + * TODO serializer doesn't support inheritnace at the moment, when it does this should work + @Test + fun `inheritance`() { + val testA = 10 + val testB = 20 + + @CordaSerializable + open class A(open val a: Int) + + @CordaSerializable + class B(override val a: Int, val b: Int) : A (a) + + val b = B(testA, testB) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + + assert(obj.obj is B) + + val carpenterSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B"))) + val metaCarpenter = TestMetaCarpenter(carpenterSchema.carpenterSchema()) + + assertEquals(1, metaCarpenter.schemas.carpenterSchemas.size) + assertEquals(mangleNames(classTestName("B")), metaCarpenter.schemas.carpenterSchemas.first().name) + assertEquals(1, metaCarpenter.schemas.dependencies.size) + assertTrue(mangleNames(classTestName("A")) in metaCarpenter.schemas.dependencies) + } + */ +} + diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/InheritanceSchemaToClassCarpenterTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/InheritanceSchemaToClassCarpenterTests.kt new file mode 100644 index 0000000000..8176823aa3 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/InheritanceSchemaToClassCarpenterTests.kt @@ -0,0 +1,459 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.core.serialization.CordaSerializable +import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput +import org.junit.Test +import kotlin.test.* + +@CordaSerializable +interface J { + val j: Int +} + +@CordaSerializable +interface I { + val i: Int +} + +@CordaSerializable +interface II { + val ii: Int +} + +@CordaSerializable +interface III : I { + val iii: Int + override val i: Int +} + +@CordaSerializable +interface IIII { + val iiii: Int + val i: I +} + +class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() { + @Test + fun interfaceParent1() { + class A(override val j: Int) : J + + val testJ = 20 + val a = A(testJ) + + assertEquals(testJ, a.j) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + assertTrue(obj.obj is A) + val serSchema = obj.envelope.schema + assertEquals(2, serSchema.types.size) + val l1 = serSchema.carpenterSchema() + + // since we're using an envelope generated by seilaising classes defined locally + // it's extremely unlikely we'd need to carpent any classes + assertEquals(0, l1.size) + + val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) + val l2 = mangleSchema.carpenterSchema() + assertEquals(1, l2.size) + + val aSchema = l2.carpenterSchemas.find { it.name == mangleName(classTestName("A")) } + assertNotEquals(null, aSchema) + assertEquals(mangleName(classTestName("A")), aSchema!!.name) + assertEquals(1, aSchema.interfaces.size) + assertEquals(J::class.java, aSchema.interfaces[0]) + + val aBuilder = ClassCarpenter().build(aSchema) + val objJ = aBuilder.constructors[0].newInstance(testJ) + val j = objJ as J + + assertEquals(aBuilder.getMethod("getJ").invoke(objJ), testJ) + assertEquals(a.j, j.j) + } + + @Test + fun interfaceParent2() { + class A(override val j: Int, val jj: Int) : J + + val testJ = 20 + val testJJ = 40 + val a = A(testJ, testJJ) + + assertEquals(testJ, a.j) + assertEquals(testJJ, a.jj) + + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assertTrue(obj.obj is A) + + val serSchema = obj.envelope.schema + + assertEquals(2, serSchema.types.size) + + val l1 = serSchema.carpenterSchema() + + assertEquals(0, l1.size) + + val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) + val aName = mangleName(classTestName("A")) + val l2 = mangleSchema.carpenterSchema() + + assertEquals(1, l2.size) + + val aSchema = l2.carpenterSchemas.find { it.name == aName } + + assertNotEquals(null, aSchema) + + assertEquals(aName, aSchema!!.name) + assertEquals(1, aSchema.interfaces.size) + assertEquals(J::class.java, aSchema.interfaces[0]) + + val aBuilder = ClassCarpenter().build(aSchema) + val objJ = aBuilder.constructors[0].newInstance(testJ, testJJ) + val j = objJ as J + + assertEquals(aBuilder.getMethod("getJ").invoke(objJ), testJ) + assertEquals(aBuilder.getMethod("getJj").invoke(objJ), testJJ) + + assertEquals(a.j, j.j) + } + + @Test + fun multipleInterfaces() { + val testI = 20 + val testII = 40 + + class A(override val i: Int, override val ii: Int) : I, II + + val a = A(testI, testII) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assertTrue(obj.obj is A) + + val serSchema = obj.envelope.schema + + assertEquals(3, serSchema.types.size) + + val l1 = serSchema.carpenterSchema() + + // since we're using an envelope generated by serialising classes defined locally + // it's extremely unlikely we'd need to carpent any classes + assertEquals(0, l1.size) + + // pretend we don't know the class we've been sent, i.e. it's unknown to the class loader, and thus + // needs some carpentry + val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) + val l2 = mangleSchema.carpenterSchema() + val aName = mangleName(classTestName("A")) + + assertEquals(1, l2.size) + + val aSchema = l2.carpenterSchemas.find { it.name == aName } + + assertNotEquals(null, aSchema) + assertEquals(aName, aSchema!!.name) + assertEquals(2, aSchema.interfaces.size) + assertTrue(I::class.java in aSchema.interfaces) + assertTrue(II::class.java in aSchema.interfaces) + + val aBuilder = ClassCarpenter().build(aSchema) + val objA = aBuilder.constructors[0].newInstance(testI, testII) + val i = objA as I + val ii = objA as II + + assertEquals(aBuilder.getMethod("getI").invoke(objA), testI) + assertEquals(aBuilder.getMethod("getIi").invoke(objA), testII) + assertEquals(a.i, i.i) + assertEquals(a.ii, ii.ii) + } + + @Test + fun nestedInterfaces() { + class A(override val i: Int, override val iii: Int) : III + + val testI = 20 + val testIII = 60 + val a = A(testI, testIII) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assertTrue(obj.obj is A) + + val serSchema = obj.envelope.schema + + assertEquals(3, serSchema.types.size) + + val l1 = serSchema.carpenterSchema() + + // since we're using an envelope generated by serialising classes defined locally + // it's extremely unlikely we'd need to carpent any classes + assertEquals(0, l1.size) + + val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"))) + val l2 = mangleSchema.carpenterSchema() + val aName = mangleName(classTestName("A")) + + assertEquals(1, l2.size) + + val aSchema = l2.carpenterSchemas.find { it.name == aName } + + assertNotEquals(null, aSchema) + assertEquals(aName, aSchema!!.name) + assertEquals(2, aSchema.interfaces.size) + assertTrue(I::class.java in aSchema.interfaces) + assertTrue(III::class.java in aSchema.interfaces) + + val aBuilder = ClassCarpenter().build(aSchema) + val objA = aBuilder.constructors[0].newInstance(testI, testIII) + val i = objA as I + val iii = objA as III + + assertEquals(aBuilder.getMethod("getI").invoke(objA), testI) + assertEquals(aBuilder.getMethod("getIii").invoke(objA), testIII) + assertEquals(a.i, i.i) + assertEquals(a.i, iii.i) + assertEquals(a.iii, iii.iii) + } + + @Test + fun memberInterface() { + class A(override val i: Int) : I + class B(override val i: I, override val iiii: Int) : IIII + + val testI = 25 + val testIIII = 50 + val a = A(testI) + val b = B(a, testIIII) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + + assertTrue(obj.obj is B) + + val serSchema = obj.envelope.schema + + // Expected classes are + // * class A + // * class A's interface (class I) + // * class B + // * class B's interface (class IIII) + assertEquals(4, serSchema.types.size) + + val mangleSchema = serSchema.mangleNames(listOf(classTestName("A"), classTestName("B"))) + val cSchema = mangleSchema.carpenterSchema() + val aName = mangleName(classTestName("A")) + val bName = mangleName(classTestName("B")) + + assertEquals(2, cSchema.size) + + val aCarpenterSchema = cSchema.carpenterSchemas.find { it.name == aName } + val bCarpenterSchema = cSchema.carpenterSchemas.find { it.name == bName } + + assertNotEquals(null, aCarpenterSchema) + assertNotEquals(null, bCarpenterSchema) + + val cc = ClassCarpenter() + val cc2 = ClassCarpenter() + val bBuilder = cc.build(bCarpenterSchema!!) + bBuilder.constructors[0].newInstance(a, testIIII) + + val aBuilder = cc.build(aCarpenterSchema!!) + val objA = aBuilder.constructors[0].newInstance(testI) + + // build a second B this time using our constructed instance of A and not the + // local one we pre defined + bBuilder.constructors[0].newInstance(objA, testIIII) + + // whittle and instantiate a different A with a new class loader + val aBuilder2 = cc2.build(aCarpenterSchema) + val objA2 = aBuilder2.constructors[0].newInstance(testI) + + bBuilder.constructors[0].newInstance(objA2, testIIII) + } + + // if we remove the nested interface we should get an error as it's impossible + // to have a concrete class loaded without having access to all of it's elements + @Test(expected = UncarpentableException::class) + fun memberInterface2() { + class A(override val i: Int) : I + class B(override val i: I, override val iiii: Int) : IIII + + val testI = 25 + val testIIII = 50 + val a = A(testI) + val b = B(a, testIIII) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b)) + + assertTrue(obj.obj is B) + + val serSchema = obj.envelope.schema + + // The classes we're expecting to find: + // * class A + // * class A's interface (class I) + // * class B + // * class B's interface (class IIII) + assertEquals(4, serSchema.types.size) + + // ignore the return as we expect this to throw + serSchema.mangleNames(listOf( + classTestName("A"), "${this.javaClass.`package`.name}.I")).carpenterSchema() + } + + @Test + fun interfaceAndImplementation() { + class A(override val i: Int) : I + + val testI = 25 + val a = A(testI) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assertTrue(obj.obj is A) + + val serSchema = obj.envelope.schema + + // The classes we're expecting to find: + // * class A + // * class A's interface (class I) + assertEquals(2, serSchema.types.size) + + val amqpSchema = serSchema.mangleNames(listOf(classTestName("A"), "${this.javaClass.`package`.name}.I")) + val aName = mangleName(classTestName("A")) + val iName = mangleName("${this.javaClass.`package`.name}.I") + val carpenterSchema = amqpSchema.carpenterSchema() + + // whilst there are two unknown classes within the envelope A depends on I so we can't construct a + // schema for A until we have for I + assertEquals(1, carpenterSchema.size) + assertNotEquals(null, carpenterSchema.carpenterSchemas.find { it.name == iName }) + + // since we can't build A it should list I as a dependency + assertTrue(aName in carpenterSchema.dependencies) + assertEquals(1, carpenterSchema.dependencies[aName]!!.second.size) + assertEquals(iName, carpenterSchema.dependencies[aName]!!.second[0]) + + // and conversly I should have A listed as a dependent + assertTrue(iName in carpenterSchema.dependsOn) + assertEquals(1, carpenterSchema.dependsOn[iName]!!.size) + assertEquals(aName, carpenterSchema.dependsOn[iName]!![0]) + + val mc = MetaCarpenter(carpenterSchema) + mc.build() + + assertEquals(0, mc.schemas.carpenterSchemas.size) + assertEquals(0, mc.schemas.dependencies.size) + assertEquals(0, mc.schemas.dependsOn.size) + assertEquals(2, mc.objects.size) + assertTrue(aName in mc.objects) + assertTrue(iName in mc.objects) + + mc.objects[aName]!!.constructors[0].newInstance(testI) + } + + @Test + fun twoInterfacesAndImplementation() { + class A(override val i: Int, override val ii: Int) : I, II + + val testI = 69 + val testII = 96 + val a = A(testI, testII) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + val amqpSchema = obj.envelope.schema.mangleNames(listOf( + classTestName("A"), + "${this.javaClass.`package`.name}.I", + "${this.javaClass.`package`.name}.II")) + + val aName = mangleName(classTestName("A")) + val iName = mangleName("${this.javaClass.`package`.name}.I") + val iiName = mangleName("${this.javaClass.`package`.name}.II") + val carpenterSchema = amqpSchema.carpenterSchema() + + // there is nothing preventing us from carpenting up the two interfaces so + // our initial list should contain both interface with A being dependent on both + // and each having A as a dependent + assertEquals(2, carpenterSchema.carpenterSchemas.size) + assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iName }) + assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iiName }) + assertNull(carpenterSchema.carpenterSchemas.find { it.name == aName }) + + assertTrue(iName in carpenterSchema.dependsOn) + assertEquals(1, carpenterSchema.dependsOn[iName]?.size) + assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == aName })) + + assertTrue(iiName in carpenterSchema.dependsOn) + assertEquals(1, carpenterSchema.dependsOn[iiName]?.size) + assertNotNull(carpenterSchema.dependsOn[iiName]?.find { it == aName }) + + assertTrue(aName in carpenterSchema.dependencies) + assertEquals(2, carpenterSchema.dependencies[aName]!!.second.size) + assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName }) + assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiName }) + + val mc = MetaCarpenter(carpenterSchema) + mc.build() + + assertEquals(0, mc.schemas.carpenterSchemas.size) + assertEquals(0, mc.schemas.dependencies.size) + assertEquals(0, mc.schemas.dependsOn.size) + assertEquals(3, mc.objects.size) + assertTrue(aName in mc.objects) + assertTrue(iName in mc.objects) + assertTrue(iiName in mc.objects) + } + + @Test + fun nestedInterfacesAndImplementation() { + class A(override val i: Int, override val iii: Int) : III + + val testI = 7 + val testIII = 11 + val a = A(testI, testIII) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + val amqpSchema = obj.envelope.schema.mangleNames(listOf( + classTestName("A"), + "${this.javaClass.`package`.name}.I", + "${this.javaClass.`package`.name}.III")) + + val aName = mangleName(classTestName("A")) + val iName = mangleName("${this.javaClass.`package`.name}.I") + val iiiName = mangleName("${this.javaClass.`package`.name}.III") + val carpenterSchema = amqpSchema.carpenterSchema() + + // Since A depends on III and III extends I we will have to construct them + // in that reverse order (I -> III -> A) + assertEquals(1, carpenterSchema.carpenterSchemas.size) + assertNotNull(carpenterSchema.carpenterSchemas.find { it.name == iName }) + assertNull(carpenterSchema.carpenterSchemas.find { it.name == iiiName }) + assertNull(carpenterSchema.carpenterSchemas.find { it.name == aName }) + + // I has III as a direct dependent and A as an indirect one + assertTrue(iName in carpenterSchema.dependsOn) + assertEquals(2, carpenterSchema.dependsOn[iName]?.size) + assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == iiiName })) + assertNotNull(carpenterSchema.dependsOn[iName]?.find({ it == aName })) + + // III has A as a dependent + assertTrue(iiiName in carpenterSchema.dependsOn) + assertEquals(1, carpenterSchema.dependsOn[iiiName]?.size) + assertNotNull(carpenterSchema.dependsOn[iiiName]?.find { it == aName }) + + // conversly III depends on I + assertTrue(iiiName in carpenterSchema.dependencies) + assertEquals(1, carpenterSchema.dependencies[iiiName]!!.second.size) + assertNotNull(carpenterSchema.dependencies[iiiName]!!.second.find { it == iName }) + + // and A depends on III and I + assertTrue(aName in carpenterSchema.dependencies) + assertEquals(2, carpenterSchema.dependencies[aName]!!.second.size) + assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiiName }) + assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName }) + + val mc = MetaCarpenter(carpenterSchema) + mc.build() + + assertEquals(0, mc.schemas.carpenterSchemas.size) + assertEquals(0, mc.schemas.dependencies.size) + assertEquals(0, mc.schemas.dependsOn.size) + assertEquals(3, mc.objects.size) + assertTrue(aName in mc.objects) + assertTrue(iName in mc.objects) + assertTrue(iiiName in mc.objects) + } +} diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt new file mode 100644 index 0000000000..d823e5a8e7 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt @@ -0,0 +1,95 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.core.serialization.CordaSerializable +import net.corda.nodeapi.internal.serialization.amqp.CompositeType +import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals + +class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() { + + @Test + fun twoInts() { + @CordaSerializable + data class A(val a: Int, val b: Int) + + val testA = 10 + val testB = 20 + val a = A(testA, testB) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(testA, amqpObj.a) + assertEquals(testB, amqpObj.b) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + + assertEquals(2, amqpSchema.fields.size) + assertEquals("a", amqpSchema.fields[0].name) + assertEquals("int", amqpSchema.fields[0].type) + assertEquals("b", amqpSchema.fields[1].name) + assertEquals("int", amqpSchema.fields[1].type) + + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + assertEquals(1, carpenterSchema.size) + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") } + + assertNotEquals(null, aSchema) + + val pinochio = ClassCarpenter().build(aSchema!!) + val p = pinochio.constructors[0].newInstance(testA, testB) + + assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a) + assertEquals(pinochio.getMethod("getB").invoke(p), amqpObj.b) + } + + @Test + fun intAndStr() { + @CordaSerializable + data class A(val a: Int, val b: String) + + val testA = 10 + val testB = "twenty" + val a = A(testA, testB) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(testA, amqpObj.a) + assertEquals(testB, amqpObj.b) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + + assertEquals(2, amqpSchema.fields.size) + assertEquals("a", amqpSchema.fields[0].name) + assertEquals("int", amqpSchema.fields[0].type) + assertEquals("b", amqpSchema.fields[1].name) + assertEquals("string", amqpSchema.fields[1].type) + + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + assertEquals(1, carpenterSchema.size) + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") } + + assertNotEquals(null, aSchema) + + val pinochio = ClassCarpenter().build(aSchema!!) + val p = pinochio.constructors[0].newInstance(testA, testB) + + assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a) + assertEquals(pinochio.getMethod("getB").invoke(p), amqpObj.b) + } + +} + diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SingleMemberCompositeSchemaToClassCarpenterTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SingleMemberCompositeSchemaToClassCarpenterTests.kt new file mode 100644 index 0000000000..c68e222568 --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/carpenter/SingleMemberCompositeSchemaToClassCarpenterTests.kt @@ -0,0 +1,197 @@ +package net.corda.nodeapi.internal.serialization.carpenter + +import net.corda.core.serialization.CordaSerializable +import net.corda.nodeapi.internal.serialization.amqp.CompositeType +import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput +import org.junit.Test +import kotlin.test.assertEquals + +class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() { + @Test + fun singleInteger() { + @CordaSerializable + data class A(val a: Int) + + val test = 10 + val a = A(test) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(test, amqpObj.a) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + + assertEquals(1, amqpSchema.fields.size) + assertEquals("a", amqpSchema.fields[0].name) + assertEquals("int", amqpSchema.fields[0].type) + + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! + val aBuilder = ClassCarpenter().build(aSchema) + val p = aBuilder.constructors[0].newInstance(test) + + assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) + } + + @Test + fun singleString() { + @CordaSerializable + data class A(val a: String) + + val test = "ten" + val a = A(test) + + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(test, amqpObj.a) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! + val aBuilder = ClassCarpenter().build(aSchema) + val p = aBuilder.constructors[0].newInstance(test) + + assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) + } + + @Test + fun singleLong() { + @CordaSerializable + data class A(val a: Long) + + val test = 10L + val a = A(test) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(test, amqpObj.a) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + + assertEquals(1, amqpSchema.fields.size) + assertEquals("a", amqpSchema.fields[0].name) + assertEquals("long", amqpSchema.fields[0].type) + + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! + val aBuilder = ClassCarpenter().build(aSchema) + val p = aBuilder.constructors[0].newInstance(test) + + assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) + } + + @Test + fun singleShort() { + @CordaSerializable + data class A(val a: Short) + + val test = 10.toShort() + val a = A(test) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(test, amqpObj.a) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + + assertEquals(1, amqpSchema.fields.size) + assertEquals("a", amqpSchema.fields[0].name) + assertEquals("short", amqpSchema.fields[0].type) + + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! + val aBuilder = ClassCarpenter().build(aSchema) + val p = aBuilder.constructors[0].newInstance(test) + + assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) + } + + @Test + fun singleDouble() { + @CordaSerializable + data class A(val a: Double) + + val test = 10.0 + val a = A(test) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(test, amqpObj.a) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + + assertEquals(1, amqpSchema.fields.size) + assertEquals("a", amqpSchema.fields[0].name) + assertEquals("double", amqpSchema.fields[0].type) + + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! + val aBuilder = ClassCarpenter().build(aSchema) + val p = aBuilder.constructors[0].newInstance(test) + + assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) + } + + @Test + fun singleFloat() { + @CordaSerializable + data class A(val a: Float) + + val test: Float = 10.0F + val a = A(test) + val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a)) + + assert(obj.obj is A) + val amqpObj = obj.obj as A + + assertEquals(test, amqpObj.a) + assertEquals(1, obj.envelope.schema.types.size) + assert(obj.envelope.schema.types[0] is CompositeType) + + val amqpSchema = obj.envelope.schema.types[0] as CompositeType + + assertEquals(1, amqpSchema.fields.size) + assertEquals("a", amqpSchema.fields[0].name) + assertEquals("float", amqpSchema.fields[0].type) + + val carpenterSchema = CarpenterSchemas.newInstance() + amqpSchema.carpenterSchema(carpenterSchemas = carpenterSchema, force = true) + + val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!! + val aBuilder = ClassCarpenter().build(aSchema) + val p = aBuilder.constructors[0].newInstance(test) + + assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a) + } +} diff --git a/node-api/src/test/resources/net/corda/nodeapi/isolated.jar b/node-api/src/test/resources/net/corda/nodeapi/isolated.jar new file mode 100644 index 0000000000..567a93ed9c Binary files /dev/null and b/node-api/src/test/resources/net/corda/nodeapi/isolated.jar differ diff --git a/node-schemas/build.gradle b/node-schemas/build.gradle index 88a9102a8b..7f015714a1 100644 --- a/node-schemas/build.gradle +++ b/node-schemas/build.gradle @@ -1,7 +1,7 @@ +apply plugin: 'net.corda.plugins.publish-utils' apply plugin: 'kotlin' apply plugin: 'kotlin-kapt' apply plugin: 'idea' -apply plugin: 'net.corda.plugins.publish-utils' apply plugin: 'com.jfrog.artifactory' description 'Corda node database schemas' @@ -33,5 +33,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } \ No newline at end of file diff --git a/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt b/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt index 58a7b3b95d..554353c29e 100644 --- a/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt +++ b/node-schemas/src/test/kotlin/net/corda/node/services/vault/schemas/VaultSchemaTest.kt @@ -8,9 +8,8 @@ import io.requery.rx.KotlinRxEntityStore import io.requery.sql.* import io.requery.sql.platform.Generic import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyContract -import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.toBase58String import net.corda.core.identity.AbstractParty @@ -25,7 +24,8 @@ import net.corda.node.services.vault.schemas.requery.* import net.corda.testing.ALICE import net.corda.testing.BOB import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY +import net.corda.testing.TestDependencyInjectionBase +import net.corda.testing.contracts.DummyContract import org.h2.jdbcx.JdbcDataSource import org.junit.After import org.junit.Assert @@ -40,7 +40,7 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue -class VaultSchemaTest { +class VaultSchemaTest : TestDependencyInjectionBase() { var instance: KotlinEntityDataStore? = null val data: KotlinEntityDataStore get() = instance!! @@ -87,14 +87,14 @@ class VaultSchemaTest { override val participants: List get() = listOf(owner) - override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Create(), copy(owner = newOwner)) + override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Commands.Create(), copy(owner = newOwner)) } interface Commands : CommandData { class Create : TypeOnlyCommandData(), Commands } - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { // Always accepts. } } @@ -119,8 +119,8 @@ class VaultSchemaTest { val commands = emptyList>() val attachments = emptyList() val id = SecureHash.randomSHA256() - val signers = listOf(DUMMY_NOTARY_KEY.public) val timeWindow: TimeWindow? = null + val privacySalt: PrivacySalt = PrivacySalt() transaction = LedgerTransaction( inputs, outputs, @@ -128,9 +128,8 @@ class VaultSchemaTest { attachments, id, notary, - signers, timeWindow, - TransactionType.General + privacySalt ) } @@ -151,8 +150,8 @@ class VaultSchemaTest { val commands = emptyList>() val attachments = emptyList() val id = SecureHash.randomSHA256() - val signers = listOf(DUMMY_NOTARY_KEY.public) val timeWindow: TimeWindow? = null + val privacySalt: PrivacySalt = PrivacySalt() return LedgerTransaction( inputs, outputs, @@ -160,9 +159,8 @@ class VaultSchemaTest { attachments, id, notary, - signers, timeWindow, - TransactionType.General + privacySalt ) } @@ -466,12 +464,12 @@ class VaultSchemaTest { fun testInsert() { val stateEntity = createStateEntity(transaction!!.inputs[0]) val latch = CountDownLatch(1) - odata.insert(stateEntity).subscribe { stateEntity -> - Assert.assertNotNull(stateEntity.txId) - Assert.assertTrue(stateEntity.txId.isNotEmpty()) + odata.insert(stateEntity).subscribe { + Assert.assertNotNull(it.txId) + Assert.assertTrue(it.txId.isNotEmpty()) val cached = data.select(VaultSchema.VaultStates::class) - .where(VaultSchema.VaultStates::txId.eq(stateEntity.txId)).get().first() - Assert.assertSame(cached, stateEntity) + .where(VaultSchema.VaultStates::txId.eq(it.txId)).get().first() + Assert.assertSame(cached, it) latch.countDown() } latch.await() diff --git a/node/build.gradle b/node/build.gradle index 2b453fb0cc..3fd0ffb452 100644 --- a/node/build.gradle +++ b/node/build.gradle @@ -20,13 +20,6 @@ configurations { exclude group: 'io.netty', module: 'netty-handler' } - testCompile { - // Excluding javassist:javassist because it clashes with Hibernate's - // transitive org.javassist:javassist dependency. - // TODO: Remove this exclusion once junit-quickcheck 0.8 is released. - exclude group: 'javassist', module: 'javassist' - } - integrationTestCompile.extendsFrom testCompile integrationTestRuntime.extendsFrom testRuntime @@ -41,6 +34,11 @@ sourceSets { runtimeClasspath += main.output + test.output srcDir file('src/integration-test/kotlin') } + java { + compileClasspath += main.output + test.output + runtimeClasspath += main.output + test.output + srcDir file('src/integration-test/java') + } resources { srcDir file('src/integration-test/resources') } @@ -53,6 +51,11 @@ sourceSets { runtimeClasspath += main.output srcDir file('src/smoke-test/kotlin') } + java { + compileClasspath += main.output + runtimeClasspath += main.output + srcDir file('src/smoke-test/java') + } } } @@ -63,11 +66,7 @@ processResources { } processSmokeTestResources { - // Build one of the demos so that we can test CorDapp scanning in CordappScanningTest. It doesn't matter which demo - // we use, just make sure the test is updated accordingly. - from(project(':samples:trader-demo').tasks.jar) { - rename 'trader-demo-(.*)', 'trader-demo.jar' - } + // Bring in the fully built corda.jar for use by NodeFactory in the smoke tests from(project(':node:capsule').tasks.buildCordaJAR) { rename 'corda-(.*)', 'corda.jar' } @@ -77,7 +76,6 @@ processSmokeTestResources { // build/reports/project/dependencies/index.html for green highlighted parts of the tree. dependencies { - compile project(':finance') compile project(':node-schemas') compile project(':node-api') compile project(':client:rpc') @@ -137,9 +135,9 @@ dependencies { // Unit testing helpers. testCompile "junit:junit:$junit_version" testCompile "org.assertj:assertj-core:${assertj_version}" - testCompile "com.pholser:junit-quickcheck-core:$quickcheck_version" testCompile project(':test-utils') testCompile project(':client:jfx') + testCompile project(':finance') // sample test schemas testCompile project(path: ':finance', configuration: 'testArtifacts') @@ -206,7 +204,13 @@ task integrationTest(type: Test) { classpath = sourceSets.integrationTest.runtimeClasspath } +task smokeTestJar(type: Jar) { + baseName = project.name + '-smoke-test' + from sourceSets.smokeTest.output +} + task smokeTest(type: Test) { + dependsOn smokeTestJar testClassesDir = sourceSets.smokeTest.output.classesDir classpath = sourceSets.smokeTest.runtimeClasspath } @@ -216,5 +220,5 @@ jar { } publish { - name = jar.baseName -} \ No newline at end of file + name jar.baseName +} diff --git a/node/capsule/build.gradle b/node/capsule/build.gradle index 63ba6e529e..ae66f43dd7 100644 --- a/node/capsule/build.gradle +++ b/node/capsule/build.gradle @@ -67,6 +67,6 @@ artifacts { } publish { - name = 'corda' disableDefaultJar = true + name 'corda' } diff --git a/node/src/integration-test/kotlin/net/corda/node/BootTests.kt b/node/src/integration-test/kotlin/net/corda/node/BootTests.kt index b5366a7ace..00a64b0926 100644 --- a/node/src/integration-test/kotlin/net/corda/node/BootTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/BootTests.kt @@ -1,21 +1,21 @@ package net.corda.node import co.paralleluniverse.fibers.Suspendable -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC -import net.corda.core.getOrThrow import net.corda.core.messaging.startFlow import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType +import net.corda.core.utilities.getOrThrow import net.corda.testing.ALICE -import net.corda.testing.driver.driver import net.corda.node.internal.NodeStartup import net.corda.node.services.startFlowPermission import net.corda.nodeapi.User import net.corda.testing.driver.ListenProcessDeathException import net.corda.testing.driver.NetworkMapStartStrategy import net.corda.testing.ProjectStructure.projectRootDir +import net.corda.testing.driver.driver import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Test diff --git a/node/src/integration-test/kotlin/net/corda/node/CordappScanningDriverTest.kt b/node/src/integration-test/kotlin/net/corda/node/CordappScanningDriverTest.kt index 0d4f786dcf..258c18cfae 100644 --- a/node/src/integration-test/kotlin/net/corda/node/CordappScanningDriverTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/CordappScanningDriverTest.kt @@ -1,14 +1,14 @@ package net.corda.node import co.paralleluniverse.fibers.Suspendable -import com.google.common.util.concurrent.Futures import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatedBy import net.corda.core.flows.InitiatingFlow import net.corda.core.flows.StartableByRPC -import net.corda.core.getOrThrow import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.transpose import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow import net.corda.testing.ALICE import net.corda.testing.BOB import net.corda.core.utilities.unwrap @@ -24,9 +24,9 @@ class CordappScanningDriverTest { val user = User("u", "p", setOf(startFlowPermission())) // The driver will automatically pick up the annotated flows below driver { - val (alice, bob) = Futures.allAsList( + val (alice, bob) = listOf( startNode(ALICE.name, rpcUsers = listOf(user)), - startNode(BOB.name)).getOrThrow() + startNode(BOB.name)).transpose().getOrThrow() val initiatedFlowClass = alice.rpcClientToNode() .start(user.username, user.password) .proxy diff --git a/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt b/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt index c22c69ec34..b539dfee7a 100644 --- a/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/NodePerformanceTests.kt @@ -2,15 +2,15 @@ package net.corda.node import co.paralleluniverse.fibers.Suspendable import com.google.common.base.Stopwatch -import com.google.common.util.concurrent.Futures import net.corda.core.contracts.DOLLARS import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC +import net.corda.core.internal.concurrent.transpose import net.corda.core.messaging.startFlow -import net.corda.core.minutes +import net.corda.core.utilities.minutes import net.corda.core.node.services.ServiceInfo import net.corda.core.utilities.OpaqueBytes -import net.corda.core.utilities.div +import net.corda.testing.performance.div import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow import net.corda.node.services.startFlowPermission @@ -113,7 +113,7 @@ class NodePerformanceTests { val doneFutures = (1..100).toList().parallelStream().map { connection.proxy.startFlow(::CashIssueFlow, 1.DOLLARS, OpaqueBytes.of(0), a.nodeInfo.legalIdentity, a.nodeInfo.notaryIdentity).returnValue }.toList() - Futures.allAsList(doneFutures).get() + doneFutures.transpose().get() println("STARTING PAYMENT") startPublishingFixedRateInjector(metricRegistry, 8, 5.minutes, 100L / TimeUnit.SECONDS) { connection.proxy.startFlow(::CashPaymentFlow, 1.DOLLARS, a.nodeInfo.legalIdentity).returnValue.get() diff --git a/node/src/integration-test/kotlin/net/corda/node/services/AdvertisedServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/AdvertisedServiceTests.kt new file mode 100644 index 0000000000..9f422f7e16 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/AdvertisedServiceTests.kt @@ -0,0 +1,41 @@ +package net.corda.node.services + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.StartableByRPC +import net.corda.core.messaging.startFlow +import net.corda.core.node.services.ServiceInfo +import net.corda.core.node.services.ServiceType +import net.corda.nodeapi.User +import net.corda.testing.driver.driver +import org.bouncycastle.asn1.x500.X500Name +import org.junit.Test +import kotlin.test.assertTrue + +class AdvertisedServiceTests { + private val serviceName = X500Name("CN=Custom Service,O=R3,OU=corda,L=London,C=GB") + private val serviceType = ServiceType.corda.getSubType("custom") + private val user = "bankA" + private val pass = "passA" + + + @StartableByRPC + class ServiceTypeCheckingFlow : FlowLogic() { + @Suspendable + override fun call(): Boolean { + return serviceHub.networkMapCache.getAnyServiceOfType(ServiceType.corda.getSubType("custom")) != null + } + } + + @Test + fun `service is accessible through getAnyServiceOfType`() { + driver(startNodesInProcess = true) { + val bankA = startNode(rpcUsers = listOf(User(user, pass, setOf(startFlowPermission())))).get() + startNode(advertisedServices = setOf(ServiceInfo(serviceType, serviceName))).get() + bankA.rpcClientToNode().use(user, pass) { connection -> + val result = connection.proxy.startFlow(::ServiceTypeCheckingFlow).returnValue.get() + assertTrue(result) + } + } + } +} diff --git a/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt index 7e7060cbc6..bba066f2be 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt @@ -3,26 +3,26 @@ package net.corda.node.services import com.nhaarman.mockito_kotlin.whenever import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TransactionType -import net.corda.testing.contracts.DummyContract -import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.SecureHash -import net.corda.core.div -import net.corda.core.getOrThrow +import net.corda.core.crypto.composite.CompositeKey +import net.corda.core.internal.div +import net.corda.core.flows.NotaryError +import net.corda.core.flows.NotaryException +import net.corda.core.flows.NotaryFlow import net.corda.core.identity.Party import net.corda.core.node.services.ServiceInfo +import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.Try -import net.corda.flows.NotaryError -import net.corda.flows.NotaryException -import net.corda.flows.NotaryFlow +import net.corda.core.utilities.getOrThrow import net.corda.node.internal.AbstractNode +import net.corda.node.services.config.BFTSMaRtConfiguration import net.corda.node.services.network.NetworkMapService import net.corda.node.services.transactions.BFTNonValidatingNotaryService import net.corda.node.services.transactions.minClusterSize import net.corda.node.services.transactions.minCorrectReplicas import net.corda.node.utilities.ServiceIdentityGenerator -import net.corda.node.utilities.transaction +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.MockNetwork import org.bouncycastle.asn1.x500.X500Name import org.junit.After @@ -44,7 +44,7 @@ class BFTNotaryServiceTests { mockNet.stopNodes() } - private fun bftNotaryCluster(clusterSize: Int): Party { + private fun bftNotaryCluster(clusterSize: Int, exposeRaces: Boolean = false): Party { Files.deleteIfExists("config" / "currentView") // XXX: Make config object warn if this exists? val replicaIds = (0 until clusterSize) val party = ServiceIdentityGenerator.generateToDisk( @@ -58,13 +58,26 @@ class BFTNotaryServiceTests { node.network.myAddress, advertisedServices = bftNotaryService, configOverrides = { - whenever(it.bftReplicaId).thenReturn(replicaId) + whenever(it.bftSMaRt).thenReturn(BFTSMaRtConfiguration(replicaId, false, exposeRaces)) whenever(it.notaryClusterAddresses).thenReturn(notaryClusterAddresses) }) } return party } + /** Failure mode is the redundant replica gets stuck in startup, so we can't dispose it cleanly at the end. */ + @Test + fun `all replicas start even if there is a new consensus during startup`() { + val notary = bftNotaryCluster(minClusterSize(1), true) // This true adds a sleep to expose the race. + val f = node.run { + val trivialTx = signInitialTransaction(notary) {} + // Create a new consensus while the redundant replica is sleeping: + services.startFlow(NotaryFlow.Client(trivialTx)).resultFuture + } + mockNet.runNetwork() + f.getOrThrow() + } + @Test fun `detect double spend 1 faulty`() { detectDoubleSpend(1) @@ -125,8 +138,8 @@ class BFTNotaryServiceTests { private fun AbstractNode.signInitialTransaction( notary: Party, makeUnique: Boolean = false, - block: TransactionType.General.Builder.() -> Any? -) = services.signInitialTransaction(TransactionType.General.Builder(notary).apply { + block: TransactionBuilder.() -> Any? +) = services.signInitialTransaction(TransactionBuilder(notary).apply { block() if (makeUnique) { addAttachment(SecureHash.randomSHA256()) diff --git a/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt index efb4bf8b80..e9db0c5a85 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/DistributedServiceTests.kt @@ -1,27 +1,25 @@ package net.corda.node.services -import net.corda.core.bufferUntilSubscribed import net.corda.core.contracts.Amount import net.corda.core.contracts.POUNDS import net.corda.core.identity.Party -import net.corda.core.getOrThrow +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.StateMachineUpdate import net.corda.core.messaging.startFlow import net.corda.core.node.NodeInfo import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow import net.corda.testing.ALICE import net.corda.testing.DUMMY_NOTARY import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow -import net.corda.testing.driver.NodeHandle -import net.corda.testing.driver.driver import net.corda.node.services.transactions.RaftValidatingNotaryService import net.corda.nodeapi.User -import net.corda.testing.expect -import net.corda.testing.expectEvents +import net.corda.testing.* +import net.corda.testing.driver.NodeHandle +import net.corda.testing.driver.driver import net.corda.testing.node.DriverBasedTest -import net.corda.testing.replicate import org.junit.Test import rx.Observable import java.util.* @@ -65,7 +63,7 @@ class DistributedServiceTests : DriverBasedTest() { aliceProxy = connectRpc(alice) val rpcClientsToNotaries = notaries.map(::connectRpc) notaryStateMachines = Observable.from(rpcClientsToNotaries.map { proxy -> - proxy.stateMachinesAndUpdates().second.map { Pair(proxy.nodeIdentity(), it) } + proxy.stateMachinesFeed().updates.map { Pair(proxy.nodeIdentity(), it) } }).flatMap { it.onErrorResumeNext(Observable.empty()) }.bufferUntilSubscribed() runTest() @@ -86,8 +84,7 @@ class DistributedServiceTests : DriverBasedTest() { val notarisationsPerNotary = HashMap() notaryStateMachines.expectEvents(isStrict = false) { replicate>(50) { - expect(match = { it.second is StateMachineUpdate.Added }) { - val (notary, update) = it + expect(match = { it.second is StateMachineUpdate.Added }) { (notary, update) -> update as StateMachineUpdate.Added notarisationsPerNotary.compute(notary.legalIdentity) { _, number -> number?.plus(1) ?: 1 } } @@ -125,8 +122,7 @@ class DistributedServiceTests : DriverBasedTest() { val notarisationsPerNotary = HashMap() notaryStateMachines.expectEvents(isStrict = false) { replicate>(30) { - expect(match = { it.second is StateMachineUpdate.Added }) { - val (notary, update) = it + expect(match = { it.second is StateMachineUpdate.Added }) { (notary, update) -> update as StateMachineUpdate.Added notarisationsPerNotary.compute(notary.legalIdentity) { _, number -> number?.plus(1) ?: 1 } } diff --git a/node/src/integration-test/kotlin/net/corda/node/services/RaftNotaryServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/RaftNotaryServiceTests.kt index 566c4674e1..b999da66d2 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/RaftNotaryServiceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/RaftNotaryServiceTests.kt @@ -1,19 +1,18 @@ package net.corda.node.services -import com.google.common.util.concurrent.Futures import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TransactionType import net.corda.testing.contracts.DummyContract -import net.corda.core.getOrThrow import net.corda.core.identity.Party -import net.corda.core.map import net.corda.testing.DUMMY_BANK_A -import net.corda.flows.NotaryError -import net.corda.flows.NotaryException -import net.corda.flows.NotaryFlow +import net.corda.core.flows.NotaryError +import net.corda.core.flows.NotaryException +import net.corda.core.flows.NotaryFlow +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.transpose +import net.corda.core.utilities.getOrThrow +import net.corda.core.transactions.TransactionBuilder import net.corda.node.internal.AbstractNode -import net.corda.node.utilities.transaction import net.corda.testing.node.NodeBasedTest import org.bouncycastle.asn1.x500.X500Name import org.junit.Test @@ -26,22 +25,22 @@ class RaftNotaryServiceTests : NodeBasedTest() { @Test fun `detect double spend`() { - val (bankA) = Futures.allAsList( + val (bankA) = listOf( startNode(DUMMY_BANK_A.name), startNotaryCluster(notaryName, 3).map { it.first() } - ).getOrThrow() + ).transpose().getOrThrow() val notaryParty = bankA.services.networkMapCache.getNotary(notaryName)!! val inputState = issueState(bankA, notaryParty) - val firstTxBuilder = TransactionType.General.Builder(notaryParty).withItems(inputState) + val firstTxBuilder = TransactionBuilder(notaryParty).withItems(inputState) val firstSpendTx = bankA.services.signInitialTransaction(firstTxBuilder) val firstSpend = bankA.services.startFlow(NotaryFlow.Client(firstSpendTx)) firstSpend.resultFuture.getOrThrow() - val secondSpendBuilder = TransactionType.General.Builder(notaryParty).withItems(inputState).run { + val secondSpendBuilder = TransactionBuilder(notaryParty).withItems(inputState).run { val dummyState = DummyContract.SingleOwnerState(0, bankA.info.legalIdentity) addOutputState(dummyState) this diff --git a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt index 459310ecf7..9ea00ba466 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/FlowVersioningTest.kt @@ -1,40 +1,45 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Suspendable -import com.google.common.util.concurrent.Futures import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatingFlow -import net.corda.core.getOrThrow import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.transpose +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap import net.corda.testing.ALICE import net.corda.testing.BOB -import net.corda.core.utilities.unwrap import net.corda.testing.node.NodeBasedTest import org.assertj.core.api.Assertions.assertThat import org.junit.Test class FlowVersioningTest : NodeBasedTest() { @Test - fun `core flows receive platform version of initiator`() { - val (alice, bob) = Futures.allAsList( + fun `getFlowContext returns the platform version for core flows`() { + val (alice, bob) = listOf( startNode(ALICE.name, platformVersion = 2), - startNode(BOB.name, platformVersion = 3)).getOrThrow() - bob.installCoreFlow(ClientFlow::class, ::SendBackPlatformVersionFlow) - val resultFuture = alice.services.startFlow(ClientFlow(bob.info.legalIdentity)).resultFuture - assertThat(resultFuture.getOrThrow()).isEqualTo(2) + startNode(BOB.name, platformVersion = 3)).transpose().getOrThrow() + bob.installCoreFlow(PretendInitiatingCoreFlow::class, ::PretendInitiatedCoreFlow) + val (alicePlatformVersionAccordingToBob, bobPlatformVersionAccordingToAlice) = alice.services.startFlow( + PretendInitiatingCoreFlow(bob.info.legalIdentity)).resultFuture.getOrThrow() + assertThat(alicePlatformVersionAccordingToBob).isEqualTo(2) + assertThat(bobPlatformVersionAccordingToAlice).isEqualTo(3) } @InitiatingFlow - private class ClientFlow(val otherParty: Party) : FlowLogic() { + private class PretendInitiatingCoreFlow(val initiatedParty: Party) : FlowLogic>() { @Suspendable - override fun call(): Any { - return sendAndReceive(otherParty, "This is ignored. We only send to kick off the flow on the other side").unwrap { it } + override fun call(): Pair { + return Pair( + receive(initiatedParty).unwrap { it }, + getFlowContext(initiatedParty).flowVersion + ) } } - private class SendBackPlatformVersionFlow(val otherParty: Party, val otherPartysPlatformVersion: Int) : FlowLogic() { + private class PretendInitiatedCoreFlow(val initiatingParty: Party) : FlowLogic() { @Suspendable - override fun call() = send(otherParty, otherPartysPlatformVersion) + override fun call() = send(initiatingParty, getFlowContext(initiatingParty).flowVersion) } } \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/services/statemachine/LargeTransactionsTest.kt b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/LargeTransactionsTest.kt new file mode 100644 index 0000000000..caf80073d6 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/services/statemachine/LargeTransactionsTest.kt @@ -0,0 +1,74 @@ +package net.corda.node.services.statemachine + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.internal.InputStreamAndHash +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.messaging.startFlow +import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.unwrap +import net.corda.testing.BOB +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.aliceBobAndNotary +import net.corda.testing.contracts.DummyState +import net.corda.testing.driver.driver +import org.junit.Test +import kotlin.test.assertEquals + +/** + * Check that we can add lots of large attachments to a transaction and that it works OK, e.g. does not hit the + * transaction size limit (which should only consider the hashes). + */ +class LargeTransactionsTest { + @StartableByRPC @InitiatingFlow + class SendLargeTransactionFlow(val hash1: SecureHash, val hash2: SecureHash, val hash3: SecureHash, val hash4: SecureHash) : FlowLogic() { + @Suspendable + override fun call() { + val tx = TransactionBuilder(notary = DUMMY_NOTARY) + .addOutputState(DummyState()) + .addAttachment(hash1) + .addAttachment(hash2) + .addAttachment(hash3) + .addAttachment(hash4) + val stx = serviceHub.signInitialTransaction(tx, serviceHub.legalIdentityKey) + // Send to the other side and wait for it to trigger resolution from us. + val bob = serviceHub.networkMapCache.getNodeByLegalName(BOB.name)!!.legalIdentity + subFlow(SendTransactionFlow(bob, stx)) + receive(bob) + } + } + + @InitiatedBy(SendLargeTransactionFlow::class) @Suppress("UNUSED") + class ReceiveLargeTransactionFlow(private val counterParty: Party) : FlowLogic() { + @Suspendable + override fun call() { + subFlow(ReceiveTransactionFlow(counterParty)) + // Unblock the other side by sending some dummy object (Unit is fine here as it's a singleton). + send(counterParty, Unit) + } + } + + @Test + fun checkCanSendLargeTransactions() { + // These 4 attachments yield a transaction that's got >10mb attached, so it'd push us over the Artemis + // max message size. + val bigFile1 = InputStreamAndHash.createInMemoryTestZip(1024 * 1024 * 3, 0) + val bigFile2 = InputStreamAndHash.createInMemoryTestZip(1024 * 1024 * 3, 1) + val bigFile3 = InputStreamAndHash.createInMemoryTestZip(1024 * 1024 * 3, 2) + val bigFile4 = InputStreamAndHash.createInMemoryTestZip(1024 * 1024 * 3, 3) + driver(startNodesInProcess = true) { + val (alice, _, _) = aliceBobAndNotary() + alice.useRPC { + val hash1 = it.uploadAttachment(bigFile1.inputStream) + val hash2 = it.uploadAttachment(bigFile2.inputStream) + val hash3 = it.uploadAttachment(bigFile3.inputStream) + val hash4 = it.uploadAttachment(bigFile4.inputStream) + assertEquals(hash1, bigFile1.sha256) + // Should not throw any exceptions. + it.startFlow(::SendLargeTransactionFlow, hash1, hash2, hash3, hash4).returnValue.get() + } + } + } +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt b/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt index 261aaff04f..0057ba0470 100644 --- a/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt +++ b/node/src/integration-test/kotlin/net/corda/node/utilities/JDBCHashMapTestSuite.kt @@ -10,16 +10,13 @@ import com.google.common.collect.testing.features.MapFeature import com.google.common.collect.testing.features.SetFeature import com.google.common.collect.testing.testers.* import junit.framework.TestSuite +import net.corda.testing.* import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties import org.assertj.core.api.Assertions.assertThat -import org.jetbrains.exposed.sql.Database -import org.jetbrains.exposed.sql.Transaction -import org.jetbrains.exposed.sql.transactions.TransactionManager import org.junit.* import org.junit.runner.RunWith import org.junit.runners.Suite -import java.io.Closeable -import java.sql.Connection import java.util.* @RunWith(Suite::class) @@ -32,9 +29,8 @@ import java.util.* JDBCHashMapTestSuite.SetConstrained::class) class JDBCHashMapTestSuite { companion object { - lateinit var dataSource: Closeable - lateinit var transaction: Transaction - lateinit var database: Database + lateinit var transaction: DatabaseTransaction + lateinit var database: CordaPersistence lateinit var loadOnInitFalseMap: JDBCHashMap lateinit var memoryConstrainedMap: JDBCHashMap lateinit var loadOnInitTrueMap: JDBCHashMap @@ -45,9 +41,8 @@ class JDBCHashMapTestSuite { @JvmStatic @BeforeClass fun before() { - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + initialiseTestSerialization() + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = { throw UnsupportedOperationException("Identity Service should not be in use") }) setUpDatabaseTx() loadOnInitFalseMap = JDBCHashMap("test_map_false", loadOnInit = false) memoryConstrainedMap = JDBCHashMap("test_map_constrained", loadOnInit = false, maxBuckets = 1) @@ -61,7 +56,8 @@ class JDBCHashMapTestSuite { @AfterClass fun after() { closeDatabaseTx() - dataSource.close() + database.close() + resetTestSerialization() } @JvmStatic @@ -105,7 +101,7 @@ class JDBCHashMapTestSuite { .createTestSuite() private fun setUpDatabaseTx() { - transaction = TransactionManager.currentOrNew(Connection.TRANSACTION_REPEATABLE_READ) + transaction = DatabaseTransactionManager.currentOrNew() } private fun closeDatabaseTx() { @@ -203,7 +199,7 @@ class JDBCHashMapTestSuite { * * If the Map reloads, then so will the Set as it just delegates. */ - class MapCanBeReloaded { + class MapCanBeReloaded : TestDependencyInjectionBase() { private val ops = listOf(Triple(AddOrRemove.ADD, "A", "1"), Triple(AddOrRemove.ADD, "B", "2"), Triple(AddOrRemove.ADD, "C", "3"), @@ -228,22 +224,18 @@ class JDBCHashMapTestSuite { private val transientMapForComparison = applyOpsToMap(LinkedHashMap()) - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence @Before fun before() { - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = { throw UnsupportedOperationException("Identity Service should not be in use") }) } @After fun after() { - dataSource.close() + database.close() } - @Test fun `fill map and check content after reconstruction`() { database.transaction { diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt index d4bb517abf..d053552c1c 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityAsNodeTest.kt @@ -1,9 +1,10 @@ package net.corda.services.messaging -import net.corda.core.copyTo -import net.corda.core.createDirectories -import net.corda.core.crypto.* -import net.corda.core.exists +import net.corda.core.crypto.Crypto +import net.corda.core.internal.copyTo +import net.corda.core.internal.createDirectories +import net.corda.core.internal.exists +import net.corda.node.utilities.* import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEER_USER import net.corda.nodeapi.RPCApi @@ -19,7 +20,6 @@ import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.bouncycastle.asn1.x509.GeneralName import org.bouncycastle.asn1.x509.GeneralSubtree import org.bouncycastle.asn1.x509.NameConstraints -import org.bouncycastle.cert.path.CertPath import org.junit.Test import java.nio.file.Files @@ -94,7 +94,9 @@ class MQSecurityAsNodeTest : MQSecurityTest() { javaClass.classLoader.getResourceAsStream("net/corda/node/internal/certificates/cordatruststore.jks").copyTo(trustStoreFile) } - val caKeyStore = KeyStoreUtilities.loadKeyStore(javaClass.classLoader.getResourceAsStream("net/corda/node/internal/certificates/cordadevcakeys.jks"), "cordacadevpass") + val caKeyStore = loadKeyStore( + javaClass.classLoader.getResourceAsStream("net/corda/node/internal/certificates/cordadevcakeys.jks"), + "cordacadevpass") val rootCACert = caKeyStore.getX509Certificate(X509Utilities.CORDA_ROOT_CA) val intermediateCA = caKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_INTERMEDIATE_CA, "cordacadevkeypass") @@ -102,25 +104,26 @@ class MQSecurityAsNodeTest : MQSecurityTest() { // Set name constrain to the legal name. val nameConstraints = NameConstraints(arrayOf(GeneralSubtree(GeneralName(GeneralName.directoryName, legalName))), arrayOf()) - val clientCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, intermediateCA.certificate, intermediateCA.keyPair, legalName, clientKey.public, nameConstraints = nameConstraints) + val clientCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, intermediateCA.certificate, + intermediateCA.keyPair, legalName, clientKey.public, nameConstraints = nameConstraints) val tlsKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) // Using different x500 name in the TLS cert which is not allowed in the name constraints. val clientTLSCert = X509Utilities.createCertificate(CertificateType.TLS, clientCACert, clientKey, MINI_CORP.name, tlsKey.public) val keyPass = keyStorePassword.toCharArray() - val clientCAKeystore = KeyStoreUtilities.loadOrCreateKeyStore(nodeKeystore, keyStorePassword) + val clientCAKeystore = loadOrCreateKeyStore(nodeKeystore, keyStorePassword) clientCAKeystore.addOrReplaceKey( X509Utilities.CORDA_CLIENT_CA, clientKey.private, keyPass, - CertPath(arrayOf(clientCACert, intermediateCA.certificate, rootCACert))) + arrayOf(clientCACert, intermediateCA.certificate, rootCACert)) clientCAKeystore.save(nodeKeystore, keyStorePassword) - val tlsKeystore = KeyStoreUtilities.loadOrCreateKeyStore(sslKeystore, keyStorePassword) + val tlsKeystore = loadOrCreateKeyStore(sslKeystore, keyStorePassword) tlsKeystore.addOrReplaceKey( X509Utilities.CORDA_CLIENT_TLS, tlsKey.private, keyPass, - CertPath(arrayOf(clientTLSCert, clientCACert, intermediateCA.certificate, rootCACert))) + arrayOf(clientTLSCert, clientCACert, intermediateCA.certificate, rootCACert)) tlsKeystore.save(sslKeystore, keyStorePassword) } } diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt index 9cae274260..b16f87f806 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/MQSecurityTest.kt @@ -7,11 +7,11 @@ import net.corda.core.crypto.toBase58String import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatedBy import net.corda.core.flows.InitiatingFlow -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.messaging.CordaRPCOps import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.crypto.random63BitValue +import net.corda.core.utilities.getOrThrow import net.corda.testing.ALICE import net.corda.testing.BOB import net.corda.core.utilities.unwrap @@ -151,7 +151,7 @@ abstract class MQSecurityTest : NodeBasedTest() { } fun loginToRPC(target: NetworkHostAndPort, rpcUser: User, sslConfiguration: SSLConfiguration? = null): CordaRPCOps { - return CordaRPCClient(target, sslConfiguration).start(rpcUser.username, rpcUser.password).proxy + return CordaRPCClient(target, sslConfiguration, initialiseSerialization = false).start(rpcUser.username, rpcUser.password).proxy } fun loginToRPCAndGetClientQueue(): String { diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt index d7e0dc2c3e..df2cc5f069 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PMessagingTest.kt @@ -1,17 +1,20 @@ package net.corda.services.messaging -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.* +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.concurrent.transpose +import net.corda.core.internal.elapsedTime +import net.corda.core.internal.times import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.node.services.ServiceInfo import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds import net.corda.node.internal.Node +import net.corda.node.services.api.DEFAULT_SESSION_ID import net.corda.node.services.messaging.* import net.corda.node.services.transactions.RaftValidatingNotaryService import net.corda.node.services.transactions.SimpleNotaryService @@ -36,13 +39,13 @@ class P2PMessagingTest : NodeBasedTest() { @Test fun `network map will work after restart`() { val identities = listOf(DUMMY_BANK_A, DUMMY_BANK_B, DUMMY_NOTARY) - fun startNodes() = Futures.allAsList(identities.map { startNode(it.name) }) + fun startNodes() = identities.map { startNode(it.name) }.transpose() val startUpDuration = elapsedTime { startNodes().getOrThrow() } // Start the network map a second time - this will restore message queues from the journal. // This will hang and fail prior the fix. https://github.com/corda/corda/issues/37 stopAllNodes() - startNodes().getOrThrow(timeout = startUpDuration.multipliedBy(3)) + startNodes().getOrThrow(timeout = startUpDuration * 3) } // https://github.com/corda/corda/issues/71 @@ -72,7 +75,7 @@ class P2PMessagingTest : NodeBasedTest() { DUMMY_MAP.name, advertisedServices = setOf(distributedService), configOverrides = mapOf("notaryNodeAddress" to notaryClusterAddress.toString())) - val (serviceNode2, alice) = Futures.allAsList( + val (serviceNode2, alice) = listOf( startNode( SERVICE_2_NAME, advertisedServices = setOf(distributedService), @@ -80,7 +83,7 @@ class P2PMessagingTest : NodeBasedTest() { "notaryNodeAddress" to freeLocalHostAndPort().toString(), "notaryClusterAddresses" to listOf(notaryClusterAddress.toString()))), startNode(ALICE.name) - ).getOrThrow() + ).transpose().getOrThrow() assertAllNodesAreUsed(listOf(networkMapNode, serviceNode2), DISTRIBUTED_SERVICE_NAME, alice) } @@ -214,7 +217,7 @@ class P2PMessagingTest : NodeBasedTest() { } } - private fun Node.receiveFrom(target: MessageRecipients): ListenableFuture { + private fun Node.receiveFrom(target: MessageRecipients): CordaFuture { val request = TestRequest(replyTo = network.myAddress) return network.sendRequest(javaClass.name, request, target) } diff --git a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PSecurityTest.kt b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PSecurityTest.kt index 5036db9077..0e32802348 100644 --- a/node/src/integration-test/kotlin/net/corda/services/messaging/P2PSecurityTest.kt +++ b/node/src/integration-test/kotlin/net/corda/services/messaging/P2PSecurityTest.kt @@ -1,13 +1,13 @@ package net.corda.services.messaging -import com.google.common.util.concurrent.ListenableFuture import com.nhaarman.mockito_kotlin.whenever -import net.corda.core.crypto.X509Utilities +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.cert -import net.corda.core.getOrThrow -import net.corda.core.node.NodeInfo import net.corda.core.crypto.random63BitValue -import net.corda.core.seconds +import net.corda.core.node.NodeInfo +import net.corda.core.utilities.seconds +import net.corda.core.utilities.NonEmptySet +import net.corda.core.utilities.getOrThrow import net.corda.node.internal.NetworkMapInfo import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.messaging.sendRequest @@ -30,7 +30,7 @@ class P2PSecurityTest : NodeBasedTest() { @Test fun `incorrect legal name for the network map service config`() { - val incorrectNetworkMapName = X509Utilities.getDevX509Name("NetworkMap-${random63BitValue()}") + val incorrectNetworkMapName = getTestX509Name("NetworkMap-${random63BitValue()}") val node = startNode(BOB.name, configOverrides = mapOf( "networkMapService" to mapOf( "address" to networkMapNode.configuration.p2pAddress.toString(), @@ -65,9 +65,9 @@ class P2PSecurityTest : NodeBasedTest() { return SimpleNode(config, trustRoot = trustRoot).apply { start() } } - private fun SimpleNode.registerWithNetworkMap(registrationName: X500Name): ListenableFuture { + private fun SimpleNode.registerWithNetworkMap(registrationName: X500Name): CordaFuture { val legalIdentity = getTestPartyAndCertificate(registrationName, identity.public) - val nodeInfo = NodeInfo(listOf(MOCK_HOST_AND_PORT), legalIdentity, setOf(legalIdentity), 1) + val nodeInfo = NodeInfo(listOf(MOCK_HOST_AND_PORT), legalIdentity, NonEmptySet.of(legalIdentity), 1) val registration = NodeRegistration(nodeInfo, System.currentTimeMillis(), AddOrRemove.ADD, Instant.MAX) val request = RegistrationRequest(registration.toWire(keyService, identity.public), network.myAddress) return network.sendRequest(NetworkMapService.REGISTER_TOPIC, request, networkMapNode.network.myAddress) diff --git a/node/src/main/kotlin/net/corda/node/ArgsParser.kt b/node/src/main/kotlin/net/corda/node/ArgsParser.kt index 6168e4b99b..1527beec44 100644 --- a/node/src/main/kotlin/net/corda/node/ArgsParser.kt +++ b/node/src/main/kotlin/net/corda/node/ArgsParser.kt @@ -2,7 +2,7 @@ package net.corda.node import joptsimple.OptionParser import joptsimple.util.EnumConverter -import net.corda.core.div +import net.corda.core.internal.div import net.corda.node.services.config.ConfigHelper import net.corda.node.services.config.FullNodeConfiguration import net.corda.nodeapi.config.parseAs diff --git a/node/src/main/kotlin/net/corda/node/SerialFilter.kt b/node/src/main/kotlin/net/corda/node/SerialFilter.kt index ef6ebe138d..9998eacd7b 100644 --- a/node/src/main/kotlin/net/corda/node/SerialFilter.kt +++ b/node/src/main/kotlin/net/corda/node/SerialFilter.kt @@ -1,7 +1,7 @@ package net.corda.node -import net.corda.core.DeclaredField -import net.corda.core.DeclaredField.Companion.declaredField +import net.corda.core.internal.DeclaredField +import net.corda.core.internal.staticField import net.corda.node.internal.Node import java.lang.reflect.Method import java.lang.reflect.Proxy @@ -32,8 +32,8 @@ internal object SerialFilter { undecided = statusEnum.getField("UNDECIDED").get(null) rejected = statusEnum.getField("REJECTED").get(null) val configClass = Class.forName("${filterInterface.name}\$Config") - serialFilterLock = declaredField(configClass, "serialFilterLock").value - serialFilterField = declaredField(configClass, "serialFilter") + serialFilterLock = configClass.staticField("serialFilterLock").value + serialFilterField = configClass.staticField("serialFilter") } internal fun install(acceptClass: (Class<*>) -> Boolean) { diff --git a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt index b05b0d6c33..71848d1616 100644 --- a/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/AbstractNode.kt @@ -3,17 +3,17 @@ package net.corda.node.internal import com.codahale.metrics.MetricRegistry import com.google.common.annotations.VisibleForTesting import com.google.common.collect.MutableClassToInstanceMap -import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors -import com.google.common.util.concurrent.SettableFuture import io.github.lukehutch.fastclasspathscanner.FastClasspathScanner import io.github.lukehutch.fastclasspathscanner.scanner.ScanResult -import net.corda.core.* +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.* -import net.corda.core.crypto.composite.CompositeKey import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.* +import net.corda.core.internal.concurrent.flatMap +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.RPCOps import net.corda.core.messaging.SingleMessageRecipient @@ -22,12 +22,18 @@ import net.corda.core.node.services.* import net.corda.core.node.services.NetworkMapCache.MapChange import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.core.serialization.deserialize import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.debug -import net.corda.flows.* -import net.corda.node.services.* +import net.corda.core.utilities.toNonEmptySet +import net.corda.flows.CashExitFlow +import net.corda.flows.CashIssueFlow +import net.corda.flows.CashPaymentFlow +import net.corda.flows.IssuerFlow +import net.corda.node.services.ContractUpgradeHandler +import net.corda.node.services.NotaryChangeHandler +import net.corda.node.services.NotifyTransactionHandler +import net.corda.node.services.TransactionKeyHandler import net.corda.node.services.api.* import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate @@ -54,31 +60,27 @@ import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.statemachine.flowVersionAndInitiatingClass import net.corda.node.services.transactions.* -import net.corda.node.services.vault.CashBalanceAsMetricsObserver import net.corda.node.services.vault.HibernateVaultQueryImpl import net.corda.node.services.vault.NodeVaultService import net.corda.node.services.vault.VaultSoftLockManager +import net.corda.node.utilities.* import net.corda.node.utilities.AddOrRemove.ADD -import net.corda.node.utilities.AffinityExecutor -import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction import org.apache.activemq.artemis.utils.ReusableLatch import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import org.slf4j.Logger import rx.Observable import java.io.IOException import java.lang.reflect.InvocationTargetException import java.lang.reflect.Modifier.* -import java.math.BigInteger import java.net.JarURLConnection import java.net.URI import java.nio.file.Path import java.nio.file.Paths import java.security.KeyPair -import java.security.KeyStore import java.security.KeyStoreException -import java.security.cert.* +import java.security.cert.CertificateFactory +import java.security.cert.X509Certificate +import java.sql.Connection import java.time.Clock import java.util.* import java.util.concurrent.ConcurrentHashMap @@ -132,15 +134,15 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, var inNodeNetworkMapService: NetworkMapService? = null lateinit var network: MessagingService protected val runOnStop = ArrayList<() -> Any?>() - lateinit var database: Database + lateinit var database: CordaPersistence protected var dbCloser: (() -> Any?)? = null var isPreviousCheckpointsPresent = false private set - protected val _networkMapRegistrationFuture: SettableFuture = SettableFuture.create() + protected val _networkMapRegistrationFuture = openFuture() /** Completes once the node has successfully registered with the network map service */ - val networkMapRegistrationFuture: ListenableFuture + val networkMapRegistrationFuture: CordaFuture get() = _networkMapRegistrationFuture /** Fetch CordaPluginRegistry classes registered in META-INF/services/net.corda.core.node.CordaPluginRegistry files that exist in the classpath */ @@ -159,7 +161,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, return configuration.myLegalName.locationOrNull?.let { CityDatabase[it] } } - open fun start(): AbstractNode { + open fun start() { require(!started) { "Node has already been started" } if (configuration.devMode) { @@ -210,17 +212,14 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, // TODO Remove this once the cash stuff is in its own CorDapp registerInitiatedFlow(IssuerFlow.Issuer::class.java) - initUploaders() - runOnStop += network::stop - _networkMapRegistrationFuture.setFuture(registerWithNetworkMapIfConfigured()) + _networkMapRegistrationFuture.captureLater(registerWithNetworkMapIfConfigured()) smm.start() // Shut down the SMM so no Fibers are scheduled. runOnStop += { smm.stop(acceptableLiveFiberCountOnStop()) } _services.schedulerService.start() } started = true - return this } private class ServiceInstantiationException(cause: Throwable?) : Exception(cause) @@ -284,7 +283,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, private fun handleCustomNotaryService(service: NotaryService) { runOnStop += service::stop service.start() - installCoreFlow(NotaryFlow.Client::class, { party: Party, version: Int -> service.createServiceFlow(party, version) }) + installCoreFlow(NotaryFlow.Client::class, service::createServiceFlow) } private inline fun Class<*>.requireAnnotation(): A { @@ -346,9 +345,15 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, val initiatingFlow = initiatedFlow.requireAnnotation().value.java val (version, classWithAnnotation) = initiatingFlow.flowVersionAndInitiatingClass require(classWithAnnotation == initiatingFlow) { - "${InitiatingFlow::class.java.name} must be annotated on ${initiatingFlow.name} and not on a super-type" + "${InitiatedBy::class.java.name} must point to ${classWithAnnotation.name} and not ${initiatingFlow.name}" } - val flowFactory = InitiatedFlowFactory.CorDapp(version, { ctor.newInstance(it) }) + val jarFile = Paths.get(initiatedFlow.protectionDomain.codeSource.location.toURI()) + val appName = if (jarFile.isRegularFile() && jarFile.toString().endsWith(".jar")) { + jarFile.fileName.toString().removeSuffix(".jar") + } else { + "" + } + val flowFactory = InitiatedFlowFactory.CorDapp(version, appName, { ctor.newInstance(it) }) val observable = internalRegisterFlowFactory(initiatingFlow, flowFactory, initiatedFlow, track) log.info("Registered ${initiatingFlow.name} to initiate ${initiatedFlow.name} (version $version)") return observable @@ -392,7 +397,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, * @suppress */ @VisibleForTesting - fun installCoreFlow(clientFlowClass: KClass>, flowFactory: (Party, Int) -> FlowLogic<*>) { + fun installCoreFlow(clientFlowClass: KClass>, flowFactory: (Party) -> FlowLogic<*>) { require(clientFlowClass.java.flowVersionAndInitiatingClass.first == 1) { "${InitiatingFlow::class.java.name}.version not applicable for core flows; their version is the node's platform version" } @@ -401,12 +406,10 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, } private fun installCoreFlows() { - installCoreFlow(FetchTransactionsFlow::class) { otherParty, _ -> FetchTransactionsHandler(otherParty) } - installCoreFlow(FetchAttachmentsFlow::class) { otherParty, _ -> FetchAttachmentsHandler(otherParty) } - installCoreFlow(BroadcastTransactionFlow::class) { otherParty, _ -> NotifyTransactionHandler(otherParty) } - installCoreFlow(NotaryChangeFlow::class) { otherParty, _ -> NotaryChangeHandler(otherParty) } - installCoreFlow(ContractUpgradeFlow::class) { otherParty, _ -> ContractUpgradeHandler(otherParty) } - installCoreFlow(TransactionKeyFlow::class) { otherParty, _ -> TransactionKeyHandler(otherParty) } + installCoreFlow(BroadcastTransactionFlow::class, ::NotifyTransactionHandler) + installCoreFlow(NotaryChangeFlow::class, ::NotaryChangeHandler) + installCoreFlow(ContractUpgradeFlow::class, ::ContractUpgradeHandler) + installCoreFlow(TransactionKeyFlow::class, ::TransactionKeyHandler) } /** @@ -416,7 +419,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, private fun makeServices(): MutableList { checkpointStorage = DBCheckpointStorage() _services = ServiceHubInternalImpl() - attachments = createAttachmentStorage() + attachments = NodeAttachmentService(configuration.dataSourceProperties, services.monitoringService.metrics, configuration.database) network = makeMessagingService() info = makeInfo() @@ -480,22 +483,16 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, .filterNot { isAbstract(it.modifiers) } } - private fun initUploaders() { - _services.uploaders += attachments - cordappServices.values.filterIsInstanceTo(_services.uploaders, AcceptsFileUpload::class.java) - } - private fun makeVaultObservers() { VaultSoftLockManager(services.vaultService, smm) - CashBalanceAsMetricsObserver(services, database) ScheduledActivityObserver(services) - HibernateObserver(services.vaultService.rawUpdates, HibernateConfiguration(services.schemaService)) + HibernateObserver(services.vaultService.rawUpdates, HibernateConfiguration(services.schemaService, configuration.database ?: Properties(), {services.identityService})) } private fun makeInfo(): NodeInfo { val advertisedServiceEntries = makeServiceEntries() val legalIdentity = obtainLegalIdentity() - val allIdentitiesSet = advertisedServiceEntries.map { it.identity }.toSet() + legalIdentity + val allIdentitiesSet = (advertisedServiceEntries.map { it.identity } + legalIdentity).toNonEmptySet() val addresses = myAddresses() // TODO There is no support for multiple IP addresses yet. return NodeInfo(addresses, legalIdentity, allIdentitiesSet, platformVersion, advertisedServiceEntries, findMyLocation()) } @@ -519,8 +516,8 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, private fun validateKeystore() { val containCorrectKeys = try { // This will throw IOException if key file not found or KeyStoreException if keystore password is incorrect. - val sslKeystore = KeyStoreUtilities.loadKeyStore(configuration.sslKeystore, configuration.keyStorePassword) - val identitiesKeystore = KeyStoreUtilities.loadKeyStore(configuration.nodeKeystore, configuration.keyStorePassword) + val sslKeystore = loadKeyStore(configuration.sslKeystore, configuration.keyStorePassword) + val identitiesKeystore = loadKeyStore(configuration.nodeKeystore, configuration.keyStorePassword) sslKeystore.containsAlias(X509Utilities.CORDA_CLIENT_TLS) && identitiesKeystore.containsAlias(X509Utilities.CORDA_CLIENT_CA) } catch (e: KeyStoreException) { log.warn("Certificate key store found but key store password does not match configuration.") @@ -534,7 +531,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, "or if you don't have one yet, fill out the config file and run corda.jar --initial-registration. " + "Read more at: https://docs.corda.net/permissioning.html" } - val identitiesKeystore = KeyStoreUtilities.loadKeyStore(configuration.sslKeystore, configuration.keyStorePassword) + val identitiesKeystore = loadKeyStore(configuration.sslKeystore, configuration.keyStorePassword) val tlsIdentity = identitiesKeystore.getX509Certificate(X509Utilities.CORDA_CLIENT_TLS).subject require(tlsIdentity == configuration.myLegalName) { @@ -548,11 +545,12 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, protected open fun initialiseDatabasePersistence(insideTransaction: () -> Unit) { val props = configuration.dataSourceProperties if (props.isNotEmpty()) { - val (toClose, database) = configureDatabase(props) - this.database = database + this.database = configureDatabase(props, configuration.database, identitySvc = { _services.identityService }) // Now log the vendor string as this will also cause a connection to be tested eagerly. - log.info("Connected to ${database.vendor} database.") - toClose::close.let { + database.transaction { + log.info("Connected to ${database.database.vendor} database.") + } + this.database::close.let { dbCloser = it runOnStop += it } @@ -564,25 +562,27 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, } } - /** - * Run any tasks that are needed to ensure the node is in a correct state before running start(). - */ - open fun setup(): AbstractNode { - configuration.baseDirectory.createDirectories() - return this - } - private fun makeAdvertisedServices(tokenizableServices: MutableList) { val serviceTypes = info.advertisedServices.map { it.info.type } if (NetworkMapService.type in serviceTypes) makeNetworkMapService() - val notaryServiceType = serviceTypes.singleOrNull { it.isNotary() } if (notaryServiceType != null) { - makeCoreNotaryService(notaryServiceType, tokenizableServices) + val service = makeCoreNotaryService(notaryServiceType) + if (service != null) { + service.apply { + tokenizableServices.add(this) + runOnStop += this::stop + start() + } + installCoreFlow(NotaryFlow.Client::class, service::createServiceFlow) + } else { + log.info("Notary type ${notaryServiceType.id} does not match any built-in notary types. " + + "It is expected to be loaded via a CorDapp") + } } } - private fun registerWithNetworkMapIfConfigured(): ListenableFuture { + private fun registerWithNetworkMapIfConfigured(): CordaFuture { services.networkMapCache.addNode(info) // In the unit test environment, we may sometimes run without any network map service return if (networkMapAddress == null && inNodeNetworkMapService == null) { @@ -597,7 +597,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, * Register this node with the network map cache, and load network map from a remote service (and register for * updates) if one has been supplied. */ - protected open fun registerWithNetworkMap(): ListenableFuture { + protected open fun registerWithNetworkMap(): CordaFuture { require(networkMapAddress != null || NetworkMapService.type in advertisedServices.map { it.type }) { "Initial network map address must indicate a node that provides a network map service" } @@ -611,7 +611,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, } } - private fun sendNetworkMapRegistration(networkMapAddress: SingleMessageRecipient): ListenableFuture { + private fun sendNetworkMapRegistration(networkMapAddress: SingleMessageRecipient): CordaFuture { // Register this node against the network val instant = platformClock.instant() val expires = instant + NetworkMapService.DEFAULT_EXPIRATION_PERIOD @@ -625,7 +625,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, protected abstract fun myAddresses(): List /** This is overriden by the mock node implementation to enable operation without any network map service */ - protected open fun noNetworkMapConfigured(): ListenableFuture { + protected open fun noNetworkMapConfigured(): CordaFuture { // TODO: There should be a consistent approach to configuration error exceptions. throw IllegalStateException("Configuration error: this node isn't being asked to act as the network map, nor " + "has any other map node been configured.") @@ -639,25 +639,15 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, inNodeNetworkMapService = PersistentNetworkMapService(services, configuration.minimumPlatformVersion) } - open protected fun makeCoreNotaryService(type: ServiceType, tokenizableServices: MutableList) { - val service: NotaryService = when (type) { + open protected fun makeCoreNotaryService(type: ServiceType): NotaryService? { + return when (type) { SimpleNotaryService.type -> SimpleNotaryService(services) ValidatingNotaryService.type -> ValidatingNotaryService(services) RaftNonValidatingNotaryService.type -> RaftNonValidatingNotaryService(services) RaftValidatingNotaryService.type -> RaftValidatingNotaryService(services) BFTNonValidatingNotaryService.type -> BFTNonValidatingNotaryService(services) - else -> { - log.info("Notary type ${type.id} does not match any built-in notary types. " + - "It is expected to be loaded via a CorDapp") - return - } + else -> null } - service.apply { - tokenizableServices.add(this) - runOnStop += this::stop - start() - } - installCoreFlow(NotaryFlow.Client::class, { party: Party, version: Int -> service.createServiceFlow(party, version) }) } protected open fun makeIdentityService(trustRoot: X509Certificate, @@ -709,63 +699,59 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, // the legal name is actually validated in some way. // TODO: Integrate with Key management service? - val certFactory = CertificateFactory.getInstance("X509") val keyStore = KeyStoreWrapper(configuration.nodeKeystore, configuration.keyStorePassword) val privateKeyAlias = "$serviceId-private-key" - val privKeyFile = configuration.baseDirectory / privateKeyAlias - val pubIdentityFile = configuration.baseDirectory / "$serviceId-public" - val certificateAndKeyPair = keyStore.certificateAndKeyPair(privateKeyAlias) - val identityCertPathAndKey: Pair = if (certificateAndKeyPair != null) { - val clientCertPath = keyStore.keyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA) - val (cert, keyPair) = certificateAndKeyPair - // Get keys from keystore. - val loadedServiceName = cert.subject - if (loadedServiceName != serviceName) { - throw ConfigurationException("The legal name in the config file doesn't match the stored identity keystore:" + - "$serviceName vs $loadedServiceName") - } - val certPath = certFactory.generateCertPath(listOf(cert.cert) + clientCertPath) - Pair(PartyAndCertificate(loadedServiceName, keyPair.public, cert, certPath), keyPair) - } else if (privKeyFile.exists()) { + val compositeKeyAlias = "$serviceId-composite-key" + + if (!keyStore.containsAlias(privateKeyAlias)) { + val privKeyFile = configuration.baseDirectory / privateKeyAlias + val pubIdentityFile = configuration.baseDirectory / "$serviceId-public" + val compositeKeyFile = configuration.baseDirectory / compositeKeyAlias + // TODO: Remove use of [ServiceIdentityGenerator.generateToDisk]. // Get keys from key file. - // TODO: this is here to smooth out the key storage transition, remove this in future release. - // Check that the identity in the config file matches the identity file we have stored to disk. - // This is just a sanity check. It shouldn't fail unless the admin has fiddled with the files and messed - // things up for us. - val myIdentity = pubIdentityFile.readAll().deserialize() - if (myIdentity.name != serviceName) - throw ConfigurationException("The legal name in the config file doesn't match the stored identity file:" + - "$serviceName vs ${myIdentity.name}") - // Load the private key. - val keyPair = privKeyFile.readAll().deserialize() - if (myIdentity.owningKey !is CompositeKey) { // TODO: Support case where owningKey is a composite key. - keyStore.save(serviceName, privateKeyAlias, keyPair) + // TODO: this is here to smooth out the key storage transition, remove this migration in future release. + if (privKeyFile.exists()) { + migrateKeysFromFile(keyStore, serviceName, pubIdentityFile, privKeyFile, compositeKeyFile, privateKeyAlias, compositeKeyAlias) + } else { + log.info("$privateKeyAlias not found in keystore ${configuration.nodeKeystore}, generating fresh key!") + keyStore.saveNewKeyPair(serviceName, privateKeyAlias, generateKeyPair()) } - val dummyCaKey = entropyToKeyPair(BigInteger.valueOf(111)) - val dummyCa = CertificateAndKeyPair( - X509Utilities.createSelfSignedCACertificate(X500Name("CN=Dummy CA,OU=Corda,O=R3 Ltd,L=London,C=GB"), dummyCaKey), - dummyCaKey) - val partyAndCertificate = getTestPartyAndCertificate(myIdentity, dummyCa) - // Sanity check the certificate and path - val validatorParameters = PKIXParameters(setOf(TrustAnchor(dummyCa.certificate.cert, null))) - val validator = CertPathValidator.getInstance("PKIX") - validatorParameters.isRevocationEnabled = false - validator.validate(partyAndCertificate.certPath, validatorParameters) as PKIXCertPathValidatorResult - Pair(partyAndCertificate, keyPair) - } else { - val clientCertPath = keyStore.keyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA) - val clientCA = keyStore.certificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA)!! - // Create new keys and store in keystore. - log.info("Identity key not found, generating fresh key!") - val keyPair: KeyPair = generateKeyPair() - val cert = X509Utilities.createCertificate(CertificateType.IDENTITY, clientCA.certificate, clientCA.keyPair, serviceName, keyPair.public) - val certPath = certFactory.generateCertPath(listOf(cert.cert) + clientCertPath) - keyStore.save(serviceName, privateKeyAlias, keyPair) - require(certPath.certificates.isNotEmpty()) { "Certificate path cannot be empty" } - Pair(PartyAndCertificate(serviceName, keyPair.public, cert, certPath), keyPair) } - partyKeys += identityCertPathAndKey.second - return identityCertPathAndKey + + val (cert, keyPair) = keyStore.certificateAndKeyPair(privateKeyAlias) + + // Get keys from keystore. + val loadedServiceName = cert.subject + if (loadedServiceName != serviceName) + throw ConfigurationException("The legal name in the config file doesn't match the stored identity keystore:$serviceName vs $loadedServiceName") + + val certPath = CertificateFactory.getInstance("X509").generateCertPath(keyStore.getCertificateChain(privateKeyAlias).toList()) + // Use composite key instead if exists + // TODO: Use configuration to indicate composite key should be used instead of public key for the identity. + val publicKey = if (keyStore.containsAlias(compositeKeyAlias)) { + Crypto.toSupportedPublicKey(keyStore.getCertificate(compositeKeyAlias).publicKey) + } else { + keyPair.public + } + + partyKeys += keyPair + return Pair(PartyAndCertificate(loadedServiceName, publicKey, cert, certPath), keyPair) + } + + private fun migrateKeysFromFile(keyStore: KeyStoreWrapper, serviceName: X500Name, + pubKeyFile: Path, privKeyFile: Path, compositeKeyFile:Path, + privateKeyAlias: String, compositeKeyAlias: String) { + log.info("Migrating $privateKeyAlias from file to keystore...") + // Check that the identity in the config file matches the identity file we have stored to disk. + // Load the private key. + val publicKey = Crypto.decodePublicKey(pubKeyFile.readAll()) + val privateKey = Crypto.decodePrivateKey(privKeyFile.readAll()) + keyStore.saveNewKeyPair(serviceName, privateKeyAlias, KeyPair(publicKey, privateKey)) + // Store composite key separately. + if (compositeKeyFile.exists()) { + keyStore.savePublicKey(serviceName, compositeKeyAlias, Crypto.decodePublicKey(compositeKeyFile.readAll())) + } + log.info("Finish migrating $privateKeyAlias from file to keystore.") } private fun getTestPartyAndCertificate(party: Party, trustRoot: CertificateAndKeyPair): PartyAndCertificate { @@ -777,34 +763,29 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, protected open fun generateKeyPair() = cryptoGenerateKeyPair() - private fun createAttachmentStorage(): NodeAttachmentService { - val attachmentsDir = (configuration.baseDirectory / "attachments").createDirectories() - return NodeAttachmentService(attachmentsDir, configuration.dataSourceProperties, services.monitoringService.metrics) - } - private inner class ServiceHubInternalImpl : ServiceHubInternal, SingletonSerializeAsToken() { override val rpcFlows = ArrayList>>() - override val uploaders = ArrayList() override val stateMachineRecordedTransactionMapping = DBTransactionMappingStorage() override val auditService = DummyAuditService() override val monitoringService = MonitoringService(MetricRegistry()) override val validatedTransactions = makeTransactionStorage() override val transactionVerifierService by lazy { makeTransactionVerifierService() } override val networkMapCache by lazy { InMemoryNetworkMapCache(this) } - override val vaultService by lazy { NodeVaultService(this, configuration.dataSourceProperties) } + override val vaultService by lazy { NodeVaultService(this, configuration.dataSourceProperties, configuration.database) } override val vaultQueryService by lazy { - HibernateVaultQueryImpl(HibernateConfiguration(schemaService), vaultService.updatesPublisher) + HibernateVaultQueryImpl(HibernateConfiguration(schemaService, configuration.database ?: Properties(), { identityService }), vaultService.updatesPublisher) } // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with // the identity key. But the infrastructure to make that easy isn't here yet. override val keyManagementService by lazy { makeKeyManagementService(identityService) } - override val schedulerService by lazy { NodeSchedulerService(this, unfinishedSchedules = busyNodeLatch) } + override val schedulerService by lazy { NodeSchedulerService(this, unfinishedSchedules = busyNodeLatch, serverThread = serverThread) } override val identityService by lazy { - val keyStoreWrapper = KeyStoreWrapper(configuration.trustStoreFile, configuration.trustStorePassword) + val trustStore = KeyStoreWrapper(configuration.trustStoreFile, configuration.trustStorePassword) + val caKeyStore = KeyStoreWrapper(configuration.nodeKeystore, configuration.keyStorePassword) makeIdentityService( - keyStoreWrapper.keyStore.getCertificate(X509Utilities.CORDA_ROOT_CA)!! as X509Certificate, - keyStoreWrapper.certificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA), + trustStore.getX509Certificate(X509Utilities.CORDA_ROOT_CA).cert, + caKeyStore.certificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA), info.legalIdentityAndCert) } override val attachments: AttachmentStorage get() = this@AbstractNode.attachments @@ -812,7 +793,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, override val clock: Clock get() = platformClock override val myInfo: NodeInfo get() = info override val schemaService by lazy { NodeSchemaService(pluginRegistries.flatMap { it.requiredSchemas }.toSet()) } - override val database: Database get() = this@AbstractNode.database + override val database: CordaPersistence get() = this@AbstractNode.database override val configuration: NodeConfiguration get() = this@AbstractNode.configuration override fun cordaService(type: Class): T { @@ -833,21 +814,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration, super.recordTransactions(txs) } } + override fun jdbcSession(): Connection = database.createSession() } } - -private class KeyStoreWrapper(val keyStore: KeyStore, val storePath: Path, private val storePassword: String) { - constructor(storePath: Path, storePassword: String) : this(KeyStoreUtilities.loadKeyStore(storePath, storePassword), storePath, storePassword) - - fun certificateAndKeyPair(alias: String): CertificateAndKeyPair? { - return if (keyStore.containsAlias(alias)) keyStore.getCertificateAndKeyPair(alias, storePassword) else null - } - - fun save(serviceName: X500Name, privateKeyAlias: String, keyPair: KeyPair) { - val clientCA = keyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, storePassword) - val cert = X509Utilities.createCertificate(CertificateType.IDENTITY, clientCA.certificate, clientCA.keyPair, serviceName, keyPair.public).cert - keyStore.addOrReplaceKey(privateKeyAlias, keyPair.private, storePassword.toCharArray(), arrayOf(cert, *keyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA))) - keyStore.save(storePath, storePassword) - } -} diff --git a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt index f881064e27..d934483095 100644 --- a/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt +++ b/node/src/main/kotlin/net/corda/node/internal/CordaRPCOpsImpl.kt @@ -1,6 +1,5 @@ package net.corda.node.internal -import net.corda.core.contracts.Amount import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.UpgradedContract @@ -24,14 +23,12 @@ import net.corda.node.services.messaging.requirePermission import net.corda.node.services.startFlowPermission import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.StateMachineManager -import net.corda.node.utilities.transaction +import net.corda.node.utilities.CordaPersistence import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import rx.Observable import java.io.InputStream import java.security.PublicKey import java.time.Instant -import java.util.* /** * Server side implementations of RPCs available to MQ based client tools. Execution takes place on the server @@ -40,7 +37,7 @@ import java.util.* class CordaRPCOpsImpl( private val services: ServiceHubInternal, private val smm: StateMachineManager, - private val database: Database + private val database: CordaPersistence ) : CordaRPCOps { override fun networkMapFeed(): DataFeed, NetworkMapCache.MapChange> { return database.transaction { @@ -48,13 +45,6 @@ class CordaRPCOpsImpl( } } - override fun vaultAndUpdates(): DataFeed>, Vault.Update> { - return database.transaction { - val (vault, updates) = services.vaultService.track() - DataFeed(vault.states.toList(), updates) - } - } - override fun vaultQueryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, @@ -68,7 +58,7 @@ class CordaRPCOpsImpl( override fun vaultTrackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, - contractType: Class): DataFeed, Vault.Update> { + contractType: Class): DataFeed, Vault.Update> { return database.transaction { services.vaultQueryService._trackBy(criteria, paging, sorting, contractType) } @@ -112,18 +102,12 @@ class CordaRPCOpsImpl( } } - override fun getCashBalances(): Map> { - return database.transaction { - services.vaultService.cashBalances - } - } - override fun startTrackedFlowDynamic(logicType: Class>, vararg args: Any?): FlowProgressHandle { val stateMachine = startFlow(logicType, args) return FlowProgressHandleImpl( id = stateMachine.id, returnValue = stateMachine.resultFuture, - progress = stateMachine.logic.track()?.second ?: Observable.empty() + progress = stateMachine.logic.track()?.updates ?: Observable.empty() ) } @@ -164,19 +148,9 @@ class CordaRPCOpsImpl( override fun authoriseContractUpgrade(state: StateAndRef<*>, upgradedContractClass: Class>) = services.vaultService.authoriseContractUpgrade(state, upgradedContractClass) override fun deauthoriseContractUpgrade(state: StateAndRef<*>) = services.vaultService.deauthoriseContractUpgrade(state) override fun currentNodeTime(): Instant = Instant.now(services.clock) - @Suppress("OverridingDeprecatedMember", "DEPRECATION") - override fun uploadFile(dataType: String, name: String?, file: InputStream): String { - val acceptor = services.uploaders.firstOrNull { it.accepts(dataType) } - return database.transaction { - acceptor?.upload(file) ?: throw RuntimeException("Cannot find file upload acceptor for $dataType") - } - } - override fun waitUntilRegisteredWithNetworkMap() = services.networkMapCache.mapServiceRegistered + override fun partyFromAnonymous(party: AbstractParty): Party? = services.identityService.partyFromAnonymous(party) override fun partyFromKey(key: PublicKey) = services.identityService.partyFromKey(key) - @Suppress("DEPRECATION") - @Deprecated("Use partyFromX500Name instead") - override fun partyFromName(name: String) = services.identityService.partyFromName(name) override fun partyFromX500Name(x500Name: X500Name) = services.identityService.partyFromX500Name(x500Name) override fun partiesFromName(query: String, exactMatch: Boolean): Set = services.identityService.partiesFromName(query, exactMatch) override fun nodeIdentityFromParty(party: AbstractParty): NodeInfo? = services.networkMapCache.getNodeByLegalIdentity(party) diff --git a/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt b/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt index cfe0682a77..506a923095 100644 --- a/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt +++ b/node/src/main/kotlin/net/corda/node/internal/EnterpriseNode.kt @@ -2,8 +2,8 @@ package net.corda.node.internal import com.jcraft.jsch.JSch import com.jcraft.jsch.JSchException +import net.corda.core.internal.Emoji import net.corda.core.node.services.ServiceInfo -import net.corda.core.utilities.Emoji import net.corda.core.utilities.loggerFor import net.corda.node.VersionInfo import net.corda.node.services.config.FullNodeConfiguration diff --git a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt index 06a0a7cc61..aaa5053627 100644 --- a/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt +++ b/node/src/main/kotlin/net/corda/node/internal/InitiatedFlowFactory.kt @@ -2,26 +2,14 @@ package net.corda.node.internal import net.corda.core.flows.FlowLogic import net.corda.core.identity.Party -import net.corda.node.services.statemachine.SessionInit -interface InitiatedFlowFactory> { - fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): F +sealed class InitiatedFlowFactory> { + protected abstract val factory: (Party) -> F + fun createFlow(otherParty: Party): F = factory(otherParty) - data class Core>(val factory: (Party, Int) -> F) : InitiatedFlowFactory { - override fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): F { - return factory(otherParty, platformVersion) - } - } - - data class CorDapp>(val version: Int, val factory: (Party) -> F) : InitiatedFlowFactory { - override fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): F { - // TODO Add support for multiple versions of the same flow when CorDapps are loaded in separate class loaders - if (sessionInit.flowVerison == version) return factory(otherParty) - throw SessionRejectException( - "Version not supported", - "Version mismatch - ${sessionInit.initiatingFlowClass} is only registered for version $version") - } - } + data class Core>(override val factory: (Party) -> F) : InitiatedFlowFactory() + data class CorDapp>(val flowVersion: Int, + val appName: String, + override val factory: (Party) -> F) : InitiatedFlowFactory() } -class SessionRejectException(val rejectMessage: String, val logMessage: String) : Exception() diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 0fb3e90970..a59560cf3b 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -1,19 +1,18 @@ package net.corda.node.internal import com.codahale.metrics.JmxReporter -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.* +import net.corda.core.concurrent.CordaFuture +import net.corda.core.internal.concurrent.doneFuture +import net.corda.core.internal.concurrent.flatMap +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.concurrent.thenMatch import net.corda.core.messaging.RPCOps import net.corda.core.node.ServiceHub import net.corda.core.node.services.ServiceInfo -import net.corda.core.seconds -import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.loggerFor -import net.corda.core.utilities.parseNetworkHostAndPort -import net.corda.core.utilities.trace +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.utilities.* import net.corda.node.VersionInfo +import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.NodeClock import net.corda.node.services.RPCUserService import net.corda.node.services.RPCUserServiceImpl @@ -33,6 +32,7 @@ import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.internal.ShutdownHook import net.corda.nodeapi.internal.addShutdownHook +import net.corda.nodeapi.internal.serialization.* import org.apache.activemq.artemis.api.core.ActiveMQNotConnectedException import org.apache.activemq.artemis.api.core.RoutingType import org.apache.activemq.artemis.api.core.client.ActiveMQClient @@ -58,7 +58,8 @@ import kotlin.system.exitProcess open class Node(override val configuration: FullNodeConfiguration, advertisedServices: Set, val versionInfo: VersionInfo, - clock: Clock = NodeClock()) : AbstractNode(configuration, advertisedServices, clock) { + clock: Clock = NodeClock(), + val initialiseSerialization: Boolean = true) : AbstractNode(configuration, advertisedServices, clock) { companion object { private val logger = loggerFor() var renderBasicInfoToConsole = true @@ -255,8 +256,8 @@ open class Node(override val configuration: FullNodeConfiguration, * Insert an initial step in the registration process which will throw an exception if a non-recoverable error is * encountered when trying to connect to the network map node. */ - override fun registerWithNetworkMap(): ListenableFuture { - val networkMapConnection = messageBroker?.networkMapConnectionFuture ?: Futures.immediateFuture(Unit) + override fun registerWithNetworkMap(): CordaFuture { + val networkMapConnection = messageBroker?.networkMapConnectionFuture ?: doneFuture(Unit) return networkMapConnection.flatMap { super.registerWithNetworkMap() } } @@ -295,9 +296,13 @@ open class Node(override val configuration: FullNodeConfiguration, super.initialiseDatabasePersistence(insideTransaction) } - val startupComplete: ListenableFuture = SettableFuture.create() + private val _startupComplete = openFuture() + val startupComplete: CordaFuture get() = _startupComplete - override fun start(): Node { + override fun start() { + if (initialiseSerialization) { + initialiseSerialization() + } super.start() networkMapRegistrationFuture.thenMatch({ @@ -320,13 +325,23 @@ open class Node(override val configuration: FullNodeConfiguration, build(). start() - (startupComplete as SettableFuture).set(Unit) + _startupComplete.set(Unit) } }, {}) shutdownHook = addShutdownHook { stop() } - return this + } + + private fun initialiseSerialization() { + SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { + registerScheme(KryoServerSerializationScheme(this)) + registerScheme(AMQPServerSerializationScheme()) + } + SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT + SerializationDefaults.RPC_SERVER_CONTEXT = KRYO_RPC_SERVER_CONTEXT + SerializationDefaults.STORAGE_CONTEXT = KRYO_STORAGE_CONTEXT + SerializationDefaults.CHECKPOINT_CONTEXT = KRYO_CHECKPOINT_CONTEXT } /** Starts a blocking event loop for message dispatch. */ @@ -334,12 +349,6 @@ open class Node(override val configuration: FullNodeConfiguration, (network as NodeMessagingClient).run(messageBroker!!.serverControl) } - // TODO: Do we really need setup? - override fun setup(): Node { - super.setup() - return this - } - private var shutdown = false override fun stop() { diff --git a/node/src/main/kotlin/net/corda/node/internal/NodeStartup.kt b/node/src/main/kotlin/net/corda/node/internal/NodeStartup.kt index 5fe737ea35..eab65b579a 100644 --- a/node/src/main/kotlin/net/corda/node/internal/NodeStartup.kt +++ b/node/src/main/kotlin/net/corda/node/internal/NodeStartup.kt @@ -5,12 +5,13 @@ import com.jcraft.jsch.JSch import com.jcraft.jsch.JSchException import com.typesafe.config.ConfigException import joptsimple.OptionException -import net.corda.core.* import net.corda.core.crypto.commonName import net.corda.core.crypto.orgName -import net.corda.node.VersionInfo +import net.corda.core.internal.concurrent.thenMatch +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div +import net.corda.core.internal.* import net.corda.core.node.services.ServiceInfo -import net.corda.core.utilities.Emoji import net.corda.core.utilities.loggerFor import net.corda.node.* import net.corda.node.serialization.NodeClock @@ -32,6 +33,7 @@ import java.lang.management.ManagementFactory import java.net.InetAddress import java.nio.file.Path import java.nio.file.Paths +import java.time.LocalDate import java.util.* import kotlin.system.exitProcess @@ -168,7 +170,7 @@ open class NodeStartup(val args: Array) { } open protected fun banJavaSerialisation(conf: FullNodeConfiguration) { - SerialFilter.install(if (conf.bftReplicaId != null) ::bftSMaRtSerialFilter else ::defaultSerialFilter) + SerialFilter.install(if (conf.bftSMaRt.isValid()) ::bftSMaRtSerialFilter else ::defaultSerialFilter) } open protected fun getVersionInfo(): VersionInfo { @@ -307,6 +309,12 @@ open class NodeStartup(val args: Array) { "Top tip: never say \"oops\", instead\nalways say \"Ah, Interesting!\"", "Computers are useless. They can only\ngive you answers. -- Picasso" ) + + // TODO: Delete this after CordaCon. + val cordaCon2017date = LocalDate.of(2017, 9, 12) + val cordaConBanner = if (LocalDate.now() < cordaCon2017date) + "${Emoji.soon} Register for our Free CordaCon event : see https://goo.gl/Z15S8W" else "" + if (Emoji.hasEmojiTerminal) messages += "Kind of like a regular database but\nwith emojis, colours and ascii art. ${Emoji.coolGuy}" val (msg1, msg2) = messages.randomOrNull()!!.split('\n') @@ -320,9 +328,9 @@ open class NodeStartup(val args: Array) { a("--- ${versionInfo.vendor} ${versionInfo.releaseVersion} (${versionInfo.revision.take(7)}) -----------------------------------------------"). newline(). newline(). - a("${Emoji.books}New! ").reset().a("Training now available worldwide, see https://corda.net/corda-training/"). + a(cordaConBanner). newline(). reset()) } } -} \ No newline at end of file +} diff --git a/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt b/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt new file mode 100644 index 0000000000..14b5ef144e --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/serialization/SerializationScheme.kt @@ -0,0 +1,27 @@ +package net.corda.node.serialization + +import com.esotericsoftware.kryo.pool.KryoPool +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationFactory +import net.corda.core.utilities.ByteSequence +import net.corda.node.services.messaging.RpcServerObservableSerializer +import net.corda.nodeapi.RPCKryo +import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme +import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer +import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 + +class KryoServerSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) { + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return byteSequence == KryoHeaderV0_1 && target != SerializationContext.UseCase.RPCClient + } + + override fun rpcClientKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } + + override fun rpcServerKryoPool(context: SerializationContext): KryoPool { + return KryoPool.Builder { + DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, serializationFactory, context)).apply { classLoader = context.deserializationClassLoader } + }.build() + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/CoreFlowHandlers.kt b/node/src/main/kotlin/net/corda/node/services/CoreFlowHandlers.kt index 3359d55e58..bbded81fa8 100644 --- a/node/src/main/kotlin/net/corda/node/services/CoreFlowHandlers.kt +++ b/node/src/main/kotlin/net/corda/node/services/CoreFlowHandlers.kt @@ -2,59 +2,15 @@ package net.corda.node.services import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionType +import net.corda.core.contracts.UpgradeCommand import net.corda.core.contracts.UpgradedContract import net.corda.core.contracts.requireThat -import net.corda.core.crypto.SecureHash -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowLogic +import net.corda.core.flows.* +import net.corda.core.identity.PartyAndCertificate import net.corda.core.identity.Party import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.unwrap -import net.corda.flows.* - -/** - * This class sets up network message handlers for requests from peers for data keyed by hash. It is a piece of simple - * glue that sits between the network layer and the database layer. - * - * Note that in our data model, to be able to name a thing by hash automatically gives the power to request it. There - * are no access control lists. If you want to keep some data private, then you must be careful who you give its name - * to, and trust that they will not pass the name onwards. If someone suspects some data might exist but does not have - * its name, then the 256-bit search space they'd have to cover makes it physically impossible to enumerate, and as - * such the hash of a piece of data can be seen as a type of password allowing access to it. - * - * Additionally, because nodes do not store invalid transactions, requesting such a transaction will always yield null. - */ -class FetchTransactionsHandler(otherParty: Party) : FetchDataHandler(otherParty) { - override fun getData(id: SecureHash): SignedTransaction? { - return serviceHub.validatedTransactions.getTransaction(id) - } -} - -// TODO: Use Artemis message streaming support here, called "large messages". This avoids the need to buffer. -class FetchAttachmentsHandler(otherParty: Party) : FetchDataHandler(otherParty) { - override fun getData(id: SecureHash): ByteArray? { - return serviceHub.attachments.openAttachment(id)?.open()?.readBytes() - } -} - -abstract class FetchDataHandler(val otherParty: Party) : FlowLogic() { - @Suspendable - @Throws(FetchDataFlow.HashNotFound::class) - override fun call() { - val request = receive(otherParty).unwrap { - if (it.hashes.isEmpty()) throw FlowException("Empty hash list") - it - } - val response = request.hashes.map { - getData(it) ?: throw FetchDataFlow.HashNotFound(it) - } - send(otherParty, response) - } - - protected abstract fun getData(id: SecureHash): T? -} // TODO: We should have a whitelist of contracts we're willing to accept at all, and reject if the transaction // includes us in any outside that list. Potentially just if it includes any outside that list at all. @@ -63,9 +19,8 @@ abstract class FetchDataHandler(val otherParty: Party) : FlowLogic( class NotifyTransactionHandler(val otherParty: Party) : FlowLogic() { @Suspendable override fun call() { - val request = receive(otherParty).unwrap { it } - subFlow(ResolveTransactionsFlow(request.tx, otherParty)) - serviceHub.recordTransactions(request.tx) + val stx = subFlow(ReceiveTransactionFlow(otherParty)) + serviceHub.recordTransactions(stx) } } @@ -77,50 +32,45 @@ class NotaryChangeHandler(otherSide: Party) : AbstractStateReplacementFlow.Accep * and is also in a geographically convenient location we can just automatically approve the change. * TODO: In more difficult cases this should call for human attention to manually verify and approve the proposal */ - override fun verifyProposal(proposal: AbstractStateReplacementFlow.Proposal): Unit { + override fun verifyProposal(stx: SignedTransaction, proposal: AbstractStateReplacementFlow.Proposal): Unit { val state = proposal.stateRef - val proposedTx = proposal.stx.tx + val proposedTx = stx.resolveNotaryChangeTransaction(serviceHub) + val newNotary = proposal.modification - if (proposedTx.type !is TransactionType.NotaryChange) { - throw StateReplacementException("The proposed transaction is not a notary change transaction.") + if (state !in proposedTx.inputs.map { it.ref }) { + throw StateReplacementException("The proposed state $state is not in the proposed transaction inputs") } - val newNotary = proposal.modification + // TODO: load and compare against notary whitelist from config. Remove the check below val isNotary = serviceHub.networkMapCache.notaryNodes.any { it.notaryIdentity == newNotary } if (!isNotary) { throw StateReplacementException("The proposed node $newNotary does not run a Notary service") } - if (state !in proposedTx.inputs) { - throw StateReplacementException("The proposed state $state is not in the proposed transaction inputs") - } - -// // An example requirement -// val blacklist = listOf("Evil Notary") -// checkProposal(newNotary.name !in blacklist) { -// "The proposed new notary $newNotary is not trusted by the party" -// } } } class ContractUpgradeHandler(otherSide: Party) : AbstractStateReplacementFlow.Acceptor>>(otherSide) { @Suspendable @Throws(StateReplacementException::class) - override fun verifyProposal(proposal: AbstractStateReplacementFlow.Proposal>>) { + override fun verifyProposal(stx: SignedTransaction, proposal: AbstractStateReplacementFlow.Proposal>>) { // Retrieve signed transaction from our side, we will apply the upgrade logic to the transaction on our side, and // verify outputs matches the proposed upgrade. - val stx = subFlow(FetchTransactionsFlow(setOf(proposal.stateRef.txhash), otherSide)).fromDisk.singleOrNull() - requireNotNull(stx) { "We don't have a copy of the referenced state" } - val oldStateAndRef = stx!!.tx.outRef(proposal.stateRef.index) + val ourSTX = serviceHub.validatedTransactions.getTransaction(proposal.stateRef.txhash) + requireNotNull(ourSTX) { "We don't have a copy of the referenced state" } + val oldStateAndRef = ourSTX!!.tx.outRef(proposal.stateRef.index) val authorisedUpgrade = serviceHub.vaultService.getAuthorisedContractUpgrade(oldStateAndRef.ref) ?: throw IllegalStateException("Contract state upgrade is unauthorised. State hash : ${oldStateAndRef.ref}") - val proposedTx = proposal.stx.tx - val expectedTx = ContractUpgradeFlow.assembleBareTx(oldStateAndRef, proposal.modification).toWireTransaction() + val proposedTx = stx.tx + val expectedTx = ContractUpgradeFlow.assembleBareTx(oldStateAndRef, proposal.modification, proposedTx.privacySalt).toWireTransaction() requireThat { "The instigator is one of the participants" using (otherSide in oldStateAndRef.state.data.participants) "The proposed upgrade ${proposal.modification.javaClass} is a trusted upgrade path" using (proposal.modification == authorisedUpgrade) "The proposed tx matches the expected tx for this upgrade" using (proposedTx == expectedTx) } - ContractUpgradeFlow.verify(oldStateAndRef.state.data, expectedTx.outRef(0).state.data, expectedTx.commands.single()) + ContractUpgradeFlow.verify( + oldStateAndRef.state.data, + expectedTx.outRef(0).state.data, + expectedTx.toLedgerTransaction(serviceHub).commandsOfType().single()) } } @@ -137,10 +87,8 @@ class TransactionKeyHandler(val otherSide: Party, val revocationEnabled: Boolean val revocationEnabled = false progressTracker.currentStep = SENDING_KEY val legalIdentityAnonymous = serviceHub.keyManagementService.freshKeyAndCert(serviceHub.myInfo.legalIdentityAndCert, revocationEnabled) - val otherSideAnonymous = sendAndReceive(otherSide, legalIdentityAnonymous).unwrap { TransactionKeyFlow.validateIdentity(otherSide, it) } - val (certPath, theirCert, txIdentity) = otherSideAnonymous - // Validate then store their identity so that we can prove the key in the transaction is owned by the - // counterparty. - serviceHub.identityService.registerAnonymousIdentity(txIdentity, otherSide, certPath) + sendAndReceive(otherSide, legalIdentityAnonymous).unwrap { confidentialIdentity -> + TransactionKeyFlow.validateAndRegisterIdentity(serviceHub.identityService, otherSide, confidentialIdentity) + } } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/api/AbstractNodeService.kt b/node/src/main/kotlin/net/corda/node/services/api/AbstractNodeService.kt index 5c974fcfb1..73acab6e44 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/AbstractNodeService.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/AbstractNodeService.kt @@ -1,6 +1,5 @@ package net.corda.node.services.api -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize diff --git a/node/src/main/kotlin/net/corda/node/services/api/AcceptsFileUpload.kt b/node/src/main/kotlin/net/corda/node/services/api/AcceptsFileUpload.kt deleted file mode 100644 index a3967ec2ba..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/api/AcceptsFileUpload.kt +++ /dev/null @@ -1,17 +0,0 @@ -package net.corda.node.services.api - -import net.corda.core.node.services.FileUploader - -/** - * A service that implements AcceptsFileUpload can have new binary data provided to it via an HTTP upload. - */ -// TODO This is no longer used and can be removed -interface AcceptsFileUpload : FileUploader { - /** A string that prefixes the URLs, e.g. "attachments" or "interest-rates". Should be OK for URLs. */ - val dataTypePrefix: String - - /** What file extensions are acceptable for the file to be handed to upload() */ - val acceptableFileExtensions: List - - override fun accepts(type: String) = type == dataTypePrefix -} diff --git a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt index 861ff1db34..2005010b83 100644 --- a/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt +++ b/node/src/main/kotlin/net/corda/node/services/api/ServiceHubInternal.kt @@ -1,7 +1,7 @@ package net.corda.node.services.api import com.google.common.annotations.VisibleForTesting -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic @@ -12,7 +12,6 @@ import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.core.node.NodeInfo import net.corda.core.node.PluginServiceHub -import net.corda.core.node.services.FileUploader import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.TransactionStorage import net.corda.core.serialization.CordaSerializable @@ -23,7 +22,13 @@ import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.messaging.MessagingService import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl import net.corda.node.services.statemachine.FlowStateMachineImpl -import org.jetbrains.exposed.sql.Database +import net.corda.node.utilities.CordaPersistence + +/** + * Session ID to use for services listening for the first message in a session (before a + * specific session ID has been established). + */ +val DEFAULT_SESSION_ID = 0L interface NetworkMapCacheInternal : NetworkMapCache { /** @@ -31,7 +36,7 @@ interface NetworkMapCacheInternal : NetworkMapCache { * @param network the network messaging service. * @param service the network map service to fetch current state from. */ - fun deregisterForUpdates(network: MessagingService, service: NodeInfo): ListenableFuture + fun deregisterForUpdates(network: MessagingService, service: NodeInfo): CordaFuture /** * Add a network map service; fetches a copy of the latest map from the service and subscribes to any further @@ -43,7 +48,7 @@ interface NetworkMapCacheInternal : NetworkMapCache { * version is less than or equal to the given version, no update is fetched. */ fun addMapService(network: MessagingService, networkMapAddress: SingleMessageRecipient, - subscribe: Boolean, ifChangedSinceVer: Int? = null): ListenableFuture + subscribe: Boolean, ifChangedSinceVer: Int? = null): CordaFuture /** Adds a node to the local cache (generally only used for adding ourselves). */ fun addNode(node: NodeInfo) @@ -81,16 +86,12 @@ interface ServiceHubInternal : PluginServiceHub { val auditService: AuditService val rpcFlows: List>> val networkService: MessagingService - val database: Database + val database: CordaPersistence val configuration: NodeConfiguration - @Suppress("DEPRECATION") - @Deprecated("This service will be removed in a future milestone") - val uploaders: List - override fun recordTransactions(txs: Iterable) { + require (txs.any()) { "No transactions passed in for recording" } val recordedTransactions = txs.filter { validatedTransactions.addTransaction(it) } - require(recordedTransactions.isNotEmpty()) { "No transactions passed in for recording" } val stateMachineRunId = FlowStateMachineImpl.currentStateMachine()?.id if (stateMachineRunId != null) { recordedTransactions.forEach { @@ -99,7 +100,9 @@ interface ServiceHubInternal : PluginServiceHub { } else { log.warn("Transactions recorded from outside of a state machine") } - vaultService.notifyAll(recordedTransactions.map { it.tx }) + + val toNotify = recordedTransactions.map { if (it.isNotaryChangeTransaction()) it.notaryChangeTx else it.tx } + vaultService.notifyAll(toNotify) } /** diff --git a/node/src/main/kotlin/net/corda/node/services/config/ConfigUtilities.kt b/node/src/main/kotlin/net/corda/node/services/config/ConfigUtilities.kt index 8fdae16bdb..f66e53a7be 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/ConfigUtilities.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/ConfigUtilities.kt @@ -1,25 +1,27 @@ -// TODO: Remove when configureTestSSL() is moved. -@file:JvmName("ConfigUtilities") - package net.corda.node.services.config import com.typesafe.config.Config import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigParseOptions import com.typesafe.config.ConfigRenderOptions -import net.corda.core.copyTo -import net.corda.core.createDirectories -import net.corda.core.crypto.KeyStoreUtilities -import net.corda.core.crypto.X509Utilities -import net.corda.core.div -import net.corda.core.exists +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SignatureScheme +import net.corda.core.internal.copyTo +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div +import net.corda.core.internal.exists import net.corda.core.utilities.loggerFor +import net.corda.node.utilities.* import net.corda.nodeapi.config.SSLConfiguration import org.bouncycastle.asn1.x500.X500Name +import org.bouncycastle.asn1.x509.GeneralName +import org.bouncycastle.asn1.x509.GeneralSubtree +import org.bouncycastle.asn1.x509.NameConstraints import java.nio.file.Path +import java.security.KeyStore -fun configOf(vararg pairs: Pair) = ConfigFactory.parseMap(mapOf(*pairs)) -operator fun Config.plus(overrides: Map) = ConfigFactory.parseMap(overrides).withFallback(this) +fun configOf(vararg pairs: Pair): Config = ConfigFactory.parseMap(mapOf(*pairs)) +operator fun Config.plus(overrides: Map): Config = ConfigFactory.parseMap(overrides).withFallback(this) object ConfigHelper { private val log = loggerFor() @@ -55,7 +57,56 @@ fun SSLConfiguration.configureDevKeyAndTrustStores(myLegalName: X500Name) { javaClass.classLoader.getResourceAsStream("net/corda/node/internal/certificates/cordatruststore.jks").copyTo(trustStoreFile) } if (!sslKeystore.exists() || !nodeKeystore.exists()) { - val caKeyStore = KeyStoreUtilities.loadKeyStore(javaClass.classLoader.getResourceAsStream("net/corda/node/internal/certificates/cordadevcakeys.jks"), "cordacadevpass") - X509Utilities.createKeystoreForCordaNode(sslKeystore, nodeKeystore, keyStorePassword, keyStorePassword, caKeyStore, "cordacadevkeypass", myLegalName) + val caKeyStore = loadKeyStore(javaClass.classLoader.getResourceAsStream("net/corda/node/internal/certificates/cordadevcakeys.jks"), "cordacadevpass") + createKeystoreForCordaNode(sslKeystore, nodeKeystore, keyStorePassword, keyStorePassword, caKeyStore, "cordacadevkeypass", myLegalName) } } + +/** + * An all in wrapper to manufacture a server certificate and keys all stored in a KeyStore suitable for running TLS on the local machine. + * @param sslKeyStorePath KeyStore path to save ssl key and cert to. + * @param clientCAKeystorePath KeyStore path to save client CA key and cert to. + * @param storePassword access password for KeyStore. + * @param keyPassword PrivateKey access password for the generated keys. + * It is recommended that this is the same as the storePassword as most TLS libraries assume they are the same. + * @param caKeyStore KeyStore containing CA keys generated by createCAKeyStoreAndTrustStore. + * @param caKeyPassword password to unlock private keys in the CA KeyStore. + * @return The KeyStore created containing a private key, certificate chain and root CA public cert for use in TLS applications. + */ +fun createKeystoreForCordaNode(sslKeyStorePath: Path, + clientCAKeystorePath: Path, + storePassword: String, + keyPassword: String, + caKeyStore: KeyStore, + caKeyPassword: String, + legalName: X500Name, + signatureScheme: SignatureScheme = X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) { + + val rootCACert = caKeyStore.getX509Certificate(X509Utilities.CORDA_ROOT_CA) + val (intermediateCACert, intermediateCAKeyPair) = caKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_INTERMEDIATE_CA, caKeyPassword) + + val clientKey = Crypto.generateKeyPair(signatureScheme) + val nameConstraints = NameConstraints(arrayOf(GeneralSubtree(GeneralName(GeneralName.directoryName, legalName))), arrayOf()) + val clientCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, intermediateCACert, intermediateCAKeyPair, legalName, clientKey.public, nameConstraints = nameConstraints) + + val tlsKey = Crypto.generateKeyPair(signatureScheme) + val clientTLSCert = X509Utilities.createCertificate(CertificateType.TLS, clientCACert, clientKey, legalName, tlsKey.public) + + val keyPass = keyPassword.toCharArray() + + val clientCAKeystore = loadOrCreateKeyStore(clientCAKeystorePath, storePassword) + clientCAKeystore.addOrReplaceKey( + X509Utilities.CORDA_CLIENT_CA, + clientKey.private, + keyPass, + arrayOf(clientCACert, intermediateCACert, rootCACert)) + clientCAKeystore.save(clientCAKeystorePath, storePassword) + + val tlsKeystore = loadOrCreateKeyStore(sslKeyStorePath, storePassword) + tlsKeystore.addOrReplaceKey( + X509Utilities.CORDA_CLIENT_TLS, + tlsKey.private, + keyPass, + arrayOf(clientTLSCert, clientCACert, intermediateCACert, rootCACert)) + tlsKeystore.save(sslKeyStorePath, storePassword) +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt index b9afbc1553..8c51097c8d 100644 --- a/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/config/NodeConfiguration.kt @@ -13,6 +13,11 @@ import java.net.URL import java.nio.file.Path import java.util.* +/** @param exposeRaces for testing only, so its default is not in reference.conf but here. */ +data class BFTSMaRtConfiguration(val replicaId: Int, val debug: Boolean, val exposeRaces: Boolean = false) { + fun isValid() = replicaId >= 0 +} + interface NodeConfiguration : NodeSSLConfiguration { val myLegalName: X500Name val networkMapService: NetworkMapInfo? @@ -20,13 +25,14 @@ interface NodeConfiguration : NodeSSLConfiguration { val emailAddress: String val exportJMXto: String val dataSourceProperties: Properties + val database: Properties? val rpcUsers: List val devMode: Boolean val certificateSigningService: URL val certificateChainCheckPolicies: List val verifierType: VerifierType val messageRedeliveryDelaySeconds: Int - val bftReplicaId: Int? + val bftSMaRt: BFTSMaRtConfiguration val notaryNodeAddress: NetworkHostAndPort? val notaryClusterAddresses: List } @@ -42,6 +48,7 @@ data class FullNodeConfiguration( override val keyStorePassword: String, override val trustStorePassword: String, override val dataSourceProperties: Properties, + override val database: Properties?, override val certificateSigningService: URL, override val networkMapService: NetworkMapInfo?, override val minimumPlatformVersion: Int = 1, @@ -57,7 +64,7 @@ data class FullNodeConfiguration( // Instead this should be a Boolean indicating whether that broker is an internal one started by the node or an external one val messagingServerAddress: NetworkHostAndPort?, val extraAdvertisedServiceIds: List, - override val bftReplicaId: Int?, + override val bftSMaRt: BFTSMaRtConfiguration, override val notaryNodeAddress: NetworkHostAndPort?, override val notaryClusterAddresses: List, override val certificateChainCheckPolicies: List, diff --git a/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt index 4d2ba9f95c..d0f300b9ee 100644 --- a/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/database/HibernateConfiguration.kt @@ -1,9 +1,12 @@ package net.corda.node.services.database +import net.corda.core.internal.castIfPossible +import net.corda.core.node.services.IdentityService import net.corda.core.schemas.MappedSchema -import net.corda.core.utilities.debug +import net.corda.core.schemas.converters.AbstractPartyToX500NameAsStringConverter import net.corda.core.utilities.loggerFor import net.corda.node.services.api.SchemaService +import net.corda.node.utilities.DatabaseTransactionManager import org.hibernate.SessionFactory import org.hibernate.boot.MetadataSources import org.hibernate.boot.model.naming.Identifier @@ -13,13 +16,11 @@ import org.hibernate.cfg.Configuration import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider import org.hibernate.engine.jdbc.env.spi.JdbcEnvironment import org.hibernate.service.UnknownUnwrapTypeException -import org.jetbrains.exposed.sql.transactions.TransactionManager import java.sql.Connection +import java.util.* import java.util.concurrent.ConcurrentHashMap -class HibernateConfiguration(val schemaService: SchemaService, val useDefaultLogging: Boolean = false) { - constructor(schemaService: SchemaService) : this(schemaService, false) - +class HibernateConfiguration(val schemaService: SchemaService, val databaseProperties: Properties, private val identitySvc: () -> IdentityService) { companion object { val logger = loggerFor() } @@ -57,9 +58,9 @@ class HibernateConfiguration(val schemaService: SchemaService, val useDefaultLog // necessarily remain and would likely be replaced by something like Liquibase. For now it is very convenient though. // TODO: replace auto schema generation as it isn't intended for production use, according to Hibernate docs. val config = Configuration(metadataSources).setProperty("hibernate.connection.provider_class", HibernateConfiguration.NodeDatabaseConnectionProvider::class.java.name) - .setProperty("hibernate.hbm2ddl.auto", "update") - .setProperty("hibernate.show_sql", "$useDefaultLogging") - .setProperty("hibernate.format_sql", "$useDefaultLogging") + .setProperty("hibernate.hbm2ddl.auto", if (databaseProperties.getProperty("initDatabase","true") == "true") "update" else "validate") + .setProperty("hibernate.format_sql", "true") + schemas.forEach { schema -> // TODO: require mechanism to set schemaOptions (databaseSchema, tablePrefix) which are not global to session schema.mappedTypes.forEach { config.addAnnotatedClass(it) } @@ -78,6 +79,9 @@ class HibernateConfiguration(val schemaService: SchemaService, val useDefaultLog return Identifier.toIdentifier(tablePrefix + default.text, default.isQuoted) } }) + // register custom converters + applyAttributeConverter(AbstractPartyToX500NameAsStringConverter(identitySvc)) + build() } @@ -94,7 +98,7 @@ class HibernateConfiguration(val schemaService: SchemaService, val useDefaultLog // during schema creation / update. class NodeDatabaseConnectionProvider : ConnectionProvider { override fun closeConnection(conn: Connection) { - val tx = TransactionManager.current() + val tx = DatabaseTransactionManager.current() tx.commit() tx.close() } @@ -102,18 +106,13 @@ class HibernateConfiguration(val schemaService: SchemaService, val useDefaultLog override fun supportsAggressiveRelease(): Boolean = true override fun getConnection(): Connection { - val tx = TransactionManager.manager.newTransaction(Connection.TRANSACTION_REPEATABLE_READ) - return tx.connection + return DatabaseTransactionManager.newTransaction().connection } override fun unwrap(unwrapType: Class): T { - try { - return unwrapType.cast(this) - } catch(e: ClassCastException) { - throw UnknownUnwrapTypeException(unwrapType) - } + return unwrapType.castIfPossible(this) ?: throw UnknownUnwrapTypeException(unwrapType) } - override fun isUnwrappableAs(unwrapType: Class<*>?): Boolean = (unwrapType == NodeDatabaseConnectionProvider::class.java) + override fun isUnwrappableAs(unwrapType: Class<*>?): Boolean = unwrapType == NodeDatabaseConnectionProvider::class.java } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/database/KotlinConfigurationTransactionWrapper.kt b/node/src/main/kotlin/net/corda/node/services/database/KotlinConfigurationTransactionWrapper.kt index e9652e9556..78d4380161 100644 --- a/node/src/main/kotlin/net/corda/node/services/database/KotlinConfigurationTransactionWrapper.kt +++ b/node/src/main/kotlin/net/corda/node/services/database/KotlinConfigurationTransactionWrapper.kt @@ -13,7 +13,7 @@ import net.corda.core.schemas.requery.converters.InstantConverter import net.corda.core.schemas.requery.converters.SecureHashConverter import net.corda.core.schemas.requery.converters.StateRefConverter import net.corda.core.schemas.requery.converters.VaultStateStatusConverter -import org.jetbrains.exposed.sql.transactions.TransactionManager +import net.corda.node.utilities.DatabaseTransactionManager import java.sql.Connection import java.util.* import java.util.concurrent.Executor @@ -128,12 +128,7 @@ class KotlinConfigurationTransactionWrapper(private val model: EntityModel, } class CordaDataSourceConnectionProvider(val dataSource: DataSource) : ConnectionProvider { - override fun getConnection(): Connection { - val tx = TransactionManager.manager.currentOrNull() - return CordaConnection( - tx?.connection ?: throw IllegalStateException("Was expecting to find database transaction: must wrap calling code within a transaction.") - ) - } + override fun getConnection(): Connection = CordaConnection(DatabaseTransactionManager.current().connection) } class CordaConnection(val connection: Connection) : Connection by connection { diff --git a/node/src/main/kotlin/net/corda/node/services/database/RequeryConfiguration.kt b/node/src/main/kotlin/net/corda/node/services/database/RequeryConfiguration.kt index 6bbc6ae0bb..8e2cc192a0 100644 --- a/node/src/main/kotlin/net/corda/node/services/database/RequeryConfiguration.kt +++ b/node/src/main/kotlin/net/corda/node/services/database/RequeryConfiguration.kt @@ -3,17 +3,18 @@ package net.corda.node.services.database import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariDataSource import io.requery.Persistable +import io.requery.TransactionIsolation import io.requery.meta.EntityModel import io.requery.sql.KotlinEntityDataStore import io.requery.sql.SchemaModifier import io.requery.sql.TableCreationMode import net.corda.core.utilities.loggerFor -import org.jetbrains.exposed.sql.transactions.TransactionManager +import net.corda.node.utilities.DatabaseTransactionManager import java.sql.Connection import java.util.* import java.util.concurrent.ConcurrentHashMap -class RequeryConfiguration(val properties: Properties, val useDefaultLogging: Boolean = false) { +class RequeryConfiguration(val properties: Properties, val useDefaultLogging: Boolean = false, val databaseProperties: Properties) { companion object { val logger = loggerFor() @@ -40,15 +41,25 @@ class RequeryConfiguration(val properties: Properties, val useDefaultLogging: Bo fun makeSessionFactoryForModel(model: EntityModel): KotlinEntityDataStore { val configuration = KotlinConfigurationTransactionWrapper(model, dataSource, useDefaultLogging = this.useDefaultLogging) val tables = SchemaModifier(configuration) - val mode = TableCreationMode.CREATE_NOT_EXISTS - tables.createTables(mode) + if (databaseProperties.getProperty("initDatabase","true") == "true" ) { + val mode = TableCreationMode.CREATE_NOT_EXISTS + tables.createTables(mode) + } return KotlinEntityDataStore(configuration) } // TODO: remove once Requery supports QUERY WITH COMPOSITE_KEY IN - fun jdbcSession(): Connection { - val ctx = TransactionManager.manager.currentOrNull() - return ctx?.connection ?: throw IllegalStateException("Was expecting to find database transaction: must wrap calling code within a transaction.") - } + fun jdbcSession(): Connection = DatabaseTransactionManager.current().connection } +fun parserTransactionIsolationLevel(property: String?) : TransactionIsolation = + when (property) { + "none" -> TransactionIsolation.NONE + "readUncommitted" -> TransactionIsolation.READ_UNCOMMITTED + "readCommitted" -> TransactionIsolation.READ_COMMITTED + "repeatableRead" -> TransactionIsolation.REPEATABLE_READ + "serializable" -> TransactionIsolation.SERIALIZABLE + else -> { + TransactionIsolation.REPEATABLE_READ + } + } diff --git a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt index b69906afbe..e386a0d61b 100644 --- a/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt +++ b/node/src/main/kotlin/net/corda/node/services/events/NodeSchedulerService.kt @@ -1,7 +1,7 @@ package net.corda.node.services.events import com.google.common.util.concurrent.SettableFuture -import net.corda.core.ThreadBox +import net.corda.core.internal.ThreadBox import net.corda.core.contracts.SchedulableState import net.corda.core.contracts.ScheduledActivity import net.corda.core.contracts.ScheduledStateRef @@ -9,7 +9,6 @@ import net.corda.core.contracts.StateRef import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.core.then import net.corda.core.utilities.loggerFor import net.corda.core.utilities.trace import net.corda.node.services.api.SchedulerService @@ -43,7 +42,8 @@ import javax.annotation.concurrent.ThreadSafe @ThreadSafe class NodeSchedulerService(private val services: ServiceHubInternal, private val schedulerTimerExecutor: Executor = Executors.newSingleThreadExecutor(), - private val unfinishedSchedules: ReusableLatch = ReusableLatch()) + private val unfinishedSchedules: ReusableLatch = ReusableLatch(), + private val serverThread: AffinityExecutor) : SchedulerService, SingletonSerializeAsToken() { companion object { @@ -157,16 +157,14 @@ class NodeSchedulerService(private val services: ServiceHubInternal, } private fun onTimeReached(scheduledState: ScheduledStateRef) { - services.database.transaction { - val scheduledFlow = getScheduledFlow(scheduledState) - if (scheduledFlow != null) { - // TODO Because the flow is executed asynchronously, there is a small window between this tx we're in - // committing and the flow's first checkpoint when it starts in which we can lose the flow if the node - // goes down. - // See discussion in https://github.com/corda/corda/pull/639#discussion_r115257437 - val future = services.startFlow(scheduledFlow, FlowInitiator.Scheduled(scheduledState)).resultFuture - future.then { - unfinishedSchedules.countDown() + serverThread.fetchFrom { + services.database.transaction { + val scheduledFlow = getScheduledFlow(scheduledState) + if (scheduledFlow != null) { + val future = services.startFlow(scheduledFlow, FlowInitiator.Scheduled(scheduledState)).resultFuture + future.then { + unfinishedSchedules.countDown() + } } } } diff --git a/node/src/main/kotlin/net/corda/node/services/identity/InMemoryIdentityService.kt b/node/src/main/kotlin/net/corda/node/services/identity/InMemoryIdentityService.kt index 4438f2e9c5..ff786afd95 100644 --- a/node/src/main/kotlin/net/corda/node/services/identity/InMemoryIdentityService.kt +++ b/node/src/main/kotlin/net/corda/node/services/identity/InMemoryIdentityService.kt @@ -1,9 +1,7 @@ package net.corda.node.services.identity import net.corda.core.contracts.PartyAndReference -import net.corda.core.crypto.Crypto import net.corda.core.crypto.cert -import net.corda.core.crypto.subject import net.corda.core.crypto.toStringShort import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty @@ -21,8 +19,7 @@ import java.security.cert.* import java.util.* import java.util.concurrent.ConcurrentHashMap import javax.annotation.concurrent.ThreadSafe -import javax.security.auth.x500.X500Principal -import kotlin.collections.ArrayList +import kotlin.collections.LinkedHashSet /** * Simple identity service which caches parties and provides functionality for efficient lookup. @@ -32,12 +29,12 @@ import kotlin.collections.ArrayList */ @ThreadSafe class InMemoryIdentityService(identities: Iterable = emptySet(), - certPaths: Map = emptyMap(), + confidentialIdentities: Iterable = emptySet(), override val trustRoot: X509Certificate, vararg caCertificates: X509Certificate) : SingletonSerializeAsToken(), IdentityService { - constructor(identities: Iterable = emptySet(), - certPaths: Map = emptyMap(), - trustRoot: X509CertificateHolder) : this(identities, certPaths, trustRoot.cert) + constructor(wellKnownIdentities: Iterable = emptySet(), + confidentialIdentities: Iterable = emptySet(), + trustRoot: X509CertificateHolder) : this(wellKnownIdentities, confidentialIdentities, trustRoot.cert) companion object { private val log = loggerFor() } @@ -49,49 +46,60 @@ class InMemoryIdentityService(identities: Iterable = emptyS override val trustRootHolder = X509CertificateHolder(trustRoot.encoded) private val trustAnchor: TrustAnchor = TrustAnchor(trustRoot, null) private val keyToParties = ConcurrentHashMap() + private val keyToIssuingParty = ConcurrentHashMap() private val principalToParties = ConcurrentHashMap() - private val partyToPath = ConcurrentHashMap() init { val caCertificatesWithRoot: Set = caCertificates.toSet() + trustRoot caCertStore = CertStore.getInstance("Collection", CollectionCertStoreParameters(caCertificatesWithRoot)) keyToParties.putAll(identities.associateBy { it.owningKey } ) principalToParties.putAll(identities.associateBy { it.name }) - partyToPath.putAll(certPaths) + confidentialIdentities.forEach { identity -> + require(identity.certPath.certificates.size >= 2) { "Certificate path must at least include subject and issuing certificates" } + keyToIssuingParty[identity.owningKey] = keyToParties[identity.certPath.certificates[1].publicKey]!! + principalToParties.computeIfAbsent(identity.name) { identity } + } } + override fun registerIdentity(party: PartyAndCertificate) = verifyAndRegisterIdentity(party) + // TODO: Check the certificate validation logic @Throws(CertificateExpiredException::class, CertificateNotYetValidException::class, InvalidAlgorithmParameterException::class) - override fun registerIdentity(party: PartyAndCertificate) { - require(party.certPath.certificates.isNotEmpty()) { "Certificate path must contain at least one certificate" } + override fun verifyAndRegisterIdentity(identity: PartyAndCertificate) { + require(identity.certPath.certificates.size >= 2) { "Certificate path must at least include subject and issuing certificates" } // Validate the chain first, before we do anything clever with it - validateCertificatePath(party.party, party.certPath) + identity.verify(trustAnchor) - log.trace { "Registering identity $party" } - require(Arrays.equals(party.certificate.subjectPublicKeyInfo.encoded, party.owningKey.encoded)) { "Party certificate must end with party's public key" } + log.trace { "Registering identity $identity" } + require(Arrays.equals(identity.certificate.subjectPublicKeyInfo.encoded, identity.owningKey.encoded)) { "Party certificate must end with party's public key" } - partyToPath[party.party] = party.certPath - keyToParties[party.owningKey] = party - principalToParties[party.name] = party + keyToParties[identity.owningKey] = identity + // TODO: This map should only be deanonymised parties, not all issuers, but we have no good way of checking for + // confidential vs anonymous identities + val issuer = keyToParties[identity.certPath.certificates[1].publicKey] + if (issuer != null) { + keyToIssuingParty[identity.owningKey] = issuer + } + // Always keep the first party we registered, as that's the well known identity + principalToParties.computeIfAbsent(identity.name) { identity } } + override fun certificateFromKey(owningKey: PublicKey): PartyAndCertificate? = keyToParties[owningKey] override fun certificateFromParty(party: Party): PartyAndCertificate? = principalToParties[party.name] // We give the caller a copy of the data set to avoid any locking problems - override fun getAllIdentities(): Iterable = ArrayList(keyToParties.values) + override fun getAllIdentities(): Iterable = java.util.ArrayList(keyToParties.values) override fun partyFromKey(key: PublicKey): Party? = keyToParties[key]?.party - @Deprecated("Use partyFromX500Name") - override fun partyFromName(name: String): Party? = principalToParties[X500Name(name)]?.party override fun partyFromX500Name(principal: X500Name): Party? = principalToParties[principal]?.party - override fun partyFromAnonymous(party: AbstractParty) = party as? Party ?: partyFromKey(party.owningKey) + override fun partyFromAnonymous(party: AbstractParty) = party as? Party ?: keyToIssuingParty[party.owningKey]?.party override fun partyFromAnonymous(partyRef: PartyAndReference) = partyFromAnonymous(partyRef.party) override fun requirePartyFromAnonymous(party: AbstractParty): Party { return partyFromAnonymous(party) ?: throw IllegalStateException("Could not deanonymise party ${party.owningKey.toStringShort()}") } override fun partiesFromName(query: String, exactMatch: Boolean): Set { - val results = HashSet() + val results = LinkedHashSet() for ((x500name, partyAndCertificate) in principalToParties) { val party = partyAndCertificate.party for (rdn in x500name.rdNs) { @@ -115,48 +123,11 @@ class InMemoryIdentityService(identities: Iterable = emptyS @Throws(IdentityService.UnknownAnonymousPartyException::class) override fun assertOwnership(party: Party, anonymousParty: AnonymousParty) { - val path = partyToPath[anonymousParty] ?: throw IdentityService.UnknownAnonymousPartyException("Unknown anonymous party ${anonymousParty.owningKey.toStringShort()}") + val path = keyToParties[anonymousParty.owningKey]?.certPath ?: throw IdentityService.UnknownAnonymousPartyException("Unknown anonymous party ${anonymousParty.owningKey.toStringShort()}") require(path.certificates.size > 1) { "Certificate path must contain at least two certificates" } val actual = path.certificates[1] require(actual is X509Certificate && actual.publicKey == party.owningKey) { "Next certificate in the path must match the party key ${party.owningKey.toStringShort()}." } val target = path.certificates.first() require(target is X509Certificate && target.publicKey == anonymousParty.owningKey) { "Certificate path starts with a certificate for the anonymous party" } } - - override fun pathForAnonymous(anonymousParty: AnonymousParty): CertPath? = partyToPath[anonymousParty] - - @Throws(CertificateExpiredException::class, CertificateNotYetValidException::class, InvalidAlgorithmParameterException::class) - override fun registerAnonymousIdentity(anonymousParty: AnonymousParty, party: Party, path: CertPath) { - val fullParty = certificateFromParty(party) ?: throw IllegalArgumentException("Unknown identity ${party.name}") - require(path.certificates.isNotEmpty()) { "Certificate path must contain at least one certificate" } - // Validate the chain first, before we do anything clever with it - validateCertificatePath(anonymousParty, path) - val subjectCertificate = path.certificates.first() - require(subjectCertificate is X509Certificate && subjectCertificate.subject == fullParty.name) { "Subject of the transaction certificate must match the well known identity" } - - log.trace { "Registering identity $fullParty" } - - partyToPath[anonymousParty] = path - keyToParties[anonymousParty.owningKey] = fullParty - principalToParties[fullParty.name] = fullParty - } - - /** - * Verify that the given certificate path is valid and leads to the owning key of the party. - */ - private fun validateCertificatePath(party: AbstractParty, path: CertPath): PKIXCertPathValidatorResult { - // Check that the path ends with a certificate for the correct party. - val endCertificate = path.certificates.first() - // Ensure the key is in the correct format for comparison. - // TODO: Replace with a Bouncy Castle cert path so we can avoid Sun internal classes appearing unexpectedly. - // For now we have to deal with this potentially being an [X509Key] which is Sun's equivalent to - // [SubjectPublicKeyInfo] but doesn't compare properly with [PublicKey]. - val endKey = Crypto.decodePublicKey(endCertificate.publicKey.encoded) - require(endKey == party.owningKey) { "Certificate path validation must end at owning key ${party.owningKey.toStringShort()}, found ${endKey.toStringShort()}" } - - val validatorParameters = PKIXParameters(setOf(trustAnchor)) - val validator = CertPathValidator.getInstance("PKIX") - validatorParameters.isRevocationEnabled = false - return validator.validate(path, validatorParameters) as PKIXCertPathValidatorResult - } } diff --git a/node/src/main/kotlin/net/corda/node/services/keys/E2ETestKeyManagementService.kt b/node/src/main/kotlin/net/corda/node/services/keys/E2ETestKeyManagementService.kt index 0222ba6fa5..0c0eb61778 100644 --- a/node/src/main/kotlin/net/corda/node/services/keys/E2ETestKeyManagementService.kt +++ b/node/src/main/kotlin/net/corda/node/services/keys/E2ETestKeyManagementService.kt @@ -1,21 +1,15 @@ package net.corda.node.services.keys -import net.corda.core.ThreadBox -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.generateKeyPair -import net.corda.core.crypto.keys -import net.corda.core.crypto.sign +import net.corda.core.crypto.* import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.ThreadBox import net.corda.core.node.services.IdentityService import net.corda.core.node.services.KeyManagementService import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.flows.AnonymisedIdentity -import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.operator.ContentSigner import java.security.KeyPair import java.security.PrivateKey import java.security.PublicKey -import java.security.cert.CertPath import java.util.* import javax.annotation.concurrent.ThreadSafe @@ -58,7 +52,7 @@ class E2ETestKeyManagementService(val identityService: IdentityService, return keyPair.public } - override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymisedIdentity { + override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): PartyAndCertificate { return freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled) } @@ -77,7 +71,13 @@ class E2ETestKeyManagementService(val identityService: IdentityService, override fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey { val keyPair = getSigningKeyPair(publicKey) - val signature = keyPair.sign(bytes) - return signature + return keyPair.sign(bytes) + } + + // TODO: A full KeyManagementService implementation needs to record activity to the Audit Service and to limit + // signing to appropriately authorised contexts and initiating users. + override fun sign(signableData: SignableData, publicKey: PublicKey): TransactionSignature { + val keyPair = getSigningKeyPair(publicKey) + return keyPair.sign(signableData) } } diff --git a/node/src/main/kotlin/net/corda/node/services/keys/KMSUtils.kt b/node/src/main/kotlin/net/corda/node/services/keys/KMSUtils.kt index 481e9e8246..07bf6b363b 100644 --- a/node/src/main/kotlin/net/corda/node/services/keys/KMSUtils.kt +++ b/node/src/main/kotlin/net/corda/node/services/keys/KMSUtils.kt @@ -1,20 +1,21 @@ package net.corda.node.services.keys -import net.corda.core.crypto.* -import net.corda.core.identity.AnonymousParty +import net.corda.core.crypto.ContentSignerBuilder +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.cert +import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate import net.corda.core.node.services.IdentityService -import net.corda.flows.AnonymisedIdentity -import org.bouncycastle.cert.X509CertificateHolder +import net.corda.core.utilities.days +import net.corda.node.utilities.CertificateType +import net.corda.node.utilities.X509Utilities import org.bouncycastle.operator.ContentSigner import java.security.KeyPair import java.security.PublicKey import java.security.Security -import java.security.cert.CertPath import java.security.cert.CertificateFactory import java.security.cert.X509Certificate import java.time.Duration -import java.util.* /** * Generates a new random [KeyPair], adds it to the internal key storage, then generates a corresponding @@ -31,16 +32,15 @@ fun freshCertificate(identityService: IdentityService, subjectPublicKey: PublicKey, issuer: PartyAndCertificate, issuerSigner: ContentSigner, - revocationEnabled: Boolean = false): AnonymisedIdentity { + revocationEnabled: Boolean = false): PartyAndCertificate { val issuerCertificate = issuer.certificate - val window = X509Utilities.getCertificateValidityWindow(Duration.ZERO, Duration.ofDays(10 * 365), issuerCertificate) - val ourCertificate = Crypto.createCertificate(CertificateType.IDENTITY, issuerCertificate.subject, issuerSigner, issuer.name, subjectPublicKey, window) + val window = X509Utilities.getCertificateValidityWindow(Duration.ZERO, 3650.days, issuerCertificate) + val ourCertificate = X509Utilities.createCertificate(CertificateType.IDENTITY, issuerCertificate.subject, issuerSigner, issuer.name, subjectPublicKey, window) val certFactory = CertificateFactory.getInstance("X509") val ourCertPath = certFactory.generateCertPath(listOf(ourCertificate.cert) + issuer.certPath.certificates) - identityService.registerAnonymousIdentity(AnonymousParty(subjectPublicKey), - issuer.party, - ourCertPath) - return AnonymisedIdentity(ourCertPath, issuerCertificate, subjectPublicKey) + val anonymisedIdentity = PartyAndCertificate(Party(issuer.name, subjectPublicKey), ourCertificate, ourCertPath) + identityService.verifyAndRegisterIdentity(anonymisedIdentity) + return anonymisedIdentity } fun getSigner(issuerKeyPair: KeyPair): ContentSigner { diff --git a/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt b/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt index 359239bc0f..95c4e0083e 100644 --- a/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt +++ b/node/src/main/kotlin/net/corda/node/services/keys/PersistentKeyManagementService.kt @@ -1,22 +1,23 @@ package net.corda.node.services.keys -import net.corda.core.ThreadBox -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.generateKeyPair -import net.corda.core.crypto.keys -import net.corda.core.crypto.sign +import net.corda.core.crypto.* import net.corda.core.identity.PartyAndCertificate import net.corda.core.node.services.IdentityService import net.corda.core.node.services.KeyManagementService +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.flows.AnonymisedIdentity -import net.corda.node.utilities.* +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.node.utilities.AppendOnlyPersistentMap +import net.corda.node.utilities.NODE_DATABASE_PREFIX import org.bouncycastle.operator.ContentSigner -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement import java.security.KeyPair import java.security.PrivateKey import java.security.PublicKey +import javax.persistence.Column +import javax.persistence.Entity +import javax.persistence.Id +import javax.persistence.Lob /** * A persistent re-implementation of [E2ETestKeyManagementService] to support node re-start. @@ -28,66 +29,73 @@ import java.security.PublicKey class PersistentKeyManagementService(val identityService: IdentityService, initialKeys: Set) : SingletonSerializeAsToken(), KeyManagementService { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}our_key_pairs") { - val publicKey = publicKey("public_key") - val privateKey = blob("private_key") - } + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}our_key_pairs") + class PersistentKey( - private class InnerState { - val keys = object : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): PublicKey = row[table.publicKey] + @Id + @Column(name = "public_key") + var publicKey: String = "", - override fun valueFromRow(row: ResultRow): PrivateKey = deserializeFromBlob(row[table.privateKey]) + @Lob + @Column(name = "private_key") + var privateKey: ByteArray = ByteArray(0) + ) - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.publicKey] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.privateKey] = serializeToBlob(entry.value, finalizables) - } + private companion object { + fun createKeyMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { it.toBase58String() }, + fromPersistentEntity = { Pair(parsePublicKeyBase58(it.publicKey), + it.privateKey.deserialize(context = SerializationDefaults.STORAGE_CONTEXT)) }, + toPersistentEntity = { key: PublicKey, value: PrivateKey -> + PersistentKey().apply { + publicKey = key.toBase58String() + privateKey = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes + } + }, + persistentEntityClass = PersistentKey::class.java + ) } } - private val mutex = ThreadBox(InnerState()) + val keysMap = createKeyMap() init { - mutex.locked { - keys.putAll(initialKeys.associate { Pair(it.public, it.private) }) - } + initialKeys.forEach({ it -> keysMap.addWithDuplicatesAllowed(it.public, it.private) }) } - override val keys: Set get() = mutex.locked { keys.keys } + override val keys: Set get() = keysMap.allPersisted().map { it.first }.toSet() - override fun filterMyKeys(candidateKeys: Iterable): Iterable { - return mutex.locked { candidateKeys.filter { it in this.keys } } - } + override fun filterMyKeys(candidateKeys: Iterable): Iterable = + candidateKeys.filter { keysMap[it] != null } override fun freshKey(): PublicKey { val keyPair = generateKeyPair() - mutex.locked { - keys[keyPair.public] = keyPair.private - } + keysMap[keyPair.public] = keyPair.private return keyPair.public } - override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymisedIdentity { - return freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled) - } + override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): PartyAndCertificate = + freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled) private fun getSigner(publicKey: PublicKey): ContentSigner = getSigner(getSigningKeyPair(publicKey)) + //It looks for the PublicKey in the (potentially) CompositeKey that is ours, and then returns the associated PrivateKey to use in signing private fun getSigningKeyPair(publicKey: PublicKey): KeyPair { - return mutex.locked { - val pk = publicKey.keys.first { keys.containsKey(it) } - KeyPair(pk, keys[pk]!!) - } + val pk = publicKey.keys.first { keysMap[it] != null } //TODO here for us to re-write this using an actual query if publicKey.keys.size > 1 + return KeyPair(pk, keysMap[pk]!!) } override fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey { val keyPair = getSigningKeyPair(publicKey) - val signature = keyPair.sign(bytes) - return signature + return keyPair.sign(bytes) } + // TODO: A full KeyManagementService implementation needs to record activity to the Audit Service and to limit + // signing to appropriately authorised contexts and initiating users. + override fun sign(signableData: SignableData, publicKey: PublicKey): TransactionSignature { + val keyPair = getSigningKeyPair(publicKey) + return keyPair.sign(signableData) + } } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt index 66b2da18f6..919a744191 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ArtemisMessagingServer.kt @@ -1,18 +1,20 @@ package net.corda.node.services.messaging import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture import io.netty.handler.ssl.SslHandler -import net.corda.core.* -import net.corda.core.crypto.* -import net.corda.core.crypto.X509Utilities.CORDA_CLIENT_TLS -import net.corda.core.crypto.X509Utilities.CORDA_ROOT_CA +import net.corda.core.concurrent.CordaFuture +import net.corda.core.crypto.AddressFormatException +import net.corda.core.crypto.newSecureRandom +import net.corda.core.crypto.parsePublicKeyBase58 +import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.div +import net.corda.core.internal.noneOrSingle import net.corda.core.node.NodeInfo import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.NetworkMapCache.MapChange -import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.debug -import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.* import net.corda.node.internal.Node import net.corda.node.services.RPCUserService import net.corda.node.services.config.NodeConfiguration @@ -20,6 +22,11 @@ import net.corda.node.services.messaging.NodeLoginModule.Companion.NODE_ROLE import net.corda.node.services.messaging.NodeLoginModule.Companion.PEER_ROLE import net.corda.node.services.messaging.NodeLoginModule.Companion.RPC_ROLE import net.corda.node.services.messaging.NodeLoginModule.Companion.VERIFIER_ROLE +import net.corda.node.utilities.X509Utilities +import net.corda.node.utilities.X509Utilities.CORDA_CLIENT_TLS +import net.corda.node.utilities.X509Utilities.CORDA_ROOT_CA +import net.corda.node.utilities.getX509Certificate +import net.corda.node.utilities.loadKeyStore import net.corda.nodeapi.* import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER import net.corda.nodeapi.ArtemisMessagingComponent.Companion.PEER_USER @@ -104,12 +111,12 @@ class ArtemisMessagingServer(override val config: NodeConfiguration, private val mutex = ThreadBox(InnerState()) private lateinit var activeMQServer: ActiveMQServer val serverControl: ActiveMQServerControl get() = activeMQServer.activeMQServerControl - private val _networkMapConnectionFuture = config.networkMapService?.let { SettableFuture.create() } + private val _networkMapConnectionFuture = config.networkMapService?.let { openFuture() } /** * A [ListenableFuture] which completes when the server successfully connects to the network map node. If a * non-recoverable error is encountered then the Future will complete with an exception. */ - val networkMapConnectionFuture: SettableFuture? get() = _networkMapConnectionFuture + val networkMapConnectionFuture: CordaFuture? get() = _networkMapConnectionFuture private var networkChangeHandle: Subscription? = null private val nodeRunsNetworkMapService = config.networkMapService == null @@ -264,8 +271,8 @@ class ArtemisMessagingServer(override val config: NodeConfiguration, @Throws(IOException::class, KeyStoreException::class) private fun createArtemisSecurityManager(): ActiveMQJAASSecurityManager { - val keyStore = KeyStoreUtilities.loadKeyStore(config.sslKeystore, config.keyStorePassword) - val trustStore = KeyStoreUtilities.loadKeyStore(config.trustStoreFile, config.trustStorePassword) + val keyStore = loadKeyStore(config.sslKeystore, config.keyStorePassword) + val trustStore = loadKeyStore(config.trustStoreFile, config.trustStorePassword) val ourCertificate = keyStore.getX509Certificate(CORDA_CLIENT_TLS) // This is a sanity check and should not fail unless things have been misconfigured diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt index c5396a0067..3e65c572f0 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/Messaging.kt @@ -1,15 +1,15 @@ package net.corda.node.services.messaging import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.catch +import net.corda.core.concurrent.CordaFuture +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize +import net.corda.node.services.api.DEFAULT_SESSION_ID import org.bouncycastle.asn1.x500.X500Name import java.time.Instant import java.util.* @@ -147,14 +147,15 @@ inline fun MessagingService.runOnNextMessage(topicSession: TopicSession, crossin } /** - * Returns a [ListenableFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId. + * Returns a [CordaFuture] of the next message payload ([Message.data]) which is received on the given topic and sessionId. * The payload is deserialized to an object of type [M]. Any exceptions thrown will be captured by the future. */ -fun MessagingService.onNext(topic: String, sessionId: Long): ListenableFuture { - val messageFuture = SettableFuture.create() +fun MessagingService.onNext(topic: String, sessionId: Long): CordaFuture { + val messageFuture = openFuture() runOnNextMessage(topic, sessionId) { message -> - messageFuture.catch { - message.data.deserialize() + messageFuture.capture { + @Suppress("UNCHECKED_CAST") + message.data.deserialize() as M } } return messageFuture diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt index 0b9a45d387..c0ddd9f8e5 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/NodeMessagingClient.kt @@ -1,19 +1,18 @@ package net.corda.node.services.messaging -import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.* +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.concurrent.andForget +import net.corda.core.internal.concurrent.thenMatch +import net.corda.core.internal.ThreadBox import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.RPCOps import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.TransactionVerifierService -import net.corda.core.utilities.opaque import net.corda.core.transactions.LedgerTransaction -import net.corda.core.utilities.NetworkHostAndPort -import net.corda.core.utilities.loggerFor -import net.corda.core.utilities.trace +import net.corda.core.utilities.* import net.corda.node.VersionInfo import net.corda.node.services.RPCUserService import net.corda.node.services.api.MonitoringService @@ -37,7 +36,6 @@ import org.apache.activemq.artemis.api.core.client.* import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE import org.apache.activemq.artemis.api.core.management.ActiveMQServerControl import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.ResultRow import org.jetbrains.exposed.sql.statements.InsertStatement import java.security.PublicKey @@ -74,8 +72,8 @@ class NodeMessagingClient(override val config: NodeConfiguration, val serverAddress: NetworkHostAndPort, val myIdentity: PublicKey?, val nodeExecutor: AffinityExecutor.ServiceAffinityExecutor, - val database: Database, - val networkMapRegistrationFuture: ListenableFuture, + val database: CordaPersistence, + val networkMapRegistrationFuture: CordaFuture, val monitoringService: MonitoringService, advertisedAddress: NetworkHostAndPort = serverAddress ) : ArtemisMessagingComponent(), MessagingService { @@ -347,7 +345,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, private val message: ClientMessage) : ReceivedMessage { override val data: ByteArray by lazy { ByteArray(message.bodySize).apply { message.bodyBuffer.readBytes(this) } } override val debugTimestamp: Instant get() = Instant.ofEpochMilli(message.timestamp) - override fun toString() = "${topicSession.topic}#${data.opaque()}" + override fun toString() = "${topicSession.topic}#${data.sequence()}" } private fun deliver(msg: ReceivedMessage): Boolean { @@ -466,8 +464,8 @@ class NodeMessagingClient(override val config: NodeConfiguration, } private fun sendWithRetry(retryCount: Int, address: String, message: ClientMessage, retryId: Long) { - fun randomiseDuplicateId(message: ClientMessage) { - message.putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) + fun ClientMessage.randomiseDuplicateId() { + putStringProperty(HDR_DUPLICATE_DETECTION_ID, SimpleString(UUID.randomUUID().toString())) } log.trace { "Attempting to retry #$retryCount message delivery for $retryId" } @@ -477,7 +475,7 @@ class NodeMessagingClient(override val config: NodeConfiguration, return } - randomiseDuplicateId(message) + message.randomiseDuplicateId() state.locked { log.trace { "Retry #$retryCount sending message $message to $address for $retryId" } diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt index 71ee340dbb..c04842e494 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/RPCServer.kt @@ -4,7 +4,6 @@ import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.cache.Cache import com.google.common.cache.CacheBuilder import com.google.common.cache.RemovalListener @@ -13,9 +12,12 @@ import com.google.common.collect.Multimaps import com.google.common.collect.SetMultimap import com.google.common.util.concurrent.ThreadFactoryBuilder import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.LazyStickyPool +import net.corda.core.internal.LifeCycle import net.corda.core.messaging.RPCOps -import net.corda.core.seconds -import net.corda.core.serialization.KryoPoolWithContext +import net.corda.core.utilities.seconds +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults.RPC_SERVER_CONTEXT import net.corda.core.utilities.* import net.corda.node.services.RPCUserService import net.corda.nodeapi.* @@ -79,7 +81,6 @@ class RPCServer( ) { private companion object { val log = loggerFor() - val kryoPool = KryoPool.Builder { RPCKryo(RpcServerObservableSerializer) }.build() } private enum class State { UNSTARTED, @@ -256,7 +257,7 @@ class RPCServer( private fun clientArtemisMessageHandler(artemisMessage: ClientMessage) { lifeCycle.requireState(State.STARTED) - val clientToServer = RPCApi.ClientToServer.fromClientMessage(kryoPool, artemisMessage) + val clientToServer = RPCApi.ClientToServer.fromClientMessage(RPC_SERVER_CONTEXT, artemisMessage) log.debug { "-> RPC -> $clientToServer" } when (clientToServer) { is RPCApi.ClientToServer.RpcRequest -> { @@ -300,8 +301,7 @@ class RPCServer( clientAddress, serverControl!!, sessionAndProducerPool, - observationSendExecutor!!, - kryoPool + observationSendExecutor!! ) val buffered = bufferIfQueueNotBound(clientAddress, reply, observableContext) @@ -383,19 +383,19 @@ class ObservableContext( val clientAddress: SimpleString, val serverControl: ActiveMQServerControl, val sessionAndProducerPool: LazyStickyPool, - val observationSendExecutor: ExecutorService, - kryoPool: KryoPool + val observationSendExecutor: ExecutorService ) { private companion object { val log = loggerFor() } - private val kryoPoolWithObservableContext = RpcServerObservableSerializer.createPoolWithContext(kryoPool, this) + private val serializationContextWithObservableContext = RpcServerObservableSerializer.createContext(this) + fun sendMessage(serverToClient: RPCApi.ServerToClient) { try { sessionAndProducerPool.run(rpcRequestId) { val artemisMessage = it.session.createMessage(false) - serverToClient.writeToClientMessage(kryoPoolWithObservableContext, artemisMessage) + serverToClient.writeToClientMessage(serializationContextWithObservableContext, artemisMessage) it.producer.send(clientAddress, artemisMessage) log.debug("<- RPC <- $serverToClient") } @@ -406,19 +406,19 @@ class ObservableContext( } } -private object RpcServerObservableSerializer : Serializer>() { +object RpcServerObservableSerializer : Serializer>() { private object RpcObservableContextKey private val log = loggerFor() - fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool { - return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext) + fun createContext(observableContext: ObservableContext): SerializationContext { + return RPC_SERVER_CONTEXT.withProperty(RpcServerObservableSerializer.RpcObservableContextKey, observableContext) } - override fun read(kryo: Kryo?, input: Input?, type: Class>?): Observable { + override fun read(kryo: Kryo?, input: Input?, type: Class>?): Observable { throw UnsupportedOperationException() } - override fun write(kryo: Kryo, output: Output, observable: Observable) { + override fun write(kryo: Kryo, output: Output, observable: Observable<*>) { val observableId = RPCApi.ObservableId(random63BitValue()) val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext output.writeLong(observableId.toLong, true) @@ -426,8 +426,8 @@ private object RpcServerObservableSerializer : Serializer>() { // We capture [observableContext] in the subscriber. Note that all synchronisation/kryo borrowing // must be done again within the subscriber subscription = observable.materialize().subscribe( - object : Subscriber>() { - override fun onNext(observation: Notification) { + object : Subscriber>() { + override fun onNext(observation: Notification<*>) { if (!isUnsubscribed) { observableContext.observationSendExecutor.submit { observableContext.sendMessage(RPCApi.ServerToClient.Observation(observableId, observation)) diff --git a/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt b/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt index 35ce5218d1..7536fcc33c 100644 --- a/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt +++ b/node/src/main/kotlin/net/corda/node/services/messaging/ServiceRequestMessage.kt @@ -1,10 +1,10 @@ package net.corda.node.services.messaging -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.serialization.CordaSerializable +import net.corda.node.services.api.DEFAULT_SESSION_ID /** * Abstract superclass for request messages sent to services which expect a reply. @@ -16,12 +16,12 @@ interface ServiceRequestMessage { } /** - * Sends a [ServiceRequestMessage] to [target] and returns a [ListenableFuture] of the response. + * Sends a [ServiceRequestMessage] to [target] and returns a [CordaFuture] of the response. * @param R The type of the response. */ fun MessagingService.sendRequest(topic: String, request: ServiceRequestMessage, - target: MessageRecipients): ListenableFuture { + target: MessageRecipients): CordaFuture { val responseFuture = onNext(topic, request.sessionID) send(topic, DEFAULT_SESSION_ID, request, target) return responseFuture diff --git a/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt b/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt index 0d3cd6e818..3af8f164e6 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/InMemoryNetworkMapCache.kt @@ -1,24 +1,23 @@ package net.corda.node.services.network import com.google.common.annotations.VisibleForTesting -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.bufferUntilSubscribed +import net.corda.core.concurrent.CordaFuture +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party -import net.corda.core.map +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.DataFeed import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub -import net.corda.core.node.services.DEFAULT_SESSION_ID -import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkMapCache.MapChange import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.loggerFor +import net.corda.node.services.api.DEFAULT_SESSION_ID import net.corda.node.services.api.NetworkCacheError import net.corda.node.services.api.NetworkMapCacheInternal import net.corda.node.services.messaging.MessagingService @@ -56,8 +55,8 @@ open class InMemoryNetworkMapCache(private val serviceHub: ServiceHub?) : Single override val changed: Observable = _changed.wrapWithDatabaseTransaction() private val changePublisher: rx.Observer get() = _changed.bufferUntilDatabaseCommit() - private val _registrationFuture = SettableFuture.create() - override val mapServiceRegistered: ListenableFuture get() = _registrationFuture + private val _registrationFuture = openFuture() + override val mapServiceRegistered: CordaFuture get() = _registrationFuture private var registeredForPush = false protected var registeredNodes: MutableMap = Collections.synchronizedMap(HashMap()) @@ -97,7 +96,7 @@ open class InMemoryNetworkMapCache(private val serviceHub: ServiceHub?) : Single } override fun addMapService(network: MessagingService, networkMapAddress: SingleMessageRecipient, subscribe: Boolean, - ifChangedSinceVer: Int?): ListenableFuture { + ifChangedSinceVer: Int?): CordaFuture { if (subscribe && !registeredForPush) { // Add handler to the network, for updates received from the remote network map service. network.addMessageHandler(NetworkMapService.PUSH_TOPIC, DEFAULT_SESSION_ID) { message, _ -> @@ -123,7 +122,7 @@ open class InMemoryNetworkMapCache(private val serviceHub: ServiceHub?) : Single nodes?.forEach { processRegistration(it) } Unit } - _registrationFuture.setFuture(future) + _registrationFuture.captureLater(future.map { null }) return future } @@ -150,7 +149,7 @@ open class InMemoryNetworkMapCache(private val serviceHub: ServiceHub?) : Single * Unsubscribes from updates from the given map service. * @param service the network map service to listen to updates from. */ - override fun deregisterForUpdates(network: MessagingService, service: NodeInfo): ListenableFuture { + override fun deregisterForUpdates(network: MessagingService, service: NodeInfo): CordaFuture { // Fetch the network map and register for updates at the same time val req = NetworkMapService.SubscribeRequest(false, network.myAddress) // `network.getAddressOfParty(partyInfo)` is a work-around for MockNetwork and InMemoryMessaging to get rid of SingleMessageRecipient in NodeInfo. @@ -158,7 +157,7 @@ open class InMemoryNetworkMapCache(private val serviceHub: ServiceHub?) : Single val future = network.sendRequest(NetworkMapService.SUBSCRIPTION_TOPIC, req, address).map { if (it.confirmed) Unit else throw NetworkCacheError.DeregistrationFailed() } - _registrationFuture.setFuture(future) + _registrationFuture.captureLater(future.map { null }) return future } @@ -182,6 +181,6 @@ open class InMemoryNetworkMapCache(private val serviceHub: ServiceHub?) : Single @VisibleForTesting override fun runWithoutMapService() { - _registrationFuture.set(Unit) + _registrationFuture.set(null) } } diff --git a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapService.kt b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapService.kt index d0f51b534e..b5d5aa729b 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapService.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapService.kt @@ -1,19 +1,18 @@ package net.corda.node.services.network import com.google.common.annotations.VisibleForTesting -import net.corda.core.ThreadBox +import net.corda.core.internal.ThreadBox import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.SignedData import net.corda.core.crypto.isFulfilledBy +import net.corda.core.crypto.random63BitValue import net.corda.core.identity.PartyAndCertificate import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.NodeInfo -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.ServiceType -import net.corda.core.crypto.random63BitValue import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize @@ -21,6 +20,7 @@ import net.corda.core.serialization.serialize import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor import net.corda.node.services.api.AbstractNodeService +import net.corda.node.services.api.DEFAULT_SESSION_ID import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.messaging.MessageHandlerRegistration import net.corda.node.services.messaging.ServiceRequestMessage diff --git a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt index 2220d94a19..1cb7de5292 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/PersistentNetworkMapService.kt @@ -1,6 +1,6 @@ package net.corda.node.services.network -import net.corda.core.ThreadBox +import net.corda.core.internal.ThreadBox import net.corda.core.identity.PartyAndCertificate import net.corda.core.messaging.SingleMessageRecipient import net.corda.node.services.api.ServiceHubInternal diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt index 89db282424..b977f36875 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBCheckpointStorage.kt @@ -1,57 +1,60 @@ package net.corda.node.services.persistence -import net.corda.core.crypto.SecureHash -import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.utilities.* -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement -import java.util.Collections.synchronizedMap +import javax.persistence.Column +import javax.persistence.Entity +import javax.persistence.Id +import javax.persistence.Lob /** - * Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites. + * Simple checkpoint key value storage in DB. */ class DBCheckpointStorage : CheckpointStorage { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}checkpoints") { - val checkpointId = secureHash("checkpoint_id") - val checkpoint = blob("checkpoint") - } + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoints") + class DBCheckpoint( + @Id + @Column(name = "checkpoint_id", length = 64) + var checkpointId: String = "", - private class CheckpointMap : AbstractJDBCHashMap, Table>(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): SecureHash = row[table.checkpointId] + @Lob + @Column(name = "checkpoint") + var checkpoint: ByteArray = ByteArray(0) + ) - override fun valueFromRow(row: ResultRow): SerializedBytes = bytesFromBlob(row[table.checkpoint]) - - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { - insert[table.checkpointId] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry>, finalizables: MutableList<() -> Unit>) { - insert[table.checkpoint] = bytesToBlob(entry.value, finalizables) - } - } - - private val checkpointStorage = synchronizedMap(CheckpointMap()) - - override fun addCheckpoint(checkpoint: Checkpoint) { - checkpointStorage.put(checkpoint.id, checkpoint.serialize(storageKryo(), true)) + override fun addCheckpoint(value: Checkpoint) { + val session = DatabaseTransactionManager.current().session + session.save(DBCheckpoint().apply { + checkpointId = value.id.toString() + checkpoint = value.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).bytes + }) } override fun removeCheckpoint(checkpoint: Checkpoint) { - checkpointStorage.remove(checkpoint.id) ?: throw IllegalArgumentException("Checkpoint not found") + val session = DatabaseTransactionManager.current().session + val criteriaBuilder = session.criteriaBuilder + val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java) + val root = delete.from(DBCheckpoint::class.java) + delete.where(criteriaBuilder.equal(root.get(DBCheckpoint::checkpointId.name), checkpoint.id.toString())) + session.createQuery(delete).executeUpdate() } override fun forEach(block: (Checkpoint) -> Boolean) { - synchronized(checkpointStorage) { - for (checkpoint in checkpointStorage.values) { - if (!block(checkpoint.deserialize())) { - break - } + val session = DatabaseTransactionManager.current().session + val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java) + val root = criteriaQuery.from(DBCheckpoint::class.java) + criteriaQuery.select(root) + val query = session.createQuery(criteriaQuery) + val checkpoints = query.resultList.map { e -> e.checkpoint.deserialize(context = SerializationDefaults.CHECKPOINT_CONTEXT) }.asSequence() + for (e in checkpoints) { + if (!block(e)) { + break } } } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt index b9c376b768..6c24e81ab0 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionMappingStorage.kt @@ -1,66 +1,63 @@ package net.corda.node.services.persistence -import net.corda.core.ThreadBox -import net.corda.core.bufferUntilSubscribed +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.crypto.SecureHash import net.corda.core.flows.StateMachineRunId import net.corda.core.messaging.DataFeed import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage import net.corda.node.utilities.* -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement import rx.subjects.PublishSubject +import java.util.* import javax.annotation.concurrent.ThreadSafe +import javax.persistence.* /** * Database storage of a txhash -> state machine id mapping. * * Mappings are added as transactions are persisted by [ServiceHub.recordTransaction], and never deleted. Used in the * RPC API to correlate transaction creation with flows. - * */ @ThreadSafe class DBTransactionMappingStorage : StateMachineRecordedTransactionMappingStorage { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}transaction_mappings") { - val txId = secureHash("tx_id") - val stateMachineRunId = uuidString("state_machine_run_id") - } + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}transaction_mappings") + class DBTransactionMapping( + @Id + @Column(name = "tx_id", length = 64) + var txId: String = "", - private class TransactionMappingsMap : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): SecureHash = row[table.txId] + @Column(name = "state_machine_run_id", length = 36) + var stateMachineRunId: String = "" + ) - override fun valueFromRow(row: ResultRow): StateMachineRunId = StateMachineRunId(row[table.stateMachineRunId]) - - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.txId] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.stateMachineRunId] = entry.value.uuid - } - } - - private class InnerState { - val stateMachineTransactionMap = TransactionMappingsMap() - val updates: PublishSubject = PublishSubject.create() - } - private val mutex = ThreadBox(InnerState()) - - override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) { - mutex.locked { - stateMachineTransactionMap[transactionId] = stateMachineRunId - updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId)) - } - } - - override fun track(): DataFeed, StateMachineTransactionMapping> { - mutex.locked { - return DataFeed( - stateMachineTransactionMap.map { StateMachineTransactionMapping(it.value, it.key) }, - updates.bufferUntilSubscribed().wrapWithDatabaseTransaction() + private companion object { + fun createMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { it.toString() }, + fromPersistentEntity = { Pair(SecureHash.parse(it.txId), StateMachineRunId(UUID.fromString(it.stateMachineRunId))) }, + toPersistentEntity = { key: SecureHash, value: StateMachineRunId -> + DBTransactionMapping().apply { + txId = key.toString() + stateMachineRunId = value.uuid.toString() + } + }, + persistentEntityClass = DBTransactionMapping::class.java ) } } + + val stateMachineTransactionMap = createMap() + val updates: PublishSubject = PublishSubject.create() + + override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) { + stateMachineTransactionMap[transactionId] = stateMachineRunId + updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId)) + } + + override fun track(): DataFeed, StateMachineTransactionMapping> = + DataFeed(stateMachineTransactionMap.allPersisted().map { StateMachineTransactionMapping(it.second, it.first) }.toList(), + updates.bufferUntilSubscribed().wrapWithDatabaseTransaction()) + } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt index 219865b00e..fb637abd4c 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/DBTransactionStorage.kt @@ -1,76 +1,63 @@ package net.corda.node.services.persistence import com.google.common.annotations.VisibleForTesting -import net.corda.core.bufferUntilSubscribed +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.crypto.SecureHash import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.serialization.* import net.corda.core.transactions.SignedTransaction import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.utilities.* -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.exposedLogger -import org.jetbrains.exposed.sql.statements.InsertStatement import rx.Observable import rx.subjects.PublishSubject -import java.util.Collections.synchronizedMap +import javax.persistence.* class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() { - private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}transactions") { - val txId = secureHash("tx_id") - val transaction = blob("transaction") - } - private class TransactionsMap : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): SecureHash = row[table.txId] + @Entity + @Table(name = "${NODE_DATABASE_PREFIX}transactions") + class DBTransaction( + @Id + @Column(name = "tx_id", length = 64) + var txId: String = "", - override fun valueFromRow(row: ResultRow): SignedTransaction = deserializeFromBlob(row[table.transaction]) + @Lob + @Column + var transaction: ByteArray = ByteArray(0) + ) - override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.txId] = entry.key - } - - override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry, finalizables: MutableList<() -> Unit>) { - insert[table.transaction] = serializeToBlob(entry.value, finalizables) + private companion object { + fun createTransactionsMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { it.toString() }, + fromPersistentEntity = { Pair(SecureHash.parse(it.txId), + it.transaction.deserialize( context = SerializationDefaults.STORAGE_CONTEXT)) }, + toPersistentEntity = { key: SecureHash, value: SignedTransaction -> + DBTransaction().apply { + txId = key.toString() + transaction = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes + } + }, + persistentEntityClass = DBTransaction::class.java + ) } } - private val txStorage = synchronizedMap(TransactionsMap()) + private val txStorage = createTransactionsMap() - override fun addTransaction(transaction: SignedTransaction): Boolean { - val recorded = synchronized(txStorage) { - val old = txStorage[transaction.id] - if (old == null) { - txStorage.put(transaction.id, transaction) - updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) - true - } else { - false - } + override fun addTransaction(transaction: SignedTransaction): Boolean = + txStorage.addWithDuplicatesAllowed(transaction.id, transaction).apply { + updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) } - if (!recorded) { - exposedLogger.warn("Duplicate recording of transaction ${transaction.id}") - } - return recorded - } - override fun getTransaction(id: SecureHash): SignedTransaction? { - synchronized(txStorage) { - return txStorage[id] - } - } + override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage[id] private val updatesPublisher = PublishSubject.create().toSerialized() override val updates: Observable = updatesPublisher.wrapWithDatabaseTransaction() - override fun track(): DataFeed, SignedTransaction> { - synchronized(txStorage) { - return DataFeed(txStorage.values.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) - } - } + override fun track(): DataFeed, SignedTransaction> = + DataFeed(txStorage.allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) @VisibleForTesting - val transactions: Iterable get() = synchronized(txStorage) { - txStorage.values.toList() - } + val transactions: Iterable get() = txStorage.allPersisted().map { it.second }.toList() } diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/InMemoryStateMachineRecordedTransactionMappingStorage.kt b/node/src/main/kotlin/net/corda/node/services/persistence/InMemoryStateMachineRecordedTransactionMappingStorage.kt index 168fee3bc8..740373d3d7 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/InMemoryStateMachineRecordedTransactionMappingStorage.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/InMemoryStateMachineRecordedTransactionMappingStorage.kt @@ -1,7 +1,7 @@ package net.corda.node.services.persistence -import net.corda.core.ThreadBox -import net.corda.core.bufferUntilSubscribed +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.crypto.SecureHash import net.corda.core.flows.StateMachineRunId import net.corda.core.messaging.DataFeed diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt index bbb4ffc340..7dac06f5a0 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt @@ -9,11 +9,9 @@ import com.google.common.io.CountingInputStream import net.corda.core.contracts.AbstractAttachment import net.corda.core.contracts.Attachment import net.corda.core.crypto.SecureHash -import net.corda.core.isDirectory import net.corda.core.node.services.AttachmentStorage import net.corda.core.serialization.* import net.corda.core.utilities.loggerFor -import net.corda.node.services.api.AcceptsFileUpload import net.corda.node.services.database.RequeryConfiguration import net.corda.node.services.persistence.schemas.requery.AttachmentEntity import net.corda.node.services.persistence.schemas.requery.Models @@ -22,7 +20,6 @@ import java.io.FilterInputStream import java.io.IOException import java.io.InputStream import java.nio.file.FileAlreadyExistsException -import java.nio.file.Path import java.nio.file.Paths import java.util.* import java.util.jar.JarInputStream @@ -32,13 +29,13 @@ import javax.annotation.concurrent.ThreadSafe * Stores attachments in H2 database. */ @ThreadSafe -class NodeAttachmentService(val storePath: Path, dataSourceProperties: Properties, metrics: MetricRegistry) - : AttachmentStorage, AcceptsFileUpload, SingletonSerializeAsToken() { +class NodeAttachmentService(dataSourceProperties: Properties, metrics: MetricRegistry, databaseProperties: Properties?) + : AttachmentStorage, SingletonSerializeAsToken() { companion object { private val log = loggerFor() } - val configuration = RequeryConfiguration(dataSourceProperties) + val configuration = RequeryConfiguration(dataSourceProperties, databaseProperties = databaseProperties ?: Properties()) val session = configuration.sessionForModel(Models.PERSISTENCE) @VisibleForTesting @@ -47,8 +44,6 @@ class NodeAttachmentService(val storePath: Path, dataSourceProperties: Propertie private val attachmentCount = metrics.counter("Attachments") init { - require(storePath.isDirectory()) { "$storePath must be a directory" } - session.withTransaction { attachmentCount.inc(session.count(AttachmentEntity::class).get().value().toLong()) } @@ -200,9 +195,4 @@ class NodeAttachmentService(val storePath: Path, dataSourceProperties: Propertie } require(count > 0) { "Stream is either empty or not a JAR/ZIP" } } - - // Implementations for AcceptsFileUpload - override val dataTypePrefix = "attachment" - override val acceptableFileExtensions = listOf(".jar", ".zip") - override fun upload(file: InputStream) = importAttachment(file).toString() } diff --git a/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt b/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt index b152f4c04b..fc168e80b2 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/HibernateObserver.kt @@ -17,7 +17,7 @@ import rx.Observable * A vault observer that extracts Object Relational Mappings for contract states that support it, and persists them with Hibernate. */ // TODO: Manage version evolution of the schemas via additional tooling. -class HibernateObserver(vaultUpdates: Observable, val config: HibernateConfiguration) { +class HibernateObserver(vaultUpdates: Observable>, val config: HibernateConfiguration) { companion object { val logger = loggerFor() diff --git a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt index 1483ac4e9a..4c8e128fbb 100644 --- a/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt +++ b/node/src/main/kotlin/net/corda/node/services/schema/NodeSchemaService.kt @@ -1,15 +1,19 @@ package net.corda.node.services.schema -import net.corda.contracts.DealState import net.corda.core.contracts.ContractState import net.corda.core.contracts.FungibleAsset import net.corda.core.contracts.LinearState +import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState import net.corda.core.schemas.QueryableState import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.node.services.api.SchemaService -import net.corda.core.schemas.CommonSchemaV1 +import net.corda.node.services.keys.PersistentKeyManagementService +import net.corda.node.services.persistence.DBCheckpointStorage +import net.corda.node.services.persistence.DBTransactionMappingStorage +import net.corda.node.services.persistence.DBTransactionStorage +import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.vault.VaultSchemaV1 import net.corda.schemas.CashSchemaV1 @@ -23,14 +27,25 @@ import net.corda.schemas.CashSchemaV1 */ class NodeSchemaService(customSchemas: Set = emptySet()) : SchemaService, SingletonSerializeAsToken() { - // Currently does not support configuring schema options. + // Entities for compulsory services + object NodeServices + + object NodeServicesV1 : MappedSchema(schemaFamily = NodeServices.javaClass, version = 1, + mappedTypes = listOf(DBCheckpointStorage.DBCheckpoint::class.java, + DBTransactionStorage.DBTransaction::class.java, + DBTransactionMappingStorage.DBTransactionMapping::class.java, + PersistentKeyManagementService.PersistentKey::class.java, + PersistentUniquenessProvider.PersistentUniqueness::class.java + )) // Required schemas are those used by internal Corda services // For example, cash is used by the vault for coin selection (but will be extracted as a standalone CorDapp in future) val requiredSchemas: Map = mapOf(Pair(CashSchemaV1, SchemaService.SchemaOptions()), Pair(CommonSchemaV1, SchemaService.SchemaOptions()), - Pair(VaultSchemaV1, SchemaService.SchemaOptions())) + Pair(VaultSchemaV1, SchemaService.SchemaOptions()), + Pair(NodeServicesV1, SchemaService.SchemaOptions())) + override val schemaOptions: Map = requiredSchemas.plus(customSchemas.map { mappedSchema -> Pair(mappedSchema, SchemaService.SchemaOptions()) @@ -43,9 +58,6 @@ class NodeSchemaService(customSchemas: Set = emptySet()) : SchemaS schemas += state.supportedSchemas() if (state is LinearState) schemas += VaultSchemaV1 // VaultLinearStates - // TODO: DealState to be deprecated (collapsed into LinearState) - if (state is DealState) - schemas += VaultSchemaV1 // VaultLinearStates if (state is FungibleAsset<*>) schemas += VaultSchemaV1 // VaultFungibleStates @@ -54,11 +66,8 @@ class NodeSchemaService(customSchemas: Set = emptySet()) : SchemaS // Because schema is always one supported by the state, just delegate. override fun generateMappedObject(state: ContractState, schema: MappedSchema): PersistentState { - // TODO: DealState to be deprecated (collapsed into LinearState) - if ((schema is VaultSchemaV1) && (state is DealState)) - return VaultSchemaV1.VaultLinearStates(state.linearId, state.ref, state.participants) if ((schema is VaultSchemaV1) && (state is LinearState)) - return VaultSchemaV1.VaultLinearStates(state.linearId, "", state.participants) + return VaultSchemaV1.VaultLinearStates(state.linearId, state.participants) if ((schema is VaultSchemaV1) && (state is FungibleAsset<*>)) return VaultSchemaV1.VaultFungibleStates(state.owner, state.amount.quantity, state.amount.token.issuer.party, state.amount.token.issuer.reference, state.participants) return (state as QueryableState).generateMappedObject(schema) diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSession.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSession.kt index de15e9ce68..e17cf976a1 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSession.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowSession.kt @@ -1,7 +1,8 @@ package net.corda.node.services.statemachine -import net.corda.core.identity.Party +import net.corda.core.flows.FlowContext import net.corda.core.flows.FlowLogic +import net.corda.core.identity.Party import net.corda.node.services.statemachine.FlowSessionState.Initiated import net.corda.node.services.statemachine.FlowSessionState.Initiating import java.util.concurrent.ConcurrentLinkedQueue @@ -41,7 +42,7 @@ sealed class FlowSessionState { override val sendToParty: Party get() = otherParty } - data class Initiated(val peerParty: Party, val peerSessionId: Long) : FlowSessionState() { + data class Initiated(val peerParty: Party, val peerSessionId: Long, val context: FlowContext) : FlowSessionState() { override val sendToParty: Party get() = peerParty } } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt index b8b5825744..3820c0976f 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/FlowStateMachineImpl.kt @@ -4,28 +4,27 @@ import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.FiberScheduler import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.DeclaredField.Companion.declaredField -import net.corda.core.abbreviate +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.crypto.random63BitValue import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.abbreviate +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.staticField import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.* import net.corda.node.services.api.FlowAppAuditEvent import net.corda.node.services.api.FlowPermissionAuditEvent import net.corda.node.services.api.ServiceHubInternal -import net.corda.node.utilities.StrandLocalTransactionManager -import net.corda.node.utilities.createTransaction -import org.jetbrains.exposed.sql.Database -import org.jetbrains.exposed.sql.Transaction -import org.jetbrains.exposed.sql.transactions.TransactionManager +import net.corda.node.services.statemachine.FlowSessionState.Initiating +import net.corda.node.utilities.CordaPersistence +import net.corda.node.utilities.DatabaseTransaction +import net.corda.node.utilities.DatabaseTransactionManager import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.sql.Connection import java.sql.SQLException import java.util.* import java.util.concurrent.TimeUnit @@ -38,7 +37,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, override val flowInitiator: FlowInitiator) : Fiber(id.toString(), scheduler), FlowStateMachine { companion object { // Used to work around a small limitation in Quasar. - private val QUASAR_UNBLOCKER = declaredField(Fiber::class, "SERIALIZER_BLOCKER").value + private val QUASAR_UNBLOCKER = Fiber::class.staticField("SERIALIZER_BLOCKER").value /** * Return the current [FlowStateMachineImpl] or null if executing outside of one. @@ -53,23 +52,23 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable inline fun sleep(millis: Long) { if (currentStateMachine() != null) { - val db = StrandLocalTransactionManager.database - TransactionManager.current().commit() - TransactionManager.current().close() + val db = DatabaseTransactionManager.dataSource + DatabaseTransactionManager.current().commit() + DatabaseTransactionManager.current().close() Strand.sleep(millis) - StrandLocalTransactionManager.database = db - TransactionManager.manager.newTransaction(Connection.TRANSACTION_REPEATABLE_READ) + DatabaseTransactionManager.dataSource = db + DatabaseTransactionManager.newTransaction() } else Strand.sleep(millis) } } // These fields shouldn't be serialised, so they are marked @Transient. @Transient override lateinit var serviceHub: ServiceHubInternal - @Transient internal lateinit var database: Database + @Transient internal lateinit var database: CordaPersistence @Transient internal lateinit var actionOnSuspend: (FlowIORequest) -> Unit @Transient internal lateinit var actionOnEnd: (Try, Boolean) -> Unit @Transient internal var fromCheckpoint: Boolean = false - @Transient private var txTrampoline: Transaction? = null + @Transient private var txTrampoline: DatabaseTransaction? = null /** * Return the logger for this state machine. The logger name incorporates [id] and so including it in the log message @@ -77,10 +76,10 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, */ override val logger: Logger = LoggerFactory.getLogger("net.corda.flow.$id") - @Transient private var _resultFuture: SettableFuture? = SettableFuture.create() + @Transient private var _resultFuture: OpenFuture? = openFuture() /** This future will complete when the call method returns. */ - override val resultFuture: ListenableFuture - get() = _resultFuture ?: SettableFuture.create().also { _resultFuture = it } + override val resultFuture: CordaFuture + get() = _resultFuture ?: openFuture().also { _resultFuture = it } // This state IS serialised, as we need it to know what the fiber is waiting for. internal val openSessions = HashMap, Party>, FlowSession>() @@ -100,7 +99,12 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, logger.debug { "Calling flow: $logic" } val startTime = System.nanoTime() val result = try { - logic.call() + val r = logic.call() + // Only sessions which have done a single send and nothing else will block here + openSessions.values + .filter { it.state is Initiating } + .forEach { it.waitForConfirmation() } + r } catch (e: FlowException) { recordDuration(startTime, success = false) // Check if the FlowException was propagated by looking at where the stack trace originates (see suspendAndExpectReceive). @@ -116,21 +120,17 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } recordDuration(startTime) - // Only sessions which have done a single send and nothing else will block here - openSessions.values - .filter { it.state is FlowSessionState.Initiating } - .forEach { it.waitForConfirmation() } // This is to prevent actionOnEnd being called twice if it throws an exception actionOnEnd(Try.Success(result), false) _resultFuture?.set(result) logic.progressTracker?.currentStep = ProgressTracker.DONE - logger.debug { "Flow finished with result $result" } + logger.debug { "Flow finished with result ${result.toString().abbreviate(300)}" } } private fun createTransaction() { // Make sure we have a database transaction database.createTransaction() - logger.trace { "Starting database transaction ${TransactionManager.currentOrNull()} on ${Strand.currentStrand()}" } + logger.trace { "Starting database transaction ${DatabaseTransactionManager.currentOrNull()} on ${Strand.currentStrand()}" } } private fun processException(exception: Throwable, propagated: Boolean) { @@ -140,7 +140,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } internal fun commitTransaction() { - val transaction = TransactionManager.current() + val transaction = DatabaseTransactionManager.current() try { logger.trace { "Committing database transaction $transaction on ${Strand.currentStrand()}." } transaction.commit() @@ -153,6 +153,12 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } } + @Suspendable + override fun getFlowContext(otherParty: Party, sessionFlow: FlowLogic<*>): FlowContext { + val state = getConfirmedSession(otherParty, sessionFlow).state as FlowSessionState.Initiated + return state.context + } + @Suspendable override fun sendAndReceive(receiveType: Class, otherParty: Party, @@ -160,7 +166,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, sessionFlow: FlowLogic<*>, retrySend: Boolean): UntrustworthyData { logger.debug { "sendAndReceive(${receiveType.name}, $otherParty, ${payload.toString().abbreviate(300)}) ..." } - val session = getConfirmedSession(otherParty, sessionFlow) + val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) val sessionData = if (session == null) { val newSession = startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = true, retryable = retrySend) // Only do a receive here as the session init has carried the payload @@ -178,8 +184,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData { logger.debug { "receive(${receiveType.name}, $otherParty) ..." } - val session = getConfirmedSession(otherParty, sessionFlow) ?: - startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true) + val session = getConfirmedSession(otherParty, sessionFlow) val sessionData = receiveInternal(session, receiveType) logger.debug { "Received ${sessionData.message.payload.toString().abbreviate(300)}" } return sessionData.checkPayloadIs(receiveType) @@ -188,7 +193,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, @Suspendable override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) { logger.debug { "send($otherParty, ${payload.toString().abbreviate(300)})" } - val session = getConfirmedSession(otherParty, sessionFlow) + val session = getConfirmedSessionIfPresent(otherParty, sessionFlow) if (session == null) { // Don't send the payload again if it was already piggy-backed on a session init startNewSession(otherParty, sessionFlow, payload, waitForConfirmation = false) @@ -222,14 +227,14 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, override fun checkFlowPermission(permissionName: String, extraAuditData: Map) { val permissionGranted = true // TODO define permission control service on ServiceHubInternal and actually check authorization. val checkPermissionEvent = FlowPermissionAuditEvent( - serviceHub.clock.instant(), - flowInitiator, - "Flow Permission Required: $permissionName", - extraAuditData, - logic.javaClass, - id, - permissionName, - permissionGranted) + serviceHub.clock.instant(), + flowInitiator, + "Flow Permission Required: $permissionName", + extraAuditData, + logic.javaClass, + id, + permissionName, + permissionGranted) serviceHub.auditService.recordAuditEvent(checkPermissionEvent) if (!permissionGranted) { throw FlowPermissionException("User $flowInitiator not permissioned for $permissionName on flow $id") @@ -237,18 +242,29 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } // TODO Dummy implementation of access to application specific audit logging - override fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) { + override fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map): Unit { val flowAuditEvent = FlowAppAuditEvent( - serviceHub.clock.instant(), - flowInitiator, - comment, - extraAuditData, - logic.javaClass, + serviceHub.clock.instant(), + flowInitiator, + comment, + extraAuditData, + logic.javaClass, id, eventType) serviceHub.auditService.recordAuditEvent(flowAuditEvent) } + @Suspendable + override fun flowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot? { + val factory = FlowStackSnapshotFactory.instance + return factory.getFlowStackSnapshot(flowClass) + } + + override fun persistFlowStackSnapshot(flowClass: Class<*>): Unit { + val factory = FlowStackSnapshotFactory.instance + factory.persistAsJsonFile(flowClass, serviceHub.configuration.baseDirectory, id.toString()) + } + /** * This method will suspend the state machine and wait for incoming session init response from other party. */ @@ -256,10 +272,13 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, private fun FlowSession.waitForConfirmation() { val (peerParty, sessionInitResponse) = receiveInternal(this, null) if (sessionInitResponse is SessionConfirm) { - state = FlowSessionState.Initiated(peerParty, sessionInitResponse.initiatedSessionId) + state = FlowSessionState.Initiated( + peerParty, + sessionInitResponse.initiatedSessionId, + FlowContext(sessionInitResponse.flowVersion, sessionInitResponse.appName)) } else { sessionInitResponse as SessionReject - throw FlowSessionException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}") + throw UnexpectedFlowEndException("Party ${state.sendToParty} rejected session request: ${sessionInitResponse.errorMessage}") } } @@ -273,9 +292,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun sendInternal(session: FlowSession, message: SessionMessage) { - suspend(SendOnly(session, message)) - } + private fun sendInternal(session: FlowSession, message: SessionMessage) = suspend(SendOnly(session, message)) private inline fun receiveInternal( session: FlowSession, @@ -291,15 +308,21 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, } @Suspendable - private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession? { + private fun getConfirmedSessionIfPresent(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession? { return openSessions[Pair(sessionFlow, otherParty)]?.apply { if (state is FlowSessionState.Initiating) { - // Session still initiating, try to retrieve the init response. + // Session still initiating, wait for the confirmation waitForConfirmation() } } } + @Suspendable + private fun getConfirmedSession(otherParty: Party, sessionFlow: FlowLogic<*>): FlowSession { + return getConfirmedSessionIfPresent(otherParty, sessionFlow) ?: + startNewSession(otherParty, sessionFlow, null, waitForConfirmation = true) + } + /** * Creates a new session. The provided [otherParty] can be an identity of any advertised service on the network, * and might be advertised by more than one node. Therefore we first choose a single node that advertises it @@ -316,7 +339,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val session = FlowSession(sessionFlow, random63BitValue(), null, FlowSessionState.Initiating(otherParty), retryable) openSessions[Pair(sessionFlow, otherParty)] = session val (version, initiatingFlowClass) = sessionFlow.javaClass.flowVersionAndInitiatingClass - val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass, version, firstPayload) + val sessionInit = SessionInit(session.ourSessionId, initiatingFlowClass.name, version, "not defined", firstPayload) sendInternal(session, sessionInit) if (waitForConfirmation) { session.waitForConfirmation() @@ -336,8 +359,9 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, val polledMessage = pollForMessage() return if (polledMessage != null) { if (this is SendAndReceive) { - // We've already received a message but we suspend so that the send can be performed - suspend(this) + // Since we've already received the message, we downgrade to a send only to get the payload out and not + // inadvertently block + suspend(SendOnly(session, message)) } polledMessage } else { @@ -361,7 +385,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, session.erroredEnd(message) } else { val expectedType = receiveRequest.userReceiveType?.name ?: receiveType.simpleName - throw FlowSessionException("Counterparty flow on ${session.state.sendToParty} has completed without " + + throw UnexpectedFlowEndException("Counterparty flow on ${session.state.sendToParty} has completed without " + "sending a $expectedType") } } else { @@ -375,7 +399,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, (end.errorResponse as java.lang.Throwable).fillInStackTrace() throw end.errorResponse } else { - throw FlowSessionException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated") + throw UnexpectedFlowEndException("Counterparty flow on ${state.sendToParty} had an internal error and has terminated") } } @@ -383,8 +407,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, private fun suspend(ioRequest: FlowIORequest) { // We have to pass the thread local database transaction across via a transient field as the fiber park // swaps them out. - txTrampoline = TransactionManager.currentOrNull() - StrandLocalTransactionManager.setThreadLocalTx(null) + txTrampoline = DatabaseTransactionManager.setThreadLocalTx(null) if (ioRequest is WaitingRequest) waitingForResponse = ioRequest @@ -393,7 +416,7 @@ class FlowStateMachineImpl(override val id: StateMachineRunId, logger.trace { "Suspended on $ioRequest" } // restore the Tx onto the ThreadLocal so that we can commit the ensuing checkpoint to the DB try { - StrandLocalTransactionManager.setThreadLocalTx(txTrampoline) + DatabaseTransactionManager.setThreadLocalTx(txTrampoline) txTrampoline = null actionOnSuspend(ioRequest) } catch (t: Throwable) { @@ -456,6 +479,7 @@ val Class>.flowVersionAndInitiatingClass: Pair>, - val flowVerison: Int, - val firstPayload: Any?) : SessionMessage - interface ExistingSessionMessage : SessionMessage { val recipientSessionId: Long } -data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage { - override fun toString(): String = "${javaClass.simpleName}(recipientSessionId=$recipientSessionId, payload=$payload)" -} - interface SessionInitResponse : ExistingSessionMessage { val initiatorSessionId: Long override val recipientSessionId: Long get() = initiatorSessionId } -data class SessionConfirm(override val initiatorSessionId: Long, val initiatedSessionId: Long) : SessionInitResponse +interface SessionEnd : ExistingSessionMessage + +data class SessionInit(val initiatorSessionId: Long, + val initiatingFlowClass: String, + val flowVersion: Int, + val appName: String, + val firstPayload: Any?) : SessionMessage + +data class SessionConfirm(override val initiatorSessionId: Long, + val initiatedSessionId: Long, + val flowVersion: Int, + val appName: String) : SessionInitResponse + data class SessionReject(override val initiatorSessionId: Long, val errorMessage: String) : SessionInitResponse -interface SessionEnd : ExistingSessionMessage +data class SessionData(override val recipientSessionId: Long, val payload: Any) : ExistingSessionMessage + data class NormalSessionEnd(override val recipientSessionId: Long) : SessionEnd + data class ErrorSessionEnd(override val recipientSessionId: Long, val errorResponse: FlowException?) : SessionEnd data class ReceivedSessionMessage(val sender: Party, val message: M) fun ReceivedSessionMessage.checkPayloadIs(type: Class): UntrustworthyData { - if (type.isInstance(message.payload)) { - return UntrustworthyData(type.cast(message.payload)) - } else { - throw FlowSessionException("We were expecting a ${type.name} from $sender but we instead got a " + - "${message.payload.javaClass.name} (${message.payload})") - } + return type.castIfPossible(message.payload)?.let { UntrustworthyData(it) } ?: + throw UnexpectedFlowEndException("We were expecting a ${type.name} from $sender but we instead got a " + + "${message.payload.javaClass.name} (${message.payload})") } diff --git a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt index 27ce753cd3..fbbbbaf4fd 100644 --- a/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/statemachine/StateMachineManager.kt @@ -2,44 +2,42 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.FiberExecutorScheduler -import co.paralleluniverse.io.serialization.kryo.KryoSerializer import co.paralleluniverse.strands.Strand import com.codahale.metrics.Gauge -import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.KryoException -import com.esotericsoftware.kryo.Serializer -import com.esotericsoftware.kryo.io.Input -import com.esotericsoftware.kryo.io.Output -import com.esotericsoftware.kryo.pool.KryoPool import com.google.common.collect.HashMultimap -import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors -import io.requery.util.CloseableIterator -import net.corda.core.ThreadBox -import net.corda.core.bufferUntilSubscribed +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.crypto.random63BitValue -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowInitiator -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.StateMachineRunId +import net.corda.core.flows.* import net.corda.core.identity.Party +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.castIfPossible import net.corda.core.messaging.DataFeed -import net.corda.core.serialization.* -import net.corda.core.then +import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT +import net.corda.core.serialization.SerializationDefaults.SERIALIZATION_FACTORY +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize import net.corda.core.utilities.Try import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor import net.corda.core.utilities.trace -import net.corda.node.internal.SessionRejectException +import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.messaging.ReceivedMessage import net.corda.node.services.messaging.TopicSession -import net.corda.node.utilities.* +import net.corda.node.utilities.AffinityExecutor +import net.corda.node.utilities.CordaPersistence +import net.corda.node.utilities.bufferUntilDatabaseCommit +import net.corda.node.utilities.wrapWithDatabaseTransaction +import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl +import net.corda.nodeapi.internal.serialization.withTokenContext import org.apache.activemq.artemis.utils.ReusableLatch -import org.jetbrains.exposed.sql.Database import org.slf4j.Logger import rx.Observable import rx.subjects.PublishSubject @@ -76,40 +74,11 @@ import kotlin.collections.ArrayList class StateMachineManager(val serviceHub: ServiceHubInternal, val checkpointStorage: CheckpointStorage, val executor: AffinityExecutor, - val database: Database, + val database: CordaPersistence, private val unfinishedFibers: ReusableLatch = ReusableLatch()) { inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) - private val quasarKryoPool = KryoPool.Builder { - val serializer = Fiber.getFiberSerializer(false) as KryoSerializer - val classResolver = makeNoWhitelistClassResolver().apply { setKryo(serializer.kryo) } - // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that - val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } - serializer.kryo.apply { - field.set(this, classResolver) - DefaultKryoCustomizer.customize(this) - addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector) - } - }.build() - - // TODO Move this into the blacklist and upgrade the blacklist to allow custom messages - private object AutoCloseableSerialisationDetector : Serializer() { - override fun write(kryo: Kryo, output: Output, closeable: AutoCloseable) { - val message = if (closeable is CloseableIterator<*>) { - "A live Iterator pointing to the database has been detected during flow checkpointing. This may be due " + - "to a Vault query - move it into a private method." - } else { - "${closeable.javaClass.name}, which is a closeable resource, has been detected during flow checkpointing. " + - "Restoring such resources across node restarts is not supported. Make sure code accessing it is " + - "confined to a private method or the reference is nulled out." - } - throw UnsupportedOperationException(message) - } - - override fun read(kryo: Kryo, input: Input, type: Class) = throw IllegalStateException("Should not reach here!") - } - companion object { private val logger = loggerFor() internal val sessionTopic = TopicSession("platform.session") @@ -170,17 +139,18 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, internal val tokenizableServices = ArrayList() // Context for tokenized services in checkpoints private val serializationContext by lazy { - SerializeAsTokenContext(tokenizableServices, quasarKryoPool, serviceHub) + SerializeAsTokenContextImpl(tokenizableServices, SERIALIZATION_FACTORY, CHECKPOINT_CONTEXT, serviceHub) } + fun findServices(predicate: (Any) -> Boolean) = tokenizableServices.filter(predicate) + /** Returns a list of all state machines executing the given flow logic at the top level (subflows do not count) */ - fun

, T> findStateMachines(flowClass: Class

): List>> { + fun

, T> findStateMachines(flowClass: Class

): List>> { @Suppress("UNCHECKED_CAST") return mutex.locked { - stateMachines.keys - .map { it.logic } - .filterIsInstance(flowClass) - .map { it to (it.stateMachine as FlowStateMachineImpl).resultFuture } + stateMachines.keys.mapNotNull { + flowClass.castIfPossible(it.logic)?.let { it to (it.stateMachine as FlowStateMachineImpl).resultFuture } + } } } @@ -233,7 +203,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, * @param allowedUnsuspendedFiberCount Optional parameter is used in some tests. */ fun stop(allowedUnsuspendedFiberCount: Int = 0) { - check(allowedUnsuspendedFiberCount >= 0) + require(allowedUnsuspendedFiberCount >= 0) mutex.locked { if (stopping) throw IllegalStateException("Already stopping!") stopping = true @@ -368,28 +338,30 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, private fun onSessionInit(sessionInit: SessionInit, receivedMessage: ReceivedMessage, sender: Party) { logger.trace { "Received $sessionInit from $sender" } - val otherPartySessionId = sessionInit.initiatorSessionId + val senderSessionId = sessionInit.initiatorSessionId - fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(otherPartySessionId, message)) + fun sendSessionReject(message: String) = sendSessionMessage(sender, SessionReject(senderSessionId, message)) - val initiatedFlowFactory = serviceHub.getFlowFactory(sessionInit.initiatingFlowClass) - if (initiatedFlowFactory == null) { - logger.warn("${sessionInit.initiatingFlowClass} has not been registered: $sessionInit") - sendSessionReject("${sessionInit.initiatingFlowClass.name} has not been registered") - return - } - - val session = try { - val flow = initiatedFlowFactory.createFlow(receivedMessage.platformVersion, sender, sessionInit) - val fiber = createFiber(flow, FlowInitiator.Peer(sender)) - val session = FlowSession(flow, random63BitValue(), sender, FlowSessionState.Initiated(sender, otherPartySessionId)) + val (session, initiatedFlowFactory) = try { + val initiatedFlowFactory = getInitiatedFlowFactory(sessionInit) + val flow = initiatedFlowFactory.createFlow(sender) + val senderFlowVersion = when (initiatedFlowFactory) { + is InitiatedFlowFactory.Core -> receivedMessage.platformVersion // The flow version for the core flows is the platform version + is InitiatedFlowFactory.CorDapp -> sessionInit.flowVersion + } + val session = FlowSession( + flow, + random63BitValue(), + sender, + FlowSessionState.Initiated(sender, senderSessionId, FlowContext(senderFlowVersion, sessionInit.appName))) if (sessionInit.firstPayload != null) { session.receivedMessages += ReceivedSessionMessage(sender, SessionData(session.ourSessionId, sessionInit.firstPayload)) } openSessions[session.ourSessionId] = session + val fiber = createFiber(flow, FlowInitiator.Peer(sender)) fiber.openSessions[Pair(flow, sender)] = session updateCheckpoint(fiber) - session + session to initiatedFlowFactory } catch (e: SessionRejectException) { logger.warn("${e.logMessage}: $sessionInit") sendSessionReject(e.rejectMessage) @@ -400,28 +372,38 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, return } - sendSessionMessage(sender, SessionConfirm(otherPartySessionId, session.ourSessionId), session.fiber) - session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass.name}" } + val (ourFlowVersion, appName) = when (initiatedFlowFactory) { + // The flow version for the core flows is the platform version + is InitiatedFlowFactory.Core -> serviceHub.myInfo.platformVersion to "corda" + is InitiatedFlowFactory.CorDapp -> initiatedFlowFactory.flowVersion to initiatedFlowFactory.appName + } + + sendSessionMessage(sender, SessionConfirm(senderSessionId, session.ourSessionId, ourFlowVersion, appName), session.fiber) + session.fiber.logger.debug { "Initiated by $sender using ${sessionInit.initiatingFlowClass}" } session.fiber.logger.trace { "Initiated from $sessionInit on $session" } resumeFiber(session.fiber) } - private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes> { - return quasarKryoPool.run { kryo -> - // add the map of tokens -> tokenizedServices to the kyro context - kryo.withSerializationContext(serializationContext) { - fiber.serialize(kryo) - } + private fun getInitiatedFlowFactory(sessionInit: SessionInit): InitiatedFlowFactory<*> { + val initiatingFlowClass = try { + Class.forName(sessionInit.initiatingFlowClass).asSubclass(FlowLogic::class.java) + } catch (e: ClassNotFoundException) { + throw SessionRejectException("Don't know ${sessionInit.initiatingFlowClass}") + } catch (e: ClassCastException) { + throw SessionRejectException("${sessionInit.initiatingFlowClass} is not a flow") } + return serviceHub.getFlowFactory(initiatingFlowClass) ?: + throw SessionRejectException("$initiatingFlowClass is not registered") + } + + private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes> { + return fiber.serialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)) } private fun deserializeFiber(checkpoint: Checkpoint, logger: Logger): FlowStateMachineImpl<*>? { return try { - quasarKryoPool.run { kryo -> - // put the map of token -> tokenized into the kryo context - kryo.withSerializationContext(serializationContext) { - checkpoint.serializedFiber.deserialize(kryo) - }.apply { fromCheckpoint = true } + checkpoint.serializedFiber.deserialize(context = CHECKPOINT_CONTEXT.withTokenContext(serializationContext)).apply { + fromCheckpoint = true } } catch (t: Throwable) { logger.error("Encountered unrestorable checkpoint!", t) @@ -505,11 +487,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, fun add(logic: FlowLogic, flowInitiator: FlowInitiator): FlowStateMachineImpl { // TODO: Check that logic has @Suspendable on its call method. executor.checkOnThread() - // We swap out the parent transaction context as using this frequently leads to a deadlock as we wait - // on the flow completion future inside that context. The problem is that any progress checkpoints are - // unable to acquire the table lock and move forward till the calling transaction finishes. - // Committing in line here on a fresh context ensure we can progress. - val fiber = database.isolatedTransaction { + val fiber = database.transaction { val fiber = createFiber(logic, flowInitiator) updateCheckpoint(fiber) fiber @@ -620,3 +598,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, } } } + +class SessionRejectException(val rejectMessage: String, val logMessage: String) : Exception() { + constructor(message: String) : this(message, message) +} diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/BFTNonValidatingNotaryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/BFTNonValidatingNotaryService.kt index a64cfaaa9d..b73868686d 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/BFTNonValidatingNotaryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/BFTNonValidatingNotaryService.kt @@ -2,9 +2,12 @@ package net.corda.node.services.transactions import co.paralleluniverse.fibers.Suspendable import com.google.common.util.concurrent.SettableFuture +import net.corda.core.crypto.Crypto import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.SignableData +import net.corda.core.crypto.SignatureMetadata import net.corda.core.flows.FlowLogic -import net.corda.core.getOrThrow +import net.corda.core.flows.NotaryException import net.corda.core.identity.Party import net.corda.core.node.services.NotaryService import net.corda.core.node.services.TimeWindowChecker @@ -12,9 +15,9 @@ import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.transactions.FilteredTransaction import net.corda.core.utilities.debug +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.core.utilities.unwrap -import net.corda.flows.NotaryException import net.corda.node.services.api.ServiceHubInternal import kotlin.concurrent.thread @@ -23,21 +26,25 @@ import kotlin.concurrent.thread * * A transaction is notarised when the consensus is reached by the cluster on its uniqueness, and time-window validity. */ -class BFTNonValidatingNotaryService(override val services: ServiceHubInternal) : NotaryService() { +class BFTNonValidatingNotaryService(override val services: ServiceHubInternal, cluster: BFTSMaRt.Cluster = distributedCluster) : NotaryService() { companion object { val type = SimpleNotaryService.type.getSubType("bft") private val log = loggerFor() + private val distributedCluster = object : BFTSMaRt.Cluster { + override fun waitUntilAllReplicasHaveInitialized() { + log.warn("A replica may still be initializing, in which case the upcoming consensus change may cause it to spin.") + } + } } private val client: BFTSMaRt.Client private val replicaHolder = SettableFuture.create() init { - val replicaId = services.configuration.bftReplicaId ?: throw IllegalArgumentException("bftReplicaId value must be specified in the configuration") - val config = BFTSMaRtConfig(services.configuration.notaryClusterAddresses) - - client = config.use { - val configHandle = config.handle() + require(services.configuration.bftSMaRt.isValid()) { "bftSMaRt replicaId must be specified in the configuration" } + client = BFTSMaRtConfig(services.configuration.notaryClusterAddresses, services.configuration.bftSMaRt.debug, services.configuration.bftSMaRt.exposeRaces).use { + val replicaId = services.configuration.bftSMaRt.replicaId + val configHandle = it.handle() // Replica startup must be in parallel with other replicas, otherwise the constructor may not return: thread(name = "BFT SMaRt replica $replicaId init", isDaemon = true) { configHandle.use { @@ -47,16 +54,18 @@ class BFTNonValidatingNotaryService(override val services: ServiceHubInternal) : log.info("BFT SMaRt replica $replicaId is running.") } } - - BFTSMaRt.Client(it, replicaId) + BFTSMaRt.Client(it, replicaId, cluster) } } + fun waitUntilReplicaHasInitialized() { + log.debug { "Waiting for replica ${services.configuration.bftSMaRt.replicaId} to initialize." } + replicaHolder.getOrThrow() // It's enough to wait for the ServiceReplica constructor to return. + } + fun commitTransaction(tx: Any, otherSide: Party) = client.commitTransaction(tx, otherSide) - override fun createServiceFlow(otherParty: Party, platformVersion: Int): FlowLogic { - return ServiceFlow(otherParty, this) - } + override fun createServiceFlow(otherParty: Party): FlowLogic = ServiceFlow(otherParty, this) private class ServiceFlow(val otherSide: Party, val service: BFTNonValidatingNotaryService) : FlowLogic() { @Suspendable @@ -101,7 +110,8 @@ class BFTNonValidatingNotaryService(override val services: ServiceHubInternal) : commitInputStates(inputs, id, callerIdentity) log.debug { "Inputs committed successfully, signing $id" } - val sig = sign(id.bytes) + val signableData = SignableData(id, SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(services.notaryIdentityKey).schemeNumberID)) + val sig = sign(signableData) BFTSMaRt.ReplicaResponse.Signature(sig) } catch (e: NotaryException) { log.debug { "Error processing transaction: ${e.error}" } diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRt.kt b/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRt.kt index f910c0ef1e..8eafe93ea7 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRt.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRt.kt @@ -3,6 +3,7 @@ package net.corda.node.services.transactions import bftsmart.communication.ServerCommunicationSystem import bftsmart.communication.client.netty.NettyClientServerCommunicationSystemClientSide import bftsmart.communication.client.netty.NettyClientServerSession +import bftsmart.statemanagement.strategy.StandardStateManager import bftsmart.tom.MessageContext import bftsmart.tom.ServiceProxy import bftsmart.tom.ServiceReplica @@ -11,32 +12,28 @@ import bftsmart.tom.core.messages.TOMMessage import bftsmart.tom.server.defaultservices.DefaultRecoverable import bftsmart.tom.server.defaultservices.DefaultReplier import bftsmart.tom.util.Extractor -import net.corda.core.DeclaredField.Companion.declaredField import net.corda.core.contracts.StateRef import net.corda.core.contracts.TimeWindow -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.SignedData -import net.corda.core.crypto.sign +import net.corda.core.crypto.* +import net.corda.core.flows.NotaryError +import net.corda.core.flows.NotaryException import net.corda.core.identity.Party +import net.corda.core.internal.declaredField +import net.corda.core.internal.toTypedArray import net.corda.core.node.services.TimeWindowChecker import net.corda.core.node.services.UniquenessProvider import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.toTypedArray import net.corda.core.transactions.FilteredTransaction import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor -import net.corda.flows.NotaryError -import net.corda.flows.NotaryException import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.transactions.BFTSMaRt.Client import net.corda.node.services.transactions.BFTSMaRt.Replica import net.corda.node.utilities.JDBCHashMap -import net.corda.node.utilities.transaction import java.nio.file.Path import java.util.* @@ -71,7 +68,12 @@ object BFTSMaRt { data class Signatures(val txSignatures: List) : ClusterResponse() } - class Client(config: BFTSMaRtConfig, private val clientId: Int) : SingletonSerializeAsToken() { + interface Cluster { + /** Avoid bug where a replica fails to start due to a consensus change during the BFT startup sequence. */ + fun waitUntilAllReplicasHaveInitialized() + } + + class Client(config: BFTSMaRtConfig, private val clientId: Int, private val cluster: Cluster) : SingletonSerializeAsToken() { companion object { private val log = loggerFor() } @@ -101,6 +103,7 @@ object BFTSMaRt { fun commitTransaction(transaction: Any, otherSide: Party): ClusterResponse { require(transaction is FilteredTransaction || transaction is SignedTransaction) { "Unsupported transaction type: ${transaction.javaClass.name}" } awaitClientConnectionToCluster() + cluster.waitUntilAllReplicasHaveInitialized() val requestBytes = CommitRequest(transaction, otherSide).serialize().bytes val responseBytes = proxy.invokeOrdered(requestBytes) return responseBytes.deserialize() @@ -170,12 +173,24 @@ object BFTSMaRt { abstract class Replica(config: BFTSMaRtConfig, replicaId: Int, tableName: String, - private val services: ServiceHubInternal, + protected val services: ServiceHubInternal, private val timeWindowChecker: TimeWindowChecker) : DefaultRecoverable() { companion object { private val log = loggerFor() } + private val stateManagerOverride = run { + // Mock framework shutdown is not in reverse order, and we need to stop the faulty replicas first: + val exposeStartupRace = config.exposeRaces && replicaId < maxFaultyReplicas(config.clusterSize) + object : StandardStateManager() { + override fun askCurrentConsensusId() { + if (exposeStartupRace) Thread.sleep(20000) // Must be long enough for the non-redundant replicas to reach a non-initial consensus. + super.askCurrentConsensusId() + } + } + } + + override fun getStateManager() = stateManagerOverride // TODO: Use Requery with proper DB schema instead of JDBCHashMap. // Must be initialised before ServiceReplica is started private val commitLog = services.database.transaction { JDBCHashMap(tableName) } @@ -235,6 +250,10 @@ object BFTSMaRt { return services.database.transaction { services.keyManagementService.sign(bytes, services.notaryIdentityKey) } } + protected fun sign(signableData: SignableData): TransactionSignature { + return services.database.transaction { services.keyManagementService.sign(signableData, services.notaryIdentityKey) } + } + // TODO: // - Test snapshot functionality with different bft-smart cluster configurations. // - Add streaming to support large data sets. diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRtConfig.kt b/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRtConfig.kt index 72c0212eb2..827c3356ae 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRtConfig.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/BFTSMaRtConfig.kt @@ -1,6 +1,6 @@ package net.corda.node.services.transactions -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor @@ -17,15 +17,17 @@ import java.util.concurrent.TimeUnit.MILLISECONDS * Each instance of this class creates such a configHome, accessible via [path]. * The files are deleted on [close] typically via [use], see [PathManager] for details. */ -class BFTSMaRtConfig(private val replicaAddresses: List, debug: Boolean = false) : PathManager(Files.createTempDirectory("bft-smart-config")) { +class BFTSMaRtConfig(private val replicaAddresses: List, debug: Boolean, val exposeRaces: Boolean) : PathManager(Files.createTempDirectory("bft-smart-config")) { companion object { private val log = loggerFor() internal val portIsClaimedFormat = "Port %s is claimed by another replica: %s" } + val clusterSize get() = replicaAddresses.size + init { val claimedPorts = mutableSetOf() - val n = replicaAddresses.size + val n = clusterSize (0 until n).forEach { replicaId -> // Each replica claims the configured port and the next one: replicaPorts(replicaId).forEach { port -> diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/DistributedImmutableMap.kt b/node/src/main/kotlin/net/corda/node/services/transactions/DistributedImmutableMap.kt index 1b0b528e21..6367346bbc 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/DistributedImmutableMap.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/DistributedImmutableMap.kt @@ -8,9 +8,8 @@ import io.atomix.copycat.server.StateMachine import io.atomix.copycat.server.storage.snapshot.SnapshotReader import io.atomix.copycat.server.storage.snapshot.SnapshotWriter import net.corda.core.utilities.loggerFor +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.JDBCHashMap -import net.corda.node.utilities.transaction -import org.jetbrains.exposed.sql.Database import java.util.* /** @@ -21,7 +20,7 @@ import java.util.* * to disk, and sharing them across the cluster. A new node joining the cluster will have to obtain and install a snapshot * containing the entire JDBC table contents. */ -class DistributedImmutableMap(val db: Database, tableName: String) : StateMachine(), Snapshottable { +class DistributedImmutableMap(val db: CordaPersistence, tableName: String) : StateMachine(), Snapshottable { companion object { private val log = loggerFor>() } @@ -55,7 +54,7 @@ class DistributedImmutableMap(val db: Database, tableName: Str * @return map containing conflicting entries */ fun put(commit: Commit>): Map { - commit.use { commit -> + commit.use { val conflicts = LinkedHashMap() db.transaction { val entries = commit.operation().entries diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/InMemoryTransactionVerifierService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/InMemoryTransactionVerifierService.kt index 57619e66e1..5cc77a3bec 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/InMemoryTransactionVerifierService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/InMemoryTransactionVerifierService.kt @@ -1,7 +1,7 @@ package net.corda.node.services.transactions -import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.MoreExecutors +import net.corda.core.internal.concurrent.fork import net.corda.core.node.services.TransactionVerifierService import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.LedgerTransaction @@ -10,9 +10,5 @@ import java.util.concurrent.Executors class InMemoryTransactionVerifierService(numberOfWorkers: Int) : SingletonSerializeAsToken(), TransactionVerifierService { private val workerPool = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(numberOfWorkers)) - override fun verify(transaction: LedgerTransaction): ListenableFuture<*> { - return workerPool.submit { - transaction.verify() - } - } + override fun verify(transaction: LedgerTransaction) = workerPool.fork(transaction::verify) } diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/NonValidatingNotaryFlow.kt b/node/src/main/kotlin/net/corda/node/services/transactions/NonValidatingNotaryFlow.kt index 354ef7799d..97bedc4c46 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/NonValidatingNotaryFlow.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/NonValidatingNotaryFlow.kt @@ -1,12 +1,13 @@ package net.corda.node.services.transactions import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.NotaryFlow +import net.corda.core.flows.TransactionParts import net.corda.core.identity.Party import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.core.transactions.FilteredTransaction +import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.utilities.unwrap -import net.corda.flows.NotaryFlow -import net.corda.flows.TransactionParts class NonValidatingNotaryFlow(otherSide: Party, service: TrustedAuthorityNotaryService) : NotaryFlow.Service(otherSide, service) { /** @@ -19,10 +20,19 @@ class NonValidatingNotaryFlow(otherSide: Party, service: TrustedAuthorityNotaryS */ @Suspendable override fun receiveAndVerifyTx(): TransactionParts { - val ftx = receive(otherSide).unwrap { - it.verify() - it + val parts = receive(otherSide).unwrap { + when (it) { + is FilteredTransaction -> { + it.verify() + TransactionParts(it.rootHash, it.filteredLeaves.inputs, it.filteredLeaves.timeWindow) + } + is NotaryChangeWireTransaction -> TransactionParts(it.id, it.inputs, null) + else -> { + throw IllegalArgumentException("Received unexpected transaction type: ${it::class.java.simpleName}," + + "expected either ${FilteredTransaction::class.java.simpleName} or ${NotaryChangeWireTransaction::class.java.simpleName}") + } + } } - return TransactionParts(ftx.rootHash, ftx.filteredLeaves.inputs, ftx.filteredLeaves.timeWindow) + return parts } -} +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/OutOfProcessTransactionVerifierService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/OutOfProcessTransactionVerifierService.kt index 1539455fd9..1a6507796c 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/OutOfProcessTransactionVerifierService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/OutOfProcessTransactionVerifierService.kt @@ -2,11 +2,12 @@ package net.corda.node.services.transactions import com.codahale.metrics.Gauge import com.codahale.metrics.Timer -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.SecureHash import net.corda.core.node.services.TransactionVerifierService import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.openFuture import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.loggerFor @@ -24,7 +25,7 @@ abstract class OutOfProcessTransactionVerifierService( private data class VerificationHandle( val transactionId: SecureHash, - val resultFuture: SettableFuture, + val resultFuture: OpenFuture, val durationTimerContext: Timer.Context ) @@ -61,9 +62,9 @@ abstract class OutOfProcessTransactionVerifierService( abstract fun sendRequest(nonce: Long, transaction: LedgerTransaction) - override fun verify(transaction: LedgerTransaction): ListenableFuture<*> { + override fun verify(transaction: LedgerTransaction): CordaFuture<*> { log.info("Verifying ${transaction.id}") - val future = SettableFuture.create() + val future = openFuture() val nonce = random63BitValue() verificationHandles[nonce] = VerificationHandle(transaction.id, future, durationTimer.time()) sendRequest(nonce, transaction) diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt index f6419d2c24..820f3d5809 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/PersistentUniquenessProvider.kt @@ -1,82 +1,110 @@ package net.corda.node.services.transactions -import net.corda.core.ThreadBox import net.corda.core.contracts.StateRef import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.parsePublicKeyBase58 import net.corda.core.identity.Party +import net.corda.core.internal.ThreadBox import net.corda.core.node.services.UniquenessException import net.corda.core.node.services.UniquenessProvider import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.utilities.loggerFor import net.corda.node.utilities.* import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.ResultRow -import org.jetbrains.exposed.sql.statements.InsertStatement +import java.io.Serializable import java.util.* import javax.annotation.concurrent.ThreadSafe +import javax.persistence.* /** A RDBMS backed Uniqueness provider */ @ThreadSafe class PersistentUniquenessProvider : UniquenessProvider, SingletonSerializeAsToken() { + + @Entity + @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}notary_commit_log") + class PersistentUniqueness ( + + @EmbeddedId + var id: StateRef = StateRef(), + + @Column(name = "consuming_transaction_id") + var consumingTxHash: String = "", + + @Column(name = "consuming_input_index", length = 36) + var consumingIndex: Int = 0, + + @Embedded + var party: Party = Party() + ) { + + @Embeddable + data class StateRef ( + @Column(name = "transaction_id") + var txId: String = "", + + @Column(name = "output_index", length = 36) + var index: Int = 0 + ) : Serializable + + @Embeddable + data class Party ( + @Column(name = "requesting_party_name") + var name: String = "", + + @Column(name = "requesting_party_key", length = 255) + var owningKey: String = "" + ) : Serializable + } + + private class InnerState { + val committedStates = createMap() + } + + private val mutex = ThreadBox(InnerState()) + companion object { - private val TABLE_NAME = "${NODE_DATABASE_PREFIX}notary_commit_log" private val log = loggerFor() - } - /** - * For each input state store the consuming transaction information. - */ - private object Table : JDBCHashedTable(TABLE_NAME) { - val output = stateRef("transaction_id", "output_index") - val consumingTxHash = secureHash("consuming_transaction_id") - val consumingIndex = integer("consuming_input_index") - val requestingParty = party("requesting_party_name", "requesting_party_key") - } - - private val committedStates = ThreadBox(object : AbstractJDBCHashMap(Table, loadOnInit = false) { - override fun keyFromRow(row: ResultRow): StateRef = StateRef(row[table.output.txId], row[table.output.index]) - - override fun valueFromRow(row: ResultRow): UniquenessProvider.ConsumingTx = UniquenessProvider.ConsumingTx( - row[table.consumingTxHash], - row[table.consumingIndex], - Party(X500Name(row[table.requestingParty.name]), row[table.requestingParty.owningKey]) - ) - - override fun addKeyToInsert(insert: InsertStatement, - entry: Map.Entry, - finalizables: MutableList<() -> Unit>) { - insert[table.output.txId] = entry.key.txhash - insert[table.output.index] = entry.key.index + fun createMap(): AppendOnlyPersistentMap { + return AppendOnlyPersistentMap( + toPersistentEntityKey = { PersistentUniqueness.StateRef(it.txhash.toString(), it.index) }, + fromPersistentEntity = { + Pair(StateRef(SecureHash.parse(it.id.txId), it.id.index), + UniquenessProvider.ConsumingTx(SecureHash.parse(it.consumingTxHash), it.consumingIndex, + Party(X500Name(it.party.name), parsePublicKeyBase58(it.party.owningKey)))) + }, + toPersistentEntity = { key: StateRef, value: UniquenessProvider.ConsumingTx -> + PersistentUniqueness().apply { + id = PersistentUniqueness.StateRef(key.txhash.toString(), key.index) + consumingTxHash = value.id.toString() + consumingIndex = value.inputIndex + party = PersistentUniqueness.Party(value.requestingParty.name.toString()) + } + }, + persistentEntityClass = PersistentUniqueness::class.java + ) } - - override fun addValueToInsert(insert: InsertStatement, - entry: Map.Entry, - finalizables: MutableList<() -> Unit>) { - insert[table.consumingTxHash] = entry.value.id - insert[table.consumingIndex] = entry.value.inputIndex - insert[table.requestingParty.name] = entry.value.requestingParty.name.toString() - insert[table.requestingParty.owningKey] = entry.value.requestingParty.owningKey - } - }) + } override fun commit(states: List, txId: SecureHash, callerIdentity: Party) { - val conflict = committedStates.locked { - val conflictingStates = LinkedHashMap() - for (inputState in states) { - val consumingTx = get(inputState) - if (consumingTx != null) conflictingStates[inputState] = consumingTx - } - if (conflictingStates.isNotEmpty()) { - log.debug("Failure, input states already committed: ${conflictingStates.keys}") - UniquenessProvider.Conflict(conflictingStates) - } else { - states.forEachIndexed { i, stateRef -> - put(stateRef, UniquenessProvider.ConsumingTx(txId, i, callerIdentity)) + + val conflict = mutex.locked { + val conflictingStates = LinkedHashMap() + for (inputState in states) { + val consumingTx = committedStates.get(inputState) + if (consumingTx != null) conflictingStates[inputState] = consumingTx + } + if (conflictingStates.isNotEmpty()) { + log.debug("Failure, input states already committed: ${conflictingStates.keys}") + UniquenessProvider.Conflict(conflictingStates) + } else { + states.forEachIndexed { i, stateRef -> + committedStates[stateRef] = UniquenessProvider.ConsumingTx(txId, i, callerIdentity) + } + log.debug("Successfully committed all input states: $states") + null + } } - log.debug("Successfully committed all input states: $states") - null - } - } if (conflict != null) throw UniquenessException(conflict) } diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/RaftNonValidatingNotaryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/RaftNonValidatingNotaryService.kt index 05bcabe172..ce2c4289c9 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/RaftNonValidatingNotaryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/RaftNonValidatingNotaryService.kt @@ -1,9 +1,9 @@ package net.corda.node.services.transactions +import net.corda.core.flows.NotaryFlow import net.corda.core.identity.Party -import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.core.node.services.TimeWindowChecker -import net.corda.flows.NotaryFlow +import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.node.services.api.ServiceHubInternal /** A non-validating notary service operated by a group of mutually trusting parties, uses the Raft algorithm to achieve consensus. */ @@ -15,9 +15,7 @@ class RaftNonValidatingNotaryService(override val services: ServiceHubInternal) override val timeWindowChecker: TimeWindowChecker = TimeWindowChecker(services.clock) override val uniquenessProvider: RaftUniquenessProvider = RaftUniquenessProvider(services) - override fun createServiceFlow(otherParty: Party, platformVersion: Int): NotaryFlow.Service { - return NonValidatingNotaryFlow(otherParty, this) - } + override fun createServiceFlow(otherParty: Party): NotaryFlow.Service = NonValidatingNotaryFlow(otherParty, this) override fun start() { uniquenessProvider.start() diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/RaftUniquenessProvider.kt b/node/src/main/kotlin/net/corda/node/services/transactions/RaftUniquenessProvider.kt index b12c7ce477..d7024866ad 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/RaftUniquenessProvider.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/RaftUniquenessProvider.kt @@ -23,8 +23,8 @@ import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize import net.corda.core.utilities.loggerFor import net.corda.node.services.api.ServiceHubInternal +import net.corda.node.utilities.CordaPersistence import net.corda.nodeapi.config.SSLConfiguration -import org.jetbrains.exposed.sql.Database import java.nio.file.Path import java.util.concurrent.CompletableFuture import javax.annotation.concurrent.ThreadSafe @@ -55,7 +55,7 @@ class RaftUniquenessProvider(services: ServiceHubInternal) : UniquenessProvider, */ private val clusterAddresses = services.configuration.notaryClusterAddresses /** The database to store the state machine state in */ - private val db: Database = services.database + private val db: CordaPersistence = services.database /** SSL configuration */ private val transportConfiguration: SSLConfiguration = services.configuration diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/RaftValidatingNotaryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/RaftValidatingNotaryService.kt index deba64d1a3..7267067997 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/RaftValidatingNotaryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/RaftValidatingNotaryService.kt @@ -1,9 +1,9 @@ package net.corda.node.services.transactions +import net.corda.core.flows.NotaryFlow import net.corda.core.identity.Party -import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.core.node.services.TimeWindowChecker -import net.corda.flows.NotaryFlow +import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.node.services.api.ServiceHubInternal /** A validating notary service operated by a group of mutually trusting parties, uses the Raft algorithm to achieve consensus. */ @@ -15,9 +15,7 @@ class RaftValidatingNotaryService(override val services: ServiceHubInternal) : T override val timeWindowChecker: TimeWindowChecker = TimeWindowChecker(services.clock) override val uniquenessProvider: RaftUniquenessProvider = RaftUniquenessProvider(services) - override fun createServiceFlow(otherParty: Party, platformVersion: Int): NotaryFlow.Service { - return ValidatingNotaryFlow(otherParty, this) - } + override fun createServiceFlow(otherParty: Party): NotaryFlow.Service = ValidatingNotaryFlow(otherParty, this) override fun start() { uniquenessProvider.start() diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/SimpleNotaryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/SimpleNotaryService.kt index c23d19532b..5b92626f3b 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/SimpleNotaryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/SimpleNotaryService.kt @@ -1,10 +1,10 @@ package net.corda.node.services.transactions +import net.corda.core.flows.NotaryFlow import net.corda.core.identity.Party -import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.core.node.services.ServiceType import net.corda.core.node.services.TimeWindowChecker -import net.corda.flows.NotaryFlow +import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.node.services.api.ServiceHubInternal /** A simple Notary service that does not perform transaction validation */ @@ -16,9 +16,7 @@ class SimpleNotaryService(override val services: ServiceHubInternal) : TrustedAu override val timeWindowChecker = TimeWindowChecker(services.clock) override val uniquenessProvider = PersistentUniquenessProvider() - override fun createServiceFlow(otherParty: Party, platformVersion: Int): NotaryFlow.Service { - return NonValidatingNotaryFlow(otherParty, this) - } + override fun createServiceFlow(otherParty: Party): NotaryFlow.Service = NonValidatingNotaryFlow(otherParty, this) override fun start() {} override fun stop() {} diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryFlow.kt b/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryFlow.kt index dca4e5f5ad..d4a1b11e60 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryFlow.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryFlow.kt @@ -2,12 +2,10 @@ package net.corda.node.services.transactions import co.paralleluniverse.fibers.Suspendable import net.corda.core.contracts.TransactionVerificationException +import net.corda.core.flows.* import net.corda.core.identity.Party import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.core.transactions.SignedTransaction -import net.corda.core.transactions.WireTransaction -import net.corda.core.utilities.unwrap -import net.corda.flows.* import java.security.SignatureException /** @@ -24,35 +22,25 @@ class ValidatingNotaryFlow(otherSide: Party, service: TrustedAuthorityNotaryServ */ @Suspendable override fun receiveAndVerifyTx(): TransactionParts { - val stx = receive(otherSide).unwrap { it } - checkSignatures(stx) - val wtx = stx.tx - validateTransaction(wtx) - return TransactionParts(wtx.id, wtx.inputs, wtx.timeWindow) - } - - private fun checkSignatures(stx: SignedTransaction) { try { - stx.verifySignatures(serviceHub.myInfo.notaryIdentity.owningKey) - } catch(e: SignedTransaction.SignaturesMissingException) { - throw NotaryException(NotaryError.SignaturesMissing(e)) - } - } - - @Suspendable - fun validateTransaction(wtx: WireTransaction) { - try { - resolveTransaction(wtx) - wtx.toLedgerTransaction(serviceHub).verify() + val stx = subFlow(ReceiveTransactionFlow(otherSide, checkSufficientSignatures = false)) + checkSignatures(stx) + val wtx = stx.tx + return TransactionParts(wtx.id, wtx.inputs, wtx.timeWindow) } catch (e: Exception) { throw when (e) { - is TransactionVerificationException -> NotaryException(NotaryError.TransactionInvalid(e.toString())) - is SignatureException -> NotaryException(NotaryError.SignaturesInvalid(e.toString())) + is TransactionVerificationException, + is SignatureException -> NotaryException(NotaryError.TransactionInvalid(e)) else -> e } } } - @Suspendable - private fun resolveTransaction(wtx: WireTransaction) = subFlow(ResolveTransactionsFlow(wtx, otherSide)) + private fun checkSignatures(stx: SignedTransaction) { + try { + stx.verifySignaturesExcept(serviceHub.myInfo.notaryIdentity.owningKey) + } catch(e: SignatureException) { + throw NotaryException(NotaryError.TransactionInvalid(e)) + } + } } diff --git a/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryService.kt b/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryService.kt index c996a8979d..9bb34b273c 100644 --- a/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryService.kt +++ b/node/src/main/kotlin/net/corda/node/services/transactions/ValidatingNotaryService.kt @@ -1,10 +1,10 @@ package net.corda.node.services.transactions +import net.corda.core.flows.NotaryFlow import net.corda.core.identity.Party -import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.core.node.services.ServiceType import net.corda.core.node.services.TimeWindowChecker -import net.corda.flows.NotaryFlow +import net.corda.core.node.services.TrustedAuthorityNotaryService import net.corda.node.services.api.ServiceHubInternal /** A Notary service that validates the transaction chain of the submitted transaction before committing it */ @@ -16,9 +16,7 @@ class ValidatingNotaryService(override val services: ServiceHubInternal) : Trust override val timeWindowChecker = TimeWindowChecker(services.clock) override val uniquenessProvider = PersistentUniquenessProvider() - override fun createServiceFlow(otherParty: Party, platformVersion: Int): NotaryFlow.Service { - return ValidatingNotaryFlow(otherParty, this) - } + override fun createServiceFlow(otherParty: Party): NotaryFlow.Service = ValidatingNotaryFlow(otherParty, this) override fun start() {} override fun stop() {} diff --git a/node/src/main/kotlin/net/corda/node/services/vault/CashBalanceAsMetricsObserver.kt b/node/src/main/kotlin/net/corda/node/services/vault/CashBalanceAsMetricsObserver.kt deleted file mode 100644 index 9eb1719ef6..0000000000 --- a/node/src/main/kotlin/net/corda/node/services/vault/CashBalanceAsMetricsObserver.kt +++ /dev/null @@ -1,45 +0,0 @@ -package net.corda.node.services.vault - -import com.codahale.metrics.Gauge -import net.corda.core.node.services.VaultService -import net.corda.node.services.api.ServiceHubInternal -import net.corda.node.utilities.transaction -import org.jetbrains.exposed.sql.Database -import java.util.* - -/** - * This class observes the vault and reflect current cash balances as exposed metrics in the monitoring service. - */ -class CashBalanceAsMetricsObserver(val serviceHubInternal: ServiceHubInternal, val database: Database) { - init { - // TODO: Need to consider failure scenarios. This needs to run if the TX is successfully recorded - serviceHubInternal.vaultService.updates.subscribe { _ -> - exportCashBalancesViaMetrics(serviceHubInternal.vaultService) - } - } - - private class BalanceMetric : Gauge { - @Volatile var pennies = 0L - override fun getValue(): Long? = pennies - } - - private val balanceMetrics = HashMap() - - private fun exportCashBalancesViaMetrics(vault: VaultService) { - // This is just for demo purposes. We probably shouldn't expose balances via JMX in a real node as that might - // be commercially sensitive info that the sysadmins aren't even meant to know. - // - // Note: exported as pennies. - val m = serviceHubInternal.monitoringService.metrics - database.transaction { - for ((key, value) in vault.cashBalances) { - val metric = balanceMetrics.getOrPut(key) { - val newMetric = BalanceMetric() - m.register("VaultBalances.${key}Pennies", newMetric) - newMetric - } - metric.pennies = value.quantity - } - } - } -} diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt index 6ad1d928f5..2ee2218160 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateQueryCriteriaParser.kt @@ -2,19 +2,18 @@ package net.corda.node.services.vault import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateRef -import net.corda.core.contracts.UniqueIdentifier import net.corda.core.identity.AbstractParty import net.corda.core.node.services.Vault import net.corda.core.node.services.VaultQueryException import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.CommonQueryCriteria +import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.PersistentState import net.corda.core.schemas.PersistentStateRef import net.corda.core.utilities.OpaqueBytes -import net.corda.core.utilities.toHexString import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.toHexString import net.corda.core.utilities.trace -import net.corda.core.schemas.CommonSchemaV1 import org.bouncycastle.asn1.x500.X500Name import java.util.* import javax.persistence.Tuple @@ -43,18 +42,30 @@ class HibernateQueryCriteriaParser(val contractType: Class, val predicateSet = mutableSetOf() // contract State Types - val combinedContractTypeTypes = criteria.contractStateTypes?.plus(contractType) ?: setOf(contractType) - combinedContractTypeTypes.filter { it.name != ContractState::class.java.name }.let { - val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() } - val concrete = it.filter { !it.isInterface }.map { it.name } - val all = interfaces.plus(concrete) - if (all.isNotEmpty()) - predicateSet.add(criteriaBuilder.and(vaultStates.get("contractStateClassName").`in`(all))) - } + val contractTypes = deriveContractTypes(criteria.contractStateTypes) + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get("contractStateClassName").`in`(contractTypes))) // soft locking - if (!criteria.includeSoftlockedStates) - predicateSet.add(criteriaBuilder.and(vaultStates.get("lockId").isNull)) + criteria.softLockingCondition?.let { + val softLocking = criteria.softLockingCondition + val type = softLocking!!.type + when(type) { + QueryCriteria.SoftLockingType.UNLOCKED_ONLY -> + predicateSet.add(criteriaBuilder.and(vaultStates.get("lockId").isNull)) + QueryCriteria.SoftLockingType.LOCKED_ONLY -> + predicateSet.add(criteriaBuilder.and(vaultStates.get("lockId").isNotNull)) + QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED -> { + require(softLocking.lockIds.isNotEmpty()) { "Must specify one or more lockIds" } + predicateSet.add(criteriaBuilder.or(vaultStates.get("lockId").isNull, + vaultStates.get("lockId").`in`(softLocking.lockIds.map { it.toString() }))) + } + QueryCriteria.SoftLockingType.SPECIFIED -> { + require(softLocking.lockIds.isNotEmpty()) { "Must specify one or more lockIds" } + predicateSet.add(criteriaBuilder.and(vaultStates.get("lockId").`in`(softLocking.lockIds.map { it.toString() }))) + } + } + } // notary names criteria.notaryName?.let { @@ -74,8 +85,8 @@ class HibernateQueryCriteriaParser(val contractType: Class, val timeCondition = criteria.timeCondition val timeInstantType = timeCondition!!.type val timeColumn = when (timeInstantType) { - QueryCriteria.TimeInstantType.RECORDED -> Column.Kotlin(VaultSchemaV1.VaultStates::recordedTime) - QueryCriteria.TimeInstantType.CONSUMED -> Column.Kotlin(VaultSchemaV1.VaultStates::consumedTime) + QueryCriteria.TimeInstantType.RECORDED -> Column(VaultSchemaV1.VaultStates::recordedTime) + QueryCriteria.TimeInstantType.CONSUMED -> Column(VaultSchemaV1.VaultStates::consumedTime) } val expression = CriteriaExpression.ColumnPredicateExpression(timeColumn, timeCondition.predicate) predicateSet.add(parseExpression(vaultStates, expression) as Predicate) @@ -83,6 +94,15 @@ class HibernateQueryCriteriaParser(val contractType: Class, return predicateSet } + private fun deriveContractTypes(contractStateTypes: Set>? = null): List { + val combinedContractStateTypes = contractStateTypes?.plus(contractType) ?: setOf(contractType) + combinedContractStateTypes.filter { it.name != ContractState::class.java.name }.let { + val interfaces = it.flatMap { contractTypeMappings[it.name] ?: emptyList() } + val concrete = it.filter { !it.isInterface }.map { it.name } + return interfaces.plus(concrete) + } + } + private fun columnPredicateToPredicate(column: Path, columnPredicate: ColumnPredicate<*>): Predicate { return when (columnPredicate) { is ColumnPredicate.EqualityComparison -> { @@ -93,9 +113,10 @@ class HibernateQueryCriteriaParser(val contractType: Class, } } is ColumnPredicate.BinaryComparison -> { - column as Path?> @Suppress("UNCHECKED_CAST") val literal = columnPredicate.rightLiteral as Comparable? + @Suppress("UNCHECKED_CAST") + column as Path?> when (columnPredicate.operator) { BinaryComparisonOperator.GREATER_THAN -> criteriaBuilder.greaterThan(column, literal) BinaryComparisonOperator.GREATER_THAN_OR_EQUAL -> criteriaBuilder.greaterThanOrEqualTo(column, literal) @@ -104,6 +125,7 @@ class HibernateQueryCriteriaParser(val contractType: Class, } } is ColumnPredicate.Likeness -> { + @Suppress("UNCHECKED_CAST") column as Path when (columnPredicate.operator) { LikenessOperator.LIKE -> criteriaBuilder.like(column, columnPredicate.rightLiteral) @@ -170,28 +192,28 @@ class HibernateQueryCriteriaParser(val contractType: Class, @Suppress("UNCHECKED_CAST") column as Path? val aggregateExpression = - when (columnPredicate.type) { - AggregateFunctionType.SUM -> criteriaBuilder.sum(column) - AggregateFunctionType.AVG -> criteriaBuilder.avg(column) - AggregateFunctionType.COUNT -> criteriaBuilder.count(column) - AggregateFunctionType.MAX -> criteriaBuilder.max(column) - AggregateFunctionType.MIN -> criteriaBuilder.min(column) - } + when (columnPredicate.type) { + AggregateFunctionType.SUM -> criteriaBuilder.sum(column) + AggregateFunctionType.AVG -> criteriaBuilder.avg(column) + AggregateFunctionType.COUNT -> criteriaBuilder.count(column) + AggregateFunctionType.MAX -> criteriaBuilder.max(column) + AggregateFunctionType.MIN -> criteriaBuilder.min(column) + } aggregateExpressions.add(aggregateExpression) // optionally order by this aggregate function expression.orderBy?.let { val orderCriteria = - when (expression.orderBy!!) { - Sort.Direction.ASC -> criteriaBuilder.asc(aggregateExpression) - Sort.Direction.DESC -> criteriaBuilder.desc(aggregateExpression) - } + when (expression.orderBy!!) { + Sort.Direction.ASC -> criteriaBuilder.asc(aggregateExpression) + Sort.Direction.DESC -> criteriaBuilder.desc(aggregateExpression) + } criteriaQuery.orderBy(orderCriteria) } // add optional group by clauses expression.groupByColumns?.let { columns -> val groupByExpressions = - columns.map { column -> - val path = root.get(getColumnName(column)) + columns.map { _column -> + val path = root.get(getColumnName(_column)) aggregateExpressions.add(path) path } @@ -206,7 +228,7 @@ class HibernateQueryCriteriaParser(val contractType: Class, override fun parseCriteria(criteria: QueryCriteria.FungibleAssetQueryCriteria) : Collection { log.trace { "Parsing FungibleAssetQueryCriteria: $criteria" } - var predicateSet = mutableSetOf() + val predicateSet = mutableSetOf() val vaultFungibleStates = criteriaQuery.from(VaultSchemaV1.VaultFungibleStates::class.java) rootEntities.putIfAbsent(VaultSchemaV1.VaultFungibleStates::class.java, vaultFungibleStates) @@ -214,6 +236,11 @@ class HibernateQueryCriteriaParser(val contractType: Class, val joinPredicate = criteriaBuilder.equal(vaultStates.get("stateRef"), vaultFungibleStates.get("stateRef")) predicateSet.add(joinPredicate) + // contract State Types + val contractTypes = deriveContractTypes() + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get("contractStateClassName").`in`(contractTypes))) + // owner criteria.owner?.let { val ownerKeys = criteria.owner as List @@ -231,8 +258,8 @@ class HibernateQueryCriteriaParser(val contractType: Class, criteria.issuerPartyName?.let { val issuerParties = criteria.issuerPartyName as List val joinFungibleStateToParty = vaultFungibleStates.join("issuerParty") - val dealPartyKeys = issuerParties.map { it.nameOrNull().toString() } - predicateSet.add(criteriaBuilder.equal(joinFungibleStateToParty.get("name"), dealPartyKeys)) + val issuerPartyNames = issuerParties.map { it.nameOrNull().toString() } + predicateSet.add(criteriaBuilder.and(joinFungibleStateToParty.get("name").`in`(issuerPartyNames))) } // issuer reference @@ -263,19 +290,21 @@ class HibernateQueryCriteriaParser(val contractType: Class, val joinPredicate = criteriaBuilder.equal(vaultStates.get("stateRef"), vaultLinearStates.get("stateRef")) joinPredicates.add(joinPredicate) - // linear ids - criteria.linearId?.let { - val uniqueIdentifiers = criteria.linearId as List - val externalIds = uniqueIdentifiers.mapNotNull { it.externalId } - if (externalIds.isNotEmpty()) - predicateSet.add(criteriaBuilder.and(vaultLinearStates.get("externalId").`in`(externalIds))) - predicateSet.add(criteriaBuilder.and(vaultLinearStates.get("uuid").`in`(uniqueIdentifiers.map { it.id }))) + // contract State Types + val contractTypes = deriveContractTypes() + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get("contractStateClassName").`in`(contractTypes))) + + // linear ids UUID + criteria.uuid?.let { + val uuids = criteria.uuid as List + predicateSet.add(criteriaBuilder.and(vaultLinearStates.get("uuid").`in`(uuids))) } - // deal refs - criteria.dealRef?.let { - val dealRefs = criteria.dealRef as List - predicateSet.add(criteriaBuilder.and(vaultLinearStates.get("dealReference").`in`(dealRefs))) + // linear ids externalId + criteria.externalId?.let { + val externalIds = criteria.externalId as List + predicateSet.add(criteriaBuilder.and(vaultLinearStates.get("externalId").`in`(externalIds))) } // deal participants @@ -302,6 +331,11 @@ class HibernateQueryCriteriaParser(val contractType: Class, val joinPredicate = criteriaBuilder.equal(vaultStates.get("stateRef"), entityRoot.get("stateRef")) joinPredicates.add(joinPredicate) + // contract State Types + val contractTypes = deriveContractTypes() + if (contractTypes.isNotEmpty()) + predicateSet.add(criteriaBuilder.and(vaultStates.get("contractStateClassName").`in`(contractTypes))) + // resolve general criteria expressions parseExpression(entityRoot, criteria.expression, predicateSet) } @@ -321,7 +355,7 @@ class HibernateQueryCriteriaParser(val contractType: Class, override fun parseOr(left: QueryCriteria, right: QueryCriteria): Collection { log.trace { "Parsing OR QueryCriteria composition: $left OR $right" } - var predicateSet = mutableSetOf() + val predicateSet = mutableSetOf() val leftPredicates = parse(left) val rightPredicates = parse(right) @@ -334,7 +368,7 @@ class HibernateQueryCriteriaParser(val contractType: Class, override fun parseAnd(left: QueryCriteria, right: QueryCriteria): Collection { log.trace { "Parsing AND QueryCriteria composition: $left AND $right" } - var predicateSet = mutableSetOf() + val predicateSet = mutableSetOf() val leftPredicates = parse(left) val rightPredicates = parse(right) @@ -353,10 +387,10 @@ class HibernateQueryCriteriaParser(val contractType: Class, } val selections = - if (aggregateExpressions.isEmpty()) - listOf(vaultStates).plus(rootEntities.map { it.value }) - else - aggregateExpressions + if (aggregateExpressions.isEmpty()) + listOf(vaultStates).plus(rootEntities.map { it.value }) + else + aggregateExpressions criteriaQuery.multiselect(selections) val combinedPredicates = joinPredicates.plus(predicateSet) criteriaQuery.where(*combinedPredicates.toTypedArray()) @@ -379,7 +413,7 @@ class HibernateQueryCriteriaParser(val contractType: Class, private fun parse(sorting: Sort) { log.trace { "Parsing sorting specification: $sorting" } - var orderCriteria = mutableListOf() + val orderCriteria = mutableListOf() sorting.columns.map { (sortAttribute, direction) -> val (entityStateClass, entityStateAttributeParent, entityStateAttributeChild) = @@ -418,21 +452,21 @@ class HibernateQueryCriteriaParser(val contractType: Class, private fun parse(sortAttribute: Sort.Attribute): Triple, String, String?> { val entityClassAndColumnName : Triple, String, String?> = - when(sortAttribute) { - is Sort.CommonStateAttribute -> { - Triple(VaultSchemaV1.VaultStates::class.java, sortAttribute.attributeParent, sortAttribute.attributeChild) + when(sortAttribute) { + is Sort.CommonStateAttribute -> { + Triple(VaultSchemaV1.VaultStates::class.java, sortAttribute.attributeParent, sortAttribute.attributeChild) + } + is Sort.VaultStateAttribute -> { + Triple(VaultSchemaV1.VaultStates::class.java, sortAttribute.attributeName, null) + } + is Sort.LinearStateAttribute -> { + Triple(VaultSchemaV1.VaultLinearStates::class.java, sortAttribute.attributeName, null) + } + is Sort.FungibleStateAttribute -> { + Triple(VaultSchemaV1.VaultFungibleStates::class.java, sortAttribute.attributeName, null) + } + else -> throw VaultQueryException("Invalid sort attribute: $sortAttribute") } - is Sort.VaultStateAttribute -> { - Triple(VaultSchemaV1.VaultStates::class.java, sortAttribute.attributeName, null) - } - is Sort.LinearStateAttribute -> { - Triple(VaultSchemaV1.VaultLinearStates::class.java, sortAttribute.attributeName, null) - } - is Sort.FungibleStateAttribute -> { - Triple(VaultSchemaV1.VaultFungibleStates::class.java, sortAttribute.attributeName, null) - } - else -> throw VaultQueryException("Invalid sort attribute: $sortAttribute") - } return entityClassAndColumnName } } \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt b/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt index fe91e3f579..3bccc8597c 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/HibernateVaultQueryImpl.kt @@ -1,7 +1,7 @@ package net.corda.node.services.vault -import net.corda.core.ThreadBox -import net.corda.core.bufferUntilSubscribed +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.contracts.ContractState import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef @@ -13,14 +13,15 @@ import net.corda.core.node.services.VaultQueryException import net.corda.core.node.services.VaultQueryService import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.VaultCustomQueryCriteria +import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize -import net.corda.core.serialization.storageKryo import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor import net.corda.node.services.database.HibernateConfiguration import org.jetbrains.exposed.sql.transactions.TransactionManager import rx.subjects.PublishSubject +import rx.Observable import java.lang.Exception import java.util.* import javax.persistence.EntityManager @@ -28,7 +29,7 @@ import javax.persistence.Tuple class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, - val updatesPublisher: PublishSubject) : SingletonSerializeAsToken(), VaultQueryService { + val updatesPublisher: PublishSubject>) : SingletonSerializeAsToken(), VaultQueryService { companion object { val log = loggerFor() } @@ -96,18 +97,18 @@ class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, return@forEachIndexed val vaultState = result[0] as VaultSchemaV1.VaultStates val stateRef = StateRef(SecureHash.parse(vaultState.stateRef!!.txId!!), vaultState.stateRef!!.index!!) - val state = vaultState.contractState.deserialize>(storageKryo()) + val state = vaultState.contractState.deserialize>(context = STORAGE_CONTEXT) statesMeta.add(Vault.StateMetadata(stateRef, vaultState.contractStateClassName, vaultState.recordedTime, vaultState.consumedTime, vaultState.stateStatus, vaultState.notaryName, vaultState.notaryKey, vaultState.lockId, vaultState.lockUpdateTime)) statesAndRefs.add(StateAndRef(state, stateRef)) } else { + // TODO: improve typing of returned other results log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" } otherResults.addAll(result.toArray().asList()) } } return Vault.Page(states = statesAndRefs, statesMetadata = statesMeta, stateTypes = criteriaParser.stateTypes, totalStatesAvailable = totalStates, otherResults = otherResults) - } catch (e: Exception) { log.error(e.message) throw e.cause ?: e @@ -118,10 +119,11 @@ class HibernateVaultQueryImpl(hibernateConfig: HibernateConfiguration, private val mutex = ThreadBox({ updatesPublisher }) @Throws(VaultQueryException::class) - override fun _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractType: Class): DataFeed, Vault.Update> { + override fun _trackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractType: Class): DataFeed, Vault.Update> { return mutex.locked { - val snapshotResults = _queryBy(criteria, paging, sorting, contractType) - val updates = updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractType, snapshotResults.stateTypes) } + val snapshotResults = _queryBy(criteria, paging, sorting, contractType) + @Suppress("UNCHECKED_CAST") + val updates = updatesPublisher.bufferUntilSubscribed().filter { it.containsType(contractType, snapshotResults.stateTypes) } as Observable> DataFeed(snapshotResults, updates) } } diff --git a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt index 2ba51cbf06..04e98b5e75 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/NodeVaultService.kt @@ -4,52 +4,48 @@ import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand import com.google.common.annotations.VisibleForTesting import io.requery.PersistenceException -import io.requery.TransactionIsolation -import io.requery.kotlin.`in` import io.requery.kotlin.eq -import io.requery.kotlin.isNull -import io.requery.kotlin.notNull import io.requery.query.RowExpression -import net.corda.contracts.asset.Cash -import net.corda.contracts.asset.OnLedgerAsset -import net.corda.core.ThreadBox -import net.corda.core.bufferUntilSubscribed import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.crypto.containsAny import net.corda.core.crypto.toBase58String -import net.corda.core.identity.AbstractParty -import net.corda.core.identity.Party -import net.corda.core.messaging.DataFeed +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.tee import net.corda.core.node.ServiceHub import net.corda.core.node.services.StatesNotAvailableException import net.corda.core.node.services.Vault import net.corda.core.node.services.VaultService -import net.corda.core.node.services.unconsumedStates +import net.corda.core.node.services.vault.IQueryCriteriaParser +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.Sort +import net.corda.core.node.services.vault.SortAttribute +import net.corda.core.schemas.PersistentState +import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo -import net.corda.core.tee -import net.corda.core.transactions.TransactionBuilder +import net.corda.core.transactions.CoreTransaction +import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.transactions.WireTransaction -import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.loggerFor -import net.corda.core.utilities.toHexString +import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.trace import net.corda.node.services.database.RequeryConfiguration +import net.corda.node.services.database.parserTransactionIsolationLevel import net.corda.node.services.statemachine.FlowStateMachineImpl -import net.corda.node.services.vault.schemas.requery.* +import net.corda.node.services.vault.schemas.requery.Models import net.corda.node.services.vault.schemas.requery.VaultSchema +import net.corda.node.services.vault.schemas.requery.VaultStatesEntity +import net.corda.node.services.vault.schemas.requery.VaultTxnNoteEntity import net.corda.node.utilities.bufferUntilDatabaseCommit import net.corda.node.utilities.wrapWithDatabaseTransaction import rx.Observable import rx.subjects.PublishSubject import java.security.PublicKey -import java.sql.SQLException import java.util.* -import java.util.concurrent.locks.ReentrantLock -import kotlin.concurrent.withLock +import javax.persistence.criteria.Predicate /** * Currently, the node vault service is a very simple RDBMS backed implementation. It will change significantly when @@ -62,7 +58,7 @@ import kotlin.concurrent.withLock * TODO: keep an audit trail with time stamps of previously unconsumed states "as of" a particular point in time. * TODO: have transaction storage do some caching. */ -class NodeVaultService(private val services: ServiceHub, dataSourceProperties: Properties) : SingletonSerializeAsToken(), VaultService { +class NodeVaultService(private val services: ServiceHub, dataSourceProperties: Properties, databaseProperties: Properties?) : SingletonSerializeAsToken(), VaultService { private companion object { val log = loggerFor() @@ -71,34 +67,36 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P val stateRefCompositeColumn: RowExpression = RowExpression.of(listOf(VaultStatesEntity.TX_ID, VaultStatesEntity.INDEX)) } - val configuration = RequeryConfiguration(dataSourceProperties) + val configuration = RequeryConfiguration(dataSourceProperties, databaseProperties = databaseProperties ?: Properties()) val session = configuration.sessionForModel(Models.VAULT) + private val transactionIsolationLevel = parserTransactionIsolationLevel(databaseProperties?.getProperty("transactionIsolationLevel") ?:"") private class InnerState { - val _updatesPublisher = PublishSubject.create()!! - val _rawUpdatesPublisher = PublishSubject.create()!! + val _updatesPublisher = PublishSubject.create>()!! + val _rawUpdatesPublisher = PublishSubject.create>()!! val _updatesInDbTx = _updatesPublisher.wrapWithDatabaseTransaction().asObservable()!! // For use during publishing only. - val updatesPublisher: rx.Observer get() = _updatesPublisher.bufferUntilDatabaseCommit().tee(_rawUpdatesPublisher) + val updatesPublisher: rx.Observer> get() = _updatesPublisher.bufferUntilDatabaseCommit().tee(_rawUpdatesPublisher) } + private val mutex = ThreadBox(InnerState()) - private fun recordUpdate(update: Vault.Update): Vault.Update { - if (update != Vault.NoUpdate) { + private fun recordUpdate(update: Vault.Update): Vault.Update { + if (!update.isEmpty()) { val producedStateRefs = update.produced.map { it.ref } val producedStateRefsMap = update.produced.associateBy { it.ref } val consumedStateRefs = update.consumed.map { it.ref } log.trace { "Removing $consumedStateRefs consumed contract states and adding $producedStateRefs produced contract states to the database." } - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { + session.withTransaction(transactionIsolationLevel) { producedStateRefsMap.forEach { it -> val state = VaultStatesEntity().apply { txId = it.key.txhash.toString() index = it.key.index stateStatus = Vault.StateStatus.UNCONSUMED contractStateClassName = it.value.state.data.javaClass.name - contractState = it.value.state.serialize(storageKryo()).bytes + contractState = it.value.state.serialize(context = STORAGE_CONTEXT).bytes notaryName = it.value.state.notary.name.toString() notaryKey = it.value.state.notary.owningKey.toBase58String() recordedTime = services.clock.instant() @@ -127,126 +125,134 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P return update } - // TODO: consider moving this logic outside the vault - // TODO: revisit the concurrency safety of this logic when we move beyond single threaded SMM. - // For example, we update currency totals in a non-deterministic order and so expose ourselves to deadlock. - private fun maybeUpdateCashBalances(update: Vault.Update) { - if (update.containsType()) { - val consumed = sumCashStates(update.consumed) - val produced = sumCashStates(update.produced) - (produced.keys + consumed.keys).map { currency -> - val producedAmount = produced[currency] ?: Amount(0, currency) - val consumedAmount = consumed[currency] ?: Amount(0, currency) + override val rawUpdates: Observable> + get() = mutex.locked { _rawUpdatesPublisher } - val cashBalanceEntity = VaultCashBalancesEntity() - cashBalanceEntity.currency = currency.currencyCode - cashBalanceEntity.amount = producedAmount.quantity - consumedAmount.quantity + override val updates: Observable> + get() = mutex.locked { _updatesInDbTx } - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { - val state = findByKey(VaultCashBalancesEntity::class, currency.currencyCode) - state?.run { - amount += producedAmount.quantity - consumedAmount.quantity + override val updatesPublisher: PublishSubject> + get() = mutex.locked { _updatesPublisher } + + /** + * Splits the provided [txns] into batches of [WireTransaction] and [NotaryChangeWireTransaction]. + * This is required because the batches get aggregated into single updates, and we want to be able to + * indicate whether an update consists entirely of regular or notary change transactions, which may require + * different processing logic. + */ + override fun notifyAll(txns: Iterable) { + // It'd be easier to just group by type, but then we'd lose ordering. + val regularTxns = mutableListOf() + val notaryChangeTxns = mutableListOf() + + for (tx in txns) { + when (tx) { + is WireTransaction -> { + regularTxns.add(tx) + if (notaryChangeTxns.isNotEmpty()) { + notifyNotaryChange(notaryChangeTxns.toList()) + notaryChangeTxns.clear() + } + } + is NotaryChangeWireTransaction -> { + notaryChangeTxns.add(tx) + if (regularTxns.isNotEmpty()) { + notifyRegular(regularTxns.toList()) + regularTxns.clear() } - upsert(state ?: cashBalanceEntity) - val total = state?.amount ?: cashBalanceEntity.amount - log.trace { "Updating Cash balance for $currency by ${cashBalanceEntity.amount} pennies (total: $total)" } } } } + + if (regularTxns.isNotEmpty()) notifyRegular(regularTxns.toList()) + if (notaryChangeTxns.isNotEmpty()) notifyNotaryChange(notaryChangeTxns.toList()) } - @Suppress("UNCHECKED_CAST") - private fun sumCashStates(states: Iterable>): Map> { - return states.mapNotNull { (it.state.data as? FungibleAsset)?.amount } - .groupBy { it.token.product } - .mapValues { it.value.map { Amount(it.quantity, it.token.product) }.sumOrThrow() } - } - - override val cashBalances: Map> get() { - val cashBalancesByCurrency = - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { - val balances = select(VaultSchema.VaultCashBalances::class) - balances.get().toList() - } - return cashBalancesByCurrency.associateBy({ Currency.getInstance(it.currency) }, - { Amount(it.amount, Currency.getInstance(it.currency)) }) - } - - override val rawUpdates: Observable - get() = mutex.locked { _rawUpdatesPublisher } - - override val updates: Observable - get() = mutex.locked { _updatesInDbTx } - - override val updatesPublisher: PublishSubject - get() = mutex.locked { _updatesPublisher } - - override fun track(): DataFeed, Vault.Update> { - return mutex.locked { - DataFeed(Vault(unconsumedStates()), _updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction()) - } - } - - override fun states(clazzes: Set>, statuses: EnumSet, includeSoftLockedStates: Boolean): Iterable> { - val stateAndRefs = - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { - val query = select(VaultSchema.VaultStates::class) - .where(VaultSchema.VaultStates::stateStatus `in` statuses) - // TODO: temporary fix to continue supporting track() function (until becomes Typed) - if (!clazzes.map { it.name }.contains(ContractState::class.java.name)) - query.and(VaultSchema.VaultStates::contractStateClassName `in` (clazzes.map { it.name })) - if (!includeSoftLockedStates) - query.and(VaultSchema.VaultStates::lockId.isNull()) - val iterator = query.get().iterator() - Sequence { iterator } - .map { it -> - val stateRef = StateRef(SecureHash.parse(it.txId), it.index) - val state = it.contractState.deserialize>(storageKryo()) - Vault.StateMetadata(stateRef, it.contractStateClassName, it.recordedTime, it.consumedTime, it.stateStatus, it.notaryName, it.notaryKey, it.lockId, it.lockUpdateTime) - StateAndRef(state, stateRef) - } - } - return stateAndRefs.asIterable() - } - - override fun statesForRefs(refs: List): Map?> { - val stateAndRefs = - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { - var results: List> = emptyList() - refs.forEach { - val result = select(VaultSchema.VaultStates::class) - .where(VaultSchema.VaultStates::stateStatus eq Vault.StateStatus.UNCONSUMED) - .and(VaultSchema.VaultStates::txId eq it.txhash.toString()) - .and(VaultSchema.VaultStates::index eq it.index) - result.get()?.each { - val stateRef = StateRef(SecureHash.parse(it.txId), it.index) - val state = it.contractState.deserialize>(storageKryo()) - results += StateAndRef(state, stateRef) - } - } - results - } - - return stateAndRefs.associateBy({ it.ref }, { it.state }) - } - - override fun notifyAll(txns: Iterable) { + private fun notifyRegular(txns: Iterable) { val ourKeys = services.keyManagementService.keys - val netDelta = txns.fold(Vault.NoUpdate) { netDelta, txn -> netDelta + makeUpdate(txn, ourKeys) } - if (netDelta != Vault.NoUpdate) { - recordUpdate(netDelta) - maybeUpdateCashBalances(netDelta) + fun makeUpdate(tx: WireTransaction): Vault.Update { + val ourNewStates = tx.outputs. + filter { isRelevant(it.data, ourKeys) }. + map { tx.outRef(it.data) } + + // Retrieve all unconsumed states for this transaction's inputs + val consumedStates = loadStates(tx.inputs) + + // Is transaction irrelevant? + if (consumedStates.isEmpty() && ourNewStates.isEmpty()) { + log.trace { "tx ${tx.id} was irrelevant to this vault, ignoring" } + return Vault.NoUpdate + } + + return Vault.Update(consumedStates, ourNewStates.toHashSet()) + } + + val netDelta = txns.fold(Vault.NoUpdate) { netDelta, txn -> netDelta + makeUpdate(txn) } + processAndNotify(netDelta) + } + + private fun notifyNotaryChange(txns: Iterable) { + val ourKeys = services.keyManagementService.keys + fun makeUpdate(tx: NotaryChangeWireTransaction): Vault.Update { + // We need to resolve the full transaction here because outputs are calculated from inputs + // We also can't do filtering beforehand, since output encumbrance pointers get recalculated based on + // input positions + val ltx = tx.resolve(services, emptyList()) + + val (consumedStateAndRefs, producedStates) = ltx.inputs. + zip(ltx.outputs). + filter { + (_, output) -> + isRelevant(output.data, ourKeys) + }. + unzip() + + val producedStateAndRefs = producedStates.map { ltx.outRef(it.data) } + + if (consumedStateAndRefs.isEmpty() && producedStateAndRefs.isEmpty()) { + log.trace { "tx ${tx.id} was irrelevant to this vault, ignoring" } + return Vault.NoNotaryUpdate + } + + return Vault.Update(consumedStateAndRefs.toHashSet(), producedStateAndRefs.toHashSet(), null, Vault.UpdateType.NOTARY_CHANGE) + } + + val netDelta = txns.fold(Vault.NoNotaryUpdate) { netDelta, txn -> netDelta + makeUpdate(txn) } + processAndNotify(netDelta) + } + + private fun loadStates(refs: Collection): HashSet> { + val states = HashSet>() + if (refs.isNotEmpty()) { + session.withTransaction(transactionIsolationLevel) { + val result = select(VaultStatesEntity::class). + where(stateRefCompositeColumn.`in`(stateRefArgs(refs))). + and(VaultSchema.VaultStates::stateStatus eq Vault.StateStatus.UNCONSUMED) + result.get().forEach { + val txHash = SecureHash.parse(it.txId) + val index = it.index + val state = it.contractState.deserialize>(context = STORAGE_CONTEXT) + states.add(StateAndRef(state, StateRef(txHash, index))) + } + } + } + return states + } + + private fun processAndNotify(update: Vault.Update) { + if (!update.isEmpty()) { + recordUpdate(update) mutex.locked { // flowId required by SoftLockManager to perform auto-registration of soft locks for new states val uuid = (Strand.currentStrand() as? FlowStateMachineImpl<*>)?.id?.uuid - val vaultUpdate = if (uuid != null) netDelta.copy(flowId = uuid) else netDelta + val vaultUpdate = if (uuid != null) update.copy(flowId = uuid) else update updatesPublisher.onNext(vaultUpdate) } } } override fun addNoteToTransaction(txnId: SecureHash, noteText: String) { - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { + session.withTransaction(transactionIsolationLevel) { val txnNoteEntity = VaultTxnNoteEntity() txnNoteEntity.txId = txnId.toString() txnNoteEntity.note = noteText @@ -255,52 +261,50 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P } override fun getTransactionNotes(txnId: SecureHash): Iterable { - return session.withTransaction(TransactionIsolation.REPEATABLE_READ) { + return session.withTransaction(transactionIsolationLevel) { (select(VaultSchema.VaultTxnNote::class) where (VaultSchema.VaultTxnNote::txId eq txnId.toString())).get().asIterable().map { it.note } } } @Throws(StatesNotAvailableException::class) - override fun softLockReserve(lockId: UUID, stateRefs: Set) { - if (stateRefs.isNotEmpty()) { - val softLockTimestamp = services.clock.instant() - val stateRefArgs = stateRefArgs(stateRefs) - try { - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { - val updatedRows = update(VaultStatesEntity::class) - .set(VaultStatesEntity.LOCK_ID, lockId.toString()) - .set(VaultStatesEntity.LOCK_UPDATE_TIME, softLockTimestamp) - .where(VaultStatesEntity.STATE_STATUS eq Vault.StateStatus.UNCONSUMED) - .and((VaultStatesEntity.LOCK_ID eq lockId.toString()) or (VaultStatesEntity.LOCK_ID.isNull())) + override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet) { + val softLockTimestamp = services.clock.instant() + val stateRefArgs = stateRefArgs(stateRefs) + try { + session.withTransaction(transactionIsolationLevel) { + val updatedRows = update(VaultStatesEntity::class) + .set(VaultStatesEntity.LOCK_ID, lockId.toString()) + .set(VaultStatesEntity.LOCK_UPDATE_TIME, softLockTimestamp) + .where(VaultStatesEntity.STATE_STATUS eq Vault.StateStatus.UNCONSUMED) + .and((VaultStatesEntity.LOCK_ID eq lockId.toString()) or (VaultStatesEntity.LOCK_ID.isNull())) + .and(stateRefCompositeColumn.`in`(stateRefArgs)).get().value() + if (updatedRows > 0 && updatedRows == stateRefs.size) { + log.trace("Reserving soft lock states for $lockId: $stateRefs") + FlowStateMachineImpl.currentStateMachine()?.hasSoftLockedStates = true + } else { + // revert partial soft locks + val revertUpdatedRows = update(VaultStatesEntity::class) + .set(VaultStatesEntity.LOCK_ID, null) + .where(VaultStatesEntity.LOCK_UPDATE_TIME eq softLockTimestamp) + .and(VaultStatesEntity.LOCK_ID eq lockId.toString()) .and(stateRefCompositeColumn.`in`(stateRefArgs)).get().value() - if (updatedRows > 0 && updatedRows == stateRefs.size) { - log.trace("Reserving soft lock states for $lockId: $stateRefs") - FlowStateMachineImpl.currentStateMachine()?.hasSoftLockedStates = true - } else { - // revert partial soft locks - val revertUpdatedRows = update(VaultStatesEntity::class) - .set(VaultStatesEntity.LOCK_ID, null) - .where(VaultStatesEntity.LOCK_UPDATE_TIME eq softLockTimestamp) - .and(VaultStatesEntity.LOCK_ID eq lockId.toString()) - .and(stateRefCompositeColumn.`in`(stateRefArgs)).get().value() - if (revertUpdatedRows > 0) { - log.trace("Reverting $revertUpdatedRows partially soft locked states for $lockId") - } - throw StatesNotAvailableException("Attempted to reserve $stateRefs for $lockId but only $updatedRows rows available") + if (revertUpdatedRows > 0) { + log.trace("Reverting $revertUpdatedRows partially soft locked states for $lockId") } + throw StatesNotAvailableException("Attempted to reserve $stateRefs for $lockId but only $updatedRows rows available") } - } catch (e: PersistenceException) { - log.error("""soft lock update error attempting to reserve states for $lockId and $stateRefs") + } + } catch (e: PersistenceException) { + log.error("""soft lock update error attempting to reserve states for $lockId and $stateRefs") $e. """) - if (e.cause is StatesNotAvailableException) throw (e.cause as StatesNotAvailableException) - } + if (e.cause is StatesNotAvailableException) throw (e.cause as StatesNotAvailableException) } } - override fun softLockRelease(lockId: UUID, stateRefs: Set?) { + override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet?) { if (stateRefs == null) { - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { + session.withTransaction(transactionIsolationLevel) { val update = update(VaultStatesEntity::class) .set(VaultStatesEntity.LOCK_ID, null) .set(VaultStatesEntity.LOCK_UPDATE_TIME, services.clock.instant()) @@ -310,9 +314,9 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P log.trace("Releasing ${update.value()} soft locked states for $lockId") } } - } else if (stateRefs.isNotEmpty()) { + } else { try { - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { + session.withTransaction(transactionIsolationLevel) { val updatedRows = update(VaultStatesEntity::class) .set(VaultStatesEntity.LOCK_ID, null) .set(VaultStatesEntity.LOCK_UPDATE_TIME, services.clock.instant()) @@ -331,173 +335,109 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P } } - // coin selection retry loop counter, sleep (msecs) and lock for selecting states - val MAX_RETRIES = 5 - val RETRY_SLEEP = 100 - val spendLock: ReentrantLock = ReentrantLock() + // TODO We shouldn't need to rewrite the query if we could modify the defaults. + private class QueryEditor(val services: ServiceHub, + val lockId: UUID, + val contractType: Class) : IQueryCriteriaParser { + var alreadyHasVaultQuery: Boolean = false + var modifiedCriteria: QueryCriteria = QueryCriteria.VaultQueryCriteria(contractStateTypes = setOf(contractType), + softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(lockId)), + status = Vault.StateStatus.UNCONSUMED) + + override fun parseCriteria(criteria: QueryCriteria.CommonQueryCriteria): Collection { + modifiedCriteria = criteria + return emptyList() + } + + override fun parseCriteria(criteria: QueryCriteria.FungibleAssetQueryCriteria): Collection { + modifiedCriteria = criteria + return emptyList() + } + + override fun parseCriteria(criteria: QueryCriteria.LinearStateQueryCriteria): Collection { + modifiedCriteria = criteria + return emptyList() + } + + override fun parseCriteria(criteria: QueryCriteria.VaultCustomQueryCriteria): Collection { + modifiedCriteria = criteria + return emptyList() + } + + override fun parseCriteria(criteria: QueryCriteria.VaultQueryCriteria): Collection { + modifiedCriteria = criteria.copy(contractStateTypes = setOf(contractType), + softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(lockId)), + status = Vault.StateStatus.UNCONSUMED) + alreadyHasVaultQuery = true + return emptyList() + } + + override fun parseOr(left: QueryCriteria, right: QueryCriteria): Collection { + parse(left) + val modifiedLeft = modifiedCriteria + parse(right) + val modifiedRight = modifiedCriteria + modifiedCriteria = modifiedLeft.or(modifiedRight) + return emptyList() + } + + override fun parseAnd(left: QueryCriteria, right: QueryCriteria): Collection { + parse(left) + val modifiedLeft = modifiedCriteria + parse(right) + val modifiedRight = modifiedCriteria + modifiedCriteria = modifiedLeft.and(modifiedRight) + return emptyList() + } + + override fun parse(criteria: QueryCriteria, sorting: Sort?): Collection { + val basicQuery = modifiedCriteria + criteria.visit(this) + modifiedCriteria = if (alreadyHasVaultQuery) modifiedCriteria else criteria.and(basicQuery) + return emptyList() + } + + fun queryForEligibleStates(criteria: QueryCriteria): Vault.Page { + val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) + val sorter = Sort(setOf(Sort.SortColumn(sortAttribute, Sort.Direction.ASC))) + parse(criteria, sorter) + + return services.vaultQueryService.queryBy(contractType, modifiedCriteria, sorter) + } + } + @Suspendable - override fun unconsumedStatesForSpending(amount: Amount, onlyFromIssuerParties: Set?, notary: Party?, lockId: UUID, withIssuerRefs: Set?): List> { - - val issuerKeysStr = onlyFromIssuerParties?.fold("") { left, right -> left + "('${right.owningKey.toBase58String()}')," }?.dropLast(1) - val issuerRefsStr = withIssuerRefs?.fold("") { left, right -> left + "('${right.bytes.toHexString()}')," }?.dropLast(1) - - val stateAndRefs = mutableListOf>() - - // TODO: Need to provide a database provider independent means of performing this function. - // We are using an H2 specific means of selecting a minimum set of rows that match a request amount of coins: - // 1) There is no standard SQL mechanism of calculating a cumulative total on a field and restricting row selection on the - // running total of such an accumulator - // 2) H2 uses session variables to perform this accumulator function: - // http://www.h2database.com/html/functions.html#set - // 3) H2 does not support JOIN's in FOR UPDATE (hence we are forced to execute 2 queries) - - for (retryCount in 1..MAX_RETRIES) { - - spendLock.withLock { - val statement = configuration.jdbcSession().createStatement() - try { - statement.execute("CALL SET(@t, 0);") - - // we select spendable states irrespective of lock but prioritised by unlocked ones (Eg. null) - // the softLockReserve update will detect whether we try to lock states locked by others - val selectJoin = """ - SELECT vs.transaction_id, vs.output_index, vs.contract_state, ccs.pennies, SET(@t, ifnull(@t,0)+ccs.pennies) total_pennies, vs.lock_id - FROM vault_states AS vs, contract_cash_states AS ccs - WHERE vs.transaction_id = ccs.transaction_id AND vs.output_index = ccs.output_index - AND vs.state_status = 0 - AND ccs.ccy_code = '${amount.token}' and @t < ${amount.quantity} - AND (vs.lock_id = '$lockId' OR vs.lock_id is null) - """ + - (if (notary != null) - " AND vs.notary_key = '${notary.owningKey.toBase58String()}'" else "") + - (if (issuerKeysStr != null) - " AND ccs.issuer_key IN ($issuerKeysStr)" else "") + - (if (issuerRefsStr != null) - " AND ccs.issuer_ref IN ($issuerRefsStr)" else "") - - // Retrieve spendable state refs - val rs = statement.executeQuery(selectJoin) - stateAndRefs.clear() - log.debug(selectJoin) - var totalPennies = 0L - while (rs.next()) { - val txHash = SecureHash.parse(rs.getString(1)) - val index = rs.getInt(2) - val stateRef = StateRef(txHash, index) - val state = rs.getBytes(3).deserialize>(storageKryo()) - val pennies = rs.getLong(4) - totalPennies = rs.getLong(5) - val rowLockId = rs.getString(6) - stateAndRefs.add(StateAndRef(state, stateRef)) - log.trace { "ROW: $rowLockId ($lockId): $stateRef : $pennies ($totalPennies)" } - } - - if (stateAndRefs.isNotEmpty() && totalPennies >= amount.quantity) { - // we should have a minimum number of states to satisfy our selection `amount` criteria - log.trace("Coin selection for $amount retrieved ${stateAndRefs.count()} states totalling $totalPennies pennies: $stateAndRefs") - - // update database - softLockReserve(lockId, stateAndRefs.map { it.ref }.toSet()) - return stateAndRefs - } - log.trace("Coin selection requested $amount but retrieved $totalPennies pennies with state refs: ${stateAndRefs.map { it.ref }}") - // retry as more states may become available - } catch (e: SQLException) { - log.error("""Failed retrieving unconsumed states for: amount [$amount], onlyFromIssuerParties [$onlyFromIssuerParties], notary [$notary], lockId [$lockId] - $e. - """) - } catch (e: StatesNotAvailableException) { - stateAndRefs.clear() - log.warn(e.message) - // retry only if there are locked states that may become available again (or consumed with change) - } finally { - statement.close() - } - } - - log.warn("Coin selection failed on attempt $retryCount") - // TODO: revisit the back off strategy for contended spending. - if (retryCount != MAX_RETRIES) { - FlowStateMachineImpl.sleep(RETRY_SLEEP * retryCount.toLong()) - } + @Throws(StatesNotAvailableException::class) + override fun , U : Any> tryLockFungibleStatesForSpending(lockId: UUID, + eligibleStatesQuery: QueryCriteria, + amount: Amount, + contractType: Class): List> { + if (amount.quantity == 0L) { + return emptyList() } - log.warn("Insufficient spendable states identified for $amount") - return stateAndRefs - } + // TODO This helper code re-writes the query to alter the defaults on things such as soft locks + // and then runs the query. Ideally we would not need to do this. + val results = QueryEditor(services, lockId, contractType).queryForEligibleStates(eligibleStatesQuery) - override fun softLockedStates(lockId: UUID?): List> { - val stateAndRefs = - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { - val query = select(VaultSchema.VaultStates::class) - .where(VaultSchema.VaultStates::stateStatus eq Vault.StateStatus.UNCONSUMED) - .and(VaultSchema.VaultStates::contractStateClassName eq Cash.State::class.java.name) - if (lockId != null) - query.and(VaultSchema.VaultStates::lockId eq lockId) - else - query.and(VaultSchema.VaultStates::lockId.notNull()) - query.get() - .map { it -> - val stateRef = StateRef(SecureHash.parse(it.txId), it.index) - val state = it.contractState.deserialize>(storageKryo()) - StateAndRef(state, stateRef) - }.toList() - } - return stateAndRefs - } - - /** - * Generate a transaction that moves an amount of currency to the given pubkey. - * - * @param onlyFromParties if non-null, the asset states will be filtered to only include those issued by the set - * of given parties. This can be useful if the party you're trying to pay has expectations - * about which type of asset claims they are willing to accept. - */ - @Suspendable - override fun generateSpend(tx: TransactionBuilder, - amount: Amount, - to: AbstractParty, - onlyFromParties: Set?): Pair> { - // Retrieve unspent and unlocked cash states that meet our spending criteria. - val acceptableCoins = unconsumedStatesForSpending(amount, onlyFromParties, tx.notary, tx.lockId) - return OnLedgerAsset.generateSpend(tx, amount, to, acceptableCoins, - { state, amount, owner -> deriveState(state, amount, owner) }, - { Cash().generateMoveCommand() }) - } - - private fun deriveState(txState: TransactionState, amount: Amount>, owner: AbstractParty) - = txState.copy(data = txState.data.copy(amount = amount, owner = owner)) - - @VisibleForTesting - internal fun makeUpdate(tx: WireTransaction, ourKeys: Set): Vault.Update { - val ourNewStates = tx.outputs. - filter { isRelevant(it.data, ourKeys) }. - map { tx.outRef(it.data) } - - // Retrieve all unconsumed states for this transaction's inputs - val consumedStates = HashSet>() - if (tx.inputs.isNotEmpty()) { - session.withTransaction(TransactionIsolation.REPEATABLE_READ) { - val result = select(VaultStatesEntity::class). - where(stateRefCompositeColumn.`in`(stateRefArgs(tx.inputs))). - and(VaultSchema.VaultStates::stateStatus eq Vault.StateStatus.UNCONSUMED) - result.get().forEach { - val txHash = SecureHash.parse(it.txId) - val index = it.index - val state = it.contractState.deserialize>(storageKryo()) - consumedStates.add(StateAndRef(state, StateRef(txHash, index))) + var claimedAmount = 0L + val claimedStates = mutableListOf>() + for (state in results.states) { + val issuedAssetToken = state.state.data.amount.token + if (issuedAssetToken.product == amount.token) { + claimedStates += state + claimedAmount += state.state.data.amount.quantity + if (claimedAmount > amount.quantity) { + break } } } - - // Is transaction irrelevant? - if (consumedStates.isEmpty() && ourNewStates.isEmpty()) { - log.trace { "tx ${tx.id} was irrelevant to this vault, ignoring" } - return Vault.NoUpdate + if (claimedStates.isEmpty() || claimedAmount < amount.quantity) { + return emptyList() } - - return Vault.Update(consumedStates, ourNewStates.toHashSet()) + softLockReserve(lockId, claimedStates.map { it.ref }.toNonEmptySet()) + return claimedStates } // TODO : Persists this in DB. diff --git a/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt b/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt index 5ef516971a..9ead62d2c9 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/VaultSchema.kt @@ -6,6 +6,7 @@ import net.corda.core.node.services.Vault import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState +import net.corda.core.serialization.CordaSerializable import net.corda.core.utilities.OpaqueBytes import java.time.Instant import java.util.* @@ -19,6 +20,7 @@ object VaultSchema /** * First version of the Vault ORM schema */ +@CordaSerializable object VaultSchemaV1 : MappedSchema(schemaFamily = VaultSchema.javaClass, version = 1, mappedTypes = listOf(VaultStates::class.java, VaultLinearStates::class.java, VaultFungibleStates::class.java, CommonSchemaV1.Party::class.java)) { @Entity @@ -66,8 +68,7 @@ object VaultSchemaV1 : MappedSchema(schemaFamily = VaultSchema.javaClass, versio @Entity @Table(name = "vault_linear_states", indexes = arrayOf(Index(name = "external_id_index", columnList = "external_id"), - Index(name = "uuid_index", columnList = "uuid"), - Index(name = "deal_reference_index", columnList = "deal_reference"))) + Index(name = "uuid_index", columnList = "uuid"))) class VaultLinearStates( /** [ContractState] attributes */ @OneToMany(cascade = arrayOf(CascadeType.ALL)) @@ -80,18 +81,11 @@ object VaultSchemaV1 : MappedSchema(schemaFamily = VaultSchema.javaClass, versio var externalId: String?, @Column(name = "uuid", nullable = false) - var uuid: UUID, - - // TODO: DealState to be deprecated (collapsed into LinearState) - - /** Deal State attributes **/ - @Column(name = "deal_reference") - var dealReference: String + var uuid: UUID ) : PersistentState() { - constructor(uid: UniqueIdentifier, _dealReference: String, _participants: List) : + constructor(uid: UniqueIdentifier, _participants: List) : this(externalId = uid.externalId, uuid = uid.id, - dealReference = _dealReference, participants = _participants.map{ CommonSchemaV1.Party(it) }.toSet() ) } @@ -103,8 +97,8 @@ object VaultSchemaV1 : MappedSchema(schemaFamily = VaultSchema.javaClass, versio var participants: Set, /** [OwnableState] attributes */ - @OneToOne(cascade = arrayOf(CascadeType.ALL)) - var owner: CommonSchemaV1.Party, + @Column(name = "owner_id") + var owner: AbstractParty, /** [FungibleAsset] attributes * @@ -124,7 +118,7 @@ object VaultSchemaV1 : MappedSchema(schemaFamily = VaultSchema.javaClass, versio var issuerRef: ByteArray ) : PersistentState() { constructor(_owner: AbstractParty, _quantity: Long, _issuerParty: AbstractParty, _issuerRef: OpaqueBytes, _participants: List) : - this(owner = CommonSchemaV1.Party(_owner), + this(owner = _owner, quantity = _quantity, issuerParty = CommonSchemaV1.Party(_issuerParty), issuerRef = _issuerRef.bytes, diff --git a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt b/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt index 475c4f900a..2d063a1468 100644 --- a/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt +++ b/node/src/main/kotlin/net/corda/node/services/vault/VaultSoftLockManager.kt @@ -4,7 +4,9 @@ import net.corda.core.contracts.StateRef import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId import net.corda.core.node.services.VaultService +import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.loggerFor +import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.trace import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.StateMachineManager @@ -36,18 +38,18 @@ class VaultSoftLockManager(val vault: VaultService, smm: StateMachineManager) { // However, the lock can be programmatically released, like any other soft lock, // should we want a long running flow that creates a visible state mid way through. - vault.rawUpdates.subscribe { update -> - update.flowId?.let { - if (update.produced.isNotEmpty()) { - registerSoftLocks(update.flowId as UUID, update.produced.map { it.ref }) + vault.rawUpdates.subscribe { (_, produced, flowId) -> + flowId?.let { + if (produced.isNotEmpty()) { + registerSoftLocks(flowId, (produced.map { it.ref }).toNonEmptySet()) } } } } - private fun registerSoftLocks(flowId: UUID, stateRefs: List) { + private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet) { log.trace("Reserving soft locks for flow id $flowId and states $stateRefs") - vault.softLockReserve(flowId, stateRefs.toSet()) + vault.softLockReserve(flowId, stateRefs) } private fun unregisterSoftLocks(id: StateMachineRunId, logic: FlowLogic<*>) { diff --git a/node/src/main/kotlin/net/corda/node/shell/FlowWatchPrintingSubscriber.kt b/node/src/main/kotlin/net/corda/node/shell/FlowWatchPrintingSubscriber.kt index f249caa619..7211d06d9b 100644 --- a/node/src/main/kotlin/net/corda/node/shell/FlowWatchPrintingSubscriber.kt +++ b/node/src/main/kotlin/net/corda/node/shell/FlowWatchPrintingSubscriber.kt @@ -1,13 +1,12 @@ package net.corda.node.shell -import com.google.common.util.concurrent.SettableFuture import net.corda.core.crypto.commonName import net.corda.core.flows.FlowInitiator import net.corda.core.flows.StateMachineRunId +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.StateMachineUpdate import net.corda.core.messaging.StateMachineUpdate.Added import net.corda.core.messaging.StateMachineUpdate.Removed -import net.corda.core.then import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.Try import org.crsh.text.Color @@ -22,7 +21,7 @@ import rx.Subscriber class FlowWatchPrintingSubscriber(private val toStream: RenderPrintWriter) : Subscriber() { private val indexMap = HashMap() private val table = createStateMachinesTable() - val future: SettableFuture = SettableFuture.create() + val future = openFuture() init { // The future is public and can be completed by something else to indicate we don't wish to follow diff --git a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt index 553890b946..37386f42f0 100644 --- a/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt +++ b/node/src/main/kotlin/net/corda/node/shell/InteractiveShell.kt @@ -7,15 +7,19 @@ import com.fasterxml.jackson.databind.* import com.fasterxml.jackson.databind.module.SimpleModule import com.fasterxml.jackson.dataformat.yaml.YAMLFactory import com.google.common.io.Closeables -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.* +import net.corda.core.concurrent.CordaFuture +import net.corda.core.contracts.UniqueIdentifier import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.concurrent.OpenFuture +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div +import net.corda.core.internal.write +import net.corda.core.internal.* import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.StateMachineUpdate -import net.corda.core.utilities.Emoji import net.corda.core.utilities.loggerFor import net.corda.jackson.JacksonSupport import net.corda.jackson.StringToMethodCallParser @@ -48,7 +52,6 @@ import org.crsh.vfs.spi.url.ClassPathMountFactory import rx.Observable import rx.Subscriber import java.io.* -import java.lang.reflect.Constructor import java.lang.reflect.InvocationTargetException import java.nio.file.Files import java.nio.file.Path @@ -184,6 +187,8 @@ object InteractiveShell { JacksonSupport.createInMemoryMapper(node.services.identityService, YAMLFactory(), true).apply { val rpcModule = SimpleModule() rpcModule.addDeserializer(InputStream::class.java, InputStreamDeserializer) + rpcModule.addDeserializer(UniqueIdentifier::class.java, UniqueIdentifierDeserializer) + rpcModule.addDeserializer(UUID::class.java, UUIDDeserializer) registerModule(rpcModule) } } @@ -274,25 +279,25 @@ object InteractiveShell { val errors = ArrayList() for (ctor in clazz.constructors) { var paramNamesFromConstructor: List? = null - fun getPrototype(ctor: Constructor<*>): List { + fun getPrototype(): List { val argTypes = ctor.parameterTypes.map { it.simpleName } - val prototype = paramNamesFromConstructor!!.zip(argTypes).map { (name, type) -> "$name: $type" } - return prototype + return paramNamesFromConstructor!!.zip(argTypes).map { (name, type) -> "$name: $type" } } + try { // Attempt construction with the given arguments. paramNamesFromConstructor = parser.paramNamesFromConstructor(ctor) val args = parser.parseArguments(clazz.name, paramNamesFromConstructor.zip(ctor.parameterTypes), inputData) if (args.size != ctor.parameterTypes.size) { - errors.add("${getPrototype(ctor)}: Wrong number of arguments (${args.size} provided, ${ctor.parameterTypes.size} needed)") + errors.add("${getPrototype()}: Wrong number of arguments (${args.size} provided, ${ctor.parameterTypes.size} needed)") continue } val flow = ctor.newInstance(*args) as FlowLogic<*> return invoke(flow) } catch(e: StringToMethodCallParser.UnparseableCallException.MissingParameter) { - errors.add("${getPrototype(ctor)}: missing parameter ${e.paramName}") + errors.add("${getPrototype()}: missing parameter ${e.paramName}") } catch(e: StringToMethodCallParser.UnparseableCallException.TooManyParameters) { - errors.add("${getPrototype(ctor)}: too many parameters") + errors.add("${getPrototype()}: too many parameters") } catch(e: StringToMethodCallParser.UnparseableCallException.ReflectionDataMissing) { val argTypes = ctor.parameterTypes.map { it.simpleName } errors.add("$argTypes: ") @@ -308,7 +313,7 @@ object InteractiveShell { @JvmStatic fun runStateMachinesView(out: RenderPrintWriter): Any? { val proxy = node.rpcOps - val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates() + val (stateMachines, stateMachineUpdates) = proxy.stateMachinesFeed() val currentStateMachines = stateMachines.map { StateMachineUpdate.Added(it) } val subscriber = FlowWatchPrintingSubscriber(out) stateMachineUpdates.startWith(currentStateMachines).subscribe(subscriber) @@ -380,7 +385,7 @@ object InteractiveShell { return result } - private fun printAndFollowRPCResponse(response: Any?, toStream: PrintWriter): ListenableFuture? { + private fun printAndFollowRPCResponse(response: Any?, toStream: PrintWriter): CordaFuture? { val printerFun = { obj: Any? -> yamlMapper.writeValueAsString(obj) } toStream.println(printerFun(response)) toStream.flush() @@ -389,7 +394,7 @@ object InteractiveShell { private class PrintingSubscriber(private val printerFun: (Any?) -> String, private val toStream: PrintWriter) : Subscriber() { private var count = 0 - val future: SettableFuture = SettableFuture.create() + val future = openFuture() init { // The future is public and can be completed by something else to indicate we don't wish to follow @@ -420,7 +425,7 @@ object InteractiveShell { // Kotlin bug: USELESS_CAST warning is generated below but the IDE won't let us remove it. @Suppress("USELESS_CAST", "UNCHECKED_CAST") - private fun maybeFollow(response: Any?, printerFun: (Any?) -> String, toStream: PrintWriter): SettableFuture? { + private fun maybeFollow(response: Any?, printerFun: (Any?) -> String, toStream: PrintWriter): OpenFuture? { // Match on a couple of common patterns for "important" observables. It's tough to do this in a generic // way because observables can be embedded anywhere in the object graph, and can emit other arbitrary // object graphs that contain yet more observables. So we just look for top level responses that follow @@ -499,5 +504,37 @@ object InteractiveShell { } } + /** + * String value deserialized to [UniqueIdentifier]. + * Any string value used as [UniqueIdentifier.externalId]. + * If string contains underscore(i.e. externalId_uuid) then split with it. + * Index 0 as [UniqueIdentifier.externalId] + * Index 1 as [UniqueIdentifier.id] + * */ + object UniqueIdentifierDeserializer : JsonDeserializer() { + override fun deserialize(p: JsonParser, ctxt: DeserializationContext): UniqueIdentifier { + //Check if externalId and UUID may be separated by underscore. + if (p.text.contains("_")) { + val ids = p.text.split("_") + //Create UUID object from string. + val uuid: UUID = UUID.fromString(ids[1]) + //Create UniqueIdentifier object using externalId and UUID. + return UniqueIdentifier(ids[0], uuid) + } + //Any other string used as externalId. + return UniqueIdentifier(p.text) + } + } + + /** + * String value deserialized to [UUID]. + * */ + object UUIDDeserializer : JsonDeserializer() { + override fun deserialize(p: JsonParser, ctxt: DeserializationContext): UUID { + //Create UUID object from string. + return UUID.fromString(p.text) + } + } + //endregion } diff --git a/node/src/main/kotlin/net/corda/node/utilities/ANSIProgressRenderer.kt b/node/src/main/kotlin/net/corda/node/utilities/ANSIProgressRenderer.kt index e8b6de4269..8339d0ae22 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/ANSIProgressRenderer.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/ANSIProgressRenderer.kt @@ -1,6 +1,6 @@ package net.corda.node.utilities -import net.corda.core.utilities.Emoji +import net.corda.core.internal.Emoji import net.corda.core.utilities.ProgressTracker import net.corda.node.utilities.ANSIProgressRenderer.progressTracker import org.apache.logging.log4j.LogManager diff --git a/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt new file mode 100644 index 0000000000..c9fb26b8b3 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/AppendOnlyPersistentMap.kt @@ -0,0 +1,113 @@ +package net.corda.node.utilities + +import net.corda.core.utilities.loggerFor +import java.util.* + + +/** + * Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [put] twice the + * behaviour is unpredictable! There is a best-effort check for double inserts, but this should *not* be relied on, so + * ONLY USE THIS IF YOUR TABLE IS APPEND-ONLY + */ +class AppendOnlyPersistentMap ( + val toPersistentEntityKey: (K) -> EK, + val fromPersistentEntity: (E) -> Pair, + val toPersistentEntity: (key: K, value: V) -> E, + val persistentEntityClass: Class, + cacheBound: Long = 1024 +) { //TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size + + private companion object { + val log = loggerFor>() + } + + private val cache = NonInvalidatingCache>( + bound = cacheBound, + concurrencyLevel = 8, + loadFunction = { key -> Optional.ofNullable(loadValue(key)) } + ) + + /** + * Returns the value associated with the key, first loading that value from the storage if necessary. + */ + operator fun get(key: K): V? { + return cache.get(key).orElse(null) + } + + /** + * Returns all key/value pairs from the underlying storage. + */ + fun allPersisted(): Sequence> { + val criteriaQuery = DatabaseTransactionManager.current().session.criteriaBuilder.createQuery(persistentEntityClass) + val root = criteriaQuery.from(persistentEntityClass) + criteriaQuery.select(root) + val query = DatabaseTransactionManager.current().session.createQuery(criteriaQuery) + val result = query.resultList + return result.map { x -> fromPersistentEntity(x) }.asSequence() + } + + private tailrec fun set(key: K, value: V, logWarning: Boolean = true, store: (K,V) -> V?): Boolean { + var insertionAttempt = false + var isUnique = true + val existingInCache = cache.get(key) { // Thread safe, if multiple threads may wait until the first one has loaded. + insertionAttempt = true + // Key wasn't in the cache and might be in the underlying storage. + // Depending on 'store' method, this may insert without checking key duplication or it may avoid inserting a duplicated key. + val existingInDb = store(key, value) + if (existingInDb != null) { // Always reuse an existing value from the storage of a duplicated key. + Optional.of(existingInDb) + } else { + Optional.of(value) + } + } + if (!insertionAttempt) { + if (existingInCache.isPresent) { + // Key already exists in cache, do nothing. + isUnique = false + } else { + // This happens when the key was queried before with no value associated. We invalidate the cached null + // value and recursively call set again. This is to avoid race conditions where another thread queries after + // the invalidate but before the set. + cache.invalidate(key) + return set(key, value, logWarning, store) + } + } + if (logWarning && !isUnique) { + log.warn("Double insert in ${this.javaClass.name} for entity class $persistentEntityClass key $key, not inserting the second time") + } + return isUnique + } + + /** + * Puts the value into the map and the underlying storage. + * Inserting the duplicated key may be unpredictable. + */ + operator fun set(key: K, value: V) = + set(key, value, logWarning = false) { + key,value -> DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) + null + } + + /** + * Puts the value into the map and underlying storage. + * Duplicated key is not added into the map and underlying storage. + * @return true if added key was unique, otherwise false + */ + fun addWithDuplicatesAllowed(key: K, value: V): Boolean = + set(key, value) { + key, value -> + val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + if (existingEntry == null) { + DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value)) + null + } else { + fromPersistentEntity(existingEntry).second + } + } + + private fun loadValue(key: K): V? { + val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key)) + return result?.let(fromPersistentEntity)?.second + } + +} diff --git a/node/src/main/kotlin/net/corda/node/utilities/ClockUtils.kt b/node/src/main/kotlin/net/corda/node/utilities/ClockUtils.kt index 3a36904c31..44f3d0ec2c 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/ClockUtils.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/ClockUtils.kt @@ -3,16 +3,14 @@ package net.corda.node.utilities import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.SettableFuture import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.then +import net.corda.core.internal.until import rx.Observable import rx.Subscriber import rx.subscriptions.Subscriptions import java.time.Clock -import java.time.Duration import java.time.Instant import java.util.concurrent.* import java.util.concurrent.atomic.AtomicLong -import java.util.function.BiConsumer import com.google.common.util.concurrent.SettableFuture as GuavaSettableFuture /** @@ -80,7 +78,7 @@ fun Clock.awaitWithDeadline(deadline: Instant, future: Future<*> = GuavaSettable } else { null } - nanos = Duration.between(this.instant(), deadline).toNanos() + nanos = (instant() until deadline).toNanos() if (nanos > 0) { try { // This will return when it times out, or when the clock mutates or when when the original future completes. @@ -106,17 +104,11 @@ fun Clock.awaitWithDeadline(deadline: Instant, future: Future<*> = GuavaSettable * We need this so that we do not block the actual thread when calling get(), but instead allow a Quasar context * switch. There's no need to checkpoint our Fibers as there's no external effect of waiting. */ -private fun makeStrandFriendlySettableFuture(future: Future): SettableFuture { - return if (future is ListenableFuture) { - val settable = SettableFuture() - future.then { settable.set(true) } - settable - } else if (future is CompletableFuture) { - val settable = SettableFuture() - future.whenComplete(BiConsumer { _, _ -> settable.set(true) }) - settable - } else { - throw IllegalArgumentException("Cannot make future $future Fiber friendly.") +private fun makeStrandFriendlySettableFuture(future: Future) = SettableFuture().also { g -> + when (future) { + is ListenableFuture -> future.addListener(Runnable { g.set(true) }, Executor { it.run() }) + is CompletionStage<*> -> future.whenComplete { _, _ -> g.set(true) } + else -> throw IllegalArgumentException("Cannot make future $future Fiber friendly.") } } diff --git a/node/src/main/kotlin/net/corda/node/utilities/CordaPersistence.kt b/node/src/main/kotlin/net/corda/node/utilities/CordaPersistence.kt new file mode 100644 index 0000000000..a29fee242e --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/CordaPersistence.kt @@ -0,0 +1,215 @@ +package net.corda.node.utilities + +import com.zaxxer.hikari.HikariConfig +import com.zaxxer.hikari.HikariDataSource +import net.corda.core.node.services.IdentityService +import net.corda.core.schemas.MappedSchema +import net.corda.node.services.database.HibernateConfiguration +import net.corda.node.services.schema.NodeSchemaService +import org.hibernate.SessionFactory +import org.jetbrains.exposed.sql.Database + +import rx.Observable +import rx.Subscriber +import rx.subjects.UnicastSubject +import java.io.Closeable +import java.sql.Connection +import java.sql.SQLException +import java.util.* +import java.util.concurrent.CopyOnWriteArrayList + + +//HikariDataSource implements Closeable which allows CordaPersistence to be Closeable +class CordaPersistence(var dataSource: HikariDataSource, var nodeSchemaService: NodeSchemaService, val identitySvc: ()-> IdentityService, databaseProperties: Properties): Closeable { + + /** Holds Exposed database, the field will be removed once Exposed library is removed */ + lateinit var database: Database + var transactionIsolationLevel = parserTransactionIsolationLevel(databaseProperties.getProperty("transactionIsolationLevel")) + + val entityManagerFactory: SessionFactory by lazy(LazyThreadSafetyMode.NONE) { + transaction { + HibernateConfiguration(nodeSchemaService, databaseProperties, identitySvc).sessionFactoryForRegisteredSchemas() + } + } + + companion object { + fun connect(dataSource: HikariDataSource, nodeSchemaService: NodeSchemaService, identitySvc: () -> IdentityService, databaseProperties: Properties): CordaPersistence { + return CordaPersistence(dataSource, nodeSchemaService, identitySvc, databaseProperties).apply { + DatabaseTransactionManager(this) + } + } + } + + fun createTransaction(): DatabaseTransaction { + // We need to set the database for the current [Thread] or [Fiber] here as some tests share threads across databases. + DatabaseTransactionManager.dataSource = this + return DatabaseTransactionManager.currentOrNew(transactionIsolationLevel) + } + + fun createSession(): Connection { + // We need to set the database for the current [Thread] or [Fiber] here as some tests share threads across databases. + DatabaseTransactionManager.dataSource = this + val ctx = DatabaseTransactionManager.currentOrNull() + return ctx?.connection ?: throw IllegalStateException("Was expecting to find database transaction: must wrap calling code within a transaction.") + } + + fun transaction(statement: DatabaseTransaction.() -> T): T { + DatabaseTransactionManager.dataSource = this + return transaction(transactionIsolationLevel, 3, statement) + } + + private fun transaction(transactionIsolation: Int, repetitionAttempts: Int, statement: DatabaseTransaction.() -> T): T { + val outer = DatabaseTransactionManager.currentOrNull() + + return if (outer != null) { + outer.statement() + } + else { + inTopLevelTransaction(transactionIsolation, repetitionAttempts, statement) + } + } + + private fun inTopLevelTransaction(transactionIsolation: Int, repetitionAttempts: Int, statement: DatabaseTransaction.() -> T): T { + var repetitions = 0 + while (true) { + val transaction = DatabaseTransactionManager.currentOrNew(transactionIsolation) + try { + val answer = transaction.statement() + transaction.commit() + return answer + } + catch (e: SQLException) { + transaction.rollback() + repetitions++ + if (repetitions >= repetitionAttempts) { + throw e + } + } + catch (e: Throwable) { + transaction.rollback() + throw e + } + finally { + transaction.close() + } + } + } + + override fun close() { + dataSource.close() + } +} + +fun configureDatabase(dataSourceProperties: Properties, databaseProperties: Properties?, entitySchemas: Set = emptySet(), identitySvc: ()-> IdentityService): CordaPersistence { + val config = HikariConfig(dataSourceProperties) + val dataSource = HikariDataSource(config) + val persistence = CordaPersistence.connect(dataSource, NodeSchemaService(entitySchemas), identitySvc, databaseProperties ?: Properties()) + + //org.jetbrains.exposed.sql.Database will be removed once Exposed library is removed + val database = Database.connect(dataSource) { _ -> ExposedTransactionManager() } + persistence.database = database + + // Check not in read-only mode. + persistence.transaction { + persistence.dataSource.connection.use { + check(!it.metaData.isReadOnly) { "Database should not be readonly." } + } + } + return persistence +} + +/** + * Buffer observations until after the current database transaction has been closed. Observations are never + * dropped, simply delayed. + * + * Primarily for use by component authors to publish observations during database transactions without racing against + * closing the database transaction. + * + * For examples, see the call hierarchy of this function. + */ +fun rx.Observer.bufferUntilDatabaseCommit(): rx.Observer { + val currentTxId = DatabaseTransactionManager.transactionId + val databaseTxBoundary: Observable = DatabaseTransactionManager.transactionBoundaries.filter { it.txId == currentTxId }.first() + val subject = UnicastSubject.create() + subject.delaySubscription(databaseTxBoundary).subscribe(this) + databaseTxBoundary.doOnCompleted { subject.onCompleted() } + return subject +} + +// A subscriber that delegates to multiple others, wrapping a database transaction around the combination. +private class DatabaseTransactionWrappingSubscriber(val db: CordaPersistence?) : Subscriber() { + // Some unsubscribes happen inside onNext() so need something that supports concurrent modification. + val delegates = CopyOnWriteArrayList>() + + fun forEachSubscriberWithDbTx(block: Subscriber.() -> Unit) { + (db ?: DatabaseTransactionManager.dataSource).transaction { + delegates.filter { !it.isUnsubscribed }.forEach { + it.block() + } + } + } + + override fun onCompleted() = forEachSubscriberWithDbTx { onCompleted() } + + override fun onError(e: Throwable?) = forEachSubscriberWithDbTx { onError(e) } + + override fun onNext(s: U) = forEachSubscriberWithDbTx { onNext(s) } + + override fun onStart() = forEachSubscriberWithDbTx { onStart() } + + fun cleanUp() { + if (delegates.removeIf { it.isUnsubscribed }) { + if (delegates.isEmpty()) { + unsubscribe() + } + } + } +} + +// A subscriber that wraps another but does not pass on observations to it. +private class NoOpSubscriber(t: Subscriber): Subscriber(t) { + override fun onCompleted() { + } + + override fun onError(e: Throwable?) { + } + + override fun onNext(s: U) { + } +} + +/** + * Wrap delivery of observations in a database transaction. Multiple subscribers will receive the observations inside + * the same database transaction. This also lazily subscribes to the source [rx.Observable] to preserve any buffering + * that might be in place. + */ +fun rx.Observable.wrapWithDatabaseTransaction(db: CordaPersistence? = null): rx.Observable { + var wrappingSubscriber = DatabaseTransactionWrappingSubscriber(db) + // Use lift to add subscribers to a special subscriber that wraps a database transaction around observations. + // Each subscriber will be passed to this lambda when they subscribe, at which point we add them to wrapping subscriber. + return this.lift { toBeWrappedInDbTx: Subscriber -> + // Add the subscriber to the wrapping subscriber, which will invoke the original subscribers together inside a database transaction. + wrappingSubscriber.delegates.add(toBeWrappedInDbTx) + // If we are the first subscriber, return the shared subscriber, otherwise return a subscriber that does nothing. + if (wrappingSubscriber.delegates.size == 1) wrappingSubscriber else NoOpSubscriber(toBeWrappedInDbTx) + // Clean up the shared list of subscribers when they unsubscribe. + }.doOnUnsubscribe { + wrappingSubscriber.cleanUp() + // If cleanup removed the last subscriber reset the system, as future subscribers might need the stream again + if (wrappingSubscriber.delegates.isEmpty()) { + wrappingSubscriber = DatabaseTransactionWrappingSubscriber(db) + } + } +} + +fun parserTransactionIsolationLevel(property: String?): Int = + when (property) { + "none" -> Connection.TRANSACTION_NONE + "readUncommitted" -> Connection.TRANSACTION_READ_UNCOMMITTED + "readCommitted" -> Connection.TRANSACTION_READ_COMMITTED + "repeatableRead" -> Connection.TRANSACTION_REPEATABLE_READ + "serializable" -> Connection.TRANSACTION_SERIALIZABLE + else -> { + Connection.TRANSACTION_REPEATABLE_READ + } + } diff --git a/node/src/main/kotlin/net/corda/node/utilities/DatabaseSupport.kt b/node/src/main/kotlin/net/corda/node/utilities/DatabaseSupport.kt index 2c50bca43f..028117015c 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/DatabaseSupport.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/DatabaseSupport.kt @@ -1,289 +1,26 @@ package net.corda.node.utilities -import co.paralleluniverse.strands.Strand -import com.zaxxer.hikari.HikariConfig -import com.zaxxer.hikari.HikariDataSource import net.corda.core.crypto.SecureHash import net.corda.core.crypto.parsePublicKeyBase58 import net.corda.core.crypto.toBase58String -import net.corda.node.utilities.StrandLocalTransactionManager.Boundary import org.bouncycastle.cert.X509CertificateHolder import org.h2.jdbc.JdbcBlob import org.jetbrains.exposed.sql.* -import org.jetbrains.exposed.sql.transactions.TransactionInterface -import org.jetbrains.exposed.sql.transactions.TransactionManager -import rx.Observable -import rx.Subscriber -import rx.subjects.PublishSubject -import rx.subjects.Subject -import rx.subjects.UnicastSubject import java.io.ByteArrayInputStream -import java.io.Closeable import java.security.PublicKey import java.security.cert.CertPath import java.security.cert.CertificateFactory -import java.sql.Connection import java.time.Instant import java.time.LocalDate import java.time.LocalDateTime import java.time.ZoneOffset import java.util.* -import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.CopyOnWriteArrayList /** * Table prefix for all tables owned by the node module. */ const val NODE_DATABASE_PREFIX = "node_" -@Deprecated("Use Database.transaction instead.") -fun databaseTransaction(db: Database, statement: Transaction.() -> T) = db.transaction(statement) - -// TODO: Handle commit failure due to database unavailable. Better to shutdown and await database reconnect/recovery. -fun Database.transaction(statement: Transaction.() -> T): T { - // We need to set the database for the current [Thread] or [Fiber] here as some tests share threads across databases. - StrandLocalTransactionManager.database = this - return org.jetbrains.exposed.sql.transactions.transaction(Connection.TRANSACTION_REPEATABLE_READ, 1, statement) -} - -fun Database.createTransaction(): Transaction { - // We need to set the database for the current [Thread] or [Fiber] here as some tests share threads across databases. - StrandLocalTransactionManager.database = this - return TransactionManager.currentOrNew(Connection.TRANSACTION_REPEATABLE_READ) -} - -fun configureDatabase(props: Properties): Pair { - val config = HikariConfig(props) - val dataSource = HikariDataSource(config) - val database = Database.connect(dataSource) { db -> StrandLocalTransactionManager(db) } - // Check not in read-only mode. - database.transaction { - check(!database.metadata.isReadOnly) { "Database should not be readonly." } - } - return Pair(dataSource, database) -} - -fun Database.isolatedTransaction(block: Transaction.() -> T): T { - val oldContext = StrandLocalTransactionManager.setThreadLocalTx(null) - return try { - transaction(block) - } finally { - StrandLocalTransactionManager.restoreThreadLocalTx(oldContext) - } -} - -/** - * Helper method wrapping code in try finally block. A mutable list is used to keep track of functions that need to be executed in finally block. - */ -inline fun withFinalizables(statement: (MutableList<() -> Unit>) -> T): T { - val finalizables = mutableListOf<() -> Unit>() - return try { - statement(finalizables) - } finally { - finalizables.forEach { it() } - } -} - -/** - * A relatively close copy of the [org.jetbrains.exposed.sql.transactions.ThreadLocalTransactionManager] - * in Exposed but with the following adjustments to suit our environment: - * - * Because the construction of a [Database] instance results in replacing the singleton [TransactionManager] instance, - * our tests involving two [MockNode]s effectively replace the database instances of each other and continue to trample - * over each other. So here we use a companion object to hold them as [ThreadLocal] and [StrandLocalTransactionManager] - * is otherwise effectively stateless so it's replacement does not matter. The [ThreadLocal] is then set correctly and - * explicitly just prior to initiating a transaction in [transaction] and [createTransaction] above. - * - * The [StrandLocalTransactionManager] instances have an [Observable] of the transaction close [Boundary]s which - * facilitates the use of [Observable.afterDatabaseCommit] to create event streams that only emit once the database - * transaction is closed and the data has been persisted and becomes visible to other observers. - */ -class StrandLocalTransactionManager(initWithDatabase: Database) : TransactionManager { - - companion object { - private val TX_ID = Key() - - private val threadLocalDb = ThreadLocal() - private val threadLocalTx = ThreadLocal() - private val databaseToInstance = ConcurrentHashMap() - - fun setThreadLocalTx(tx: Transaction?): Pair { - val oldTx = threadLocalTx.get() - threadLocalTx.set(tx) - return Pair(threadLocalDb.get(), oldTx) - } - - fun restoreThreadLocalTx(context: Pair) { - threadLocalDb.set(context.first) - threadLocalTx.set(context.second) - } - - var database: Database - get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find database set on current strand: ${Strand.currentStrand()}") - set(value) { - threadLocalDb.set(value) - } - - val transactionId: UUID - get() = threadLocalTx.get()?.getUserData(TX_ID) ?: throw IllegalStateException("Was expecting to find transaction set on current strand: ${Strand.currentStrand()}") - - val manager: StrandLocalTransactionManager get() = databaseToInstance[database]!! - - val transactionBoundaries: Subject get() = manager._transactionBoundaries - } - - - data class Boundary(val txId: UUID) - - private val _transactionBoundaries = PublishSubject.create().toSerialized() - - init { - // Found a unit test that was forgetting to close the database transactions. When you close() on the top level - // database transaction it will reset the threadLocalTx back to null, so if it isn't then there is still a - // databae transaction open. The [transaction] helper above handles this in a finally clause for you - // but any manual database transaction management is liable to have this problem. - if (threadLocalTx.get() != null) { - throw IllegalStateException("Was not expecting to find existing database transaction on current strand when setting database: ${Strand.currentStrand()}, ${threadLocalTx.get()}") - } - database = initWithDatabase - databaseToInstance[database] = this - } - - override fun newTransaction(isolation: Int): Transaction { - val impl = StrandLocalTransaction(database, isolation, threadLocalTx, transactionBoundaries) - return Transaction(impl).apply { - threadLocalTx.set(this) - putUserData(TX_ID, impl.id) - } - } - - override fun currentOrNull(): Transaction? = threadLocalTx.get() - - // Direct copy of [ThreadLocalTransaction]. - private class StrandLocalTransaction(override val db: Database, isolation: Int, val threadLocal: ThreadLocal, val transactionBoundaries: Subject) : TransactionInterface { - val id = UUID.randomUUID() - - override val connection: Connection by lazy(LazyThreadSafetyMode.NONE) { - db.connector().apply { - autoCommit = false - transactionIsolation = isolation - } - } - - override val outerTransaction = threadLocal.get() - - override fun commit() { - connection.commit() - } - - override fun rollback() { - if (!connection.isClosed) { - connection.rollback() - } - } - - override fun close() { - connection.close() - threadLocal.set(outerTransaction) - if (outerTransaction == null) { - transactionBoundaries.onNext(Boundary(id)) - } - } - } -} - -/** - * Buffer observations until after the current database transaction has been closed. Observations are never - * dropped, simply delayed. - * - * Primarily for use by component authors to publish observations during database transactions without racing against - * closing the database transaction. - * - * For examples, see the call hierarchy of this function. - */ -fun rx.Observer.bufferUntilDatabaseCommit(): rx.Observer { - val currentTxId = StrandLocalTransactionManager.transactionId - val databaseTxBoundary: Observable = StrandLocalTransactionManager.transactionBoundaries.filter { it.txId == currentTxId }.first() - val subject = UnicastSubject.create() - subject.delaySubscription(databaseTxBoundary).subscribe(this) - databaseTxBoundary.doOnCompleted { subject.onCompleted() } - return subject -} - -// A subscriber that delegates to multiple others, wrapping a database transaction around the combination. -private class DatabaseTransactionWrappingSubscriber(val db: Database?) : Subscriber() { - // Some unsubscribes happen inside onNext() so need something that supports concurrent modification. - val delegates = CopyOnWriteArrayList>() - - fun forEachSubscriberWithDbTx(block: Subscriber.() -> Unit) { - (db ?: StrandLocalTransactionManager.database).transaction { - delegates.filter { !it.isUnsubscribed }.forEach { - it.block() - } - } - } - - override fun onCompleted() { - forEachSubscriberWithDbTx { onCompleted() } - } - - override fun onError(e: Throwable?) { - forEachSubscriberWithDbTx { onError(e) } - } - - override fun onNext(s: U) { - forEachSubscriberWithDbTx { onNext(s) } - } - - override fun onStart() { - forEachSubscriberWithDbTx { onStart() } - } - - fun cleanUp() { - if (delegates.removeIf { it.isUnsubscribed }) { - if (delegates.isEmpty()) { - unsubscribe() - } - } - } -} - -// A subscriber that wraps another but does not pass on observations to it. -private class NoOpSubscriber(t: Subscriber) : Subscriber(t) { - override fun onCompleted() { - } - - override fun onError(e: Throwable?) { - } - - override fun onNext(s: U) { - } -} - -/** - * Wrap delivery of observations in a database transaction. Multiple subscribers will receive the observations inside - * the same database transaction. This also lazily subscribes to the source [rx.Observable] to preserve any buffering - * that might be in place. - */ -fun rx.Observable.wrapWithDatabaseTransaction(db: Database? = null): rx.Observable { - var wrappingSubscriber = DatabaseTransactionWrappingSubscriber(db) - // Use lift to add subscribers to a special subscriber that wraps a database transaction around observations. - // Each subscriber will be passed to this lambda when they subscribe, at which point we add them to wrapping subscriber. - return this.lift { toBeWrappedInDbTx: Subscriber -> - // Add the subscriber to the wrapping subscriber, which will invoke the original subscribers together inside a database transaction. - wrappingSubscriber.delegates.add(toBeWrappedInDbTx) - // If we are the first subscriber, return the shared subscriber, otherwise return a subscriber that does nothing. - if (wrappingSubscriber.delegates.size == 1) wrappingSubscriber else NoOpSubscriber(toBeWrappedInDbTx) - // Clean up the shared list of subscribers when they unsubscribe. - }.doOnUnsubscribe { - wrappingSubscriber.cleanUp() - // If cleanup removed the last subscriber reset the system, as future subscribers might need the stream again - if (wrappingSubscriber.delegates.isEmpty()) { - wrappingSubscriber = DatabaseTransactionWrappingSubscriber(db) - } - } -} - // Composite columns for use with below Exposed helpers. data class PartyColumns(val name: Column, val owningKey: Column) data class PartyAndCertificateColumns(val name: Column, val owningKey: Column, @@ -297,7 +34,6 @@ data class TxnNoteColumns(val txId: Column, val note: Column fun Table.certificate(name: String) = this.registerColumn(name, X509CertificateColumnType) fun Table.certificatePath(name: String) = this.registerColumn(name, CertPathColumnType) fun Table.publicKey(name: String) = this.registerColumn(name, PublicKeyColumnType) - fun Table.secureHash(name: String) = this.registerColumn(name, SecureHashColumnType) fun Table.party(nameColumnName: String, keyColumnName: String) = PartyColumns(this.varchar(nameColumnName, length = 255), this.publicKey(keyColumnName)) diff --git a/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt b/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt new file mode 100644 index 0000000000..da64097850 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/DatabaseTransactionManager.kt @@ -0,0 +1,123 @@ +package net.corda.node.utilities + +import co.paralleluniverse.strands.Strand +import org.hibernate.Session +import org.hibernate.Transaction +import rx.subjects.PublishSubject +import rx.subjects.Subject +import java.sql.Connection +import java.util.* +import java.util.concurrent.ConcurrentHashMap + +class DatabaseTransaction(isolation: Int, val threadLocal: ThreadLocal, + val transactionBoundaries: Subject, + val cordaPersistence: CordaPersistence) { + + val id: UUID = UUID.randomUUID() + + val connection: Connection by lazy(LazyThreadSafetyMode.NONE) { + cordaPersistence.dataSource.connection + .apply { + autoCommit = false + transactionIsolation = isolation + } + } + + private val sessionDelegate = lazy { + val session = cordaPersistence.entityManagerFactory.withOptions().connection(connection).openSession() + hibernateTransaction = session.beginTransaction() + session + } + + val session: Session by sessionDelegate + private lateinit var hibernateTransaction : Transaction + + val outerTransaction: DatabaseTransaction? = threadLocal.get() + + fun commit() { + if (sessionDelegate.isInitialized()) { + hibernateTransaction.commit() + } + connection.commit() + } + + fun rollback() { + if (sessionDelegate.isInitialized() && session.isOpen) { + session.clear() + } + if (!connection.isClosed) { + connection.rollback() + } + } + + fun close() { + connection.close() + threadLocal.set(outerTransaction) + if (outerTransaction == null) { + transactionBoundaries.onNext(DatabaseTransactionManager.Boundary(id)) + } + } +} + +class DatabaseTransactionManager(initDataSource: CordaPersistence) { + companion object { + private val threadLocalDb = ThreadLocal() + private val threadLocalTx = ThreadLocal() + private val databaseToInstance = ConcurrentHashMap() + + fun setThreadLocalTx(tx: DatabaseTransaction?): DatabaseTransaction? { + val oldTx = threadLocalTx.get() + threadLocalTx.set(tx) + return oldTx + } + + fun restoreThreadLocalTx(context: DatabaseTransaction?) { + if (context != null) { + threadLocalDb.set(context.cordaPersistence) + } + threadLocalTx.set(context) + } + + var dataSource: CordaPersistence + get() = threadLocalDb.get() ?: throw IllegalStateException("Was expecting to find CordaPersistence set on current thread: ${Strand.currentStrand()}") + set(value) = threadLocalDb.set(value) + + val transactionId: UUID + get() = threadLocalTx.get()?.id ?: throw IllegalStateException("Was expecting to find transaction set on current strand: ${Strand.currentStrand()}") + + val manager: DatabaseTransactionManager get() = databaseToInstance[dataSource]!! + + val transactionBoundaries: Subject get() = manager._transactionBoundaries + + fun currentOrNull(): DatabaseTransaction? = manager.currentOrNull() + + fun currentOrNew(isolation: Int = dataSource.transactionIsolationLevel) = currentOrNull() ?: manager.newTransaction(isolation) + + fun current(): DatabaseTransaction = currentOrNull() ?: error("No transaction in context.") + + fun newTransaction(isolation: Int = dataSource.transactionIsolationLevel) = manager.newTransaction(isolation) + } + + data class Boundary(val txId: UUID) + + private val _transactionBoundaries = PublishSubject.create().toSerialized() + + init { + // Found a unit test that was forgetting to close the database transactions. When you close() on the top level + // database transaction it will reset the threadLocalTx back to null, so if it isn't then there is still a + // database transaction open. The [transaction] helper above handles this in a finally clause for you + // but any manual database transaction management is liable to have this problem. + if (threadLocalTx.get() != null) { + throw IllegalStateException("Was not expecting to find existing database transaction on current strand when setting database: ${Strand.currentStrand()}, ${threadLocalTx.get()}") + } + dataSource = initDataSource + databaseToInstance[dataSource] = this + } + + private fun newTransaction(isolation: Int) = + DatabaseTransaction(isolation, threadLocalTx, transactionBoundaries, dataSource).apply { + threadLocalTx.set(this) + } + + private fun currentOrNull(): DatabaseTransaction? = threadLocalTx.get() +} diff --git a/node/src/main/kotlin/net/corda/node/utilities/ExposedTransactionManager.kt b/node/src/main/kotlin/net/corda/node/utilities/ExposedTransactionManager.kt new file mode 100644 index 0000000000..1d90449c35 --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/ExposedTransactionManager.kt @@ -0,0 +1,54 @@ +package net.corda.node.utilities + +import org.jetbrains.exposed.sql.Database +import org.jetbrains.exposed.sql.Transaction +import org.jetbrains.exposed.sql.transactions.TransactionInterface +import org.jetbrains.exposed.sql.transactions.TransactionManager +import java.sql.Connection + +/** + * Wrapper of [DatabaseTransaction], because the class is effectively used for [ExposedTransaction.connection] method only not all methods are implemented. + * The class will obsolete when Exposed library is phased out. + */ +class ExposedTransaction(override val db: Database, val databaseTransaction: DatabaseTransaction): TransactionInterface { + + override val outerTransaction: Transaction? + get() = throw UnsupportedOperationException() + + override val connection: Connection by lazy(LazyThreadSafetyMode.NONE) { + databaseTransaction.connection + } + + override fun commit() { + databaseTransaction.commit() + } + + override fun rollback() { + databaseTransaction.rollback() + } + + override fun close() { + databaseTransaction.close() + } +} + +/** + * Delegates methods to [DatabaseTransactionManager]. + * The class will obsolete when Exposed library is phased out. + */ +class ExposedTransactionManager: TransactionManager { + companion object { + val database: Database + get() = DatabaseTransactionManager.dataSource.database + } + + override fun newTransaction(isolation: Int): Transaction { + var databaseTransaction = DatabaseTransactionManager.newTransaction(isolation) + return Transaction(ExposedTransaction(database, databaseTransaction)) + } + + override fun currentOrNull(): Transaction? { + val databaseTransaction = DatabaseTransactionManager.currentOrNull() + return if (databaseTransaction != null) Transaction(ExposedTransaction(database, databaseTransaction)) else null + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/utilities/FiberBox.kt b/node/src/main/kotlin/net/corda/node/utilities/FiberBox.kt deleted file mode 100644 index d930fe206d..0000000000 --- a/node/src/main/kotlin/net/corda/node/utilities/FiberBox.kt +++ /dev/null @@ -1,76 +0,0 @@ -package net.corda.node.utilities - -import co.paralleluniverse.fibers.Suspendable -import co.paralleluniverse.strands.concurrent.ReentrantLock -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.RetryableException -import java.time.Clock -import java.time.Instant -import java.util.concurrent.Future -import java.util.concurrent.locks.Lock -import kotlin.concurrent.withLock - -// TODO: We should consider using a Semaphore or CountDownLatch here to make it a little easier to understand, but it seems as though the current version of Quasar does not support suspending on either of their implementations. - -/** - * Modelled on [net.corda.core.ThreadBox], but with support for waiting that is compatible with Quasar [Fiber]s and [MutableClock]s. - * - * It supports 3 main operations, all of which operate in a similar context to the [locked] method - * of [net.corda.core.ThreadBox]. i.e. in the context of the content. - * * [read] operations which acquire the associated lock but do not notify any waiters (see [readWithDeadline]) - * and is a direct equivalent of [net.corda.core.ThreadBox.locked]. - * * [write] operations which are the same as [read] operations but additionally notify any waiters that the content may have changed. - * * [readWithDeadline] operations acquire the lock and are evaluated repeatedly until they no longer throw any subclass - * of [RetryableException]. Between iterations it will wait until woken by a [write] or the deadline is reached. It will eventually - * re-throw a [RetryableException] if the deadline passes without any successful iterations. - * - * The construct also supports [MutableClock]s so it can cope with artificial progress towards the deadline, for simulations - * or testing. - * - * Currently this is intended for use within a node as a simplified way for Oracles to implement subscriptions for changing - * data by running a flow internally to implement the request handler which can then - * effectively relinquish control until the data becomes available. This isn't the most scalable design and is intended - * to be temporary. In addition, it's enitrely possible to envisage a time when we want public [net.corda.core.flows.FlowLogic] - * implementations to be able to wait for some condition to become true outside of message send/receive. At that point - * we may revisit this implementation and indeed the whole model for this, when we understand that requirement more fully. - */ -// TODO This is no longer used and can be removed -class FiberBox(private val content: T, private val lock: Lock = ReentrantLock()) { - private var mutated: SettableFuture? = null - - @Suppress("UNUSED_VALUE") // This is here due to the compiler thinking ourMutated is not used - @Suspendable - fun readWithDeadline(clock: Clock, deadline: Instant, body: T.() -> R): R { - var ex: Exception - var ourMutated: Future? = null - do { - lock.lock() - try { - if (mutated == null || mutated!!.isDone) { - mutated = SettableFuture.create() - } - ourMutated = mutated - return body(content) - } catch(e: RetryableException) { - ex = e - } finally { - lock.unlock() - } - } while (clock.awaitWithDeadline(deadline, ourMutated!!) && clock.instant().isBefore(deadline)) - throw ex - } - - @Suspendable - fun read(body: T.() -> R): R = lock.withLock { body(content) } - - @Suspendable - fun write(body: T.() -> R): R { - lock.lock() - try { - return body(content) - } finally { - mutated?.set(true) - lock.unlock() - } - } -} diff --git a/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt b/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt index 26df75020b..5b193ec2e8 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/JDBCHashMap.kt @@ -1,14 +1,13 @@ package net.corda.node.utilities +import net.corda.core.serialization.SerializationDefaults.STORAGE_CONTEXT import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.deserialize import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.core.utilities.loggerFor import net.corda.core.utilities.trace import org.jetbrains.exposed.sql.* import org.jetbrains.exposed.sql.statements.InsertStatement -import org.jetbrains.exposed.sql.transactions.TransactionManager import java.sql.Blob import java.util.* import kotlin.system.measureTimeMillis @@ -60,23 +59,24 @@ class JDBCHashMap(tableName: String, } fun bytesToBlob(value: SerializedBytes<*>, finalizables: MutableList<() -> Unit>): Blob { - val blob = TransactionManager.current().connection.createBlob() + val blob = DatabaseTransactionManager.current().connection.createBlob() finalizables += { blob.free() } blob.setBytes(1, value.bytes) return blob } -fun serializeToBlob(value: Any, finalizables: MutableList<() -> Unit>): Blob = bytesToBlob(value.serialize(storageKryo(), true), finalizables) +fun serializeToBlob(value: Any, finalizables: MutableList<() -> Unit>): Blob = bytesToBlob(value.serialize(context = STORAGE_CONTEXT), finalizables) fun bytesFromBlob(blob: Blob): SerializedBytes { try { - return SerializedBytes(blob.getBytes(0, blob.length().toInt()), true) + return SerializedBytes(blob.getBytes(0, blob.length().toInt())) } finally { blob.free() } } -fun deserializeFromBlob(blob: Blob): T = bytesFromBlob(blob).deserialize() +@Suppress("UNCHECKED_CAST") +fun deserializeFromBlob(blob: Blob): T = bytesFromBlob(blob).deserialize(context = STORAGE_CONTEXT) as T /** * A convenient JDBC table backed hash set with iteration order based on insertion order. diff --git a/core/src/main/kotlin/net/corda/core/crypto/KeyStoreUtilities.kt b/node/src/main/kotlin/net/corda/node/utilities/KeyStoreUtilities.kt similarity index 51% rename from core/src/main/kotlin/net/corda/core/crypto/KeyStoreUtilities.kt rename to node/src/main/kotlin/net/corda/node/utilities/KeyStoreUtilities.kt index 0c88ee2f27..b429f2e09e 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/KeyStoreUtilities.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/KeyStoreUtilities.kt @@ -1,71 +1,72 @@ -package net.corda.core.crypto +package net.corda.node.utilities -import net.corda.core.exists -import net.corda.core.read -import net.corda.core.write +import net.corda.core.crypto.* +import net.corda.core.internal.exists +import net.corda.core.internal.read +import net.corda.core.internal.write +import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.cert.X509CertificateHolder -import org.bouncycastle.cert.path.CertPath import java.io.IOException import java.io.InputStream import java.io.OutputStream import java.nio.file.Path import java.security.* +import java.security.cert.CertPath import java.security.cert.Certificate +import java.security.cert.CertificateFactory -object KeyStoreUtilities { - val KEYSTORE_TYPE = "JKS" +val KEYSTORE_TYPE = "JKS" - /** - * Helper method to either open an existing keystore for modification, or create a new blank keystore. - * @param keyStoreFilePath location of KeyStore file. - * @param storePassword password to open the store. This does not have to be the same password as any keys stored, - * but for SSL purposes this is recommended. - * @return returns the KeyStore opened/created. - */ - fun loadOrCreateKeyStore(keyStoreFilePath: Path, storePassword: String): KeyStore { - val pass = storePassword.toCharArray() - val keyStore = KeyStore.getInstance(KEYSTORE_TYPE) - if (keyStoreFilePath.exists()) { - keyStoreFilePath.read { keyStore.load(it, pass) } - } else { - keyStore.load(null, pass) - keyStoreFilePath.write { keyStore.store(it, pass) } - } - return keyStore +/** + * Helper method to either open an existing keystore for modification, or create a new blank keystore. + * @param keyStoreFilePath location of KeyStore file. + * @param storePassword password to open the store. This does not have to be the same password as any keys stored, + * but for SSL purposes this is recommended. + * @return returns the KeyStore opened/created. + */ +fun loadOrCreateKeyStore(keyStoreFilePath: Path, storePassword: String): KeyStore { + val pass = storePassword.toCharArray() + val keyStore = KeyStore.getInstance(KEYSTORE_TYPE) + if (keyStoreFilePath.exists()) { + keyStoreFilePath.read { keyStore.load(it, pass) } + } else { + keyStore.load(null, pass) + keyStoreFilePath.write { keyStore.store(it, pass) } } + return keyStore +} - /** - * Helper method to open an existing keystore for modification/read. - * @param keyStoreFilePath location of KeyStore file which must exist, or this will throw FileNotFoundException. - * @param storePassword password to open the store. This does not have to be the same password as any keys stored, - * but for SSL purposes this is recommended. - * @return returns the KeyStore opened. - * @throws IOException if there was an error reading the key store from the file. - * @throws KeyStoreException if the password is incorrect or the key store is damaged. - */ - @Throws(KeyStoreException::class, IOException::class) - fun loadKeyStore(keyStoreFilePath: Path, storePassword: String): KeyStore { - return keyStoreFilePath.read { loadKeyStore(it, storePassword) } - } +/** + * Helper method to open an existing keystore for modification/read. + * @param keyStoreFilePath location of KeyStore file which must exist, or this will throw FileNotFoundException. + * @param storePassword password to open the store. This does not have to be the same password as any keys stored, + * but for SSL purposes this is recommended. + * @return returns the KeyStore opened. + * @throws IOException if there was an error reading the key store from the file. + * @throws KeyStoreException if the password is incorrect or the key store is damaged. + */ +@Throws(KeyStoreException::class, IOException::class) +fun loadKeyStore(keyStoreFilePath: Path, storePassword: String): KeyStore { + return keyStoreFilePath.read { loadKeyStore(it, storePassword) } +} - /** - * Helper method to open an existing keystore for modification/read. - * @param input stream containing a KeyStore e.g. loaded from a resource file. - * @param storePassword password to open the store. This does not have to be the same password as any keys stored, - * but for SSL purposes this is recommended. - * @return returns the KeyStore opened. - * @throws IOException if there was an error reading the key store from the stream. - * @throws KeyStoreException if the password is incorrect or the key store is damaged. - */ - @Throws(KeyStoreException::class, IOException::class) - fun loadKeyStore(input: InputStream, storePassword: String): KeyStore { - val pass = storePassword.toCharArray() - val keyStore = KeyStore.getInstance(KEYSTORE_TYPE) - input.use { - keyStore.load(input, pass) - } - return keyStore +/** + * Helper method to open an existing keystore for modification/read. + * @param input stream containing a KeyStore e.g. loaded from a resource file. + * @param storePassword password to open the store. This does not have to be the same password as any keys stored, + * but for SSL purposes this is recommended. + * @return returns the KeyStore opened. + * @throws IOException if there was an error reading the key store from the stream. + * @throws KeyStoreException if the password is incorrect or the key store is damaged. + */ +@Throws(KeyStoreException::class, IOException::class) +fun loadKeyStore(input: InputStream, storePassword: String): KeyStore { + val pass = storePassword.toCharArray() + val keyStore = KeyStore.getInstance(KEYSTORE_TYPE) + input.use { + keyStore.load(input, pass) } + return keyStore } /** @@ -76,8 +77,8 @@ object KeyStoreUtilities { * but for SSL purposes this is recommended. * @param chain the sequence of certificates starting with the public key certificate for this key and extending to the root CA cert. */ -fun KeyStore.addOrReplaceKey(alias: String, key: Key, password: CharArray, chain: CertPath) { - addOrReplaceKey(alias, key, password, chain.certificates.map { it.cert }.toTypedArray()) +fun KeyStore.addOrReplaceKey(alias: String, key: Key, password: CharArray, chain: Array) { + addOrReplaceKey(alias, key, password, chain.map { it.cert }.toTypedArray()) } /** @@ -88,7 +89,7 @@ fun KeyStore.addOrReplaceKey(alias: String, key: Key, password: CharArray, chain * but for SSL purposes this is recommended. * @param chain the sequence of certificates starting with the public key certificate for this key and extending to the root CA cert. */ -fun KeyStore.addOrReplaceKey(alias: String, key: Key, password: CharArray, chain: Array) { +fun KeyStore.addOrReplaceKey(alias: String, key: Key, password: CharArray, chain: Array) { if (containsAlias(alias)) { this.deleteEntry(alias) } @@ -107,7 +108,6 @@ fun KeyStore.addOrReplaceCertificate(alias: String, cert: Certificate) { this.setCertificateEntry(alias, cert) } - /** * Helper method save KeyStore to storage. * @param keyStoreFilePath the file location to save to. @@ -118,7 +118,6 @@ fun KeyStore.save(keyStoreFilePath: Path, storePassword: String) = keyStoreFileP fun KeyStore.store(out: OutputStream, password: String) = store(out, password.toCharArray()) - /** * Extract public and private keys from a KeyStore file assuming storage alias is known. * @param alias The name to lookup the Key and Certificate chain from. @@ -146,7 +145,7 @@ fun KeyStore.getCertificateAndKeyPair(alias: String, keyPassword: String): Certi * @return The X509Certificate found in the KeyStore under the specified alias. */ fun KeyStore.getX509Certificate(alias: String): X509CertificateHolder { - val encoded = getCertificate(alias)?.encoded ?: throw IllegalArgumentException("No certificate under alias \"${alias}\"") + val encoded = getCertificate(alias)?.encoded ?: throw IllegalArgumentException("No certificate under alias \"$alias\"") return X509CertificateHolder(encoded) } @@ -169,3 +168,43 @@ fun KeyStore.getSupportedKey(alias: String, keyPassword: String): PrivateKey { val key = getKey(alias, keyPass) as PrivateKey return Crypto.toSupportedPrivateKey(key) } + +class KeyStoreWrapper(private val storePath: Path, private val storePassword: String) { + private val keyStore = storePath.read { loadKeyStore(it, storePassword) } + + private fun createCertificate(serviceName: X500Name, pubKey: PublicKey): CertPath { + val clientCertPath = keyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA) + // Assume key password = store password. + val clientCA = certificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA) + // Create new keys and store in keystore. + val cert = X509Utilities.createCertificate(CertificateType.IDENTITY, clientCA.certificate, clientCA.keyPair, serviceName, pubKey) + val certPath = CertificateFactory.getInstance("X509").generateCertPath(listOf(cert.cert) + clientCertPath) + require(certPath.certificates.isNotEmpty()) { "Certificate path cannot be empty" } + return certPath + } + + fun saveNewKeyPair(serviceName: X500Name, privateKeyAlias: String, keyPair: KeyPair) { + val certPath = createCertificate(serviceName, keyPair.public) + // Assume key password = store password. + keyStore.addOrReplaceKey(privateKeyAlias, keyPair.private, storePassword.toCharArray(), certPath.certificates.toTypedArray()) + keyStore.save(storePath, storePassword) + } + + fun savePublicKey(serviceName: X500Name, pubKeyAlias: String, pubKey: PublicKey) { + val certPath = createCertificate(serviceName, pubKey) + // Assume key password = store password. + keyStore.addOrReplaceCertificate(pubKeyAlias, certPath.certificates.first()) + keyStore.save(storePath, storePassword) + } + + // Delegate methods to keystore. Sadly keystore doesn't have an interface. + fun containsAlias(alias: String) = keyStore.containsAlias(alias) + + fun getX509Certificate(alias: String) = keyStore.getX509Certificate(alias) + + fun getCertificateChain(alias: String): Array = keyStore.getCertificateChain(alias) + + fun getCertificate(alias: String): Certificate = keyStore.getCertificate(alias) + + fun certificateAndKeyPair(alias: String) = keyStore.getCertificateAndKeyPair(alias, storePassword) +} diff --git a/node/src/main/kotlin/net/corda/node/utilities/NonInvalidatingCache.kt b/node/src/main/kotlin/net/corda/node/utilities/NonInvalidatingCache.kt new file mode 100644 index 0000000000..c456f8af3b --- /dev/null +++ b/node/src/main/kotlin/net/corda/node/utilities/NonInvalidatingCache.kt @@ -0,0 +1,33 @@ +package net.corda.node.utilities + +import com.google.common.cache.CacheBuilder +import com.google.common.cache.CacheLoader +import com.google.common.cache.LoadingCache +import com.google.common.util.concurrent.ListenableFuture + + +class NonInvalidatingCache private constructor( + val cache: LoadingCache +): LoadingCache by cache { + + constructor(bound: Long, concurrencyLevel: Int, loadFunction: (K) -> V) : + this(buildCache(bound, concurrencyLevel, loadFunction)) + + private companion object { + private fun buildCache(bound: Long, concurrencyLevel: Int, loadFunction: (K) -> V): LoadingCache { + val builder = CacheBuilder.newBuilder().maximumSize(bound).concurrencyLevel(concurrencyLevel) + return builder.build(NonInvalidatingCacheLoader(loadFunction)) + } + } + + // TODO look into overriding loadAll() if we ever use it + private class NonInvalidatingCacheLoader(val loadFunction: (K) -> V) : CacheLoader() { + override fun reload(key: K, oldValue: V): ListenableFuture { + throw IllegalStateException("Non invalidating cache refreshed") + } + override fun load(key: K) = loadFunction(key) + override fun loadAll(keys: Iterable): MutableMap { + return super.loadAll(keys) + } + } +} \ No newline at end of file diff --git a/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt b/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt index 2f2db9d09b..fc87226b3d 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/ServiceIdentityGenerator.kt @@ -3,8 +3,6 @@ package net.corda.node.utilities import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.generateKeyPair import net.corda.core.identity.Party -import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.core.utilities.loggerFor import net.corda.core.utilities.trace import org.bouncycastle.asn1.x500.X500Name @@ -33,16 +31,15 @@ object ServiceIdentityGenerator { val keyPairs = (1..dirs.size).map { generateKeyPair() } val notaryKey = CompositeKey.Builder().addKeys(keyPairs.map { it.public }).build(threshold) // Avoid adding complexity! This class is a hack that needs to stay runnable in the gradle environment. - val notaryParty = Party(serviceName, notaryKey) - val notaryPartyBytes = notaryParty.serialize() val privateKeyFile = "$serviceId-private-key" val publicKeyFile = "$serviceId-public" + val compositeKeyFile = "$serviceId-composite-key" keyPairs.zip(dirs) { keyPair, dir -> Files.createDirectories(dir) - notaryPartyBytes.writeToFile(dir.resolve(publicKeyFile)) - // Use storageKryo as our whitelist is not available in the gradle build environment: - keyPair.serialize(storageKryo()).writeToFile(dir.resolve(privateKeyFile)) + Files.write(dir.resolve(compositeKeyFile), notaryKey.encoded) + Files.write(dir.resolve(privateKeyFile), keyPair.private.encoded) + Files.write(dir.resolve(publicKeyFile), keyPair.public.encoded) } - return notaryParty + return Party(serviceName, notaryKey) } } diff --git a/node/src/main/kotlin/net/corda/node/utilities/TestClock.kt b/node/src/main/kotlin/net/corda/node/utilities/TestClock.kt index 3b0487d3f0..3d1c57d312 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/TestClock.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/TestClock.kt @@ -1,10 +1,13 @@ package net.corda.node.utilities +import net.corda.core.internal.until import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsTokenContext -import net.corda.core.serialization.SingletonSerializationToken import net.corda.core.serialization.SingletonSerializationToken.Companion.singletonSerializationToken -import java.time.* +import java.time.Clock +import java.time.Instant +import java.time.LocalDate +import java.time.ZoneId import javax.annotation.concurrent.ThreadSafe /** @@ -21,7 +24,7 @@ class TestClock(private var delegateClock: Clock = Clock.systemUTC()) : MutableC val currentDate = LocalDate.now(this) if (currentDate.isBefore(date)) { // It's ok to increment - delegateClock = Clock.offset(delegateClock, Duration.between(currentDate.atStartOfDay(), date.atStartOfDay())) + delegateClock = Clock.offset(delegateClock, currentDate.atStartOfDay() until date.atStartOfDay()) notifyMutationObservers() return true } diff --git a/core/src/main/kotlin/net/corda/core/crypto/X509Utilities.kt b/node/src/main/kotlin/net/corda/node/utilities/X509Utilities.kt similarity index 50% rename from core/src/main/kotlin/net/corda/core/crypto/X509Utilities.kt rename to node/src/main/kotlin/net/corda/node/utilities/X509Utilities.kt index 79973b514f..53ed5580eb 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/X509Utilities.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/X509Utilities.kt @@ -1,21 +1,30 @@ -package net.corda.core.crypto +package net.corda.node.utilities -import net.corda.core.crypto.Crypto.generateKeyPair -import org.bouncycastle.asn1.ASN1Encodable +import net.corda.core.crypto.* +import net.corda.core.utilities.days +import net.corda.core.utilities.millis +import org.bouncycastle.asn1.ASN1EncodableVector +import org.bouncycastle.asn1.ASN1Sequence +import org.bouncycastle.asn1.DERSequence import org.bouncycastle.asn1.x500.X500Name -import org.bouncycastle.asn1.x500.X500NameBuilder -import org.bouncycastle.asn1.x500.style.BCStyle import org.bouncycastle.asn1.x509.* +import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.cert.X509CertificateHolder -import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter +import org.bouncycastle.cert.X509v3CertificateBuilder +import org.bouncycastle.cert.bc.BcX509ExtensionUtils +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder import org.bouncycastle.openssl.jcajce.JcaPEMWriter +import org.bouncycastle.operator.ContentSigner +import org.bouncycastle.operator.jcajce.JcaContentVerifierProviderBuilder +import org.bouncycastle.pkcs.PKCS10CertificationRequest +import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder import org.bouncycastle.util.io.pem.PemReader import java.io.FileReader import java.io.FileWriter import java.io.InputStream +import java.math.BigInteger import java.nio.file.Path import java.security.KeyPair -import java.security.KeyStore import java.security.PublicKey import java.security.cert.* import java.security.cert.Certificate @@ -34,7 +43,7 @@ object X509Utilities { val CORDA_CLIENT_TLS = "cordaclienttls" val CORDA_CLIENT_CA = "cordaclientca" - private val DEFAULT_VALIDITY_WINDOW = Pair(Duration.ofMillis(0), Duration.ofDays(365 * 10)) + private val DEFAULT_VALIDITY_WINDOW = Pair(0.millis, 3650.days) /** * Helper function to return the latest out of an instant and an optional date. */ @@ -68,44 +77,13 @@ object X509Utilities { return Pair(notBefore, notAfter) } - /** - * Return a bogus X509 for dev purposes. Use [getX509Name] for something more real. - */ - @Deprecated("Full legal names should be specified in all configurations") - fun getDevX509Name(commonName: String): X500Name { - val nameBuilder = X500NameBuilder(BCStyle.INSTANCE) - nameBuilder.addRDN(BCStyle.CN, commonName) - nameBuilder.addRDN(BCStyle.O, "R3") - nameBuilder.addRDN(BCStyle.OU, "corda") - nameBuilder.addRDN(BCStyle.L, "London") - nameBuilder.addRDN(BCStyle.C, "GB") - return nameBuilder.build() - } - - /** - * Generate a distinguished name from the provided values. - * - * @see [CoreTestUtils.getTestX509Name] for generating distinguished names for test cases. - */ - @JvmOverloads - @JvmStatic - fun getX509Name(myLegalName: String, nearestCity: String, email: String, country: String? = null): X500Name { - return X500NameBuilder(BCStyle.INSTANCE).let { builder -> - builder.addRDN(BCStyle.CN, myLegalName) - builder.addRDN(BCStyle.L, nearestCity) - country?.let { builder.addRDN(BCStyle.C, it) } - builder.addRDN(BCStyle.E, email) - builder.build() - } - } - /* * Create a de novo root self-signed X509 v3 CA cert. */ @JvmStatic fun createSelfSignedCACertificate(subject: X500Name, keyPair: KeyPair, validityWindow: Pair = DEFAULT_VALIDITY_WINDOW): X509CertificateHolder { val window = getCertificateValidityWindow(validityWindow.first, validityWindow.second) - return Crypto.createCertificate(CertificateType.ROOT_CA, subject, keyPair, subject, keyPair.public, window) + return createCertificate(CertificateType.ROOT_CA, subject, keyPair, subject, keyPair.public, window) } /** @@ -125,7 +103,7 @@ object X509Utilities { validityWindow: Pair = DEFAULT_VALIDITY_WINDOW, nameConstraints: NameConstraints? = null): X509CertificateHolder { val window = getCertificateValidityWindow(validityWindow.first, validityWindow.second, issuerCertificate) - return Crypto.createCertificate(certificateType, issuerCertificate.subject, issuerKeyPair, subject, subjectPublicKey, window, nameConstraints) + return createCertificate(certificateType, issuerCertificate.subject, issuerKeyPair, subject, subjectPublicKey, window, nameConstraints) } fun validateCertificateChain(trustedRoot: X509CertificateHolder, vararg certificates: Certificate) { @@ -168,102 +146,92 @@ object X509Utilities { } /** - * An all in wrapper to manufacture a server certificate and keys all stored in a KeyStore suitable for running TLS on the local machine. - * @param sslKeyStorePath KeyStore path to save ssl key and cert to. - * @param clientCAKeystorePath KeyStore path to save client CA key and cert to. - * @param storePassword access password for KeyStore. - * @param keyPassword PrivateKey access password for the generated keys. - * It is recommended that this is the same as the storePassword as most TLS libraries assume they are the same. - * @param caKeyStore KeyStore containing CA keys generated by createCAKeyStoreAndTrustStore. - * @param caKeyPassword password to unlock private keys in the CA KeyStore. - * @return The KeyStore created containing a private key, certificate chain and root CA public cert for use in TLS applications. + * Build a partial X.509 certificate ready for signing. + * + * @param issuer name of the issuing entity. + * @param subject name of the certificate subject. + * @param subjectPublicKey public key of the certificate subject. + * @param validityWindow the time period the certificate is valid for. + * @param nameConstraints any name constraints to impose on certificates signed by the generated certificate. */ - fun createKeystoreForCordaNode(sslKeyStorePath: Path, - clientCAKeystorePath: Path, - storePassword: String, - keyPassword: String, - caKeyStore: KeyStore, - caKeyPassword: String, - legalName: X500Name, - signatureScheme: SignatureScheme = DEFAULT_TLS_SIGNATURE_SCHEME) { + fun createCertificate(certificateType: CertificateType, issuer: X500Name, + subject: X500Name, subjectPublicKey: PublicKey, + validityWindow: Pair, + nameConstraints: NameConstraints? = null): X509v3CertificateBuilder { - val rootCACert = caKeyStore.getX509Certificate(CORDA_ROOT_CA) - val (intermediateCACert, intermediateCAKeyPair) = caKeyStore.getCertificateAndKeyPair(CORDA_INTERMEDIATE_CA, caKeyPassword) + val serial = BigInteger.valueOf(random63BitValue()) + val keyPurposes = DERSequence(ASN1EncodableVector().apply { certificateType.purposes.forEach { add(it) } }) + val subjectPublicKeyInfo = SubjectPublicKeyInfo.getInstance(ASN1Sequence.getInstance(subjectPublicKey.encoded)) - val clientKey = generateKeyPair(signatureScheme) - val nameConstraints = NameConstraints(arrayOf(GeneralSubtree(GeneralName(GeneralName.directoryName, legalName))), arrayOf()) - val clientCACert = createCertificate(CertificateType.INTERMEDIATE_CA, intermediateCACert, intermediateCAKeyPair, legalName, clientKey.public, nameConstraints = nameConstraints) + val builder = JcaX509v3CertificateBuilder(issuer, serial, validityWindow.first, validityWindow.second, subject, subjectPublicKey) + .addExtension(Extension.subjectKeyIdentifier, false, BcX509ExtensionUtils().createSubjectKeyIdentifier(subjectPublicKeyInfo)) + .addExtension(Extension.basicConstraints, certificateType.isCA, BasicConstraints(certificateType.isCA)) + .addExtension(Extension.keyUsage, false, certificateType.keyUsage) + .addExtension(Extension.extendedKeyUsage, false, keyPurposes) - val tlsKey = generateKeyPair(signatureScheme) - val clientTLSCert = createCertificate(CertificateType.TLS, clientCACert, clientKey, legalName, tlsKey.public) - - val keyPass = keyPassword.toCharArray() - - val clientCAKeystore = KeyStoreUtilities.loadOrCreateKeyStore(clientCAKeystorePath, storePassword) - clientCAKeystore.addOrReplaceKey( - CORDA_CLIENT_CA, - clientKey.private, - keyPass, - org.bouncycastle.cert.path.CertPath(arrayOf(clientCACert, intermediateCACert, rootCACert))) - clientCAKeystore.save(clientCAKeystorePath, storePassword) - - val tlsKeystore = KeyStoreUtilities.loadOrCreateKeyStore(sslKeyStorePath, storePassword) - tlsKeystore.addOrReplaceKey( - CORDA_CLIENT_TLS, - tlsKey.private, - keyPass, - org.bouncycastle.cert.path.CertPath(arrayOf(clientTLSCert, clientCACert, intermediateCACert, rootCACert))) - tlsKeystore.save(sslKeyStorePath, storePassword) + if (nameConstraints != null) { + builder.addExtension(Extension.nameConstraints, true, nameConstraints) + } + return builder } - fun createCertificateSigningRequest(subject: X500Name, keyPair: KeyPair, signatureScheme: SignatureScheme = DEFAULT_TLS_SIGNATURE_SCHEME) = Crypto.createCertificateSigningRequest(subject, keyPair, signatureScheme) -} - -/** - * Rebuild the distinguished name, adding a postfix to the common name. If no common name is present. - * @throws IllegalArgumentException if the distinguished name does not contain a common name element. - */ -fun X500Name.appendToCommonName(commonName: String): X500Name = mutateCommonName { attr -> attr.toString() + commonName } - -/** - * Rebuild the distinguished name, replacing the common name with the given value. If no common name is present, this - * adds one. - * @throws IllegalArgumentException if the distinguished name does not contain a common name element. - */ -fun X500Name.replaceCommonName(commonName: String): X500Name = mutateCommonName { _ -> commonName } - -/** - * Rebuild the distinguished name, replacing the common name with a value generated from the provided function. - * - * @param mutator a function to generate the new value from the previous one. - * @throws IllegalArgumentException if the distinguished name does not contain a common name element. - */ -private fun X500Name.mutateCommonName(mutator: (ASN1Encodable) -> String): X500Name { - val builder = X500NameBuilder(BCStyle.INSTANCE) - var matched = false - this.rdNs.forEach { rdn -> - rdn.typesAndValues.forEach { typeAndValue -> - when (typeAndValue.type) { - BCStyle.CN -> { - matched = true - builder.addRDN(typeAndValue.type, mutator(typeAndValue.value)) - } - else -> { - builder.addRDN(typeAndValue) - } - } + /** + * Build and sign an X.509 certificate with the given signer. + * + * @param issuer name of the issuing entity. + * @param issuerSigner content signer to sign the certificate with. + * @param subject name of the certificate subject. + * @param subjectPublicKey public key of the certificate subject. + * @param validityWindow the time period the certificate is valid for. + * @param nameConstraints any name constraints to impose on certificates signed by the generated certificate. + */ + fun createCertificate(certificateType: CertificateType, issuer: X500Name, issuerSigner: ContentSigner, + subject: X500Name, subjectPublicKey: PublicKey, + validityWindow: Pair, + nameConstraints: NameConstraints? = null): X509CertificateHolder { + val builder = createCertificate(certificateType, issuer, subject, subjectPublicKey, validityWindow, nameConstraints) + return builder.build(issuerSigner).apply { + require(isValidOn(Date())) } } - require(matched) { "Input X.500 name must include a common name (CN) attribute: ${this}" } - return builder.build() + + /** + * Build and sign an X.509 certificate with CA cert private key. + * + * @param issuer name of the issuing entity. + * @param issuerKeyPair the public & private key to sign the certificate with. + * @param subject name of the certificate subject. + * @param subjectPublicKey public key of the certificate subject. + * @param validityWindow the time period the certificate is valid for. + * @param nameConstraints any name constraints to impose on certificates signed by the generated certificate. + */ + fun createCertificate(certificateType: CertificateType, issuer: X500Name, issuerKeyPair: KeyPair, + subject: X500Name, subjectPublicKey: PublicKey, + validityWindow: Pair, + nameConstraints: NameConstraints? = null): X509CertificateHolder { + + val signatureScheme = Crypto.findSignatureScheme(issuerKeyPair.private) + val provider = Crypto.providerMap[signatureScheme.providerName] + val builder = createCertificate(certificateType, issuer, subject, subjectPublicKey, validityWindow, nameConstraints) + + val signer = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider) + return builder.build(signer).apply { + require(isValidOn(Date())) + require(isSignatureValid(JcaContentVerifierProviderBuilder().build(issuerKeyPair.public))) + } + } + + /** + * Create certificate signing request using provided information. + */ + fun createCertificateSigningRequest(subject: X500Name, keyPair: KeyPair, signatureScheme: SignatureScheme): PKCS10CertificationRequest { + val signer = ContentSignerBuilder.build(signatureScheme, keyPair.private, Crypto.providerMap[signatureScheme.providerName]) + return JcaPKCS10CertificationRequestBuilder(subject, keyPair.public).build(signer) + } + + fun createCertificateSigningRequest(subject: X500Name, keyPair: KeyPair) = createCertificateSigningRequest(subject, keyPair, DEFAULT_TLS_SIGNATURE_SCHEME) } -val X500Name.commonName: String get() = getRDNs(BCStyle.CN).first().first.value.toString() -val X500Name.orgName: String? get() = getRDNs(BCStyle.O).firstOrNull()?.first?.value?.toString() -val X500Name.location: String get() = getRDNs(BCStyle.L).first().first.value.toString() -val X500Name.locationOrNull: String? get() = try { location } catch (e: Exception) { null } -val X509Certificate.subject: X500Name get() = X509CertificateHolder(encoded).subject -val X509CertificateHolder.cert: X509Certificate get() = JcaX509CertificateConverter().getCertificate(this) class CertificateStream(val input: InputStream) { private val certificateFactory = CertificateFactory.getInstance("X.509") @@ -271,8 +239,6 @@ class CertificateStream(val input: InputStream) { fun nextCertificate(): X509Certificate = certificateFactory.generateCertificate(input) as X509Certificate } -data class CertificateAndKeyPair(val certificate: X509CertificateHolder, val keyPair: KeyPair) - enum class CertificateType(val keyUsage: KeyUsage, vararg val purposes: KeyPurposeId, val isCA: Boolean) { ROOT_CA(KeyUsage(KeyUsage.digitalSignature or KeyUsage.keyCertSign or KeyUsage.cRLSign), KeyPurposeId.id_kp_serverAuth, KeyPurposeId.id_kp_clientAuth, KeyPurposeId.anyExtendedKeyUsage, isCA = true), INTERMEDIATE_CA(KeyUsage(KeyUsage.digitalSignature or KeyUsage.keyCertSign or KeyUsage.cRLSign), KeyPurposeId.id_kp_serverAuth, KeyPurposeId.id_kp_clientAuth, KeyPurposeId.anyExtendedKeyUsage, isCA = true), diff --git a/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt b/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt index ec803884b2..8ddb347d64 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt @@ -1,7 +1,7 @@ package net.corda.node.utilities.registration import com.google.common.net.MediaType -import net.corda.core.crypto.CertificateStream +import net.corda.node.utilities.CertificateStream import org.apache.commons.io.IOUtils import org.bouncycastle.pkcs.PKCS10CertificationRequest import java.io.IOException diff --git a/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt b/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt index 9701203d45..268b89fa53 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/registration/NetworkRegistrationHelper.kt @@ -1,24 +1,29 @@ package net.corda.node.utilities.registration -import net.corda.core.* -import net.corda.core.crypto.* -import net.corda.core.crypto.X509Utilities.CORDA_CLIENT_CA -import net.corda.core.crypto.X509Utilities.CORDA_CLIENT_TLS -import net.corda.core.crypto.X509Utilities.CORDA_ROOT_CA +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.cert +import net.corda.core.internal.* +import net.corda.core.utilities.seconds +import net.corda.core.utilities.validateX500Name import net.corda.node.services.config.NodeConfiguration -import org.bouncycastle.cert.path.CertPath +import net.corda.node.utilities.* +import net.corda.node.utilities.X509Utilities.CORDA_CLIENT_CA +import net.corda.node.utilities.X509Utilities.CORDA_CLIENT_TLS +import net.corda.node.utilities.X509Utilities.CORDA_ROOT_CA import org.bouncycastle.openssl.jcajce.JcaPEMWriter import org.bouncycastle.util.io.pem.PemObject import java.io.StringWriter import java.security.KeyPair +import java.security.KeyStore import java.security.cert.Certificate import kotlin.system.exitProcess /** * This checks the config.certificatesDirectory field for certificates required to connect to a Corda network. * If the certificates are not found, a [org.bouncycastle.pkcs.PKCS10CertificationRequest] will be submitted to - * Corda network permissioning server using [NetworkRegistrationService]. This process will enter a polling loop until the request has been approved, and then - * the certificate chain will be downloaded and stored in [Keystore] reside in the certificates directory. + * Corda network permissioning server using [NetworkRegistrationService]. This process will enter a polling loop until + * the request has been approved, and then the certificate chain will be downloaded and stored in [KeyStore] reside in + * the certificates directory. */ class NetworkRegistrationHelper(val config: NodeConfiguration, val certService: NetworkRegistrationService) { companion object { @@ -32,8 +37,9 @@ class NetworkRegistrationHelper(val config: NodeConfiguration, val certService: private val privateKeyPassword = config.keyStorePassword fun buildKeystore() { + validateX500Name(config.myLegalName) config.certificatesDirectory.createDirectories() - val caKeyStore = KeyStoreUtilities.loadOrCreateKeyStore(config.nodeKeystore, keystorePassword) + val caKeyStore = loadOrCreateKeyStore(config.nodeKeystore, keystorePassword) if (!caKeyStore.containsAlias(CORDA_CLIENT_CA)) { // Create or load self signed keypair from the key store. // We use the self sign certificate to store the key temporarily in the keystore while waiting for the request approval. @@ -42,7 +48,7 @@ class NetworkRegistrationHelper(val config: NodeConfiguration, val certService: val selfSignCert = X509Utilities.createSelfSignedCACertificate(config.myLegalName, keyPair) // Save to the key store. caKeyStore.addOrReplaceKey(SELF_SIGNED_PRIVATE_KEY, keyPair.private, privateKeyPassword.toCharArray(), - CertPath(arrayOf(selfSignCert))) + arrayOf(selfSignCert)) caKeyStore.save(config.nodeKeystore, keystorePassword) } val keyPair = caKeyStore.getKeyPair(SELF_SIGNED_PRIVATE_KEY, privateKeyPassword) @@ -64,7 +70,7 @@ class NetworkRegistrationHelper(val config: NodeConfiguration, val certService: caKeyStore.deleteEntry(SELF_SIGNED_PRIVATE_KEY) caKeyStore.save(config.nodeKeystore, keystorePassword) // Save root certificates to trust store. - val trustStore = KeyStoreUtilities.loadOrCreateKeyStore(config.trustStoreFile, config.trustStorePassword) + val trustStore = loadOrCreateKeyStore(config.trustStoreFile, config.trustStorePassword) // Assumes certificate chain always starts with client certificate and end with root certificate. trustStore.addOrReplaceCertificate(CORDA_ROOT_CA, certificates.last()) trustStore.save(config.trustStoreFile, config.trustStorePassword) @@ -74,7 +80,7 @@ class NetworkRegistrationHelper(val config: NodeConfiguration, val certService: val sslKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) val caCert = caKeyStore.getX509Certificate(CORDA_CLIENT_CA) val sslCert = X509Utilities.createCertificate(CertificateType.TLS, caCert, keyPair, caCert.subject, sslKey.public) - val sslKeyStore = KeyStoreUtilities.loadOrCreateKeyStore(config.sslKeystore, keystorePassword) + val sslKeyStore = loadOrCreateKeyStore(config.sslKeystore, keystorePassword) sslKeyStore.addOrReplaceKey(CORDA_CLIENT_TLS, sslKey.private, privateKeyPassword.toCharArray(), arrayOf(sslCert.cert, *certificates)) sslKeyStore.save(config.sslKeystore, config.keyStorePassword) diff --git a/node/src/main/resources/reference.conf b/node/src/main/resources/reference.conf index 7d9f8a2b42..11e2269fce 100644 --- a/node/src/main/resources/reference.conf +++ b/node/src/main/resources/reference.conf @@ -9,9 +9,17 @@ dataSourceProperties = { "dataSource.user" = sa "dataSource.password" = "" } +database = { + transactionIsolationLevel = "repeatableRead" + initDatabase = true +} devMode = true certificateSigningService = "https://cordaci-netperm.corda.r3cev.com" useHTTPS = false h2port = 0 useTestClock = false -verifierType = InMemory \ No newline at end of file +verifierType = InMemory +bftSMaRt = { + replicaId = -1 + debug = false +} diff --git a/node/src/smoke-test/kotlin/net/corda/node/CordappScanningNodeProcessTest.kt b/node/src/smoke-test/kotlin/net/corda/node/CordappScanningNodeProcessTest.kt deleted file mode 100644 index 87b8faad10..0000000000 --- a/node/src/smoke-test/kotlin/net/corda/node/CordappScanningNodeProcessTest.kt +++ /dev/null @@ -1,52 +0,0 @@ -package net.corda.node - -import net.corda.core.copyToDirectory -import net.corda.core.createDirectories -import net.corda.core.div -import net.corda.nodeapi.User -import net.corda.smoketesting.NodeConfig -import net.corda.smoketesting.NodeProcess -import org.assertj.core.api.Assertions.assertThat -import org.bouncycastle.asn1.x500.X500Name -import org.junit.Test -import java.nio.file.Paths -import java.util.concurrent.atomic.AtomicInteger - -class CordappScanningNodeProcessTest { - private companion object { - val user = User("user1", "test", permissions = setOf("ALL")) - val port = AtomicInteger(15100) - } - - private val factory = NodeProcess.Factory() - - private val aliceConfig = NodeConfig( - legalName = X500Name("CN=Alice Corp,O=Alice Corp,L=Madrid,C=ES"), - p2pPort = port.andIncrement, - rpcPort = port.andIncrement, - webPort = port.andIncrement, - extraServices = emptyList(), - users = listOf(user) - ) - - @Test - fun `CorDapp jar in plugins directory is scanned`() { - // If the CorDapp jar does't exist then run the smokeTestClasses gradle task - val cordappJar = Paths.get(javaClass.getResource("/trader-demo.jar").toURI()) - val pluginsDir = (factory.baseDirectory(aliceConfig) / "plugins").createDirectories() - cordappJar.copyToDirectory(pluginsDir) - - factory.create(aliceConfig).use { - it.connect().use { - // If the CorDapp wasn't scanned then SellerFlow won't have been picked up as an RPC flow - assertThat(it.proxy.registeredFlows()).contains("net.corda.traderdemo.flow.SellerFlow") - } - } - } - - @Test - fun `empty plugins directory`() { - (factory.baseDirectory(aliceConfig) / "plugins").createDirectories() - factory.create(aliceConfig).close() - } -} diff --git a/node/src/smoke-test/kotlin/net/corda/node/CordappSmokeTest.kt b/node/src/smoke-test/kotlin/net/corda/node/CordappSmokeTest.kt new file mode 100644 index 0000000000..4093f6b935 --- /dev/null +++ b/node/src/smoke-test/kotlin/net/corda/node/CordappSmokeTest.kt @@ -0,0 +1,76 @@ +package net.corda.node + +import co.paralleluniverse.fibers.Suspendable +import net.corda.core.flows.* +import net.corda.core.identity.Party +import net.corda.core.internal.copyToDirectory +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div +import net.corda.core.internal.list +import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow +import net.corda.nodeapi.User +import net.corda.smoketesting.NodeConfig +import net.corda.smoketesting.NodeProcess +import org.assertj.core.api.Assertions.assertThat +import org.bouncycastle.asn1.x500.X500Name +import org.junit.Test +import java.nio.file.Paths +import java.util.concurrent.atomic.AtomicInteger +import kotlin.streams.toList + +class CordappSmokeTest { + private companion object { + val user = User("user1", "test", permissions = setOf("ALL")) + val port = AtomicInteger(15100) + } + + private val factory = NodeProcess.Factory() + + private val aliceConfig = NodeConfig( + legalName = X500Name("CN=Alice Corp,O=Alice Corp,L=Madrid,C=ES"), + p2pPort = port.andIncrement, + rpcPort = port.andIncrement, + webPort = port.andIncrement, + extraServices = emptyList(), + users = listOf(user) + ) + + @Test + fun `FlowContent appName returns the filename of the CorDapp jar`() { + val pluginsDir = (factory.baseDirectory(aliceConfig) / "plugins").createDirectories() + // Find the jar file for the smoke tests of this module + val selfCorDapp = Paths.get("build", "libs").list { + it.filter { "-smoke-test" in it.toString() }.toList().single() + } + selfCorDapp.copyToDirectory(pluginsDir) + + factory.create(aliceConfig).use { alice -> + alice.connect().use { connectionToAlice -> + val aliceIdentity = connectionToAlice.proxy.nodeIdentity().legalIdentity + val future = connectionToAlice.proxy.startFlow(::DummyInitiatingFlow, aliceIdentity).returnValue + assertThat(future.getOrThrow().appName).isEqualTo(selfCorDapp.fileName.toString().removeSuffix(".jar")) + } + } + } + + @Test + fun `empty plugins directory`() { + (factory.baseDirectory(aliceConfig) / "plugins").createDirectories() + factory.create(aliceConfig).close() + } + + @InitiatingFlow + @StartableByRPC + class DummyInitiatingFlow(val otherParty: Party) : FlowLogic() { + @Suspendable + override fun call() = getFlowContext(otherParty) + } + + @Suppress("unused") + @InitiatedBy(DummyInitiatingFlow::class) + class DummyInitiatedFlow(val otherParty: Party) : FlowLogic() { + @Suspendable + override fun call() = Unit + } +} diff --git a/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java b/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java index f5b44a3487..1d2cbb9995 100644 --- a/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java +++ b/node/src/test/java/net/corda/node/services/vault/VaultQueryJavaTests.java @@ -1,97 +1,73 @@ package net.corda.node.services.vault; -import com.google.common.collect.*; -import kotlin.*; -import net.corda.contracts.*; -import net.corda.contracts.asset.*; +import com.google.common.collect.ImmutableSet; +import kotlin.Pair; +import net.corda.contracts.DealState; +import net.corda.contracts.asset.Cash; import net.corda.core.contracts.*; -import net.corda.core.crypto.*; -import net.corda.core.identity.*; -import net.corda.core.messaging.*; -import net.corda.core.node.services.*; +import net.corda.core.crypto.EncodingUtils; +import net.corda.core.identity.AbstractParty; +import net.corda.core.messaging.DataFeed; +import net.corda.core.node.services.Vault; +import net.corda.core.node.services.VaultQueryException; +import net.corda.core.node.services.VaultQueryService; import net.corda.core.node.services.vault.*; -import net.corda.core.node.services.vault.QueryCriteria.*; -import net.corda.core.schemas.*; -import net.corda.core.transactions.*; -import net.corda.core.utilities.*; -import net.corda.node.services.database.*; -import net.corda.node.services.schema.*; -import net.corda.schemas.*; -import net.corda.testing.*; -import net.corda.testing.contracts.*; -import net.corda.testing.node.*; -import net.corda.testing.schemas.DummyLinearStateSchemaV1; -import org.jetbrains.annotations.*; -import org.jetbrains.exposed.sql.*; -import org.junit.*; +import net.corda.core.node.services.vault.QueryCriteria.LinearStateQueryCriteria; +import net.corda.core.node.services.vault.QueryCriteria.VaultCustomQueryCriteria; +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria; +import net.corda.core.utilities.OpaqueBytes; +import net.corda.node.utilities.CordaPersistence; +import net.corda.schemas.CashSchemaV1; +import net.corda.testing.TestConstants; +import net.corda.testing.TestDependencyInjectionBase; +import net.corda.testing.contracts.DummyLinearContract; +import net.corda.testing.contracts.VaultFiller; +import net.corda.testing.node.MockServices; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; import rx.Observable; -import java.io.*; -import java.lang.reflect.*; +import java.io.IOException; +import java.lang.reflect.Field; +import java.security.KeyPair; import java.util.*; -import java.util.stream.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; -import static net.corda.contracts.asset.CashKt.*; -import static net.corda.core.contracts.ContractsDSL.*; -import static net.corda.core.node.services.vault.QueryCriteriaUtils.*; -import static net.corda.node.utilities.DatabaseSupportKt.*; -import static net.corda.node.utilities.DatabaseSupportKt.transaction; -import static net.corda.testing.CoreTestUtils.*; -import static net.corda.testing.node.MockServicesKt.*; +import static net.corda.contracts.asset.CashKt.getDUMMY_CASH_ISSUER; +import static net.corda.contracts.asset.CashKt.getDUMMY_CASH_ISSUER_KEY; +import static net.corda.core.node.services.vault.QueryCriteriaUtils.DEFAULT_PAGE_NUM; +import static net.corda.core.node.services.vault.QueryCriteriaUtils.MAX_PAGE_SIZE; import static net.corda.core.utilities.ByteArrays.toHexString; -import static org.assertj.core.api.Assertions.*; +import static net.corda.testing.CoreTestUtils.*; +import static net.corda.testing.TestConstants.*; +import static net.corda.testing.node.MockServicesKt.makeTestDatabaseAndMockServices; +import static org.assertj.core.api.Assertions.assertThat; -public class VaultQueryJavaTests { +public class VaultQueryJavaTests extends TestDependencyInjectionBase { private MockServices services; - private VaultService vaultSvc; + private MockServices issuerServices; private VaultQueryService vaultQuerySvc; - private Closeable dataSource; - private Database database; + private CordaPersistence database; @Before public void setUp() { - Properties dataSourceProps = makeTestDataSourceProperties(SecureHash.randomSHA256().toString()); - Pair dataSourceAndDatabase = configureDatabase(dataSourceProps); - dataSource = dataSourceAndDatabase.getFirst(); - database = dataSourceAndDatabase.getSecond(); - - Set customSchemas = new HashSet<>(Collections.singletonList(DummyLinearStateSchemaV1.INSTANCE)); - HibernateConfiguration hibernateConfig = new HibernateConfiguration(new NodeSchemaService(customSchemas)); - transaction(database, - statement -> { services = new MockServices(getMEGA_CORP_KEY()) { - @NotNull - @Override - public VaultService getVaultService() { - return makeVaultService(dataSourceProps, hibernateConfig); - } - - @NotNull - @Override - public VaultQueryService getVaultQueryService() { - return new HibernateVaultQueryImpl(hibernateConfig, getVaultService().getUpdatesPublisher()); - } - - @Override - public void recordTransactions(@NotNull Iterable txs) { - for (SignedTransaction stx : txs) { - getValidatedTransactions().addTransaction(stx); - } - - Stream wtxn = StreamSupport.stream(txs.spliterator(), false).map(SignedTransaction::getTx); - getVaultService().notifyAll(wtxn.collect(Collectors.toList())); - } - }; - vaultSvc = services.getVaultService(); - vaultQuerySvc = services.getVaultQueryService(); - - return services; - }); + ArrayList keys = new ArrayList<>(); + keys.add(getMEGA_CORP_KEY()); + keys.add(getDUMMY_NOTARY_KEY()); + Pair databaseAndServices = makeTestDatabaseAndMockServices(Collections.EMPTY_SET, keys); + issuerServices = new MockServices(getDUMMY_CASH_ISSUER_KEY(), getBOC_KEY()); + database = databaseAndServices.getFirst(); + services = databaseAndServices.getSecond(); + vaultQuerySvc = services.getVaultQueryService(); } @After public void cleanUp() throws IOException { - dataSource.close(); + database.close(); } /** @@ -104,7 +80,7 @@ public class VaultQueryJavaTests { @Test public void unconsumedLinearStates() throws VaultQueryException { - transaction(database, tx -> { + database.transaction(tx -> { VaultFiller.fillWithSomeTestLinearStates(services, 3); @@ -120,7 +96,7 @@ public class VaultQueryJavaTests { @Test public void unconsumedStatesForStateRefsSortedByTxnId() { - transaction(database, tx -> { + database.transaction(tx -> { VaultFiller.fillWithSomeTestLinearStates(services, 8); Vault issuedStates = VaultFiller.fillWithSomeTestLinearStates(services, 2); @@ -145,22 +121,22 @@ public class VaultQueryJavaTests { @Test public void consumedCashStates() { - transaction(database, tx -> { + database.transaction(tx -> { Amount amount = new Amount<>(100, Currency.getInstance("USD")); VaultFiller.fillWithSomeTestCash(services, new Amount<>(100, Currency.getInstance("USD")), + issuerServices, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(), new OpaqueBytes("1".getBytes()), null, - getDUMMY_CASH_ISSUER(), - getDUMMY_CASH_ISSUER_KEY() ); + getDUMMY_CASH_ISSUER()); - VaultFiller.consumeCash(services, amount); + VaultFiller.consumeCash(services, amount, getDUMMY_NOTARY()); // DOCSTART VaultJavaQueryExample1 VaultQueryCriteria criteria = new VaultQueryCriteria(Vault.StateStatus.CONSUMED); @@ -175,7 +151,7 @@ public class VaultQueryJavaTests { @Test public void consumedDealStatesPagedSorted() throws VaultQueryException { - transaction(database, tx -> { + database.transaction(tx -> { Vault states = VaultFiller.fillWithSomeTestLinearStates(services, 10, null); StateAndRef linearState = states.getStates().iterator().next(); @@ -185,8 +161,8 @@ public class VaultQueryJavaTests { Vault dealStates = VaultFiller.fillWithSomeTestDeals(services, dealIds); // consume states - VaultFiller.consumeDeals(services, (List>) dealStates.getStates()); - VaultFiller.consumeLinearStates(services, Collections.singletonList(linearState)); + VaultFiller.consumeDeals(services, (List>) dealStates.getStates(), getDUMMY_NOTARY()); + VaultFiller.consumeLinearStates(services, Collections.singletonList(linearState), getDUMMY_NOTARY()); // DOCSTART VaultJavaQueryExample2 Vault.StateStatus status = Vault.StateStatus.CONSUMED; @@ -195,7 +171,7 @@ public class VaultQueryJavaTests { QueryCriteria vaultCriteria = new VaultQueryCriteria(status, contractStateTypes); - List linearIds = Collections.singletonList(uid); + List linearIds = Collections.singletonList(uid.getId()); QueryCriteria linearCriteriaAll = new LinearStateQueryCriteria(null, linearIds); QueryCriteria dealCriteriaAll = new LinearStateQueryCriteria(null, null, dealIds); @@ -217,17 +193,17 @@ public class VaultQueryJavaTests { @Test @SuppressWarnings("unchecked") public void customQueryForCashStatesWithAmountOfCurrencyGreaterOrEqualThanQuantity() { - transaction(database, tx -> { + database.transaction(tx -> { Amount pounds = new Amount<>(100, Currency.getInstance("GBP")); Amount dollars100 = new Amount<>(100, Currency.getInstance("USD")); Amount dollars10 = new Amount<>(10, Currency.getInstance("USD")); Amount dollars1 = new Amount<>(1, Currency.getInstance("USD")); - VaultFiller.fillWithSomeTestCash(services, pounds, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars100, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars10, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars1, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); + VaultFiller.fillWithSomeTestCash(services, pounds, issuerServices, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars100, issuerServices, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars10, issuerServices, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars1, issuerServices, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); try { // DOCSTART VaultJavaQueryExample3 @@ -261,27 +237,27 @@ public class VaultQueryJavaTests { @Test public void trackCashStates() { - transaction(database, tx -> { + database.transaction(tx -> { VaultFiller.fillWithSomeTestCash(services, new Amount<>(100, Currency.getInstance("USD")), + issuerServices, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(), new OpaqueBytes("1".getBytes()), null, - getDUMMY_CASH_ISSUER(), - getDUMMY_CASH_ISSUER_KEY() ); + getDUMMY_CASH_ISSUER()); // DOCSTART VaultJavaQueryExample4 @SuppressWarnings("unchecked") Set> contractStateTypes = new HashSet(Collections.singletonList(Cash.State.class)); VaultQueryCriteria criteria = new VaultQueryCriteria(Vault.StateStatus.UNCONSUMED, contractStateTypes); - DataFeed, Vault.Update> results = vaultQuerySvc.trackBy(ContractState.class, criteria); + DataFeed, Vault.Update> results = vaultQuerySvc.trackBy(ContractState.class, criteria); Vault.Page snapshot = results.getSnapshot(); - Observable updates = results.getUpdates(); + Observable> updates = results.getUpdates(); // DOCEND VaultJavaQueryExample4 assertThat(snapshot.getStates()).hasSize(3); @@ -292,7 +268,7 @@ public class VaultQueryJavaTests { @Test public void trackDealStatesPagedSorted() { - transaction(database, tx -> { + database.transaction(tx -> { Vault states = VaultFiller.fillWithSomeTestLinearStates(services, 10, null); UniqueIdentifier uid = states.getStates().iterator().next().component1().getData().getLinearId(); @@ -305,7 +281,7 @@ public class VaultQueryJavaTests { Set> contractStateTypes = new HashSet(Arrays.asList(DealState.class, LinearState.class)); QueryCriteria vaultCriteria = new VaultQueryCriteria(Vault.StateStatus.UNCONSUMED, contractStateTypes); - List linearIds = Collections.singletonList(uid); + List linearIds = Collections.singletonList(uid.getId()); List dealParty = Collections.singletonList(getMEGA_CORP()); QueryCriteria dealCriteria = new LinearStateQueryCriteria(dealParty, null, dealIds); QueryCriteria linearCriteria = new LinearStateQueryCriteria(dealParty, linearIds, null); @@ -315,10 +291,9 @@ public class VaultQueryJavaTests { PageSpecification pageSpec = new PageSpecification(DEFAULT_PAGE_NUM, MAX_PAGE_SIZE); Sort.SortColumn sortByUid = new Sort.SortColumn(new SortAttribute.Standard(Sort.LinearStateAttribute.UUID), Sort.Direction.DESC); Sort sorting = new Sort(ImmutableSet.of(sortByUid)); - DataFeed, Vault.Update> results = vaultQuerySvc.trackBy(ContractState.class, compositeCriteria, pageSpec, sorting); + DataFeed, Vault.Update> results = vaultQuerySvc.trackBy(ContractState.class, compositeCriteria, pageSpec, sorting); Vault.Page snapshot = results.getSnapshot(); - Observable updates = results.getUpdates(); // DOCEND VaultJavaQueryExample5 assertThat(snapshot.getStates()).hasSize(13); @@ -327,66 +302,6 @@ public class VaultQueryJavaTests { }); } - /** - * Deprecated usage - */ - - @Test - public void consumedStatesDeprecated() { - transaction(database, tx -> { - Amount amount = new Amount<>(100, USD); - VaultFiller.fillWithSomeTestCash(services, - new Amount<>(100, USD), - TestConstants.getDUMMY_NOTARY(), - 3, - 3, - new Random(), - new OpaqueBytes("1".getBytes()), - null, - getDUMMY_CASH_ISSUER(), - getDUMMY_CASH_ISSUER_KEY() ); - - VaultFiller.consumeCash(services, amount); - - // DOCSTART VaultDeprecatedJavaQueryExample1 - @SuppressWarnings("unchecked") - Set> contractStateTypes = new HashSet(Collections.singletonList(Cash.State.class)); - EnumSet status = EnumSet.of(Vault.StateStatus.CONSUMED); - - // WARNING! unfortunately cannot use inlined reified Kotlin extension methods. - Iterable> results = vaultSvc.states(contractStateTypes, status, true); - // DOCEND VaultDeprecatedJavaQueryExample1 - - assertThat(results).hasSize(3); - - return tx; - }); - } - - @Test - public void consumedStatesForLinearIdDeprecated() { - transaction(database, tx -> { - - Vault linearStates = VaultFiller.fillWithSomeTestLinearStates(services, 4,null); - linearStates.getStates().iterator().next().component1().getData().getLinearId(); - - VaultFiller.consumeLinearStates(services, (List>) linearStates.getStates()); - - // DOCSTART VaultDeprecatedJavaQueryExample0 - @SuppressWarnings("unchecked") - Set> contractStateTypes = new HashSet(Collections.singletonList(DummyLinearContract.State.class)); - EnumSet status = EnumSet.of(Vault.StateStatus.CONSUMED); - - // WARNING! unfortunately cannot use inlined reified Kotlin extension methods. - Iterable> results = vaultSvc.states(contractStateTypes, status, true); - // DOCEND VaultDeprecatedJavaQueryExample0 - - assertThat(results).hasSize(4); - - return tx; - }); - } - /** * Aggregation Functions */ @@ -394,7 +309,7 @@ public class VaultQueryJavaTests { @Test @SuppressWarnings("unchecked") public void aggregateFunctionsWithoutGroupClause() { - transaction(database, tx -> { + database.transaction(tx -> { Amount dollars100 = new Amount<>(100, Currency.getInstance("USD")); Amount dollars200 = new Amount<>(200, Currency.getInstance("USD")); @@ -402,11 +317,11 @@ public class VaultQueryJavaTests { Amount pounds = new Amount<>(400, Currency.getInstance("GBP")); Amount swissfrancs = new Amount<>(500, Currency.getInstance("CHF")); - VaultFiller.fillWithSomeTestCash(services, dollars100, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars200, TestConstants.getDUMMY_NOTARY(), 2, 2, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars300, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, pounds, TestConstants.getDUMMY_NOTARY(), 4, 4, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, swissfrancs, TestConstants.getDUMMY_NOTARY(), 5, 5, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); + VaultFiller.fillWithSomeTestCash(services, dollars100, issuerServices, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars200, issuerServices, TestConstants.getDUMMY_NOTARY(), 2, 2, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars300, issuerServices, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, pounds, issuerServices, TestConstants.getDUMMY_NOTARY(), 4, 4, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, swissfrancs, issuerServices, TestConstants.getDUMMY_NOTARY(), 5, 5, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); try { // DOCSTART VaultJavaQueryExample21 @@ -439,7 +354,7 @@ public class VaultQueryJavaTests { @Test @SuppressWarnings("unchecked") public void aggregateFunctionsWithSingleGroupClause() { - transaction(database, tx -> { + database.transaction(tx -> { Amount dollars100 = new Amount<>(100, Currency.getInstance("USD")); Amount dollars200 = new Amount<>(200, Currency.getInstance("USD")); @@ -447,11 +362,11 @@ public class VaultQueryJavaTests { Amount pounds = new Amount<>(400, Currency.getInstance("GBP")); Amount swissfrancs = new Amount<>(500, Currency.getInstance("CHF")); - VaultFiller.fillWithSomeTestCash(services, dollars100, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars200, TestConstants.getDUMMY_NOTARY(), 2, 2, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars300, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, pounds, TestConstants.getDUMMY_NOTARY(), 4, 4, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, swissfrancs, TestConstants.getDUMMY_NOTARY(), 5, 5, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); + VaultFiller.fillWithSomeTestCash(services, dollars100, issuerServices, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars200, issuerServices, TestConstants.getDUMMY_NOTARY(), 2, 2, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars300, issuerServices, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, pounds, issuerServices, TestConstants.getDUMMY_NOTARY(), 4, 4, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, swissfrancs, issuerServices, TestConstants.getDUMMY_NOTARY(), 5, 5, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); try { // DOCSTART VaultJavaQueryExample22 @@ -510,17 +425,17 @@ public class VaultQueryJavaTests { @Test @SuppressWarnings("unchecked") public void aggregateFunctionsSumByIssuerAndCurrencyAndSortByAggregateSum() { - transaction(database, tx -> { + database.transaction(tx -> { Amount dollars100 = new Amount<>(100, Currency.getInstance("USD")); Amount dollars200 = new Amount<>(200, Currency.getInstance("USD")); Amount pounds300 = new Amount<>(300, Currency.getInstance("GBP")); Amount pounds400 = new Amount<>(400, Currency.getInstance("GBP")); - VaultFiller.fillWithSomeTestCash(services, dollars100, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, dollars200, TestConstants.getDUMMY_NOTARY(), 2, 2, new Random(0L), new OpaqueBytes("1".getBytes()), null, getBOC().ref(new OpaqueBytes("1".getBytes())), getBOC_KEY()); - VaultFiller.fillWithSomeTestCash(services, pounds300, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER(), getDUMMY_CASH_ISSUER_KEY()); - VaultFiller.fillWithSomeTestCash(services, pounds400, TestConstants.getDUMMY_NOTARY(), 4, 4, new Random(0L), new OpaqueBytes("1".getBytes()), null, getBOC().ref(new OpaqueBytes("1".getBytes())), getBOC_KEY()); + VaultFiller.fillWithSomeTestCash(services, dollars100, issuerServices, TestConstants.getDUMMY_NOTARY(), 1, 1, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, dollars200, issuerServices, TestConstants.getDUMMY_NOTARY(), 2, 2, new Random(0L), new OpaqueBytes("1".getBytes()), null, getBOC().ref(new OpaqueBytes("1".getBytes()))); + VaultFiller.fillWithSomeTestCash(services, pounds300, issuerServices, TestConstants.getDUMMY_NOTARY(), 3, 3, new Random(0L), new OpaqueBytes("1".getBytes()), null, getDUMMY_CASH_ISSUER()); + VaultFiller.fillWithSomeTestCash(services, pounds400, issuerServices, TestConstants.getDUMMY_NOTARY(), 4, 4, new Random(0L), new OpaqueBytes("1".getBytes()), null, getBOC().ref(new OpaqueBytes("1".getBytes()))); try { // DOCSTART VaultJavaQueryExample23 diff --git a/node/src/test/kotlin/net/corda/node/ArgsParserTest.kt b/node/src/test/kotlin/net/corda/node/ArgsParserTest.kt index 2316f31e92..94b09c11c5 100644 --- a/node/src/test/kotlin/net/corda/node/ArgsParserTest.kt +++ b/node/src/test/kotlin/net/corda/node/ArgsParserTest.kt @@ -1,7 +1,7 @@ package net.corda.node import joptsimple.OptionException -import net.corda.core.div +import net.corda.core.internal.div import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.junit.Test diff --git a/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt b/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt index 11009e3c42..6d2d235610 100644 --- a/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt +++ b/node/src/test/kotlin/net/corda/node/CordaRPCOpsImplTest.kt @@ -7,13 +7,13 @@ import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.keys import net.corda.core.flows.FlowLogic import net.corda.core.flows.StateMachineRunId -import net.corda.core.getOrThrow import net.corda.core.messaging.* import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.Vault -import net.corda.core.node.services.unconsumedStates -import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.getOrThrow +import net.corda.core.node.services.queryBy +import net.corda.core.utilities.OpaqueBytes import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow import net.corda.node.internal.CordaRPCOpsImpl @@ -22,7 +22,6 @@ import net.corda.node.services.messaging.RpcContext import net.corda.node.services.network.NetworkMapService import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService -import net.corda.node.utilities.transaction import net.corda.nodeapi.PermissionException import net.corda.nodeapi.User import net.corda.testing.expect @@ -54,8 +53,7 @@ class CordaRPCOpsImplTest { lateinit var rpc: CordaRPCOps lateinit var stateMachineUpdates: Observable lateinit var transactions: Observable - lateinit var vaultUpdates: Observable // TODO: deprecated - lateinit var vaultTrackCash: Observable + lateinit var vaultTrackCash: Observable> @Before fun setup() { @@ -70,10 +68,9 @@ class CordaRPCOpsImplTest { )))) aliceNode.database.transaction { - stateMachineUpdates = rpc.stateMachinesAndUpdates().second - transactions = rpc.verifiedTransactions().second - vaultUpdates = rpc.vaultAndUpdates().second - vaultTrackCash = rpc.vaultTrackBy().future + stateMachineUpdates = rpc.stateMachinesFeed().updates + transactions = rpc.verifiedTransactionsFeed().updates + vaultTrackCash = rpc.vaultTrackBy().updates } } @@ -89,7 +86,7 @@ class CordaRPCOpsImplTest { // Check the monitoring service wallet is empty aliceNode.database.transaction { - assertFalse(aliceNode.services.vaultService.unconsumedStates().iterator().hasNext()) + assertFalse(aliceNode.services.vaultQueryService.queryBy().totalStatesAvailable > 0) } // Tell the monitoring service node to issue some cash @@ -111,7 +108,7 @@ class CordaRPCOpsImplTest { ) } - val tx = result.returnValue.getOrThrow() + result.returnValue.getOrThrow() val expectedState = Cash.State(Amount(quantity, Issued(aliceNode.info.legalIdentity.ref(ref), GBP)), recipient) @@ -120,14 +117,6 @@ class CordaRPCOpsImplTest { val cash = rpc.vaultQueryBy() assertEquals(expectedState, cash.states.first().state.data) - // TODO: deprecated - vaultUpdates.expectEvents { - expect { update -> - val actual = update.produced.single().state.data - assertEquals(expectedState, actual) - } - } - vaultTrackCash.expectEvents { expect { update -> val actual = update.produced.single().state.data @@ -140,7 +129,7 @@ class CordaRPCOpsImplTest { fun `issue and move`() { val anonymous = false val result = rpc.startFlow(::CashIssueFlow, - Amount(100, USD), + 100.DOLLARS, OpaqueBytes(ByteArray(1, { 1 })), aliceNode.info.legalIdentity, notaryNode.info.notaryIdentity, @@ -149,13 +138,13 @@ class CordaRPCOpsImplTest { mockNet.runNetwork() - rpc.startFlow(::CashPaymentFlow, Amount(100, USD), aliceNode.info.legalIdentity, anonymous) + rpc.startFlow(::CashPaymentFlow, 100.DOLLARS, aliceNode.info.legalIdentity, anonymous) mockNet.runNetwork() var issueSmId: StateMachineRunId? = null var moveSmId: StateMachineRunId? = null - stateMachineUpdates.expectEvents() { + stateMachineUpdates.expectEvents { sequence( // ISSUE expect { add: StateMachineUpdate.Added -> @@ -174,7 +163,7 @@ class CordaRPCOpsImplTest { ) } - val tx = result.returnValue.getOrThrow() + result.returnValue.getOrThrow() transactions.expectEvents { sequence( // ISSUE @@ -199,33 +188,17 @@ class CordaRPCOpsImplTest { ) } - // TODO: deprecated - vaultUpdates.expectEvents { - sequence( - // ISSUE - expect { update -> - require(update.consumed.isEmpty()) { update.consumed.size } - require(update.produced.size == 1) { update.produced.size } - }, - // MOVE - expect { update -> - require(update.consumed.size == 1) { update.consumed.size } - require(update.produced.size == 1) { update.produced.size } - } - ) - } - vaultTrackCash.expectEvents { sequence( // ISSUE - expect { update -> - require(update.consumed.isEmpty()) { update.consumed.size } - require(update.produced.size == 1) { update.produced.size } + expect { (consumed, produced) -> + require(consumed.isEmpty()) { consumed.size } + require(produced.size == 1) { produced.size } }, // MOVE - expect { update -> - require(update.consumed.size == 1) { update.consumed.size } - require(update.produced.size == 1) { update.produced.size } + expect { (consumed, produced) -> + require(consumed.size == 1) { consumed.size } + require(produced.size == 1) { produced.size } } ) } diff --git a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt index 4a2c2051de..4838ee78b5 100644 --- a/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt +++ b/node/src/test/kotlin/net/corda/node/InteractiveShellTest.kt @@ -1,21 +1,23 @@ package net.corda.node import com.fasterxml.jackson.dataformat.yaml.YAMLFactory -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.Amount import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FlowContext import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowLogic -import net.corda.core.internal.FlowStateMachine +import net.corda.core.flows.FlowStackSnapshot import net.corda.core.flows.StateMachineRunId import net.corda.core.identity.Party +import net.corda.core.internal.FlowStateMachine import net.corda.core.node.ServiceHub import net.corda.core.transactions.SignedTransaction -import net.corda.testing.DUMMY_CA import net.corda.core.utilities.UntrustworthyData import net.corda.jackson.JacksonSupport import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.shell.InteractiveShell +import net.corda.testing.DUMMY_CA import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP_IDENTITY import org.junit.Test @@ -31,6 +33,7 @@ class InteractiveShellTest { constructor(amount: Amount) : this(amount.toString()) constructor(pair: Pair, SecureHash.SHA256>) : this(pair.toString()) constructor(party: Party) : this(party.name.toString()) + override fun call() = a } @@ -70,32 +73,26 @@ class InteractiveShellTest { fun party() = check("party: \"${MEGA_CORP.name}\"", MEGA_CORP.name.toString()) class DummyFSM(val logic: FlowA) : FlowStateMachine { + override fun getFlowContext(otherParty: Party, sessionFlow: FlowLogic<*>): FlowContext { + throw UnsupportedOperationException("not implemented") + } override fun sendAndReceive(receiveType: Class, otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>, retrySend: Boolean): UntrustworthyData { throw UnsupportedOperationException("not implemented") } - override fun receive(receiveType: Class, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData { throw UnsupportedOperationException("not implemented") } - override fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) { throw UnsupportedOperationException("not implemented") } - override fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction { throw UnsupportedOperationException("not implemented") } - - override val serviceHub: ServiceHub - get() = throw UnsupportedOperationException() - override val logger: Logger - get() = throw UnsupportedOperationException() - override val id: StateMachineRunId - get() = throw UnsupportedOperationException() - override val resultFuture: ListenableFuture - get() = throw UnsupportedOperationException() - override val flowInitiator: FlowInitiator - get() = throw UnsupportedOperationException() + override val serviceHub: ServiceHub get() = throw UnsupportedOperationException() + override val logger: Logger get() = throw UnsupportedOperationException() + override val id: StateMachineRunId get() = throw UnsupportedOperationException() + override val resultFuture: CordaFuture get() = throw UnsupportedOperationException() + override val flowInitiator: FlowInitiator get() = throw UnsupportedOperationException() override fun checkFlowPermission(permissionName: String, extraAuditData: Map) { // Do nothing @@ -104,5 +101,13 @@ class InteractiveShellTest { override fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map) { // Do nothing } + + override fun flowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot? { + return null + } + + override fun persistFlowStackSnapshot(flowClass: Class<*>) { + // Do nothing + } } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt index b935635105..7dbce0a062 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/InMemoryMessagingTests.kt @@ -1,13 +1,15 @@ package net.corda.node.messaging -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.node.services.ServiceInfo +import net.corda.node.services.api.DEFAULT_SESSION_ID import net.corda.node.services.messaging.Message import net.corda.node.services.messaging.TopicStringValidator import net.corda.node.services.messaging.createMessage import net.corda.node.services.network.NetworkMapService import net.corda.testing.node.MockNetwork import org.junit.After +import net.corda.testing.resetTestSerialization +import org.junit.Before import org.junit.Test import java.util.* import kotlin.test.assertEquals @@ -15,11 +17,20 @@ import kotlin.test.assertFails import kotlin.test.assertTrue class InMemoryMessagingTests { - val mockNet = MockNetwork() + lateinit var mockNet: MockNetwork + + @Before + fun setUp() { + mockNet = MockNetwork() + } @After - fun cleanUp() { - mockNet.stopNodes() + fun tearDown() { + if (mockNet.nodes.isNotEmpty()) { + mockNet.stopNodes() + } else { + resetTestSerialization() + } } @Test diff --git a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt index 049e44a463..f86adff8c3 100644 --- a/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/messaging/TwoPartyTradeFlowTests.kt @@ -2,12 +2,13 @@ package net.corda.node.messaging import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.CommercialPaper -import net.corda.contracts.asset.* -import net.corda.core.* +import net.corda.contracts.asset.CASH +import net.corda.contracts.asset.Cash +import net.corda.contracts.asset.`issued by` +import net.corda.contracts.asset.`owned by` +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.* -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.sign +import net.corda.core.crypto.* import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatedBy import net.corda.core.flows.InitiatingFlow @@ -16,17 +17,21 @@ import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.internal.FlowStateMachine +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.rootCause import net.corda.core.messaging.DataFeed import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.core.node.NodeInfo import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.Vault -import net.corda.core.serialization.serialize +import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.WireTransaction -import net.corda.testing.LogHelper +import net.corda.core.utilities.days +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.toNonEmptySet import net.corda.core.utilities.unwrap import net.corda.flows.TwoPartyTradeFlow.Buyer import net.corda.flows.TwoPartyTradeFlow.Seller @@ -35,14 +40,13 @@ import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.persistence.checkpoints -import net.corda.node.utilities.transaction +import net.corda.node.utilities.CordaPersistence import net.corda.testing.* import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockNetwork import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Test @@ -51,9 +55,7 @@ import java.io.ByteArrayInputStream import java.io.ByteArrayOutputStream import java.math.BigInteger import java.security.KeyPair -import java.security.PublicKey import java.util.* -import java.util.concurrent.Future import java.util.jar.JarOutputStream import java.util.zip.ZipEntry import kotlin.test.assertEquals @@ -71,7 +73,6 @@ class TwoPartyTradeFlowTests { @Before fun before() { - mockNet = MockNetwork(false) LogHelper.setLevel("platform.trade", "core.contract.TransactionGroup", "recordingmap") } @@ -88,25 +89,29 @@ class TwoPartyTradeFlowTests { // allow interruption half way through. mockNet = MockNetwork(false, true) - ledger { - val basketOfNodes = mockNet.createSomeNodes(2) + ledger(initialiseSerialization = false) { + val basketOfNodes = mockNet.createSomeNodes(3) val notaryNode = basketOfNodes.notaryNode val aliceNode = basketOfNodes.partyNodes[0] val bobNode = basketOfNodes.partyNodes[1] + val bankNode = basketOfNodes.partyNodes[2] + val cashIssuer = bankNode.info.legalIdentity.ref(1) + val cpIssuer = bankNode.info.legalIdentity.ref(1, 2, 3) aliceNode.disableDBCloseOnStop() bobNode.disableDBCloseOnStop() bobNode.database.transaction { - bobNode.services.fillWithSomeTestCash(2000.DOLLARS, outputNotary = notaryNode.info.notaryIdentity) + bobNode.services.fillWithSomeTestCash(2000.DOLLARS, bankNode.services, outputNotary = notaryNode.info.notaryIdentity, + issuedBy = cashIssuer) } val alicesFakePaper = aliceNode.database.transaction { - fillUpForSeller(false, aliceNode.info.legalIdentity, - 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null, notaryNode.info.notaryIdentity).second + fillUpForSeller(false, cpIssuer, aliceNode.info.legalIdentity, + 1200.DOLLARS `issued by` bankNode.info.legalIdentity.ref(0), null, notaryNode.info.notaryIdentity).second } - insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, MEGA_CORP_PUBKEY) + insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, bankNode) val (bobStateMachine, aliceResult) = runBuyerAndSeller(notaryNode, aliceNode, bobNode, "alice's paper".outputStateAndRef()) @@ -133,33 +138,40 @@ class TwoPartyTradeFlowTests { fun `trade cash for commercial paper fails using soft locking`() { mockNet = MockNetwork(false, true) - ledger { + ledger(initialiseSerialization = false) { val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = mockNet.createPartyNode(notaryNode.network.myAddress, ALICE.name) val bobNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOB.name) + val bankNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOC.name) + val cashIssuer = bankNode.info.legalIdentity.ref(1) + val cpIssuer = bankNode.info.legalIdentity.ref(1, 2, 3) aliceNode.disableDBCloseOnStop() bobNode.disableDBCloseOnStop() val cashStates = bobNode.database.transaction { - bobNode.services.fillWithSomeTestCash(2000.DOLLARS, notaryNode.info.notaryIdentity, 3, 3) + bobNode.services.fillWithSomeTestCash(2000.DOLLARS, bankNode.services, notaryNode.info.notaryIdentity, 3, 3, + issuedBy = cashIssuer) } val alicesFakePaper = aliceNode.database.transaction { - fillUpForSeller(false, aliceNode.info.legalIdentity, - 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null, notaryNode.info.notaryIdentity).second + fillUpForSeller(false, cpIssuer, aliceNode.info.legalIdentity, + 1200.DOLLARS `issued by` bankNode.info.legalIdentity.ref(0), null, notaryNode.info.notaryIdentity).second } - insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, MEGA_CORP_PUBKEY) + insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, bankNode) val cashLockId = UUID.randomUUID() bobNode.database.transaction { // lock the cash states with an arbitrary lockId (to prevent the Buyer flow from claiming the states) - bobNode.services.vaultService.softLockReserve(cashLockId, cashStates.states.map { it.ref }.toSet()) + val refs = cashStates.states.map { it.ref } + if (refs.isNotEmpty()) { + bobNode.services.vaultService.softLockReserve(cashLockId, refs.toNonEmptySet()) + } } val (bobStateMachine, aliceResult) = runBuyerAndSeller(notaryNode, aliceNode, bobNode, - "alice's paper".outputStateAndRef()) + "alice's paper".outputStateAndRef()) assertEquals(aliceResult.getOrThrow(), bobStateMachine.getOrThrow().resultFuture.getOrThrow()) @@ -179,26 +191,34 @@ class TwoPartyTradeFlowTests { @Test fun `shutdown and restore`() { - ledger { + mockNet = MockNetwork(false) + ledger(initialiseSerialization = false) { val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = mockNet.createPartyNode(notaryNode.network.myAddress, ALICE.name) var bobNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOB.name) + val bankNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOC.name) + val cashIssuer = bankNode.info.legalIdentity.ref(1) + val cpIssuer = bankNode.info.legalIdentity.ref(1, 2, 3) + + aliceNode.services.identityService.registerIdentity(bobNode.info.legalIdentityAndCert) + bobNode.services.identityService.registerIdentity(aliceNode.info.legalIdentityAndCert) aliceNode.disableDBCloseOnStop() bobNode.disableDBCloseOnStop() val bobAddr = bobNode.network.myAddress as InMemoryMessagingNetwork.PeerHandle - val networkMapAddr = notaryNode.network.myAddress + val networkMapAddress = notaryNode.network.myAddress mockNet.runNetwork() // Clear network map registration messages bobNode.database.transaction { - bobNode.services.fillWithSomeTestCash(2000.DOLLARS, outputNotary = notaryNode.info.notaryIdentity) + bobNode.services.fillWithSomeTestCash(2000.DOLLARS, bankNode.services, outputNotary = notaryNode.info.notaryIdentity, + issuedBy = cashIssuer) } val alicesFakePaper = aliceNode.database.transaction { - fillUpForSeller(false, aliceNode.info.legalIdentity, - 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, null, notaryNode.info.notaryIdentity).second + fillUpForSeller(false, cpIssuer, aliceNode.info.legalIdentity, + 1200.DOLLARS `issued by` bankNode.info.legalIdentity.ref(0), null, notaryNode.info.notaryIdentity).second } - insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, MEGA_CORP_PUBKEY) + insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, bankNode) val aliceFuture = runBuyerAndSeller(notaryNode, aliceNode, bobNode, "alice's paper".outputStateAndRef()).sellerResult // Everything is on this thread so we can now step through the flow one step at a time. @@ -233,7 +253,7 @@ class TwoPartyTradeFlowTests { // ... bring the node back up ... the act of constructing the SMM will re-register the message handlers // that Bob was waiting on before the reboot occurred. - bobNode = mockNet.createNode(networkMapAddr, bobAddr.id, object : MockNetwork.Factory { + bobNode = mockNet.createNode(networkMapAddress, bobAddr.id, object : MockNetwork.Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, entropyRoot: BigInteger): MockNetwork.MockNode { @@ -273,11 +293,10 @@ class TwoPartyTradeFlowTests { // Creates a mock node with an overridden storage service that uses a RecordingMap, that lets us test the order // of gets and puts. private fun makeNodeWithTracking( - networkMapAddr: SingleMessageRecipient?, - name: X500Name, - overrideServices: Map? = null): MockNetwork.MockNode { + networkMapAddress: SingleMessageRecipient?, + name: X500Name): MockNetwork.MockNode { // Create a node in the mock network ... - return mockNet.createNode(networkMapAddr, -1, object : MockNetwork.Factory { + return mockNet.createNode(networkMapAddress, nodeFactory = object : MockNetwork.Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, @@ -291,16 +310,20 @@ class TwoPartyTradeFlowTests { } } } - }, true, name, overrideServices) + }, legalName = name) } @Test fun `check dependencies of sale asset are resolved`() { + mockNet = MockNetwork(false) + val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = makeNodeWithTracking(notaryNode.network.myAddress, ALICE.name) val bobNode = makeNodeWithTracking(notaryNode.network.myAddress, BOB.name) + val bankNode = makeNodeWithTracking(notaryNode.network.myAddress, BOC.name) + val issuer = bankNode.info.legalIdentity.ref(1, 2, 3) - ledger(aliceNode.services) { + ledger(aliceNode.services, initialiseSerialization = false) { // Insert a prospectus type attachment into the commercial paper transaction. val stream = ByteArrayOutputStream() @@ -313,16 +336,14 @@ class TwoPartyTradeFlowTests { attachment(ByteArrayInputStream(stream.toByteArray())) } - val extraKey = bobNode.services.keyManagementService.keys.single() - val bobsFakeCash = fillUpForBuyer(false, AnonymousParty(extraKey), - DUMMY_CASH_ISSUER.party, + val bobsFakeCash = fillUpForBuyer(false, issuer, AnonymousParty(bobNode.info.legalIdentity.owningKey), notaryNode.info.notaryIdentity).second - val bobsSignedTxns = insertFakeTransactions(bobsFakeCash, bobNode, notaryNode, extraKey, DUMMY_CASH_ISSUER_KEY.public, MEGA_CORP_PUBKEY) + val bobsSignedTxns = insertFakeTransactions(bobsFakeCash, bobNode, notaryNode, bankNode) val alicesFakePaper = aliceNode.database.transaction { - fillUpForSeller(false, aliceNode.info.legalIdentity, - 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID, notaryNode.info.notaryIdentity).second + fillUpForSeller(false, issuer, aliceNode.info.legalIdentity, + 1200.DOLLARS `issued by` bankNode.info.legalIdentity.ref(0), attachmentID, notaryNode.info.notaryIdentity).second } - val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, MEGA_CORP_PUBKEY) + val alicesSignedTxns = insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, bankNode) mockNet.runNetwork() // Clear network map registration messages @@ -395,11 +416,15 @@ class TwoPartyTradeFlowTests { @Test fun `track works`() { + mockNet = MockNetwork(false) + val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = makeNodeWithTracking(notaryNode.network.myAddress, ALICE.name) val bobNode = makeNodeWithTracking(notaryNode.network.myAddress, BOB.name) + val bankNode = makeNodeWithTracking(notaryNode.network.myAddress, BOC.name) + val issuer = bankNode.info.legalIdentity.ref(1, 2, 3) - ledger(aliceNode.services) { + ledger(aliceNode.services, initialiseSerialization = false) { // Insert a prospectus type attachment into the commercial paper transaction. val stream = ByteArrayOutputStream() @@ -413,17 +438,16 @@ class TwoPartyTradeFlowTests { } val bobsKey = bobNode.services.keyManagementService.keys.single() - val bobsFakeCash = fillUpForBuyer(false, AnonymousParty(bobsKey), - DUMMY_CASH_ISSUER.party, + val bobsFakeCash = fillUpForBuyer(false, issuer, AnonymousParty(bobsKey), notaryNode.info.notaryIdentity).second - insertFakeTransactions(bobsFakeCash, bobNode, notaryNode, DUMMY_CASH_ISSUER_KEY.public, MEGA_CORP_PUBKEY) + insertFakeTransactions(bobsFakeCash, bobNode, notaryNode, bankNode) val alicesFakePaper = aliceNode.database.transaction { - fillUpForSeller(false, aliceNode.info.legalIdentity, - 1200.DOLLARS `issued by` DUMMY_CASH_ISSUER, attachmentID, notaryNode.info.notaryIdentity).second + fillUpForSeller(false, issuer, aliceNode.info.legalIdentity, + 1200.DOLLARS `issued by` bankNode.info.legalIdentity.ref(0), attachmentID, notaryNode.info.notaryIdentity).second } - insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, MEGA_CORP_PUBKEY) + insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, bankNode) mockNet.runNetwork() // Clear network map registration messages @@ -469,22 +493,24 @@ class TwoPartyTradeFlowTests { @Test fun `dependency with error on buyer side`() { - ledger { - runWithError(true, false, "at least one asset input") + mockNet = MockNetwork(false) + ledger(initialiseSerialization = false) { + runWithError(true, false, "at least one cash input") } } @Test fun `dependency with error on seller side`() { - ledger { - runWithError(false, true, "Issuances must have a time-window") + mockNet = MockNetwork(false) + ledger(initialiseSerialization = false) { + runWithError(false, true, "Issuances have a time-window") } } private data class RunResult( // The buyer is not created immediately, only when the seller starts running - val buyer: Future>, - val sellerResult: Future, + val buyer: CordaFuture>, + val sellerResult: CordaFuture, val sellerId: StateMachineRunId ) @@ -492,11 +518,12 @@ class TwoPartyTradeFlowTests { sellerNode: MockNetwork.MockNode, buyerNode: MockNetwork.MockNode, assetToSell: StateAndRef): RunResult { - sellerNode.services.identityService.registerIdentity(buyerNode.info.legalIdentityAndCert) - buyerNode.services.identityService.registerIdentity(sellerNode.info.legalIdentityAndCert) + val anonymousSeller = sellerNode.services.let { serviceHub -> + serviceHub.keyManagementService.freshKeyAndCert(serviceHub.myInfo.legalIdentityAndCert, false) + }.party.anonymise() val buyerFlows: Observable = buyerNode.registerInitiatedFlow(BuyerAcceptor::class.java) val firstBuyerFiber = buyerFlows.toFuture().map { it.stateMachine } - val seller = SellerInitiator(buyerNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS) + val seller = SellerInitiator(buyerNode.info.legalIdentity, notaryNode.info, assetToSell, 1000.DOLLARS, anonymousSeller) val sellerResult = sellerNode.services.startFlow(seller).resultFuture return RunResult(firstBuyerFiber, sellerResult, seller.stateMachine.id) } @@ -505,17 +532,17 @@ class TwoPartyTradeFlowTests { class SellerInitiator(val buyer: Party, val notary: NodeInfo, val assetToSell: StateAndRef, - val price: Amount) : FlowLogic() { + val price: Amount, + val me: AnonymousParty) : FlowLogic() { @Suspendable override fun call(): SignedTransaction { send(buyer, Pair(notary.notaryIdentity, price)) - val key = serviceHub.keyManagementService.freshKey() return subFlow(Seller( - buyer, - notary, - assetToSell, - price, - AnonymousParty(key))) + buyer, + notary, + assetToSell, + price, + me)) } } @@ -539,19 +566,20 @@ class TwoPartyTradeFlowTests { val notaryNode = mockNet.createNotaryNode(null, DUMMY_NOTARY.name) val aliceNode = mockNet.createPartyNode(notaryNode.network.myAddress, ALICE.name) val bobNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOB.name) - val issuer = MEGA_CORP.ref(1, 2, 3) + val bankNode = mockNet.createPartyNode(notaryNode.network.myAddress, BOC.name) + val issuer = bankNode.info.legalIdentity.ref(1, 2, 3) val bobsBadCash = bobNode.database.transaction { - fillUpForBuyer(bobError, bobNode.info.legalIdentity, DUMMY_CASH_ISSUER.party, + fillUpForBuyer(bobError, issuer, bobNode.info.legalIdentity, notaryNode.info.notaryIdentity).second } val alicesFakePaper = aliceNode.database.transaction { - fillUpForSeller(aliceError, aliceNode.info.legalIdentity, + fillUpForSeller(aliceError, issuer, aliceNode.info.legalIdentity, 1200.DOLLARS `issued by` issuer, null, notaryNode.info.notaryIdentity).second } - insertFakeTransactions(bobsBadCash, bobNode, notaryNode, DUMMY_CASH_ISSUER_KEY.public, MEGA_CORP_PUBKEY) - insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, MEGA_CORP_PUBKEY) + insertFakeTransactions(bobsBadCash, bobNode, notaryNode, bankNode) + insertFakeTransactions(alicesFakePaper, aliceNode, notaryNode, bankNode) mockNet.runNetwork() // Clear network map registration messages @@ -575,25 +603,18 @@ class TwoPartyTradeFlowTests { private fun insertFakeTransactions( wtxToSign: List, node: AbstractNode, - notaryNode: MockNetwork.MockNode, - vararg extraKeys: PublicKey): Map { + notaryNode: AbstractNode, + vararg extraSigningNodes: AbstractNode): Map { val signed = wtxToSign.map { - val bits = it.serialize() val id = it.id - val sigs = mutableListOf() - sigs.add(node.services.keyManagementService.sign(id.bytes, node.services.legalIdentityKey)) - sigs.add(notaryNode.services.keyManagementService.sign(id.bytes, notaryNode.services.notaryIdentityKey)) - for (extraKey in extraKeys) { - if (extraKey == DUMMY_CASH_ISSUER_KEY.public) { - sigs.add(DUMMY_CASH_ISSUER_KEY.sign(id.bytes)) - } else if (extraKey == MEGA_CORP_PUBKEY) { - sigs.add(MEGA_CORP_KEY.sign(id.bytes)) - } else { - sigs.add(node.services.keyManagementService.sign(id.bytes, extraKey)) - } + val sigs = mutableListOf() + sigs.add(node.services.keyManagementService.sign(SignableData(id, SignatureMetadata(1, Crypto.findSignatureScheme(node.services.legalIdentityKey).schemeNumberID)), node.services.legalIdentityKey)) + sigs.add(notaryNode.services.keyManagementService.sign(SignableData(id, SignatureMetadata(1, Crypto.findSignatureScheme(notaryNode.services.notaryIdentityKey).schemeNumberID)), notaryNode.services.notaryIdentityKey)) + extraSigningNodes.forEach { currentNode -> + sigs.add(currentNode.services.keyManagementService.sign(SignableData(id, SignatureMetadata(1, Crypto.findSignatureScheme(currentNode.info.legalIdentity.owningKey).schemeNumberID)), currentNode.info.legalIdentity.owningKey)) } - SignedTransaction(bits, sigs) + SignedTransaction(it, sigs) } return node.database.transaction { node.services.recordTransactions(signed) @@ -607,10 +628,10 @@ class TwoPartyTradeFlowTests { private fun LedgerDSL.fillUpForBuyer( withError: Boolean, + issuer: PartyAndReference, owner: AbstractParty, - issuer: AbstractParty, notary: Party): Pair, List> { - val interimOwner = MEGA_CORP + val interimOwner = issuer.party // Bob (Buyer) has some cash he got from the Bank of Elbonia, Alice (Seller) has some commercial paper she // wants to sell to Bob. val eb1 = transaction(transactionBuilder = TransactionBuilder(notary = notary)) { @@ -618,10 +639,10 @@ class TwoPartyTradeFlowTests { output("elbonian money 1", notary = notary) { 800.DOLLARS.CASH `issued by` issuer `owned by` interimOwner } output("elbonian money 2", notary = notary) { 1000.DOLLARS.CASH `issued by` issuer `owned by` interimOwner } if (!withError) { - command(issuer.owningKey) { Cash.Commands.Issue() } + command(issuer.party.owningKey) { Cash.Commands.Issue() } } else { // Put a broken command on so at least a signature is created - command(issuer.owningKey) { Cash.Commands.Move() } + command(issuer.party.owningKey) { Cash.Commands.Move() } } timeWindow(TEST_TX_TIME) if (withError) { @@ -653,15 +674,16 @@ class TwoPartyTradeFlowTests { private fun LedgerDSL.fillUpForSeller( withError: Boolean, + issuer: PartyAndReference, owner: AbstractParty, amount: Amount>, attachmentID: SecureHash?, notary: Party): Pair, List> { val ap = transaction(transactionBuilder = TransactionBuilder(notary = notary)) { output("alice's paper", notary = notary) { - CommercialPaper.State(MEGA_CORP.ref(1, 2, 3), owner, amount, TEST_TX_TIME + 7.days) + CommercialPaper.State(issuer, owner, amount, TEST_TX_TIME + 7.days) } - command(MEGA_CORP_PUBKEY) { CommercialPaper.Commands.Issue() } + command(issuer.party.owningKey) { CommercialPaper.Commands.Issue() } if (!withError) timeWindow(time = TEST_TX_TIME) if (attachmentID != null) @@ -678,7 +700,7 @@ class TwoPartyTradeFlowTests { } - class RecordingTransactionStorage(val database: Database, val delegate: WritableTransactionStorage) : WritableTransactionStorage { + class RecordingTransactionStorage(val database: CordaPersistence, val delegate: WritableTransactionStorage) : WritableTransactionStorage { override fun track(): DataFeed, SignedTransaction> { return database.transaction { delegate.track() diff --git a/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt b/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt index 693abdc855..1994ed04c8 100644 --- a/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt +++ b/node/src/test/kotlin/net/corda/node/services/MockServiceHubInternal.kt @@ -6,6 +6,7 @@ import net.corda.core.flows.FlowLogic import net.corda.core.node.NodeInfo import net.corda.core.node.services.* import net.corda.core.serialization.SerializeAsToken +import net.corda.core.utilities.NonEmptySet import net.corda.node.internal.InitiatedFlowFactory import net.corda.node.serialization.NodeClock import net.corda.node.services.api.* @@ -15,16 +16,19 @@ import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.transactions.InMemoryTransactionVerifierService +import net.corda.node.utilities.CordaPersistence +import net.corda.testing.DUMMY_IDENTITY_1 +import net.corda.testing.MOCK_HOST_AND_PORT import net.corda.testing.MOCK_IDENTITY_SERVICE import net.corda.testing.node.MockAttachmentStorage import net.corda.testing.node.MockNetworkMapCache import net.corda.testing.node.MockStateMachineRecordedTransactionMappingStorage import net.corda.testing.node.MockTransactionStorage -import org.jetbrains.exposed.sql.Database +import java.sql.Connection import java.time.Clock open class MockServiceHubInternal( - override val database: Database, + override val database: CordaPersistence, override val configuration: NodeConfiguration, val customVault: VaultService? = null, val customVaultQuery: VaultQueryService? = null, @@ -33,7 +37,6 @@ open class MockServiceHubInternal( val identity: IdentityService? = MOCK_IDENTITY_SERVICE, override val attachments: AttachmentStorage = MockAttachmentStorage(), override val validatedTransactions: WritableTransactionStorage = MockTransactionStorage(), - override val uploaders: List = listOf(), override val stateMachineRecordedTransactionMapping: StateMachineRecordedTransactionMappingStorage = MockStateMachineRecordedTransactionMappingStorage(), val mapCache: NetworkMapCacheInternal? = null, val scheduler: SchedulerService? = null, @@ -60,7 +63,7 @@ open class MockServiceHubInternal( override val clock: Clock get() = overrideClock ?: throw UnsupportedOperationException() override val myInfo: NodeInfo - get() = throw UnsupportedOperationException() + get() = NodeInfo(listOf(MOCK_HOST_AND_PORT), DUMMY_IDENTITY_1, NonEmptySet.of(DUMMY_IDENTITY_1), 1) // Required to get a dummy platformVersion when required for tests. override val monitoringService: MonitoringService = MonitoringService(MetricRegistry()) override val rpcFlows: List>> get() = throw UnsupportedOperationException() @@ -77,4 +80,6 @@ open class MockServiceHubInternal( } override fun getFlowFactory(initiatingFlowClass: Class>): InitiatedFlowFactory<*>? = null + + override fun jdbcSession(): Connection = database.createSession() } diff --git a/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt b/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt index 8aef7f445e..fd5fff4e00 100644 --- a/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/NotaryChangeTests.kt @@ -1,19 +1,20 @@ package net.corda.node.services import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyContract import net.corda.core.crypto.generateKeyPair -import net.corda.core.getOrThrow +import net.corda.core.flows.NotaryChangeFlow +import net.corda.core.flows.StateReplacementException import net.corda.core.identity.Party import net.corda.core.node.services.ServiceInfo -import net.corda.core.seconds +import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.WireTransaction -import net.corda.flows.NotaryChangeFlow -import net.corda.flows.StateReplacementException +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds import net.corda.node.internal.AbstractNode import net.corda.node.services.network.NetworkMapService import net.corda.node.services.transactions.SimpleNotaryService import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.contracts.DummyContract import net.corda.testing.getTestPartyAndCertificate import net.corda.testing.node.MockNetwork import org.assertj.core.api.Assertions.assertThatExceptionOfType @@ -34,7 +35,7 @@ class NotaryChangeTests { lateinit var clientNodeB: MockNetwork.MockNode @Before - fun setup() { + fun setUp() { mockNet = MockNetwork() oldNotaryNode = mockNet.createNode( legalName = DUMMY_NOTARY.name, @@ -106,11 +107,12 @@ class NotaryChangeTests { val newState = future.resultFuture.getOrThrow() assertEquals(newState.state.notary, newNotary) - val notaryChangeTx = clientNodeA.services.validatedTransactions.getTransaction(newState.ref.txhash)!!.tx + val recordedTx = clientNodeA.services.validatedTransactions.getTransaction(newState.ref.txhash)!! + val notaryChangeTx = recordedTx.resolveNotaryChangeTransaction(clientNodeA.services) // Check that all encumbrances have been propagated to the outputs - val originalOutputs = issueTx.outputs.map { it.data } - val newOutputs = notaryChangeTx.outputs.map { it.data } + val originalOutputs = issueTx.outputStates + val newOutputs = notaryChangeTx.outputStates assertTrue(originalOutputs.minus(newOutputs).isEmpty()) // Check that encumbrance links aren't broken after notary change @@ -135,7 +137,7 @@ class NotaryChangeTests { val stateB = DummyContract.SingleOwnerState(Random().nextInt(), owner.party) val stateC = DummyContract.SingleOwnerState(Random().nextInt(), owner.party) - val tx = TransactionType.General.Builder(null).apply { + val tx = TransactionBuilder(null).apply { addCommand(Command(DummyContract.Commands.Create(), owner.party.owningKey)) addOutputState(stateA, notary, encumbrance = 2) // Encumbered by stateB addOutputState(stateC, notary) @@ -166,7 +168,7 @@ fun issueState(node: AbstractNode, notaryNode: AbstractNode): StateAndRef<*> { fun issueMultiPartyState(nodeA: AbstractNode, nodeB: AbstractNode, notaryNode: AbstractNode): StateAndRef { val state = TransactionState(DummyContract.MultiOwnerState(0, listOf(nodeA.info.legalIdentity, nodeB.info.legalIdentity)), notaryNode.info.notaryIdentity) - val tx = TransactionType.NotaryChange.Builder(notaryNode.info.notaryIdentity).withItems(state) + val tx = TransactionBuilder(notary = notaryNode.info.notaryIdentity).withItems(state) val signedByA = nodeA.services.signInitialTransaction(tx) val signedByAB = nodeB.services.addSignature(signedByA) val stx = notaryNode.services.addSignature(signedByAB, notaryNode.services.notaryIdentityKey) diff --git a/node/src/test/kotlin/net/corda/node/services/config/FullNodeConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/config/FullNodeConfigurationTest.kt index fcebec4411..03adc6102d 100644 --- a/node/src/test/kotlin/net/corda/node/services/config/FullNodeConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/config/FullNodeConfigurationTest.kt @@ -5,6 +5,7 @@ import net.corda.core.utilities.NetworkHostAndPort import net.corda.testing.ALICE import net.corda.nodeapi.User import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Test import java.net.URL @@ -21,6 +22,7 @@ class FullNodeConfigurationTest { keyStorePassword = "cordacadevpass", trustStorePassword = "trustpass", dataSourceProperties = makeTestDataSourceProperties(ALICE.name.commonName), + database = makeTestDatabaseProperties(), certificateSigningService = URL("http://localhost"), rpcUsers = emptyList(), verifierType = VerifierType.InMemory, @@ -29,7 +31,7 @@ class FullNodeConfigurationTest { rpcAddress = NetworkHostAndPort("localhost", 1), messagingServerAddress = null, extraAdvertisedServiceIds = emptyList(), - bftReplicaId = null, + bftSMaRt = BFTSMaRtConfiguration(-1, false), notaryNodeAddress = null, notaryClusterAddresses = emptyList(), certificateChainCheckPolicies = emptyList(), diff --git a/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt index cc8c9a98b1..923457f44b 100644 --- a/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/database/HibernateConfigurationTest.kt @@ -1,60 +1,53 @@ package net.corda.node.services.database -import net.corda.contracts.asset.Cash -import net.corda.contracts.asset.DUMMY_CASH_ISSUER -import net.corda.contracts.asset.DummyFungibleContract +import net.corda.contracts.asset.* +import net.corda.core.contracts.* +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.toBase58String +import net.corda.core.node.services.Vault +import net.corda.core.node.services.VaultService +import net.corda.core.schemas.CommonSchemaV1 +import net.corda.core.schemas.PersistentStateRef +import net.corda.core.serialization.deserialize +import net.corda.core.transactions.SignedTransaction +import net.corda.node.services.schema.HibernateObserver +import net.corda.node.services.schema.NodeSchemaService +import net.corda.node.services.vault.VaultSchemaV1 +import net.corda.node.utilities.CordaPersistence +import net.corda.node.utilities.configureDatabase +import net.corda.schemas.CashSchemaV1 +import net.corda.schemas.SampleCashSchemaV2 +import net.corda.schemas.SampleCashSchemaV3 +import net.corda.testing.* import net.corda.testing.contracts.consumeCash import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.testing.contracts.fillWithSomeTestDeals import net.corda.testing.contracts.fillWithSomeTestLinearStates -import net.corda.core.contracts.* -import net.corda.core.crypto.toBase58String -import net.corda.core.node.services.Vault -import net.corda.core.node.services.VaultService -import net.corda.core.schemas.PersistentStateRef -import net.corda.testing.schemas.DummyLinearStateSchemaV1 -import net.corda.testing.schemas.DummyLinearStateSchemaV2 -import net.corda.core.serialization.deserialize -import net.corda.core.serialization.storageKryo -import net.corda.core.transactions.SignedTransaction -import net.corda.testing.ALICE -import net.corda.testing.BOB -import net.corda.testing.BOB_KEY -import net.corda.testing.DUMMY_NOTARY -import net.corda.node.services.schema.HibernateObserver -import net.corda.node.services.schema.NodeSchemaService -import net.corda.node.services.vault.NodeVaultService -import net.corda.core.schemas.CommonSchemaV1 -import net.corda.node.services.vault.VaultSchemaV1 -import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction -import net.corda.schemas.CashSchemaV1 -import net.corda.schemas.SampleCashSchemaV2 -import net.corda.schemas.SampleCashSchemaV3 -import net.corda.testing.BOB_PUBKEY -import net.corda.testing.BOC -import net.corda.testing.BOC_KEY import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService +import net.corda.testing.schemas.DummyLinearStateSchemaV1 +import net.corda.testing.schemas.DummyLinearStateSchemaV2 import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions.assertThat import org.hibernate.SessionFactory -import org.jetbrains.exposed.sql.Database import org.junit.After +import org.junit.Assert import org.junit.Before import org.junit.Test -import java.io.Closeable +import java.math.BigDecimal import java.time.Instant import java.util.* import javax.persistence.EntityManager import javax.persistence.Tuple import javax.persistence.criteria.CriteriaBuilder -class HibernateConfigurationTest { +class HibernateConfigurationTest : TestDependencyInjectionBase() { lateinit var services: MockServices - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var issuerServices: MockServices + lateinit var database: CordaPersistence val vault: VaultService get() = services.vaultService // Hibernate configuration objects @@ -69,22 +62,15 @@ class HibernateConfigurationTest { @Before fun setUp() { + issuerServices = MockServices(DUMMY_CASH_ISSUER_KEY, BOB_KEY, BOC_KEY) val dataSourceProps = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProps) + val defaultDatabaseProperties = makeTestDatabaseProperties() + database = configureDatabase(dataSourceProps, defaultDatabaseProperties, identitySvc = ::makeTestIdentityService) val customSchemas = setOf(VaultSchemaV1, CashSchemaV1, SampleCashSchemaV2, SampleCashSchemaV3) - - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second database.transaction { - - hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas)) - - services = object : MockServices(BOB_KEY) { - override val vaultService: VaultService get() { - val vaultService = NodeVaultService(this, dataSourceProps) - hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig) - return vaultService - } + hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), ::makeTestIdentityService) + services = object : MockServices(BOB_KEY, BOC_KEY, DUMMY_NOTARY_KEY) { + override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig) override fun recordTransactions(txs: Iterable) { for (stx in txs) { @@ -93,7 +79,9 @@ class HibernateConfigurationTest { // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. vaultService.notifyAll(txs.map { it.tx }) } + override fun jdbcSession() = database.createSession() } + hibernatePersister = services.hibernatePersister } setUpDb() @@ -104,12 +92,12 @@ class HibernateConfigurationTest { @After fun cleanUp() { - dataSource.close() + database.close() } private fun setUpDb() { database.transaction { - cashStates = services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 10, 10, Random(0L)).states.toList() + cashStates = services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 10, 10, Random(0L)).states.toList() } } @@ -128,7 +116,7 @@ class HibernateConfigurationTest { @Test fun `consumed states`() { database.transaction { - services.consumeCash(50.DOLLARS) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) } // structure query @@ -139,7 +127,8 @@ class HibernateConfigurationTest { // execute query val queryResults = entityManager.createQuery(criteriaQuery).resultList - assertThat(queryResults.size).isEqualTo(6) + val coins = queryResults.map { it.contractState.deserialize>().data }.sumCash() + assertThat(coins.toDecimal() >= BigDecimal("50.00")) } @Test @@ -206,11 +195,11 @@ class HibernateConfigurationTest { fun `with sorting by state ref desc and asc`() { // generate additional state ref indexes database.transaction { - services.consumeCash(1.DOLLARS) - services.consumeCash(2.DOLLARS) - services.consumeCash(3.DOLLARS) - services.consumeCash(4.DOLLARS) - services.consumeCash(5.DOLLARS) + services.consumeCash(1.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(2.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(3.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(4.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(5.DOLLARS, notary = DUMMY_NOTARY) } // structure query @@ -236,11 +225,11 @@ class HibernateConfigurationTest { fun `with sorting by state ref index and txId desc and asc`() { // generate additional state ref indexes database.transaction { - services.consumeCash(1.DOLLARS) - services.consumeCash(2.DOLLARS) - services.consumeCash(3.DOLLARS) - services.consumeCash(4.DOLLARS) - services.consumeCash(5.DOLLARS) + services.consumeCash(1.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(2.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(3.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(4.DOLLARS, notary = DUMMY_NOTARY) + services.consumeCash(5.DOLLARS, notary = DUMMY_NOTARY) } // structure query @@ -267,7 +256,7 @@ class HibernateConfigurationTest { fun `with pagination`() { // add 100 additional cash entries database.transaction { - services.fillWithSomeTestCash(1000.POUNDS, DUMMY_NOTARY, 100, 100, Random(0L)) + services.fillWithSomeTestCash(1000.POUNDS, issuerServices, DUMMY_NOTARY, 100, 100, Random(0L), issuedBy = DUMMY_CASH_ISSUER) } // structure query @@ -369,11 +358,11 @@ class HibernateConfigurationTest { fun `calculate cash balances`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 10, 10, Random(0L)) // +$100 = $200 - services.fillWithSomeTestCash(50.POUNDS, DUMMY_NOTARY, 5, 5, Random(0L)) // £50 = £50 - services.fillWithSomeTestCash(25.POUNDS, DUMMY_NOTARY, 5, 5, Random(0L)) // +£25 = £175 - services.fillWithSomeTestCash(500.SWISS_FRANCS, DUMMY_NOTARY, 10, 10, Random(0L)) // CHF500 = CHF500 - services.fillWithSomeTestCash(250.SWISS_FRANCS, DUMMY_NOTARY, 5, 5, Random(0L)) // +CHF250 = CHF750 + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 10, 10, Random(0L)) // +$100 = $200 + services.fillWithSomeTestCash(50.POUNDS, issuerServices, DUMMY_NOTARY, 5, 5, Random(0L)) // £50 = £50 + services.fillWithSomeTestCash(25.POUNDS, issuerServices, DUMMY_NOTARY, 5, 5, Random(0L)) // +£25 = £175 + services.fillWithSomeTestCash(500.SWISS_FRANCS, issuerServices, DUMMY_NOTARY, 10, 10, Random(0L)) // CHF500 = CHF500 + services.fillWithSomeTestCash(250.SWISS_FRANCS, issuerServices, DUMMY_NOTARY, 5, 5, Random(0L)) // +CHF250 = CHF750 } // structure query @@ -402,8 +391,8 @@ class HibernateConfigurationTest { @Test fun `calculate cash balance for single currency`() { database.transaction { - services.fillWithSomeTestCash(50.POUNDS, DUMMY_NOTARY, 5, 5, Random(0L)) // £50 = £50 - services.fillWithSomeTestCash(25.POUNDS, DUMMY_NOTARY, 5, 5, Random(0L)) // +£25 = £175 + services.fillWithSomeTestCash(50.POUNDS, issuerServices, DUMMY_NOTARY, 5, 5, Random(0L)) // £50 = £50 + services.fillWithSomeTestCash(25.POUNDS, issuerServices, DUMMY_NOTARY, 5, 5, Random(0L)) // +£25 = £175 } // structure query @@ -433,9 +422,9 @@ class HibernateConfigurationTest { fun `calculate and order by cash balance for owner and currency`() { database.transaction { - services.fillWithSomeTestCash(200.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L), issuedBy = BOC.ref(1), issuerKey = BOC_KEY) - services.fillWithSomeTestCash(300.POUNDS, DUMMY_NOTARY, 3, 3, Random(0L), issuedBy = DUMMY_CASH_ISSUER) - services.fillWithSomeTestCash(400.POUNDS, DUMMY_NOTARY, 4, 4, Random(0L), issuedBy = BOC.ref(2), issuerKey = BOC_KEY) + services.fillWithSomeTestCash(200.DOLLARS, issuerServices, DUMMY_NOTARY, 2, 2, Random(0L), issuedBy = BOC.ref(1)) + services.fillWithSomeTestCash(300.POUNDS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L), issuedBy = DUMMY_CASH_ISSUER) + services.fillWithSomeTestCash(400.POUNDS, issuerServices, DUMMY_NOTARY, 4, 4, Random(0L), issuedBy = BOC.ref(2)) } // structure query @@ -624,9 +613,9 @@ class HibernateConfigurationTest { hibernatePersister.persistStateWithSchema(dummyFungibleState, it.ref, SampleCashSchemaV3) } - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L), ownedBy = ALICE) - val cashStates = services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L), - issuedBy = BOB.ref(0), issuerKey = BOB_KEY, ownedBy = (BOB)).states + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 2, 2, Random(0L), ownedBy = ALICE) + val cashStates = services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 2, 2, Random(0L), + issuedBy = BOB.ref(0), ownedBy = (BOB)).states // persist additional cash states explicitly with V3 schema cashStates.forEach { val cashState = it.state.data @@ -660,7 +649,7 @@ class HibernateConfigurationTest { val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val contractState = it.contractState.deserialize>(storageKryo()) + val contractState = it.contractState.deserialize>() val cashState = contractState.data as Cash.State println("${it.stateRef} with owner: ${cashState.owner.owningKey.toBase58String()}") } @@ -705,8 +694,8 @@ class HibernateConfigurationTest { hibernatePersister.persistStateWithSchema(dummyFungibleState, it.ref, SampleCashSchemaV3) } - val moreCash = services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L), - issuedBy = BOB.ref(0), issuerKey = BOB_KEY, ownedBy = BOB).states + val moreCash = services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 2, 2, Random(0L), + issuedBy = BOB.ref(0), ownedBy = BOB).states // persist additional cash states explicitly with V3 schema moreCash.forEach { val cashState = it.state.data @@ -714,7 +703,7 @@ class HibernateConfigurationTest { hibernatePersister.persistStateWithSchema(dummyFungibleState, it.ref, SampleCashSchemaV3) } - val cashStates = services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L), ownedBy = (ALICE)).states + val cashStates = services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 2, 2, Random(0L), ownedBy = (ALICE)).states // persist additional cash states explicitly with V3 schema cashStates.forEach { val cashState = it.state.data @@ -744,7 +733,7 @@ class HibernateConfigurationTest { // execute query val queryResults = entityManager.createQuery(criteriaQuery).resultList queryResults.forEach { - val contractState = it.contractState.deserialize>(storageKryo()) + val contractState = it.contractState.deserialize>() val cashState = contractState.data as Cash.State println("${it.stateRef} with owner ${cashState.owner.owningKey.toBase58String()} and participants ${cashState.participants.map { it.owningKey.toBase58String() }}") } @@ -862,4 +851,26 @@ class HibernateConfigurationTest { assertThat(queryResults).hasSize(6) } + /** + * Test invoking SQL query using JDBC connection (session) + */ + @Test + fun `test calling an arbitrary JDBC native query`() { + // DOCSTART JdbcSession + val nativeQuery = "SELECT v.transaction_id, v.output_index FROM vault_states v WHERE v.state_status = 0" + + database.transaction { + val jdbcSession = database.createSession() + val prepStatement = jdbcSession.prepareStatement(nativeQuery) + val rs = prepStatement.executeQuery() + // DOCEND JdbcSession + var count = 0 + while (rs.next()) { + val stateRef = StateRef(SecureHash.parse(rs.getString(1)), rs.getInt(2)) + Assert.assertTrue(cashStates.map { it.ref }.contains(stateRef)) + count++ + } + Assert.assertEquals(cashStates.count(), count) + } + } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt b/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt index ec8f6a57d2..db05d880db 100644 --- a/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/database/RequeryConfigurationTest.kt @@ -4,59 +4,50 @@ import io.requery.Persistable import io.requery.kotlin.eq import io.requery.sql.KotlinEntityDataStore import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TransactionType -import net.corda.testing.contracts.DummyContract -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.testing.NullPublicKey -import net.corda.core.crypto.toBase58String +import net.corda.core.crypto.* import net.corda.core.identity.AnonymousParty import net.corda.core.node.services.Vault import net.corda.core.serialization.serialize -import net.corda.core.serialization.storageKryo import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_PUBKEY_1 import net.corda.node.services.persistence.DBTransactionStorage import net.corda.node.services.vault.schemas.requery.Models import net.corda.node.services.vault.schemas.requery.VaultCashBalancesEntity import net.corda.node.services.vault.schemas.requery.VaultSchema import net.corda.node.services.vault.schemas.requery.VaultStatesEntity +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction +import net.corda.testing.* +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue import org.junit.Before import org.junit.Test -import java.io.Closeable import java.time.Instant import java.util.* -class RequeryConfigurationTest { +class RequeryConfigurationTest : TestDependencyInjectionBase() { - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence lateinit var transactionStorage: DBTransactionStorage lateinit var requerySession: KotlinEntityDataStore @Before fun setUp() { val dataSourceProperties = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProperties) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) newTransactionStorage() newRequeryStorage(dataSourceProperties) } @After fun cleanUp() { - dataSource.close() + database.close() } @Test @@ -163,7 +154,7 @@ class RequeryConfigurationTest { val nativeQuery = "SELECT v.transaction_id, v.output_index FROM vault_states v WHERE v.state_status = 0" database.transaction { - val configuration = RequeryConfiguration(dataSourceProperties, true) + val configuration = RequeryConfiguration(dataSourceProperties, true, makeTestDatabaseProperties()) val jdbcSession = configuration.jdbcSession() val prepStatement = jdbcSession.prepareStatement(nativeQuery) val rs = prepStatement.executeQuery() @@ -180,7 +171,7 @@ class RequeryConfigurationTest { index = txnState.index stateStatus = Vault.StateStatus.UNCONSUMED contractStateClassName = DummyContract.SingleOwnerState::class.java.name - contractState = DummyContract.SingleOwnerState(owner = AnonymousParty(DUMMY_PUBKEY_1)).serialize(storageKryo()).bytes + contractState = DummyContract.SingleOwnerState(owner = AnonymousParty(MEGA_CORP_PUBKEY)).serialize().bytes notaryName = txn.tx.notary!!.name.toString() notaryKey = txn.tx.notary!!.owningKey.toBase58String() recordedTime = Instant.now() @@ -203,7 +194,7 @@ class RequeryConfigurationTest { private fun newRequeryStorage(dataSourceProperties: Properties) { database.transaction { - val configuration = RequeryConfiguration(dataSourceProperties, true) + val configuration = RequeryConfiguration(dataSourceProperties, true, makeTestDatabaseProperties()) requerySession = configuration.sessionForModel(Models.VAULT) } } @@ -215,10 +206,8 @@ class RequeryConfigurationTest { outputs = emptyList(), commands = emptyList(), notary = DUMMY_NOTARY, - signers = emptyList(), - type = TransactionType.General, timeWindow = null ) - return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1)))) + return SignedTransaction(wtx, listOf(TransactionSignature(ByteArray(1), ALICE_PUBKEY, SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID)))) } } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt index 46ef15d5c8..2480ad4cab 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/NodeSchedulerServiceTest.kt @@ -1,7 +1,7 @@ package net.corda.node.services.events import net.corda.core.contracts.* -import net.corda.core.days +import net.corda.core.utilities.days import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogicRef import net.corda.core.flows.FlowLogicRefFactory @@ -9,9 +9,7 @@ import net.corda.core.identity.AbstractParty import net.corda.core.node.ServiceHub import net.corda.core.node.services.VaultService import net.corda.core.serialization.SingletonSerializeAsToken -import net.corda.testing.ALICE_KEY -import net.corda.testing.DUMMY_CA -import net.corda.testing.DUMMY_NOTARY +import net.corda.core.transactions.TransactionBuilder import net.corda.node.services.MockServiceHubInternal import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.persistence.DBCheckpointStorage @@ -19,21 +17,18 @@ import net.corda.node.services.statemachine.FlowLogicRefFactoryImpl import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.vault.NodeVaultService import net.corda.node.utilities.AffinityExecutor +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction -import net.corda.testing.getTestX509Name +import net.corda.testing.* import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockKeyManagementService +import net.corda.testing.node.* import net.corda.testing.node.TestClock -import net.corda.testing.node.makeTestDataSourceProperties -import net.corda.testing.testNodeConfiguration import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable import java.nio.file.Paths import java.security.PublicKey import java.time.Clock @@ -54,8 +49,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { lateinit var scheduler: NodeSchedulerService lateinit var smmExecutor: AffinityExecutor.ServiceAffinityExecutor - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence lateinit var countDown: CountDownLatch lateinit var smmHasRemovedAllFlows: CountDownLatch @@ -72,13 +66,12 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { @Before fun setup() { + initialiseTestSerialization() countDown = CountDownLatch(1) smmHasRemovedAllFlows = CountDownLatch(1) calls = 0 val dataSourceProps = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProps) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(dataSourceProps, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) val identityService = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) val kms = MockKeyManagementService(identityService, ALICE_KEY) @@ -95,11 +88,11 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { overrideClock = testClock, keyManagement = kms, network = mockMessagingService), TestReference { - override val vaultService: VaultService = NodeVaultService(this, dataSourceProps) + override val vaultService: VaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties()) override val testReference = this@NodeSchedulerServiceTest } - scheduler = NodeSchedulerService(services, schedulerGatedExecutor) smmExecutor = AffinityExecutor.ServiceAffinityExecutor("test", 1) + scheduler = NodeSchedulerService(services, schedulerGatedExecutor, serverThread = smmExecutor) val mockSMM = StateMachineManager(services, DBCheckpointStorage(), smmExecutor, database) mockSMM.changes.subscribe { change -> if (change is StateMachineManager.Change.Removed && mockSMM.allStateMachines.isEmpty()) { @@ -120,7 +113,8 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { } smmExecutor.shutdown() smmExecutor.awaitTermination(60, TimeUnit.SECONDS) - dataSource.close() + database.close() + resetTestSerialization() } class TestState(val flowLogicRef: FlowLogicRef, val instant: Instant) : LinearState, SchedulableState { @@ -283,7 +277,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() { apply { val freshKey = services.keyManagementService.freshKey() val state = TestState(FlowLogicRefFactoryImpl.createForRPC(TestFlowLogic::class.java, increment), instant) - val builder = TransactionType.General.Builder(null).apply { + val builder = TransactionBuilder(null).apply { addOutputState(state, DUMMY_NOTARY) addCommand(Command(), freshKey) } diff --git a/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt b/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt index 6fec905bae..97ff1b7182 100644 --- a/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/events/ScheduledFlowTests.kt @@ -1,23 +1,27 @@ package net.corda.node.services.events import co.paralleluniverse.fibers.Suspendable +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyContract import net.corda.core.crypto.containsAny -import net.corda.core.flows.FlowInitiator -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowLogicRefFactory -import net.corda.core.flows.SchedulableFlow +import net.corda.core.flows.* import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.node.services.ServiceInfo -import net.corda.core.node.services.linearHeadsOfType -import net.corda.testing.DUMMY_NOTARY -import net.corda.flows.FinalityFlow +import net.corda.core.node.services.VaultQueryService +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.DEFAULT_PAGE_NUM +import net.corda.core.node.services.vault.PageSpecification +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria +import net.corda.core.node.services.vault.Sort +import net.corda.core.node.services.vault.SortAttribute +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.getOrThrow import net.corda.node.services.network.NetworkMapService import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.transactions.ValidatingNotaryService -import net.corda.node.utilities.transaction +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.MockNetwork import org.junit.After import org.junit.Assert.assertTrue @@ -28,6 +32,10 @@ import java.time.Instant import kotlin.test.assertEquals class ScheduledFlowTests { + companion object { + val PAGE_SIZE = 20 + val SORTING = Sort(listOf(Sort.SortColumn(SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF_TXN_ID), Sort.Direction.DESC))) + } lateinit var mockNet: MockNetwork lateinit var notaryNode: MockNetwork.MockNode lateinit var nodeA: MockNetwork.MockNode @@ -62,7 +70,7 @@ class ScheduledFlowTests { serviceHub.myInfo.legalIdentity, destination) val notary = serviceHub.networkMapCache.getAnyNotary() - val builder = TransactionType.General.Builder(notary) + val builder = TransactionBuilder(notary) builder.withItems(scheduledState) val tx = serviceHub.signInitialTransaction(builder) subFlow(FinalityFlow(tx, setOf(serviceHub.myInfo.legalIdentity))) @@ -82,7 +90,7 @@ class ScheduledFlowTests { require(!scheduledState.processed) { "State should not have been previously processed" } val notary = state.state.notary val newStateOutput = scheduledState.copy(processed = true) - val builder = TransactionType.General.Builder(notary) + val builder = TransactionBuilder(notary) builder.withItems(state, newStateOutput) val tx = serviceHub.signInitialTransaction(builder) subFlow(FinalityFlow(tx, setOf(scheduledState.source, scheduledState.destination))) @@ -107,9 +115,8 @@ class ScheduledFlowTests { @Test fun `create and run scheduled flow then wait for result`() { - val stateMachines = nodeA.smm.track() var countScheduledFlows = 0 - stateMachines.second.subscribe { + nodeA.smm.track().updates.subscribe { if (it is StateMachineManager.Change.Add) { val initiator = it.logic.stateMachine.flowInitiator if (initiator is FlowInitiator.Scheduled) @@ -119,10 +126,10 @@ class ScheduledFlowTests { nodeA.services.startFlow(InsertInitialStateFlow(nodeB.info.legalIdentity)) mockNet.waitQuiescent() val stateFromA = nodeA.database.transaction { - nodeA.services.vaultService.linearHeadsOfType().values.first() + nodeA.services.vaultQueryService.queryBy().states.single() } val stateFromB = nodeB.database.transaction { - nodeB.services.vaultService.linearHeadsOfType().values.first() + nodeB.services.vaultQueryService.queryBy().states.single() } assertEquals(1, countScheduledFlows) assertEquals(stateFromA, stateFromB, "Must be same copy on both nodes") @@ -132,19 +139,53 @@ class ScheduledFlowTests { @Test fun `run a whole batch of scheduled flows`() { val N = 100 + val futures = mutableListOf>() for (i in 0..N - 1) { - nodeA.services.startFlow(InsertInitialStateFlow(nodeB.info.legalIdentity)) - nodeB.services.startFlow(InsertInitialStateFlow(nodeA.info.legalIdentity)) + futures.add(nodeA.services.startFlow(InsertInitialStateFlow(nodeB.info.legalIdentity)).resultFuture) + futures.add(nodeB.services.startFlow(InsertInitialStateFlow(nodeA.info.legalIdentity)).resultFuture) } mockNet.waitQuiescent() - val statesFromA = nodeA.database.transaction { - nodeA.services.vaultService.linearHeadsOfType() + + // Check all of the flows completed successfully + futures.forEach { it.getOrThrow() } + + // Convert the states into maps to make error reporting easier + val statesFromA: List> = nodeA.database.transaction { + queryStatesWithPaging(nodeA.services.vaultQueryService) } - val statesFromB = nodeB.database.transaction { - nodeB.services.vaultService.linearHeadsOfType() + val statesFromB: List> = nodeB.database.transaction { + queryStatesWithPaging(nodeB.services.vaultQueryService) } assertEquals(2 * N, statesFromA.count(), "Expect all states to be present") + statesFromA.forEach { ref -> + if (ref !in statesFromB) { + throw IllegalStateException("State $ref is only present on node A.") + } + } + statesFromB.forEach { ref -> + if (ref !in statesFromA) { + throw IllegalStateException("State $ref is only present on node B.") + } + } assertEquals(statesFromA, statesFromB, "Expect identical data on both nodes") - assertTrue("Expect all states have run the scheduled task", statesFromB.values.all { it.state.data.processed }) + assertTrue("Expect all states have run the scheduled task", statesFromB.all { it.state.data.processed }) + } + + /** + * Query all states from the Vault, fetching results as a series of pages with ordered states in order to perform + * integration testing of that functionality. + * + * @return states ordered by the transaction ID. + */ + private fun queryStatesWithPaging(vaultQueryService: VaultQueryService): List> { + var pageNumber = DEFAULT_PAGE_NUM + val states = mutableListOf>() + do { + val pageSpec = PageSpecification(pageSize = PAGE_SIZE, pageNumber = pageNumber) + val results = vaultQueryService.queryBy(VaultQueryCriteria(), pageSpec, SORTING) + states.addAll(results.states) + pageNumber++ + } while ((pageSpec.pageSize * (pageNumber)) <= results.totalStatesAvailable) + return states.toList() } } diff --git a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt index d92bec8a22..0d19c27a65 100644 --- a/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/messaging/ArtemisMessagingTests.kt @@ -1,40 +1,37 @@ package net.corda.node.services.messaging import com.codahale.metrics.MetricRegistry -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.SettableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.generateKeyPair +import net.corda.core.internal.concurrent.doneFuture +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.RPCOps -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.ALICE -import net.corda.testing.LogHelper import net.corda.node.services.RPCUserService import net.corda.node.services.RPCUserServiceImpl +import net.corda.node.services.api.DEFAULT_SESSION_ID import net.corda.node.services.api.MonitoringService import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.configureWithDevSSLCertificate +import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.network.InMemoryNetworkMapCache import net.corda.node.services.network.NetworkMapService import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction -import net.corda.testing.freeLocalHostAndPort -import net.corda.testing.freePort +import net.corda.testing.* import net.corda.testing.node.MOCK_VERSION_INFO import net.corda.testing.node.makeTestDataSourceProperties -import net.corda.testing.testNodeConfiguration +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder -import java.io.Closeable import java.net.ServerSocket import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.TimeUnit.MILLISECONDS @@ -43,7 +40,7 @@ import kotlin.test.assertEquals import kotlin.test.assertNull //TODO This needs to be merged into P2PMessagingTest as that creates a more realistic environment -class ArtemisMessagingTests { +class ArtemisMessagingTests : TestDependencyInjectionBase() { @Rule @JvmField val temporaryFolder = TemporaryFolder() val serverPort = freePort() @@ -52,10 +49,9 @@ class ArtemisMessagingTests { val identity = generateKeyPair() lateinit var config: NodeConfiguration - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence lateinit var userService: RPCUserService - lateinit var networkMapRegistrationFuture: ListenableFuture + lateinit var networkMapRegistrationFuture: CordaFuture var messagingClient: NodeMessagingClient? = null var messagingServer: ArtemisMessagingServer? = null @@ -75,10 +71,8 @@ class ArtemisMessagingTests { baseDirectory = baseDirectory, myLegalName = ALICE.name) LogHelper.setLevel(PersistentUniquenessProvider::class) - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second - networkMapRegistrationFuture = Futures.immediateFuture(Unit) + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) + networkMapRegistrationFuture = doneFuture(Unit) } @After @@ -87,7 +81,7 @@ class ArtemisMessagingTests { messagingServer?.stop() messagingClient = null messagingServer = null - dataSource.close() + database.close() LogHelper.reset(PersistentUniquenessProvider::class) } @@ -142,7 +136,7 @@ class ArtemisMessagingTests { @Test fun `client should be able to send message to itself before network map is available, and receive after`() { - val settableFuture: SettableFuture = SettableFuture.create() + val settableFuture = openFuture() networkMapRegistrationFuture = settableFuture val receivedMessages = LinkedBlockingQueue() @@ -167,7 +161,7 @@ class ArtemisMessagingTests { fun `client should be able to send large numbers of messages to itself before network map is available and survive restart, then receive messages`() { // Crank the iteration up as high as you want... just takes longer to run. val iterations = 100 - networkMapRegistrationFuture = SettableFuture.create() + networkMapRegistrationFuture = openFuture() val receivedMessages = LinkedBlockingQueue() @@ -188,7 +182,7 @@ class ArtemisMessagingTests { messagingClient.stop() messagingServer?.stop() - networkMapRegistrationFuture = Futures.immediateFuture(Unit) + networkMapRegistrationFuture = doneFuture(Unit) createAndStartClientAndServer(receivedMessages) for (iter in 1..iterations) { diff --git a/node/src/test/kotlin/net/corda/node/services/network/AbstractNetworkMapServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/network/AbstractNetworkMapServiceTest.kt index 9c6df604e7..9d385513fb 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/AbstractNetworkMapServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/AbstractNetworkMapServiceTest.kt @@ -1,16 +1,12 @@ package net.corda.node.services.network -import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.getOrThrow +import net.corda.core.concurrent.CordaFuture import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.NodeInfo -import net.corda.core.node.services.DEFAULT_SESSION_ID import net.corda.core.node.services.ServiceInfo import net.corda.core.serialization.deserialize -import net.corda.testing.ALICE -import net.corda.testing.BOB -import net.corda.testing.CHARLIE -import net.corda.testing.DUMMY_MAP +import net.corda.core.utilities.getOrThrow +import net.corda.node.services.api.DEFAULT_SESSION_ID import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.messaging.send import net.corda.node.services.messaging.sendRequest @@ -23,20 +19,26 @@ import net.corda.node.services.network.NetworkMapService.Companion.PUSH_TOPIC import net.corda.node.services.network.NetworkMapService.Companion.QUERY_TOPIC import net.corda.node.services.network.NetworkMapService.Companion.REGISTER_TOPIC import net.corda.node.services.network.NetworkMapService.Companion.SUBSCRIPTION_TOPIC +import net.corda.node.services.transactions.SimpleNotaryService import net.corda.node.utilities.AddOrRemove import net.corda.node.utilities.AddOrRemove.ADD import net.corda.node.utilities.AddOrRemove.REMOVE +import net.corda.testing.ALICE +import net.corda.testing.BOB +import net.corda.testing.CHARLIE +import net.corda.testing.DUMMY_MAP import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode import org.assertj.core.api.Assertions.assertThat import org.bouncycastle.asn1.x500.X500Name -import org.eclipse.jetty.util.BlockingArrayQueue import org.junit.After import org.junit.Before import org.junit.Test import java.math.BigInteger import java.security.KeyPair import java.time.Instant +import java.util.* +import java.util.concurrent.LinkedBlockingQueue abstract class AbstractNetworkMapServiceTest { lateinit var mockNet: MockNetwork @@ -50,10 +52,11 @@ abstract class AbstractNetworkMapServiceTest @Before fun setup() { mockNet = MockNetwork(defaultFactory = nodeFactory) - mockNet.createTwoNodes(firstNodeName = DUMMY_MAP.name, secondNodeName = ALICE.name).apply { - mapServiceNode = first - alice = second - } + mapServiceNode = mockNet.createNode( + nodeFactory = nodeFactory, + legalName = DUMMY_MAP.name, + advertisedServices = *arrayOf(ServiceInfo(NetworkMapService.type), ServiceInfo(SimpleNotaryService.type))) + alice = mockNet.createNode(mapServiceNode.network.myAddress, nodeFactory = nodeFactory, legalName = ALICE.name) mockNet.runNetwork() lastSerial = System.currentTimeMillis() } @@ -63,7 +66,7 @@ abstract class AbstractNetworkMapServiceTest mockNet.stopNodes() } - protected abstract val nodeFactory: MockNetwork.Factory + protected abstract val nodeFactory: MockNetwork.Factory<*> protected abstract val networkMapService: S @@ -207,7 +210,7 @@ abstract class AbstractNetworkMapServiceTest private var lastSerial = Long.MIN_VALUE private fun MockNode.registration(addOrRemove: AddOrRemove, - serial: Long? = null): ListenableFuture { + serial: Long? = null): CordaFuture { val distinctSerial = if (serial == null) { ++lastSerial } else { @@ -222,9 +225,9 @@ abstract class AbstractNetworkMapServiceTest return response } - private fun MockNode.subscribe(): List { + private fun MockNode.subscribe(): Queue { val request = SubscribeRequest(true, network.myAddress) - val updates = BlockingArrayQueue() + val updates = LinkedBlockingQueue() services.networkService.addMessageHandler(PUSH_TOPIC, DEFAULT_SESSION_ID) { message, _ -> updates += message.data.deserialize() } @@ -248,7 +251,7 @@ abstract class AbstractNetworkMapServiceTest } private fun addNewNodeToNetworkMap(legalName: X500Name): MockNode { - val node = mockNet.createNode(networkMapAddress = mapServiceNode.network.myAddress, legalName = legalName) + val node = mockNet.createNode(mapServiceNode.network.myAddress, legalName = legalName) mockNet.runNetwork() lastSerial = System.currentTimeMillis() return node @@ -268,7 +271,7 @@ abstract class AbstractNetworkMapServiceTest } } - private object NoNMSNodeFactory : MockNetwork.Factory { + private object NoNMSNodeFactory : MockNetwork.Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, diff --git a/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt index 053d85bfaa..4507bd395a 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/InMemoryIdentityServiceTests.kt @@ -1,12 +1,16 @@ package net.corda.node.services.network -import net.corda.core.crypto.* +import net.corda.core.crypto.CertificateAndKeyPair +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.cert +import net.corda.core.crypto.generateKeyPair import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate import net.corda.core.node.services.IdentityService -import net.corda.flows.AnonymisedIdentity import net.corda.node.services.identity.InMemoryIdentityService +import net.corda.node.utilities.CertificateType +import net.corda.node.utilities.X509Utilities import net.corda.testing.* import org.bouncycastle.asn1.x500.X500Name import org.junit.Test @@ -25,13 +29,13 @@ class InMemoryIdentityServiceTests { // Nothing registered, so empty set assertNull(service.getAllIdentities().firstOrNull()) - service.registerIdentity(ALICE_IDENTITY) + service.verifyAndRegisterIdentity(ALICE_IDENTITY) var expected = setOf(ALICE) var actual = service.getAllIdentities().map { it.party }.toHashSet() assertEquals(expected, actual) // Add a second party and check we get both back - service.registerIdentity(BOB_IDENTITY) + service.verifyAndRegisterIdentity(BOB_IDENTITY) expected = setOf(ALICE, BOB) actual = service.getAllIdentities().map { it.party }.toHashSet() assertEquals(expected, actual) @@ -41,7 +45,7 @@ class InMemoryIdentityServiceTests { fun `get identity by key`() { val service = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) assertNull(service.partyFromKey(ALICE_PUBKEY)) - service.registerIdentity(ALICE_IDENTITY) + service.verifyAndRegisterIdentity(ALICE_IDENTITY) assertEquals(ALICE, service.partyFromKey(ALICE_PUBKEY)) assertNull(service.partyFromKey(BOB_PUBKEY)) } @@ -56,10 +60,10 @@ class InMemoryIdentityServiceTests { fun `get identity by substring match`() { val trustRoot = DUMMY_CA val service = InMemoryIdentityService(trustRoot = trustRoot.certificate) - service.registerIdentity(ALICE_IDENTITY) - service.registerIdentity(BOB_IDENTITY) + service.verifyAndRegisterIdentity(ALICE_IDENTITY) + service.verifyAndRegisterIdentity(BOB_IDENTITY) val alicente = getTestPartyAndCertificate(X500Name("O=Alicente Worldwide,L=London,C=GB"), generateKeyPair().public) - service.registerIdentity(alicente) + service.verifyAndRegisterIdentity(alicente) assertEquals(setOf(ALICE, alicente.party), service.partiesFromName("Alice", false)) assertEquals(setOf(ALICE), service.partiesFromName("Alice Corp", true)) assertEquals(setOf(BOB), service.partiesFromName("Bob Plc", true)) @@ -71,7 +75,7 @@ class InMemoryIdentityServiceTests { val identities = listOf("Node A", "Node B", "Node C") .map { getTestPartyAndCertificate(X500Name("CN=$it,O=R3,OU=corda,L=London,C=GB"), generateKeyPair().public) } assertNull(service.partyFromX500Name(identities.first().name)) - identities.forEach { service.registerIdentity(it) } + identities.forEach { service.verifyAndRegisterIdentity(it) } identities.forEach { assertEquals(it.party, service.partyFromX500Name(it.name)) } } @@ -80,16 +84,18 @@ class InMemoryIdentityServiceTests { */ @Test fun `assert unknown anonymous key is unrecognised`() { - val rootKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val rootCert = X509Utilities.createSelfSignedCACertificate(ALICE.name, rootKey) - val txKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val service = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) - // TODO: Generate certificate with an EdDSA key rather than ECDSA - val identity = Party(CertificateAndKeyPair(rootCert, rootKey)) - val txIdentity = AnonymousParty(txKey.public) + withTestSerialization { + val rootKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val rootCert = X509Utilities.createSelfSignedCACertificate(ALICE.name, rootKey) + val txKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val service = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) + // TODO: Generate certificate with an EdDSA key rather than ECDSA + val identity = Party(CertificateAndKeyPair(rootCert, rootKey)) + val txIdentity = AnonymousParty(txKey.public) - assertFailsWith { - service.assertOwnership(identity, txIdentity) + assertFailsWith { + service.assertOwnership(identity, txIdentity) + } } } @@ -98,50 +104,66 @@ class InMemoryIdentityServiceTests { * Also checks that incorrect associations are rejected. */ @Test - fun `assert ownership`() { + fun `get anonymous identity by key`() { val trustRoot = DUMMY_CA val (alice, aliceTxIdentity) = createParty(ALICE.name, trustRoot) - - val certFactory = CertificateFactory.getInstance("X509") - val bobRootKey = Crypto.generateKeyPair() - val bobRoot = getTestPartyAndCertificate(BOB.name, bobRootKey.public) - val bobRootCert = bobRoot.certificate - val bobTxKey = Crypto.generateKeyPair() - val bobTxCert = X509Utilities.createCertificate(CertificateType.IDENTITY, bobRootCert, bobRootKey, BOB.name, bobTxKey.public) - val bobCertPath = certFactory.generateCertPath(listOf(bobTxCert.cert, bobRootCert.cert)) - val bob = PartyAndCertificate(BOB.name, bobRootKey.public, bobRootCert, bobCertPath) + val (bob, bobTxIdentity) = createParty(ALICE.name, trustRoot) // Now we have identities, construct the service and let it know about both - val service = InMemoryIdentityService(setOf(alice, bob), emptyMap(), trustRoot.certificate.cert) - service.registerAnonymousIdentity(aliceTxIdentity.identity, alice.party, aliceTxIdentity.certPath) + val service = InMemoryIdentityService(setOf(alice), emptySet(), trustRoot.certificate.cert) + service.verifyAndRegisterIdentity(aliceTxIdentity) - val anonymousBob = AnonymousParty(bobTxKey.public) - service.registerAnonymousIdentity(anonymousBob, bob.party, bobCertPath) + var actual = service.certificateFromKey(aliceTxIdentity.party.owningKey) + assertEquals(aliceTxIdentity, actual!!) - // Verify that paths are verified - service.assertOwnership(alice.party, aliceTxIdentity.identity) - service.assertOwnership(bob.party, anonymousBob) - assertFailsWith { - service.assertOwnership(alice.party, anonymousBob) - } - assertFailsWith { - service.assertOwnership(bob.party, aliceTxIdentity.identity) - } + assertNull(service.certificateFromKey(bobTxIdentity.party.owningKey)) + service.verifyAndRegisterIdentity(bobTxIdentity) + actual = service.certificateFromKey(bobTxIdentity.party.owningKey) + assertEquals(bobTxIdentity, actual!!) + } - assertFailsWith { - val owningKey = Crypto.decodePublicKey(trustRoot.certificate.subjectPublicKeyInfo.encoded) - service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), aliceTxIdentity.identity) + /** + * Generate a pair of certificate paths from a root CA, down to a transaction key, store and verify the associations. + * Also checks that incorrect associations are rejected. + */ + @Test + fun `assert ownership`() { + withTestSerialization { + val trustRoot = DUMMY_CA + val (alice, anonymousAlice) = createParty(ALICE.name, trustRoot) + val (bob, anonymousBob) = createParty(BOB.name, trustRoot) + + // Now we have identities, construct the service and let it know about both + val service = InMemoryIdentityService(setOf(alice, bob), emptySet(), trustRoot.certificate.cert) + + service.verifyAndRegisterIdentity(anonymousAlice) + service.verifyAndRegisterIdentity(anonymousBob) + + // Verify that paths are verified + service.assertOwnership(alice.party, anonymousAlice.party.anonymise()) + service.assertOwnership(bob.party, anonymousBob.party.anonymise()) + assertFailsWith { + service.assertOwnership(alice.party, anonymousBob.party.anonymise()) + } + assertFailsWith { + service.assertOwnership(bob.party, anonymousAlice.party.anonymise()) + } + + assertFailsWith { + val owningKey = Crypto.decodePublicKey(trustRoot.certificate.subjectPublicKeyInfo.encoded) + service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), anonymousAlice.party.anonymise()) + } } } - private fun createParty(x500Name: X500Name, ca: CertificateAndKeyPair): Pair { + private fun createParty(x500Name: X500Name, ca: CertificateAndKeyPair): Pair { val certFactory = CertificateFactory.getInstance("X509") val issuerKeyPair = generateKeyPair() val issuer = getTestPartyAndCertificate(x500Name, issuerKeyPair.public, ca) val txKey = Crypto.generateKeyPair() val txCert = X509Utilities.createCertificate(CertificateType.IDENTITY, issuer.certificate, issuerKeyPair, x500Name, txKey.public) val txCertPath = certFactory.generateCertPath(listOf(txCert.cert) + issuer.certPath.certificates) - return Pair(issuer, AnonymisedIdentity(txCertPath, txCert, AnonymousParty(txKey.public))) + return Pair(issuer, PartyAndCertificate(Party(x500Name, txKey.public), txCert, txCertPath)) } /** diff --git a/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt b/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt index 948228acb8..0b4cdb9a72 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapCacheTest.kt @@ -1,19 +1,24 @@ package net.corda.node.services.network -import net.corda.core.getOrThrow import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.ServiceInfo +import net.corda.core.utilities.getOrThrow import net.corda.testing.ALICE import net.corda.testing.BOB -import net.corda.node.utilities.transaction import net.corda.testing.node.MockNetwork import org.junit.After +import org.junit.Before import org.junit.Test import java.math.BigInteger import kotlin.test.assertEquals class InMemoryNetworkMapCacheTest { - private val mockNet = MockNetwork() + lateinit var mockNet: MockNetwork + + @Before + fun setUp() { + mockNet = MockNetwork() + } @After fun teardown() { @@ -22,7 +27,9 @@ class InMemoryNetworkMapCacheTest { @Test fun registerWithNetwork() { - val (n0, n1) = mockNet.createTwoNodes() + val nodes = mockNet.createSomeNodes(1) + val n0 = nodes.mapNode + val n1 = nodes.partyNodes[0] val future = n1.services.networkMapCache.addMapService(n1.network, n0.network.myAddress, false, null) mockNet.runNetwork() future.getOrThrow() @@ -31,8 +38,8 @@ class InMemoryNetworkMapCacheTest { @Test fun `key collision`() { val entropy = BigInteger.valueOf(24012017L) - val nodeA = mockNet.createNode(null, -1, MockNetwork.DefaultFactory, true, ALICE.name, null, entropy, ServiceInfo(NetworkMapService.type)) - val nodeB = mockNet.createNode(null, -1, MockNetwork.DefaultFactory, true, BOB.name, null, entropy, ServiceInfo(NetworkMapService.type)) + val nodeA = mockNet.createNode(nodeFactory = MockNetwork.DefaultFactory, legalName = ALICE.name, entropyRoot = entropy, advertisedServices = ServiceInfo(NetworkMapService.type)) + val nodeB = mockNet.createNode(nodeFactory = MockNetwork.DefaultFactory, legalName = BOB.name, entropyRoot = entropy, advertisedServices = ServiceInfo(NetworkMapService.type)) assertEquals(nodeA.info.legalIdentity, nodeB.info.legalIdentity) mockNet.runNetwork() @@ -49,7 +56,9 @@ class InMemoryNetworkMapCacheTest { @Test fun `getNodeByLegalIdentity`() { - val (n0, n1) = mockNet.createTwoNodes() + val nodes = mockNet.createSomeNodes(1) + val n0 = nodes.mapNode + val n1 = nodes.partyNodes[0] val node0Cache: NetworkMapCache = n0.services.networkMapCache val expected = n1.info diff --git a/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapServiceTest.kt index 0185e4a47c..c6d8566560 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/InMemoryNetworkMapServiceTest.kt @@ -1,10 +1,9 @@ package net.corda.node.services.network -import net.corda.node.services.network.InMemoryNetworkMapService import net.corda.testing.node.MockNetwork class InMemoryNetworkMapServiceTest : AbstractNetworkMapServiceTest() { - override val nodeFactory: MockNetwork.Factory get() = MockNetwork.DefaultFactory + override val nodeFactory get() = MockNetwork.DefaultFactory override val networkMapService: InMemoryNetworkMapService get() = mapServiceNode.inNodeNetworkMapService as InMemoryNetworkMapService override fun swizzle() = Unit } diff --git a/node/src/test/kotlin/net/corda/node/services/network/PersistentNetworkMapServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/network/PersistentNetworkMapServiceTest.kt index c5af8af7ae..27dac2cd35 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/PersistentNetworkMapServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/PersistentNetworkMapServiceTest.kt @@ -4,7 +4,6 @@ import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.services.ServiceInfo import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.config.NodeConfiguration -import net.corda.node.utilities.transaction import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode import java.math.BigInteger @@ -16,7 +15,7 @@ import java.security.KeyPair */ class PersistentNetworkMapServiceTest : AbstractNetworkMapServiceTest() { - override val nodeFactory: MockNetwork.Factory get() = NodeFactory + override val nodeFactory: MockNetwork.Factory<*> get() = NodeFactory override val networkMapService: PersistentNetworkMapService get() = (mapServiceNode.inNodeNetworkMapService as SwizzleNetworkMapService).delegate @@ -27,7 +26,7 @@ class PersistentNetworkMapServiceTest : AbstractNetworkMapServiceTest { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt index 92180fde67..e060680d41 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBCheckpointStorageTests.kt @@ -2,20 +2,20 @@ package net.corda.node.services.persistence import com.google.common.primitives.Ints import net.corda.core.serialization.SerializedBytes -import net.corda.testing.LogHelper import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction +import net.corda.testing.LogHelper +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat -import org.assertj.core.api.Assertions.assertThatExceptionOfType -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable internal fun CheckpointStorage.checkpoints(): List { val checkpoints = mutableListOf() @@ -26,23 +26,20 @@ internal fun CheckpointStorage.checkpoints(): List { return checkpoints } -class DBCheckpointStorageTests { +class DBCheckpointStorageTests : TestDependencyInjectionBase() { lateinit var checkpointStorage: DBCheckpointStorage - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence @Before fun setUp() { LogHelper.setLevel(PersistentUniquenessProvider::class) - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) newCheckpointStorage() } @After fun cleanUp() { - dataSource.close() + database.close() LogHelper.reset(PersistentUniquenessProvider::class) } @@ -97,16 +94,6 @@ class DBCheckpointStorageTests { } } - @Test - fun `remove unknown checkpoint`() { - val checkpoint = newCheckpoint() - database.transaction { - assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { - checkpointStorage.removeCheckpoint(checkpoint) - } - } - } - @Test fun `add two checkpoints then remove first one`() { val firstCheckpoint = newCheckpoint() diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt index c98f62dfa3..ba59ceefa1 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DBTransactionStorageTests.kt @@ -1,45 +1,84 @@ package net.corda.node.services.persistence import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TransactionType -import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.Crypto import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.testing.NullPublicKey +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.node.services.VaultService +import net.corda.core.crypto.TransactionSignature +import net.corda.core.schemas.MappedSchema import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.WireTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.LogHelper +import net.corda.node.services.database.HibernateConfiguration +import net.corda.node.services.schema.HibernateObserver +import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.node.services.vault.NodeVaultService +import net.corda.node.services.vault.VaultSchemaV1 +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction +import net.corda.schemas.CashSchemaV1 +import net.corda.schemas.SampleCashSchemaV2 +import net.corda.schemas.SampleCashSchemaV3 +import net.corda.testing.* +import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable import java.util.concurrent.TimeUnit import kotlin.test.assertEquals -class DBTransactionStorageTests { - lateinit var dataSource: Closeable - lateinit var database: Database +class DBTransactionStorageTests : TestDependencyInjectionBase() { + lateinit var database: CordaPersistence lateinit var transactionStorage: DBTransactionStorage + lateinit var services: MockServices + val vault: VaultService get() = services.vaultService + // Hibernate configuration objects + lateinit var hibernateConfig: HibernateConfiguration @Before fun setUp() { LogHelper.setLevel(PersistentUniquenessProvider::class) - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + val dataSourceProps = makeTestDataSourceProperties() + + val transactionSchema = MappedSchema(schemaFamily = javaClass, version = 1, + mappedTypes = listOf(DBTransactionStorage.DBTransaction::class.java)) + + val customSchemas = setOf(VaultSchemaV1, CashSchemaV1, SampleCashSchemaV2, SampleCashSchemaV3, transactionSchema) + + database = configureDatabase(dataSourceProps, makeTestDatabaseProperties(), customSchemas, identitySvc = ::makeTestIdentityService) + + database.transaction { + + hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) + + services = object : MockServices(BOB_KEY) { + override val vaultService: VaultService get() { + val vaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties()) + hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig) + return vaultService + } + + override fun recordTransactions(txs: Iterable) { + for (stx in txs) { + validatedTransactions.addTransaction(stx) + } + // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. + vaultService.notifyAll(txs.map { it.tx }) + } + } + } newTransactionStorage() } @After fun cleanUp() { - dataSource.close() + database.close() LogHelper.reset(PersistentUniquenessProvider::class) } @@ -122,6 +161,37 @@ class DBTransactionStorageTests { } } + @Test + fun `transaction saved twice in same DB transaction scope`() { + val firstTransaction = newTransaction() + database.transaction { + transactionStorage.addTransaction(firstTransaction) + transactionStorage.addTransaction(firstTransaction) + } + assertTransactionIsRetrievable(firstTransaction) + database.transaction { + assertThat(transactionStorage.transactions).containsOnly(firstTransaction) + } + } + + @Test + fun `transaction saved twice in two DB transaction scopes`() { + val firstTransaction = newTransaction() + val secondTransaction = newTransaction() + database.transaction { + transactionStorage.addTransaction(firstTransaction) + } + + database.transaction { + transactionStorage.addTransaction(secondTransaction) + transactionStorage.addTransaction(firstTransaction) + } + assertTransactionIsRetrievable(firstTransaction) + database.transaction { + assertThat(transactionStorage.transactions).containsOnly(firstTransaction, secondTransaction) + } + } + @Test fun `updates are fired`() { val future = transactionStorage.updates.toFuture() @@ -152,10 +222,8 @@ class DBTransactionStorageTests { outputs = emptyList(), commands = emptyList(), notary = DUMMY_NOTARY, - signers = emptyList(), - type = TransactionType.General, timeWindow = null ) - return SignedTransaction(wtx.serialized, listOf(DigitalSignature.WithKey(NullPublicKey, ByteArray(1)))) + return SignedTransaction(wtx, listOf(TransactionSignature(ByteArray(1), ALICE_PUBKEY, SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID)))) } } diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/DataVendingServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/persistence/DataVendingServiceTests.kt index f912949506..3f1f0e909c 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/DataVendingServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/DataVendingServiceTests.kt @@ -4,17 +4,16 @@ import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.asset.Cash import net.corda.core.contracts.Amount import net.corda.core.contracts.Issued -import net.corda.core.contracts.TransactionType import net.corda.core.contracts.USD import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatedBy import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.SendTransactionFlow import net.corda.core.identity.Party -import net.corda.core.node.services.unconsumedStates +import net.corda.core.node.services.queryBy import net.corda.core.transactions.SignedTransaction -import net.corda.flows.BroadcastTransactionFlow.NotifyTxRequest +import net.corda.core.transactions.TransactionBuilder import net.corda.node.services.NotifyTransactionHandler -import net.corda.node.utilities.transaction import net.corda.testing.DUMMY_NOTARY import net.corda.testing.MEGA_CORP import net.corda.testing.node.MockNetwork @@ -43,24 +42,26 @@ class DataVendingServiceTests { @Test fun `notify of transaction`() { - val (vaultServiceNode, registerNode) = mockNet.createTwoNodes() + val nodes = mockNet.createSomeNodes(2) + val vaultServiceNode = nodes.partyNodes[0] + val registerNode = nodes.partyNodes[1] val beneficiary = vaultServiceNode.info.legalIdentity val deposit = registerNode.info.legalIdentity.ref(1) mockNet.runNetwork() // Generate an issuance transaction - val ptx = TransactionType.General.Builder(null) + val ptx = TransactionBuilder(null) Cash().generateIssue(ptx, Amount(100, Issued(deposit, USD)), beneficiary, DUMMY_NOTARY) // Complete the cash transaction, and then manually relay it val tx = registerNode.services.signInitialTransaction(ptx) vaultServiceNode.database.transaction { - assertThat(vaultServiceNode.services.vaultService.unconsumedStates()).isEmpty() + assertThat(vaultServiceNode.services.vaultQueryService.queryBy().states.isEmpty()) registerNode.sendNotifyTx(tx, vaultServiceNode) // Check the transaction is in the receiving node - val actual = vaultServiceNode.services.vaultService.unconsumedStates().singleOrNull() + val actual = vaultServiceNode.services.vaultQueryService.queryBy().states.singleOrNull() val expected = tx.tx.outRef(0) assertEquals(expected, actual) } @@ -71,24 +72,26 @@ class DataVendingServiceTests { */ @Test fun `notify failure`() { - val (vaultServiceNode, registerNode) = mockNet.createTwoNodes() + val nodes = mockNet.createSomeNodes(2) + val vaultServiceNode = nodes.partyNodes[0] + val registerNode = nodes.partyNodes[1] val beneficiary = vaultServiceNode.info.legalIdentity val deposit = MEGA_CORP.ref(1) mockNet.runNetwork() // Generate an issuance transaction - val ptx = TransactionType.General.Builder(DUMMY_NOTARY) + val ptx = TransactionBuilder(DUMMY_NOTARY) Cash().generateIssue(ptx, Amount(100, Issued(deposit, USD)), beneficiary, DUMMY_NOTARY) // The transaction tries issuing MEGA_CORP cash, but we aren't the issuer, so it's invalid val tx = registerNode.services.signInitialTransaction(ptx) vaultServiceNode.database.transaction { - assertThat(vaultServiceNode.services.vaultService.unconsumedStates()).isEmpty() + assertThat(vaultServiceNode.services.vaultQueryService.queryBy().states.isEmpty()) registerNode.sendNotifyTx(tx, vaultServiceNode) // Check the transaction is not in the receiving node - assertThat(vaultServiceNode.services.vaultService.unconsumedStates()).isEmpty() + assertThat(vaultServiceNode.services.vaultQueryService.queryBy().states.isEmpty()) } } @@ -99,9 +102,9 @@ class DataVendingServiceTests { } @InitiatingFlow - private class NotifyTxFlow(val otherParty: Party, val stx: SignedTransaction) : FlowLogic() { + private class NotifyTxFlow(val otherParty: Party, val stx: SignedTransaction) : FlowLogic() { @Suspendable - override fun call() = send(otherParty, NotifyTxRequest(stx)) + override fun call() = subFlow(SendTransactionFlow(otherParty, stx)) } @InitiatedBy(NotifyTxFlow::class) diff --git a/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt b/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt index f50738f6c3..287c4638cb 100644 --- a/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/persistence/NodeAttachmentStorageTest.kt @@ -5,22 +5,22 @@ import com.google.common.jimfs.Configuration import com.google.common.jimfs.Jimfs import net.corda.core.crypto.SecureHash import net.corda.core.crypto.sha256 -import net.corda.core.read -import net.corda.core.readAll +import net.corda.core.internal.read +import net.corda.core.internal.readAll import net.corda.testing.LogHelper -import net.corda.core.write -import net.corda.core.writeLines +import net.corda.core.internal.write +import net.corda.core.internal.writeLines import net.corda.node.services.database.RequeryConfiguration import net.corda.node.services.persistence.schemas.requery.AttachmentEntity import net.corda.node.services.transactions.PersistentUniquenessProvider +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction import net.corda.testing.node.makeTestDataSourceProperties -import org.jetbrains.exposed.sql.Database +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable import java.nio.charset.Charset import java.nio.file.FileAlreadyExistsException import java.nio.file.FileSystem @@ -35,8 +35,7 @@ import kotlin.test.assertNull class NodeAttachmentStorageTest { // Use an in memory file system for testing attachment storage. lateinit var fs: FileSystem - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence lateinit var dataSourceProperties: Properties lateinit var configuration: RequeryConfiguration @@ -45,17 +44,15 @@ class NodeAttachmentStorageTest { LogHelper.setLevel(PersistentUniquenessProvider::class) dataSourceProperties = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProperties) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) - configuration = RequeryConfiguration(dataSourceProperties) + configuration = RequeryConfiguration(dataSourceProperties, databaseProperties = makeTestDatabaseProperties()) fs = Jimfs.newFileSystem(Configuration.unix()) } @After fun tearDown() { - dataSource.close() + database.close() } @Test @@ -64,7 +61,7 @@ class NodeAttachmentStorageTest { val expectedHash = testJar.readAll().sha256() database.transaction { - val storage = NodeAttachmentService(fs.getPath("/"), dataSourceProperties, MetricRegistry()) + val storage = NodeAttachmentService(dataSourceProperties, MetricRegistry(), makeTestDatabaseProperties()) val id = testJar.read { storage.importAttachment(it) } assertEquals(expectedHash, id) @@ -90,7 +87,7 @@ class NodeAttachmentStorageTest { fun `duplicates not allowed`() { val testJar = makeTestJar() database.transaction { - val storage = NodeAttachmentService(fs.getPath("/"), dataSourceProperties, MetricRegistry()) + val storage = NodeAttachmentService(dataSourceProperties, MetricRegistry(), makeTestDatabaseProperties()) testJar.read { storage.importAttachment(it) } @@ -106,7 +103,7 @@ class NodeAttachmentStorageTest { fun `corrupt entry throws exception`() { val testJar = makeTestJar() database.transaction { - val storage = NodeAttachmentService(fs.getPath("/"), dataSourceProperties, MetricRegistry()) + val storage = NodeAttachmentService(dataSourceProperties, MetricRegistry(), makeTestDatabaseProperties()) val id = testJar.read { storage.importAttachment(it) } // Corrupt the file in the store. @@ -134,7 +131,7 @@ class NodeAttachmentStorageTest { @Test fun `non jar rejected`() { database.transaction { - val storage = NodeAttachmentService(fs.getPath("/"), dataSourceProperties, MetricRegistry()) + val storage = NodeAttachmentService(dataSourceProperties, MetricRegistry(), makeTestDatabaseProperties()) val path = fs.getPath("notajar") path.writeLines(listOf("Hey", "there!")) path.read { diff --git a/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt b/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt index 6cb9b6c47c..f81aac2b48 100644 --- a/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/schema/HibernateObserverTests.kt @@ -10,38 +10,38 @@ import net.corda.core.schemas.QueryableState import net.corda.testing.LogHelper import net.corda.node.services.api.SchemaService import net.corda.node.services.database.HibernateConfiguration +import net.corda.node.services.identity.InMemoryIdentityService +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction +import net.corda.testing.DUMMY_CA import net.corda.testing.MEGA_CORP +import net.corda.testing.MOCK_IDENTITIES import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.hibernate.annotations.Cascade import org.hibernate.annotations.CascadeType -import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.transactions.TransactionManager import org.junit.After import org.junit.Before import org.junit.Test import rx.subjects.PublishSubject -import java.io.Closeable import javax.persistence.* import kotlin.test.assertEquals class HibernateObserverTests { - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence @Before fun setUp() { LogHelper.setLevel(HibernateObserver::class) - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) } @After fun cleanUp() { - dataSource.close() + database.close() LogHelper.reset(HibernateObserver::class) } @@ -91,7 +91,7 @@ class HibernateObserverTests { @Test fun testChildObjectsArePersisted() { val testSchema = object : MappedSchema(SchemaFamily::class.java, 1, setOf(Parent::class.java, Child::class.java)) {} - val rawUpdatesPublisher = PublishSubject.create() + val rawUpdatesPublisher = PublishSubject.create>() val schemaService = object : SchemaService { override val schemaOptions: Map = emptyMap() @@ -106,7 +106,7 @@ class HibernateObserverTests { } @Suppress("UNUSED_VARIABLE") - val observer = HibernateObserver(rawUpdatesPublisher, HibernateConfiguration(schemaService)) + val observer = HibernateObserver(rawUpdatesPublisher, HibernateConfiguration(schemaService, makeTestDatabaseProperties(), ::makeTestIdentityService)) database.transaction { rawUpdatesPublisher.onNext(Vault.Update(emptySet(), setOf(StateAndRef(TransactionState(TestState(), MEGA_CORP), StateRef(SecureHash.sha256("dummy"), 0))))) val parentRowCountResult = TransactionManager.current().connection.prepareStatement("select count(*) from Parents").executeQuery() diff --git a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt index f83de31b69..41118d6f83 100644 --- a/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/statemachine/FlowFrameworkTests.kt @@ -2,51 +2,44 @@ package net.corda.node.services.statemachine import co.paralleluniverse.fibers.Fiber import co.paralleluniverse.fibers.Suspendable -import com.google.common.util.concurrent.ListenableFuture -import net.corda.contracts.asset.Cash -import net.corda.core.* +import co.paralleluniverse.strands.concurrent.Semaphore +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.ContractState import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.StateAndRef -import net.corda.testing.contracts.DummyState -import net.corda.core.crypto.SecureHash import net.corda.core.crypto.generateKeyPair import net.corda.core.crypto.random63BitValue -import net.corda.core.flows.FlowException -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.FlowSessionException -import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.* import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.flatMap +import net.corda.core.internal.concurrent.map import net.corda.core.messaging.MessageRecipients import net.corda.core.node.services.PartyInfo import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.queryBy -import net.corda.core.node.services.unconsumedStates -import net.corda.core.utilities.OpaqueBytes import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder -import net.corda.testing.LogHelper +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker.Change +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.unwrap import net.corda.flows.CashIssueFlow import net.corda.flows.CashPaymentFlow -import net.corda.flows.FinalityFlow -import net.corda.flows.NotaryFlow import net.corda.node.internal.InitiatedFlowFactory +import net.corda.node.services.network.NetworkMapService import net.corda.node.services.persistence.checkpoints import net.corda.node.services.transactions.ValidatingNotaryService -import net.corda.node.utilities.transaction -import net.corda.testing.expect -import net.corda.testing.expectEvents -import net.corda.testing.getTestX509Name +import net.corda.testing.* +import net.corda.testing.contracts.DummyState import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.InMemoryMessagingNetwork.MessageTransfer import net.corda.testing.node.InMemoryMessagingNetwork.ServicePeerAllocationStrategy.RoundRobin import net.corda.testing.node.MockNetwork import net.corda.testing.node.MockNetwork.MockNode -import net.corda.testing.sequence import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.AssertionsForClassTypes.assertThatExceptionOfType @@ -70,7 +63,7 @@ class FlowFrameworkTests { } private val mockNet = MockNetwork(servicePeerAllocationStrategy = RoundRobin()) - private val sessionTransfers = ArrayList() + private val receivedSessionMessages = ArrayList() private lateinit var node1: MockNode private lateinit var node2: MockNode private lateinit var notary1: MockNode @@ -78,24 +71,33 @@ class FlowFrameworkTests { @Before fun start() { - val nodes = mockNet.createTwoNodes() - node1 = nodes.first - node2 = nodes.second + node1 = mockNet.createNode(advertisedServices = ServiceInfo(NetworkMapService.type)) + node2 = mockNet.createNode(networkMapAddress = node1.network.myAddress) + // We intentionally create our own notary and ignore the one provided by the network val notaryKeyPair = generateKeyPair() val notaryService = ServiceInfo(ValidatingNotaryService.type, getTestX509Name("notary-service-2000")) val overrideServices = mapOf(Pair(notaryService, notaryKeyPair)) // Note that these notaries don't operate correctly as they don't share their state. They are only used for testing // service addressing. - notary1 = mockNet.createNotaryNode(networkMapAddr = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name) - notary2 = mockNet.createNotaryNode(networkMapAddr = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name) + notary1 = mockNet.createNotaryNode(networkMapAddress = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name) + notary2 = mockNet.createNotaryNode(networkMapAddress = node1.network.myAddress, overrideServices = overrideServices, serviceName = notaryService.name) - mockNet.messagingNetwork.receivedMessages.toSessionTransfers().forEach { sessionTransfers += it } + receivedSessionMessagesObservable().forEach { receivedSessionMessages += it } mockNet.runNetwork() + + // We don't create a network map, so manually handle registrations + val nodes = listOf(node1, node2, notary1, notary2) + nodes.forEach { node -> + nodes.map { it.services.myInfo.legalIdentityAndCert }.forEach { identity -> + node.services.identityService.verifyAndRegisterIdentity(identity) + } + } } @After fun cleanUp() { mockNet.stopNodes() + receivedSessionMessages.clear() } @Test @@ -166,7 +168,7 @@ class FlowFrameworkTests { node3.disableDBCloseOnStop() node3.stop() - node3 = mockNet.createNode(node1.network.myAddress, forcedID = node3.id) + node3 = mockNet.createNode(node1.network.myAddress, node3.id) val restoredFlow = node3.getSingleFlow().first assertEquals(false, restoredFlow.flowStarted) // Not started yet as no network activity has been allowed yet mockNet.runNetwork() // Allow network map messages to flow @@ -176,7 +178,7 @@ class FlowFrameworkTests { node3.stop() // Now it is completed the flow should leave no Checkpoint. - node3 = mockNet.createNode(node1.network.myAddress, forcedID = node3.id) + node3 = mockNet.createNode(node1.network.myAddress, node3.id) mockNet.runNetwork() // Allow network map messages to flow node3.smm.executor.flush() assertTrue(node3.smm.findStateMachines(NoOpFlow::class.java).isEmpty()) @@ -227,7 +229,7 @@ class FlowFrameworkTests { node2b.smm.executor.flush() fut1.getOrThrow() - val receivedCount = sessionTransfers.count { it.isPayloadTransfer } + val receivedCount = receivedSessionMessages.count { it.isPayloadTransfer } // Check flows completed cleanly and didn't get out of phase assertEquals(4, receivedCount, "Flow should have exchanged 4 unique messages")// Two messages each way // can't give a precise value as every addMessageHandler re-runs the undelivered messages @@ -259,15 +261,15 @@ class FlowFrameworkTests { assertThat(node3Flow.receivedPayloads[0]).isEqualTo(payload) assertSessionTransfers(node2, - node1 sent sessionInit(SendFlow::class, 1, payload) to node2, - node2 sent sessionConfirm to node1, + node1 sent sessionInit(SendFlow::class, payload = payload) to node2, + node2 sent sessionConfirm() to node1, node1 sent normalEnd to node2 //There's no session end from the other flows as they're manually suspended ) assertSessionTransfers(node3, - node1 sent sessionInit(SendFlow::class, 1, payload) to node3, - node3 sent sessionConfirm to node1, + node1 sent sessionInit(SendFlow::class, payload = payload) to node3, + node3 sent sessionConfirm() to node1, node1 sent normalEnd to node3 //There's no session end from the other flows as they're manually suspended ) @@ -293,14 +295,14 @@ class FlowFrameworkTests { assertSessionTransfers(node2, node1 sent sessionInit(ReceiveFlow::class) to node2, - node2 sent sessionConfirm to node1, + node2 sent sessionConfirm() to node1, node2 sent sessionData(node2Payload) to node1, node2 sent normalEnd to node1 ) assertSessionTransfers(node3, node1 sent sessionInit(ReceiveFlow::class) to node3, - node3 sent sessionConfirm to node1, + node3 sent sessionConfirm() to node1, node3 sent sessionData(node3Payload) to node1, node3 sent normalEnd to node1 ) @@ -313,12 +315,13 @@ class FlowFrameworkTests { mockNet.runNetwork() assertSessionTransfers( - node1 sent sessionInit(PingPongFlow::class, 1, 10L) to node2, - node2 sent sessionConfirm to node1, + node1 sent sessionInit(PingPongFlow::class, payload = 10L) to node2, + node2 sent sessionConfirm() to node1, node2 sent sessionData(20L) to node1, node1 sent sessionData(11L) to node2, node2 sent sessionData(21L) to node1, - node1 sent normalEnd to node2 + node1 sent normalEnd to node2, + node2 sent normalEnd to node1 ) } @@ -333,8 +336,9 @@ class FlowFrameworkTests { anonymous = false)) // We pay a couple of times, the notary picking should go round robin for (i in 1..3) { - node1.services.startFlow(CashPaymentFlow(500.DOLLARS, node2.info.legalIdentity, anonymous = false)) + val flow = node1.services.startFlow(CashPaymentFlow(500.DOLLARS, node2.info.legalIdentity, anonymous = false)) mockNet.runNetwork() + flow.resultFuture.getOrThrow() } val endpoint = mockNet.messagingNetwork.endpoint(notary1.network.myAddress as InMemoryMessagingNetwork.PeerHandle)!! val party1Info = notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!! @@ -342,10 +346,10 @@ class FlowFrameworkTests { val notary1Address: MessageRecipients = endpoint.getAddressOfParty(notary1.services.networkMapCache.getPartyInfo(notary1.info.notaryIdentity)!!) assertThat(notary1Address).isInstanceOf(InMemoryMessagingNetwork.ServiceHandle::class.java) assertEquals(notary1Address, endpoint.getAddressOfParty(notary2.services.networkMapCache.getPartyInfo(notary2.info.notaryIdentity)!!)) - sessionTransfers.expectEvents(isStrict = false) { + receivedSessionMessages.expectEvents(isStrict = false) { sequence( // First Pay - expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java }) { + expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java.name }) { it.message as SessionInit assertEquals(node1.id, it.from) assertEquals(notary1Address, it.to) @@ -355,7 +359,7 @@ class FlowFrameworkTests { assertEquals(notary1.id, it.from) }, // Second pay - expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java }) { + expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java.name }) { it.message as SessionInit assertEquals(node1.id, it.from) assertEquals(notary1Address, it.to) @@ -365,7 +369,7 @@ class FlowFrameworkTests { assertEquals(notary2.id, it.from) }, // Third pay - expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java }) { + expect(match = { it.message is SessionInit && it.message.initiatingFlowClass == NotaryFlow.Client::class.java.name }) { it.message as SessionInit assertEquals(node1.id, it.from) assertEquals(notary1Address, it.to) @@ -383,11 +387,37 @@ class FlowFrameworkTests { node2.registerFlowFactory(ReceiveFlow::class) { NoOpFlow() } val resultFuture = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture mockNet.runNetwork() - assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { + assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { resultFuture.getOrThrow() }.withMessageContaining(String::class.java.name) // Make sure the exception message mentions the type the flow was expecting to receive } + @Test + fun `receiving unexpected session end before entering sendAndReceive`() { + node2.registerFlowFactory(WaitForOtherSideEndBeforeSendAndReceive::class) { NoOpFlow() } + val sessionEndReceived = Semaphore(0) + receivedSessionMessagesObservable().filter { it.message is SessionEnd }.subscribe { sessionEndReceived.release() } + val resultFuture = node1.services.startFlow( + WaitForOtherSideEndBeforeSendAndReceive(node2.info.legalIdentity, sessionEndReceived)).resultFuture + mockNet.runNetwork() + assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { + resultFuture.getOrThrow() + } + } + + @InitiatingFlow + private class WaitForOtherSideEndBeforeSendAndReceive(val otherParty: Party, + @Transient val receivedOtherFlowEnd: Semaphore) : FlowLogic() { + @Suspendable + override fun call() { + // Kick off the flow on the other side ... + send(otherParty, 1) + // ... then pause this one until it's received the session-end message from the other side + receivedOtherFlowEnd.acquire() + sendAndReceive(otherParty, 2) + } + } + @Test fun `non-FlowException thrown on other side`() { val erroringFlowFuture = node2.registerFlowFactory(ReceiveFlow::class) { @@ -406,7 +436,7 @@ class FlowFrameworkTests { Notification.createOnError(erroringFlowFuture.get().exceptionThrown) ) - val receiveFlowException = assertFailsWith(FlowSessionException::class) { + val receiveFlowException = assertFailsWith(UnexpectedFlowEndException::class) { receiveFlowResult.getOrThrow() } assertThat(receiveFlowException.message).doesNotContain("evil bug!") @@ -417,7 +447,7 @@ class FlowFrameworkTests { assertSessionTransfers( node1 sent sessionInit(ReceiveFlow::class) to node2, - node2 sent sessionConfirm to node1, + node2 sent sessionConfirm() to node1, node2 sent erroredEnd() to node1 ) } @@ -450,11 +480,11 @@ class FlowFrameworkTests { assertSessionTransfers( node1 sent sessionInit(ReceiveFlow::class) to node2, - node2 sent sessionConfirm to node1, + node2 sent sessionConfirm() to node1, node2 sent erroredEnd(erroringFlow.get().exceptionThrown) to node1 ) // Make sure the original stack trace isn't sent down the wire - assertThat((sessionTransfers.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty() + assertThat((receivedSessionMessages.last().message as ErrorSessionEnd).errorResponse!!.stackTrace).isEmpty() } @Test @@ -492,13 +522,13 @@ class FlowFrameworkTests { node1Fiber.resultFuture.getOrThrow() } val node2ResultFuture = node2Fiber.getOrThrow().resultFuture - assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { + assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { node2ResultFuture.getOrThrow() } assertSessionTransfers(node2, node1 sent sessionInit(ReceiveFlow::class) to node2, - node2 sent sessionConfirm to node1, + node2 sent sessionConfirm() to node1, node2 sent sessionData("Hello") to node1, node1 sent erroredEnd() to node2 ) @@ -545,7 +575,7 @@ class FlowFrameworkTests { node2.registerFlowFactory(ReceiveFlow::class) { SendFlow(NonSerialisableData(1), it) } val result = node1.services.startFlow(ReceiveFlow(node2.info.legalIdentity)).resultFuture mockNet.runNetwork() - assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { + assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { result.getOrThrow() } } @@ -587,18 +617,11 @@ class FlowFrameworkTests { } val waiter = node2.services.startFlow(WaitingFlows.Waiter(stx, node1.info.legalIdentity)).resultFuture mockNet.runNetwork() - assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { + assertThatExceptionOfType(UnexpectedFlowEndException::class.java).isThrownBy { waiter.getOrThrow() } } - @Test - fun `lazy db iterator left on stack during checkpointing`() { - val result = node2.services.startFlow(VaultAccessFlow()).resultFuture - mockNet.runNetwork() - assertThatThrownBy { result.getOrThrow() }.hasMessageContaining("Vault").hasMessageContaining("private method") - } - @Test fun `verify vault query service is tokenizable by force checkpointing within a flow`() { val ptx = TransactionBuilder(notary = notary1.info.notaryIdentity) @@ -632,31 +655,62 @@ class FlowFrameworkTests { } @Test - fun `upgraded flow`() { - node1.services.startFlow(UpgradedFlow(node2.info.legalIdentity)) + fun `upgraded initiating flow`() { + node2.registerFlowFactory(UpgradedFlow::class, initiatedFlowVersion = 1) { SendFlow("Old initiated", it) } + val result = node1.services.startFlow(UpgradedFlow(node2.info.legalIdentity)).resultFuture mockNet.runNetwork() - assertThat(sessionTransfers).startsWith( - node1 sent sessionInit(UpgradedFlow::class, 2) to node2 + assertThat(receivedSessionMessages).startsWith( + node1 sent sessionInit(UpgradedFlow::class, flowVersion = 2) to node2, + node2 sent sessionConfirm(flowVersion = 1) to node1 ) + val (receivedPayload, node2FlowVersion) = result.getOrThrow() + assertThat(receivedPayload).isEqualTo("Old initiated") + assertThat(node2FlowVersion).isEqualTo(1) } @Test - fun `unsupported new flow version`() { - node2.internalRegisterFlowFactory( - UpgradedFlow::class.java, - InitiatedFlowFactory.CorDapp(version = 1, factory = ::DoubleInlinedSubFlow), - DoubleInlinedSubFlow::class.java, - track = false) - val result = node1.services.startFlow(UpgradedFlow(node2.info.legalIdentity)).resultFuture + fun `upgraded initiated flow`() { + node2.registerFlowFactory(SendFlow::class, initiatedFlowVersion = 2) { UpgradedFlow(it) } + val initiatingFlow = SendFlow("Old initiating", node2.info.legalIdentity) + node1.services.startFlow(initiatingFlow) mockNet.runNetwork() - assertThatExceptionOfType(FlowSessionException::class.java).isThrownBy { - result.getOrThrow() - }.withMessageContaining("Version") + assertThat(receivedSessionMessages).startsWith( + node1 sent sessionInit(SendFlow::class, flowVersion = 1, payload = "Old initiating") to node2, + node2 sent sessionConfirm(flowVersion = 2) to node1 + ) + assertThat(initiatingFlow.getFlowContext(node2.info.legalIdentity).flowVersion).isEqualTo(2) + } + + @Test + fun `unregistered flow`() { + val future = node1.services.startFlow(SendFlow("Hello", node2.info.legalIdentity)).resultFuture + mockNet.runNetwork() + assertThatExceptionOfType(UnexpectedFlowEndException::class.java) + .isThrownBy { future.getOrThrow() } + .withMessageEndingWith("${SendFlow::class.java.name} is not registered") + } + + @Test + fun `unknown class in session init`() { + node1.sendSessionMessage(SessionInit(random63BitValue(), "not.a.real.Class", 1, "version", null), node2) + mockNet.runNetwork() + assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected + val reject = receivedSessionMessages.last().message as SessionReject + assertThat(reject.errorMessage).isEqualTo("Don't know not.a.real.Class") + } + + @Test + fun `non-flow class in session init`() { + node1.sendSessionMessage(SessionInit(random63BitValue(), String::class.java.name, 1, "version", null), node2) + mockNet.runNetwork() + assertThat(receivedSessionMessages).hasSize(2) // Only the session-init and session-reject are expected + val reject = receivedSessionMessages.last().message as SessionReject + assertThat(reject.errorMessage).isEqualTo("${String::class.java.name} is not a flow") } @Test fun `single inlined sub-flow`() { - node2.registerFlowFactory(SendAndReceiveFlow::class, ::SingleInlinedSubFlow) + node2.registerFlowFactory(SendAndReceiveFlow::class) { SingleInlinedSubFlow(it) } val result = node1.services.startFlow(SendAndReceiveFlow(node2.info.legalIdentity, "Hello")).resultFuture mockNet.runNetwork() assertThat(result.getOrThrow()).isEqualTo("HelloHello") @@ -664,7 +718,7 @@ class FlowFrameworkTests { @Test fun `double inlined sub-flow`() { - node2.registerFlowFactory(SendAndReceiveFlow::class, ::DoubleInlinedSubFlow) + node2.registerFlowFactory(SendAndReceiveFlow::class) { DoubleInlinedSubFlow(it) } val result = node1.services.startFlow(SendAndReceiveFlow(node2.info.legalIdentity, "Hello")).resultFuture mockNet.runNetwork() assertThat(result.getOrThrow()).isEqualTo("HelloHello") @@ -684,36 +738,44 @@ class FlowFrameworkTests { return newNode.getSingleFlow

().first } - private inline fun > MockNode.getSingleFlow(): Pair> { + private inline fun > MockNode.getSingleFlow(): Pair> { return smm.findStateMachines(P::class.java).single() } private inline fun > MockNode.registerFlowFactory( - initiatingFlowClass: KClass>, - noinline flowFactory: (Party) -> P): ListenableFuture

+ initiatingFlowClass: KClass>, + initiatedFlowVersion: Int = 1, + noinline flowFactory: (Party) -> P): CordaFuture

{ - val observable = internalRegisterFlowFactory(initiatingFlowClass.java, object : InitiatedFlowFactory

{ - override fun createFlow(platformVersion: Int, otherParty: Party, sessionInit: SessionInit): P { - return flowFactory(otherParty) - } - }, P::class.java, track = true) + val observable = internalRegisterFlowFactory( + initiatingFlowClass.java, + InitiatedFlowFactory.CorDapp(initiatedFlowVersion, "", flowFactory), + P::class.java, + track = true) return observable.toFuture() } private fun sessionInit(clientFlowClass: KClass>, flowVersion: Int = 1, payload: Any? = null): SessionInit { - return SessionInit(0, clientFlowClass.java, flowVersion, payload) + return SessionInit(0, clientFlowClass.java.name, flowVersion, "", payload) } - private val sessionConfirm = SessionConfirm(0, 0) + private fun sessionConfirm(flowVersion: Int = 1) = SessionConfirm(0, 0, flowVersion, "") private fun sessionData(payload: Any) = SessionData(0, payload) private val normalEnd = NormalSessionEnd(0) private fun erroredEnd(errorResponse: FlowException? = null) = ErrorSessionEnd(0, errorResponse) + private fun MockNode.sendSessionMessage(message: SessionMessage, destination: MockNode) { + services.networkService.apply { + val address = getAddressOfParty(PartyInfo.Node(destination.info)) + send(createMessage(StateMachineManager.sessionTopic, message.serialize().bytes), address) + } + } + private fun assertSessionTransfers(vararg expected: SessionTransfer) { - assertThat(sessionTransfers).containsExactly(*expected) + assertThat(receivedSessionMessages).containsExactly(*expected) } private fun assertSessionTransfers(node: MockNode, vararg expected: SessionTransfer): List { - val actualForNode = sessionTransfers.filter { it.from == node.id || it.to == node.network.myAddress } + val actualForNode = receivedSessionMessages.filter { it.from == node.id || it.to == node.network.myAddress } assertThat(actualForNode).containsExactly(*expected) return actualForNode } @@ -723,6 +785,10 @@ class FlowFrameworkTests { override fun toString(): String = "$from sent $message to $to" } + private fun receivedSessionMessagesObservable(): Observable { + return mockNet.messagingNetwork.receivedMessages.toSessionTransfers() + } + private fun Observable.toSessionTransfers(): Observable { return filter { it.message.topicSession == StateMachineManager.sessionTopic }.map { val from = it.sender.id @@ -733,8 +799,8 @@ class FlowFrameworkTests { private fun sanitise(message: SessionMessage) = when (message) { is SessionData -> message.copy(recipientSessionId = 0) - is SessionInit -> message.copy(initiatorSessionId = 0) - is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0) + is SessionInit -> message.copy(initiatorSessionId = 0, appName = "") + is SessionConfirm -> message.copy(initiatorSessionId = 0, initiatedSessionId = 0, appName = "") is NormalSessionEnd -> message.copy(recipientSessionId = 0) is ErrorSessionEnd -> message.copy(recipientSessionId = 0) else -> message @@ -743,7 +809,7 @@ class FlowFrameworkTests { private infix fun MockNode.sent(message: SessionMessage): Pair = Pair(id, message) private infix fun Pair.to(node: MockNode): SessionTransfer = SessionTransfer(first, second, node.network.myAddress) - private val FlowLogic<*>.progressSteps: ListenableFuture>> get() { + private val FlowLogic<*>.progressSteps: CordaFuture>> get() { return progressTracker!!.changes .ofType(Change.Position::class.java) .map { it.newStep } @@ -770,7 +836,6 @@ class FlowFrameworkTests { } } - @InitiatingFlow private open class SendFlow(val payload: Any, vararg val otherParties: Party) : FlowLogic() { init { @@ -878,14 +943,6 @@ class FlowFrameworkTests { } } - private class VaultAccessFlow : FlowLogic() { - @Suspendable - override fun call() { - serviceHub.vaultService.unconsumedStates().filter { true } - waitForLedgerCommit(SecureHash.zeroHash) - } - } - @InitiatingFlow private class VaultQueryFlow(val stx: SignedTransaction, val otherParty: Party) : FlowLogic>>() { @Suspendable @@ -900,9 +957,13 @@ class FlowFrameworkTests { } @InitiatingFlow(version = 2) - private class UpgradedFlow(val otherParty: Party) : FlowLogic() { + private class UpgradedFlow(val otherParty: Party) : FlowLogic>() { @Suspendable - override fun call(): Any = receive(otherParty).unwrap { it } + override fun call(): Pair { + val received = receive(otherParty).unwrap { it } + val otherFlowVersion = getFlowContext(otherParty).flowVersion + return Pair(received, otherFlowVersion) + } } private class SingleInlinedSubFlow(val otherParty: Party) : FlowLogic() { diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/BFTSMaRtConfigTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/BFTSMaRtConfigTests.kt index 5fdee29f25..1837953266 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/BFTSMaRtConfigTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/BFTSMaRtConfigTests.kt @@ -28,13 +28,13 @@ class BFTSMaRtConfigTests { @Test fun `overlapping port ranges are rejected`() { - fun addresses(vararg ports: Int) = ports.map { NetworkHostAndPort("localhost", it) } - assertThatThrownBy { BFTSMaRtConfig(addresses(11000, 11001)).use {} } + fun config(vararg ports: Int) = BFTSMaRtConfig(ports.map { NetworkHostAndPort("localhost", it) }, false, false) + assertThatThrownBy { config(11000, 11001).use {} } .isInstanceOf(IllegalArgumentException::class.java) .hasMessage(portIsClaimedFormat.format("localhost:11001", setOf("localhost:11000", "localhost:11001"))) - assertThatThrownBy { BFTSMaRtConfig(addresses(11001, 11000)).use {} } + assertThatThrownBy { config(11001, 11000).use {} } .isInstanceOf(IllegalArgumentException::class.java) .hasMessage(portIsClaimedFormat.format("localhost:11001", setOf("localhost:11001", "localhost:11002", "localhost:11000"))) - BFTSMaRtConfig(addresses(11000, 11002)).use {} // Non-overlapping. + config(11000, 11002).use {} // Non-overlapping. } } diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt index 7895f4defc..95082a625b 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/DistributedImmutableMapTests.kt @@ -6,38 +6,35 @@ import io.atomix.copycat.client.CopycatClient import io.atomix.copycat.server.CopycatServer import io.atomix.copycat.server.storage.Storage import io.atomix.copycat.server.storage.StorageLevel -import net.corda.core.getOrThrow import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.LogHelper +import net.corda.core.utilities.getOrThrow import net.corda.node.services.network.NetworkMapService +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.testing.freeLocalHostAndPort +import net.corda.testing.* import net.corda.testing.node.makeTestDataSourceProperties -import org.jetbrains.exposed.sql.Database +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.jetbrains.exposed.sql.Transaction import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable import java.util.concurrent.CompletableFuture import kotlin.test.assertEquals import kotlin.test.assertTrue -class DistributedImmutableMapTests { +class DistributedImmutableMapTests : TestDependencyInjectionBase() { data class Member(val client: CopycatClient, val server: CopycatServer) lateinit var cluster: List - lateinit var dataSource: Closeable lateinit var transaction: Transaction - lateinit var database: Database + lateinit var database: CordaPersistence @Before fun setup() { LogHelper.setLevel("-org.apache.activemq") LogHelper.setLevel(NetworkMapService::class) - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) cluster = setUpCluster() } @@ -49,7 +46,7 @@ class DistributedImmutableMapTests { it.client.close() it.server.shutdown() } - dataSource.close() + database.close() } @Test diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt index c0da57821c..e72e273dc1 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/NotaryServiceTests.kt @@ -1,21 +1,21 @@ package net.corda.node.services.transactions -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TransactionType -import net.corda.testing.contracts.DummyContract -import net.corda.core.crypto.DigitalSignature -import net.corda.core.getOrThrow +import net.corda.core.crypto.TransactionSignature +import net.corda.core.flows.NotaryError +import net.corda.core.flows.NotaryException +import net.corda.core.flows.NotaryFlow import net.corda.core.node.services.ServiceInfo -import net.corda.core.seconds import net.corda.core.transactions.SignedTransaction -import net.corda.flows.NotaryError -import net.corda.flows.NotaryException -import net.corda.flows.NotaryFlow +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds import net.corda.node.internal.AbstractNode import net.corda.node.services.network.NetworkMapService import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.MockNetwork import org.assertj.core.api.Assertions.assertThat import org.junit.After @@ -37,7 +37,7 @@ class NotaryServiceTests { notaryNode = mockNet.createNode( legalName = DUMMY_NOTARY.name, advertisedServices = *arrayOf(ServiceInfo(NetworkMapService.type), ServiceInfo(SimpleNotaryService.type))) - clientNode = mockNet.createNode(networkMapAddress = notaryNode.network.myAddress) + clientNode = mockNet.createNode(notaryNode.network.myAddress) mockNet.runNetwork() // Clear network map registration messages } @@ -50,7 +50,7 @@ class NotaryServiceTests { fun `should sign a unique transaction with a valid time-window`() { val stx = run { val inputState = issueState(clientNode) - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState) tx.setTimeWindow(Instant.now(), 30.seconds) clientNode.services.signInitialTransaction(tx) } @@ -64,7 +64,7 @@ class NotaryServiceTests { fun `should sign a unique transaction without a time-window`() { val stx = run { val inputState = issueState(clientNode) - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState) clientNode.services.signInitialTransaction(tx) } @@ -77,7 +77,7 @@ class NotaryServiceTests { fun `should report error for transaction with an invalid time-window`() { val stx = run { val inputState = issueState(clientNode) - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState) tx.setTimeWindow(Instant.now().plusSeconds(3600), 30.seconds) clientNode.services.signInitialTransaction(tx) } @@ -92,7 +92,7 @@ class NotaryServiceTests { fun `should sign identical transaction multiple times (signing is idempotent)`() { val stx = run { val inputState = issueState(clientNode) - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState) clientNode.services.signInitialTransaction(tx) } @@ -110,11 +110,11 @@ class NotaryServiceTests { fun `should report conflict when inputs are reused across transactions`() { val inputState = issueState(clientNode) val stx = run { - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState) clientNode.services.signInitialTransaction(tx) } val stx2 = run { - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState) tx.addInputState(issueState(clientNode)) clientNode.services.signInitialTransaction(tx) } @@ -132,7 +132,7 @@ class NotaryServiceTests { notaryError.conflict.verified() } - private fun runNotaryClient(stx: SignedTransaction): ListenableFuture> { + private fun runNotaryClient(stx: SignedTransaction): CordaFuture> { val flow = NotaryFlow.Client(stx) val future = clientNode.services.startFlow(flow).resultFuture mockNet.runNetwork() diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/PathManagerTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/PathManagerTests.kt index baec774fa0..834126f17d 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/PathManagerTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/PathManagerTests.kt @@ -1,6 +1,6 @@ package net.corda.node.services.transactions -import net.corda.core.exists +import net.corda.core.internal.exists import org.junit.Test import java.nio.file.Files import kotlin.test.assertFailsWith diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt index 1ff8c103f3..57eefb0f22 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/PersistentUniquenessProviderTests.kt @@ -2,38 +2,33 @@ package net.corda.node.services.transactions import net.corda.core.crypto.SecureHash import net.corda.core.node.services.UniquenessException -import net.corda.testing.LogHelper +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction -import net.corda.testing.MEGA_CORP -import net.corda.testing.generateStateRef +import net.corda.testing.* import net.corda.testing.node.makeTestDataSourceProperties -import org.jetbrains.exposed.sql.Database +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable import kotlin.test.assertEquals import kotlin.test.assertFailsWith -class PersistentUniquenessProviderTests { +class PersistentUniquenessProviderTests : TestDependencyInjectionBase() { val identity = MEGA_CORP val txID = SecureHash.randomSHA256() - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence @Before fun setUp() { LogHelper.setLevel(PersistentUniquenessProvider::class) - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) } @After fun tearDown() { - dataSource.close() + database.close() LogHelper.reset(PersistentUniquenessProvider::class) } diff --git a/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt b/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt index 3072677dfd..335ec1048b 100644 --- a/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/transactions/ValidatingNotaryServiceTests.kt @@ -1,23 +1,23 @@ package net.corda.node.services.transactions -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.Command import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef -import net.corda.core.contracts.TransactionType -import net.corda.testing.contracts.DummyContract -import net.corda.core.crypto.DigitalSignature -import net.corda.core.getOrThrow +import net.corda.core.crypto.TransactionSignature +import net.corda.core.flows.NotaryError +import net.corda.core.flows.NotaryException +import net.corda.core.flows.NotaryFlow import net.corda.core.node.services.ServiceInfo import net.corda.core.transactions.SignedTransaction -import net.corda.flows.NotaryError -import net.corda.flows.NotaryException -import net.corda.flows.NotaryFlow +import net.corda.core.utilities.getOrThrow +import net.corda.core.transactions.TransactionBuilder import net.corda.node.internal.AbstractNode import net.corda.node.services.issueInvalidState import net.corda.node.services.network.NetworkMapService import net.corda.testing.DUMMY_NOTARY import net.corda.testing.MEGA_CORP_KEY +import net.corda.testing.contracts.DummyContract import net.corda.testing.node.MockNetwork import org.assertj.core.api.Assertions.assertThat import org.junit.After @@ -39,7 +39,7 @@ class ValidatingNotaryServiceTests { legalName = DUMMY_NOTARY.name, advertisedServices = *arrayOf(ServiceInfo(NetworkMapService.type), ServiceInfo(ValidatingNotaryService.type)) ) - clientNode = mockNet.createNode(networkMapAddress = notaryNode.network.myAddress) + clientNode = mockNet.createNode(notaryNode.network.myAddress) mockNet.runNetwork() // Clear network map registration messages } @@ -52,14 +52,15 @@ class ValidatingNotaryServiceTests { fun `should report error for invalid transaction dependency`() { val stx = run { val inputState = issueInvalidState(clientNode, notaryNode.info.notaryIdentity) - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState) clientNode.services.signInitialTransaction(tx) } val future = runClient(stx) val ex = assertFailsWith(NotaryException::class) { future.getOrThrow() } - assertThat(ex.error).isInstanceOf(NotaryError.SignaturesInvalid::class.java) + val notaryError = ex.error as NotaryError.TransactionInvalid + assertThat(notaryError.cause).isInstanceOf(SignedTransaction.SignaturesMissingException::class.java) } @Test @@ -69,7 +70,7 @@ class ValidatingNotaryServiceTests { val inputState = issueState(clientNode) val command = Command(DummyContract.Commands.Move(), expectedMissingKey) - val tx = TransactionType.General.Builder(notaryNode.info.notaryIdentity).withItems(inputState, command) + val tx = TransactionBuilder(notaryNode.info.notaryIdentity).withItems(inputState, command) clientNode.services.signInitialTransaction(tx) } @@ -77,14 +78,14 @@ class ValidatingNotaryServiceTests { val future = runClient(stx) future.getOrThrow() } - val notaryError = ex.error - assertThat(notaryError).isInstanceOf(NotaryError.SignaturesMissing::class.java) + val notaryError = ex.error as NotaryError.TransactionInvalid + assertThat(notaryError.cause).isInstanceOf(SignedTransaction.SignaturesMissingException::class.java) - val missingKeys = (notaryError as NotaryError.SignaturesMissing).cause.missing + val missingKeys = (notaryError.cause as SignedTransaction.SignaturesMissingException).missing assertEquals(setOf(expectedMissingKey), missingKeys) } - private fun runClient(stx: SignedTransaction): ListenableFuture> { + private fun runClient(stx: SignedTransaction): CordaFuture> { val flow = NotaryFlow.Client(stx) val future = clientNode.services.startFlow(flow).resultFuture mockNet.runNetwork() diff --git a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt index 2e100edf28..418c62f631 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/NodeVaultServiceTest.kt @@ -1,82 +1,96 @@ package net.corda.node.services.vault +import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.asset.Cash import net.corda.contracts.asset.DUMMY_CASH_ISSUER +import net.corda.contracts.asset.DUMMY_CASH_ISSUER_KEY +import net.corda.contracts.asset.sumCash +import net.corda.contracts.getCashBalance import net.corda.core.contracts.* import net.corda.core.crypto.generateKeyPair +import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty -import net.corda.core.node.services.StatesNotAvailableException -import net.corda.core.node.services.Vault -import net.corda.core.node.services.VaultService -import net.corda.core.node.services.unconsumedStates +import net.corda.core.identity.Party +import net.corda.core.node.services.* +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.QueryCriteria.* +import net.corda.core.transactions.NotaryChangeWireTransaction import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.OpaqueBytes -import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction +import net.corda.core.utilities.toNonEmptySet +import net.corda.node.utilities.CordaPersistence import net.corda.testing.* import net.corda.testing.contracts.fillWithSomeTestCash import net.corda.testing.node.MockServices -import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseAndMockServices import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatExceptionOfType -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable +import rx.observers.TestSubscriber +import java.math.BigDecimal import java.util.* import java.util.concurrent.CountDownLatch import java.util.concurrent.Executors import kotlin.test.assertEquals import kotlin.test.assertFalse -import kotlin.test.assertNull import kotlin.test.assertTrue -class NodeVaultServiceTest { +class NodeVaultServiceTest : TestDependencyInjectionBase() { lateinit var services: MockServices + lateinit var issuerServices: MockServices val vaultSvc: VaultService get() = services.vaultService - lateinit var dataSource: Closeable - lateinit var database: Database + val vaultQuery: VaultQueryService get() = services.vaultQueryService + lateinit var database: CordaPersistence @Before fun setUp() { LogHelper.setLevel(NodeVaultService::class) - val dataSourceProps = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProps) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second - database.transaction { - services = object : MockServices() { - override val vaultService: VaultService = makeVaultService(dataSourceProps) - - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - validatedTransactions.addTransaction(stx) - } - // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. - vaultService.notifyAll(txs.map { it.tx }) - } - } - } + val databaseAndServices = makeTestDatabaseAndMockServices(keys = listOf(BOC_KEY, DUMMY_CASH_ISSUER_KEY)) + database = databaseAndServices.first + services = databaseAndServices.second + issuerServices = MockServices(DUMMY_CASH_ISSUER_KEY, BOC_KEY) } @After fun tearDown() { - dataSource.close() + database.close() LogHelper.reset(NodeVaultService::class) } + @Suspendable + private fun VaultService.unconsumedCashStatesForSpending(amount: Amount, + onlyFromIssuerParties: Set? = null, + notary: Party? = null, + lockId: UUID = UUID.randomUUID(), + withIssuerRefs: Set? = null): List> { + + val notaryName = if (notary != null) listOf(notary.name) else null + var baseCriteria: QueryCriteria = QueryCriteria.VaultQueryCriteria(notaryName = notaryName) + if (onlyFromIssuerParties != null || withIssuerRefs != null) { + baseCriteria = baseCriteria.and(QueryCriteria.FungibleAssetQueryCriteria( + issuerPartyName = onlyFromIssuerParties?.toList(), + issuerRef = withIssuerRefs?.toList())) + } + + return tryLockFungibleStatesForSpending(lockId, baseCriteria, amount, Cash.State::class.java) + } + + @Test fun `states not local to instance`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L)) - val w1 = vaultSvc.unconsumedStates() + val w1 = vaultQuery.queryBy().states assertThat(w1).hasSize(3) val originalVault = vaultSvc + val originalVaultQuery = vaultQuery val services2 = object : MockServices() { override val vaultService: VaultService get() = originalVault override fun recordTransactions(txs: Iterable) { @@ -85,9 +99,10 @@ class NodeVaultServiceTest { vaultService.notify(stx.tx) } } + override val vaultQueryService : VaultQueryService get() = originalVaultQuery } - val w2 = services2.vaultService.unconsumedStates() + val w2 = services2.vaultQueryService.queryBy().states assertThat(w2).hasSize(3) } } @@ -96,13 +111,12 @@ class NodeVaultServiceTest { fun `states for refs`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L)) - val w1 = vaultSvc.unconsumedStates().toList() + val w1 = vaultQuery.queryBy().states assertThat(w1).hasSize(3) - val stateRefs = listOf(w1[1].ref, w1[2].ref) - val states = vaultSvc.statesForRefs(stateRefs) + val states = vaultQuery.queryBy(VaultQueryCriteria(stateRefs = listOf(w1[1].ref, w1[2].ref))).states assertThat(states).hasSize(2) } } @@ -111,34 +125,36 @@ class NodeVaultServiceTest { fun `states soft locking reserve and release`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L)) - val unconsumedStates = vaultSvc.unconsumedStates().toList() + val unconsumedStates = vaultQuery.queryBy().states assertThat(unconsumedStates).hasSize(3) - val stateRefsToSoftLock = setOf(unconsumedStates[1].ref, unconsumedStates[2].ref) + val stateRefsToSoftLock = NonEmptySet.of(unconsumedStates[1].ref, unconsumedStates[2].ref) // soft lock two of the three states val softLockId = UUID.randomUUID() vaultSvc.softLockReserve(softLockId, stateRefsToSoftLock) // all softlocked states - assertThat(vaultSvc.softLockedStates()).hasSize(2) + val criteriaLocked = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.LOCKED_ONLY)) + assertThat(vaultQuery.queryBy(criteriaLocked).states).hasSize(2) // my softlocked states - assertThat(vaultSvc.softLockedStates(softLockId)).hasSize(2) + val criteriaByLockId = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(softLockId))) + assertThat(vaultQuery.queryBy(criteriaByLockId).states).hasSize(2) // excluding softlocked states - val unlockedStates1 = vaultSvc.unconsumedStates(includeSoftLockedStates = false).toList() + val unlockedStates1 = vaultQuery.queryBy(VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_ONLY))).states assertThat(unlockedStates1).hasSize(1) // soft lock release one of the states explicitly - vaultSvc.softLockRelease(softLockId, setOf(unconsumedStates[1].ref)) - val unlockedStates2 = vaultSvc.unconsumedStates(includeSoftLockedStates = false).toList() + vaultSvc.softLockRelease(softLockId, NonEmptySet.of(unconsumedStates[1].ref)) + val unlockedStates2 = vaultQuery.queryBy(VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_ONLY))).states assertThat(unlockedStates2).hasSize(2) // soft lock release the rest by id vaultSvc.softLockRelease(softLockId) - val unlockedStates = vaultSvc.unconsumedStates(includeSoftLockedStates = false).toList() + val unlockedStates = vaultQuery.queryBy(VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_ONLY))).states assertThat(unlockedStates).hasSize(3) // should be back to original states @@ -148,19 +164,21 @@ class NodeVaultServiceTest { @Test fun `soft locking attempt concurrent reserve`() { - val backgroundExecutor = Executors.newFixedThreadPool(2) val countDown = CountDownLatch(2) val softLockId1 = UUID.randomUUID() val softLockId2 = UUID.randomUUID() + val criteriaByLockId1 = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(softLockId1))) + val criteriaByLockId2 = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(softLockId2))) + val vaultStates = database.transaction { - assertNull(vaultSvc.cashBalances[USD]) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + assertEquals(0.DOLLARS, services.getCashBalance(USD)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L)) } - val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() + val stateRefsToSoftLock = (vaultStates.states.map { it.ref }).toNonEmptySet() println("State Refs:: $stateRefsToSoftLock") // 1st tx locks states @@ -168,7 +186,7 @@ class NodeVaultServiceTest { try { database.transaction { vaultSvc.softLockReserve(softLockId1, stateRefsToSoftLock) - assertThat(vaultSvc.softLockedStates(softLockId1)).hasSize(3) + assertThat(vaultQuery.queryBy(criteriaByLockId1).states).hasSize(3) } println("SOFT LOCK STATES #1 succeeded") } catch(e: Throwable) { @@ -184,7 +202,7 @@ class NodeVaultServiceTest { Thread.sleep(100) // let 1st thread soft lock them 1st database.transaction { vaultSvc.softLockReserve(softLockId2, stateRefsToSoftLock) - assertThat(vaultSvc.softLockedStates(softLockId2)).hasSize(3) + assertThat(vaultQuery.queryBy(criteriaByLockId2).states).hasSize(3) } println("SOFT LOCK STATES #2 succeeded") } catch(e: Throwable) { @@ -196,10 +214,10 @@ class NodeVaultServiceTest { countDown.await() database.transaction { - val lockStatesId1 = vaultSvc.softLockedStates(softLockId1) + val lockStatesId1 = vaultQuery.queryBy(criteriaByLockId1).states println("SOFT LOCK #1 final states: $lockStatesId1") assertThat(lockStatesId1.size).isIn(0, 3) - val lockStatesId2 = vaultSvc.softLockedStates(softLockId2) + val lockStatesId2 = vaultQuery.queryBy(criteriaByLockId2).states println("SOFT LOCK #2 final states: $lockStatesId2") assertThat(lockStatesId2.size).isIn(0, 3) } @@ -207,55 +225,55 @@ class NodeVaultServiceTest { @Test fun `soft locking partial reserve states fails`() { - val softLockId1 = UUID.randomUUID() val softLockId2 = UUID.randomUUID() val vaultStates = database.transaction { - assertNull(vaultSvc.cashBalances[USD]) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + assertEquals(0.DOLLARS, services.getCashBalance(USD)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L)) } - val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() + val stateRefsToSoftLock = vaultStates.states.map { it.ref } println("State Refs:: $stateRefsToSoftLock") // lock 1st state with LockId1 database.transaction { - vaultSvc.softLockReserve(softLockId1, setOf(stateRefsToSoftLock.first())) - assertThat(vaultSvc.softLockedStates(softLockId1)).hasSize(1) + vaultSvc.softLockReserve(softLockId1, NonEmptySet.of(stateRefsToSoftLock.first())) + val criteriaByLockId1 = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(softLockId1))) + assertThat(vaultQuery.queryBy(criteriaByLockId1).states).hasSize(1) } // attempt to lock all 3 states with LockId2 database.transaction { assertThatExceptionOfType(StatesNotAvailableException::class.java).isThrownBy( - { vaultSvc.softLockReserve(softLockId2, stateRefsToSoftLock) } + { vaultSvc.softLockReserve(softLockId2, stateRefsToSoftLock.toNonEmptySet()) } ).withMessageContaining("only 2 rows available").withNoCause() } } @Test fun `attempt to lock states already soft locked by me`() { - val softLockId1 = UUID.randomUUID() + val criteriaByLockId1 = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(softLockId1))) val vaultStates = database.transaction { - assertNull(vaultSvc.cashBalances[USD]) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + assertEquals(0.DOLLARS, services.getCashBalance(USD)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L)) } - val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() + val stateRefsToSoftLock = (vaultStates.states.map { it.ref }).toNonEmptySet() println("State Refs:: $stateRefsToSoftLock") // lock states with LockId1 database.transaction { vaultSvc.softLockReserve(softLockId1, stateRefsToSoftLock) - assertThat(vaultSvc.softLockedStates(softLockId1)).hasSize(3) + assertThat(vaultQuery.queryBy(criteriaByLockId1).states).hasSize(3) } // attempt to relock same states with LockId1 database.transaction { vaultSvc.softLockReserve(softLockId1, stateRefsToSoftLock) - assertThat(vaultSvc.softLockedStates(softLockId1)).hasSize(3) + assertThat(vaultQuery.queryBy(criteriaByLockId1).states).hasSize(3) } } @@ -263,25 +281,26 @@ class NodeVaultServiceTest { fun `lock additional states to some already soft locked by me`() { val softLockId1 = UUID.randomUUID() + val criteriaByLockId1 = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(softLockId1))) val vaultStates = database.transaction { - assertNull(vaultSvc.cashBalances[USD]) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + assertEquals(0.DOLLARS, services.getCashBalance(USD)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L)) } - val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() + val stateRefsToSoftLock = vaultStates.states.map { it.ref } println("State Refs:: $stateRefsToSoftLock") // lock states with LockId1 database.transaction { - vaultSvc.softLockReserve(softLockId1, setOf(stateRefsToSoftLock.first())) - assertThat(vaultSvc.softLockedStates(softLockId1)).hasSize(1) + vaultSvc.softLockReserve(softLockId1, NonEmptySet.of(stateRefsToSoftLock.first())) + assertThat(vaultQuery.queryBy(criteriaByLockId1).states).hasSize(1) } // attempt to lock all states with LockId1 (including previously already locked one) database.transaction { - vaultSvc.softLockReserve(softLockId1, stateRefsToSoftLock) - assertThat(vaultSvc.softLockedStates(softLockId1)).hasSize(3) + vaultSvc.softLockReserve(softLockId1, stateRefsToSoftLock.toNonEmptySet()) + assertThat(vaultQuery.queryBy(criteriaByLockId1).states).hasSize(3) } } @@ -289,16 +308,17 @@ class NodeVaultServiceTest { fun `unconsumedStatesForSpending exact amount`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L)) - val unconsumedStates = vaultSvc.unconsumedStates().toList() + val unconsumedStates = vaultQuery.queryBy().states assertThat(unconsumedStates).hasSize(1) - val spendableStatesUSD = (vaultSvc as NodeVaultService).unconsumedStatesForSpending(100.DOLLARS, lockId = UUID.randomUUID()) + val spendableStatesUSD = vaultSvc.unconsumedCashStatesForSpending(100.DOLLARS) spendableStatesUSD.forEach(::println) assertThat(spendableStatesUSD).hasSize(1) assertThat(spendableStatesUSD[0].state.data.amount.quantity).isEqualTo(100L * 100) - assertThat(vaultSvc.softLockedStates()).hasSize(1) + val criteriaLocked = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.LOCKED_ONLY)) + assertThat(vaultQuery.queryBy(criteriaLocked).states).hasSize(1) } } @@ -306,15 +326,16 @@ class NodeVaultServiceTest { fun `unconsumedStatesForSpending from two issuer parties`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1)), issuerKey = BOC_KEY) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1))) - val spendableStatesUSD = vaultSvc.unconsumedStatesForSpending(200.DOLLARS, lockId = UUID.randomUUID(), - onlyFromIssuerParties = setOf(DUMMY_CASH_ISSUER.party, BOC)).toList() + val spendableStatesUSD = vaultSvc.unconsumedCashStatesForSpending(200.DOLLARS, + onlyFromIssuerParties = setOf(DUMMY_CASH_ISSUER.party, BOC)) spendableStatesUSD.forEach(::println) assertThat(spendableStatesUSD).hasSize(2) - assertThat(spendableStatesUSD[0].state.data.amount.token.issuer).isEqualTo(DUMMY_CASH_ISSUER) - assertThat(spendableStatesUSD[1].state.data.amount.token.issuer).isEqualTo(BOC.ref(1)) + assertThat(spendableStatesUSD[0].state.data.amount.token.issuer).isIn(DUMMY_CASH_ISSUER, BOC.ref(1)) + assertThat(spendableStatesUSD[1].state.data.amount.token.issuer).isIn(DUMMY_CASH_ISSUER, BOC.ref(1)) + assertThat(spendableStatesUSD[0].state.data.amount.token.issuer).isNotEqualTo(spendableStatesUSD[1].state.data.amount.token.issuer) } } @@ -322,20 +343,21 @@ class NodeVaultServiceTest { fun `unconsumedStatesForSpending from specific issuer party and refs`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1)), issuerKey = BOC_KEY, ref = OpaqueBytes.of(1)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(2)), issuerKey = BOC_KEY, ref = OpaqueBytes.of(2)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(3)), issuerKey = BOC_KEY, ref = OpaqueBytes.of(3)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1)), ref = OpaqueBytes.of(1)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(2)), ref = OpaqueBytes.of(2)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(3)), ref = OpaqueBytes.of(3)) - val unconsumedStates = vaultSvc.unconsumedStates().toList() + val unconsumedStates = vaultQuery.queryBy().states assertThat(unconsumedStates).hasSize(4) - val spendableStatesUSD = vaultSvc.unconsumedStatesForSpending(200.DOLLARS, lockId = UUID.randomUUID(), - onlyFromIssuerParties = setOf(BOC), withIssuerRefs = setOf(OpaqueBytes.of(1), OpaqueBytes.of(2))).toList() + val spendableStatesUSD = vaultSvc.unconsumedCashStatesForSpending(200.DOLLARS, + onlyFromIssuerParties = setOf(BOC), withIssuerRefs = setOf(OpaqueBytes.of(1), OpaqueBytes.of(2))) assertThat(spendableStatesUSD).hasSize(2) assertThat(spendableStatesUSD[0].state.data.amount.token.issuer.party).isEqualTo(BOC) - assertThat(spendableStatesUSD[0].state.data.amount.token.issuer.reference).isEqualTo(BOC.ref(1).reference) - assertThat(spendableStatesUSD[1].state.data.amount.token.issuer.reference).isEqualTo(BOC.ref(2).reference) + assertThat(spendableStatesUSD[0].state.data.amount.token.issuer.reference).isIn(BOC.ref(1).reference, BOC.ref(2).reference) + assertThat(spendableStatesUSD[1].state.data.amount.token.issuer.reference).isIn(BOC.ref(1).reference, BOC.ref(2).reference) + assertThat(spendableStatesUSD[0].state.data.amount.token.issuer.reference).isNotEqualTo(spendableStatesUSD[1].state.data.amount.token.issuer.reference) } } @@ -343,15 +365,16 @@ class NodeVaultServiceTest { fun `unconsumedStatesForSpending insufficient amount`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L)) - val unconsumedStates = vaultSvc.unconsumedStates().toList() + val unconsumedStates = vaultQuery.queryBy().states assertThat(unconsumedStates).hasSize(1) - val spendableStatesUSD = (vaultSvc as NodeVaultService).unconsumedStatesForSpending(110.DOLLARS, lockId = UUID.randomUUID()) + val spendableStatesUSD = vaultSvc.unconsumedCashStatesForSpending(110.DOLLARS) spendableStatesUSD.forEach(::println) - assertThat(spendableStatesUSD).hasSize(1) - assertThat(vaultSvc.softLockedStates()).hasSize(0) + assertThat(spendableStatesUSD).hasSize(0) + val criteriaLocked = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.LOCKED_ONLY)) + assertThat(vaultQuery.queryBy(criteriaLocked).states).hasSize(0) } } @@ -359,16 +382,17 @@ class NodeVaultServiceTest { fun `unconsumedStatesForSpending small amount`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 2, 2, Random(0L)) - val unconsumedStates = vaultSvc.unconsumedStates().toList() + val unconsumedStates = vaultQuery.queryBy().states assertThat(unconsumedStates).hasSize(2) - val spendableStatesUSD = (vaultSvc as NodeVaultService).unconsumedStatesForSpending(1.DOLLARS, lockId = UUID.randomUUID()) + val spendableStatesUSD = vaultSvc.unconsumedCashStatesForSpending(1.DOLLARS) spendableStatesUSD.forEach(::println) assertThat(spendableStatesUSD).hasSize(1) assertThat(spendableStatesUSD[0].state.data.amount.quantity).isGreaterThanOrEqualTo(100L) - assertThat(vaultSvc.softLockedStates()).hasSize(1) + val criteriaLocked = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.LOCKED_ONLY)) + assertThat(vaultQuery.queryBy(criteriaLocked).states).hasSize(1) } } @@ -376,19 +400,34 @@ class NodeVaultServiceTest { fun `states soft locking query granularity`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 10, 10, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 10, 10, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 10, 10, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 10, 10, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, issuerServices, DUMMY_NOTARY, 10, 10, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, issuerServices, DUMMY_NOTARY, 10, 10, Random(0L)) - val allStates = vaultSvc.unconsumedStates() - assertThat(allStates).hasSize(30) + var unlockedStates = 30 + val allStates = vaultQuery.queryBy().states + assertThat(allStates).hasSize(unlockedStates) + var lockedCount = 0 for (i in 1..5) { - val spendableStatesUSD = (vaultSvc as NodeVaultService).unconsumedStatesForSpending(20.DOLLARS, lockId = UUID.randomUUID()) + val lockId = UUID.randomUUID() + val spendableStatesUSD = vaultSvc.unconsumedCashStatesForSpending(20.DOLLARS, lockId = lockId) spendableStatesUSD.forEach(::println) + assertThat(spendableStatesUSD.size <= unlockedStates) + unlockedStates -= spendableStatesUSD.size + val criteriaLocked = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(lockId))) + val lockedStates = vaultQuery.queryBy(criteriaLocked).states + if (spendableStatesUSD.isNotEmpty()) { + assertEquals(spendableStatesUSD.size, lockedStates.size) + val lockedTotal = lockedStates.map { it.state.data }.sumCash() + val foundAmount = spendableStatesUSD.map { it.state.data }.sumCash() + assertThat(foundAmount.toDecimal() >= BigDecimal("20.00")) + assertThat(lockedTotal == foundAmount) + lockedCount += lockedStates.size + } } - // note only 3 spend attempts succeed with a total of 8 states - assertThat(vaultSvc.softLockedStates()).hasSize(8) + val criteriaLocked = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.LOCKED_ONLY)) + assertThat(vaultQuery.queryBy(criteriaLocked).states).hasSize(lockedCount) } } @@ -400,7 +439,7 @@ class NodeVaultServiceTest { val freshKey = services.legalIdentityKey // Issue a txn to Send us some Money - val usefulBuilder = TransactionType.General.Builder(null).apply { + val usefulBuilder = TransactionBuilder(null).apply { Cash().generateIssue(this, 100.DOLLARS `issued by` MEGA_CORP.ref(1), AnonymousParty(freshKey), DUMMY_NOTARY) } val usefulTX = megaCorpServices.signInitialTransaction(usefulBuilder) @@ -413,7 +452,7 @@ class NodeVaultServiceTest { assertEquals(3, vaultSvc.getTransactionNotes(usefulTX.id).count()) // Issue more Money (GBP) - val anotherBuilder = TransactionType.General.Builder(null).apply { + val anotherBuilder = TransactionBuilder(null).apply { Cash().generateIssue(this, 200.POUNDS `issued by` MEGA_CORP.ref(1), AnonymousParty(freshKey), DUMMY_NOTARY) } val anotherTX = megaCorpServices.signInitialTransaction(anotherBuilder) @@ -433,7 +472,7 @@ class NodeVaultServiceTest { assertTrue { service.isRelevant(wellKnownCash, services.keyManagementService.keys) } val anonymousIdentity = services.keyManagementService.freshKeyAndCert(services.myInfo.legalIdentityAndCert, false) - val anonymousCash = Cash.State(amount, anonymousIdentity.identity) + val anonymousCash = Cash.State(amount, anonymousIdentity.party) assertTrue { service.isRelevant(anonymousCash, services.keyManagementService.keys) } val thirdPartyIdentity = AnonymousParty(generateKeyPair().public) @@ -444,34 +483,86 @@ class NodeVaultServiceTest { // TODO: Unit test linear state relevancy checks @Test - fun `make update`() { + fun `correct updates are generated for general transactions`() { val service = (services.vaultService as NodeVaultService) + val vaultSubscriber = TestSubscriber>().apply { + service.updates.subscribe(this) + } + val anonymousIdentity = services.keyManagementService.freshKeyAndCert(services.myInfo.legalIdentityAndCert, false) val thirdPartyIdentity = AnonymousParty(generateKeyPair().public) val amount = Amount(1000, Issued(BOC.ref(1), GBP)) // Issue then move some cash - val issueTx = TransactionBuilder(TransactionType.General, services.myInfo.legalIdentity).apply { + val issueTx = TransactionBuilder(services.myInfo.legalIdentity).apply { Cash().generateIssue(this, - amount, anonymousIdentity.identity, services.myInfo.legalIdentity) + amount, anonymousIdentity.party, services.myInfo.legalIdentity) }.toWireTransaction() val cashState = StateAndRef(issueTx.outputs.single(), StateRef(issueTx.id, 0)) - database.transaction { - val expected = Vault.Update(emptySet(), setOf(cashState), null) - val actual = service.makeUpdate(issueTx, setOf(anonymousIdentity.identity.owningKey)) - assertEquals(expected, actual) - services.vaultService.notify(issueTx) - } + database.transaction { service.notify(issueTx) } + val expectedIssueUpdate = Vault.Update(emptySet(), setOf(cashState), null) database.transaction { - val moveTx = TransactionBuilder(TransactionType.General, services.myInfo.legalIdentity).apply { - services.vaultService.generateSpend(this, Amount(1000, GBP), thirdPartyIdentity) + val moveTx = TransactionBuilder(services.myInfo.legalIdentity).apply { + Cash.generateSpend(services, this, Amount(1000, GBP), thirdPartyIdentity) }.toWireTransaction() - - val expected = Vault.Update(setOf(cashState), emptySet(), null) - val actual = service.makeUpdate(moveTx, setOf(anonymousIdentity.identity.owningKey)) - assertEquals(expected, actual) + service.notify(moveTx) } + val expectedMoveUpdate = Vault.Update(setOf(cashState), emptySet(), null) + + val observedUpdates = vaultSubscriber.onNextEvents + assertEquals(observedUpdates, listOf(expectedIssueUpdate, expectedMoveUpdate)) + } + + @Test + fun `correct updates are generated when changing notaries`() { + val service = (services.vaultService as NodeVaultService) + val notary = services.myInfo.legalIdentity + + val vaultSubscriber = TestSubscriber>().apply { + service.updates.subscribe(this) + } + + val anonymousIdentity = services.keyManagementService.freshKeyAndCert(services.myInfo.legalIdentityAndCert, false) + val thirdPartyIdentity = AnonymousParty(generateKeyPair().public) + val amount = Amount(1000, Issued(BOC.ref(1), GBP)) + + // Issue some cash + val issueTxBuilder = TransactionBuilder(notary).apply { + Cash().generateIssue(this, amount, anonymousIdentity.party, notary) + } + val issueStx = services.signInitialTransaction(issueTxBuilder) + // We need to record the issue transaction so inputs can be resolved for the notary change transaction + services.validatedTransactions.addTransaction(issueStx) + + val initialCashState = StateAndRef(issueStx.tx.outputs.single(), StateRef(issueStx.id, 0)) + + // Change notary + val newNotary = DUMMY_NOTARY + val changeNotaryTx = NotaryChangeWireTransaction(listOf(initialCashState.ref), issueStx.notary!!, newNotary) + val cashStateWithNewNotary = StateAndRef(initialCashState.state.copy(notary = newNotary), StateRef(changeNotaryTx.id, 0)) + + database.transaction { + service.notifyAll(listOf(issueStx.tx, changeNotaryTx)) + } + + // Move cash + val moveTx = database.transaction { + TransactionBuilder(newNotary).apply { + Cash.generateSpend(services, this, Amount(1000, GBP), thirdPartyIdentity) + }.toWireTransaction() + } + + database.transaction { + service.notify(moveTx) + } + + val expectedIssueUpdate = Vault.Update(emptySet(), setOf(initialCashState), null) + val expectedNotaryChangeUpdate = Vault.Update(setOf(initialCashState), setOf(cashStateWithNewNotary), null, Vault.UpdateType.NOTARY_CHANGE) + val expectedMoveUpdate = Vault.Update(setOf(cashStateWithNewNotary), emptySet(), null) + + val observedUpdates = vaultSubscriber.onNextEvents + assertEquals(observedUpdates, listOf(expectedIssueUpdate, expectedNotaryChangeUpdate, expectedMoveUpdate)) } } diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt index 587e95081a..47ece34c00 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultQueryTests.kt @@ -5,22 +5,18 @@ import net.corda.contracts.Commodity import net.corda.contracts.DealState import net.corda.contracts.asset.Cash import net.corda.contracts.asset.DUMMY_CASH_ISSUER +import net.corda.contracts.asset.DUMMY_CASH_ISSUER_KEY import net.corda.core.contracts.* +import net.corda.core.crypto.SecureHash import net.corda.core.crypto.entropyToKeyPair import net.corda.core.crypto.toBase58String -import net.corda.core.days import net.corda.core.identity.Party import net.corda.core.node.services.* import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.QueryCriteria.* -import net.corda.core.seconds -import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.OpaqueBytes -import net.corda.core.utilities.toHexString -import net.corda.node.services.database.HibernateConfiguration -import net.corda.node.services.schema.NodeSchemaService +import net.corda.core.utilities.* +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction import net.corda.schemas.CashSchemaV1 import net.corda.schemas.CashSchemaV1.PersistentCashState import net.corda.schemas.CommercialPaperSchemaV1 @@ -28,16 +24,16 @@ import net.corda.schemas.SampleCashSchemaV3 import net.corda.testing.* import net.corda.testing.contracts.* import net.corda.testing.node.MockServices -import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseAndMockServices +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import net.corda.testing.schemas.DummyLinearStateSchemaV1 import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import org.junit.* import org.junit.rules.ExpectedException -import java.io.Closeable import java.lang.Thread.sleep import java.math.BigInteger import java.security.KeyPair @@ -47,41 +43,25 @@ import java.time.ZoneOffset import java.time.temporal.ChronoUnit import java.util.* -class VaultQueryTests { +class VaultQueryTests : TestDependencyInjectionBase() { lateinit var services: MockServices + lateinit var notaryServices: MockServices val vaultSvc: VaultService get() = services.vaultService val vaultQuerySvc: VaultQueryService get() = services.vaultQueryService - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence @Before fun setUp() { - val dataSourceProps = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProps) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second - database.transaction { - val customSchemas = setOf(CommercialPaperSchemaV1, DummyLinearStateSchemaV1) - val hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas)) - services = object : MockServices(MEGA_CORP_KEY) { - override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig) - - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - validatedTransactions.addTransaction(stx) - } - // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. - vaultService.notifyAll(txs.map { it.tx }) - } - override val vaultQueryService : VaultQueryService = HibernateVaultQueryImpl(hibernateConfig, vaultService.updatesPublisher) - } - } + val databaseAndServices = makeTestDatabaseAndMockServices(keys = listOf(MEGA_CORP_KEY, DUMMY_NOTARY_KEY)) + database = databaseAndServices.first + services = databaseAndServices.second + notaryServices = MockServices(DUMMY_NOTARY_KEY, DUMMY_CASH_ISSUER_KEY, BOC_KEY, MEGA_CORP_KEY) } @After fun tearDown() { - dataSource.close() + database.close() } /** @@ -90,21 +70,19 @@ class VaultQueryTests { @Ignore @Test fun createPersistentTestDb() { - val dataSourceAndDatabase = configureDatabase(makePersistentDataSourceProperties()) - val dataSource = dataSourceAndDatabase.first - val database = dataSourceAndDatabase.second + val database = configureDatabase(makePersistentDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) setUpDb(database, 5000) - dataSource.close() + database.close() } - private fun setUpDb(_database: Database, delay: Long = 0) { + private fun setUpDb(_database: CordaPersistence, delay: Long = 0) { _database.transaction { // create new states - services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 10, 10, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 10, 10, Random(0L)) val linearStatesXYZ = services.fillWithSomeTestLinearStates(1, "XYZ") val linearStatesJKL = services.fillWithSomeTestLinearStates(2, "JKL") services.fillWithSomeTestLinearStates(3, "ABC") @@ -115,10 +93,10 @@ class VaultQueryTests { sleep(delay) // consume some states - services.consumeLinearStates(linearStatesXYZ.states.toList()) - services.consumeLinearStates(linearStatesJKL.states.toList()) - services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" }) - services.consumeCash(50.DOLLARS) + services.consumeLinearStates(linearStatesXYZ.states.toList(), DUMMY_NOTARY) + services.consumeLinearStates(linearStatesJKL.states.toList(), DUMMY_NOTARY) + services.consumeDeals(dealStates.states.filter { it.state.data.linearId.externalId == "456" }, DUMMY_NOTARY) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) // Total unconsumed states = 4 + 3 + 2 + 1 (new cash change) = 10 // Total consumed states = 6 + 1 + 2 + 1 = 10 @@ -134,6 +112,9 @@ class VaultQueryTests { return props } + @get:Rule + val expectedEx = ExpectedException.none()!! + /** * Query API tests */ @@ -145,7 +126,7 @@ class VaultQueryTests { fun `unconsumed states simple`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) @@ -172,7 +153,7 @@ class VaultQueryTests { fun `unconsumed states verbose`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) @@ -188,10 +169,10 @@ class VaultQueryTests { fun `unconsumed states with count`() { database.transaction { - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val paging = PageSpecification(DEFAULT_PAGE_NUM, 10) @@ -199,7 +180,7 @@ class VaultQueryTests { assertThat(resultsBeforeConsume.states).hasSize(4) assertThat(resultsBeforeConsume.totalStatesAvailable).isEqualTo(4) - services.consumeCash(75.DOLLARS) + services.consumeCash(75.DOLLARS, notary = DUMMY_NOTARY) val consumedCriteria = VaultQueryCriteria(status = Vault.StateStatus.UNCONSUMED) val resultsAfterConsume = vaultQuerySvc.queryBy(consumedCriteria, paging) @@ -212,7 +193,7 @@ class VaultQueryTests { fun `unconsumed cash states simple`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) @@ -227,7 +208,7 @@ class VaultQueryTests { fun `unconsumed cash states verbose`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) @@ -243,15 +224,16 @@ class VaultQueryTests { fun `unconsumed cash states sorted by state ref`() { database.transaction { - var stateRefs : MutableList = mutableListOf() + val stateRefs: MutableList = mutableListOf() - val issuedStates = services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 10, 10, Random(0L)) + val issuedStates = services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 10, 10, Random(0L)) val issuedStateRefs = issuedStates.states.map { it.ref }.toList() stateRefs.addAll(issuedStateRefs) - val spentStates = services.consumeCash(25.DOLLARS) - var spentStateRefs = spentStates.states.map { it.ref }.toList() - stateRefs.addAll(spentStateRefs) + val spentStates = services.consumeCash(25.DOLLARS, notary = DUMMY_NOTARY) + val consumedStateRefs = spentStates.consumed.map { it.ref }.toList() + val producedStateRefs = spentStates.produced.map { it.ref }.toList() + stateRefs.addAll(consumedStateRefs.plus(producedStateRefs)) val sortAttribute = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF) val criteria = VaultQueryCriteria() @@ -273,9 +255,10 @@ class VaultQueryTests { @Test fun `unconsumed cash states sorted by state ref txnId and index`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 10, 10, Random(0L)) - services.consumeCash(10.DOLLARS) - services.consumeCash(10.DOLLARS) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 10, 10, Random(0L)) + val consumed = mutableSetOf() + services.consumeCash(10.DOLLARS, notary = DUMMY_NOTARY).consumed.forEach { consumed += it.ref.txhash } + services.consumeCash(10.DOLLARS, notary = DUMMY_NOTARY).consumed.forEach { consumed += it.ref.txhash } val sortAttributeTxnId = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF_TXN_ID) val sortAttributeIndex = SortAttribute.Standard(Sort.CommonStateAttribute.STATE_REF_INDEX) @@ -286,13 +269,11 @@ class VaultQueryTests { results.statesMetadata.forEach { println(" ${it.ref}") + assertThat(it.status).isEqualTo(Vault.StateStatus.UNCONSUMED) } - - // explicit sort order asc by txnId and then index: - // order by - // vaultschem1_.transaction_id asc, - // vaultschem1_.output_index asc - assertThat(results.states).hasSize(9) // -2 CONSUMED + 1 NEW UNCONSUMED (change) + val sorted = results.states.sortedBy { it.ref.toString() } + assertThat(results.states).isEqualTo(sorted) + assertThat(results.states).allSatisfy { !consumed.contains(it.ref.txhash) } } } @@ -320,7 +301,7 @@ class VaultQueryTests { @Test fun `unconsumed states for contract state types`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) @@ -337,14 +318,14 @@ class VaultQueryTests { @Test fun `consumed states`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(2, "TEST") // create 2 states with same externalId services.fillWithSomeTestLinearStates(8) val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")) - services.consumeLinearStates(linearStates.states.toList()) - services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" }) - services.consumeCash(50.DOLLARS) + services.consumeLinearStates(linearStates.states.toList(), DUMMY_NOTARY) + services.consumeDeals(dealStates.states.filter { it.state.data.linearId.externalId == "456" }, DUMMY_NOTARY) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) val results = vaultQuerySvc.queryBy(criteria) @@ -356,10 +337,10 @@ class VaultQueryTests { fun `consumed states with count`() { database.transaction { - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(25.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(25.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val paging = PageSpecification(DEFAULT_PAGE_NUM, 10) @@ -367,7 +348,7 @@ class VaultQueryTests { assertThat(resultsBeforeConsume.states).hasSize(4) assertThat(resultsBeforeConsume.totalStatesAvailable).isEqualTo(4) - services.consumeCash(75.DOLLARS) + services.consumeCash(75.DOLLARS, notary = DUMMY_NOTARY) val consumedCriteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) val resultsAfterConsume = vaultQuerySvc.queryBy(consumedCriteria, paging) @@ -379,14 +360,14 @@ class VaultQueryTests { @Test fun `all states`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(2, "TEST") // create 2 results with same UID services.fillWithSomeTestLinearStates(8) val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")) - services.consumeLinearStates(linearStates.states.toList()) - services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" }) - services.consumeCash(50.DOLLARS) // generates a new change state! + services.consumeLinearStates(linearStates.states.toList(), DUMMY_NOTARY) + services.consumeDeals(dealStates.states.filter { it.state.data.linearId.externalId == "456" }, DUMMY_NOTARY) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) // generates a new change state! val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val results = vaultQuerySvc.queryBy(criteria) @@ -398,7 +379,7 @@ class VaultQueryTests { fun `all states with count`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val paging = PageSpecification(DEFAULT_PAGE_NUM, 10) @@ -406,7 +387,7 @@ class VaultQueryTests { assertThat(resultsBeforeConsume.states).hasSize(1) assertThat(resultsBeforeConsume.totalStatesAvailable).isEqualTo(1) - services.consumeCash(50.DOLLARS) // consumed 100 (spent), produced 50 (change) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) // consumed 100 (spent), produced 50 (change) val resultsAfterConsume = vaultQuerySvc.queryBy(criteria, paging) assertThat(resultsAfterConsume.states).hasSize(2) @@ -414,14 +395,14 @@ class VaultQueryTests { } } - val CASH_NOTARY_KEY: KeyPair by lazy { entropyToKeyPair(BigInteger.valueOf(20)) } + val CASH_NOTARY_KEY: KeyPair by lazy { entropyToKeyPair(BigInteger.valueOf(21)) } val CASH_NOTARY: Party get() = Party(X500Name("CN=Cash Notary Service,O=R3,OU=corda,L=Zurich,C=CH"), CASH_NOTARY_KEY.public) @Test fun `unconsumed states by notary`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, CASH_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) @@ -466,15 +447,46 @@ class VaultQueryTests { } @Test - fun `unconsumed states excluding soft locks`() { + fun `unconsumed states with soft locking`() { database.transaction { - val issuedStates = services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 3, 3, Random(0L)) - vaultSvc.softLockReserve(UUID.randomUUID(), setOf(issuedStates.states.first().ref, issuedStates.states.last().ref)) + val issuedStates = services.fillWithSomeTestCash(100.DOLLARS, notaryServices, CASH_NOTARY, 10, 10, Random(0L)).states.toList() + vaultSvc.softLockReserve(UUID.randomUUID(), NonEmptySet.of(issuedStates[1].ref, issuedStates[2].ref, issuedStates[3].ref)) + val lockId1 = UUID.randomUUID() + vaultSvc.softLockReserve(lockId1, NonEmptySet.of(issuedStates[4].ref, issuedStates[5].ref)) + val lockId2 = UUID.randomUUID() + vaultSvc.softLockReserve(lockId2, NonEmptySet.of(issuedStates[6].ref)) - val criteria = VaultQueryCriteria(includeSoftlockedStates = false) - val results = vaultQuerySvc.queryBy(criteria) - assertThat(results.states).hasSize(1) + // excluding soft locked states + val criteriaExclusive = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_ONLY)) + val resultsExclusive = vaultQuerySvc.queryBy(criteriaExclusive) + assertThat(resultsExclusive.states).hasSize(4) + + // only soft locked states + val criteriaLockedOnly = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.LOCKED_ONLY)) + val resultsLockedOnly = vaultQuerySvc.queryBy(criteriaLockedOnly) + assertThat(resultsLockedOnly.states).hasSize(6) + + // soft locked states by single lock id + val criteriaByLockId = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(lockId1))) + val resultsByLockId = vaultQuerySvc.queryBy(criteriaByLockId) + assertThat(resultsByLockId.states).hasSize(2) + + // soft locked states by multiple lock ids + val criteriaByLockIds = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.SPECIFIED, listOf(lockId1, lockId2))) + val resultsByLockIds = vaultQuerySvc.queryBy(criteriaByLockIds) + assertThat(resultsByLockIds.states).hasSize(3) + + // unlocked and locked by `lockId2` + val criteriaUnlockedAndByLockId = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(lockId2))) + val resultsUnlockedAndByLockIds = vaultQuerySvc.queryBy(criteriaUnlockedAndByLockId) + assertThat(resultsUnlockedAndByLockIds.states).hasSize(5) + + // missing lockId + expectedEx.expect(IllegalArgumentException::class.java) + expectedEx.expectMessage("Must specify one or more lockIds") + val criteriaMissingLockId = VaultQueryCriteria(softLockingCondition = SoftLockingCondition(SoftLockingType.UNLOCKED_AND_SPECIFIED)) + vaultQuerySvc.queryBy(criteriaMissingLockId) } } @@ -482,9 +494,9 @@ class VaultQueryTests { fun `logical operator EQUAL`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::currency.equal(GBP.currencyCode) } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -497,9 +509,9 @@ class VaultQueryTests { fun `logical operator NOT EQUAL`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::currency.notEqual(GBP.currencyCode) } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -512,9 +524,9 @@ class VaultQueryTests { fun `logical operator GREATER_THAN`() { database.transaction { - services.fillWithSomeTestCash(1.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(10.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(1.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(10.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::pennies.greaterThan(1000L) } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -527,9 +539,9 @@ class VaultQueryTests { fun `logical operator GREATER_THAN_OR_EQUAL`() { database.transaction { - services.fillWithSomeTestCash(1.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(10.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(1.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(10.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::pennies.greaterThanOrEqual(1000L) } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -542,9 +554,9 @@ class VaultQueryTests { fun `logical operator LESS_THAN`() { database.transaction { - services.fillWithSomeTestCash(1.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(10.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(1.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(10.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::pennies.lessThan(1000L) } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -557,9 +569,9 @@ class VaultQueryTests { fun `logical operator LESS_THAN_OR_EQUAL`() { database.transaction { - services.fillWithSomeTestCash(1.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(10.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(1.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(10.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::pennies.lessThanOrEqual(1000L) } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -572,9 +584,9 @@ class VaultQueryTests { fun `logical operator BETWEEN`() { database.transaction { - services.fillWithSomeTestCash(1.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(10.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(1.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(10.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::pennies.between(500L, 1500L) } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -587,9 +599,9 @@ class VaultQueryTests { fun `logical operator IN`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val currencies = listOf(CHF.currencyCode, GBP.currencyCode) val logicalExpression = builder { CashSchemaV1.PersistentCashState::currency.`in`(currencies) } @@ -603,9 +615,9 @@ class VaultQueryTests { fun `logical operator NOT IN`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val currencies = listOf(CHF.currencyCode, GBP.currencyCode) val logicalExpression = builder { CashSchemaV1.PersistentCashState::currency.notIn(currencies) } @@ -619,9 +631,9 @@ class VaultQueryTests { fun `logical operator LIKE`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::currency.like("%BP") } // GPB val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -634,9 +646,9 @@ class VaultQueryTests { fun `logical operator NOT LIKE`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::currency.notLike("%BP") } // GPB val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -649,9 +661,9 @@ class VaultQueryTests { fun `logical operator IS_NULL`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::issuerParty.isNull() } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -664,9 +676,9 @@ class VaultQueryTests { fun `logical operator NOT_NULL`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val logicalExpression = builder { CashSchemaV1.PersistentCashState::issuerParty.notNull() } val criteria = VaultCustomQueryCriteria(logicalExpression) @@ -679,11 +691,11 @@ class VaultQueryTests { fun `aggregate functions without group clause`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(200.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L)) - services.fillWithSomeTestCash(300.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.fillWithSomeTestCash(400.POUNDS, DUMMY_NOTARY, 4, 4, Random(0L)) - services.fillWithSomeTestCash(500.SWISS_FRANCS, DUMMY_NOTARY, 5, 5, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(200.DOLLARS, notaryServices, DUMMY_NOTARY, 2, 2, Random(0L)) + services.fillWithSomeTestCash(300.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(400.POUNDS, notaryServices, DUMMY_NOTARY, 4, 4, Random(0L)) + services.fillWithSomeTestCash(500.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 5, 5, Random(0L)) // DOCSTART VaultQueryExample21 val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum() } @@ -721,11 +733,11 @@ class VaultQueryTests { fun `aggregate functions with single group clause`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(200.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L)) - services.fillWithSomeTestCash(300.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.fillWithSomeTestCash(400.POUNDS, DUMMY_NOTARY, 4, 4, Random(0L)) - services.fillWithSomeTestCash(500.SWISS_FRANCS, DUMMY_NOTARY, 5, 5, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(200.DOLLARS, notaryServices, DUMMY_NOTARY, 2, 2, Random(0L)) + services.fillWithSomeTestCash(300.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(400.POUNDS, notaryServices, DUMMY_NOTARY, 4, 4, Random(0L)) + services.fillWithSomeTestCash(500.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 5, 5, Random(0L)) // DOCSTART VaultQueryExample22 val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } @@ -781,10 +793,10 @@ class VaultQueryTests { fun `aggregate functions sum by issuer and currency and sort by aggregate sum`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = DUMMY_CASH_ISSUER) - services.fillWithSomeTestCash(200.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L), issuedBy = BOC.ref(1), issuerKey = BOC_KEY) - services.fillWithSomeTestCash(300.POUNDS, DUMMY_NOTARY, 3, 3, Random(0L), issuedBy = DUMMY_CASH_ISSUER) - services.fillWithSomeTestCash(400.POUNDS, DUMMY_NOTARY, 4, 4, Random(0L), issuedBy = BOC.ref(2), issuerKey = BOC_KEY) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = DUMMY_CASH_ISSUER) + services.fillWithSomeTestCash(200.DOLLARS, notaryServices, DUMMY_NOTARY, 2, 2, Random(0L), issuedBy = BOC.ref(1)) + services.fillWithSomeTestCash(300.POUNDS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L), issuedBy = DUMMY_CASH_ISSUER) + services.fillWithSomeTestCash(400.POUNDS, notaryServices, DUMMY_NOTARY, 4, 4, Random(0L), issuedBy = BOC.ref(2)) // DOCSTART VaultQueryExample23 val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::issuerParty, @@ -812,13 +824,103 @@ class VaultQueryTests { } } + @Test + fun `aggregate functions count by contract type`() { + database.transaction { + // create new states + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, CASH_NOTARY, 10, 10, Random(0L)) + services.fillWithSomeTestLinearStates(1, "XYZ") + services.fillWithSomeTestLinearStates(2, "JKL") + services.fillWithSomeTestLinearStates(3, "ABC") + services.fillWithSomeTestDeals(listOf("123", "456", "789")) + + // count fungible assets + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + val fungibleStateCount = vaultQuerySvc.queryBy>(countCriteria).otherResults.single() as Long + assertThat(fungibleStateCount).isEqualTo(10L) + + // count linear states + val linearStateCount = vaultQuerySvc.queryBy(countCriteria).otherResults.single() as Long + assertThat(linearStateCount).isEqualTo(9L) + + // count deal states + val dealStateCount = vaultQuerySvc.queryBy(countCriteria).otherResults.single() as Long + assertThat(dealStateCount).isEqualTo(3L) + } + } + + @Test + fun `aggregate functions count by contract type and state status`() { + database.transaction { + // create new states + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 10, 10, Random(0L)) + val linearStatesXYZ = services.fillWithSomeTestLinearStates(1, "XYZ") + val linearStatesJKL = services.fillWithSomeTestLinearStates(2, "JKL") + services.fillWithSomeTestLinearStates(3, "ABC") + val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")) + + // ALL states + + // count fungible assets + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.ALL) + val fungibleStateCount = vaultQuerySvc.queryBy>(countCriteria).otherResults.single() as Long + assertThat(fungibleStateCount).isEqualTo(10L) + + // count linear states + val linearStateCount = vaultQuerySvc.queryBy(countCriteria).otherResults.single() as Long + assertThat(linearStateCount).isEqualTo(9L) + + // count deal states + val dealStateCount = vaultQuerySvc.queryBy(countCriteria).otherResults.single() as Long + assertThat(dealStateCount).isEqualTo(3L) + + // consume some states + services.consumeLinearStates(linearStatesXYZ.states.toList(), DUMMY_NOTARY) + services.consumeLinearStates(linearStatesJKL.states.toList(), DUMMY_NOTARY) + services.consumeDeals(dealStates.states.filter { it.state.data.linearId.externalId == "456" }, DUMMY_NOTARY) + val cashUpdates = services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) + + // UNCONSUMED states (default) + + // count fungible assets + val countCriteriaUnconsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.UNCONSUMED) + val fungibleStateCountUnconsumed = vaultQuerySvc.queryBy>(countCriteriaUnconsumed).otherResults.single() as Long + assertThat(fungibleStateCountUnconsumed.toInt()).isEqualTo(10 - cashUpdates.consumed.size + cashUpdates.produced.size) + + // count linear states + val linearStateCountUnconsumed = vaultQuerySvc.queryBy(countCriteriaUnconsumed).otherResults.single() as Long + assertThat(linearStateCountUnconsumed).isEqualTo(5L) + + // count deal states + val dealStateCountUnconsumed = vaultQuerySvc.queryBy(countCriteriaUnconsumed).otherResults.single() as Long + assertThat(dealStateCountUnconsumed).isEqualTo(2L) + + // CONSUMED states + + // count fungible assets + val countCriteriaConsumed = QueryCriteria.VaultCustomQueryCriteria(count, Vault.StateStatus.CONSUMED) + val fungibleStateCountConsumed = vaultQuerySvc.queryBy>(countCriteriaConsumed).otherResults.single() as Long + assertThat(fungibleStateCountConsumed.toInt()).isEqualTo(cashUpdates.consumed.size) + + // count linear states + val linearStateCountConsumed = vaultQuerySvc.queryBy(countCriteriaConsumed).otherResults.single() as Long + assertThat(linearStateCountConsumed).isEqualTo(4L) + + // count deal states + val dealStateCountConsumed = vaultQuerySvc.queryBy(countCriteriaConsumed).otherResults.single() as Long + assertThat(dealStateCountConsumed).isEqualTo(1L) + } + } + private val TODAY = LocalDate.now().atStartOfDay().toInstant(ZoneOffset.UTC) @Test fun `unconsumed states recorded between two time intervals`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, CASH_NOTARY, 3, 3, Random(0L)) // DOCSTART VaultQueryExample6 val start = TODAY @@ -844,11 +946,11 @@ class VaultQueryTests { fun `states consumed after time`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) - services.consumeCash(100.DOLLARS) + services.consumeCash(100.DOLLARS, notary = DUMMY_NOTARY) val asOfDateTime = TODAY val consumedAfterExpression = TimeCondition( @@ -866,7 +968,7 @@ class VaultQueryTests { fun `all states with paging specification - first page`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 100, 100, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 100, 100, Random(0L)) // DOCSTART VaultQueryExample7 val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, 10) @@ -883,7 +985,7 @@ class VaultQueryTests { fun `all states with paging specification - last`() { database.transaction { - services.fillWithSomeTestCash(95.DOLLARS, DUMMY_NOTARY, 95, 95, Random(0L)) + services.fillWithSomeTestCash(95.DOLLARS, notaryServices, DUMMY_NOTARY, 95, 95, Random(0L)) // Last page implies we need to perform a row count for the Query first, // and then re-query for a given offset defined by (count - pageSize) @@ -896,9 +998,6 @@ class VaultQueryTests { } } - @get:Rule - val expectedEx = ExpectedException.none()!! - // pagination: invalid page number @Test fun `invalid page number`() { @@ -907,7 +1006,7 @@ class VaultQueryTests { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 100, 100, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 100, 100, Random(0L)) val pagingSpec = PageSpecification(0, 10) @@ -924,7 +1023,7 @@ class VaultQueryTests { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 100, 100, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 100, 100, Random(0L)) val pagingSpec = PageSpecification(DEFAULT_PAGE_NUM, MAX_PAGE_SIZE + 1) // overflow = -2147483648 val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) @@ -940,7 +1039,7 @@ class VaultQueryTests { database.transaction { - services.fillWithSomeTestCash(201.DOLLARS, DUMMY_NOTARY, 201, 201, Random(0L)) + services.fillWithSomeTestCash(201.DOLLARS, notaryServices, DUMMY_NOTARY, 201, 201, Random(0L)) val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) vaultQuerySvc.queryBy(criteria) @@ -980,8 +1079,8 @@ class VaultQueryTests { fun `unconsumed fungible assets`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.fillWithSomeTestCommodity(Amount(100, Commodity.getInstance("FCOJ")!!)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCommodity(Amount(100, Commodity.getInstance("FCOJ")!!), notaryServices) services.fillWithSomeTestLinearStates(10) val results = vaultQuerySvc.queryBy>() @@ -993,9 +1092,9 @@ class VaultQueryTests { fun `consumed fungible assets`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.consumeCash(50.DOLLARS) - services.fillWithSomeTestCommodity(Amount(100, Commodity.getInstance("FCOJ")!!)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) + services.fillWithSomeTestCommodity(Amount(100, Commodity.getInstance("FCOJ")!!), notaryServices) services.fillWithSomeTestLinearStates(10) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) @@ -1008,7 +1107,7 @@ class VaultQueryTests { fun `unconsumed cash fungible assets`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) val results = vaultQuerySvc.queryBy() @@ -1020,8 +1119,8 @@ class VaultQueryTests { fun `unconsumed cash fungible assets after spending`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.consumeCash(50.DOLLARS) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) // should now have x2 CONSUMED + x2 UNCONSUMED (one spent + one change) val results = vaultQuerySvc.queryBy(FungibleAssetQueryCriteria()) @@ -1034,10 +1133,10 @@ class VaultQueryTests { fun `consumed cash fungible assets`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.consumeCash(50.DOLLARS) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) val linearStates = services.fillWithSomeTestLinearStates(10) - services.consumeLinearStates(linearStates.states.toList()) + services.consumeLinearStates(linearStates.states.toList(), DUMMY_NOTARY) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) val results = vaultQuerySvc.queryBy(criteria) @@ -1049,7 +1148,7 @@ class VaultQueryTests { fun `unconsumed linear heads`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestLinearStates(10) services.fillWithSomeTestDeals(listOf("123", "456", "789")) @@ -1062,14 +1161,14 @@ class VaultQueryTests { fun `consumed linear heads`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(2, "TEST") // create 2 states with same externalId services.fillWithSomeTestLinearStates(8) val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")) - services.consumeLinearStates(linearStates.states.toList()) - services.consumeDeals(dealStates.states.filter { it.state.data.ref == "456" }) - services.consumeCash(50.DOLLARS) + services.consumeLinearStates(linearStates.states.toList(), DUMMY_NOTARY) + services.consumeDeals(dealStates.states.filter { it.state.data.linearId.externalId == "456" }, DUMMY_NOTARY) + services.consumeCash(50.DOLLARS, notary = DUMMY_NOTARY) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) val results = vaultQuerySvc.queryBy(criteria) @@ -1086,8 +1185,8 @@ class VaultQueryTests { val issuedStates = services.fillWithSomeTestLinearStates(10) // DOCSTART VaultQueryExample8 - val linearIds = issuedStates.states.map { it.state.data.linearId }.toList() - val criteria = LinearStateQueryCriteria(linearId = listOf(linearIds.first(), linearIds.last())) + val linearIds = issuedStates.states.map { it.state.data.linearId.id }.toList() + val criteria = LinearStateQueryCriteria(uuid = listOf(linearIds.first(), linearIds.last())) val results = vaultQuerySvc.queryBy(criteria) // DOCEND VaultQueryExample8 assertThat(results.states).hasSize(2) @@ -1102,8 +1201,8 @@ class VaultQueryTests { services.fillWithSomeTestLinearStates(1, "ID2") val linearState3 = services.fillWithSomeTestLinearStates(1, "ID3") - val linearIds = listOf(linearState1.states.first().state.data.linearId, linearState3.states.first().state.data.linearId) - val criteria = LinearStateQueryCriteria(linearId = linearIds) + val linearIds = listOf(linearState1.states.first().state.data.linearId.id, linearState3.states.first().state.data.linearId.id) + val criteria = LinearStateQueryCriteria(uuid = linearIds) val results = vaultQuerySvc.queryBy(criteria) assertThat(results.states).hasSize(2) } @@ -1116,13 +1215,13 @@ class VaultQueryTests { val txns = services.fillWithSomeTestLinearStates(1, "TEST") val linearState = txns.states.first() val linearId = linearState.state.data.linearId - services.evolveLinearState(linearState) // consume current and produce new state reference - services.evolveLinearState(linearState) // consume current and produce new state reference - services.evolveLinearState(linearState) // consume current and produce new state reference + services.evolveLinearState(linearState, DUMMY_NOTARY) // consume current and produce new state reference + services.evolveLinearState(linearState, DUMMY_NOTARY) // consume current and produce new state reference + services.evolveLinearState(linearState, DUMMY_NOTARY) // consume current and produce new state reference // should now have 1 UNCONSUMED & 3 CONSUMED state refs for Linear State with "TEST" // DOCSTART VaultQueryExample9 - val linearStateCriteria = LinearStateQueryCriteria(linearId = listOf(linearId), status = Vault.StateStatus.ALL) + val linearStateCriteria = LinearStateQueryCriteria(uuid = listOf(linearId.id), status = Vault.StateStatus.ALL) val vaultCriteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val results = vaultQuerySvc.queryBy(linearStateCriteria and vaultCriteria) // DOCEND VaultQueryExample9 @@ -1136,12 +1235,12 @@ class VaultQueryTests { val txns = services.fillWithSomeTestLinearStates(2, "TEST") val linearStates = txns.states.toList() - services.evolveLinearStates(linearStates) // consume current and produce new state reference - services.evolveLinearStates(linearStates) // consume current and produce new state reference - services.evolveLinearStates(linearStates) // consume current and produce new state reference + services.evolveLinearStates(linearStates, DUMMY_NOTARY) // consume current and produce new state reference + services.evolveLinearStates(linearStates, DUMMY_NOTARY) // consume current and produce new state reference + services.evolveLinearStates(linearStates, DUMMY_NOTARY) // consume current and produce new state reference // should now have 1 UNCONSUMED & 3 CONSUMED state refs for Linear State with "TEST" - val linearStateCriteria = LinearStateQueryCriteria(linearId = linearStates.map { it.state.data.linearId }, status = Vault.StateStatus.ALL) + val linearStateCriteria = LinearStateQueryCriteria(uuid = linearStates.map { it.state.data.linearId.id }, status = Vault.StateStatus.ALL) val vaultCriteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val sorting = Sort(setOf(Sort.SortColumn(SortAttribute.Standard(Sort.LinearStateAttribute.UUID), Sort.Direction.DESC))) @@ -1163,24 +1262,24 @@ class VaultQueryTests { val sorting = Sort(setOf(Sort.SortColumn(SortAttribute.Standard(Sort.LinearStateAttribute.EXTERNAL_ID), Sort.Direction.DESC))) val results = vaultQuerySvc.queryBy((vaultCriteria), sorting = sorting) - results.states.forEach { println("${it.state.data.linearString}") } + results.states.forEach { println(it.state.data.linearString) } assertThat(results.states).hasSize(6) } } @Test - fun `unconsumed deal states paged and sorted`() { + fun `unconsumed deal states sorted`() { database.transaction { val linearStates = services.fillWithSomeTestLinearStates(10) - val uid = linearStates.states.first().state.data.linearId + val uid = linearStates.states.first().state.data.linearId.id services.fillWithSomeTestDeals(listOf("123", "456", "789")) - val linearStateCriteria = LinearStateQueryCriteria(linearId = listOf(uid)) - val dealStateCriteria = LinearStateQueryCriteria(dealRef = listOf("123", "456", "789")) + val linearStateCriteria = LinearStateQueryCriteria(uuid = listOf(uid)) + val dealStateCriteria = LinearStateQueryCriteria(externalId = listOf("123", "456", "789")) val compositeCriteria = linearStateCriteria or dealStateCriteria - val sorting = Sort(setOf(Sort.SortColumn(SortAttribute.Standard(Sort.LinearStateAttribute.DEAL_REFERENCE), Sort.Direction.DESC))) + val sorting = Sort(setOf(Sort.SortColumn(SortAttribute.Standard(Sort.LinearStateAttribute.EXTERNAL_ID), Sort.Direction.DESC))) val results = vaultQuerySvc.queryBy(compositeCriteria, sorting = sorting) assertThat(results.statesMetadata).hasSize(13) @@ -1200,7 +1299,7 @@ class VaultQueryTests { val sorting = Sort(setOf(Sort.SortColumn(SortAttribute.Custom(DummyLinearStateSchemaV1.PersistentDummyLinearState::class.java, "linearString"), Sort.Direction.DESC))) val results = vaultQuerySvc.queryBy((vaultCriteria), sorting = sorting) - results.states.forEach { println("${it.state.data.linearString}") } + results.states.forEach { println(it.state.data.linearString) } assertThat(results.states).hasSize(6) } } @@ -1211,12 +1310,12 @@ class VaultQueryTests { val txns = services.fillWithSomeTestLinearStates(1, "TEST") val linearState = txns.states.first() - val linearState2 = services.evolveLinearState(linearState) // consume current and produce new state reference - val linearState3 = services.evolveLinearState(linearState2) // consume current and produce new state reference - services.evolveLinearState(linearState3) // consume current and produce new state reference + val linearState2 = services.evolveLinearState(linearState, DUMMY_NOTARY) // consume current and produce new state reference + val linearState3 = services.evolveLinearState(linearState2, DUMMY_NOTARY) // consume current and produce new state reference + services.evolveLinearState(linearState3, DUMMY_NOTARY) // consume current and produce new state reference // should now have 1 UNCONSUMED & 3 CONSUMED state refs for Linear State with "TEST" - val linearStateCriteria = LinearStateQueryCriteria(linearId = txns.states.map { it.state.data.linearId }, status = Vault.StateStatus.CONSUMED) + val linearStateCriteria = LinearStateQueryCriteria(uuid = txns.states.map { it.state.data.linearId.id }, status = Vault.StateStatus.CONSUMED) val vaultCriteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) val sorting = Sort(setOf(Sort.SortColumn(SortAttribute.Standard(Sort.LinearStateAttribute.UUID), Sort.Direction.DESC))) val results = vaultQuerySvc.queryBy(linearStateCriteria.and(vaultCriteria), sorting = sorting) @@ -1224,82 +1323,6 @@ class VaultQueryTests { } } - @Test - fun `DEPRECATED unconsumed linear states for a given id`() { - database.transaction { - - val txns = services.fillWithSomeTestLinearStates(1, "TEST") - val linearState = txns.states.first() - val linearId = linearState.state.data.linearId - val linearState2 = services.evolveLinearState(linearState) // consume current and produce new state reference - val linearState3 = services.evolveLinearState(linearState2) // consume current and produce new state reference - services.evolveLinearState(linearState3) // consume current and produce new state reference - - // should now have 1 UNCONSUMED & 3 CONSUMED state refs for Linear State with "TEST" - - // DOCSTART VaultDeprecatedQueryExample1 - val states = vaultSvc.linearHeadsOfType().filter { it.key == linearId } - // DOCEND VaultDeprecatedQueryExample1 - assertThat(states).hasSize(1) - - // validate against new query api - val results = vaultQuerySvc.queryBy(LinearStateQueryCriteria(linearId = listOf(linearId))) - assertThat(results.statesMetadata).hasSize(1) - assertThat(results.states).hasSize(1) - } - } - - @Test - fun `DEPRECATED consumed linear states for a given id`() { - database.transaction { - - val txns = services.fillWithSomeTestLinearStates(1, "TEST") - val linearState = txns.states.first() - val linearId = linearState.state.data.linearId - val linearState2 = services.evolveLinearState(linearState) // consume current and produce new state reference - val linearState3 = services.evolveLinearState(linearState2) // consume current and produce new state reference - services.evolveLinearState(linearState3) // consume current and produce new state reference - - // should now have 1 UNCONSUMED & 3 CONSUMED state refs for Linear State with "TEST" - - // DOCSTART VaultDeprecatedQueryExample2 - val states = vaultSvc.consumedStates().filter { it.state.data.linearId == linearId } - // DOCEND VaultDeprecatedQueryExample2 - assertThat(states).hasSize(3) - - // validate against new query api - val results = vaultQuerySvc.queryBy(LinearStateQueryCriteria(linearId = listOf(linearId), status = Vault.StateStatus.CONSUMED)) - assertThat(results.statesMetadata).hasSize(3) - assertThat(results.states).hasSize(3) - } - } - - @Test - fun `DEPRECATED all linear states for a given id`() { - database.transaction { - - val txns = services.fillWithSomeTestLinearStates(1, "TEST") - val linearState = txns.states.first() - val linearId = linearState.state.data.linearId - services.evolveLinearState(linearState) // consume current and produce new state reference - services.evolveLinearState(linearState) // consume current and produce new state reference - services.evolveLinearState(linearState) // consume current and produce new state reference - - // should now have 1 UNCONSUMED & 3 CONSUMED state refs for Linear State with "TEST" - - // DOCSTART VaultDeprecatedQueryExample3 - val states = vaultSvc.states(setOf(DummyLinearContract.State::class.java), - EnumSet.of(Vault.StateStatus.CONSUMED, Vault.StateStatus.UNCONSUMED)).filter { it.state.data.linearId == linearId } - // DOCEND VaultDeprecatedQueryExample3 - assertThat(states).hasSize(4) - - // validate against new query api - val results = vaultQuerySvc.queryBy(LinearStateQueryCriteria(linearId = listOf(linearId), status = Vault.StateStatus.ALL)) - assertThat(results.statesMetadata).hasSize(4) - assertThat(results.states).hasSize(4) - } - } - /** * Deal Contract state to be removed as is duplicate of LinearState */ @@ -1321,7 +1344,7 @@ class VaultQueryTests { services.fillWithSomeTestDeals(listOf("123", "456", "789")) // DOCSTART VaultQueryExample10 - val criteria = LinearStateQueryCriteria(dealRef = listOf("456", "789")) + val criteria = LinearStateQueryCriteria(externalId = listOf("456", "789")) val results = vaultQuerySvc.queryBy(criteria) // DOCEND VaultQueryExample10 @@ -1340,7 +1363,7 @@ class VaultQueryTests { val all = vaultQuerySvc.queryBy() all.states.forEach { println(it.state) } - val criteria = LinearStateQueryCriteria(dealRef = listOf("456")) + val criteria = LinearStateQueryCriteria(externalId = listOf("456")) val results = vaultQuerySvc.queryBy(criteria) assertThat(results.states).hasSize(1) } @@ -1372,10 +1395,10 @@ class VaultQueryTests { fun `unconsumed fungible assets for specific issuer party and refs`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1)), issuerKey = BOC_KEY, ref = OpaqueBytes.of(1)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(2)), issuerKey = BOC_KEY, ref = OpaqueBytes.of(2)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(3)), issuerKey = BOC_KEY, ref = OpaqueBytes.of(3)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1)), ref = OpaqueBytes.of(1)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(2)), ref = OpaqueBytes.of(2)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(3)), ref = OpaqueBytes.of(3)) val criteria = FungibleAssetQueryCriteria(issuerPartyName = listOf(BOC), issuerRef = listOf(BOC.ref(1).reference, BOC.ref(2).reference)) @@ -1384,13 +1407,39 @@ class VaultQueryTests { } } + @Test + fun `unconsumed fungible assets for selected issuer parties`() { + // GBP issuer + val gbpCashIssuerKey = entropyToKeyPair(BigInteger.valueOf(1001)) + val gbpCashIssuer = Party(X500Name("CN=British Pounds Cash Issuer,O=R3,OU=corda,L=London,C=GB"), gbpCashIssuerKey.public).ref(1) + val gbpCashIssuerServices = MockServices(gbpCashIssuerKey) + // USD issuer + val usdCashIssuerKey = entropyToKeyPair(BigInteger.valueOf(1002)) + val usdCashIssuer = Party(X500Name("CN=US Dollars Cash Issuer,O=R3,OU=corda,L=New York,C=US"), usdCashIssuerKey.public).ref(1) + val usdCashIssuerServices = MockServices(usdCashIssuerKey) + // CHF issuer + val chfCashIssuerKey = entropyToKeyPair(BigInteger.valueOf(1003)) + val chfCashIssuer = Party(X500Name("CN=Swiss Francs Cash Issuer,O=R3,OU=corda,L=Zurich,C=CH"), chfCashIssuerKey.public).ref(1) + val chfCashIssuerServices = MockServices(chfCashIssuerKey) + + database.transaction { + services.fillWithSomeTestCash(100.POUNDS, gbpCashIssuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (gbpCashIssuer)) + services.fillWithSomeTestCash(100.DOLLARS, usdCashIssuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (usdCashIssuer)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, chfCashIssuerServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (chfCashIssuer)) + + val criteria = FungibleAssetQueryCriteria(issuerPartyName = listOf(gbpCashIssuer.party, usdCashIssuer.party)) + val results = vaultQuerySvc.queryBy>(criteria) + assertThat(results.states).hasSize(2) + } + } + @Test fun `unconsumed fungible assets by owner`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), - issuedBy = MEGA_CORP.ref(0), issuerKey = MEGA_CORP_KEY, ownedBy = (MEGA_CORP)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 2, 2, Random(0L), issuedBy = BOC.ref(1)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), + issuedBy = MEGA_CORP.ref(0), ownedBy = (MEGA_CORP)) val criteria = FungibleAssetQueryCriteria(owner = listOf(MEGA_CORP)) val results = vaultQuerySvc.queryBy>(criteria) @@ -1403,14 +1452,14 @@ class VaultQueryTests { fun `unconsumed fungible states for owners`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), - issuedBy = MEGA_CORP.ref(0), issuerKey = MEGA_CORP_KEY, ownedBy = (MEGA_CORP)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), - issuedBy = MINI_CORP.ref(0), issuerKey = MINI_CORP_KEY, ownedBy = (MINI_CORP)) // irrelevant to this vault + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, CASH_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), + issuedBy = MEGA_CORP.ref(0), ownedBy = (MEGA_CORP)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), + issuedBy = BOC.ref(0), ownedBy = (MINI_CORP)) // irrelevant to this vault // DOCSTART VaultQueryExample5.2 - val criteria = FungibleAssetQueryCriteria(owner = listOf(MEGA_CORP,MINI_CORP)) + val criteria = FungibleAssetQueryCriteria(owner = listOf(MEGA_CORP, BOC)) val results = vaultQuerySvc.queryBy(criteria) // DOCEND VaultQueryExample5.2 @@ -1424,9 +1473,9 @@ class VaultQueryTests { database.transaction { services.fillWithSomeTestLinearStates(10) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) // DOCSTART VaultQueryExample12 val ccyIndex = builder { CashSchemaV1.PersistentCashState::currency.equal(USD.currencyCode) } @@ -1442,8 +1491,8 @@ class VaultQueryTests { fun `unconsumed cash balance for single currency`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(200.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(200.DOLLARS, notaryServices, DUMMY_NOTARY, 2, 2, Random(0L)) val sum = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } val sumCriteria = VaultCustomQueryCriteria(sum) @@ -1463,12 +1512,12 @@ class VaultQueryTests { fun `unconsumed cash balances for all currencies`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(200.DOLLARS, DUMMY_NOTARY, 2, 2, Random(0L)) - services.fillWithSomeTestCash(300.POUNDS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.fillWithSomeTestCash(400.POUNDS, DUMMY_NOTARY, 4, 4, Random(0L)) - services.fillWithSomeTestCash(500.SWISS_FRANCS, DUMMY_NOTARY, 5, 5, Random(0L)) - services.fillWithSomeTestCash(600.SWISS_FRANCS, DUMMY_NOTARY, 6, 6, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(200.DOLLARS, notaryServices, DUMMY_NOTARY, 2, 2, Random(0L)) + services.fillWithSomeTestCash(300.POUNDS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(400.POUNDS, notaryServices, DUMMY_NOTARY, 4, 4, Random(0L)) + services.fillWithSomeTestCash(500.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 5, 5, Random(0L)) + services.fillWithSomeTestCash(600.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 6, 6, Random(0L)) val ccyIndex = builder { CashSchemaV1.PersistentCashState::pennies.sum(groupByColumns = listOf(CashSchemaV1.PersistentCashState::currency)) } val criteria = VaultCustomQueryCriteria(ccyIndex) @@ -1488,10 +1537,10 @@ class VaultQueryTests { fun `unconsumed fungible assets for quantity greater than`() { database.transaction { - services.fillWithSomeTestCash(10.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) - services.fillWithSomeTestCash(25.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(50.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(10.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(25.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(50.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) // DOCSTART VaultQueryExample13 val fungibleAssetCriteria = FungibleAssetQueryCriteria(quantity = builder { greaterThan(2500L) }) @@ -1506,8 +1555,8 @@ class VaultQueryTests { fun `unconsumed fungible assets for issuer party`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1)), issuerKey = BOC_KEY) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (DUMMY_CASH_ISSUER)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L), issuedBy = (BOC.ref(1))) // DOCSTART VaultQueryExample14 val criteria = FungibleAssetQueryCriteria(issuerPartyName = listOf(BOC)) @@ -1522,10 +1571,10 @@ class VaultQueryTests { fun `unconsumed fungible assets for single currency and quantity greater than`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(50.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(50.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) val ccyIndex = builder { CashSchemaV1.PersistentCashState::currency.equal(GBP.currencyCode) } val customCriteria = VaultCustomQueryCriteria(ccyIndex) @@ -1548,21 +1597,22 @@ class VaultQueryTests { // MegaCorp™ issues $10,000 of commercial paper, to mature in 30 days, owned by itself. val faceValue = 10000.DOLLARS `issued by` DUMMY_CASH_ISSUER val commercialPaper = - CommercialPaper().generateIssue(issuance, faceValue, TEST_TX_TIME + 30.days, DUMMY_NOTARY).apply { - setTimeWindow(TEST_TX_TIME, 30.seconds) - signWith(MEGA_CORP_KEY) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + CommercialPaper().generateIssue(issuance, faceValue, TEST_TX_TIME + 30.days, DUMMY_NOTARY).let { builder -> + builder.setTimeWindow(TEST_TX_TIME, 30.seconds) + val stx = services.signInitialTransaction(builder, MEGA_CORP_PUBKEY) + notaryServices.addSignature(stx, DUMMY_NOTARY_KEY.public) + } + services.recordTransactions(commercialPaper) // MegaCorp™ now issues £10,000 of commercial paper, to mature in 30 days, owned by itself. val faceValue2 = 10000.POUNDS `issued by` DUMMY_CASH_ISSUER val commercialPaper2 = - CommercialPaper().generateIssue(issuance, faceValue2, TEST_TX_TIME + 30.days, DUMMY_NOTARY).apply { - setTimeWindow(TEST_TX_TIME, 30.seconds) - signWith(MEGA_CORP_KEY) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + CommercialPaper().generateIssue(issuance, faceValue2, TEST_TX_TIME + 30.days, DUMMY_NOTARY).let { builder -> + builder.setTimeWindow(TEST_TX_TIME, 30.seconds) + val stx = services.signInitialTransaction(builder, MEGA_CORP_PUBKEY) + notaryServices.addSignature(stx, DUMMY_NOTARY_KEY.public) + } services.recordTransactions(commercialPaper2) val ccyIndex = builder { CommercialPaperSchemaV1.PersistentCommercialPaperState::currency.equal(USD.currencyCode) } @@ -1585,21 +1635,23 @@ class VaultQueryTests { // MegaCorp™ issues $10,000 of commercial paper, to mature in 30 days, owned by itself. val faceValue = 10000.DOLLARS `issued by` DUMMY_CASH_ISSUER val commercialPaper = - CommercialPaper().generateIssue(issuance, faceValue, TEST_TX_TIME + 30.days, DUMMY_NOTARY).apply { - setTimeWindow(TEST_TX_TIME, 30.seconds) - signWith(MEGA_CORP_KEY) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + CommercialPaper().generateIssue(issuance, faceValue, TEST_TX_TIME + 30.days, DUMMY_NOTARY).let { builder -> + builder.setTimeWindow(TEST_TX_TIME, 30.seconds) + val stx = services.signInitialTransaction(builder, MEGA_CORP_PUBKEY) + notaryServices.addSignature(stx, DUMMY_NOTARY_KEY.public) + } + commercialPaper.verifyRequiredSignatures() services.recordTransactions(commercialPaper) // MegaCorp™ now issues £5,000 of commercial paper, to mature in 30 days, owned by itself. val faceValue2 = 5000.POUNDS `issued by` DUMMY_CASH_ISSUER val commercialPaper2 = - CommercialPaper().generateIssue(issuance, faceValue2, TEST_TX_TIME + 30.days, DUMMY_NOTARY).apply { - setTimeWindow(TEST_TX_TIME, 30.seconds) - signWith(MEGA_CORP_KEY) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + CommercialPaper().generateIssue(issuance, faceValue2, TEST_TX_TIME + 30.days, DUMMY_NOTARY).let { builder -> + builder.setTimeWindow(TEST_TX_TIME, 30.seconds) + val stx = services.signInitialTransaction(builder, MEGA_CORP_PUBKEY) + notaryServices.addSignature(stx, DUMMY_NOTARY_KEY.public) + } + commercialPaper2.verifyRequiredSignatures() services.recordTransactions(commercialPaper2) val result = builder { @@ -1625,9 +1677,9 @@ class VaultQueryTests { fun `query attempting to use unregistered schema`() { database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) // CashSchemaV3 NOT registered with NodeSchemaService val logicalExpression = builder { SampleCashSchemaV3.PersistentCashState::currency.equal(GBP.currencyCode) } @@ -1647,10 +1699,10 @@ class VaultQueryTests { database.transaction { - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(10.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) - services.fillWithSomeTestCash(1.DOLLARS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(10.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(1.DOLLARS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) // DOCSTART VaultQueryExample20 val generalCriteria = VaultQueryCriteria(Vault.StateStatus.ALL) @@ -1749,7 +1801,8 @@ class VaultQueryTests { database.transaction { services.fillWithSomeTestLinearStates(1, "TEST1") - services.fillWithSomeTestLinearStates(1, "TEST2") + val aState = services.fillWithSomeTestLinearStates(1, "TEST2").states + services.consumeLinearStates(aState.toList(), DUMMY_NOTARY) val uuid = services.fillWithSomeTestLinearStates(1, "TEST3").states.first().state.data.linearId.id // 2 unconsumed states with same external ID, 1 with different external ID @@ -1820,7 +1873,7 @@ class VaultQueryTests { val updates = database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 5, 5, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 5, 5, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(10).states val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")).states @@ -1832,14 +1885,14 @@ class VaultQueryTests { assertThat(snapshot.statesMetadata).hasSize(5) // add more cash - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) // add another deal services.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) // consume stuff - services.consumeCash(100.DOLLARS) - services.consumeDeals(dealStates.toList()) - services.consumeLinearStates(linearStates.toList()) + services.consumeCash(100.DOLLARS, notary = DUMMY_NOTARY) + services.consumeDeals(dealStates.toList(), DUMMY_NOTARY) + services.consumeLinearStates(linearStates.toList(), DUMMY_NOTARY) updates } @@ -1848,12 +1901,12 @@ class VaultQueryTests { sequence( expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 5) {} }, expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 1) {} } ) @@ -1865,17 +1918,17 @@ class VaultQueryTests { val updates = database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 5, 5, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 5, 5, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(10).states val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")).states // add more cash - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) // add another deal services.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) // consume stuff - services.consumeCash(100.POUNDS) + services.consumeCash(100.POUNDS, notary = DUMMY_NOTARY) val criteria = VaultQueryCriteria(status = Vault.StateStatus.CONSUMED) val (snapshot, updates) = vaultQuerySvc.trackBy(criteria) @@ -1884,9 +1937,9 @@ class VaultQueryTests { assertThat(snapshot.statesMetadata).hasSize(1) // consume more stuff - services.consumeCash(100.DOLLARS) - services.consumeDeals(dealStates.toList()) - services.consumeLinearStates(linearStates.toList()) + services.consumeCash(100.DOLLARS, notary = DUMMY_NOTARY) + services.consumeDeals(dealStates.toList(), DUMMY_NOTARY) + services.consumeLinearStates(linearStates.toList(), DUMMY_NOTARY) updates } @@ -1896,12 +1949,12 @@ class VaultQueryTests { expect { (consumed, produced, flowId) -> require(flowId == null) {} require(consumed.size == 1) {} - require(produced.size == 0) {} + require(produced.isEmpty()) {} }, expect { (consumed, produced, flowId) -> require(flowId == null) {} require(consumed.size == 5) {} - require(produced.size == 0) {} + require(produced.isEmpty()) {} } ) } @@ -1912,17 +1965,17 @@ class VaultQueryTests { val updates = database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 5, 5, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 5, 5, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(10).states val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")).states // add more cash - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) // add another deal services.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) // consume stuff - services.consumeCash(99.POUNDS) + services.consumeCash(99.POUNDS, notary = DUMMY_NOTARY) val criteria = VaultQueryCriteria(status = Vault.StateStatus.ALL) val (snapshot, updates) = vaultQuerySvc.trackBy(criteria) @@ -1931,9 +1984,9 @@ class VaultQueryTests { assertThat(snapshot.statesMetadata).hasSize(7) // consume more stuff - services.consumeCash(100.DOLLARS) - services.consumeDeals(dealStates.toList()) - services.consumeLinearStates(linearStates.toList()) + services.consumeCash(100.DOLLARS, notary = DUMMY_NOTARY) + services.consumeDeals(dealStates.toList(), DUMMY_NOTARY) + services.consumeLinearStates(linearStates.toList(), DUMMY_NOTARY) updates } @@ -1942,12 +1995,12 @@ class VaultQueryTests { sequence( expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 5) {} }, expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 1) {} }, expect { (consumed, produced, flowId) -> @@ -1958,7 +2011,7 @@ class VaultQueryTests { expect { (consumed, produced, flowId) -> require(flowId == null) {} require(consumed.size == 5) {} - require(produced.size == 0) {} + require(produced.isEmpty()) {} } ) } @@ -1969,7 +2022,7 @@ class VaultQueryTests { val updates = database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(10).states val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")).states @@ -1982,14 +2035,14 @@ class VaultQueryTests { assertThat(snapshot.statesMetadata).hasSize(13) // add more cash - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) // add another deal services.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) // consume stuff - services.consumeCash(100.DOLLARS) - services.consumeDeals(dealStates.toList()) - services.consumeLinearStates(linearStates.toList()) + services.consumeCash(100.DOLLARS, notary = DUMMY_NOTARY) + services.consumeDeals(dealStates.toList(), DUMMY_NOTARY) + services.consumeLinearStates(linearStates.toList(), DUMMY_NOTARY) updates } @@ -1998,17 +2051,17 @@ class VaultQueryTests { sequence( expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 10) {} }, expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 3) {} }, expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 1) {} } ) @@ -2020,7 +2073,7 @@ class VaultQueryTests { val updates = database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, notaryServices, DUMMY_NOTARY, 3, 3, Random(0L)) val linearStates = services.fillWithSomeTestLinearStates(10).states val dealStates = services.fillWithSomeTestDeals(listOf("123", "456", "789")).states @@ -2032,14 +2085,14 @@ class VaultQueryTests { assertThat(snapshot.statesMetadata).hasSize(3) // add more cash - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, notaryServices, DUMMY_NOTARY, 1, 1, Random(0L)) // add another deal services.fillWithSomeTestDeals(listOf("SAMPLE DEAL")) // consume stuff - services.consumeCash(100.DOLLARS) - services.consumeDeals(dealStates.toList()) - services.consumeLinearStates(linearStates.toList()) + services.consumeCash(100.DOLLARS, notary = DUMMY_NOTARY) + services.consumeDeals(dealStates.toList(), DUMMY_NOTARY) + services.consumeLinearStates(linearStates.toList(), DUMMY_NOTARY) updates } @@ -2048,12 +2101,12 @@ class VaultQueryTests { sequence( expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 3) {} }, expect { (consumed, produced, flowId) -> require(flowId == null) {} - require(consumed.size == 0) {} + require(consumed.isEmpty()) {} require(produced.size == 1) {} } ) diff --git a/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt b/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt index 4ca05bb8e9..0b35c7a775 100644 --- a/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/vault/VaultWithCashTest.kt @@ -1,85 +1,65 @@ package net.corda.node.services.vault -import net.corda.testing.contracts.DummyDealContract import net.corda.contracts.asset.Cash import net.corda.contracts.asset.DUMMY_CASH_ISSUER -import net.corda.testing.contracts.fillWithSomeTestCash -import net.corda.testing.contracts.fillWithSomeTestDeals -import net.corda.testing.contracts.fillWithSomeTestLinearStates +import net.corda.contracts.asset.DUMMY_CASH_ISSUER_KEY +import net.corda.contracts.getCashBalance import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyLinearContract import net.corda.core.identity.AnonymousParty +import net.corda.core.node.services.Vault +import net.corda.core.node.services.VaultQueryService import net.corda.core.node.services.VaultService -import net.corda.core.node.services.consumedStates -import net.corda.core.node.services.unconsumedStates -import net.corda.core.transactions.SignedTransaction -import net.corda.testing.BOB -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY -import net.corda.testing.LogHelper -import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction -import net.corda.testing.MEGA_CORP -import net.corda.testing.MEGA_CORP_KEY +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria +import net.corda.core.transactions.TransactionBuilder +import net.corda.node.utilities.CordaPersistence +import net.corda.testing.* +import net.corda.testing.contracts.* import net.corda.testing.node.MockServices -import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseAndMockServices import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Before import org.junit.Test -import java.io.Closeable import java.util.* import java.util.concurrent.CountDownLatch import java.util.concurrent.Executors import kotlin.test.assertEquals -import kotlin.test.assertNull // TODO: Move this to the cash contract tests once mock services are further split up. -class VaultWithCashTest { +class VaultWithCashTest : TestDependencyInjectionBase() { lateinit var services: MockServices + lateinit var issuerServices: MockServices val vault: VaultService get() = services.vaultService - lateinit var dataSource: Closeable - lateinit var database: Database + val vaultQuery: VaultQueryService get() = services.vaultQueryService + lateinit var database: CordaPersistence val notaryServices = MockServices(DUMMY_NOTARY_KEY) @Before fun setUp() { LogHelper.setLevel(VaultWithCashTest::class) - val dataSourceProps = makeTestDataSourceProperties() - val dataSourceAndDatabase = configureDatabase(dataSourceProps) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second - database.transaction { - services = object : MockServices() { - override val vaultService: VaultService = makeVaultService(dataSourceProps) - - override fun recordTransactions(txs: Iterable) { - for (stx in txs) { - validatedTransactions.addTransaction(stx) - } - // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. - vaultService.notifyAll(txs.map { it.tx }) - } - } - } + val databaseAndServices = makeTestDatabaseAndMockServices(keys = listOf(DUMMY_CASH_ISSUER_KEY, DUMMY_NOTARY_KEY)) + database = databaseAndServices.first + services = databaseAndServices.second + issuerServices = MockServices(DUMMY_CASH_ISSUER_KEY, MEGA_CORP_KEY) } @After fun tearDown() { LogHelper.reset(VaultWithCashTest::class) - dataSource.close() + database.close() } @Test fun splits() { database.transaction { // Fix the PRNG so that we get the same splits every time. - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L), issuedBy = DUMMY_CASH_ISSUER) - val w = vault.unconsumedStates().toList() + val w = vaultQuery.queryBy().states assertEquals(3, w.size) val state = w[0].state.data @@ -98,33 +78,33 @@ class VaultWithCashTest { database.transaction { // A tx that sends us money. val freshKey = services.keyManagementService.freshKey() - val usefulBuilder = TransactionType.General.Builder(null) + val usefulBuilder = TransactionBuilder(null) Cash().generateIssue(usefulBuilder, 100.DOLLARS `issued by` MEGA_CORP.ref(1), AnonymousParty(freshKey), DUMMY_NOTARY) val usefulTX = megaCorpServices.signInitialTransaction(usefulBuilder) - assertNull(vault.cashBalances[USD]) + assertEquals(0.DOLLARS, services.getCashBalance(USD)) services.recordTransactions(usefulTX) // A tx that spends our money. - val spendTXBuilder = TransactionType.General.Builder(DUMMY_NOTARY) - vault.generateSpend(spendTXBuilder, 80.DOLLARS, BOB) + val spendTXBuilder = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(services, spendTXBuilder, 80.DOLLARS, BOB) val spendPTX = services.signInitialTransaction(spendTXBuilder, freshKey) val spendTX = notaryServices.addSignature(spendPTX) - assertEquals(100.DOLLARS, vault.cashBalances[USD]) + assertEquals(100.DOLLARS, services.getCashBalance(USD)) // A tx that doesn't send us anything. - val irrelevantBuilder = TransactionType.General.Builder(DUMMY_NOTARY) + val irrelevantBuilder = TransactionBuilder(DUMMY_NOTARY) Cash().generateIssue(irrelevantBuilder, 100.DOLLARS `issued by` MEGA_CORP.ref(1), BOB, DUMMY_NOTARY) val irrelevantPTX = megaCorpServices.signInitialTransaction(irrelevantBuilder) val irrelevantTX = notaryServices.addSignature(irrelevantPTX) services.recordTransactions(irrelevantTX) - assertEquals(100.DOLLARS, vault.cashBalances[USD]) + assertEquals(100.DOLLARS, services.getCashBalance(USD)) services.recordTransactions(spendTX) - assertEquals(20.DOLLARS, vault.cashBalances[USD]) + assertEquals(20.DOLLARS, services.getCashBalance(USD)) // TODO: Flesh out these tests as needed. } @@ -133,41 +113,47 @@ class VaultWithCashTest { @Test fun `issue and attempt double spend`() { val freshKey = services.keyManagementService.freshKey() + val criteriaLocked = VaultQueryCriteria(softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.LOCKED_ONLY)) database.transaction { // A tx that sends us money. - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 10, 10, Random(0L), - issuedBy = MEGA_CORP.ref(1), - issuerKey = MEGA_CORP_KEY, - ownedBy = AnonymousParty(freshKey)) - println("Cash balance: ${vault.cashBalances[USD]}") + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 10, 10, Random(0L), ownedBy = AnonymousParty(freshKey), + issuedBy = MEGA_CORP.ref(1)) + println("Cash balance: ${services.getCashBalance(USD)}") - assertThat(vault.unconsumedStates()).hasSize(10) - assertThat(vault.softLockedStates()).hasSize(0) + assertThat(vaultQuery.queryBy().states).hasSize(10) + assertThat(vaultQuery.queryBy(criteriaLocked).states).hasSize(0) } val backgroundExecutor = Executors.newFixedThreadPool(2) val countDown = CountDownLatch(2) + // 1st tx that spends our money. backgroundExecutor.submit { database.transaction { try { - val txn1Builder = TransactionType.General.Builder(DUMMY_NOTARY) - vault.generateSpend(txn1Builder, 60.DOLLARS, BOB) + val txn1Builder = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(services, txn1Builder, 60.DOLLARS, BOB) val ptxn1 = notaryServices.signInitialTransaction(txn1Builder) val txn1 = services.addSignature(ptxn1, freshKey) println("txn1: ${txn1.id} spent ${((txn1.tx.outputs[0].data) as Cash.State).amount}") + val unconsumedStates1 = vaultQuery.queryBy() + val consumedStates1 = vaultQuery.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) + val lockedStates1 = vaultQuery.queryBy(criteriaLocked).states println("""txn1 states: - UNCONSUMED: ${vault.unconsumedStates().count()} : ${vault.unconsumedStates()}, - CONSUMED: ${vault.consumedStates().count()} : ${vault.consumedStates()}, - LOCKED: ${vault.softLockedStates().count()} : ${vault.softLockedStates()} + UNCONSUMED: ${unconsumedStates1.totalStatesAvailable} : $unconsumedStates1, + CONSUMED: ${consumedStates1.totalStatesAvailable} : $consumedStates1, + LOCKED: ${lockedStates1.count()} : $lockedStates1 """) services.recordTransactions(txn1) - println("txn1: Cash balance: ${vault.cashBalances[USD]}") + println("txn1: Cash balance: ${services.getCashBalance(USD)}") + val unconsumedStates2 = vaultQuery.queryBy() + val consumedStates2 = vaultQuery.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) + val lockedStates2 = vaultQuery.queryBy(criteriaLocked).states println("""txn1 states: - UNCONSUMED: ${vault.unconsumedStates().count()} : ${vault.unconsumedStates()}, - CONSUMED: ${vault.consumedStates().count()} : ${vault.consumedStates()}, - LOCKED: ${vault.softLockedStates().count()} : ${vault.softLockedStates()} + UNCONSUMED: ${unconsumedStates2.totalStatesAvailable} : $unconsumedStates2, + CONSUMED: ${consumedStates2.totalStatesAvailable} : $consumedStates2, + LOCKED: ${lockedStates2.count()} : $lockedStates2 """) txn1 } catch(e: Exception) { @@ -182,22 +168,28 @@ class VaultWithCashTest { backgroundExecutor.submit { database.transaction { try { - val txn2Builder = TransactionType.General.Builder(DUMMY_NOTARY) - vault.generateSpend(txn2Builder, 80.DOLLARS, BOB) + val txn2Builder = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(services, txn2Builder, 80.DOLLARS, BOB) val ptxn2 = notaryServices.signInitialTransaction(txn2Builder) val txn2 = services.addSignature(ptxn2, freshKey) println("txn2: ${txn2.id} spent ${((txn2.tx.outputs[0].data) as Cash.State).amount}") + val unconsumedStates1 = vaultQuery.queryBy() + val consumedStates1 = vaultQuery.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)) + val lockedStates1 = vaultQuery.queryBy(criteriaLocked).states println("""txn2 states: - UNCONSUMED: ${vault.unconsumedStates().count()} : ${vault.unconsumedStates()}, - CONSUMED: ${vault.consumedStates().count()} : ${vault.consumedStates()}, - LOCKED: ${vault.softLockedStates().count()} : ${vault.softLockedStates()} + UNCONSUMED: ${unconsumedStates1.totalStatesAvailable} : $unconsumedStates1, + CONSUMED: ${consumedStates1.totalStatesAvailable} : $consumedStates1, + LOCKED: ${lockedStates1.count()} : $lockedStates1 """) services.recordTransactions(txn2) - println("txn2: Cash balance: ${vault.cashBalances[USD]}") + println("txn2: Cash balance: ${services.getCashBalance(USD)}") + val unconsumedStates2 = vaultQuery.queryBy() + val consumedStates2 = vaultQuery.queryBy() + val lockedStates2 = vaultQuery.queryBy(criteriaLocked).states println("""txn2 states: - UNCONSUMED: ${vault.unconsumedStates().count()} : ${vault.unconsumedStates()}, - CONSUMED: ${vault.consumedStates().count()} : ${vault.consumedStates()}, - LOCKED: ${vault.softLockedStates().count()} : ${vault.softLockedStates()} + UNCONSUMED: ${unconsumedStates2.totalStatesAvailable} : $unconsumedStates2, + CONSUMED: ${consumedStates2.totalStatesAvailable} : $consumedStates2, + LOCKED: ${lockedStates2.count()} : $lockedStates2 """) txn2 } catch(e: Exception) { @@ -211,8 +203,8 @@ class VaultWithCashTest { countDown.await() database.transaction { - println("Cash balance: ${vault.cashBalances[USD]}") - assertThat(vault.cashBalances[USD]).isIn(DOLLARS(20), DOLLARS(40)) + println("Cash balance: ${services.getCashBalance(USD)}") + assertThat(services.getCashBalance(USD)).isIn(DOLLARS(20), DOLLARS(40)) } } @@ -224,7 +216,7 @@ class VaultWithCashTest { val linearId = UniqueIdentifier() // Issue a linear state - val dummyIssueBuilder = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + val dummyIssueBuilder = TransactionBuilder(notary = DUMMY_NOTARY).apply { addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshIdentity))) addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshIdentity))) } @@ -245,7 +237,7 @@ class VaultWithCashTest { val linearId = UniqueIdentifier() // Issue a linear state - val dummyIssueBuilder = TransactionType.General.Builder(notary = DUMMY_NOTARY) + val dummyIssueBuilder = TransactionBuilder(notary = DUMMY_NOTARY) dummyIssueBuilder.addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshIdentity))) val dummyIssuePtx = notaryServices.signInitialTransaction(dummyIssueBuilder) val dummyIssue = services.addSignature(dummyIssuePtx) @@ -253,10 +245,10 @@ class VaultWithCashTest { dummyIssue.toLedgerTransaction(services).verify() services.recordTransactions(dummyIssue) - assertThat(vault.unconsumedStates()).hasSize(1) + assertThat(vaultQuery.queryBy().states).hasSize(1) // Move the same state - val dummyMoveBuilder = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + val dummyMoveBuilder = TransactionBuilder(notary = DUMMY_NOTARY).apply { addOutputState(DummyLinearContract.State(linearId = linearId, participants = listOf(freshIdentity))) addInputState(dummyIssue.tx.outRef(0)) } @@ -266,7 +258,7 @@ class VaultWithCashTest { dummyIssue.toLedgerTransaction(services).verify() services.recordTransactions(dummyMove) - assertThat(vault.unconsumedStates()).hasSize(1) + assertThat(vaultQuery.queryBy().states).hasSize(1) } } @@ -275,29 +267,29 @@ class VaultWithCashTest { val freshKey = services.keyManagementService.freshKey() database.transaction { - services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L), ownedBy = AnonymousParty(freshKey)) - services.fillWithSomeTestCash(100.SWISS_FRANCS, DUMMY_NOTARY, 2, 2, Random(0L)) - services.fillWithSomeTestCash(100.POUNDS, DUMMY_NOTARY, 1, 1, Random(0L)) - val cash = vault.unconsumedStates() + services.fillWithSomeTestCash(100.DOLLARS, issuerServices, DUMMY_NOTARY, 3, 3, Random(0L), ownedBy = AnonymousParty(freshKey)) + services.fillWithSomeTestCash(100.SWISS_FRANCS, issuerServices, DUMMY_NOTARY, 2, 2, Random(0L)) + services.fillWithSomeTestCash(100.POUNDS, issuerServices, DUMMY_NOTARY, 1, 1, Random(0L)) + val cash = vaultQuery.queryBy().states cash.forEach { println(it.state.data.amount) } services.fillWithSomeTestDeals(listOf("123", "456", "789")) - val deals = vault.unconsumedStates() - deals.forEach { println(it.state.data.ref) } + val deals = vaultQuery.queryBy().states + deals.forEach { println(it.state.data.linearId.externalId!!) } } database.transaction { // A tx that spends our money. - val spendTXBuilder = TransactionType.General.Builder(DUMMY_NOTARY) - vault.generateSpend(spendTXBuilder, 80.DOLLARS, BOB) + val spendTXBuilder = TransactionBuilder(DUMMY_NOTARY) + Cash.generateSpend(services, spendTXBuilder, 80.DOLLARS, BOB) val spendPTX = notaryServices.signInitialTransaction(spendTXBuilder) val spendTX = services.addSignature(spendPTX, freshKey) services.recordTransactions(spendTX) - val consumedStates = vault.consumedStates() + val consumedStates = vaultQuery.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)).states assertEquals(3, consumedStates.count()) - val unconsumedStates = vault.unconsumedStates() + val unconsumedStates = vaultQuery.queryBy().states assertEquals(7, unconsumedStates.count()) } } @@ -310,15 +302,15 @@ class VaultWithCashTest { database.transaction { services.fillWithSomeTestDeals(listOf("123", "456", "789")) - val deals = vault.unconsumedStates().toList() - deals.forEach { println(it.state.data.ref) } + val deals = vaultQuery.queryBy().states + deals.forEach { println(it.state.data.linearId.externalId!!) } services.fillWithSomeTestLinearStates(3) - val linearStates = vault.unconsumedStates().toList() + val linearStates = vaultQuery.queryBy().states linearStates.forEach { println(it.state.data.linearId) } // Create a txn consuming different contract types - val dummyMoveBuilder = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + val dummyMoveBuilder = TransactionBuilder(notary = DUMMY_NOTARY).apply { addOutputState(DummyLinearContract.State(participants = listOf(freshIdentity))) addOutputState(DummyDealContract.State(ref = "999", participants = listOf(freshIdentity))) addInputState(linearStates.first()) @@ -330,10 +322,10 @@ class VaultWithCashTest { dummyMove.toLedgerTransaction(services).verify() services.recordTransactions(dummyMove) - val consumedStates = vault.consumedStates() + val consumedStates = vaultQuery.queryBy(VaultQueryCriteria(status = Vault.StateStatus.CONSUMED)).states assertEquals(2, consumedStates.count()) - val unconsumedStates = vault.unconsumedStates() + val unconsumedStates = vaultQuery.queryBy().states assertEquals(6, unconsumedStates.count()) } } diff --git a/node/src/test/kotlin/net/corda/node/shell/CustomTypeJsonParsingTests.kt b/node/src/test/kotlin/net/corda/node/shell/CustomTypeJsonParsingTests.kt new file mode 100644 index 0000000000..1622f839f8 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/shell/CustomTypeJsonParsingTests.kt @@ -0,0 +1,71 @@ +package net.corda.node.shell + +import com.fasterxml.jackson.databind.JsonMappingException +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.module.SimpleModule +import com.fasterxml.jackson.module.kotlin.readValue +import net.corda.core.contracts.UniqueIdentifier +import org.junit.Before +import org.junit.Test +import java.util.* +import kotlin.test.assertEquals + +class CustomTypeJsonParsingTests { + lateinit var objectMapper: ObjectMapper + + //Dummy classes for testing. + data class State(val linearId: UniqueIdentifier) { + constructor() : this(UniqueIdentifier("required-for-json-deserializer")) + } + + data class UuidState(val uuid: UUID) { + //Default constructor required for json deserializer. + constructor() : this(UUID.randomUUID()) + } + + @Before + fun setup() { + objectMapper = ObjectMapper() + val simpleModule = SimpleModule() + simpleModule.addDeserializer(UniqueIdentifier::class.java, InteractiveShell.UniqueIdentifierDeserializer) + simpleModule.addDeserializer(UUID::class.java, InteractiveShell.UUIDDeserializer) + objectMapper.registerModule(simpleModule) + } + + @Test + fun `Deserializing UniqueIdentifier by parsing string`() { + val json = """{"linearId":"26b37265-a1fd-4c77-b2e0-715917ef619f"}""" + val state = objectMapper.readValue(json) + + assertEquals("26b37265-a1fd-4c77-b2e0-715917ef619f", state.linearId.externalId) + } + + @Test + fun `Deserializing UniqueIdentifier by parsing string with underscore`() { + val json = """{"linearId":"extkey564_26b37265-a1fd-4c77-b2e0-715917ef619f"}""" + val state = objectMapper.readValue(json) + + assertEquals("extkey564", state.linearId.externalId) + assertEquals("26b37265-a1fd-4c77-b2e0-715917ef619f", state.linearId.id.toString()) + } + + @Test(expected = JsonMappingException::class) + fun `Deserializing by parsing string contain invalid uuid with underscore`() { + val json = """{"linearId":"extkey564_26b37265-a1fd-4c77-b2e0"}""" + objectMapper.readValue(json) + } + + @Test + fun `Deserializing UUID by parsing string`() { + val json = """{"uuid":"26b37265-a1fd-4c77-b2e0-715917ef619f"}""" + val state = objectMapper.readValue(json) + + assertEquals("26b37265-a1fd-4c77-b2e0-715917ef619f", state.uuid.toString()) + } + + @Test(expected = JsonMappingException::class) + fun `Deserializing UUID by parsing invalid uuid string`() { + val json = """{"uuid":"26b37265-a1fd-4c77-b2e0"}""" + objectMapper.readValue(json) + } +} \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/utilities/ClockUtilsTest.kt b/node/src/test/kotlin/net/corda/node/utilities/ClockUtilsTest.kt index 38a5c91d67..23723f70f2 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/ClockUtilsTest.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/ClockUtilsTest.kt @@ -5,7 +5,9 @@ import co.paralleluniverse.fibers.FiberExecutorScheduler import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.strands.Strand import com.google.common.util.concurrent.SettableFuture -import net.corda.core.getOrThrow +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.hours +import net.corda.core.utilities.minutes import net.corda.testing.node.TestClock import org.junit.After import org.junit.Before @@ -44,7 +46,7 @@ class ClockUtilsTest { @Test fun `test waiting negative time for a deadline`() { - assertFalse(stoppedClock.awaitWithDeadline(stoppedClock.instant().minus(Duration.ofHours(1))), "Should have reached deadline") + assertFalse(stoppedClock.awaitWithDeadline(stoppedClock.instant().minus(1.hours)), "Should have reached deadline") } @Test @@ -56,13 +58,13 @@ class ClockUtilsTest { @Test fun `test waiting negative time for a deadline with incomplete future`() { val future = SettableFuture.create() - assertFalse(stoppedClock.awaitWithDeadline(stoppedClock.instant().minus(Duration.ofHours(1)), future), "Should have reached deadline") + assertFalse(stoppedClock.awaitWithDeadline(stoppedClock.instant().minus(1.hours), future), "Should have reached deadline") } @Test fun `test waiting for a deadline with future completed before wait`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) + val advancedClock = Clock.offset(stoppedClock, 1.hours) val future = SettableFuture.create() completeNow(future) assertTrue(stoppedClock.awaitWithDeadline(advancedClock.instant(), future), "Should not have reached deadline") @@ -70,7 +72,7 @@ class ClockUtilsTest { @Test fun `test waiting for a deadline with future completed after wait`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) + val advancedClock = Clock.offset(stoppedClock, 1.hours) val future = SettableFuture.create() completeAfterWaiting(future) assertTrue(stoppedClock.awaitWithDeadline(advancedClock.instant(), future), "Should not have reached deadline") @@ -78,38 +80,38 @@ class ClockUtilsTest { @Test fun `test waiting for a deadline with clock advance`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) + val advancedClock = Clock.offset(stoppedClock, 1.hours) val testClock = TestClock(stoppedClock) - advanceClockAfterWait(testClock, Duration.ofHours(1)) + advanceClockAfterWait(testClock, 1.hours) assertFalse(testClock.awaitWithDeadline(advancedClock.instant()), "Should have reached deadline") } @Test fun `test waiting for a deadline with clock advance and incomplete future`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) + val advancedClock = Clock.offset(stoppedClock, 1.hours) val testClock = TestClock(stoppedClock) val future = SettableFuture.create() - advanceClockAfterWait(testClock, Duration.ofHours(1)) + advanceClockAfterWait(testClock, 1.hours) assertFalse(testClock.awaitWithDeadline(advancedClock.instant(), future), "Should have reached deadline") } @Test fun `test waiting for a deadline with clock advance and complete future`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(2)) + val advancedClock = Clock.offset(stoppedClock, 2.hours) val testClock = TestClock(stoppedClock) val future = SettableFuture.create() - advanceClockAfterWait(testClock, Duration.ofHours(1)) + advanceClockAfterWait(testClock, 1.hours) completeAfterWaiting(future) assertTrue(testClock.awaitWithDeadline(advancedClock.instant(), future), "Should not have reached deadline") } @Test fun `test waiting for a deadline with multiple clock advance and incomplete future`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) + val advancedClock = Clock.offset(stoppedClock, 1.hours) val testClock = TestClock(stoppedClock) val future = SettableFuture.create() for (advance in 1..6) { - advanceClockAfterWait(testClock, Duration.ofMinutes(10)) + advanceClockAfterWait(testClock, 10.minutes) } assertFalse(testClock.awaitWithDeadline(advancedClock.instant(), future), "Should have reached deadline") } @@ -126,7 +128,7 @@ class ClockUtilsTest { } val testClock = TestClock(stoppedClock) - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(10)) + val advancedClock = Clock.offset(stoppedClock, 10.hours) try { testClock.awaitWithDeadline(advancedClock.instant(), SettableFuture.create()) @@ -138,7 +140,7 @@ class ClockUtilsTest { @Test @Suspendable fun `test waiting for a deadline with multiple clock advance and incomplete JDK8 future on Fibers`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) + val advancedClock = Clock.offset(stoppedClock, 1.hours) val testClock = TestClock(stoppedClock) val future = CompletableFuture() val scheduler = FiberExecutorScheduler("test", executor) @@ -151,7 +153,7 @@ class ClockUtilsTest { while (fiber.state != Strand.State.TIMED_WAITING) { Strand.sleep(1) } - testClock.advanceBy(Duration.ofMinutes(10)) + testClock.advanceBy(10.minutes) }).start() } assertFalse(future.getOrThrow(), "Should have reached deadline") @@ -160,7 +162,7 @@ class ClockUtilsTest { @Test @Suspendable fun `test waiting for a deadline with multiple clock advance and incomplete Guava future on Fibers`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) + val advancedClock = Clock.offset(stoppedClock, 1.hours) val testClock = TestClock(stoppedClock) val future = SettableFuture.create() val scheduler = FiberExecutorScheduler("test", executor) @@ -173,7 +175,7 @@ class ClockUtilsTest { while (fiber.state != Strand.State.TIMED_WAITING) { Strand.sleep(1) } - testClock.advanceBy(Duration.ofMinutes(10)) + testClock.advanceBy(10.minutes) }).start() } assertFalse(future.getOrThrow(), "Should have reached deadline") diff --git a/node/src/test/kotlin/net/corda/node/utilities/FiberBoxTest.kt b/node/src/test/kotlin/net/corda/node/utilities/FiberBoxTest.kt deleted file mode 100644 index 80211f61aa..0000000000 --- a/node/src/test/kotlin/net/corda/node/utilities/FiberBoxTest.kt +++ /dev/null @@ -1,167 +0,0 @@ -package net.corda.node.utilities - -import co.paralleluniverse.fibers.FiberExecutorScheduler -import co.paralleluniverse.fibers.Suspendable -import co.paralleluniverse.strands.Strand -import net.corda.core.RetryableException -import net.corda.core.getOrThrow -import net.corda.testing.node.TestClock -import org.junit.After -import org.junit.Before -import org.junit.Test -import java.time.Clock -import java.time.Duration -import java.util.concurrent.CompletableFuture -import java.util.concurrent.ExecutorService -import java.util.concurrent.Executors -import kotlin.test.assertEquals - -class FiberBoxTest { - - class Content { - var integer: Int = 0 - } - - class TestRetryableException(message: String) : RetryableException(message) - - lateinit var mutex: FiberBox - lateinit var realClock: Clock - lateinit var stoppedClock: Clock - lateinit var executor: ExecutorService - - @Before - fun setup() { - mutex = FiberBox(Content()) - realClock = Clock.systemUTC() - stoppedClock = Clock.fixed(realClock.instant(), realClock.zone) - executor = Executors.newSingleThreadExecutor() - } - - @After - fun teardown() { - executor.shutdown() - } - - @Test - fun `write and read`() { - mutex.write { integer = 1 } - assertEquals(1, mutex.read { integer }) - } - - @Test - fun `readWithDeadline with no wait`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) - - mutex.write { integer = 1 } - assertEquals(1, mutex.readWithDeadline(realClock, advancedClock.instant()) { integer }) - } - - @Test - fun `readWithDeadline with stopped clock and background write`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) - - assertEquals(1, mutex.readWithDeadline(stoppedClock, advancedClock.instant()) { - backgroundWrite() - if (integer == 1) 1 else throw TestRetryableException("Not 1") - }) - } - - @Test(expected = TestRetryableException::class) - fun `readWithDeadline with clock advanced`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) - val testClock = TestClock(stoppedClock) - - assertEquals(1, mutex.readWithDeadline(testClock, advancedClock.instant()) { - backgroundAdvanceClock(testClock, Duration.ofHours(1)) - if (integer == 1) 0 else throw TestRetryableException("Not 1") - }) - } - - @Test - fun `readWithDeadline with clock advanced 5x and background write`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) - val testClock = TestClock(stoppedClock) - - assertEquals(5, mutex.readWithDeadline(testClock, advancedClock.instant()) { - backgroundAdvanceClock(testClock, Duration.ofMinutes(10)) - backgroundWrite() - if (integer == 5) 5 else throw TestRetryableException("Not 5") - }) - } - - /** - * If this test seems to hang and throw an NPE, then likely that quasar suspendables scanner has not been - * run on core module (in IntelliJ, open gradle side tab and run: - * r3prototyping -> core -> Tasks -> other -> quasarScan - */ - @Test(expected = TestRetryableException::class) - @Suspendable - fun `readWithDeadline with clock advanced on Fibers`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) - val testClock = TestClock(stoppedClock) - val future = CompletableFuture() - val scheduler = FiberExecutorScheduler("test", executor) - val fiber = scheduler.newFiber(@Suspendable { - try { - future.complete(mutex.readWithDeadline(testClock, advancedClock.instant()) { - if (integer == 1) 1 else throw TestRetryableException("Not 1") - }) - } catch(e: Exception) { - future.completeExceptionally(e) - } - }).start() - for (advance in 1..6) { - scheduler.newFiber(@Suspendable { - // Wait until fiber is waiting - while (fiber.state != Strand.State.TIMED_WAITING) { - Strand.sleep(1) - } - testClock.advanceBy(Duration.ofMinutes(10)) - }).start() - } - assertEquals(2, future.getOrThrow()) - } - - /** - * If this test seems to hang and throw an NPE, then likely that quasar suspendables scanner has not been - * run on core module (in IntelliJ, open gradle side tab and run: - * r3prototyping -> core -> Tasks -> other -> quasarScan - */ - @Test - @Suspendable - fun `readWithDeadline with background write on Fibers`() { - val advancedClock = Clock.offset(stoppedClock, Duration.ofHours(1)) - val testClock = TestClock(stoppedClock) - val future = CompletableFuture() - val scheduler = FiberExecutorScheduler("test", executor) - val fiber = scheduler.newFiber(@Suspendable { - try { - future.complete(mutex.readWithDeadline(testClock, advancedClock.instant()) { - if (integer == 1) 1 else throw TestRetryableException("Not 1") - }) - } catch(e: Exception) { - future.completeExceptionally(e) - } - }).start() - scheduler.newFiber(@Suspendable { - // Wait until fiber is waiting - while (fiber.state != Strand.State.TIMED_WAITING) { - Strand.sleep(1) - } - mutex.write { integer = 1 } - }).start() - assertEquals(1, future.getOrThrow()) - } - - private fun backgroundWrite() { - executor.execute { - mutex.write { integer += 1 } - } - } - - private fun backgroundAdvanceClock(clock: TestClock, duration: Duration) { - executor.execute { - clock.advanceBy(duration) - } - } -} diff --git a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt index 8e0c666796..0d99e3f3b8 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/ObservablesTests.kt @@ -1,12 +1,12 @@ package net.corda.node.utilities import com.google.common.util.concurrent.SettableFuture -import net.corda.core.bufferUntilSubscribed -import net.corda.core.tee +import net.corda.core.internal.bufferUntilSubscribed +import net.corda.core.internal.tee import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties +import net.corda.testing.node.makeTestIdentityService import org.assertj.core.api.Assertions.assertThat -import org.jetbrains.exposed.sql.Database -import org.jetbrains.exposed.sql.transactions.TransactionManager import org.junit.After import org.junit.Test import rx.Observable @@ -16,13 +16,13 @@ import java.util.* class ObservablesTests { - private fun isInDatabaseTransaction(): Boolean = (TransactionManager.currentOrNull() != null) + private fun isInDatabaseTransaction(): Boolean = (DatabaseTransactionManager.currentOrNull() != null) val toBeClosed = mutableListOf() - fun createDatabase(): Database { - val (closeable, database) = configureDatabase(makeTestDataSourceProperties()) - toBeClosed += closeable + fun createDatabase(): CordaPersistence { + val database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) + toBeClosed += database return database } @@ -167,7 +167,7 @@ class ObservablesTests { observableWithDbTx.first().subscribe { undelayedEvent.set(it to isInDatabaseTransaction()) } fun observeSecondEvent(event: Int, future: SettableFuture>) { - future.set(event to if (isInDatabaseTransaction()) StrandLocalTransactionManager.transactionId else null) + future.set(event to if (isInDatabaseTransaction()) DatabaseTransactionManager.transactionId else null) } observableWithDbTx.skip(1).first().subscribe { observeSecondEvent(it, delayedEventFromSecondObserver) } diff --git a/core/src/test/kotlin/net/corda/core/crypto/X509UtilitiesTest.kt b/node/src/test/kotlin/net/corda/node/utilities/X509UtilitiesTest.kt similarity index 73% rename from core/src/test/kotlin/net/corda/core/crypto/X509UtilitiesTest.kt rename to node/src/test/kotlin/net/corda/node/utilities/X509UtilitiesTest.kt index 780c300044..69ada450ff 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/X509UtilitiesTest.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/X509UtilitiesTest.kt @@ -1,17 +1,28 @@ -package net.corda.core.crypto +package net.corda.node.utilities +import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto.EDDSA_ED25519_SHA512 import net.corda.core.crypto.Crypto.generateKeyPair -import net.corda.core.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME -import net.corda.core.crypto.X509Utilities.createSelfSignedCACertificate -import net.corda.core.div -import net.corda.core.toTypedArray -import net.corda.testing.MEGA_CORP -import net.corda.testing.getTestX509Name +import net.corda.core.crypto.cert +import net.corda.core.crypto.commonName +import net.corda.core.crypto.getX509Name +import net.corda.core.internal.div +import net.corda.core.internal.toTypedArray +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.node.services.config.createKeystoreForCordaNode +import net.corda.nodeapi.internal.serialization.AllWhitelist +import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 +import net.corda.nodeapi.internal.serialization.SerializationContextImpl +import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl +import net.corda.testing.* import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x509.BasicConstraints import org.bouncycastle.asn1.x509.Extension import org.bouncycastle.asn1.x509.KeyUsage +import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.operator.jcajce.JcaContentVerifierProviderBuilder import org.junit.Rule import org.junit.Test @@ -25,7 +36,9 @@ import java.nio.file.Path import java.security.KeyStore import java.security.PrivateKey import java.security.SecureRandom +import java.security.cert.CertPath import java.security.cert.Certificate +import java.security.cert.CertificateFactory import java.security.cert.X509Certificate import java.util.* import java.util.stream.Stream @@ -40,8 +53,8 @@ class X509UtilitiesTest { @Test fun `create valid self-signed CA certificate`() { - val caKey = generateKeyPair(DEFAULT_TLS_SIGNATURE_SCHEME) - val caCert = createSelfSignedCACertificate(getTestX509Name("Test Cert"), caKey) + val caKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val caCert = X509Utilities.createSelfSignedCACertificate(getTestX509Name("Test Cert"), caKey) assertTrue { caCert.subject.commonName == "Test Cert" } // using our subject common name assertEquals(caCert.issuer, caCert.subject) //self-signed caCert.isValidOn(Date()) // throws on verification problems @@ -55,8 +68,8 @@ class X509UtilitiesTest { @Test fun `load and save a PEM file certificate`() { val tmpCertificateFile = tempFile("cacert.pem") - val caKey = generateKeyPair(DEFAULT_TLS_SIGNATURE_SCHEME) - val caCert = createSelfSignedCACertificate(getTestX509Name("Test Cert"), caKey) + val caKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val caCert = X509Utilities.createSelfSignedCACertificate(getTestX509Name("Test Cert"), caKey) X509Utilities.saveCertificateAsPEMFile(caCert, tmpCertificateFile) val readCertificate = X509Utilities.loadCertificateFromPEMFile(tmpCertificateFile) assertEquals(caCert, readCertificate) @@ -64,10 +77,10 @@ class X509UtilitiesTest { @Test fun `create valid server certificate chain`() { - val caKey = generateKeyPair(DEFAULT_TLS_SIGNATURE_SCHEME) - val caCert = createSelfSignedCACertificate(getTestX509Name("Test CA Cert"), caKey) + val caKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val caCert = X509Utilities.createSelfSignedCACertificate(getTestX509Name("Test CA Cert"), caKey) val subject = getTestX509Name("Server Cert") - val keyPair = generateKeyPair(DEFAULT_TLS_SIGNATURE_SCHEME) + val keyPair = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) val serverCert = X509Utilities.createCertificate(CertificateType.TLS, caCert, caKey, subject, keyPair.public) assertTrue { serverCert.subject.toString().contains("CN=Server Cert") } // using our subject common name assertEquals(caCert.issuer, serverCert.issuer) // Issued by our CA cert @@ -84,18 +97,18 @@ class X509UtilitiesTest { val tmpKeyStore = tempFile("keystore.jks") val keyPair = generateKeyPair(EDDSA_ED25519_SHA512) - val selfSignCert = createSelfSignedCACertificate(X500Name("CN=Test"), keyPair) + val selfSignCert = X509Utilities.createSelfSignedCACertificate(X500Name("CN=Test"), keyPair) assertTrue(Arrays.equals(selfSignCert.subjectPublicKeyInfo.encoded, keyPair.public.encoded)) // Save the EdDSA private key with self sign cert in the keystore. - val keyStore = KeyStoreUtilities.loadOrCreateKeyStore(tmpKeyStore, "keystorepass") + val keyStore = loadOrCreateKeyStore(tmpKeyStore, "keystorepass") keyStore.setKeyEntry("Key", keyPair.private, "password".toCharArray(), Stream.of(selfSignCert).map { it.cert }.toTypedArray()) keyStore.save(tmpKeyStore, "keystorepass") // Load the keystore from file and make sure keys are intact. - val keyStore2 = KeyStoreUtilities.loadOrCreateKeyStore(tmpKeyStore, "keystorepass") + val keyStore2 = loadOrCreateKeyStore(tmpKeyStore, "keystorepass") val privateKey = keyStore2.getKey("Key", "password".toCharArray()) val pubKey = keyStore2.getCertificate("Key").publicKey @@ -109,18 +122,18 @@ class X509UtilitiesTest { fun `signing EdDSA key with EcDSA certificate`() { val tmpKeyStore = tempFile("keystore.jks") val ecDSAKey = generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) - val ecDSACert = createSelfSignedCACertificate(X500Name("CN=Test"), ecDSAKey) + val ecDSACert = X509Utilities.createSelfSignedCACertificate(X500Name("CN=Test"), ecDSAKey) val edDSAKeypair = generateKeyPair(EDDSA_ED25519_SHA512) val edDSACert = X509Utilities.createCertificate(CertificateType.TLS, ecDSACert, ecDSAKey, X500Name("CN=TestEdDSA"), edDSAKeypair.public) // Save the EdDSA private key with cert chains. - val keyStore = KeyStoreUtilities.loadOrCreateKeyStore(tmpKeyStore, "keystorepass") + val keyStore = loadOrCreateKeyStore(tmpKeyStore, "keystorepass") keyStore.setKeyEntry("Key", edDSAKeypair.private, "password".toCharArray(), Stream.of(ecDSACert, edDSACert).map { it.cert }.toTypedArray()) keyStore.save(tmpKeyStore, "keystorepass") // Load the keystore from file and make sure keys are intact. - val keyStore2 = KeyStoreUtilities.loadOrCreateKeyStore(tmpKeyStore, "keystorepass") + val keyStore2 = loadOrCreateKeyStore(tmpKeyStore, "keystorepass") val privateKey = keyStore2.getKey("Key", "password".toCharArray()) val certs = keyStore2.getCertificateChain("Key") @@ -142,8 +155,8 @@ class X509UtilitiesTest { createCAKeyStoreAndTrustStore(tmpKeyStore, "keystorepass", "keypass", tmpTrustStore, "trustpass") // Load back generated root CA Cert and private key from keystore and check against copy in truststore - val keyStore = KeyStoreUtilities.loadKeyStore(tmpKeyStore, "keystorepass") - val trustStore = KeyStoreUtilities.loadKeyStore(tmpTrustStore, "trustpass") + val keyStore = loadKeyStore(tmpKeyStore, "keystorepass") + val trustStore = loadKeyStore(tmpTrustStore, "trustpass") val rootCaCert = keyStore.getCertificate(X509Utilities.CORDA_ROOT_CA) as X509Certificate val rootCaPrivateKey = keyStore.getKey(X509Utilities.CORDA_ROOT_CA, "keypass".toCharArray()) as PrivateKey val rootCaFromTrustStore = trustStore.getCertificate(X509Utilities.CORDA_ROOT_CA) as X509Certificate @@ -153,8 +166,8 @@ class X509UtilitiesTest { // Now sign something with private key and verify against certificate public key val testData = "12345".toByteArray() - val caSignature = Crypto.doSign(DEFAULT_TLS_SIGNATURE_SCHEME, rootCaPrivateKey, testData) - assertTrue { Crypto.isValid(DEFAULT_TLS_SIGNATURE_SCHEME, rootCaCert.publicKey, caSignature, testData) } + val caSignature = Crypto.doSign(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME, rootCaPrivateKey, testData) + assertTrue { Crypto.isValid(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME, rootCaCert.publicKey, caSignature, testData) } // Load back generated intermediate CA Cert and private key val intermediateCaCert = keyStore.getCertificate(X509Utilities.CORDA_INTERMEDIATE_CA) as X509Certificate @@ -163,8 +176,8 @@ class X509UtilitiesTest { intermediateCaCert.verify(rootCaCert.publicKey) // Now sign something with private key and verify against certificate public key - val intermediateSignature = Crypto.doSign(DEFAULT_TLS_SIGNATURE_SCHEME, intermediateCaCertPrivateKey, testData) - assertTrue { Crypto.isValid(DEFAULT_TLS_SIGNATURE_SCHEME, intermediateCaCert.publicKey, intermediateSignature, testData) } + val intermediateSignature = Crypto.doSign(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME, intermediateCaCertPrivateKey, testData) + assertTrue { Crypto.isValid(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME, intermediateCaCert.publicKey, intermediateSignature, testData) } } @Test @@ -182,14 +195,14 @@ class X509UtilitiesTest { "trustpass") // Load signing intermediate CA cert - val caKeyStore = KeyStoreUtilities.loadKeyStore(tmpCAKeyStore, "cakeystorepass") + val caKeyStore = loadKeyStore(tmpCAKeyStore, "cakeystorepass") val caCertAndKey = caKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_INTERMEDIATE_CA, "cakeypass") // Generate server cert and private key and populate another keystore suitable for SSL - X509Utilities.createKeystoreForCordaNode(tmpSSLKeyStore, tmpServerKeyStore, "serverstorepass", "serverkeypass", caKeyStore, "cakeypass", MEGA_CORP.name) + createKeystoreForCordaNode(tmpSSLKeyStore, tmpServerKeyStore, "serverstorepass", "serverkeypass", caKeyStore, "cakeypass", MEGA_CORP.name) // Load back server certificate - val serverKeyStore = KeyStoreUtilities.loadKeyStore(tmpServerKeyStore, "serverstorepass") + val serverKeyStore = loadKeyStore(tmpServerKeyStore, "serverstorepass") val serverCertAndKey = serverKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA, "serverkeypass") serverCertAndKey.certificate.isValidOn(Date()) @@ -198,7 +211,7 @@ class X509UtilitiesTest { assertTrue { serverCertAndKey.certificate.subject.toString().contains(MEGA_CORP.name.commonName) } // Load back server certificate - val sslKeyStore = KeyStoreUtilities.loadKeyStore(tmpSSLKeyStore, "serverstorepass") + val sslKeyStore = loadKeyStore(tmpSSLKeyStore, "serverstorepass") val sslCertAndKey = sslKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS, "serverkeypass") sslCertAndKey.certificate.isValidOn(Date()) @@ -207,9 +220,9 @@ class X509UtilitiesTest { assertTrue { sslCertAndKey.certificate.subject.toString().contains(MEGA_CORP.name.commonName) } // Now sign something with private key and verify against certificate public key val testData = "123456".toByteArray() - val signature = Crypto.doSign(DEFAULT_TLS_SIGNATURE_SCHEME, serverCertAndKey.keyPair.private, testData) + val signature = Crypto.doSign(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME, serverCertAndKey.keyPair.private, testData) val publicKey = Crypto.toSupportedPublicKey(serverCertAndKey.certificate.subjectPublicKeyInfo) - assertTrue { Crypto.isValid(DEFAULT_TLS_SIGNATURE_SCHEME, publicKey, signature, testData) } + assertTrue { Crypto.isValid(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME, publicKey, signature, testData) } } @Test @@ -227,9 +240,9 @@ class X509UtilitiesTest { "trustpass") // Generate server cert and private key and populate another keystore suitable for SSL - X509Utilities.createKeystoreForCordaNode(tmpSSLKeyStore, tmpServerKeyStore, "serverstorepass", "serverstorepass", caKeyStore, "cakeypass", MEGA_CORP.name) - val keyStore = KeyStoreUtilities.loadKeyStore(tmpSSLKeyStore, "serverstorepass") - val trustStore = KeyStoreUtilities.loadKeyStore(tmpTrustStore, "trustpass") + createKeystoreForCordaNode(tmpSSLKeyStore, tmpServerKeyStore, "serverstorepass", "serverstorepass", caKeyStore, "cakeypass", MEGA_CORP.name) + val keyStore = loadKeyStore(tmpSSLKeyStore, "serverstorepass") + val trustStore = loadKeyStore(tmpTrustStore, "trustpass") val context = SSLContext.getInstance("TLS") val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) @@ -341,14 +354,14 @@ class X509UtilitiesTest { trustStoreFilePath: Path, trustStorePassword: String ): KeyStore { - val rootCAKey = generateKeyPair(DEFAULT_TLS_SIGNATURE_SCHEME) - val rootCACert = createSelfSignedCACertificate(X509Utilities.getX509Name("Corda Node Root CA","London","demo@r3.com",null), rootCAKey) + val rootCAKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val rootCACert = X509Utilities.createSelfSignedCACertificate(getX509Name("Corda Node Root CA", "London", "demo@r3.com", null), rootCAKey) val intermediateCAKeyPair = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) - val intermediateCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, rootCACert, rootCAKey, X509Utilities.getX509Name("Corda Node Intermediate CA","London","demo@r3.com",null), intermediateCAKeyPair.public) + val intermediateCACert = X509Utilities.createCertificate(CertificateType.INTERMEDIATE_CA, rootCACert, rootCAKey, getX509Name("Corda Node Intermediate CA", "London", "demo@r3.com", null), intermediateCAKeyPair.public) val keyPass = keyPassword.toCharArray() - val keyStore = KeyStoreUtilities.loadOrCreateKeyStore(keyStoreFilePath, storePassword) + val keyStore = loadOrCreateKeyStore(keyStoreFilePath, storePassword) keyStore.addOrReplaceKey(X509Utilities.CORDA_ROOT_CA, rootCAKey.private, keyPass, arrayOf(rootCACert.cert)) @@ -359,7 +372,7 @@ class X509UtilitiesTest { keyStore.save(keyStoreFilePath, storePassword) - val trustStore = KeyStoreUtilities.loadOrCreateKeyStore(trustStoreFilePath, trustStorePassword) + val trustStore = loadOrCreateKeyStore(trustStoreFilePath, trustStorePassword) trustStore.addOrReplaceCertificate(X509Utilities.CORDA_ROOT_CA, rootCACert.cert) trustStore.addOrReplaceCertificate(X509Utilities.CORDA_INTERMEDIATE_CA, intermediateCACert.cert) @@ -372,8 +385,8 @@ class X509UtilitiesTest { @Test fun `Get correct private key type from Keystore`() { val keyPair = generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) - val selfSignCert = createSelfSignedCACertificate(X500Name("CN=Test"), keyPair) - val keyStore = KeyStoreUtilities.loadOrCreateKeyStore(tempFile("testKeystore.jks"), "keystorepassword") + val selfSignCert = X509Utilities.createSelfSignedCACertificate(X500Name("CN=Test"), keyPair) + val keyStore = loadOrCreateKeyStore(tempFile("testKeystore.jks"), "keystorepassword") keyStore.setKeyEntry("Key", keyPair.private, "keypassword".toCharArray(), arrayOf(selfSignCert.cert)) val keyFromKeystore = keyStore.getKey("Key", "keypassword".toCharArray()) @@ -383,4 +396,37 @@ class X509UtilitiesTest { assertTrue(keyFromKeystoreCasted is org.bouncycastle.jce.interfaces.ECPrivateKey) } + @Test + fun `serialize - deserialize X509CertififcateHolder`() { + val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } + val context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) + val expected: X509CertificateHolder = X509Utilities.createSelfSignedCACertificate(ALICE.name, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)) + val serialized = expected.serialize(factory, context).bytes + val actual: X509CertificateHolder = serialized.deserialize(factory, context) + assertEquals(expected, actual) + } + + @Test + fun `serialize - deserialize X509CertPath`() { + val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) } + val context = SerializationContextImpl(KryoHeaderV0_1, + javaClass.classLoader, + AllWhitelist, + emptyMap(), + true, + SerializationContext.UseCase.P2P) + val certFactory = CertificateFactory.getInstance("X509") + val rootCAKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME) + val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE.name, rootCAKey) + val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB.name, BOB_PUBKEY) + val expected = certFactory.generateCertPath(listOf(certificate.cert, rootCACert.cert)) + val serialized = expected.serialize(factory, context).bytes + val actual: CertPath = serialized.deserialize(factory, context) + assertEquals(expected, actual) + } } diff --git a/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkisRegistrationHelperTest.kt b/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkisRegistrationHelperTest.kt index 201fb206f4..51de115d1f 100644 --- a/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkisRegistrationHelperTest.kt +++ b/node/src/test/kotlin/net/corda/node/utilities/registration/NetworkisRegistrationHelperTest.kt @@ -3,9 +3,14 @@ package net.corda.node.utilities.registration import com.nhaarman.mockito_kotlin.any import com.nhaarman.mockito_kotlin.eq import com.nhaarman.mockito_kotlin.mock -import net.corda.core.crypto.* -import net.corda.core.exists -import net.corda.core.toTypedArray +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.cert +import net.corda.core.crypto.commonName +import net.corda.core.internal.exists +import net.corda.core.internal.toTypedArray +import net.corda.node.utilities.X509Utilities +import net.corda.node.utilities.loadKeyStore import net.corda.testing.ALICE import net.corda.testing.getTestX509Name import net.corda.testing.testNodeConfiguration @@ -52,10 +57,9 @@ class NetworkRegistrationHelperTest { assertTrue(config.sslKeystore.exists()) assertTrue(config.trustStoreFile.exists()) - val nodeKeystore = KeyStoreUtilities.loadKeyStore(config.nodeKeystore, config.keyStorePassword) - val sslKeystore = KeyStoreUtilities.loadKeyStore(config.sslKeystore, config.keyStorePassword) - val trustStore = KeyStoreUtilities.loadKeyStore(config.trustStoreFile, config.trustStorePassword) - + val nodeKeystore = loadKeyStore(config.nodeKeystore, config.keyStorePassword) + val sslKeystore = loadKeyStore(config.sslKeystore, config.keyStorePassword) + val trustStore = loadKeyStore(config.trustStoreFile, config.trustStorePassword) nodeKeystore.run { assertTrue(containsAlias(X509Utilities.CORDA_CLIENT_CA)) diff --git a/samples/attachment-demo/build.gradle b/samples/attachment-demo/build.gradle index 55be04795d..6a8491a923 100644 --- a/samples/attachment-demo/build.gradle +++ b/samples/attachment-demo/build.gradle @@ -26,19 +26,20 @@ dependencies { testCompile "junit:junit:$junit_version" // Corda integration dependencies - compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') - compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') - compile project(':core') - compile project(':test-utils') + cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts') + cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') + cordaCompile project(':core') + cordaCompile project(':webserver') + cordaCompile project(':test-utils') } task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar']) { ext.rpcUsers = [['username': "demo", 'password': "demo", 'permissions': ["StartFlow.net.corda.attachmentdemo.AttachmentDemoFlow"]]] directory "./build/nodes" - networkMap "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + networkMap "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" node { - name "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + name "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" advertisedServices["corda.notary.validating"] p2pPort 10002 rpcPort 10003 diff --git a/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt b/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt index a7e1f16027..b1bea3b8ab 100644 --- a/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt +++ b/samples/attachment-demo/src/integration-test/kotlin/net/corda/attachmentdemo/AttachmentDemoTest.kt @@ -1,8 +1,8 @@ package net.corda.attachmentdemo -import com.google.common.util.concurrent.Futures -import net.corda.core.getOrThrow import net.corda.core.node.services.ServiceInfo +import net.corda.core.internal.concurrent.transpose +import net.corda.core.utilities.getOrThrow import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B import net.corda.testing.DUMMY_NOTARY @@ -19,11 +19,11 @@ class AttachmentDemoTest { val numOfExpectedBytes = 10_000_000 driver(dsl = { val demoUser = listOf(User("demo", "demo", setOf(startFlowPermission()))) - val (nodeA, nodeB) = Futures.allAsList( + val (nodeA, nodeB) = listOf( startNode(DUMMY_BANK_A.name, rpcUsers = demoUser), startNode(DUMMY_BANK_B.name, rpcUsers = demoUser), startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type))) - ).getOrThrow() + ).transpose().getOrThrow() val senderThread = CompletableFuture.supplyAsync { nodeA.rpcClientToNode().start(demoUser[0].username, demoUser[0].password).use { diff --git a/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt b/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt index 22f54033f3..8c03f061f4 100644 --- a/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt +++ b/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/AttachmentDemo.kt @@ -3,22 +3,25 @@ package net.corda.attachmentdemo import co.paralleluniverse.fibers.Suspendable import joptsimple.OptionParser import net.corda.client.rpc.CordaRPCClient +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.Contract import net.corda.core.contracts.ContractState -import net.corda.core.contracts.TransactionForContract -import net.corda.core.contracts.TransactionType import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FinalityFlow import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC -import net.corda.core.getOrThrow import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.internal.Emoji +import net.corda.core.internal.InputStreamAndHash import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startTrackedFlow -import net.corda.core.sizedInputStreamAndHash +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.* -import net.corda.flows.FinalityFlow +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.getOrThrow import net.corda.testing.DUMMY_BANK_B import net.corda.testing.DUMMY_NOTARY import net.corda.testing.driver.poll @@ -26,6 +29,7 @@ import java.io.InputStream import java.net.HttpURLConnection import java.net.URL import java.util.concurrent.Executors +import java.util.concurrent.ScheduledExecutorService import java.util.jar.JarInputStream import javax.servlet.http.HttpServletResponse.SC_OK import javax.ws.rs.core.HttpHeaders.CONTENT_DISPOSITION @@ -70,14 +74,20 @@ fun main(args: Array) { /** An in memory test zip attachment of at least numOfClearBytes size, will be used. */ fun sender(rpc: CordaRPCOps, numOfClearBytes: Int = 1024) { // default size 1K. - val (inputStream, hash) = sizedInputStreamAndHash(numOfClearBytes) - sender(rpc, inputStream, hash) + val (inputStream, hash) = InputStreamAndHash.createInMemoryTestZip(numOfClearBytes, 0) + val executor = Executors.newScheduledThreadPool(2) + try { + sender(rpc, inputStream, hash, executor) + } finally { + executor.shutdown() + } } -fun sender(rpc: CordaRPCOps, inputStream: InputStream, hash: SecureHash.SHA256) { +private fun sender(rpc: CordaRPCOps, inputStream: InputStream, hash: SecureHash.SHA256, executor: ScheduledExecutorService) { + // Get the identity key of the other side (the recipient). - val executor = Executors.newScheduledThreadPool(1) - val otherSide: Party = poll(executor, DUMMY_BANK_B.name.toString()) { rpc.partyFromX500Name(DUMMY_BANK_B.name) }.get() + val notaryFuture: CordaFuture = poll(executor, DUMMY_NOTARY.name.toString()) { rpc.partyFromX500Name(DUMMY_NOTARY.name) } + val otherSideFuture: CordaFuture = poll(executor, DUMMY_BANK_B.name.toString()) { rpc.partyFromX500Name(DUMMY_BANK_B.name) } // Make sure we have the file in storage if (!rpc.attachmentExists(hash)) { @@ -88,14 +98,14 @@ fun sender(rpc: CordaRPCOps, inputStream: InputStream, hash: SecureHash.SHA256) require(rpc.attachmentExists(hash)) } - val flowHandle = rpc.startTrackedFlow(::AttachmentDemoFlow, otherSide, hash) + val flowHandle = rpc.startTrackedFlow(::AttachmentDemoFlow, otherSideFuture.get(), notaryFuture.get(), hash) flowHandle.progress.subscribe(::println) val stx = flowHandle.returnValue.getOrThrow() println("Sent ${stx.id}") } @StartableByRPC -class AttachmentDemoFlow(val otherSide: Party, val hash: SecureHash.SHA256) : FlowLogic() { +class AttachmentDemoFlow(val otherSide: Party, val notary: Party, val hash: SecureHash.SHA256) : FlowLogic() { object SIGNING : ProgressTracker.Step("Signing transaction") @@ -104,7 +114,7 @@ class AttachmentDemoFlow(val otherSide: Party, val hash: SecureHash.SHA256) : Fl @Suspendable override fun call(): SignedTransaction { // Create a trivial transaction with an output that describes the attachment, and the attachment itself - val ptx = TransactionType.General.Builder(notary = DUMMY_NOTARY) + val ptx = TransactionBuilder(notary) ptx.addOutputState(AttachmentContract.State(hash)) ptx.addAttachment(hash) @@ -119,11 +129,11 @@ class AttachmentDemoFlow(val otherSide: Party, val hash: SecureHash.SHA256) : Fl fun recipient(rpc: CordaRPCOps) { println("Waiting to receive transaction ...") - val stx = rpc.verifiedTransactions().second.toBlocking().first() + val stx = rpc.verifiedTransactionsFeed().updates.toBlocking().first() val wtx = stx.tx if (wtx.attachments.isNotEmpty()) { if (wtx.outputs.isNotEmpty()) { - val state = wtx.outputs.map { it.data }.filterIsInstance().single() + val state = wtx.outputsOfType().single() require(rpc.attachmentExists(state.hash)) // Download the attachment via the Web endpoint. @@ -170,8 +180,8 @@ class AttachmentContract : Contract { override val legalContractReference: SecureHash get() = SecureHash.zeroHash // TODO not implemented - override fun verify(tx: TransactionForContract) { - val state = tx.outputs.filterIsInstance().single() + override fun verify(tx: LedgerTransaction) { + val state = tx.outputsOfType().single() val attachment = tx.attachments.single() require(state.hash == attachment.id) } diff --git a/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/Main.kt b/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/Main.kt index 4528c419c0..bd586c7542 100644 --- a/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/Main.kt +++ b/samples/attachment-demo/src/main/kotlin/net/corda/attachmentdemo/Main.kt @@ -1,6 +1,6 @@ package net.corda.attachmentdemo -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.node.services.ServiceInfo import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B diff --git a/samples/bank-of-corda-demo/build.gradle b/samples/bank-of-corda-demo/build.gradle index fd957bbcdc..0b1ddf7076 100644 --- a/samples/bank-of-corda-demo/build.gradle +++ b/samples/bank-of-corda-demo/build.gradle @@ -23,19 +23,25 @@ configurations { dependencies { compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" - testCompile "junit:junit:$junit_version" // Corda integration dependencies - compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') - compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') - compile project(':core') - compile project(':client:jfx') - compile project(':client:rpc') - compile project(':finance') - compile project(':test-utils') + cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts') + cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') + cordaCompile project(':core') + cordaCompile project(':client:jfx') + cordaCompile project(':client:rpc') + cordaCompile project(':finance') + cordaCompile project(':webserver') + cordaCompile project(':test-utils') // Javax is required for webapis compile "org.glassfish.jersey.core:jersey-server:${jersey_version}" + + // Cordapp dependencies + // Specify your cordapp's dependencies below, including dependent cordapps + + // Test dependencies + testCompile "junit:junit:$junit_version" } task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar']) { diff --git a/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaHttpAPITest.kt b/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaHttpAPITest.kt index afee0671f7..d6d24385bf 100644 --- a/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaHttpAPITest.kt +++ b/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaHttpAPITest.kt @@ -1,13 +1,13 @@ package net.corda.bank -import com.google.common.util.concurrent.Futures import net.corda.bank.api.BankOfCordaClientApi import net.corda.bank.api.BankOfCordaWebApi.IssueRequestParams -import net.corda.core.getOrThrow import net.corda.core.node.services.ServiceInfo +import net.corda.core.internal.concurrent.transpose +import net.corda.core.utilities.getOrThrow +import net.corda.testing.driver.driver import net.corda.node.services.transactions.SimpleNotaryService import net.corda.testing.BOC -import net.corda.testing.driver.driver import org.junit.Test import kotlin.test.assertTrue @@ -15,10 +15,10 @@ class BankOfCordaHttpAPITest { @Test fun `issuer flow via Http`() { driver(dsl = { - val (nodeBankOfCorda) = Futures.allAsList( + val (nodeBankOfCorda) = listOf( startNode(BOC.name, setOf(ServiceInfo(SimpleNotaryService.type))), startNode(BIGCORP_LEGAL_NAME) - ).getOrThrow() + ).transpose().getOrThrow() val anonymous = false val nodeBankOfCordaApiAddr = startWebserver(nodeBankOfCorda).getOrThrow().listenAddress assertTrue(BankOfCordaClientApi(nodeBankOfCordaApiAddr).requestWebIssue(IssueRequestParams(1000, "USD", BIGCORP_LEGAL_NAME, "1", BOC.name, BOC.name, anonymous))) diff --git a/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt b/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt index 7159c077b4..5baaa550e1 100644 --- a/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt +++ b/samples/bank-of-corda-demo/src/integration-test/kotlin/net/corda/bank/BankOfCordaRPCClientTest.kt @@ -1,20 +1,20 @@ package net.corda.bank -import com.google.common.util.concurrent.Futures import net.corda.contracts.asset.Cash import net.corda.core.contracts.DOLLARS -import net.corda.core.getOrThrow import net.corda.core.messaging.startFlow +import net.corda.core.messaging.vaultTrackBy import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.Vault -import net.corda.core.node.services.trackBy import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.internal.concurrent.transpose +import net.corda.core.utilities.getOrThrow import net.corda.flows.IssuerFlow.IssuanceRequester -import net.corda.testing.driver.driver import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService import net.corda.nodeapi.User import net.corda.testing.* +import net.corda.testing.driver.driver import org.junit.Test class BankOfCordaRPCClientTest { @@ -23,10 +23,10 @@ class BankOfCordaRPCClientTest { driver(dsl = { val bocManager = User("bocManager", "password1", permissions = setOf(startFlowPermission())) val bigCorpCFO = User("bigCorpCFO", "password2", permissions = emptySet()) - val (nodeBankOfCorda, nodeBigCorporation) = Futures.allAsList( + val (nodeBankOfCorda, nodeBigCorporation) = listOf( startNode(BOC.name, setOf(ServiceInfo(SimpleNotaryService.type)), listOf(bocManager)), startNode(BIGCORP_LEGAL_NAME, rpcUsers = listOf(bigCorpCFO)) - ).getOrThrow() + ).transpose().getOrThrow() // Bank of Corda RPC Client val bocClient = nodeBankOfCorda.rpcClientToNode() @@ -38,10 +38,10 @@ class BankOfCordaRPCClientTest { // Register for Bank of Corda Vault updates val criteria = QueryCriteria.VaultQueryCriteria(status = Vault.StateStatus.ALL) - val (_, vaultUpdatesBoc) = bocProxy.vaultTrackByCriteria(Cash.State::class.java, criteria) + val vaultUpdatesBoc = bocProxy.vaultTrackByCriteria(Cash.State::class.java, criteria).updates // Register for Big Corporation Vault updates - val (_, vaultUpdatesBigCorp) = bigCorpProxy.vaultTrackByCriteria(Cash.State::class.java, criteria) + val vaultUpdatesBigCorp = bigCorpProxy.vaultTrackByCriteria(Cash.State::class.java, criteria).updates // Kick-off actual Issuer Flow val anonymous = true diff --git a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/BankOfCordaDriver.kt b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/BankOfCordaDriver.kt index ed44a63309..eda775af3f 100644 --- a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/BankOfCordaDriver.kt +++ b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/BankOfCordaDriver.kt @@ -7,15 +7,15 @@ import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.DUMMY_NOTARY import net.corda.flows.CashExitFlow import net.corda.flows.CashPaymentFlow import net.corda.flows.IssuerFlow -import net.corda.testing.driver.driver import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService import net.corda.nodeapi.User import net.corda.testing.BOC +import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.driver.driver import org.bouncycastle.asn1.x500.X500Name import kotlin.system.exitProcess @@ -55,38 +55,48 @@ private class BankOfCordaDriver { // The ISSUER will launch a Bank of Corda node // The ISSUE_CASH will request some Cash from the ISSUER on behalf of Big Corporation node val role = options.valueOf(roleArg)!! - if (role == Role.ISSUER) { - driver(dsl = { - val bankUser = User(BANK_USERNAME, "test", permissions = setOf(startFlowPermission(), startFlowPermission(), startFlowPermission())) - val bigCorpUser = User(BIGCORP_USERNAME, "test", permissions = setOf(startFlowPermission())) - startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type))) - val bankOfCorda = startNode(BOC.name, rpcUsers = listOf(bankUser), advertisedServices = setOf(ServiceInfo(ServiceType.corda.getSubType("issuer.USD")))) - startNode(BIGCORP_LEGAL_NAME, rpcUsers = listOf(bigCorpUser)) - startWebserver(bankOfCorda.get()) - waitForAllNodesToFinish() - }, isDebug = true) - } else { - try { - val anonymous = true - val requestParams = IssueRequestParams(options.valueOf(quantity), options.valueOf(currency), BIGCORP_LEGAL_NAME, "1", BOC.name, DUMMY_NOTARY.name, anonymous) - when (role) { - Role.ISSUE_CASH_RPC -> { - println("Requesting Cash via RPC ...") - val result = BankOfCordaClientApi(NetworkHostAndPort("localhost", 10006)).requestRPCIssue(requestParams) - if (result is SignedTransaction) - println("Success!! You transaction receipt is ${result.tx.id}") - } - Role.ISSUE_CASH_WEB -> { - println("Requesting Cash via Web ...") - val result = BankOfCordaClientApi(NetworkHostAndPort("localhost", 10007)).requestWebIssue(requestParams) - if (result) - println("Successfully processed Cash Issue request") - } + + val anonymous = true + val requestParams = IssueRequestParams(options.valueOf(quantity), options.valueOf(currency), BIGCORP_LEGAL_NAME, "1", BOC.name, DUMMY_NOTARY.name, anonymous) + + try { + when (role) { + Role.ISSUER -> { + driver(dsl = { + val bankUser = User( + BANK_USERNAME, + "test", + permissions = setOf( + startFlowPermission(), + startFlowPermission(), + startFlowPermission())) + val bigCorpUser = User(BIGCORP_USERNAME, "test", permissions = setOf(startFlowPermission())) + startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type))) + val bankOfCorda = startNode( + BOC.name, + rpcUsers = listOf(bankUser), + advertisedServices = setOf(ServiceInfo(ServiceType.corda.getSubType("issuer.USD")))) + startNode(BIGCORP_LEGAL_NAME, rpcUsers = listOf(bigCorpUser)) + startWebserver(bankOfCorda.get()) + waitForAllNodesToFinish() + }, isDebug = true) + } + Role.ISSUE_CASH_RPC -> { + println("Requesting Cash via RPC ...") + val result = BankOfCordaClientApi(NetworkHostAndPort("localhost", 10006)).requestRPCIssue(requestParams) + if (result is SignedTransaction) + println("Success!! You transaction receipt is ${result.tx.id}") + } + Role.ISSUE_CASH_WEB -> { + println("Requesting Cash via Web ...") + val result = BankOfCordaClientApi(NetworkHostAndPort("localhost", 10007)).requestWebIssue(requestParams) + if (result) + println("Successfully processed Cash Issue request") } - } catch (e: Exception) { - println("Exception occurred: $e \n ${e.printStackTrace()}") - exitProcess(1) } + } catch (e: Exception) { + println("Exception occurred: $e \n ${e.printStackTrace()}") + exitProcess(1) } } diff --git a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt index 2f4da02b39..a2be1feb62 100644 --- a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt +++ b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaClientApi.kt @@ -4,12 +4,13 @@ import net.corda.bank.api.BankOfCordaWebApi.IssueRequestParams import net.corda.client.rpc.CordaRPCClient import net.corda.core.contracts.Amount import net.corda.core.contracts.currency -import net.corda.core.getOrThrow import net.corda.core.messaging.startFlow import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.getOrThrow import net.corda.flows.IssuerFlow.IssuanceRequester +import net.corda.testing.DUMMY_NOTARY import net.corda.testing.http.HttpApi /** @@ -33,20 +34,22 @@ class BankOfCordaClientApi(val hostAndPort: NetworkHostAndPort) { val client = CordaRPCClient(hostAndPort) // TODO: privileged security controls required client.start("bankUser", "test").use { connection -> - val proxy = connection.proxy + val rpc = connection.proxy // Resolve parties via RPC - val issueToParty = proxy.partyFromX500Name(params.issueToPartyName) + val issueToParty = rpc.partyFromX500Name(params.issueToPartyName) ?: throw Exception("Unable to locate ${params.issueToPartyName} in Network Map Service") - val issuerBankParty = proxy.partyFromX500Name(params.issuerBankName) + val issuerBankParty = rpc.partyFromX500Name(params.issuerBankName) ?: throw Exception("Unable to locate ${params.issuerBankName} in Network Map Service") - val notaryParty = proxy.partyFromX500Name(params.notaryName) - ?: throw Exception("Unable to locate ${params.notaryName} in Network Map Service") + val notaryLegalIdentity = rpc.partyFromX500Name(params.notaryName) + ?: throw IllegalStateException("Unable to locate ${params.notaryName} in Network Map Service") + val notaryNode = rpc.nodeIdentityFromParty(notaryLegalIdentity) + ?: throw IllegalStateException("Unable to locate notary node in network map cache") val amount = Amount(params.amount, currency(params.currency)) val issuerToPartyRef = OpaqueBytes.of(params.issueToPartyRefAsString.toByte()) - return proxy.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty, notaryParty, params.anonymous) + return rpc.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty, notaryNode.notaryIdentity, params.anonymous) .returnValue.getOrThrow().stx } } diff --git a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaWebApi.kt b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaWebApi.kt index 31c8b3abf3..47e7c2bacc 100644 --- a/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaWebApi.kt +++ b/samples/bank-of-corda-demo/src/main/kotlin/net/corda/bank/api/BankOfCordaWebApi.kt @@ -2,10 +2,10 @@ package net.corda.bank.api import net.corda.core.contracts.Amount import net.corda.core.contracts.currency -import net.corda.core.getOrThrow import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.flows.IssuerFlow.IssuanceRequester import org.bouncycastle.asn1.x500.X500Name @@ -43,11 +43,13 @@ class BankOfCordaWebApi(val rpc: CordaRPCOps) { fun issueAssetRequest(params: IssueRequestParams): Response { // Resolve parties via RPC val issueToParty = rpc.partyFromX500Name(params.issueToPartyName) - ?: return Response.status(Response.Status.FORBIDDEN).entity("Unable to locate ${params.issueToPartyName} in Network Map Service").build() + ?: return Response.status(Response.Status.FORBIDDEN).entity("Unable to locate ${params.issueToPartyName} in identity service").build() val issuerBankParty = rpc.partyFromX500Name(params.issuerBankName) - ?: return Response.status(Response.Status.FORBIDDEN).entity("Unable to locate ${params.issuerBankName} in Network Map Service").build() + ?: return Response.status(Response.Status.FORBIDDEN).entity("Unable to locate ${params.issuerBankName} in identity service").build() val notaryParty = rpc.partyFromX500Name(params.notaryName) - ?: return Response.status(Response.Status.FORBIDDEN).entity("Unable to locate ${params.notaryName} in Network Map Service").build() + ?: return Response.status(Response.Status.FORBIDDEN).entity("Unable to locate ${params.notaryName} in identity service").build() + val notaryNode = rpc.nodeIdentityFromParty(notaryParty) + ?: return Response.status(Response.Status.FORBIDDEN).entity("Unable to locate ${notaryParty} in network map service").build() val amount = Amount(params.amount, currency(params.currency)) val issuerToPartyRef = OpaqueBytes.of(params.issueToPartyRefAsString.toByte()) @@ -56,7 +58,7 @@ class BankOfCordaWebApi(val rpc: CordaRPCOps) { // invoke client side of Issuer Flow: IssuanceRequester // The line below blocks and waits for the future to resolve. return try { - rpc.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty, notaryParty, anonymous).returnValue.getOrThrow() + rpc.startFlow(::IssuanceRequester, amount, issueToParty, issuerToPartyRef, issuerBankParty, notaryNode.notaryIdentity, anonymous).returnValue.getOrThrow() logger.info("Issue request completed successfully: $params") Response.status(Response.Status.CREATED).build() } catch (e: Exception) { diff --git a/samples/irs-demo/build.gradle b/samples/irs-demo/build.gradle index f577151982..70d42a9d19 100644 --- a/samples/irs-demo/build.gradle +++ b/samples/irs-demo/build.gradle @@ -28,11 +28,11 @@ dependencies { compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" // Corda integration dependencies - compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') - compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') - compile project(':core') - compile project(':finance') - compile project(':webserver') + cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts') + cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') + cordaCompile project(':core') + cordaCompile project(':finance') + cordaCompile project(':webserver') // Javax is required for webapis compile "org.glassfish.jersey.core:jersey-server:${jersey_version}" @@ -48,9 +48,9 @@ dependencies { task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar']) { directory "./build/nodes" - networkMap "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + networkMap "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" node { - name "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + name "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" advertisedServices = ["corda.notary.validating", "corda.interest_rates"] p2pPort 10002 rpcPort 10003 @@ -115,4 +115,4 @@ publishing { jar { from sourceSets.test.output -} \ No newline at end of file +} diff --git a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt index ac93b3ec0b..309490c07d 100644 --- a/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt +++ b/samples/irs-demo/src/integration-test/kotlin/net/corda/irs/IRSDemoTest.kt @@ -1,20 +1,22 @@ package net.corda.irs -import com.google.common.util.concurrent.Futures import net.corda.client.rpc.CordaRPCClient -import net.corda.core.getOrThrow +import net.corda.core.messaging.vaultTrackBy import net.corda.core.node.services.ServiceInfo import net.corda.core.toFuture +import net.corda.core.internal.concurrent.transpose import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.DUMMY_BANK_A -import net.corda.testing.DUMMY_BANK_B -import net.corda.testing.DUMMY_NOTARY +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.seconds import net.corda.irs.api.NodeInterestRates import net.corda.irs.contract.InterestRateSwap import net.corda.irs.utilities.uploadFile import net.corda.node.services.config.FullNodeConfiguration import net.corda.node.services.transactions.SimpleNotaryService import net.corda.nodeapi.User +import net.corda.testing.DUMMY_BANK_A +import net.corda.testing.DUMMY_BANK_B +import net.corda.testing.DUMMY_NOTARY import net.corda.testing.IntegrationTestCategory import net.corda.testing.driver.driver import net.corda.testing.http.HttpApi @@ -25,30 +27,29 @@ import rx.Observable import java.net.URL import java.time.Duration import java.time.LocalDate -import java.time.temporal.ChronoUnit class IRSDemoTest : IntegrationTestCategory { val rpcUser = User("user", "password", emptySet()) val currentDate: LocalDate = LocalDate.now() val futureDate: LocalDate = currentDate.plusMonths(6) - val maxWaitTime: Duration = Duration.of(60, ChronoUnit.SECONDS) + val maxWaitTime: Duration = 60.seconds @Test fun `runs IRS demo`() { driver(useTestClock = true, isDebug = true) { - val (controller, nodeA, nodeB) = Futures.allAsList( + val (controller, nodeA, nodeB) = listOf( startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type), ServiceInfo(NodeInterestRates.Oracle.type))), startNode(DUMMY_BANK_A.name, rpcUsers = listOf(rpcUser)), startNode(DUMMY_BANK_B.name) - ).getOrThrow() + ).transpose().getOrThrow() println("All nodes started") - val (controllerAddr, nodeAAddr, nodeBAddr) = Futures.allAsList( + val (controllerAddr, nodeAAddr, nodeBAddr) = listOf( startWebserver(controller), startWebserver(nodeA), startWebserver(nodeB) - ).getOrThrow().map { it.listenAddress } + ).transpose().getOrThrow().map { it.listenAddress } println("All webservers started") @@ -79,9 +80,9 @@ class IRSDemoTest : IntegrationTestCategory { fun getFloatingLegFixCount(nodeApi: HttpApi) = getTrades(nodeApi)[0].calculation.floatingLegPaymentSchedule.count { it.value.rate.ratioUnit != null } fun getFixingDateObservable(config: FullNodeConfiguration): Observable { - val client = CordaRPCClient(config.rpcAddress!!) + val client = CordaRPCClient(config.rpcAddress!!, initialiseSerialization = false) val proxy = client.start("user", "password").proxy - val vaultUpdates = proxy.vaultAndUpdates().second + val vaultUpdates = proxy.vaultTrackBy().updates return vaultUpdates.map { update -> val irsStates = update.produced.map { it.state.data }.filterIsInstance() diff --git a/samples/irs-demo/src/main/kotlin/net/corda/irs/api/InterestRateSwapAPI.kt b/samples/irs-demo/src/main/kotlin/net/corda/irs/api/InterestRateSwapAPI.kt index 29821de3c9..2377a9c86a 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/irs/api/InterestRateSwapAPI.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/irs/api/InterestRateSwapAPI.kt @@ -1,10 +1,10 @@ package net.corda.irs.api -import net.corda.client.rpc.notUsed import net.corda.core.contracts.filterStatesOfType -import net.corda.core.getOrThrow import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow +import net.corda.core.messaging.vaultQueryBy +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.irs.contract.InterestRateSwap import net.corda.irs.flows.AutoOfferFlow @@ -36,9 +36,8 @@ class InterestRateSwapAPI(val rpc: CordaRPCOps) { private fun generateDealLink(deal: InterestRateSwap.State) = "/api/irs/deals/" + deal.common.tradeID private fun getDealByRef(ref: String): InterestRateSwap.State? { - val (vault, vaultUpdates) = rpc.vaultAndUpdates() - vaultUpdates.notUsed() - val states = vault.filterStatesOfType().filter { it.state.data.ref == ref } + val vault = rpc.vaultQueryBy().states + val states = vault.filterStatesOfType().filter { it.state.data.linearId.externalId == ref } return if (states.isEmpty()) null else { val deals = states.map { it.state.data } return if (deals.isEmpty()) null else deals[0] @@ -46,8 +45,7 @@ class InterestRateSwapAPI(val rpc: CordaRPCOps) { } private fun getAllDeals(): Array { - val (vault, vaultUpdates) = rpc.vaultAndUpdates() - vaultUpdates.notUsed() + val vault = rpc.vaultQueryBy().states val states = vault.filterStatesOfType() val swaps = states.map { it.state.data }.toTypedArray() return swaps diff --git a/samples/irs-demo/src/main/kotlin/net/corda/irs/api/NodeInterestRates.kt b/samples/irs-demo/src/main/kotlin/net/corda/irs/api/NodeInterestRates.kt index 33604a7454..76e6c6007c 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/irs/api/NodeInterestRates.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/irs/api/NodeInterestRates.kt @@ -8,16 +8,14 @@ import net.corda.contracts.Tenor import net.corda.contracts.math.CubicSplineInterpolator import net.corda.contracts.math.Interpolator import net.corda.contracts.math.InterpolatorFactory -import net.corda.core.RetryableException -import net.corda.core.ThreadBox import net.corda.core.contracts.Command -import net.corda.core.crypto.DigitalSignature -import net.corda.core.crypto.MerkleTreeException -import net.corda.core.crypto.keys +import net.corda.core.crypto.* +import net.corda.core.flows.FlowException import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatedBy import net.corda.core.flows.StartableByRPC import net.corda.core.identity.Party +import net.corda.core.internal.ThreadBox import net.corda.core.node.PluginServiceHub import net.corda.core.node.ServiceHub import net.corda.core.node.services.CordaService @@ -146,12 +144,12 @@ object NodeInterestRates { // Oracle gets signing request for only some of them with a valid partial tree? We sign over a whole transaction. // It will be fixed by adding partial signatures later. // DOCSTART 1 - fun sign(ftx: FilteredTransaction): DigitalSignature.WithKey { + fun sign(ftx: FilteredTransaction): TransactionSignature { if (!ftx.verify()) { throw MerkleTreeException("Rate Fix Oracle: Couldn't verify partial Merkle tree.") } // Performing validation of obtained FilteredLeaves. - fun commandValidator(elem: Command): Boolean { + fun commandValidator(elem: Command<*>): Boolean { if (!(identity.owningKey in elem.signers && elem.value is Fix)) throw IllegalArgumentException("Oracle received unknown command (not in signers or not Fix).") val fix = elem.value as Fix @@ -163,7 +161,7 @@ object NodeInterestRates { fun check(elem: Any): Boolean { return when (elem) { - is Command -> commandValidator(elem) + is Command<*> -> commandValidator(elem) else -> throw IllegalArgumentException("Oracle received data of different type than expected.") } } @@ -177,8 +175,9 @@ object NodeInterestRates { // Note that we will happily sign an invalid transaction, as we are only being presented with a filtered // version so we can't resolve or check it ourselves. However, that doesn't matter much, as if we sign // an invalid transaction the signature is worthless. - val signature = services.keyManagementService.sign(ftx.rootHash.bytes, signingKey) - return DigitalSignature.WithKey(signingKey, signature.bytes) + val signableData = SignableData(ftx.rootHash, SignatureMetadata(services.myInfo.platformVersion, Crypto.findSignatureScheme(signingKey).schemeNumberID)) + val signature = services.keyManagementService.sign(signableData, signingKey) + return TransactionSignature(signature.bytes, signingKey, signableData.signatureMetadata) } // DOCEND 1 @@ -192,7 +191,7 @@ object NodeInterestRates { } // TODO: can we split into two? Fix not available (retryable/transient) and unknown (permanent) - class UnknownFix(val fix: FixOf) : RetryableException("Unknown fix: $fix") + class UnknownFix(val fix: FixOf) : FlowException("Unknown fix: $fix") // Upload the raw fix data via RPC. In a real system the oracle data would be taken from a database. @StartableByRPC diff --git a/samples/irs-demo/src/main/kotlin/net/corda/irs/contract/IRS.kt b/samples/irs-demo/src/main/kotlin/net/corda/irs/contract/IRS.kt index bc445b4077..8c032f695d 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/irs/contract/IRS.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/irs/contract/IRS.kt @@ -1,11 +1,8 @@ package net.corda.irs.contract -import net.corda.contracts.* -import com.fasterxml.jackson.annotation.JsonIgnore import com.fasterxml.jackson.annotation.JsonIgnoreProperties -import com.fasterxml.jackson.annotation.JsonProperty +import net.corda.contracts.* import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.* import net.corda.core.crypto.SecureHash import net.corda.core.crypto.containsAny import net.corda.core.flows.FlowLogicRefFactory @@ -13,6 +10,7 @@ import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.node.services.ServiceType import net.corda.core.serialization.CordaSerializable +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.irs.api.NodeInterestRates import net.corda.irs.flows.FixingFlow @@ -461,190 +459,137 @@ class InterestRateSwap : Contract { fixingCalendar, index, indexSource, indexTenor) } - override fun verify(tx: TransactionForContract) = verifyClause(tx, AllOf(Clauses.TimeWindow(), Clauses.Group()), tx.commands.select()) + // These functions may make more sense to use for basket types, but for now let's leave them here + private fun checkLegDates(legs: List) { + requireThat { + "Effective date is before termination date" using legs.all { it.effectiveDate < it.terminationDate } + "Effective dates are in alignment" using legs.all { it.effectiveDate == legs[0].effectiveDate } + "Termination dates are in alignment" using legs.all { it.terminationDate == legs[0].terminationDate } + } + } - interface Clauses { - /** - * Common superclass for IRS contract clauses, which defines behaviour on match/no-match, and provides - * helper functions for the clauses. - */ - abstract class AbstractIRSClause : Clause() { - // These functions may make more sense to use for basket types, but for now let's leave them here - fun checkLegDates(legs: List) { + private fun checkLegAmounts(legs: List) { + requireThat { + "The notional is non zero" using legs.any { it.notional.quantity > (0).toLong() } + "The notional for all legs must be the same" using legs.all { it.notional == legs[0].notional } + } + for (leg: CommonLeg in legs) { + if (leg is FixedLeg) { requireThat { - "Effective date is before termination date" using legs.all { it.effectiveDate < it.terminationDate } - "Effective dates are in alignment" using legs.all { it.effectiveDate == legs[0].effectiveDate } - "Termination dates are in alignment" using legs.all { it.terminationDate == legs[0].terminationDate } - } - } - - fun checkLegAmounts(legs: List) { - requireThat { - "The notional is non zero" using legs.any { it.notional.quantity > (0).toLong() } - "The notional for all legs must be the same" using legs.all { it.notional == legs[0].notional } - } - for (leg: CommonLeg in legs) { - if (leg is FixedLeg) { - requireThat { - // TODO: Confirm: would someone really enter a swap with a negative fixed rate? - "Fixed leg rate must be positive" using leg.fixedRate.isPositive() - } - } - } - } - - // TODO: After business rules discussion, add further checks to the schedules and rates - fun checkSchedules(@Suppress("UNUSED_PARAMETER") legs: List): Boolean = true - - fun checkRates(@Suppress("UNUSED_PARAMETER") legs: List): Boolean = true - - /** - * Compares two schedules of Floating Leg Payments, returns the difference (i.e. omissions in either leg or changes to the values). - */ - fun getFloatingLegPaymentsDifferences(payments1: Map, payments2: Map): List>> { - val diff1 = payments1.filter { payments1[it.key] != payments2[it.key] } - val diff2 = payments2.filter { payments1[it.key] != payments2[it.key] } - return (diff1.keys + diff2.keys).map { - it to Pair(diff1[it] as FloatingRatePaymentEvent, diff2[it] as FloatingRatePaymentEvent) + // TODO: Confirm: would someone really enter a swap with a negative fixed rate? + "Fixed leg rate must be positive" using leg.fixedRate.isPositive() } } } + } - class Group : GroupClauseVerifier(AnyOf(Agree(), Fix(), Pay(), Mature())) { - // Group by Trade ID for in / out states - override fun groupStates(tx: TransactionForContract): List> { - return tx.groupStates { state -> state.linearId } - } + /** + * Compares two schedules of Floating Leg Payments, returns the difference (i.e. omissions in either leg or changes to the values). + */ + private fun getFloatingLegPaymentsDifferences(payments1: Map, payments2: Map): List>> { + val diff1 = payments1.filter { payments1[it.key] != payments2[it.key] } + val diff2 = payments2.filter { payments1[it.key] != payments2[it.key] } + return (diff1.keys + diff2.keys).map { + it to Pair(diff1[it] as FloatingRatePaymentEvent, diff2[it] as FloatingRatePaymentEvent) + } + } + + private fun verifyAgreeCommand(inputs: List, outputs: List) { + val irs = outputs.filterIsInstance().single() + requireThat { + "There are no in states for an agreement" using inputs.isEmpty() + "There are events in the fix schedule" using (irs.calculation.fixedLegPaymentSchedule.isNotEmpty()) + "There are events in the float schedule" using (irs.calculation.floatingLegPaymentSchedule.isNotEmpty()) + "All notionals must be non zero" using (irs.fixedLeg.notional.quantity > 0 && irs.floatingLeg.notional.quantity > 0) + "The fixed leg rate must be positive" using (irs.fixedLeg.fixedRate.isPositive()) + "The currency of the notionals must be the same" using (irs.fixedLeg.notional.token == irs.floatingLeg.notional.token) + "All leg notionals must be the same" using (irs.fixedLeg.notional == irs.floatingLeg.notional) + "The effective date is before the termination date for the fixed leg" using (irs.fixedLeg.effectiveDate < irs.fixedLeg.terminationDate) + "The effective date is before the termination date for the floating leg" using (irs.floatingLeg.effectiveDate < irs.floatingLeg.terminationDate) + "The effective dates are aligned" using (irs.floatingLeg.effectiveDate == irs.fixedLeg.effectiveDate) + "The termination dates are aligned" using (irs.floatingLeg.terminationDate == irs.fixedLeg.terminationDate) + "The fixing period date offset cannot be negative" using (irs.floatingLeg.fixingPeriodOffset >= 0) + + // TODO: further tests + } + checkLegAmounts(listOf(irs.fixedLeg, irs.floatingLeg)) + checkLegDates(listOf(irs.fixedLeg, irs.floatingLeg)) + } + + private fun verifyFixCommand(inputs: List, outputs: List, command: AuthenticatedObject) { + val irs = outputs.filterIsInstance().single() + val prevIrs = inputs.filterIsInstance().single() + val paymentDifferences = getFloatingLegPaymentsDifferences(prevIrs.calculation.floatingLegPaymentSchedule, irs.calculation.floatingLegPaymentSchedule) + + // Having both of these tests are "redundant" as far as verify() goes, however, by performing both + // we can relay more information back to the user in the case of failure. + requireThat { + "There is at least one difference in the IRS floating leg payment schedules" using !paymentDifferences.isEmpty() + "There is only one change in the IRS floating leg payment schedule" using (paymentDifferences.size == 1) } - class TimeWindow : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Unit?): Set { - require(tx.timeWindow?.midpoint != null) { "must be have a time-window)" } - // We return an empty set because we don't process any commands - return emptySet() + val (oldFloatingRatePaymentEvent, newFixedRatePaymentEvent) = paymentDifferences.single().second // Ignore the date of the changed rate (we checked that earlier). + val fixValue = command.value.fix + // Need to check that everything is the same apart from the new fixed rate entry. + requireThat { + "The fixed leg parties are constant" using (irs.fixedLeg.fixedRatePayer == prevIrs.fixedLeg.fixedRatePayer) // Although superseded by the below test, this is included for a regression issue + "The fixed leg is constant" using (irs.fixedLeg == prevIrs.fixedLeg) + "The floating leg is constant" using (irs.floatingLeg == prevIrs.floatingLeg) + "The common values are constant" using (irs.common == prevIrs.common) + "The fixed leg payment schedule is constant" using (irs.calculation.fixedLegPaymentSchedule == prevIrs.calculation.fixedLegPaymentSchedule) + "The expression is unchanged" using (irs.calculation.expression == prevIrs.calculation.expression) + "There is only one changed payment in the floating leg" using (paymentDifferences.size == 1) + "There changed payment is a floating payment" using (oldFloatingRatePaymentEvent.rate is ReferenceRate) + "The new payment is a fixed payment" using (newFixedRatePaymentEvent.rate is FixedRate) + "The changed payments dates are aligned" using (oldFloatingRatePaymentEvent.date == newFixedRatePaymentEvent.date) + "The new payment has the correct rate" using (newFixedRatePaymentEvent.rate.ratioUnit!!.value == fixValue.value) + "The fixing is for the next required date" using (prevIrs.calculation.nextFixingDate() == fixValue.of.forDay) + "The fix payment has the same currency as the notional" using (newFixedRatePaymentEvent.flow.token == irs.floatingLeg.notional.token) + // "The fixing is not in the future " by (fixCommand) // The oracle should not have signed this . + } + } + + private fun verifyPayCommand() { + requireThat { + "Payments not supported / verifiable yet" using false + } + } + + private fun verifyMatureCommand(inputs: List, outputs: List) { + val irs = inputs.filterIsInstance().single() + requireThat { + "No more fixings to be applied" using (irs.calculation.nextFixingDate() == null) + "The irs is fully consumed and there is no id matched output state" using outputs.isEmpty() + } + } + + override fun verify(tx: LedgerTransaction) { + requireNotNull(tx.timeWindow) { "must be have a time-window)" } + val groups: List> = tx.groupStates { state -> state.linearId } + var atLeastOneCommandProcessed = false + for ((inputs, outputs, _) in groups) { + val agreeCommand = tx.commands.select().firstOrNull() + if (agreeCommand != null) { + verifyAgreeCommand(inputs, outputs) + atLeastOneCommandProcessed = true + } + val fixCommand = tx.commands.select().firstOrNull() + if (fixCommand != null) { + verifyFixCommand(inputs, outputs, fixCommand) + atLeastOneCommandProcessed = true + } + val payCommand = tx.commands.select().firstOrNull() + if (payCommand != null) { + verifyPayCommand() + atLeastOneCommandProcessed = true + } + val matureCommand = tx.commands.select().firstOrNull() + if (matureCommand != null) { + verifyMatureCommand(inputs, outputs) + atLeastOneCommandProcessed = true } } - - class Agree : AbstractIRSClause() { - override val requiredCommands: Set> = setOf(Commands.Agree::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: UniqueIdentifier?): Set { - val command = tx.commands.requireSingleCommand() - val irs = outputs.filterIsInstance().single() - requireThat { - "There are no in states for an agreement" using inputs.isEmpty() - "There are events in the fix schedule" using (irs.calculation.fixedLegPaymentSchedule.isNotEmpty()) - "There are events in the float schedule" using (irs.calculation.floatingLegPaymentSchedule.isNotEmpty()) - "All notionals must be non zero" using (irs.fixedLeg.notional.quantity > 0 && irs.floatingLeg.notional.quantity > 0) - "The fixed leg rate must be positive" using (irs.fixedLeg.fixedRate.isPositive()) - "The currency of the notionals must be the same" using (irs.fixedLeg.notional.token == irs.floatingLeg.notional.token) - "All leg notionals must be the same" using (irs.fixedLeg.notional == irs.floatingLeg.notional) - - "The effective date is before the termination date for the fixed leg" using (irs.fixedLeg.effectiveDate < irs.fixedLeg.terminationDate) - "The effective date is before the termination date for the floating leg" using (irs.floatingLeg.effectiveDate < irs.floatingLeg.terminationDate) - "The effective dates are aligned" using (irs.floatingLeg.effectiveDate == irs.fixedLeg.effectiveDate) - "The termination dates are aligned" using (irs.floatingLeg.terminationDate == irs.fixedLeg.terminationDate) - "The rates are valid" using checkRates(listOf(irs.fixedLeg, irs.floatingLeg)) - "The schedules are valid" using checkSchedules(listOf(irs.fixedLeg, irs.floatingLeg)) - "The fixing period date offset cannot be negative" using (irs.floatingLeg.fixingPeriodOffset >= 0) - - // TODO: further tests - } - checkLegAmounts(listOf(irs.fixedLeg, irs.floatingLeg)) - checkLegDates(listOf(irs.fixedLeg, irs.floatingLeg)) - - return setOf(command.value) - } - } - - class Fix : AbstractIRSClause() { - override val requiredCommands: Set> = setOf(Commands.Refix::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: UniqueIdentifier?): Set { - val command = tx.commands.requireSingleCommand() - val irs = outputs.filterIsInstance().single() - val prevIrs = inputs.filterIsInstance().single() - val paymentDifferences = getFloatingLegPaymentsDifferences(prevIrs.calculation.floatingLegPaymentSchedule, irs.calculation.floatingLegPaymentSchedule) - - // Having both of these tests are "redundant" as far as verify() goes, however, by performing both - // we can relay more information back to the user in the case of failure. - requireThat { - "There is at least one difference in the IRS floating leg payment schedules" using !paymentDifferences.isEmpty() - "There is only one change in the IRS floating leg payment schedule" using (paymentDifferences.size == 1) - } - - val changedRates = paymentDifferences.single().second // Ignore the date of the changed rate (we checked that earlier). - val (oldFloatingRatePaymentEvent, newFixedRatePaymentEvent) = changedRates - val fixValue = command.value.fix - // Need to check that everything is the same apart from the new fixed rate entry. - requireThat { - "The fixed leg parties are constant" using (irs.fixedLeg.fixedRatePayer == prevIrs.fixedLeg.fixedRatePayer) // Although superseded by the below test, this is included for a regression issue - "The fixed leg is constant" using (irs.fixedLeg == prevIrs.fixedLeg) - "The floating leg is constant" using (irs.floatingLeg == prevIrs.floatingLeg) - "The common values are constant" using (irs.common == prevIrs.common) - "The fixed leg payment schedule is constant" using (irs.calculation.fixedLegPaymentSchedule == prevIrs.calculation.fixedLegPaymentSchedule) - "The expression is unchanged" using (irs.calculation.expression == prevIrs.calculation.expression) - "There is only one changed payment in the floating leg" using (paymentDifferences.size == 1) - "There changed payment is a floating payment" using (oldFloatingRatePaymentEvent.rate is ReferenceRate) - "The new payment is a fixed payment" using (newFixedRatePaymentEvent.rate is FixedRate) - "The changed payments dates are aligned" using (oldFloatingRatePaymentEvent.date == newFixedRatePaymentEvent.date) - "The new payment has the correct rate" using (newFixedRatePaymentEvent.rate.ratioUnit!!.value == fixValue.value) - "The fixing is for the next required date" using (prevIrs.calculation.nextFixingDate() == fixValue.of.forDay) - "The fix payment has the same currency as the notional" using (newFixedRatePaymentEvent.flow.token == irs.floatingLeg.notional.token) - // "The fixing is not in the future " by (fixCommand) // The oracle should not have signed this . - } - - return setOf(command.value) - } - } - - class Pay : AbstractIRSClause() { - override val requiredCommands: Set> = setOf(Commands.Pay::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: UniqueIdentifier?): Set { - val command = tx.commands.requireSingleCommand() - requireThat { - "Payments not supported / verifiable yet" using false - } - return setOf(command.value) - } - } - - class Mature : AbstractIRSClause() { - override val requiredCommands: Set> = setOf(Commands.Mature::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: UniqueIdentifier?): Set { - val command = tx.commands.requireSingleCommand() - val irs = inputs.filterIsInstance().single() - requireThat { - "No more fixings to be applied" using (irs.calculation.nextFixingDate() == null) - "The irs is fully consumed and there is no id matched output state" using outputs.isEmpty() - } - - return setOf(command.value) - } - } - + require(atLeastOneCommandProcessed) { "At least one command needs to present" } } interface Commands : CommandData { @@ -671,7 +616,7 @@ class InterestRateSwap : Contract { override val oracleType: ServiceType get() = NodeInterestRates.Oracle.type - override val ref = common.tradeID + val ref: String get() = linearId.externalId ?: "" override val participants: List get() = listOf(fixedLeg.fixedRatePayer, floatingLeg.floatingRatePayer) @@ -785,7 +730,7 @@ class InterestRateSwap : Contract { // Put all the above into a new State object. val state = State(fixedLeg, floatingLeg, newCalculation, common) - return TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Agree(), listOf(state.floatingLeg.floatingRatePayer.owningKey, state.fixedLeg.fixedRatePayer.owningKey))) + return TransactionBuilder(notary).withItems(state, Command(Commands.Agree(), listOf(state.floatingLeg.floatingRatePayer.owningKey, state.fixedLeg.fixedRatePayer.owningKey))) } private fun calcFixingDate(date: LocalDate, fixingPeriodOffset: Int, calendar: BusinessCalendar): LocalDate { diff --git a/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt b/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt index 57110870aa..a902252889 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt @@ -3,8 +3,8 @@ package net.corda.irs.flows import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.Fix import net.corda.contracts.FixableDealState -import net.corda.core.TransientProperty import net.corda.core.contracts.* +import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.toBase58String import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatedBy @@ -13,12 +13,13 @@ import net.corda.core.flows.SchedulableFlow import net.corda.core.identity.Party import net.corda.core.node.NodeInfo import net.corda.core.node.services.ServiceType -import net.corda.core.seconds import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.seconds import net.corda.core.utilities.trace +import net.corda.core.utilities.transient import net.corda.flows.TwoPartyDealFlow import java.math.BigDecimal import java.security.PublicKey @@ -52,7 +53,7 @@ object FixingFlow { } @Suspendable - override fun assembleSharedTX(handshake: TwoPartyDealFlow.Handshake): Pair> { + override fun assembleSharedTX(handshake: TwoPartyDealFlow.Handshake): Triple, List> { @Suppress("UNCHECKED_CAST") val fixOf = deal.nextFixingOf()!! @@ -61,7 +62,7 @@ object FixingFlow { val newDeal = deal - val ptx = TransactionType.General.Builder(txState.notary) + val ptx = TransactionBuilder(txState.notary) val oracle = serviceHub.networkMapCache.getNodesWithService(handshake.payload.oracleType).first() val oracleParty = oracle.serviceIdentities(handshake.payload.oracleType).first() @@ -80,14 +81,14 @@ object FixingFlow { @Suspendable override fun filtering(elem: Any): Boolean { return when (elem) { - is Command -> oracleParty.owningKey in elem.signers && elem.value is Fix + is Command<*> -> oracleParty.owningKey in elem.signers && elem.value is Fix else -> false } } } - subFlow(addFixing) + val sig = subFlow(addFixing) // DOCEND 1 - return Pair(ptx, arrayListOf(myOldParty.owningKey)) + return Triple(ptx, arrayListOf(myOldParty.owningKey), listOf(sig)) } } @@ -103,7 +104,7 @@ object FixingFlow { override val progressTracker: ProgressTracker = TwoPartyDealFlow.Primary.tracker()) : TwoPartyDealFlow.Primary() { @Suppress("UNCHECKED_CAST") - internal val dealToFix: StateAndRef by TransientProperty { + internal val dealToFix: StateAndRef by transient { val state = serviceHub.loadState(payload.ref) as TransactionState StateAndRef(state, payload.ref) } diff --git a/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/RatesFixFlow.kt b/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/RatesFixFlow.kt index 0ab867f889..dde01b1226 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/RatesFixFlow.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/RatesFixFlow.kt @@ -3,7 +3,7 @@ package net.corda.irs.flows import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.Fix import net.corda.contracts.FixOf -import net.corda.core.crypto.DigitalSignature +import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.isFulfilledBy import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatingFlow @@ -33,7 +33,7 @@ open class RatesFixFlow(protected val tx: TransactionBuilder, protected val fixOf: FixOf, protected val expectedRate: BigDecimal, protected val rateTolerance: BigDecimal, - override val progressTracker: ProgressTracker = RatesFixFlow.tracker(fixOf.name)) : FlowLogic() { + override val progressTracker: ProgressTracker = RatesFixFlow.tracker(fixOf.name)) : FlowLogic() { companion object { class QUERYING(val name: String) : ProgressTracker.Step("Querying oracle for $name interest rate") @@ -54,7 +54,7 @@ open class RatesFixFlow(protected val tx: TransactionBuilder, // DOCSTART 2 @Suspendable - override fun call() { + override fun call(): TransactionSignature { progressTracker.currentStep = progressTracker.steps[1] val fix = subFlow(FixQueryFlow(fixOf, oracle)) progressTracker.currentStep = WORKING @@ -63,8 +63,7 @@ open class RatesFixFlow(protected val tx: TransactionBuilder, beforeSigning(fix) progressTracker.currentStep = SIGNING val mtx = tx.toWireTransaction().buildFilteredTransaction(Predicate { filtering(it) }) - val signature = subFlow(FixSignFlow(tx, oracle, mtx)) - tx.addSignatureUnchecked(signature) + return subFlow(FixSignFlow(tx, oracle, mtx)) } // DOCEND 2 @@ -112,10 +111,10 @@ open class RatesFixFlow(protected val tx: TransactionBuilder, @InitiatingFlow class FixSignFlow(val tx: TransactionBuilder, val oracle: Party, - val partialMerkleTx: FilteredTransaction) : FlowLogic() { + val partialMerkleTx: FilteredTransaction) : FlowLogic() { @Suspendable - override fun call(): DigitalSignature.WithKey { - val resp = sendAndReceive(oracle, SignRequest(partialMerkleTx)) + override fun call(): TransactionSignature { + val resp = sendAndReceive(oracle, SignRequest(partialMerkleTx)) return resp.unwrap { sig -> check(oracle.owningKey.isFulfilledBy(listOf(sig.by))) tx.toWireTransaction().checkSignature(sig) diff --git a/samples/irs-demo/src/main/kotlin/net/corda/irs/utilities/OracleUtils.kt b/samples/irs-demo/src/main/kotlin/net/corda/irs/utilities/OracleUtils.kt index 54413553a8..bb370c9a36 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/irs/utilities/OracleUtils.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/irs/utilities/OracleUtils.kt @@ -1,7 +1,11 @@ package net.corda.irs.utilities import net.corda.core.contracts.TimeWindow -import java.time.* +import net.corda.core.utilities.hours +import java.time.LocalDate +import java.time.LocalTime +import java.time.ZoneId +import java.time.ZonedDateTime /** * This whole file exists as short cuts to get demos working. In reality we'd have static data and/or rules engine @@ -16,5 +20,5 @@ fun suggestInterestRateAnnouncementTimeWindow(index: String, source: String, dat // Here we apply a blanket announcement time of 11:45 London irrespective of source or index val time = LocalTime.of(11, 45) val zoneId = ZoneId.of("Europe/London") - return TimeWindow.fromStartAndDuration(ZonedDateTime.of(date, time, zoneId).toInstant(), Duration.ofHours(24)) + return TimeWindow.fromStartAndDuration(ZonedDateTime.of(date, time, zoneId).toInstant(), 24.hours) } diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/Main.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/Main.kt index 321180d97d..f66fbe7a96 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/Main.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/Main.kt @@ -1,8 +1,8 @@ package net.corda.irs -import com.google.common.util.concurrent.Futures -import net.corda.core.getOrThrow +import net.corda.core.internal.concurrent.transpose import net.corda.core.node.services.ServiceInfo +import net.corda.core.utilities.getOrThrow import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B import net.corda.testing.DUMMY_NOTARY @@ -16,11 +16,11 @@ import net.corda.testing.driver.driver */ fun main(args: Array) { driver(dsl = { - val (controller, nodeA, nodeB) = Futures.allAsList( + val (controller, nodeA, nodeB) = listOf( startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type), ServiceInfo(NodeInterestRates.Oracle.type))), startNode(DUMMY_BANK_A.name), startNode(DUMMY_BANK_B.name) - ).getOrThrow() + ).transpose().getOrThrow() startWebserver(controller) startWebserver(nodeA) diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/InterestRatesSwapDemoAPI.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/InterestRatesSwapDemoAPI.kt index 8e9cb6045b..0ea9bfda2f 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/InterestRatesSwapDemoAPI.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/InterestRatesSwapDemoAPI.kt @@ -1,8 +1,8 @@ package net.corda.irs.api -import net.corda.core.getOrThrow import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.irs.flows.UpdateBusinessDayFlow import java.time.LocalDate diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt index 9d3035cef3..c501759e09 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/api/NodeInterestRatesTest.kt @@ -6,37 +6,31 @@ import net.corda.contracts.asset.CASH import net.corda.contracts.asset.Cash import net.corda.contracts.asset.`issued by` import net.corda.contracts.asset.`owned by` -import net.corda.core.bd import net.corda.core.contracts.* import net.corda.core.crypto.MerkleTreeException import net.corda.core.crypto.generateKeyPair -import net.corda.core.getOrThrow import net.corda.core.identity.Party import net.corda.core.node.services.ServiceInfo import net.corda.core.transactions.TransactionBuilder -import net.corda.testing.LogHelper import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.getOrThrow import net.corda.irs.flows.RatesFixFlow +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction import net.corda.testing.* -import net.corda.testing.node.MockNetwork -import net.corda.testing.node.MockServices -import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.* import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import org.junit.After import org.junit.Assert import org.junit.Before import org.junit.Test -import java.io.Closeable import java.math.BigDecimal import java.util.function.Predicate import kotlin.test.assertEquals import kotlin.test.assertFailsWith import kotlin.test.assertFalse -class NodeInterestRatesTest { +class NodeInterestRatesTest : TestDependencyInjectionBase() { val TEST_DATA = NodeInterestRates.parseFile(""" LIBOR 2016-03-16 1M = 0.678 LIBOR 2016-03-16 2M = 0.685 @@ -50,23 +44,20 @@ class NodeInterestRatesTest { val DUMMY_CASH_ISSUER = Party(X500Name("CN=Cash issuer,O=R3,OU=corda,L=London,C=GB"), DUMMY_CASH_ISSUER_KEY.public) lateinit var oracle: NodeInterestRates.Oracle - lateinit var dataSource: Closeable - lateinit var database: Database + lateinit var database: CordaPersistence fun fixCmdFilter(elem: Any): Boolean { return when (elem) { - is Command -> oracle.identity.owningKey in elem.signers && elem.value is Fix + is Command<*> -> oracle.identity.owningKey in elem.signers && elem.value is Fix else -> false } } - fun filterCmds(elem: Any): Boolean = elem is Command + fun filterCmds(elem: Any): Boolean = elem is Command<*> @Before fun setUp() { - val dataSourceAndDatabase = configureDatabase(makeTestDataSourceProperties()) - dataSource = dataSourceAndDatabase.first - database = dataSourceAndDatabase.second + database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService) database.transaction { oracle = NodeInterestRates.Oracle( MEGA_CORP, @@ -78,7 +69,7 @@ class NodeInterestRatesTest { @After fun tearDown() { - dataSource.close() + database.close() } @Test @@ -87,7 +78,7 @@ class NodeInterestRatesTest { val q = NodeInterestRates.parseFixOf("LIBOR 2016-03-16 1M") val res = oracle.query(listOf(q)) assertEquals(1, res.size) - assertEquals("0.678".bd, res[0].value) + assertEquals(BigDecimal("0.678"), res[0].value) assertEquals(q, res[0].of) } } @@ -169,7 +160,7 @@ class NodeInterestRatesTest { database.transaction { val tx = makeTX() val fixOf = NodeInterestRates.parseFixOf("LIBOR 2016-03-16 1M") - val badFix = Fix(fixOf, "0.6789".bd) + val badFix = Fix(fixOf, BigDecimal("0.6789")) tx.addCommand(badFix, oracle.identity.owningKey) val wtx = tx.toWireTransaction() val ftx = wtx.buildFilteredTransaction(Predicate { x -> fixCmdFilter(x) }) @@ -185,7 +176,7 @@ class NodeInterestRatesTest { val fix = oracle.query(listOf(NodeInterestRates.parseFixOf("LIBOR 2016-03-16 1M"))).first() fun filtering(elem: Any): Boolean { return when (elem) { - is Command -> oracle.identity.owningKey in elem.signers && elem.value is Fix + is Command<*> -> oracle.identity.owningKey in elem.signers && elem.value is Fix is TransactionState -> true else -> false } @@ -207,7 +198,7 @@ class NodeInterestRatesTest { @Test fun `network tearoff`() { - val mockNet = MockNetwork() + val mockNet = MockNetwork(initialiseSerialization = false) val n1 = mockNet.createNotaryNode() val n2 = mockNet.createNode(n1.network.myAddress, advertisedServices = ServiceInfo(NodeInterestRates.Oracle.type)) n2.registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java) @@ -215,10 +206,10 @@ class NodeInterestRatesTest { n2.database.transaction { n2.installCordaService(NodeInterestRates.Oracle::class.java).knownFixes = TEST_DATA } - val tx = TransactionType.General.Builder(null) + val tx = TransactionBuilder(null) val fixOf = NodeInterestRates.parseFixOf("LIBOR 2016-03-16 1M") val oracle = n2.info.serviceIdentities(NodeInterestRates.Oracle.type).first() - val flow = FilteredRatesFlow(tx, oracle, fixOf, "0.675".bd, "0.1".bd) + val flow = FilteredRatesFlow(tx, oracle, fixOf, BigDecimal("0.675"), BigDecimal("0.1")) LogHelper.setLevel("rates") mockNet.runNetwork() val future = n1.services.startFlow(flow).resultFuture @@ -227,7 +218,7 @@ class NodeInterestRatesTest { // We should now have a valid fix of our tx from the oracle. val fix = tx.toWireTransaction().commands.map { it.value as Fix }.first() assertEquals(fixOf, fix.of) - assertEquals("0.678".bd, fix.value) + assertEquals(BigDecimal("0.678"), fix.value) mockNet.stopNodes() } @@ -240,12 +231,12 @@ class NodeInterestRatesTest { : RatesFixFlow(tx, oracle, fixOf, expectedRate, rateTolerance, progressTracker) { override fun filtering(elem: Any): Boolean { return when (elem) { - is Command -> oracle.owningKey in elem.signers && elem.value is Fix + is Command<*> -> oracle.owningKey in elem.signers && elem.value is Fix else -> false } } } - private fun makeTX() = TransactionType.General.Builder(DUMMY_NOTARY).withItems( + private fun makeTX() = TransactionBuilder(DUMMY_NOTARY).withItems( 1000.DOLLARS.CASH `issued by` DUMMY_CASH_ISSUER `owned by` ALICE `with notary` DUMMY_NOTARY) } diff --git a/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt b/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt index fc01eddcaa..68c4bb195b 100644 --- a/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt +++ b/samples/irs-demo/src/test/kotlin/net/corda/irs/contract/IRSTests.kt @@ -2,11 +2,9 @@ package net.corda.irs.contract import net.corda.contracts.* import net.corda.core.contracts.* -import net.corda.core.seconds import net.corda.core.transactions.SignedTransaction -import net.corda.testing.DUMMY_NOTARY -import net.corda.testing.DUMMY_NOTARY_KEY -import net.corda.testing.TEST_TX_TIME +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.seconds import net.corda.testing.* import net.corda.testing.node.MockServices import org.junit.Test @@ -200,7 +198,7 @@ fun createDummyIRS(irsSelect: Int): InterestRateSwap.State { } } -class IRSTests { +class IRSTests : TestDependencyInjectionBase() { val megaCorpServices = MockServices(MEGA_CORP_KEY) val miniCorpServices = MockServices(MINI_CORP_KEY) val notaryServices = MockServices(DUMMY_NOTARY_KEY) @@ -249,7 +247,7 @@ class IRSTests { * Utility so I don't have to keep typing this. */ fun singleIRS(irsSelector: Int = 1): InterestRateSwap.State { - return generateIRSTxn(irsSelector).tx.outputs.map { it.data }.filterIsInstance().single() + return generateIRSTxn(irsSelector).tx.outputsOfType().single() } /** @@ -303,12 +301,12 @@ class IRSTests { var previousTXN = generateIRSTxn(1) previousTXN.toLedgerTransaction(services).verify() services.recordTransactions(previousTXN) - fun currentIRS() = previousTXN.tx.outputs.map { it.data }.filterIsInstance().single() + fun currentIRS() = previousTXN.tx.outputsOfType().single() while (true) { val nextFix: FixOf = currentIRS().nextFixingOf() ?: break val fixTX: SignedTransaction = run { - val tx = TransactionType.General.Builder(DUMMY_NOTARY) + val tx = TransactionBuilder(DUMMY_NOTARY) val fixing = Fix(nextFix, "0.052".percent.value) InterestRateSwap().generateFix(tx, previousTXN.tx.outRef(0), fixing) tx.setTimeWindow(TEST_TX_TIME, 30.seconds) @@ -370,7 +368,7 @@ class IRSTests { val ld = LocalDate.of(2016, 3, 8) val bd = BigDecimal("0.0063518") - return ledger { + return ledger(initialiseSerialization = false) { transaction("Agreement") { output("irs post agreement") { singleIRS() } command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() } @@ -401,7 +399,7 @@ class IRSTests { @Test fun `ensure failure occurs when there are inbound states for an agreement command`() { val irs = singleIRS() - transaction { + transaction(initialiseSerialization = false) { input { irs } output("irs post agreement") { irs } command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() } @@ -414,7 +412,7 @@ class IRSTests { fun `ensure failure occurs when no events in fix schedule`() { val irs = singleIRS() val emptySchedule = mutableMapOf() - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(calculation = irs.calculation.copy(fixedLegPaymentSchedule = emptySchedule)) } @@ -428,7 +426,7 @@ class IRSTests { fun `ensure failure occurs when no events in floating schedule`() { val irs = singleIRS() val emptySchedule = mutableMapOf() - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(calculation = irs.calculation.copy(floatingLegPaymentSchedule = emptySchedule)) } @@ -441,7 +439,7 @@ class IRSTests { @Test fun `ensure notionals are non zero`() { val irs = singleIRS() - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(irs.fixedLeg.copy(notional = irs.fixedLeg.notional.copy(quantity = 0))) } @@ -450,7 +448,7 @@ class IRSTests { this `fails with` "All notionals must be non zero" } - transaction { + transaction(initialiseSerialization = false) { output { irs.copy(irs.fixedLeg.copy(notional = irs.floatingLeg.notional.copy(quantity = 0))) } @@ -464,7 +462,7 @@ class IRSTests { fun `ensure positive rate on fixed leg`() { val irs = singleIRS() val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(fixedRate = FixedRate(PercentageRatioUnit("-0.1")))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS } @@ -481,7 +479,7 @@ class IRSTests { fun `ensure same currency notionals`() { val irs = singleIRS() val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.fixedLeg.notional.quantity, Currency.getInstance("JPY")))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS } @@ -495,7 +493,7 @@ class IRSTests { fun `ensure notional amounts are equal`() { val irs = singleIRS() val modifiedIRS = irs.copy(fixedLeg = irs.fixedLeg.copy(notional = Amount(irs.floatingLeg.notional.quantity + 1, irs.floatingLeg.notional.token))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS } @@ -509,7 +507,7 @@ class IRSTests { fun `ensure trade date and termination date checks are done pt1`() { val irs = singleIRS() val modifiedIRS1 = irs.copy(fixedLeg = irs.fixedLeg.copy(terminationDate = irs.fixedLeg.effectiveDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS1 } @@ -519,7 +517,7 @@ class IRSTests { } val modifiedIRS2 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.floatingLeg.effectiveDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS2 } @@ -534,7 +532,7 @@ class IRSTests { val irs = singleIRS() val modifiedIRS3 = irs.copy(floatingLeg = irs.floatingLeg.copy(terminationDate = irs.fixedLeg.terminationDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS3 } @@ -545,7 +543,7 @@ class IRSTests { val modifiedIRS4 = irs.copy(floatingLeg = irs.floatingLeg.copy(effectiveDate = irs.fixedLeg.effectiveDate.minusDays(1))) - transaction { + transaction(initialiseSerialization = false) { output { modifiedIRS4 } @@ -561,7 +559,7 @@ class IRSTests { val ld = LocalDate.of(2016, 3, 8) val bd = BigDecimal("0.0063518") - transaction { + transaction(initialiseSerialization = false) { output("irs post agreement") { singleIRS() } command(MEGA_CORP_PUBKEY) { InterestRateSwap.Commands.Agree() } timeWindow(TEST_TX_TIME) @@ -574,7 +572,7 @@ class IRSTests { oldIRS.calculation.applyFixing(ld, FixedRate(RatioUnit(bd))), oldIRS.common) - transaction { + transaction(initialiseSerialization = false) { input { oldIRS @@ -654,7 +652,7 @@ class IRSTests { val irs = singleIRS() - return ledger { + return ledger(initialiseSerialization = false) { transaction("Agreement") { output("irs post agreement1") { irs.copy( diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt index 655badd4c6..93fa4b4e95 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/NetworkMapVisualiser.kt @@ -13,7 +13,6 @@ import javafx.stage.Stage import javafx.util.Duration import net.corda.core.crypto.commonName import net.corda.core.serialization.deserialize -import net.corda.core.then import net.corda.core.utilities.ProgressTracker import net.corda.netmap.VisualiserViewModel.Style import net.corda.netmap.simulation.IRSSimulation diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/VisualiserViewModel.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/VisualiserViewModel.kt index afcfac1f43..547d29be7b 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/VisualiserViewModel.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/VisualiserViewModel.kt @@ -8,6 +8,7 @@ import javafx.scene.shape.Circle import javafx.scene.shape.Line import javafx.util.Duration import net.corda.core.crypto.commonName +import net.corda.core.node.ScreenCoordinate import net.corda.core.utilities.ProgressTracker import net.corda.netmap.simulation.IRSSimulation import net.corda.testing.node.MockNetwork @@ -26,8 +27,8 @@ class VisualiserViewModel { inner class NodeWidget(val node: MockNetwork.MockNode, val innerDot: Circle, val outerDot: Circle, val longPulseDot: Circle, val pulseAnim: Animation, val longPulseAnim: Animation, val nameLabel: Label, val statusLabel: Label) { - fun position(index: Int, nodeCoords: (node: MockNetwork.MockNode, index: Int) -> Pair) { - val (x, y) = nodeCoords(node, index) + fun position(nodeCoords: (node: MockNetwork.MockNode) -> ScreenCoordinate) { + val (x, y) = nodeCoords(node) innerDot.centerX = x innerDot.centerY = y outerDot.centerX = x @@ -63,20 +64,20 @@ class VisualiserViewModel { fun repositionNodes() { for ((index, bank) in simulation.banks.withIndex()) { - nodesToWidgets[bank]!!.position(index, when (displayStyle) { - Style.MAP -> { node, _ -> nodeMapCoords(node) } - Style.CIRCLE -> { _, index -> nodeCircleCoords(NetworkMapVisualiser.NodeType.BANK, index) } + nodesToWidgets[bank]!!.position(when (displayStyle) { + Style.MAP -> { node -> nodeMapCoords(node) } + Style.CIRCLE -> { _ -> nodeCircleCoords(NetworkMapVisualiser.NodeType.BANK, index) } }) } for ((index, serviceProvider) in (simulation.serviceProviders + simulation.regulators).withIndex()) { - nodesToWidgets[serviceProvider]!!.position(index, when (displayStyle) { - Style.MAP -> { node, _ -> nodeMapCoords(node) } - Style.CIRCLE -> { _, index -> nodeCircleCoords(NetworkMapVisualiser.NodeType.SERVICE, index) } + nodesToWidgets[serviceProvider]!!.position(when (displayStyle) { + Style.MAP -> { node -> nodeMapCoords(node) } + Style.CIRCLE -> { _ -> nodeCircleCoords(NetworkMapVisualiser.NodeType.SERVICE, index) } }) } } - fun nodeMapCoords(node: MockNetwork.MockNode): Pair { + fun nodeMapCoords(node: MockNetwork.MockNode): ScreenCoordinate { // For an image of the whole world, we use: // return node.place.coordinate.project(mapImage.fitWidth, mapImage.fitHeight, 85.0511, -85.0511, -180.0, 180.0) @@ -90,7 +91,7 @@ class VisualiserViewModel { } } - fun nodeCircleCoords(type: NetworkMapVisualiser.NodeType, index: Int): Pair { + fun nodeCircleCoords(type: NetworkMapVisualiser.NodeType, index: Int): ScreenCoordinate { val stepRad: Double = when (type) { NetworkMapVisualiser.NodeType.BANK -> 2 * Math.PI / bankCount NetworkMapVisualiser.NodeType.SERVICE -> (2 * Math.PI / serviceCount) @@ -109,7 +110,7 @@ class VisualiserViewModel { val circleY = view.stageHeight / 2 + yOffset val x: Double = radius * Math.cos(tangentRad) + circleX val y: Double = radius * Math.sin(tangentRad) + circleY - return Pair(x, y) + return ScreenCoordinate(x, y) } fun createNodes() { @@ -172,8 +173,8 @@ class VisualiserViewModel { val widget = NodeWidget(forNode, innerDot, outerDot, longPulseOuterDot, pulseAnim, longPulseAnim, nameLabel, statusLabel) when (displayStyle) { - Style.CIRCLE -> widget.position(index, { _, index -> nodeCircleCoords(nodeType, index) }) - Style.MAP -> widget.position(index, { node, _ -> nodeMapCoords(node) }) + Style.CIRCLE -> widget.position { _ -> nodeCircleCoords(nodeType, index) } + Style.MAP -> widget.position { node -> nodeMapCoords(node) } } return widget } diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/IRSSimulation.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/IRSSimulation.kt index c4148f403e..d229f24cc7 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/IRSSimulation.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/IRSSimulation.kt @@ -3,18 +3,17 @@ package net.corda.netmap.simulation import co.paralleluniverse.fibers.Suspendable import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.kotlin.readValue -import com.google.common.util.concurrent.* -import net.corda.core.* +import net.corda.core.concurrent.CordaFuture import net.corda.core.contracts.StateAndRef -import net.corda.core.contracts.UniqueIdentifier import net.corda.core.flows.FlowLogic -import net.corda.core.internal.FlowStateMachine import net.corda.core.flows.InitiatedBy import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.Party -import net.corda.core.node.services.linearHeadsOfType +import net.corda.core.internal.concurrent.* +import net.corda.core.internal.FlowStateMachine +import net.corda.core.node.services.queryBy +import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction -import net.corda.testing.DUMMY_CA import net.corda.flows.TwoPartyDealFlow.Acceptor import net.corda.flows.TwoPartyDealFlow.AutoOffer import net.corda.flows.TwoPartyDealFlow.Instigator @@ -22,7 +21,7 @@ import net.corda.irs.contract.InterestRateSwap import net.corda.irs.flows.FixingFlow import net.corda.jackson.JacksonSupport import net.corda.node.services.identity.InMemoryIdentityService -import net.corda.node.utilities.transaction +import net.corda.testing.DUMMY_CA import net.corda.testing.node.InMemoryMessagingNetwork import rx.Observable import java.security.PublicKey @@ -42,8 +41,8 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten private val executeOnNextIteration = Collections.synchronizedList(LinkedList<() -> Unit>()) - override fun startMainSimulation(): ListenableFuture { - val future = SettableFuture.create() + override fun startMainSimulation(): CordaFuture { + val future = openFuture() om = JacksonSupport.createInMemoryMapper(InMemoryIdentityService((banks + regulators + networkMap).map { it.info.legalIdentityAndCert }, trustRoot = DUMMY_CA.certificate)) startIRSDealBetween(0, 1).thenMatch({ @@ -52,44 +51,39 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten executeOnNextIteration.add { // Keep fixing until there's no more left to do. val initialFixFuture = doNextFixing(0, 1) + fun onFailure(t: Throwable) { + future.setException(t) // Propagate the error. + } - Futures.addCallback(initialFixFuture, object : FutureCallback { - override fun onFailure(t: Throwable) { - future.setException(t) // Propagate the error. - } - - override fun onSuccess(result: Unit?) { - // Pause for an iteration. - executeOnNextIteration.add {} - executeOnNextIteration.add { - val f = doNextFixing(0, 1) - if (f != null) { - Futures.addCallback(f, this, MoreExecutors.directExecutor()) - } else { - // All done! - future.set(Unit) - } + fun onSuccess(result: Unit?) { + // Pause for an iteration. + executeOnNextIteration.add {} + executeOnNextIteration.add { + val f = doNextFixing(0, 1) + if (f != null) { + f.thenMatch(::onSuccess, ::onFailure) + } else { + // All done! + future.set(Unit) } } - }, MoreExecutors.directExecutor()) + } + initialFixFuture!!.thenMatch(::onSuccess, ::onFailure) } }, {}) return future } - private fun loadLinearHeads(node: SimulatedNode): Map> { - return node.database.transaction { - node.services.vaultService.linearHeadsOfType() - } - } - - private fun doNextFixing(i: Int, j: Int): ListenableFuture? { + private fun doNextFixing(i: Int, j: Int): CordaFuture? { println("Doing a fixing between $i and $j") val node1: SimulatedNode = banks[i] val node2: SimulatedNode = banks[j] - val swaps: Map> = loadLinearHeads(node1) - val theDealRef: StateAndRef = swaps.values.single() + val swaps = + node1.database.transaction { + node1.services.vaultQueryService.queryBy().states + } + val theDealRef: StateAndRef = swaps.single() // Do we have any more days left in this deal's lifetime? If not, return. val nextFixingDate = theDealRef.state.data.calculation.nextFixingDate() ?: return null @@ -106,10 +100,10 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten if (nextFixingDate > currentDateAndTime.toLocalDate()) currentDateAndTime = nextFixingDate.atTime(15, 0) - return Futures.allAsList(futA, futB).map { Unit } + return listOf(futA, futB).transpose().map { Unit } } - private fun startIRSDealBetween(i: Int, j: Int): ListenableFuture { + private fun startIRSDealBetween(i: Int, j: Int): CordaFuture { val node1: SimulatedNode = banks[i] val node2: SimulatedNode = banks[j] @@ -153,7 +147,7 @@ class IRSSimulation(networkSendManuallyPumped: Boolean, runAsync: Boolean, laten node1.services.legalIdentityKey) val instigatorTxFuture = node1.services.startFlow(instigator).resultFuture - return Futures.allAsList(instigatorTxFuture, acceptorTxFuture).flatMap { instigatorTxFuture } + return listOf(instigatorTxFuture, acceptorTxFuture).transpose().flatMap { instigatorTxFuture } } override fun iterate(): InMemoryMessagingNetwork.MessageTransfer? { diff --git a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt index 1548e83ef6..c70c383102 100644 --- a/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt +++ b/samples/network-visualiser/src/main/kotlin/net/corda/netmap/simulation/Simulation.kt @@ -1,15 +1,16 @@ package net.corda.netmap.simulation -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.locationOrNull -import net.corda.core.flatMap import net.corda.core.flows.FlowLogic import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.CityDatabase import net.corda.core.node.WorldMapLocation import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.containsType +import net.corda.core.internal.concurrent.doneFuture +import net.corda.core.internal.concurrent.flatMap +import net.corda.core.internal.concurrent.transpose import net.corda.testing.DUMMY_MAP import net.corda.testing.DUMMY_NOTARY import net.corda.testing.DUMMY_REGULATOR @@ -19,7 +20,6 @@ import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.network.NetworkMapService import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.transactions.SimpleNotaryService -import net.corda.node.utilities.transaction import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.MockNetwork import net.corda.testing.node.TestClock @@ -62,12 +62,12 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, } } - inner class BankFactory : MockNetwork.Factory { + inner class BankFactory : MockNetwork.Factory { var counter = 0 override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, - entropyRoot: BigInteger): MockNetwork.MockNode { + entropyRoot: BigInteger): SimulatedNode { val letter = 'A' + counter val (city, country) = bankLocations[counter++ % bankLocations.size] @@ -80,17 +80,17 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, fun createAll(): List { return bankLocations.mapIndexed { i, _ -> // Use deterministic seeds so the simulation is stable. Needed so that party owning keys are stable. - mockNet.createNode(networkMap.network.myAddress, start = false, nodeFactory = this, entropyRoot = BigInteger.valueOf(i.toLong())) as SimulatedNode + mockNet.createNode(networkMap.network.myAddress, nodeFactory = this, start = false, entropyRoot = BigInteger.valueOf(i.toLong())) } } } val bankFactory = BankFactory() - object NetworkMapNodeFactory : MockNetwork.Factory { + object NetworkMapNodeFactory : MockNetwork.Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, - entropyRoot: BigInteger): MockNetwork.MockNode { + entropyRoot: BigInteger): SimulatedNode { require(advertisedServices.containsType(NetworkMapService.type)) val cfg = testNodeConfiguration( baseDirectory = config.baseDirectory, @@ -99,10 +99,10 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, } } - object NotaryNodeFactory : MockNetwork.Factory { + object NotaryNodeFactory : MockNetwork.Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, - entropyRoot: BigInteger): MockNetwork.MockNode { + entropyRoot: BigInteger): SimulatedNode { require(advertisedServices.containsType(SimpleNotaryService.type)) val cfg = testNodeConfiguration( baseDirectory = config.baseDirectory, @@ -111,19 +111,19 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, } } - object RatesOracleFactory : MockNetwork.Factory { + object RatesOracleFactory : MockNetwork.Factory { // TODO: Make a more realistic legal name val RATES_SERVICE_NAME = X500Name("CN=Rates Service Provider,O=R3,OU=corda,L=Madrid,C=ES") override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, - entropyRoot: BigInteger): MockNetwork.MockNode { + entropyRoot: BigInteger): SimulatedNode { require(advertisedServices.containsType(NodeInterestRates.Oracle.type)) val cfg = testNodeConfiguration( baseDirectory = config.baseDirectory, myLegalName = RATES_SERVICE_NAME) return object : SimulatedNode(cfg, network, networkMapAddr, advertisedServices, id, overrideServices, entropyRoot) { - override fun start(): MockNetwork.MockNode { + override fun start() { super.start() registerInitiatedFlow(NodeInterestRates.FixQueryHandler::class.java) registerInitiatedFlow(NodeInterestRates.FixSignHandler::class.java) @@ -132,16 +132,15 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, installCordaService(NodeInterestRates.Oracle::class.java).uploadFixes(it.reader().readText()) } } - return this } } } } - object RegulatorFactory : MockNetwork.Factory { + object RegulatorFactory : MockNetwork.Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, - entropyRoot: BigInteger): MockNetwork.MockNode { + entropyRoot: BigInteger): SimulatedNode { val cfg = testNodeConfiguration( baseDirectory = config.baseDirectory, myLegalName = DUMMY_REGULATOR.name) @@ -155,13 +154,10 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, val mockNet = MockNetwork(networkSendManuallyPumped, runAsync) // This one must come first. - val networkMap: SimulatedNode - = mockNet.createNode(null, nodeFactory = NetworkMapNodeFactory, advertisedServices = ServiceInfo(NetworkMapService.type)) as SimulatedNode - val notary: SimulatedNode - = mockNet.createNode(networkMap.network.myAddress, nodeFactory = NotaryNodeFactory, advertisedServices = ServiceInfo(SimpleNotaryService.type)) as SimulatedNode - val regulators: List = listOf(mockNet.createNode(networkMap.network.myAddress, start = false, nodeFactory = RegulatorFactory) as SimulatedNode) - val ratesOracle: SimulatedNode - = mockNet.createNode(networkMap.network.myAddress, start = false, nodeFactory = RatesOracleFactory, advertisedServices = ServiceInfo(NodeInterestRates.Oracle.type)) as SimulatedNode + val networkMap = mockNet.createNode(nodeFactory = NetworkMapNodeFactory, advertisedServices = ServiceInfo(NetworkMapService.type)) + val notary = mockNet.createNode(networkMap.network.myAddress, nodeFactory = NotaryNodeFactory, advertisedServices = ServiceInfo(SimpleNotaryService.type)) + val regulators = listOf(mockNet.createNode(networkMap.network.myAddress, start = false, nodeFactory = RegulatorFactory)) + val ratesOracle = mockNet.createNode(networkMap.network.myAddress, start = false, nodeFactory = RatesOracleFactory, advertisedServices = ServiceInfo(NodeInterestRates.Oracle.type)) // All nodes must be in one of these two lists for the purposes of the visualiser tool. val serviceProviders: List = listOf(notary, ratesOracle, networkMap) @@ -265,10 +261,9 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, } } - val networkInitialisationFinished: ListenableFuture<*> = - Futures.allAsList(mockNet.nodes.map { it.networkMapRegistrationFuture }) + val networkInitialisationFinished = mockNet.nodes.map { it.networkMapRegistrationFuture }.transpose() - fun start(): ListenableFuture { + fun start(): CordaFuture { mockNet.startNodes() // Wait for all the nodes to have finished registering with the network map service. return networkInitialisationFinished.flatMap { startMainSimulation() } @@ -278,8 +273,8 @@ abstract class Simulation(val networkSendManuallyPumped: Boolean, * Sub-classes should override this to trigger whatever they want to simulate. This method will be invoked once the * network bringup has been simulated. */ - protected open fun startMainSimulation(): ListenableFuture { - return Futures.immediateFuture(Unit) + protected open fun startMainSimulation(): CordaFuture { + return doneFuture(Unit) } fun stop() { diff --git a/samples/network-visualiser/src/test/kotlin/net/corda/netmap/simulation/IRSSimulationTest.kt b/samples/network-visualiser/src/test/kotlin/net/corda/netmap/simulation/IRSSimulationTest.kt index 6fc24d24e3..3cab8ae88f 100644 --- a/samples/network-visualiser/src/test/kotlin/net/corda/netmap/simulation/IRSSimulationTest.kt +++ b/samples/network-visualiser/src/test/kotlin/net/corda/netmap/simulation/IRSSimulationTest.kt @@ -1,6 +1,6 @@ package net.corda.netmap.simulation -import net.corda.core.getOrThrow +import net.corda.core.utilities.getOrThrow import net.corda.testing.LogHelper import org.junit.Test diff --git a/samples/notary-demo/build.gradle b/samples/notary-demo/build.gradle index 5f9204b252..d6b9266937 100644 --- a/samples/notary-demo/build.gradle +++ b/samples/notary-demo/build.gradle @@ -18,13 +18,13 @@ dependencies { testCompile "junit:junit:$junit_version" // Corda integration dependencies - compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') - compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') - compile project(':core') - compile project(':client:jfx') - compile project(':client:rpc') - compile project(':test-utils') - compile project(':cordform-common') + cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts') + cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') + cordaCompile project(':core') + cordaCompile project(':client:jfx') + cordaCompile project(':client:rpc') + cordaCompile project(':test-utils') + cordaCompile project(':cordform-common') } idea { diff --git a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/BFTNotaryCordform.kt b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/BFTNotaryCordform.kt index 3f6af7fc9f..b57011a62d 100644 --- a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/BFTNotaryCordform.kt +++ b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/BFTNotaryCordform.kt @@ -1,6 +1,6 @@ package net.corda.notarydemo -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.node.services.ServiceInfo import net.corda.testing.ALICE import net.corda.testing.BOB @@ -11,8 +11,8 @@ import net.corda.node.utilities.ServiceIdentityGenerator import net.corda.cordform.CordformDefinition import net.corda.cordform.CordformContext import net.corda.cordform.CordformNode -import net.corda.core.stream -import net.corda.core.toTypedArray +import net.corda.core.internal.stream +import net.corda.core.internal.toTypedArray import net.corda.core.utilities.NetworkHostAndPort import net.corda.node.services.transactions.minCorrectReplicas import org.bouncycastle.asn1.x500.X500Name diff --git a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/Notarise.kt b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/Notarise.kt index 855fd1dd27..8c7d7afd61 100644 --- a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/Notarise.kt +++ b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/Notarise.kt @@ -1,19 +1,19 @@ package net.corda.notarydemo -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.notUsed +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.toStringShort -import net.corda.core.getOrThrow -import net.corda.core.map import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow import net.corda.core.transactions.SignedTransaction +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.transpose import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.BOB +import net.corda.core.utilities.getOrThrow import net.corda.notarydemo.flows.DummyIssueAndMove import net.corda.notarydemo.flows.RPCStartableNotaryFlowClient +import net.corda.testing.BOB import kotlin.streams.asSequence fun main(args: Array) { @@ -27,14 +27,14 @@ fun main(args: Array) { /** Interface for using the notary demo API from a client. */ private class NotaryDemoClientApi(val rpc: CordaRPCOps) { private val notary by lazy { - val (parties, partyUpdates) = rpc.networkMapUpdates() + val (parties, partyUpdates) = rpc.networkMapFeed() partyUpdates.notUsed() val id = parties.stream().filter { it.advertisedServices.any { it.info.type.isNotary() } }.map { it.notaryIdentity }.distinct().asSequence().singleOrNull() checkNotNull(id) { "No unique notary identity, try cleaning the node directories." } } private val counterpartyNode by lazy { - val (parties, partyUpdates) = rpc.networkMapUpdates() + val (parties, partyUpdates) = rpc.networkMapFeed() partyUpdates.notUsed() parties.single { it.legalIdentity.name == BOB.name } } @@ -55,9 +55,9 @@ private class NotaryDemoClientApi(val rpc: CordaRPCOps) { * as it consumes the original asset and creates a copy with the new owner as its output. */ private fun buildTransactions(count: Int): List { - return Futures.allAsList((1..count).map { - rpc.startFlow(::DummyIssueAndMove, notary, counterpartyNode.legalIdentity).returnValue - }).getOrThrow() + return (1..count).map { + rpc.startFlow(::DummyIssueAndMove, notary, counterpartyNode.legalIdentity, it).returnValue + }.transpose().getOrThrow() } /** @@ -66,7 +66,7 @@ private class NotaryDemoClientApi(val rpc: CordaRPCOps) { * * @return a list of encoded signer public keys - one for every transaction */ - private fun notariseTransactions(transactions: List): List>> { + private fun notariseTransactions(transactions: List): List>> { return transactions.map { rpc.startFlow(::RPCStartableNotaryFlowClient, it).returnValue.map { it.map { it.by.toStringShort() } } } diff --git a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/RaftNotaryCordform.kt b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/RaftNotaryCordform.kt index 10a94ea6a4..f0f5279def 100644 --- a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/RaftNotaryCordform.kt +++ b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/RaftNotaryCordform.kt @@ -1,7 +1,7 @@ package net.corda.notarydemo import net.corda.core.crypto.appendToCommonName -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.node.services.ServiceInfo import net.corda.testing.ALICE import net.corda.testing.BOB diff --git a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/SingleNotaryCordform.kt b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/SingleNotaryCordform.kt index bdd1c21093..567bfe03f5 100644 --- a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/SingleNotaryCordform.kt +++ b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/SingleNotaryCordform.kt @@ -1,6 +1,6 @@ package net.corda.notarydemo -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.node.services.ServiceInfo import net.corda.testing.ALICE import net.corda.testing.BOB diff --git a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/DummyIssueAndMove.kt b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/DummyIssueAndMove.kt index 50f51bc9e1..4db198241d 100644 --- a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/DummyIssueAndMove.kt +++ b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/DummyIssueAndMove.kt @@ -1,28 +1,41 @@ package net.corda.notarydemo.flows import co.paralleluniverse.fibers.Suspendable -import net.corda.testing.contracts.DummyContract +import net.corda.contracts.asset.Cash +import net.corda.core.contracts.* +import net.corda.core.crypto.sha256 import net.corda.core.flows.FlowLogic import net.corda.core.flows.StartableByRPC +import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.SignedTransaction -import java.util.* +import net.corda.core.transactions.TransactionBuilder @StartableByRPC -class DummyIssueAndMove(private val notary: Party, private val counterpartyNode: Party) : FlowLogic() { +class DummyIssueAndMove(private val notary: Party, private val counterpartyNode: Party, private val discriminator: Int) : FlowLogic() { + object DoNothingContract : Contract { + override val legalContractReference = byteArrayOf().sha256() + override fun verify(tx: LedgerTransaction) {} + } + + data class State(override val participants: List, private val discriminator: Int) : ContractState { + override val contract = DoNothingContract + } + @Suspendable - override fun call(): SignedTransaction { - val random = Random() + override fun call() = serviceHub.run { // Self issue an asset - val issueTxBuilder = DummyContract.generateInitial(random.nextInt(), notary, serviceHub.myInfo.legalIdentity.ref(0)) - val issueTx = serviceHub.signInitialTransaction(issueTxBuilder) + val amount = Amount(1000000, Issued(myInfo.legalIdentity.ref(0), GBP)) + val issueTxBuilder = TransactionBuilder(notary = notary) + val signers = Cash().generateIssue(issueTxBuilder, amount, serviceHub.myInfo.legalIdentity, notary) + val issueTx = serviceHub.signInitialTransaction(issueTxBuilder, signers) serviceHub.recordTransactions(issueTx) // Move ownership of the asset to the counterparty - val counterPartyKey = counterpartyNode.owningKey - val asset = issueTx.tx.outRef(0) - val moveTxBuilder = DummyContract.move(asset, counterpartyNode) - val moveTx = serviceHub.signInitialTransaction(moveTxBuilder) + val moveTxBuilder = TransactionBuilder(notary = notary) + + val (_, keys) = Cash.generateSpend(serviceHub, moveTxBuilder, Amount(amount.quantity, GBP), counterpartyNode) // We don't check signatures because we know that the notary's signature is missing - return moveTx + signInitialTransaction(moveTxBuilder, keys) } } diff --git a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/RPCStartableNotaryFlowClient.kt b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/RPCStartableNotaryFlowClient.kt index a3d16c4984..95eb184c09 100644 --- a/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/RPCStartableNotaryFlowClient.kt +++ b/samples/notary-demo/src/main/kotlin/net/corda/notarydemo/flows/RPCStartableNotaryFlowClient.kt @@ -1,8 +1,8 @@ package net.corda.notarydemo.flows +import net.corda.core.flows.NotaryFlow import net.corda.core.flows.StartableByRPC import net.corda.core.transactions.SignedTransaction -import net.corda.flows.NotaryFlow @StartableByRPC class RPCStartableNotaryFlowClient(stx: SignedTransaction) : NotaryFlow.Client(stx) diff --git a/samples/simm-valuation-demo/build.gradle b/samples/simm-valuation-demo/build.gradle index 4be015ab67..41deacb868 100644 --- a/samples/simm-valuation-demo/build.gradle +++ b/samples/simm-valuation-demo/build.gradle @@ -29,11 +29,14 @@ dependencies { compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" // Corda integration dependencies - compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') - compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') - compile project(':core') - compile project(':webserver') - compile project(':finance') + cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts') + cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') + cordaCompile project(':core') + cordaCompile project(':finance') + cordaCompile project(':webserver') + + // Javax is required for webapis + compile "org.glassfish.jersey.core:jersey-server:${jersey_version}" // Cordapp dependencies // Specify your cordapp's dependencies below, including dependent cordapps @@ -48,6 +51,7 @@ dependencies { compile "com.opengamma.strata:strata-loader:${strata_version}" compile "com.opengamma.strata:strata-math:${strata_version}" + // Test dependencies testCompile project(':test-utils') testCompile "junit:junit:$junit_version" testCompile "org.assertj:assertj-core:${assertj_version}" @@ -55,9 +59,9 @@ dependencies { task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar']) { directory "./build/nodes" - networkMap "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + networkMap "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" node { - name "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + name "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" advertisedServices = ["corda.notary.validating"] p2pPort 10002 cordapps = [] diff --git a/samples/simm-valuation-demo/src/integration-test/kotlin/net/corda/vega/SimmValuationTest.kt b/samples/simm-valuation-demo/src/integration-test/kotlin/net/corda/vega/SimmValuationTest.kt index ec255ccd88..b10b1ea541 100644 --- a/samples/simm-valuation-demo/src/integration-test/kotlin/net/corda/vega/SimmValuationTest.kt +++ b/samples/simm-valuation-demo/src/integration-test/kotlin/net/corda/vega/SimmValuationTest.kt @@ -1,9 +1,9 @@ package net.corda.vega -import com.google.common.util.concurrent.Futures import com.opengamma.strata.product.common.BuySell -import net.corda.core.getOrThrow import net.corda.core.node.services.ServiceInfo +import net.corda.core.internal.concurrent.transpose +import net.corda.core.utilities.getOrThrow import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B import net.corda.testing.DUMMY_NOTARY @@ -34,8 +34,8 @@ class SimmValuationTest : IntegrationTestCategory { fun `runs SIMM valuation demo`() { driver(isDebug = true) { startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type))).getOrThrow() - val (nodeA, nodeB) = Futures.allAsList(startNode(nodeALegalName), startNode(nodeBLegalName)).getOrThrow() - val (nodeAApi, nodeBApi) = Futures.allAsList(startWebserver(nodeA), startWebserver(nodeB)) + val (nodeA, nodeB) = listOf(startNode(nodeALegalName), startNode(nodeBLegalName)).transpose().getOrThrow() + val (nodeAApi, nodeBApi) = listOf(startWebserver(nodeA), startWebserver(nodeB)).transpose() .getOrThrow() .map { HttpApi.fromHostAndPort(it.listenAddress, "api/simmvaluationdemo") } val nodeBParty = getPartyWithName(nodeAApi, nodeBLegalName) diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/analytics/example/OGSwapPricingCcpExample.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/analytics/example/OGSwapPricingCcpExample.kt index 3966488829..4c4821d4f5 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/analytics/example/OGSwapPricingCcpExample.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/analytics/example/OGSwapPricingCcpExample.kt @@ -32,8 +32,8 @@ import com.opengamma.strata.product.common.BuySell import com.opengamma.strata.product.swap.type.FixedIborSwapConventions import com.opengamma.strata.report.ReportCalculationResults import com.opengamma.strata.report.trade.TradeReport -import net.corda.core.div -import net.corda.core.exists +import net.corda.core.internal.div +import net.corda.core.internal.exists import java.nio.file.Path import java.nio.file.Paths import java.time.LocalDate diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/api/PortfolioApi.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/api/PortfolioApi.kt index 38ec9efa02..001f57261f 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/api/PortfolioApi.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/api/PortfolioApi.kt @@ -4,15 +4,15 @@ import com.opengamma.strata.basics.currency.MultiCurrencyAmount import net.corda.client.rpc.notUsed import net.corda.contracts.DealState import net.corda.core.contracts.StateAndRef -import net.corda.core.contracts.filterStatesOfType import net.corda.core.crypto.parsePublicKeyBase58 import net.corda.core.crypto.toBase58String -import net.corda.core.getOrThrow import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow +import net.corda.core.messaging.vaultQueryBy import net.corda.core.node.services.ServiceType +import net.corda.core.utilities.getOrThrow import net.corda.vega.analytics.InitialMarginTriple import net.corda.vega.contracts.IRSState import net.corda.vega.contracts.PortfolioState @@ -38,9 +38,10 @@ class PortfolioApi(val rpc: CordaRPCOps) { private val portfolioUtils = PortfolioApiUtils(ownParty) private inline fun dealsWith(party: AbstractParty): List> { - val (vault, vaultUpdates) = rpc.vaultAndUpdates() - vaultUpdates.notUsed() - return vault.filterStatesOfType().filter { it.state.data.participants.any { it == party } } + val linearStates = rpc.vaultQueryBy().states + // TODO: enhancement to Vault Query to check for any participant in participants attribute + // QueryCriteria.LinearStateQueryCriteria(participants = anyOf(party)) + return linearStates.filter { it.state.data.participants.any { it == party } } } /** diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/IRSState.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/IRSState.kt index 1b61ac3eed..8d40347f27 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/IRSState.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/IRSState.kt @@ -2,7 +2,6 @@ package net.corda.vega.contracts import net.corda.contracts.DealState import net.corda.core.contracts.Command -import net.corda.core.contracts.TransactionType import net.corda.core.contracts.UniqueIdentifier import net.corda.core.crypto.keys import net.corda.core.identity.AbstractParty @@ -20,7 +19,7 @@ data class IRSState(val swap: SwapData, val seller: AbstractParty, override val contract: OGTrade, override val linearId: UniqueIdentifier = UniqueIdentifier(swap.id.first + swap.id.second)) : DealState { - override val ref: String = linearId.externalId!! // Same as the constructor for UniqueIdentified + val ref: String get() = linearId.externalId!! // Same as the constructor for UniqueIdentified override val participants: List get() = listOf(buyer, seller) override fun isRelevant(ourKeys: Set): Boolean { @@ -29,6 +28,6 @@ data class IRSState(val swap: SwapData, override fun generateAgreement(notary: Party): TransactionBuilder { val state = IRSState(swap, buyer, seller, OGTrade()) - return TransactionType.General.Builder(notary).withItems(state, Command(OGTrade.Commands.Agree(), participants.map { it.owningKey })) + return TransactionBuilder(notary).withItems(state, Command(OGTrade.Commands.Agree(), participants.map { it.owningKey })) } } diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/OGTrade.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/OGTrade.kt index 488a9dd4a1..9d6cfa56d9 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/OGTrade.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/OGTrade.kt @@ -1,48 +1,22 @@ package net.corda.vega.contracts import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.* import net.corda.core.crypto.SecureHash +import net.corda.core.transactions.LedgerTransaction import java.math.BigDecimal /** * Specifies the contract between two parties that trade an OpenGamma IRS. Currently can only agree to trade. */ data class OGTrade(override val legalContractReference: SecureHash = SecureHash.sha256("OGTRADE.KT")) : Contract { - override fun verify(tx: TransactionForContract) = verifyClause(tx, AllOf(Clauses.TimeWindowed(), Clauses.Group()), tx.commands.select()) - - interface Commands : CommandData { - class Agree : TypeOnlyCommandData(), Commands // Both sides agree to trade - } - - interface Clauses { - class TimeWindowed : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Unit?): Set { - require(tx.timeWindow?.midpoint != null) { "must have a time-window" } - // We return an empty set because we don't process any commands - return emptySet() - } - } - - class Group : GroupClauseVerifier(AnyOf(Agree())) { - override fun groupStates(tx: TransactionForContract): List> - // Group by Trade ID for in / out states - = tx.groupStates { state -> state.linearId } - } - - class Agree : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: UniqueIdentifier?): Set { - val command = tx.commands.requireSingleCommand() - - require(inputs.size == 0) { "Inputs must be empty" } + override fun verify(tx: LedgerTransaction) { + requireNotNull(tx.timeWindow) { "must have a time-window" } + val groups: List> = tx.groupStates { state -> state.linearId } + var atLeastOneCommandProcessed = false + for ((inputs, outputs, key) in groups) { + val command = tx.commands.select().firstOrNull() + if (command != null) { + require(inputs.isEmpty()) { "Inputs must be empty" } require(outputs.size == 1) { "" } require(outputs[0].buyer != outputs[0].seller) require(outputs[0].participants.containsAll(outputs[0].participants)) @@ -50,9 +24,13 @@ data class OGTrade(override val legalContractReference: SecureHash = SecureHash. require(outputs[0].swap.startDate.isBefore(outputs[0].swap.endDate)) require(outputs[0].swap.notional > BigDecimal(0)) require(outputs[0].swap.tradeDate.isBefore(outputs[0].swap.endDate)) - - return setOf(command.value) + atLeastOneCommandProcessed = true } } + require(atLeastOneCommandProcessed) { "At least one command needs to present" } + } + + interface Commands : CommandData { + class Agree : TypeOnlyCommandData(), Commands // Both sides agree to trade } } diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioState.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioState.kt index 6fcbb6eba3..5fb811f26e 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioState.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioState.kt @@ -29,7 +29,7 @@ data class PortfolioState(val portfolio: List, data class Update(val portfolio: List? = null, val valuation: PortfolioValuation? = null) override val participants: List get() = _parties.toList() - override val ref: String = linearId.toString() + val ref: String get() = linearId.toString() val valuer: AbstractParty get() = participants[0] override fun nextScheduledActivity(thisStateRef: StateRef, flowLogicRefFactory: FlowLogicRefFactory): ScheduledActivity { @@ -42,7 +42,7 @@ data class PortfolioState(val portfolio: List, } override fun generateAgreement(notary: Party): TransactionBuilder { - return TransactionType.General.Builder(notary).withItems(copy(), Command(PortfolioSwap.Commands.Agree(), participants.map { it.owningKey })) + return TransactionBuilder(notary).withItems(copy(), Command(PortfolioSwap.Commands.Agree(), participants.map { it.owningKey })) } override fun generateRevision(notary: Party, oldState: StateAndRef<*>, updatedValue: Update): TransactionBuilder { @@ -50,7 +50,7 @@ data class PortfolioState(val portfolio: List, val portfolio = updatedValue.portfolio ?: portfolio val valuation = updatedValue.valuation ?: valuation - val tx = TransactionType.General.Builder(notary) + val tx = TransactionBuilder(notary) tx.addInputState(oldState) tx.addOutputState(copy(portfolio = portfolio, valuation = valuation)) tx.addCommand(PortfolioSwap.Commands.Update(), participants.map { it.owningKey }) diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioSwap.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioSwap.kt index c946500748..d417111e7d 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioSwap.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/contracts/PortfolioSwap.kt @@ -1,8 +1,8 @@ package net.corda.vega.contracts import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.* import net.corda.core.crypto.SecureHash +import net.corda.core.transactions.LedgerTransaction /** * Specifies the contract between two parties that are agreeing to a portfolio of trades and valuating that portfolio. @@ -10,71 +10,34 @@ import net.corda.core.crypto.SecureHash * of the portfolio arbitrarily. */ data class PortfolioSwap(override val legalContractReference: SecureHash = SecureHash.sha256("swordfish")) : Contract { - override fun verify(tx: TransactionForContract) = verifyClause(tx, AllOf(Clauses.TimeWindowed(), Clauses.Group()), tx.commands.select()) + override fun verify(tx: LedgerTransaction) { + requireNotNull(tx.timeWindow) { "must have a time-window)" } + val groups: List> = tx.groupStates { state -> state.linearId } + for ((inputs, outputs, key) in groups) { + val agreeCommand = tx.commands.select().firstOrNull() + if (agreeCommand != null) { + requireThat { + "there are no inputs" using (inputs.isEmpty()) + "there is one output" using (outputs.size == 1) + "valuer must be a party" using (outputs[0].participants.contains(outputs[0].valuer)) + } + } else { + val updateCommand = tx.commands.select().firstOrNull() + if (updateCommand != null) { + requireThat { + "there is only one input" using (inputs.size == 1) + "there is only one output" using (outputs.size == 1) + "the valuer hasn't changed" using (inputs[0].valuer == outputs[0].valuer) + "the linear id hasn't changed" using (inputs[0].linearId == outputs[0].linearId) + } + + } + } + } + } interface Commands : CommandData { class Agree : TypeOnlyCommandData(), Commands // Both sides agree to portfolio class Update : TypeOnlyCommandData(), Commands // Both sides re-agree to portfolio } - - interface Clauses { - class TimeWindowed : Clause() { - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: Unit?): Set { - require(tx.timeWindow?.midpoint != null) { "must have a time-window)" } - // We return an empty set because we don't process any commands - return emptySet() - } - } - - class Group : GroupClauseVerifier(FirstOf(Agree(), Update())) { - override fun groupStates(tx: TransactionForContract): List> - // Group by Trade ID for in / out states - = tx.groupStates { state -> state.linearId } - } - - class Update : Clause() { - override val requiredCommands: Set> = setOf(Commands.Update::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: UniqueIdentifier?): Set { - val command = tx.commands.requireSingleCommand() - - requireThat { - "there is only one input" using (inputs.size == 1) - "there is only one output" using (outputs.size == 1) - "the valuer hasn't changed" using (inputs[0].valuer == outputs[0].valuer) - "the linear id hasn't changed" using (inputs[0].linearId == outputs[0].linearId) - } - - return setOf(command.value) - } - } - - class Agree : Clause() { - override val requiredCommands: Set> = setOf(Commands.Agree::class.java) - - override fun verify(tx: TransactionForContract, - inputs: List, - outputs: List, - commands: List>, - groupingKey: UniqueIdentifier?): Set { - val command = tx.commands.requireSingleCommand() - - requireThat { - "there are no inputs" using (inputs.size == 0) - "there is one output" using (outputs.size == 1) - "valuer must be a party" using (outputs[0].participants.contains(outputs[0].valuer)) - } - - return setOf(command.value) - } - } - } } diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/OpenGammaCordaUtils.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/OpenGammaCordaUtils.kt index a30bc3f7fe..ca6271ed98 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/OpenGammaCordaUtils.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/OpenGammaCordaUtils.kt @@ -56,12 +56,12 @@ fun InitialMarginTriple.toCordaCompatible() = InitialMarginTriple(twoDecimalPlac * Utility function to ensure that [CurrencyParameterSensitivities] can be sent over corda and compared */ fun CurrencyParameterSensitivities.toCordaCompatible(): CurrencyParameterSensitivities { - return CurrencyParameterSensitivities.of(this.sensitivities.map { - it.metaBean().builder() - .set("marketDataName", it.marketDataName) - .set("parameterMetadata", it.parameterMetadata) - .set("currency", Currency.of(it.currency.code).serialize().deserialize()) - .set("sensitivity", it.sensitivity.map { it -> twoDecimalPlaces(it) }) + return CurrencyParameterSensitivities.of(this.sensitivities.map { sensitivity -> + sensitivity.metaBean().builder() + .set("marketDataName", sensitivity.marketDataName) + .set("parameterMetadata", sensitivity.parameterMetadata) + .set("currency", Currency.of(sensitivity.currency.code).serialize().deserialize()) + .set("sensitivity", sensitivity.sensitivity.map { twoDecimalPlaces(it) }) .build() }) } diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmFlow.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmFlow.kt index 331f4b2b96..9fb4a1ab30 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmFlow.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmFlow.kt @@ -8,25 +8,22 @@ import com.opengamma.strata.pricer.curve.CalibrationMeasures import com.opengamma.strata.pricer.curve.CurveCalibrator import com.opengamma.strata.pricer.rate.ImmutableRatesProvider import com.opengamma.strata.pricer.swap.DiscountingSwapProductPricer -import net.corda.contracts.dealsWith import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateRef -import net.corda.core.flows.FlowLogic -import net.corda.core.flows.InitiatedBy -import net.corda.core.flows.InitiatingFlow -import net.corda.core.flows.StartableByRPC +import net.corda.core.flows.* +import net.corda.core.flows.AbstractStateReplacementFlow.Proposal import net.corda.core.identity.Party +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.QueryCriteria.LinearStateQueryCriteria +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.unwrap -import net.corda.flows.AbstractStateReplacementFlow.Proposal -import net.corda.flows.StateReplacementException import net.corda.flows.TwoPartyDealFlow import net.corda.vega.analytics.* import net.corda.vega.contracts.* import net.corda.vega.portfolio.Portfolio import net.corda.vega.portfolio.toPortfolio -import net.corda.vega.portfolio.toStateAndRef import java.time.LocalDate /** @@ -67,14 +64,17 @@ object SimmFlow { notary = serviceHub.networkMapCache.notaryNodes.first().notaryIdentity myIdentity = serviceHub.myInfo.legalIdentity - val trades = serviceHub.vaultService.dealsWith(otherParty) + val criteria = LinearStateQueryCriteria(participants = listOf(otherParty)) + val trades = serviceHub.vaultQueryService.queryBy(criteria).states + val portfolio = Portfolio(trades, valuationDate) if (existing == null) { agreePortfolio(portfolio) } else { updatePortfolio(portfolio, existing) } - val portfolioStateRef = serviceHub.vaultService.dealsWith(otherParty).first() + val portfolioStateRef = serviceHub.vaultQueryService.queryBy(criteria).states.first() + val state = updateValuation(portfolioStateRef) logger.info("SimmFlow done") return state @@ -104,7 +104,8 @@ object SimmFlow { private fun updateValuation(stateRef: StateAndRef): RevisionedState { logger.info("Agreeing valuations") val state = stateRef.state.data - val portfolio = state.portfolio.toStateAndRef(serviceHub).toPortfolio() + val portfolio = serviceHub.vaultQueryService.queryBy(VaultQueryCriteria(stateRefs = state.portfolio)).states.toPortfolio() + val valuer = serviceHub.identityService.partyFromAnonymous(state.valuer) require(valuer != null) { "Valuer party must be known to this node" } val valuation = agreeValuation(portfolio, valuationDate, valuer!!) @@ -190,7 +191,9 @@ object SimmFlow { @Suspendable override fun call() { ownParty = serviceHub.myInfo.legalIdentity - val trades = serviceHub.vaultService.dealsWith(replyToParty) + + val criteria = LinearStateQueryCriteria(participants = listOf(replyToParty)) + val trades = serviceHub.vaultQueryService.queryBy(criteria).states val portfolio = Portfolio(trades) logger.info("SimmFlow receiver started") offer = receive(replyToParty).unwrap { it } @@ -199,7 +202,7 @@ object SimmFlow { } else { updatePortfolio(portfolio) } - val portfolioStateRef = serviceHub.vaultService.dealsWith(replyToParty).first() + val portfolioStateRef = serviceHub.vaultQueryService.queryBy(criteria).states.first() updateValuation(portfolioStateRef) } @@ -299,8 +302,8 @@ object SimmFlow { logger.info("Handshake finished, awaiting Simm update") send(replyToParty, Ack) // Hack to state that this party is ready. subFlow(object : StateRevisionFlow.Receiver(replyToParty) { - override fun verifyProposal(proposal: Proposal) { - super.verifyProposal(proposal) + override fun verifyProposal(stx:SignedTransaction, proposal: Proposal) { + super.verifyProposal(stx, proposal) if (proposal.modification.portfolio != portfolio.refs) throw StateReplacementException() } }) @@ -308,12 +311,12 @@ object SimmFlow { @Suspendable private fun updateValuation(stateRef: StateAndRef) { - val portfolio = stateRef.state.data.portfolio.toStateAndRef(serviceHub).toPortfolio() + val portfolio = serviceHub.vaultQueryService.queryBy(VaultQueryCriteria(stateRefs = stateRef.state.data.portfolio)).states.toPortfolio() val valuer = serviceHub.identityService.partyFromAnonymous(stateRef.state.data.valuer) ?: throw IllegalStateException("Unknown valuer party ${stateRef.state.data.valuer}") val valuation = agreeValuation(portfolio, offer.valuationDate, valuer) subFlow(object : StateRevisionFlow.Receiver(replyToParty) { - override fun verifyProposal(proposal: Proposal) { - super.verifyProposal(proposal) + override fun verifyProposal(stx: SignedTransaction, proposal: Proposal) { + super.verifyProposal(stx, proposal) if (proposal.modification.valuation != valuation) throw StateReplacementException() } }) diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmRevaluation.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmRevaluation.kt index da189a9204..ef2839db21 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmRevaluation.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/SimmRevaluation.kt @@ -5,7 +5,8 @@ import net.corda.core.contracts.StateRef import net.corda.core.flows.FlowLogic import net.corda.core.flows.SchedulableFlow import net.corda.core.flows.StartableByRPC -import net.corda.core.node.services.linearHeadsOfType +import net.corda.core.node.services.queryBy +import net.corda.core.node.services.vault.QueryCriteria.VaultQueryCriteria import net.corda.vega.contracts.PortfolioState import java.time.LocalDate @@ -19,7 +20,7 @@ object SimmRevaluation { class Initiator(val curStateRef: StateRef, val valuationDate: LocalDate) : FlowLogic() { @Suspendable override fun call(): Unit { - val stateAndRef = serviceHub.vaultService.linearHeadsOfType().values.first { it.ref == curStateRef } + val stateAndRef = serviceHub.vaultQueryService.queryBy(VaultQueryCriteria(stateRefs = listOf(curStateRef))).states.single() val curState = stateAndRef.state.data val myIdentity = serviceHub.myInfo.legalIdentity if (myIdentity == curState.participants[0]) { diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/StateRevisionFlow.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/StateRevisionFlow.kt index c9d56ae403..10b2d4ee58 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/StateRevisionFlow.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/flows/StateRevisionFlow.kt @@ -1,18 +1,17 @@ package net.corda.vega.flows +import net.corda.core.contracts.PrivacySalt import net.corda.core.contracts.StateAndRef -import net.corda.core.identity.AbstractParty +import net.corda.core.flows.AbstractStateReplacementFlow +import net.corda.core.flows.StateReplacementException import net.corda.core.identity.Party -import net.corda.core.seconds import net.corda.core.transactions.SignedTransaction -import net.corda.flows.AbstractStateReplacementFlow -import net.corda.flows.StateReplacementException +import net.corda.core.utilities.seconds import net.corda.vega.contracts.RevisionedState -import java.security.PublicKey /** * Flow that generates an update on a mutable deal state and commits the resulting transaction reaching consensus - * on the update between two parties + * on the update between two parties. */ object StateRevisionFlow { class Requester(curStateRef: StateAndRef>, @@ -21,6 +20,8 @@ object StateRevisionFlow { val state = originalState.state.data val tx = state.generateRevision(originalState.state.notary, originalState, modification) tx.setTimeWindow(serviceHub.clock.instant(), 30.seconds) + val privacySalt = PrivacySalt() + tx.setPrivacySalt(privacySalt) val stx = serviceHub.signInitialTransaction(tx) val participantKeys = state.participants.map { it.owningKey } @@ -31,8 +32,8 @@ object StateRevisionFlow { } open class Receiver(otherParty: Party) : AbstractStateReplacementFlow.Acceptor(otherParty) { - override fun verifyProposal(proposal: AbstractStateReplacementFlow.Proposal) { - val proposedTx = proposal.stx.tx + override fun verifyProposal(stx: SignedTransaction, proposal: AbstractStateReplacementFlow.Proposal) { + val proposedTx = stx.tx val state = proposal.stateRef if (state !in proposedTx.inputs) { throw StateReplacementException("The proposed state $state is not in the proposed transaction inputs") diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/plugin/SimmPlugin.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/plugin/SimmPlugin.kt new file mode 100644 index 0000000000..3eda5bef8e --- /dev/null +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/plugin/SimmPlugin.kt @@ -0,0 +1,17 @@ +package net.corda.vega.plugin + +import net.corda.core.node.CordaPluginRegistry +import net.corda.vega.api.PortfolioApi +import net.corda.webserver.services.WebServerPluginRegistry +import java.util.function.Function + +/** + * [SimmService] is the object that makes available the flows and services for the Simm agreement / evaluation flow. + * It is loaded via discovery - see [CordaPluginRegistry]. + * It is also the object that enables a human usable web service for demo purpose + * It is loaded via discovery see [WebServerPluginRegistry]. + */ +class SimmPlugin : WebServerPluginRegistry { + override val webApis = listOf(Function(::PortfolioApi)) + override val staticServeDirs: Map = mapOf("simmvaluationdemo" to javaClass.classLoader.getResource("simmvaluationweb").toExternalForm()) +} \ No newline at end of file diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/plugin/SimmPluginRegistry.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/plugin/SimmPluginRegistry.kt new file mode 100644 index 0000000000..b965550b9a --- /dev/null +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/plugin/SimmPluginRegistry.kt @@ -0,0 +1,44 @@ +package net.corda.vega.plugin + +import com.google.common.collect.Ordering +import com.opengamma.strata.basics.currency.Currency +import com.opengamma.strata.basics.currency.CurrencyAmount +import com.opengamma.strata.basics.currency.MultiCurrencyAmount +import com.opengamma.strata.basics.date.Tenor +import com.opengamma.strata.collect.array.DoubleArray +import com.opengamma.strata.market.curve.CurveName +import com.opengamma.strata.market.param.CurrencyParameterSensitivities +import com.opengamma.strata.market.param.CurrencyParameterSensitivity +import com.opengamma.strata.market.param.TenorDateParameterMetadata +import net.corda.core.node.CordaPluginRegistry +import net.corda.core.serialization.SerializationCustomization +import net.corda.vega.analytics.CordaMarketData +import net.corda.vega.analytics.InitialMarginTriple +import net.corda.webserver.services.WebServerPluginRegistry + +/** + * [SimmService] is the object that makes available the flows and services for the Simm agreement / evaluation flow. + * It is loaded via discovery - see [CordaPluginRegistry]. + * It is also the object that enables a human usable web service for demo purpose + * It is loaded via discovery see [WebServerPluginRegistry]. + */ +class SimmPluginRegistry : CordaPluginRegistry() { + override fun customizeSerialization(custom: SerializationCustomization): Boolean { + custom.apply { + // OpenGamma classes. + addToWhitelist(MultiCurrencyAmount::class.java) + addToWhitelist(Ordering.natural>().javaClass) + addToWhitelist(CurrencyAmount::class.java) + addToWhitelist(Currency::class.java) + addToWhitelist(InitialMarginTriple::class.java) + addToWhitelist(CordaMarketData::class.java) + addToWhitelist(CurrencyParameterSensitivities::class.java) + addToWhitelist(CurrencyParameterSensitivity::class.java) + addToWhitelist(DoubleArray::class.java) + addToWhitelist(CurveName::class.java) + addToWhitelist(TenorDateParameterMetadata::class.java) + addToWhitelist(Tenor::class.java) + } + return true + } +} diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/portfolio/Portfolio.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/portfolio/Portfolio.kt index b91186362e..f61995da92 100644 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/portfolio/Portfolio.kt +++ b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/portfolio/Portfolio.kt @@ -1,13 +1,14 @@ package net.corda.vega.portfolio -import net.corda.client.rpc.notUsed import net.corda.core.contracts.* import net.corda.core.identity.Party +import net.corda.core.internal.sum import net.corda.core.messaging.CordaRPCOps -import net.corda.core.node.ServiceHub -import net.corda.core.sum +import net.corda.core.messaging.vaultQueryBy +import net.corda.core.node.services.vault.QueryCriteria import net.corda.vega.contracts.IRSState import net.corda.vega.contracts.SwapData +import java.math.BigDecimal import java.time.LocalDate /** @@ -22,7 +23,7 @@ data class Portfolio(private val tradeStateAndRefs: List>, val swaps: List by lazy { trades.map { it.swap } } val refs: List by lazy { tradeStateAndRefs.map { it.ref } } - fun getNotionalForParty(party: Party) = trades.map { it.swap.getLegForParty(party).notional }.sum() + fun getNotionalForParty(party: Party): BigDecimal = trades.map { it.swap.getLegForParty(party).notional }.sum() fun update(curTrades: List>): Portfolio { return copy(tradeStateAndRefs = curTrades) @@ -34,16 +35,5 @@ fun List>.toPortfolio(): Portfolio { } inline fun List.toStateAndRef(rpc: CordaRPCOps): List> { - val (vault, vaultUpdates) = rpc.vaultAndUpdates() - vaultUpdates.notUsed() - val stateRefs = vault.associateBy { it.ref } - return mapNotNull { stateRefs[it] }.filterStatesOfType() -} - -// TODO: This should probably have its generics fixed and moved into the core platform API. -@Suppress("UNCHECKED_CAST") -fun List.toStateAndRef(services: ServiceHub): List> { - return services.vaultService.statesForRefs(this).map { - StateAndRef(it.value as TransactionState, it.key) - } + return rpc.vaultQueryBy(QueryCriteria.VaultQueryCriteria(stateRefs = this)).states } diff --git a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/services/SimmService.kt b/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/services/SimmService.kt deleted file mode 100644 index dd318fc617..0000000000 --- a/samples/simm-valuation-demo/src/main/kotlin/net/corda/vega/services/SimmService.kt +++ /dev/null @@ -1,50 +0,0 @@ -package net.corda.vega.services - -import com.google.common.collect.Ordering -import com.opengamma.strata.basics.currency.Currency -import com.opengamma.strata.basics.currency.CurrencyAmount -import com.opengamma.strata.basics.currency.MultiCurrencyAmount -import com.opengamma.strata.basics.date.Tenor -import com.opengamma.strata.collect.array.DoubleArray -import com.opengamma.strata.market.curve.CurveName -import com.opengamma.strata.market.param.CurrencyParameterSensitivities -import com.opengamma.strata.market.param.CurrencyParameterSensitivity -import com.opengamma.strata.market.param.TenorDateParameterMetadata -import net.corda.core.node.CordaPluginRegistry -import net.corda.core.serialization.SerializationCustomization -import net.corda.vega.analytics.CordaMarketData -import net.corda.vega.analytics.InitialMarginTriple -import net.corda.vega.api.PortfolioApi -import net.corda.webserver.services.WebServerPluginRegistry -import java.util.function.Function - -/** - * [SimmService] is the object that makes available the flows and services for the Simm agreement / evaluation flow. - * It is loaded via discovery - see [CordaPluginRegistry]. - * It is also the object that enables a human usable web service for demo purpose - * It is loaded via discovery see [WebServerPluginRegistry]. - */ -object SimmService { - class Plugin : CordaPluginRegistry(), WebServerPluginRegistry { - override val webApis = listOf(Function(::PortfolioApi)) - override val staticServeDirs: Map = mapOf("simmvaluationdemo" to javaClass.classLoader.getResource("simmvaluationweb").toExternalForm()) - override fun customizeSerialization(custom: SerializationCustomization): Boolean { - custom.apply { - // OpenGamma classes. - addToWhitelist(MultiCurrencyAmount::class.java) - addToWhitelist(Ordering.natural>().javaClass) - addToWhitelist(CurrencyAmount::class.java) - addToWhitelist(Currency::class.java) - addToWhitelist(InitialMarginTriple::class.java) - addToWhitelist(CordaMarketData::class.java) - addToWhitelist(CurrencyParameterSensitivities::class.java) - addToWhitelist(CurrencyParameterSensitivity::class.java) - addToWhitelist(DoubleArray::class.java) - addToWhitelist(CurveName::class.java) - addToWhitelist(TenorDateParameterMetadata::class.java) - addToWhitelist(Tenor::class.java) - } - return true - } - } -} diff --git a/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry b/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry index bdfd21fedd..e2faa7858d 100644 --- a/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry +++ b/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.core.node.CordaPluginRegistry @@ -1,2 +1,2 @@ # Register a ServiceLoader service extending from net.corda.core.node.CordaPluginRegistry -net.corda.vega.services.SimmService$Plugin +net.corda.vega.plugin.SimmPluginRegistry diff --git a/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.webserver.services.WebServerPluginRegistry b/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.webserver.services.WebServerPluginRegistry index 7fabecaa0c..95a1afd507 100644 --- a/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.webserver.services.WebServerPluginRegistry +++ b/samples/simm-valuation-demo/src/main/resources/META-INF/services/net.corda.webserver.services.WebServerPluginRegistry @@ -1,2 +1,2 @@ # Register a ServiceLoader service extending from net.corda.webserver.services.WebServerPluginRegistry -net.corda.vega.services.SimmService$Plugin \ No newline at end of file +net.corda.vega.plugin.SimmPlugin diff --git a/samples/simm-valuation-demo/src/test/kotlin/net/corda/vega/Main.kt b/samples/simm-valuation-demo/src/test/kotlin/net/corda/vega/Main.kt index b8a52397ee..2ad314e6c6 100644 --- a/samples/simm-valuation-demo/src/test/kotlin/net/corda/vega/Main.kt +++ b/samples/simm-valuation-demo/src/test/kotlin/net/corda/vega/Main.kt @@ -1,8 +1,8 @@ package net.corda.vega -import com.google.common.util.concurrent.Futures -import net.corda.core.getOrThrow +import net.corda.core.internal.concurrent.transpose import net.corda.core.node.services.ServiceInfo +import net.corda.core.utilities.getOrThrow import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B import net.corda.testing.DUMMY_BANK_C @@ -18,11 +18,11 @@ import net.corda.testing.driver.driver fun main(args: Array) { driver(dsl = { startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type))) - val (nodeA, nodeB, nodeC) = Futures.allAsList( + val (nodeA, nodeB, nodeC) = listOf( startNode(DUMMY_BANK_A.name), startNode(DUMMY_BANK_B.name), startNode(DUMMY_BANK_C.name) - ).getOrThrow() + ).transpose().getOrThrow() startWebserver(nodeA) startWebserver(nodeB) diff --git a/samples/trader-demo/build.gradle b/samples/trader-demo/build.gradle index 94dda10602..5cec069227 100644 --- a/samples/trader-demo/build.gradle +++ b/samples/trader-demo/build.gradle @@ -25,10 +25,10 @@ dependencies { compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" // Corda integration dependencies - compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') - compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') - compile project(':core') - compile project(':finance') + cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts') + cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') + cordaCompile project(':core') + cordaCompile project(':finance') // Corda Plugins: dependent flows and services cordapp project(':samples:bank-of-corda-demo') @@ -40,16 +40,17 @@ dependencies { task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar']) { ext.rpcUsers = [['username': "demo", 'password': "demo", 'permissions': [ - 'StartFlow.net.corda.flows.IssuerFlow$IssuanceRequester', - "StartFlow.net.corda.traderdemo.flow.SellerFlow" + 'StartFlow.net.corda.flows.CashIssueFlow', + 'StartFlow.net.corda.traderdemo.flow.CommercialPaperIssueFlow', + 'StartFlow.net.corda.traderdemo.flow.SellerFlow' ]]] directory "./build/nodes" // This name "Notary" is hard-coded into TraderDemoClientApi so if you change it here, change it there too. // In this demo the node that runs a standalone notary also acts as the network map server. - networkMap "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + networkMap "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" node { - name "CN=Notary Service,O=R3,OU=corda,L=London,C=GB" + name "CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH" advertisedServices = ["corda.notary.validating"] p2pPort 10002 cordapps = [] @@ -74,7 +75,9 @@ task deployNodes(type: net.corda.plugins.Cordform, dependsOn: ['jar']) { name "CN=BankOfCorda,O=R3,L=New York,C=US" advertisedServices = [] p2pPort 10011 + rpcPort 10012 cordapps = [] + rpcUsers = ext.rpcUsers } } @@ -102,11 +105,11 @@ publishing { } } -task runBuyer(type: JavaExec) { +task runBank(type: JavaExec) { classpath = sourceSets.main.runtimeClasspath main = 'net.corda.traderdemo.TraderDemoKt' args '--role' - args 'BUYER' + args 'BANK' } task runSeller(type: JavaExec) { diff --git a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt index a0198451a9..4955c54207 100644 --- a/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt +++ b/samples/trader-demo/src/integration-test/kotlin/net/corda/traderdemo/TraderDemoTest.kt @@ -1,22 +1,24 @@ package net.corda.traderdemo -import com.google.common.util.concurrent.Futures import net.corda.client.rpc.CordaRPCClient import net.corda.core.contracts.DOLLARS -import net.corda.core.getOrThrow -import net.corda.core.millis +import net.corda.core.utilities.millis import net.corda.core.node.services.ServiceInfo +import net.corda.core.internal.concurrent.transpose +import net.corda.core.utilities.getOrThrow +import net.corda.flows.CashIssueFlow import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B import net.corda.testing.DUMMY_NOTARY import net.corda.flows.IssuerFlow -import net.corda.testing.driver.poll import net.corda.node.services.startFlowPermission import net.corda.node.services.transactions.SimpleNotaryService import net.corda.nodeapi.User import net.corda.testing.BOC +import net.corda.testing.driver.poll import net.corda.testing.node.NodeBasedTest import net.corda.traderdemo.flow.BuyerFlow +import net.corda.traderdemo.flow.CommercialPaperIssueFlow import net.corda.traderdemo.flow.SellerFlow import org.assertj.core.api.Assertions.assertThat import org.junit.Test @@ -25,35 +27,38 @@ import java.util.concurrent.Executors class TraderDemoTest : NodeBasedTest() { @Test fun `runs trader demo`() { - val permissions = setOf( - startFlowPermission(), - startFlowPermission()) - val demoUser = listOf(User("demo", "demo", permissions)) - val user = User("user1", "test", permissions = setOf(startFlowPermission())) - val (nodeA, nodeB) = Futures.allAsList( - startNode(DUMMY_BANK_A.name, rpcUsers = demoUser), - startNode(DUMMY_BANK_B.name, rpcUsers = demoUser), - startNode(BOC.name, rpcUsers = listOf(user)), + val demoUser = User("demo", "demo", setOf(startFlowPermission())) + val bankUser = User("user1", "test", permissions = setOf(startFlowPermission(), + startFlowPermission())) + val (nodeA, nodeB, bankNode, notaryNode) = listOf( + startNode(DUMMY_BANK_A.name, rpcUsers = listOf(demoUser)), + startNode(DUMMY_BANK_B.name, rpcUsers = listOf(demoUser)), + startNode(BOC.name, rpcUsers = listOf(bankUser)), startNode(DUMMY_NOTARY.name, advertisedServices = setOf(ServiceInfo(SimpleNotaryService.type))) - ).getOrThrow() + ).transpose().getOrThrow() nodeA.registerInitiatedFlow(BuyerFlow::class.java) val (nodeARpc, nodeBRpc) = listOf(nodeA, nodeB).map { - val client = CordaRPCClient(it.configuration.rpcAddress!!) - client.start(demoUser[0].username, demoUser[0].password).proxy + val client = CordaRPCClient(it.configuration.rpcAddress!!, initialiseSerialization = false) + client.start(demoUser.username, demoUser.password).proxy + } + val nodeBankRpc = let { + val client = CordaRPCClient(bankNode.configuration.rpcAddress!!, initialiseSerialization = false) + client.start(bankUser.username, bankUser.password).proxy } val clientA = TraderDemoClientApi(nodeARpc) val clientB = TraderDemoClientApi(nodeBRpc) + val clientBank = TraderDemoClientApi(nodeBankRpc) val originalACash = clientA.cashCount // A has random number of issued amount val expectedBCash = clientB.cashCount + 1 val expectedPaper = listOf(clientA.commercialPaperCount + 1, clientB.commercialPaperCount) // TODO: Enable anonymisation - clientA.runBuyer(amount = 100.DOLLARS, anonymous = false) - clientB.runSeller(counterparty = nodeA.info.legalIdentity.name, amount = 5.DOLLARS) + clientBank.runIssuer(amount = 100.DOLLARS, buyerName = nodeA.info.legalIdentity.name, sellerName = nodeB.info.legalIdentity.name, notaryName = notaryNode.info.legalIdentity.name) + clientB.runSeller(buyerName = nodeA.info.legalIdentity.name, amount = 5.DOLLARS) assertThat(clientA.cashCount).isGreaterThan(originalACash) assertThat(clientB.cashCount).isEqualTo(expectedBCash) diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt index 7fba92d602..f7accdc4ae 100644 --- a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemo.kt @@ -4,8 +4,10 @@ import joptsimple.OptionParser import net.corda.client.rpc.CordaRPCClient import net.corda.core.contracts.DOLLARS import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.DUMMY_BANK_A import net.corda.core.utilities.loggerFor +import net.corda.testing.DUMMY_BANK_A +import net.corda.testing.DUMMY_BANK_B +import net.corda.testing.DUMMY_NOTARY import org.slf4j.Logger import kotlin.system.exitProcess @@ -18,12 +20,18 @@ fun main(args: Array) { private class TraderDemo { enum class Role { - BUYER, + BANK, SELLER } companion object { val logger: Logger = loggerFor() + val buyerName = DUMMY_BANK_A.name + val sellerName = DUMMY_BANK_B.name + val notaryName = DUMMY_NOTARY.name + val buyerRpcPort = 10006 + val sellerRpcPort = 10009 + val bankRpcPort = 10012 } fun main(args: Array) { @@ -41,15 +49,15 @@ private class TraderDemo { // What happens next depends on the role. The buyer sits around waiting for a trade to start. The seller role // will contact the buyer and actually make something happen. val role = options.valueOf(roleArg)!! - if (role == Role.BUYER) { - val host = NetworkHostAndPort("localhost", 10006) - CordaRPCClient(host).start("demo", "demo").use { - TraderDemoClientApi(it.proxy).runBuyer() + if (role == Role.BANK) { + val bankHost = NetworkHostAndPort("localhost", bankRpcPort) + CordaRPCClient(bankHost).use("demo", "demo") { + TraderDemoClientApi(it.proxy).runIssuer(1100.DOLLARS, buyerName, sellerName, notaryName) } } else { - val host = NetworkHostAndPort("localhost", 10009) - CordaRPCClient(host).use("demo", "demo") { - TraderDemoClientApi(it.proxy).runSeller(1000.DOLLARS, DUMMY_BANK_A.name) + val sellerHost = NetworkHostAndPort("localhost", sellerRpcPort) + CordaRPCClient(sellerHost).use("demo", "demo") { + TraderDemoClientApi(it.proxy).runSeller(1000.DOLLARS, buyerName) } } } diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemoClientApi.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemoClientApi.kt index 7f2ddbee4e..e8e5d0d924 100644 --- a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemoClientApi.kt +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/TraderDemoClientApi.kt @@ -1,23 +1,26 @@ package net.corda.traderdemo -import com.google.common.util.concurrent.Futures -import net.corda.client.rpc.notUsed import net.corda.contracts.CommercialPaper import net.corda.contracts.asset.Cash -import net.corda.testing.contracts.calculateRandomlySizedAmounts +import net.corda.contracts.getCashBalance import net.corda.core.contracts.Amount import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.USD -import net.corda.core.contracts.filterStatesOfType -import net.corda.core.getOrThrow +import net.corda.core.internal.Emoji +import net.corda.core.internal.concurrent.transpose import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.startFlow +import net.corda.core.messaging.vaultQueryBy +import net.corda.core.node.services.vault.QueryCriteria +import net.corda.core.node.services.vault.builder import net.corda.core.utilities.OpaqueBytes -import net.corda.core.utilities.Emoji +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor -import net.corda.flows.IssuerFlow.IssuanceRequester -import net.corda.testing.BOC +import net.corda.flows.CashIssueFlow +import net.corda.node.services.vault.VaultSchemaV1 import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.contracts.calculateRandomlySizedAmounts +import net.corda.traderdemo.flow.CommercialPaperIssueFlow import net.corda.traderdemo.flow.SellerFlow import org.bouncycastle.asn1.x500.X500Name import java.util.* @@ -30,39 +33,56 @@ class TraderDemoClientApi(val rpc: CordaRPCOps) { val logger = loggerFor() } - val cashCount: Int get() { - val (vault, vaultUpdates) = rpc.vaultAndUpdates() - vaultUpdates.notUsed() - return vault.filterStatesOfType().size + val cashCount: Long get() { + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + return rpc.vaultQueryBy(countCriteria).otherResults.single() as Long } - val dollarCashBalance: Amount get() = rpc.getCashBalances()[USD]!! + val dollarCashBalance: Amount get() = rpc.getCashBalance(USD) - val commercialPaperCount: Int get() { - val (vault, vaultUpdates) = rpc.vaultAndUpdates() - vaultUpdates.notUsed() - return vault.filterStatesOfType().size + val commercialPaperCount: Long get() { + val count = builder { VaultSchemaV1.VaultStates::recordedTime.count() } + val countCriteria = QueryCriteria.VaultCustomQueryCriteria(count) + return rpc.vaultQueryBy(countCriteria).otherResults.single() as Long } - fun runBuyer(amount: Amount = 30000.DOLLARS, anonymous: Boolean = true) { - val bankOfCordaParty = rpc.partyFromX500Name(BOC.name) - ?: throw IllegalStateException("Unable to locate ${BOC.name} in Network Map Service") + fun runIssuer(amount: Amount = 1100.0.DOLLARS, buyerName: X500Name, sellerName: X500Name, notaryName: X500Name) { + val ref = OpaqueBytes.of(1) + val buyer = rpc.partyFromX500Name(buyerName) ?: throw IllegalStateException("Don't know $buyerName") + val seller = rpc.partyFromX500Name(sellerName) ?: throw IllegalStateException("Don't know $sellerName") val notaryLegalIdentity = rpc.partyFromX500Name(DUMMY_NOTARY.name) ?: throw IllegalStateException("Unable to locate ${DUMMY_NOTARY.name} in Network Map Service") val notaryNode = rpc.nodeIdentityFromParty(notaryLegalIdentity) ?: throw IllegalStateException("Unable to locate notary node in network map cache") - val me = rpc.nodeIdentity() val amounts = calculateRandomlySizedAmounts(amount, 3, 10, Random()) - // issuer random amounts of currency totaling 30000.DOLLARS in parallel + val anonymous = false + // issue random amounts of currency up to the requested amount, in parallel val resultFutures = amounts.map { pennies -> - rpc.startFlow(::IssuanceRequester, Amount(pennies, amount.token), me.legalIdentity, OpaqueBytes.of(1), bankOfCordaParty, notaryNode.notaryIdentity, anonymous).returnValue + rpc.startFlow(::CashIssueFlow, amount.copy(quantity = pennies), OpaqueBytes.of(1), buyer, notaryNode.notaryIdentity, anonymous).returnValue } - Futures.allAsList(resultFutures).getOrThrow() + resultFutures.transpose().getOrThrow() + println("Cash issued to buyer") + + // The CP sale transaction comes with a prospectus PDF, which will tag along for the ride in an + // attachment. Make sure we have the transaction prospectus attachment loaded into our store. + // + // This can also be done via an HTTP upload, but here we short-circuit and do it from code. + if (!rpc.attachmentExists(SellerFlow.PROSPECTUS_HASH)) { + javaClass.classLoader.getResourceAsStream("bank-of-london-cp.jar").use { + val id = rpc.uploadAttachment(it) + check(SellerFlow.PROSPECTUS_HASH == id) + } + } + + // The line below blocks and waits for the future to resolve. + val stx = rpc.startFlow(::CommercialPaperIssueFlow, amount, ref, seller, notaryNode.notaryIdentity).returnValue.getOrThrow() + println("Commercial paper issued to seller") } - fun runSeller(amount: Amount = 1000.0.DOLLARS, counterparty: X500Name) { - val otherParty = rpc.partyFromX500Name(counterparty) ?: throw IllegalStateException("Don't know $counterparty") + fun runSeller(amount: Amount = 1000.0.DOLLARS, buyerName: X500Name) { + val otherParty = rpc.partyFromX500Name(buyerName) ?: throw IllegalStateException("Don't know $buyerName") // The seller will sell some commercial paper to the buyer, who will pay with (self issued) cash. // // The CP sale transaction comes with a prospectus PDF, which will tag along for the ride in an diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt index 10f2e030f0..4b924af833 100644 --- a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/BuyerFlow.kt @@ -2,14 +2,15 @@ package net.corda.traderdemo.flow import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.CommercialPaper +import net.corda.contracts.getCashBalances import net.corda.core.contracts.Amount import net.corda.core.contracts.TransactionGraphSearch import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatedBy import net.corda.core.identity.Party +import net.corda.core.internal.Emoji import net.corda.core.node.NodeInfo import net.corda.core.transactions.SignedTransaction -import net.corda.core.utilities.Emoji import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.unwrap import net.corda.flows.TwoPartyTradeFlow @@ -37,8 +38,6 @@ class BuyerFlow(val otherParty: Party) : FlowLogic() { // This invokes the trading flow and out pops our finished transaction. val tradeTX: SignedTransaction = subFlow(buyer) - // TODO: This should be moved into the flow itself. - serviceHub.recordTransactions(tradeTX) println("Purchase complete - we are a happy customer! Final transaction is: " + "\n\n${Emoji.renderIfSupported(tradeTX.tx)}") @@ -48,7 +47,7 @@ class BuyerFlow(val otherParty: Party) : FlowLogic() { } private fun logBalance() { - val balances = serviceHub.vaultService.cashBalances.entries.map { "${it.key.currencyCode} ${it.value}" } + val balances = serviceHub.getCashBalances().entries.map { "${it.key.currencyCode} ${it.value}" } println("Remaining balance: ${balances.joinToString()}") } diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/CommercialPaperIssueFlow.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/CommercialPaperIssueFlow.kt new file mode 100644 index 0000000000..56a4b4696d --- /dev/null +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/CommercialPaperIssueFlow.kt @@ -0,0 +1,76 @@ +package net.corda.traderdemo.flow + +import co.paralleluniverse.fibers.Suspendable +import net.corda.contracts.CommercialPaper +import net.corda.contracts.asset.DUMMY_CASH_ISSUER +import net.corda.core.contracts.Amount +import net.corda.core.contracts.`issued by` +import net.corda.core.crypto.SecureHash +import net.corda.core.flows.FinalityFlow +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.identity.Party +import net.corda.core.node.NodeInfo +import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.ProgressTracker +import net.corda.core.utilities.days +import net.corda.core.utilities.seconds +import java.time.Instant +import java.util.* + +/** + * Flow for the Bank of Corda node to issue some commercial paper to the seller's node, to sell to the buyer. + */ +@InitiatingFlow +@StartableByRPC +class CommercialPaperIssueFlow(val amount: Amount, + val issueRef: OpaqueBytes, + val recipient: Party, + val notary: Party, + override val progressTracker: ProgressTracker) : FlowLogic() { + constructor(amount: Amount, issueRef: OpaqueBytes, recipient: Party, notary: Party) : this(amount, issueRef, recipient, notary, tracker()) + + companion object { + val PROSPECTUS_HASH = SecureHash.parse("decd098666b9657314870e192ced0c3519c2c9d395507a238338f8d003929de9") + object ISSUING : ProgressTracker.Step("Issuing and timestamping some commercial paper") + fun tracker() = ProgressTracker(ISSUING) + } + + @Suspendable + override fun call(): SignedTransaction { + progressTracker.currentStep = ISSUING + + val me = serviceHub.myInfo.legalIdentity + val issuance: SignedTransaction = run { + val tx = CommercialPaper().generateIssue(me.ref(issueRef), amount `issued by` me.ref(issueRef), + Instant.now() + 10.days, notary) + + // TODO: Consider moving these two steps below into generateIssue. + + // Attach the prospectus. + tx.addAttachment(serviceHub.attachments.openAttachment(PROSPECTUS_HASH)!!.id) + + // Requesting a time-window to be set, all CP must have a validation window. + tx.setTimeWindow(Instant.now(), 30.seconds) + + // Sign it as ourselves. + val stx = serviceHub.signInitialTransaction(tx) + + subFlow(FinalityFlow(stx)).single() + } + + // Now make a dummy transaction that moves it to a new key, just to show that resolving dependencies works. + val move: SignedTransaction = run { + val builder = TransactionBuilder(notary) + CommercialPaper().generateMove(builder, issuance.tx.outRef(0), recipient) + val stx = serviceHub.signInitialTransaction(builder) + subFlow(FinalityFlow(stx)).single() + } + + return move + } + +} diff --git a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/SellerFlow.kt b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/SellerFlow.kt index 720f33ffe5..c51edf30ca 100644 --- a/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/SellerFlow.kt +++ b/samples/trader-demo/src/main/kotlin/net/corda/traderdemo/flow/SellerFlow.kt @@ -2,26 +2,17 @@ package net.corda.traderdemo.flow import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.CommercialPaper -import net.corda.contracts.asset.DUMMY_CASH_ISSUER -import net.corda.core.contracts.* +import net.corda.core.contracts.Amount import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.generateKeyPair -import net.corda.core.days import net.corda.core.flows.FlowLogic import net.corda.core.flows.InitiatingFlow import net.corda.core.flows.StartableByRPC -import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.node.NodeInfo -import net.corda.core.seconds import net.corda.core.transactions.SignedTransaction import net.corda.core.utilities.ProgressTracker -import net.corda.flows.FinalityFlow -import net.corda.flows.NotaryFlow import net.corda.flows.TwoPartyTradeFlow -import net.corda.testing.BOC -import java.time.Instant import java.util.* @InitiatingFlow @@ -52,7 +43,8 @@ class SellerFlow(val otherParty: Party, val notary: NodeInfo = serviceHub.networkMapCache.notaryNodes[0] val cpOwnerKey = serviceHub.keyManagementService.freshKey() - val commercialPaper = selfIssueSomeCommercialPaper(serviceHub.myInfo.legalIdentity, notary) + val commercialPaper = serviceHub.vaultQueryService.queryBy(CommercialPaper.State::class.java).states.first() + progressTracker.currentStep = TRADING @@ -67,39 +59,4 @@ class SellerFlow(val otherParty: Party, progressTracker.getChildProgressTracker(TRADING)!!) return subFlow(seller) } - - @Suspendable - fun selfIssueSomeCommercialPaper(ownedBy: AbstractParty, notaryNode: NodeInfo): StateAndRef { - // Make a fake company that's issued its own paper. - val party = Party(BOC.name, serviceHub.legalIdentityKey) - - val issuance: SignedTransaction = run { - val tx = CommercialPaper().generateIssue(party.ref(1, 2, 3), 1100.DOLLARS `issued by` DUMMY_CASH_ISSUER, - Instant.now() + 10.days, notaryNode.notaryIdentity) - - // TODO: Consider moving these two steps below into generateIssue. - - // Attach the prospectus. - tx.addAttachment(serviceHub.attachments.openAttachment(PROSPECTUS_HASH)!!.id) - - // Requesting a time-window to be set, all CP must have a validation window. - tx.setTimeWindow(Instant.now(), 30.seconds) - - // Sign it as ourselves. - val stx = serviceHub.signInitialTransaction(tx) - - subFlow(FinalityFlow(stx)).single() - } - - // Now make a dummy transaction that moves it to a new key, just to show that resolving dependencies works. - val move: SignedTransaction = run { - val builder = TransactionType.General.Builder(notaryNode.notaryIdentity) - CommercialPaper().generateMove(builder, issuance.tx.outRef(0), ownedBy) - val stx = serviceHub.signInitialTransaction(builder) - subFlow(FinalityFlow(stx)).single() - } - - return move.tx.outRef(0) - } - } diff --git a/samples/trader-demo/src/test/kotlin/net/corda/traderdemo/Main.kt b/samples/trader-demo/src/test/kotlin/net/corda/traderdemo/Main.kt index 544055d907..d376d21ca8 100644 --- a/samples/trader-demo/src/test/kotlin/net/corda/traderdemo/Main.kt +++ b/samples/trader-demo/src/test/kotlin/net/corda/traderdemo/Main.kt @@ -1,6 +1,6 @@ package net.corda.traderdemo -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.node.services.ServiceInfo import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_BANK_B diff --git a/settings.gradle b/settings.gradle index a2405379d1..9fe7de00d0 100644 --- a/settings.gradle +++ b/settings.gradle @@ -27,6 +27,7 @@ include 'tools:explorer' include 'tools:explorer:capsule' include 'tools:demobench' include 'tools:loadtest' +include 'tools:graphs' include 'docs/source/example-code' // Note that we are deliberately choosing to use '/' here. With ':' gradle would treat the directories as actual projects. include 'samples:attachment-demo' include 'samples:trader-demo' diff --git a/smoke-test-utils/src/main/kotlin/net/corda/smoketesting/NodeProcess.kt b/smoke-test-utils/src/main/kotlin/net/corda/smoketesting/NodeProcess.kt index f7a56f2be9..6ab3de90d7 100644 --- a/smoke-test-utils/src/main/kotlin/net/corda/smoketesting/NodeProcess.kt +++ b/smoke-test-utils/src/main/kotlin/net/corda/smoketesting/NodeProcess.kt @@ -2,8 +2,8 @@ package net.corda.smoketesting import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCConnection -import net.corda.core.createDirectories -import net.corda.core.div +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.loggerFor import java.nio.file.Path diff --git a/test-common/build.gradle b/test-common/build.gradle index 3a955f1691..05c6d36b20 100644 --- a/test-common/build.gradle +++ b/test-common/build.gradle @@ -1,9 +1,10 @@ apply plugin: 'net.corda.plugins.publish-utils' +apply plugin: 'com.jfrog.artifactory' jar { baseName 'corda-test-common' } publish { - name = jar.baseName + name jar.baseName } diff --git a/test-common/src/main/resources/log4j2-test.xml b/test-common/src/main/resources/log4j2-test.xml index 739a4d6a17..222a4a1778 100644 --- a/test-common/src/main/resources/log4j2-test.xml +++ b/test-common/src/main/resources/log4j2-test.xml @@ -5,7 +5,7 @@ - + @@ -22,5 +22,8 @@ + 
 + 
 + diff --git a/test-utils/build.gradle b/test-utils/build.gradle index f6801db7ef..c1ac72f456 100644 --- a/test-utils/build.gradle +++ b/test-utils/build.gradle @@ -30,10 +30,8 @@ sourceSets { dependencies { compile project(':test-common') - compile project(':finance') compile project(':core') compile project(':node') - compile project(':webserver') compile project(':client:mock') compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" @@ -67,5 +65,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } diff --git a/test-utils/src/integration-test/kotlin/net/corda/testing/FlowStackSnapshotTest.kt b/test-utils/src/integration-test/kotlin/net/corda/testing/FlowStackSnapshotTest.kt new file mode 100644 index 0000000000..613ee0b0ca --- /dev/null +++ b/test-utils/src/integration-test/kotlin/net/corda/testing/FlowStackSnapshotTest.kt @@ -0,0 +1,297 @@ +package net.corda.testing + +import co.paralleluniverse.fibers.Suspendable +import com.fasterxml.jackson.databind.ObjectMapper +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowStackSnapshot +import net.corda.core.flows.StartableByRPC +import net.corda.core.messaging.startFlow +import net.corda.core.serialization.CordaSerializable +import net.corda.node.services.startFlowPermission +import net.corda.nodeapi.User +import net.corda.testing.driver.driver +import org.junit.Ignore +import org.junit.Test +import java.io.File +import java.nio.file.Path +import java.time.LocalDateTime +import java.time.format.DateTimeFormatter +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +@CordaSerializable +data class StackSnapshotFrame(val method: String, val clazz: String, val dataTypes: List, val flowId: String? = null) + +/** + * Calculates the count of full and empty frames. We consider frame to be empty if there is no stack data + * associated with it (i.e. the stackObjects is an empty list). Otherwise (i.e. when the stackObjects is not + * an empty list the frame is considered to be full. + */ +fun convertToStackSnapshotFrames(snapshot: FlowStackSnapshot): List { + return snapshot.stackFrames.map { + val dataTypes = it.stackObjects.map { + if (it == null) null else it::class.qualifiedName + } + val stackTraceElement = it.stackTraceElement!! + StackSnapshotFrame(stackTraceElement.methodName, stackTraceElement.className, dataTypes) + } +} + +/** + * Flow that during its execution performs calls with side effects in terms of Quasar. The presence of + * side effect calls drives Quasar decision on stack optimisation application. The stack optimisation method aims + * to reduce the amount of data stored on Quasar stack to minimum and is based on static code analyses performed during + * the code instrumentation phase, during which Quasar checks if a method performs side effect calls. If not, + * the method is annotated to be optimised, meaning that none of its local variables are stored on the stack and + * during the runtime the method can be replayed with a guarantee to be idempotent. + */ +@StartableByRPC +class SideEffectFlow : FlowLogic>() { + var sideEffectField = "" + + @Suspendable + override fun call(): List { + sideEffectField = "sideEffectInCall" + // Expected to be on stack + @Suppress("UNUSED_VARIABLE") + val unusedVar = Constants.IN_CALL_VALUE + val numberOfFullFrames = retrieveStackSnapshot() + return numberOfFullFrames + } + + @Suspendable + fun retrieveStackSnapshot(): List { + sideEffectField = "sideEffectInRetrieveStackSnapshot" + // Expected to be on stack + @Suppress("UNUSED_VARIABLE") + val unusedVar = Constants.IN_RETRIEVE_STACK_SNAPSHOT_VALUE + val snapshot = flowStackSnapshot() + return convertToStackSnapshotFrames(snapshot!!) + } + +} + +/** + * Flow that during its execution performs calls with no side effects in terms of Quasar. + * Thus empty frames are expected on in the stack snapshot as Quasar will optimise. + */ +@StartableByRPC +class NoSideEffectFlow : FlowLogic>() { + + @Suspendable + override fun call(): List { + // Using the [Constants] object here is considered by Quasar as a side effect. Thus explicit initialization + @Suppress("UNUSED_VARIABLE") + val unusedVar = "inCall" + val numberOfFullFrames = retrieveStackSnapshot() + return numberOfFullFrames + } + + @Suspendable + fun retrieveStackSnapshot(): List { + // Using the [Constants] object here is considered by Quasar as a side effect. Thus explicit initialization + @Suppress("UNUSED_VARIABLE") + val unusedVar = "inRetrieveStackSnapshot" + val snapshot = flowStackSnapshot() + return convertToStackSnapshotFrames(snapshot!!) + } +} + +object Constants { + val IN_PERSIST_VALUE = "inPersist" + val IN_CALL_VALUE = "inCall" + val IN_RETRIEVE_STACK_SNAPSHOT_VALUE = "inRetrieveStackSnapshot" + val USER = "User" + val PASSWORD = "Password" + +} + +/** + * No side effect flow that stores the partial snapshot into a file, path to which is passed in the flow constructor. + */ +@StartableByRPC +class PersistingNoSideEffectFlow : FlowLogic() { + + @Suspendable + override fun call(): String { + // Using the [Constants] object here is considered by Quasar as a side effect. Thus explicit initialization + @Suppress("UNUSED_VARIABLE") + val unusedVar = "inCall" + persist() + return stateMachine.id.toString() + } + + @Suspendable + fun persist() { + // Using the [Constants] object here is considered by Quasar as a side effect. Thus explicit initialization + @Suppress("UNUSED_VARIABLE") + val unusedVar = "inPersist" + persistFlowStackSnapshot() + } +} + +/** + * Flow with side effects that stores the partial snapshot into a file, path to which is passed in the flow constructor. + */ +@StartableByRPC +class PersistingSideEffectFlow : FlowLogic() { + + @Suspendable + override fun call(): String { + @Suppress("UNUSED_VARIABLE") + val unusedVar = Constants.IN_CALL_VALUE + persist() + return stateMachine.id.toString() + } + + @Suspendable + fun persist() { + @Suppress("UNUSED_VARIABLE") + val unusedVar = Constants.IN_PERSIST_VALUE + persistFlowStackSnapshot() + } +} + +/** + * Similar to [PersistingSideEffectFlow] but aims to produce multiple snapshot files. + */ +@StartableByRPC +class MultiplePersistingSideEffectFlow(val persistCallCount: Int) : FlowLogic() { + + @Suspendable + override fun call(): String { + @Suppress("UNUSED_VARIABLE") + val unusedVar = Constants.IN_CALL_VALUE + for (i in 1..persistCallCount) { + persist() + } + return stateMachine.id.toString() + } + + @Suspendable + fun persist() { + @Suppress("UNUSED_VARIABLE") + val unusedVar = Constants.IN_PERSIST_VALUE + persistFlowStackSnapshot() + } +} + +fun readFlowStackSnapshotFromDir(baseDir: Path, flowId: String): FlowStackSnapshot { + val snapshotFile = File(baseDir.toFile(), "flowStackSnapshots/${LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE)}/$flowId/flowStackSnapshot.json") + return ObjectMapper().readValue(snapshotFile.inputStream(), FlowStackSnapshot::class.java) +} + +fun countFilesInDir(baseDir: Path, flowId: String): Int { + val flowDir = File(baseDir.toFile(), "flowStackSnapshots/${LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE)}/$flowId/") + return flowDir.listFiles().size +} + +fun assertFrame(expectedMethod: String, expectedEmpty: Boolean, frame: StackSnapshotFrame) { + assertEquals(expectedMethod, frame.method) + assertEquals(expectedEmpty, frame.dataTypes.isEmpty()) +} + +class FlowStackSnapshotTest { + + @Test + @Ignore("This test is skipped due to Jacoco agent interference with the quasar instrumentation process. " + + "This violates tested criteria (specifically: extra objects are introduced to the quasar stack by th Jacoco agent)") + fun `flowStackSnapshot contains full frames when methods with side effects are called`() { + driver(startNodesInProcess = true) { + val a = startNode(rpcUsers = listOf(User(Constants.USER, Constants.PASSWORD, setOf(startFlowPermission())))).get() + a.rpcClientToNode().use(Constants.USER, Constants.PASSWORD) { connection -> + val stackSnapshotFrames = connection.proxy.startFlow(::SideEffectFlow).returnValue.get() + val iterator = stackSnapshotFrames.listIterator() + assertFrame("run", false, iterator.next()) + assertFrame("call", false, iterator.next()) + assertFrame("retrieveStackSnapshot", false, iterator.next()) + assertFrame("flowStackSnapshot", false, iterator.next()) + } + } + } + + @Test + @Ignore("This test is skipped due to Jacoco agent interference with the quasar instrumentation process. " + + "This violates tested criteria (specifically extra objects are introduced to the quasar stack by th Jacoco agent)") + fun `flowStackSnapshot contains empty frames when methods with no side effects are called`() { + driver(startNodesInProcess = true) { + val a = startNode(rpcUsers = listOf(User(Constants.USER, Constants.PASSWORD, setOf(startFlowPermission())))).get() + a.rpcClientToNode().use(Constants.USER, Constants.PASSWORD) { connection -> + val stackSnapshotFrames = connection.proxy.startFlow(::NoSideEffectFlow).returnValue.get() + val iterator = stackSnapshotFrames.listIterator() + assertFrame("run", false, iterator.next()) + assertFrame("call", true, iterator.next()) + assertFrame("retrieveStackSnapshot", true, iterator.next()) + assertFrame("flowStackSnapshot", false, iterator.next()) + } + } + } + + @Test + @Ignore("This test is skipped due to Jacoco agent interference with the quasar instrumentation process. " + + "This violates tested criteria (specifically extra objects are introduced to the quasar stack by th Jacoco agent)") + fun `persistFlowStackSnapshot persists empty frames to a file when methods with no side effects are called`() { + driver(startNodesInProcess = true) { + val a = startNode(rpcUsers = listOf(User(Constants.USER, Constants.PASSWORD, setOf(startFlowPermission())))).get() + + a.rpcClientToNode().use(Constants.USER, Constants.PASSWORD) { connection -> + val flowId = connection.proxy.startFlow(::PersistingNoSideEffectFlow).returnValue.get() + val snapshotFromFile = readFlowStackSnapshotFromDir(a.configuration.baseDirectory, flowId) + val stackSnapshotFrames = convertToStackSnapshotFrames(snapshotFromFile) + val iterator = stackSnapshotFrames.listIterator() + assertFrame("call", true, iterator.next()) + assertFrame("persist", true, iterator.next()) + assertFrame("persistFlowStackSnapshot", false, iterator.next()) + } + } + } + + @Test + @Ignore("This test is skipped due to Jacoco agent interference with the quasar instrumentation process. " + + "This violates tested criteria (specifically extra objects are introduced to the quasar stack by th Jacoco agent)") + fun `persistFlowStackSnapshot persists multiple snapshots in different files`() { + driver(startNodesInProcess = true) { + val a = startNode(rpcUsers = listOf(User(Constants.USER, Constants.PASSWORD, setOf(startFlowPermission())))).get() + + a.rpcClientToNode().use(Constants.USER, Constants.PASSWORD) { connection -> + val numberOfFlowSnapshots = 5 + val flowId = connection.proxy.startFlow(::MultiplePersistingSideEffectFlow, 5).returnValue.get() + val fileCount = countFilesInDir(a.configuration.baseDirectory, flowId) + assertEquals(numberOfFlowSnapshots, fileCount) + } + } + } + + @Test + @Ignore("This test is skipped due to Jacoco agent interference with the quasar instrumentation process. " + + "This violates tested criteria (specifically extra objects are introduced to the quasar stack by th Jacoco agent)") + fun `persistFlowStackSnapshot stack traces are aligned with stack objects`() { + driver(startNodesInProcess = true) { + val a = startNode(rpcUsers = listOf(User(Constants.USER, Constants.PASSWORD, setOf(startFlowPermission())))).get() + + a.rpcClientToNode().use(Constants.USER, Constants.PASSWORD) { connection -> + val flowId = connection.proxy.startFlow(::PersistingSideEffectFlow).returnValue.get() + val snapshotFromFile = readFlowStackSnapshotFromDir(a.configuration.baseDirectory, flowId) + var inCallCount = 0 + var inPersistCount = 0 + snapshotFromFile.stackFrames.forEach { + val trace = it.stackTraceElement + it.stackObjects.forEach { + when (it) { + Constants.IN_CALL_VALUE -> { + assertEquals(PersistingSideEffectFlow::call.name, trace!!.methodName) + inCallCount++ + } + Constants.IN_PERSIST_VALUE -> { + assertEquals(PersistingSideEffectFlow::persist.name, trace!!.methodName) + inPersistCount++ + } + } + } + } + assertTrue(inCallCount > 0) + assertTrue(inPersistCount > 0) + } + } + } +} diff --git a/test-utils/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt b/test-utils/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt index a6ba60fc47..acec6efcb2 100644 --- a/test-utils/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt +++ b/test-utils/src/integration-test/kotlin/net/corda/testing/driver/DriverTests.kt @@ -1,11 +1,11 @@ package net.corda.testing.driver -import com.google.common.util.concurrent.ListenableFuture -import net.corda.core.div -import net.corda.core.getOrThrow -import net.corda.core.list +import net.corda.core.concurrent.CordaFuture +import net.corda.core.internal.div +import net.corda.core.internal.list import net.corda.core.node.services.ServiceInfo -import net.corda.core.readLines +import net.corda.core.internal.readLines +import net.corda.core.utilities.getOrThrow import net.corda.testing.DUMMY_BANK_A import net.corda.testing.DUMMY_NOTARY import net.corda.testing.DUMMY_REGULATOR @@ -24,7 +24,7 @@ class DriverTests { private val executorService: ScheduledExecutorService = Executors.newScheduledThreadPool(2) - private fun nodeMustBeUp(handleFuture: ListenableFuture) = handleFuture.getOrThrow().apply { + private fun nodeMustBeUp(handleFuture: CordaFuture) = handleFuture.getOrThrow().apply { val hostAndPort = nodeInfo.addresses.first() // Check that the port is bound addressMustBeBound(executorService, hostAndPort, (this as? NodeHandle.OutOfProcess)?.process) diff --git a/test-utils/src/main/kotlin/net/corda/testing/AlwaysSucceedContract.kt b/test-utils/src/main/kotlin/net/corda/testing/AlwaysSucceedContract.kt index 77f5a358fc..ff26ed9ae8 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/AlwaysSucceedContract.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/AlwaysSucceedContract.kt @@ -1,10 +1,10 @@ package net.corda.testing import net.corda.core.contracts.Contract -import net.corda.core.contracts.TransactionForContract import net.corda.core.crypto.SecureHash +import net.corda.core.transactions.LedgerTransaction class AlwaysSucceedContract(override val legalContractReference: SecureHash = SecureHash.sha256("Always succeed contract")) : Contract { - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { } } diff --git a/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt b/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt index 6c36d638c6..ed759632be 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/CoreTestUtils.kt @@ -11,16 +11,19 @@ import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate import net.corda.core.node.ServiceHub import net.corda.core.node.services.IdentityService -import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.TransactionBuilder import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.OpaqueBytes import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.VerifierType import net.corda.node.services.config.configureDevKeyAndTrustStores import net.corda.node.services.identity.InMemoryIdentityService +import net.corda.node.utilities.CertificateType +import net.corda.node.utilities.X509Utilities import net.corda.nodeapi.config.SSLConfiguration import net.corda.testing.node.MockServices import net.corda.testing.node.makeTestDataSourceProperties +import net.corda.testing.node.makeTestDatabaseProperties import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500NameBuilder import org.bouncycastle.asn1.x500.style.BCStyle @@ -66,9 +69,9 @@ val ALICE_PUBKEY: PublicKey get() = ALICE_KEY.public val BOB_PUBKEY: PublicKey get() = BOB_KEY.public val CHARLIE_PUBKEY: PublicKey get() = CHARLIE_KEY.public -val MEGA_CORP_IDENTITY: PartyAndCertificate get() = getTestPartyAndCertificate(X509Utilities.getX509Name("MegaCorp","London","demo@r3.com",null), MEGA_CORP_PUBKEY) +val MEGA_CORP_IDENTITY: PartyAndCertificate get() = getTestPartyAndCertificate(getX509Name("MegaCorp", "London", "demo@r3.com", null), MEGA_CORP_PUBKEY) val MEGA_CORP: Party get() = MEGA_CORP_IDENTITY.party -val MINI_CORP_IDENTITY: PartyAndCertificate get() = getTestPartyAndCertificate(X509Utilities.getX509Name("MiniCorp","London","demo@r3.com",null), MINI_CORP_PUBKEY) +val MINI_CORP_IDENTITY: PartyAndCertificate get() = getTestPartyAndCertificate(getX509Name("MiniCorp", "London", "demo@r3.com", null), MINI_CORP_PUBKEY) val MINI_CORP: Party get() = MINI_CORP_IDENTITY.party val BOC_KEY: KeyPair by lazy { generateKeyPair() } @@ -79,14 +82,14 @@ val BOC_PARTY_REF = BOC.ref(OpaqueBytes.of(1)).reference val BIG_CORP_KEY: KeyPair by lazy { generateKeyPair() } val BIG_CORP_PUBKEY: PublicKey get() = BIG_CORP_KEY.public -val BIG_CORP_IDENTITY: PartyAndCertificate get() = getTestPartyAndCertificate(X509Utilities.getX509Name("BigCorporation","London","demo@r3.com",null), BIG_CORP_PUBKEY) +val BIG_CORP_IDENTITY: PartyAndCertificate get() = getTestPartyAndCertificate(getX509Name("BigCorporation", "London", "demo@r3.com", null), BIG_CORP_PUBKEY) val BIG_CORP: Party get() = BIG_CORP_IDENTITY.party val BIG_CORP_PARTY_REF = BIG_CORP.ref(OpaqueBytes.of(1)).reference val ALL_TEST_KEYS: List get() = listOf(MEGA_CORP_KEY, MINI_CORP_KEY, ALICE_KEY, BOB_KEY, DUMMY_NOTARY_KEY) val MOCK_IDENTITIES = listOf(MEGA_CORP_IDENTITY, MINI_CORP_IDENTITY, DUMMY_NOTARY_IDENTITY) -val MOCK_IDENTITY_SERVICE: IdentityService get() = InMemoryIdentityService(MOCK_IDENTITIES, emptyMap(), DUMMY_CA.certificate.cert) +val MOCK_IDENTITY_SERVICE: IdentityService get() = InMemoryIdentityService(MOCK_IDENTITIES, emptySet(), DUMMY_CA.certificate.cert) val MOCK_HOST_AND_PORT = NetworkHostAndPort("mockHost", 30000) @@ -126,11 +129,17 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List.() -> Unit ): LedgerDSL { - val ledgerDsl = LedgerDSL(TestLedgerDSLInterpreter(services)) - dsl(ledgerDsl) - return ledgerDsl + if (initialiseSerialization) initialiseTestSerialization() + try { + val ledgerDsl = LedgerDSL(TestLedgerDSLInterpreter(services)) + dsl(ledgerDsl) + return ledgerDsl + } finally { + if (initialiseSerialization) resetTestSerialization() + } } /** @@ -141,8 +150,9 @@ fun getFreeLocalPorts(hostName: String, numberToAlloc: Int): List.() -> EnforceVerifyOrFail -) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) } +) = ledger(initialiseSerialization = initialiseSerialization) { this.transaction(transactionLabel, transactionBuilder, dsl) } fun testNodeConfiguration( baseDirectory: Path, @@ -156,6 +166,7 @@ fun testNodeConfiguration( whenever(nc.trustStorePassword).thenReturn("trustpass") whenever(nc.rpcUsers).thenReturn(emptyList()) whenever(nc.dataSourceProperties).thenReturn(makeTestDataSourceProperties(myLegalName.commonName)) + whenever(nc.database).thenReturn(makeTestDatabaseProperties()) whenever(nc.emailAddress).thenReturn("") whenever(nc.exportJMXto).thenReturn("") whenever(nc.devMode).thenReturn(true) diff --git a/test-utils/src/main/kotlin/net/corda/testing/Expect.kt b/test-utils/src/main/kotlin/net/corda/testing/Expect.kt index 697c041684..1128186c67 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/Expect.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/Expect.kt @@ -1,7 +1,7 @@ package net.corda.testing import com.google.common.util.concurrent.SettableFuture -import net.corda.core.getOrThrow +import net.corda.core.utilities.getOrThrow import org.slf4j.Logger import org.slf4j.LoggerFactory import rx.Observable @@ -78,6 +78,11 @@ fun sequence(expectations: List>): ExpectCompose = Expec */ fun parallel(vararg expectations: ExpectCompose): ExpectCompose = ExpectCompose.Parallel(listOf(*expectations)) +/** + * Tests that events arrive in unspecified order. + * + * @param expectations The pieces of DSL all of which should run but in an unspecified order depending on what sequence events arrive. + */ fun parallel(expectations: List>): ExpectCompose = ExpectCompose.Parallel(expectations) /** diff --git a/test-utils/src/main/kotlin/net/corda/testing/FlowStackSnapshot.kt b/test-utils/src/main/kotlin/net/corda/testing/FlowStackSnapshot.kt new file mode 100644 index 0000000000..13ec465284 --- /dev/null +++ b/test-utils/src/main/kotlin/net/corda/testing/FlowStackSnapshot.kt @@ -0,0 +1,190 @@ +package net.corda.testing + +import co.paralleluniverse.fibers.Fiber +import co.paralleluniverse.fibers.Instrumented +import co.paralleluniverse.fibers.Stack +import co.paralleluniverse.fibers.Suspendable +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.SerializationFeature +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowStackSnapshot +import net.corda.core.flows.FlowStackSnapshot.Frame +import net.corda.core.flows.FlowStackSnapshotFactory +import net.corda.core.flows.StackFrameDataToken +import net.corda.core.internal.FlowStateMachine +import net.corda.core.serialization.SerializeAsToken +import java.io.File +import java.nio.file.Path +import java.time.LocalDateTime +import java.time.format.DateTimeFormatter + +class FlowStackSnapshotFactoryImpl : FlowStackSnapshotFactory { + @Suspendable + override fun getFlowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot? { + var snapshot: FlowStackSnapshot? = null + val stackTrace = Fiber.currentFiber().stackTrace + Fiber.parkAndSerialize { fiber, _ -> + snapshot = extractStackSnapshotFromFiber(fiber, stackTrace.toList(), flowClass) + Fiber.unparkDeserialized(fiber, fiber.scheduler) + } + // This is because the dump itself is on the stack, which means it creates a loop in the object graph, we set + // it to null to break the loop + val temporarySnapshot = snapshot + snapshot = null + return temporarySnapshot!! + } + + override fun persistAsJsonFile(flowClass: Class<*>, baseDir: Path, flowId: String) { + val flowStackSnapshot = getFlowStackSnapshot(flowClass) + val mapper = ObjectMapper() + mapper.disable(SerializationFeature.FAIL_ON_EMPTY_BEANS) + mapper.enable(SerializationFeature.INDENT_OUTPUT) + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + val file = createFile(baseDir, flowId) + file.bufferedWriter().use { out -> + mapper.writeValue(out, filterOutStackDump(flowStackSnapshot!!)) + } + } + + private fun extractStackSnapshotFromFiber(fiber: Fiber<*>, stackTrace: List, flowClass: Class<*>): FlowStackSnapshot { + val stack = getFiberStack(fiber) + val objectStack = getObjectStack(stack).toList() + val frameOffsets = getFrameOffsets(stack) + val frameObjects = frameOffsets.map { (frameOffset, frameSize) -> + objectStack.subList(frameOffset + 1, frameOffset + frameSize + 1) + } + // We drop the first element as it is corda internal call irrelevant from the perspective of a CordApp developer + val relevantStackTrace = removeConstructorStackTraceElements(stackTrace).drop(1) + val stackTraceToAnnotation = relevantStackTrace.map { + val element = StackTraceElement(it.className, it.methodName, it.fileName, it.lineNumber) + element to getInstrumentedAnnotation(element) + } + val frameObjectsIterator = frameObjects.listIterator() + val frames = stackTraceToAnnotation.reversed().map { (element, annotation) -> + // If annotation is null then the case indicates that this is an entry point - i.e. + // the net.corda.node.services.statemachine.FlowStateMachineImpl.run method + if (frameObjectsIterator.hasNext() && (annotation == null || !annotation.methodOptimized)) { + Frame(element, frameObjectsIterator.next()) + } else { + Frame(element, listOf()) + } + } + return FlowStackSnapshot(flowClass = flowClass, stackFrames = frames) + } + + private fun getInstrumentedAnnotation(element: StackTraceElement): Instrumented? { + Class.forName(element.className).methods.forEach { + if (it.name == element.methodName && it.isAnnotationPresent(Instrumented::class.java)) { + return it.getAnnotation(Instrumented::class.java) + } + } + return null + } + + private fun removeConstructorStackTraceElements(stackTrace: List): List { + val newStackTrace = ArrayList() + var previousElement: StackTraceElement? = null + for (element in stackTrace) { + if (element.methodName == previousElement?.methodName && + element.className == previousElement?.className && + element.fileName == previousElement?.fileName) { + continue + } + newStackTrace.add(element) + previousElement = element + } + return newStackTrace + } + + private fun filterOutStackDump(flowStackSnapshot: FlowStackSnapshot): FlowStackSnapshot { + val framesFilteredByStackTraceElement = flowStackSnapshot.stackFrames.filter { + !FlowStateMachine::class.java.isAssignableFrom(Class.forName(it.stackTraceElement!!.className)) + } + val framesFilteredByObjects = framesFilteredByStackTraceElement.map { + Frame(it.stackTraceElement, it.stackObjects.map { + if (it != null && (it is FlowLogic<*> || it is FlowStateMachine<*> || it is Fiber<*> || it is SerializeAsToken)) { + StackFrameDataToken(it::class.java.name) + } else { + it + } + }) + } + return FlowStackSnapshot(flowStackSnapshot.timestamp, flowStackSnapshot.flowClass, framesFilteredByObjects) + } + + private fun createFile(baseDir: Path, flowId: String): File { + val file: File + val dir = File(baseDir.toFile(), "flowStackSnapshots/${LocalDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE)}/$flowId/") + val index = ThreadLocalIndex.currentIndex.get() + if (index == 0) { + dir.mkdirs() + file = File(dir, "flowStackSnapshot.json") + } else { + file = File(dir, "flowStackSnapshot-$index.json") + } + ThreadLocalIndex.currentIndex.set(index + 1) + return file + } + + private class ThreadLocalIndex private constructor() { + + companion object { + val currentIndex = object : ThreadLocal() { + override fun initialValue() = 0 + } + } + } + +} + +private inline fun R.getField(name: String): A { + val field = R::class.java.getDeclaredField(name) + field.isAccessible = true + @Suppress("UNCHECKED_CAST") + return field.get(this) as A +} + +private fun getFiberStack(fiber: Fiber<*>): Stack { + return fiber.getField("stack") +} + +private fun getObjectStack(stack: Stack): Array { + return stack.getField("dataObject") +} + +private fun getPrimitiveStack(stack: Stack): LongArray { + return stack.getField("dataLong") +} + +/* + * Returns pairs of (offset, size of frame) + */ +private fun getFrameOffsets(stack: Stack): List> { + val primitiveStack = getPrimitiveStack(stack) + val offsets = ArrayList>() + var offset = 0 + while (true) { + val record = primitiveStack[offset] + val slots = getNumSlots(record) + if (slots > 0) { + offsets.add(offset to slots) + offset += slots + 1 + } else { + break + } + } + return offsets +} + +private val MASK_FULL: Long = -1L + +private fun getNumSlots(record: Long): Int { + return getUnsignedBits(record, 14, 16).toInt() +} + +private fun getUnsignedBits(word: Long, offset: Int, length: Int): Long { + val a = 64 - length + val b = a - offset + return word.ushr(b) and MASK_FULL.ushr(a) +} diff --git a/test-utils/src/main/kotlin/net/corda/testing/ProjectStructure.kt b/test-utils/src/main/kotlin/net/corda/testing/ProjectStructure.kt index 3e866fdf4e..1809d4dc16 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/ProjectStructure.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/ProjectStructure.kt @@ -1,7 +1,7 @@ package net.corda.testing -import net.corda.core.div -import net.corda.core.isDirectory +import net.corda.core.internal.div +import net.corda.core.internal.isDirectory import java.nio.file.Path import java.nio.file.Paths diff --git a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt index af6844273b..fb940a4233 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/RPCDriver.kt @@ -1,19 +1,20 @@ package net.corda.testing -import com.google.common.util.concurrent.ListenableFuture import net.corda.client.mock.Generator import net.corda.client.mock.generateOrFail import net.corda.client.mock.int import net.corda.client.mock.string +import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClientConfiguration -import net.corda.core.div -import net.corda.core.map +import net.corda.core.concurrent.CordaFuture +import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.concurrent.fork +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.div import net.corda.core.messaging.RPCOps import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.parseNetworkHostAndPort -import net.corda.core.crypto.random63BitValue -import net.corda.testing.driver.ProcessUtilities import net.corda.node.services.RPCUserService import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.node.services.messaging.RPCServer @@ -22,6 +23,7 @@ import net.corda.nodeapi.ArtemisTcpTransport import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.RPCApi import net.corda.nodeapi.User +import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT import net.corda.testing.driver.* import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.TransportConfiguration @@ -66,7 +68,7 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE, configuration: RPCServerConfiguration = RPCServerConfiguration.default, ops : I - ): ListenableFuture + ): CordaFuture /** * Starts an In-VM RPC client. @@ -81,7 +83,7 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { username: String = rpcTestUser.username, password: String = rpcTestUser.password, configuration: RPCClientConfiguration = RPCClientConfiguration.default - ): ListenableFuture + ): CordaFuture /** * Starts an In-VM Artemis session connecting to the RPC server. @@ -112,7 +114,7 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { configuration: RPCServerConfiguration = RPCServerConfiguration.default, customPort: NetworkHostAndPort? = null, ops : I - ) : ListenableFuture + ) : CordaFuture /** * Starts a Netty RPC client. @@ -129,7 +131,7 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { username: String = rpcTestUser.username, password: String = rpcTestUser.password, configuration: RPCClientConfiguration = RPCClientConfiguration.default - ): ListenableFuture + ): CordaFuture /** * Starts a Netty RPC client in a new JVM process that calls random RPCs with random arguments. @@ -144,7 +146,7 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { rpcAddress: NetworkHostAndPort, username: String = rpcTestUser.username, password: String = rpcTestUser.password - ): ListenableFuture + ): CordaFuture /** * Starts a Netty Artemis session connecting to an RPC server. @@ -165,13 +167,13 @@ interface RPCDriverExposedDSLInterface : DriverDSLExposedInterface { maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE, maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE, customPort: NetworkHostAndPort? = null - ): ListenableFuture + ): CordaFuture fun startInVmRpcBroker( rpcUser: User = rpcTestUser, maxFileSize: Int = ArtemisMessagingServer.MAX_FILE_SIZE, maxBufferedBytesPerClient: Long = 10L * ArtemisMessagingServer.MAX_FILE_SIZE - ): ListenableFuture + ): CordaFuture fun startRpcServerWithBrokerRunning( rpcUser: User = rpcTestUser, @@ -224,6 +226,7 @@ fun rpcDriver( debugPortAllocation: PortAllocation = globalDebugPortAllocation, systemProperties: Map = emptyMap(), useTestClock: Boolean = false, + initialiseSerialization: Boolean = true, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = false), startNodesInProcess: Boolean = false, dsl: RPCDriverExposedDSLInterface.() -> A @@ -241,7 +244,8 @@ fun rpcDriver( ) ), coerce = { it }, - dsl = dsl + dsl = dsl, + initialiseSerialization = initialiseSerialization ) private class SingleUserSecurityManager(val rpcUser: User) : ActiveMQSecurityManager3 { @@ -331,14 +335,14 @@ data class RPCDriverDSL( maxBufferedBytesPerClient: Long, configuration: RPCServerConfiguration, ops: I - ): ListenableFuture { + ): CordaFuture { return startInVmRpcBroker(rpcUser, maxFileSize, maxBufferedBytesPerClient).map { broker -> startRpcServerWithBrokerRunning(rpcUser, nodeLegalName, configuration, ops, broker) } } - override fun startInVmRpcClient(rpcOpsClass: Class, username: String, password: String, configuration: RPCClientConfiguration): ListenableFuture { - return driverDSL.executorService.submit { + override fun startInVmRpcClient(rpcOpsClass: Class, username: String, password: String, configuration: RPCClientConfiguration): CordaFuture { + return driverDSL.executorService.fork { val client = RPCClient(inVmClientTransportConfiguration, configuration) val connection = client.start(rpcOpsClass, username, password) driverDSL.shutdownManager.registerShutdown { @@ -369,7 +373,7 @@ data class RPCDriverDSL( configuration: RPCServerConfiguration, customPort: NetworkHostAndPort?, ops: I - ): ListenableFuture { + ): CordaFuture { return startRpcBroker(serverName, rpcUser, maxFileSize, maxBufferedBytesPerClient, customPort).map { broker -> startRpcServerWithBrokerRunning(rpcUser, nodeLegalName, configuration, ops, broker) } @@ -381,8 +385,8 @@ data class RPCDriverDSL( username: String, password: String, configuration: RPCClientConfiguration - ): ListenableFuture { - return driverDSL.executorService.submit { + ): CordaFuture { + return driverDSL.executorService.fork { val client = RPCClient(ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), rpcAddress, null), configuration) val connection = client.start(rpcOpsClass, username, password) driverDSL.shutdownManager.registerShutdown { @@ -392,8 +396,8 @@ data class RPCDriverDSL( } } - override fun startRandomRpcClient(rpcOpsClass: Class, rpcAddress: NetworkHostAndPort, username: String, password: String): ListenableFuture { - val processFuture = driverDSL.executorService.submit { + override fun startRandomRpcClient(rpcOpsClass: Class, rpcAddress: NetworkHostAndPort, username: String, password: String): CordaFuture { + val processFuture = driverDSL.executorService.fork { ProcessUtilities.startJavaProcess(listOf(rpcOpsClass.name, rpcAddress.toString(), username, password)) } driverDSL.shutdownManager.registerProcessShutdown(processFuture) @@ -419,10 +423,10 @@ data class RPCDriverDSL( maxFileSize: Int, maxBufferedBytesPerClient: Long, customPort: NetworkHostAndPort? - ): ListenableFuture { + ): CordaFuture { val hostAndPort = customPort ?: driverDSL.portAllocation.nextHostAndPort() addressMustNotBeBound(driverDSL.executorService, hostAndPort) - return driverDSL.executorService.submit { + return driverDSL.executorService.fork { val artemisConfig = createRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient, driverDSL.driverDirectory / serverName, hostAndPort) val server = ActiveMQServerImpl(artemisConfig, SingleUserSecurityManager(rpcUser)) server.start() @@ -438,8 +442,8 @@ data class RPCDriverDSL( } } - override fun startInVmRpcBroker(rpcUser: User, maxFileSize: Int, maxBufferedBytesPerClient: Long): ListenableFuture { - return driverDSL.executorService.submit { + override fun startInVmRpcBroker(rpcUser: User, maxFileSize: Int, maxBufferedBytesPerClient: Long): CordaFuture { + return driverDSL.executorService.fork { val artemisConfig = createInVmRpcServerArtemisConfig(maxFileSize, maxBufferedBytesPerClient) val server = EmbeddedActiveMQ() server.setConfiguration(artemisConfig) @@ -510,7 +514,8 @@ class RandomRpcUser { val hostAndPort = args[1].parseNetworkHostAndPort() val username = args[2] val password = args[3] - val handle = RPCClient(hostAndPort, null).start(rpcClass, username, password) + CordaRPCClient.initialiseSerialization() + val handle = RPCClient(hostAndPort, null, serializationContext = KRYO_RPC_CLIENT_CONTEXT).start(rpcClass, username, password) val callGenerators = rpcClass.declaredMethods.map { method -> Generator.sequence(method.parameters.map { generatorStore[it.type] ?: throw Exception("No generator for ${it.type}") diff --git a/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt new file mode 100644 index 0000000000..9ec3760622 --- /dev/null +++ b/test-utils/src/main/kotlin/net/corda/testing/SerializationTestHelpers.kt @@ -0,0 +1,146 @@ +package net.corda.testing + +import net.corda.client.rpc.serialization.KryoClientSerializationScheme +import net.corda.core.serialization.* +import net.corda.core.utilities.ByteSequence +import net.corda.node.serialization.KryoServerSerializationScheme +import net.corda.nodeapi.internal.serialization.* + +inline fun withTestSerialization(block: () -> T): T { + initialiseTestSerialization() + try { + return block() + } finally { + resetTestSerialization() + } +} + +fun initialiseTestSerialization() { + // Check that everything is configured for testing with mutable delegating instances. + try { + check(SerializationDefaults.SERIALIZATION_FACTORY is TestSerializationFactory) { + "Found non-test serialization configuration: ${SerializationDefaults.SERIALIZATION_FACTORY}" + } + } catch(e: IllegalStateException) { + SerializationDefaults.SERIALIZATION_FACTORY = TestSerializationFactory() + } + try { + check(SerializationDefaults.P2P_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.P2P_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.RPC_SERVER_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.RPC_SERVER_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.RPC_CLIENT_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.RPC_CLIENT_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.STORAGE_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.STORAGE_CONTEXT = TestSerializationContext() + } + try { + check(SerializationDefaults.CHECKPOINT_CONTEXT is TestSerializationContext) + } catch(e: IllegalStateException) { + SerializationDefaults.CHECKPOINT_CONTEXT = TestSerializationContext() + } + + // Check that the previous test, if there was one, cleaned up after itself. + // IF YOU SEE THESE MESSAGES, THEN IT MEANS A TEST HAS NOT CALLED resetTestSerialization() + check((SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.SERIALIZATION_FACTORY}" }) + check((SerializationDefaults.P2P_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.P2P_CONTEXT}" }) + check((SerializationDefaults.RPC_SERVER_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.RPC_SERVER_CONTEXT}" }) + check((SerializationDefaults.RPC_CLIENT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.RPC_CLIENT_CONTEXT}" }) + check((SerializationDefaults.STORAGE_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.STORAGE_CONTEXT}" }) + check((SerializationDefaults.CHECKPOINT_CONTEXT as TestSerializationContext).delegate == null, { "Expected uninitialised serialization framework but found it set from: ${SerializationDefaults.CHECKPOINT_CONTEXT}" }) + + // Now configure all the testing related delegates. + (SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply { + registerScheme(KryoClientSerializationScheme(this)) + registerScheme(KryoServerSerializationScheme(this)) + registerScheme(AMQPClientSerializationScheme()) + registerScheme(AMQPServerSerializationScheme()) + } + (SerializationDefaults.P2P_CONTEXT as TestSerializationContext).delegate = KRYO_P2P_CONTEXT + (SerializationDefaults.RPC_SERVER_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_SERVER_CONTEXT + (SerializationDefaults.RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = KRYO_RPC_CLIENT_CONTEXT + (SerializationDefaults.STORAGE_CONTEXT as TestSerializationContext).delegate = KRYO_STORAGE_CONTEXT + (SerializationDefaults.CHECKPOINT_CONTEXT as TestSerializationContext).delegate = KRYO_CHECKPOINT_CONTEXT +} + +fun resetTestSerialization() { + (SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = null + (SerializationDefaults.P2P_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.RPC_SERVER_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.RPC_CLIENT_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.STORAGE_CONTEXT as TestSerializationContext).delegate = null + (SerializationDefaults.CHECKPOINT_CONTEXT as TestSerializationContext).delegate = null +} + +class TestSerializationFactory : SerializationFactory { + var delegate: SerializationFactory? = null + set(value) { + field = value + stackTrace = Exception().stackTrace.asList() + } + private var stackTrace: List? = null + + override fun toString(): String = stackTrace?.joinToString("\n") ?: "null" + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + return delegate!!.deserialize(byteSequence, clazz, context) + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + return delegate!!.serialize(obj, context) + } +} + +class TestSerializationContext : SerializationContext { + var delegate: SerializationContext? = null + set(value) { + field = value + stackTrace = Exception().stackTrace.asList() + } + private var stackTrace: List? = null + + override fun toString(): String = stackTrace?.joinToString("\n") ?: "null" + + override val preferedSerializationVersion: ByteSequence + get() = delegate!!.preferedSerializationVersion + override val deserializationClassLoader: ClassLoader + get() = delegate!!.deserializationClassLoader + override val whitelist: ClassWhitelist + get() = delegate!!.whitelist + override val properties: Map + get() = delegate!!.properties + override val objectReferencesEnabled: Boolean + get() = delegate!!.objectReferencesEnabled + override val useCase: SerializationContext.UseCase + get() = delegate!!.useCase + + override fun withProperty(property: Any, value: Any): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withProperty(property, value) } + } + + override fun withoutReferences(): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withoutReferences() } + } + + override fun withClassLoader(classLoader: ClassLoader): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withClassLoader(classLoader) } + } + + override fun withWhitelisted(clazz: Class<*>): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) } + } + + override fun withPreferredSerializationVersion(versionHeader: ByteSequence): SerializationContext { + return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withPreferredSerializationVersion(versionHeader) } + } +} diff --git a/test-utils/src/main/kotlin/net/corda/testing/TestConstants.kt b/test-utils/src/main/kotlin/net/corda/testing/TestConstants.kt index d2ebdae030..29d79fe802 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/TestConstants.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/TestConstants.kt @@ -2,10 +2,20 @@ package net.corda.testing -import net.corda.core.crypto.* -import net.corda.core.crypto.testing.DummyPublicKey +import net.corda.core.contracts.Command +import net.corda.core.contracts.TypeOnlyCommandData +import net.corda.core.crypto.CertificateAndKeyPair +import net.corda.core.crypto.entropyToKeyPair +import net.corda.core.crypto.generateKeyPair import net.corda.core.identity.Party import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.concurrent.transpose +import net.corda.core.messaging.CordaRPCOps +import net.corda.core.node.services.ServiceInfo +import net.corda.node.services.transactions.ValidatingNotaryService +import net.corda.node.utilities.X509Utilities +import net.corda.nodeapi.User +import net.corda.testing.driver.DriverDSLExposedInterface import org.bouncycastle.asn1.x500.X500Name import java.math.BigInteger import java.security.KeyPair @@ -15,9 +25,6 @@ import java.time.Instant // A dummy time at which we will be pretending test transactions are created. val TEST_TX_TIME: Instant get() = Instant.parse("2015-04-17T12:00:00.00Z") -val DUMMY_PUBKEY_1: PublicKey get() = DummyPublicKey("x1") -val DUMMY_PUBKEY_2: PublicKey get() = DummyPublicKey("x2") - val DUMMY_KEY_1: KeyPair by lazy { generateKeyPair() } val DUMMY_KEY_2: KeyPair by lazy { generateKeyPair() } @@ -67,3 +74,55 @@ val DUMMY_CA: CertificateAndKeyPair by lazy { CertificateAndKeyPair(cert, DUMMY_CA_KEY) } +fun dummyCommand(vararg signers: PublicKey) = Command(object : TypeOnlyCommandData() {}, signers.toList()) + +val DUMMY_IDENTITY_1: PartyAndCertificate get() = getTestPartyAndCertificate(DUMMY_PARTY) +val DUMMY_PARTY: Party get() = Party(X500Name("CN=Dummy,O=Dummy,L=Madrid,C=ES"), DUMMY_KEY_1.public) + +// +// Extensions to the Driver DSL to auto-manufacture nodes by name. +// + +/** + * A simple wrapper for objects provided by the integration test driver DSL. The fields are lazy so + * node construction won't start until you access the members. You can get one of these from the + * [alice], [bob] and [aliceBobAndNotary] functions. + */ +class PredefinedTestNode internal constructor(party: Party, driver: DriverDSLExposedInterface, services: Set) { + val rpcUsers = listOf(User("admin", "admin", setOf("ALL"))) // TODO: Randomize? + val nodeFuture by lazy { driver.startNode(party.name, rpcUsers = rpcUsers, advertisedServices = services) } + val node by lazy { nodeFuture.get()!! } + val rpc by lazy { node.rpcClientToNode() } + + fun useRPC(block: (CordaRPCOps) -> R) = rpc.use(rpcUsers[0].username, rpcUsers[0].password) { block(it.proxy) } +} + +// TODO: Probably we should inject the above keys through the driver to make the nodes use it, rather than have the warnings below. + +/** + * Returns a plain, entirely stock node pre-configured with the [ALICE] identity. Note that a random key will be generated + * for it: you won't have [ALICE_KEY]. + */ +fun DriverDSLExposedInterface.alice(): PredefinedTestNode = PredefinedTestNode(ALICE, this, emptySet()) +/** + * Returns a plain, entirely stock node pre-configured with the [BOB] identity. Note that a random key will be generated + * for it: you won't have [BOB_KEY]. + */ +fun DriverDSLExposedInterface.bob(): PredefinedTestNode = PredefinedTestNode(BOB, this, emptySet()) +/** + * Returns a plain single node notary pre-configured with the [DUMMY_NOTARY] identity. Note that a random key will be generated + * for it: you won't have [DUMMY_NOTARY_KEY]. + */ +fun DriverDSLExposedInterface.notary(): PredefinedTestNode = PredefinedTestNode(DUMMY_NOTARY, this, setOf(ServiceInfo(ValidatingNotaryService.type))) + +/** + * Returns plain, entirely stock nodes pre-configured with the [ALICE], [BOB] and [DUMMY_NOTARY] X.500 names in that + * order. They have been started up in parallel and are now ready to use. + */ +fun DriverDSLExposedInterface.aliceBobAndNotary(): List { + val alice = alice() + val bob = bob() + val notary = notary() + listOf(alice.nodeFuture, bob.nodeFuture, notary.nodeFuture).transpose().get() + return listOf(alice, bob, notary) +} \ No newline at end of file diff --git a/test-utils/src/main/kotlin/net/corda/testing/TestDSL.kt b/test-utils/src/main/kotlin/net/corda/testing/TestDSL.kt index ea574f79a0..d3dd0e6bfb 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/TestDSL.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/TestDSL.kt @@ -3,10 +3,9 @@ package net.corda.testing import net.corda.core.contracts.* import net.corda.core.crypto.* import net.corda.core.crypto.composite.expandedCompositeKeys -import net.corda.core.crypto.testing.NullSignature +import net.corda.core.crypto.testing.NULL_SIGNATURE import net.corda.core.identity.Party import net.corda.core.node.ServiceHub -import net.corda.core.serialization.serialize import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.WireTransaction @@ -14,6 +13,9 @@ import java.io.InputStream import java.security.KeyPair import java.security.PublicKey import java.util.* +import kotlin.collections.component1 +import kotlin.collections.component2 +import kotlin.collections.set //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -284,7 +286,7 @@ data class TestLedgerDSLInterpreter private constructor( override fun verifies(): EnforceVerifyOrFail { try { val usedInputs = mutableSetOf() - services.recordTransactions(transactionsUnverified.map { SignedTransaction(it.serialized, listOf(NullSignature)) }) + services.recordTransactions(transactionsUnverified.map { SignedTransaction(it, listOf(NULL_SIGNATURE)) }) for ((_, value) in transactionWithLocations) { val wtx = value.transaction val ltx = wtx.toLedgerTransaction(services) @@ -296,7 +298,7 @@ data class TestLedgerDSLInterpreter private constructor( throw DoubleSpentInputs(txIds) } usedInputs.addAll(wtx.inputs) - services.recordTransactions(SignedTransaction(wtx.serialized, listOf(NullSignature))) + services.recordTransactions(SignedTransaction(wtx, listOf(NULL_SIGNATURE))) } return EnforceVerifyOrFail.Token } catch (exception: TransactionVerificationException) { @@ -329,20 +331,18 @@ data class TestLedgerDSLInterpreter private constructor( * @return List of [SignedTransaction]s. */ fun signAll(transactionsToSign: List, extraKeys: List) = transactionsToSign.map { wtx -> - check(wtx.mustSign.isNotEmpty()) - val bits = wtx.serialize() - require(bits == wtx.serialized) - val signatures = ArrayList() + check(wtx.requiredSigningKeys.isNotEmpty()) + val signatures = ArrayList() val keyLookup = HashMap() (ALL_TEST_KEYS + extraKeys).forEach { keyLookup[it.public] = it } - wtx.mustSign.expandedCompositeKeys.forEach { + wtx.requiredSigningKeys.expandedCompositeKeys.forEach { val key = keyLookup[it] ?: throw IllegalArgumentException("Missing required key for ${it.toStringShort()}") - signatures += key.sign(wtx.id) + signatures += key.sign(SignableData(wtx.id, SignatureMetadata(1, Crypto.findSignatureScheme(it).schemeNumberID))) } - SignedTransaction(bits, signatures) + SignedTransaction(wtx, signatures) } /** diff --git a/test-utils/src/main/kotlin/net/corda/testing/TestDependencyInjectionBase.kt b/test-utils/src/main/kotlin/net/corda/testing/TestDependencyInjectionBase.kt new file mode 100644 index 0000000000..549cd2ac6d --- /dev/null +++ b/test-utils/src/main/kotlin/net/corda/testing/TestDependencyInjectionBase.kt @@ -0,0 +1,19 @@ +package net.corda.testing + +import org.junit.After +import org.junit.Before + +/** + * The beginnings of somewhere to inject implementations for unit tests. + */ +abstract class TestDependencyInjectionBase { + @Before + fun initialiseSerialization() { + initialiseTestSerialization() + } + + @After + fun resetInitialisation() { + resetTestSerialization() + } +} \ No newline at end of file diff --git a/test-utils/src/main/kotlin/net/corda/testing/TransactionDSLInterpreter.kt b/test-utils/src/main/kotlin/net/corda/testing/TransactionDSLInterpreter.kt index cb973e022f..92b9604350 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/TransactionDSLInterpreter.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/TransactionDSLInterpreter.kt @@ -4,7 +4,7 @@ import net.corda.core.contracts.* import net.corda.testing.contracts.DummyContract import net.corda.core.crypto.SecureHash import net.corda.core.identity.Party -import net.corda.core.seconds +import net.corda.core.utilities.seconds import net.corda.core.transactions.TransactionBuilder import java.security.PublicKey import java.time.Duration diff --git a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContract.kt b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContract.kt index e7a96d9a0b..92fb2c1084 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContract.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContract.kt @@ -4,6 +4,7 @@ import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder // The dummy contract doesn't do anything useful. It exists for testing purposes. @@ -20,7 +21,7 @@ data class DummyContract(override val legalContractReference: SecureHash = Secur override val participants: List get() = listOf(owner) - override fun withNewOwner(newOwner: AbstractParty) = Pair(Commands.Move(), copy(owner = newOwner)) + override fun withNewOwner(newOwner: AbstractParty) = CommandAndState(Commands.Move(), copy(owner = newOwner)) } /** @@ -39,7 +40,7 @@ data class DummyContract(override val legalContractReference: SecureHash = Secur class Move : TypeOnlyCommandData(), Commands } - override fun verify(tx: TransactionForContract) { + override fun verify(tx: LedgerTransaction) { // Always accepts. } @@ -49,10 +50,10 @@ data class DummyContract(override val legalContractReference: SecureHash = Secur val owners = listOf(owner) + otherOwners return if (owners.size == 1) { val state = SingleOwnerState(magicNumber, owners.first().party) - TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Create(), owners.first().party.owningKey)) + TransactionBuilder(notary).withItems(state, Command(Commands.Create(), owners.first().party.owningKey)) } else { val state = MultiOwnerState(magicNumber, owners.map { it.party }) - TransactionType.General.Builder(notary = notary).withItems(state, Command(Commands.Create(), owners.map { it.party.owningKey })) + TransactionBuilder(notary).withItems(state, Command(Commands.Create(), owners.map { it.party.owningKey })) } } @@ -61,7 +62,7 @@ data class DummyContract(override val legalContractReference: SecureHash = Secur require(priors.isNotEmpty()) val priorState = priors[0].state.data val (cmd, state) = priorState.withNewOwner(newOwner) - return TransactionType.General.Builder(notary = priors[0].state.notary).withItems( + return TransactionBuilder(notary = priors[0].state.notary).withItems( /* INPUTS */ *priors.toTypedArray(), /* COMMAND */ Command(cmd, priorState.owner.owningKey), /* OUTPUT */ state diff --git a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContractV2.kt b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContractV2.kt index b14b55937f..2f616589ed 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContractV2.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyContractV2.kt @@ -2,9 +2,11 @@ package net.corda.testing.contracts import net.corda.core.contracts.* import net.corda.core.crypto.SecureHash +import net.corda.core.flows.ContractUpgradeFlow import net.corda.core.identity.AbstractParty +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.WireTransaction -import net.corda.flows.ContractUpgradeFlow // The dummy contract doesn't do anything useful. It exists for testing purposes. val DUMMY_V2_PROGRAM_ID = DummyContractV2() @@ -30,7 +32,7 @@ class DummyContractV2 : UpgradedContract = states.flatMap { it.state.data.participants }.distinct().toSet() - return Pair(TransactionType.General.Builder(notary).apply { + return Pair(TransactionBuilder(notary).apply { states.forEach { addInputState(it) addOutputState(upgrade(it.state.data)) diff --git a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyDealContract.kt b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyDealContract.kt index 57d155f38d..421de71ba6 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyDealContract.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyDealContract.kt @@ -2,7 +2,6 @@ package net.corda.testing.contracts import net.corda.contracts.DealState import net.corda.core.contracts.Contract -import net.corda.core.contracts.TransactionForContract import net.corda.core.contracts.UniqueIdentifier import net.corda.core.crypto.SecureHash import net.corda.core.crypto.containsAny @@ -11,21 +10,25 @@ import net.corda.core.identity.Party import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState import net.corda.core.schemas.QueryableState -import net.corda.testing.schemas.DummyDealStateSchemaV1 +import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.TransactionBuilder +import net.corda.testing.schemas.DummyDealStateSchemaV1 import java.security.PublicKey class DummyDealContract : Contract { override val legalContractReference: SecureHash = SecureHash.sha256("TestDeal") - override fun verify(tx: TransactionForContract) {} + override fun verify(tx: LedgerTransaction) {} data class State( - override val contract: Contract = DummyDealContract(), - override val participants: List = listOf(), - override val linearId: UniqueIdentifier = UniqueIdentifier(), - override val ref: String) : DealState, QueryableState + override val contract: Contract, + override val participants: List, + override val linearId: UniqueIdentifier) : DealState, QueryableState { + constructor(contract: Contract = DummyDealContract(), + participants: List = listOf(), + ref: String) : this(contract, participants, UniqueIdentifier(ref)) + override fun isRelevant(ourKeys: Set): Boolean { return participants.any { it.owningKey.containsAny(ourKeys) } } @@ -39,8 +42,7 @@ class DummyDealContract : Contract { override fun generateMappedObject(schema: MappedSchema): PersistentState { return when (schema) { is DummyDealStateSchemaV1 -> DummyDealStateSchemaV1.PersistentDummyDealState( - uid = linearId, - dealReference = ref + uid = linearId ) else -> throw IllegalArgumentException("Unrecognised schema $schema") } diff --git a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyLinearContract.kt b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyLinearContract.kt index 4f0d31d676..2ed7cc7a9d 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyLinearContract.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/contracts/DummyLinearContract.kt @@ -1,15 +1,16 @@ package net.corda.testing.contracts -import net.corda.core.contracts.* -import net.corda.core.contracts.clauses.Clause -import net.corda.core.contracts.clauses.FilterOn -import net.corda.core.contracts.clauses.verifyClause +import net.corda.core.contracts.Contract +import net.corda.core.contracts.LinearState +import net.corda.core.contracts.UniqueIdentifier +import net.corda.core.contracts.requireThat import net.corda.core.crypto.SecureHash import net.corda.core.crypto.containsAny import net.corda.core.identity.AbstractParty import net.corda.core.schemas.MappedSchema import net.corda.core.schemas.PersistentState import net.corda.core.schemas.QueryableState +import net.corda.core.transactions.LedgerTransaction import net.corda.testing.schemas.DummyLinearStateSchemaV1 import net.corda.testing.schemas.DummyLinearStateSchemaV2 import java.time.LocalDateTime @@ -18,10 +19,17 @@ import java.time.ZoneOffset.UTC class DummyLinearContract : Contract { override val legalContractReference: SecureHash = SecureHash.sha256("Test") - val clause: Clause = LinearState.ClauseVerifier() - override fun verify(tx: TransactionForContract) = verifyClause(tx, - FilterOn(clause, { states -> states.filterIsInstance() }), - emptyList()) + override fun verify(tx: LedgerTransaction) { + val inputs = tx.inputs.map { it.state.data }.filterIsInstance() + val outputs = tx.outputs.map { it.data }.filterIsInstance() + + val inputIds = inputs.map { it.linearId }.distinct() + val outputIds = outputs.map { it.linearId }.distinct() + requireThat { + "LinearStates are not merged" using (inputIds.count() == inputs.count()) + "LinearStates are not split" using (outputIds.count() == outputs.count()) + } + } data class State( override val linearId: UniqueIdentifier = UniqueIdentifier(), diff --git a/test-utils/src/main/kotlin/net/corda/testing/contracts/VaultFiller.kt b/test-utils/src/main/kotlin/net/corda/testing/contracts/VaultFiller.kt index 0994505b0b..dfb30e6ebb 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/contracts/VaultFiller.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/contracts/VaultFiller.kt @@ -6,35 +6,41 @@ import net.corda.contracts.Commodity import net.corda.contracts.DealState import net.corda.contracts.asset.* import net.corda.core.contracts.* +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.utilities.getOrThrow import net.corda.core.identity.AbstractParty import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.node.ServiceHub import net.corda.core.node.services.Vault -import net.corda.core.utilities.OpaqueBytes +import net.corda.core.toFuture import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.OpaqueBytes import net.corda.testing.CHARLIE import net.corda.testing.DUMMY_NOTARY import net.corda.testing.DUMMY_NOTARY_KEY -import java.security.KeyPair import java.security.PublicKey +import java.time.Duration import java.time.Instant import java.time.Instant.now import java.util.* @JvmOverloads fun ServiceHub.fillWithSomeTestDeals(dealIds: List, - participants: List = emptyList()) : Vault { + participants: List = emptyList(), + notary: Party = DUMMY_NOTARY) : Vault { val myKey: PublicKey = myInfo.legalIdentity.owningKey val me = AnonymousParty(myKey) val transactions: List = dealIds.map { // Issue a deal state - val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + val dummyIssue = TransactionBuilder(notary = notary).apply { addOutputState(DummyDealContract.State(ref = it, participants = participants.plus(me))) - signWith(DUMMY_NOTARY_KEY) } - return@map signInitialTransaction(dummyIssue) + val stx = signInitialTransaction(dummyIssue) + return@map addSignature(stx, notary.owningKey) } recordTransactions(transactions) @@ -57,10 +63,12 @@ fun ServiceHub.fillWithSomeTestLinearStates(numberToCreate: Int, linearTimestamp: Instant = now()) : Vault { val myKey: PublicKey = myInfo.legalIdentity.owningKey val me = AnonymousParty(myKey) + val issuerKey = DUMMY_NOTARY_KEY + val signatureMetadata = SignatureMetadata(myInfo.platformVersion, Crypto.findSignatureScheme(issuerKey.public).schemeNumberID) val transactions: List = (1..numberToCreate).map { // Issue a Linear state - val dummyIssue = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + val dummyIssue = TransactionBuilder(notary = DUMMY_NOTARY).apply { addOutputState(DummyLinearContract.State( linearId = UniqueIdentifier(externalId), participants = participants.plus(me), @@ -68,10 +76,9 @@ fun ServiceHub.fillWithSomeTestLinearStates(numberToCreate: Int, linearNumber = linearNumber, linearBoolean = linearBoolean, linearTimestamp = linearTimestamp)) - signWith(DUMMY_NOTARY_KEY) } - return@map signInitialTransaction(dummyIssue) + return@map signInitialTransaction(dummyIssue).withAdditionalSignature(issuerKey, signatureMetadata) } recordTransactions(transactions) @@ -91,18 +98,19 @@ fun ServiceHub.fillWithSomeTestLinearStates(numberToCreate: Int, * * The service hub needs to provide at least a key management service and a storage service. * + * @param issuerServices service hub of the issuer node, which will be used to sign the transaction. * @param outputNotary the notary to use for output states. The transaction is NOT signed by this notary. * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). */ fun ServiceHub.fillWithSomeTestCash(howMuch: Amount, + issuerServices: ServiceHub = this, outputNotary: Party = DUMMY_NOTARY, atLeastThisManyStates: Int = 3, atMostThisManyStates: Int = 10, rng: Random = Random(), ref: OpaqueBytes = OpaqueBytes(ByteArray(1, { 1 })), ownedBy: AbstractParty? = null, - issuedBy: PartyAndReference = DUMMY_CASH_ISSUER, - issuerKey: KeyPair = DUMMY_CASH_ISSUER_KEY): Vault { + issuedBy: PartyAndReference = DUMMY_CASH_ISSUER): Vault { val amounts = calculateRandomlySizedAmounts(howMuch, atLeastThisManyStates, atMostThisManyStates, rng) val myKey: PublicKey = ownedBy?.owningKey ?: myInfo.legalIdentity.owningKey @@ -111,11 +119,10 @@ fun ServiceHub.fillWithSomeTestCash(howMuch: Amount, // We will allocate one state to one transaction, for simplicities sake. val cash = Cash() val transactions: List = amounts.map { pennies -> - val issuance = TransactionType.General.Builder(null as Party?) + val issuance = TransactionBuilder(null as Party?) cash.generateIssue(issuance, Amount(pennies, Issued(issuedBy.copy(reference = ref), howMuch.token)), me, outputNotary) - issuance.signWith(issuerKey) - return@map issuance.toSignedTransaction(true) + return@map issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) } recordTransactions(transactions) @@ -128,21 +135,26 @@ fun ServiceHub.fillWithSomeTestCash(howMuch: Amount, return Vault(states) } +/** + * + * @param issuerServices service hub of the issuer node, which will be used to sign the transaction. + * @param outputNotary the notary to use for output states. The transaction is NOT signed by this notary. + * @return a vault object that represents the generated states (it will NOT be the full vault from the service hub!). + */ // TODO: need to make all FungibleAsset commands (issue, move, exit) generic fun ServiceHub.fillWithSomeTestCommodity(amount: Amount, + issuerServices: ServiceHub = this, outputNotary: Party = DUMMY_NOTARY, ref: OpaqueBytes = OpaqueBytes(ByteArray(1, { 1 })), ownedBy: AbstractParty? = null, - issuedBy: PartyAndReference = DUMMY_OBLIGATION_ISSUER.ref(1), - issuerKey: KeyPair = DUMMY_OBLIGATION_ISSUER_KEY): Vault { + issuedBy: PartyAndReference = DUMMY_OBLIGATION_ISSUER.ref(1)): Vault { val myKey: PublicKey = ownedBy?.owningKey ?: myInfo.legalIdentity.owningKey val me = AnonymousParty(myKey) val commodity = CommodityContract() - val issuance = TransactionType.General.Builder(null as Party?) + val issuance = TransactionBuilder(null as Party?) commodity.generateIssue(issuance, Amount(amount.quantity, Issued(issuedBy.copy(reference = ref), amount.token)), me, outputNotary) - issuance.signWith(issuerKey) - val transaction = issuance.toSignedTransaction(true) + val transaction = issuerServices.signInitialTransaction(issuance, issuedBy.party.owningKey) recordTransactions(transaction) @@ -177,63 +189,62 @@ fun calculateRandomlySizedAmounts(howMuch: Amount, min: Int, max: Int, return amounts } -fun ServiceHub.consume(states: List>) { +fun ServiceHub.consume(states: List>, notary: Party) { // Create a txn consuming different contract types states.forEach { - val consumedTx = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + val builder = TransactionBuilder(notary = notary).apply { addInputState(it) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + } + val consumedTx = signInitialTransaction(builder, notary.owningKey) recordTransactions(consumedTx) } } -fun ServiceHub.consumeAndProduce(stateAndRef: StateAndRef): StateAndRef { +fun ServiceHub.consumeAndProduce(stateAndRef: StateAndRef, notary: Party): StateAndRef { // Create a txn consuming different contract types - val consumedTx = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + var builder = TransactionBuilder(notary = notary).apply { addInputState(stateAndRef) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + } + val consumedTx = signInitialTransaction(builder, notary.owningKey) recordTransactions(consumedTx) // Create a txn consuming different contract types - val producedTx = TransactionType.General.Builder(notary = DUMMY_NOTARY).apply { + builder = TransactionBuilder(notary = notary).apply { addOutputState(DummyLinearContract.State(linearId = stateAndRef.state.data.linearId, participants = stateAndRef.state.data.participants)) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction() + } + val producedTx = signInitialTransaction(builder, notary.owningKey) recordTransactions(producedTx) return producedTx.tx.outRef(0) } -fun ServiceHub.consumeAndProduce(states: List>) { +fun ServiceHub.consumeAndProduce(states: List>, notary: Party) { states.forEach { - consumeAndProduce(it) + consumeAndProduce(it, notary) } } -fun ServiceHub.consumeDeals(dealStates: List>) = consume(dealStates) -fun ServiceHub.consumeLinearStates(linearStates: List>) = consume(linearStates) -fun ServiceHub.evolveLinearStates(linearStates: List>) = consumeAndProduce(linearStates) -fun ServiceHub.evolveLinearState(linearState: StateAndRef) : StateAndRef = consumeAndProduce(linearState) +fun ServiceHub.consumeDeals(dealStates: List>, notary: Party) = consume(dealStates, notary) +fun ServiceHub.consumeLinearStates(linearStates: List>, notary: Party) = consume(linearStates, notary) +fun ServiceHub.evolveLinearStates(linearStates: List>, notary: Party) = consumeAndProduce(linearStates, notary) +fun ServiceHub.evolveLinearState(linearState: StateAndRef, notary: Party) : StateAndRef = consumeAndProduce(linearState, notary) @JvmOverloads -fun ServiceHub.consumeCash(amount: Amount, to: Party = CHARLIE): Vault { +fun ServiceHub.consumeCash(amount: Amount, to: Party = CHARLIE, notary: Party): Vault.Update { + val update = vaultService.rawUpdates.toFuture() + val services = this + // A tx that spends our money. - val spendTX = TransactionType.General.Builder(DUMMY_NOTARY).apply { - vaultService.generateSpend(this, amount, to) - signWith(DUMMY_NOTARY_KEY) - }.toSignedTransaction(checkSufficientSignatures = false) + val builder = TransactionBuilder(notary).apply { + Cash.generateSpend(services, this, amount, to) + } + val spendTx = signInitialTransaction(builder, notary.owningKey) - recordTransactions(spendTX) + recordTransactions(spendTx) - // Get all the StateRefs of all the generated transactions. - val states = spendTX.tx.outputs.indices.map { i -> spendTX.tx.outRef(i) } - - return Vault(states) + return update.getOrThrow(Duration.ofSeconds(3)) } - diff --git a/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt b/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt index 809f3fd4b6..3a14d967e0 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/driver/Driver.kt @@ -2,19 +2,23 @@ package net.corda.testing.driver -import com.google.common.util.concurrent.* +import com.google.common.util.concurrent.ThreadFactoryBuilder import com.typesafe.config.Config import com.typesafe.config.ConfigRenderOptions import net.corda.client.rpc.CordaRPCClient import net.corda.cordform.CordformContext import net.corda.cordform.CordformNode import net.corda.cordform.NodeDefinition -import net.corda.core.* +import net.corda.core.concurrent.CordaFuture import net.corda.core.concurrent.firstOf -import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.appendToCommonName import net.corda.core.crypto.commonName +import net.corda.core.crypto.getX509Name import net.corda.core.identity.Party +import net.corda.core.internal.ThreadBox +import net.corda.core.internal.concurrent.* +import net.corda.core.internal.div +import net.corda.core.internal.times import net.corda.core.messaging.CordaRPCOps import net.corda.core.node.NodeInfo import net.corda.core.node.services.ServiceInfo @@ -33,10 +37,7 @@ import net.corda.nodeapi.User import net.corda.nodeapi.config.SSLConfiguration import net.corda.nodeapi.config.parseAs import net.corda.nodeapi.internal.addShutdownHook -import net.corda.testing.ALICE -import net.corda.testing.BOB -import net.corda.testing.DUMMY_BANK_A -import net.corda.testing.DUMMY_NOTARY +import net.corda.testing.* import net.corda.testing.node.MOCK_VERSION_INFO import okhttp3.OkHttpClient import okhttp3.Request @@ -52,9 +53,12 @@ import java.time.Instant import java.time.ZoneOffset.UTC import java.time.format.DateTimeFormatter import java.util.* -import java.util.concurrent.* +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit.MILLISECONDS import java.util.concurrent.TimeUnit.SECONDS +import java.util.concurrent.TimeoutException import java.util.concurrent.atomic.AtomicInteger import kotlin.concurrent.thread @@ -91,12 +95,12 @@ interface DriverDSLExposedInterface : CordformContext { rpcUsers: List = emptyList(), verifierType: VerifierType = VerifierType.InMemory, customOverrides: Map = emptyMap(), - startInSameProcess: Boolean? = null): ListenableFuture + startInSameProcess: Boolean? = null): CordaFuture fun startNodes( nodes: List, startInSameProcess: Boolean? = null - ): List> + ): List> /** * Starts a distributed notary cluster. @@ -116,14 +120,14 @@ interface DriverDSLExposedInterface : CordformContext { type: ServiceType = RaftValidatingNotaryService.type, verifierType: VerifierType = VerifierType.InMemory, rpcUsers: List = emptyList(), - startInSameProcess: Boolean? = null): ListenableFuture>> + startInSameProcess: Boolean? = null): CordaFuture>> /** * Starts a web server for a node * * @param handle The handle for the node that this webserver connects to via RPC. */ - fun startWebserver(handle: NodeHandle): ListenableFuture + fun startWebserver(handle: NodeHandle): CordaFuture /** * Starts a network map service node. Note that only a single one should ever be running, so you will probably want @@ -131,7 +135,7 @@ interface DriverDSLExposedInterface : CordformContext { * @param startInProcess Determines if the node should be started inside this process. If null the Driver-level * value will be used. */ - fun startDedicatedNetworkMapService(startInProcess: Boolean? = null): ListenableFuture + fun startDedicatedNetworkMapService(startInProcess: Boolean? = null): CordaFuture fun waitForAllNodesToFinish() @@ -144,12 +148,12 @@ interface DriverDSLExposedInterface : CordformContext { * @param check The function being polled. * @return A future that completes with the non-null value [check] has returned. */ - fun pollUntilNonNull(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> A?): ListenableFuture + fun pollUntilNonNull(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> A?): CordaFuture /** * Polls the given function until it returns true. * @see pollUntilNonNull */ - fun pollUntilTrue(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> Boolean): ListenableFuture { + fun pollUntilTrue(pollName: String, pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> Boolean): CordaFuture { return pollUntilNonNull(pollName, pollInterval, warnCount) { if (check()) Unit else null } } @@ -185,7 +189,7 @@ sealed class NodeHandle { val nodeThread: Thread ) : NodeHandle() - fun rpcClientToNode(): CordaRPCClient = CordaRPCClient(configuration.rpcAddress!!) + fun rpcClientToNode(): CordaRPCClient = CordaRPCClient(configuration.rpcAddress!!, initialiseSerialization = false) } data class WebserverHandle( @@ -221,7 +225,7 @@ sealed class PortAllocation { * (...) * } * - * Note that [DriverDSL.startNode] does not wait for the node to start up synchronously, but rather returns a [Future] + * Note that [DriverDSL.startNode] does not wait for the node to start up synchronously, but rather returns a [CordaFuture] * of the [NodeInfo] that may be waited on, which completes when the new node registered with the network map service. * * The driver implicitly bootstraps a [NetworkMapService]. @@ -248,6 +252,7 @@ fun driver( debugPortAllocation: PortAllocation = PortAllocation.Incremental(5005), systemProperties: Map = emptyMap(), useTestClock: Boolean = false, + initialiseSerialization: Boolean = true, networkMapStartStrategy: NetworkMapStartStrategy = NetworkMapStartStrategy.Dedicated(startAutomatically = true), startNodesInProcess: Boolean = false, dsl: DriverDSLExposedInterface.() -> A @@ -263,7 +268,8 @@ fun driver( isDebug = isDebug ), coerce = { it }, - dsl = dsl + dsl = dsl, + initialiseSerialization = initialiseSerialization ) /** @@ -276,9 +282,11 @@ fun driver( */ fun genericDriver( driverDsl: D, + initialiseSerialization: Boolean = true, coerce: (D) -> DI, dsl: DI.() -> A ): A { + if (initialiseSerialization) initialiseTestSerialization() val shutdownHook = addShutdownHook(driverDsl::shutdown) try { driverDsl.start() @@ -289,6 +297,7 @@ fun genericD } finally { driverDsl.shutdown() shutdownHook.cancel() + if (initialiseSerialization) resetTestSerialization() } } @@ -305,7 +314,7 @@ fun addressMustBeBound(executorService: ScheduledExecutorService, hostAndPort: N addressMustBeBoundFuture(executorService, hostAndPort, listenProcess).getOrThrow() } -fun addressMustBeBoundFuture(executorService: ScheduledExecutorService, hostAndPort: NetworkHostAndPort, listenProcess: Process? = null): ListenableFuture { +fun addressMustBeBoundFuture(executorService: ScheduledExecutorService, hostAndPort: NetworkHostAndPort, listenProcess: Process? = null): CordaFuture { return poll(executorService, "address $hostAndPort to bind") { if (listenProcess != null && !listenProcess.isAlive) { throw ListenProcessDeathException(hostAndPort, listenProcess) @@ -323,7 +332,7 @@ fun addressMustNotBeBound(executorService: ScheduledExecutorService, hostAndPort addressMustNotBeBoundFuture(executorService, hostAndPort).getOrThrow() } -fun addressMustNotBeBoundFuture(executorService: ScheduledExecutorService, hostAndPort: NetworkHostAndPort): ListenableFuture { +fun addressMustNotBeBoundFuture(executorService: ScheduledExecutorService, hostAndPort: NetworkHostAndPort): CordaFuture { return poll(executorService, "address $hostAndPort to unbind") { try { Socket(hostAndPort.host, hostAndPort.port).close() @@ -340,14 +349,14 @@ fun poll( pollInterval: Duration = 500.millis, warnCount: Int = 120, check: () -> A? -): ListenableFuture { - val resultFuture = SettableFuture.create() +): CordaFuture { + val resultFuture = openFuture() val task = object : Runnable { var counter = -1 override fun run() { if (resultFuture.isCancelled) return // Give up, caller can no longer get the result. if (++counter == warnCount) { - log.warn("Been polling $pollName for ${pollInterval.multipliedBy(warnCount.toLong()).seconds} seconds...") + log.warn("Been polling $pollName for ${(pollInterval * warnCount.toLong()).seconds} seconds...") } try { val checkResult = check() @@ -367,7 +376,7 @@ fun poll( class ShutdownManager(private val executorService: ExecutorService) { private class State { - val registeredShutdowns = ArrayList Unit>>() + val registeredShutdowns = ArrayList Unit>>() var isShutdown = false } @@ -389,7 +398,7 @@ class ShutdownManager(private val executorService: ExecutorService) { fun shutdown() { val shutdownActionFutures = state.locked { if (isShutdown) { - emptyList Unit>>() + emptyList Unit>>() } else { isShutdown = true registeredShutdowns @@ -407,15 +416,15 @@ class ShutdownManager(private val executorService: ExecutorService) { } } } - fun registerShutdown(shutdown: ListenableFuture<() -> Unit>) { + fun registerShutdown(shutdown: CordaFuture<() -> Unit>) { state.locked { require(!isShutdown) registeredShutdowns.add(shutdown) } } - fun registerShutdown(shutdown: () -> Unit) = registerShutdown(Futures.immediateFuture(shutdown)) + fun registerShutdown(shutdown: () -> Unit) = registerShutdown(doneFuture(shutdown)) - fun registerProcessShutdown(processFuture: ListenableFuture) { + fun registerProcessShutdown(processFuture: CordaFuture) { val processShutdown = processFuture.map { process -> { process.destroy() @@ -450,7 +459,7 @@ class ShutdownManager(private val executorService: ExecutorService) { registeredShutdowns.subList(start, end).listIterator(end - start).run { while (hasPrevious()) { previous().getOrThrow().invoke() - set(Futures.immediateFuture {}) // Don't break other followers by doing a remove. + set(doneFuture {}) // Don't break other followers by doing a remove. } } } @@ -469,14 +478,14 @@ class DriverDSL( val startNodesInProcess: Boolean ) : DriverDSLInternalInterface { private val dedicatedNetworkMapAddress = portAllocation.nextHostAndPort() - private var _executorService: ListeningScheduledExecutorService? = null + private var _executorService: ScheduledExecutorService? = null val executorService get() = _executorService!! private var _shutdownManager: ShutdownManager? = null override val shutdownManager get() = _shutdownManager!! private val callerPackage = getCallerPackage() class State { - val processes = ArrayList>() + val processes = ArrayList>() } private val state = ThreadBox(State()) @@ -490,7 +499,7 @@ class DriverDSL( Paths.get(quasarFileUrl.toURI()).toString() } - fun registerProcess(process: ListenableFuture) { + fun registerProcess(process: CordaFuture) { shutdownManager.registerProcessShutdown(process) state.locked { processes.add(process) @@ -498,7 +507,7 @@ class DriverDSL( } override fun waitForAllNodesToFinish() = state.locked { - Futures.allAsList(processes).get().forEach { + processes.transpose().get().forEach { it.waitFor() } } @@ -508,8 +517,8 @@ class DriverDSL( _executorService?.shutdownNow() } - private fun establishRpc(nodeAddress: NetworkHostAndPort, sslConfig: SSLConfiguration, processDeathFuture: ListenableFuture): ListenableFuture { - val client = CordaRPCClient(nodeAddress, sslConfig) + private fun establishRpc(nodeAddress: NetworkHostAndPort, sslConfig: SSLConfiguration, processDeathFuture: CordaFuture): CordaFuture { + val client = CordaRPCClient(nodeAddress, sslConfig, initialiseSerialization = false) val connectionFuture = poll(executorService, "RPC connection") { try { client.start(ArtemisMessagingComponent.NODE_USER, ArtemisMessagingComponent.NODE_USER) @@ -521,7 +530,7 @@ class DriverDSL( } return firstOf(connectionFuture, processDeathFuture) { if (it == processDeathFuture) { - throw processDeathFuture.getOrThrow() + throw ListenProcessDeathException(nodeAddress, processDeathFuture.getOrThrow()) } val connection = connectionFuture.getOrThrow() shutdownManager.registerShutdown(connection::close) @@ -555,12 +564,12 @@ class DriverDSL( verifierType: VerifierType, customOverrides: Map, startInSameProcess: Boolean? - ): ListenableFuture { + ): CordaFuture { val p2pAddress = portAllocation.nextHostAndPort() val rpcAddress = portAllocation.nextHostAndPort() val webAddress = portAllocation.nextHostAndPort() // TODO: Derive name from the full picked name, don't just wrap the common name - val name = providedName ?: X509Utilities.getX509Name("${oneOf(names).commonName}-${p2pAddress.port}","London","demo@r3.com",null) + val name = providedName ?: getX509Name("${oneOf(names).commonName}-${p2pAddress.port}", "London", "demo@r3.com", null) val networkMapServiceConfigLookup = networkMapServiceConfigLookup(listOf(object : NodeDefinition { override fun getName() = name.toString() override fun getConfig() = configOf("p2pAddress" to p2pAddress.toString()) @@ -583,7 +592,7 @@ class DriverDSL( return startNodeInternal(config, webAddress, startInSameProcess) } - override fun startNodes(nodes: List, startInSameProcess: Boolean?): List> { + override fun startNodes(nodes: List, startInSameProcess: Boolean?): List> { val networkMapServiceConfigLookup = networkMapServiceConfigLookup(nodes) return nodes.map { node -> portAllocation.nextHostAndPort() // rpcAddress @@ -611,7 +620,7 @@ class DriverDSL( verifierType: VerifierType, rpcUsers: List, startInSameProcess: Boolean? - ): ListenableFuture>> { + ): CordaFuture>> { val nodeNames = (0 until clusterSize).map { DUMMY_NOTARY.name.appendToCommonName(" $it") } val paths = nodeNames.map { baseDirectory(it) } ServiceIdentityGenerator.generateToDisk(paths, type.id, notaryName) @@ -636,7 +645,7 @@ class DriverDSL( return firstNotaryFuture.flatMap { firstNotary -> val notaryParty = firstNotary.nodeInfo.notaryIdentity - Futures.allAsList(restNotaryFutures).map { restNotaries -> + restNotaryFutures.transpose().map { restNotaries -> Pair(notaryParty, listOf(firstNotary) + restNotaries) } } @@ -659,7 +668,7 @@ class DriverDSL( throw IllegalStateException("Webserver at ${handle.webAddress} has died") } - override fun startWebserver(handle: NodeHandle): ListenableFuture { + override fun startWebserver(handle: NodeHandle): CordaFuture { val debugPort = if (isDebug) debugPortAllocation.nextPort() else null val processFuture = DriverDSL.startWebserver(executorService, handle, debugPort) registerProcess(processFuture) @@ -667,9 +676,7 @@ class DriverDSL( } override fun start() { - _executorService = MoreExecutors.listeningDecorator( - Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build()) - ) + _executorService = Executors.newScheduledThreadPool(2, ThreadFactoryBuilder().setNameFormat("driver-pool-thread-%d").build()) _shutdownManager = ShutdownManager(executorService) // We set this property so that in-process nodes find cordapps. Out-of-process nodes need this passed in when started. System.setProperty("net.corda.node.cordapp.scan.package", callerPackage) @@ -680,7 +687,7 @@ class DriverDSL( override fun baseDirectory(nodeName: X500Name): Path = driverDirectory / nodeName.commonName.replace(WHITESPACE, "") - override fun startDedicatedNetworkMapService(startInProcess: Boolean?): ListenableFuture { + override fun startDedicatedNetworkMapService(startInProcess: Boolean?): CordaFuture { val webAddress = portAllocation.nextHostAndPort() val networkMapLegalName = networkMapStartStrategy.legalName val config = ConfigHelper.loadConfig( @@ -698,7 +705,7 @@ class DriverDSL( return startNodeInternal(config, webAddress, startInProcess) } - private fun startNodeInternal(config: Config, webAddress: NetworkHostAndPort, startInProcess: Boolean?): ListenableFuture { + private fun startNodeInternal(config: Config, webAddress: NetworkHostAndPort, startInProcess: Boolean?): CordaFuture { val nodeConfiguration = config.parseAs() if (startInProcess ?: startNodesInProcess) { val nodeAndThreadFuture = startInProcessNode(executorService, nodeConfiguration, config) @@ -709,7 +716,7 @@ class DriverDSL( } } ) return nodeAndThreadFuture.flatMap { (node, thread) -> - establishRpc(nodeConfiguration.p2pAddress, nodeConfiguration, SettableFuture.create()).flatMap { rpc -> + establishRpc(nodeConfiguration.p2pAddress, nodeConfiguration, openFuture()).flatMap { rpc -> rpc.waitUntilRegisteredWithNetworkMap().map { NodeHandle.InProcess(rpc.nodeIdentity(), rpc, nodeConfiguration, webAddress, node, thread) } @@ -721,17 +728,17 @@ class DriverDSL( registerProcess(processFuture) return processFuture.flatMap { process -> val processDeathFuture = poll(executorService, "process death") { - if (process.isAlive) null else ListenProcessDeathException(nodeConfiguration.p2pAddress, process) + if (process.isAlive) null else process } // We continue to use SSL enabled port for RPC when its for node user. establishRpc(nodeConfiguration.p2pAddress, nodeConfiguration, processDeathFuture).flatMap { rpc -> // Call waitUntilRegisteredWithNetworkMap in background in case RPC is failing over: - val networkMapFuture = executorService.submit(Callable { + val networkMapFuture = executorService.fork { rpc.waitUntilRegisteredWithNetworkMap() - }).flatMap { it } + }.flatMap { it } firstOf(processDeathFuture, networkMapFuture) { if (it == processDeathFuture) { - throw processDeathFuture.getOrThrow() + throw ListenProcessDeathException(nodeConfiguration.p2pAddress, process) } processDeathFuture.cancel(false) NodeHandle.OutOfProcess(rpc.nodeIdentity(), rpc, nodeConfiguration, webAddress, debugPort, process) @@ -741,7 +748,7 @@ class DriverDSL( } } - override fun pollUntilNonNull(pollName: String, pollInterval: Duration, warnCount: Int, check: () -> A?): ListenableFuture { + override fun pollUntilNonNull(pollName: String, pollInterval: Duration, warnCount: Int, check: () -> A?): CordaFuture { val pollFuture = poll(executorService, pollName, pollInterval, warnCount, check) shutdownManager.registerShutdown { pollFuture.cancel(true) } return pollFuture @@ -757,17 +764,17 @@ class DriverDSL( private fun oneOf(array: Array) = array[Random().nextInt(array.size)] private fun startInProcessNode( - executorService: ListeningScheduledExecutorService, + executorService: ScheduledExecutorService, nodeConf: FullNodeConfiguration, config: Config - ): ListenableFuture> { - return executorService.submit> { + ): CordaFuture> { + return executorService.fork { log.info("Starting in-process Node ${nodeConf.myLegalName.commonName}") // Write node.conf writeConfig(nodeConf.baseDirectory, "node.conf", config) val clock: Clock = if (nodeConf.useTestClock) TestClock() else NodeClock() // TODO pass the version in? - val node = Node(nodeConf, nodeConf.calculateServices(), MOCK_VERSION_INFO, clock) + val node = Node(nodeConf, nodeConf.calculateServices(), MOCK_VERSION_INFO, clock, initialiseSerialization = false) node.start() val nodeThread = thread(name = nodeConf.myLegalName.commonName) { node.run() @@ -777,15 +784,15 @@ class DriverDSL( } private fun startOutOfProcessNode( - executorService: ListeningScheduledExecutorService, + executorService: ScheduledExecutorService, nodeConf: FullNodeConfiguration, config: Config, quasarJarPath: String, debugPort: Int?, overriddenSystemProperties: Map, callerPackage: String - ): ListenableFuture { - val processFuture = executorService.submit { + ): CordaFuture { + val processFuture = executorService.fork { log.info("Starting out-of-process Node ${nodeConf.myLegalName.commonName}") // Write node.conf writeConfig(nodeConf.baseDirectory, "node.conf", config) @@ -823,11 +830,11 @@ class DriverDSL( } private fun startWebserver( - executorService: ListeningScheduledExecutorService, + executorService: ScheduledExecutorService, handle: NodeHandle, debugPort: Int? - ): ListenableFuture { - return executorService.submit { + ): CordaFuture { + return executorService.fork { val className = "net.corda.webserver.WebServer" ProcessUtilities.startCordaProcess( className = className, // cannot directly get class for this, so just use string diff --git a/test-utils/src/main/kotlin/net/corda/testing/driver/ProcessUtilities.kt b/test-utils/src/main/kotlin/net/corda/testing/driver/ProcessUtilities.kt index 7ac9eedf94..c020d830a1 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/driver/ProcessUtilities.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/driver/ProcessUtilities.kt @@ -1,7 +1,7 @@ package net.corda.testing.driver -import net.corda.core.div -import net.corda.core.exists +import net.corda.core.internal.div +import net.corda.core.internal.exists import java.io.File.pathSeparator import java.nio.file.Path diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/DriverBasedTest.kt b/test-utils/src/main/kotlin/net/corda/testing/node/DriverBasedTest.kt index 6c9219dfb4..d9b4807845 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/DriverBasedTest.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/DriverBasedTest.kt @@ -1,7 +1,7 @@ package net.corda.testing.node import com.google.common.util.concurrent.SettableFuture -import net.corda.core.getOrThrow +import net.corda.core.utilities.getOrThrow import net.corda.testing.driver.DriverDSLExposedInterface import org.junit.After import org.junit.Before diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt b/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt index b5546e9fc4..440b4da95a 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/InMemoryMessagingNetwork.kt @@ -3,9 +3,8 @@ package net.corda.testing.node import com.google.common.util.concurrent.Futures import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.SettableFuture -import net.corda.core.ThreadBox -import net.corda.core.crypto.X509Utilities -import net.corda.core.getOrThrow +import net.corda.core.crypto.getX509Name +import net.corda.core.internal.ThreadBox import net.corda.core.messaging.AllPossibleRecipients import net.corda.core.messaging.MessageRecipientGroup import net.corda.core.messaging.MessageRecipients @@ -14,15 +13,15 @@ import net.corda.core.node.ServiceEntry import net.corda.core.node.services.PartyInfo import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SingletonSerializeAsToken +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.trace import net.corda.node.services.messaging.* import net.corda.node.utilities.AffinityExecutor +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.JDBCHashSet -import net.corda.node.utilities.transaction import net.corda.testing.node.InMemoryMessagingNetwork.InMemoryMessaging import org.apache.activemq.artemis.utils.ReusableLatch import org.bouncycastle.asn1.x500.X500Name -import org.jetbrains.exposed.sql.Database import org.slf4j.LoggerFactory import rx.Observable import rx.subjects.PublishSubject @@ -108,7 +107,7 @@ class InMemoryMessagingNetwork( fun createNode(manuallyPumped: Boolean, executor: AffinityExecutor, advertisedServices: List, - database: Database): Pair> { + database: CordaPersistence): Pair> { check(counter >= 0) { "In memory network stopped: please recreate." } val builder = createNodeWithID(manuallyPumped, counter, executor, advertisedServices, database = database) as Builder counter++ @@ -129,8 +128,8 @@ class InMemoryMessagingNetwork( id: Int, executor: AffinityExecutor, advertisedServices: List, - description: X500Name = X509Utilities.getX509Name("In memory node $id","London","demo@r3.com",null), - database: Database) + description: X500Name = getX509Name("In memory node $id", "London", "demo@r3.com", null), + database: CordaPersistence) : MessagingServiceBuilder { val peerHandle = PeerHandle(id, description) peersMapping[peerHandle.description] = peerHandle // Assume that the same name - the same entity in MockNetwork. @@ -187,7 +186,7 @@ class InMemoryMessagingNetwork( val id: PeerHandle, val serviceHandles: List, val executor: AffinityExecutor, - val database: Database) : MessagingServiceBuilder { + val database: CordaPersistence) : MessagingServiceBuilder { override fun start(): ListenableFuture { synchronized(this@InMemoryMessagingNetwork) { val node = InMemoryMessaging(manuallyPumped, id, executor, database) @@ -304,7 +303,7 @@ class InMemoryMessagingNetwork( inner class InMemoryMessaging(private val manuallyPumped: Boolean, private val peerHandle: PeerHandle, private val executor: AffinityExecutor, - private val database: Database) : SingletonSerializeAsToken(), MessagingService { + private val database: CordaPersistence) : SingletonSerializeAsToken(), MessagingService { inner class Handler(val topicSession: TopicSession, val callback: (ReceivedMessage, MessageHandlerRegistration) -> Unit) : MessageHandlerRegistration diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockNetworkMapCache.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockNetworkMapCache.kt index a9b26e75bb..ae80cc56da 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockNetworkMapCache.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockNetworkMapCache.kt @@ -7,6 +7,7 @@ import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub import net.corda.core.node.services.NetworkMapCache import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.NonEmptySet import net.corda.node.services.network.InMemoryNetworkMapCache import net.corda.testing.getTestPartyAndCertificate import net.corda.testing.getTestX509Name @@ -28,8 +29,8 @@ class MockNetworkMapCache(serviceHub: ServiceHub) : InMemoryNetworkMapCache(serv override val changed: Observable = PublishSubject.create() init { - val mockNodeA = NodeInfo(listOf(BANK_C_ADDR), BANK_C, setOf(BANK_C), 1) - val mockNodeB = NodeInfo(listOf(BANK_D_ADDR), BANK_D, setOf(BANK_D), 1) + val mockNodeA = NodeInfo(listOf(BANK_C_ADDR), BANK_C, NonEmptySet.of(BANK_C), 1) + val mockNodeB = NodeInfo(listOf(BANK_D_ADDR), BANK_D, NonEmptySet.of(BANK_D), 1) registeredNodes[mockNodeA.legalIdentity.owningKey] = mockNodeA registeredNodes[mockNodeB.legalIdentity.owningKey] = mockNodeB runWithoutMapService() diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt index aaf1abb4ad..addd075e7e 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockNode.kt @@ -2,25 +2,24 @@ package net.corda.testing.node import com.google.common.jimfs.Configuration.unix import com.google.common.jimfs.Jimfs -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture import com.nhaarman.mockito_kotlin.whenever -import net.corda.core.* import net.corda.core.crypto.CertificateAndKeyPair import net.corda.core.crypto.cert import net.corda.core.crypto.entropyToKeyPair import net.corda.core.crypto.random63BitValue import net.corda.core.identity.PartyAndCertificate +import net.corda.core.internal.concurrent.doneFuture +import net.corda.core.internal.createDirectories +import net.corda.core.internal.createDirectory import net.corda.core.messaging.MessageRecipients import net.corda.core.messaging.RPCOps import net.corda.core.messaging.SingleMessageRecipient import net.corda.core.node.CordaPluginRegistry import net.corda.core.node.ServiceEntry -import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.node.WorldMapLocation -import net.corda.core.node.services.IdentityService -import net.corda.core.node.services.KeyManagementService -import net.corda.core.node.services.ServiceInfo +import net.corda.core.node.services.* +import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.node.internal.AbstractNode import net.corda.node.services.config.NodeConfiguration @@ -29,9 +28,7 @@ import net.corda.node.services.keys.E2ETestKeyManagementService import net.corda.node.services.messaging.MessagingService import net.corda.node.services.network.InMemoryNetworkMapService import net.corda.node.services.network.NetworkMapService -import net.corda.node.services.transactions.InMemoryTransactionVerifierService -import net.corda.node.services.transactions.SimpleNotaryService -import net.corda.node.services.transactions.ValidatingNotaryService +import net.corda.node.services.transactions.* import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor import net.corda.testing.* @@ -39,11 +36,9 @@ import org.apache.activemq.artemis.utils.ReusableLatch import org.bouncycastle.asn1.x500.X500Name import org.slf4j.Logger import java.math.BigInteger -import java.nio.file.FileSystem import java.nio.file.Path import java.security.KeyPair import java.security.cert.X509Certificate -import java.util.* import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger @@ -64,29 +59,27 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, private val threadPerNode: Boolean = false, servicePeerAllocationStrategy: InMemoryMessagingNetwork.ServicePeerAllocationStrategy = InMemoryMessagingNetwork.ServicePeerAllocationStrategy.Random(), - private val defaultFactory: Factory = MockNetwork.DefaultFactory) { - val nextNodeId - get() = _nextNodeId - private var _nextNodeId = 0 - val filesystem: FileSystem = Jimfs.newFileSystem(unix()) - private val busyLatch: ReusableLatch = ReusableLatch() + private val defaultFactory: Factory<*> = MockNetwork.DefaultFactory, + private val initialiseSerialization: Boolean = true) { + var nextNodeId = 0 + private set + private val filesystem = Jimfs.newFileSystem(unix()) + private val busyLatch = ReusableLatch() val messagingNetwork = InMemoryMessagingNetwork(networkSendManuallyPumped, servicePeerAllocationStrategy, busyLatch) - // A unique identifier for this network to segregate databases with the same nodeID but different networks. private val networkId = random63BitValue() - - val identities = ArrayList() - - private val _nodes = ArrayList() + private val identities = mutableListOf() + private val _nodes = mutableListOf() /** A read only view of the current set of executing nodes. */ - val nodes: List = _nodes + val nodes: List get() = _nodes init { + if (initialiseSerialization) initialiseTestSerialization() filesystem.getPath("/nodes").createDirectory() } /** Allows customisation of how nodes are created. */ - interface Factory { + interface Factory { /** * @param overrideServices a set of service entries to use in place of the node's default service entries, * for example where a node's service is part of a cluster. @@ -95,10 +88,10 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, */ fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, - entropyRoot: BigInteger): MockNode + entropyRoot: BigInteger): N } - object DefaultFactory : Factory { + object DefaultFactory : Factory { override fun create(config: NodeConfiguration, network: MockNetwork, networkMapAddr: SingleMessageRecipient?, advertisedServices: Set, id: Int, overrideServices: Map?, entropyRoot: BigInteger): MockNode { @@ -215,7 +208,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, } // It's OK to not have a network map service in the mock network. - override fun noNetworkMapConfigured(): ListenableFuture = Futures.immediateFuture(Unit) + override fun noNetworkMapConfigured() = doneFuture(Unit) // There is no need to slow down the unit tests by initialising CityDatabase override fun findMyLocation(): WorldMapLocation? = null @@ -224,10 +217,9 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, override fun myAddresses() = emptyList() - override fun start(): MockNode { + override fun start() { super.start() mockNet.identities.add(info.legalIdentityAndCert) - return this } // Allow unit tests to modify the plugin list before the node start, @@ -257,17 +249,25 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, var acceptableLiveFiberCountOnStop: Int = 0 override fun acceptableLiveFiberCountOnStop(): Int = acceptableLiveFiberCountOnStop - } - /** - * Returns a node, optionally created by the passed factory method. - * @param overrideServices a set of service entries to use in place of the node's default service entries, - * for example where a node's service is part of a cluster. - */ - fun createNode(networkMapAddress: SingleMessageRecipient? = null, forcedID: Int = -1, nodeFactory: Factory = defaultFactory, - start: Boolean = true, legalName: X500Name? = null, overrideServices: Map? = null, - vararg advertisedServices: ServiceInfo): MockNode - = createNode(networkMapAddress, forcedID, nodeFactory, start, legalName, overrideServices, BigInteger.valueOf(random63BitValue()), *advertisedServices) + override fun makeCoreNotaryService(type: ServiceType): NotaryService? { + if (type != BFTNonValidatingNotaryService.type) return super.makeCoreNotaryService(type) + return BFTNonValidatingNotaryService(services, object : BFTSMaRt.Cluster { + override fun waitUntilAllReplicasHaveInitialized() { + val clusterNodes = mockNet.nodes.filter { + services.notaryIdentityKey in it.info.serviceIdentities(BFTNonValidatingNotaryService.type).map { it.owningKey } + } + if (clusterNodes.size != configuration.notaryClusterAddresses.size) { + throw IllegalStateException("Unable to enumerate all nodes in BFT cluster.") + } + clusterNodes.forEach { + val notaryService = it.smm.findServices { it is BFTNonValidatingNotaryService }.single() as BFTNonValidatingNotaryService + notaryService.waitUntilReplicaHasInitialized() + } + } + }) + } + } /** * Returns a node, optionally created by the passed factory method. @@ -277,32 +277,34 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, * but can be overridden to cause nodes to have stable or colliding identity/service keys. * @param configOverrides add/override behaviour of the [NodeConfiguration] mock object. */ - fun createNode(networkMapAddress: SingleMessageRecipient? = null, forcedID: Int = -1, nodeFactory: Factory = defaultFactory, + fun createNode(networkMapAddress: SingleMessageRecipient? = null, forcedID: Int? = null, start: Boolean = true, legalName: X500Name? = null, overrideServices: Map? = null, entropyRoot: BigInteger = BigInteger.valueOf(random63BitValue()), vararg advertisedServices: ServiceInfo, configOverrides: (NodeConfiguration) -> Any? = {}): MockNode { - val newNode = forcedID == -1 - val id = if (newNode) _nextNodeId++ else forcedID - - val path = baseDirectory(id) - if (newNode) - (path / "attachments").createDirectories() + return createNode(networkMapAddress, forcedID, defaultFactory, start, legalName, overrideServices, entropyRoot, *advertisedServices, configOverrides = configOverrides) + } + /** Like the other [createNode] but takes a [Factory] and propagates its [MockNode] subtype. */ + fun createNode(networkMapAddress: SingleMessageRecipient? = null, forcedID: Int? = null, nodeFactory: Factory, + start: Boolean = true, legalName: X500Name? = null, overrideServices: Map? = null, + entropyRoot: BigInteger = BigInteger.valueOf(random63BitValue()), + vararg advertisedServices: ServiceInfo, + configOverrides: (NodeConfiguration) -> Any? = {}): N { + val id = forcedID ?: nextNodeId++ val config = testNodeConfiguration( - baseDirectory = path, + baseDirectory = baseDirectory(id).createDirectories(), myLegalName = legalName ?: getTestX509Name("Mock Company $id")).also { whenever(it.dataSourceProperties).thenReturn(makeTestDataSourceProperties("node_${id}_net_$networkId")) configOverrides(it) } - val node = nodeFactory.create(config, this, networkMapAddress, advertisedServices.toSet(), id, overrideServices, entropyRoot) - if (start) { - node.setup().start() - if (threadPerNode && networkMapAddress != null) - node.networkMapRegistrationFuture.getOrThrow() // Block and wait for the node to register in the net map. + return nodeFactory.create(config, this, networkMapAddress, advertisedServices.toSet(), id, overrideServices, entropyRoot).apply { + if (start) { + start() + if (threadPerNode && networkMapAddress != null) networkMapRegistrationFuture.getOrThrow() // XXX: What about manually-started nodes? + } + _nodes.add(this) } - _nodes.add(node) - return node } fun baseDirectory(nodeId: Int): Path = filesystem.getPath("/nodes/$nodeId") @@ -328,27 +330,6 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, } } - // TODO: Move this to using createSomeNodes which doesn't conflate network services with network users. - /** - * Sets up a two node network, in which the first node runs network map and notary services and the other - * doesn't. - */ - fun createTwoNodes(firstNodeName: X500Name? = null, - secondNodeName: X500Name? = null, - nodeFactory: Factory = defaultFactory, - notaryKeyPair: KeyPair? = null): Pair { - require(nodes.isEmpty()) - val notaryServiceInfo = ServiceInfo(SimpleNotaryService.type) - val notaryOverride = if (notaryKeyPair != null) - mapOf(Pair(notaryServiceInfo, notaryKeyPair)) - else - null - return Pair( - createNode(null, -1, nodeFactory, true, firstNodeName, notaryOverride, BigInteger.valueOf(random63BitValue()), ServiceInfo(NetworkMapService.type), notaryServiceInfo), - createNode(nodes[0].network.myAddress, -1, nodeFactory, true, secondNodeName) - ) - } - /** * A bundle that separates the generic user nodes and service-providing nodes. A real network might not be so * clearly separated, but this is convenient for testing. @@ -357,42 +338,41 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, /** * Sets up a network with the requested number of nodes (defaulting to two), with one or more service nodes that - * run a notary, network map, any oracles etc. Can't be combined with [createTwoNodes]. + * run a notary, network map, any oracles etc. */ @JvmOverloads - fun createSomeNodes(numPartyNodes: Int = 2, nodeFactory: Factory = defaultFactory, notaryKeyPair: KeyPair? = DUMMY_NOTARY_KEY): BasketOfNodes { + fun createSomeNodes(numPartyNodes: Int = 2, nodeFactory: Factory<*> = defaultFactory, notaryKeyPair: KeyPair? = DUMMY_NOTARY_KEY): BasketOfNodes { require(nodes.isEmpty()) val notaryServiceInfo = ServiceInfo(SimpleNotaryService.type) val notaryOverride = if (notaryKeyPair != null) mapOf(Pair(notaryServiceInfo, notaryKeyPair)) else null - val mapNode = createNode(null, nodeFactory = nodeFactory, advertisedServices = ServiceInfo(NetworkMapService.type)) + val mapNode = createNode(nodeFactory = nodeFactory, advertisedServices = ServiceInfo(NetworkMapService.type)) val mapAddress = mapNode.network.myAddress - val notaryNode = createNode(mapAddress, nodeFactory = nodeFactory, overrideServices = notaryOverride, - advertisedServices = notaryServiceInfo) + val notaryNode = createNode(mapAddress, nodeFactory = nodeFactory, overrideServices = notaryOverride, advertisedServices = notaryServiceInfo) val nodes = ArrayList() repeat(numPartyNodes) { nodes += createPartyNode(mapAddress) } nodes.forEach { itNode -> - nodes.map { it.info.legalIdentityAndCert }.forEach(itNode.services.identityService::registerIdentity) + nodes.map { it.info.legalIdentityAndCert }.forEach(itNode.services.identityService::verifyAndRegisterIdentity) } return BasketOfNodes(nodes, notaryNode, mapNode) } - fun createNotaryNode(networkMapAddr: SingleMessageRecipient? = null, + fun createNotaryNode(networkMapAddress: SingleMessageRecipient? = null, legalName: X500Name? = null, overrideServices: Map? = null, serviceName: X500Name? = null): MockNode { - return createNode(networkMapAddr, -1, defaultFactory, true, legalName, overrideServices, BigInteger.valueOf(random63BitValue()), - ServiceInfo(NetworkMapService.type), ServiceInfo(ValidatingNotaryService.type, serviceName)) + return createNode(networkMapAddress, legalName = legalName, overrideServices = overrideServices, + advertisedServices = *arrayOf(ServiceInfo(NetworkMapService.type), ServiceInfo(ValidatingNotaryService.type, serviceName))) } - fun createPartyNode(networkMapAddr: SingleMessageRecipient, + fun createPartyNode(networkMapAddress: SingleMessageRecipient, legalName: X500Name? = null, overrideServices: Map? = null): MockNode { - return createNode(networkMapAddr, -1, defaultFactory, true, legalName, overrideServices) + return createNode(networkMapAddress, legalName = legalName, overrideServices = overrideServices) } @Suppress("unused") // This is used from the network visualiser tool. @@ -414,6 +394,7 @@ class MockNetwork(private val networkSendManuallyPumped: Boolean = false, fun stopNodes() { nodes.forEach { if (it.started) it.stop() } + if (initialiseSerialization) resetTestSerialization() } // Test method to block until all scheduled activity, active flows diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt index 9e1fba3192..43db1c6550 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/MockServices.kt @@ -8,10 +8,11 @@ import net.corda.core.messaging.DataFeed import net.corda.core.node.NodeInfo import net.corda.core.node.ServiceHub import net.corda.core.node.services.* +import net.corda.core.schemas.MappedSchema import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.transactions.SignedTransaction -import net.corda.flows.AnonymisedIdentity +import net.corda.core.utilities.NonEmptySet import net.corda.node.VersionInfo import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage import net.corda.node.services.api.WritableTransactionStorage @@ -23,11 +24,14 @@ import net.corda.node.services.persistence.InMemoryStateMachineRecordedTransacti import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.transactions.InMemoryTransactionVerifierService +import net.corda.node.services.vault.HibernateVaultQueryImpl import net.corda.node.services.vault.NodeVaultService -import net.corda.testing.DUMMY_CA -import net.corda.testing.MEGA_CORP -import net.corda.testing.MOCK_IDENTITIES -import net.corda.testing.getTestPartyAndCertificate +import net.corda.node.utilities.CordaPersistence +import net.corda.node.utilities.configureDatabase +import net.corda.schemas.CashSchemaV1 +import net.corda.schemas.CommercialPaperSchemaV1 +import net.corda.testing.* +import net.corda.testing.schemas.DummyLinearStateSchemaV1 import org.bouncycastle.operator.ContentSigner import rx.Observable import rx.subjects.PublishSubject @@ -38,6 +42,7 @@ import java.io.InputStream import java.security.KeyPair import java.security.PrivateKey import java.security.PublicKey +import java.sql.Connection import java.time.Clock import java.util.* import java.util.jar.JarInputStream @@ -50,6 +55,7 @@ import java.util.jar.JarInputStream * building chains of transactions and verifying them. It isn't sufficient for testing flows however. */ open class MockServices(vararg val keys: KeyPair) : ServiceHub { + constructor() : this(generateKeyPair()) val key: KeyPair get() = keys.first() @@ -75,17 +81,21 @@ open class MockServices(vararg val keys: KeyPair) : ServiceHub { override val clock: Clock get() = Clock.systemUTC() override val myInfo: NodeInfo get() { val identity = getTestPartyAndCertificate(MEGA_CORP.name, key.public) - return NodeInfo(emptyList(), identity, setOf(identity), 1) + return NodeInfo(emptyList(), identity, NonEmptySet.of(identity), 1) } override val transactionVerifierService: TransactionVerifierService get() = InMemoryTransactionVerifierService(2) - fun makeVaultService(dataSourceProps: Properties, hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService())): VaultService { - val vaultService = NodeVaultService(this, dataSourceProps) - HibernateObserver(vaultService.rawUpdates, hibernateConfig) + lateinit var hibernatePersister: HibernateObserver + + fun makeVaultService(dataSourceProps: Properties, hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), { identityService })): VaultService { + val vaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties()) + hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig) return vaultService } override fun cordaService(type: Class): T = throw IllegalArgumentException("${type.name} not found") + + override fun jdbcSession(): Connection = throw UnsupportedOperationException() } class MockKeyManagementService(val identityService: IdentityService, @@ -104,7 +114,7 @@ class MockKeyManagementService(val identityService: IdentityService, override fun filterMyKeys(candidateKeys: Iterable): Iterable = candidateKeys.filter { it in this.keys } - override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymisedIdentity { + override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): PartyAndCertificate { return freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled) } @@ -117,8 +127,12 @@ class MockKeyManagementService(val identityService: IdentityService, override fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey { val keyPair = getSigningKeyPair(publicKey) - val signature = keyPair.sign(bytes) - return signature + return keyPair.sign(bytes) + } + + override fun sign(signableData: SignableData, publicKey: PublicKey): TransactionSignature { + val keyPair = getSigningKeyPair(publicKey) + return keyPair.sign(signableData) } } @@ -195,4 +209,38 @@ fun makeTestDataSourceProperties(nodeName: String = SecureHash.randomSHA256().to return props } +fun makeTestDatabaseProperties(): Properties { + val props = Properties() + props.setProperty("transactionIsolationLevel", "repeatableRead") //for other possible values see net.corda.node.utilities.CordaPeristence.parserTransactionIsolationLevel(String) + return props +} + +fun makeTestIdentityService() = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) + +fun makeTestDatabaseAndMockServices(customSchemas: Set = setOf(CommercialPaperSchemaV1, DummyLinearStateSchemaV1, CashSchemaV1), keys: List = listOf(MEGA_CORP_KEY)): Pair { + val dataSourceProps = makeTestDataSourceProperties() + val databaseProperties = makeTestDatabaseProperties() + + val database = configureDatabase(dataSourceProps, databaseProperties, identitySvc = ::makeTestIdentityService) + val mockService = database.transaction { + val hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), databaseProperties, identitySvc = ::makeTestIdentityService) + object : MockServices(*(keys.toTypedArray())) { + override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig) + + override fun recordTransactions(txs: Iterable) { + for (stx in txs) { + validatedTransactions.addTransaction(stx) + } + // Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions. + vaultService.notifyAll(txs.map { it.tx }) + } + + override val vaultQueryService: VaultQueryService = HibernateVaultQueryImpl(hibernateConfig, vaultService.updatesPublisher) + + override fun jdbcSession(): Connection = database.createSession() + } + } + return Pair(database, mockService) +} + val MOCK_VERSION_INFO = VersionInfo(1, "Mock release", "Mock revision", "Mock Vendor") diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt b/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt index 81b11bd7fc..7e8d3bedd1 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/NodeBasedTest.kt @@ -1,15 +1,19 @@ package net.corda.testing.node -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.MoreExecutors.listeningDecorator -import net.corda.core.* -import net.corda.core.crypto.X509Utilities +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.appendToCommonName import net.corda.core.crypto.commonName +import net.corda.core.crypto.getX509Name +import net.corda.core.internal.concurrent.flatMap +import net.corda.core.internal.concurrent.fork +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.transpose +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType import net.corda.core.utilities.WHITESPACE +import net.corda.core.utilities.getOrThrow import net.corda.node.internal.Node import net.corda.node.serialization.NodeClock import net.corda.node.services.config.ConfigHelper @@ -21,6 +25,7 @@ import net.corda.node.utilities.ServiceIdentityGenerator import net.corda.nodeapi.User import net.corda.nodeapi.config.parseAs import net.corda.testing.DUMMY_MAP +import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.driver.addressMustNotBeBoundFuture import net.corda.testing.getFreeLocalPorts import org.apache.logging.log4j.Level @@ -37,7 +42,7 @@ import kotlin.concurrent.thread * purposes. Use the driver if you need to run the nodes in separate processes otherwise this class will suffice. */ // TODO Some of the logic here duplicates what's in the driver -abstract class NodeBasedTest { +abstract class NodeBasedTest : TestDependencyInjectionBase() { @Rule @JvmField val tempFolder = TemporaryFolder() @@ -57,8 +62,8 @@ abstract class NodeBasedTest { */ @After fun stopAllNodes() { - val shutdownExecutor = listeningDecorator(Executors.newScheduledThreadPool(nodes.size)) - Futures.allAsList(nodes.map { shutdownExecutor.submit(it::stop) }).getOrThrow() + val shutdownExecutor = Executors.newScheduledThreadPool(nodes.size) + nodes.map { shutdownExecutor.fork(it::stop) }.transpose().getOrThrow() // Wait until ports are released val portNotBoundChecks = nodes.flatMap { listOf( @@ -68,7 +73,7 @@ abstract class NodeBasedTest { }.filterNotNull() nodes.clear() _networkMapNode = null - Futures.allAsList(portNotBoundChecks).getOrThrow() + portNotBoundChecks.transpose().getOrThrow() } /** @@ -90,7 +95,7 @@ abstract class NodeBasedTest { platformVersion: Int = 1, advertisedServices: Set = emptySet(), rpcUsers: List = emptyList(), - configOverrides: Map = emptyMap()): ListenableFuture { + configOverrides: Map = emptyMap()): CordaFuture { val node = startNodeInternal( legalName, platformVersion, @@ -108,7 +113,7 @@ abstract class NodeBasedTest { fun startNotaryCluster(notaryName: X500Name, clusterSize: Int, - serviceType: ServiceType = RaftValidatingNotaryService.type): ListenableFuture> { + serviceType: ServiceType = RaftValidatingNotaryService.type): CordaFuture> { ServiceIdentityGenerator.generateToDisk( (0 until clusterSize).map { baseDirectory(notaryName.appendToCommonName("-$it")) }, serviceType.id, @@ -118,20 +123,20 @@ abstract class NodeBasedTest { val nodeAddresses = getFreeLocalPorts("localhost", clusterSize).map { it.toString() } val masterNodeFuture = startNode( - X509Utilities.getX509Name("${notaryName.commonName}-0","London","demo@r3.com",null), + getX509Name("${notaryName.commonName}-0", "London", "demo@r3.com", null), advertisedServices = setOf(serviceInfo), configOverrides = mapOf("notaryNodeAddress" to nodeAddresses[0])) val remainingNodesFutures = (1 until clusterSize).map { startNode( - X509Utilities.getX509Name("${notaryName.commonName}-$it","London","demo@r3.com",null), + getX509Name("${notaryName.commonName}-$it", "London", "demo@r3.com", null), advertisedServices = setOf(serviceInfo), configOverrides = mapOf( "notaryNodeAddress" to nodeAddresses[it], "notaryClusterAddresses" to listOf(nodeAddresses[0]))) } - return Futures.allAsList(remainingNodesFutures).flatMap { remainingNodes -> + return remainingNodesFutures.transpose().flatMap { remainingNodes -> masterNodeFuture.map { masterNode -> listOf(masterNode) + remainingNodes } } } @@ -159,7 +164,7 @@ abstract class NodeBasedTest { val parsedConfig = config.parseAs() val node = Node(parsedConfig, parsedConfig.calculateServices(), MOCK_VERSION_INFO.copy(platformVersion = platformVersion), - if (parsedConfig.useTestClock) TestClock() else NodeClock()) + if (parsedConfig.useTestClock) TestClock() else NodeClock(), initialiseSerialization = false) node.start() nodes += node thread(name = legalName.commonName) { diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt b/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt index d006e380ac..0f55f1bdb3 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/SimpleNode.kt @@ -1,9 +1,9 @@ package net.corda.testing.node import com.codahale.metrics.MetricRegistry -import com.google.common.util.concurrent.SettableFuture import net.corda.core.crypto.commonName import net.corda.core.crypto.generateKeyPair +import net.corda.core.internal.concurrent.openFuture import net.corda.core.messaging.RPCOps import net.corda.core.node.services.IdentityService import net.corda.core.node.services.KeyManagementService @@ -17,11 +17,9 @@ import net.corda.node.services.messaging.ArtemisMessagingServer import net.corda.node.services.messaging.NodeMessagingClient import net.corda.node.services.network.InMemoryNetworkMapCache import net.corda.node.utilities.AffinityExecutor.ServiceAffinityExecutor +import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.configureDatabase -import net.corda.node.utilities.transaction import net.corda.testing.freeLocalHostAndPort -import org.jetbrains.exposed.sql.Database -import java.io.Closeable import java.security.KeyPair import java.security.cert.X509Certificate import kotlin.concurrent.thread @@ -34,17 +32,16 @@ class SimpleNode(val config: NodeConfiguration, val address: NetworkHostAndPort rpcAddress: NetworkHostAndPort = freeLocalHostAndPort(), trustRoot: X509Certificate) : AutoCloseable { - private val databaseWithCloseable: Pair = configureDatabase(config.dataSourceProperties) - val database: Database get() = databaseWithCloseable.second val userService = RPCUserServiceImpl(config.rpcUsers) val monitoringService = MonitoringService(MetricRegistry()) val identity: KeyPair = generateKeyPair() val identityService: IdentityService = InMemoryIdentityService(trustRoot = trustRoot) + val database: CordaPersistence = configureDatabase(config.dataSourceProperties, config.database, identitySvc = {InMemoryIdentityService(trustRoot = trustRoot)}) val keyService: KeyManagementService = E2ETestKeyManagementService(identityService, setOf(identity)) val executor = ServiceAffinityExecutor(config.myLegalName.commonName, 1) // TODO: We should have a dummy service hub rather than change behaviour in tests val broker = ArtemisMessagingServer(config, address.port, rpcAddress.port, InMemoryNetworkMapCache(serviceHub = null), userService) - val networkMapRegistrationFuture: SettableFuture = SettableFuture.create() + val networkMapRegistrationFuture = openFuture() val network = database.transaction { NodeMessagingClient( config, @@ -72,7 +69,7 @@ class SimpleNode(val config: NodeConfiguration, val address: NetworkHostAndPort override fun close() { network.stop() broker.stop() - databaseWithCloseable.first.close() + database.close() executor.shutdownNow() } } diff --git a/test-utils/src/main/kotlin/net/corda/testing/node/TestClock.kt b/test-utils/src/main/kotlin/net/corda/testing/node/TestClock.kt index 46d26859e9..d0303085d8 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/node/TestClock.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/node/TestClock.kt @@ -1,9 +1,10 @@ package net.corda.testing.node +import net.corda.core.internal.until import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsTokenContext -import net.corda.core.serialization.SingletonSerializationToken import net.corda.core.serialization.SingletonSerializationToken.Companion.singletonSerializationToken +import net.corda.core.internal.until import net.corda.node.utilities.MutableClock import java.time.Clock import java.time.Duration @@ -35,7 +36,7 @@ class TestClock(private var delegateClock: Clock = Clock.systemUTC()) : MutableC * * This will only be approximate due to the time ticking away, but will be some time shortly after the requested [Instant]. */ - @Synchronized fun setTo(newInstant: Instant) = advanceBy(Duration.between(instant(), newInstant)) + @Synchronized fun setTo(newInstant: Instant) = advanceBy(instant() until newInstant) @Synchronized override fun instant(): Instant { return delegateClock.instant() diff --git a/test-utils/src/main/kotlin/net/corda/testing/performance/Injectors.kt b/test-utils/src/main/kotlin/net/corda/testing/performance/Injectors.kt index 5690d0dba9..c57e9a165f 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/performance/Injectors.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/performance/Injectors.kt @@ -3,7 +3,6 @@ package net.corda.testing.performance import com.codahale.metrics.Gauge import com.codahale.metrics.MetricRegistry import com.google.common.base.Stopwatch -import net.corda.core.utilities.Rate import net.corda.testing.driver.ShutdownManager import java.time.Duration import java.util.* diff --git a/core/src/main/kotlin/net/corda/core/utilities/Rate.kt b/test-utils/src/main/kotlin/net/corda/testing/performance/Rate.kt similarity index 71% rename from core/src/main/kotlin/net/corda/core/utilities/Rate.kt rename to test-utils/src/main/kotlin/net/corda/testing/performance/Rate.kt index 1936a27fa3..8a01bd4c27 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/Rate.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/performance/Rate.kt @@ -1,4 +1,4 @@ -package net.corda.core.utilities +package net.corda.testing.performance import java.time.Duration import java.time.temporal.ChronoUnit @@ -21,9 +21,9 @@ data class Rate( /** * Converts the number of events to the given unit. */ - operator fun times(inUnit: TimeUnit): Long { - return inUnit.convert(numberOfEvents, perTimeUnit) - } + operator fun times(inUnit: TimeUnit): Long = inUnit.convert(numberOfEvents, perTimeUnit) + + override fun toString(): String = "$numberOfEvents / ${perTimeUnit.name.dropLast(1).toLowerCase()}" // drop the "s" at the end } operator fun Long.div(timeUnit: TimeUnit) = Rate(this, timeUnit) diff --git a/test-utils/src/main/kotlin/net/corda/testing/schemas/DummyDealStateSchemaV1.kt b/test-utils/src/main/kotlin/net/corda/testing/schemas/DummyDealStateSchemaV1.kt index c59a09b192..7bf754473c 100644 --- a/test-utils/src/main/kotlin/net/corda/testing/schemas/DummyDealStateSchemaV1.kt +++ b/test-utils/src/main/kotlin/net/corda/testing/schemas/DummyDealStateSchemaV1.kt @@ -3,7 +3,6 @@ package net.corda.testing.schemas import net.corda.core.contracts.UniqueIdentifier import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.MappedSchema -import javax.persistence.Column import javax.persistence.Entity import javax.persistence.Table import javax.persistence.Transient @@ -21,10 +20,6 @@ object DummyDealStateSchemaV1 : MappedSchema(schemaFamily = DummyDealStateSchema @Entity @Table(name = "dummy_deal_states") class PersistentDummyDealState( - - @Column(name = "deal_reference") - var dealReference: String, - /** parent attributes */ @Transient val uid: UniqueIdentifier diff --git a/test-utils/src/main/resources/META-INF/services/net.corda.core.flows.FlowStackSnapshotFactory b/test-utils/src/main/resources/META-INF/services/net.corda.core.flows.FlowStackSnapshotFactory new file mode 100644 index 0000000000..37ed9eb7fc --- /dev/null +++ b/test-utils/src/main/resources/META-INF/services/net.corda.core.flows.FlowStackSnapshotFactory @@ -0,0 +1 @@ +net.corda.testing.FlowStackSnapshotFactoryImpl \ No newline at end of file diff --git a/tools/demobench/build.gradle b/tools/demobench/build.gradle index 8d71db2137..1eb6b3a8d8 100644 --- a/tools/demobench/build.gradle +++ b/tools/demobench/build.gradle @@ -189,6 +189,10 @@ task javapackage(dependsOn: distZip) { include(name: 'plugins/*.jar') include(name: 'explorer/*.jar') } + + fileset(dir: "$pkg_source/package", type: 'data') { + include(name: "bugfixes/**") + } } // This is specific to MacOSX packager. diff --git a/tools/demobench/package-demobench-dmg.sh b/tools/demobench/package-demobench-dmg.sh index c72c6d3abe..ba0cddaa93 100755 --- a/tools/demobench/package-demobench-dmg.sh +++ b/tools/demobench/package-demobench-dmg.sh @@ -7,7 +7,12 @@ if [ -z "$JAVA_HOME" -o ! -x $JAVA_HOME/bin/java ]; then exit 1 fi -$DIRNAME/../../gradlew -PpackageType=dmg javapackage $* -echo -echo "Wrote installer to '$(find build/javapackage/bundles -type f)'" -echo +if ($DIRNAME/../../gradlew -PpackageType=dmg javapackage $*); then + echo + echo "Wrote installer to '$(find build/javapackage/bundles -type f)'" + echo +else + echo "Failed to create installer." + exit 1 +fi + diff --git a/tools/demobench/package-demobench-exe.bat b/tools/demobench/package-demobench-exe.bat index bb6c21191d..c2f0b41e80 100644 --- a/tools/demobench/package-demobench-exe.bat +++ b/tools/demobench/package-demobench-exe.bat @@ -8,13 +8,19 @@ if not defined JAVA_HOME goto NoJavaHome set DIRNAME=%~dp0 if "%DIRNAME%" == "" set DIRNAME=. -call %DIRNAME%\..\..\gradlew -PpackageType=exe javapackage +call %DIRNAME%\..\..\gradlew -PpackageType=exe javapackage %* +if ERRORLEVEL 1 goto Fail @echo -@echo "Wrote installer to %DIRNAME%\build\javapackage\bundles\" +@echo Wrote installer to %DIRNAME%\build\javapackage\bundles\ @echo goto end :NoJavaHome -@echo "Please set JAVA_HOME correctly" +@echo Please set JAVA_HOME correctly. +exit /b 1 + +:Fail +@echo Failed to write installer. +exit /b 1 :end diff --git a/tools/demobench/package-demobench-rpm.sh b/tools/demobench/package-demobench-rpm.sh index 3d14661206..5bcc9c2167 100755 --- a/tools/demobench/package-demobench-rpm.sh +++ b/tools/demobench/package-demobench-rpm.sh @@ -7,7 +7,11 @@ if [ -z "$JAVA_HOME" -o ! -x $JAVA_HOME/bin/java ]; then exit 1 fi -$DIRNAME/../../gradlew -PpackageType=rpm javapackage $* -echo -echo "Wrote installer to '$(find $DIRNAME/build/javapackage/bundles -type f)'" -echo +if ($DIRNAME/../../gradlew -PpackageType=rpm javapackage $*); then + echo + echo "Wrote installer to '$(find $DIRNAME/build/javapackage/bundles -type f)'" + echo +else + echo "Failed to create installer." + exit 1 +fi diff --git a/tools/demobench/package/bugfixes/apply.bat b/tools/demobench/package/bugfixes/apply.bat new file mode 100644 index 0000000000..f770222122 --- /dev/null +++ b/tools/demobench/package/bugfixes/apply.bat @@ -0,0 +1,39 @@ +@echo off + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. + +set SOURCEDIR=%DIRNAME%java +set BUILDDIR=%DIRNAME%build + +if '%1' == '' ( + @echo Need location of rt.jar + exit /b 1 +) +if not "%~nx1" == "rt.jar" ( + @echo File '%1' is not rt.jar + exit /b 1 +) +if not exist %1 ( + @echo %1 not found. + exit /b 1 +) + +if exist "%BUILDDIR%" rmdir /s /q "%BUILDDIR%" +mkdir "%BUILDDIR%" + +for /r "%SOURCEDIR%" %%j in (*.java) do ( + javac -O -d "%BUILDDIR%" "%%j" + if ERRORLEVEL 1 ( + @echo "Failed to compile %%j" + exit /b 1 + ) +) + +jar uvf %1 -C "%BUILDDIR%" . +if ERRORLEVEL 1 ( + @echo "Failed to update %1" + exit /b 1 +) + +@echo "Completed" diff --git a/tools/demobench/package/bugfixes/apply.sh b/tools/demobench/package/bugfixes/apply.sh new file mode 100755 index 0000000000..560f83abc7 --- /dev/null +++ b/tools/demobench/package/bugfixes/apply.sh @@ -0,0 +1,32 @@ +#!/bin/sh + +BASEDIR=$(dirname $0) +SOURCEDIR=$BASEDIR/java +BUILDDIR=$BASEDIR/build +RTJAR=$1 + +if [ -z "$RTJAR" ]; then + echo "Need location of rt.jar" + exit 1 +elif [ $(basename $RTJAR) != "rt.jar" ]; then + echo "File is not rt.jar" + exit 1 +elif [ ! -f $RTJAR ]; then + echo "$RTJAR not found" + exit 1 +fi + +# Bugfixes: +# ========= +# +# sun.swing.JLightweightFrame:473 +# https://github.com/JetBrains/jdk8u_jdk/issues/6 +# https://github.com/JetBrains/jdk8u/issues/8 + +rm -rf $BUILDDIR && mkdir $BUILDDIR +if (javac -O -d $BUILDDIR $(find $SOURCEDIR -name "*.java")); then + chmod u+w $RTJAR + jar uvf $RTJAR -C $BUILDDIR . + chmod ugo-w $RTJAR +fi + diff --git a/tools/demobench/package/bugfixes/java/sun/swing/JLightweightFrame.java b/tools/demobench/package/bugfixes/java/sun/swing/JLightweightFrame.java new file mode 100644 index 0000000000..6f39eb5d62 --- /dev/null +++ b/tools/demobench/package/bugfixes/java/sun/swing/JLightweightFrame.java @@ -0,0 +1,515 @@ +/* + * Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package sun.swing; + +import java.awt.BorderLayout; +import java.awt.Color; +import java.awt.Component; +import java.awt.Container; +import java.awt.Dimension; +import java.awt.EventQueue; +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.awt.MouseInfo; +import java.awt.Point; +import java.awt.PointerInfo; +import java.awt.Rectangle; +import java.awt.Window; +import java.awt.dnd.DragGestureEvent; +import java.awt.dnd.DragGestureListener; +import java.awt.dnd.DragGestureRecognizer; +import java.awt.dnd.DragSource; +import java.awt.dnd.DropTarget; +import java.awt.dnd.InvalidDnDOperationException; +import java.awt.dnd.peer.DragSourceContextPeer; +import java.awt.event.ContainerEvent; +import java.awt.event.ContainerListener; +import java.awt.image.BufferedImage; +import java.awt.image.DataBufferInt; +import java.beans.PropertyChangeEvent; +import java.beans.PropertyChangeListener; +import java.security.AccessController; +import java.util.logging.Logger; +import javax.swing.JComponent; + +import javax.swing.JLayeredPane; +import javax.swing.JPanel; +import javax.swing.JRootPane; +import javax.swing.LayoutFocusTraversalPolicy; +import javax.swing.RepaintManager; +import javax.swing.RootPaneContainer; +import javax.swing.SwingUtilities; + +import sun.awt.DisplayChangedListener; +import sun.awt.LightweightFrame; +import sun.security.action.GetPropertyAction; +import sun.swing.SwingUtilities2.RepaintListener; + +/** + * The frame serves as a lightweight container which paints its content + * to an offscreen image and provides access to the image's data via the + * {@link LightweightContent} interface. Note, that it may not be shown + * as a standalone toplevel frame. Its purpose is to provide functionality + * for lightweight embedding. + * + * @author Artem Ananiev + * @author Anton Tarasov + */ +public final class JLightweightFrame extends LightweightFrame implements RootPaneContainer { + + private final JRootPane rootPane = new JRootPane(); + + private LightweightContent content; + + private Component component; + private JPanel contentPane; + + private BufferedImage bbImage; + + private volatile int scaleFactor = 1; + + /** + * {@code copyBufferEnabled}, true by default, defines the following strategy. + * A duplicating (copy) buffer is created for the original pixel buffer. + * The copy buffer is synchronized with the original buffer every time the + * latter changes. {@code JLightweightFrame} passes the copy buffer array + * to the {@link LightweightContent#imageBufferReset} method. The code spot + * which synchronizes two buffers becomes the only critical section guarded + * by the lock (managed with the {@link LightweightContent#paintLock()}, + * {@link LightweightContent#paintUnlock()} methods). + */ + private static boolean copyBufferEnabled; + private int[] copyBuffer; + + private PropertyChangeListener layoutSizeListener; + private RepaintListener repaintListener; + + static { + SwingAccessor.setJLightweightFrameAccessor(new SwingAccessor.JLightweightFrameAccessor() { + @Override + public void updateCursor(JLightweightFrame frame) { + frame.updateClientCursor(); + } + }); + copyBufferEnabled = "true".equals(AccessController. + doPrivileged(new GetPropertyAction("swing.jlf.copyBufferEnabled", "true"))); + } + + /** + * Constructs a new, initially invisible {@code JLightweightFrame} + * instance. + */ + public JLightweightFrame() { + super(); + copyBufferEnabled = "true".equals(AccessController. + doPrivileged(new GetPropertyAction("swing.jlf.copyBufferEnabled", "true"))); + + add(rootPane, BorderLayout.CENTER); + setFocusTraversalPolicy(new LayoutFocusTraversalPolicy()); + if (getGraphicsConfiguration().isTranslucencyCapable()) { + setBackground(new Color(0, 0, 0, 0)); + } + + layoutSizeListener = new PropertyChangeListener() { + @Override + public void propertyChange(PropertyChangeEvent e) { + Dimension d = (Dimension)e.getNewValue(); + + if ("preferredSize".equals(e.getPropertyName())) { + content.preferredSizeChanged(d.width, d.height); + + } else if ("maximumSize".equals(e.getPropertyName())) { + content.maximumSizeChanged(d.width, d.height); + + } else if ("minimumSize".equals(e.getPropertyName())) { + content.minimumSizeChanged(d.width, d.height); + } + } + }; + + repaintListener = (JComponent c, int x, int y, int w, int h) -> { + Window jlf = SwingUtilities.getWindowAncestor(c); + if (jlf != JLightweightFrame.this) { + return; + } + Point p = SwingUtilities.convertPoint(c, x, y, jlf); + Rectangle r = new Rectangle(p.x, p.y, w, h).intersection( + new Rectangle(0, 0, bbImage.getWidth() / scaleFactor, + bbImage.getHeight() / scaleFactor)); + + if (!r.isEmpty()) { + notifyImageUpdated(r.x, r.y, r.width, r.height); + } + }; + + SwingAccessor.getRepaintManagerAccessor().addRepaintListener( + RepaintManager.currentManager(this), repaintListener); + } + + @Override + public void dispose() { + SwingAccessor.getRepaintManagerAccessor().removeRepaintListener( + RepaintManager.currentManager(this), repaintListener); + super.dispose(); + } + + /** + * Sets the {@link LightweightContent} instance for this frame. + * The {@code JComponent} object returned by the + * {@link LightweightContent#getComponent()} method is immediately + * added to the frame's content pane. + * + * @param content the {@link LightweightContent} instance + */ + public void setContent(final LightweightContent content) { + if (content == null) { + System.err.println("JLightweightFrame.setContent: content may not be null!"); + return; + } + this.content = content; + this.component = content.getComponent(); + + Dimension d = this.component.getPreferredSize(); + content.preferredSizeChanged(d.width, d.height); + + d = this.component.getMaximumSize(); + content.maximumSizeChanged(d.width, d.height); + + d = this.component.getMinimumSize(); + content.minimumSizeChanged(d.width, d.height); + + initInterior(); + } + + @Override + public Graphics getGraphics() { + if (bbImage == null) return null; + + Graphics2D g = bbImage.createGraphics(); + g.setBackground(getBackground()); + g.setColor(getForeground()); + g.setFont(getFont()); + g.scale(scaleFactor, scaleFactor); + return g; + } + + /** + * {@inheritDoc} + * + * @see LightweightContent#focusGrabbed() + */ + @Override + public void grabFocus() { + if (content != null) content.focusGrabbed(); + } + + /** + * {@inheritDoc} + * + * @see LightweightContent#focusUngrabbed() + */ + @Override + public void ungrabFocus() { + if (content != null) content.focusUngrabbed(); + } + + @Override + public int getScaleFactor() { + return scaleFactor; + } + + @Override + public void notifyDisplayChanged(final int scaleFactor) { + if (scaleFactor != this.scaleFactor) { + if (!copyBufferEnabled) content.paintLock(); + try { + if (bbImage != null) { + resizeBuffer(getWidth(), getHeight(), scaleFactor); + } + } finally { + if (!copyBufferEnabled) content.paintUnlock(); + } + this.scaleFactor = scaleFactor; + } + if (getPeer() instanceof DisplayChangedListener) { + ((DisplayChangedListener)getPeer()).displayChanged(); + } + repaint(); + } + + @Override + public void addNotify() { + super.addNotify(); + if (getPeer() instanceof DisplayChangedListener) { + ((DisplayChangedListener)getPeer()).displayChanged(); + } + } + + private void syncCopyBuffer(boolean reset, int x, int y, int w, int h, int scale) { + content.paintLock(); + try { + int[] srcBuffer = ((DataBufferInt)bbImage.getRaster().getDataBuffer()).getData(); + if (reset) { + copyBuffer = new int[srcBuffer.length]; + } + int linestride = bbImage.getWidth(); + + x *= scale; + y *= scale; + w *= scale; + h *= scale; + + for (int i=0; i= newW) && (oldH >= newH)) { + createBB = false; + } else { + if (oldW >= newW) { + newW = oldW; + } else { + newW = Math.max((int)(oldW * 1.2), width); + } + if (oldH >= newH) { + newH = oldH; + } else { + newH = Math.max((int)(oldH * 1.2), height); + } + } + } + } + } + if (createBB) { + resizeBuffer(newW, newH, scaleFactor); + return; + } + content.imageReshaped(0, 0, width, height); + + } finally { + if (!copyBufferEnabled) { + content.paintUnlock(); + } + } + } + + private void resizeBuffer(int width, int height, int newScaleFactor) { + bbImage = new BufferedImage(width*newScaleFactor,height*newScaleFactor, + BufferedImage.TYPE_INT_ARGB_PRE); + int[] pixels= ((DataBufferInt)bbImage.getRaster().getDataBuffer()).getData(); + if (copyBufferEnabled) { + syncCopyBuffer(true, 0, 0, width, height, newScaleFactor); + pixels = copyBuffer; + } + content.imageBufferReset(pixels, 0, 0, width, height, + width * newScaleFactor, newScaleFactor); + } + + @Override + public JRootPane getRootPane() { + return rootPane; + } + + @Override + public void setContentPane(Container contentPane) { + getRootPane().setContentPane(contentPane); + } + + @Override + public Container getContentPane() { + return getRootPane().getContentPane(); + } + + @Override + public void setLayeredPane(JLayeredPane layeredPane) { + getRootPane().setLayeredPane(layeredPane); + } + + @Override + public JLayeredPane getLayeredPane() { + return getRootPane().getLayeredPane(); + } + + @Override + public void setGlassPane(Component glassPane) { + getRootPane().setGlassPane(glassPane); + } + + @Override + public Component getGlassPane() { + return getRootPane().getGlassPane(); + } + + + /* + * Notifies client toolkit that it should change a cursor. + * + * Called from the peer via SwingAccessor, because the + * Component.updateCursorImmediately method is final + * and could not be overridden. + */ + private void updateClientCursor() { + PointerInfo pointerInfo = MouseInfo.getPointerInfo(); + /* + * BUGFIX not yet applied upstream! + */ + if (pointerInfo == null) { + Logger log = Logger.getLogger(getClass().getName()); + log.warning("BUGFIX - NPE avoided"); + return; + } + Point p = pointerInfo.getLocation(); + SwingUtilities.convertPointFromScreen(p, this); + Component target = SwingUtilities.getDeepestComponentAt(this, p.x, p.y); + if (target != null) { + content.setCursor(target.getCursor()); + } + } + + public T createDragGestureRecognizer( + Class abstractRecognizerClass, + DragSource ds, Component c, int srcActions, + DragGestureListener dgl) + { + return content == null ? null : content.createDragGestureRecognizer( + abstractRecognizerClass, ds, c, srcActions, dgl); + } + + public DragSourceContextPeer createDragSourceContextPeer(DragGestureEvent dge) throws InvalidDnDOperationException { + return content == null ? null : content.createDragSourceContextPeer(dge); + } + + public void addDropTarget(DropTarget dt) { + if (content == null) return; + content.addDropTarget(dt); + } + + public void removeDropTarget(DropTarget dt) { + if (content == null) return; + content.removeDropTarget(dt); + } +} + diff --git a/tools/demobench/package/linux/CordaDemoBench.spec b/tools/demobench/package/linux/CordaDemoBench.spec index d000f25216..b3d6c3f0e3 100644 --- a/tools/demobench/package/linux/CordaDemoBench.spec +++ b/tools/demobench/package/linux/CordaDemoBench.spec @@ -20,6 +20,8 @@ Autoreq: 0 %define __jar_repack %{nil} %define _javaHome %{getenv:JAVA_HOME} +%define _bugfixdir %{_sourcedir}/CordaDemoBench/app/bugfixes +%define _rtJar %{_sourcedir}/CordaDemoBench/runtime/lib/rt.jar %description Corda DemoBench @@ -29,6 +31,12 @@ Corda DemoBench %build %install +# Apply bugfixes to installed rt.jar +if [ -f %{_bugfixdir}/apply.sh ]; then + chmod ugo+x %{_bugfixdir}/apply.sh + %{_bugfixdir}/apply.sh %{_rtJar} + rm -rf %{_bugfixdir} +fi rm -rf %{buildroot} mkdir -p %{buildroot}/opt cp -r %{_sourcedir}/CordaDemoBench %{buildroot}/opt diff --git a/tools/demobench/package/macosx/Corda DemoBench-post-image.sh b/tools/demobench/package/macosx/Corda DemoBench-post-image.sh index 2d1c42fd2f..43a94041f4 100644 --- a/tools/demobench/package/macosx/Corda DemoBench-post-image.sh +++ b/tools/demobench/package/macosx/Corda DemoBench-post-image.sh @@ -13,14 +13,27 @@ function signApplication() { echo "**** Failed to re-sign the embedded JVM" return 1 fi + + # Resign the application because we've deleted the bugfixes directory. + if ! (codesign --force --sign "$IDENTITY" --preserve-metadata=identifier,entitlements,requirements --verbose "$APPDIR"); then + echo "*** Failed to resign DemoBench application" + return 1 + fi } # Switch to folder containing application. cd ../images/image-*/Corda\ DemoBench.app -INSTALL_HOME=Contents/PlugIns/Java.runtime/Contents/Home/jre/bin -if (mkdir -p $INSTALL_HOME); then - cp $JAVA_HOME/bin/java $INSTALL_HOME +JRE_HOME=Contents/PlugIns/Java.runtime/Contents/Home/jre +if (mkdir -p $JRE_HOME/bin); then + cp $JAVA_HOME/bin/java $JRE_HOME/bin +fi + +BUGFIX_HOME=Contents/Java/bugfixes +if [ -f $BUGFIX_HOME/apply.sh ]; then + chmod ugo+x $BUGFIX_HOME/apply.sh + $BUGFIX_HOME/apply.sh $JRE_HOME/lib/rt.jar + rm -rf $BUGFIX_HOME fi # Switch to image directory in order to sign it. diff --git a/tools/demobench/package/windows/Corda DemoBench-post-image.wsf b/tools/demobench/package/windows/Corda DemoBench-post-image.wsf index 6d54bde85f..54302824e1 100644 --- a/tools/demobench/package/windows/Corda DemoBench-post-image.wsf +++ b/tools/demobench/package/windows/Corda DemoBench-post-image.wsf @@ -4,13 +4,15 @@ diff --git a/tools/demobench/src/main/kotlin/net/corda/demobench/explorer/Explorer.kt b/tools/demobench/src/main/kotlin/net/corda/demobench/explorer/Explorer.kt index 52f76905be..3bf445b827 100644 --- a/tools/demobench/src/main/kotlin/net/corda/demobench/explorer/Explorer.kt +++ b/tools/demobench/src/main/kotlin/net/corda/demobench/explorer/Explorer.kt @@ -1,8 +1,8 @@ package net.corda.demobench.explorer -import net.corda.core.createDirectories -import net.corda.core.div -import net.corda.core.list +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div +import net.corda.core.internal.list import net.corda.core.utilities.loggerFor import net.corda.demobench.model.JVMConfig import net.corda.demobench.model.NodeConfig diff --git a/tools/demobench/src/main/kotlin/net/corda/demobench/model/NodeController.kt b/tools/demobench/src/main/kotlin/net/corda/demobench/model/NodeController.kt index b593bc32ce..0b484aacda 100644 --- a/tools/demobench/src/main/kotlin/net/corda/demobench/model/NodeController.kt +++ b/tools/demobench/src/main/kotlin/net/corda/demobench/model/NodeController.kt @@ -1,6 +1,6 @@ package net.corda.demobench.model -import net.corda.core.crypto.X509Utilities.getX509Name +import net.corda.core.crypto.getX509Name import net.corda.demobench.plugin.PluginController import net.corda.demobench.pty.R3Pty import tornadofx.* diff --git a/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTabView.kt b/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTabView.kt index 0049c9e194..cf9822f969 100644 --- a/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTabView.kt +++ b/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTabView.kt @@ -15,14 +15,14 @@ import javafx.scene.layout.Priority import javafx.stage.FileChooser import javafx.util.StringConverter import net.corda.core.crypto.commonName -import net.corda.core.div -import net.corda.core.exists +import net.corda.core.internal.div +import net.corda.core.internal.exists import net.corda.core.node.CityDatabase import net.corda.core.node.WorldMapLocation -import net.corda.core.readAllLines +import net.corda.core.internal.readAllLines import net.corda.core.utilities.normaliseLegalName import net.corda.core.utilities.validateLegalName -import net.corda.core.writeLines +import net.corda.core.internal.writeLines import net.corda.demobench.model.* import net.corda.demobench.ui.CloseableTab import org.controlsfx.control.CheckListView diff --git a/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTerminalView.kt b/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTerminalView.kt index d2d81031c6..174d464a78 100644 --- a/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTerminalView.kt +++ b/tools/demobench/src/main/kotlin/net/corda/demobench/views/NodeTerminalView.kt @@ -3,24 +3,23 @@ package net.corda.demobench.views import com.jediterm.terminal.TerminalColor import com.jediterm.terminal.TextStyle import com.jediterm.terminal.ui.settings.DefaultSettingsProvider -import java.awt.Dimension -import java.net.URI -import java.util.logging.Level -import javax.swing.SwingUtilities import javafx.application.Platform import javafx.embed.swing.SwingNode import javafx.scene.control.Button import javafx.scene.control.Label import javafx.scene.control.ProgressIndicator import javafx.scene.image.ImageView -import javafx.scene.layout.StackPane import javafx.scene.layout.HBox +import javafx.scene.layout.StackPane import javafx.scene.layout.VBox import javafx.util.Duration +import net.corda.contracts.getCashBalances +import net.corda.core.concurrent.match +import net.corda.core.contracts.ContractState import net.corda.core.crypto.commonName -import net.corda.core.match -import net.corda.core.then import net.corda.core.messaging.CordaRPCOps +import net.corda.core.messaging.vaultTrackBy +import net.corda.core.node.services.vault.PageSpecification import net.corda.demobench.explorer.ExplorerController import net.corda.demobench.model.NodeConfig import net.corda.demobench.model.NodeController @@ -33,10 +32,18 @@ import net.corda.demobench.web.WebServerController import rx.Subscription import rx.schedulers.Schedulers import tornadofx.* +import java.awt.Dimension +import java.net.URI +import java.util.logging.Level +import javax.swing.SwingUtilities class NodeTerminalView : Fragment() { override val root by fxml() + private companion object { + val pageSpecification = PageSpecification(1, 1) + } + private val nodeController by inject() private val explorerController by inject() private val webServerController by inject() @@ -53,7 +60,7 @@ class NodeTerminalView : Fragment() { private val subscriptions: MutableList = mutableListOf() private var txCount: Int = 0 - private var stateCount: Int = 0 + private var stateCount: Long = 0 private var isDestroyed: Boolean = false private val explorer = explorerController.explorer() private val webServer = webServerController.webServer() @@ -159,17 +166,15 @@ class NodeTerminalView : Fragment() { webServer.open(config).then { Platform.runLater { launchWebButton.graphic = null - } - it.match({ - log.info("Web server for ${config.legalName} started on $it") - Platform.runLater { + it.match(success = { + log.info("Web server for ${config.legalName} started on $it") webURL = it launchWebButton.text = "Reopen\nweb site" app.hostServices.showDocument(it.toString()) - } - }, { - launchWebButton.text = oldLabel - }) + }, failure = { + launchWebButton.text = oldLabel + }) + } } } } @@ -182,11 +187,12 @@ class NodeTerminalView : Fragment() { private fun initialise(config: NodeConfig, ops: CordaRPCOps) { try { - val (txInit, txNext) = ops.verifiedTransactions() - val (stateInit, stateNext) = ops.vaultAndUpdates() + val (txInit, txNext) = ops.verifiedTransactionsFeed() + val (stateInit, stateNext) = ops.vaultTrackBy(paging = pageSpecification) txCount = txInit.size - stateCount = stateInit.size + // This is the total number of states in the vault, regardless of pagination. + stateCount = stateInit.totalStatesAvailable Platform.runLater { logo.opacityProperty().animate(1.0, Duration.seconds(2.5)) @@ -194,7 +200,7 @@ class NodeTerminalView : Fragment() { states.value = stateCount.toString() } - val fxScheduler = Schedulers.from({ Platform.runLater(it) }) + val fxScheduler = Schedulers.from(Platform::runLater) subscriptions.add(txNext.observeOn(fxScheduler).subscribe { transactions.value = (++txCount).toString() }) diff --git a/tools/demobench/src/main/kotlin/net/corda/demobench/web/WebServer.kt b/tools/demobench/src/main/kotlin/net/corda/demobench/web/WebServer.kt index aad88033bc..4ff4e23e01 100644 --- a/tools/demobench/src/main/kotlin/net/corda/demobench/web/WebServer.kt +++ b/tools/demobench/src/main/kotlin/net/corda/demobench/web/WebServer.kt @@ -1,11 +1,10 @@ package net.corda.demobench.web -import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.RateLimiter -import com.google.common.util.concurrent.SettableFuture -import net.corda.core.catch -import net.corda.core.minutes -import net.corda.core.until +import net.corda.core.concurrent.CordaFuture +import net.corda.core.utilities.minutes +import net.corda.core.internal.concurrent.openFuture +import net.corda.core.internal.until import net.corda.core.utilities.loggerFor import net.corda.demobench.model.NodeConfig import net.corda.demobench.readErrorLines @@ -26,12 +25,12 @@ class WebServer internal constructor(private val webServerController: WebServerC private var process: Process? = null @Throws(IOException::class) - fun open(config: NodeConfig): ListenableFuture { + fun open(config: NodeConfig): CordaFuture { val nodeDir = config.nodeDir.toFile() if (!nodeDir.isDirectory) { log.warn("Working directory '{}' does not exist.", nodeDir.absolutePath) - return SettableFuture.create() + return openFuture() } try { @@ -58,9 +57,9 @@ class WebServer internal constructor(private val webServerController: WebServerC } } - val future = SettableFuture.create() + val future = openFuture() thread { - future.catch { + future.capture { log.info("Waiting for web server for ${config.legalName} to start ...") waitForStart(config.webPort) } diff --git a/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeConfigTest.kt b/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeConfigTest.kt index 904a90f717..6fee64384a 100644 --- a/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeConfigTest.kt +++ b/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeConfigTest.kt @@ -5,7 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.SerializationFeature import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigValueFactory -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.utilities.NetworkHostAndPort import net.corda.testing.DUMMY_NOTARY import net.corda.node.internal.NetworkMapInfo diff --git a/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeControllerTest.kt b/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeControllerTest.kt index 6c77fa3b05..d476ccc93a 100644 --- a/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeControllerTest.kt +++ b/tools/demobench/src/test/kotlin/net/corda/demobench/model/NodeControllerTest.kt @@ -1,8 +1,8 @@ package net.corda.demobench.model -import net.corda.core.crypto.X509Utilities.getX509Name -import net.corda.testing.DUMMY_NOTARY +import net.corda.core.crypto.getX509Name import net.corda.nodeapi.User +import net.corda.testing.DUMMY_NOTARY import org.junit.Test import java.nio.file.Path import java.nio.file.Paths diff --git a/tools/explorer/src/main/kotlin/net/corda/explorer/ExplorerSimulation.kt b/tools/explorer/src/main/kotlin/net/corda/explorer/ExplorerSimulation.kt index 1ea92e3f48..4b17207b51 100644 --- a/tools/explorer/src/main/kotlin/net/corda/explorer/ExplorerSimulation.kt +++ b/tools/explorer/src/main/kotlin/net/corda/explorer/ExplorerSimulation.kt @@ -11,11 +11,11 @@ import net.corda.core.contracts.Amount import net.corda.core.contracts.GBP import net.corda.core.contracts.USD import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.thenMatch import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.FlowHandle import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceType -import net.corda.core.thenMatch import net.corda.core.utilities.OpaqueBytes import net.corda.flows.* import net.corda.node.services.startFlowPermission diff --git a/tools/explorer/src/main/kotlin/net/corda/explorer/model/SettingsModel.kt b/tools/explorer/src/main/kotlin/net/corda/explorer/model/SettingsModel.kt index 2ee2e62814..2e85b434d8 100644 --- a/tools/explorer/src/main/kotlin/net/corda/explorer/model/SettingsModel.kt +++ b/tools/explorer/src/main/kotlin/net/corda/explorer/model/SettingsModel.kt @@ -5,9 +5,9 @@ import javafx.beans.Observable import javafx.beans.property.ObjectProperty import javafx.beans.property.SimpleObjectProperty import net.corda.core.contracts.currency -import net.corda.core.createDirectories -import net.corda.core.div -import net.corda.core.exists +import net.corda.core.internal.createDirectories +import net.corda.core.internal.div +import net.corda.core.internal.exists import tornadofx.* import java.nio.file.Files import java.nio.file.Path diff --git a/tools/explorer/src/main/kotlin/net/corda/explorer/views/Network.kt b/tools/explorer/src/main/kotlin/net/corda/explorer/views/Network.kt index 44e08e45b2..9df6973b25 100644 --- a/tools/explorer/src/main/kotlin/net/corda/explorer/views/Network.kt +++ b/tools/explorer/src/main/kotlin/net/corda/explorer/views/Network.kt @@ -27,9 +27,10 @@ import javafx.util.Duration import net.corda.client.jfx.model.* import net.corda.client.jfx.utils.* import net.corda.core.contracts.ContractState -import net.corda.core.identity.Party import net.corda.core.crypto.toBase58String +import net.corda.core.identity.Party import net.corda.core.node.NodeInfo +import net.corda.core.node.ScreenCoordinate import net.corda.explorer.formatters.PartyNameFormatter import net.corda.explorer.model.CordaView import tornadofx.* @@ -77,7 +78,7 @@ class Network : CordaView() { .map { it as? PartiallyResolvedTransaction.InputResolution.Resolved } .filterNotNull() .map { it.stateAndRef.state.data }.getParties() - val outputParties = it.transaction.tx.outputs.map { it.data }.observable().getParties() + val outputParties = it.transaction.tx.outputStates.observable().getParties() val signingParties = it.transaction.sigs.map { getModel().lookup(it.by) } // Input parties fire a bullets to all output parties, and to the signing parties. !! This is a rough guess of how the message moves in the network. // TODO : Expose artemis queue to get real message information. @@ -122,11 +123,11 @@ class Network : CordaView() { contentDisplay = ContentDisplay.TOP val coordinate = Bindings.createObjectBinding({ // These coordinates are obtained when we generate the map using TileMill. - node.worldMapLocation?.coordinate?.project(mapPane.width, mapPane.height, 85.0511, -85.0511, -180.0, 180.0) ?: Pair(0.0, 0.0) + node.worldMapLocation?.coordinate?.project(mapPane.width, mapPane.height, 85.0511, -85.0511, -180.0, 180.0) ?: ScreenCoordinate(0.0, 0.0) }, arrayOf(mapPane.widthProperty(), mapPane.heightProperty())) // Center point of the label. - layoutXProperty().bind(coordinate.map { it.first - width / 2 }) - layoutYProperty().bind(coordinate.map { it.second - height / 4 }) + layoutXProperty().bind(coordinate.map { it.screenX - width / 2 }) + layoutYProperty().bind(coordinate.map { it.screenY - height / 4 }) } val button = node.renderButton(mapLabel) @@ -211,45 +212,43 @@ class Network : CordaView() { private fun List.getParties() = map { it.participants.map { getModel().lookup(it.owningKey) } }.flatten() private fun fireBulletBetweenNodes(senderNode: Party, destNode: Party, startType: String, endType: String) { - allComponentMap[senderNode]?.let { senderNode -> - allComponentMap[destNode]?.let { destNode -> - val sender = senderNode.label.boundsInParentProperty().map { Point2D(it.width / 2 + it.minX, it.height / 4 - 2.5 + it.minY) } - val receiver = destNode.label.boundsInParentProperty().map { Point2D(it.width / 2 + it.minX, it.height / 4 - 2.5 + it.minY) } - val bullet = Circle(3.0) - bullet.styleClass += "bullet" - bullet.styleClass += "connection-$startType-to-$endType" - with(TranslateTransition(stepDuration, bullet)) { - fromXProperty().bind(sender.map { it.x }) - fromYProperty().bind(sender.map { it.y }) - toXProperty().bind(receiver.map { it.x }) - toYProperty().bind(receiver.map { it.y }) - setOnFinished { mapPane.children.remove(bullet) } + val senderNodeComp = allComponentMap[senderNode] ?: return + val destNodeComp = allComponentMap[destNode] ?: return + val sender = senderNodeComp.label.boundsInParentProperty().map { Point2D(it.width / 2 + it.minX, it.height / 4 - 2.5 + it.minY) } + val receiver = destNodeComp.label.boundsInParentProperty().map { Point2D(it.width / 2 + it.minX, it.height / 4 - 2.5 + it.minY) } + val bullet = Circle(3.0) + bullet.styleClass += "bullet" + bullet.styleClass += "connection-$startType-to-$endType" + with(TranslateTransition(stepDuration, bullet)) { + fromXProperty().bind(sender.map { it.x }) + fromYProperty().bind(sender.map { it.y }) + toXProperty().bind(receiver.map { it.x }) + toYProperty().bind(receiver.map { it.y }) + setOnFinished { mapPane.children.remove(bullet) } + play() + } + val line = Line().apply { + styleClass += "message-line" + startXProperty().bind(sender.map { it.x }) + startYProperty().bind(sender.map { it.y }) + endXProperty().bind(receiver.map { it.x }) + endYProperty().bind(receiver.map { it.y }) + } + // Fade in quick, then fade out slow. + with(FadeTransition(stepDuration.divide(5.0), line)) { + fromValue = 0.0 + toValue = 1.0 + play() + setOnFinished { + with(FadeTransition(stepDuration.multiply(6.0), line)) { + fromValue = 1.0 + toValue = 0.0 play() + setOnFinished { mapPane.children.remove(line) } } - val line = Line().apply { - styleClass += "message-line" - startXProperty().bind(sender.map { it.x }) - startYProperty().bind(sender.map { it.y }) - endXProperty().bind(receiver.map { it.x }) - endYProperty().bind(receiver.map { it.y }) - } - // Fade in quick, then fade out slow. - with(FadeTransition(stepDuration.divide(5.0), line)) { - fromValue = 0.0 - toValue = 1.0 - play() - setOnFinished { - with(FadeTransition(stepDuration.multiply(6.0), line)) { - fromValue = 1.0 - toValue = 0.0 - play() - setOnFinished { mapPane.children.remove(line) } - } - } - } - mapPane.children.add(1, line) - mapPane.children.add(bullet) } } + mapPane.children.add(1, line) + mapPane.children.add(bullet) } } diff --git a/tools/explorer/src/main/kotlin/net/corda/explorer/views/TransactionViewer.kt b/tools/explorer/src/main/kotlin/net/corda/explorer/views/TransactionViewer.kt index 858117de3f..ae8a6252e7 100644 --- a/tools/explorer/src/main/kotlin/net/corda/explorer/views/TransactionViewer.kt +++ b/tools/explorer/src/main/kotlin/net/corda/explorer/views/TransactionViewer.kt @@ -22,7 +22,10 @@ import net.corda.client.jfx.utils.map import net.corda.client.jfx.utils.sequence import net.corda.contracts.asset.Cash import net.corda.core.contracts.* -import net.corda.core.crypto.* +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.commonName +import net.corda.core.crypto.toBase58String +import net.corda.core.crypto.toStringShort import net.corda.core.identity.AbstractParty import net.corda.core.node.NodeInfo import net.corda.explorer.AmountDiff @@ -124,7 +127,7 @@ class TransactionViewer : CordaView("Transactions") { totalValueEquiv = ::calculateTotalEquiv.lift(myIdentity, reportingExchange, resolved.map { it.state.data }.lift(), - it.transaction.tx.outputs.map { it.data }.lift()) + it.transaction.tx.outputStates.lift()) ) } diff --git a/tools/explorer/src/main/kotlin/net/corda/explorer/views/cordapps/cash/NewTransaction.kt b/tools/explorer/src/main/kotlin/net/corda/explorer/views/cordapps/cash/NewTransaction.kt index d2a59c1cea..0057bbe945 100644 --- a/tools/explorer/src/main/kotlin/net/corda/explorer/views/cordapps/cash/NewTransaction.kt +++ b/tools/explorer/src/main/kotlin/net/corda/explorer/views/cordapps/cash/NewTransaction.kt @@ -21,14 +21,13 @@ import net.corda.core.contracts.Amount import net.corda.core.contracts.sumOrNull import net.corda.core.contracts.withoutIssuer import net.corda.core.flows.FlowException -import net.corda.core.getOrThrow -import net.corda.core.identity.AbstractParty import net.corda.core.identity.Party import net.corda.core.messaging.FlowHandle import net.corda.core.messaging.startFlow import net.corda.core.node.NodeInfo -import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.SignedTransaction +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.OpaqueBytes import net.corda.explorer.formatters.PartyNameFormatter import net.corda.explorer.model.CashTransaction import net.corda.explorer.model.IssuerModel @@ -111,7 +110,7 @@ class NewTransaction : Fragment() { } finally { dialog.dialogPane.isDisable = false } - }.ui { it -> + }.ui { val stx: SignedTransaction = it.stx val type = when (command) { is CashFlowCommand.IssueCash -> "Cash Issued" @@ -193,7 +192,7 @@ class NewTransaction : Fragment() { // Issuer issuerLabel.visibleProperty().bind(transactionTypeCB.valueProperty().isNotNull) issuerChoiceBox.apply { - items = issuers.map { it.legalIdentity as Party }.unique().sorted() + items = issuers.map { it.legalIdentity }.unique().sorted() converter = stringConverter { PartyNameFormatter.short.format(it.name) } visibleProperty().bind(transactionTypeCB.valueProperty().map { it == CashTransaction.Pay }) } @@ -220,7 +219,7 @@ class NewTransaction : Fragment() { ) availableAmount.textProperty() .bind(Bindings.createStringBinding({ - val filteredCash = cash.filtered { it.token.issuer.party as AbstractParty == issuer.value && it.token.product == currencyChoiceBox.value } + val filteredCash = cash.filtered { it.token.issuer.party == issuer.value && it.token.product == currencyChoiceBox.value } .map { it.withoutIssuer() }.sumOrNull() "${filteredCash ?: "None"} Available" }, arrayOf(currencyChoiceBox.valueProperty(), issuerChoiceBox.valueProperty()))) diff --git a/tools/explorer/src/test/kotlin/net/corda/explorer/model/SettingsModelTest.kt b/tools/explorer/src/test/kotlin/net/corda/explorer/model/SettingsModelTest.kt index 5d78027bf0..d09cde66cb 100644 --- a/tools/explorer/src/test/kotlin/net/corda/explorer/model/SettingsModelTest.kt +++ b/tools/explorer/src/test/kotlin/net/corda/explorer/model/SettingsModelTest.kt @@ -1,6 +1,6 @@ package net.corda.explorer.model -import net.corda.core.div +import net.corda.core.internal.div import org.junit.Rule import org.junit.Test import org.junit.rules.TemporaryFolder diff --git a/tools/graphs/build.gradle b/tools/graphs/build.gradle new file mode 100644 index 0000000000..cd947894b4 --- /dev/null +++ b/tools/graphs/build.gradle @@ -0,0 +1,75 @@ +class GraphProject { + def projects, project, nodeName, safeName + GraphProject(projects, project) { + def path = project.path.split(':').findAll { it } + if (!path) path.add project.rootProject.name + path = path.collect { it.split('[/\\\\]')[-1] } + nodeName = path.join(':') + safeName = path.join('_') + this.projects = projects + this.project = project + } + def getCompileDeps() { + project.configurations.compile.dependencies.matching { it in ProjectDependency }.collect { projects[it.dependencyProject] } + } +} + +class Graph { + def arcs = new LinkedHashSet() + def dotFile, imgFile, project + Graph(graphsDir, project) { + initArcs(project) + dotFile = new File(graphsDir, "${project.safeName}.dot") + imgFile = new File(graphsDir, "${project.safeName}.png") + this.project = project + } + def initArcs(project) { + project.compileDeps.each { + arcs.add([project, it]) + initArcs(it) + } + } + def output() { + dotFile.text = '' + dotFile << "digraph \"$project.nodeName\" {\n" + dotFile << ' rankdir=LR;\n' + arcs.collect { it.collect { it.nodeName } }.each { + dotFile << " \"${it[0]}\" -> \"${it[1]}\";\n" + } + dotFile << '}\n' + project.project.exec { + commandLine 'dot', '-Tpng', '-o', imgFile, dotFile + } + } +} + +def walkProjects(project, block) { + block(project) + project.childProjects.each { walkProjects(it.value, block) } +} + +task graphs { + doLast { + def projects = new LinkedHashMap() + walkProjects(rootProject) { projects[it] = new GraphProject(projects, it) } + def graphsDir = reporting.baseDir + graphsDir.mkdirs() + def graphs = projects.collect { new Graph(graphsDir, it.value) } + graphs.each { graph -> + if (!graph.arcs) { + logger.info "$graph.project.nodeName is a leaf." + return + } + for (def that : graphs) { + if (that != graph && that.arcs + graph.arcs == that.arcs) { + logger.info "$graph.project.nodeName is included in: $that.imgFile" + return + } + } + graph.output() + } + exec { + commandLine 'eog', graphsDir + } + } +} diff --git a/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt b/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt index 80167f0416..5b4f0f57ba 100644 --- a/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt +++ b/tools/loadtest/src/main/kotlin/net/corda/loadtest/ConnectionManager.kt @@ -40,13 +40,12 @@ fun setupJSchWithSshAgent(): JSch { override fun getName() = connector.name override fun getIdentities(): Vector = Vector(listOf( object : Identity { - override fun clear() { - } - + override fun clear() {} override fun getAlgName() = String(Buffer(identity.blob).string) override fun getName() = String(identity.comment) override fun isEncrypted() = false override fun getSignature(data: ByteArray?) = agentProxy.sign(identity.blob, data) + @Suppress("OverridingDeprecatedMember") override fun decrypt() = true override fun getPublicKeyBlob() = identity.blob override fun setPassphrase(passphrase: ByteArray?) = true @@ -55,7 +54,7 @@ fun setupJSchWithSshAgent(): JSch { override fun remove(blob: ByteArray?) = throw UnsupportedOperationException() override fun removeAll() = throw UnsupportedOperationException() - override fun add(identity: ByteArray?) = throw UnsupportedOperationException() + override fun add(bytes: ByteArray?) = throw UnsupportedOperationException() } } } @@ -85,9 +84,6 @@ class ConnectionManager(private val jSch: JSch) { * Connects to a list of nodes and executes the passed in action with the connections as parameter. The connections are * safely cleaned up if an exception is thrown. * - * @param username The UNIX username to use for SSH authentication. - * @param nodeHosts The list of hosts. - * @param remoteMessagingPort The Artemis messaging port nodes are listening on. * @param tunnelPortAllocation A local port allocation strategy for creating SSH tunnels. * @param withConnections An action to run once we're connected to the nodes. * @return The return value of [withConnections] diff --git a/tools/loadtest/src/main/kotlin/net/corda/loadtest/LoadTest.kt b/tools/loadtest/src/main/kotlin/net/corda/loadtest/LoadTest.kt index 34c9972798..1a31342c1e 100644 --- a/tools/loadtest/src/main/kotlin/net/corda/loadtest/LoadTest.kt +++ b/tools/loadtest/src/main/kotlin/net/corda/loadtest/LoadTest.kt @@ -181,7 +181,7 @@ fun runLoadTests(configuration: LoadTestConfiguration, tests: List { + private fun runShellCommand(command: String, stdout: OutputStream, stderr: OutputStream): CordaFuture { log.info("Running '$command' on ${remoteNode.hostname}") - return future { + return ForkJoinPool.commonPool().fork { val (exitCode, _) = withChannelExec(command) { channel -> channel.outputStream = stdout channel.setErrStream(stderr) diff --git a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/CrossCashTest.kt b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/CrossCashTest.kt index f2d90594ab..f8b8ae4a61 100644 --- a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/CrossCashTest.kt +++ b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/CrossCashTest.kt @@ -2,13 +2,13 @@ package net.corda.loadtest.tests import net.corda.client.mock.Generator import net.corda.client.mock.pickN -import net.corda.client.rpc.notUsed import net.corda.contracts.asset.Cash import net.corda.core.contracts.Issued import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.USD import net.corda.core.identity.AbstractParty -import net.corda.core.thenMatch +import net.corda.core.internal.concurrent.thenMatch +import net.corda.core.messaging.vaultQueryBy import net.corda.core.utilities.OpaqueBytes import net.corda.flows.CashFlowCommand import net.corda.loadtest.LoadTest @@ -218,14 +218,11 @@ val crossCashTest = LoadTest( val currentNodeVaults = HashMap>() simpleNodes.forEach { val quantities = HashMap() - val (vault, vaultUpdates) = it.proxy.vaultAndUpdates() - vaultUpdates.notUsed() + val vault = it.proxy.vaultQueryBy().states vault.forEach { val state = it.state.data - if (state is Cash.State) { - val issuer = state.amount.token.issuer.party - quantities.put(issuer, (quantities[issuer] ?: 0L) + state.amount.quantity) - } + val issuer = state.amount.token.issuer.party + quantities.put(issuer, (quantities[issuer] ?: 0L) + state.amount.quantity) } currentNodeVaults.put(it.info.legalIdentity, quantities) } @@ -257,9 +254,7 @@ val crossCashTest = LoadTest( if (minimum == null) { HashMap(next) } else { - next.forEach { entry -> - minimum.merge(entry.key, entry.value, Math::min) - } + next.forEach { minimum.merge(it.key, it.value, Math::min) } minimum } }!! diff --git a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/NotaryTest.kt b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/NotaryTest.kt index f684b08bb4..c8b58371e8 100644 --- a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/NotaryTest.kt +++ b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/NotaryTest.kt @@ -6,11 +6,11 @@ import net.corda.client.mock.pickOne import net.corda.client.mock.replicate import net.corda.contracts.asset.DUMMY_CASH_ISSUER import net.corda.contracts.asset.DUMMY_CASH_ISSUER_KEY +import net.corda.core.flows.FinalityFlow import net.corda.core.flows.FlowException +import net.corda.core.internal.concurrent.thenMatch import net.corda.core.messaging.startFlow -import net.corda.core.thenMatch import net.corda.core.transactions.SignedTransaction -import net.corda.flows.FinalityFlow import net.corda.loadtest.LoadTest import net.corda.loadtest.NodeConnection import net.corda.testing.contracts.DummyContract @@ -43,7 +43,7 @@ val dummyNotarisationTest = LoadTest( val proxy = node.proxy val issueFlow = proxy.startFlow(::FinalityFlow, issueTx) issueFlow.returnValue.thenMatch({ - val moveFlow = proxy.startFlow(::FinalityFlow, moveTx) + proxy.startFlow(::FinalityFlow, moveTx) }, {}) } catch (e: FlowException) { log.error("Failure", e) diff --git a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/SelfIssueTest.kt b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/SelfIssueTest.kt index 61194d7ea8..2402179a7b 100644 --- a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/SelfIssueTest.kt +++ b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/SelfIssueTest.kt @@ -4,12 +4,12 @@ import de.danielbechler.diff.ObjectDifferFactory import net.corda.client.mock.Generator import net.corda.client.mock.pickOne import net.corda.client.mock.replicatePoisson -import net.corda.client.rpc.notUsed import net.corda.contracts.asset.Cash import net.corda.core.contracts.USD import net.corda.core.flows.FlowException -import net.corda.core.getOrThrow import net.corda.core.identity.AbstractParty +import net.corda.core.utilities.getOrThrow +import net.corda.core.messaging.vaultQueryBy import net.corda.flows.CashFlowCommand import net.corda.loadtest.LoadTest import net.corda.loadtest.NodeConnection @@ -71,15 +71,12 @@ val selfIssueTest = LoadTest( gatherRemoteState = { previousState -> val selfIssueVaults = HashMap() simpleNodes.forEach { connection -> - val (vault, vaultUpdates) = connection.proxy.vaultAndUpdates() - vaultUpdates.notUsed() + val vault = connection.proxy.vaultQueryBy().states vault.forEach { val state = it.state.data - if (state is Cash.State) { - val issuer = state.amount.token.issuer.party - if (issuer == connection.info.legalIdentity as AbstractParty) { - selfIssueVaults.put(issuer, (selfIssueVaults[issuer] ?: 0L) + state.amount.quantity) - } + val issuer = state.amount.token.issuer.party + if (issuer == connection.info.legalIdentity as AbstractParty) { + selfIssueVaults.put(issuer, (selfIssueVaults[issuer] ?: 0L) + state.amount.quantity) } } } diff --git a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/StabilityTest.kt b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/StabilityTest.kt index b24f0216b4..26d8fe66dd 100644 --- a/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/StabilityTest.kt +++ b/tools/loadtest/src/main/kotlin/net/corda/loadtest/tests/StabilityTest.kt @@ -4,9 +4,9 @@ import net.corda.client.mock.Generator import net.corda.core.contracts.Amount import net.corda.core.contracts.USD import net.corda.core.flows.FlowException -import net.corda.core.getOrThrow -import net.corda.core.thenMatch +import net.corda.core.internal.concurrent.thenMatch import net.corda.core.utilities.OpaqueBytes +import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.loggerFor import net.corda.flows.CashFlowCommand import net.corda.loadtest.LoadTest diff --git a/verifier/build.gradle b/verifier/build.gradle index e0035ffa33..41ba1988a5 100644 --- a/verifier/build.gradle +++ b/verifier/build.gradle @@ -75,6 +75,6 @@ artifacts { } publish { - name = 'corda-verifier' disableDefaultJar = true + name 'corda-verifier' } diff --git a/verifier/src/integration-test/kotlin/net/corda/verifier/GeneratedLedger.kt b/verifier/src/integration-test/kotlin/net/corda/verifier/GeneratedLedger.kt index 3c0ccd7460..a5890f8a9b 100644 --- a/verifier/src/integration-test/kotlin/net/corda/verifier/GeneratedLedger.kt +++ b/verifier/src/integration-test/kotlin/net/corda/verifier/GeneratedLedger.kt @@ -2,9 +2,7 @@ package net.corda.verifier import net.corda.client.mock.* import net.corda.core.contracts.* -import net.corda.testing.contracts.DummyContract import net.corda.core.crypto.SecureHash -import net.corda.core.crypto.X509Utilities import net.corda.core.crypto.entropyToKeyPair import net.corda.core.crypto.sha256 import net.corda.core.identity.AbstractParty @@ -12,6 +10,8 @@ import net.corda.core.identity.AnonymousParty import net.corda.core.identity.Party import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.WireTransaction +import net.corda.testing.contracts.DummyContract +import net.corda.testing.getTestX509Name import java.io.ByteArrayInputStream import java.math.BigInteger import java.security.PublicKey @@ -49,8 +49,8 @@ data class GeneratedLedger( Generator.replicatePoisson(1.0, pickOneOrMaybeNew(attachments, attachmentGenerator)) } - val commandsGenerator: Generator>> by lazy { - Generator.replicatePoisson(4.0, commandGenerator(identities)) + val commandsGenerator: Generator, Party>>> by lazy { + Generator.replicatePoisson(4.0, commandGenerator(identities), atLeastOne = true) } /** @@ -68,15 +68,12 @@ data class GeneratedLedger( ) } attachmentsGenerator.combine(outputsGen, commandsGenerator) { txAttachments, outputs, commands -> - val signers = commands.flatMap { it.first.signers } val newTransaction = WireTransaction( emptyList(), txAttachments.map { it.id }, outputs, commands.map { it.first }, null, - signers, - TransactionType.General, null ) val newOutputStateAndRefs = outputs.mapIndexed { i, state -> @@ -90,10 +87,42 @@ data class GeneratedLedger( } } + /** + * Generates an exit transaction. + * Invariants: + * * The output list must be empty + */ + fun exitTransactionGenerator(inputNotary: Party, inputsToChooseFrom: List>): Generator> { + val inputsGen = Generator.sampleBernoulli(inputsToChooseFrom) + return inputsGen.combine(attachmentsGenerator, commandsGenerator) { inputs, txAttachments, commands -> + val newTransaction = WireTransaction( + inputs.map { it.ref }, + txAttachments.map { it.id }, + emptyList(), + commands.map { it.first }, + inputNotary, + null + ) + + val availableOutputsMinusConsumed = HashMap(availableOutputs) + if (inputs.size == inputsToChooseFrom.size) { + availableOutputsMinusConsumed.remove(inputNotary) + } else { + availableOutputsMinusConsumed[inputNotary] = inputsToChooseFrom - inputs + } + val newAvailableOutputs = availableOutputsMinusConsumed + val newAttachments = attachments + txAttachments + val newIdentities = identities + commands.map { it.second } + val newLedger = GeneratedLedger(transactions + newTransaction, newAvailableOutputs, newAttachments, newIdentities) + Pair(newTransaction, newLedger) + } + } + /** * Generates a regular non-issue transaction. * Invariants: * * Input and output notaries must be one and the same. + * * There must be at least one input and output state. */ fun regularTransactionGenerator(inputNotary: Party, inputsToChooseFrom: List>): Generator> { val outputsGen = outputsGenerator.map { outputs -> @@ -103,15 +132,12 @@ data class GeneratedLedger( } val inputsGen = Generator.sampleBernoulli(inputsToChooseFrom) return inputsGen.combine(attachmentsGenerator, outputsGen, commandsGenerator) { inputs, txAttachments, outputs, commands -> - val signers = commands.flatMap { it.first.signers } + inputNotary.owningKey val newTransaction = WireTransaction( inputs.map { it.ref }, txAttachments.map { it.id }, outputs, commands.map { it.first }, inputNotary, - signers, - TransactionType.General, null ) val newOutputStateAndRefs = outputs.mapIndexed { i, state -> @@ -132,45 +158,7 @@ data class GeneratedLedger( } /** - * Generates a notary change transaction. - * Invariants: - * * Input notary must be different from the output ones. - * * All other data must stay the same. - */ - fun notaryChangeTransactionGenerator(inputNotary: Party, inputsToChooseFrom: List>): Generator> { - val newNotaryGen = pickOneOrMaybeNew(identities - inputNotary, partyGenerator) - val inputsGen = Generator.sampleBernoulli(inputsToChooseFrom) - return inputsGen.flatMap { inputs -> - val signers: List = (inputs.flatMap { it.state.data.participants } + inputNotary).map { it.owningKey } - val outputsGen = Generator.sequence(inputs.map { input -> newNotaryGen.map { TransactionState(input.state.data, it, null) } }) - outputsGen.combine(attachmentsGenerator) { outputs, txAttachments -> - val newNotaries = outputs.map { it.notary } - val newTransaction = WireTransaction( - inputs.map { it.ref }, - txAttachments.map { it.id }, - outputs, - emptyList(), - inputNotary, - signers, - TransactionType.NotaryChange, - null - ) - val newOutputStateAndRefs = outputs.mapIndexed { i, state -> - StateAndRef(state, StateRef(newTransaction.id, i)) - } - val availableOutputsMinusConsumed = HashMap(availableOutputs) - availableOutputsMinusConsumed[inputNotary] = inputsToChooseFrom - inputs - val newAvailableOutputs = availableOutputsMinusConsumed + newOutputStateAndRefs.groupBy { it.state.notary } - val newAttachments = attachments + txAttachments - val newIdentities = identities + newNotaries - val newLedger = GeneratedLedger(transactions + newTransaction, newAvailableOutputs, newAttachments, newIdentities) - Pair(newTransaction, newLedger) - } - } - } - - /** - * Generates a valid transaction. It may be one of three types of issuance, regular and notary change. These have + * Generates a valid transaction. It may be either an issuance or a regular spend transaction. These have * different invariants on notary fields. */ val transactionGenerator: Generator> by lazy { @@ -181,8 +169,8 @@ data class GeneratedLedger( val inputsToChooseFrom = availableOutputs[inputNotary]!! Generator.frequency( 0.3 to issuanceGenerator, - 0.4 to regularTransactionGenerator(inputNotary, inputsToChooseFrom), - 0.3 to notaryChangeTransactionGenerator(inputNotary, inputsToChooseFrom) + 0.3 to exitTransactionGenerator(inputNotary, inputsToChooseFrom), + 0.4 to regularTransactionGenerator(inputNotary, inputsToChooseFrom) ) } } @@ -214,7 +202,7 @@ val stateGenerator: Generator = GeneratedState(nonce, participants.map { AnonymousParty(it) }) } -fun commandGenerator(partiesToPickFrom: Collection): Generator> { +fun commandGenerator(partiesToPickFrom: Collection): Generator, Party>> { return pickOneOrMaybeNew(partiesToPickFrom, partyGenerator).combine(Generator.long()) { signer, nonce -> Pair( Command(GeneratedCommandData(nonce), signer.owningKey), @@ -224,7 +212,7 @@ fun commandGenerator(partiesToPickFrom: Collection): Generator = Generator.int().combine(publicKeyGenerator) { n, key -> - Party(X509Utilities.getDevX509Name("Party$n"), key) + Party(getTestX509Name("Party$n"), key) } fun pickOneOrMaybeNew(from: Collection, generator: Generator): Generator { @@ -238,4 +226,4 @@ fun pickOneOrMaybeNew(from: Collection, generator: Generator): Generat } val attachmentGenerator: Generator = Generator.bytes(16).map(::GeneratedAttachment) -val outputsGenerator = Generator.replicatePoisson(3.0, stateGenerator) +val outputsGenerator = Generator.replicatePoisson(3.0, stateGenerator, atLeastOne = true) diff --git a/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierDriver.kt b/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierDriver.kt index 66315907aa..b97f5845dc 100644 --- a/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierDriver.kt +++ b/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierDriver.kt @@ -1,19 +1,14 @@ package net.corda.verifier -import com.google.common.util.concurrent.Futures -import com.google.common.util.concurrent.ListenableFuture -import com.google.common.util.concurrent.ListeningScheduledExecutorService -import com.google.common.util.concurrent.SettableFuture import com.typesafe.config.Config import com.typesafe.config.ConfigFactory -import net.corda.core.crypto.X509Utilities +import net.corda.core.concurrent.CordaFuture import net.corda.core.crypto.commonName -import net.corda.core.div -import net.corda.core.map import net.corda.core.crypto.random63BitValue +import net.corda.core.internal.concurrent.* +import net.corda.core.internal.div import net.corda.core.transactions.LedgerTransaction import net.corda.core.utilities.NetworkHostAndPort -import net.corda.testing.driver.ProcessUtilities import net.corda.core.utilities.loggerFor import net.corda.node.services.config.configureDevKeyAndTrustStores import net.corda.nodeapi.ArtemisMessagingComponent.Companion.NODE_USER @@ -23,6 +18,7 @@ import net.corda.nodeapi.VerifierApi import net.corda.nodeapi.config.NodeSSLConfiguration import net.corda.nodeapi.config.SSLConfiguration import net.corda.testing.driver.* +import net.corda.testing.getTestX509Name import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.client.ActiveMQClient import org.apache.activemq.artemis.api.core.client.ClientProducer @@ -39,6 +35,7 @@ import org.bouncycastle.asn1.x500.X500Name import java.nio.file.Path import java.nio.file.Paths import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.atomic.AtomicInteger /** @@ -47,10 +44,10 @@ import java.util.concurrent.atomic.AtomicInteger */ interface VerifierExposedDSLInterface : DriverDSLExposedInterface { /** Starts a lightweight verification requestor that implements the Node's Verifier API */ - fun startVerificationRequestor(name: X500Name): ListenableFuture + fun startVerificationRequestor(name: X500Name): CordaFuture /** Starts an out of process verifier connected to [address] */ - fun startVerifier(address: NetworkHostAndPort): ListenableFuture + fun startVerifier(address: NetworkHostAndPort): CordaFuture /** * Waits until [number] verifiers are listening for verification requests coming from the Node. Check @@ -110,15 +107,15 @@ data class VerificationRequestorHandle( private val responseAddress: SimpleString, private val session: ClientSession, private val requestProducer: ClientProducer, - private val addVerificationFuture: (Long, SettableFuture) -> Unit, - private val executorService: ListeningScheduledExecutorService + private val addVerificationFuture: (Long, OpenFuture) -> Unit, + private val executorService: ScheduledExecutorService ) { - fun verifyTransaction(transaction: LedgerTransaction): ListenableFuture { + fun verifyTransaction(transaction: LedgerTransaction): CordaFuture { val message = session.createMessage(false) val verificationId = random63BitValue() val request = VerifierApi.VerificationRequest(verificationId, transaction, responseAddress) request.writeToClientMessage(message) - val verificationFuture = SettableFuture.create() + val verificationFuture = openFuture() addVerificationFuture(verificationId, verificationFuture) requestProducer.send(message) return verificationFuture @@ -176,9 +173,9 @@ data class VerifierDriverDSL( } } - override fun startVerificationRequestor(name: X500Name): ListenableFuture { + override fun startVerificationRequestor(name: X500Name): CordaFuture { val hostAndPort = driverDSL.portAllocation.nextHostAndPort() - return driverDSL.executorService.submit { + return driverDSL.executorService.fork { startVerificationRequestorInternal(name, hostAndPort) } } @@ -207,7 +204,7 @@ data class VerifierDriverDSL( val server = ActiveMQServerImpl(artemisConfig, securityManager) log.info("Starting verification requestor Artemis server with base dir $baseDir") server.start() - driverDSL.shutdownManager.registerShutdown(Futures.immediateFuture { + driverDSL.shutdownManager.registerShutdown(doneFuture { server.stop() }) @@ -215,7 +212,7 @@ data class VerifierDriverDSL( val transport = ArtemisTcpTransport.tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfig) val sessionFactory = locator.createSessionFactory(transport) val session = sessionFactory.createSession() - driverDSL.shutdownManager.registerShutdown(Futures.immediateFuture { + driverDSL.shutdownManager.registerShutdown(doneFuture { session.stop() sessionFactory.close() }) @@ -223,7 +220,7 @@ data class VerifierDriverDSL( val consumer = session.createConsumer(responseAddress) // We demux the individual txs ourselves to avoid race when a new verifier is added - val verificationResponseFutures = ConcurrentHashMap>() + val verificationResponseFutures = ConcurrentHashMap>() consumer.setMessageHandler { val result = VerifierApi.VerificationResponse.fromClientMessage(it) val resultFuture = verificationResponseFutures.remove(result.verificationId) @@ -247,12 +244,12 @@ data class VerifierDriverDSL( ) } - override fun startVerifier(address: NetworkHostAndPort): ListenableFuture { + override fun startVerifier(address: NetworkHostAndPort): CordaFuture { log.info("Starting verifier connecting to address $address") val id = verifierCount.andIncrement val jdwpPort = if (driverDSL.isDebug) driverDSL.debugPortAllocation.nextPort() else null - val processFuture = driverDSL.executorService.submit { - val verifierName = X509Utilities.getDevX509Name("verifier$id") + val processFuture = driverDSL.executorService.fork { + val verifierName = getTestX509Name("verifier$id") val baseDirectory = driverDSL.driverDirectory / verifierName.commonName val config = createConfiguration(baseDirectory, address) val configFilename = "verifier.conf" diff --git a/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierTests.kt b/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierTests.kt index da594e5422..084e0140e4 100644 --- a/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierTests.kt +++ b/verifier/src/integration-test/kotlin/net/corda/verifier/VerifierTests.kt @@ -1,14 +1,14 @@ package net.corda.verifier -import com.google.common.util.concurrent.Futures import net.corda.client.mock.generateOrFail import net.corda.core.contracts.DOLLARS -import net.corda.core.map import net.corda.core.messaging.startFlow import net.corda.core.node.services.ServiceInfo import net.corda.core.utilities.OpaqueBytes import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.WireTransaction +import net.corda.core.internal.concurrent.map +import net.corda.core.internal.concurrent.transpose import net.corda.testing.ALICE import net.corda.testing.DUMMY_NOTARY import net.corda.flows.CashIssueFlow @@ -41,7 +41,7 @@ class VerifierTests { val alice = aliceFuture.get() startVerifier(alice) alice.waitUntilNumberOfVerifiers(1) - val results = Futures.allAsList(transactions.map { alice.verifyTransaction(it) }).get() + val results = transactions.map { alice.verifyTransaction(it) }.transpose().get() results.forEach { if (it != null) { throw it @@ -61,7 +61,7 @@ class VerifierTests { startVerifier(alice) } alice.waitUntilNumberOfVerifiers(numberOfVerifiers) - val results = Futures.allAsList(transactions.map { alice.verifyTransaction(it) }).get() + val results = transactions.map { alice.verifyTransaction(it) }.transpose().get() results.forEach { if (it != null) { throw it @@ -94,7 +94,7 @@ class VerifierTests { it } } - Futures.allAsList(futures).get() + futures.transpose().get() } } @@ -106,7 +106,7 @@ class VerifierTests { val alice = aliceFuture.get() val futures = transactions.map { alice.verifyTransaction(it) } startVerifier(alice) - Futures.allAsList(futures).get() + futures.transpose().get() } } diff --git a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt index b8df7f891b..fec528f883 100644 --- a/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt +++ b/verifier/src/main/kotlin/net/corda/verifier/Verifier.kt @@ -1,9 +1,14 @@ package net.corda.verifier +import com.esotericsoftware.kryo.pool.KryoPool import com.typesafe.config.Config import com.typesafe.config.ConfigFactory import com.typesafe.config.ConfigParseOptions -import net.corda.core.div +import net.corda.core.internal.div +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults +import net.corda.core.serialization.SerializationFactory +import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.debug import net.corda.core.utilities.loggerFor @@ -14,6 +19,10 @@ import net.corda.nodeapi.VerifierApi.VERIFICATION_REQUESTS_QUEUE_NAME import net.corda.nodeapi.config.NodeSSLConfiguration import net.corda.nodeapi.config.getValue import net.corda.nodeapi.internal.addShutdownHook +import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme +import net.corda.nodeapi.internal.serialization.KRYO_P2P_CONTEXT +import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 +import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl import org.apache.activemq.artemis.api.core.client.ActiveMQClient import java.nio.file.Path import java.nio.file.Paths @@ -55,6 +64,7 @@ class Verifier { session.close() sessionFactory.close() } + initialiseSerialization() val consumer = session.createConsumer(VERIFICATION_REQUESTS_QUEUE_NAME) val replyProducer = session.createProducer() consumer.setMessageHandler { @@ -77,5 +87,26 @@ class Verifier { log.info("Verifier started") Thread.sleep(Long.MAX_VALUE) } + + private fun initialiseSerialization() { + SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { + registerScheme(KryoVerifierSerializationScheme(this)) + } + SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT + } + } + + class KryoVerifierSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) { + override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { + return byteSequence.equals(KryoHeaderV0_1) && target == SerializationContext.UseCase.P2P + } + + override fun rpcClientKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } + + override fun rpcServerKryoPool(context: SerializationContext): KryoPool { + throw UnsupportedOperationException() + } } } \ No newline at end of file diff --git a/verify-enclave/src/main/kotlin/com/r3/enclaves/txverify/Enclavelet.kt b/verify-enclave/src/main/kotlin/com/r3/enclaves/txverify/Enclavelet.kt index bac47aaff2..4905285b08 100644 --- a/verify-enclave/src/main/kotlin/com/r3/enclaves/txverify/Enclavelet.kt +++ b/verify-enclave/src/main/kotlin/com/r3/enclaves/txverify/Enclavelet.kt @@ -5,7 +5,6 @@ package com.r3.enclaves.txverify import com.esotericsoftware.minlog.Log import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SerializedBytes -import net.corda.core.serialization.createTestKryo import net.corda.core.serialization.deserialize import net.corda.core.transactions.WireTransaction import java.io.File @@ -31,10 +30,9 @@ class TransactionVerificationRequest(val wtxToVerify: SerializedBytes(kryo) - val wtxToVerify = req.wtxToVerify.deserialize(kryo) - val dependencies = req.dependencies.map { it.deserialize(kryo) }.associateBy { it.id } + val req = reqBytes.deserialize() + val wtxToVerify = req.wtxToVerify.deserialize() + val dependencies = req.dependencies.map { it.deserialize() }.associateBy { it.id } val ltx = wtxToVerify.toLedgerTransaction( resolveIdentity = { null }, resolveAttachment = { null }, diff --git a/webserver/build.gradle b/webserver/build.gradle index 0868a6ec8a..05c2daf185 100644 --- a/webserver/build.gradle +++ b/webserver/build.gradle @@ -72,5 +72,5 @@ jar { } publish { - name = jar.baseName + name jar.baseName } diff --git a/webserver/src/integration-test/kotlin/net/corda/webserver/WebserverDriverTests.kt b/webserver/src/integration-test/kotlin/net/corda/webserver/WebserverDriverTests.kt index dca27b20f1..bac83b0eab 100644 --- a/webserver/src/integration-test/kotlin/net/corda/webserver/WebserverDriverTests.kt +++ b/webserver/src/integration-test/kotlin/net/corda/webserver/WebserverDriverTests.kt @@ -1,7 +1,7 @@ package net.corda.webserver -import net.corda.core.getOrThrow import net.corda.core.utilities.NetworkHostAndPort +import net.corda.core.utilities.getOrThrow import net.corda.testing.DUMMY_BANK_A import net.corda.testing.driver.WebserverHandle import net.corda.testing.driver.addressMustBeBound diff --git a/webserver/src/main/kotlin/net/corda/webserver/WebArgsParser.kt b/webserver/src/main/kotlin/net/corda/webserver/WebArgsParser.kt index 539b812a44..d78f9087a7 100644 --- a/webserver/src/main/kotlin/net/corda/webserver/WebArgsParser.kt +++ b/webserver/src/main/kotlin/net/corda/webserver/WebArgsParser.kt @@ -6,7 +6,7 @@ import com.typesafe.config.ConfigParseOptions import com.typesafe.config.ConfigRenderOptions import joptsimple.OptionParser import joptsimple.util.EnumConverter -import net.corda.core.div +import net.corda.core.internal.div import net.corda.core.utilities.loggerFor import org.slf4j.event.Level import java.io.PrintStream diff --git a/webserver/src/main/kotlin/net/corda/webserver/WebServer.kt b/webserver/src/main/kotlin/net/corda/webserver/WebServer.kt index 4f7a072466..25388e96c1 100644 --- a/webserver/src/main/kotlin/net/corda/webserver/WebServer.kt +++ b/webserver/src/main/kotlin/net/corda/webserver/WebServer.kt @@ -3,8 +3,8 @@ package net.corda.webserver import com.typesafe.config.ConfigException -import net.corda.core.div -import net.corda.core.rootCause +import net.corda.core.internal.div +import net.corda.core.internal.rootCause import net.corda.webserver.internal.NodeWebServer import org.slf4j.LoggerFactory import java.lang.management.ManagementFactory diff --git a/webserver/src/main/kotlin/net/corda/webserver/servlets/DataUploadServlet.kt b/webserver/src/main/kotlin/net/corda/webserver/servlets/DataUploadServlet.kt index a81adb9a1b..e2c4c49e3d 100644 --- a/webserver/src/main/kotlin/net/corda/webserver/servlets/DataUploadServlet.kt +++ b/webserver/src/main/kotlin/net/corda/webserver/servlets/DataUploadServlet.kt @@ -17,7 +17,6 @@ class DataUploadServlet : HttpServlet() { @Throws(IOException::class) override fun doPost(req: HttpServletRequest, resp: HttpServletResponse) { - @Suppress("DEPRECATION") // Bogus warning due to superclass static method being deprecated. val isMultipart = ServletFileUpload.isMultipartContent(req) val rpc = servletContext.getAttribute("rpc") as CordaRPCOps @@ -34,20 +33,25 @@ class DataUploadServlet : HttpServlet() { resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Got an upload request with no files") return } - + fun reportError(message: String) { + println(message) // Show in webserver window. + resp.sendError(HttpServletResponse.SC_BAD_REQUEST, message) + } while (iterator.hasNext()) { val item = iterator.next() log.info("Receiving ${item.name}") - - try { - val dataType = req.pathInfo.substring(1).substringBefore('/') - @Suppress("DEPRECATION") // TODO: Replace the use of uploadFile - messages += rpc.uploadFile(dataType, item.name, item.openStream()) - log.info("${item.name} successfully accepted: ${messages.last()}") - } catch(e: RuntimeException) { - println(e) - resp.sendError(HttpServletResponse.SC_BAD_REQUEST, "Got a file upload request for an unknown data type") + val dataType = req.pathInfo.substring(1).substringBefore('/') + if (dataType != "attachment") { + reportError("Got a file upload request for an unknown data type $dataType") + continue } + try { + messages += rpc.uploadAttachment(item.openStream()).toString() + } catch (e: RuntimeException) { + reportError(e.toString()) + continue + } + log.info("${item.name} successfully accepted: ${messages.last()}") } // Send back the hashes as a convenience for the user. diff --git a/webserver/webcapsule/build.gradle b/webserver/webcapsule/build.gradle index 9cfc1f8e0c..9f5fa8df55 100644 --- a/webserver/webcapsule/build.gradle +++ b/webserver/webcapsule/build.gradle @@ -57,7 +57,6 @@ artifacts { } publish { - name = 'corda-webserver' - publishWar = false // TODO: Use WAR instead of JAR disableDefaultJar = true + name 'corda-webserver' }