diff --git a/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt b/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt index 82a1d348a0..c22b1dadd3 100644 --- a/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt +++ b/core/src/main/kotlin/net/corda/core/contracts/ContractsDSL.kt @@ -12,7 +12,6 @@ import java.util.* * Defines a simple domain specific language for the specification of financial contracts. Currently covers: * * - Some utilities for working with commands. - * - An Amount type that represents a positive quantity of a specific token. * - A simple language extension for specifying requirements in English, along with logic to enforce them. */ @@ -30,37 +29,44 @@ inline fun requireThat(body: Requirements.() -> R) = Requirements.body() //// Authenticated commands /////////////////////////////////////////////////////////////////////////////////////////// -// TODO: Provide a version of select that interops with Java - /** Filters the command list by type, party and public key all at once. */ inline fun Collection>.select(signer: PublicKey? = null, party: AbstractParty? = null) = - filter { it.value is T }. + select(T::class.java, signer, party) + +/** Filters the command list by type, party and public key all at once. */ +fun Collection>.select(klass: Class, + signer: PublicKey? = null, + party: AbstractParty? = null) = + mapNotNull { if (klass.isInstance(it.value)) uncheckedCast, CommandWithParties>(it) else null }. filter { if (signer == null) true else signer in it.signers }. filter { if (party == null) true else party in it.signingParties }. - map { CommandWithParties(it.signers, it.signingParties, it.value as T) } - -// TODO: Provide a version of select that interops with Java + map { CommandWithParties(it.signers, it.signingParties, it.value) } /** Filters the command list by type, parties and public keys all at once. */ inline fun Collection>.select(signers: Collection?, parties: Collection?) = - filter { it.value is T }. + select(T::class.java, signers, parties) + +/** Filters the command list by type, parties and public keys all at once. */ +fun Collection>.select(klass: Class, + signers: Collection?, + parties: Collection?) = + mapNotNull { if (klass.isInstance(it.value)) uncheckedCast, CommandWithParties>(it) else null }. filter { if (signers == null) true else it.signers.containsAll(signers) }. filter { if (parties == null) true else it.signingParties.containsAll(parties) }. - map { CommandWithParties(it.signers, it.signingParties, it.value as T) } + map { CommandWithParties(it.signers, it.signingParties, it.value) } /** Ensures that a transaction has only one command that is of the given type, otherwise throws an exception. */ -inline fun Collection>.requireSingleCommand() = try { - select().single() +inline fun Collection>.requireSingleCommand() = requireSingleCommand(T::class.java) + +/** Ensures that a transaction has only one command that is of the given type, otherwise throws an exception. */ +fun Collection>.requireSingleCommand(klass: Class) = try { + select(klass).single() } catch (e: NoSuchElementException) { - throw IllegalStateException("Required ${T::class.qualifiedName} command") // Better error message. + throw IllegalStateException("Required ${klass.kotlin.qualifiedName} command") // Better error message. } -/** Ensures that a transaction has only one command that is of the given type, otherwise throws an exception. */ -fun Collection>.requireSingleCommand(klass: Class) = - mapNotNull { if (klass.isInstance(it.value)) uncheckedCast, CommandWithParties>(it) else null }.single() - /** * Simple functionality for verifying a move command. Verifies that each input has a signature from its owning key. * diff --git a/core/src/test/kotlin/net/corda/core/contracts/ContractsDSLTests.kt b/core/src/test/kotlin/net/corda/core/contracts/ContractsDSLTests.kt new file mode 100644 index 0000000000..c4e75a3b05 --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/contracts/ContractsDSLTests.kt @@ -0,0 +1,179 @@ +package net.corda.core.contracts + +import net.corda.core.identity.AbstractParty +import net.corda.core.identity.CordaX500Name +import net.corda.core.identity.Party +import net.corda.testing.TestIdentity +import org.assertj.core.api.Assertions +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import java.security.PublicKey +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class ContractsDSLTests { + class UnwantedCommand : CommandData + + interface TestCommands : CommandData { + class CommandOne : TypeOnlyCommandData(), TestCommands + class CommandTwo : TypeOnlyCommandData(), TestCommands + } + + private companion object { + val megaCorp = TestIdentity(CordaX500Name("MegaCorp", "London", "GB")) + val miniCorp = TestIdentity(CordaX500Name("MiniCorp", "London", "GB")) + + val validCommandOne = CommandWithParties(listOf(megaCorp.publicKey, miniCorp.publicKey), listOf(megaCorp.party, miniCorp.party), TestCommands.CommandOne()) + val validCommandTwo = CommandWithParties(listOf(megaCorp.publicKey), listOf(megaCorp.party), TestCommands.CommandTwo()) + val invalidCommand = CommandWithParties(emptyList(), emptyList(), UnwantedCommand()) + } + + @RunWith(Parameterized::class) + class RequireSingleCommandTests(private val testFunction: (Collection>) -> CommandWithParties, description: String) { + companion object { + @JvmStatic + @Parameterized.Parameters(name = "{1}") + fun data(): Collection> = listOf( + arrayOf({ commands: Collection> -> commands.requireSingleCommand() }, "Inline version"), + arrayOf({ commands: Collection> -> commands.requireSingleCommand(TestCommands::class.java) }, "Interop version") + ) + } + + @Test + fun `check function returns one value`() { + val commands = listOf(validCommandOne, invalidCommand) + val returnedCommand = testFunction(commands) + assertEquals(returnedCommand, validCommandOne, "they should be the same") + } + + @Test(expected = IllegalArgumentException::class) + fun `check error is thrown if more than one valid command`() { + val commands = listOf(validCommandOne, validCommandTwo) + testFunction(commands) + } + + @Test + fun `check error is thrown when command is of wrong type`() { + val commands = listOf(invalidCommand) + Assertions.assertThatThrownBy { testFunction(commands) } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Required net.corda.core.contracts.ContractsDSLTests.TestCommands command") + } + } + + @RunWith(Parameterized::class) + class SelectWithSingleInputsTests(private val testFunction: (Collection>, PublicKey?, AbstractParty?) -> Iterable>, description: String) { + companion object { + @JvmStatic + @Parameterized.Parameters(name = "{1}") + fun data(): Collection> = listOf( + arrayOf({ commands: Collection>, signer: PublicKey?, party: AbstractParty? -> commands.select(signer, party) }, "Inline version"), + arrayOf({ commands: Collection>, signer: PublicKey?, party: AbstractParty? -> commands.select(TestCommands::class.java, signer, party) }, "Interop version") + ) + } + + @Test + fun `check that function returns all values`() { + val commands = listOf(validCommandOne, validCommandTwo) + testFunction(commands, null, null) + assertEquals(2, commands.size) + assertTrue(commands.contains(validCommandOne)) + assertTrue(commands.contains(validCommandTwo)) + } + + @Test + fun `check that function does not return invalid command types`() { + val commands = listOf(validCommandOne, invalidCommand) + val filteredCommands = testFunction(commands, null, null).toList() + assertEquals(1, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertFalse(filteredCommands.contains(invalidCommand)) + } + + @Test + fun `check that function returns commands from valid signers`() { + val commands = listOf(validCommandOne, validCommandTwo) + val filteredCommands = testFunction(commands, miniCorp.publicKey, null).toList() + assertEquals(1, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertFalse(filteredCommands.contains(validCommandTwo)) + } + + @Test + fun `check that function returns commands from valid parties`() { + val commands = listOf(validCommandOne, validCommandTwo) + val filteredCommands = testFunction(commands, null, miniCorp.party).toList() + assertEquals(1, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertFalse(filteredCommands.contains(validCommandTwo)) + } + } + + @RunWith(Parameterized::class) + class SelectWithMultipleInputsTests(private val testFunction: (Collection>, Collection?, Collection?) -> Iterable>, description: String) { + companion object { + @JvmStatic + @Parameterized.Parameters(name = "{1}") + fun data(): Collection> = listOf( + arrayOf({ commands: Collection>, signers: Collection?, party: Collection? -> commands.select(signers, party) }, "Inline version"), + arrayOf({ commands: Collection>, signers: Collection?, party: Collection? -> commands.select(TestCommands::class.java, signers, party) }, "Interop version") + ) + } + + @Test + fun `check that function returns all values`() { + val commands = listOf(validCommandOne, validCommandTwo) + testFunction(commands, null, null) + assertEquals(2, commands.size) + assertTrue(commands.contains(validCommandOne)) + assertTrue(commands.contains(validCommandTwo)) + } + + @Test + fun `check that function does not return invalid command types`() { + val commands = listOf(validCommandOne, invalidCommand) + val filteredCommands = testFunction(commands, null, null).toList() + assertEquals(1, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertFalse(filteredCommands.contains(invalidCommand)) + } + + @Test + fun `check that function returns commands from valid signers`() { + val commands = listOf(validCommandOne, validCommandTwo) + val filteredCommands = testFunction(commands, listOf(megaCorp.publicKey), null).toList() + assertEquals(2, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertTrue(filteredCommands.contains(validCommandTwo)) + } + + @Test + fun `check that function returns commands from all valid signers`() { + val commands = listOf(validCommandOne, validCommandTwo) + val filteredCommands = testFunction(commands, listOf(miniCorp.publicKey, megaCorp.publicKey), null).toList() + assertEquals(1, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertFalse(filteredCommands.contains(validCommandTwo)) + } + + @Test + fun `check that function returns commands from valid parties`() { + val commands = listOf(validCommandOne, validCommandTwo) + val filteredCommands = testFunction(commands, null, listOf(megaCorp.party)).toList() + assertEquals(2, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertTrue(filteredCommands.contains(validCommandTwo)) + } + + @Test + fun `check that function returns commands from all valid parties`() { + val commands = listOf(validCommandOne, validCommandTwo) + val filteredCommands = testFunction(commands, null, listOf(miniCorp.party, megaCorp.party)).toList() + assertEquals(1, filteredCommands.size) + assertTrue(filteredCommands.contains(validCommandOne)) + assertFalse(filteredCommands.contains(validCommandTwo)) + } + } +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt b/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt index 13d560fc3c..0577bff135 100644 --- a/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt +++ b/node/src/integration-test/kotlin/net/corda/node/services/BFTNotaryServiceTests.kt @@ -50,6 +50,7 @@ class BFTNotaryServiceTests : IntegrationTest() { val databaseSchemas = IntegrationTestSchemas("node_0", "node_1", "node_2", "node_3", "node_4", "node_5", "node_6", "node_7", "node_8", "node_9") } + private lateinit var mockNet: MockNetwork private lateinit var notary: Party private lateinit var node: StartedNode @@ -58,6 +59,7 @@ class BFTNotaryServiceTests : IntegrationTest() { fun before() { mockNet = MockNetwork(emptyList()) } + @After fun stopNodes() { mockNet.stopNodes()