Cleaned up NonEmptySet and expanded its usage in the codebase

This commit is contained in:
Shams Asari
2017-07-11 19:36:56 +01:00
parent 0ec6f31f94
commit fa4577d236
21 changed files with 205 additions and 280 deletions

View File

@ -4,6 +4,7 @@ import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.toNonEmptySet
import java.security.PublicKey import java.security.PublicKey
/** Defines transaction build & validation logic for a specific transaction type */ /** Defines transaction build & validation logic for a specific transaction type */
@ -20,7 +21,7 @@ sealed class TransactionType {
fun verify(tx: LedgerTransaction) { fun verify(tx: LedgerTransaction) {
require(tx.notary != null || tx.timeWindow == null) { "Transactions with time-windows must be notarised" } require(tx.notary != null || tx.timeWindow == null) { "Transactions with time-windows must be notarised" }
val duplicates = detectDuplicateInputs(tx) val duplicates = detectDuplicateInputs(tx)
if (duplicates.isNotEmpty()) throw TransactionVerificationException.DuplicateInputStates(tx.id, duplicates) if (duplicates.isNotEmpty()) throw TransactionVerificationException.DuplicateInputStates(tx.id, duplicates.toNonEmptySet())
val missing = verifySigners(tx) val missing = verifySigners(tx)
if (missing.isNotEmpty()) throw TransactionVerificationException.SignersMissing(tx.id, missing.toList()) if (missing.isNotEmpty()) throw TransactionVerificationException.SignersMissing(tx.id, missing.toList())
verifyTransaction(tx) verifyTransaction(tx)
@ -51,7 +52,7 @@ sealed class TransactionType {
} }
/** /**
* Return the list of public keys that that require signatures for the transaction type. * Return the set of public keys that require signatures for the transaction type.
* Note: the notary key is checked separately for all transactions and need not be included. * Note: the notary key is checked separately for all transactions and need not be included.
*/ */
abstract fun getRequiredSigners(tx: LedgerTransaction): Set<PublicKey> abstract fun getRequiredSigners(tx: LedgerTransaction): Set<PublicKey>

View File

@ -4,6 +4,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.NonEmptySet
import java.security.PublicKey import java.security.PublicKey
import java.util.* import java.util.*
@ -107,7 +108,7 @@ sealed class TransactionVerificationException(val txId: SecureHash, cause: Throw
override fun toString(): String = "Signers missing: ${missing.joinToString()}" override fun toString(): String = "Signers missing: ${missing.joinToString()}"
} }
class DuplicateInputStates(txId: SecureHash, val duplicates: Set<StateRef>) : TransactionVerificationException(txId, null) { class DuplicateInputStates(txId: SecureHash, val duplicates: NonEmptySet<StateRef>) : TransactionVerificationException(txId, null) {
override fun toString(): String = "Duplicate inputs: ${duplicates.joinToString()}" override fun toString(): String = "Duplicate inputs: ${duplicates.joinToString()}"
} }

View File

@ -6,6 +6,7 @@ import net.corda.core.node.services.ServiceInfo
import net.corda.core.node.services.ServiceType import net.corda.core.node.services.ServiceType
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.NonEmptySet
/** /**
* Information for an advertised service including the service specific identity information. * Information for an advertised service including the service specific identity information.
@ -21,7 +22,7 @@ data class ServiceEntry(val info: ServiceInfo, val identity: PartyAndCertificate
@CordaSerializable @CordaSerializable
data class NodeInfo(val addresses: List<NetworkHostAndPort>, data class NodeInfo(val addresses: List<NetworkHostAndPort>,
val legalIdentityAndCert: PartyAndCertificate, //TODO This field will be removed in future PR which gets rid of services. val legalIdentityAndCert: PartyAndCertificate, //TODO This field will be removed in future PR which gets rid of services.
val legalIdentitiesAndCerts: Set<PartyAndCertificate>, val legalIdentitiesAndCerts: NonEmptySet<PartyAndCertificate>,
val platformVersion: Int, val platformVersion: Int,
var advertisedServices: List<ServiceEntry> = emptyList(), var advertisedServices: List<ServiceEntry> = emptyList(),
val worldMapLocation: WorldMapLocation? = null) { val worldMapLocation: WorldMapLocation? = null) {

View File

@ -17,11 +17,12 @@ import net.corda.core.node.services.vault.QueryCriteria
import net.corda.core.node.services.vault.Sort import net.corda.core.node.services.vault.Sort
import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE import net.corda.core.node.services.vault.DEFAULT_PAGE_SIZE
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.OpaqueBytes
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.OpaqueBytes
import net.corda.flows.AnonymisedIdentity import net.corda.flows.AnonymisedIdentity
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
@ -294,7 +295,7 @@ interface VaultService {
* @throws [StatesNotAvailableException] when not possible to softLock all of requested [StateRef] * @throws [StatesNotAvailableException] when not possible to softLock all of requested [StateRef]
*/ */
@Throws(StatesNotAvailableException::class) @Throws(StatesNotAvailableException::class)
fun softLockReserve(lockId: UUID, stateRefs: Set<StateRef>) fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet<StateRef>)
/** /**
* Release all or an explicitly specified set of [StateRef] for a given [UUID] unique identifier. * Release all or an explicitly specified set of [StateRef] for a given [UUID] unique identifier.
@ -303,7 +304,7 @@ interface VaultService {
* In the case of coin selection, softLock are automatically released once previously gathered unconsumed input refs * In the case of coin selection, softLock are automatically released once previously gathered unconsumed input refs
* are consumed as part of cash spending. * are consumed as part of cash spending.
*/ */
fun softLockRelease(lockId: UUID, stateRefs: Set<StateRef>? = null) fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>? = null)
/** /**
* Retrieve softLockStates for a given [UUID] or return all softLockStates in vault for a given * Retrieve softLockStates for a given [UUID] or return all softLockStates in vault for a given
@ -318,7 +319,11 @@ interface VaultService {
* is implemented in a separate module (finance) and requires access to it. * is implemented in a separate module (finance) and requires access to it.
*/ */
@Suspendable @Suspendable
fun <T : ContractState> unconsumedStatesForSpending(amount: Amount<Currency>, onlyFromIssuerParties: Set<AbstractParty>? = null, notary: Party? = null, lockId: UUID, withIssuerRefs: Set<OpaqueBytes>? = null): List<StateAndRef<T>> fun <T : ContractState> unconsumedStatesForSpending(amount: Amount<Currency>,
onlyFromIssuerParties: Set<AbstractParty>? = null,
notary: Party? = null,
lockId: UUID,
withIssuerRefs: Set<OpaqueBytes>? = null): List<StateAndRef<T>>
} }
// TODO: Remove this from the interface // TODO: Remove this from the interface

