mirror of
https://github.com/corda/corda.git
synced 2025-01-18 02:39:51 +00:00
Minor: some small serialisation type safety improvements
This commit is contained in:
parent
5863d489bc
commit
ab9e026053
@ -9,6 +9,7 @@
|
||||
package core
|
||||
|
||||
import com.google.common.io.BaseEncoding
|
||||
import core.serialization.OpaqueBytes
|
||||
import java.math.BigInteger
|
||||
import java.security.*
|
||||
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
package core
|
||||
|
||||
import core.serialization.OpaqueBytes
|
||||
import core.serialization.serialize
|
||||
import java.security.PublicKey
|
||||
|
||||
@ -34,7 +35,7 @@ interface OwnableState : ContractState {
|
||||
}
|
||||
|
||||
/** Returns the SHA-256 hash of the serialised contents of this state (not cached!) */
|
||||
fun ContractState.hash(): SecureHash = SecureHash.sha256((serialize()))
|
||||
fun ContractState.hash(): SecureHash = SecureHash.sha256(serialize().bits)
|
||||
|
||||
/**
|
||||
* A stateref is a pointer to a state, this is an equivalent of an "outpoint" in Bitcoin. It records which transaction
|
||||
|
@ -8,8 +8,7 @@
|
||||
|
||||
package core
|
||||
|
||||
import core.serialization.deserialize
|
||||
import core.serialization.serialize
|
||||
import core.serialization.*
|
||||
import java.security.KeyPair
|
||||
import java.security.PublicKey
|
||||
import java.security.SignatureException
|
||||
@ -21,7 +20,7 @@ import java.util.*
|
||||
* tree passed into a contract.
|
||||
*
|
||||
* TimestampedWireTransaction wraps a serialized SignedWireTransaction. The timestamp is a signature from a timestamping
|
||||
* authority and is what gives the contract a sense of time. This isn't used yet.
|
||||
* authority and is what gives the contract a sense of time. This arrangement may change in future.
|
||||
*
|
||||
* SignedWireTransaction wraps a serialized WireTransaction. It contains one or more ECDSA signatures, each one from
|
||||
* a public key that is mentioned inside a transaction command.
|
||||
@ -56,8 +55,6 @@ data class WireCommand(val command: Command, val pubkeys: List<PublicKey>) {
|
||||
data class WireTransaction(val inputStates: List<ContractStateRef>,
|
||||
val outputStates: List<ContractState>,
|
||||
val commands: List<WireCommand>) {
|
||||
fun serializeForSignature(): ByteArray = serialize()
|
||||
|
||||
fun toLedgerTransaction(timestamp: Instant?, identityService: IdentityService, originalHash: SecureHash): LedgerTransaction {
|
||||
val authenticatedArgs = commands.map {
|
||||
val institutions = it.pubkeys.mapNotNull { pk -> identityService.partyFromKey(pk) }
|
||||
@ -95,8 +92,8 @@ class PartialTransaction(private val inputStates: MutableList<ContractStateRef>
|
||||
fun signWith(key: KeyPair) {
|
||||
check(currentSigs.none { it.by == key.public }) { "This partial transaction was already signed by ${key.public}" }
|
||||
check(commands.count { it.pubkeys.contains(key.public) } > 0) { "Trying to sign with a key that isn't in any command" }
|
||||
val bits = toWireTransaction().serializeForSignature()
|
||||
currentSigs.add(key.private.signWithECDSA(bits, key.public))
|
||||
val data = toWireTransaction().serialize()
|
||||
currentSigs.add(key.private.signWithECDSA(data.bits, key.public))
|
||||
}
|
||||
|
||||
fun toWireTransaction() = WireTransaction(inputStates, outputStates, commands)
|
||||
@ -107,7 +104,7 @@ class PartialTransaction(private val inputStates: MutableList<ContractStateRef>
|
||||
val gotKeys = currentSigs.map { it.by }.toSet()
|
||||
check(gotKeys == requiredKeys) { "The set of required signatures isn't equal to the signatures we've got" }
|
||||
}
|
||||
return SignedWireTransaction(toWireTransaction().serialize().opaque(), ArrayList(currentSigs))
|
||||
return SignedWireTransaction(toWireTransaction().serialize(), ArrayList(currentSigs))
|
||||
}
|
||||
|
||||
fun addInputState(ref: ContractStateRef) {
|
||||
@ -133,7 +130,7 @@ class PartialTransaction(private val inputStates: MutableList<ContractStateRef>
|
||||
fun commands(): List<WireCommand> = ArrayList(commands)
|
||||
}
|
||||
|
||||
data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<DigitalSignature.WithKey>) {
|
||||
data class SignedWireTransaction(val txBits: SerializedBytes<WireTransaction>, val sigs: List<DigitalSignature.WithKey>) {
|
||||
init {
|
||||
check(sigs.isNotEmpty())
|
||||
}
|
||||
@ -158,7 +155,7 @@ data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<Digital
|
||||
*/
|
||||
fun verify(): WireTransaction {
|
||||
verifySignatures()
|
||||
val wtx = txBits.deserialize<WireTransaction>()
|
||||
val wtx = txBits.deserialize()
|
||||
// Verify that every command key was in the set that we just verified: there should be no commands that were
|
||||
// unverified.
|
||||
val cmdKeys = wtx.commands.flatMap { it.pubkeys }.toSet()
|
||||
@ -171,11 +168,11 @@ data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<Digital
|
||||
/** Uses the given timestamper service to calculate a signed timestamp and then returns a wrapper for both */
|
||||
fun toTimestampedTransaction(timestamper: TimestamperService): TimestampedWireTransaction {
|
||||
val bits = serialize()
|
||||
return TimestampedWireTransaction(bits.opaque(), timestamper.timestamp(bits.sha256()).opaque())
|
||||
return TimestampedWireTransaction(bits, timestamper.timestamp(bits.sha256()).opaque())
|
||||
}
|
||||
|
||||
/** Returns a [TimestampedWireTransaction] with an empty byte array as the timestamp: this means, no time was provided. */
|
||||
fun toTimestampedTransactionWithoutTime() = TimestampedWireTransaction(serialize().opaque(), null)
|
||||
fun toTimestampedTransactionWithoutTime() = TimestampedWireTransaction(serialize(), null)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -184,7 +181,7 @@ data class SignedWireTransaction(val txBits: OpaqueBytes, val sigs: List<Digital
|
||||
*/
|
||||
data class TimestampedWireTransaction(
|
||||
/** A serialised SignedWireTransaction */
|
||||
val signedWireTX: OpaqueBytes,
|
||||
val signedWireTXBytes: SerializedBytes<SignedWireTransaction>,
|
||||
|
||||
/** Signature from a timestamping authority. For instance using RFC 3161 */
|
||||
val timestamp: OpaqueBytes?
|
||||
@ -192,9 +189,13 @@ data class TimestampedWireTransaction(
|
||||
val transactionID: SecureHash = serialize().sha256()
|
||||
|
||||
fun verifyToLedgerTransaction(timestamper: TimestamperService, identityService: IdentityService): LedgerTransaction {
|
||||
val stx: SignedWireTransaction = signedWireTX.deserialize()
|
||||
val stx = signedWireTXBytes.deserialize()
|
||||
val wtx: WireTransaction = stx.verify()
|
||||
val instant: Instant? = if (timestamp != null) timestamper.verifyTimestamp(signedWireTX.sha256(), timestamp.bits) else null
|
||||
val instant: Instant? =
|
||||
if (timestamp != null)
|
||||
timestamper.verifyTimestamp(signedWireTXBytes.sha256(), timestamp.bits)
|
||||
else
|
||||
null
|
||||
return wtx.toLedgerTransaction(instant, identityService, transactionID)
|
||||
}
|
||||
}
|
||||
|
@ -8,38 +8,14 @@
|
||||
|
||||
package core
|
||||
|
||||
import com.google.common.io.BaseEncoding
|
||||
import com.google.common.util.concurrent.ListenableFuture
|
||||
import com.google.common.util.concurrent.MoreExecutors
|
||||
import com.google.common.util.concurrent.SettableFuture
|
||||
import org.slf4j.Logger
|
||||
import java.security.SecureRandom
|
||||
import java.time.Duration
|
||||
import java.util.*
|
||||
import java.util.concurrent.Executor
|
||||
|
||||
/** A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect */
|
||||
open class OpaqueBytes(val bits: ByteArray) {
|
||||
init { check(bits.isNotEmpty()) }
|
||||
|
||||
companion object {
|
||||
fun of(vararg b: Byte) = OpaqueBytes(byteArrayOf(*b))
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean{
|
||||
if (this === other) return true
|
||||
if (other !is OpaqueBytes) return false
|
||||
return Arrays.equals(bits, other.bits)
|
||||
}
|
||||
|
||||
override fun hashCode() = Arrays.hashCode(bits)
|
||||
override fun toString() = "[" + BaseEncoding.base16().encode(bits) + "]"
|
||||
|
||||
val size: Int get() = bits.size
|
||||
}
|
||||
|
||||
fun ByteArray.opaque(): OpaqueBytes = OpaqueBytes(this)
|
||||
|
||||
val Int.days: Duration get() = Duration.ofDays(this.toLong())
|
||||
val Int.hours: Duration get() = Duration.ofHours(this.toLong())
|
||||
val Int.minutes: Duration get() = Duration.ofMinutes(this.toLong())
|
||||
|
@ -79,7 +79,7 @@ fun MessagingService.runOnNextMessage(topic: String = "", executor: Executor? =
|
||||
}
|
||||
}
|
||||
|
||||
fun MessagingService.send(topic: String, to: MessageRecipients, obj: Any) = send(createMessage(topic, obj.serialize()), to)
|
||||
fun MessagingService.send(topic: String, to: MessageRecipients, obj: Any) = send(createMessage(topic, obj.serialize().bits), to)
|
||||
|
||||
/**
|
||||
* This class lets you start up a [MessagingService]. Its purpose is to stop you from getting access to the methods
|
||||
|
@ -196,7 +196,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
req.obj?.let {
|
||||
val topic = "${req.topic}.${req.sessionIDForSend}"
|
||||
logger.trace { "-> $topic : message of type ${it.javaClass.name}" }
|
||||
net.send(net.createMessage(topic, it.serialize()), otherSide)
|
||||
net.send(net.createMessage(topic, it.serialize().bits), otherSide)
|
||||
}
|
||||
if (req is ContinuationResult.NotExpectingResponse) {
|
||||
// We sent a message, but don't expect a response, so re-enter the continuation to let it keep going.
|
||||
@ -210,7 +210,7 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
otherSide: MessageRecipients, responseType: Class<*>,
|
||||
topic: String, prevPersistedBytes: ByteArray?) {
|
||||
val checkpoint = Checkpoint(nextState, otherSide, logger.name, topic, responseType.name)
|
||||
val curPersistedBytes = checkpoint.serialize()
|
||||
val curPersistedBytes = checkpoint.serialize().bits
|
||||
persistCheckpoint(prevPersistedBytes, curPersistedBytes)
|
||||
net.runOnNextMessage(topic, runInThread) { netMsg ->
|
||||
val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType)
|
||||
|
41
src/main/kotlin/core/serialization/ByteArrays.kt
Normal file
41
src/main/kotlin/core/serialization/ByteArrays.kt
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* Copyright 2015 Distributed Ledger Group LLC. Distributed as Licensed Company IP to DLG Group Members
|
||||
* pursuant to the August 7, 2015 Advisory Services Agreement and subject to the Company IP License terms
|
||||
* set forth therein.
|
||||
*
|
||||
* All other rights reserved.
|
||||
*/
|
||||
|
||||
package core.serialization
|
||||
|
||||
import com.google.common.io.BaseEncoding
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* A simple class that wraps a byte array and makes the equals/hashCode/toString methods work as you actually expect.
|
||||
* In an ideal JVM this would be a value type and be completely overhead free. Project Valhalla is adding such
|
||||
* functionality to Java, but it won't arrive for a few years yet!
|
||||
*/
|
||||
open class OpaqueBytes(val bits: ByteArray) {
|
||||
init { check(bits.isNotEmpty()) }
|
||||
|
||||
companion object {
|
||||
fun of(vararg b: Byte) = OpaqueBytes(byteArrayOf(*b))
|
||||
}
|
||||
|
||||
override fun equals(other: Any?): Boolean{
|
||||
if (this === other) return true
|
||||
if (other !is OpaqueBytes) return false
|
||||
return Arrays.equals(bits, other.bits)
|
||||
}
|
||||
|
||||
override fun hashCode() = Arrays.hashCode(bits)
|
||||
override fun toString() = "[" + BaseEncoding.base16().encode(bits) + "]"
|
||||
|
||||
val size: Int get() = bits.size
|
||||
}
|
||||
|
||||
class SerializedBytes<T : Any>(bits: ByteArray) : OpaqueBytes(bits)
|
||||
|
||||
fun ByteArray.opaque(): OpaqueBytes = OpaqueBytes(this)
|
||||
inline fun <reified T : Any> SerializedBytes<T>.deserialize(): T = bits.deserialize()
|
@ -11,7 +11,6 @@ package core.serialization
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import core.OpaqueBytes
|
||||
import de.javakaffee.kryoserializers.ArraysAsListSerializer
|
||||
import org.objenesis.strategy.StdInstantiatorStrategy
|
||||
import java.io.ByteArrayOutputStream
|
||||
@ -50,12 +49,12 @@ val THREAD_LOCAL_KRYO = ThreadLocal.withInitial { createKryo() }
|
||||
inline fun <reified T : Any> ByteArray.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): T = kryo.readObject(Input(this), T::class.java)
|
||||
inline fun <reified T : Any> OpaqueBytes.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): T = kryo.readObject(Input(this.bits), T::class.java)
|
||||
|
||||
fun Any.serialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): ByteArray {
|
||||
fun <T : Any> T.serialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): SerializedBytes<T> {
|
||||
val stream = ByteArrayOutputStream()
|
||||
Output(stream).use {
|
||||
kryo.writeObject(it, this)
|
||||
}
|
||||
return stream.toByteArray()
|
||||
return SerializedBytes<T>(stream.toByteArray())
|
||||
}
|
||||
|
||||
fun createKryo(): Kryo {
|
||||
|
@ -14,6 +14,7 @@ import contracts.Cash
|
||||
import contracts.DummyContract
|
||||
import contracts.InsufficientBalanceException
|
||||
import core.*
|
||||
import core.serialization.OpaqueBytes
|
||||
import core.testutils.*
|
||||
import org.junit.Test
|
||||
import java.security.PublicKey
|
||||
|
Loading…
Reference in New Issue
Block a user