Pool Kryo instances for efficiency. (#352)

Pooled Kryo
This commit is contained in:
Rick Parker 2017-03-16 08:24:06 +00:00 committed by GitHub
parent b8a4c7bea3
commit f3a5f8e659
20 changed files with 259 additions and 200 deletions

View File

@ -3,9 +3,11 @@ package net.corda.core.serialization
import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.serializers.JavaSerializer import com.esotericsoftware.kryo.serializers.JavaSerializer
import com.esotericsoftware.kryo.serializers.MapSerializer import com.esotericsoftware.kryo.serializers.MapSerializer
import com.esotericsoftware.kryo.util.MapReferenceResolver import com.esotericsoftware.kryo.util.MapReferenceResolver
import com.google.common.annotations.VisibleForTesting
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.node.AttachmentsClassLoader import net.corda.core.node.AttachmentsClassLoader
@ -60,12 +62,9 @@ import kotlin.reflect.jvm.javaType
*/ */
// A convenient instance of Kryo pre-configured with some useful things. Used as a default by various functions. // A convenient instance of Kryo pre-configured with some useful things. Used as a default by various functions.
private val THREAD_LOCAL_KRYO: ThreadLocal<Kryo> = ThreadLocal.withInitial { createKryo() } fun p2PKryo(): KryoPool = kryoPool
// Same again, but this has whitelisting turned off for internal storage use only. // Same again, but this has whitelisting turned off for internal storage use only.
private val INTERNAL_THREAD_LOCAL_KRYO: ThreadLocal<Kryo> = ThreadLocal.withInitial { createInternalKryo() } fun storageKryo(): KryoPool = internalKryoPool
fun threadLocalP2PKryo(): Kryo = THREAD_LOCAL_KRYO.get()
fun threadLocalStorageKryo(): Kryo = INTERNAL_THREAD_LOCAL_KRYO.get()
/** /**
* A type safe wrapper around a byte array that contains a serialised object. You can call [SerializedBytes.deserialize] * A type safe wrapper around a byte array that contains a serialised object. You can call [SerializedBytes.deserialize]
@ -82,26 +81,34 @@ class SerializedBytes<T : Any>(bytes: ByteArray, val internalOnly: Boolean = fal
private val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray()) private val KryoHeaderV0_1: OpaqueBytes = OpaqueBytes("corda\u0000\u0000\u0001".toByteArray())
// Some extension functions that make deserialisation convenient and provide auto-casting of the result. // Some extension functions that make deserialisation convenient and provide auto-casting of the result.
fun <T : Any> ByteArray.deserialize(kryo: Kryo = threadLocalP2PKryo()): T { fun <T : Any> ByteArray.deserialize(kryo: KryoPool = p2PKryo()): T {
Input(this).use { Input(this).use {
val header = OpaqueBytes(it.readBytes(8)) val header = OpaqueBytes(it.readBytes(8))
if (header != KryoHeaderV0_1) { if (header != KryoHeaderV0_1) {
throw KryoException("Serialized bytes header does not match any known format.") throw KryoException("Serialized bytes header does not match any known format.")
} }
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
return kryo.readClassAndObject(it) as T return kryo.run { k -> k.readClassAndObject(it) as T }
} }
} }
fun <T : Any> OpaqueBytes.deserialize(kryo: Kryo = threadLocalP2PKryo()): T { // TODO: The preferred usage is with a pool. Try and eliminate use of this from RPC.
fun <T : Any> ByteArray.deserialize(kryo: Kryo): T = deserialize(kryo.asPool())
fun <T : Any> OpaqueBytes.deserialize(kryo: KryoPool = p2PKryo()): T {
return this.bytes.deserialize(kryo) return this.bytes.deserialize(kryo)
} }
// The more specific deserialize version results in the bytes being cached, which is faster. // The more specific deserialize version results in the bytes being cached, which is faster.
@JvmName("SerializedBytesWireTransaction") @JvmName("SerializedBytesWireTransaction")
fun SerializedBytes<WireTransaction>.deserialize(kryo: Kryo = threadLocalP2PKryo()): WireTransaction = WireTransaction.deserialize(this, kryo) fun SerializedBytes<WireTransaction>.deserialize(kryo: KryoPool = p2PKryo()): WireTransaction = WireTransaction.deserialize(this, kryo)
fun <T : Any> SerializedBytes<T>.deserialize(kryo: Kryo = if (internalOnly) threadLocalStorageKryo() else threadLocalP2PKryo()): T = bytes.deserialize(kryo) fun <T : Any> SerializedBytes<T>.deserialize(kryo: KryoPool = if (internalOnly) storageKryo() else p2PKryo()): T = bytes.deserialize(kryo)
fun <T : Any> SerializedBytes<T>.deserialize(kryo: Kryo): T = bytes.deserialize(kryo.asPool())
// Internal adapter for use when we haven't yet converted to a pool, or for tests.
private fun Kryo.asPool(): KryoPool = (KryoPool.Builder { this }.build())
/** /**
* A serialiser that avoids writing the wrapper class to the byte stream, thus ensuring [SerializedBytes] is a pure * A serialiser that avoids writing the wrapper class to the byte stream, thus ensuring [SerializedBytes] is a pure
@ -122,7 +129,11 @@ object SerializedBytesSerializer : Serializer<SerializedBytes<Any>>() {
* Can be called on any object to convert it to a byte array (wrapped by [SerializedBytes]), regardless of whether * Can be called on any object to convert it to a byte array (wrapped by [SerializedBytes]), regardless of whether
* the type is marked as serializable or was designed for it (so be careful!). * the type is marked as serializable or was designed for it (so be careful!).
*/ */
fun <T : Any> T.serialize(kryo: Kryo = threadLocalP2PKryo(), internalOnly: Boolean = false): SerializedBytes<T> { fun <T : Any> T.serialize(kryo: KryoPool = p2PKryo(), internalOnly: Boolean = false): SerializedBytes<T> {
return kryo.run { k -> serialize(k, internalOnly) }
}
fun <T : Any> T.serialize(kryo: Kryo, internalOnly: Boolean = false): SerializedBytes<T> {
val stream = ByteArrayOutputStream() val stream = ByteArrayOutputStream()
Output(stream).use { Output(stream).use {
it.writeBytes(KryoHeaderV0_1.bytes) it.writeBytes(KryoHeaderV0_1.bytes)
@ -399,14 +410,12 @@ object KotlinObjectSerializer : Serializer<DeserializeAsKotlinObjectDef>() {
} }
// No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. // No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors.
fun createInternalKryo(k: Kryo = CordaKryo(makeNoWhitelistClassResolver())): Kryo { private val internalKryoPool = KryoPool.Builder { DefaultKryoCustomizer.customize(CordaKryo(makeNoWhitelistClassResolver())) }.build()
return DefaultKryoCustomizer.customize(k) private val kryoPool = KryoPool.Builder { DefaultKryoCustomizer.customize(CordaKryo(makeStandardClassResolver())) }.build()
}
// No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. // No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors.
fun createKryo(k: Kryo = CordaKryo(makeStandardClassResolver())): Kryo { @VisibleForTesting
return DefaultKryoCustomizer.customize(k) fun createTestKryo(): Kryo = DefaultKryoCustomizer.customize(CordaKryo(makeNoWhitelistClassResolver()))
}
/** /**
* We need to disable whitelist checking during calls from our Kryo code to register a serializer, since it checks * We need to disable whitelist checking during calls from our Kryo code to register a serializer, since it checks
@ -475,21 +484,20 @@ inline fun <reified T : Any> Kryo.noReferencesWithin() {
class NoReferencesSerializer<T>(val baseSerializer: Serializer<T>) : Serializer<T>() { class NoReferencesSerializer<T>(val baseSerializer: Serializer<T>) : Serializer<T>() {
override fun read(kryo: Kryo, input: Input, type: Class<T>): T { override fun read(kryo: Kryo, input: Input, type: Class<T>): T {
val previousValue = kryo.setReferences(false) return kryo.withoutReferences { baseSerializer.read(kryo, input, type) }
try {
return baseSerializer.read(kryo, input, type)
} finally {
kryo.references = previousValue
}
} }
override fun write(kryo: Kryo, output: Output, obj: T) { override fun write(kryo: Kryo, output: Output, obj: T) {
val previousValue = kryo.setReferences(false) kryo.withoutReferences { baseSerializer.write(kryo, output, obj) }
try { }
baseSerializer.write(kryo, output, obj) }
} finally {
kryo.references = previousValue fun <T> Kryo.withoutReferences(block: () -> T): T {
} val previousValue = setReferences(false)
try {
return block()
} finally {
references = previousValue
} }
} }
@ -524,17 +532,6 @@ var Kryo.attachmentStorage: AttachmentStorage?
this.context.put(ATTACHMENT_STORAGE, value) this.context.put(ATTACHMENT_STORAGE, value)
} }
//TODO: It's a little workaround for serialization of HashMaps inside contract states.
//Used in Merkle tree calculation. It doesn't cover all the cases of unstable serialization format.
fun extendKryoHash(kryo: Kryo): Kryo {
return kryo.apply {
references = false
register(LinkedHashMap::class.java, MapSerializer())
register(HashMap::class.java, OrderedSerializer)
}
}
object OrderedSerializer : Serializer<HashMap<Any, Any>>() { object OrderedSerializer : Serializer<HashMap<Any, Any>>() {
override fun write(kryo: Kryo, output: Output, obj: HashMap<Any, Any>) { override fun write(kryo: Kryo, output: Output, obj: HashMap<Any, Any>) {
//Change a HashMap to LinkedHashMap. //Change a HashMap to LinkedHashMap.

View File

@ -5,7 +5,7 @@ import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import java.util.* import com.esotericsoftware.kryo.pool.KryoPool
/** /**
* The interfaces and classes in this file allow large, singleton style classes to * The interfaces and classes in this file allow large, singleton style classes to
@ -36,8 +36,6 @@ interface SerializationToken {
/** /**
* A Kryo serializer for [SerializeAsToken] implementations. * A Kryo serializer for [SerializeAsToken] implementations.
*
* This is registered in [createKryo].
*/ */
class SerializeAsTokenSerializer<T : SerializeAsToken> : Serializer<T>() { class SerializeAsTokenSerializer<T : SerializeAsToken> : Serializer<T>() {
override fun write(kryo: Kryo, output: Output, obj: T) { override fun write(kryo: Kryo, output: Output, obj: T) {
@ -76,8 +74,8 @@ class SerializeAsTokenSerializer<T : SerializeAsToken> : Serializer<T>() {
* Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary * Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary
* on the Kryo instance when serializing to enable/disable tokenization. * on the Kryo instance when serializing to enable/disable tokenization.
*/ */
class SerializeAsTokenContext(toBeTokenized: Any, kryo: Kryo = createKryo()) { class SerializeAsTokenContext(toBeTokenized: Any, kryoPool: KryoPool) {
internal val tokenToTokenized = HashMap<SerializationToken, SerializeAsToken>() internal val tokenToTokenized = mutableMapOf<SerializationToken, SerializeAsToken>()
internal var readOnly = false internal var readOnly = false
init { init {
@ -90,9 +88,11 @@ class SerializeAsTokenContext(toBeTokenized: Any, kryo: Kryo = createKryo()) {
* accidental registrations from occuring as these could not be deserialized in a deserialization-first * accidental registrations from occuring as these could not be deserialized in a deserialization-first
* scenario if they are not part of this iniital context construction serialization. * scenario if they are not part of this iniital context construction serialization.
*/ */
SerializeAsTokenSerializer.setContext(kryo, this) kryoPool.run { kryo ->
toBeTokenized.serialize(kryo) SerializeAsTokenSerializer.setContext(kryo, this)
SerializeAsTokenSerializer.clearContext(kryo) toBeTokenized.serialize(kryo)
SerializeAsTokenSerializer.clearContext(kryo)
}
readOnly = true readOnly = true
} }
} }

View File

@ -3,13 +3,12 @@ package net.corda.core.transactions
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.createKryo import net.corda.core.serialization.p2PKryo
import net.corda.core.serialization.extendKryoHash
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.serialization.withoutReferences
fun <T : Any> serializedHash(x: T): SecureHash { fun <T : Any> serializedHash(x: T): SecureHash {
val kryo = extendKryoHash(createKryo()) // Dealing with HashMaps inside states. return p2PKryo().run { kryo -> kryo.withoutReferences { x.serialize(kryo).hash } }
return x.serialize(kryo).hash
} }
/** /**

View File

@ -1,6 +1,6 @@
package net.corda.core.transactions package net.corda.core.transactions
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.MerkleTree import net.corda.core.crypto.MerkleTree
@ -10,8 +10,8 @@ import net.corda.core.indexOfOrThrow
import net.corda.core.node.ServicesForResolution import net.corda.core.node.ServicesForResolution
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.p2PKryo
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.serialization.threadLocalP2PKryo
import net.corda.core.utilities.Emoji import net.corda.core.utilities.Emoji
import java.security.PublicKey import java.security.PublicKey
@ -45,7 +45,7 @@ class WireTransaction(
override val id: SecureHash by lazy { merkleTree.hash } override val id: SecureHash by lazy { merkleTree.hash }
companion object { companion object {
fun deserialize(data: SerializedBytes<WireTransaction>, kryo: Kryo = threadLocalP2PKryo()): WireTransaction { fun deserialize(data: SerializedBytes<WireTransaction>, kryo: KryoPool = p2PKryo()): WireTransaction {
val wtx = data.bytes.deserialize<WireTransaction>(kryo) val wtx = data.bytes.deserialize<WireTransaction>(kryo)
wtx.cachedBytes = data wtx.cachedBytes = data
return wtx return wtx

View File

@ -88,7 +88,7 @@ class ProgressTracker(vararg steps: Step) {
@CordaSerializable @CordaSerializable
private data class Child(val tracker: ProgressTracker, @Transient val subscription: Subscription?) private data class Child(val tracker: ProgressTracker, @Transient val subscription: Subscription?)
private val childProgressTrackers = HashMap<Step, Child>() private val childProgressTrackers = mutableMapOf<Step, Child>()
init { init {
steps.forEach { steps.forEach {

View File

@ -1,18 +1,20 @@
package net.corda.core.crypto package net.corda.core.crypto
import com.esotericsoftware.kryo.serializers.MapSerializer import com.esotericsoftware.kryo.KryoException
import net.corda.contracts.asset.Cash import net.corda.contracts.asset.Cash
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash.Companion.zeroHash import net.corda.core.crypto.SecureHash.Companion.zeroHash
import net.corda.core.serialization.* import net.corda.core.serialization.p2PKryo
import net.corda.core.transactions.* import net.corda.core.serialization.serialize
import net.corda.core.utilities.* import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.core.utilities.DUMMY_PUBKEY_1
import net.corda.core.utilities.TEST_TX_TIME
import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP
import net.corda.testing.MEGA_CORP_PUBKEY import net.corda.testing.MEGA_CORP_PUBKEY
import net.corda.testing.ledger import net.corda.testing.ledger
import org.junit.Test import org.junit.Test
import java.util.*
import kotlin.test.* import kotlin.test.*
class PartialMerkleTreeTest { class PartialMerkleTreeTest {
@ -208,15 +210,12 @@ class PartialMerkleTreeTest {
assertFalse(pmt.verify(wrongRoot, inclHashes)) assertFalse(pmt.verify(wrongRoot, inclHashes))
} }
@Test @Test(expected = KryoException::class)
fun `hash map serialization`() { fun `hash map serialization not allowed`() {
val hm1 = hashMapOf("a" to 1, "b" to 2, "c" to 3, "e" to 4) val hm1 = hashMapOf("a" to 1, "b" to 2, "c" to 3, "e" to 4)
assert(serializedHash(hm1) == serializedHash(hm1.serialize().deserialize())) // It internally uses the ordered HashMap extension. p2PKryo().run { kryo ->
val kryo = extendKryoHash(createKryo()) hm1.serialize(kryo)
assertTrue(kryo.getSerializer(HashMap::class.java) is OrderedSerializer) }
assertTrue(kryo.getSerializer(LinkedHashMap::class.java) is MapSerializer)
val hm2 = hm1.serialize(kryo).deserialize(kryo)
assert(hm1.hashCode() == hm2.hashCode())
} }
private fun makeSimpleCashWtx(notary: Party, timestamp: Timestamp? = null, attachments: List<SecureHash> = emptyList()): WireTransaction { private fun makeSimpleCashWtx(notary: Party, timestamp: Timestamp? = null, attachments: List<SecureHash> = emptyList()): WireTransaction {

View File

@ -1,5 +1,6 @@
package net.corda.core.node package net.corda.core.node
import com.esotericsoftware.kryo.Kryo
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.Party import net.corda.core.crypto.Party
@ -11,7 +12,9 @@ import net.corda.core.utilities.DUMMY_NOTARY
import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP
import net.corda.testing.node.MockAttachmentStorage import net.corda.testing.node.MockAttachmentStorage
import org.apache.commons.io.IOUtils import org.apache.commons.io.IOUtils
import org.junit.After
import org.junit.Assert import org.junit.Assert
import org.junit.Before
import org.junit.Test import org.junit.Test
import java.io.ByteArrayInputStream import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
@ -75,6 +78,21 @@ class AttachmentClassLoaderTests {
class ClassLoaderForTests : URLClassLoader(arrayOf(ISOLATED_CONTRACTS_JAR_PATH), FilteringClassLoader) class ClassLoaderForTests : URLClassLoader(arrayOf(ISOLATED_CONTRACTS_JAR_PATH), FilteringClassLoader)
lateinit var kryo: Kryo
lateinit var kryo2: Kryo
@Before
fun setup() {
kryo = p2PKryo().borrow()
kryo2 = p2PKryo().borrow()
}
@After
fun teardown() {
p2PKryo().release(kryo)
p2PKryo().release(kryo2)
}
@Test @Test
fun `dynamically load AnotherDummyContract from isolated contracts jar`() { fun `dynamically load AnotherDummyContract from isolated contracts jar`() {
val child = ClassLoaderForTests() val child = ClassLoaderForTests()
@ -205,7 +223,6 @@ class AttachmentClassLoaderTests {
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val kryo = createKryo()
kryo.classLoader = cl kryo.classLoader = cl
kryo.addToWhitelist(contract.javaClass) kryo.addToWhitelist(contract.javaClass)
@ -224,7 +241,6 @@ class AttachmentClassLoaderTests {
assertNotNull(data.contract) assertNotNull(data.contract)
val kryo2 = createKryo()
kryo2.addToWhitelist(data.contract.javaClass) kryo2.addToWhitelist(data.contract.javaClass)
val bytes = data.serialize(kryo2) val bytes = data.serialize(kryo2)
@ -236,7 +252,6 @@ class AttachmentClassLoaderTests {
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader) val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val kryo = createKryo()
kryo.classLoader = cl kryo.classLoader = cl
kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)) kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl))
@ -263,7 +278,6 @@ class AttachmentClassLoaderTests {
val contract = contractClass.newInstance() as DummyContractBackdoor val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY) val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val storage = MockAttachmentStorage() val storage = MockAttachmentStorage()
val kryo = createKryo()
kryo.addToWhitelist(contract.javaClass) kryo.addToWhitelist(contract.javaClass)
kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child)) kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child))
kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child)) kryo.addToWhitelist(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child))
@ -279,7 +293,6 @@ class AttachmentClassLoaderTests {
val bytes = wireTransaction.serialize(kryo) val bytes = wireTransaction.serialize(kryo)
val kryo2 = createKryo()
// use empty attachmentStorage // use empty attachmentStorage
kryo2.attachmentStorage = storage kryo2.attachmentStorage = storage
@ -297,7 +310,6 @@ class AttachmentClassLoaderTests {
val contract = contractClass.newInstance() as DummyContractBackdoor val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY) val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val storage = MockAttachmentStorage() val storage = MockAttachmentStorage()
val kryo = createKryo()
// todo - think about better way to push attachmentStorage down to serializer // todo - think about better way to push attachmentStorage down to serializer
kryo.attachmentStorage = storage kryo.attachmentStorage = storage
@ -310,7 +322,6 @@ class AttachmentClassLoaderTests {
val bytes = wireTransaction.serialize(kryo) val bytes = wireTransaction.serialize(kryo)
val kryo2 = createKryo()
// use empty attachmentStorage // use empty attachmentStorage
kryo2.attachmentStorage = MockAttachmentStorage() kryo2.attachmentStorage = MockAttachmentStorage()

View File

@ -1,5 +1,6 @@
package net.corda.core.serialization package net.corda.core.serialization
import com.esotericsoftware.kryo.Kryo
import com.google.common.primitives.Ints import com.google.common.primitives.Ints
import net.corda.core.crypto.* import net.corda.core.crypto.*
import net.corda.core.messaging.Ack import net.corda.core.messaging.Ack
@ -7,6 +8,8 @@ import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.pqc.jcajce.provider.BouncyCastlePQCProvider import org.bouncycastle.pqc.jcajce.provider.BouncyCastlePQCProvider
import org.junit.After
import org.junit.Before
import org.junit.Test import org.junit.Test
import java.io.InputStream import java.io.InputStream
import java.security.Security import java.security.Security
@ -16,7 +19,17 @@ import kotlin.test.assertEquals
class KryoTests { class KryoTests {
private val kryo = createKryo() private lateinit var kryo: Kryo
@Before
fun setup() {
kryo = p2PKryo().borrow()
}
@After
fun teardown() {
p2PKryo().release(kryo)
}
@Test @Test
fun ok() { fun ok() {

View File

@ -15,12 +15,13 @@ class SerializationTokenTest {
@Before @Before
fun setup() { fun setup() {
kryo = threadLocalStorageKryo() kryo = storageKryo().borrow()
} }
@After @After
fun cleanup() { fun cleanup() {
SerializeAsTokenSerializer.clearContext(kryo) SerializeAsTokenSerializer.clearContext(kryo)
storageKryo().release(kryo)
} }
// Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized // Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized
@ -38,7 +39,7 @@ class SerializationTokenTest {
@Test @Test
fun `write token and read tokenizable`() { fun `write token and read tokenizable`() {
val tokenizableBefore = LargeTokenizable() val tokenizableBefore = LargeTokenizable()
val context = SerializeAsTokenContext(tokenizableBefore, kryo) val context = SerializeAsTokenContext(tokenizableBefore, storageKryo())
SerializeAsTokenSerializer.setContext(kryo, context) SerializeAsTokenSerializer.setContext(kryo, context)
val serializedBytes = tokenizableBefore.serialize(kryo) val serializedBytes = tokenizableBefore.serialize(kryo)
assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes) assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes)
@ -51,7 +52,7 @@ class SerializationTokenTest {
@Test @Test
fun `write and read singleton`() { fun `write and read singleton`() {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = SerializeAsTokenContext(tokenizableBefore, kryo) val context = SerializeAsTokenContext(tokenizableBefore, storageKryo())
SerializeAsTokenSerializer.setContext(kryo, context) SerializeAsTokenSerializer.setContext(kryo, context)
val serializedBytes = tokenizableBefore.serialize(kryo) val serializedBytes = tokenizableBefore.serialize(kryo)
val tokenizableAfter = serializedBytes.deserialize(kryo) val tokenizableAfter = serializedBytes.deserialize(kryo)
@ -61,7 +62,7 @@ class SerializationTokenTest {
@Test(expected = UnsupportedOperationException::class) @Test(expected = UnsupportedOperationException::class)
fun `new token encountered after context init`() { fun `new token encountered after context init`() {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = SerializeAsTokenContext(emptyList<Any>(), kryo) val context = SerializeAsTokenContext(emptyList<Any>(), storageKryo())
SerializeAsTokenSerializer.setContext(kryo, context) SerializeAsTokenSerializer.setContext(kryo, context)
tokenizableBefore.serialize(kryo) tokenizableBefore.serialize(kryo)
} }
@ -69,9 +70,9 @@ class SerializationTokenTest {
@Test(expected = UnsupportedOperationException::class) @Test(expected = UnsupportedOperationException::class)
fun `deserialize unregistered token`() { fun `deserialize unregistered token`() {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = SerializeAsTokenContext(emptyList<Any>(), kryo) val context = SerializeAsTokenContext(emptyList<Any>(), storageKryo())
SerializeAsTokenSerializer.setContext(kryo, context) SerializeAsTokenSerializer.setContext(kryo, context)
val serializedBytes = tokenizableBefore.toToken(SerializeAsTokenContext(emptyList<Any>(), kryo)).serialize(kryo) val serializedBytes = tokenizableBefore.toToken(SerializeAsTokenContext(emptyList<Any>(), storageKryo())).serialize(kryo)
serializedBytes.deserialize(kryo) serializedBytes.deserialize(kryo)
} }
@ -84,7 +85,7 @@ class SerializationTokenTest {
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `deserialize non-token`() { fun `deserialize non-token`() {
val tokenizableBefore = UnitSerializeAsToken() val tokenizableBefore = UnitSerializeAsToken()
val context = SerializeAsTokenContext(tokenizableBefore, kryo) val context = SerializeAsTokenContext(tokenizableBefore, storageKryo())
SerializeAsTokenSerializer.setContext(kryo, context) SerializeAsTokenSerializer.setContext(kryo, context)
val stream = ByteArrayOutputStream() val stream = ByteArrayOutputStream()
Output(stream).use { Output(stream).use {
@ -106,7 +107,7 @@ class SerializationTokenTest {
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `token returns unexpected type`() { fun `token returns unexpected type`() {
val tokenizableBefore = WrongTypeSerializeAsToken() val tokenizableBefore = WrongTypeSerializeAsToken()
val context = SerializeAsTokenContext(tokenizableBefore, kryo) val context = SerializeAsTokenContext(tokenizableBefore, storageKryo())
SerializeAsTokenSerializer.setContext(kryo, context) SerializeAsTokenSerializer.setContext(kryo, context)
val serializedBytes = tokenizableBefore.serialize(kryo) val serializedBytes = tokenizableBefore.serialize(kryo)
serializedBytes.deserialize(kryo) serializedBytes.deserialize(kryo)

View File

@ -4,7 +4,7 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoSerializable import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.createInternalKryo import net.corda.core.serialization.createTestKryo
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -106,7 +106,7 @@ class ProgressTrackerTest {
} }
} }
val kryo = createInternalKryo().apply { val kryo = createTestKryo().apply {
// This is required to make sure Kryo walks through the auto-generated members for the lambda below. // This is required to make sure Kryo walks through the auto-generated members for the lambda below.
fieldSerializerConfig.isIgnoreSyntheticFields = false fieldSerializerConfig.isIgnoreSyntheticFields = false
} }

View File

@ -11,6 +11,9 @@ import java.time.LocalDate
import java.time.Period import java.time.Period
import java.util.* import java.util.*
/**
* NOTE: We do not whitelist [HashMap] or [HashSet] since they are unstable under serialization.
*/
class DefaultWhitelist : CordaPluginRegistry() { class DefaultWhitelist : CordaPluginRegistry() {
override fun customizeSerialization(custom: SerializationCustomization): Boolean { override fun customizeSerialization(custom: SerializationCustomization): Boolean {
custom.apply { custom.apply {
@ -41,7 +44,6 @@ class DefaultWhitelist : CordaPluginRegistry() {
addToWhitelist(java.time.Instant::class.java) addToWhitelist(java.time.Instant::class.java)
addToWhitelist(java.time.LocalDate::class.java) addToWhitelist(java.time.LocalDate::class.java)
addToWhitelist(java.util.Collections.singletonMap("A", "B").javaClass) addToWhitelist(java.util.Collections.singletonMap("A", "B").javaClass)
addToWhitelist(java.util.HashMap::class.java)
addToWhitelist(java.util.LinkedHashMap::class.java) addToWhitelist(java.util.LinkedHashMap::class.java)
addToWhitelist(BigDecimal::class.java) addToWhitelist(BigDecimal::class.java)
addToWhitelist(LocalDate::class.java) addToWhitelist(LocalDate::class.java)

View File

@ -113,18 +113,20 @@ class CordaRPCClientImpl(private val session: ClientSession,
@GuardedBy("sessionLock") @GuardedBy("sessionLock")
private val addressToQueuedObservables = CacheBuilder.newBuilder().weakValues().build<String, QueuedObservable>() private val addressToQueuedObservables = CacheBuilder.newBuilder().weakValues().build<String, QueuedObservable>()
// This is used to hold a reference counted hard reference when we know there are subscribers. // This is used to hold a reference counted hard reference when we know there are subscribers.
private val hardReferencesToQueuedObservables = mutableSetOf<QueuedObservable>() private val hardReferencesToQueuedObservables = Collections.synchronizedSet(mutableSetOf<QueuedObservable>())
private var producer: ClientProducer? = null private var producer: ClientProducer? = null
private inner class ObservableDeserializer(private val qName: String, class ObservableDeserializer() : Serializer<Observable<Any>>() {
private val rpcName: String,
private val rpcLocation: Throwable) : Serializer<Observable<Any>>() {
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> { override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
val qName = kryo.context[RPCKryoQNameKey] as String
val rpcName = kryo.context[RPCKryoMethodNameKey] as String
val rpcLocation = kryo.context[RPCKryoLocationKey] as Throwable
val rpcClient = kryo.context[RPCKryoClientKey] as CordaRPCClientImpl
val handle = input.readInt(true) val handle = input.readInt(true)
val ob = sessionLock.withLock { val ob = rpcClient.sessionLock.withLock {
addressToQueuedObservables.getIfPresent(qName) ?: QueuedObservable(qName, rpcName, rpcLocation, this).apply { rpcClient.addressToQueuedObservables.getIfPresent(qName) ?: rpcClient.QueuedObservable(qName, rpcName, rpcLocation).apply {
addressToQueuedObservables.put(qName, this) rpcClient.addressToQueuedObservables.put(qName, this)
} }
} }
val result = ob.getForHandle(handle) val result = ob.getForHandle(handle)
@ -182,9 +184,17 @@ class CordaRPCClientImpl(private val session: ClientSession,
checkMethodVersion(method) checkMethodVersion(method)
// sendRequest may return a reconfigured Kryo if the method returns observables. val msg: ClientMessage = createMessage(method)
val kryo: Kryo = sendRequest(args, location, method) ?: createRPCKryo() // We could of course also check the return type of the method to see if it's Observable, but I'd
val next: ErrorOr<*> = receiveResponse(kryo, method, timeout) // rather haved the annotation be used consistently.
val returnsObservables = method.isAnnotationPresent(RPCReturnsObservables::class.java)
val kryo = if (returnsObservables) maybePrepareForObservables(location, method, msg) else createRPCKryoForDeserialization(this@CordaRPCClientImpl)
val next: ErrorOr<*> = try {
sendRequest(args, msg)
receiveResponse(kryo, method, timeout)
} finally {
releaseRPCKryoForDeserialization(kryo)
}
rpcLog.debug { "<- RPC <- ${method.name} = $next" } rpcLog.debug { "<- RPC <- ${method.name} = $next" }
return unwrapOrThrow(next) return unwrapOrThrow(next)
} }
@ -215,22 +225,18 @@ class CordaRPCClientImpl(private val session: ClientSession,
return next return next
} }
private fun sendRequest(args: Array<out Any>?, location: Throwable, method: Method): Kryo? { private fun sendRequest(args: Array<out Any>?, msg: ClientMessage) {
// We could of course also check the return type of the method to see if it's Observable, but I'd
// rather haved the annotation be used consistently.
val returnsObservables = method.isAnnotationPresent(RPCReturnsObservables::class.java)
sessionLock.withLock { sessionLock.withLock {
val msg: ClientMessage = createMessage(method) val argsKryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl)
val kryo = if (returnsObservables) maybePrepareForObservables(location, method, msg) else null
val serializedArgs = try { val serializedArgs = try {
(args ?: emptyArray<Any?>()).serialize(createRPCKryo()) (args ?: emptyArray<Any?>()).serialize(argsKryo)
} catch (e: KryoException) { } catch (e: KryoException) {
throw RPCException("Could not serialize RPC arguments", e) throw RPCException("Could not serialize RPC arguments", e)
} finally {
releaseRPCKryoForDeserialization(argsKryo)
} }
msg.writeBodyBufferBytes(serializedArgs.bytes) msg.writeBodyBufferBytes(serializedArgs.bytes)
producer!!.send(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, msg) producer!!.send(ArtemisMessagingComponent.RPC_REQUESTS_QUEUE, msg)
return kryo
} }
} }
@ -242,7 +248,7 @@ class CordaRPCClientImpl(private val session: ClientSession,
msg.putLongProperty(ClientRPCRequestMessage.OBSERVATIONS_TO, observationsId) msg.putLongProperty(ClientRPCRequestMessage.OBSERVATIONS_TO, observationsId)
// And make sure that we deserialise observable handles so that they're linked to the right // And make sure that we deserialise observable handles so that they're linked to the right
// queue. Also record a bit of metadata for debugging purposes. // queue. Also record a bit of metadata for debugging purposes.
return createRPCKryo(observableSerializer = ObservableDeserializer(observationsQueueName, method.name, location)) return createRPCKryoForDeserialization(this@CordaRPCClientImpl, observationsQueueName, method.name, location)
} }
private fun createMessage(method: Method): ClientMessage { private fun createMessage(method: Method): ClientMessage {
@ -278,8 +284,7 @@ class CordaRPCClientImpl(private val session: ClientSession,
@ThreadSafe @ThreadSafe
private inner class QueuedObservable(private val qName: String, private inner class QueuedObservable(private val qName: String,
private val rpcName: String, private val rpcName: String,
private val rpcLocation: Throwable, private val rpcLocation: Throwable) {
private val observableDeserializer: ObservableDeserializer) {
private val root = PublishSubject.create<MarshalledObservation>() private val root = PublishSubject.create<MarshalledObservation>()
private val rootShared = root.doOnUnsubscribe { close() }.share() private val rootShared = root.doOnUnsubscribe { close() }.share()
@ -345,8 +350,10 @@ class CordaRPCClientImpl(private val session: ClientSession,
private fun deliver(msg: ClientMessage) { private fun deliver(msg: ClientMessage) {
msg.acknowledge() msg.acknowledge()
val kryo = createRPCKryo(observableSerializer = observableDeserializer) val kryo = createRPCKryoForDeserialization(this@CordaRPCClientImpl, qName, rpcName, rpcLocation)
val received: MarshalledObservation = msg.deserialize(kryo) val received: MarshalledObservation = try { msg.deserialize(kryo) } finally {
releaseRPCKryoForDeserialization(kryo)
}
rpcLog.debug { "<- Observable [$rpcName] <- Received $received" } rpcLog.debug { "<- Observable [$rpcName] <- Received $received" }
synchronized(observables) { synchronized(observables) {
// Force creation of the buffer if it doesn't already exist. // Force creation of the buffer if it doesn't already exist.

View File

@ -42,6 +42,8 @@ abstract class RPCDispatcher(val ops: RPCOps, val userService: RPCUserService, v
private val queueToSubscription = HashMultimap.create<String, Subscription>() private val queueToSubscription = HashMultimap.create<String, Subscription>()
private val handleCounter = AtomicInteger()
// Created afresh for every RPC that is annotated as returning observables. Every time an observable is // Created afresh for every RPC that is annotated as returning observables. Every time an observable is
// encountered either in the RPC response or in an object graph that is being emitted by one of those // encountered either in the RPC response or in an object graph that is being emitted by one of those
// observables, the handle counter is incremented and the server-side observable is subscribed to. The // observables, the handle counter is incremented and the server-side observable is subscribed to. The
@ -49,41 +51,48 @@ abstract class RPCDispatcher(val ops: RPCOps, val userService: RPCUserService, v
// //
// When the observables are deserialised on the client side, the handle is read from the byte stream and // When the observables are deserialised on the client side, the handle is read from the byte stream and
// the queue is filtered to extract just those observations. // the queue is filtered to extract just those observations.
private inner class ObservableSerializer(private val toQName: String) : Serializer<Observable<Any>>() { class ObservableSerializer() : Serializer<Observable<Any>>() {
private val handleCounter = AtomicInteger() private fun toQName(kryo: Kryo): String = kryo.context[RPCKryoQNameKey] as String
private fun toDispatcher(kryo: Kryo): RPCDispatcher = kryo.context[RPCKryoDispatcherKey] as RPCDispatcher
override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> { override fun read(kryo: Kryo, input: Input, type: Class<Observable<Any>>): Observable<Any> {
throw UnsupportedOperationException("not implemented") throw UnsupportedOperationException("not implemented")
} }
override fun write(kryo: Kryo, output: Output, obj: Observable<Any>) { override fun write(kryo: Kryo, output: Output, obj: Observable<Any>) {
val handle = handleCounter.andIncrement val qName = toQName(kryo)
val dispatcher = toDispatcher(kryo)
val handle = dispatcher.handleCounter.andIncrement
output.writeInt(handle, true) output.writeInt(handle, true)
// Observables can do three kinds of callback: "next" with a content object, "completed" and "error". // Observables can do three kinds of callback: "next" with a content object, "completed" and "error".
// Materializing the observable converts these three kinds of callback into a single stream of objects // Materializing the observable converts these three kinds of callback into a single stream of objects
// representing what happened, which is useful for us to send over the wire. // representing what happened, which is useful for us to send over the wire.
val subscription = obj.materialize().subscribe { materialised: Notification<out Any> -> val subscription = obj.materialize().subscribe { materialised: Notification<out Any> ->
val newKryo = createRPCKryo(observableSerializer = this@ObservableSerializer) val newKryo = createRPCKryoForSerialization(qName, dispatcher)
val bits = MarshalledObservation(handle, materialised).serialize(newKryo) val bits = try { MarshalledObservation(handle, materialised).serialize(newKryo) } finally {
releaseRPCKryoForSerialization(newKryo)
}
rpcLog.debug("RPC sending observation: $materialised") rpcLog.debug("RPC sending observation: $materialised")
send(bits, toQName) dispatcher.send(bits, qName)
} }
synchronized(queueToSubscription) { synchronized(dispatcher.queueToSubscription) {
queueToSubscription.put(toQName, subscription) dispatcher.queueToSubscription.put(qName, subscription)
} }
} }
} }
fun dispatch(msg: ClientRPCRequestMessage) { fun dispatch(msg: ClientRPCRequestMessage) {
val (argsBytes, replyTo, observationsTo, methodName) = msg val (argsBytes, replyTo, observationsTo, methodName) = msg
val kryo = createRPCKryo(observableSerializer = if (observationsTo != null) ObservableSerializer(observationsTo) else null)
val response: ErrorOr<Any> = ErrorOr.catch { val response: ErrorOr<Any> = ErrorOr.catch {
val method = methodTable[methodName] ?: throw RPCException("Received RPC for unknown method $methodName - possible client/server version skew?") val method = methodTable[methodName] ?: throw RPCException("Received RPC for unknown method $methodName - possible client/server version skew?")
if (method.isAnnotationPresent(RPCReturnsObservables::class.java) && observationsTo == null) if (method.isAnnotationPresent(RPCReturnsObservables::class.java) && observationsTo == null)
throw RPCException("Received RPC without any destination for observations, but the RPC returns observables") throw RPCException("Received RPC without any destination for observations, but the RPC returns observables")
val args = argsBytes.deserialize(kryo) val kryo = createRPCKryoForSerialization(observationsTo, this)
val args = try { argsBytes.deserialize(kryo) } finally {
releaseRPCKryoForSerialization(kryo)
}
rpcLog.debug { "-> RPC -> $methodName(${args.joinToString()}) [reply to $replyTo]" } rpcLog.debug { "-> RPC -> $methodName(${args.joinToString()}) [reply to $replyTo]" }
@ -95,13 +104,15 @@ abstract class RPCDispatcher(val ops: RPCOps, val userService: RPCUserService, v
} }
rpcLog.debug { "<- RPC <- $methodName = $response " } rpcLog.debug { "<- RPC <- $methodName = $response " }
// Serialise, or send back a simple serialised ErrorOr structure if we couldn't do it. // Serialise, or send back a simple serialised ErrorOr structure if we couldn't do it.
val kryo = createRPCKryoForSerialization(observationsTo, this)
val responseBits = try { val responseBits = try {
response.serialize(kryo) response.serialize(kryo)
} catch (e: KryoException) { } catch (e: KryoException) {
rpcLog.error("Failed to respond to inbound RPC $methodName", e) rpcLog.error("Failed to respond to inbound RPC $methodName", e)
ErrorOr.of(e).serialize(kryo) ErrorOr.of(e).serialize(kryo)
} finally {
releaseRPCKryoForSerialization(kryo)
} }
send(responseBits, replyTo) send(responseBits, replyTo)
} }

View File

@ -7,6 +7,7 @@ import com.esotericsoftware.kryo.Registration
import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.flows.FlowException import net.corda.core.flows.FlowException
import net.corda.core.serialization.* import net.corda.core.serialization.*
@ -88,10 +89,16 @@ object ClassSerializer : Serializer<Class<*>>() {
@CordaSerializable @CordaSerializable
class PermissionException(msg: String) : RuntimeException(msg) class PermissionException(msg: String) : RuntimeException(msg)
object RPCKryoClientKey
object RPCKryoDispatcherKey
object RPCKryoQNameKey
object RPCKryoMethodNameKey
object RPCKryoLocationKey
// The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly. // The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly.
// This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes, // This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes,
// because we can see everything we're using in one place. // because we can see everything we're using in one place.
private class RPCKryo(observableSerializer: Serializer<Observable<Any>>? = null) : CordaKryo(makeStandardClassResolver()) { private class RPCKryo(observableSerializer: Serializer<Observable<Any>>) : CordaKryo(makeStandardClassResolver()) {
init { init {
DefaultKryoCustomizer.customize(this) DefaultKryoCustomizer.customize(this)
@ -99,49 +106,68 @@ private class RPCKryo(observableSerializer: Serializer<Observable<Any>>? = null)
register(Class::class.java, ClassSerializer) register(Class::class.java, ClassSerializer)
register(MultipartStream.ItemInputStream::class.java, InputStreamSerializer) register(MultipartStream.ItemInputStream::class.java, InputStreamSerializer)
register(MarshalledObservation::class.java, ImmutableClassSerializer(MarshalledObservation::class)) register(MarshalledObservation::class.java, ImmutableClassSerializer(MarshalledObservation::class))
} register(Observable::class.java, observableSerializer)
// TODO: workaround to prevent Observable registration conflict when using plugin registered kyro classes
private val observableRegistration: Registration? = observableSerializer?.let { register(Observable::class.java, it, 10000) }
private val listenableFutureRegistration: Registration? = observableSerializer?.let {
// Register ListenableFuture by making use of Observable serialisation.
// TODO Serialisation could be made more efficient as a future can only emit one value (or exception)
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
register(ListenableFuture::class, register(ListenableFuture::class,
read = { kryo, input -> it.read(kryo, input, Observable::class.java as Class<Observable<Any>>).toFuture() }, read = { kryo, input -> observableSerializer.read(kryo, input, Observable::class.java as Class<Observable<Any>>).toFuture() },
write = { kryo, output, obj -> it.write(kryo, output, obj.toObservable()) } write = { kryo, output, obj -> observableSerializer.write(kryo, output, obj.toObservable()) }
)
register(
FlowException::class,
read = { kryo, input ->
val message = input.readString()
val cause = kryo.readObjectOrNull(input, Throwable::class.java)
FlowException(message, cause)
},
write = { kryo, output, obj ->
// The subclass may have overridden toString so we use that
val message = if (obj.javaClass != FlowException::class.java) obj.toString() else obj.message
output.writeString(message)
kryo.writeObjectOrNull(output, obj.cause, Throwable::class.java)
}
) )
} }
// Avoid having to worry about the subtypes of FlowException by converting all of them to just FlowException.
// This is a temporary hack until a proper serialisation mechanism is in place.
private val flowExceptionRegistration: Registration = register(
FlowException::class,
read = { kryo, input ->
val message = input.readString()
val cause = kryo.readObjectOrNull(input, Throwable::class.java)
FlowException(message, cause)
},
write = { kryo, output, obj ->
// The subclass may have overridden toString so we use that
val message = if (obj.javaClass != FlowException::class.java) obj.toString() else obj.message
output.writeString(message)
kryo.writeObjectOrNull(output, obj.cause, Throwable::class.java)
}
)
override fun getRegistration(type: Class<*>): Registration { override fun getRegistration(type: Class<*>): Registration {
if (Observable::class.java.isAssignableFrom(type)) val annotated = context[RPCKryoQNameKey] != null
return observableRegistration ?: if (Observable::class.java.isAssignableFrom(type)) {
throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables") return if (annotated) super.getRegistration(Observable::class.java)
if (ListenableFuture::class.java.isAssignableFrom(type)) else throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables")
return listenableFutureRegistration ?: }
throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables") if (ListenableFuture::class.java.isAssignableFrom(type)) {
return if (annotated) super.getRegistration(ListenableFuture::class.java)
else throw IllegalStateException("This RPC was not annotated with @RPCReturnsObservables")
}
if (FlowException::class.java.isAssignableFrom(type)) if (FlowException::class.java.isAssignableFrom(type))
return flowExceptionRegistration return super.getRegistration(FlowException::class.java)
return super.getRegistration(type) return super.getRegistration(type)
} }
} }
fun createRPCKryo(observableSerializer: Serializer<Observable<Any>>? = null): Kryo = RPCKryo(observableSerializer) private val rpcSerKryoPool = KryoPool.Builder { RPCKryo(RPCDispatcher.ObservableSerializer()) }.build()
fun createRPCKryoForSerialization(qName: String? = null, dispatcher: RPCDispatcher? = null): Kryo {
val kryo = rpcSerKryoPool.borrow()
kryo.context.put(RPCKryoQNameKey, qName)
kryo.context.put(RPCKryoDispatcherKey, dispatcher)
return kryo
}
fun releaseRPCKryoForSerialization(kryo: Kryo) {
rpcSerKryoPool.release(kryo)
}
private val rpcDesKryoPool = KryoPool.Builder { RPCKryo(CordaRPCClientImpl.ObservableDeserializer()) }.build()
fun createRPCKryoForDeserialization(rpcClient: CordaRPCClientImpl, qName: String? = null, rpcName: String? = null, rpcLocation: Throwable? = null): Kryo {
val kryo = rpcDesKryoPool.borrow()
kryo.context.put(RPCKryoClientKey, rpcClient)
kryo.context.put(RPCKryoQNameKey, qName)
kryo.context.put(RPCKryoMethodNameKey, rpcName)
kryo.context.put(RPCKryoLocationKey, rpcLocation)
return kryo
}
fun releaseRPCKryoForDeserialization(kryo: Kryo) {
rpcDesKryoPool.release(kryo)
}

View File

@ -4,7 +4,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.serialization.threadLocalStorageKryo import net.corda.core.serialization.storageKryo
import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.utilities.* import net.corda.node.utilities.*
@ -39,7 +39,7 @@ class DBCheckpointStorage : CheckpointStorage {
private val checkpointStorage = synchronizedMap(CheckpointMap()) private val checkpointStorage = synchronizedMap(CheckpointMap())
override fun addCheckpoint(checkpoint: Checkpoint) { override fun addCheckpoint(checkpoint: Checkpoint) {
checkpointStorage.put(checkpoint.id, checkpoint.serialize(threadLocalStorageKryo(), true)) checkpointStorage.put(checkpoint.id, checkpoint.serialize(storageKryo(), true))
} }
override fun removeCheckpoint(checkpoint: Checkpoint) { override fun removeCheckpoint(checkpoint: Checkpoint) {

View File

@ -6,6 +6,7 @@ import co.paralleluniverse.io.serialization.kryo.KryoSerializer
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import com.codahale.metrics.Gauge import com.codahale.metrics.Gauge
import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.collect.HashMultimap import com.google.common.collect.HashMultimap
import com.google.common.util.concurrent.ListenableFuture import com.google.common.util.concurrent.ListenableFuture
import kotlinx.support.jdk8.collections.removeIf import kotlinx.support.jdk8.collections.removeIf
@ -71,6 +72,11 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor) inner class FiberScheduler : FiberExecutorScheduler("Same thread scheduler", executor)
private val quasarKryoPool = KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
DefaultKryoCustomizer.customize(serializer.kryo)
}.build()
companion object { companion object {
private val logger = loggerFor<StateMachineManager>() private val logger = loggerFor<StateMachineManager>()
internal val sessionTopic = TopicSession("platform.session") internal val sessionTopic = TopicSession("platform.session")
@ -354,32 +360,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal,
} }
private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes<FlowStateMachineImpl<*>> { private fun serializeFiber(fiber: FlowStateMachineImpl<*>): SerializedBytes<FlowStateMachineImpl<*>> {
val kryo = quasarKryo() return quasarKryo().run { kryo ->
// add the map of tokens -> tokenizedServices to the kyro context // add the map of tokens -> tokenizedServices to the kyro context
SerializeAsTokenSerializer.setContext(kryo, serializationContext) SerializeAsTokenSerializer.setContext(kryo, serializationContext)
return fiber.serialize(kryo) fiber.serialize(kryo)
}
} }
private fun deserializeFiber(checkpoint: Checkpoint): FlowStateMachineImpl<*> { private fun deserializeFiber(checkpoint: Checkpoint): FlowStateMachineImpl<*> {
val kryo = quasarKryo() return quasarKryo().run { kryo ->
// put the map of token -> tokenized into the kryo context // put the map of token -> tokenized into the kryo context
SerializeAsTokenSerializer.setContext(kryo, serializationContext) SerializeAsTokenSerializer.setContext(kryo, serializationContext)
return checkpoint.serializedFiber.deserialize(kryo).apply { fromCheckpoint = true } checkpoint.serializedFiber.deserialize(kryo).apply { fromCheckpoint = true }
}
private fun quasarKryo(): Kryo {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
return createKryo(serializer.kryo).apply {
// Because we like to stick a Kryo object in a ThreadLocal to speed things up a bit, we can end up trying to
// serialise the Kryo object itself when suspending a fiber. That's dumb, useless AND can cause crashes, so
// we avoid it here. This is checkpointing specific.
register(Kryo::class,
read = { kryo, input -> createKryo((Fiber.getFiberSerializer() as KryoSerializer).kryo) },
write = { kryo, output, obj -> }
)
} }
} }
private fun quasarKryo(): KryoPool = quasarKryoPool
private fun <T> createFiber(logic: FlowLogic<T>): FlowStateMachineImpl<T> { private fun <T> createFiber(logic: FlowLogic<T>): FlowStateMachineImpl<T> {
val id = StateMachineRunId.createRandom() val id = StateMachineRunId.createRandom()
return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) } return FlowStateMachineImpl(id, logic, scheduler).apply { initFiber(this) }

View File

@ -8,7 +8,6 @@ import net.corda.core.ThreadBox
import net.corda.core.bufferUntilSubscribed import net.corda.core.bufferUntilSubscribed
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.AbstractParty import net.corda.core.crypto.AbstractParty
import net.corda.core.crypto.AnonymousParty
import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.node.ServiceHub import net.corda.core.node.ServiceHub
@ -16,9 +15,9 @@ import net.corda.core.node.services.Vault
import net.corda.core.node.services.VaultService import net.corda.core.node.services.VaultService
import net.corda.core.node.services.unconsumedStates import net.corda.core.node.services.unconsumedStates
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.createKryo
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.serialization.storageKryo
import net.corda.core.tee import net.corda.core.tee
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
@ -76,8 +75,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
index = it.key.index index = it.key.index
stateStatus = Vault.StateStatus.UNCONSUMED stateStatus = Vault.StateStatus.UNCONSUMED
contractStateClassName = it.value.state.data.javaClass.name contractStateClassName = it.value.state.data.javaClass.name
// TODO: revisit Kryo bug when using THREAD_LOCAL_KYRO contractState = it.value.state.serialize(storageKryo()).bytes
contractState = it.value.state.serialize(createKryo()).bytes
notaryName = it.value.state.notary.name notaryName = it.value.state.notary.name
notaryKey = it.value.state.notary.owningKey.toBase58String() notaryKey = it.value.state.notary.owningKey.toBase58String()
recordedTime = services.clock.instant() recordedTime = services.clock.instant()
@ -165,8 +163,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
Sequence{iterator} Sequence{iterator}
.map { it -> .map { it ->
val stateRef = StateRef(SecureHash.parse(it.txId), it.index) val stateRef = StateRef(SecureHash.parse(it.txId), it.index)
// TODO: revisit Kryo bug when using THREAD_LOCAL_KRYO val state = it.contractState.deserialize<TransactionState<T>>(storageKryo())
val state = it.contractState.deserialize<TransactionState<T>>(createKryo())
StateAndRef(state, stateRef) StateAndRef(state, stateRef)
} }
} }
@ -184,7 +181,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
.and(VaultSchema.VaultStates::index eq it.index) .and(VaultSchema.VaultStates::index eq it.index)
result.get()?.each { result.get()?.each {
val stateRef = StateRef(SecureHash.parse(it.txId), it.index) val stateRef = StateRef(SecureHash.parse(it.txId), it.index)
val state = it.contractState.deserialize<TransactionState<*>>() val state = it.contractState.deserialize<TransactionState<*>>(storageKryo())
results += StateAndRef(state, stateRef) results += StateAndRef(state, stateRef)
} }
} }
@ -353,7 +350,7 @@ class NodeVaultService(private val services: ServiceHub, dataSourceProperties: P
while (rs.next()) { while (rs.next()) {
val txHash = SecureHash.parse(rs.getString(1)) val txHash = SecureHash.parse(rs.getString(1))
val index = rs.getInt(2) val index = rs.getInt(2)
val state = rs.getBytes(3).deserialize<TransactionState<ContractState>>(createKryo()) val state = rs.getBytes(3).deserialize<TransactionState<ContractState>>(storageKryo())
consumedStates.add(StateAndRef(state, StateRef(txHash, index))) consumedStates.add(StateAndRef(state, StateRef(txHash, index)))
} }
} }

View File

@ -3,7 +3,7 @@ package net.corda.node.utilities
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.serialization.threadLocalStorageKryo import net.corda.core.serialization.storageKryo
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import org.jetbrains.exposed.sql.* import org.jetbrains.exposed.sql.*
@ -65,7 +65,7 @@ fun bytesToBlob(value: SerializedBytes<*>, finalizables: MutableList<() -> Unit>
return blob return blob
} }
fun serializeToBlob(value: Any, finalizables: MutableList<() -> Unit>): Blob = bytesToBlob(value.serialize(threadLocalStorageKryo(), true), finalizables) fun serializeToBlob(value: Any, finalizables: MutableList<() -> Unit>): Blob = bytesToBlob(value.serialize(storageKryo(), true), finalizables)
fun <T : Any> bytesFromBlob(blob: Blob): SerializedBytes<T> { fun <T : Any> bytesFromBlob(blob: Blob): SerializedBytes<T> {
try { try {

View File

@ -10,8 +10,8 @@ import net.corda.core.crypto.DigitalSignature
import net.corda.core.crypto.NullPublicKey import net.corda.core.crypto.NullPublicKey
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.node.services.Vault import net.corda.core.node.services.Vault
import net.corda.core.serialization.createKryo
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.serialization.storageKryo
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.DUMMY_NOTARY import net.corda.core.utilities.DUMMY_NOTARY
@ -128,7 +128,7 @@ class RequeryConfigurationTest {
index = txnState.index index = txnState.index
stateStatus = Vault.StateStatus.UNCONSUMED stateStatus = Vault.StateStatus.UNCONSUMED
contractStateClassName = DummyContract.SingleOwnerState::class.java.name contractStateClassName = DummyContract.SingleOwnerState::class.java.name
contractState = DummyContract.SingleOwnerState(owner = DUMMY_PUBKEY_1).serialize(createKryo()).bytes contractState = DummyContract.SingleOwnerState(owner = DUMMY_PUBKEY_1).serialize(storageKryo()).bytes
notaryName = txn.tx.notary!!.name notaryName = txn.tx.notary!!.name
notaryKey = txn.tx.notary!!.owningKey.toBase58String() notaryKey = txn.tx.notary!!.owningKey.toBase58String()
recordedTime = Instant.now() recordedTime = Instant.now()

View File

@ -1,7 +1,6 @@
package net.corda.irs.testing package net.corda.irs.testing
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.node.recordTransactions
import net.corda.core.seconds import net.corda.core.seconds
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.DUMMY_NOTARY import net.corda.core.utilities.DUMMY_NOTARY
@ -77,8 +76,8 @@ fun createDummyIRS(irsSelect: Int): InterestRateSwap.State {
expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" + expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" +
"(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"), "(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"),
floatingLegPaymentSchedule = HashMap(), floatingLegPaymentSchedule = mutableMapOf(),
fixedLegPaymentSchedule = HashMap() fixedLegPaymentSchedule = mutableMapOf()
) )
val EUR = currency("EUR") val EUR = currency("EUR")
@ -167,8 +166,8 @@ fun createDummyIRS(irsSelect: Int): InterestRateSwap.State {
expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" + expression = Expression("( fixedLeg.notional.pennies * (fixedLeg.fixedRate.ratioUnit.value)) -" +
"(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"), "(floatingLeg.notional.pennies * (calculation.fixingSchedule.get(context.getDate('currentDate')).rate.ratioUnit.value))"),
floatingLegPaymentSchedule = HashMap(), floatingLegPaymentSchedule = mutableMapOf(),
fixedLegPaymentSchedule = HashMap() fixedLegPaymentSchedule = mutableMapOf()
) )
val EUR = currency("EUR") val EUR = currency("EUR")
@ -413,7 +412,7 @@ class IRSTests {
@Test @Test
fun `ensure failure occurs when no events in fix schedule`() { fun `ensure failure occurs when no events in fix schedule`() {
val irs = singleIRS() val irs = singleIRS()
val emptySchedule = HashMap<LocalDate, FixedRatePaymentEvent>() val emptySchedule = mutableMapOf<LocalDate, FixedRatePaymentEvent>()
transaction { transaction {
output() { output() {
irs.copy(calculation = irs.calculation.copy(fixedLegPaymentSchedule = emptySchedule)) irs.copy(calculation = irs.calculation.copy(fixedLegPaymentSchedule = emptySchedule))
@ -427,7 +426,7 @@ class IRSTests {
@Test @Test
fun `ensure failure occurs when no events in floating schedule`() { fun `ensure failure occurs when no events in floating schedule`() {
val irs = singleIRS() val irs = singleIRS()
val emptySchedule = HashMap<LocalDate, FloatingRatePaymentEvent>() val emptySchedule = mutableMapOf<LocalDate, FloatingRatePaymentEvent>()
transaction { transaction {
output() { output() {
irs.copy(calculation = irs.calculation.copy(floatingLegPaymentSchedule = emptySchedule)) irs.copy(calculation = irs.calculation.copy(floatingLegPaymentSchedule = emptySchedule))