View File

@ -1,6 +1,9 @@
package net.corda.core.serialization package net.corda.core.serialization
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer
import com.esotericsoftware.kryo.serializers.FieldSerializer import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.MapReferenceResolver import com.esotericsoftware.kryo.util.MapReferenceResolver
@ -8,13 +11,13 @@ import de.javakaffee.kryoserializers.ArraysAsListSerializer
import de.javakaffee.kryoserializers.BitSetSerializer import de.javakaffee.kryoserializers.BitSetSerializer
import de.javakaffee.kryoserializers.UnmodifiableCollectionsSerializer import de.javakaffee.kryoserializers.UnmodifiableCollectionsSerializer
import de.javakaffee.kryoserializers.guava.* import de.javakaffee.kryoserializers.guava.*
import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.crypto.MetaData import net.corda.core.crypto.MetaData
import net.corda.core.crypto.composite.CompositeKey
import net.corda.core.node.CordaPluginRegistry import net.corda.core.node.CordaPluginRegistry
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.NonEmptySetSerializer import net.corda.core.utilities.toNonEmptySet
import net.i2p.crypto.eddsa.EdDSAPrivateKey import net.i2p.crypto.eddsa.EdDSAPrivateKey
import net.i2p.crypto.eddsa.EdDSAPublicKey import net.i2p.crypto.eddsa.EdDSAPublicKey
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
@ -36,6 +39,7 @@ import java.io.InputStream
import java.lang.reflect.Modifier.isPublic import java.lang.reflect.Modifier.isPublic
import java.security.cert.CertPath import java.security.cert.CertPath
import java.util.* import java.util.*
import kotlin.collections.ArrayList
object DefaultKryoCustomizer { object DefaultKryoCustomizer {
private val pluginRegistries: List<CordaPluginRegistry> by lazy { private val pluginRegistries: List<CordaPluginRegistry> by lazy {
@ -128,4 +132,22 @@ object DefaultKryoCustomizer {
return strat.newInstantiatorOf(type) return strat.newInstantiatorOf(type)
} }
} }
private object NonEmptySetSerializer : Serializer<NonEmptySet<Any>>() {
override fun write(kryo: Kryo, output: Output, obj: NonEmptySet<Any>) {
// Write out the contents as normal
output.writeInt(obj.size, true)
obj.forEach { kryo.writeClassAndObject(output, it) }
}
override fun read(kryo: Kryo, input: Input, type: Class<NonEmptySet<Any>>): NonEmptySet<Any> {
val size = input.readInt(true)
require(size >= 1) { "Invalid size read off the wire: $size" }
val list = ArrayList<Any>(size)
repeat(size) {
list += kryo.readClassAndObject(input)
}
return list.toNonEmptySet()
}
}
} }

View File

@ -7,10 +7,11 @@ import net.corda.core.contracts.TransactionVerificationException
import net.corda.core.crypto.DigitalSignature import net.corda.core.crypto.DigitalSignature
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.isFulfilledBy import net.corda.core.crypto.isFulfilledBy
import net.corda.core.crypto.keys
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.toNonEmptySet
import java.security.PublicKey import java.security.PublicKey
import java.security.SignatureException import java.security.SignatureException
import java.util.* import java.util.*
@ -50,11 +51,8 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
override val id: SecureHash get() = tx.id override val id: SecureHash get() = tx.id
@CordaSerializable @CordaSerializable
class SignaturesMissingException(val missing: Set<PublicKey>, val descriptions: List<String>, override val id: SecureHash) : NamedByHash, SignatureException() { class SignaturesMissingException(val missing: NonEmptySet<PublicKey>, val descriptions: List<String>, override val id: SecureHash)
override fun toString(): String { : NamedByHash, SignatureException("Missing signatures for $descriptions on transaction ${id.prefixChars()} for ${missing.joinToString()}")
return "Missing signatures for $descriptions on transaction ${id.prefixChars()} for ${missing.joinToString()}"
}
}
/** /**
* Verifies the signatures on this transaction and throws if any are missing which aren't passed as parameters. * Verifies the signatures on this transaction and throws if any are missing which aren't passed as parameters.
@ -80,7 +78,7 @@ data class SignedTransaction(val txBits: SerializedBytes<WireTransaction>,
val allowed = allowedToBeMissing.toSet() val allowed = allowedToBeMissing.toSet()
val needed = missing - allowed val needed = missing - allowed
if (needed.isNotEmpty()) if (needed.isNotEmpty())
throw SignaturesMissingException(needed, getMissingKeyDescriptions(needed), id) throw SignaturesMissingException(needed.toNonEmptySet(), getMissingKeyDescriptions(needed), id)
} }
check(tx.id == id) check(tx.id == id)
return tx return tx

View File

@ -19,3 +19,6 @@ inline fun Logger.trace(msg: () -> String) {
inline fun Logger.debug(msg: () -> String) { inline fun Logger.debug(msg: () -> String) {
if (isDebugEnabled) debug(msg()) if (isDebugEnabled) debug(msg())
} }
/** @see NonEmptySet.copyOf */
fun <T> Collection<T>.toNonEmptySet(): NonEmptySet<T> = NonEmptySet.copyOf(this)

View File

