Merge community-master

This commit is contained in:
Michal Kit 2017-08-15 12:04:09 +01:00
commit b6fd5ede58
649 changed files with 21249 additions and 14346 deletions

3
.gitignore vendored
View File

@ -32,6 +32,7 @@ lib/dokka.jar
.idea/libraries .idea/libraries
.idea/shelf .idea/shelf
.idea/dataSources .idea/dataSources
.idea/markdown-navigator
/gradle-plugins/.idea/ /gradle-plugins/.idea/
# Include the -parameters compiler option by default in IntelliJ required for serialization. # Include the -parameters compiler option by default in IntelliJ required for serialization.
@ -66,7 +67,7 @@ lib/dokka.jar
## Plugin-specific files: ## Plugin-specific files:
# IntelliJ # IntelliJ
/out/ **/out/
/classes/ /classes/
# mpeltonen/sbt-idea plugin # mpeltonen/sbt-idea plugin

2
.idea/compiler.xml generated
View File

@ -40,6 +40,8 @@
<module name="finance_test" target="1.8" /> <module name="finance_test" target="1.8" />
<module name="intellij-plugin_main" target="1.8" /> <module name="intellij-plugin_main" target="1.8" />
<module name="intellij-plugin_test" target="1.8" /> <module name="intellij-plugin_test" target="1.8" />
<module name="graphs_main" target="1.8" />
<module name="graphs_test" target="1.8" />
<module name="irs-demo_integrationTest" target="1.8" /> <module name="irs-demo_integrationTest" target="1.8" />
<module name="irs-demo_main" target="1.8" /> <module name="irs-demo_main" target="1.8" />
<module name="irs-demo_test" target="1.8" /> <module name="irs-demo_test" target="1.8" />

View File

@ -4,7 +4,7 @@ buildscript {
file("$projectDir/constants.properties").withInputStream { constants.load(it) } file("$projectDir/constants.properties").withInputStream { constants.load(it) }
// Our version: bump this on release. // Our version: bump this on release.
ext.corda_release_version = "0.14-SNAPSHOT" ext.corda_release_version = "0.15-SNAPSHOT"
// Increment this on any release that changes public APIs anywhere in the Corda platform // Increment this on any release that changes public APIs anywhere in the Corda platform
// TODO This is going to be difficult until we have a clear separation throughout the code of what is public and what is internal // TODO This is going to be difficult until we have a clear separation throughout the code of what is public and what is internal
ext.corda_platform_version = 1 ext.corda_platform_version = 1
@ -31,7 +31,6 @@ buildscript {
ext.log4j_version = '2.7' ext.log4j_version = '2.7'
ext.bouncycastle_version = constants.getProperty("bouncycastleVersion") ext.bouncycastle_version = constants.getProperty("bouncycastleVersion")
ext.guava_version = constants.getProperty("guavaVersion") ext.guava_version = constants.getProperty("guavaVersion")
ext.quickcheck_version = '0.7'
ext.okhttp_version = '3.5.0' ext.okhttp_version = '3.5.0'
ext.netty_version = '4.1.9.Final' ext.netty_version = '4.1.9.Final'
ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion") ext.typesafe_config_version = constants.getProperty("typesafeConfigVersion")
@ -168,22 +167,23 @@ repositories {
} }
} }
// TODO: Corda root project currently produces a dummy cordapp when it shouldn't.
// Required for building out the fat JAR. // Required for building out the fat JAR.
dependencies { dependencies {
compile project(':node') cordaCompile project(':node')
compile "com.google.guava:guava:$guava_version" compile "com.google.guava:guava:$guava_version"
// Set to compile to ensure it exists now deploy nodes no longer relies on build // Set to corda compile to ensure it exists now deploy nodes no longer relies on build
compile project(path: ":node:capsule", configuration: 'runtimeArtifacts') cordaCompile project(path: ":node:capsule", configuration: 'runtimeArtifacts')
compile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts') cordaCompile project(path: ":webserver:webcapsule", configuration: 'runtimeArtifacts')
// For the buildCordappDependenciesJar task // For the buildCordappDependenciesJar task
runtime project(':client:jfx') cordaRuntime project(':client:jfx')
runtime project(':client:mock') cordaRuntime project(':client:mock')
runtime project(':client:rpc') cordaRuntime project(':client:rpc')
runtime project(':core') cordaRuntime project(':core')
runtime project(':finance') cordaRuntime project(':finance')
runtime project(':webserver') cordaRuntime project(':webserver')
testCompile project(':test-utils') testCompile project(':test-utils')
} }
@ -285,7 +285,7 @@ artifactory {
password = System.getenv('CORDA_ARTIFACTORY_PASSWORD') password = System.getenv('CORDA_ARTIFACTORY_PASSWORD')
} }
defaults { defaults {
publications('corda-jfx', 'corda-mock', 'corda-rpc', 'corda-core', 'corda', 'cordform-common', 'corda-finance', 'corda-node', 'corda-node-api', 'corda-node-schemas', 'corda-test-utils', 'corda-jackson', 'corda-verifier', 'corda-webserver-impl', 'corda-webserver') publications('corda-jfx', 'corda-mock', 'corda-rpc', 'corda-core', 'corda', 'cordform-common', 'corda-finance', 'corda-node', 'corda-node-api', 'corda-node-schemas', 'corda-test-common', 'corda-test-utils', 'corda-jackson', 'corda-verifier', 'corda-webserver-impl', 'corda-webserver')
} }
} }
} }

View File

