Testing: make the ledger DSL take a ServiceHub rather than individual services.

It defaults to a fresh UnitTestServices(). Also clear up a few other areas.
This commit is contained in:
Mike Hearn
2016-07-29 14:48:25 +02:00
parent ba05b90b8f
commit 1c3379f508
5 changed files with 35 additions and 36 deletions

View File

@ -4,14 +4,12 @@ package com.r3corda.core.testing
import com.google.common.base.Throwables import com.google.common.base.Throwables
import com.google.common.net.HostAndPort import com.google.common.net.HostAndPort
import com.r3corda.core.contracts.Attachment
import com.r3corda.core.contracts.StateRef import com.r3corda.core.contracts.StateRef
import com.r3corda.core.contracts.TransactionBuilder import com.r3corda.core.contracts.TransactionBuilder
import com.r3corda.core.crypto.* import com.r3corda.core.crypto.*
import com.r3corda.core.node.services.IdentityService import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.StorageService
import com.r3corda.core.node.services.testing.MockIdentityService import com.r3corda.core.node.services.testing.MockIdentityService
import com.r3corda.core.node.services.testing.MockStorageService import com.r3corda.core.node.services.testing.UnitTestServices
import java.math.BigInteger import java.math.BigInteger
import java.net.ServerSocket import java.net.ServerSocket
import java.security.KeyPair import java.security.KeyPair
@ -95,17 +93,14 @@ fun freeLocalHostAndPort(): HostAndPort {
} }
/** /**
* Creates and tests a ledger built by the passed in dsl. * Creates and tests a ledger built by the passed in dsl. The provided services can be customised, otherwise a default
* @param identityService: The [IdentityService] to be used while building the ledger. * of a freshly built [UnitTestServices] is used.
* @param storageService: The [StorageService] to be used for storing e.g. [Attachment]s.
* @param dsl: The dsl building the ledger.
*/ */
@JvmOverloads fun ledger( @JvmOverloads fun ledger(
identityService: IdentityService = MOCK_IDENTITY_SERVICE, services: ServiceHub = UnitTestServices(),
storageService: StorageService = MockStorageService(),
dsl: LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.() -> Unit dsl: LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.() -> Unit
): LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter> { ): LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter> {
val ledgerDsl = LedgerDSL(TestLedgerDSLInterpreter(identityService, storageService)) val ledgerDsl = LedgerDSL(TestLedgerDSLInterpreter(services))
dsl(ledgerDsl) dsl(ledgerDsl)
return ledgerDsl return ledgerDsl
} }

View File

@ -40,11 +40,13 @@ interface Verifies {
val exceptionMessage = exception.message val exceptionMessage = exception.message
if (exceptionMessage == null) { if (exceptionMessage == null) {
throw AssertionError( throw AssertionError(
"Expected exception containing '$expectedMessage' but raised exception had no message" "Expected exception containing '$expectedMessage' but raised exception had no message",
exception
) )
} else if (!exceptionMessage.toLowerCase().contains(expectedMessage.toLowerCase())) { } else if (!exceptionMessage.toLowerCase().contains(expectedMessage.toLowerCase())) {
throw AssertionError( throw AssertionError(
"Expected exception containing '$expectedMessage' but raised exception was '$exception'" "Expected exception containing '$expectedMessage' but raised exception was '$exception'",
exception
) )
} }
} }

View File