@ -1,117 +1,63 @@
package net.corda.core.utilities package net.corda.core.utilities
import com.esotericsoftware.kryo.Kryo import com.google.common.collect.Iterators
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import java.util.* import java.util.*
import java.util.function.Consumer
import java.util.stream.Stream
/** /**
* A set which is constrained to ensure it can never be empty. An initial value must be provided at * An immutable ordered non-empty set.
* construction, and attempting to remove the last element will cause an IllegalStateException.
* The underlying set is exposed for Kryo to access, but should not be accessed directly.
*/ */
class NonEmptySet<T>(initial: T) : MutableSet<T> { class NonEmptySet<T> private constructor(private val elements: Set<T>) : Set<T> by elements {
private val set: MutableSet<T> = HashSet() companion object {
/**
* Returns a singleton set containing [element]. This behaves the same as [Collections.singleton] but returns a
* [NonEmptySet] for the extra type-safety.
*/
@JvmStatic
fun <T> of(element: T): NonEmptySet<T> = NonEmptySet(Collections.singleton(element))
init { /** Returns a non-empty set containing the given elements, minus duplicates, in the order each was specified. */
set.add(initial) @JvmStatic
fun <T> of(first: T, second: T, vararg rest: T): NonEmptySet<T> {
val elements = LinkedHashSet<T>(rest.size + 2)
elements += first
elements += second
elements.addAll(rest)
return NonEmptySet(elements)
} }
override val size: Int /**
get() = set.size * Returns a non-empty set containing each of [elements], minus duplicates, in the order each appears first in
* the source collection.
* @throws IllegalArgumentException If [elements] is empty.
*/
@JvmStatic
fun <T> copyOf(elements: Collection<T>): NonEmptySet<T> {
if (elements is NonEmptySet) return elements
return when (elements.size) {
0 -> throw IllegalArgumentException("elements is empty")
1 -> of(elements.first())
else -> {
val copy = LinkedHashSet<T>(elements.size)
elements.forEach { copy += it } // Can't use Collection.addAll as it doesn't specify insertion order
NonEmptySet(copy)
}
}
}
}
override fun add(element: T): Boolean = set.add(element) /** Returns the first element of the set. */
override fun addAll(elements: Collection<T>): Boolean = set.addAll(elements) fun head(): T = elements.iterator().next()
override fun clear() = throw UnsupportedOperationException()
override fun contains(element: T): Boolean = set.contains(element)
override fun containsAll(elements: Collection<T>): Boolean = set.containsAll(elements)
override fun isEmpty(): Boolean = false override fun isEmpty(): Boolean = false
override fun iterator(): Iterator<T> = Iterators.unmodifiableIterator(elements.iterator())
override fun iterator(): MutableIterator<T> = Iterator(set.iterator()) // Following methods are not delegated by Kotlin's Class delegation
override fun forEach(action: Consumer<in T>) = elements.forEach(action)
override fun remove(element: T): Boolean = override fun stream(): Stream<T> = elements.stream()
// Test either there's more than one element, or the removal is a no-op override fun parallelStream(): Stream<T> = elements.parallelStream()
if (size > 1) override fun spliterator(): Spliterator<T> = elements.spliterator()
set.remove(element) override fun equals(other: Any?): Boolean = other === this || other == elements
else if (!contains(element)) override fun hashCode(): Int = elements.hashCode()
false override fun toString(): String = elements.toString()
else
throw IllegalStateException()
override fun removeAll(elements: Collection<T>): Boolean =
if (size > elements.size)
set.removeAll(elements)
else if (!containsAll(elements))
// Remove the common elements
set.removeAll(elements)
else
throw IllegalStateException()
override fun retainAll(elements: Collection<T>): Boolean {
val iterator = iterator()
val ret = false
// The iterator will throw an IllegalStateException if we try removing the last element
while (iterator.hasNext()) {
if (!elements.contains(iterator.next())) {
iterator.remove()
}
}
return ret
}
override fun equals(other: Any?): Boolean =
if (other is Set<*>)
// Delegate down to the wrapped set's equals() function
set == other
else
false
override fun hashCode(): Int = set.hashCode()
override fun toString(): String = set.toString()
inner class Iterator<out T>(val iterator: MutableIterator<T>) : MutableIterator<T> {
override fun hasNext(): Boolean = iterator.hasNext()
override fun next(): T = iterator.next()
override fun remove() =
if (set.size > 1)
iterator.remove()
else
throw IllegalStateException()
}
}
fun <T> nonEmptySetOf(initial: T, vararg elements: T): NonEmptySet<T> {
val set = NonEmptySet(initial)
// We add the first element twice, but it's a set, so who cares
set.addAll(elements)
return set
}
/**
* Custom serializer which understands it has to read in an item before
* trying to construct the set.
*/
object NonEmptySetSerializer : Serializer<NonEmptySet<Any>>() {
override fun write(kryo: Kryo, output: Output, obj: NonEmptySet<Any>) {
// Write out the contents as normal
output.writeInt(obj.size)
obj.forEach { kryo.writeClassAndObject(output, it) }
}
override fun read(kryo: Kryo, input: Input, type: Class<NonEmptySet<Any>>): NonEmptySet<Any> {
val size = input.readInt()
require(size >= 1) { "Size is positive" }
// TODO: Is there an upper limit we can apply to how big one of these could be?
val first = kryo.readClassAndObject(input)
// Read the first item and use it to construct the NonEmptySet
val set = NonEmptySet(first)
// Read in the rest of the set
for (i in 2..size) {
set.add(kryo.readClassAndObject(input))
}
return set
}
} }

View File

