Minor: some small serialisation type safety improvements

This commit is contained in:
Mike Hearn 2015-12-15 19:48:29 +01:00
parent 5863d489bc
commit ab9e026053
9 changed files with 66 additions and 46 deletions

View File

@ -9,6 +9,7 @@
package core
import com.google.common.io.BaseEncoding
import core.serialization.OpaqueBytes
import java.math.BigInteger
import java.security.*

View File

@ -8,6 +8,7 @@
package core
import core.serialization.OpaqueBytes
import core.serialization.serialize
import java.security.PublicKey
@ -34,7 +35,7 @@ interface OwnableState : ContractState {
}
/** Returns the SHA-256 hash of the serialised contents of this state (not cached!) */
fun ContractState.hash(): SecureHash = SecureHash.sha256((serialize()))
fun ContractState.hash(): SecureHash = SecureHash.sha256(serialize().bits)
/**
* A stateref is a pointer to a state, this is an equivalent of an "outpoint" in Bitcoin. It records which transaction

View File

@ -8,8 +8,7 @@
package core
import core.serialization.deserialize
import core.serialization.serialize
import core.serialization.*
import java.security.KeyPair
import java.security.PublicKey
import java.security.SignatureException
@ -21,7 +20,7 @@ import java.util.*
* tree passed into a contract.
*
* TimestampedWireTransaction wraps a serialized SignedWireTransaction. The timestamp is a signature from a timestamping
* authority and is what gives the contract a sense of time. This isn't used yet.
* authority and is what gives the contract a sense of time. This arrangement may change in future.
*
* SignedWireTransaction wraps a serialized WireTransaction. It contains one or more ECDSA signatures, each one from
* a public key that is mentioned inside a transaction command.
@ -56,8 +55,6 @@ data class WireCommand(val command: Command, val pubkeys: List<PublicKey>) {
data class WireTransaction(val inputStates: List<ContractStateRef>,
val outputStates: List<ContractState>,
val commands: List<WireCommand>) {
fun serializeForSignature(): ByteArray = serialize()
fun toLedgerTransaction(timestamp: Instant?, identityService: IdentityService, originalHash: SecureHash): LedgerTransaction {
val authenticatedArgs = commands.map {
val institutions = it.pubkeys.mapNotNull { pk -> identityService.partyFromKey(pk) }
@ -95,8 +92,8 @@ class PartialTransaction(private val inputStates: MutableList<ContractStateRef>
fun signWith(key: KeyPair) {
check(currentSigs.none { it.by == key.public }) { "This partial transaction was already signed by ${key.public}" }
check(commands.count { it.pubkeys.contains(key.public) } > 0) { "Trying to sign with a key that isn't in any command" }
val bits = toWireTransaction().serializeForSignature()
currentSigs.add(key.private.signWithECDSA(bits, key.public))
val data = toWireTransaction().serialize()
currentSigs.add(key.private.signWithECDSA(data.bits, key.public))
}
fun toWireTransaction() = WireTransaction(inputStates, outputStates, commands)
@ -107,7 +104,7 @@ class PartialTransaction(private val inputStates: MutableList<ContractStateRef>
val gotKeys = currentSigs.map { it.by }.toSet()
check(gotKeys == requiredKeys) { "The set of required signatures isn't equal to the signatures we've got" }
}
return SignedWireTransaction(toWireTransaction().serialize().opaque(), ArrayList(currentSigs))
return SignedWireTransaction(toWireTransaction().serialize(), ArrayList(currentSigs))
}
fun addInputState(ref: ContractStateRef) {
@ -133,7 +130,7 @@ class PartialTransaction(private val inputStates: MutableList<ContractStateRef>
fun commands(): List<WireCommand> = ArrayList(commands)
}
data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<DigitalSignature.WithKey>) {
data class SignedWireTransaction(val txBits: SerializedBytes<WireTransaction>, val sigs: List<DigitalSignature.WithKey>) {
init {
check(sigs.isNotEmpty())
}
@ -158,7 +155,7 @@ data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<Digital
*/
fun verify(): WireTransaction {
verifySignatures()
val wtx = txBits.deserialize<WireTransaction>()
val wtx = txBits.deserialize()
// Verify that every command key was in the set that we just verified: there should be no commands that were
// unverified.
val cmdKeys = wtx.commands.flatMap { it.pubkeys }.toSet()
@ -171,11 +168,11 @@ data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<Digital
/** Uses the given timestamper service to calculate a signed timestamp and then returns a wrapper for both */
fun toTimestampedTransaction(timestamper: TimestamperService): TimestampedWireTransaction {
val bits = serialize()
return TimestampedWireTransaction(bits.opaque(), timestamper.timestamp(bits.sha256()).opaque())
return TimestampedWireTransaction(bits, timestamper.timestamp(bits.sha256()).opaque())
}
/** Returns a [TimestampedWireTransaction] with an empty byte array as the timestamp: this means, no time was provided. */
fun toTimestampedTransactionWithoutTime() = TimestampedWireTransaction(serialize().opaque(), null)
fun toTimestampedTransactionWithoutTime() = TimestampedWireTransaction(serialize(), null)
}
/**
@ -184,7 +181,7 @@ data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<Digital
*/
data class TimestampedWireTransaction(
/** A serialised SignedWireTransaction */
val signedWireTX: OpaqueBytes,
val signedWireTXBytes: SerializedBytes<SignedWireTransaction>,
/** Signature from a timestamping authority. For instance using RFC 3161 */
val timestamp: OpaqueBytes?
@ -192,9 +189,13 @@ data class TimestampedWireTransaction(
val transactionID: SecureHash = serialize().sha256()
fun verifyToLedgerTransaction(timestamper: TimestamperService, identityService: IdentityService): LedgerTransaction {
val stx: SignedWireTransaction = signedWireTX.deserialize()
val stx = signedWireTXBytes.deserialize()
val wtx: WireTransaction = stx.verify()
val instant: Instant? = if (timestamp != null) timestamper.verifyTimestamp(signedWireTX.sha256(), timestamp.bits) else null
val instant: Instant? =
if (timestamp != null)
timestamper.verifyTimestamp(signedWireTXBytes.sha256(), timestamp.bits)
else
null
return wtx.toLedgerTransaction(instant, identityService, transactionID)
}
}

View File

@ -8,38 +8,14 @@
package core
import com.google.common.io.BaseEncoding
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
import org.slf4j.Logger
import java.security.SecureRandom
import java.time.Duration
import java.util.*
import java.util.concurrent.Executor
/** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */
open class OpaqueBytes(val bits: ByteArray) {
init { check(bits.isNotEmpty()) }
companion object {
fun of(vararg b: Byte) = OpaqueBytes(byteArrayOf(*b))
}
override fun equals(other: Any?): Boolean{
if (this === other) return true
if (other !is OpaqueBytes) return false
return Arrays.equals(bits, other.bits)
}
override fun hashCode() = Arrays.hashCode(bits)
override fun toString() = "[" + BaseEncoding.base16().encode(bits) + "]"
val size: Int get() = bits.size
}
fun ByteArray.opaque(): OpaqueBytes = OpaqueBytes(this)
val Int.days: Duration get() = Duration.ofDays(this.toLong())
val Int.hours: Duration get() = Duration.ofHours(this.toLong())
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())

