mirror of
https://github.com/corda/corda.git
synced 2025-01-29 15:43:55 +00:00
Switch Kryo serialisation to always include the class name, and avoid overhead from writing out SerializedBytes wrapper data.
This simplifies the serialisation code, reduces the use of inline functions, and ensures that running SerializedBytes<SuperClass>.deserialise() will correctly return SubClass if that's what it contained, efficiently.
This commit is contained in:
parent
105f39adb5
commit
2de44a516f
@ -69,34 +69,44 @@ class SerializedBytes<T : Any>(bits: ByteArray) : OpaqueBytes(bits) {
|
||||
}
|
||||
|
||||
// Some extension functions that make deserialisation convenient and provide auto-casting of the result.
|
||||
inline fun <reified T : Any> ByteArray.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get(), includeClassName: Boolean = false): T {
|
||||
if (includeClassName)
|
||||
return kryo.readClassAndObject(Input(this)) as T
|
||||
else
|
||||
return kryo.readObject(Input(this), T::class.java)
|
||||
fun <T : Any> ByteArray.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): T {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return kryo.readClassAndObject(Input(this)) as T
|
||||
}
|
||||
|
||||
inline fun <reified T : Any> OpaqueBytes.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get(), includeClassName: Boolean = false): T {
|
||||
return this.bits.deserialize(kryo, includeClassName)
|
||||
fun <T : Any> OpaqueBytes.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): T {
|
||||
return this.bits.deserialize(kryo)
|
||||
}
|
||||
|
||||
// The more specific deserialize version results in the bytes being cached, which is faster.
|
||||
@JvmName("SerializedBytesWireTransaction")
|
||||
fun SerializedBytes<WireTransaction>.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): WireTransaction = WireTransaction.deserialize(this, kryo)
|
||||
|
||||
inline fun <reified T : Any> SerializedBytes<T>.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get(), includeClassName: Boolean = false): T = bits.deserialize(kryo, includeClassName)
|
||||
fun <T : Any> SerializedBytes<T>.deserialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): T = bits.deserialize(kryo)
|
||||
|
||||
/**
|
||||
* A serialiser that avoids writing the wrapper class to the byte stream, thus ensuring [SerializedBytes] is a pure
|
||||
* type safety hack.
|
||||
*/
|
||||
object SerializedBytesSerializer : Serializer<SerializedBytes<Any>>() {
|
||||
override fun write(kryo: Kryo, output: Output, obj: SerializedBytes<Any>) {
|
||||
output.writeVarInt(obj.bits.size, true)
|
||||
output.writeBytes(obj.bits)
|
||||
}
|
||||
|
||||
override fun read(kryo: Kryo, input: Input, type: Class<SerializedBytes<Any>>): SerializedBytes<Any> {
|
||||
return SerializedBytes(input.readBytes(input.readVarInt(true)))
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Can be called on any object to convert it to a byte array (wrapped by [SerializedBytes]), regardless of whether
|
||||
* the type is marked as serializable or was designed for it (so be careful!)
|
||||
*/
|
||||
fun <T : Any> T.serialize(kryo: Kryo = THREAD_LOCAL_KRYO.get(), includeClassName: Boolean = false): SerializedBytes<T> {
|
||||
fun <T : Any> T.serialize(kryo: Kryo = THREAD_LOCAL_KRYO.get()): SerializedBytes<T> {
|
||||
val stream = ByteArrayOutputStream()
|
||||
Output(stream).use {
|
||||
if (includeClassName)
|
||||
kryo.writeClassAndObject(it, this)
|
||||
else
|
||||
kryo.writeObject(it, this)
|
||||
kryo.writeClassAndObject(it, this)
|
||||
}
|
||||
return SerializedBytes(stream.toByteArray())
|
||||
}
|
||||
@ -258,8 +268,6 @@ fun createKryo(k: Kryo = Kryo()): Kryo {
|
||||
}
|
||||
})
|
||||
|
||||
register(WireTransaction::class.java, WireTransactionSerializer)
|
||||
|
||||
// Some things where the JRE provides an efficient custom serialisation.
|
||||
val ser = JavaSerializer()
|
||||
val keyPair = generateKeyPair()
|
||||
@ -269,16 +277,13 @@ fun createKryo(k: Kryo = Kryo()): Kryo {
|
||||
|
||||
// Some classes have to be handled with the ImmutableClassSerializer because they need to have their
|
||||
// constructors be invoked (typically for lazy members).
|
||||
val immutables = listOf(
|
||||
SignedTransaction::class,
|
||||
SerializedBytes::class
|
||||
)
|
||||
register(SignedTransaction::class.java, ImmutableClassSerializer(SignedTransaction::class))
|
||||
|
||||
immutables.forEach {
|
||||
register(it.java, ImmutableClassSerializer(it))
|
||||
}
|
||||
// This class has special handling.
|
||||
register(WireTransaction::class.java, WireTransactionSerializer)
|
||||
|
||||
// TODO: See if we can make Lazy<T> serialize properly so we can use "by lazy" in serialized object.
|
||||
// This ensures a SerializedBytes<Foo> wrapper is written out as just a byte array.
|
||||
register(SerializedBytes::class.java, SerializedBytesSerializer)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -74,8 +74,8 @@ fun MessagingService.runOnNextMessage(topic: String = "", executor: Executor? =
|
||||
}
|
||||
}
|
||||
|
||||
fun MessagingService.send(topic: String, to: MessageRecipients, obj: Any, includeClassName: Boolean = false) {
|
||||
send(createMessage(topic, obj.serialize(includeClassName = includeClassName).bits), to)
|
||||
fun MessagingService.send(topic: String, to: MessageRecipients, obj: Any) {
|
||||
send(createMessage(topic, obj.serialize().bits), to)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -115,7 +115,10 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
|
||||
// And now re-wire the deserialised continuation back up to the network service.
|
||||
serviceHub.networkService.runOnNextMessage(topic, runInThread) { netMsg ->
|
||||
val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), awaitingObjectOfType)
|
||||
// TODO: See security note below.
|
||||
val obj: Any = THREAD_LOCAL_KRYO.get().readClassAndObject(Input(netMsg.data))
|
||||
if (!awaitingObjectOfType.isInstance(obj))
|
||||
throw ClassCastException("Received message of unexpected type: ${obj.javaClass.name} vs ${awaitingObjectOfType.name}")
|
||||
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
|
||||
iterateStateMachine(psm, serviceHub.networkService, logger, obj, checkpointKey) {
|
||||
try {
|
||||
@ -221,7 +224,16 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
persistCheckpoint(prevCheckpointKey, curPersistedBytes)
|
||||
val newCheckpointKey = curPersistedBytes.sha256()
|
||||
net.runOnNextMessage(topic, runInThread) { netMsg ->
|
||||
val obj: Any = THREAD_LOCAL_KRYO.get().readObject(Input(netMsg.data), responseType)
|
||||
// TODO: This is insecure: we should not deserialise whatever we find and *then* check.
|
||||
//
|
||||
// We should instead verify as we read the data that it's what we are expecting and throw as early as
|
||||
// possible. We only do it this way for convenience during the prototyping stage. Note that this means
|
||||
// we could simply not require the programmer to specify the expected return type at all, and catch it
|
||||
// at the last moment when we do the downcast. However this would make protocol code harder to read and
|
||||
// make it more difficult to migrate to a more explicit serialisation scheme later.
|
||||
val obj: Any = THREAD_LOCAL_KRYO.get().readClassAndObject(Input(netMsg.data))
|
||||
if (!responseType.isInstance(obj))
|
||||
throw ClassCastException("Expected message of type ${responseType.name} but got ${obj.javaClass.name}")
|
||||
logger.trace { "<- $topic : message of type ${obj.javaClass.name}" }
|
||||
iterateStateMachine(psm, net, logger, obj, newCheckpointKey) {
|
||||
try {
|
||||
@ -247,7 +259,11 @@ class StateMachineManager(val serviceHub: ServiceHub, val runInThread: Executor)
|
||||
val responseType: Class<R>
|
||||
) : FiberRequest(topic, destination, sessionIDForSend, sessionIDForReceive, obj)
|
||||
|
||||
class NotExpectingResponse(topic: String, destination: MessageRecipients, sessionIDForSend: Long, obj: Any?)
|
||||
: FiberRequest(topic, destination, sessionIDForSend, -1, obj)
|
||||
class NotExpectingResponse(
|
||||
topic: String,
|
||||
destination: MessageRecipients,
|
||||
sessionIDForSend: Long,
|
||||
obj: Any?
|
||||
) : FiberRequest(topic, destination, sessionIDForSend, -1, obj)
|
||||
}
|
||||
}
|
||||
|
@ -174,13 +174,13 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
|
||||
val myIdentity = Party(configuration.myLegalName, keypair.public)
|
||||
// We include the Party class with the file here to help catch mixups when admins provide files of the
|
||||
// wrong type by mistake.
|
||||
myIdentity.serialize(includeClassName = true).writeToFile(pubIdentityFile)
|
||||
myIdentity.serialize().writeToFile(pubIdentityFile)
|
||||
Pair(myIdentity, keypair)
|
||||
} else {
|
||||
// Check that the identity in the config file matches the identity file we have stored to disk.
|
||||
// This is just a sanity check. It shouldn't fail unless the admin has fiddled with the files and messed
|
||||
// things up for us.
|
||||
val myIdentity = Files.readAllBytes(pubIdentityFile).deserialize<Party>(includeClassName = true)
|
||||
val myIdentity = Files.readAllBytes(pubIdentityFile).deserialize<Party>()
|
||||
if (myIdentity.name != configuration.myLegalName)
|
||||
throw ConfigurationException("The legal name in the config file doesn't match the stored identity file:" +
|
||||
"${configuration.myLegalName} vs ${myIdentity.name}")
|
||||
|
@ -12,8 +12,8 @@ import core.node.services.ArtemisMessagingService
|
||||
import core.node.services.NodeInterestRates
|
||||
import core.node.services.ServiceType
|
||||
import core.node.services.TimestamperService
|
||||
import core.testing.MockNetworkMapCache
|
||||
import core.serialization.deserialize
|
||||
import core.testing.MockNetworkMapCache
|
||||
import core.utilities.BriefLogFormatter
|
||||
import demos.protocols.AutoOfferProtocol
|
||||
import demos.protocols.ExitServerProtocol
|
||||
@ -129,7 +129,7 @@ fun nodeInfo(hostAndPortString: String, identityFile: String, advertisedServices
|
||||
try {
|
||||
val addr = HostAndPort.fromString(hostAndPortString).withDefaultPort(Node.DEFAULT_PORT)
|
||||
val path = Paths.get(identityFile)
|
||||
val party = Files.readAllBytes(path).deserialize<Party>(includeClassName = true)
|
||||
val party = Files.readAllBytes(path).deserialize<Party>()
|
||||
return NodeInfo(ArtemisMessagingService.makeRecipient(addr), party, advertisedServices)
|
||||
} catch (e: Exception) {
|
||||
println("Could not find identify file $identityFile. If the file has just been created as part of starting the demo, please restart this node")
|
||||
|
@ -4,8 +4,8 @@ import contracts.Cash
|
||||
import core.*
|
||||
import core.node.Node
|
||||
import core.node.NodeConfiguration
|
||||
import core.node.services.ArtemisMessagingService
|
||||
import core.node.NodeInfo
|
||||
import core.node.services.ArtemisMessagingService
|
||||
import core.node.services.NodeInterestRates
|
||||
import core.serialization.deserialize
|
||||
import core.utilities.ANSIProgressRenderer
|
||||
@ -51,7 +51,7 @@ fun main(args: Array<String>) {
|
||||
|
||||
// Load oracle stuff (in lieu of having a network map service)
|
||||
val oracleAddr = ArtemisMessagingService.makeRecipient(options.valueOf(oracleAddrArg))
|
||||
val oracleIdentity = Files.readAllBytes(Paths.get(options.valueOf(oracleIdentityArg))).deserialize<Party>(includeClassName = true)
|
||||
val oracleIdentity = Files.readAllBytes(Paths.get(options.valueOf(oracleIdentityArg))).deserialize<Party>()
|
||||
val oracleNode = NodeInfo(oracleAddr, oracleIdentity)
|
||||
|
||||
val fixOf: FixOf = NodeInterestRates.parseFixOf(options.valueOf(fixOfArg))
|
||||
|
@ -87,7 +87,7 @@ fun main(args: Array<String>) {
|
||||
val timestamperId = if (options.has(timestamperIdentityFile)) {
|
||||
val addr = HostAndPort.fromString(options.valueOf(timestamperNetAddr)).withDefaultPort(Node.DEFAULT_PORT)
|
||||
val path = Paths.get(options.valueOf(timestamperIdentityFile))
|
||||
val party = Files.readAllBytes(path).deserialize<Party>(includeClassName = true)
|
||||
val party = Files.readAllBytes(path).deserialize<Party>()
|
||||
NodeInfo(ArtemisMessagingService.makeRecipient(addr), party, advertisedServices = setOf(TimestamperService.Type))
|
||||
} else null
|
||||
|
||||
|
@ -136,7 +136,7 @@ class AttachmentClassLoaderTests {
|
||||
fun `testing Kryo with ClassLoader (with top level class name)`() {
|
||||
val contract = createContract2Cash()
|
||||
|
||||
val bytes = contract.serialize(includeClassName = true)
|
||||
val bytes = contract.serialize()
|
||||
|
||||
val storage = MockAttachmentStorage()
|
||||
|
||||
@ -149,7 +149,7 @@ class AttachmentClassLoaderTests {
|
||||
val kryo = createKryo()
|
||||
kryo.classLoader = cl
|
||||
|
||||
val state2 = bytes.deserialize(kryo, true)
|
||||
val state2 = bytes.deserialize(kryo)
|
||||
assert(state2.javaClass.classLoader is AttachmentsClassLoader)
|
||||
assertNotNull(state2)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user