@ -6,6 +6,7 @@ import net.corda.core.flows.InitiatingFlow
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet
/** /**
* Notify the specified parties about a transaction. The remote peers will download this transaction and its * Notify the specified parties about a transaction. The remote peers will download this transaction and its
@ -18,7 +19,7 @@ import net.corda.core.transactions.SignedTransaction
*/ */
@InitiatingFlow @InitiatingFlow
class BroadcastTransactionFlow(val notarisedTransaction: SignedTransaction, class BroadcastTransactionFlow(val notarisedTransaction: SignedTransaction,
val participants: Set<Party>) : FlowLogic<Unit>() { val participants: NonEmptySet<Party>) : FlowLogic<Unit>() {
@CordaSerializable @CordaSerializable
data class NotifyTxRequest(val tx: SignedTransaction) data class NotifyTxRequest(val tx: SignedTransaction)

View File

@ -11,6 +11,7 @@ import net.corda.core.node.ServiceHub
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.toNonEmptySet
/** /**
* Verifies the given transactions, then sends them to the named notary. If the notary agrees that the transactions * Verifies the given transactions, then sends them to the named notary. If the notary agrees that the transactions
@ -65,7 +66,10 @@ class FinalityFlow(val transactions: Iterable<SignedTransaction>,
progressTracker.currentStep = BROADCASTING progressTracker.currentStep = BROADCASTING
val me = serviceHub.myInfo.legalIdentity val me = serviceHub.myInfo.legalIdentity
for ((stx, parties) in notarisedTxns) { for ((stx, parties) in notarisedTxns) {
subFlow(BroadcastTransactionFlow(stx, parties + extraRecipients - me)) val participants = parties + extraRecipients - me
if (participants.isNotEmpty()) {
subFlow(BroadcastTransactionFlow(stx, participants.toNonEmptySet()))
}
} }
return notarisedTxns.map { it.first } return notarisedTxns.map { it.first }
} }

View File

@ -4,119 +4,56 @@ import com.google.common.collect.testing.SetTestSuiteBuilder
import com.google.common.collect.testing.TestIntegerSetGenerator import com.google.common.collect.testing.TestIntegerSetGenerator
import com.google.common.collect.testing.features.CollectionFeature import com.google.common.collect.testing.features.CollectionFeature
import com.google.common.collect.testing.features.CollectionSize import com.google.common.collect.testing.features.CollectionSize
import com.google.common.collect.testing.testers.*
import junit.framework.TestSuite import junit.framework.TestSuite
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test import org.junit.Test
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.Suite import org.junit.runners.Suite
import kotlin.test.assertEquals
@RunWith(Suite::class) @RunWith(Suite::class)
@Suite.SuiteClasses( @Suite.SuiteClasses(
NonEmptySetTest.Guava::class, NonEmptySetTest.Guava::class,
NonEmptySetTest.Remove::class, NonEmptySetTest.General::class
NonEmptySetTest.Serializer::class
) )
class NonEmptySetTest { class NonEmptySetTest {
/** object Guava {
* Guava test suite generator for NonEmptySet.
*/
class Guava {
companion object {
@JvmStatic @JvmStatic
fun suite(): TestSuite fun suite(): TestSuite {
= SetTestSuiteBuilder return SetTestSuiteBuilder
.using(NonEmptySetGenerator()) .using(NonEmptySetGenerator)
.named("test NonEmptySet with several values") .named("Guava test suite")
.withFeatures( .withFeatures(
CollectionSize.SEVERAL, CollectionSize.SEVERAL,
CollectionFeature.ALLOWS_NULL_VALUES, CollectionFeature.ALLOWS_NULL_VALUES,
CollectionFeature.FAILS_FAST_ON_CONCURRENT_MODIFICATION, CollectionFeature.KNOWN_ORDER
CollectionFeature.GENERAL_PURPOSE
) )
// Kotlin throws the wrong exception in this cases
.suppressing(CollectionAddAllTester::class.java.getMethod("testAddAll_nullCollectionReference"))
// Disable tests that try to remove everything:
.suppressing(CollectionRemoveAllTester::class.java.getMethod("testRemoveAll_nullCollectionReferenceNonEmptySubject"))
.suppressing(CollectionClearTester::class.java.methods.toList())
.suppressing(CollectionRetainAllTester::class.java.methods.toList())
.suppressing(CollectionRemoveIfTester::class.java.getMethod("testRemoveIf_allPresent"))
.createTestSuite() .createTestSuite()
} }
/**
* For some reason IntelliJ really wants to scan this class for tests and fail when
* it doesn't find any. This stops that error from occurring.
*/
@Test fun dummy() {
}
} }
/** class General {
* Test removal, which Guava's standard tests can't cover for us.
*/
class Remove {
@Test @Test
fun `construction`() { fun `copyOf - empty source`() {
val expected = 17 assertThatThrownBy { NonEmptySet.copyOf(HashSet<Int>()) }.isInstanceOf(IllegalArgumentException::class.java)
val basicSet = nonEmptySetOf(expected)
val actual = basicSet.first()
assertEquals(expected, actual)
}
@Test(expected = IllegalStateException::class)
fun `remove sole element`() {
val basicSet = nonEmptySetOf(-17)
basicSet.remove(-17)
} }
@Test @Test
fun `remove one of two elements`() { fun head() {
val basicSet = nonEmptySetOf(-17, 17) assertThat(NonEmptySet.of(1, 2).head()).isEqualTo(1)
basicSet.remove(-17)
} }
@Test
fun `remove element which does not exist`() {
val basicSet = nonEmptySetOf(-17)
basicSet.remove(-5)
assertEquals(1, basicSet.size)
}
@Test(expected = IllegalStateException::class)
fun `remove via iterator`() {
val basicSet = nonEmptySetOf(-17, 17)
val iterator = basicSet.iterator()
while (iterator.hasNext()) {
iterator.remove()
}
}
}
/**
* Test serialization/deserialization.
*/
class Serializer {
@Test @Test
fun `serialize deserialize`() { fun `serialize deserialize`() {
val expected: NonEmptySet<Int> = nonEmptySetOf(-17, 22, 17) val original = NonEmptySet.of(-17, 22, 17)
val serialized = expected.serialize().bytes val copy = original.serialize().deserialize()
val actual = serialized.deserialize<NonEmptySet<Int>>() assertThat(copy).isEqualTo(original).isNotSameAs(original)
assertEquals(expected, actual)
} }
} }
}
private object NonEmptySetGenerator : TestIntegerSetGenerator() {
/** override fun create(elements: Array<out Int?>): NonEmptySet<Int?> = NonEmptySet.copyOf(elements.asList())
* Generator of non empty set instances needed for testing.
*/
class NonEmptySetGenerator : TestIntegerSetGenerator() {
override fun create(elements: Array<out Int?>?): NonEmptySet<Int?>? {
val set = nonEmptySetOf(elements!!.first())
set.addAll(elements.toList())
return set
} }
} }

View File