@ -5,8 +5,7 @@ import com.r3corda.core.crypto.DigitalSignature
import com.r3corda.core.crypto.Party import com.r3corda.core.crypto.Party
import com.r3corda.core.crypto.SecureHash import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.signWithECDSA import com.r3corda.core.crypto.signWithECDSA
import com.r3corda.core.node.services.IdentityService import com.r3corda.core.node.ServiceHub
import com.r3corda.core.node.services.StorageService
import com.r3corda.core.serialization.serialize import com.r3corda.core.serialization.serialize
import java.io.InputStream import java.io.InputStream
import java.security.KeyPair import java.security.KeyPair
@ -95,6 +94,10 @@ data class TestTransactionDSLInterpreter private constructor(
transactionBuilder: TransactionBuilder transactionBuilder: TransactionBuilder
) : this(ledgerInterpreter, transactionBuilder, HashMap()) ) : this(ledgerInterpreter, transactionBuilder, HashMap())
val services = object : ServiceHub by ledgerInterpreter.services {
override fun loadState(stateRef: StateRef) = ledgerInterpreter.resolveStateRef<ContractState>(stateRef)
}
private fun copy(): TestTransactionDSLInterpreter = private fun copy(): TestTransactionDSLInterpreter =
TestTransactionDSLInterpreter( TestTransactionDSLInterpreter(
ledgerInterpreter = ledgerInterpreter, ledgerInterpreter = ledgerInterpreter,
@ -141,18 +144,15 @@ data class TestTransactionDSLInterpreter private constructor(
} }
data class TestLedgerDSLInterpreter private constructor ( data class TestLedgerDSLInterpreter private constructor (
private val identityService: IdentityService, val services: ServiceHub,
private val storageService: StorageService,
internal val labelToOutputStateAndRefs: HashMap<String, StateAndRef<ContractState>> = HashMap(), internal val labelToOutputStateAndRefs: HashMap<String, StateAndRef<ContractState>> = HashMap(),
private val transactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap(), private val transactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = LinkedHashMap(),
private val nonVerifiedTransactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap() private val nonVerifiedTransactionWithLocations: HashMap<SecureHash, WireTransactionWithLocation> = HashMap()
) : LedgerDSLInterpreter<TestTransactionDSLInterpreter> { ) : LedgerDSLInterpreter<TestTransactionDSLInterpreter> {
val wireTransactions: List<WireTransaction> get() = transactionWithLocations.values.map { it.transaction } val wireTransactions: List<WireTransaction> get() = transactionWithLocations.values.map { it.transaction }
// We specify [labelToOutputStateAndRefs] just so that Kotlin picks the primary constructor instead of cycling // We specify [labelToOutputStateAndRefs] just so that Kotlin picks the primary constructor instead of cycling
constructor(identityService: IdentityService, storageService: StorageService) : this( constructor(services: ServiceHub) : this(services, labelToOutputStateAndRefs = HashMap())
identityService, storageService, labelToOutputStateAndRefs = HashMap()
)
companion object { companion object {
private fun getCallerLocation(): String? { private fun getCallerLocation(): String? {
@ -179,8 +179,7 @@ data class TestLedgerDSLInterpreter private constructor (
internal fun copy(): TestLedgerDSLInterpreter = internal fun copy(): TestLedgerDSLInterpreter =
TestLedgerDSLInterpreter( TestLedgerDSLInterpreter(
identityService, services,
storageService,
labelToOutputStateAndRefs = HashMap(labelToOutputStateAndRefs), labelToOutputStateAndRefs = HashMap(labelToOutputStateAndRefs),
transactionWithLocations = HashMap(transactionWithLocations), transactionWithLocations = HashMap(transactionWithLocations),
nonVerifiedTransactionWithLocations = HashMap(nonVerifiedTransactionWithLocations) nonVerifiedTransactionWithLocations = HashMap(nonVerifiedTransactionWithLocations)
@ -189,7 +188,7 @@ data class TestLedgerDSLInterpreter private constructor (
internal fun resolveWireTransaction(wireTransaction: WireTransaction): TransactionForVerification { internal fun resolveWireTransaction(wireTransaction: WireTransaction): TransactionForVerification {
return wireTransaction.run { return wireTransaction.run {
val authenticatedCommands = commands.map { val authenticatedCommands = commands.map {
AuthenticatedObject(it.signers, it.signers.mapNotNull { identityService.partyFromKey(it) }, it.value) AuthenticatedObject(it.signers, it.signers.mapNotNull { services.identityService.partyFromKey(it) }, it.value)
} }
val resolvedInputStates = inputs.map { resolveStateRef<ContractState>(it) } val resolvedInputStates = inputs.map { resolveStateRef<ContractState>(it) }
val resolvedAttachments = attachments.map { resolveAttachment(it) } val resolvedAttachments = attachments.map { resolveAttachment(it) }
@ -220,7 +219,7 @@ data class TestLedgerDSLInterpreter private constructor (
} }
internal fun resolveAttachment(attachmentId: SecureHash): Attachment = internal fun resolveAttachment(attachmentId: SecureHash): Attachment =
storageService.attachments.openAttachment(attachmentId) ?: throw AttachmentResolutionException(attachmentId) services.storageService.attachments.openAttachment(attachmentId) ?: throw AttachmentResolutionException(attachmentId)
private fun <R> interpretTransactionDsl( private fun <R> interpretTransactionDsl(
transactionBuilder: TransactionBuilder, transactionBuilder: TransactionBuilder,
@ -233,10 +232,10 @@ data class TestLedgerDSLInterpreter private constructor (
fun toTransactionGroup(): TransactionGroup { fun toTransactionGroup(): TransactionGroup {
val ledgerTransactions = transactionWithLocations.map { val ledgerTransactions = transactionWithLocations.map {
it.value.transaction.toLedgerTransaction(identityService, storageService.attachments) it.value.transaction.toLedgerTransaction(services.identityService, services.storageService.attachments)
} }
val nonVerifiedLedgerTransactions = nonVerifiedTransactionWithLocations.map { val nonVerifiedLedgerTransactions = nonVerifiedTransactionWithLocations.map {
it.value.transaction.toLedgerTransaction(identityService, storageService.attachments) it.value.transaction.toLedgerTransaction(services.identityService, services.storageService.attachments)
} }
return TransactionGroup(ledgerTransactions.toSet(), nonVerifiedLedgerTransactions.toSet()) return TransactionGroup(ledgerTransactions.toSet(), nonVerifiedLedgerTransactions.toSet())
} }
@ -295,7 +294,7 @@ data class TestLedgerDSLInterpreter private constructor (
dsl(LedgerDSL(copy())) dsl(LedgerDSL(copy()))
override fun attachment(attachment: InputStream): SecureHash { override fun attachment(attachment: InputStream): SecureHash {
return storageService.attachments.importAttachment(attachment) return services.storageService.attachments.importAttachment(attachment)
} }
override fun verifies(): EnforceVerifyOrFail { override fun verifies(): EnforceVerifyOrFail {
@ -322,6 +321,9 @@ data class TestLedgerDSLInterpreter private constructor (
return stateAndRef as StateAndRef<S> return stateAndRef as StateAndRef<S>
} }
} }
val transactionsToVerify: List<WireTransaction> get() = transactionWithLocations.values.map { it.transaction }
val transactionsUnverified: List<WireTransaction> get() = nonVerifiedTransactionWithLocations.values.map { it.transaction }
} }
/** /**
@ -330,7 +332,7 @@ data class TestLedgerDSLInterpreter private constructor (
* @param extraKeys extra keys to sign transactions with. * @param extraKeys extra keys to sign transactions with.
* @return List of [SignedTransaction]s. * @return List of [SignedTransaction]s.
*/ */
fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: Array<out KeyPair>) = transactionsToSign.map { wtx -> fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: List<KeyPair>) = transactionsToSign.map { wtx ->
val allPubKeys = wtx.signers.toMutableSet() val allPubKeys = wtx.signers.toMutableSet()
val bits = wtx.serialize() val bits = wtx.serialize()
require(bits == wtx.serialized) require(bits == wtx.serialized)
@ -350,4 +352,4 @@ fun signAll(transactionsToSign: List<WireTransaction>, extraKeys: Array<out KeyP
* @return List of [SignedTransaction]s. * @return List of [SignedTransaction]s.
*/ */
fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.signAll( fun LedgerDSL<TestTransactionDSLInterpreter, TestLedgerDSLInterpreter>.signAll(
vararg extraKeys: KeyPair) = signAll(this.interpreter.wireTransactions, extraKeys) vararg extraKeys: KeyPair) = signAll(this.interpreter.wireTransactions, extraKeys.toList())

View File

@ -246,7 +246,7 @@ class TwoPartyTradeProtocolTests {
val aliceNode = makeNodeWithTracking(notaryNode.info, ALICE.name, ALICE_KEY) val aliceNode = makeNodeWithTracking(notaryNode.info, ALICE.name, ALICE_KEY)
val bobNode = makeNodeWithTracking(notaryNode.info, BOB.name, BOB_KEY) val bobNode = makeNodeWithTracking(notaryNode.info, BOB.name, BOB_KEY)
ledger(storageService = aliceNode.storage) { ledger(aliceNode.services) {
// Insert a prospectus type attachment into the commercial paper transaction. // Insert a prospectus type attachment into the commercial paper transaction.
val stream = ByteArrayOutputStream() val stream = ByteArrayOutputStream()
@ -413,7 +413,7 @@ class TwoPartyTradeProtocolTests {
wtxToSign: List<WireTransaction>, wtxToSign: List<WireTransaction>,
services: ServiceHub, services: ServiceHub,
vararg extraKeys: KeyPair): Map<SecureHash, SignedTransaction> { vararg extraKeys: KeyPair): Map<SecureHash, SignedTransaction> {
val signed: List<SignedTransaction> = signAll(wtxToSign, extraKeys) val signed: List<SignedTransaction> = signAll(wtxToSign, extraKeys.toList())
services.recordTransactions(signed) services.recordTransactions(signed)
val validatedTransactions = services.storageService.validatedTransactions val validatedTransactions = services.storageService.validatedTransactions
if (validatedTransactions is RecordingTransactionStorage) { if (validatedTransactions is RecordingTransactionStorage) {

View File

@ -17,13 +17,13 @@ class GraphVisualiser(val dsl: LedgerDSL<TestTransactionDSLInterpreter, TestLedg
} }
fun convert(): SingleGraph { fun convert(): SingleGraph {
val tg = dsl.interpreter.toTransactionGroup() val testLedger: TestLedgerDSLInterpreter = dsl.interpreter
val graph = createGraph("Transaction group", css) val graph = createGraph("Transaction group", css)
// Map all the transactions, including the bogus non-verified ones (with no inputs) to graph nodes. // Map all the transactions, including the bogus non-verified ones (with no inputs) to graph nodes.
for ((txIndex, tx) in (tg.transactions + tg.nonVerifiedRoots).withIndex()) { for ((txIndex, tx) in (testLedger.transactionsToVerify + testLedger.transactionsUnverified).withIndex()) {
val txNode = graph.addNode<Node>("tx$txIndex") val txNode = graph.addNode<Node>("tx$txIndex")
if (tx !in tg.nonVerifiedRoots) if (tx !in testLedger.transactionsUnverified)
txNode.label = dsl.interpreter.transactionName(tx.id).let { it ?: "TX[${tx.id.prefixChars()}]" } txNode.label = dsl.interpreter.transactionName(tx.id).let { it ?: "TX[${tx.id.prefixChars()}]" }
txNode.styleClass = "tx" txNode.styleClass = "tx"
@ -48,7 +48,7 @@ class GraphVisualiser(val dsl: LedgerDSL<TestTransactionDSLInterpreter, TestLedg
} }
} }
// And now all states and transactions were mapped to graph nodes, hook up the input edges. // And now all states and transactions were mapped to graph nodes, hook up the input edges.
for ((txIndex, tx) in tg.transactions.withIndex()) { for ((txIndex, tx) in testLedger.transactionsToVerify.withIndex()) {
for ((inputIndex, ref) in tx.inputs.withIndex()) { for ((inputIndex, ref) in tx.inputs.withIndex()) {
val edge = graph.addEdge<Edge>("tx$txIndex-in$inputIndex", ref.toString(), "tx$txIndex", true) val edge = graph.addEdge<Edge>("tx$txIndex-in$inputIndex", ref.toString(), "tx$txIndex", true)
edge.weight = 1.2 edge.weight = 1.2