View File

@ -79,7 +79,7 @@ fun MessagingService.runOnNextMessage(topic: String = "", executor: Executor? =
}
}
fun MessagingService.send(topic: String, to: MessageRecipients, obj: Any) = send(createMessage(topic, obj.serialize()), to)
fun MessagingService.send(topic: String, to: MessageRecipients, obj: Any) = send(createMessage(topic, obj.serialize().bits), to)
/**
* This class lets you start up a [MessagingService]. Its purpose is to stop you from getting access to the methods

View File

@ -196,7 +196,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
req.obj?.let {
val topic = "${req.topic}.${req.sessionIDForSend}"
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
net.send(net.createMessage(topic, it.serialize()), otherSide)
net.send(net.createMessage(topic, it.serialize().bits), otherSide)
}
if (req is ContinuationResult.NotExpectingResponse) {
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
@ -210,7 +210,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
otherSide: MessageRecipients, responseType: Class<*>,
topic: String, prevPersistedBytes: ByteArray?) {
val checkpoint = Checkpoint(nextState, otherSide, logger.name, topic, responseType.name)
val curPersistedBytes = checkpoint.serialize()
val curPersistedBytes = checkpoint.serialize().bits
persistCheckpoint(prevPersistedBytes, curPersistedBytes)
net.runOnNextMessage(topic, runInThread) { netMsg ->
val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType)

View File

@ -0,0 +1,41 @@
/*
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
* set forth therein.
*
* All other rights reserved.
*/
package core.serialization
import com.google.common.io.BaseEncoding
import java.util.*
/**
* A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect.
* In an ideal JVM this would be a value type and be completely overhead free. Project Valhalla is adding such
* functionality to Java, but it won't arrive for a few years yet!
*/
open class OpaqueBytes(val bits: ByteArray) {
init { check(bits.isNotEmpty()) }
companion object {
fun of(vararg b: Byte) = OpaqueBytes(byteArrayOf(*b))
}
override fun equals(other: Any?): Boolean{
if (this === other) return true
if (other !is OpaqueBytes) return false
return Arrays.equals(bits, other.bits)
}
override fun hashCode() = Arrays.hashCode(bits)
override fun toString() = "[" + BaseEncoding.base16().encode(bits) + "]"
val size: Int get() = bits.size
}
class SerializedBytes<T : Any>(bits: ByteArray) : OpaqueBytes(bits)
fun ByteArray.opaque(): OpaqueBytes = OpaqueBytes(this)
inline fun <reified T : Any> SerializedBytes<T>.deserialize(): T = bits.deserialize()

View File

@ -11,7 +11,6 @@ package core.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import core.OpaqueBytes
import de.javakaffee.kryoserializers.ArraysAsListSerializer
import org.objenesis.strategy.StdInstantiatorStrategy
import java.io.ByteArrayOutputStream
@ -50,12 +49,12 @@ val THREAD_LOCAL_KRYO = ThreadLocal.withInitial { createKryo() }
inline fun <reified T : Any> ByteArray.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): T = kryo.readObject(Input(this), T::class.java)
inline fun <reified T : Any> OpaqueBytes.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): T = kryo.readObject(Input(this.bits), T::class.java)
fun Any.serialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): ByteArray {
fun <T : Any> T.serialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): SerializedBytes<T> {
val stream = ByteArrayOutputStream()
Output(stream).use {
kryo.writeObject(it, this)
}
return stream.toByteArray()
return SerializedBytes<T>(stream.toByteArray())
}
fun createKryo(): Kryo {

View File

@ -14,6 +14,7 @@ import contracts.Cash
import contracts.DummyContract
import contracts.InsufficientBalanceException
import core.*
import core.serialization.OpaqueBytes
import core.testutils.*
import org.junit.Test
import java.security.PublicKey