@ -4,15 +4,15 @@ import net.corda.contracts.Commodity
import net.corda.contracts.NetType import net.corda.contracts.NetType
import net.corda.contracts.asset.Obligation.Lifecycle import net.corda.contracts.asset.Obligation.Lifecycle
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.testing.contracts.DummyState
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.hours
import net.corda.core.crypto.testing.NULL_PARTY import net.corda.core.crypto.testing.NULL_PARTY
import net.corda.core.hours
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.AnonymousParty import net.corda.core.identity.AnonymousParty
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.*
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.contracts.DummyState
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
import org.junit.Test import org.junit.Test
import java.time.Duration import java.time.Duration
@ -28,9 +28,9 @@ class ObligationTests {
val defaultRef = OpaqueBytes.of(1) val defaultRef = OpaqueBytes.of(1)
val defaultIssuer = MEGA_CORP.ref(defaultRef) val defaultIssuer = MEGA_CORP.ref(defaultRef)
val oneMillionDollars = 1000000.DOLLARS `issued by` defaultIssuer val oneMillionDollars = 1000000.DOLLARS `issued by` defaultIssuer
val trustedCashContract = nonEmptySetOf(SecureHash.randomSHA256() as SecureHash) val trustedCashContract = NonEmptySet.of(SecureHash.randomSHA256() as SecureHash)
val megaIssuedDollars = nonEmptySetOf(Issued(defaultIssuer, USD)) val megaIssuedDollars = NonEmptySet.of(Issued(defaultIssuer, USD))
val megaIssuedPounds = nonEmptySetOf(Issued(defaultIssuer, GBP)) val megaIssuedPounds = NonEmptySet.of(Issued(defaultIssuer, GBP))
val fivePm: Instant = TEST_TX_TIME.truncatedTo(ChronoUnit.DAYS) + 17.hours val fivePm: Instant = TEST_TX_TIME.truncatedTo(ChronoUnit.DAYS) + 17.hours
val sixPm: Instant = fivePm + 1.hours val sixPm: Instant = fivePm + 1.hours
val megaCorpDollarSettlement = Obligation.Terms(trustedCashContract, megaIssuedDollars, fivePm) val megaCorpDollarSettlement = Obligation.Terms(trustedCashContract, megaIssuedDollars, fivePm)
@ -500,7 +500,7 @@ class ObligationTests {
fun `commodity settlement`() { fun `commodity settlement`() {
val defaultFcoj = Issued(defaultIssuer, Commodity.getInstance("FCOJ")!!) val defaultFcoj = Issued(defaultIssuer, Commodity.getInstance("FCOJ")!!)
val oneUnitFcoj = Amount(1, defaultFcoj) val oneUnitFcoj = Amount(1, defaultFcoj)
val obligationDef = Obligation.Terms(nonEmptySetOf(CommodityContract().legalContractReference), nonEmptySetOf(defaultFcoj), TEST_TX_TIME) val obligationDef = Obligation.Terms(NonEmptySet.of(CommodityContract().legalContractReference), NonEmptySet.of(defaultFcoj), TEST_TX_TIME)
val oneUnitFcojObligation = Obligation.State(Obligation.Lifecycle.NORMAL, ALICE, val oneUnitFcojObligation = Obligation.State(Obligation.Lifecycle.NORMAL, ALICE,
obligationDef, oneUnitFcoj.quantity, NULL_PARTY) obligationDef, oneUnitFcoj.quantity, NULL_PARTY)
// Try settling a simple commodity obligation // Try settling a simple commodity obligation
@ -755,10 +755,10 @@ class ObligationTests {
// States must not be nettable if the cash contract differs // States must not be nettable if the cash contract differs
assertNotEquals(fiveKDollarsFromMegaToMega.bilateralNetState, assertNotEquals(fiveKDollarsFromMegaToMega.bilateralNetState,
fiveKDollarsFromMegaToMega.copy(template = megaCorpDollarSettlement.copy(acceptableContracts = nonEmptySetOf(SecureHash.randomSHA256()))).bilateralNetState) fiveKDollarsFromMegaToMega.copy(template = megaCorpDollarSettlement.copy(acceptableContracts = NonEmptySet.of(SecureHash.randomSHA256()))).bilateralNetState)
// States must not be nettable if the trusted issuers differ // States must not be nettable if the trusted issuers differ
val miniCorpIssuer = nonEmptySetOf(Issued(MINI_CORP.ref(1), USD)) val miniCorpIssuer = NonEmptySet.of(Issued(MINI_CORP.ref(1), USD))
assertNotEquals(fiveKDollarsFromMegaToMega.bilateralNetState, assertNotEquals(fiveKDollarsFromMegaToMega.bilateralNetState,
fiveKDollarsFromMegaToMega.copy(template = megaCorpDollarSettlement.copy(acceptableIssuedProducts = miniCorpIssuer)).bilateralNetState) fiveKDollarsFromMegaToMega.copy(template = megaCorpDollarSettlement.copy(acceptableIssuedProducts = miniCorpIssuer)).bilateralNetState)
} }
@ -875,7 +875,7 @@ class ObligationTests {
} }
val Issued<Currency>.OBLIGATION_DEF: Obligation.Terms<Currency> val Issued<Currency>.OBLIGATION_DEF: Obligation.Terms<Currency>
get() = Obligation.Terms(nonEmptySetOf(Cash().legalContractReference), nonEmptySetOf(this), TEST_TX_TIME) get() = Obligation.Terms(NonEmptySet.of(Cash().legalContractReference), NonEmptySet.of(this), TEST_TX_TIME)
val Amount<Issued<Currency>>.OBLIGATION: Obligation.State<Currency> val Amount<Issued<Currency>>.OBLIGATION: Obligation.State<Currency>
get() = Obligation.State(Obligation.Lifecycle.NORMAL, DUMMY_OBLIGATION_ISSUER, token.OBLIGATION_DEF, quantity, NULL_PARTY) get() = Obligation.State(Obligation.Lifecycle.NORMAL, DUMMY_OBLIGATION_ISSUER, token.OBLIGATION_DEF, quantity, NULL_PARTY)
} }

View File

