testdsl: Use and expose TransactionBuilder in TestTransactionDSLInterpreter

This commit is contained in:
Andras Slemmer 2016-07-06 17:22:38 +01:00
parent 656b06f7f5
commit c3060c11c0
7 changed files with 95 additions and 68 deletions

View File

@ -20,16 +20,31 @@ import java.util.*
* an output state can be added by just passing in a [ContractState] a [TransactionState] with the * an output state can be added by just passing in a [ContractState] a [TransactionState] with the
* default notary will be generated automatically. * default notary will be generated automatically.
*/ */
abstract class TransactionBuilder(protected val type: TransactionType = TransactionType.General(), open class TransactionBuilder(
protected val notary: Party? = null) { protected val type: TransactionType = TransactionType.General(),
protected val inputs: MutableList<StateRef> = arrayListOf() protected val notary: Party? = null,
protected val attachments: MutableList<SecureHash> = arrayListOf() protected val inputs: MutableList<StateRef> = arrayListOf(),
protected val outputs: MutableList<TransactionState<ContractState>> = arrayListOf() protected val attachments: MutableList<SecureHash> = arrayListOf(),
protected val commands: MutableList<Command> = arrayListOf() protected val outputs: MutableList<TransactionState<ContractState>> = arrayListOf(),
protected val signers: MutableSet<PublicKey> = mutableSetOf() protected val commands: MutableList<Command> = arrayListOf(),
protected val signers: MutableSet<PublicKey> = mutableSetOf()) {
val time: TimestampCommand? get() = commands.mapNotNull { it.value as? TimestampCommand }.singleOrNull() val time: TimestampCommand? get() = commands.mapNotNull { it.value as? TimestampCommand }.singleOrNull()
/**
* Creates a copy of the builder
*/
fun copy(): TransactionBuilder =
TransactionBuilder(
type = type,
notary = notary,
inputs = ArrayList(inputs),
attachments = ArrayList(attachments),
outputs = ArrayList(outputs),
commands = ArrayList(commands),
signers = LinkedHashSet(signers)
)
/** /**
* Places a [TimestampCommand] in this transaction, removing any existing command if there is one. * Places a [TimestampCommand] in this transaction, removing any existing command if there is one.
* The command requires a signature from the Notary service, which acts as a Timestamp Authority. * The command requires a signature from the Notary service, which acts as a Timestamp Authority.
@ -112,31 +127,32 @@ abstract class TransactionBuilder(protected val type: TransactionType = Transact
return SignedTransaction(toWireTransaction().serialize(), ArrayList(currentSigs)) return SignedTransaction(toWireTransaction().serialize(), ArrayList(currentSigs))
} }
open fun addInputState(stateAndRef: StateAndRef<*>) { open fun addInputState(stateAndRef: StateAndRef<*>) = addInputState(stateAndRef.ref, stateAndRef.state.notary)
fun addInputState(stateRef: StateRef, notary: Party) {
check(currentSigs.isEmpty()) check(currentSigs.isEmpty())
val notaryKey = stateAndRef.state.notary.owningKey signers.add(notary.owningKey)
signers.add(notaryKey) inputs.add(stateRef)
inputs.add(stateAndRef.ref)
} }
fun addAttachment(attachment: Attachment) { fun addAttachment(attachmentId: SecureHash) {
check(currentSigs.isEmpty()) check(currentSigs.isEmpty())
attachments.add(attachment.id) attachments.add(attachmentId)
} }
fun addOutputState(state: TransactionState<*>) { fun addOutputState(state: TransactionState<*>): Int {
check(currentSigs.isEmpty()) check(currentSigs.isEmpty())
outputs.add(state) outputs.add(state)
return outputs.size - 1
} }
fun addOutputState(state: ContractState, notary: Party) = addOutputState(TransactionState(state, notary)) fun addOutputState(state: ContractState, notary: Party) = addOutputState(TransactionState(state, notary))
/** A default notary must be specified during builder construction to use this method */ /** A default notary must be specified during builder construction to use this method */
fun addOutputState(state: ContractState) { fun addOutputState(state: ContractState): Int {
checkNotNull(notary) { "Need to specify a Notary for the state, or set a default one on TransactionBuilder initialisation" } checkNotNull(notary) { "Need to specify a Notary for the state, or set a default one on TransactionBuilder initialisation" }
addOutputState(state, notary!!) return addOutputState(state, notary!!)
} }
fun addCommand(arg: Command) { fun addCommand(arg: Command) {

View File

@ -9,8 +9,10 @@ interface OutputStateLookup {
} }
interface LedgerDSLInterpreter<R, out T : TransactionDSLInterpreter<R>> : OutputStateLookup { interface LedgerDSLInterpreter<R, out T : TransactionDSLInterpreter<R>> : OutputStateLookup {
fun transaction(transactionLabel: String?, dsl: TransactionDSL<R, T>.() -> R): WireTransaction fun _transaction(transactionLabel: String?, transactionBuilder: TransactionBuilder,
fun unverifiedTransaction(transactionLabel: String?, dsl: TransactionDSL<R, T>.() -> Unit): WireTransaction dsl: TransactionDSL<R, T>.() -> R): WireTransaction
fun _unverifiedTransaction(transactionLabel: String?, transactionBuilder: TransactionBuilder,
dsl: TransactionDSL<R, T>.() -> Unit): WireTransaction
fun tweak(dsl: LedgerDSL<R, T, LedgerDSLInterpreter<R, T>>.() -> Unit) fun tweak(dsl: LedgerDSL<R, T, LedgerDSLInterpreter<R, T>>.() -> Unit)
fun attachment(attachment: InputStream): SecureHash fun attachment(attachment: InputStream): SecureHash
fun verifies() fun verifies()
@ -26,10 +28,14 @@ interface LedgerDSLInterpreter<R, out T : TransactionDSLInterpreter<R>> : Output
class LedgerDSL<R, out T : TransactionDSLInterpreter<R>, out L : LedgerDSLInterpreter<R, T>> (val interpreter: L) : class LedgerDSL<R, out T : TransactionDSLInterpreter<R>, out L : LedgerDSLInterpreter<R, T>> (val interpreter: L) :
LedgerDSLInterpreter<R, TransactionDSLInterpreter<R>> by interpreter { LedgerDSLInterpreter<R, TransactionDSLInterpreter<R>> by interpreter {
fun transaction(dsl: TransactionDSL<R, TransactionDSLInterpreter<R>>.() -> R) = @JvmOverloads
transaction(null, dsl) fun transaction(label: String? = null, transactionBuilder: TransactionBuilder = TransactionBuilder(),
fun unverifiedTransaction(dsl: TransactionDSL<R, TransactionDSLInterpreter<R>>.() -> Unit) = dsl: TransactionDSL<R, TransactionDSLInterpreter<R>>.() -> R) =
unverifiedTransaction(null, dsl) _transaction(label, transactionBuilder, dsl)
@JvmOverloads
fun unverifiedTransaction(label: String? = null, transactionBuilder: TransactionBuilder = TransactionBuilder(),
dsl: TransactionDSL<R, TransactionDSLInterpreter<R>>.() -> Unit) =
_unverifiedTransaction(label, transactionBuilder, dsl)
inline fun <reified S : ContractState> String.outputStateAndRef(): StateAndRef<S> = inline fun <reified S : ContractState> String.outputStateAndRef(): StateAndRef<S> =
retrieveOutputStateAndRef(S::class.java, this) retrieveOutputStateAndRef(S::class.java, this)

View File

@ -16,11 +16,12 @@ import java.util.*
fun transaction( fun transaction(
transactionLabel: String? = null, transactionLabel: String? = null,
transactionBuilder: TransactionBuilder = TransactionBuilder(),
dsl: TransactionDSL< dsl: TransactionDSL<
EnforceVerifyOrFail, EnforceVerifyOrFail,
TransactionDSLInterpreter<EnforceVerifyOrFail> TransactionDSLInterpreter<EnforceVerifyOrFail>
>.() -> EnforceVerifyOrFail >.() -> EnforceVerifyOrFail
) = JavaTestHelpers.transaction(transactionLabel, dsl) ) = JavaTestHelpers.transaction(transactionLabel, transactionBuilder, dsl)
fun ledger( fun ledger(
identityService: IdentityService = MOCK_IDENTITY_SERVICE, identityService: IdentityService = MOCK_IDENTITY_SERVICE,
@ -68,57 +69,55 @@ sealed class EnforceVerifyOrFail {
internal object Token: EnforceVerifyOrFail() internal object Token: EnforceVerifyOrFail()
} }
class DuplicateOutputLabel(label: String) : Exception("Output label '$label' already used")
/** /**
* This interpreter builds a transaction, and [TransactionDSL.verifies] that the resolved transaction is correct. Note * This interpreter builds a transaction, and [TransactionDSL.verifies] that the resolved transaction is correct. Note
* that transactions corresponding to input states are not verified. Use [LedgerDSL.verifies] for that. * that transactions corresponding to input states are not verified. Use [LedgerDSL.verifies] for that.
*/ */
data class TestTransactionDSLInterpreter( data class TestTransactionDSLInterpreter private constructor(
override val ledgerInterpreter: TestLedgerDSLInterpreter, override val ledgerInterpreter: TestLedgerDSLInterpreter,
private val inputStateRefs: ArrayList<StateRef> = arrayListOf(), val transactionBuilder: TransactionBuilder,
internal val outputStates: ArrayList<LabeledOutput> = arrayListOf(), internal val labelToIndexMap: HashMap<String, Int>
private val attachments: ArrayList<SecureHash> = arrayListOf(),
private val commands: ArrayList<Command> = arrayListOf(),
private val signers: LinkedHashSet<PublicKey> = LinkedHashSet(),
private val transactionType: TransactionType = TransactionType.General()
) : TransactionDSLInterpreter<EnforceVerifyOrFail>, OutputStateLookup by ledgerInterpreter { ) : TransactionDSLInterpreter<EnforceVerifyOrFail>, OutputStateLookup by ledgerInterpreter {
constructor(
ledgerInterpreter: TestLedgerDSLInterpreter,
transactionBuilder: TransactionBuilder // = TransactionBuilder()
) : this(ledgerInterpreter, transactionBuilder, HashMap())
private fun copy(): TestTransactionDSLInterpreter = private fun copy(): TestTransactionDSLInterpreter =
TestTransactionDSLInterpreter( TestTransactionDSLInterpreter(
ledgerInterpreter = ledgerInterpreter, ledgerInterpreter = ledgerInterpreter,
inputStateRefs = ArrayList(inputStateRefs), transactionBuilder = transactionBuilder.copy(),
outputStates = ArrayList(outputStates), labelToIndexMap = HashMap(labelToIndexMap)
attachments = ArrayList(attachments),
commands = ArrayList(commands),
signers = LinkedHashSet(signers),
transactionType = transactionType
) )
internal fun toWireTransaction(): WireTransaction = internal fun toWireTransaction() = transactionBuilder.toWireTransaction()
WireTransaction(
inputs = inputStateRefs,
outputs = outputStates.map { it.state },
attachments = attachments,
commands = commands,
signers = signers.toList(),
type = transactionType
)
override fun input(stateRef: StateRef) { override fun input(stateRef: StateRef) {
val notary = ledgerInterpreter.resolveStateRef<ContractState>(stateRef).notary val notary = ledgerInterpreter.resolveStateRef<ContractState>(stateRef).notary
signers.add(notary.owningKey) transactionBuilder.addInputState(stateRef, notary)
inputStateRefs.add(stateRef)
} }
override fun _output(label: String?, notary: Party, contractState: ContractState) { override fun _output(label: String?, notary: Party, contractState: ContractState) {
outputStates.add(LabeledOutput(label, TransactionState(contractState, notary))) val outputIndex = transactionBuilder.addOutputState(contractState, notary)
if (label != null) {
if (labelToIndexMap.contains(label)) {
throw DuplicateOutputLabel(label)
} else {
labelToIndexMap[label] = outputIndex
}
}
} }
override fun attachment(attachmentId: SecureHash) { override fun attachment(attachmentId: SecureHash) {
attachments.add(attachmentId) transactionBuilder.addAttachment(attachmentId)
} }
override fun _command(signers: List<PublicKey>, commandData: CommandData) { override fun _command(signers: List<PublicKey>, commandData: CommandData) {
this.signers.addAll(signers) val command = Command(commandData, signers)
commands.add(Command(commandData, signers)) transactionBuilder.addCommand(command)
} }
override fun verifies(): EnforceVerifyOrFail { override fun verifies(): EnforceVerifyOrFail {
@ -243,9 +242,10 @@ data class TestLedgerDSLInterpreter private constructor (
storageService.attachments.openAttachment(attachmentId) ?: throw AttachmentResolutionException(attachmentId) storageService.attachments.openAttachment(attachmentId) ?: throw AttachmentResolutionException(attachmentId)
private fun <Return> interpretTransactionDsl( private fun <Return> interpretTransactionDsl(
transactionBuilder: TransactionBuilder,
dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> Return dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> Return
): TestTransactionDSLInterpreter { ): TestTransactionDSLInterpreter {
val transactionInterpreter = TestTransactionDSLInterpreter(this) val transactionInterpreter = TestTransactionDSLInterpreter(this, transactionBuilder)
dsl(TransactionDSL(transactionInterpreter)) dsl(TransactionDSL(transactionInterpreter))
return transactionInterpreter return transactionInterpreter
} }
@ -274,18 +274,20 @@ data class TestLedgerDSLInterpreter private constructor (
private fun <R> recordTransactionWithTransactionMap( private fun <R> recordTransactionWithTransactionMap(
transactionLabel: String?, transactionLabel: String?,
transactionBuilder: TransactionBuilder,
dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> R, dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> R,
transactionMap: HashMap<SecureHash, WireTransactionWithLocation> = HashMap() transactionMap: HashMap<SecureHash, WireTransactionWithLocation> = HashMap()
): WireTransaction { ): WireTransaction {
val transactionLocation = getCallerLocation(3) val transactionLocation = getCallerLocation(3)
val transactionInterpreter = interpretTransactionDsl(dsl) val transactionInterpreter = interpretTransactionDsl(transactionBuilder, dsl)
// Create the WireTransaction // Create the WireTransaction
val wireTransaction = transactionInterpreter.toWireTransaction() val wireTransaction = transactionInterpreter.toWireTransaction()
// Record the output states // Record the output states
transactionInterpreter.outputStates.forEachIndexed { index, labeledOutput -> transactionInterpreter.labelToIndexMap.forEach { label, index ->
if (labeledOutput.label != null) { if (labelToOutputStateAndRefs.contains(label)) {
labelToOutputStateAndRefs[labeledOutput.label] = wireTransaction.outRef(index) throw DuplicateOutputLabel(label)
} }
labelToOutputStateAndRefs[label] = wireTransaction.outRef(index)
} }
transactionMap[wireTransaction.serialized.hash] = transactionMap[wireTransaction.serialized.hash] =
@ -294,15 +296,17 @@ data class TestLedgerDSLInterpreter private constructor (
return wireTransaction return wireTransaction
} }
override fun transaction( override fun _transaction(
transactionLabel: String?, transactionLabel: String?,
transactionBuilder: TransactionBuilder,
dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> EnforceVerifyOrFail dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> EnforceVerifyOrFail
) = recordTransactionWithTransactionMap(transactionLabel, dsl, transactionWithLocations) ) = recordTransactionWithTransactionMap(transactionLabel, transactionBuilder, dsl, transactionWithLocations)
override fun unverifiedTransaction( override fun _unverifiedTransaction(
transactionLabel: String?, transactionLabel: String?,
transactionBuilder: TransactionBuilder,
dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> Unit dsl: TransactionDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter>.() -> Unit
) = recordTransactionWithTransactionMap(transactionLabel, dsl, nonVerifiedTransactionWithLocations) ) = recordTransactionWithTransactionMap(transactionLabel, transactionBuilder, dsl, nonVerifiedTransactionWithLocations)
override fun tweak( override fun tweak(
dsl: LedgerDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter, dsl: LedgerDSL<EnforceVerifyOrFail, TestTransactionDSLInterpreter,

View File

@ -102,11 +102,12 @@ object JavaTestHelpers {
@JvmStatic @JvmOverloads fun transaction( @JvmStatic @JvmOverloads fun transaction(
transactionLabel: String? = null, transactionLabel: String? = null,
transactionBuilder: TransactionBuilder = TransactionBuilder(),
dsl: TransactionDSL< dsl: TransactionDSL<
EnforceVerifyOrFail, EnforceVerifyOrFail,
TransactionDSLInterpreter<EnforceVerifyOrFail> TransactionDSLInterpreter<EnforceVerifyOrFail>
>.() -> EnforceVerifyOrFail >.() -> EnforceVerifyOrFail
) = ledger { transaction(transactionLabel, dsl) } ) = ledger { this.transaction(transactionLabel, transactionBuilder, dsl) }
} }
val TEST_TX_TIME = JavaTestHelpers.TEST_TX_TIME val TEST_TX_TIME = JavaTestHelpers.TEST_TX_TIME

View File

@ -59,7 +59,7 @@ class TransactionDSL<R, out T : TransactionDSLInterpreter<R>> (val interpreter:
* Adds the passed in state as a non-verified transaction output to the ledger and adds that as an input. * Adds the passed in state as a non-verified transaction output to the ledger and adds that as an input.
*/ */
fun input(state: ContractState) { fun input(state: ContractState) {
val transaction = ledgerInterpreter.unverifiedTransaction(null) { val transaction = ledgerInterpreter._unverifiedTransaction(null, TransactionBuilder()) {
output { state } output { state }
} }
input(transaction.outRef<ContractState>(0).ref) input(transaction.outRef<ContractState>(0).ref)

View File

@ -234,7 +234,7 @@ class AttachmentClassLoaderTests {
val attachmentRef = importJar(storage) val attachmentRef = importJar(storage)
tx.addAttachment(storage.openAttachment(attachmentRef)!!) tx.addAttachment(storage.openAttachment(attachmentRef)!!.id)
val wireTransaction = tx.toWireTransaction() val wireTransaction = tx.toWireTransaction()
@ -265,7 +265,7 @@ class AttachmentClassLoaderTests {
val attachmentRef = importJar(storage) val attachmentRef = importJar(storage)
tx.addAttachment(storage.openAttachment(attachmentRef)!!) tx.addAttachment(storage.openAttachment(attachmentRef)!!.id)
val wireTransaction = tx.toWireTransaction() val wireTransaction = tx.toWireTransaction()

View File

@ -341,7 +341,7 @@ private class TraderDemoProtocolSeller(val otherSide: Party,
// TODO: Consider moving these two steps below into generateIssue. // TODO: Consider moving these two steps below into generateIssue.
// Attach the prospectus. // Attach the prospectus.
tx.addAttachment(serviceHub.storageService.attachments.openAttachment(PROSPECTUS_HASH)!!) tx.addAttachment(serviceHub.storageService.attachments.openAttachment(PROSPECTUS_HASH)!!.id)
// Requesting timestamping, all CP must be timestamped. // Requesting timestamping, all CP must be timestamped.
tx.setTime(Instant.now(), notaryNode.identity, 30.seconds) tx.setTime(Instant.now(), notaryNode.identity, 30.seconds)