@ -6,6 +6,8 @@ apply plugin: 'com.jfrog.artifactory'
dependencies { dependencies {
compile project(':core') compile project(':core')
compile project(':finance') compile project(':finance')
testCompile project(':test-utils')
compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version" compile "org.jetbrains.kotlin:kotlin-stdlib-jre8:$kotlin_version"
testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version" testCompile "org.jetbrains.kotlin:kotlin-test:$kotlin_version"
@ -18,12 +20,6 @@ dependencies {
testCompile project(path: ':core', configuration: 'testArtifacts') testCompile project(path: ':core', configuration: 'testArtifacts')
testCompile "junit:junit:$junit_version" testCompile "junit:junit:$junit_version"
// TODO: Upgrade to junit-quickcheck 0.8, once it is released,
// because it depends on org.javassist:javassist instead
// of javassist:javassist.
testCompile "com.pholser:junit-quickcheck-core:$quickcheck_version"
testCompile "com.pholser:junit-quickcheck-generators:$quickcheck_version"
} }
jar { jar {
@ -31,5 +27,5 @@ jar {
} }
publish { publish {
name = jar.baseName name jar.baseName
} }

View File

@ -1,5 +1,7 @@
package net.corda.jackson package net.corda.jackson
import com.fasterxml.jackson.annotation.JsonIgnore
import com.fasterxml.jackson.annotation.JsonProperty
import com.fasterxml.jackson.core.* import com.fasterxml.jackson.core.*
import com.fasterxml.jackson.databind.* import com.fasterxml.jackson.databind.*
import com.fasterxml.jackson.databind.deser.std.NumberDeserializers import com.fasterxml.jackson.databind.deser.std.NumberDeserializers
@ -9,6 +11,8 @@ import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule
import com.fasterxml.jackson.module.kotlin.KotlinModule import com.fasterxml.jackson.module.kotlin.KotlinModule
import net.corda.contracts.BusinessCalendar import net.corda.contracts.BusinessCalendar
import net.corda.core.contracts.Amount import net.corda.core.contracts.Amount
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.crypto.composite.CompositeKey import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
@ -17,9 +21,14 @@ import net.corda.core.identity.Party
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.services.IdentityService import net.corda.core.node.services.IdentityService
import net.corda.core.utilities.OpaqueBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.OpaqueBytes
import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.EdDSAPublicKey
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import java.math.BigDecimal import java.math.BigDecimal
@ -38,32 +47,24 @@ object JacksonSupport {
// If you change this API please update the docs in the docsite (json.rst) // If you change this API please update the docs in the docsite (json.rst)
interface PartyObjectMapper { interface PartyObjectMapper {
@Deprecated("Use partyFromX500Name instead")
fun partyFromName(partyName: String): Party?
fun partyFromX500Name(name: X500Name): Party? fun partyFromX500Name(name: X500Name): Party?
fun partyFromKey(owningKey: PublicKey): Party? fun partyFromKey(owningKey: PublicKey): Party?
fun partiesFromName(query: String): Set<Party> fun partiesFromName(query: String): Set<Party>
} }
class RpcObjectMapper(val rpc: CordaRPCOps, factory: JsonFactory, val fuzzyIdentityMatch: Boolean) : PartyObjectMapper, ObjectMapper(factory) { class RpcObjectMapper(val rpc: CordaRPCOps, factory: JsonFactory, val fuzzyIdentityMatch: Boolean) : PartyObjectMapper, ObjectMapper(factory) {
@Suppress("OverridingDeprecatedMember", "DEPRECATION")
override fun partyFromName(partyName: String): Party? = rpc.partyFromName(partyName)
override fun partyFromX500Name(name: X500Name): Party? = rpc.partyFromX500Name(name) override fun partyFromX500Name(name: X500Name): Party? = rpc.partyFromX500Name(name)
override fun partyFromKey(owningKey: PublicKey): Party? = rpc.partyFromKey(owningKey) override fun partyFromKey(owningKey: PublicKey): Party? = rpc.partyFromKey(owningKey)
override fun partiesFromName(query: String) = rpc.partiesFromName(query, fuzzyIdentityMatch) override fun partiesFromName(query: String) = rpc.partiesFromName(query, fuzzyIdentityMatch)
} }
class IdentityObjectMapper(val identityService: IdentityService, factory: JsonFactory, val fuzzyIdentityMatch: Boolean) : PartyObjectMapper, ObjectMapper(factory) { class IdentityObjectMapper(val identityService: IdentityService, factory: JsonFactory, val fuzzyIdentityMatch: Boolean) : PartyObjectMapper, ObjectMapper(factory) {
@Suppress("OverridingDeprecatedMember", "DEPRECATION")
override fun partyFromName(partyName: String): Party? = identityService.partyFromName(partyName)
override fun partyFromX500Name(name: X500Name): Party? = identityService.partyFromX500Name(name) override fun partyFromX500Name(name: X500Name): Party? = identityService.partyFromX500Name(name)
override fun partyFromKey(owningKey: PublicKey): Party? = identityService.partyFromKey(owningKey) override fun partyFromKey(owningKey: PublicKey): Party? = identityService.partyFromKey(owningKey)
override fun partiesFromName(query: String) = identityService.partiesFromName(query, fuzzyIdentityMatch) override fun partiesFromName(query: String) = identityService.partiesFromName(query, fuzzyIdentityMatch)
} }
class NoPartyObjectMapper(factory: JsonFactory) : PartyObjectMapper, ObjectMapper(factory) { class NoPartyObjectMapper(factory: JsonFactory) : PartyObjectMapper, ObjectMapper(factory) {
@Suppress("OverridingDeprecatedMember", "DEPRECATION")
override fun partyFromName(partyName: String): Party? = throw UnsupportedOperationException()
override fun partyFromX500Name(name: X500Name): Party? = throw UnsupportedOperationException() override fun partyFromX500Name(name: X500Name): Party? = throw UnsupportedOperationException()
override fun partyFromKey(owningKey: PublicKey): Party? = throw UnsupportedOperationException() override fun partyFromKey(owningKey: PublicKey): Party? = throw UnsupportedOperationException()
override fun partiesFromName(query: String) = throw UnsupportedOperationException() override fun partiesFromName(query: String) = throw UnsupportedOperationException()
@ -109,6 +110,10 @@ object JacksonSupport {
// For X.500 distinguished names // For X.500 distinguished names
addDeserializer(X500Name::class.java, X500NameDeserializer) addDeserializer(X500Name::class.java, X500NameDeserializer)
addSerializer(X500Name::class.java, X500NameSerializer) addSerializer(X500Name::class.java, X500NameSerializer)
// Mixins for transaction types to prevent some properties from being serialized
setMixInAnnotation(SignedTransaction::class.java, SignedTransactionMixin::class.java)
setMixInAnnotation(WireTransaction::class.java, WireTransactionMixin::class.java)
} }
} }
@ -278,7 +283,7 @@ object JacksonSupport {
object CalendarSerializer : JsonSerializer<BusinessCalendar>() { object CalendarSerializer : JsonSerializer<BusinessCalendar>() {
override fun serialize(obj: BusinessCalendar, generator: JsonGenerator, context: SerializerProvider) { override fun serialize(obj: BusinessCalendar, generator: JsonGenerator, context: SerializerProvider) {
val calendarName = BusinessCalendar.calendars.find { BusinessCalendar.getInstance(it) == obj } val calendarName = BusinessCalendar.calendars.find { BusinessCalendar.getInstance(it) == obj }
if(calendarName != null) { if (calendarName != null) {
generator.writeString(calendarName) generator.writeString(calendarName)
} else { } else {
generator.writeObject(BusinessCalendarWrapper(obj.holidayDates)) generator.writeObject(BusinessCalendarWrapper(obj.holidayDates))
@ -371,5 +376,24 @@ object JacksonSupport {
gen.writeBinary(value.bytes) gen.writeBinary(value.bytes)
} }
} }
abstract class SignedTransactionMixin {
@JsonIgnore abstract fun getTxBits(): SerializedBytes<CoreTransaction>
@JsonProperty("signatures") protected abstract fun getSigs(): List<TransactionSignature>
@JsonProperty protected abstract fun getTransaction(): CoreTransaction
@JsonIgnore abstract fun getTx(): WireTransaction
@JsonIgnore abstract fun getNotaryChangeTx(): NotaryChangeWireTransaction
@JsonIgnore abstract fun getInputs(): List<StateRef>
@JsonIgnore abstract fun getNotary(): Party?
@JsonIgnore abstract fun getId(): SecureHash
@JsonIgnore abstract fun getRequiredSigningKeys(): Set<PublicKey>
}
abstract class WireTransactionMixin {
@JsonIgnore abstract fun getMerkleTree(): MerkleTree
@JsonIgnore abstract fun getAvailableComponents(): List<Any>
@JsonIgnore abstract fun getAvailableComponentHashes(): List<SecureHash>
@JsonIgnore abstract fun getOutputStates(): List<ContractState>
}
} }

View File

@ -1,27 +1,31 @@
package net.corda.jackson package net.corda.jackson
import com.fasterxml.jackson.databind.SerializationFeature import com.fasterxml.jackson.databind.SerializationFeature
import com.pholser.junit.quickcheck.From
import com.pholser.junit.quickcheck.Property
import com.pholser.junit.quickcheck.runner.JUnitQuickcheck
import net.corda.core.contracts.Amount import net.corda.core.contracts.Amount
import net.corda.core.contracts.USD import net.corda.core.contracts.USD
import net.corda.core.testing.PublicKeyGenerator import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.generateKeyPair
import net.corda.core.transactions.SignedTransaction
import net.corda.testing.ALICE_PUBKEY
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.MINI_CORP
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.contracts.DummyContract
import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.EdDSAPublicKey
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith
import java.security.PublicKey
import java.util.* import java.util.*
import kotlin.test.assertEquals import kotlin.test.assertEquals
@RunWith(JUnitQuickcheck::class) class JacksonSupportTest : TestDependencyInjectionBase() {
class JacksonSupportTest {
companion object { companion object {
val mapper = JacksonSupport.createNonRpcMapper() val mapper = JacksonSupport.createNonRpcMapper()
} }
@Property @Test
fun publicKeySerializingWorks(@From(PublicKeyGenerator::class) publicKey: PublicKey) { fun publicKeySerializingWorks() {
val publicKey = generateKeyPair().public
val serialized = mapper.writeValueAsString(publicKey) val serialized = mapper.writeValueAsString(publicKey)
val parsedKey = mapper.readValue(serialized, EdDSAPublicKey::class.java) val parsedKey = mapper.readValue(serialized, EdDSAPublicKey::class.java)
assertEquals(publicKey, parsedKey) assertEquals(publicKey, parsedKey)
@ -50,4 +54,24 @@ class JacksonSupportTest {
val writer = mapper.writer().without(SerializationFeature.INDENT_OUTPUT) val writer = mapper.writer().without(SerializationFeature.INDENT_OUTPUT)
assertEquals("""{"notional":"25000000.00 USD"}""", writer.writeValueAsString(Dummy(Amount.parseCurrency("$25000000")))) assertEquals("""{"notional":"25000000.00 USD"}""", writer.writeValueAsString(Dummy(Amount.parseCurrency("$25000000"))))
} }
@Test
fun writeTransaction() {
fun makeDummyTx(): SignedTransaction {
val wtx = DummyContract.generateInitial(1, DUMMY_NOTARY, MINI_CORP.ref(1)).toWireTransaction()
val signatures = TransactionSignature(
ByteArray(1),
ALICE_PUBKEY,
SignatureMetadata(
1,
Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID
)
)
return SignedTransaction(wtx, listOf(signatures))
}
val writer = mapper.writer()
// We don't particularly care about the serialized format, just need to make sure it completes successfully.
writer.writeValueAsString(makeDummyTx())
}
} }

View File

@ -62,5 +62,5 @@ jar {
} }
publish { publish {
name = jar.baseName name jar.baseName
} }

View File

@ -2,15 +2,15 @@ package net.corda.client.jfx
import net.corda.client.jfx.model.NodeMonitorModel import net.corda.client.jfx.model.NodeMonitorModel
import net.corda.client.jfx.model.ProgressTrackingEvent import net.corda.client.jfx.model.ProgressTrackingEvent
import net.corda.core.bufferUntilSubscribed import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.contracts.Amount import net.corda.core.contracts.Amount
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.USD import net.corda.core.contracts.USD
import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.isFulfilledBy
import net.corda.core.crypto.keys import net.corda.core.crypto.keys
import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.getOrThrow
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.core.messaging.StateMachineTransactionMapping
import net.corda.core.messaging.StateMachineUpdate import net.corda.core.messaging.StateMachineUpdate
@ -19,12 +19,9 @@ import net.corda.core.node.NodeInfo
import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.NetworkMapCache
import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceInfo
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.testing.ALICE import net.corda.core.utilities.OpaqueBytes
import net.corda.testing.BOB import net.corda.core.utilities.getOrThrow
import net.corda.testing.CHARLIE
import net.corda.testing.DUMMY_NOTARY
import net.corda.flows.CashExitFlow import net.corda.flows.CashExitFlow
import net.corda.flows.CashIssueFlow import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow import net.corda.flows.CashPaymentFlow
@ -32,11 +29,9 @@ import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.startFlowPermission import net.corda.node.services.startFlowPermission
import net.corda.node.services.transactions.SimpleNotaryService import net.corda.node.services.transactions.SimpleNotaryService
import net.corda.nodeapi.User import net.corda.nodeapi.User
import net.corda.testing.*
import net.corda.testing.driver.driver import net.corda.testing.driver.driver
import net.corda.testing.expect
import net.corda.testing.expectEvents
import net.corda.testing.node.DriverBasedTest import net.corda.testing.node.DriverBasedTest
import net.corda.testing.sequence
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.junit.Test import org.junit.Test
import rx.Observable import rx.Observable
@ -53,7 +48,7 @@ class NodeMonitorModelTest : DriverBasedTest() {
lateinit var stateMachineUpdatesBob: Observable<StateMachineUpdate> lateinit var stateMachineUpdatesBob: Observable<StateMachineUpdate>
lateinit var progressTracking: Observable<ProgressTrackingEvent> lateinit var progressTracking: Observable<ProgressTrackingEvent>
lateinit var transactions: Observable<SignedTransaction> lateinit var transactions: Observable<SignedTransaction>
lateinit var vaultUpdates: Observable<Vault.Update> lateinit var vaultUpdates: Observable<Vault.Update<ContractState>>
lateinit var networkMapUpdates: Observable<NetworkMapCache.MapChange> lateinit var networkMapUpdates: Observable<NetworkMapCache.MapChange>
lateinit var newNode: (X500Name) -> NodeInfo lateinit var newNode: (X500Name) -> NodeInfo
@ -78,14 +73,14 @@ class NodeMonitorModelTest : DriverBasedTest() {
vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed() vaultUpdates = monitor.vaultUpdates.bufferUntilSubscribed()
networkMapUpdates = monitor.networkMap.bufferUntilSubscribed() networkMapUpdates = monitor.networkMap.bufferUntilSubscribed()
monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) monitor.register(aliceNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password, initialiseSerialization = false)
rpc = monitor.proxyObservable.value!! rpc = monitor.proxyObservable.value!!
val bobNodeHandle = startNode(BOB.name, rpcUsers = listOf(cashUser)).getOrThrow() val bobNodeHandle = startNode(BOB.name, rpcUsers = listOf(cashUser)).getOrThrow()
bobNode = bobNodeHandle.nodeInfo bobNode = bobNodeHandle.nodeInfo
val monitorBob = NodeMonitorModel() val monitorBob = NodeMonitorModel()
stateMachineUpdatesBob = monitorBob.stateMachineUpdates.bufferUntilSubscribed() stateMachineUpdatesBob = monitorBob.stateMachineUpdates.bufferUntilSubscribed()
monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password) monitorBob.register(bobNodeHandle.configuration.rpcAddress!!, cashUser.username, cashUser.password, initialiseSerialization = false)
rpcBob = monitorBob.proxyObservable.value!! rpcBob = monitorBob.proxyObservable.value!!
runTest() runTest()
} }
@ -148,7 +143,7 @@ class NodeMonitorModelTest : DriverBasedTest() {
var moveSmId: StateMachineRunId? = null var moveSmId: StateMachineRunId? = null
var issueTx: SignedTransaction? = null var issueTx: SignedTransaction? = null
var moveTx: SignedTransaction? = null var moveTx: SignedTransaction? = null
stateMachineUpdates.expectEvents { stateMachineUpdates.expectEvents(isStrict = false) {
sequence( sequence(
// ISSUE // ISSUE
expect { add: StateMachineUpdate.Added -> expect { add: StateMachineUpdate.Added ->
@ -159,14 +154,13 @@ class NodeMonitorModelTest : DriverBasedTest() {
expect { remove: StateMachineUpdate.Removed -> expect { remove: StateMachineUpdate.Removed ->
require(remove.id == issueSmId) require(remove.id == issueSmId)
}, },
// MOVE // MOVE - N.B. There are other framework flows that happen in parallel for the remote resolve transactions flow
expect { add: StateMachineUpdate.Added -> expect(match = { it is StateMachineUpdate.Added && it.stateMachineInfo.flowLogicClassName == CashPaymentFlow::class.java.name }) { add: StateMachineUpdate.Added ->
moveSmId = add.id moveSmId = add.id
val initiator = add.stateMachineInfo.initiator val initiator = add.stateMachineInfo.initiator
require(initiator is FlowInitiator.RPC && initiator.username == "user1") require(initiator is FlowInitiator.RPC && initiator.username == "user1")
}, },
expect { remove: StateMachineUpdate.Removed -> expect(match = { it is StateMachineUpdate.Removed && it.id == moveSmId }) {
require(remove.id == moveSmId)
} }
) )
} }

View File

@ -19,7 +19,7 @@ data class Diff<out T : ContractState>(
* This model exposes the list of owned contract states. * This model exposes the list of owned contract states.
*/ */
class ContractStateModel { class ContractStateModel {
private val vaultUpdates: Observable<Vault.Update> by observable(NodeMonitorModel::vaultUpdates) private val vaultUpdates: Observable<Vault.Update<ContractState>> by observable(NodeMonitorModel::vaultUpdates)
private val contractStatesDiff: Observable<Diff<ContractState>> = vaultUpdates.map { private val contractStatesDiff: Observable<Diff<ContractState>> = vaultUpdates.map {
Diff(it.produced, it.consumed) Diff(it.produced, it.consumed)

View File

@ -3,14 +3,13 @@ package net.corda.client.jfx.model
import javafx.beans.property.SimpleObjectProperty import javafx.beans.property.SimpleObjectProperty
import net.corda.client.rpc.CordaRPCClient import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCClientConfiguration import net.corda.client.rpc.CordaRPCClientConfiguration
import net.corda.core.contracts.ContractState
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.*
import net.corda.core.messaging.StateMachineInfo
import net.corda.core.messaging.StateMachineTransactionMapping
import net.corda.core.messaging.StateMachineUpdate
import net.corda.core.node.services.NetworkMapCache.MapChange import net.corda.core.node.services.NetworkMapCache.MapChange
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.seconds import net.corda.core.node.services.vault.*
import net.corda.core.utilities.seconds
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import rx.Observable import rx.Observable
@ -32,14 +31,14 @@ data class ProgressTrackingEvent(val stateMachineId: StateMachineRunId, val mess
class NodeMonitorModel { class NodeMonitorModel {
private val stateMachineUpdatesSubject = PublishSubject.create<StateMachineUpdate>() private val stateMachineUpdatesSubject = PublishSubject.create<StateMachineUpdate>()
private val vaultUpdatesSubject = PublishSubject.create<Vault.Update>() private val vaultUpdatesSubject = PublishSubject.create<Vault.Update<ContractState>>()
private val transactionsSubject = PublishSubject.create<SignedTransaction>() private val transactionsSubject = PublishSubject.create<SignedTransaction>()
private val stateMachineTransactionMappingSubject = PublishSubject.create<StateMachineTransactionMapping>() private val stateMachineTransactionMappingSubject = PublishSubject.create<StateMachineTransactionMapping>()
private val progressTrackingSubject = PublishSubject.create<ProgressTrackingEvent>() private val progressTrackingSubject = PublishSubject.create<ProgressTrackingEvent>()
private val networkMapSubject = PublishSubject.create<MapChange>() private val networkMapSubject = PublishSubject.create<MapChange>()
val stateMachineUpdates: Observable<StateMachineUpdate> = stateMachineUpdatesSubject val stateMachineUpdates: Observable<StateMachineUpdate> = stateMachineUpdatesSubject
val vaultUpdates: Observable<Vault.Update> = vaultUpdatesSubject val vaultUpdates: Observable<Vault.Update<ContractState>> = vaultUpdatesSubject
val transactions: Observable<SignedTransaction> = transactionsSubject val transactions: Observable<SignedTransaction> = transactionsSubject
val stateMachineTransactionMapping: Observable<StateMachineTransactionMapping> = stateMachineTransactionMappingSubject val stateMachineTransactionMapping: Observable<StateMachineTransactionMapping> = stateMachineTransactionMappingSubject
val progressTracking: Observable<ProgressTrackingEvent> = progressTrackingSubject val progressTracking: Observable<ProgressTrackingEvent> = progressTrackingSubject
@ -51,17 +50,18 @@ class NodeMonitorModel {
* Register for updates to/from a given vault. * Register for updates to/from a given vault.
* TODO provide an unsubscribe mechanism * TODO provide an unsubscribe mechanism
*/ */
fun register(nodeHostAndPort: NetworkHostAndPort, username: String, password: String) { fun register(nodeHostAndPort: NetworkHostAndPort, username: String, password: String, initialiseSerialization: Boolean = true) {
val client = CordaRPCClient( val client = CordaRPCClient(
hostAndPort = nodeHostAndPort, hostAndPort = nodeHostAndPort,
configuration = CordaRPCClientConfiguration.default.copy( configuration = CordaRPCClientConfiguration.default.copy(
connectionMaxRetryInterval = 10.seconds connectionMaxRetryInterval = 10.seconds
) ),
initialiseSerialization = initialiseSerialization
) )
val connection = client.start(username, password) val connection = client.start(username, password)
val proxy = connection.proxy val proxy = connection.proxy
val (stateMachines, stateMachineUpdates) = proxy.stateMachinesAndUpdates() val (stateMachines, stateMachineUpdates) = proxy.stateMachinesFeed()
// Extract the flow tracking stream // Extract the flow tracking stream
// TODO is there a nicer way of doing this? Stream of streams in general results in code like this... // TODO is there a nicer way of doing this? Stream of streams in general results in code like this...
val currentProgressTrackerUpdates = stateMachines.mapNotNull { stateMachine -> val currentProgressTrackerUpdates = stateMachines.mapNotNull { stateMachine ->
@ -82,21 +82,22 @@ class NodeMonitorModel {
val currentStateMachines = stateMachines.map { StateMachineUpdate.Added(it) } val currentStateMachines = stateMachines.map { StateMachineUpdate.Added(it) }
stateMachineUpdates.startWith(currentStateMachines).subscribe(stateMachineUpdatesSubject) stateMachineUpdates.startWith(currentStateMachines).subscribe(stateMachineUpdatesSubject)
// Vault updates // Vault snapshot (force single page load with MAX_PAGE_SIZE) + updates
val (vault, vaultUpdates) = proxy.vaultAndUpdates() val (vaultSnapshot, vaultUpdates) = proxy.vaultTrackBy<ContractState>(QueryCriteria.VaultQueryCriteria(Vault.StateStatus.ALL),
val initialVaultUpdate = Vault.Update(setOf(), vault.toSet()) PageSpecification(DEFAULT_PAGE_NUM, MAX_PAGE_SIZE))
val initialVaultUpdate = Vault.Update(setOf(), vaultSnapshot.states.toSet())
vaultUpdates.startWith(initialVaultUpdate).subscribe(vaultUpdatesSubject) vaultUpdates.startWith(initialVaultUpdate).subscribe(vaultUpdatesSubject)
// Transactions // Transactions
val (transactions, newTransactions) = proxy.verifiedTransactions() val (transactions, newTransactions) = proxy.verifiedTransactionsFeed()
newTransactions.startWith(transactions).subscribe(transactionsSubject) newTransactions.startWith(transactions).subscribe(transactionsSubject)
// SM -> TX mapping // SM -> TX mapping
val (smTxMappings, futureSmTxMappings) = proxy.stateMachineRecordedTransactionMapping() val (smTxMappings, futureSmTxMappings) = proxy.stateMachineRecordedTransactionMappingFeed()
futureSmTxMappings.startWith(smTxMappings).subscribe(stateMachineTransactionMappingSubject) futureSmTxMappings.startWith(smTxMappings).subscribe(stateMachineTransactionMappingSubject)
// Parties on network // Parties on network
val (parties, futurePartyUpdate) = proxy.networkMapUpdates() val (parties, futurePartyUpdate) = proxy.networkMapFeed()
futurePartyUpdate.startWith(parties.map { MapChange.Added(it) }).subscribe(networkMapSubject) futurePartyUpdate.startWith(parties.map { MapChange.Added(it) }).subscribe(networkMapSubject)
proxyObservable.set(proxy) proxyObservable.set(proxy)

View File

@ -23,13 +23,14 @@ object AmountBindings {
) { sum -> Amount(sum.toLong(), token) } ) { sum -> Amount(sum.toLong(), token) }
fun exchange( fun exchange(
currency: ObservableValue<Currency>, observableCurrency: ObservableValue<Currency>,
exchangeRate: ObservableValue<ExchangeRate> observableExchangeRate: ObservableValue<ExchangeRate>
): ObservableValue<Pair<Currency, (Amount<Currency>) -> Long>> { ): ObservableValue<Pair<Currency, (Amount<Currency>) -> Long>> {
return EasyBind.combine(currency, exchangeRate) { currency, exchangeRate -> return EasyBind.combine(observableCurrency, observableExchangeRate) { currency, exchangeRate ->
Pair(currency) { amount: Amount<Currency> -> Pair<Currency, (Amount<Currency>) -> Long>(
(exchangeRate.rate(amount.token, currency) * amount.quantity).toLong() currency,
} { (quantity, _, token) -> (exchangeRate.rate(token, currency) * quantity).toLong() }
)
} }
} }

View File

@ -1,5 +1,6 @@
package net.corda.client.jfx.utils package net.corda.client.jfx.utils
import javafx.application.Platform
import javafx.beans.binding.Bindings import javafx.beans.binding.Bindings
import javafx.beans.binding.BooleanBinding import javafx.beans.binding.BooleanBinding
import javafx.beans.property.ReadOnlyObjectWrapper import javafx.beans.property.ReadOnlyObjectWrapper
@ -10,7 +11,13 @@ import javafx.collections.MapChangeListener
import javafx.collections.ObservableList import javafx.collections.ObservableList
import javafx.collections.ObservableMap import javafx.collections.ObservableMap
import javafx.collections.transformation.FilteredList import javafx.collections.transformation.FilteredList
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.messaging.DataFeed
import net.corda.core.node.services.Vault
import org.fxmisc.easybind.EasyBind import org.fxmisc.easybind.EasyBind
import rx.Observable
import rx.schedulers.Schedulers
import java.util.function.Predicate import java.util.function.Predicate
/** /**
@ -313,3 +320,36 @@ fun <A> ObservableList<A>.firstOrDefault(default: ObservableValue<A?>, predicate
fun <A> ObservableList<A>.firstOrNullObservable(predicate: (A) -> Boolean): ObservableValue<A?> { fun <A> ObservableList<A>.firstOrNullObservable(predicate: (A) -> Boolean): ObservableValue<A?> {
return Bindings.createObjectBinding({ this.firstOrNull(predicate) }, arrayOf(this)) return Bindings.createObjectBinding({ this.firstOrNull(predicate) }, arrayOf(this))
} }
/**
* Modifies the given Rx observable such that emissions are run on the JavaFX GUI thread. Use this when you have an Rx
* observable that may emit in the background e.g. from the network and you wish to link it to the user interface.
*
* Note: you should use the returned observable, not the original one this method is called on.
*/
fun <T> Observable<T>.observeOnFXThread(): Observable<T> = observeOn(Schedulers.from(Platform::runLater))
/**
* Given a [DataFeed] that contains the results of a vault query and a subsequent stream of changes, returns a JavaFX
* [ObservableList] that mirrors the streamed results on the UI thread. Note that the paging is *not* respected by this
* function: if a state is added that would not have appeared in the page in the initial query, it will still be added
* to the observable list.
*
* @see toFXListOfStates if you want just the state objects and not the ledger pointers too.
*/
fun <T : ContractState> DataFeed<Vault.Page<T>, Vault.Update<T>>.toFXListOfStateRefs(): ObservableList<StateAndRef<T>> {
val list = FXCollections.observableArrayList(snapshot.states)
updates.observeOnFXThread().subscribe { (consumed, produced) ->
list.removeAll(consumed)
list.addAll(produced)
}
return list
}
/**
* Returns the same list as [toFXListOfStateRefs] but which contains the states instead of [StateAndRef] wrappers.
* The same notes apply as with that function.
*/
fun <T : ContractState> DataFeed<Vault.Page<T>, Vault.Update<T>>.toFXListOfStates(): ObservableList<T> {
return toFXListOfStateRefs().map { it.state.data }
}

View File

@ -30,5 +30,5 @@ jar {
} }
publish { publish {
name = jar.baseName name jar.baseName
} }

View File

@ -167,21 +167,23 @@ fun <A> Generator.Companion.replicate(number: Int, generator: Generator<A>): Gen
} }
fun <A> Generator.Companion.replicatePoisson(meanSize: Double, generator: Generator<A>) = Generator<List<A>> { fun <A> Generator.Companion.replicatePoisson(meanSize: Double, generator: Generator<A>, atLeastOne: Boolean = false) = Generator<List<A>> {
val chance = (meanSize - 1) / meanSize val chance = (meanSize - 1) / meanSize
val result = mutableListOf<A>() val result = mutableListOf<A>()
var finish = false var finish = false
while (!finish) { while (!finish) {
val result = Generator.doubleRange(0.0, 1.0).generate(it).flatMap { value -> val res = Generator.doubleRange(0.0, 1.0).generate(it).flatMap { value ->
if (value < chance) { if (value < chance) {
generator.generate(it).map { result.add(it) } generator.generate(it).map { result.add(it) }
} else { } else {
finish = true finish = true
Try.Success(Unit) if (result.isEmpty() && atLeastOne) {
generator.generate(it).map { result.add(it) }
} else Try.Success(Unit)
} }
} }
if (result is Try.Failure) { if (res is Try.Failure) {
return@Generator result return@Generator res
} }
} }
Try.Success(result) Try.Success(result)

View File

@ -24,6 +24,11 @@ sourceSets {
runtimeClasspath += main.output + test.output runtimeClasspath += main.output + test.output
srcDir file('src/integration-test/kotlin') srcDir file('src/integration-test/kotlin')
} }
java {
compileClasspath += main.output + test.output
runtimeClasspath += main.output + test.output
srcDir file('src/integration-test/java')
}
} }
smokeTest { smokeTest {
kotlin { kotlin {
@ -33,6 +38,11 @@ sourceSets {
runtimeClasspath += main.output runtimeClasspath += main.output
srcDir file('src/smoke-test/kotlin') srcDir file('src/smoke-test/kotlin')
} }
java {
compileClasspath += main.output
runtimeClasspath += main.output
srcDir file('src/smoke-test/java')
}
} }
} }
@ -82,5 +92,5 @@ jar {
} }
publish { publish {
name = jar.baseName name jar.baseName
} }

View File

@ -0,0 +1,81 @@
package net.corda.client.rpc;
import net.corda.core.concurrent.CordaFuture;
import net.corda.client.rpc.internal.RPCClient;
import net.corda.core.contracts.Amount;
import net.corda.core.messaging.CordaRPCOps;
import net.corda.core.messaging.FlowHandle;
import net.corda.core.node.services.ServiceInfo;
import net.corda.core.utilities.OpaqueBytes;
import net.corda.flows.AbstractCashFlow;
import net.corda.flows.CashIssueFlow;
import net.corda.flows.CashPaymentFlow;
import net.corda.node.internal.Node;
import net.corda.node.services.transactions.ValidatingNotaryService;
import net.corda.nodeapi.User;
import net.corda.testing.node.NodeBasedTest;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.ExecutionException;
import static kotlin.test.AssertionsKt.assertEquals;
import static net.corda.client.rpc.CordaRPCClientConfiguration.getDefault;
import static net.corda.contracts.GetBalances.getCashBalance;
import static net.corda.node.services.RPCUserServiceKt.startFlowPermission;
import static net.corda.testing.TestConstants.getALICE;
public class CordaRPCJavaClientTest extends NodeBasedTest {
private List<String> perms = Arrays.asList(startFlowPermission(CashPaymentFlow.class), startFlowPermission(CashIssueFlow.class));
private Set<String> permSet = new HashSet<>(perms);
private User rpcUser = new User("user1", "test", permSet);
private Node node;
private CordaRPCClient client;
private RPCClient.RPCConnection<CordaRPCOps> connection = null;
private CordaRPCOps rpcProxy;
private void login(String username, String password) {
connection = client.start(username, password);
rpcProxy = connection.getProxy();
}
@Before
public void setUp() throws ExecutionException, InterruptedException {
Set<ServiceInfo> services = new HashSet<>(Collections.singletonList(new ServiceInfo(ValidatingNotaryService.Companion.getType(), null)));
CordaFuture<Node> nodeFuture = startNode(getALICE().getName(), 1, services, Arrays.asList(rpcUser), Collections.emptyMap());
node = nodeFuture.get();
client = new CordaRPCClient(node.getConfiguration().getRpcAddress(), null, getDefault(), false);
}
@After
public void done() throws IOException {
connection.close();
}
@Test
public void testLogin() {
login(rpcUser.getUsername(), rpcUser.getPassword());
}
@Test
public void testCashBalances() throws NoSuchFieldException, ExecutionException, InterruptedException {
login(rpcUser.getUsername(), rpcUser.getPassword());
Amount<Currency> dollars123 = new Amount<>(123, Currency.getInstance("USD"));
FlowHandle<AbstractCashFlow.Result> flowHandle = rpcProxy.startFlowDynamic(CashIssueFlow.class,
dollars123, OpaqueBytes.of("1".getBytes()),
node.info.getLegalIdentity(), node.info.getLegalIdentity());
System.out.println("Started issuing cash, waiting on result");
flowHandle.getReturnValue().get();
Amount<Currency> balance = getCashBalance(rpcProxy, Currency.getInstance("USD"));
System.out.print("Balance: " + balance + "\n");
assertEquals(dollars123, balance, "matching");
}
}

View File

@ -1,13 +1,15 @@
package net.corda.client.rpc package net.corda.client.rpc
import net.corda.contracts.getCashBalance
import net.corda.contracts.getCashBalances
import net.corda.core.contracts.DOLLARS import net.corda.core.contracts.DOLLARS
import net.corda.core.contracts.USD
import net.corda.core.crypto.random63BitValue
import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowInitiator
import net.corda.core.getOrThrow
import net.corda.core.messaging.* import net.corda.core.messaging.*
import net.corda.core.node.services.ServiceInfo import net.corda.core.node.services.ServiceInfo
import net.corda.core.crypto.random63BitValue
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.testing.ALICE import net.corda.core.utilities.getOrThrow
import net.corda.flows.CashException import net.corda.flows.CashException
import net.corda.flows.CashIssueFlow import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow import net.corda.flows.CashPaymentFlow
@ -15,13 +17,13 @@ import net.corda.node.internal.Node
import net.corda.node.services.startFlowPermission import net.corda.node.services.startFlowPermission
import net.corda.node.services.transactions.ValidatingNotaryService import net.corda.node.services.transactions.ValidatingNotaryService
import net.corda.nodeapi.User import net.corda.nodeapi.User
import net.corda.testing.ALICE
import net.corda.testing.node.NodeBasedTest import net.corda.testing.node.NodeBasedTest
import org.apache.activemq.artemis.api.core.ActiveMQSecurityException import org.apache.activemq.artemis.api.core.ActiveMQSecurityException
import org.assertj.core.api.Assertions.assertThatExceptionOfType import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.util.*
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFalse import kotlin.test.assertFalse
import kotlin.test.assertTrue import kotlin.test.assertTrue
@ -42,7 +44,7 @@ class CordaRPCClientTest : NodeBasedTest() {
@Before @Before
fun setUp() { fun setUp() {
node = startNode(ALICE.name, rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow() node = startNode(ALICE.name, rpcUsers = listOf(rpcUser), advertisedServices = setOf(ServiceInfo(ValidatingNotaryService.type))).getOrThrow()
client = CordaRPCClient(node.configuration.rpcAddress!!) client = CordaRPCClient(node.configuration.rpcAddress!!, initialiseSerialization = false)
} }
@After @After
@ -117,20 +119,18 @@ class CordaRPCClientTest : NodeBasedTest() {
println("Started issuing cash, waiting on result") println("Started issuing cash, waiting on result")
flowHandle.returnValue.get() flowHandle.returnValue.get()
val finishCash = proxy.getCashBalances() val cashDollars = proxy.getCashBalance(USD)
println("Cash Balances: $finishCash") println("Balance: $cashDollars")
assertEquals(1, finishCash.size) assertEquals(123.DOLLARS, cashDollars)
assertEquals(123.DOLLARS, finishCash.get(Currency.getInstance("USD")))
} }
@Test @Test
fun `flow initiator via RPC`() { fun `flow initiator via RPC`() {
login(rpcUser.username, rpcUser.password) login(rpcUser.username, rpcUser.password)
val proxy = connection!!.proxy val proxy = connection!!.proxy
val smUpdates = proxy.stateMachinesAndUpdates()
var countRpcFlows = 0 var countRpcFlows = 0
var countShellFlows = 0 var countShellFlows = 0
smUpdates.second.subscribe { proxy.stateMachinesFeed().updates.subscribe {
if (it is StateMachineUpdate.Added) { if (it is StateMachineUpdate.Added) {
val initiator = it.stateMachineInfo.initiator val initiator = it.stateMachineInfo.initiator
if (initiator is FlowInitiator.RPC) if (initiator is FlowInitiator.RPC)

View File

@ -1,24 +1,15 @@
package net.corda.client.rpc package net.corda.client.rpc
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.util.concurrent.Futures
import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.future import net.corda.core.internal.concurrent.fork
import net.corda.core.getOrThrow import net.corda.core.internal.concurrent.transpose
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.millis import net.corda.core.serialization.SerializationDefaults
import net.corda.core.seconds import net.corda.core.utilities.*
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.Try
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.RPCApi import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.RPCKryo
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.driver.poll import net.corda.testing.driver.poll
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
@ -29,10 +20,7 @@ import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject import rx.subjects.UnicastSubject
import java.time.Duration import java.time.Duration
import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.*
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
class RPCStabilityTests { class RPCStabilityTests {
@ -238,9 +226,7 @@ class RPCStabilityTests {
assertEquals("pong", client.ping()) assertEquals("pong", client.ping())
serverFollower.shutdown() serverFollower.shutdown()
startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow() startRpcServer<ReconnectOps>(ops = ops, customPort = serverPort).getOrThrow()
val pingFuture = future { val pingFuture = ForkJoinPool.commonPool().fork(client::ping)
client.ping()
}
assertEquals("pong", pingFuture.getOrThrow(10.seconds)) assertEquals("pong", pingFuture.getOrThrow(10.seconds))
clientFollower.shutdown() // Driver would do this after the new server, causing hang. clientFollower.shutdown() // Driver would do this after the new server, causing hang.
} }
@ -274,9 +260,9 @@ class RPCStabilityTests {
).get() ).get()
val numberOfClients = 4 val numberOfClients = 4
val clients = Futures.allAsList((1 .. numberOfClients).map { val clients = (1 .. numberOfClients).map {
startRandomRpcClient<TrackSubscriberOps>(server.broker.hostAndPort!!) startRandomRpcClient<TrackSubscriberOps>(server.broker.hostAndPort!!)
}).get() }.transpose().get()
// Poll until all clients connect // Poll until all clients connect
pollUntilClientNumber(server, numberOfClients) pollUntilClientNumber(server, numberOfClients)
@ -305,16 +291,8 @@ class RPCStabilityTests {
return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk } return Observable.interval(interval.toMillis(), TimeUnit.MILLISECONDS).map { chunk }
} }
} }
val dummyObservableSerialiser = object : Serializer<Observable<Any>>() {
override fun write(kryo: Kryo?, output: Output?, `object`: Observable<Any>?) {
}
override fun read(kryo: Kryo?, input: Input?, type: Class<Observable<Any>>?): Observable<Any> {
return Observable.empty()
}
}
@Test @Test
fun `slow consumers are kicked`() { fun `slow consumers are kicked`() {
val kryoPool = KryoPool.Builder { RPCKryo(dummyObservableSerialiser) }.build()
rpcDriver { rpcDriver {
val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get() val server = startRpcServer(maxBufferedBytesPerClient = 10 * 1024 * 1024, ops = SlowConsumerRPCOpsImpl()).get()
@ -339,7 +317,7 @@ class RPCStabilityTests {
methodName = SlowConsumerRPCOps::streamAtInterval.name, methodName = SlowConsumerRPCOps::streamAtInterval.name,
arguments = listOf(10.millis, 123456) arguments = listOf(10.millis, 123456)
) )
request.writeToClientMessage(kryoPool, message) request.writeToClientMessage(SerializationDefaults.RPC_SERVER_CONTEXT, message)
producer.send(message) producer.send(message)
session.commit() session.commit()

View File

@ -2,11 +2,17 @@ package net.corda.client.rpc
import net.corda.client.rpc.internal.RPCClient import net.corda.client.rpc.internal.RPCClient
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.client.rpc.serialization.KryoClientSerializationScheme
import net.corda.core.messaging.CordaRPCOps import net.corda.core.messaging.CordaRPCOps
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
import net.corda.nodeapi.ConnectionDirection import net.corda.nodeapi.ConnectionDirection
import net.corda.nodeapi.config.SSLConfiguration import net.corda.nodeapi.config.SSLConfiguration
import net.corda.nodeapi.internal.serialization.AMQPClientSerializationScheme
import net.corda.nodeapi.internal.serialization.KRYO_P2P_CONTEXT
import net.corda.nodeapi.internal.serialization.KRYO_RPC_CLIENT_CONTEXT
import net.corda.nodeapi.internal.serialization.SerializationFactoryImpl
import java.time.Duration import java.time.Duration
/** @see RPCClient.RPCConnection */ /** @see RPCClient.RPCConnection */
@ -35,11 +41,22 @@ data class CordaRPCClientConfiguration(
class CordaRPCClient( class CordaRPCClient(
hostAndPort: NetworkHostAndPort, hostAndPort: NetworkHostAndPort,
sslConfiguration: SSLConfiguration? = null, sslConfiguration: SSLConfiguration? = null,
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.default,
initialiseSerialization: Boolean = true
) { ) {
init {
// Init serialization. It's plausible there are multiple clients in a single JVM, so be tolerant of
// others having registered first.
// TODO: allow clients to have serialization factory etc injected and align with RPC protocol version?
if (initialiseSerialization) {
initialiseSerialization()
}
}
private val rpcClient = RPCClient<CordaRPCOps>( private val rpcClient = RPCClient<CordaRPCOps>(
tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration),
configuration.toRpcClientConfiguration() configuration.toRpcClientConfiguration(),
KRYO_RPC_CLIENT_CONTEXT
) )
fun start(username: String, password: String): CordaRPCConnection { fun start(username: String, password: String): CordaRPCConnection {
@ -49,4 +66,21 @@ class CordaRPCClient(
inline fun <A> use(username: String, password: String, block: (CordaRPCConnection) -> A): A { inline fun <A> use(username: String, password: String, block: (CordaRPCConnection) -> A): A {
return start(username, password).use(block) return start(username, password).use(block)
} }
companion object {
fun initialiseSerialization() {
try {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme(this))
registerScheme(AMQPClientSerializationScheme())
}
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT
SerializationDefaults.RPC_CLIENT_CONTEXT = KRYO_RPC_CLIENT_CONTEXT
} catch(e: IllegalStateException) {
// Check that it's registered as we expect
check(SerializationDefaults.SERIALIZATION_FACTORY is SerializationFactoryImpl) { "RPC client encountered conflicting configuration of serialization subsystem." }
check((SerializationDefaults.SERIALIZATION_FACTORY as SerializationFactoryImpl).alreadyRegisteredSchemes.any { it is KryoClientSerializationScheme }) { "RPC client encountered conflicting configuration of serialization subsystem." }
}
}
}
} }

View File

@ -1,10 +1,12 @@
package net.corda.client.rpc.internal package net.corda.client.rpc.internal
import net.corda.core.logElapsedTime
import net.corda.core.messaging.RPCOps
import net.corda.core.minutes
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.seconds import net.corda.core.internal.logElapsedTime
import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport import net.corda.nodeapi.ArtemisTcpTransport.Companion.tcpTransport
@ -85,13 +87,15 @@ data class RPCClientConfiguration(
*/ */
class RPCClient<I : RPCOps>( class RPCClient<I : RPCOps>(
val transport: TransportConfiguration, val transport: TransportConfiguration,
val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default val rpcConfiguration: RPCClientConfiguration = RPCClientConfiguration.default,
val serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT
) { ) {
constructor( constructor(
hostAndPort: NetworkHostAndPort, hostAndPort: NetworkHostAndPort,
sslConfiguration: SSLConfiguration? = null, sslConfiguration: SSLConfiguration? = null,
configuration: RPCClientConfiguration = RPCClientConfiguration.default configuration: RPCClientConfiguration = RPCClientConfiguration.default,
) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration) serializationContext: SerializationContext = SerializationDefaults.RPC_CLIENT_CONTEXT
) : this(tcpTransport(ConnectionDirection.Outbound(), hostAndPort, sslConfiguration), configuration, serializationContext)
companion object { companion object {
private val log = loggerFor<RPCClient<*>>() private val log = loggerFor<RPCClient<*>>()
@ -146,7 +150,7 @@ class RPCClient<I : RPCOps>(
minLargeMessageSize = rpcConfiguration.maxFileSize minLargeMessageSize = rpcConfiguration.maxFileSize
} }
val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass) val proxyHandler = RPCClientProxyHandler(rpcConfiguration, username, password, serverLocator, clientAddress, rpcOpsClass, serializationContext)
try { try {
proxyHandler.start() proxyHandler.start()

View File

@ -4,18 +4,19 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.cache.Cache import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder import com.google.common.cache.CacheBuilder
import com.google.common.cache.RemovalCause import com.google.common.cache.RemovalCause
import com.google.common.cache.RemovalListener import com.google.common.cache.RemovalListener
import com.google.common.util.concurrent.SettableFuture import com.google.common.util.concurrent.SettableFuture
import com.google.common.util.concurrent.ThreadFactoryBuilder import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.core.ThreadBox import net.corda.core.internal.ThreadBox
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.getOrThrow import net.corda.core.internal.LazyPool
import net.corda.core.internal.LazyStickyPool
import net.corda.core.internal.LifeCycle
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.KryoPoolWithContext import net.corda.core.serialization.SerializationContext
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.nodeapi.* import net.corda.nodeapi.*
import org.apache.activemq.artemis.api.core.SimpleString import org.apache.activemq.artemis.api.core.SimpleString
@ -61,7 +62,8 @@ class RPCClientProxyHandler(
private val rpcPassword: String, private val rpcPassword: String,
private val serverLocator: ServerLocator, private val serverLocator: ServerLocator,
private val clientAddress: SimpleString, private val clientAddress: SimpleString,
private val rpcOpsClass: Class<out RPCOps> private val rpcOpsClass: Class<out RPCOps>,
serializationContext: SerializationContext
) : InvocationHandler { ) : InvocationHandler {
private enum class State { private enum class State {
@ -74,9 +76,6 @@ class RPCClientProxyHandler(
private companion object { private companion object {
val log = loggerFor<RPCClientProxyHandler>() val log = loggerFor<RPCClientProxyHandler>()
// Note that this KryoPool is not yet capable of deserialising Observables, it requires Proxy-specific context
// to do that. However it may still be used for serialisation of RPC requests and related messages.
val kryoPool: KryoPool = KryoPool.Builder { RPCKryo(RpcClientObservableSerializer) }.build()
// To check whether toString() is being invoked // To check whether toString() is being invoked
val toStringMethod: Method = Object::toString.javaMethod!! val toStringMethod: Method = Object::toString.javaMethod!!
} }
@ -85,7 +84,7 @@ class RPCClientProxyHandler(
private var reaperExecutor: ScheduledExecutorService? = null private var reaperExecutor: ScheduledExecutorService? = null
// A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering. // A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering.
private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").build() private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").setDaemon(true).build()
private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) { private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) {
Executors.newFixedThreadPool(1, observationExecutorThreadFactory) Executors.newFixedThreadPool(1, observationExecutorThreadFactory)
} }
@ -109,11 +108,10 @@ class RPCClientProxyHandler(
private val observablesToReap = ThreadBox(object { private val observablesToReap = ThreadBox(object {
var observables = ArrayList<RPCApi.ObservableId>() var observables = ArrayList<RPCApi.ObservableId>()
}) })
// A Kryo pool that automatically adds the observable context when an instance is requested. private val serializationContextWithObservableContext = RpcClientObservableSerializer.createContext(serializationContext, observableContext)
private val kryoPoolWithObservableContext = RpcClientObservableSerializer.createPoolWithContext(kryoPool, observableContext)
private fun createRpcObservableMap(): RpcObservableMap { private fun createRpcObservableMap(): RpcObservableMap {
val onObservableRemove = RemovalListener<RPCApi.ObservableId, UnicastSubject<Notification<Any>>> { val onObservableRemove = RemovalListener<RPCApi.ObservableId, UnicastSubject<Notification<*>>> {
val rpcCallSite = callSiteMap?.remove(it.key.toLong) val rpcCallSite = callSiteMap?.remove(it.key.toLong)
if (it.cause == RemovalCause.COLLECTED) { if (it.cause == RemovalCause.COLLECTED) {
log.warn(listOf( log.warn(listOf(
@ -194,7 +192,7 @@ class RPCClientProxyHandler(
val replyFuture = SettableFuture.create<Any>() val replyFuture = SettableFuture.create<Any>()
sessionAndProducerPool.run { sessionAndProducerPool.run {
val message = it.session.createMessage(false) val message = it.session.createMessage(false)
request.writeToClientMessage(kryoPool, message) request.writeToClientMessage(serializationContextWithObservableContext, message)
log.debug { log.debug {
val argumentsString = arguments?.joinToString() ?: "" val argumentsString = arguments?.joinToString() ?: ""
@ -221,7 +219,7 @@ class RPCClientProxyHandler(
// The handler for Artemis messages. // The handler for Artemis messages.
private fun artemisMessageHandler(message: ClientMessage) { private fun artemisMessageHandler(message: ClientMessage) {
val serverToClient = RPCApi.ServerToClient.fromClientMessage(kryoPoolWithObservableContext, message) val serverToClient = RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message)
log.debug { "Got message from RPC server $serverToClient" } log.debug { "Got message from RPC server $serverToClient" }
when (serverToClient) { when (serverToClient) {
is RPCApi.ServerToClient.RpcReply -> { is RPCApi.ServerToClient.RpcReply -> {
@ -338,7 +336,7 @@ class RPCClientProxyHandler(
} }
} }
private typealias RpcObservableMap = Cache<RPCApi.ObservableId, UnicastSubject<Notification<Any>>> private typealias RpcObservableMap = Cache<RPCApi.ObservableId, UnicastSubject<Notification<*>>>
private typealias RpcReplyMap = ConcurrentHashMap<RPCApi.RpcRequestId, SettableFuture<Any?>> private typealias RpcReplyMap = ConcurrentHashMap<RPCApi.RpcRequestId, SettableFuture<Any?>>
private typealias CallSiteMap = ConcurrentHashMap<Long, Throwable?> private typealias CallSiteMap = ConcurrentHashMap<Long, Throwable?>
@ -348,7 +346,7 @@ private typealias CallSiteMap = ConcurrentHashMap<Long, Throwable?>
* @param observableMap holds the Observables that are ultimately exposed to the user. * @param observableMap holds the Observables that are ultimately exposed to the user.
* @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to. * @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to.
*/ */
private data class ObservableContext( data class ObservableContext(
val callSiteMap: CallSiteMap?, val callSiteMap: CallSiteMap?,
val observableMap: RpcObservableMap, val observableMap: RpcObservableMap,
val hardReferenceStore: MutableSet<Observable<*>> val hardReferenceStore: MutableSet<Observable<*>>
@ -357,17 +355,17 @@ private data class ObservableContext(
/** /**
* A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext]. * A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext].
*/ */
private object RpcClientObservableSerializer : Serializer<Observable<Any>>() { object RpcClientObservableSerializer : Serializer<Observable<*>>() {
private object RpcObservableContextKey private object RpcObservableContextKey
fun createPoolWithContext(kryoPool: KryoPool, observableContext: ObservableContext): KryoPool {
return KryoPoolWithContext(kryoPool, RpcObservableContextKey, observableContext) fun createContext(serializationContext: SerializationContext, observableContext: ObservableContext): SerializationContext {
return serializationContext.withProperty(RpcObservableContextKey, observableContext)
} }
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> { override fun read(kryo: Kryo, input: Input, type: Class<Observable<*>>): Observable<Any> {
@Suppress("UNCHECKED_CAST")
val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext
val observableId = RPCApi.ObservableId(input.readLong(true)) val observableId = RPCApi.ObservableId(input.readLong(true))
val observable = UnicastSubject.create<Notification<Any>>() val observable = UnicastSubject.create<Notification<*>>()
require(observableContext.observableMap.getIfPresent(observableId) == null) { require(observableContext.observableMap.getIfPresent(observableId) == null) {
"Multiple Observables arrived with the same ID $observableId" "Multiple Observables arrived with the same ID $observableId"
} }
@ -384,7 +382,7 @@ private object RpcClientObservableSerializer : Serializer<Observable<Any>>() {
}.dematerialize() }.dematerialize()
} }
override fun write(kryo: Kryo, output: Output, observable: Observable<Any>) { override fun write(kryo: Kryo, output: Output, observable: Observable<*>) {
throw UnsupportedOperationException("Cannot serialise Observables on the client side") throw UnsupportedOperationException("Cannot serialise Observables on the client side")
} }

View File

@ -0,0 +1,28 @@
package net.corda.client.rpc.serialization
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.client.rpc.internal.RpcClientObservableSerializer
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.RPCKryo
import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme
import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer
import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1
class KryoClientSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) {
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
return byteSequence == KryoHeaderV0_1 && (target == SerializationContext.UseCase.RPCClient || target == SerializationContext.UseCase.P2P)
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool {
return KryoPool.Builder {
DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, serializationFactory, context)).apply { classLoader = context.deserializationClassLoader }
}.build()
}
// We're on the client and don't have access to server classes.
override fun rpcServerKryoPool(context: SerializationContext): KryoPool {
throw UnsupportedOperationException()
}
}

View File

@ -0,0 +1,87 @@
package net.corda.java.rpc;
import net.corda.client.rpc.CordaRPCConnection;
import net.corda.core.contracts.Amount;
import net.corda.core.messaging.CordaRPCOps;
import net.corda.core.messaging.DataFeed;
import net.corda.core.messaging.FlowHandle;
import net.corda.core.node.NodeInfo;
import net.corda.core.node.services.NetworkMapCache;
import net.corda.core.utilities.OpaqueBytes;
import net.corda.flows.AbstractCashFlow;
import net.corda.flows.CashIssueFlow;
import net.corda.nodeapi.User;
import net.corda.smoketesting.NodeConfig;
import net.corda.smoketesting.NodeProcess;
import org.bouncycastle.asn1.x500.X500Name;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.util.*;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
import static kotlin.test.AssertionsKt.assertEquals;
import static net.corda.contracts.GetBalances.getCashBalance;
public class StandaloneCordaRPCJavaClientTest {
private List<String> perms = Collections.singletonList("ALL");
private Set<String> permSet = new HashSet<>(perms);
private User rpcUser = new User("user1", "test", permSet);
private AtomicInteger port = new AtomicInteger(15000);
private NodeProcess notary;
private CordaRPCOps rpcProxy;
private CordaRPCConnection connection;
private NodeInfo notaryNode;
private NodeConfig notaryConfig = new NodeConfig(
new X500Name("CN=Notary Service,O=R3,OU=corda,L=Zurich,C=CH"),
port.getAndIncrement(),
port.getAndIncrement(),
port.getAndIncrement(),
Collections.singletonList("corda.notary.validating"),
Arrays.asList(rpcUser),
null
);
@Before
public void setUp() {
notary = new NodeProcess.Factory().create(notaryConfig);
connection = notary.connect();
rpcProxy = connection.getProxy();
notaryNode = fetchNotaryIdentity();
}
@After
public void done() {
try {
connection.close();
} finally {
notary.close();
}
}
private NodeInfo fetchNotaryIdentity() {
DataFeed<List<NodeInfo>, NetworkMapCache.MapChange> nodeDataFeed = rpcProxy.networkMapFeed();
return nodeDataFeed.getSnapshot().get(0);
}
@Test
public void testCashBalances() throws NoSuchFieldException, ExecutionException, InterruptedException {
Amount<Currency> dollars123 = new Amount<>(123, Currency.getInstance("USD"));
FlowHandle<AbstractCashFlow.Result> flowHandle = rpcProxy.startFlowDynamic(CashIssueFlow.class,
dollars123, OpaqueBytes.of("1".getBytes()),
notaryNode.getLegalIdentity(), notaryNode.getLegalIdentity());
System.out.println("Started issuing cash, waiting on result");
flowHandle.getReturnValue().get();
Amount<Currency> balance = getCashBalance(rpcProxy, Currency.getInstance("USD"));
System.out.print("Balance: " + balance + "\n");
assertEquals(dollars123, balance, "matching");
}
}

View File

@ -5,19 +5,19 @@ import com.google.common.hash.HashingInputStream
import net.corda.client.rpc.CordaRPCConnection import net.corda.client.rpc.CordaRPCConnection
import net.corda.client.rpc.notUsed import net.corda.client.rpc.notUsed
import net.corda.contracts.asset.Cash import net.corda.contracts.asset.Cash
import net.corda.core.contracts.DOLLARS import net.corda.contracts.getCashBalance
import net.corda.core.contracts.POUNDS import net.corda.contracts.getCashBalances
import net.corda.core.contracts.SWISS_FRANCS import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.getOrThrow import net.corda.core.internal.InputStreamAndHash
import net.corda.core.messaging.* import net.corda.core.messaging.*
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.*
import net.corda.core.seconds
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.sizedInputStreamAndHash import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.seconds
import net.corda.flows.CashIssueFlow import net.corda.flows.CashIssueFlow
import net.corda.flows.CashPaymentFlow import net.corda.flows.CashPaymentFlow
import net.corda.nodeapi.User import net.corda.nodeapi.User
@ -35,6 +35,7 @@ import java.util.concurrent.atomic.AtomicInteger
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertFalse import kotlin.test.assertFalse
import kotlin.test.assertNotEquals import kotlin.test.assertNotEquals
import kotlin.test.assertTrue
class StandaloneCordaRPClientTest { class StandaloneCordaRPClientTest {
private companion object { private companion object {
@ -78,7 +79,7 @@ class StandaloneCordaRPClientTest {
@Test @Test
fun `test attachments`() { fun `test attachments`() {
val attachment = sizedInputStreamAndHash(attachmentSize) val attachment = InputStreamAndHash.createInMemoryTestZip(attachmentSize, 1)
assertFalse(rpcProxy.attachmentExists(attachment.sha256)) assertFalse(rpcProxy.attachmentExists(attachment.sha256))
val id = WrapperStream(attachment.inputStream).use { rpcProxy.uploadAttachment(it) } val id = WrapperStream(attachment.inputStream).use { rpcProxy.uploadAttachment(it) }
assertEquals(attachment.sha256, id, "Attachment has incorrect SHA256 hash") assertEquals(attachment.sha256, id, "Attachment has incorrect SHA256 hash")
@ -117,38 +118,38 @@ class StandaloneCordaRPClientTest {
@Test @Test
fun `test state machines`() { fun `test state machines`() {
val (stateMachines, updates) = rpcProxy.stateMachinesAndUpdates() val (stateMachines, updates) = rpcProxy.stateMachinesFeed()
assertEquals(0, stateMachines.size) assertEquals(0, stateMachines.size)
var updateCount = 0 val updateCount = AtomicInteger(0)
updates.subscribe { update -> updates.subscribe { update ->
if (update is StateMachineUpdate.Added) { if (update is StateMachineUpdate.Added) {
log.info("StateMachine>> Id=${update.id}") log.info("StateMachine>> Id=${update.id}")
++updateCount updateCount.incrementAndGet()
} }
} }
// Now issue some cash // Now issue some cash
rpcProxy.startFlow(::CashIssueFlow, 513.SWISS_FRANCS, OpaqueBytes.of(0), notaryNode.legalIdentity, notaryNode.notaryIdentity) rpcProxy.startFlow(::CashIssueFlow, 513.SWISS_FRANCS, OpaqueBytes.of(0), notaryNode.legalIdentity, notaryNode.notaryIdentity)
.returnValue.getOrThrow(timeout) .returnValue.getOrThrow(timeout)
assertEquals(1, updateCount) assertEquals(1, updateCount.get())
} }
@Test @Test
fun `test vault track by`() { fun `test vault track by`() {
val (vault, vaultUpdates) = rpcProxy.vaultTrackBy<Cash.State>() val (vault, vaultUpdates) = rpcProxy.vaultTrackBy<Cash.State>(paging = PageSpecification(DEFAULT_PAGE_NUM))
assertEquals(0, vault.states.size) assertEquals(0, vault.totalStatesAvailable)
var updateCount = 0 val updateCount = AtomicInteger(0)
vaultUpdates.subscribe { update -> vaultUpdates.subscribe { update ->
log.info("Vault>> FlowId=${update.flowId}") log.info("Vault>> FlowId=${update.flowId}")
++updateCount updateCount.incrementAndGet()
} }
// Now issue some cash // Now issue some cash
rpcProxy.startFlow(::CashIssueFlow, 629.POUNDS, OpaqueBytes.of(0), notaryNode.legalIdentity, notaryNode.notaryIdentity) rpcProxy.startFlow(::CashIssueFlow, 629.POUNDS, OpaqueBytes.of(0), notaryNode.legalIdentity, notaryNode.notaryIdentity)
.returnValue.getOrThrow(timeout) .returnValue.getOrThrow(timeout)
assertNotEquals(0, updateCount) assertNotEquals(0, updateCount.get())
// Check that this cash exists in the vault // Check that this cash exists in the vault
val cashBalance = rpcProxy.getCashBalances() val cashBalance = rpcProxy.getCashBalances()
@ -177,10 +178,27 @@ class StandaloneCordaRPClientTest {
assertEquals(3, moreResults.totalStatesAvailable) // 629 - 100 + 100 assertEquals(3, moreResults.totalStatesAvailable) // 629 - 100 + 100
// Check that this cash exists in the vault // Check that this cash exists in the vault
val cashBalance = rpcProxy.getCashBalances() val cashBalances = rpcProxy.getCashBalances()
log.info("Cash Balances: $cashBalance") log.info("Cash Balances: $cashBalances")
assertEquals(1, cashBalance.size) assertEquals(1, cashBalances.size)
assertEquals(629.POUNDS, cashBalance[Currency.getInstance("GBP")]) assertEquals(629.POUNDS, cashBalances[Currency.getInstance("GBP")])
}
@Test
fun `test cash balances`() {
val startCash = rpcProxy.getCashBalances()
assertTrue(startCash.isEmpty(), "Should not start with any cash")
val flowHandle = rpcProxy.startFlow(::CashIssueFlow,
629.DOLLARS, OpaqueBytes.of(0),
notaryNode.legalIdentity, notaryNode.legalIdentity
)
println("Started issuing cash, waiting on result")
flowHandle.returnValue.get()
val balance = rpcProxy.getCashBalance(USD)
println("Balance: " + balance)
assertEquals(629.DOLLARS, balance)
} }
private fun fetchNotaryIdentity(): NodeInfo { private fun fetchNotaryIdentity(): NodeInfo {

View File

@ -1,6 +1,6 @@
package net.corda.kotlin.rpc package net.corda.kotlin.rpc
import net.corda.core.div import net.corda.core.internal.div
import org.junit.Test import org.junit.Test
import java.io.File import java.io.File
import java.nio.file.Path import java.nio.file.Path

View File

@ -1,8 +1,8 @@
package net.corda.client.rpc package net.corda.client.rpc
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.flatMap import net.corda.core.internal.concurrent.flatMap
import net.corda.core.map import net.corda.core.internal.concurrent.map
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.nodeapi.User import net.corda.nodeapi.User
@ -44,13 +44,13 @@ open class AbstractRPCTest {
startInVmRpcClient<I>(rpcUser.username, rpcUser.password, clientConfiguration).map { startInVmRpcClient<I>(rpcUser.username, rpcUser.password, clientConfiguration).map {
TestProxy(it, { startInVmArtemisSession(rpcUser.username, rpcUser.password) }) TestProxy(it, { startInVmArtemisSession(rpcUser.username, rpcUser.password) })
} }
}.get() }
RPCTestMode.Netty -> RPCTestMode.Netty ->
startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server -> startRpcServer(ops = ops, rpcUser = rpcUser, configuration = serverConfiguration).flatMap { server ->
startRpcClient<I>(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password, clientConfiguration).map { startRpcClient<I>(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password, clientConfiguration).map {
TestProxy(it, { startArtemisSession(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password) }) TestProxy(it, { startArtemisSession(server.broker.hostAndPort!!, rpcUser.username, rpcUser.password) })
} }
}
}.get() }.get()
} }
}
} }

View File

@ -1,11 +1,11 @@
package net.corda.client.rpc package net.corda.client.rpc
import com.google.common.util.concurrent.Futures import net.corda.core.concurrent.CordaFuture
import com.google.common.util.concurrent.ListenableFuture import net.corda.core.internal.concurrent.doneFuture
import com.google.common.util.concurrent.SettableFuture import net.corda.core.internal.concurrent.openFuture
import net.corda.core.getOrThrow import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.thenMatch import net.corda.core.utilities.getOrThrow
import net.corda.node.services.messaging.getRpcContext import net.corda.node.services.messaging.getRpcContext
import net.corda.nodeapi.RPCSinceVersion import net.corda.nodeapi.RPCSinceVersion
import net.corda.testing.RPCDriverExposedDSLInterface import net.corda.testing.RPCDriverExposedDSLInterface
@ -27,7 +27,9 @@ import kotlin.test.assertTrue
class ClientRPCInfrastructureTests : AbstractRPCTest() { class ClientRPCInfrastructureTests : AbstractRPCTest() {
// TODO: Test that timeouts work // TODO: Test that timeouts work
private fun RPCDriverExposedDSLInterface.testProxy() = testProxy<TestOps>(TestOpsImpl()).ops private fun RPCDriverExposedDSLInterface.testProxy(): TestOps {
return testProxy<TestOps>(TestOpsImpl()).ops
}
interface TestOps : RPCOps { interface TestOps : RPCOps {
@Throws(IllegalArgumentException::class) @Throws(IllegalArgumentException::class)
@ -41,9 +43,9 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() {
fun makeComplicatedObservable(): Observable<Pair<String, Observable<String>>> fun makeComplicatedObservable(): Observable<Pair<String, Observable<String>>>
fun makeListenableFuture(): ListenableFuture<Int> fun makeListenableFuture(): CordaFuture<Int>
fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>> fun makeComplicatedListenableFuture(): CordaFuture<Pair<String, CordaFuture<String>>>
@RPCSinceVersion(2) @RPCSinceVersion(2)
fun addedLater() fun addedLater()
@ -52,7 +54,7 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() {
} }
private lateinit var complicatedObservable: Observable<Pair<String, Observable<String>>> private lateinit var complicatedObservable: Observable<Pair<String, Observable<String>>>
private lateinit var complicatedListenableFuturee: ListenableFuture<Pair<String, ListenableFuture<String>>> private lateinit var complicatedListenableFuturee: CordaFuture<Pair<String, CordaFuture<String>>>
inner class TestOpsImpl : TestOps { inner class TestOpsImpl : TestOps {
override val protocolVersion = 1 override val protocolVersion = 1
@ -60,9 +62,9 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() {
override fun void() {} override fun void() {}
override fun someCalculation(str: String, num: Int) = "$str $num" override fun someCalculation(str: String, num: Int) = "$str $num"
override fun makeObservable(): Observable<Int> = Observable.just(1, 2, 3, 4) override fun makeObservable(): Observable<Int> = Observable.just(1, 2, 3, 4)
override fun makeListenableFuture(): ListenableFuture<Int> = Futures.immediateFuture(1) override fun makeListenableFuture() = doneFuture(1)
override fun makeComplicatedObservable() = complicatedObservable override fun makeComplicatedObservable() = complicatedObservable
override fun makeComplicatedListenableFuture(): ListenableFuture<Pair<String, ListenableFuture<String>>> = complicatedListenableFuturee override fun makeComplicatedListenableFuture() = complicatedListenableFuturee
override fun addedLater(): Unit = throw IllegalStateException() override fun addedLater(): Unit = throw IllegalStateException()
override fun captureUser(): String = getRpcContext().currentUser.username override fun captureUser(): String = getRpcContext().currentUser.username
} }
@ -150,10 +152,10 @@ class ClientRPCInfrastructureTests : AbstractRPCTest() {
fun `complex ListenableFuture`() { fun `complex ListenableFuture`() {
rpcDriver { rpcDriver {
val proxy = testProxy() val proxy = testProxy()
val serverQuote = SettableFuture.create<Pair<String, ListenableFuture<String>>>() val serverQuote = openFuture<Pair<String, CordaFuture<String>>>()
complicatedListenableFuturee = serverQuote complicatedListenableFuturee = serverQuote
val twainQuote = "Mark Twain" to Futures.immediateFuture("I have never let my schooling interfere with my education.") val twainQuote = "Mark Twain" to doneFuture("I have never let my schooling interfere with my education.")
val clientQuotes = LinkedBlockingQueue<String>() val clientQuotes = LinkedBlockingQueue<String>()
val clientFuture = proxy.makeComplicatedListenableFuture() val clientFuture = proxy.makeComplicatedListenableFuture()

View File

@ -1,10 +1,10 @@
package net.corda.client.rpc package net.corda.client.rpc
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.future
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.millis import net.corda.core.utilities.millis
import net.corda.core.crypto.random63BitValue import net.corda.core.crypto.random63BitValue
import net.corda.core.internal.concurrent.fork
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.testing.RPCDriverExposedDSLInterface import net.corda.testing.RPCDriverExposedDSLInterface
@ -17,6 +17,7 @@ import rx.subjects.UnicastSubject
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
import java.util.concurrent.ForkJoinPool
@RunWith(Parameterized::class) @RunWith(Parameterized::class)
class RPCConcurrencyTests : AbstractRPCTest() { class RPCConcurrencyTests : AbstractRPCTest() {
@ -68,7 +69,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
Observable.empty<ObservableRose<Int>>() Observable.empty<ObservableRose<Int>>()
} else { } else {
val publish = UnicastSubject.create<ObservableRose<Int>>() val publish = UnicastSubject.create<ObservableRose<Int>>()
future { ForkJoinPool.commonPool().fork {
(1..branchingFactor).toList().parallelStream().forEach { (1..branchingFactor).toList().parallelStream().forEach {
publish.onNext(getParallelObservableTree(depth - 1, branchingFactor)) publish.onNext(getParallelObservableTree(depth - 1, branchingFactor))
} }
@ -105,7 +106,7 @@ class RPCConcurrencyTests : AbstractRPCTest() {
val done = CountDownLatch(numberOfBlockedCalls) val done = CountDownLatch(numberOfBlockedCalls)
// Start a couple of blocking RPC calls // Start a couple of blocking RPC calls
(1..numberOfBlockedCalls).forEach { (1..numberOfBlockedCalls).forEach {
future { ForkJoinPool.commonPool().fork {
proxy.ops.waitLatch(id) proxy.ops.waitLatch(id)
done.countDown() done.countDown()
} }

View File

@ -3,12 +3,11 @@ package net.corda.client.rpc
import com.google.common.base.Stopwatch import com.google.common.base.Stopwatch
import net.corda.client.rpc.internal.RPCClientConfiguration import net.corda.client.rpc.internal.RPCClientConfiguration
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.core.minutes import net.corda.core.utilities.minutes
import net.corda.core.seconds import net.corda.core.utilities.seconds
import net.corda.core.utilities.div import net.corda.testing.performance.div
import net.corda.node.services.messaging.RPCServerConfiguration import net.corda.node.services.messaging.RPCServerConfiguration
import net.corda.testing.RPCDriverExposedDSLInterface import net.corda.testing.RPCDriverExposedDSLInterface
import net.corda.testing.driver.ShutdownManager
import net.corda.testing.measure import net.corda.testing.measure
import net.corda.testing.performance.startPublishingFixedRateInjector import net.corda.testing.performance.startPublishingFixedRateInjector
import net.corda.testing.performance.startReporter import net.corda.testing.performance.startReporter

View File

@ -1,8 +1,8 @@
package net.corda.client.rpc package net.corda.client.rpc
import net.corda.core.messaging.RPCOps import net.corda.core.messaging.RPCOps
import net.corda.node.services.messaging.requirePermission
import net.corda.node.services.messaging.getRpcContext import net.corda.node.services.messaging.getRpcContext
import net.corda.node.services.messaging.requirePermission
import net.corda.nodeapi.PermissionException import net.corda.nodeapi.PermissionException
import net.corda.nodeapi.User import net.corda.nodeapi.User
import net.corda.testing.RPCDriverExposedDSLInterface import net.corda.testing.RPCDriverExposedDSLInterface

View File

@ -1,4 +1,4 @@
gradlePluginsVersion=0.13.2 gradlePluginsVersion=0.13.6
kotlinVersion=1.1.1 kotlinVersion=1.1.1
guavaVersion=21.0 guavaVersion=21.0
bouncycastleVersion=1.57 bouncycastleVersion=1.57

View File

@ -17,3 +17,7 @@ dependencies {
// Bouncy Castle: for X.500 distinguished name manipulation // Bouncy Castle: for X.500 distinguished name manipulation
compile "org.bouncycastle:bcprov-jdk15on:$bouncycastle_version" compile "org.bouncycastle:bcprov-jdk15on:$bouncycastle_version"
} }
publish {
name project.name
}

View File

@ -4,6 +4,7 @@ import static java.util.Collections.emptyList;
import com.typesafe.config.Config; import com.typesafe.config.Config;
import com.typesafe.config.ConfigFactory; import com.typesafe.config.ConfigFactory;
import com.typesafe.config.ConfigValueFactory; import com.typesafe.config.ConfigValueFactory;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -86,6 +87,6 @@ public class CordformNode implements NodeDefinition {
* @param id The (0-based) BFT replica ID. * @param id The (0-based) BFT replica ID.
*/ */
public void bftReplicaId(Integer id) { public void bftReplicaId(Integer id) {
config = config.withValue("bftReplicaId", ConfigValueFactory.fromAnyRef(id)); config = config.withValue("bftSMaRt", ConfigValueFactory.fromMap(Collections.singletonMap("replicaId", id)));
} }
} }

View File

@ -2,6 +2,7 @@ apply plugin: 'kotlin'
apply plugin: 'kotlin-jpa' apply plugin: 'kotlin-jpa'
apply plugin: 'net.corda.plugins.quasar-utils' apply plugin: 'net.corda.plugins.quasar-utils'
apply plugin: 'net.corda.plugins.publish-utils' apply plugin: 'net.corda.plugins.publish-utils'
apply plugin: 'com.jfrog.artifactory'
description 'Corda core' description 'Corda core'
@ -40,22 +41,12 @@ dependencies {
// AssertJ: for fluent assertions for testing // AssertJ: for fluent assertions for testing
testCompile "org.assertj:assertj-core:${assertj_version}" testCompile "org.assertj:assertj-core:${assertj_version}"
// TODO: Upgrade to junit-quickcheck 0.8, once it is released,
// because it depends on org.javassist:javassist instead
// of javassist:javassist.
testCompile "com.pholser:junit-quickcheck-core:$quickcheck_version"
testCompile "com.pholser:junit-quickcheck-generators:$quickcheck_version"
// Guava: Google utilities library. // Guava: Google utilities library.
compile "com.google.guava:guava:$guava_version" compile "com.google.guava:guava:$guava_version"
// RxJava: observable streams of events. // RxJava: observable streams of events.
compile "io.reactivex:rxjava:$rxjava_version" compile "io.reactivex:rxjava:$rxjava_version"
// Kryo: object graph serialization.
compile "com.esotericsoftware:kryo:4.0.0"
compile "de.javakaffee:kryo-serializers:0.41"
// Apache JEXL: An embeddable expression evaluation library. // Apache JEXL: An embeddable expression evaluation library.
// This may be temporary until we experiment with other ways to do on-the-fly contract specialisation via an API. // This may be temporary until we experiment with other ways to do on-the-fly contract specialisation via an API.
compile "org.apache.commons:commons-jexl3:3.0" compile "org.apache.commons:commons-jexl3:3.0"
@ -98,5 +89,5 @@ jar {
} }
publish { publish {
name = jar.baseName name jar.baseName
} }

View File

@ -0,0 +1,4 @@
/**
* WARNING: This is an internal package and not part of the public API. Do not use anything found here or in any sub-package.
*/
package net.corda.core.internal;

View File

@ -58,7 +58,7 @@ open class CordaException internal constructor(override var originalExceptionCla
} }
} }
open class CordaRuntimeException internal constructor(override var originalExceptionClassName: String?, open class CordaRuntimeException(override var originalExceptionClassName: String?,
private var _message: String? = null, private var _message: String? = null,
private var _cause: Throwable? = null) : RuntimeException(null, null, true, true), CordaThrowable { private var _cause: Throwable? = null) : RuntimeException(null, null, true, true), CordaThrowable {
constructor(message: String?, cause: Throwable?) : this(null, message, cause) constructor(message: String?, cause: Throwable?) : this(null, message, cause)

View File

@ -1,30 +0,0 @@
package net.corda.core
import java.util.*
import java.util.Spliterator.*
import java.util.stream.IntStream
import java.util.stream.Stream
import java.util.stream.StreamSupport
import kotlin.streams.asSequence
private fun IntProgression.spliteratorOfInt(): Spliterator.OfInt {
val kotlinIterator = iterator()
val javaIterator = object : PrimitiveIterator.OfInt {
override fun nextInt() = kotlinIterator.nextInt()
override fun hasNext() = kotlinIterator.hasNext()
override fun remove() = throw UnsupportedOperationException("remove")
}
val spliterator = Spliterators.spliterator(
javaIterator,
(1 + (last - first) / step).toLong(),
SUBSIZED or IMMUTABLE or NONNULL or SIZED or ORDERED or SORTED or DISTINCT
)
return if (step > 0) spliterator else object : Spliterator.OfInt by spliterator {
override fun getComparator() = Comparator.reverseOrder<Int>()
}
}
fun IntProgression.stream(): IntStream = StreamSupport.intStream(spliteratorOfInt(), false)
@Suppress("UNCHECKED_CAST") // When toArray has filled in the array, the component type is no longer T? but T (that may itself be nullable).
inline fun <reified T> Stream<out T>.toTypedArray() = toArray { size -> arrayOfNulls<T>(size) } as Array<T>

View File

@ -1,102 +1,16 @@
// TODO Move out the Kotlin specific stuff into a separate file
@file:JvmName("Utils") @file:JvmName("Utils")
package net.corda.core package net.corda.core
import com.google.common.base.Throwables import net.corda.core.concurrent.CordaFuture
import com.google.common.io.ByteStreams import net.corda.core.internal.concurrent.openFuture
import com.google.common.util.concurrent.* import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256
import net.corda.core.flows.FlowException
import net.corda.core.serialization.CordaSerializable
import org.slf4j.Logger
import rx.Observable import rx.Observable
import rx.Observer import rx.Observer
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.io.*
import java.math.BigDecimal
import java.nio.charset.Charset
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.*
import java.nio.file.attribute.FileAttribute
import java.time.Duration
import java.time.temporal.Temporal
import java.util.concurrent.CompletableFuture
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import java.util.stream.Stream
import java.util.zip.Deflater
import java.util.zip.ZipEntry
import java.util.zip.ZipInputStream
import java.util.zip.ZipOutputStream
import kotlin.concurrent.withLock
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
val Int.days: Duration get() = Duration.ofDays(this.toLong()) // TODO Delete this file once the Future stuff is out of here
@Suppress("unused") // It's here for completeness
val Int.hours: Duration get() = Duration.ofHours(this.toLong())
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
val Int.seconds: Duration get() = Duration.ofSeconds(this.toLong())
val Int.millis: Duration get() = Duration.ofMillis(this.toLong())
fun <A> CordaFuture<out A>.toObservable(): Observable<A> {
// TODO: Review by EOY2016 if we ever found these utilities helpful.
val Int.bd: BigDecimal get() = BigDecimal(this)
val Double.bd: BigDecimal get() = BigDecimal(this)
val String.bd: BigDecimal get() = BigDecimal(this)
val Long.bd: BigDecimal get() = BigDecimal(this)
fun String.abbreviate(maxWidth: Int): String = if (length <= maxWidth) this else take(maxWidth - 1) + ""
/** Like the + operator but throws an exception in case of integer overflow. */
infix fun Int.checkedAdd(b: Int) = Math.addExact(this, b)
/** Like the + operator but throws an exception in case of integer overflow. */
@Suppress("unused")
infix fun Long.checkedAdd(b: Long) = Math.addExact(this, b)
/** Same as [Future.get] but with a more descriptive name, and doesn't throw [ExecutionException], instead throwing its cause */
fun <T> Future<T>.getOrThrow(timeout: Duration? = null): T {
return try {
if (timeout == null) get() else get(timeout.toNanos(), TimeUnit.NANOSECONDS)
} catch (e: ExecutionException) {
throw e.cause!!
}
}
fun <V> future(block: () -> V): Future<V> = CompletableFuture.supplyAsync(block)
fun <F : ListenableFuture<*>, V> F.then(block: (F) -> V) = addListener(Runnable { block(this) }, MoreExecutors.directExecutor())
fun <U, V> Future<U>.match(success: (U) -> V, failure: (Throwable) -> V): V {
return success(try {
getOrThrow()
} catch (t: Throwable) {
return failure(t)
})
}
fun <U, V, W> ListenableFuture<U>.thenMatch(success: (U) -> V, failure: (Throwable) -> W) = then { it.match(success, failure) }
fun ListenableFuture<*>.andForget(log: Logger) = then { it.match({}, { log.error("Background task failed:", it) }) }
@Suppress("UNCHECKED_CAST") // We need the awkward cast because otherwise F cannot be nullable, even though it's safe.
infix fun <F, T> ListenableFuture<F>.map(mapper: (F) -> T): ListenableFuture<T> = Futures.transform(this, { (mapper as (F?) -> T)(it) })
infix fun <F, T> ListenableFuture<F>.flatMap(mapper: (F) -> ListenableFuture<T>): ListenableFuture<T> = Futures.transformAsync(this) { mapper(it!!) }
/** Executes the given block and sets the future to either the result, or any exception that was thrown. */
inline fun <T> SettableFuture<T>.catch(block: () -> T) {
try {
set(block())
} catch (t: Throwable) {
setException(t)
}
}
fun <A> ListenableFuture<out A>.toObservable(): Observable<A> {
return Observable.create { subscriber -> return Observable.create { subscriber ->
thenMatch({ thenMatch({
subscriber.onNext(it) subscriber.onNext(it)
@ -107,303 +21,26 @@ fun <A> ListenableFuture<out A>.toObservable(): Observable<A> {
} }
} }
/** Allows you to write code like: Paths.get("someDir") / "subdir" / "filename" but using the Paths API to avoid platform separator problems. */
operator fun Path.div(other: String): Path = resolve(other)
operator fun String.div(other: String): Path = Paths.get(this) / other
fun Path.createDirectory(vararg attrs: FileAttribute<*>): Path = Files.createDirectory(this, *attrs)
fun Path.createDirectories(vararg attrs: FileAttribute<*>): Path = Files.createDirectories(this, *attrs)
fun Path.exists(vararg options: LinkOption): Boolean = Files.exists(this, *options)
fun Path.copyToDirectory(targetDir: Path, vararg options: CopyOption): Path {
require(targetDir.isDirectory()) { "$targetDir is not a directory" }
val targetFile = targetDir.resolve(fileName)
Files.copy(this, targetFile, *options)
return targetFile
}
fun Path.moveTo(target: Path, vararg options: CopyOption): Path = Files.move(this, target, *options)
fun Path.isRegularFile(vararg options: LinkOption): Boolean = Files.isRegularFile(this, *options)
fun Path.isDirectory(vararg options: LinkOption): Boolean = Files.isDirectory(this, *options)
val Path.size: Long get() = Files.size(this)
inline fun <R> Path.list(block: (Stream<Path>) -> R): R = Files.list(this).use(block)
fun Path.deleteIfExists(): Boolean = Files.deleteIfExists(this)
fun Path.readAll(): ByteArray = Files.readAllBytes(this)
inline fun <R> Path.read(vararg options: OpenOption, block: (InputStream) -> R): R = Files.newInputStream(this, *options).use(block)
inline fun Path.write(createDirs: Boolean = false, vararg options: OpenOption = emptyArray(), block: (OutputStream) -> Unit) {
if (createDirs) {
normalize().parent?.createDirectories()
}
Files.newOutputStream(this, *options).use(block)
}
inline fun <R> Path.readLines(charset: Charset = UTF_8, block: (Stream<String>) -> R): R = Files.lines(this, charset).use(block)
fun Path.readAllLines(charset: Charset = UTF_8): List<String> = Files.readAllLines(this, charset)
fun Path.writeLines(lines: Iterable<CharSequence>, charset: Charset = UTF_8, vararg options: OpenOption): Path = Files.write(this, lines, charset, *options)
fun InputStream.copyTo(target: Path, vararg options: CopyOption): Long = Files.copy(this, target, *options)
// Simple infix function to add back null safety that the JDK lacks: timeA until timeB
infix fun Temporal.until(endExclusive: Temporal): Duration = Duration.between(this, endExclusive)
/** Returns the index of the given item or throws [IllegalArgumentException] if not found. */
fun <T> List<T>.indexOfOrThrow(item: T): Int {
val i = indexOf(item)
require(i != -1)
return i
}
/** /**
* Returns the single element matching the given [predicate], or `null` if element was not found, * Returns a [CordaFuture] bound to the *first* item emitted by this Observable. The future will complete with a
* or throws if more than one element was found.
*/
fun <T> Iterable<T>.noneOrSingle(predicate: (T) -> Boolean): T? {
var single: T? = null
for (element in this) {
if (predicate(element)) {
if (single == null) {
single = element
} else throw IllegalArgumentException("Collection contains more than one matching element.")
}
}
return single
}
/** Returns single element, or `null` if element was not found, or throws if more than one element was found. */
fun <T> Iterable<T>.noneOrSingle(): T? {
var single: T? = null
for (element in this) {
if (single == null) {
single = element
} else throw IllegalArgumentException("Collection contains more than one matching element.")
}
return single
}
/** Returns a random element in the list, or null if empty */
fun <T> List<T>.randomOrNull(): T? {
if (size <= 1) return firstOrNull()
val randomIndex = (Math.random() * size).toInt()
return get(randomIndex)
}
/** Returns a random element in the list matching the given predicate, or null if none found */
fun <T> List<T>.randomOrNull(predicate: (T) -> Boolean) = filter(predicate).randomOrNull()
inline fun elapsedTime(block: () -> Unit): Duration {
val start = System.nanoTime()
block()
val end = System.nanoTime()
return Duration.ofNanos(end - start)
}
// TODO: Add inline back when a new Kotlin version is released and check if the java.lang.VerifyError
// returns in the IRSSimulationTest. If not, commit the inline back.
fun <T> logElapsedTime(label: String, logger: Logger? = null, body: () -> T): T {
// Use nanoTime as it's monotonic.
val now = System.nanoTime()
try {
return body()
} finally {
val elapsed = Duration.ofNanos(System.nanoTime() - now).toMillis()
if (logger != null)
logger.info("$label took $elapsed msec")
else
println("$label took $elapsed msec")
}
}
fun <T> Logger.logElapsedTime(label: String, body: () -> T): T = logElapsedTime(label, this, body)
/**
* A threadbox is a simple utility that makes it harder to forget to take a lock before accessing some shared state.
* Simply define a private class to hold the data that must be grouped under the same lock, and then pass the only
* instance to the ThreadBox constructor. You can now use the [locked] method with a lambda to take the lock in a
* way that ensures it'll be released if there's an exception.
*
* Note that this technique is not infallible: if you capture a reference to the fields in another lambda which then
* gets stored and invoked later, there may still be unsafe multi-threaded access going on, so watch out for that.
* This is just a simple guard rail that makes it harder to slip up.
*
* Example:
*
* private class MutableState { var i = 5 }
* private val state = ThreadBox(MutableState())
*
* val ii = state.locked { i }
*/
class ThreadBox<out T>(val content: T, val lock: ReentrantLock = ReentrantLock()) {
inline fun <R> locked(body: T.() -> R): R = lock.withLock { body(content) }
inline fun <R> alreadyLocked(body: T.() -> R): R {
check(lock.isHeldByCurrentThread, { "Expected $lock to already be locked." })
return body(content)
}
fun checkNotLocked() = check(!lock.isHeldByCurrentThread)
}
/**
* This represents a transient exception or condition that might no longer be thrown if the operation is re-run or called
* again.
*
* We avoid the use of the word transient here to hopefully reduce confusion with the term in relation to (Java) serialization.
*/
@CordaSerializable
abstract class RetryableException(message: String) : FlowException(message)
/**
* A simple wrapper that enables the use of Kotlin's "val x by TransientProperty { ... }" syntax. Such a property
* will not be serialized to disk, and if it's missing (or the first time it's accessed), the initializer will be
* used to set it up. Note that the initializer will be called with the TransientProperty object locked.
*/
class TransientProperty<out T>(private val initializer: () -> T) {
@Transient private var v: T? = null
@Synchronized
operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: initializer().also { v = it }
}
/**
* Given a path to a zip file, extracts it to the given directory.
*/
fun extractZipFile(zipFile: Path, toDirectory: Path) = extractZipFile(Files.newInputStream(zipFile), toDirectory)
/**
* Given a zip file input stream, extracts it to the given directory.
*/
fun extractZipFile(inputStream: InputStream, toDirectory: Path) {
val normalisedDirectory = toDirectory.normalize().createDirectories()
ZipInputStream(BufferedInputStream(inputStream)).use {
while (true) {
val e = it.nextEntry ?: break
val outPath = (normalisedDirectory / e.name).normalize()
// Security checks: we should reject a zip that contains tricksy paths that try to escape toDirectory.
check(outPath.startsWith(normalisedDirectory)) { "ZIP contained a path that resolved incorrectly: ${e.name}" }
if (e.isDirectory) {
outPath.createDirectories()
continue
}
outPath.write { out ->
ByteStreams.copy(it, out)
}
it.closeEntry()
}
}
}
/**
* Get a valid InputStream from an in-memory zip as required for tests.
* Note that a slightly bigger than numOfExpectedBytes size is expected.
*/
@Throws(IllegalArgumentException::class)
fun sizedInputStreamAndHash(numOfExpectedBytes: Int): InputStreamAndHash {
if (numOfExpectedBytes <= 0) throw IllegalArgumentException("A positive number of numOfExpectedBytes is required.")
val baos = ByteArrayOutputStream()
ZipOutputStream(baos).use({ zos ->
val arraySize = 1024
val bytes = ByteArray(arraySize)
val n = (numOfExpectedBytes - 1) / arraySize + 1 // same as Math.ceil(numOfExpectedBytes/arraySize).
zos.setLevel(Deflater.NO_COMPRESSION)
zos.putNextEntry(ZipEntry("z"))
for (i in 0 until n) {
zos.write(bytes, 0, arraySize)
}
zos.closeEntry()
})
return getInputStreamAndHashFromOutputStream(baos)
}
/** Convert a [ByteArrayOutputStream] to [InputStreamAndHash]. */
fun getInputStreamAndHashFromOutputStream(baos: ByteArrayOutputStream): InputStreamAndHash {
// TODO: Consider converting OutputStream to InputStream without creating a ByteArray, probably using piped streams.
val bytes = baos.toByteArray()
// TODO: Consider calculating sha256 on the fly using a DigestInputStream.
return InputStreamAndHash(ByteArrayInputStream(bytes), bytes.sha256())
}
data class InputStreamAndHash(val inputStream: InputStream, val sha256: SecureHash.SHA256)
// TODO: Generic csv printing utility for clases.
val Throwable.rootCause: Throwable get() = Throwables.getRootCause(this)
/**
* Returns an Observable that buffers events until subscribed.
* @see UnicastSubject
*/
fun <T> Observable<T>.bufferUntilSubscribed(): Observable<T> {
val subject = UnicastSubject.create<T>()
val subscription = subscribe(subject)
return subject.doOnUnsubscribe { subscription.unsubscribe() }
}
/**
* Copy an [Observer] to multiple other [Observer]s.
*/
fun <T> Observer<T>.tee(vararg teeTo: Observer<T>): Observer<T> {
val subject = PublishSubject.create<T>()
subject.subscribe(this)
teeTo.forEach { subject.subscribe(it) }
return subject
}
/**
* Returns a [ListenableFuture] bound to the *first* item emitted by this Observable. The future will complete with a
* NoSuchElementException if no items are emitted or any other error thrown by the Observable. If it's cancelled then * NoSuchElementException if no items are emitted or any other error thrown by the Observable. If it's cancelled then
* it will unsubscribe from the observable. * it will unsubscribe from the observable.
*/ */
fun <T> Observable<T>.toFuture(): ListenableFuture<T> = ObservableToFuture(this) fun <T> Observable<T>.toFuture(): CordaFuture<T> = openFuture<T>().also {
val subscription = first().subscribe(object : Observer<T> {
private class ObservableToFuture<T>(observable: Observable<T>) : AbstractFuture<T>(), Observer<T> {
private val subscription = observable.first().subscribe(this)
override fun onNext(value: T) { override fun onNext(value: T) {
set(value) it.set(value)
} }
override fun onError(e: Throwable) { override fun onError(e: Throwable) {
setException(e) it.setException(e)
}
override fun cancel(mayInterruptIfRunning: Boolean): Boolean {
subscription.unsubscribe()
return super.cancel(mayInterruptIfRunning)
} }
override fun onCompleted() {} override fun onCompleted() {}
} })
it.then {
/** Return the sum of an Iterable of [BigDecimal]s. */ if (it.isCancelled) {
fun Iterable<BigDecimal>.sum(): BigDecimal = fold(BigDecimal.ZERO) { a, b -> a + b } subscription.unsubscribe()
fun codePointsString(vararg codePoints: Int): String {
val builder = StringBuilder()
codePoints.forEach { builder.append(Character.toChars(it)) }
return builder.toString()
}
fun <T> Class<T>.checkNotUnorderedHashMap() {
if (HashMap::class.java.isAssignableFrom(this) && !LinkedHashMap::class.java.isAssignableFrom(this)) {
throw NotSerializableException("Map type $this is unstable under iteration. Suggested fix: use LinkedHashMap instead.")
}
}
fun Class<*>.requireExternal(msg: String = "Internal class")
= require(!name.startsWith("net.corda.node.") && !name.contains(".internal.")) { "$msg: $name" }
interface DeclaredField<T> {
companion object {
inline fun <reified T> Any?.declaredField(clazz: KClass<*>, name: String): DeclaredField<T> = declaredField(clazz.java, name)
inline fun <reified T> Any.declaredField(name: String): DeclaredField<T> = declaredField(javaClass, name)
inline fun <reified T> Any?.declaredField(clazz: Class<*>, name: String): DeclaredField<T> {
val javaField = clazz.getDeclaredField(name).apply { isAccessible = true }
val receiver = this
return object : DeclaredField<T> {
override var value
get() = javaField.get(receiver) as T
set(value) = javaField.set(receiver, value)
} }
} }
}
var value: T
} }

View File

@ -1,34 +1,44 @@
package net.corda.core.concurrent package net.corda.core.concurrent
import com.google.common.annotations.VisibleForTesting import net.corda.core.internal.concurrent.openFuture
import com.google.common.util.concurrent.ListenableFuture import net.corda.core.utilities.getOrThrow
import com.google.common.util.concurrent.SettableFuture import net.corda.core.internal.VisibleForTesting
import net.corda.core.catch
import net.corda.core.match
import net.corda.core.then
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.util.concurrent.Future
import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.atomic.AtomicBoolean
/** Invoke [getOrThrow] and pass the value/throwable to success/failure respectively. */
fun <V, W> Future<V>.match(success: (V) -> W, failure: (Throwable) -> W): W {
val value = try {
getOrThrow()
} catch (t: Throwable) {
return failure(t)
}
return success(value)
}
/** /**
* As soon as a given future becomes done, the handler is invoked with that future as its argument. * As soon as a given future becomes done, the handler is invoked with that future as its argument.
* The result of the handler is copied into the result future, and the handler isn't invoked again. * The result of the handler is copied into the result future, and the handler isn't invoked again.
* If a given future errors after the result future is done, the error is automatically logged. * If a given future errors after the result future is done, the error is automatically logged.
*/ */
fun <S, T> firstOf(vararg futures: ListenableFuture<out S>, handler: (ListenableFuture<out S>) -> T) = firstOf(futures, defaultLog, handler) fun <V, W> firstOf(vararg futures: CordaFuture<out V>, handler: (CordaFuture<out V>) -> W) = firstOf(futures, defaultLog, handler)
private val defaultLog = LoggerFactory.getLogger("net.corda.core.concurrent") private val defaultLog = LoggerFactory.getLogger("net.corda.core.concurrent")
@VisibleForTesting @VisibleForTesting
internal val shortCircuitedTaskFailedMessage = "Short-circuited task failed:" internal val shortCircuitedTaskFailedMessage = "Short-circuited task failed:"
internal fun <S, T> firstOf(futures: Array<out ListenableFuture<out S>>, log: Logger, handler: (ListenableFuture<out S>) -> T): ListenableFuture<T> { internal fun <V, W> firstOf(futures: Array<out CordaFuture<out V>>, log: Logger, handler: (CordaFuture<out V>) -> W): CordaFuture<W> {
val resultFuture = SettableFuture.create<T>() val resultFuture = openFuture<W>()
val winnerChosen = AtomicBoolean() val winnerChosen = AtomicBoolean()
futures.forEach { futures.forEach {
it.then { it.then {
if (winnerChosen.compareAndSet(false, true)) { if (winnerChosen.compareAndSet(false, true)) {
resultFuture.catch { handler(it) } resultFuture.capture { handler(it) }
} else if (!it.isCancelled) { } else if (it.isCancelled) {
// Do nothing.
} else {
it.match({}, { log.error(shortCircuitedTaskFailedMessage, it) }) it.match({}, { log.error(shortCircuitedTaskFailedMessage, it) })
} }
} }

View File

@ -0,0 +1,22 @@
package net.corda.core.concurrent
import java.util.concurrent.CompletableFuture
import java.util.concurrent.Future
/**
* Same as [Future] with additional methods to provide some of the features of [java.util.concurrent.CompletableFuture] while minimising the API surface area.
* In Kotlin, to avoid compile errors, whenever CordaFuture is used in a parameter or extension method receiver type, its type parameter should be specified with out variance.
*/
interface CordaFuture<V> : Future<V> {
/**
* Run the given callback when this future is done, on the completion thread.
* If the completion thread is problematic for you e.g. deadlock, you can submit to an executor manually.
* If callback fails, its throwable is logged.
*/
fun <W> then(callback: (CordaFuture<V>) -> W): Unit
/**
* @return a new [CompletableFuture] with the same outcome as this Future.
*/
fun toCompletableFuture(): CompletableFuture<V>
}

View File

@ -1,5 +1,8 @@
package net.corda.core.contracts package net.corda.core.contracts
import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.utilities.exactAdd
import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import java.math.BigDecimal import java.math.BigDecimal
import java.math.RoundingMode import java.math.RoundingMode
@ -168,7 +171,7 @@ data class Amount<T : Any>(val quantity: Long, val displayTokenSize: BigDecimal,
*/ */
operator fun plus(other: Amount<T>): Amount<T> { operator fun plus(other: Amount<T>): Amount<T> {
checkToken(other) checkToken(other)
return Amount(Math.addExact(quantity, other.quantity), displayTokenSize, token) return Amount(quantity exactAdd other.quantity, displayTokenSize, token)
} }
/** /**
@ -268,9 +271,9 @@ data class SourceAndAmount<T : Any, out P : Any>(val source: P, val amount: Amou
* but in various scenarios it may be more consistent to allow positive and negative values. * but in various scenarios it may be more consistent to allow positive and negative values.
* For example it is common for a bank to code asset flows as gains and losses from its perspective i.e. always the destination. * For example it is common for a bank to code asset flows as gains and losses from its perspective i.e. always the destination.
* @param token represents the type of asset token as would be used to construct Amount<T> objects. * @param token represents the type of asset token as would be used to construct Amount<T> objects.
* @param source is the [Party], [Account], [CompositeKey], or other identifier of the token source if quantityDelta is positive, * @param source is the [Party], [CompositeKey], or other identifier of the token source if quantityDelta is positive,
* or the token sink if quantityDelta is negative. The type P should support value equality. * or the token sink if quantityDelta is negative. The type P should support value equality.
* @param destination is the [Party], [Account], [CompositeKey], or other identifier of the token sink if quantityDelta is positive, * @param destination is the [Party], [CompositeKey], or other identifier of the token sink if quantityDelta is positive,
* or the token source if quantityDelta is negative. The type P should support value equality. * or the token source if quantityDelta is negative. The type P should support value equality.
*/ */
@CordaSerializable @CordaSerializable
@ -329,7 +332,7 @@ class AmountTransfer<T : Any, P : Any>(val quantityDelta: Long,
"Only AmountTransfer between the same two parties can be aggregated/netted" "Only AmountTransfer between the same two parties can be aggregated/netted"
} }
return if (other.source == source) { return if (other.source == source) {
AmountTransfer(Math.addExact(quantityDelta, other.quantityDelta), token, source, destination) AmountTransfer(quantityDelta exactAdd other.quantityDelta, token, source, destination)
} else { } else {
AmountTransfer(Math.subtractExact(quantityDelta, other.quantityDelta), token, source, destination) AmountTransfer(Math.subtractExact(quantityDelta, other.quantityDelta), token, source, destination)
} }
@ -388,10 +391,10 @@ class AmountTransfer<T : Any, P : Any>(val quantityDelta: Long,
* relative asset exchange happens, but with each party exchanging versus a central counterparty, or clearing house. * relative asset exchange happens, but with each party exchanging versus a central counterparty, or clearing house.
* *
* @param centralParty The central party to face the exchange against. * @param centralParty The central party to face the exchange against.
* @return Returns two new AmountTransfers each between one of the original parties and the centralParty. * @return Returns a list of two new AmountTransfers each between one of the original parties and the centralParty.
* The net total exchange is the same as in the original input. * The net total exchange is the same as in the original input.
*/ */
fun novate(centralParty: P): Pair<AmountTransfer<T, P>, AmountTransfer<T, P>> = Pair(copy(destination = centralParty), copy(source = centralParty)) fun novate(centralParty: P): List<AmountTransfer<T, P>> = listOf(copy(destination = centralParty), copy(source = centralParty))
/** /**
* Applies this AmountTransfer to a list of [SourceAndAmount] objects representing balances. * Applies this AmountTransfer to a list of [SourceAndAmount] objects representing balances.

View File

@ -2,6 +2,7 @@
package net.corda.core.contracts package net.corda.core.contracts
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
import java.math.BigDecimal import java.math.BigDecimal
import java.security.PublicKey import java.security.PublicKey
@ -54,13 +55,6 @@ object Requirements {
infix inline fun String.using(expr: Boolean) { infix inline fun String.using(expr: Boolean) {
if (!expr) throw IllegalArgumentException("Failed requirement: $this") if (!expr) throw IllegalArgumentException("Failed requirement: $this")
} }
// Avoid overloading Kotlin keywords
@Deprecated("This function is deprecated, use 'using' instead",
ReplaceWith("using (expr)", "net.corda.core.contracts.Requirements.using"))
@Suppress("NOTHING_TO_INLINE") // Inlining this takes it out of our committed ABI.
infix inline fun String.by(expr: Boolean) {
using(expr)
}
} }
inline fun <R> requireThat(body: Requirements.() -> R) = Requirements.body() inline fun <R> requireThat(body: Requirements.() -> R) = Requirements.body()
@ -71,7 +65,7 @@ inline fun <R> requireThat(body: Requirements.() -> R) = Requirements.body()
/** Filters the command list by type, party and public key all at once. */ /** Filters the command list by type, party and public key all at once. */
inline fun <reified T : CommandData> Collection<AuthenticatedObject<CommandData>>.select(signer: PublicKey? = null, inline fun <reified T : CommandData> Collection<AuthenticatedObject<CommandData>>.select(signer: PublicKey? = null,
party: Party? = null) = party: AbstractParty? = null) =
filter { it.value is T }. filter { it.value is T }.
filter { if (signer == null) true else signer in it.signers }. filter { if (signer == null) true else signer in it.signers }.
filter { if (party == null) true else party in it.signingParties }. filter { if (party == null) true else party in it.signingParties }.

View File

@ -1,19 +1,24 @@
@file:JvmName("Structures")
package net.corda.core.contracts package net.corda.core.contracts
import net.corda.core.contracts.clauses.Clause
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.secureRandomBytes
import net.corda.core.flows.FlowLogicRef import net.corda.core.flows.FlowLogicRef
import net.corda.core.flows.FlowLogicRefFactory import net.corda.core.flows.FlowLogicRefFactory
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.serialization.* import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.serialize
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import java.io.FileNotFoundException import java.io.FileNotFoundException
import java.io.IOException import java.io.IOException
import java.io.InputStream import java.io.InputStream
import java.io.OutputStream import java.io.OutputStream
import java.security.PublicKey import java.security.PublicKey
import java.time.Duration
import java.time.Instant import java.time.Instant
import java.util.jar.JarInputStream import java.util.jar.JarInputStream
@ -76,7 +81,7 @@ interface ContractState {
* A _participant_ is any party that is able to consume this state in a valid transaction. * A _participant_ is any party that is able to consume this state in a valid transaction.
* *
* The list of participants is required for certain types of transactions. For example, when changing the notary * The list of participants is required for certain types of transactions. For example, when changing the notary
* for this state ([TransactionType.NotaryChange]), every participant has to be involved and approve the transaction * for this state, every participant has to be involved and approve the transaction
* so that they receive the updated state, and don't end up in a situation where they can no longer use a state * so that they receive the updated state, and don't end up in a situation where they can no longer use a state
* they possess, since someone consumed that state during the notary change process. * they possess, since someone consumed that state during the notary change process.
* *
@ -141,6 +146,12 @@ data class Issued<out P : Any>(val issuer: PartyAndReference, val product: P) {
fun <T : Any> Amount<Issued<T>>.withoutIssuer(): Amount<T> = Amount(quantity, token.product) fun <T : Any> Amount<Issued<T>>.withoutIssuer(): Amount<T> = Amount(quantity, token.product)
// DOCSTART 3 // DOCSTART 3
/**
* Return structure for [OwnableState.withNewOwner]
*/
data class CommandAndState(val command: CommandData, val ownableState: OwnableState)
/** /**
* A contract state that can have a single owner. * A contract state that can have a single owner.
*/ */
@ -149,7 +160,7 @@ interface OwnableState : ContractState {
val owner: AbstractParty val owner: AbstractParty
/** Copies the underlying data structure, replacing the owner field with this new value and leaving the rest alone */ /** Copies the underlying data structure, replacing the owner field with this new value and leaving the rest alone */
fun withNewOwner(newOwner: AbstractParty): Pair<CommandData, OwnableState> fun withNewOwner(newOwner: AbstractParty): CommandAndState
} }
// DOCEND 3 // DOCEND 3
@ -199,26 +210,6 @@ interface LinearState : ContractState {
* True if this should be tracked by our vault(s). * True if this should be tracked by our vault(s).
*/ */
fun isRelevant(ourKeys: Set<PublicKey>): Boolean fun isRelevant(ourKeys: Set<PublicKey>): Boolean
/**
* Standard clause to verify the LinearState safety properties.
*/
@CordaSerializable
class ClauseVerifier<in S : LinearState, C : CommandData> : Clause<S, C, Unit>() {
override fun verify(tx: TransactionForContract,
inputs: List<S>,
outputs: List<S>,
commands: List<AuthenticatedObject<C>>,
groupingKey: Unit?): Set<C> {
val inputIds = inputs.map { it.linearId }.distinct()
val outputIds = outputs.map { it.linearId }.distinct()
requireThat {
"LinearStates are not merged" using (inputIds.count() == inputs.count())
"LinearStates are not split" using (outputIds.count() == outputs.count())
}
return emptySet()
}
}
} }
// DOCEND 2 // DOCEND 2
@ -281,14 +272,13 @@ abstract class TypeOnlyCommandData : CommandData {
/** Command data/content plus pubkey pair: the signature is stored at the end of the serialized bytes */ /** Command data/content plus pubkey pair: the signature is stored at the end of the serialized bytes */
@CordaSerializable @CordaSerializable
// DOCSTART 9 data class Command<T : CommandData>(val value: T, val signers: List<PublicKey>) {
data class Command(val value: CommandData, val signers: List<PublicKey>) { // TODO Introduce NonEmptyList?
// DOCEND 9
init { init {
require(signers.isNotEmpty()) require(signers.isNotEmpty())
} }
constructor(data: CommandData, key: PublicKey) : this(data, listOf(key)) constructor(data: T, key: PublicKey) : this(data, listOf(key))
private fun commandDataToString() = value.toString().let { if (it.contains("@")) it.replace('$', '.').split("@")[0] else it } private fun commandDataToString() = value.toString().let { if (it.contains("@")) it.replace('$', '.').split("@")[0] else it }
override fun toString() = "${commandDataToString()} with pubkeys ${signers.joinToString()}" override fun toString() = "${commandDataToString()} with pubkeys ${signers.joinToString()}"
@ -324,63 +314,6 @@ data class AuthenticatedObject<out T : Any>(
) )
// DOCEND 6 // DOCEND 6
/**
* A time-window is required for validation/notarization purposes.
* If present in a transaction, contains a time that was verified by the uniqueness service. The true time must be
* between (fromTime, untilTime).
* Usually, a time-window is required to have both sides set (fromTime, untilTime).
* However, some apps may require that a time-window has a start [Instant] (fromTime), but no end [Instant] (untilTime) and vice versa.
* TODO: Consider refactoring using TimeWindow abstraction like TimeWindow.From, TimeWindow.Until, TimeWindow.Between.
*/
@CordaSerializable
class TimeWindow private constructor(
/** The time at which this transaction is said to have occurred is after this moment. */
val fromTime: Instant?,
/** The time at which this transaction is said to have occurred is before this moment. */
val untilTime: Instant?
) {
companion object {
/** Use when the left-side [fromTime] of a [TimeWindow] is only required and we don't need an end instant (untilTime). */
@JvmStatic
fun fromOnly(fromTime: Instant) = TimeWindow(fromTime, null)
/** Use when the right-side [untilTime] of a [TimeWindow] is only required and we don't need a start instant (fromTime). */
@JvmStatic
fun untilOnly(untilTime: Instant) = TimeWindow(null, untilTime)
/** Use when both sides of a [TimeWindow] must be set ([fromTime], [untilTime]). */
@JvmStatic
fun between(fromTime: Instant, untilTime: Instant): TimeWindow {
require(fromTime < untilTime) { "fromTime should be earlier than untilTime" }
return TimeWindow(fromTime, untilTime)
}
/** Use when we have a start time and a period of validity. */
@JvmStatic
fun fromStartAndDuration(fromTime: Instant, duration: Duration): TimeWindow = between(fromTime, fromTime + duration)
/**
* When we need to create a [TimeWindow] based on a specific time [Instant] and some tolerance in both sides of this instant.
* The result will be the following time-window: ([time] - [tolerance], [time] + [tolerance]).
*/
@JvmStatic
fun withTolerance(time: Instant, tolerance: Duration) = between(time - tolerance, time + tolerance)
}
/** The midpoint is calculated as fromTime + (untilTime - fromTime)/2. Note that it can only be computed if both sides are set. */
val midpoint: Instant get() = fromTime!! + Duration.between(fromTime, untilTime!!).dividedBy(2)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is TimeWindow) return false
return (fromTime == other.fromTime && untilTime == other.untilTime)
}
override fun hashCode() = 31 * (fromTime?.hashCode() ?: 0) + (untilTime?.hashCode() ?: 0)
override fun toString() = "TimeWindow(fromTime=$fromTime, untilTime=$untilTime)"
}
// DOCSTART 5 // DOCSTART 5
/** /**
* Implemented by a program that implements business logic on the shared ledger. All participants run this code for * Implemented by a program that implements business logic on the shared ledger. All participants run this code for
@ -399,7 +332,7 @@ interface Contract {
* existing contract code. * existing contract code.
*/ */
@Throws(IllegalArgumentException::class) @Throws(IllegalArgumentException::class)
fun verify(tx: TransactionForContract) fun verify(tx: LedgerTransaction)
/** /**
* Unparsed reference to the natural language contract that this code is supposed to express (usually a hash of * Unparsed reference to the natural language contract that this code is supposed to express (usually a hash of
@ -489,3 +422,23 @@ fun JarInputStream.extractFile(path: String, outputTo: OutputStream) {
} }
throw FileNotFoundException(path) throw FileNotFoundException(path)
} }
/**
* A privacy salt is required to compute nonces per transaction component in order to ensure that an adversary cannot
* use brute force techniques and reveal the content of a Merkle-leaf hashed value.
* Because this salt serves the role of the seed to compute nonces, its size and entropy should be equal to the
* underlying hash function used for Merkle tree generation, currently [SHA256], which has an output of 32 bytes.
* There are two constructors, one that generates a new 32-bytes random salt, and another that takes a [ByteArray] input.
* The latter is required in cases where the salt value needs to be pre-generated (agreed between transacting parties),
* but it is highlighted that one should always ensure it has sufficient entropy.
*/
@CordaSerializable
class PrivacySalt(bytes: ByteArray) : OpaqueBytes(bytes) {
/** Constructs a salt with a randomly-generated 32 byte value. */
constructor() : this(secureRandomBytes(32))
init {
require(bytes.size == 32) { "Privacy salt should be 32 bytes." }
require(!bytes.all { it == 0.toByte() }) { "Privacy salt should not be all zeros." }
}
}

View File

@ -0,0 +1,92 @@
package net.corda.core.contracts
import net.corda.core.internal.div
import net.corda.core.internal.until
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.WireTransaction
import java.time.Duration
import java.time.Instant
/**
* An interval on the time-line; not a single instantaneous point.
*
* There is no such thing as _exact_ time in networked systems due to the underlying physics involved and other issues
* such as network latency. The best that can be approximated is "fuzzy time" or an instant of time which has margin of
* tolerance around it. This is what [TimeWindow] represents. Time windows can be open-ended (i.e. specify only one of
* [fromTime] and [untilTime]) or they can be fully bounded.
*
* [WireTransaction] has an optional time-window property, which if specified, restricts the validity of the transaction
* to that time-interval as the Consensus Service will not sign it if it's received outside of this window.
*/
@CordaSerializable
abstract class TimeWindow {
companion object {
/** Creates a [TimeWindow] with null [untilTime], i.e. the time interval `[fromTime, ∞)`. [midpoint] will return null. */
@JvmStatic
fun fromOnly(fromTime: Instant): TimeWindow = From(fromTime)
/** Creates a [TimeWindow] with null [fromTime], i.e. the time interval `(∞, untilTime)`. [midpoint] will return null. */
@JvmStatic
fun untilOnly(untilTime: Instant): TimeWindow = Until(untilTime)
/**
* Creates a [TimeWindow] with the time interval `[fromTime, untilTime)`. [midpoint] will return
* `fromTime + (untilTime - fromTime) / 2`.
* @throws IllegalArgumentException If [fromTime] [untilTime]
*/
@JvmStatic
fun between(fromTime: Instant, untilTime: Instant): TimeWindow = Between(fromTime, untilTime)
/**
* Creates a [TimeWindow] with the time interval `[fromTime, fromTime + duration)`. [midpoint] will return
* `fromTime + duration / 2`
*/
@JvmStatic
fun fromStartAndDuration(fromTime: Instant, duration: Duration): TimeWindow = between(fromTime, fromTime + duration)
/**
* Creates a [TimeWindow] which is centered around [instant] with the given [tolerance] on both sides, i.e the
* time interval `[instant - tolerance, instant + tolerance)`. [midpoint] will return [instant].
*/
@JvmStatic
fun withTolerance(instant: Instant, tolerance: Duration) = between(instant - tolerance, instant + tolerance)
}
/** Returns the inclusive lower-bound of this [TimeWindow]'s interval, with null implying infinity. */
abstract val fromTime: Instant?
/** Returns the exclusive upper-bound of this [TimeWindow]'s interval, with null implying infinity. */
abstract val untilTime: Instant?
/**
* Returns the midpoint of [fromTime] and [untilTime] if both are non-null, calculated as
* `fromTime + (untilTime - fromTime) / 2`, otherwise returns null.
*/
abstract val midpoint: Instant?
/** Returns true iff the given [instant] is within the time interval of this [TimeWindow]. */
abstract operator fun contains(instant: Instant): Boolean
private data class From(override val fromTime: Instant) : TimeWindow() {
override val untilTime: Instant? get() = null
override val midpoint: Instant? get() = null
override fun contains(instant: Instant): Boolean = instant >= fromTime
override fun toString(): String = "[$fromTime, ∞)"
}
private data class Until(override val untilTime: Instant) : TimeWindow() {
override val fromTime: Instant? get() = null
override val midpoint: Instant? get() = null
override fun contains(instant: Instant): Boolean = instant < untilTime
override fun toString(): String = "(∞, $untilTime)"
}
private data class Between(override val fromTime: Instant, override val untilTime: Instant) : TimeWindow() {
init {
require(fromTime < untilTime) { "fromTime must be earlier than untilTime" }
}
override val midpoint: Instant get() = fromTime + (fromTime until untilTime) / 2
override fun contains(instant: Instant): Boolean = instant >= fromTime && instant < untilTime
override fun toString(): String = "[$fromTime, $untilTime)"
}
}

View File

@ -1,176 +0,0 @@
package net.corda.core.contracts
import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder
import java.security.PublicKey
/** Defines transaction build & validation logic for a specific transaction type */
@CordaSerializable
sealed class TransactionType {
/**
* Check that the transaction is valid based on:
* - General platform rules
* - Rules for the specific transaction type
*
* Note: Presence of _signatures_ is not checked, only the public keys to be signed for.
*/
@Throws(TransactionVerificationException::class)
fun verify(tx: LedgerTransaction) {
require(tx.notary != null || tx.timeWindow == null) { "Transactions with time-windows must be notarised" }
val duplicates = detectDuplicateInputs(tx)
if (duplicates.isNotEmpty()) throw TransactionVerificationException.DuplicateInputStates(tx.id, duplicates)
val missing = verifySigners(tx)
if (missing.isNotEmpty()) throw TransactionVerificationException.SignersMissing(tx.id, missing.toList())
verifyTransaction(tx)
}
/** Check that the list of signers includes all the necessary keys */
fun verifySigners(tx: LedgerTransaction): Set<PublicKey> {
val notaryKey = tx.inputs.map { it.state.notary.owningKey }.toSet()
if (notaryKey.size > 1) throw TransactionVerificationException.MoreThanOneNotary(tx.id)
val requiredKeys = getRequiredSigners(tx) + notaryKey
val missing = requiredKeys - tx.mustSign
return missing
}
/** Check that the inputs are unique. */
private fun detectDuplicateInputs(tx: LedgerTransaction): Set<StateRef> {
var seenInputs = emptySet<StateRef>()
var duplicates = emptySet<StateRef>()
tx.inputs.forEach { state ->
if (seenInputs.contains(state.ref)) {
duplicates += state.ref
}
seenInputs += state.ref
}
return duplicates
}
/**
* Return the list of public keys that that require signatures for the transaction type.
* Note: the notary key is checked separately for all transactions and need not be included.
*/
abstract fun getRequiredSigners(tx: LedgerTransaction): Set<PublicKey>
/** Implement type specific transaction validation logic */
abstract fun verifyTransaction(tx: LedgerTransaction)
/** A general transaction type where transaction validity is determined by custom contract code */
object General : TransactionType() {
/** Just uses the default [TransactionBuilder] with no special logic */
class Builder(notary: Party?) : TransactionBuilder(General, notary)
override fun verifyTransaction(tx: LedgerTransaction) {
verifyNoNotaryChange(tx)
verifyEncumbrances(tx)
verifyContracts(tx)
}
/**
* Make sure the notary has stayed the same. As we can't tell how inputs and outputs connect, if there
* are any inputs, all outputs must have the same notary.
*
* TODO: Is that the correct set of restrictions? May need to come back to this, see if we can be more
* flexible on output notaries.
*/
private fun verifyNoNotaryChange(tx: LedgerTransaction) {
if (tx.notary != null && tx.inputs.isNotEmpty()) {
tx.outputs.forEach {
if (it.notary != tx.notary) {
throw TransactionVerificationException.NotaryChangeInWrongTransactionType(tx.id, tx.notary, it.notary)
}
}
}
}
private fun verifyEncumbrances(tx: LedgerTransaction) {
// Validate that all encumbrances exist within the set of input states.
val encumberedInputs = tx.inputs.filter { it.state.encumbrance != null }
encumberedInputs.forEach { (state, ref) ->
val encumbranceStateExists = tx.inputs.any {
it.ref.txhash == ref.txhash && it.ref.index == state.encumbrance
}
if (!encumbranceStateExists) {
throw TransactionVerificationException.TransactionMissingEncumbranceException(
tx.id,
state.encumbrance!!,
TransactionVerificationException.Direction.INPUT
)
}
}
// Check that, in the outputs, an encumbered state does not refer to itself as the encumbrance,
// and that the number of outputs can contain the encumbrance.
for ((i, output) in tx.outputs.withIndex()) {
val encumbranceIndex = output.encumbrance ?: continue
if (encumbranceIndex == i || encumbranceIndex >= tx.outputs.size) {
throw TransactionVerificationException.TransactionMissingEncumbranceException(
tx.id,
encumbranceIndex,
TransactionVerificationException.Direction.OUTPUT)
}
}
}
/**
* Check the transaction is contract-valid by running the verify() for each input and output state contract.
* If any contract fails to verify, the whole transaction is considered to be invalid.
*/
private fun verifyContracts(tx: LedgerTransaction) {
val ctx = tx.toTransactionForContract()
// TODO: This will all be replaced in future once the sandbox and contract constraints work is done.
val contracts = (ctx.inputs.map { it.contract } + ctx.outputs.map { it.contract }).toSet()
for (contract in contracts) {
try {
contract.verify(ctx)
} catch(e: Throwable) {
throw TransactionVerificationException.ContractRejection(tx.id, contract, e)
}
}
}
override fun getRequiredSigners(tx: LedgerTransaction) = tx.commands.flatMap { it.signers }.toSet()
}
/**
* A special transaction type for reassigning a notary for a state. Validation does not involve running
* any contract code, it just checks that the states are unmodified apart from the notary field.
*/
object NotaryChange : TransactionType() {
/**
* A transaction builder that automatically sets the transaction type to [NotaryChange]
* and adds the list of participants to the signers set for every input state.
*/
class Builder(notary: Party) : TransactionBuilder(NotaryChange, notary) {
override fun addInputState(stateAndRef: StateAndRef<*>): TransactionBuilder {
signers.addAll(stateAndRef.state.data.participants.map { it.owningKey })
super.addInputState(stateAndRef)
return this
}
}
/**
* Check that the difference between inputs and outputs is only the notary field, and that all required signing
* public keys are present.
*
* @throws TransactionVerificationException.InvalidNotaryChange if the validity check fails.
*/
override fun verifyTransaction(tx: LedgerTransaction) {
try {
for ((input, output) in tx.inputs.zip(tx.outputs)) {
check(input.state.data == output.data)
check(input.state.notary != output.notary)
}
check(tx.commands.isEmpty())
} catch (e: IllegalStateException) {
throw TransactionVerificationException.InvalidNotaryChange(tx.id)
}
}
override fun getRequiredSigners(tx: LedgerTransaction) = tx.inputs.flatMap { it.state.data.participants }.map { it.owningKey }.toSet()
}
}

View File

@ -1,130 +0,0 @@
package net.corda.core.contracts
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException
import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable
import java.security.PublicKey
import java.util.*
// TODO: Consider moving this out of the core module and providing a different way for unit tests to test contracts.
/**
* A transaction to be passed as input to a contract verification function. Defines helper methods to
* simplify verification logic in contracts.
*/
// DOCSTART 1
data class TransactionForContract(val inputs: List<ContractState>,
val outputs: List<ContractState>,
val attachments: List<Attachment>,
val commands: List<AuthenticatedObject<CommandData>>,
val origHash: SecureHash,
val inputNotary: Party? = null,
val timeWindow: TimeWindow? = null) {
// DOCEND 1
override fun hashCode() = origHash.hashCode()
override fun equals(other: Any?) = other is TransactionForContract && other.origHash == origHash
/**
* Given a type and a function that returns a grouping key, associates inputs and outputs together so that they
* can be processed as one. The grouping key is any arbitrary object that can act as a map key (so must implement
* equals and hashCode).
*
* The purpose of this function is to simplify the writing of verification logic for transactions that may contain
* similar but unrelated state evolutions which need to be checked independently. Consider a transaction that
* simultaneously moves both dollars and euros (e.g. is an atomic FX trade). There may be multiple dollar inputs and
* multiple dollar outputs, depending on things like how fragmented the owner's vault is and whether various privacy
* techniques are in use. The quantity of dollars on the output side must sum to the same as on the input side, to
* ensure no money is being lost track of. This summation and checking must be repeated independently for each
* currency. To solve this, you would use groupStates with a type of Cash.State and a selector that returns the
* currency field: the resulting list can then be iterated over to perform the per-currency calculation.
*/
// DOCSTART 2
fun <T : ContractState, K : Any> groupStates(ofType: Class<T>, selector: (T) -> K): List<InOutGroup<T, K>> {
val inputs = inputs.filterIsInstance(ofType)
val outputs = outputs.filterIsInstance(ofType)
val inGroups: Map<K, List<T>> = inputs.groupBy(selector)
val outGroups: Map<K, List<T>> = outputs.groupBy(selector)
@Suppress("DEPRECATION")
return groupStatesInternal(inGroups, outGroups)
}
// DOCEND 2
/** See the documentation for the reflection-based version of [groupStates] */
inline fun <reified T : ContractState, K : Any> groupStates(selector: (T) -> K): List<InOutGroup<T, K>> {
val inputs = inputs.filterIsInstance<T>()
val outputs = outputs.filterIsInstance<T>()
val inGroups: Map<K, List<T>> = inputs.groupBy(selector)
val outGroups: Map<K, List<T>> = outputs.groupBy(selector)
@Suppress("DEPRECATION")
return groupStatesInternal(inGroups, outGroups)
}
@Deprecated("Do not use this directly: exposed as public only due to function inlining")
fun <T : ContractState, K : Any> groupStatesInternal(inGroups: Map<K, List<T>>, outGroups: Map<K, List<T>>): List<InOutGroup<T, K>> {
val result = ArrayList<InOutGroup<T, K>>()
for ((k, v) in inGroups.entries)
result.add(InOutGroup(v, outGroups[k] ?: emptyList(), k))
for ((k, v) in outGroups.entries) {
if (inGroups[k] == null)
result.add(InOutGroup(emptyList(), v, k))
}
return result
}
/** Utilities for contract writers to incorporate into their logic. */
/**
* A set of related inputs and outputs that are connected by some common attributes. An InOutGroup is calculated
* using [groupStates] and is useful for handling cases where a transaction may contain similar but unrelated
* state evolutions, for example, a transaction that moves cash in two different currencies. The numbers must add
* up on both sides of the transaction, but the values must be summed independently per currency. Grouping can
* be used to simplify this logic.
*/
// DOCSTART 3
data class InOutGroup<out T : ContractState, out K : Any>(val inputs: List<T>, val outputs: List<T>, val groupingKey: K)
// DOCEND 3
}
class TransactionResolutionException(val hash: SecureHash) : FlowException() {
override fun toString(): String = "Transaction resolution failure for $hash"
}
class AttachmentResolutionException(val hash: SecureHash) : FlowException() {
override fun toString(): String = "Attachment resolution failure for $hash"
}
sealed class TransactionVerificationException(val txId: SecureHash, cause: Throwable?) : FlowException(cause) {
class ContractRejection(txId: SecureHash, val contract: Contract, cause: Throwable?) : TransactionVerificationException(txId, cause)
class MoreThanOneNotary(txId: SecureHash) : TransactionVerificationException(txId, null)
class SignersMissing(txId: SecureHash, val missing: List<PublicKey>) : TransactionVerificationException(txId, null) {
override fun toString(): String = "Signers missing: ${missing.joinToString()}"
}
class DuplicateInputStates(txId: SecureHash, val duplicates: Set<StateRef>) : TransactionVerificationException(txId, null) {
override fun toString(): String = "Duplicate inputs: ${duplicates.joinToString()}"
}
class InvalidNotaryChange(txId: SecureHash) : TransactionVerificationException(txId, null)
class NotaryChangeInWrongTransactionType(txId: SecureHash, val txNotary: Party, val outputNotary: Party) : TransactionVerificationException(txId, null) {
override fun toString(): String {
return "Found unexpected notary change in transaction. Tx notary: $txNotary, found: $outputNotary"
}
}
class TransactionMissingEncumbranceException(txId: SecureHash, val missing: Int, val inOut: Direction) : TransactionVerificationException(txId, null) {
override val message: String get() = "Missing required encumbrance $missing in $inOut"
}
@CordaSerializable
enum class Direction {
INPUT,
OUTPUT
}
}

View File

@ -0,0 +1,42 @@
package net.corda.core.contracts
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException
import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.NonEmptySet
import java.security.PublicKey
class TransactionResolutionException(val hash: SecureHash) : FlowException("Transaction resolution failure for $hash")
class AttachmentResolutionException(val hash: SecureHash) : FlowException("Attachment resolution failure for $hash")
sealed class TransactionVerificationException(val txId: SecureHash, message: String, cause: Throwable?)
: FlowException("$message, transaction: $txId", cause) {
class ContractRejection(txId: SecureHash, contract: Contract, cause: Throwable)
: TransactionVerificationException(txId, "Contract verification failed: ${cause.message}, contract: $contract", cause)
class MoreThanOneNotary(txId: SecureHash)
: TransactionVerificationException(txId, "More than one notary", null)
class SignersMissing(txId: SecureHash, missing: List<PublicKey>)
: TransactionVerificationException(txId, "Signers missing: ${missing.joinToString()}", null)
class DuplicateInputStates(txId: SecureHash, val duplicates: NonEmptySet<StateRef>)
: TransactionVerificationException(txId, "Duplicate inputs: ${duplicates.joinToString()}", null)
class InvalidNotaryChange(txId: SecureHash)
: TransactionVerificationException(txId, "Detected a notary change. Outputs must use the same notary as inputs", null)
class NotaryChangeInWrongTransactionType(txId: SecureHash, txNotary: Party, outputNotary: Party)
: TransactionVerificationException(txId, "Found unexpected notary change in transaction. Tx notary: $txNotary, found: $outputNotary", null)
class TransactionMissingEncumbranceException(txId: SecureHash, missing: Int, inOut: Direction)
: TransactionVerificationException(txId, "Missing required encumbrance $missing in $inOut", null)
@CordaSerializable
enum class Direction {
INPUT,
OUTPUT
}
}

View File

@ -1,6 +1,6 @@
package net.corda.core.contracts package net.corda.core.contracts
import com.google.common.annotations.VisibleForTesting import net.corda.core.internal.VisibleForTesting
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import java.util.* import java.util.*

View File

@ -1,10 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
/**
* Compose a number of clauses, such that all of the clauses must run for verification to pass.
*/
@Deprecated("Use AllOf")
class AllComposition<S : ContractState, C : CommandData, K : Any>(firstClause: Clause<S, C, K>, vararg remainingClauses: Clause<S, C, K>) : AllOf<S, C, K>(firstClause, *remainingClauses)

View File

@ -1,38 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
import java.util.*
/**
* Compose a number of clauses, such that all of the clauses must run for verification to pass.
*/
open class AllOf<S : ContractState, C : CommandData, K : Any>(firstClause: Clause<S, C, K>, vararg remainingClauses: Clause<S, C, K>) : CompositeClause<S, C, K>() {
override val clauses = ArrayList<Clause<S, C, K>>()
init {
clauses.add(firstClause)
clauses.addAll(remainingClauses)
}
override fun matchedClauses(commands: List<AuthenticatedObject<C>>): List<Clause<S, C, K>> {
clauses.forEach { clause ->
check(clause.matches(commands)) { "Failed to match clause $clause" }
}
return clauses
}
override fun verify(tx: TransactionForContract,
inputs: List<S>,
outputs: List<S>,
commands: List<AuthenticatedObject<C>>,
groupingKey: K?): Set<C> {
return matchedClauses(commands).flatMapTo(HashSet<C>()) { clause ->
clause.verify(tx, inputs, outputs, commands, groupingKey)
}
}
override fun toString() = "All: $clauses.toList()"
}

View File

@ -1,10 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
/**
* Compose a number of clauses, such that any number of the clauses can run.
*/
@Deprecated("Use AnyOf instead, although note that any of requires at least one matched clause")
class AnyComposition<in S : ContractState, C : CommandData, in K : Any>(vararg rawClauses: Clause<S, C, K>) : AnyOf<S, C, K>(*rawClauses)

View File

@ -1,28 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
import java.util.*
/**
* Compose a number of clauses, such that one or more of the clauses can run.
*/
open class AnyOf<in S : ContractState, C : CommandData, in K : Any>(vararg rawClauses: Clause<S, C, K>) : CompositeClause<S, C, K>() {
override val clauses: List<Clause<S, C, K>> = rawClauses.toList()
override fun matchedClauses(commands: List<AuthenticatedObject<C>>): List<Clause<S, C, K>> {
val matched = clauses.filter { it.matches(commands) }
require(matched.isNotEmpty()) { "At least one clause must match" }
return matched
}
override fun verify(tx: TransactionForContract, inputs: List<S>, outputs: List<S>, commands: List<AuthenticatedObject<C>>, groupingKey: K?): Set<C> {
return matchedClauses(commands).flatMapTo(HashSet<C>()) { clause ->
clause.verify(tx, inputs, outputs, commands, groupingKey)
}
}
override fun toString(): String = "Any: ${clauses.toList()}"
}

View File

@ -1,74 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
import net.corda.core.utilities.loggerFor
import org.slf4j.Logger
/**
* A clause of a contract, containing a chunk of verification logic. That logic may be delegated to other clauses, or
* provided directly by this clause.
*
* @param S the type of contract state this clause operates on.
* @param C a common supertype of commands this clause operates on.
* @param K the type of the grouping key for states this clause operates on. Use [Unit] if not applicable.
*
* @see CompositeClause
*/
abstract class Clause<in S : ContractState, C : CommandData, in K : Any> {
companion object {
val log: Logger by lazy { loggerFor<Clause<*, *, *>>() }
}
/** Determine whether this clause runs or not */
open val requiredCommands: Set<Class<out CommandData>> = emptySet()
/**
* Determine the subclauses which will be verified as a result of verifying this clause.
*
* @throws IllegalStateException if the given commands do not result in a valid execution (for example no match
* with [FirstOf]).
*/
@Throws(IllegalStateException::class)
open fun getExecutionPath(commands: List<AuthenticatedObject<C>>): List<Clause<*, *, *>>
= listOf(this)
/**
* Verify the transaction matches the conditions from this clause. For example, a "no zero amount output" clause
* would check each of the output states that it applies to, looking for a zero amount, and throw IllegalStateException
* if any matched.
*
* @param tx the full transaction being verified. This is provided for cases where clauses need to access
* states or commands outside of their normal scope.
* @param inputs input states which are relevant to this clause. By default this is the set passed into [verifyClause],
* but may be further reduced by clauses such as [GroupClauseVerifier].
* @param outputs output states which are relevant to this clause. By default this is the set passed into [verifyClause],
* but may be further reduced by clauses such as [GroupClauseVerifier].
* @param commands commands which are relevant to this clause. By default this is the set passed into [verifyClause],
* but may be further reduced by clauses such as [GroupClauseVerifier].
* @param groupingKey a grouping key applied to states and commands, where applicable. Taken from
* [TransactionForContract.InOutGroup].
* @return the set of commands that are consumed IF this clause is matched, and cannot be used to match a
* later clause. This would normally be all commands matching "requiredCommands" for this clause, but some
* verify() functions may do further filtering on possible matches, and return a subset. This may also include
* commands that were not required (for example the Exit command for fungible assets is optional).
*/
@Throws(IllegalStateException::class)
abstract fun verify(tx: TransactionForContract,
inputs: List<S>,
outputs: List<S>,
commands: List<AuthenticatedObject<C>>,
groupingKey: K?): Set<C>
}
/**
* Determine if the given list of commands matches the required commands for a clause to trigger.
*/
fun <C : CommandData> Clause<*, C, *>.matches(commands: List<AuthenticatedObject<C>>): Boolean {
return if (requiredCommands.isEmpty())
true
else
commands.map { it.value.javaClass }.toSet().containsAll(requiredCommands)
}

View File

@ -1,29 +0,0 @@
@file:JvmName("ClauseVerifier")
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
/**
* Verify a transaction against the given list of clauses.
*
* @param tx transaction to be verified.
* @param clauses the clauses to verify.
* @param commands commands extracted from the transaction, which are relevant to the
* clauses.
*/
fun <C : CommandData> verifyClause(tx: TransactionForContract,
clause: Clause<ContractState, C, Unit>,
commands: List<AuthenticatedObject<C>>) {
if (Clause.log.isTraceEnabled) {
clause.getExecutionPath(commands).forEach {
Clause.log.trace("Tx ${tx.origHash} clause: $clause")
}
}
val matchedCommands = clause.verify(tx, tx.inputs, tx.outputs, commands, null)
check(matchedCommands.containsAll(commands.map { it.value })) { "The following commands were not matched at the end of execution: " + (commands - matchedCommands) }
}

View File

@ -1,25 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
/**
* Abstract supertype for clauses which compose other clauses together in some logical manner.
*/
abstract class CompositeClause<in S : ContractState, C : CommandData, in K : Any> : Clause<S, C, K>() {
/** List of clauses under this composite clause */
abstract val clauses: List<Clause<S, C, K>>
override fun getExecutionPath(commands: List<AuthenticatedObject<C>>): List<Clause<*, *, *>>
= matchedClauses(commands).flatMap { it.getExecutionPath(commands) }
/**
* Determine which clauses are matched by the supplied commands.
*
* @throws IllegalStateException if the given commands do not result in a valid execution (for example no match
* with [FirstOf]).
*/
@Throws(IllegalStateException::class)
abstract fun matchedClauses(commands: List<AuthenticatedObject<C>>): List<Clause<S, C, K>>
}

View File

@ -1,25 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
/**
* Filter the states that are passed through to the wrapped clause, to restrict them to a specific type.
*/
class FilterOn<S : ContractState, C : CommandData, in K : Any>(val clause: Clause<S, C, K>,
val filterStates: (List<ContractState>) -> List<S>) : Clause<ContractState, C, K>() {
override val requiredCommands: Set<Class<out CommandData>>
= clause.requiredCommands
override fun getExecutionPath(commands: List<AuthenticatedObject<C>>): List<Clause<*, *, *>>
= clause.getExecutionPath(commands)
override fun verify(tx: TransactionForContract,
inputs: List<ContractState>,
outputs: List<ContractState>,
commands: List<AuthenticatedObject<C>>,
groupingKey: K?): Set<C>
= clause.verify(tx, filterStates(inputs), filterStates(outputs), commands, groupingKey)
}

View File

@ -1,28 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
import java.util.*
/**
* Compose a number of clauses, such that the first match is run, and it errors if none is run.
*/
@Deprecated("Use FirstOf instead")
class FirstComposition<S : ContractState, C : CommandData, K : Any>(firstClause: Clause<S, C, K>, vararg remainingClauses: Clause<S, C, K>) : CompositeClause<S, C, K>() {
override val clauses = ArrayList<Clause<S, C, K>>()
override fun matchedClauses(commands: List<AuthenticatedObject<C>>): List<Clause<S, C, K>> = listOf(clauses.first { it.matches(commands) })
init {
clauses.add(firstClause)
clauses.addAll(remainingClauses)
}
override fun verify(tx: TransactionForContract, inputs: List<S>, outputs: List<S>, commands: List<AuthenticatedObject<C>>, groupingKey: K?): Set<C> {
val clause = matchedClauses(commands).singleOrNull() ?: throw IllegalStateException("No delegate clause matched in first composition")
return clause.verify(tx, inputs, outputs, commands, groupingKey)
}
override fun toString() = "First: ${clauses.toList()}"
}

View File

@ -1,41 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
import net.corda.core.utilities.loggerFor
import java.util.*
/**
* Compose a number of clauses, such that the first match is run, and it errors if none is run.
*/
class FirstOf<S : ContractState, C : CommandData, K : Any>(firstClause: Clause<S, C, K>, vararg remainingClauses: Clause<S, C, K>) : CompositeClause<S, C, K>() {
companion object {
val logger = loggerFor<FirstOf<*, *, *>>()
}
override val clauses = ArrayList<Clause<S, C, K>>()
/**
* Get the single matched clause from the set this composes, based on the given commands. This is provided as
* helper method for internal use, rather than using the exposed [matchedClauses] function which unnecessarily
* wraps the clause in a list.
*/
private fun matchedClause(commands: List<AuthenticatedObject<C>>): Clause<S, C, K> {
return clauses.firstOrNull { it.matches(commands) } ?: throw IllegalStateException("No delegate clause matched in first composition")
}
override fun matchedClauses(commands: List<AuthenticatedObject<C>>) = listOf(matchedClause(commands))
init {
clauses.add(firstClause)
clauses.addAll(remainingClauses)
}
override fun verify(tx: TransactionForContract, inputs: List<S>, outputs: List<S>, commands: List<AuthenticatedObject<C>>, groupingKey: K?): Set<C> {
return matchedClause(commands).verify(tx, inputs, outputs, commands, groupingKey)
}
override fun toString() = "First: ${clauses.toList()}"
}

View File

@ -1,29 +0,0 @@
package net.corda.core.contracts.clauses
import net.corda.core.contracts.AuthenticatedObject
import net.corda.core.contracts.CommandData
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionForContract
import java.util.*
abstract class GroupClauseVerifier<S : ContractState, C : CommandData, K : Any>(val clause: Clause<S, C, K>) : Clause<ContractState, C, Unit>() {
abstract fun groupStates(tx: TransactionForContract): List<TransactionForContract.InOutGroup<S, K>>
override fun getExecutionPath(commands: List<AuthenticatedObject<C>>): List<Clause<*, *, *>>
= clause.getExecutionPath(commands)
override fun verify(tx: TransactionForContract,
inputs: List<ContractState>,
outputs: List<ContractState>,
commands: List<AuthenticatedObject<C>>,
groupingKey: Unit?): Set<C> {
val groups = groupStates(tx)
val matchedCommands = HashSet<C>()
for ((groupInputs, groupOutputs, groupToken) in groups) {
matchedCommands.addAll(clause.verify(tx, groupInputs, groupOutputs, commands, groupToken))
}
return matchedCommands
}
}

View File

@ -4,6 +4,7 @@ import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.crypto.composite.CompositeSignature import net.corda.core.crypto.composite.CompositeSignature
import net.corda.core.crypto.provider.CordaObjectIdentifier import net.corda.core.crypto.provider.CordaObjectIdentifier
import net.corda.core.crypto.provider.CordaSecurityProvider import net.corda.core.crypto.provider.CordaSecurityProvider
import net.corda.core.serialization.serialize
import net.i2p.crypto.eddsa.EdDSAEngine import net.i2p.crypto.eddsa.EdDSAEngine
import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPrivateKey
import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.EdDSAPublicKey
@ -13,19 +14,19 @@ import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec
import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable
import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec
import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec
import org.bouncycastle.asn1.* import org.bouncycastle.asn1.ASN1Integer
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.DERNull
import org.bouncycastle.asn1.DLSequence
import org.bouncycastle.asn1.bc.BCObjectIdentifiers import org.bouncycastle.asn1.bc.BCObjectIdentifiers
import org.bouncycastle.asn1.nist.NISTObjectIdentifiers import org.bouncycastle.asn1.nist.NISTObjectIdentifiers
import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers
import org.bouncycastle.asn1.pkcs.PrivateKeyInfo import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
import org.bouncycastle.asn1.sec.SECObjectIdentifiers import org.bouncycastle.asn1.sec.SECObjectIdentifiers
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x509.AlgorithmIdentifier
import org.bouncycastle.asn1.x509.* import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.asn1.x9.X9ObjectIdentifiers import org.bouncycastle.asn1.x9.X9ObjectIdentifiers
import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.X509v3CertificateBuilder
import org.bouncycastle.cert.bc.BcX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey
import org.bouncycastle.jcajce.provider.asymmetric.rsa.BCRSAPrivateKey import org.bouncycastle.jcajce.provider.asymmetric.rsa.BCRSAPrivateKey
@ -39,10 +40,6 @@ import org.bouncycastle.jce.spec.ECPublicKeySpec
import org.bouncycastle.math.ec.ECConstants import org.bouncycastle.math.ec.ECConstants
import org.bouncycastle.math.ec.FixedPointCombMultiplier import org.bouncycastle.math.ec.FixedPointCombMultiplier
import org.bouncycastle.math.ec.WNafUtil import org.bouncycastle.math.ec.WNafUtil
import org.bouncycastle.operator.ContentSigner
import org.bouncycastle.operator.jcajce.JcaContentVerifierProviderBuilder
import org.bouncycastle.pkcs.PKCS10CertificationRequest
import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder
import org.bouncycastle.pqc.jcajce.provider.BouncyCastlePQCProvider import org.bouncycastle.pqc.jcajce.provider.BouncyCastlePQCProvider
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PublicKey import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PublicKey
@ -52,7 +49,6 @@ import java.security.*
import java.security.spec.InvalidKeySpecException import java.security.spec.InvalidKeySpecException
import java.security.spec.PKCS8EncodedKeySpec import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec import java.security.spec.X509EncodedKeySpec
import java.util.*
import javax.crypto.Mac import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec import javax.crypto.spec.SecretKeySpec
@ -195,7 +191,7 @@ object Crypto {
// that could cause unexpected and suspicious behaviour. // that could cause unexpected and suspicious behaviour.
// i.e. if someone removes a Provider and then he/she adds a new one with the same name. // i.e. if someone removes a Provider and then he/she adds a new one with the same name.
// The val is private to avoid any harmful state changes. // The val is private to avoid any harmful state changes.
private val providerMap: Map<String, Provider> = mapOf( val providerMap: Map<String, Provider> = mapOf(
BouncyCastleProvider.PROVIDER_NAME to getBouncyCastleProvider(), BouncyCastleProvider.PROVIDER_NAME to getBouncyCastleProvider(),
CordaSecurityProvider.PROVIDER_NAME to CordaSecurityProvider(), CordaSecurityProvider.PROVIDER_NAME to CordaSecurityProvider(),
"BCPQC" to BouncyCastlePQCProvider()) // unfortunately, provider's name is not final in BouncyCastlePQCProvider, so we explicitly set it. "BCPQC" to BouncyCastlePQCProvider()) // unfortunately, provider's name is not final in BouncyCastlePQCProvider, so we explicitly set it.
@ -282,7 +278,7 @@ object Crypto {
/** /**
* Decode a PKCS8 encoded key to its [PrivateKey] object based on the input scheme code name. * Decode a PKCS8 encoded key to its [PrivateKey] object based on the input scheme code name.
* This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers.
* @param schemeCodeName a [String] that should match a key in supportedSignatureSchemes map (e.g. ECDSA_SECP256K1_SHA256). * @param schemeCodeName a [String] that should match a key in supportedSignatureSchemes map (e.g. ECDSA_SECP256K1_SHA256).
* @param encodedKey a PKCS8 encoded private key. * @param encodedKey a PKCS8 encoded private key.
* @throws IllegalArgumentException on not supported scheme or if the given key specification * @throws IllegalArgumentException on not supported scheme or if the given key specification
@ -293,7 +289,7 @@ object Crypto {
/** /**
* Decode a PKCS8 encoded key to its [PrivateKey] object based on the input scheme code name. * Decode a PKCS8 encoded key to its [PrivateKey] object based on the input scheme code name.
* This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers.
* @param signatureScheme a signature scheme (e.g. ECDSA_SECP256K1_SHA256). * @param signatureScheme a signature scheme (e.g. ECDSA_SECP256K1_SHA256).
* @param encodedKey a PKCS8 encoded private key. * @param encodedKey a PKCS8 encoded private key.
* @throws IllegalArgumentException on not supported scheme or if the given key specification * @throws IllegalArgumentException on not supported scheme or if the given key specification
@ -325,7 +321,7 @@ object Crypto {
/** /**
* Decode an X509 encoded key to its [PrivateKey] object based on the input scheme code name. * Decode an X509 encoded key to its [PrivateKey] object based on the input scheme code name.
* This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers.
* @param schemeCodeName a [String] that should match a key in supportedSignatureSchemes map (e.g. ECDSA_SECP256K1_SHA256). * @param schemeCodeName a [String] that should match a key in supportedSignatureSchemes map (e.g. ECDSA_SECP256K1_SHA256).
* @param encodedKey an X509 encoded public key. * @param encodedKey an X509 encoded public key.
* @throws IllegalArgumentException if the requested scheme is not supported. * @throws IllegalArgumentException if the requested scheme is not supported.
@ -337,7 +333,7 @@ object Crypto {
/** /**
* Decode an X509 encoded key to its [PrivateKey] object based on the input scheme code name. * Decode an X509 encoded key to its [PrivateKey] object based on the input scheme code name.
* This should be used when the type key is known, e.g. during Kryo deserialisation or with key caches or key managers. * This should be used when the type key is known, e.g. during deserialisation or with key caches or key managers.
* @param signatureScheme a signature scheme (e.g. ECDSA_SECP256K1_SHA256). * @param signatureScheme a signature scheme (e.g. ECDSA_SECP256K1_SHA256).
* @param encodedKey an X509 encoded public key. * @param encodedKey an X509 encoded public key.
* @throws IllegalArgumentException if the requested scheme is not supported. * @throws IllegalArgumentException if the requested scheme is not supported.
@ -401,23 +397,23 @@ object Crypto {
} }
/** /**
* Generic way to sign [MetaData] objects with a [PrivateKey]. * Generic way to sign [SignableData] objects with a [PrivateKey].
* [MetaData] is a wrapper over the transaction's Merkle root in order to attach extra information, such as a timestamp or partial and blind signature indicators. * [SignableData] is a wrapper over the transaction's id (Merkle root) in order to attach extra information, such as a timestamp or partial and blind signature indicators.
* @param privateKey the signer's [PrivateKey]. * @param privateKey the signer's [PrivateKey].
* @param metaData a [MetaData] object that adds extra information to a transaction. * @param signableData a [SignableData] object that adds extra information to a transaction.
* @return a [TransactionSignature] object than contains the output of a successful signing and the metaData. * @return a [TransactionSignature] object than contains the output of a successful signing, signer's public key and the signature metadata.
* @throws IllegalArgumentException if the signature scheme is not supported for this private key or * @throws IllegalArgumentException if the signature scheme is not supported for this private key.
* if metaData.schemeCodeName is not aligned with key type.
* @throws InvalidKeyException if the private key is invalid. * @throws InvalidKeyException if the private key is invalid.
* @throws SignatureException if signing is not possible due to malformed data or private key. * @throws SignatureException if signing is not possible due to malformed data or private key.
*/ */
@Throws(IllegalArgumentException::class, InvalidKeyException::class, SignatureException::class) @Throws(IllegalArgumentException::class, InvalidKeyException::class, SignatureException::class)
fun doSign(privateKey: PrivateKey, metaData: MetaData): TransactionSignature { fun doSign(keyPair: KeyPair, signableData: SignableData): TransactionSignature {
val sigKey: SignatureScheme = findSignatureScheme(privateKey) val sigKey: SignatureScheme = findSignatureScheme(keyPair.private)
val sigMetaData: SignatureScheme = findSignatureScheme(metaData.schemeCodeName) val sigMetaData: SignatureScheme = findSignatureScheme(keyPair.public)
if (sigKey != sigMetaData) throw IllegalArgumentException("Metadata schemeCodeName: ${metaData.schemeCodeName} is not aligned with the key type.") if (sigKey != sigMetaData) throw IllegalArgumentException("Metadata schemeCodeName: ${sigMetaData.schemeCodeName}" +
val signatureData = doSign(sigKey.schemeCodeName, privateKey, metaData.bytes()) " is not aligned with the key type: ${sigKey.schemeCodeName}.")
return TransactionSignature(signatureData, metaData) val signatureBytes = doSign(sigKey.schemeCodeName, keyPair.private, signableData.serialize().bytes)
return TransactionSignature(signatureBytes, keyPair.public, signableData.signatureMetadata)
} }
/** /**
@ -434,7 +430,7 @@ object Crypto {
* if this signatureData scheme is unable to process the input data provided, if the verification is not possible. * if this signatureData scheme is unable to process the input data provided, if the verification is not possible.
* @throws IllegalArgumentException if the signature scheme is not supported or if any of the clear or signature data is empty. * @throws IllegalArgumentException if the signature scheme is not supported or if any of the clear or signature data is empty.
*/ */
@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) @Throws(InvalidKeyException::class, SignatureException::class)
fun doVerify(schemeCodeName: String, publicKey: PublicKey, signatureData: ByteArray, clearData: ByteArray) = doVerify(findSignatureScheme(schemeCodeName), publicKey, signatureData, clearData) fun doVerify(schemeCodeName: String, publicKey: PublicKey, signatureData: ByteArray, clearData: ByteArray) = doVerify(findSignatureScheme(schemeCodeName), publicKey, signatureData, clearData)
/** /**
@ -485,9 +481,9 @@ object Crypto {
/** /**
* Utility to simplify the act of verifying a [TransactionSignature]. * Utility to simplify the act of verifying a [TransactionSignature].
* It returns true if it succeeds, but it always throws an exception if verification fails. * It returns true if it succeeds, but it always throws an exception if verification fails.
* @param publicKey the signer's [PublicKey]. * @param txId transaction's id (Merkle root).
* @param transactionSignature the signatureData on a message. * @param transactionSignature the signature on the transaction.
* @return true if verification passes or throws an exception if verification fails. * @return true if verification passes or throw exception if verification fails.
* @throws InvalidKeyException if the key is invalid. * @throws InvalidKeyException if the key is invalid.
* @throws SignatureException if this signatureData object is not initialized properly, * @throws SignatureException if this signatureData object is not initialized properly,
* the passed-in signatureData is improperly encoded or of the wrong type, * the passed-in signatureData is improperly encoded or of the wrong type,
@ -495,9 +491,26 @@ object Crypto {
* @throws IllegalArgumentException if the signature scheme is not supported or if any of the clear or signature data is empty. * @throws IllegalArgumentException if the signature scheme is not supported or if any of the clear or signature data is empty.
*/ */
@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) @Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class)
fun doVerify(publicKey: PublicKey, transactionSignature: TransactionSignature): Boolean { fun doVerify(txId: SecureHash, transactionSignature: TransactionSignature): Boolean {
if (publicKey != transactionSignature.metaData.publicKey) IllegalArgumentException("MetaData's publicKey: ${transactionSignature.metaData.publicKey.toStringShort()} does not match") val signableData = SignableData(txId, transactionSignature.signatureMetadata)
return Crypto.doVerify(publicKey, transactionSignature.signatureData, transactionSignature.metaData.bytes()) return Crypto.doVerify(transactionSignature.by, transactionSignature.bytes, signableData.serialize().bytes)
}
/**
* Utility to simplify the act of verifying a digital signature by identifying the signature scheme used from the input public key's type.
* It returns true if it succeeds and false if not. In comparison to [doVerify] if the key and signature
* do not match it returns false rather than throwing an exception. Normally you should use the function which throws,
* as it avoids the risk of failing to test the result.
* @param txId transaction's id (Merkle root).
* @param transactionSignature the signature on the transaction.
* @throws SignatureException if this signatureData object is not initialized properly,
* the passed-in signatureData is improperly encoded or of the wrong type,
* if this signatureData scheme is unable to process the input data provided, if the verification is not possible.
*/
@Throws(SignatureException::class)
fun isValid(txId: SecureHash, transactionSignature: TransactionSignature): Boolean {
val signableData = SignableData(txId, transactionSignature.signatureMetadata)
return isValid(findSignatureScheme(transactionSignature.by), transactionSignature.by, transactionSignature.bytes, signableData.serialize().bytes)
} }
/** /**
@ -752,90 +765,6 @@ object Crypto {
return mac.doFinal(seed) return mac.doFinal(seed)
} }
/**
* Build a partial X.509 certificate ready for signing.
*
* @param issuer name of the issuing entity.
* @param subject name of the certificate subject.
* @param subjectPublicKey public key of the certificate subject.
* @param validityWindow the time period the certificate is valid for.
* @param nameConstraints any name constraints to impose on certificates signed by the generated certificate.
*/
fun createCertificate(certificateType: CertificateType, issuer: X500Name,
subject: X500Name, subjectPublicKey: PublicKey,
validityWindow: Pair<Date, Date>,
nameConstraints: NameConstraints? = null): X509v3CertificateBuilder {
val serial = BigInteger.valueOf(random63BitValue())
val keyPurposes = DERSequence(ASN1EncodableVector().apply { certificateType.purposes.forEach { add(it) } })
val subjectPublicKeyInfo = SubjectPublicKeyInfo.getInstance(ASN1Sequence.getInstance(subjectPublicKey.encoded))
val builder = JcaX509v3CertificateBuilder(issuer, serial, validityWindow.first, validityWindow.second, subject, subjectPublicKey)
.addExtension(Extension.subjectKeyIdentifier, false, BcX509ExtensionUtils().createSubjectKeyIdentifier(subjectPublicKeyInfo))
.addExtension(Extension.basicConstraints, certificateType.isCA, BasicConstraints(certificateType.isCA))
.addExtension(Extension.keyUsage, false, certificateType.keyUsage)
.addExtension(Extension.extendedKeyUsage, false, keyPurposes)
if (nameConstraints != null) {
builder.addExtension(Extension.nameConstraints, true, nameConstraints)
}
return builder
}
/**
* Build and sign an X.509 certificate with the given signer.
*
* @param issuer name of the issuing entity.
* @param issuerSigner content signer to sign the certificate with.
* @param subject name of the certificate subject.
* @param subjectPublicKey public key of the certificate subject.
* @param validityWindow the time period the certificate is valid for.
* @param nameConstraints any name constraints to impose on certificates signed by the generated certificate.
*/
fun createCertificate(certificateType: CertificateType, issuer: X500Name, issuerSigner: ContentSigner,
subject: X500Name, subjectPublicKey: PublicKey,
validityWindow: Pair<Date, Date>,
nameConstraints: NameConstraints? = null): X509CertificateHolder {
val builder = createCertificate(certificateType, issuer, subject, subjectPublicKey, validityWindow, nameConstraints)
return builder.build(issuerSigner).apply {
require(isValidOn(Date()))
}
}
/**
* Build and sign an X.509 certificate with CA cert private key.
*
* @param issuer name of the issuing entity.
* @param issuerKeyPair the public & private key to sign the certificate with.
* @param subject name of the certificate subject.
* @param subjectPublicKey public key of the certificate subject.
* @param validityWindow the time period the certificate is valid for.
* @param nameConstraints any name constraints to impose on certificates signed by the generated certificate.
*/
fun createCertificate(certificateType: CertificateType, issuer: X500Name, issuerKeyPair: KeyPair,
subject: X500Name, subjectPublicKey: PublicKey,
validityWindow: Pair<Date, Date>,
nameConstraints: NameConstraints? = null): X509CertificateHolder {
val signatureScheme = findSignatureScheme(issuerKeyPair.private)
val provider = providerMap[signatureScheme.providerName]
val builder = createCertificate(certificateType, issuer, subject, subjectPublicKey, validityWindow, nameConstraints)
val signer = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider)
return builder.build(signer).apply {
require(isValidOn(Date()))
require(isSignatureValid(JcaContentVerifierProviderBuilder().build(issuerKeyPair.public)))
}
}
/**
* Create certificate signing request using provided information.
*/
fun createCertificateSigningRequest(subject: X500Name, keyPair: KeyPair, signatureScheme: SignatureScheme): PKCS10CertificationRequest {
val signer = ContentSignerBuilder.build(signatureScheme, keyPair.private, providerMap[signatureScheme.providerName])
return JcaPKCS10CertificationRequestBuilder(subject, keyPair.public).build(signer)
}
private class KeyInfoConverter(val signatureScheme: SignatureScheme) : AsymmetricKeyInfoConverter { private class KeyInfoConverter(val signatureScheme: SignatureScheme) : AsymmetricKeyInfoConverter {
override fun generatePublic(keyInfo: SubjectPublicKeyInfo?): PublicKey? = keyInfo?.let { decodePublicKey(signatureScheme, it.encoded) } override fun generatePublic(keyInfo: SubjectPublicKeyInfo?): PublicKey? = keyInfo?.let { decodePublicKey(signatureScheme, it.encoded) }
override fun generatePrivate(keyInfo: PrivateKeyInfo?): PrivateKey? = keyInfo?.let { decodePrivateKey(signatureScheme, it.encoded) } override fun generatePrivate(keyInfo: PrivateKeyInfo?): PrivateKey? = keyInfo?.let { decodePrivateKey(signatureScheme, it.encoded) }

View File

@ -35,7 +35,17 @@ fun PrivateKey.sign(bytesToSign: ByteArray, publicKey: PublicKey): DigitalSignat
*/ */
@Throws(IllegalArgumentException::class, InvalidKeyException::class, SignatureException::class) @Throws(IllegalArgumentException::class, InvalidKeyException::class, SignatureException::class)
fun KeyPair.sign(bytesToSign: ByteArray) = private.sign(bytesToSign, public) fun KeyPair.sign(bytesToSign: ByteArray) = private.sign(bytesToSign, public)
fun KeyPair.sign(bytesToSign: OpaqueBytes) = private.sign(bytesToSign.bytes, public) fun KeyPair.sign(bytesToSign: OpaqueBytes) = sign(bytesToSign.bytes)
/**
* Helper function for signing a [SignableData] object.
* @param signableData the object to be signed.
* @return a [TransactionSignature] object.
* @throws IllegalArgumentException if the signature scheme is not supported for this private key.
* @throws InvalidKeyException if the private key is invalid.
* @throws SignatureException if signing is not possible due to malformed data or private key.
*/
@Throws(InvalidKeyException::class, SignatureException::class)
fun KeyPair.sign(signableData: SignableData): TransactionSignature = Crypto.doSign(this, signableData)
/** /**
* Utility to simplify the act of verifying a signature. * Utility to simplify the act of verifying a signature.
@ -89,7 +99,7 @@ fun PublicKey.containsAny(otherKeys: Iterable<PublicKey>): Boolean {
} }
/** Returns the set of all [PublicKey]s of the signatures */ /** Returns the set of all [PublicKey]s of the signatures */
fun Iterable<DigitalSignature.WithKey>.byKeys() = map { it.by }.toSet() fun Iterable<TransactionSignature>.byKeys() = map { it.by }.toSet()
// Allow Kotlin destructuring: val (private, public) = keyPair // Allow Kotlin destructuring: val (private, public) = keyPair
operator fun KeyPair.component1(): PrivateKey = this.private operator fun KeyPair.component1(): PrivateKey = this.private
@ -106,17 +116,6 @@ fun generateKeyPair(): KeyPair = Crypto.generateKeyPair()
*/ */
fun entropyToKeyPair(entropy: BigInteger): KeyPair = Crypto.deriveKeyPairFromEntropy(entropy) fun entropyToKeyPair(entropy: BigInteger): KeyPair = Crypto.deriveKeyPairFromEntropy(entropy)
/**
* Helper function for signing.
* @param metaData tha attached MetaData object.
* @return a [TransactionSignature ] object.
* @throws IllegalArgumentException if the signature scheme is not supported for this private key.
* @throws InvalidKeyException if the private key is invalid.
* @throws SignatureException if signing is not possible due to malformed data or private key.
*/
@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class)
fun PrivateKey.sign(metaData: MetaData): TransactionSignature = Crypto.doSign(this, metaData)
/** /**
* Helper function to verify a signature. * Helper function to verify a signature.
* @param signatureData the signature on a message. * @param signatureData the signature on a message.
@ -130,21 +129,6 @@ fun PrivateKey.sign(metaData: MetaData): TransactionSignature = Crypto.doSign(th
@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) @Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class)
fun PublicKey.verify(signatureData: ByteArray, clearData: ByteArray): Boolean = Crypto.doVerify(this, signatureData, clearData) fun PublicKey.verify(signatureData: ByteArray, clearData: ByteArray): Boolean = Crypto.doVerify(this, signatureData, clearData)
/**
* Helper function to verify a metadata attached signature. It is noted that the transactionSignature contains
* signatureData and a [MetaData] object that contains the signer's public key and the transaction's Merkle root.
* @param transactionSignature a [TransactionSignature] object that .
* @throws InvalidKeyException if the key is invalid.
* @throws SignatureException if this signatureData object is not initialized properly,
* the passed-in signatureData is improperly encoded or of the wrong type,
* if this signatureData algorithm is unable to process the input data provided, etc.
* @throws IllegalArgumentException if the signature scheme is not supported for this private key or if any of the clear or signature data is empty.
*/
@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class)
fun PublicKey.verify(transactionSignature: TransactionSignature): Boolean {
return Crypto.doVerify(this, transactionSignature)
}
/** /**
* Helper function for the signers to verify their own signature. * Helper function for the signers to verify their own signature.
* @param signatureData the signature on a message. * @param signatureData the signature on a message.

View File

@ -60,7 +60,7 @@ fun String.hexToBase58(): String = hexToByteArray().toBase58()
/** Encoding changer. Hex-[String] to Base64-[String], i.e. "48656C6C6F20576F726C64" -> "SGVsbG8gV29ybGQ=" */ /** Encoding changer. Hex-[String] to Base64-[String], i.e. "48656C6C6F20576F726C64" -> "SGVsbG8gV29ybGQ=" */
fun String.hexToBase64(): String = hexToByteArray().toBase64() fun String.hexToBase64(): String = hexToByteArray().toBase64()
// TODO We use for both CompositeKeys and EdDSAPublicKey custom Kryo serializers and deserializers. We need to specify encoding. // TODO We use for both CompositeKeys and EdDSAPublicKey custom serializers and deserializers. We need to specify encoding.
// TODO: follow the crypto-conditions ASN.1 spec, some changes are needed to be compatible with the condition // TODO: follow the crypto-conditions ASN.1 spec, some changes are needed to be compatible with the condition
// structure, e.g. mapping a PublicKey to a condition with the specific feature (ED25519). // structure, e.g. mapping a PublicKey to a condition with the specific feature (ED25519).
fun parsePublicKeyBase58(base58String: String): PublicKey = base58String.base58ToByteArray().deserialize<PublicKey>() fun parsePublicKeyBase58(base58String: String): PublicKey = base58String.base58ToByteArray().deserialize<PublicKey>()

View File

@ -23,8 +23,10 @@ sealed class MerkleTree {
/** /**
* Merkle tree building using hashes, with zero hash padding to full power of 2. * Merkle tree building using hashes, with zero hash padding to full power of 2.
*/ */
@Throws(IllegalArgumentException::class) @Throws(MerkleTreeException::class)
fun getMerkleTree(allLeavesHashes: List<SecureHash>): MerkleTree { fun getMerkleTree(allLeavesHashes: List<SecureHash>): MerkleTree {
if (allLeavesHashes.isEmpty())
throw MerkleTreeException("Cannot calculate Merkle root on empty hash list.")
val leaves = padWithZeros(allLeavesHashes).map { Leaf(it) } val leaves = padWithZeros(allLeavesHashes).map { Leaf(it) }
return buildMerkleTree(leaves) return buildMerkleTree(leaves)
} }
@ -46,8 +48,6 @@ sealed class MerkleTree {
* @return Tree root. * @return Tree root.
*/ */
private tailrec fun buildMerkleTree(lastNodesList: List<MerkleTree>): MerkleTree { private tailrec fun buildMerkleTree(lastNodesList: List<MerkleTree>): MerkleTree {
if (lastNodesList.isEmpty())
throw MerkleTreeException("Cannot calculate Merkle root on empty hash list.")
if (lastNodesList.size == 1) { if (lastNodesList.size == 1) {
return lastNodesList[0] //Root reached. return lastNodesList[0] //Root reached.
} else { } else {

View File

@ -1,71 +0,0 @@
package net.corda.core.crypto
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.opaque
import net.corda.core.serialization.serialize
import java.security.PublicKey
import java.time.Instant
import java.util.*
/**
* A [MetaData] object adds extra information to a transaction. MetaData is used to support a universal
* digital signature model enabling full, partial, fully or partially blind and metaData attached signatures,
* (such as an attached timestamp). A MetaData object contains both the merkle root of the transaction and the signer's public key.
* When signatureType is set to FULL, then visibleInputs and signedInputs can be ignored.
* Note: We could omit signatureType as it can always be defined by combining visibleInputs and signedInputs,
* but it helps to speed up the process when FULL is used, and thus we can bypass the extra check on boolean arrays.
*
* @param schemeCodeName a signature scheme's code name (e.g. ECDSA_SECP256K1_SHA256).
* @param versionID DLT's version.
* @param signatureType type of the signature, see [SignatureType] (e.g. FULL, PARTIAL, BLIND, PARTIAL_AND_BLIND).
* @param timestamp the signature's timestamp as provided by the signer.
* @param visibleInputs for partially/fully blind signatures. We use Merkle tree boolean index flags (from left to right)
* indicating what parts of the transaction were visible when the signature was calculated.
* @param signedInputs for partial signatures. We use Merkle tree boolean index flags (from left to right)
* indicating what parts of the Merkle tree are actually signed.
* @param merkleRoot the Merkle root of the transaction.
* @param publicKey the signer's public key.
*/
@CordaSerializable
open class MetaData(
val schemeCodeName: String,
val versionID: String,
val signatureType: SignatureType = SignatureType.FULL,
val timestamp: Instant?,
val visibleInputs: BitSet?,
val signedInputs: BitSet?,
val merkleRoot: ByteArray,
val publicKey: PublicKey) {
fun bytes() = this.serialize().bytes
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other?.javaClass != javaClass) return false
other as MetaData
if (schemeCodeName != other.schemeCodeName) return false
if (versionID != other.versionID) return false
if (signatureType != other.signatureType) return false
if (timestamp != other.timestamp) return false
if (visibleInputs != other.visibleInputs) return false
if (signedInputs != other.signedInputs) return false
if (merkleRoot.opaque() != other.merkleRoot.opaque()) return false
if (publicKey != other.publicKey) return false
return true
}
override fun hashCode(): Int {
var result = schemeCodeName.hashCode()
result = 31 * result + versionID.hashCode()
result = 31 * result + signatureType.hashCode()
result = 31 * result + (timestamp?.hashCode() ?: 0)
result = 31 * result + (visibleInputs?.hashCode() ?: 0)
result = 31 * result + (signedInputs?.hashCode() ?: 0)
result = 31 * result + Arrays.hashCode(merkleRoot)
result = 31 * result + publicKey.hashCode()
return result
}
}

View File

@ -5,9 +5,7 @@ import net.corda.core.serialization.CordaSerializable
import java.util.* import java.util.*
@CordaSerializable @CordaSerializable
class MerkleTreeException(val reason: String) : Exception() { class MerkleTreeException(val reason: String) : Exception("Partial Merkle Tree exception. Reason: $reason")
override fun toString() = "Partial Merkle Tree exception. Reason: $reason"
}
/** /**
* Building and verification of Partial Merkle Tree. * Building and verification of Partial Merkle Tree.

View File

@ -1,8 +1,9 @@
package net.corda.core.crypto package net.corda.core.crypto
import com.google.common.io.BaseEncoding
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.parseAsHex
import net.corda.core.utilities.toHexString
import java.security.MessageDigest import java.security.MessageDigest
/** /**
@ -18,7 +19,7 @@ sealed class SecureHash(bytes: ByteArray) : OpaqueBytes(bytes) {
} }
} }
override fun toString(): String = BaseEncoding.base16().encode(bytes) override fun toString(): String = bytes.toHexString()
fun prefixChars(prefixLen: Int = 6) = toString().substring(0, prefixLen) fun prefixChars(prefixLen: Int = 6) = toString().substring(0, prefixLen)
fun hashConcat(other: SecureHash) = (this.bytes + other.bytes).sha256() fun hashConcat(other: SecureHash) = (this.bytes + other.bytes).sha256()
@ -26,7 +27,7 @@ sealed class SecureHash(bytes: ByteArray) : OpaqueBytes(bytes) {
// Like static methods in Java, except the 'companion' is a singleton that can have state. // Like static methods in Java, except the 'companion' is a singleton that can have state.
companion object { companion object {
@JvmStatic @JvmStatic
fun parse(str: String) = BaseEncoding.base16().decode(str.toUpperCase()).let { fun parse(str: String) = str.toUpperCase().parseAsHex().let {
when (it.size) { when (it.size) {
32 -> SHA256(it) 32 -> SHA256(it)
else -> throw IllegalArgumentException("Provided string is ${it.size} bytes not 32 bytes in hex: $str") else -> throw IllegalArgumentException("Provided string is ${it.size} bytes not 32 bytes in hex: $str")

View File

@ -0,0 +1,14 @@
package net.corda.core.crypto
import net.corda.core.serialization.CordaSerializable
/**
* A [SignableData] object is the packet actually signed.
* It works as a wrapper over transaction id and signature metadata.
*
* @param txId transaction's id.
* @param signatureMetadata meta data required.
*/
@CordaSerializable
data class SignableData(val txId: SecureHash, val signatureMetadata: SignatureMetadata)

View File

@ -0,0 +1,15 @@
package net.corda.core.crypto
import net.corda.core.serialization.CordaSerializable
/**
* SignatureMeta is required to add extra meta-data to a transaction's signature.
* It currently supports platformVersion only, but it can be extended to support a universal digital
* signature model enabling partial signatures and attaching extra information, such as a user's timestamp or other
* application-specific fields.
*
* @param platformVersion current DLT version.
* @param schemeNumberID number id of the signature scheme used based on signer's key-pair, see [SignatureScheme.schemeNumberID].
*/
@CordaSerializable
data class SignatureMetadata(val platformVersion: Int, val schemeNumberID: Int)

View File

@ -6,7 +6,7 @@ import java.security.spec.AlgorithmParameterSpec
/** /**
* This class is used to define a digital signature scheme. * This class is used to define a digital signature scheme.
* @param schemeNumberID we assign a number ID for more efficient on-wire serialisation. Please ensure uniqueness between schemes. * @param schemeNumberID we assign a number ID for better efficiency on-wire serialisation. Please ensure uniqueness between schemes.
* @param schemeCodeName code name for this signature scheme (e.g. RSA_SHA256, ECDSA_SECP256K1_SHA256, ECDSA_SECP256R1_SHA256, EDDSA_ED25519_SHA512, SPHINCS-256_SHA512). * @param schemeCodeName code name for this signature scheme (e.g. RSA_SHA256, ECDSA_SECP256K1_SHA256, ECDSA_SECP256R1_SHA256, EDDSA_ED25519_SHA512, SPHINCS-256_SHA512).
* @param signatureOID ASN.1 algorithm identifier of the signature algorithm (e.g 1.3.101.112 for EdDSA) * @param signatureOID ASN.1 algorithm identifier of the signature algorithm (e.g 1.3.101.112 for EdDSA)
* @param alternativeOIDs ASN.1 algorithm identifiers for keys of the signature, where we want to map multiple keys to * @param alternativeOIDs ASN.1 algorithm identifiers for keys of the signature, where we want to map multiple keys to

View File

@ -1,17 +0,0 @@
package net.corda.core.crypto
import net.corda.core.serialization.CordaSerializable
/**
* Supported Signature types:
* <p><ul>
* <li>FULL = signature covers whole transaction, by the convention that signing the Merkle root, it is equivalent to signing all parts of the transaction.
* <li>PARTIAL = signature covers only a part of the transaction, see [MetaData].
* <li>BLIND = when an entity blindly signs without having full knowledge on the content, see [MetaData].
* <li>PARTIAL_AND_BLIND = combined PARTIAL and BLIND in the same time.
* </ul>
*/
@CordaSerializable
enum class SignatureType {
FULL, PARTIAL, BLIND, PARTIAL_AND_BLIND
}

View File

@ -23,7 +23,8 @@ open class SignedData<T : Any>(val raw: SerializedBytes<T>, val sig: DigitalSign
@Throws(SignatureException::class) @Throws(SignatureException::class)
fun verified(): T { fun verified(): T {
sig.by.verify(raw.bytes, sig) sig.by.verify(raw.bytes, sig)
val data = raw.deserialize() @Suppress("UNCHECKED_CAST")
val data = raw.deserialize<Any>() as T
verifyData(data) verifyData(data)
return data return data
} }

View File

@ -1,22 +1,57 @@
package net.corda.core.crypto package net.corda.core.crypto
import net.corda.core.serialization.CordaSerializable
import java.security.InvalidKeyException import java.security.InvalidKeyException
import java.security.PublicKey
import java.security.SignatureException import java.security.SignatureException
import java.util.*
/** /**
* A wrapper around a digital signature accompanied with metadata, see [MetaData.Full] and [DigitalSignature]. * A wrapper over the signature output accompanied by signer's public key and signature metadata.
* The signature protocol works as follows: s = sign(MetaData.hashBytes). * This is similar to [DigitalSignature.WithKey], but targeted to DLT transaction signatures.
*/ */
open class TransactionSignature(val signatureData: ByteArray, val metaData: MetaData) : DigitalSignature(signatureData) { @CordaSerializable
class TransactionSignature(bytes: ByteArray, val by: PublicKey, val signatureMetadata: SignatureMetadata): DigitalSignature(bytes) {
/** /**
* Function to auto-verify a [MetaData] object's signature. * Function to verify a [SignableData] object's signature.
* Note that [MetaData] contains both public key and merkle root of the transaction. * Note that [SignableData] contains the id of the transaction and extra metadata, such as DLT's platform version.
*
* @param txId transaction's id (Merkle root), which along with [signatureMetadata] will be used to construct the [SignableData] object to be signed.
* @throws InvalidKeyException if the key is invalid. * @throws InvalidKeyException if the key is invalid.
* @throws SignatureException if this signatureData object is not initialized properly, * @throws SignatureException if this signatureData object is not initialized properly,
* the passed-in signatureData is improperly encoded or of the wrong type, * the passed-in signatureData is improperly encoded or of the wrong type,
* if this signatureData algorithm is unable to process the input data provided, etc. * if this signatureData algorithm is unable to process the input data provided, etc.
* @throws IllegalArgumentException if the signature scheme is not supported for this private key or if any of the clear or signature data is empty. * @throws IllegalArgumentException if the signature scheme is not supported for this private key or if any of the clear or signature data is empty.
*/ */
@Throws(InvalidKeyException::class, SignatureException::class, IllegalArgumentException::class) @Throws(InvalidKeyException::class, SignatureException::class)
fun verify(): Boolean = Crypto.doVerify(metaData.publicKey, signatureData, metaData.bytes()) fun verify(txId: SecureHash) = Crypto.doVerify(txId, this)
/**
* Utility to simplify the act of verifying a signature. In comparison to [verify] doesn't throw an
* exception, making it more suitable where a boolean is required, but normally you should use the function
* which throws, as it avoids the risk of failing to test the result.
*
* @throws InvalidKeyException if the key to verify the signature with is not valid (i.e. wrong key type for the
* signature).
* @throws SignatureException if the signature is invalid (i.e. damaged).
* @return whether the signature is correct for this key.
*/
@Throws(InvalidKeyException::class, SignatureException::class)
fun isValid(txId: SecureHash) = Crypto.isValid(txId, this)
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other !is TransactionSignature) return false
return (Arrays.equals(bytes, other.bytes)
&& by == other.by
&& signatureMetadata == other.signatureMetadata)
}
override fun hashCode(): Int {
var result = super.hashCode()
result = 31 * result + by.hashCode()
result = 31 * result + signatureMetadata.hashCode()
return result
}
} }

View File

@ -0,0 +1,77 @@
@file:JvmName("X500NameUtils")
package net.corda.core.crypto
import org.bouncycastle.asn1.ASN1Encodable
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.X500NameBuilder
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.cert.X509CertificateHolder
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter
import java.security.KeyPair
import java.security.cert.X509Certificate
/**
* Rebuild the distinguished name, adding a postfix to the common name. If no common name is present.
* @throws IllegalArgumentException if the distinguished name does not contain a common name element.
*/
fun X500Name.appendToCommonName(commonName: String): X500Name = mutateCommonName { attr -> attr.toString() + commonName }
/**
* Rebuild the distinguished name, replacing the common name with the given value. If no common name is present, this
* adds one.
* @throws IllegalArgumentException if the distinguished name does not contain a common name element.
*/
fun X500Name.replaceCommonName(commonName: String): X500Name = mutateCommonName { _ -> commonName }
/**
* Rebuild the distinguished name, replacing the common name with a value generated from the provided function.
*
* @param mutator a function to generate the new value from the previous one.
* @throws IllegalArgumentException if the distinguished name does not contain a common name element.
*/
private fun X500Name.mutateCommonName(mutator: (ASN1Encodable) -> String): X500Name {
val builder = X500NameBuilder(BCStyle.INSTANCE)
var matched = false
this.rdNs.forEach { rdn ->
rdn.typesAndValues.forEach { typeAndValue ->
when (typeAndValue.type) {
BCStyle.CN -> {
matched = true
builder.addRDN(typeAndValue.type, mutator(typeAndValue.value))
}
else -> {
builder.addRDN(typeAndValue)
}
}
}
}
require(matched) { "Input X.500 name must include a common name (CN) attribute: ${this}" }
return builder.build()
}
val X500Name.commonName: String get() = getRDNs(BCStyle.CN).first().first.value.toString()
val X500Name.orgName: String? get() = getRDNs(BCStyle.O).firstOrNull()?.first?.value?.toString()
val X500Name.location: String get() = getRDNs(BCStyle.L).first().first.value.toString()
val X500Name.locationOrNull: String? get() = try {
location
} catch (e: Exception) {
null
}
val X509Certificate.subject: X500Name get() = X509CertificateHolder(encoded).subject
val X509CertificateHolder.cert: X509Certificate get() = JcaX509CertificateConverter().getCertificate(this)
/**
* Generate a distinguished name from the provided values.
*/
@JvmOverloads
fun getX509Name(myLegalName: String, nearestCity: String, email: String, country: String? = null): X500Name {
return X500NameBuilder(BCStyle.INSTANCE).let { builder ->
builder.addRDN(BCStyle.CN, myLegalName)
builder.addRDN(BCStyle.L, nearestCity)
country?.let { builder.addRDN(BCStyle.C, it) }
builder.addRDN(BCStyle.E, email)
builder.build()
}
}
data class CertificateAndKeyPair(val certificate: X509CertificateHolder, val keyPair: KeyPair)

View File

@ -4,13 +4,13 @@ import net.corda.core.crypto.Crypto
import net.corda.core.crypto.composite.CompositeKey.NodeAndWeight import net.corda.core.crypto.composite.CompositeKey.NodeAndWeight
import net.corda.core.crypto.keys import net.corda.core.crypto.keys
import net.corda.core.crypto.provider.CordaObjectIdentifier import net.corda.core.crypto.provider.CordaObjectIdentifier
import net.corda.core.crypto.toSHA256Bytes
import net.corda.core.crypto.toStringShort import net.corda.core.crypto.toStringShort
import net.corda.core.utilities.exactAdd
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.sequence
import org.bouncycastle.asn1.* import org.bouncycastle.asn1.*
import org.bouncycastle.asn1.x509.AlgorithmIdentifier import org.bouncycastle.asn1.x509.AlgorithmIdentifier
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import java.nio.ByteBuffer
import java.security.PublicKey import java.security.PublicKey
import java.util.* import java.util.*
@ -59,7 +59,7 @@ class CompositeKey private constructor(val threshold: Int, children: List<NodeAn
} }
} }
val children = children.sorted() val children: List<NodeAndWeight> = children.sorted()
init { init {
// TODO: replace with the more extensive, but slower, checkValidity() test. // TODO: replace with the more extensive, but slower, checkValidity() test.
@ -127,7 +127,7 @@ class CompositeKey private constructor(val threshold: Int, children: List<NodeAn
var sum = 0 var sum = 0
for ((_, weight) in children) { for ((_, weight) in children) {
require(weight > 0) { "Non-positive weight: $weight detected." } require(weight > 0) { "Non-positive weight: $weight detected." }
sum = Math.addExact(sum, weight) // Add and check for integer overflow. sum = sum exactAdd weight // Add and check for integer overflow.
} }
return sum return sum
} }
@ -145,7 +145,7 @@ class CompositeKey private constructor(val threshold: Int, children: List<NodeAn
override fun compareTo(other: NodeAndWeight): Int { override fun compareTo(other: NodeAndWeight): Int {
return if (weight == other.weight) return if (weight == other.weight)
ByteBuffer.wrap(node.toSHA256Bytes()).compareTo(ByteBuffer.wrap(other.node.toSHA256Bytes())) node.encoded.sequence().compareTo(other.node.encoded.sequence())
else else
weight.compareTo(other.weight) weight.compareTo(other.weight)
} }

View File

@ -1,8 +1,7 @@
package net.corda.core.crypto.composite package net.corda.core.crypto.composite
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import org.bouncycastle.asn1.ASN1ObjectIdentifier
import org.bouncycastle.asn1.x509.AlgorithmIdentifier
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
import java.security.* import java.security.*
import java.security.spec.AlgorithmParameterSpec import java.security.spec.AlgorithmParameterSpec
@ -77,7 +76,7 @@ class CompositeSignature : Signature(SIGNATURE_ALGORITHM) {
fun engineVerify(sigBytes: ByteArray): Boolean { fun engineVerify(sigBytes: ByteArray): Boolean {
val sig = sigBytes.deserialize<CompositeSignaturesWithKeys>() val sig = sigBytes.deserialize<CompositeSignaturesWithKeys>()
return if (verifyKey.isFulfilledBy(sig.sigs.map { it.by })) { return if (verifyKey.isFulfilledBy(sig.sigs.map { it.by })) {
val clearData = buffer.toByteArray() val clearData = SecureHash.SHA256(buffer.toByteArray())
sig.sigs.all { it.isValid(clearData) } sig.sigs.all { it.isValid(clearData) }
} else { } else {
false false

View File

@ -1,14 +1,14 @@
package net.corda.core.crypto.composite package net.corda.core.crypto.composite
import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
/** /**
* Custom class for holding signature data. This exists for later extension work to provide a standardised cross-platform * Custom class for holding signature data. This exists for later extension work to provide a standardised cross-platform
* serialization format (i.e. not Kryo). * serialization format.
*/ */
@CordaSerializable @CordaSerializable
data class CompositeSignaturesWithKeys(val sigs: List<DigitalSignature.WithKey>) { data class CompositeSignaturesWithKeys(val sigs: List<TransactionSignature>) {
companion object { companion object {
val EMPTY = CompositeSignaturesWithKeys(emptyList()) val EMPTY = CompositeSignaturesWithKeys(emptyList())
} }

View File

@ -1,35 +0,0 @@
package net.corda.core.crypto.testing
import net.corda.core.crypto.DigitalSignature
import net.corda.core.identity.AnonymousParty
import net.corda.core.serialization.CordaSerializable
import java.math.BigInteger
import java.security.PublicKey
@CordaSerializable
object NullPublicKey : PublicKey, Comparable<PublicKey> {
override fun getAlgorithm() = "NULL"
override fun getEncoded() = byteArrayOf(0)
override fun getFormat() = "NULL"
override fun compareTo(other: PublicKey): Int = if (other == NullPublicKey) 0 else -1
override fun toString() = "NULL_KEY"
}
val NULL_PARTY = AnonymousParty(NullPublicKey)
// TODO: Clean up this duplication between Null and Dummy public key
@CordaSerializable
@Deprecated("Has encoding format problems, consider entropyToKeyPair() instead")
class DummyPublicKey(val s: String) : PublicKey, Comparable<PublicKey> {
override fun getAlgorithm() = "DUMMY"
override fun getEncoded() = s.toByteArray()
override fun getFormat() = "ASN.1"
override fun compareTo(other: PublicKey): Int = BigInteger(encoded).compareTo(BigInteger(other.encoded))
override fun equals(other: Any?) = other is DummyPublicKey && other.s == s
override fun hashCode(): Int = s.hashCode()
override fun toString() = "PUBKEY[$s]"
}
/** A signature with a key and value of zero. Useful when you want a signature object that you know won't ever be used. */
@CordaSerializable
object NullSignature : DigitalSignature.WithKey(NullPublicKey, ByteArray(32))

View File

@ -0,0 +1,21 @@
package net.corda.core.crypto.testing
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.crypto.TransactionSignature
import net.corda.core.identity.AnonymousParty
import net.corda.core.serialization.CordaSerializable
import java.security.PublicKey
@CordaSerializable
object NullPublicKey : PublicKey, Comparable<PublicKey> {
override fun getAlgorithm() = "NULL"
override fun getEncoded() = byteArrayOf(0)
override fun getFormat() = "NULL"
override fun compareTo(other: PublicKey): Int = if (other == NullPublicKey) 0 else -1
override fun toString() = "NULL_KEY"
}
val NULL_PARTY = AnonymousParty(NullPublicKey)
/** A signature with a key and value of zero. Useful when you want a signature object that you know won't ever be used. */
val NULL_SIGNATURE = TransactionSignature(ByteArray(32), NullPublicKey, SignatureMetadata(1, -1))

View File

@ -1,19 +1,14 @@
package net.corda.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.isFulfilledBy
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.AnonymousParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
@ -32,7 +27,7 @@ abstract class AbstractStateReplacementFlow {
* @param M the type of a class representing proposed modification by the instigator. * @param M the type of a class representing proposed modification by the instigator.
*/ */
@CordaSerializable @CordaSerializable
data class Proposal<out M>(val stateRef: StateRef, val modification: M, val stx: SignedTransaction) data class Proposal<out M>(val stateRef: StateRef, val modification: M)
/** /**
* The assembled transaction for upgrading a contract. * The assembled transaction for upgrading a contract.
@ -56,7 +51,7 @@ abstract class AbstractStateReplacementFlow {
abstract class Instigator<out S : ContractState, out T : ContractState, out M>( abstract class Instigator<out S : ContractState, out T : ContractState, out M>(
val originalState: StateAndRef<S>, val originalState: StateAndRef<S>,
val modification: M, val modification: M,
override val progressTracker: ProgressTracker = tracker()) : FlowLogic<StateAndRef<T>>() { override val progressTracker: ProgressTracker = Instigator.tracker()) : FlowLogic<StateAndRef<T>>() {
companion object { companion object {
object SIGNING : ProgressTracker.Step("Requesting signatures from other parties") object SIGNING : ProgressTracker.Step("Requesting signatures from other parties")
object NOTARY : ProgressTracker.Step("Requesting notary signature") object NOTARY : ProgressTracker.Step("Requesting notary signature")
@ -79,7 +74,16 @@ abstract class AbstractStateReplacementFlow {
val finalTx = stx + signatures val finalTx = stx + signatures
serviceHub.recordTransactions(finalTx) serviceHub.recordTransactions(finalTx)
return finalTx.tx.outRef(0)
val newOutput = run {
if (stx.isNotaryChangeTransaction()) {
stx.resolveNotaryChangeTransaction(serviceHub).outRef<T>(0)
} else {
stx.tx.outRef<T>(0)
}
}
return newOutput
} }
/** /**
@ -91,7 +95,7 @@ abstract class AbstractStateReplacementFlow {
abstract protected fun assembleTx(): UpgradeTx abstract protected fun assembleTx(): UpgradeTx
@Suspendable @Suspendable
private fun collectSignatures(participants: Iterable<PublicKey>, stx: SignedTransaction): List<DigitalSignature.WithKey> { private fun collectSignatures(participants: Iterable<PublicKey>, stx: SignedTransaction): List<TransactionSignature> {
val parties = participants.map { val parties = participants.map {
val participantNode = serviceHub.networkMapCache.getNodeByLegalIdentityKey(it) ?: val participantNode = serviceHub.networkMapCache.getNodeByLegalIdentityKey(it) ?:
throw IllegalStateException("Participant $it to state $originalState not found on the network") throw IllegalStateException("Participant $it to state $originalState not found on the network")
@ -109,10 +113,10 @@ abstract class AbstractStateReplacementFlow {
} }
@Suspendable @Suspendable
private fun getParticipantSignature(party: Party, stx: SignedTransaction): DigitalSignature.WithKey { private fun getParticipantSignature(party: Party, stx: SignedTransaction): TransactionSignature {
val proposal = Proposal(originalState.ref, modification, stx) val proposal = Proposal(originalState.ref, modification)
val response = sendAndReceive<DigitalSignature.WithKey>(party, proposal) subFlow(SendTransactionFlow(party, stx))
return response.unwrap { return sendAndReceive<TransactionSignature>(party, proposal).unwrap {
check(party.owningKey.isFulfilledBy(it.by)) { "Not signed by the required participant" } check(party.owningKey.isFulfilledBy(it.by)) { "Not signed by the required participant" }
it.verify(stx.id) it.verify(stx.id)
it it
@ -120,7 +124,7 @@ abstract class AbstractStateReplacementFlow {
} }
@Suspendable @Suspendable
private fun getNotarySignatures(stx: SignedTransaction): List<DigitalSignature.WithKey> { private fun getNotarySignatures(stx: SignedTransaction): List<TransactionSignature> {
progressTracker.currentStep = NOTARY progressTracker.currentStep = NOTARY
try { try {
return subFlow(NotaryFlow.Client(stx)) return subFlow(NotaryFlow.Client(stx))
@ -133,7 +137,7 @@ abstract class AbstractStateReplacementFlow {
// Type parameter should ideally be Unit but that prevents Java code from subclassing it (https://youtrack.jetbrains.com/issue/KT-15964). // Type parameter should ideally be Unit but that prevents Java code from subclassing it (https://youtrack.jetbrains.com/issue/KT-15964).
// We use Void? instead of Unit? as that's what you'd use in Java. // We use Void? instead of Unit? as that's what you'd use in Java.
abstract class Acceptor<in T>(val otherSide: Party, abstract class Acceptor<in T>(val otherSide: Party,
override val progressTracker: ProgressTracker = tracker()) : FlowLogic<Void?>() { override val progressTracker: ProgressTracker = Acceptor.tracker()) : FlowLogic<Void?>() {
companion object { companion object {
object VERIFYING : ProgressTracker.Step("Verifying state replacement proposal") object VERIFYING : ProgressTracker.Step("Verifying state replacement proposal")
object APPROVING : ProgressTracker.Step("State replacement approved") object APPROVING : ProgressTracker.Step("State replacement approved")
@ -145,63 +149,61 @@ abstract class AbstractStateReplacementFlow {
@Throws(StateReplacementException::class) @Throws(StateReplacementException::class)
override fun call(): Void? { override fun call(): Void? {
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
// We expect stx to have insufficient signatures here
val stx = subFlow(ReceiveTransactionFlow(otherSide, checkSufficientSignatures = false))
checkMySignatureRequired(stx)
val maybeProposal: UntrustworthyData<Proposal<T>> = receive(otherSide) val maybeProposal: UntrustworthyData<Proposal<T>> = receive(otherSide)
val stx: SignedTransaction = maybeProposal.unwrap { maybeProposal.unwrap {
verifyProposal(it) verifyProposal(stx, it)
verifyTx(it.stx)
it.stx
} }
approve(stx) approve(stx)
return null return null
} }
@Suspendable
private fun verifyTx(stx: SignedTransaction) {
checkMySignatureRequired(stx.tx)
checkDependenciesValid(stx)
// We expect stx to have insufficient signatures, so we convert the WireTransaction to the LedgerTransaction
// here, thus bypassing the sufficient-signatures check.
stx.tx.toLedgerTransaction(serviceHub).verify()
}
@Suspendable @Suspendable
private fun approve(stx: SignedTransaction) { private fun approve(stx: SignedTransaction) {
progressTracker.currentStep = APPROVING progressTracker.currentStep = APPROVING
val mySignature = sign(stx) val mySignature = sign(stx)
val swapSignatures = sendAndReceive<List<DigitalSignature.WithKey>>(otherSide, mySignature) val swapSignatures = sendAndReceive<List<TransactionSignature>>(otherSide, mySignature)
// TODO: This step should not be necessary, as signatures are re-checked in verifySignatures. // TODO: This step should not be necessary, as signatures are re-checked in verifyRequiredSignatures.
val allSignatures = swapSignatures.unwrap { signatures -> val allSignatures = swapSignatures.unwrap { signatures ->
signatures.forEach { it.verify(stx.id) } signatures.forEach { it.verify(stx.id) }
signatures signatures
} }
val finalTx = stx + allSignatures val finalTx = stx + allSignatures
finalTx.verifySignatures() if (finalTx.isNotaryChangeTransaction()) {
finalTx.resolveNotaryChangeTransaction(serviceHub).verifyRequiredSignatures()
} else {
finalTx.verifyRequiredSignatures()
}
serviceHub.recordTransactions(finalTx) serviceHub.recordTransactions(finalTx)
} }
/** /**
* Check the state change proposal to confirm that it's acceptable to this node. Rules for verification depend * Check the state change proposal and the signed transaction to confirm that it's acceptable to this node.
* on the change proposed, and may further depend on the node itself (for example configuration). The * Rules for verification depend on the change proposed, and may further depend on the node itself (for example configuration).
* proposal is returned if acceptable, otherwise a [StateReplacementException] is thrown. * The proposal is returned if acceptable, otherwise a [StateReplacementException] is thrown.
*/ */
@Throws(StateReplacementException::class) @Throws(StateReplacementException::class)
abstract protected fun verifyProposal(proposal: Proposal<T>) abstract protected fun verifyProposal(stx: SignedTransaction, proposal: Proposal<T>)
private fun checkMySignatureRequired(tx: WireTransaction) { private fun checkMySignatureRequired(stx: SignedTransaction) {
// TODO: use keys from the keyManagementService instead // TODO: use keys from the keyManagementService instead
val myKey = serviceHub.myInfo.legalIdentity.owningKey val myKey = serviceHub.myInfo.legalIdentity.owningKey
require(myKey in tx.mustSign) { "Party is not a participant for any of the input states of transaction ${tx.id}" }
val requiredKeys = if (stx.isNotaryChangeTransaction()) {
stx.resolveNotaryChangeTransaction(serviceHub).requiredSigningKeys
} else {
stx.tx.requiredSigningKeys
} }
@Suspendable require(myKey in requiredKeys) { "Party is not a participant for any of the input states of transaction ${stx.id}" }
private fun checkDependenciesValid(stx: SignedTransaction) {
subFlow(ResolveTransactionsFlow(stx.tx, otherSide))
} }
private fun sign(stx: SignedTransaction): DigitalSignature.WithKey { private fun sign(stx: SignedTransaction): TransactionSignature {
return serviceHub.createSignature(stx) return serviceHub.createSignature(stx)
} }
} }

View File

@ -1,11 +1,9 @@
package net.corda.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet
/** /**
* Notify the specified parties about a transaction. The remote peers will download this transaction and its * Notify the specified parties about a transaction. The remote peers will download this transaction and its
@ -18,17 +16,13 @@ import net.corda.core.transactions.SignedTransaction
*/ */
@InitiatingFlow @InitiatingFlow
class BroadcastTransactionFlow(val notarisedTransaction: SignedTransaction, class BroadcastTransactionFlow(val notarisedTransaction: SignedTransaction,
val participants: Set<Party>) : FlowLogic<Unit>() { val participants: NonEmptySet<Party>) : FlowLogic<Unit>() {
@CordaSerializable
data class NotifyTxRequest(val tx: SignedTransaction)
@Suspendable @Suspendable
override fun call() { override fun call() {
// TODO: Messaging layer should handle this broadcast for us // TODO: Messaging layer should handle this broadcast for us
val msg = NotifyTxRequest(notarisedTransaction)
participants.filter { it != serviceHub.myInfo.legalIdentity }.forEach { participant -> participants.filter { it != serviceHub.myInfo.legalIdentity }.forEach { participant ->
// This pops out the other side in NotifyTransactionHandler // SendTransactionFlow allows otherParty to access our data to resolve the transaction.
send(participant, msg) subFlow(SendTransactionFlow(participant, notarisedTransaction))
} }
} }
} }

View File

@ -1,11 +1,9 @@
package net.corda.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.isFulfilledBy
import net.corda.core.crypto.toBase58String import net.corda.core.crypto.toBase58String
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
@ -19,7 +17,7 @@ import java.security.PublicKey
* *
* You would typically use this flow after you have built a transaction with the TransactionBuilder and signed it with * You would typically use this flow after you have built a transaction with the TransactionBuilder and signed it with
* your key pair. If there are additional signatures to collect then they can be collected using this flow. Signatures * your key pair. If there are additional signatures to collect then they can be collected using this flow. Signatures
* are collected based upon the [WireTransaction.mustSign] property which contains the union of all the PublicKeys * are collected based upon the [WireTransaction.requiredSigningKeys] property which contains the union of all the PublicKeys
* listed in the transaction's commands as well as a notary's public key, if required. This flow returns a * listed in the transaction's commands as well as a notary's public key, if required. This flow returns a
* [SignedTransaction] which can then be passed to the [FinalityFlow] for notarisation. The other side of this flow is * [SignedTransaction] which can then be passed to the [FinalityFlow] for notarisation. The other side of this flow is
* the [SignTransactionFlow]. * the [SignTransactionFlow].
@ -44,7 +42,7 @@ import java.security.PublicKey
* *
* Example - issuing a multi-lateral agreement which requires N signatures: * Example - issuing a multi-lateral agreement which requires N signatures:
* *
* val builder = TransactionType.General.Builder(notaryRef) * val builder = TransactionBuilder(notaryRef)
* val issueCommand = Command(Agreement.Commands.Issue(), state.participants) * val issueCommand = Command(Agreement.Commands.Issue(), state.participants)
* *
* builder.withItems(state, issueCommand) * builder.withItems(state, issueCommand)
@ -62,7 +60,7 @@ import java.security.PublicKey
// TODO: AbstractStateReplacementFlow needs updating to use this flow. // TODO: AbstractStateReplacementFlow needs updating to use this flow.
// TODO: Update this flow to handle randomly generated keys when that works is complete. // TODO: Update this flow to handle randomly generated keys when that works is complete.
class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction, class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction,
override val progressTracker: ProgressTracker = tracker()): FlowLogic<SignedTransaction>() { override val progressTracker: ProgressTracker = CollectSignaturesFlow.tracker()) : FlowLogic<SignedTransaction>() {
companion object { companion object {
object COLLECTING : ProgressTracker.Step("Collecting signatures from counter-parties.") object COLLECTING : ProgressTracker.Step("Collecting signatures from counter-parties.")
@ -80,7 +78,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction,
// Usually just the Initiator and possibly an oracle would have signed at this point. // Usually just the Initiator and possibly an oracle would have signed at this point.
val myKey = serviceHub.myInfo.legalIdentity.owningKey val myKey = serviceHub.myInfo.legalIdentity.owningKey
val signed = partiallySignedTx.sigs.map { it.by } val signed = partiallySignedTx.sigs.map { it.by }
val notSigned = partiallySignedTx.tx.mustSign - signed val notSigned = partiallySignedTx.tx.requiredSigningKeys - signed
// One of the signatures collected so far MUST be from the initiator of this flow. // One of the signatures collected so far MUST be from the initiator of this flow.
require(partiallySignedTx.sigs.any { it.by == myKey }) { require(partiallySignedTx.sigs.any { it.by == myKey }) {
@ -88,7 +86,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction,
} }
// The signatures must be valid and the transaction must be valid. // The signatures must be valid and the transaction must be valid.
partiallySignedTx.verifySignatures(*notSigned.toTypedArray()) partiallySignedTx.verifySignaturesExcept(*notSigned.toTypedArray())
partiallySignedTx.tx.toLedgerTransaction(serviceHub).verify() partiallySignedTx.tx.toLedgerTransaction(serviceHub).verify()
// Determine who still needs to sign. // Determine who still needs to sign.
@ -107,7 +105,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction,
// Verify all but the notary's signature if the transaction requires a notary, otherwise verify all signatures. // Verify all but the notary's signature if the transaction requires a notary, otherwise verify all signatures.
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
if (notaryKey != null) stx.verifySignatures(notaryKey) else stx.verifySignatures() if (notaryKey != null) stx.verifySignaturesExcept(notaryKey) else stx.verifyRequiredSignatures()
return stx return stx
} }
@ -115,7 +113,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction,
/** /**
* Lookup the [Party] object for each [PublicKey] using the [ServiceHub.networkMapCache]. * Lookup the [Party] object for each [PublicKey] using the [ServiceHub.networkMapCache].
*/ */
@Suspendable private fun keysToParties(keys: List<PublicKey>): List<Party> = keys.map { @Suspendable private fun keysToParties(keys: Collection<PublicKey>): List<Party> = keys.map {
// TODO: Revisit when IdentityService supports resolution of a (possibly random) public key to a legal identity key. // TODO: Revisit when IdentityService supports resolution of a (possibly random) public key to a legal identity key.
val partyNode = serviceHub.networkMapCache.getNodeByLegalIdentityKey(it) val partyNode = serviceHub.networkMapCache.getNodeByLegalIdentityKey(it)
?: throw IllegalStateException("Party ${it.toBase58String()} not found on the network.") ?: throw IllegalStateException("Party ${it.toBase58String()} not found on the network.")
@ -126,8 +124,10 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction,
/** /**
* Get and check the required signature. * Get and check the required signature.
*/ */
@Suspendable private fun collectSignature(counterparty: Party): DigitalSignature.WithKey { @Suspendable private fun collectSignature(counterparty: Party): TransactionSignature {
return sendAndReceive<DigitalSignature.WithKey>(counterparty, partiallySignedTx).unwrap { // SendTransactionFlow allows otherParty to access our data to resolve the transaction.
subFlow(SendTransactionFlow(counterparty, partiallySignedTx))
return receive<TransactionSignature>(counterparty).unwrap {
require(counterparty.owningKey.isFulfilledBy(it.by)) { "Not signed by the required Party." } require(counterparty.owningKey.isFulfilledBy(it.by)) { "Not signed by the required Party." }
it it
} }
@ -175,7 +175,7 @@ class CollectSignaturesFlow(val partiallySignedTx: SignedTransaction,
* @param otherParty The counter-party which is providing you a transaction to sign. * @param otherParty The counter-party which is providing you a transaction to sign.
*/ */
abstract class SignTransactionFlow(val otherParty: Party, abstract class SignTransactionFlow(val otherParty: Party,
override val progressTracker: ProgressTracker = tracker()) : FlowLogic<SignedTransaction>() { override val progressTracker: ProgressTracker = SignTransactionFlow.tracker()) : FlowLogic<SignedTransaction>() {
companion object { companion object {
object RECEIVING : ProgressTracker.Step("Receiving transaction proposal for signing.") object RECEIVING : ProgressTracker.Step("Receiving transaction proposal for signing.")
@ -187,35 +187,30 @@ abstract class SignTransactionFlow(val otherParty: Party,
@Suspendable override fun call(): SignedTransaction { @Suspendable override fun call(): SignedTransaction {
progressTracker.currentStep = RECEIVING progressTracker.currentStep = RECEIVING
val checkedProposal = receive<SignedTransaction>(otherParty).unwrap { proposal -> // Receive transaction and resolve dependencies, check sufficient signatures is disabled as we don't have all signatures.
val stx = subFlow(ReceiveTransactionFlow(otherParty, checkSufficientSignatures = false))
progressTracker.currentStep = VERIFYING progressTracker.currentStep = VERIFYING
// Check that the Responder actually needs to sign. // Check that the Responder actually needs to sign.
checkMySignatureRequired(proposal) checkMySignatureRequired(stx)
// Check the signatures which have already been provided. Usually the Initiators and possibly an Oracle's. // Check the signatures which have already been provided. Usually the Initiators and possibly an Oracle's.
checkSignatures(proposal) checkSignatures(stx)
// Resolve dependencies and verify, pass in the WireTransaction as we don't have all signatures. stx.tx.toLedgerTransaction(serviceHub).verify()
subFlow(ResolveTransactionsFlow(proposal.tx, otherParty))
proposal.tx.toLedgerTransaction(serviceHub).verify()
// Perform some custom verification over the transaction. // Perform some custom verification over the transaction.
try { try {
checkTransaction(proposal) checkTransaction(stx)
} catch(e: Exception) { } catch(e: Exception) {
if (e is IllegalStateException || e is IllegalArgumentException || e is AssertionError) if (e is IllegalStateException || e is IllegalArgumentException || e is AssertionError)
throw FlowException(e) throw FlowException(e)
else else
throw e throw e
} }
// All good. Unwrap the proposal.
proposal
}
// Sign and send back our signature to the Initiator. // Sign and send back our signature to the Initiator.
progressTracker.currentStep = SIGNING progressTracker.currentStep = SIGNING
val mySignature = serviceHub.createSignature(checkedProposal) val mySignature = serviceHub.createSignature(stx)
send(otherParty, mySignature) send(otherParty, mySignature)
// Return the fully signed transaction once it has been committed. // Return the fully signed transaction once it has been committed.
return waitForLedgerCommit(checkedProposal.id) return waitForLedgerCommit(stx.id)
} }
@Suspendable private fun checkSignatures(stx: SignedTransaction) { @Suspendable private fun checkSignatures(stx: SignedTransaction) {
@ -223,9 +218,9 @@ abstract class SignTransactionFlow(val otherParty: Party,
"The Initiator of CollectSignaturesFlow must have signed the transaction." "The Initiator of CollectSignaturesFlow must have signed the transaction."
} }
val signed = stx.sigs.map { it.by } val signed = stx.sigs.map { it.by }
val allSigners = stx.tx.mustSign val allSigners = stx.tx.requiredSigningKeys
val notSigned = allSigners - signed val notSigned = allSigners - signed
stx.verifySignatures(*notSigned.toTypedArray()) stx.verifySignaturesExcept(*notSigned.toTypedArray())
} }
/** /**
@ -253,7 +248,7 @@ abstract class SignTransactionFlow(val otherParty: Party,
@Suspendable private fun checkMySignatureRequired(stx: SignedTransaction) { @Suspendable private fun checkMySignatureRequired(stx: SignedTransaction) {
// TODO: Revisit when key management is properly fleshed out. // TODO: Revisit when key management is properly fleshed out.
val myKey = serviceHub.myInfo.legalIdentity.owningKey val myKey = serviceHub.myInfo.legalIdentity.owningKey
require(myKey in stx.tx.mustSign) { require(myKey in stx.tx.requiredSigningKeys) {
"Party is not a participant for any of the input states of transaction ${stx.id}" "Party is not a participant for any of the input states of transaction ${stx.id}"
} }
} }

View File

@ -1,10 +1,7 @@
package net.corda.flows package net.corda.core.flows
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.flows.InitiatingFlow import net.corda.core.transactions.LedgerTransaction
import net.corda.core.flows.StartableByRPC
import net.corda.core.identity.AbstractParty
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import java.security.PublicKey import java.security.PublicKey
@ -25,14 +22,17 @@ class ContractUpgradeFlow<OldState : ContractState, out NewState : ContractState
companion object { companion object {
@JvmStatic @JvmStatic
fun verify(tx: TransactionForContract) { fun verify(tx: LedgerTransaction) {
// Contract Upgrade transaction should have 1 input, 1 output and 1 command. // Contract Upgrade transaction should have 1 input, 1 output and 1 command.
verify(tx.inputs.single(), tx.outputs.single(), tx.commands.map { Command(it.value, it.signers) }.single()) verify(
tx.inputStates.single(),
tx.outputStates.single(),
tx.commandsOfType<UpgradeCommand>().single())
} }
@JvmStatic @JvmStatic
fun verify(input: ContractState, output: ContractState, commandData: Command) { fun verify(input: ContractState, output: ContractState, commandData: Command<UpgradeCommand>) {
val command = commandData.value as UpgradeCommand val command = commandData.value
val participantKeys: Set<PublicKey> = input.participants.map { it.owningKey }.toSet() val participantKeys: Set<PublicKey> = input.participants.map { it.owningKey }.toSet()
val keysThatSigned: Set<PublicKey> = commandData.signers.toSet() val keysThatSigned: Set<PublicKey> = commandData.signers.toSet()
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
@ -47,19 +47,22 @@ class ContractUpgradeFlow<OldState : ContractState, out NewState : ContractState
fun <OldState : ContractState, NewState : ContractState> assembleBareTx( fun <OldState : ContractState, NewState : ContractState> assembleBareTx(
stateRef: StateAndRef<OldState>, stateRef: StateAndRef<OldState>,
upgradedContractClass: Class<out UpgradedContract<OldState, NewState>> upgradedContractClass: Class<out UpgradedContract<OldState, NewState>>,
privacySalt: PrivacySalt
): TransactionBuilder { ): TransactionBuilder {
val contractUpgrade = upgradedContractClass.newInstance() val contractUpgrade = upgradedContractClass.newInstance()
return TransactionType.General.Builder(stateRef.state.notary) return TransactionBuilder(stateRef.state.notary)
.withItems( .withItems(
stateRef, stateRef,
contractUpgrade.upgrade(stateRef.state.data), contractUpgrade.upgrade(stateRef.state.data),
Command(UpgradeCommand(upgradedContractClass), stateRef.state.data.participants.map { it.owningKey })) Command(UpgradeCommand(upgradedContractClass), stateRef.state.data.participants.map { it.owningKey }),
privacySalt
)
} }
} }
override fun assembleTx(): AbstractStateReplacementFlow.UpgradeTx { override fun assembleTx(): AbstractStateReplacementFlow.UpgradeTx {
val baseTx = assembleBareTx(originalState, modification) val baseTx = assembleBareTx(originalState, modification, PrivacySalt())
val participantKeys = originalState.state.data.participants.map { it.owningKey }.toSet() val participantKeys = originalState.state.data.participants.map { it.owningKey }.toSet()
// TODO: We need a much faster way of finding our key in the transaction // TODO: We need a much faster way of finding our key in the transaction
val myKey = serviceHub.keyManagementService.filterMyKeys(participantKeys).single() val myKey = serviceHub.keyManagementService.filterMyKeys(participantKeys).single()

View File

@ -1,16 +1,18 @@
package net.corda.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.ContractState import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.isFulfilledBy
import net.corda.core.flows.FlowLogic import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.ResolveTransactionsFlow
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.toNonEmptySet
/** /**
* Verifies the given transactions, then sends them to the named notary. If the notary agrees that the transactions * Verifies the given transactions, then sends them to the named notary. If the notary agrees that the transactions
@ -32,9 +34,10 @@ import net.corda.core.utilities.ProgressTracker
* @param transactions What to commit. * @param transactions What to commit.
* @param extraRecipients A list of additional participants to inform of the transaction. * @param extraRecipients A list of additional participants to inform of the transaction.
*/ */
class FinalityFlow(val transactions: Iterable<SignedTransaction>, open class FinalityFlow(val transactions: Iterable<SignedTransaction>,
val extraRecipients: Set<Party>, val extraRecipients: Set<Party>,
override val progressTracker: ProgressTracker) : FlowLogic<List<SignedTransaction>>() { override val progressTracker: ProgressTracker) : FlowLogic<List<SignedTransaction>>() {
val extraParticipants: Set<Participant> = extraRecipients.map { it -> Participant(it, it) }.toSet()
constructor(transaction: SignedTransaction, extraParticipants: Set<Party>) : this(listOf(transaction), extraParticipants, tracker()) constructor(transaction: SignedTransaction, extraParticipants: Set<Party>) : this(listOf(transaction), extraParticipants, tracker())
constructor(transaction: SignedTransaction) : this(listOf(transaction), emptySet(), tracker()) constructor(transaction: SignedTransaction) : this(listOf(transaction), emptySet(), tracker())
constructor(transaction: SignedTransaction, progressTracker: ProgressTracker) : this(listOf(transaction), emptySet(), progressTracker) constructor(transaction: SignedTransaction, progressTracker: ProgressTracker) : this(listOf(transaction), emptySet(), progressTracker)
@ -50,6 +53,9 @@ class FinalityFlow(val transactions: Iterable<SignedTransaction>,
fun tracker() = ProgressTracker(NOTARISING, BROADCASTING) fun tracker() = ProgressTracker(NOTARISING, BROADCASTING)
} }
open protected val me
get() = serviceHub.myInfo.legalIdentity
@Suspendable @Suspendable
@Throws(NotaryException::class) @Throws(NotaryException::class)
override fun call(): List<SignedTransaction> { override fun call(): List<SignedTransaction> {
@ -59,22 +65,35 @@ class FinalityFlow(val transactions: Iterable<SignedTransaction>,
// Lookup the resolved transactions and use them to map each signed transaction to the list of participants. // Lookup the resolved transactions and use them to map each signed transaction to the list of participants.
// Then send to the notary if needed, record locally and distribute. // Then send to the notary if needed, record locally and distribute.
progressTracker.currentStep = NOTARISING progressTracker.currentStep = NOTARISING
val notarisedTxns = notariseAndRecord(lookupParties(resolveDependenciesOf(transactions))) val notarisedTxns: List<Pair<SignedTransaction, Set<Participant>>> = resolveDependenciesOf(transactions)
.map { (stx, ltx) -> Pair(notariseAndRecord(stx), lookupParties(ltx)) }
// Each transaction has its own set of recipients, but extra recipients get them all. // Each transaction has its own set of recipients, but extra recipients get them all.
progressTracker.currentStep = BROADCASTING progressTracker.currentStep = BROADCASTING
val me = serviceHub.myInfo.legalIdentity
for ((stx, parties) in notarisedTxns) { for ((stx, parties) in notarisedTxns) {
subFlow(BroadcastTransactionFlow(stx, parties + extraRecipients - me)) broadcastTransaction(stx, (parties + extraParticipants).filter { it.wellKnown != me })
} }
return notarisedTxns.map { it.first } return notarisedTxns.map { it.first }
} }
// TODO: API: Make some of these protected? /**
* Broadcast a transaction to the participants. By default calls [BroadcastTransactionFlow], however can be
* overridden for more complex transaction delivery protocols (for example where not all parties know each other).
* This implementation will filter out any participants for whom there is no well known identity.
*
* @param participants the participants to send the transaction to. This is expected to include extra participants
* and exclude the local node.
*/
@Suspendable
open protected fun broadcastTransaction(stx: SignedTransaction, participants: Iterable<Participant>) {
val wellKnownParticipants = participants.map { it.wellKnown }.filterNotNull()
if (wellKnownParticipants.isNotEmpty()) {
subFlow(BroadcastTransactionFlow(stx, wellKnownParticipants.toNonEmptySet()))
}
}
@Suspendable @Suspendable
private fun notariseAndRecord(stxnsAndParties: List<Pair<SignedTransaction, Set<Party>>>): List<Pair<SignedTransaction, Set<Party>>> { private fun notariseAndRecord(stx: SignedTransaction): SignedTransaction {
return stxnsAndParties.map { (stx, parties) ->
val notarised = if (needsNotarySignature(stx)) { val notarised = if (needsNotarySignature(stx)) {
val notarySignatures = subFlow(NotaryFlow.Client(stx)) val notarySignatures = subFlow(NotaryFlow.Client(stx))
stx + notarySignatures stx + notarySignatures
@ -82,8 +101,7 @@ class FinalityFlow(val transactions: Iterable<SignedTransaction>,
stx stx
} }
serviceHub.recordTransactions(notarised) serviceHub.recordTransactions(notarised)
Pair(notarised, parties) return notarised
}
} }
private fun needsNotarySignature(stx: SignedTransaction): Boolean { private fun needsNotarySignature(stx: SignedTransaction): Boolean {
@ -99,14 +117,31 @@ class FinalityFlow(val transactions: Iterable<SignedTransaction>,
return !(notaryKey?.isFulfilledBy(signers) ?: false) return !(notaryKey?.isFulfilledBy(signers) ?: false)
} }
private fun lookupParties(ltxns: List<Pair<SignedTransaction, LedgerTransaction>>): List<Pair<SignedTransaction, Set<Party>>> { /**
return ltxns.map { (stx, ltx) -> * Resolve the parties involved in a transaction.
*
* @return the set of participants and their resolved well known identities (where known).
*/
open protected fun lookupParties(ltx: LedgerTransaction): Set<Participant> {
// Calculate who is meant to see the results based on the participants involved. // Calculate who is meant to see the results based on the participants involved.
val keys = ltx.outputs.flatMap { it.data.participants } + ltx.inputs.flatMap { it.state.data.participants } return extractParticipants(ltx)
// TODO: Is it safe to drop participants we don't know how to contact? Does not knowing how to contact them count as a reason to fail? .map(this::partyFromAnonymous)
val parties = keys.mapNotNull { serviceHub.identityService.partyFromAnonymous(it) }.toSet() .toSet()
Pair(stx, parties)
} }
/**
* Helper function to extract all participants from a ledger transaction. Intended to help implement [lookupParties]
* overriding functions.
*/
protected fun extractParticipants(ltx: LedgerTransaction): List<AbstractParty> {
return ltx.outputStates.flatMap { it.participants } + ltx.inputStates.flatMap { it.participants }
}
/**
* Helper function which wraps [IdentityService.partyFromAnonymous] so it can be called as a lambda function.
*/
protected fun partyFromAnonymous(anon: AbstractParty): Participant {
return Participant(anon, serviceHub.identityService.partyFromAnonymous(anon))
} }
private fun resolveDependenciesOf(signedTransactions: Iterable<SignedTransaction>): List<Pair<SignedTransaction, LedgerTransaction>> { private fun resolveDependenciesOf(signedTransactions: Iterable<SignedTransaction>): List<Pair<SignedTransaction, LedgerTransaction>> {
@ -125,10 +160,12 @@ class FinalityFlow(val transactions: Iterable<SignedTransaction>,
return sorted.map { stx -> return sorted.map { stx ->
val notary = stx.tx.notary val notary = stx.tx.notary
// The notary signature(s) are allowed to be missing but no others. // The notary signature(s) are allowed to be missing but no others.
val wtx = if (notary != null) stx.verifySignatures(notary.owningKey) else stx.verifySignatures() if (notary != null) stx.verifySignaturesExcept(notary.owningKey) else stx.verifyRequiredSignatures()
val ltx = wtx.toLedgerTransaction(augmentedLookup) val ltx = stx.toLedgerTransaction(augmentedLookup, false)
ltx.verify() ltx.verify()
stx to ltx stx to ltx
} }
} }
data class Participant(val participant: AbstractParty, val wellKnown: Party?)
} }

View File

@ -25,6 +25,6 @@ open class FlowException(message: String?, cause: Throwable?) : CordaException(m
* that we were not expecting), or the other side had an internal error, or the other side terminated when we * that we were not expecting), or the other side had an internal error, or the other side terminated when we
* were waiting for a response. * were waiting for a response.
*/ */
class FlowSessionException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) { class UnexpectedFlowEndException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) {
constructor(msg: String) : this(msg, null) constructor(msg: String) : this(msg, null)
} }

View File

@ -4,8 +4,10 @@ import co.paralleluniverse.fibers.Suspendable
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.FlowStateMachine import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.abbreviate
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.UntrustworthyData import net.corda.core.utilities.UntrustworthyData
@ -51,10 +53,15 @@ abstract class FlowLogic<out T> {
*/ */
val serviceHub: ServiceHub get() = stateMachine.serviceHub val serviceHub: ServiceHub get() = stateMachine.serviceHub
@Deprecated("This is no longer used and will be removed in a future release. If you are using this to communicate " + /**
"with the same party but for two different message streams, then the correct way of doing that is to use sub-flows", * Returns a [FlowContext] object describing the flow [otherParty] is using. With [FlowContext.flowVersion] it
level = DeprecationLevel.ERROR) * provides the necessary information needed for the evolution of flows and enabling backwards compatibility.
open fun getCounterpartyMarker(party: Party): Class<*> = javaClass *
* This method can be called before any send or receive has been done with [otherParty]. In such a case this will force
* them to start their flow.
*/
@Suspendable
fun getFlowContext(otherParty: Party): FlowContext = stateMachine.getFlowContext(otherParty, flowUsedForSessions)
/** /**
* Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response * Serializes and queues the given [payload] object for sending to the [otherParty]. Suspends until a response
@ -89,11 +96,6 @@ abstract class FlowLogic<out T> {
return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions) return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions)
} }
/** @see sendAndReceiveWithRetry */
internal inline fun <reified R : Any> sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData<R> {
return sendAndReceiveWithRetry(R::class.java, otherParty, payload)
}
/** /**
* Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received. * Similar to [sendAndReceive] but also instructs the `payload` to be redelivered until the expected message is received.
* *
@ -103,9 +105,8 @@ abstract class FlowLogic<out T> {
* oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered * oracle services. If one or more nodes in the service cluster go down mid-session, the message will be redelivered
* to a different one, so there is no need to wait until the initial node comes back up to obtain a response. * to a different one, so there is no need to wait until the initial node comes back up to obtain a response.
*/ */
@Suspendable internal inline fun <reified R : Any> sendAndReceiveWithRetry(otherParty: Party, payload: Any): UntrustworthyData<R> {
internal open fun <R : Any> sendAndReceiveWithRetry(receiveType: Class<R>, otherParty: Party, payload: Any): UntrustworthyData<R> { return stateMachine.sendAndReceive(R::class.java, otherParty, payload, flowUsedForSessions, true)
return stateMachine.sendAndReceive(receiveType, otherParty, payload, flowUsedForSessions, true)
} }
/** /**
@ -139,7 +140,7 @@ abstract class FlowLogic<out T> {
* network's event horizon time. * network's event horizon time.
*/ */
@Suspendable @Suspendable
open fun send(otherParty: Party, payload: Any) = stateMachine.send(otherParty, payload, flowUsedForSessions) open fun send(otherParty: Party, payload: Any): Unit = stateMachine.send(otherParty, payload, flowUsedForSessions)
/** /**
* Invokes the given subflow. This function returns once the subflow completes successfully with the result * Invokes the given subflow. This function returns once the subflow completes successfully with the result
@ -163,7 +164,7 @@ abstract class FlowLogic<out T> {
} }
logger.debug { "Calling subflow: $subLogic" } logger.debug { "Calling subflow: $subLogic" }
val result = subLogic.call() val result = subLogic.call()
logger.debug { "Subflow finished with result $result" } logger.debug { "Subflow finished with result ${result.toString().abbreviate(300)}" }
// It's easy to forget this when writing flows so we just step it to the DONE state when it completes. // It's easy to forget this when writing flows so we just step it to the DONE state when it completes.
subLogic.progressTracker?.currentStep = ProgressTracker.DONE subLogic.progressTracker?.currentStep = ProgressTracker.DONE
return result return result
@ -180,7 +181,9 @@ abstract class FlowLogic<out T> {
* @param extraAuditData in the audit log for this permission check these extra key value pairs will be recorded. * @param extraAuditData in the audit log for this permission check these extra key value pairs will be recorded.
*/ */
@Throws(FlowException::class) @Throws(FlowException::class)
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>) = stateMachine.checkFlowPermission(permissionName, extraAuditData) fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>) {
stateMachine.checkFlowPermission(permissionName, extraAuditData)
}
/** /**
@ -189,7 +192,9 @@ abstract class FlowLogic<out T> {
* @param comment a general human readable summary of the event. * @param comment a general human readable summary of the event.
* @param extraAuditData in the audit log for this permission check these extra key value pairs will be recorded. * @param extraAuditData in the audit log for this permission check these extra key value pairs will be recorded.
*/ */
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>) = stateMachine.recordAuditEvent(eventType, comment, extraAuditData) fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>) {
stateMachine.recordAuditEvent(eventType, comment, extraAuditData)
}
/** /**
* Override this to provide a [ProgressTracker]. If one is provided and stepped, the framework will do something * Override this to provide a [ProgressTracker]. If one is provided and stepped, the framework will do something
@ -230,6 +235,29 @@ abstract class FlowLogic<out T> {
@Suspendable @Suspendable
fun waitForLedgerCommit(hash: SecureHash): SignedTransaction = stateMachine.waitForLedgerCommit(hash, this) fun waitForLedgerCommit(hash: SecureHash): SignedTransaction = stateMachine.waitForLedgerCommit(hash, this)
/**
* Returns a shallow copy of the Quasar stack frames at the time of call to [flowStackSnapshot]. Use this to inspect
* what objects would be serialised at the time of call to a suspending action (e.g. send/receive).
* Note: This logic is only available during tests and is not meant to be used during the production deployment.
* Therefore the default implementationdoes nothing.
*/
@Suspendable
fun flowStackSnapshot(): FlowStackSnapshot? = stateMachine.flowStackSnapshot(this::class.java)
/**
* Persists a shallow copy of the Quasar stack frames at the time of call to [persistFlowStackSnapshot].
* Use this to track the monitor evolution of the quasar stack values during the flow execution.
* The flow stack snapshot is stored in a file located in {baseDir}/flowStackSnapshots/YYYY-MM-DD/{flowId}/
* where baseDir is the node running directory and flowId is the flow unique identifier generated by the platform.
*
* Note: With respect to the [flowStackSnapshot], the snapshot being persisted by this method is partial,
* meaning that only flow relevant traces and local variables are persisted.
* Also, this logic is only available during tests and is not meant to be used during the production deployment.
* Therefore the default implementation does nothing.
*/
@Suspendable
fun persistFlowStackSnapshot(): Unit = stateMachine.persistFlowStackSnapshot(this::class.java)
//////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////
private var _stateMachine: FlowStateMachine<*>? = null private var _stateMachine: FlowStateMachine<*>? = null
@ -261,3 +289,20 @@ abstract class FlowLogic<out T> {
} }
} }
} }
/**
* Version and name of the CorDapp hosting the other side of the flow.
*/
@CordaSerializable
data class FlowContext(
/**
* The integer flow version the other side is using.
* @see InitiatingFlow
*/
val flowVersion: Int,
/**
* Name of the CorDapp jar hosting the flow, without the .jar extension. It will include a unique identifier
* to deduplicate it from other releases of the same CorDapp, typically a version string. See the
* [CorDapp JAR format](https://docs.corda.net/cordapp-build-systems.html#cordapp-jar-format) for more details.
*/
val appName: String)

View File

@ -0,0 +1,73 @@
package net.corda.core.flows
import net.corda.core.utilities.loggerFor
import java.nio.file.Path
import java.util.*
interface FlowStackSnapshotFactory {
private object Holder {
val INSTANCE: FlowStackSnapshotFactory
init {
val serviceFactory = ServiceLoader.load(FlowStackSnapshotFactory::class.java).singleOrNull()
INSTANCE = serviceFactory ?: FlowStackSnapshotDefaultFactory()
}
}
companion object {
val instance: FlowStackSnapshotFactory by lazy { Holder.INSTANCE }
}
/**
* Returns flow stack data snapshot extracted from Quasar stack.
* It is designed to be used in the debug mode of the flow execution.
* Note. This logic is only available during tests and is not meant to be used during the production deployment.
* Therefore the default implementation does nothing.
*/
fun getFlowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot?
/** Stores flow stack snapshot as a json file. The stored shapshot is only partial and consists
* only data (i.e. stack traces and local variables values) relevant to the flow. It does not
* persist corda internal data (e.g. FlowStateMachine). Instead it uses [StackFrameDataToken] to indicate
* the class of the element on the stack.
* The flow stack snapshot is stored in a file located in
* {baseDir}/flowStackSnapshots/YYYY-MM-DD/{flowId}/
* where baseDir is the node running directory and flowId is the flow unique identifier generated by the platform.
* Note. This logic is only available during tests and is not meant to be used during the production deployment.
* Therefore the default implementation does nothing.
*/
fun persistAsJsonFile(flowClass: Class<*>, baseDir: Path, flowId: String): Unit
}
private class FlowStackSnapshotDefaultFactory : FlowStackSnapshotFactory {
val log = loggerFor<FlowStackSnapshotDefaultFactory>()
override fun getFlowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot? {
log.warn("Flow stack snapshot are not supposed to be used in a production deployment")
return null
}
override fun persistAsJsonFile(flowClass: Class<*>, baseDir: Path, flowId: String) {
log.warn("Flow stack snapshot are not supposed to be used in a production deployment")
}
}
/**
* Main data object representing snapshot of the flow stack, extracted from the Quasar stack.
*/
data class FlowStackSnapshot constructor(
val timestamp: Long = System.currentTimeMillis(),
val flowClass: Class<*>? = null,
val stackFrames: List<Frame> = listOf()
) {
data class Frame(
val stackTraceElement: StackTraceElement? = null, // This should be the call that *pushed* the frame of [objects]
val stackObjects: List<Any?> = listOf()
)
}
/**
* Token class, used to indicate stack presence of the corda internal data. Since this data is of no use for
* a CordApp developer, it is skipped from serialisation and its presence is only marked by this token.
*/
data class StackFrameDataToken(val className: String)

View File

@ -0,0 +1,20 @@
package net.corda.core.flows
import net.corda.core.identity.Party
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
/**
* Alternative finality flow which only does not attempt to take participants from the transaction, but instead all
* participating parties must be provided manually.
*
* @param transactions What to commit.
* @param extraRecipients A list of additional participants to inform of the transaction.
*/
class ManualFinalityFlow(transactions: Iterable<SignedTransaction>,
recipients: Set<Party>,
progressTracker: ProgressTracker) : FinalityFlow(transactions, recipients, progressTracker) {
constructor(transaction: SignedTransaction, extraParticipants: Set<Party>) : this(listOf(transaction), extraParticipants, tracker())
override fun lookupParties(ltx: LedgerTransaction): Set<Participant> = emptySet()
}

View File

@ -0,0 +1,59 @@
package net.corda.core.flows
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SignableData
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.identity.Party
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker
/**
* A flow to be used for changing a state's Notary. This is required since all input states to a transaction
* must point to the same notary.
*
* This assembles the transaction for notary replacement and sends out change proposals to all participants
* of that state. If participants agree to the proposed change, they each sign the transaction.
* Finally, the transaction containing all signatures is sent back to each participant so they can record it and
* use the new updated state for future transactions.
*/
@InitiatingFlow
class NotaryChangeFlow<out T : ContractState>(
originalState: StateAndRef<T>,
newNotary: Party,
progressTracker: ProgressTracker = AbstractStateReplacementFlow.Instigator.tracker())
: AbstractStateReplacementFlow.Instigator<T, T, Party>(originalState, newNotary, progressTracker) {
override fun assembleTx(): AbstractStateReplacementFlow.UpgradeTx {
val inputs = resolveEncumbrances(originalState)
val tx = NotaryChangeWireTransaction(
inputs.map { it.ref },
originalState.state.notary,
modification
)
val participantKeys = inputs.flatMap { it.state.data.participants }.map { it.owningKey }.toSet()
// TODO: We need a much faster way of finding our key in the transaction
val myKey = serviceHub.keyManagementService.filterMyKeys(participantKeys).single()
val signableData = SignableData(tx.id, SignatureMetadata(serviceHub.myInfo.platformVersion, Crypto.findSignatureScheme(myKey).schemeNumberID))
val mySignature = serviceHub.keyManagementService.sign(signableData, myKey)
val stx = SignedTransaction(tx, listOf(mySignature))
return AbstractStateReplacementFlow.UpgradeTx(stx, participantKeys, myKey)
}
/** Resolves the encumbrance state chain for the given [state] */
private fun resolveEncumbrances(state: StateAndRef<T>): List<StateAndRef<T>> {
val states = mutableListOf(state)
while (states.last().state.encumbrance != null) {
val encumbranceStateRef = StateRef(states.last().ref.txhash, states.last().state.encumbrance!!)
val encumbranceState = serviceHub.toStateAndRef<T>(encumbranceStateRef)
states.add(encumbranceState)
}
return states
}
}

View File

@ -1,21 +1,22 @@
package net.corda.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TimeWindow import net.corda.core.contracts.TimeWindow
import net.corda.core.crypto.DigitalSignature
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignedData import net.corda.core.crypto.SignedData
import net.corda.core.crypto.TransactionSignature
import net.corda.core.crypto.keys import net.corda.core.crypto.keys
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.services.* import net.corda.core.internal.FetchDataFlow
import net.corda.core.node.services.NotaryService
import net.corda.core.node.services.TrustedAuthorityNotaryService
import net.corda.core.node.services.UniquenessProvider
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import java.security.SignatureException
import java.util.function.Predicate import java.util.function.Predicate
object NotaryFlow { object NotaryFlow {
@ -31,8 +32,8 @@ object NotaryFlow {
*/ */
@InitiatingFlow @InitiatingFlow
open class Client(private val stx: SignedTransaction, open class Client(private val stx: SignedTransaction,
override val progressTracker: ProgressTracker) : FlowLogic<List<DigitalSignature.WithKey>>() { override val progressTracker: ProgressTracker) : FlowLogic<List<TransactionSignature>>() {
constructor(stx: SignedTransaction) : this(stx, Client.tracker()) constructor(stx: SignedTransaction) : this(stx, tracker())
companion object { companion object {
object REQUESTING : ProgressTracker.Step("Requesting signature by Notary service") object REQUESTING : ProgressTracker.Step("Requesting signature by Notary service")
@ -45,27 +46,36 @@ object NotaryFlow {
@Suspendable @Suspendable
@Throws(NotaryException::class) @Throws(NotaryException::class)
override fun call(): List<DigitalSignature.WithKey> { override fun call(): List<TransactionSignature> {
progressTracker.currentStep = REQUESTING progressTracker.currentStep = REQUESTING
val wtx = stx.tx
notaryParty = wtx.notary ?: throw IllegalStateException("Transaction does not specify a Notary") notaryParty = stx.notary ?: throw IllegalStateException("Transaction does not specify a Notary")
check(wtx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) { check(stx.inputs.all { stateRef -> serviceHub.loadState(stateRef).notary == notaryParty }) {
"Input states must have the same Notary" "Input states must have the same Notary"
} }
try {
stx.verifySignatures(notaryParty.owningKey)
} catch (ex: SignedTransaction.SignaturesMissingException) {
throw NotaryException(NotaryError.SignaturesMissing(ex))
}
val payload: Any = if (serviceHub.networkMapCache.isValidatingNotary(notaryParty)) { try {
stx if (stx.isNotaryChangeTransaction()) {
stx.resolveNotaryChangeTransaction(serviceHub).verifySignaturesExcept(notaryParty.owningKey)
} else { } else {
wtx.buildFilteredTransaction(Predicate { it is StateRef || it is TimeWindow }) stx.verifySignaturesExcept(notaryParty.owningKey)
}
} catch (ex: SignatureException) {
throw NotaryException(NotaryError.TransactionInvalid(ex))
} }
val response = try { val response = try {
sendAndReceiveWithRetry<List<DigitalSignature.WithKey>>(notaryParty, payload) if (serviceHub.networkMapCache.isValidatingNotary(notaryParty)) {
subFlow(SendTransactionWithRetry(notaryParty, stx))
receive<List<TransactionSignature>>(notaryParty)
} else {
val tx: Any = if (stx.isNotaryChangeTransaction()) {
stx.notaryChangeTx
} else {
stx.tx.buildFilteredTransaction(Predicate { it is StateRef || it is TimeWindow })
}
sendAndReceiveWithRetry(notaryParty, tx)
}
} catch (e: NotaryException) { } catch (e: NotaryException) {
if (e.error is NotaryError.Conflict) { if (e.error is NotaryError.Conflict) {
e.error.conflict.verified() e.error.conflict.verified()
@ -74,14 +84,14 @@ object NotaryFlow {
} }
return response.unwrap { signatures -> return response.unwrap { signatures ->
signatures.forEach { validateSignature(it, stx.id.bytes) } signatures.forEach { validateSignature(it, stx.id) }
signatures signatures
} }
} }
private fun validateSignature(sig: DigitalSignature.WithKey, data: ByteArray) { private fun validateSignature(sig: TransactionSignature, txId: SecureHash) {
check(sig.by in notaryParty.owningKey.keys) { "Invalid signer for the notary result" } check(sig.by in notaryParty.owningKey.keys) { "Invalid signer for the notary result" }
sig.verify(data) sig.verify(txId)
} }
} }
@ -114,7 +124,7 @@ object NotaryFlow {
@Suspendable @Suspendable
private fun signAndSendResponse(txId: SecureHash) { private fun signAndSendResponse(txId: SecureHash) {
val signature = service.sign(txId.bytes) val signature = service.sign(txId)
send(otherSide, listOf(signature)) send(otherSide, listOf(signature))
} }
} }
@ -137,10 +147,16 @@ sealed class NotaryError {
/** Thrown if the time specified in the [TimeWindow] command is outside the allowed tolerance. */ /** Thrown if the time specified in the [TimeWindow] command is outside the allowed tolerance. */
object TimeWindowInvalid : NotaryError() object TimeWindowInvalid : NotaryError()
data class TransactionInvalid(val msg: String) : NotaryError() data class TransactionInvalid(val cause: Throwable) : NotaryError() {
data class SignaturesInvalid(val msg: String) : NotaryError()
data class SignaturesMissing(val cause: SignedTransaction.SignaturesMissingException) : NotaryError() {
override fun toString() = cause.toString() override fun toString() = cause.toString()
} }
} }
/**
* The [SendTransactionWithRetry] flow is equivalent to [SendTransactionFlow] but using [sendAndReceiveWithRetry]
* instead of [sendAndReceive], [SendTransactionWithRetry] is intended to be use by the notary client only.
*/
private class SendTransactionWithRetry(otherSide: Party, stx: SignedTransaction) : SendTransactionFlow(otherSide, stx) {
@Suspendable
override fun sendPayloadAndReceiveDataRequest(otherSide: Party, payload: Any) = sendAndReceiveWithRetry<FetchDataFlow.Request>(otherSide, payload)
}

View File

@ -0,0 +1,48 @@
package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.*
import net.corda.core.identity.Party
import net.corda.core.internal.ResolveTransactionsFlow
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.unwrap
import java.security.SignatureException
/**
* The [ReceiveTransactionFlow] should be called in response to the [SendTransactionFlow].
*
* This flow is a combination of [receive], resolve and [SignedTransaction.verify]. This flow will receive the [SignedTransaction]
* and perform the resolution back-and-forth required to check the dependencies and download any missing attachments.
* The flow will return the [SignedTransaction] after it is resolved and then verified using [SignedTransaction.verify].
*/
class ReceiveTransactionFlow
@JvmOverloads
constructor(private val otherParty: Party, private val checkSufficientSignatures: Boolean = true) : FlowLogic<SignedTransaction>() {
@Suspendable
@Throws(SignatureException::class, AttachmentResolutionException::class, TransactionResolutionException::class, TransactionVerificationException::class)
override fun call(): SignedTransaction {
return receive<SignedTransaction>(otherParty).unwrap {
subFlow(ResolveTransactionsFlow(it, otherParty))
it.verify(serviceHub, checkSufficientSignatures)
it
}
}
}
/**
* The [ReceiveStateAndRefFlow] should be called in response to the [SendStateAndRefFlow].
*
* This flow is a combination of [receive] and resolve. This flow will receive a list of [StateAndRef]
* and perform the resolution back-and-forth required to check the dependencies.
* The flow will return the list of [StateAndRef] after it is resolved.
*/
// @JvmSuppressWildcards is used to suppress wildcards in return type when calling `subFlow(new ReceiveStateAndRef<T>(otherParty))` in java.
class ReceiveStateAndRefFlow<out T : ContractState>(private val otherParty: Party) : FlowLogic<@JvmSuppressWildcards List<StateAndRef<T>>>() {
@Suspendable
override fun call(): List<StateAndRef<T>> {
return receive<List<StateAndRef<T>>>(otherParty).unwrap {
subFlow(ResolveTransactionsFlow(it.map { it.ref.txhash }.toSet(), otherParty))
it
}
}
}

View File

@ -0,0 +1,69 @@
package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.StateAndRef
import net.corda.core.identity.Party
import net.corda.core.internal.FetchDataFlow
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.unwrap
/**
* The [SendTransactionFlow] should be used to send a transaction to another peer that wishes to verify that transaction's
* integrity by resolving and checking the dependencies as well. The other side should invoke [ReceiveTransactionFlow] at
* the right point in the conversation to receive the sent transaction and perform the resolution back-and-forth required
* to check the dependencies and download any missing attachments.
*
* @param otherSide the target party.
* @param stx the [SignedTransaction] being sent to the [otherSide].
*/
open class SendTransactionFlow(otherSide: Party, stx: SignedTransaction) : DataVendingFlow(otherSide, stx)
/**
* The [SendStateAndRefFlow] should be used to send a list of input [StateAndRef] to another peer that wishes to verify
* the input's integrity by resolving and checking the dependencies as well. The other side should invoke [ReceiveStateAndRefFlow]
* at the right point in the conversation to receive the input state and ref and perform the resolution back-and-forth
* required to check the dependencies.
*
* @param otherSide the target party.
* @param stateAndRefs the list of [StateAndRef] being sent to the [otherSide].
*/
open class SendStateAndRefFlow(otherSide: Party, stateAndRefs: List<StateAndRef<*>>) : DataVendingFlow(otherSide, stateAndRefs)
sealed class DataVendingFlow(val otherSide: Party, val payload: Any) : FlowLogic<Void?>() {
@Suspendable
protected open fun sendPayloadAndReceiveDataRequest(otherSide: Party, payload: Any) = sendAndReceive<FetchDataFlow.Request>(otherSide, payload)
@Suspendable
protected open fun verifyDataRequest(dataRequest: FetchDataFlow.Request.Data) {
// User can override this method to perform custom request verification.
}
@Suspendable
override fun call(): Void? {
// The first payload will be the transaction data, subsequent payload will be the transaction/attachment data.
var payload = payload
// This loop will receive [FetchDataFlow.Request] continuously until the `otherSide` has all the data they need
// to resolve the transaction, a [FetchDataFlow.EndRequest] will be sent from the `otherSide` to indicate end of
// data request.
while (true) {
val dataRequest = sendPayloadAndReceiveDataRequest(otherSide, payload).unwrap { request ->
when (request) {
is FetchDataFlow.Request.Data -> {
// Security TODO: Check for abnormally large or malformed data requests
verifyDataRequest(request)
request
}
FetchDataFlow.Request.End -> return null
}
}
payload = when (dataRequest.dataType) {
FetchDataFlow.DataType.TRANSACTION -> dataRequest.hashes.map {
serviceHub.validatedTransactions.getTransaction(it) ?: throw FetchDataFlow.HashNotFound(it)
}
FetchDataFlow.DataType.ATTACHMENT -> dataRequest.hashes.map {
serviceHub.attachments.openAttachment(it)?.open()?.readBytes() ?: throw FetchDataFlow.HashNotFound(it)
}
}
}
}
}

View File

@ -1,10 +1,10 @@
package net.corda.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.FlowLogic import net.corda.core.identity.AnonymousParty
import net.corda.core.flows.InitiatingFlow import net.corda.core.identity.PartyAndCertificate
import net.corda.core.flows.StartableByRPC
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.services.IdentityService
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
@ -16,33 +16,37 @@ import net.corda.core.utilities.unwrap
@InitiatingFlow @InitiatingFlow
class TransactionKeyFlow(val otherSide: Party, class TransactionKeyFlow(val otherSide: Party,
val revocationEnabled: Boolean, val revocationEnabled: Boolean,
override val progressTracker: ProgressTracker) : FlowLogic<LinkedHashMap<Party, AnonymisedIdentity>>() { override val progressTracker: ProgressTracker) : FlowLogic<LinkedHashMap<Party, AnonymousParty>>() {
constructor(otherSide: Party) : this(otherSide, false, tracker()) constructor(otherSide: Party) : this(otherSide, false, tracker())
companion object { companion object {
object AWAITING_KEY : ProgressTracker.Step("Awaiting key") object AWAITING_KEY : ProgressTracker.Step("Awaiting key")
fun tracker() = ProgressTracker(AWAITING_KEY) fun tracker() = ProgressTracker(AWAITING_KEY)
fun validateIdentity(otherSide: Party, anonymousOtherSide: AnonymisedIdentity): AnonymisedIdentity { fun validateAndRegisterIdentity(identityService: IdentityService, otherSide: Party, anonymousOtherSide: PartyAndCertificate): PartyAndCertificate {
require(anonymousOtherSide.certificate.subject == otherSide.name) require(anonymousOtherSide.name == otherSide.name)
// Validate then store their identity so that we can prove the key in the transaction is owned by the
// counterparty.
identityService.verifyAndRegisterIdentity(anonymousOtherSide)
return anonymousOtherSide return anonymousOtherSide
} }
} }
@Suspendable @Suspendable
override fun call(): LinkedHashMap<Party, AnonymisedIdentity> { override fun call(): LinkedHashMap<Party, AnonymousParty> {
progressTracker.currentStep = AWAITING_KEY progressTracker.currentStep = AWAITING_KEY
val legalIdentityAnonymous = serviceHub.keyManagementService.freshKeyAndCert(serviceHub.myInfo.legalIdentityAndCert, revocationEnabled) val legalIdentityAnonymous = serviceHub.keyManagementService.freshKeyAndCert(serviceHub.myInfo.legalIdentityAndCert, revocationEnabled)
serviceHub.identityService.registerAnonymousIdentity(legalIdentityAnonymous.identity, serviceHub.myInfo.legalIdentity, legalIdentityAnonymous.certPath)
// Special case that if we're both parties, a single identity is generated // Special case that if we're both parties, a single identity is generated
val identities = LinkedHashMap<Party, AnonymisedIdentity>() val identities = LinkedHashMap<Party, AnonymousParty>()
if (otherSide == serviceHub.myInfo.legalIdentity) { if (otherSide == serviceHub.myInfo.legalIdentity) {
identities.put(otherSide, legalIdentityAnonymous) identities.put(otherSide, legalIdentityAnonymous.party.anonymise())
} else { } else {
val otherSideAnonymous = sendAndReceive<AnonymisedIdentity>(otherSide, legalIdentityAnonymous).unwrap { validateIdentity(otherSide, it) } val anonymousOtherSide = sendAndReceive<PartyAndCertificate>(otherSide, legalIdentityAnonymous).unwrap { confidentialIdentity ->
identities.put(serviceHub.myInfo.legalIdentity, legalIdentityAnonymous) validateAndRegisterIdentity(serviceHub.identityService, otherSide, confidentialIdentity)
identities.put(otherSide, otherSideAnonymous) }
identities.put(serviceHub.myInfo.legalIdentity, legalIdentityAnonymous.party.anonymise())
identities.put(otherSide, anonymousOtherSide.party.anonymise())
} }
return identities return identities
} }

View File

@ -1,16 +0,0 @@
package net.corda.flows
import net.corda.core.identity.AnonymousParty
import net.corda.core.serialization.CordaSerializable
import org.bouncycastle.cert.X509CertificateHolder
import java.security.PublicKey
import java.security.cert.CertPath
@CordaSerializable
data class AnonymisedIdentity(
val certPath: CertPath,
val certificate: X509CertificateHolder,
val identity: AnonymousParty) {
constructor(certPath: CertPath, certificate: X509CertificateHolder, identity: PublicKey)
: this(certPath, certificate, AnonymousParty(identity))
}

View File

@ -2,6 +2,7 @@ package net.corda.core.identity
import net.corda.core.contracts.PartyAndReference import net.corda.core.contracts.PartyAndReference
import net.corda.core.crypto.toBase58String import net.corda.core.crypto.toBase58String
import net.corda.core.crypto.toStringShort
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import java.security.PublicKey import java.security.PublicKey
@ -13,7 +14,7 @@ import java.security.PublicKey
class AnonymousParty(owningKey: PublicKey) : AbstractParty(owningKey) { class AnonymousParty(owningKey: PublicKey) : AbstractParty(owningKey) {
// Use the key as the bulk of the toString(), but include a human readable identifier as well, so that [Party] // Use the key as the bulk of the toString(), but include a human readable identifier as well, so that [Party]
// can put in the key and actual name // can put in the key and actual name
override fun toString() = "${owningKey.toBase58String()} <Anonymous>" override fun toString() = "${owningKey.toStringShort()} <Anonymous>"
override fun nameOrNull(): X500Name? = null override fun nameOrNull(): X500Name? = null

View File

@ -30,5 +30,6 @@ class Party(val name: X500Name, owningKey: PublicKey) : AbstractParty(owningKey)
override fun toString() = name.toString() override fun toString() = name.toString()
override fun nameOrNull(): X500Name? = name override fun nameOrNull(): X500Name? = name
fun anonymise(): AnonymousParty = AnonymousParty(owningKey)
override fun ref(bytes: OpaqueBytes): PartyAndReference = PartyAndReference(this, bytes) override fun ref(bytes: OpaqueBytes): PartyAndReference = PartyAndReference(this, bytes)
} }

View File

@ -4,12 +4,14 @@ import net.corda.core.serialization.CordaSerializable
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.cert.X509CertificateHolder import org.bouncycastle.cert.X509CertificateHolder
import java.security.PublicKey import java.security.PublicKey
import java.security.cert.CertPath import java.security.cert.*
import java.util.*
/** /**
* A full party plus the X.509 certificate and path linking the party back to a trust root. Equality of * A full party plus the X.509 certificate and path linking the party back to a trust root. Equality of
* [PartyAndCertificate] instances is based on the party only, as certificate and path are data associated with the party, * [PartyAndCertificate] instances is based on the party only, as certificate and path are data associated with the party,
* not part of the identifier themselves. * not part of the identifier themselves. While party and certificate can both be derived from the certificate path,
* this class exists in order to ensure the implementation classes of certificates and party public keys are kept stable.
*/ */
@CordaSerializable @CordaSerializable
data class PartyAndCertificate(val party: Party, data class PartyAndCertificate(val party: Party,
@ -30,4 +32,18 @@ data class PartyAndCertificate(val party: Party,
override fun hashCode(): Int = party.hashCode() override fun hashCode(): Int = party.hashCode()
override fun toString(): String = party.toString() override fun toString(): String = party.toString()
/**
* Verify that the given certificate path is valid and leads to the owning key of the party.
*/
fun verify(trustAnchor: TrustAnchor): PKIXCertPathValidatorResult {
require(certPath.certificates.first() is X509Certificate) { "Subject certificate must be an X.509 certificate" }
require(Arrays.equals(party.owningKey.encoded, certificate.subjectPublicKeyInfo.encoded)) { "Certificate public key must match party owning key" }
require(Arrays.equals(certPath.certificates.first().encoded, certificate.encoded)) { "Certificate path must link to certificate" }
val validatorParameters = PKIXParameters(setOf(trustAnchor))
val validator = CertPathValidator.getInstance("PKIX")
validatorParameters.isRevocationEnabled = false
return validator.validate(certPath, validatorParameters) as PKIXCertPathValidatorResult
}
} }

View File

@ -1,6 +1,4 @@
package net.corda.core.utilities package net.corda.core.internal
import net.corda.core.codePointsString
/** /**
* A simple wrapper class that contains icons and support for printing them only when we're connected to a terminal. * A simple wrapper class that contains icons and support for printing them only when we're connected to a terminal.
@ -29,6 +27,9 @@ object Emoji {
@JvmStatic val CODE_BOOKS: String = codePointsString(0x1F4DA) @JvmStatic val CODE_BOOKS: String = codePointsString(0x1F4DA)
@JvmStatic val CODE_SLEEPING_FACE: String = codePointsString(0x1F634) @JvmStatic val CODE_SLEEPING_FACE: String = codePointsString(0x1F634)
@JvmStatic val CODE_LIGHTBULB: String = codePointsString(0x1F4A1) @JvmStatic val CODE_LIGHTBULB: String = codePointsString(0x1F4A1)
@JvmStatic val CODE_FREE: String = codePointsString(0x1F193)
@JvmStatic val CODE_SOON: String = codePointsString(0x1F51C)
/** /**
* When non-null, toString() methods are allowed to use emoji in the output as we're going to render them to a * When non-null, toString() methods are allowed to use emoji in the output as we're going to render them to a
@ -46,6 +47,8 @@ object Emoji {
val books: String get() = if (emojiMode.get() != null) "$CODE_BOOKS " else "" val books: String get() = if (emojiMode.get() != null) "$CODE_BOOKS " else ""
val sleepingFace: String get() = if (emojiMode.get() != null) "$CODE_SLEEPING_FACE " else "" val sleepingFace: String get() = if (emojiMode.get() != null) "$CODE_SLEEPING_FACE " else ""
val lightBulb: String get() = if (emojiMode.get() != null) "$CODE_LIGHTBULB " else "" val lightBulb: String get() = if (emojiMode.get() != null) "$CODE_LIGHTBULB " else ""
val free: String get() = if (emojiMode.get() != null) "$CODE_FREE " else ""
val soon: String get() = if (emojiMode.get() != null) "$CODE_SOON " else ""
// These have old/non-emoji symbols with better platform support. // These have old/non-emoji symbols with better platform support.
val greenTick: String get() = if (emojiMode.get() != null) "$CODE_GREEN_TICK " else "" val greenTick: String get() = if (emojiMode.get() != null) "$CODE_GREEN_TICK " else ""
@ -79,4 +82,9 @@ object Emoji {
} }
} }
private fun codePointsString(vararg codePoints: Int): String {
val builder = StringBuilder()
codePoints.forEach { builder.append(Character.toChars(it)) }
return builder.toString()
}
} }

View File

@ -0,0 +1,179 @@
package net.corda.core.internal
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.contracts.AbstractAttachment
import net.corda.core.contracts.Attachment
import net.corda.core.contracts.NamedByHash
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256
import net.corda.core.flows.FlowException
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.Party
import net.corda.core.internal.FetchDataFlow.DownloadedVsRequestedDataMismatch
import net.corda.core.internal.FetchDataFlow.HashNotFound
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationToken
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.UntrustworthyData
import net.corda.core.utilities.unwrap
import java.util.*
/**
* An abstract flow for fetching typed data from a remote peer.
*
* Given a set of hashes (IDs), either loads them from local disk or asks the remote peer to provide them.
*
* A malicious response in which the data provided by the remote peer does not hash to the requested hash results in
* [DownloadedVsRequestedDataMismatch] being thrown. If the remote peer doesn't have an entry, it results in a
* [HashNotFound] exception being thrown.
*
* By default this class does not insert data into any local database, if you want to do that after missing items were
* fetched then override [maybeWriteToDisk]. You *must* override [load]. If the wire type is not the same as the
* ultimate type, you must also override [convert].
*
* @param T The ultimate type of the data being fetched.
* @param W The wire type of the data being fetched, for when it isn't the same as the ultimate type.
*/
sealed class FetchDataFlow<T : NamedByHash, in W : Any>(
protected val requests: Set<SecureHash>,
protected val otherSide: Party,
protected val dataType: DataType) : FlowLogic<FetchDataFlow.Result<T>>() {
@CordaSerializable
class DownloadedVsRequestedDataMismatch(val requested: SecureHash, val got: SecureHash) : IllegalArgumentException()
@CordaSerializable
class DownloadedVsRequestedSizeMismatch(val requested: Int, val got: Int) : IllegalArgumentException()
class HashNotFound(val requested: SecureHash) : FlowException()
@CordaSerializable
data class Result<out T : NamedByHash>(val fromDisk: List<T>, val downloaded: List<T>)
@CordaSerializable
sealed class Request {
data class Data(val hashes: NonEmptySet<SecureHash>, val dataType: DataType) : Request()
object End : Request()
}
@CordaSerializable
enum class DataType {
TRANSACTION, ATTACHMENT
}
@Suspendable
@Throws(HashNotFound::class)
override fun call(): Result<T> {
// Load the items we have from disk and figure out which we're missing.
val (fromDisk, toFetch) = loadWhatWeHave()
return if (toFetch.isEmpty()) {
Result(fromDisk, emptyList())
} else {
logger.info("Requesting ${toFetch.size} dependency(s) for verification from ${otherSide.name}")
// TODO: Support "large message" response streaming so response sizes are not limited by RAM.
// We can then switch to requesting items in large batches to minimise the latency penalty.
// This is blocked by bugs ARTEMIS-1278 and ARTEMIS-1279. For now we limit attachments and txns to 10mb each
// and don't request items in batch, which is a performance loss, but works around the issue. We have
// configured Artemis to not fragment messages up to 10mb so we can send 10mb messages without problems.
// Above that, we start losing authentication data on the message fragments and take exceptions in the
// network layer.
val maybeItems = ArrayList<W>(toFetch.size)
for (hash in toFetch) {
// We skip the validation here (with unwrap { it }) because we will do it below in validateFetchResponse.
// The only thing checked is the object type. It is a protocol violation to send results out of order.
maybeItems += sendAndReceive<List<W>>(otherSide, Request.Data(NonEmptySet.of(hash), dataType)).unwrap { it }
}
// Check for a buggy/malicious peer answering with something that we didn't ask for.
val downloaded = validateFetchResponse(UntrustworthyData(maybeItems), toFetch)
logger.info("Fetched ${downloaded.size} elements from ${otherSide.name}")
maybeWriteToDisk(downloaded)
Result(fromDisk, downloaded)
}
}
protected open fun maybeWriteToDisk(downloaded: List<T>) {
// Do nothing by default.
}
private fun loadWhatWeHave(): Pair<List<T>, List<SecureHash>> {
val fromDisk = ArrayList<T>()
val toFetch = ArrayList<SecureHash>()
for (txid in requests) {
val stx = load(txid)
if (stx == null)
toFetch += txid
else
fromDisk += stx
}
return Pair(fromDisk, toFetch)
}
protected abstract fun load(txid: SecureHash): T?
@Suppress("UNCHECKED_CAST")
protected open fun convert(wire: W): T = wire as T
private fun validateFetchResponse(maybeItems: UntrustworthyData<ArrayList<W>>,
requests: List<SecureHash>): List<T> {
return maybeItems.unwrap { response ->
if (response.size != requests.size)
throw DownloadedVsRequestedSizeMismatch(requests.size, response.size)
val answers = response.map { convert(it) }
// Check transactions actually hash to what we requested, if this fails the remote node
// is a malicious flow violator or buggy.
for ((index, item) in answers.withIndex()) {
if (item.id != requests[index])
throw DownloadedVsRequestedDataMismatch(requests[index], item.id)
}
answers
}
}
}
/**
* Given a set of hashes either loads from from local storage or requests them from the other peer. Downloaded
* attachments are saved to local storage automatically.
*/
class FetchAttachmentsFlow(requests: Set<SecureHash>,
otherSide: Party) : FetchDataFlow<Attachment, ByteArray>(requests, otherSide, DataType.ATTACHMENT) {
override fun load(txid: SecureHash): Attachment? = serviceHub.attachments.openAttachment(txid)
override fun convert(wire: ByteArray): Attachment = FetchedAttachment({ wire })
override fun maybeWriteToDisk(downloaded: List<Attachment>) {
for (attachment in downloaded) {
serviceHub.attachments.importAttachment(attachment.open())
}
}
private class FetchedAttachment(dataLoader: () -> ByteArray) : AbstractAttachment(dataLoader), SerializeAsToken {
override val id: SecureHash by lazy { attachmentData.sha256() }
private class Token(private val id: SecureHash) : SerializationToken {
override fun fromToken(context: SerializeAsTokenContext) = FetchedAttachment(context.attachmentDataLoader(id))
}
override fun toToken(context: SerializeAsTokenContext) = Token(id)
}
}
/**
* Given a set of tx hashes (IDs), either loads them from local disk or asks the remote peer to provide them.
*
* A malicious response in which the data provided by the remote peer does not hash to the requested hash results in
* [FetchDataFlow.DownloadedVsRequestedDataMismatch] being thrown. If the remote peer doesn't have an entry, it
* results in a [FetchDataFlow.HashNotFound] exception. Note that returned transactions are not inserted into
* the database, because it's up to the caller to actually verify the transactions are valid.
*/
class FetchTransactionsFlow(requests: Set<SecureHash>, otherSide: Party) :
FetchDataFlow<SignedTransaction, SignedTransaction>(requests, otherSide, DataType.TRANSACTION) {
override fun load(txid: SecureHash): SignedTransaction? = serviceHub.validatedTransactions.getTransaction(txid)
}

View File

@ -1,10 +1,12 @@
package net.corda.core.internal package net.corda.core.internal
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import com.google.common.util.concurrent.ListenableFuture import net.corda.core.concurrent.CordaFuture
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowContext
import net.corda.core.flows.FlowInitiator import net.corda.core.flows.FlowInitiator
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowStackSnapshot
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
@ -14,6 +16,9 @@ import org.slf4j.Logger
/** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */ /** This is an internal interface that is implemented by code in the node module. You should look at [FlowLogic]. */
interface FlowStateMachine<R> { interface FlowStateMachine<R> {
@Suspendable
fun getFlowContext(otherParty: Party, sessionFlow: FlowLogic<*>): FlowContext
@Suspendable @Suspendable
fun <T : Any> sendAndReceive(receiveType: Class<T>, fun <T : Any> sendAndReceive(receiveType: Class<T>,
otherParty: Party, otherParty: Party,
@ -25,18 +30,37 @@ interface FlowStateMachine<R> {
fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T> fun <T : Any> receive(receiveType: Class<T>, otherParty: Party, sessionFlow: FlowLogic<*>): UntrustworthyData<T>
@Suspendable @Suspendable
fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>) fun send(otherParty: Party, payload: Any, sessionFlow: FlowLogic<*>): Unit
@Suspendable @Suspendable
fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction fun waitForLedgerCommit(hash: SecureHash, sessionFlow: FlowLogic<*>): SignedTransaction
fun checkFlowPermission(permissionName: String, extraAuditData: Map<String,String>) fun checkFlowPermission(permissionName: String, extraAuditData: Map<String, String>): Unit
fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String,String>) fun recordAuditEvent(eventType: String, comment: String, extraAuditData: Map<String, String>): Unit
/**
* Returns a shallow copy of the Quasar stack frames at the time of call to [flowStackSnapshot]. Use this to inspect
* what objects would be serialised at the time of call to a suspending action (e.g. send/receive).
*/
@Suspendable
fun flowStackSnapshot(flowClass: Class<*>): FlowStackSnapshot?
/**
* Persists a shallow copy of the Quasar stack frames at the time of call to [persistFlowStackSnapshot].
* Use this to track the monitor evolution of the quasar stack values during the flow execution.
* The flow stack snapshot is stored in a file located in {baseDir}/flowStackSnapshots/YYYY-MM-DD/{flowId}/
* where baseDir is the node running directory and flowId is the flow unique identifier generated by the platform.
*
* Note: With respect to the [flowStackSnapshot], the snapshot being persisted by this method is partial,
* meaning that only flow relevant traces and local variables are persisted.
*/
@Suspendable
fun persistFlowStackSnapshot(flowClass: Class<*>): Unit
val serviceHub: ServiceHub val serviceHub: ServiceHub
val logger: Logger val logger: Logger
val id: StateMachineRunId val id: StateMachineRunId
val resultFuture: ListenableFuture<R> val resultFuture: CordaFuture<R>
val flowInitiator: FlowInitiator val flowInitiator: FlowInitiator
} }

View File

@ -0,0 +1,258 @@
package net.corda.core.internal
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.sha256
import org.slf4j.Logger
import rx.Observable
import rx.Observer
import rx.subjects.PublishSubject
import rx.subjects.UnicastSubject
import java.io.*
import java.lang.reflect.Field
import java.math.BigDecimal
import java.nio.charset.Charset
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.*
import java.nio.file.attribute.FileAttribute
import java.time.Duration
import java.time.temporal.Temporal
import java.util.*
import java.util.Spliterator.*
import java.util.stream.IntStream
import java.util.stream.Stream
import java.util.stream.StreamSupport
import java.util.zip.Deflater
import java.util.zip.ZipEntry
import java.util.zip.ZipOutputStream
import kotlin.reflect.KClass
val Throwable.rootCause: Throwable get() = cause?.rootCause ?: this
fun Throwable.getStackTraceAsString() = StringWriter().also { printStackTrace(PrintWriter(it)) }.toString()
infix fun Temporal.until(endExclusive: Temporal): Duration = Duration.between(this, endExclusive)
operator fun Duration.div(divider: Long): Duration = dividedBy(divider)
operator fun Duration.times(multiplicand: Long): Duration = multipliedBy(multiplicand)
/**
* Allows you to write code like: Paths.get("someDir") / "subdir" / "filename" but using the Paths API to avoid platform
* separator problems.
*/
operator fun Path.div(other: String): Path = resolve(other)
operator fun String.div(other: String): Path = Paths.get(this) / other
/**
* Returns the single element matching the given [predicate], or `null` if the collection is empty, or throws exception
* if more than one element was found.
*/
inline fun <T> Iterable<T>.noneOrSingle(predicate: (T) -> Boolean): T? {
val iterator = iterator()
for (single in iterator) {
if (predicate(single)) {
while (iterator.hasNext()) {
if (predicate(iterator.next())) throw IllegalArgumentException("Collection contains more than one matching element.")
}
return single
}
}
return null
}
/**
* Returns the single element, or `null` if the list is empty, or throws an exception if it has more than one element.
*/
fun <T> List<T>.noneOrSingle(): T? {
return when (size) {
0 -> null
1 -> this[0]
else -> throw IllegalArgumentException("List has more than one element.")
}
}
/** Returns a random element in the list, or `null` if empty */
fun <T> List<T>.randomOrNull(): T? {
return when (size) {
0 -> null
1 -> this[0]
else -> this[(Math.random() * size).toInt()]
}
}
/** Returns the index of the given item or throws [IllegalArgumentException] if not found. */
fun <T> List<T>.indexOfOrThrow(item: T): Int {
val i = indexOf(item)
require(i != -1)
return i
}
fun Path.createDirectory(vararg attrs: FileAttribute<*>): Path = Files.createDirectory(this, *attrs)
fun Path.createDirectories(vararg attrs: FileAttribute<*>): Path = Files.createDirectories(this, *attrs)
fun Path.exists(vararg options: LinkOption): Boolean = Files.exists(this, *options)
fun Path.copyToDirectory(targetDir: Path, vararg options: CopyOption): Path {
require(targetDir.isDirectory()) { "$targetDir is not a directory" }
val targetFile = targetDir.resolve(fileName)
Files.copy(this, targetFile, *options)
return targetFile
}
fun Path.moveTo(target: Path, vararg options: CopyOption): Path = Files.move(this, target, *options)
fun Path.isRegularFile(vararg options: LinkOption): Boolean = Files.isRegularFile(this, *options)
fun Path.isDirectory(vararg options: LinkOption): Boolean = Files.isDirectory(this, *options)
inline val Path.size: Long get() = Files.size(this)
inline fun <R> Path.list(block: (Stream<Path>) -> R): R = Files.list(this).use(block)
fun Path.deleteIfExists(): Boolean = Files.deleteIfExists(this)
fun Path.readAll(): ByteArray = Files.readAllBytes(this)
inline fun <R> Path.read(vararg options: OpenOption, block: (InputStream) -> R): R = Files.newInputStream(this, *options).use(block)
inline fun Path.write(createDirs: Boolean = false, vararg options: OpenOption = emptyArray(), block: (OutputStream) -> Unit) {
if (createDirs) {
normalize().parent?.createDirectories()
}
Files.newOutputStream(this, *options).use(block)
}
inline fun <R> Path.readLines(charset: Charset = UTF_8, block: (Stream<String>) -> R): R = Files.lines(this, charset).use(block)
fun Path.readAllLines(charset: Charset = UTF_8): List<String> = Files.readAllLines(this, charset)
fun Path.writeLines(lines: Iterable<CharSequence>, charset: Charset = UTF_8, vararg options: OpenOption): Path = Files.write(this, lines, charset, *options)
fun InputStream.copyTo(target: Path, vararg options: CopyOption): Long = Files.copy(this, target, *options)
fun String.abbreviate(maxWidth: Int): String = if (length <= maxWidth) this else take(maxWidth - 1) + ""
/** Return the sum of an Iterable of [BigDecimal]s. */
fun Iterable<BigDecimal>.sum(): BigDecimal = fold(BigDecimal.ZERO) { a, b -> a + b }
/**
* Returns an Observable that buffers events until subscribed.
* @see UnicastSubject
*/
fun <T> Observable<T>.bufferUntilSubscribed(): Observable<T> {
val subject = UnicastSubject.create<T>()
val subscription = subscribe(subject)
return subject.doOnUnsubscribe { subscription.unsubscribe() }
}
/** Copy an [Observer] to multiple other [Observer]s. */
fun <T> Observer<T>.tee(vararg teeTo: Observer<T>): Observer<T> {
val subject = PublishSubject.create<T>()
subject.subscribe(this)
teeTo.forEach { subject.subscribe(it) }
return subject
}
/** Executes the given code block and returns a [Duration] of how long it took to execute in nanosecond precision. */
inline fun elapsedTime(block: () -> Unit): Duration {
val start = System.nanoTime()
block()
val end = System.nanoTime()
return Duration.ofNanos(end - start)
}
fun <T> Logger.logElapsedTime(label: String, body: () -> T): T = logElapsedTime(label, this, body)
// TODO: Add inline back when a new Kotlin version is released and check if the java.lang.VerifyError
// returns in the IRSSimulationTest. If not, commit the inline back.
fun <T> logElapsedTime(label: String, logger: Logger? = null, body: () -> T): T {
// Use nanoTime as it's monotonic.
val now = System.nanoTime()
try {
return body()
} finally {
val elapsed = Duration.ofNanos(System.nanoTime() - now).toMillis()
if (logger != null)
logger.info("$label took $elapsed msec")
else
println("$label took $elapsed msec")
}
}
/** Convert a [ByteArrayOutputStream] to [InputStreamAndHash]. */
fun ByteArrayOutputStream.toInputStreamAndHash(): InputStreamAndHash {
val bytes = toByteArray()
return InputStreamAndHash(ByteArrayInputStream(bytes), bytes.sha256())
}
data class InputStreamAndHash(val inputStream: InputStream, val sha256: SecureHash.SHA256) {
companion object {
/**
* Get a valid InputStream from an in-memory zip as required for some tests. The zip consists of a single file
* called "z" that contains the given content byte repeated the given number of times.
* Note that a slightly bigger than numOfExpectedBytes size is expected.
*/
fun createInMemoryTestZip(numOfExpectedBytes: Int, content: Byte): InputStreamAndHash {
require(numOfExpectedBytes > 0)
val baos = ByteArrayOutputStream()
ZipOutputStream(baos).use { zos ->
val arraySize = 1024
val bytes = ByteArray(arraySize) { content }
val n = (numOfExpectedBytes - 1) / arraySize + 1 // same as Math.ceil(numOfExpectedBytes/arraySize).
zos.setLevel(Deflater.NO_COMPRESSION)
zos.putNextEntry(ZipEntry("z"))
for (i in 0 until n) {
zos.write(bytes, 0, arraySize)
}
zos.closeEntry()
}
return baos.toInputStreamAndHash()
}
}
}
fun IntIterator.toJavaIterator(): PrimitiveIterator.OfInt {
return object : PrimitiveIterator.OfInt {
override fun nextInt() = this@toJavaIterator.nextInt()
override fun hasNext() = this@toJavaIterator.hasNext()
override fun remove() = throw UnsupportedOperationException("remove")
}
}
private fun IntProgression.toSpliterator(): Spliterator.OfInt {
val spliterator = Spliterators.spliterator(
iterator().toJavaIterator(),
(1 + (last - first) / step).toLong(),
SUBSIZED or IMMUTABLE or NONNULL or SIZED or ORDERED or SORTED or DISTINCT
)
return if (step > 0) spliterator else object : Spliterator.OfInt by spliterator {
override fun getComparator() = Comparator.reverseOrder<Int>()
}
}
fun IntProgression.stream(parallel: Boolean = false): IntStream = StreamSupport.intStream(toSpliterator(), parallel)
@Suppress("UNCHECKED_CAST") // When toArray has filled in the array, the component type is no longer T? but T (that may itself be nullable).
inline fun <reified T> Stream<out T>.toTypedArray() = toArray { size -> arrayOfNulls<T>(size) } as Array<T>
fun <T> Class<T>.castIfPossible(obj: Any): T? = if (isInstance(obj)) cast(obj) else null
/** Returns a [DeclaredField] wrapper around the declared (possibly non-public) static field of the receiver [Class]. */
fun <T> Class<*>.staticField(name: String): DeclaredField<T> = DeclaredField(this, name, null)
/** Returns a [DeclaredField] wrapper around the declared (possibly non-public) static field of the receiver [KClass]. */
fun <T> KClass<*>.staticField(name: String): DeclaredField<T> = DeclaredField(java, name, null)
/** Returns a [DeclaredField] wrapper around the declared (possibly non-public) instance field of the receiver object. */
fun <T> Any.declaredField(name: String): DeclaredField<T> = DeclaredField(javaClass, name, this)
/**
* Returns a [DeclaredField] wrapper around the (possibly non-public) instance field of the receiver object, but declared
* in its superclass [clazz].
*/
fun <T> Any.declaredField(clazz: KClass<*>, name: String): DeclaredField<T> = DeclaredField(clazz.java, name, this)
/**
* A simple wrapper around a [Field] object providing type safe read and write access using [value], ignoring the field's
* visibility.
*/
class DeclaredField<T>(clazz: Class<*>, name: String, private val receiver: Any?) {
private val javaField = clazz.getDeclaredField(name).apply { isAccessible = true }
var value: T
@Suppress("UNCHECKED_CAST")
get() = javaField.get(receiver) as T
set(value) = javaField.set(receiver, value)
}
/** The annotated object would have a more restricted visibility were it not needed in tests. */
@Target(AnnotationTarget.CLASS,
AnnotationTarget.PROPERTY,
AnnotationTarget.CONSTRUCTOR,
AnnotationTarget.FUNCTION,
AnnotationTarget.TYPEALIAS)
@Retention(AnnotationRetention.SOURCE)
@MustBeDocumented
annotation class VisibleForTesting

Some files were not shown because too many files have changed in this diff Show More