@ -2,12 +2,12 @@ package net.corda.services.messaging
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.X509Utilities
import net.corda.core.crypto.cert import net.corda.core.crypto.cert
import net.corda.core.crypto.random63BitValue
import net.corda.core.getOrThrow import net.corda.core.getOrThrow
import net.corda.core.node.NodeInfo import net.corda.core.node.NodeInfo
import net.corda.core.crypto.random63BitValue
import net.corda.core.seconds import net.corda.core.seconds
import net.corda.core.utilities.NonEmptySet
import net.corda.node.internal.NetworkMapInfo import net.corda.node.internal.NetworkMapInfo
import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.messaging.sendRequest import net.corda.node.services.messaging.sendRequest
@ -30,7 +30,7 @@ class P2PSecurityTest : NodeBasedTest() {
@Test @Test
fun `incorrect legal name for the network map service config`() { fun `incorrect legal name for the network map service config`() {
val incorrectNetworkMapName = X509Utilities.getDevX509Name("NetworkMap-${random63BitValue()}") val incorrectNetworkMapName = getTestX509Name("NetworkMap-${random63BitValue()}")
val node = startNode(BOB.name, configOverrides = mapOf( val node = startNode(BOB.name, configOverrides = mapOf(
"networkMapService" to mapOf( "networkMapService" to mapOf(
"address" to networkMapNode.configuration.p2pAddress.toString(), "address" to networkMapNode.configuration.p2pAddress.toString(),
@ -67,7 +67,7 @@ class P2PSecurityTest : NodeBasedTest() {
private fun SimpleNode.registerWithNetworkMap(registrationName: X500Name): ListenableFuture<NetworkMapService.RegistrationResponse> { private fun SimpleNode.registerWithNetworkMap(registrationName: X500Name): ListenableFuture<NetworkMapService.RegistrationResponse> {
val legalIdentity = getTestPartyAndCertificate(registrationName, identity.public) val legalIdentity = getTestPartyAndCertificate(registrationName, identity.public)
val nodeInfo = NodeInfo(listOf(MOCK_HOST_AND_PORT), legalIdentity, setOf(legalIdentity), 1) val nodeInfo = NodeInfo(listOf(MOCK_HOST_AND_PORT), legalIdentity, NonEmptySet.of(legalIdentity), 1)
val registration = NodeRegistration(nodeInfo, System.currentTimeMillis(), AddOrRemove.ADD, Instant.MAX) val registration = NodeRegistration(nodeInfo, System.currentTimeMillis(), AddOrRemove.ADD, Instant.MAX)
val request = RegistrationRequest(registration.toWire(keyService, identity.public), network.myAddress) val request = RegistrationRequest(registration.toWire(keyService, identity.public), network.myAddress)
return network.sendRequest<NetworkMapService.RegistrationResponse>(NetworkMapService.REGISTER_TOPIC, request, networkMapNode.network.myAddress) return network.sendRequest<NetworkMapService.RegistrationResponse>(NetworkMapService.REGISTER_TOPIC, request, networkMapNode.network.myAddress)

View File

@ -26,6 +26,7 @@ import net.corda.core.serialization.deserialize
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
import net.corda.core.utilities.toNonEmptySet
import net.corda.flows.* import net.corda.flows.*
import net.corda.node.services.* import net.corda.node.services.*
import net.corda.node.services.api.* import net.corda.node.services.api.*
@ -495,7 +496,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
private fun makeInfo(): NodeInfo { private fun makeInfo(): NodeInfo {
val advertisedServiceEntries = makeServiceEntries() val advertisedServiceEntries = makeServiceEntries()
val legalIdentity = obtainLegalIdentity() val legalIdentity = obtainLegalIdentity()
val allIdentitiesSet = advertisedServiceEntries.map { it.identity }.toSet() + legalIdentity val allIdentitiesSet = (advertisedServiceEntries.map { it.identity } + legalIdentity).toNonEmptySet()
val addresses = myAddresses() // TODO There is no support for multiple IP addresses yet. val addresses = myAddresses() // TODO There is no support for multiple IP addresses yet.
return NodeInfo(addresses, legalIdentity, allIdentitiesSet, platformVersion, advertisedServiceEntries, findMyLocation()) return NodeInfo(addresses, legalIdentity, allIdentitiesSet, platformVersion, advertisedServiceEntries, findMyLocation())
} }

View File

@ -33,10 +33,7 @@ import net.corda.core.serialization.storageKryo
import net.corda.core.tee import net.corda.core.tee
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.*
import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.toHexString
import net.corda.core.utilities.trace
import net.corda.node.services.database.RequeryConfiguration import net.corda.node.services.database.RequeryConfiguration
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.vault.schemas.requery.* import net.corda.node.services.vault.schemas.requery.*
@ -261,8 +258,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
} }
@Throws(StatesNotAvailableException::class) @Throws(StatesNotAvailableException::class)
override fun softLockReserve(lockId: UUID, stateRefs: Set<StateRef>) { override fun softLockReserve(lockId: UUID, stateRefs: NonEmptySet<StateRef>) {
if (stateRefs.isNotEmpty()) {
val softLockTimestamp = services.clock.instant() val softLockTimestamp = services.clock.instant()
val stateRefArgs = stateRefArgs(stateRefs) val stateRefArgs = stateRefArgs(stateRefs)
try { try {
@ -296,9 +292,8 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
if (e.cause is StatesNotAvailableException) throw (e.cause as StatesNotAvailableException) if (e.cause is StatesNotAvailableException) throw (e.cause as StatesNotAvailableException)
} }
} }
}
override fun softLockRelease(lockId: UUID, stateRefs: Set<StateRef>?) { override fun softLockRelease(lockId: UUID, stateRefs: NonEmptySet<StateRef>?) {
if (stateRefs == null) { if (stateRefs == null) {
session.withTransaction(TransactionIsolation.REPEATABLE_READ) { session.withTransaction(TransactionIsolation.REPEATABLE_READ) {
val update = update(VaultStatesEntity::class) val update = update(VaultStatesEntity::class)
@ -310,7 +305,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
log.trace("Releasing ${update.value()} soft locked states for $lockId") log.trace("Releasing ${update.value()} soft locked states for $lockId")
} }
} }
} else if (stateRefs.isNotEmpty()) { } else {
try { try {
session.withTransaction(TransactionIsolation.REPEATABLE_READ) { session.withTransaction(TransactionIsolation.REPEATABLE_READ) {
val updatedRows = update(VaultStatesEntity::class) val updatedRows = update(VaultStatesEntity::class)
@ -398,7 +393,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
log.trace("Coin selection for $amount retrieved ${stateAndRefs.count()} states totalling $totalPennies pennies: $stateAndRefs") log.trace("Coin selection for $amount retrieved ${stateAndRefs.count()} states totalling $totalPennies pennies: $stateAndRefs")
// update database // update database
softLockReserve(lockId, stateAndRefs.map { it.ref }.toSet()) softLockReserve(lockId, (stateAndRefs.map { it.ref }).toNonEmptySet())
return stateAndRefs return stateAndRefs
} }
log.trace("Coin selection requested $amount but retrieved $totalPennies pennies with state refs: ${stateAndRefs.map { it.ref }}") log.trace("Coin selection requested $amount but retrieved $totalPennies pennies with state refs: ${stateAndRefs.map { it.ref }}")

View File

@ -4,7 +4,9 @@ import net.corda.core.contracts.StateRef
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
import net.corda.core.node.services.VaultService import net.corda.core.node.services.VaultService
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.node.services.statemachine.FlowStateMachineImpl import net.corda.node.services.statemachine.FlowStateMachineImpl
import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.statemachine.StateMachineManager
@ -36,18 +38,18 @@ class VaultSoftLockManager(val vault: VaultService, smm: StateMachineManager) {
// However, the lock can be programmatically released, like any other soft lock, // However, the lock can be programmatically released, like any other soft lock,
// should we want a long running flow that creates a visible state mid way through. // should we want a long running flow that creates a visible state mid way through.
vault.rawUpdates.subscribe { update -> vault.rawUpdates.subscribe { (_, produced, flowId) ->
update.flowId?.let { flowId?.let {
if (update.produced.isNotEmpty()) { if (produced.isNotEmpty()) {
registerSoftLocks(update.flowId as UUID, update.produced.map { it.ref }) registerSoftLocks(flowId, (produced.map { it.ref }).toNonEmptySet())
} }
} }
} }
} }
private fun registerSoftLocks(flowId: UUID, stateRefs: List<StateRef>) { private fun registerSoftLocks(flowId: UUID, stateRefs: NonEmptySet<StateRef>) {
log.trace("Reserving soft locks for flow id $flowId and states $stateRefs") log.trace("Reserving soft locks for flow id $flowId and states $stateRefs")
vault.softLockReserve(flowId, stateRefs.toSet()) vault.softLockReserve(flowId, stateRefs)
} }
private fun unregisterSoftLocks(id: StateMachineRunId, logic: FlowLogic<*>) { private fun unregisterSoftLocks(id: StateMachineRunId, logic: FlowLogic<*>) {

View File

@ -26,7 +26,7 @@ import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.testing.LogHelper import net.corda.core.utilities.toNonEmptySet
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.flows.TwoPartyTradeFlow.Buyer import net.corda.flows.TwoPartyTradeFlow.Buyer
import net.corda.flows.TwoPartyTradeFlow.Seller import net.corda.flows.TwoPartyTradeFlow.Seller
@ -155,7 +155,10 @@ class TwoPartyTradeFlowTests {
val cashLockId = UUID.randomUUID() val cashLockId = UUID.randomUUID()
bobNode.database.transaction { bobNode.database.transaction {
// lock the cash states with an arbitrary lockId (to prevent the Buyer flow from claiming the states) // lock the cash states with an arbitrary lockId (to prevent the Buyer flow from claiming the states)
bobNode.services.vaultService.softLockReserve(cashLockId, cashStates.states.map { it.ref }.toSet()) val refs = cashStates.states.map { it.ref }
if (refs.isNotEmpty()) {
bobNode.services.vaultService.softLockReserve(cashLockId, refs.toNonEmptySet())
}
} }
val (bobStateMachine, aliceResult) = runBuyerAndSeller(notaryNode, aliceNode, bobNode, val (bobStateMachine, aliceResult) = runBuyerAndSeller(notaryNode, aliceNode, bobNode,

View File

@ -11,7 +11,9 @@ import net.corda.core.node.services.VaultService
import net.corda.core.node.services.unconsumedStates import net.corda.core.node.services.unconsumedStates
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.toNonEmptySet
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.node.utilities.transaction import net.corda.node.utilities.transaction
import net.corda.testing.* import net.corda.testing.*
@ -116,7 +118,7 @@ class NodeVaultServiceTest {
val unconsumedStates = vaultSvc.unconsumedStates<Cash.State>().toList() val unconsumedStates = vaultSvc.unconsumedStates<Cash.State>().toList()
assertThat(unconsumedStates).hasSize(3) assertThat(unconsumedStates).hasSize(3)
val stateRefsToSoftLock = setOf(unconsumedStates[1].ref, unconsumedStates[2].ref) val stateRefsToSoftLock = NonEmptySet.of(unconsumedStates[1].ref, unconsumedStates[2].ref)
// soft lock two of the three states // soft lock two of the three states
val softLockId = UUID.randomUUID() val softLockId = UUID.randomUUID()
@ -132,7 +134,7 @@ class NodeVaultServiceTest {
assertThat(unlockedStates1).hasSize(1) assertThat(unlockedStates1).hasSize(1)
// soft lock release one of the states explicitly // soft lock release one of the states explicitly
vaultSvc.softLockRelease(softLockId, setOf(unconsumedStates[1].ref)) vaultSvc.softLockRelease(softLockId, NonEmptySet.of(unconsumedStates[1].ref))
val unlockedStates2 = vaultSvc.unconsumedStates<Cash.State>(includeSoftLockedStates = false).toList() val unlockedStates2 = vaultSvc.unconsumedStates<Cash.State>(includeSoftLockedStates = false).toList()
assertThat(unlockedStates2).hasSize(2) assertThat(unlockedStates2).hasSize(2)
@ -160,7 +162,7 @@ class NodeVaultServiceTest {
assertNull(vaultSvc.cashBalances[USD]) assertNull(vaultSvc.cashBalances[USD])
services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L))
} }
val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() val stateRefsToSoftLock = (vaultStates.states.map { it.ref }).toNonEmptySet()
println("State Refs:: $stateRefsToSoftLock") println("State Refs:: $stateRefsToSoftLock")
// 1st tx locks states // 1st tx locks states
@ -216,19 +218,19 @@ class NodeVaultServiceTest {
assertNull(vaultSvc.cashBalances[USD]) assertNull(vaultSvc.cashBalances[USD])
services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L))
} }
val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() val stateRefsToSoftLock = vaultStates.states.map { it.ref }
println("State Refs:: $stateRefsToSoftLock") println("State Refs:: $stateRefsToSoftLock")
// lock 1st state with LockId1 // lock 1st state with LockId1
database.transaction { database.transaction {
vaultSvc.softLockReserve(softLockId1, setOf(stateRefsToSoftLock.first())) vaultSvc.softLockReserve(softLockId1, NonEmptySet.of(stateRefsToSoftLock.first()))
assertThat(vaultSvc.softLockedStates<Cash.State>(softLockId1)).hasSize(1) assertThat(vaultSvc.softLockedStates<Cash.State>(softLockId1)).hasSize(1)
} }
// attempt to lock all 3 states with LockId2 // attempt to lock all 3 states with LockId2
database.transaction { database.transaction {
assertThatExceptionOfType(StatesNotAvailableException::class.java).isThrownBy( assertThatExceptionOfType(StatesNotAvailableException::class.java).isThrownBy(
{ vaultSvc.softLockReserve(softLockId2, stateRefsToSoftLock) } { vaultSvc.softLockReserve(softLockId2, stateRefsToSoftLock.toNonEmptySet()) }
).withMessageContaining("only 2 rows available").withNoCause() ).withMessageContaining("only 2 rows available").withNoCause()
} }
} }
@ -243,7 +245,7 @@ class NodeVaultServiceTest {
assertNull(vaultSvc.cashBalances[USD]) assertNull(vaultSvc.cashBalances[USD])
services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L))
} }
val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() val stateRefsToSoftLock = (vaultStates.states.map { it.ref }).toNonEmptySet()
println("State Refs:: $stateRefsToSoftLock") println("State Refs:: $stateRefsToSoftLock")
// lock states with LockId1 // lock states with LockId1
@ -269,18 +271,18 @@ class NodeVaultServiceTest {
assertNull(vaultSvc.cashBalances[USD]) assertNull(vaultSvc.cashBalances[USD])
services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L)) services.fillWithSomeTestCash(100.DOLLARS, DUMMY_NOTARY, 3, 3, Random(0L))
} }
val stateRefsToSoftLock = vaultStates.states.map { it.ref }.toSet() val stateRefsToSoftLock = vaultStates.states.map { it.ref }
println("State Refs:: $stateRefsToSoftLock") println("State Refs:: $stateRefsToSoftLock")
// lock states with LockId1 // lock states with LockId1
database.transaction { database.transaction {
vaultSvc.softLockReserve(softLockId1, setOf(stateRefsToSoftLock.first())) vaultSvc.softLockReserve(softLockId1, NonEmptySet.of(stateRefsToSoftLock.first()))
assertThat(vaultSvc.softLockedStates<Cash.State>(softLockId1)).hasSize(1) assertThat(vaultSvc.softLockedStates<Cash.State>(softLockId1)).hasSize(1)
} }
// attempt to lock all states with LockId1 (including previously already locked one) // attempt to lock all states with LockId1 (including previously already locked one)
database.transaction { database.transaction {
vaultSvc.softLockReserve(softLockId1, stateRefsToSoftLock) vaultSvc.softLockReserve(softLockId1, stateRefsToSoftLock.toNonEmptySet())
assertThat(vaultSvc.softLockedStates<Cash.State>(softLockId1)).hasSize(3) assertThat(vaultSvc.softLockedStates<Cash.State>(softLockId1)).hasSize(3)
} }
} }

View File

@ -16,6 +16,7 @@ import net.corda.core.node.services.vault.QueryCriteria.*
import net.corda.core.seconds import net.corda.core.seconds
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.toHexString import net.corda.core.utilities.toHexString
import net.corda.node.services.database.HibernateConfiguration import net.corda.node.services.database.HibernateConfiguration
import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.NodeSchemaService
@ -470,7 +471,7 @@ class VaultQueryTests {
database.transaction { database.transaction {
val issuedStates = services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 3, 3, Random(0L)) val issuedStates = services.fillWithSomeTestCash(100.DOLLARS, CASH_NOTARY, 3, 3, Random(0L))
vaultSvc.softLockReserve(UUID.randomUUID(), setOf(issuedStates.states.first().ref, issuedStates.states.last().ref)) vaultSvc.softLockReserve(UUID.randomUUID(), NonEmptySet.of(issuedStates.states.first().ref, issuedStates.states.last().ref))
val criteria = VaultQueryCriteria(includeSoftlockedStates = false) val criteria = VaultQueryCriteria(includeSoftlockedStates = false)
val results = vaultQuerySvc.queryBy<ContractState>(criteria) val results = vaultQuerySvc.queryBy<ContractState>(criteria)

View File

@ -7,6 +7,7 @@ import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
import net.corda.core.node.services.NetworkMapCache import net.corda.core.node.services.NetworkMapCache
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.NonEmptySet
import net.corda.node.services.network.InMemoryNetworkMapCache import net.corda.node.services.network.InMemoryNetworkMapCache
import net.corda.testing.getTestPartyAndCertificate import net.corda.testing.getTestPartyAndCertificate
import net.corda.testing.getTestX509Name import net.corda.testing.getTestX509Name
@ -28,8 +29,8 @@ class MockNetworkMapCache(serviceHub: ServiceHub) : InMemoryNetworkMapCache(serv
override val changed: Observable<NetworkMapCache.MapChange> = PublishSubject.create<NetworkMapCache.MapChange>() override val changed: Observable<NetworkMapCache.MapChange> = PublishSubject.create<NetworkMapCache.MapChange>()
init { init {
val mockNodeA = NodeInfo(listOf(BANK_C_ADDR), BANK_C, setOf(BANK_C), 1) val mockNodeA = NodeInfo(listOf(BANK_C_ADDR), BANK_C, NonEmptySet.of(BANK_C), 1)
val mockNodeB = NodeInfo(listOf(BANK_D_ADDR), BANK_D, setOf(BANK_D), 1) val mockNodeB = NodeInfo(listOf(BANK_D_ADDR), BANK_D, NonEmptySet.of(BANK_D), 1)
registeredNodes[mockNodeA.legalIdentity.owningKey] = mockNodeA registeredNodes[mockNodeA.legalIdentity.owningKey] = mockNodeA
registeredNodes[mockNodeB.legalIdentity.owningKey] = mockNodeB registeredNodes[mockNodeB.legalIdentity.owningKey] = mockNodeB
runWithoutMapService() runWithoutMapService()

View File

@ -11,6 +11,7 @@ import net.corda.core.node.services.*
import net.corda.core.serialization.SerializeAsToken import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet
import net.corda.flows.AnonymisedIdentity import net.corda.flows.AnonymisedIdentity
import net.corda.node.VersionInfo import net.corda.node.VersionInfo
import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage
@ -75,7 +76,7 @@ open class MockServices(vararg val keys: KeyPair) : ServiceHub {
override val clock: Clock get() = Clock.systemUTC() override val clock: Clock get() = Clock.systemUTC()
override val myInfo: NodeInfo get() { override val myInfo: NodeInfo get() {
val identity = getTestPartyAndCertificate(MEGA_CORP.name, key.public) val identity = getTestPartyAndCertificate(MEGA_CORP.name, key.public)
return NodeInfo(emptyList(), identity, setOf(identity), 1) return NodeInfo(emptyList(), identity, NonEmptySet.of(identity), 1)
} }
override val transactionVerifierService: TransactionVerifierService get() = InMemoryTransactionVerifierService(2) override val transactionVerifierService: TransactionVerifierService get() = InMemoryTransactionVerifierService(2)