diff --git a/.ci/api-current.txt b/.ci/api-current.txt index 56c3c79be5..39388ad791 100644 --- a/.ci/api-current.txt +++ b/.ci/api-current.txt @@ -5937,6 +5937,13 @@ public @interface net.corda.core.serialization.CordaSerializationTransformRename public @interface net.corda.core.serialization.CordaSerializationTransformRenames public abstract net.corda.core.serialization.CordaSerializationTransformRename[] value() ## +public interface net.corda.core.serialization.CustomSerializationScheme + @NotNull + public abstract T deserialize(net.corda.core.utilities.ByteSequence, Class, net.corda.core.serialization.SerializationSchemeContext) + public abstract int getSchemeId() + @NotNull + public abstract net.corda.core.utilities.ByteSequence serialize(T, net.corda.core.serialization.SerializationSchemeContext) +## public @interface net.corda.core.serialization.DeprecatedConstructorForDeserialization public abstract int version() ## @@ -6078,6 +6085,13 @@ public static final class net.corda.core.serialization.SerializationFactory$Comp @NotNull public final net.corda.core.serialization.SerializationFactory getDefaultFactory() ## +@DoNotImplement +public interface net.corda.core.serialization.SerializationSchemeContext + @NotNull + public abstract ClassLoader getDeserializationClassLoader() + @NotNull + public abstract net.corda.core.serialization.ClassWhitelist getWhitelist() +## public interface net.corda.core.serialization.SerializationToken @NotNull public abstract Object fromToken(net.corda.core.serialization.SerializeAsTokenContext) diff --git a/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt b/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt index 882466a059..0f3c35300b 100644 --- a/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt +++ b/core-tests/src/test/kotlin/net/corda/coretests/transactions/TransactionBuilderTest.kt @@ -3,7 +3,16 @@ package net.corda.coretests.transactions import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.mock import com.nhaarman.mockito_kotlin.whenever -import net.corda.core.contracts.* +import net.corda.core.contracts.Command +import net.corda.core.contracts.ContractAttachment +import net.corda.core.contracts.HashAttachmentConstraint +import net.corda.core.contracts.PrivacySalt +import net.corda.core.contracts.SignatureAttachmentConstraint +import net.corda.core.contracts.StateAndRef +import net.corda.core.contracts.StateRef +import net.corda.core.contracts.TimeWindow +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.TransactionVerificationException import net.corda.core.cordapp.CordappProvider import net.corda.core.crypto.CompositeKey import net.corda.core.crypto.DigestService @@ -20,11 +29,16 @@ import net.corda.core.node.services.IdentityService import net.corda.core.node.services.NetworkParametersService import net.corda.core.serialization.serialize import net.corda.core.transactions.TransactionBuilder +import net.corda.coretesting.internal.rigorousMock import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyState -import net.corda.testing.core.* -import net.corda.coretesting.internal.rigorousMock +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.DummyCommandData +import net.corda.testing.core.SerializationEnvironmentRule +import net.corda.testing.core.TestIdentity import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.Assert.assertFalse @@ -35,6 +49,7 @@ import org.junit.Rule import org.junit.Test import java.security.PublicKey import java.time.Instant +import kotlin.test.assertFailsWith class TransactionBuilderTest { @Rule @@ -299,4 +314,22 @@ class TransactionBuilderTest { HashAgility.init() } } + + @Test(timeout=300_000) + fun `toWireTransaction fails if no scheme is registered with schemeId`() { + val outputState = TransactionState( + data = DummyState(), + contract = DummyContract.PROGRAM_ID, + notary = notary, + constraint = HashAttachmentConstraint(contractAttachmentId) + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, notary.owningKey) + + val schemeId = 7 + assertFailsWith("Could not find custom serialization scheme with SchemeId = $schemeId.") { + builder.toWireTransaction(services, schemeId) + } + } } diff --git a/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt b/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt index e807329a92..01ea4cf421 100644 --- a/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/TransactionUtils.kt @@ -154,7 +154,8 @@ fun createComponentGroups(inputs: List, timeWindow: TimeWindow?, references: List, networkParametersHash: SecureHash?): List { - val serialize = { value: Any, _: Int -> value.serialize() } + val serializationContext = SerializationFactory.defaultFactory.defaultContext + val serialize = { value: Any, _: Int -> value.serialize(context = serializationContext) } val componentGroupMap: MutableList = mutableListOf() if (inputs.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.INPUTS_GROUP.ordinal, inputs.lazyMapped(serialize))) if (references.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.REFERENCES_GROUP.ordinal, references.lazyMapped(serialize))) diff --git a/core/src/main/kotlin/net/corda/core/serialization/CustomSerializationScheme.kt b/core/src/main/kotlin/net/corda/core/serialization/CustomSerializationScheme.kt new file mode 100644 index 0000000000..599d250e67 --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/CustomSerializationScheme.kt @@ -0,0 +1,35 @@ +package net.corda.core.serialization + +import net.corda.core.utilities.ByteSequence +import java.io.NotSerializableException + +/*** + * Implement this interface to add your own Serialization Scheme. This is an experimental feature. All methods in this class MUST be + * thread safe i.e. methods from the same instance of this class can be called in different threads simultaneously. + */ +interface CustomSerializationScheme { + /** + * This method must return an id used to uniquely identify the Scheme. This should be unique within a network as serialized data might + * be sent over the wire. + */ + fun getSchemeId(): Int + + /** + * This method must deserialize the data stored [bytes] into an instance of [T]. + * + * @param bytes the serialized data. + * @param clazz the class to instantiate. + * @param context used to pass information about how the object should be deserialized. + */ + @Throws(NotSerializableException::class) + fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T + + /** + * This method must be able to serialize any object [T] into a ByteSequence. + * + * @param obj the object to be serialized. + * @param context used to pass information about how the object should be serialized. + */ + @Throws(NotSerializableException::class) + fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index c97a511db2..bcae581e66 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -134,7 +134,7 @@ interface SerializationContext { */ val encodingWhitelist: EncodingWhitelist /** - * A map of any addition properties specific to the particular use case. + * A map of any additional properties specific to the particular use case. */ val properties: Map /** @@ -178,6 +178,11 @@ interface SerializationContext { */ fun withProperty(property: Any, value: Any): SerializationContext + /** + * Helper method to return a new context based on this context with the extra properties added. + */ + fun withProperties(extraProperties: Map): SerializationContext + /** * Helper method to return a new context based on this context with object references disabled. */ diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationSchemeContext.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationSchemeContext.kt new file mode 100644 index 0000000000..eb1709390a --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationSchemeContext.kt @@ -0,0 +1,30 @@ +package net.corda.core.serialization + +import net.corda.core.DoNotImplement + +/** + * This is used to pass information into [CustomSerializationScheme] about how the object should be (de)serialized. + * This context can change depending on the specific circumstances in the node when (de)serialization occurs. + */ +@DoNotImplement +interface SerializationSchemeContext { + /** + * The class loader to use for deserialization. This is guaranteed to be able to load all the required classes passed into + * [CustomSerializationScheme.deserialize]. + */ + val deserializationClassLoader: ClassLoader + /** + * A whitelist that contains (mostly for security purposes) which classes are authorised to be deserialized. + * A secure implementation will not instantiate any object which is not either whitelisted or annotated with [CordaSerializable] when + * deserializing. To catch classes missing from the whitelist as early as possible it is HIGHLY recommended to also check this + * whitelist when serializing (as well as deserializing) objects. + */ + val whitelist: ClassWhitelist + /** + * A map of any additional properties specific to the particular use case. If these properties are set via + * [toWireTransaction][net.corda.core.transactions.TransactionBuilder.toWireTransaction] then they might not be available when + * deserializing. If the properties are required when deserializing, they can be added into the blob when serializing and read back + * when deserializing. + */ + val properties: Map +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/CustomSerializationSchemeUtils.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/CustomSerializationSchemeUtils.kt new file mode 100644 index 0000000000..b0588755aa --- /dev/null +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/CustomSerializationSchemeUtils.kt @@ -0,0 +1,28 @@ +package net.corda.core.serialization.internal + +import net.corda.core.KeepForDJVM +import net.corda.core.serialization.SerializationMagic +import net.corda.core.utilities.ByteSequence +import java.nio.ByteBuffer + +class CustomSerializationSchemeUtils { + + @KeepForDJVM + companion object { + + private const val SERIALIZATION_SCHEME_ID_SIZE = 4 + private val PREFIX = "CUS".toByteArray() + + fun getCustomSerializationMagicFromSchemeId(schemeId: Int) : SerializationMagic { + return SerializationMagic.of(PREFIX + ByteBuffer.allocate(SERIALIZATION_SCHEME_ID_SIZE).putInt(schemeId).array()) + } + + fun getSchemeIdIfCustomSerializationMagic(magic: SerializationMagic): Int? { + return if (magic.take(PREFIX.size) != ByteSequence.of(PREFIX)) { + null + } else { + return magic.slice(start = PREFIX.size).int + } + } + } +} \ No newline at end of file diff --git a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt index 2b2f9655c2..71fb6c728b 100644 --- a/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt +++ b/core/src/main/kotlin/net/corda/core/transactions/TransactionBuilder.kt @@ -16,14 +16,18 @@ import net.corda.core.node.ServicesForResolution import net.corda.core.node.ZoneVersionTooLowException import net.corda.core.node.services.AttachmentId import net.corda.core.node.services.KeyManagementService +import net.corda.core.serialization.CustomSerializationScheme import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationFactory +import net.corda.core.serialization.SerializationMagic +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId import net.corda.core.utilities.contextLogger import java.security.PublicKey import java.time.Duration import java.time.Instant -import java.util.ArrayDeque -import java.util.UUID +import java.util.* import java.util.regex.Pattern import kotlin.collections.ArrayList import kotlin.collections.component1 @@ -140,6 +144,41 @@ open class TransactionBuilder( fun toWireTransaction(services: ServicesForResolution): WireTransaction = toWireTransactionWithContext(services, null) .apply { checkSupportedHashType() } + /** + * Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction. + * + * @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction]. + * This is an experimental feature. + * + * @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder]. + * + * @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4. + */ + @Throws(MissingContractAttachments::class) + fun toWireTransaction(services: ServicesForResolution, schemeId: Int): WireTransaction { + return toWireTransaction(services, schemeId, emptyMap()).apply { checkSupportedHashType() } + } + + /** + * Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction. + * + * @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction]. + * This is an experimental feature. + * + * @param [properties] a list of properties to add to the [SerializationSchemeContext] these properties can be accessed in [CustomSerializationScheme.serialize] + * when serializing the componentGroups of the wire transaction but might not be available when deserializing. + * + * @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder]. + * + * @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4. + */ + @Throws(MissingContractAttachments::class) + fun toWireTransaction(services: ServicesForResolution, schemeId: Int, properties: Map): WireTransaction { + val magic: SerializationMagic = getCustomSerializationMagicFromSchemeId(schemeId) + val serializationContext = SerializationDefaults.P2P_CONTEXT.withPreferredSerializationVersion(magic).withProperties(properties) + return toWireTransactionWithContext(services, serializationContext).apply { checkSupportedHashType() } + } + @CordaInternal internal fun toWireTransactionWithContext( services: ServicesForResolution, diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapter.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapter.kt new file mode 100644 index 0000000000..f656f81502 --- /dev/null +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapter.kt @@ -0,0 +1,47 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId +import net.corda.core.utilities.ByteSequence +import net.corda.serialization.internal.CordaSerializationMagic +import net.corda.serialization.internal.SerializationScheme +import java.io.ByteArrayOutputStream +import java.io.NotSerializableException + +class CustomSerializationSchemeAdapter(private val customScheme: CustomSerializationScheme): SerializationScheme { + + val serializationSchemeMagic = getCustomSerializationMagicFromSchemeId(customScheme.getSchemeId()) + + override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean { + return magic == serializationSchemeMagic + } + + override fun deserialize(byteSequence: ByteSequence, clazz: Class, context: SerializationContext): T { + val readMagic = byteSequence.take(serializationSchemeMagic.size) + if (readMagic != serializationSchemeMagic) { + throw NotSerializableException("Scheme ${customScheme::class.java} is incompatible with blob." + + " Magic from blob = $readMagic (Expected = $serializationSchemeMagic)") + } + return customScheme.deserialize( + byteSequence.subSequence(serializationSchemeMagic.size, byteSequence.size - serializationSchemeMagic.size), + clazz, + SerializationSchemeContextAdapter(context) + ) + } + + override fun serialize(obj: T, context: SerializationContext): SerializedBytes { + val stream = ByteArrayOutputStream() + stream.write(serializationSchemeMagic.bytes) + stream.write(customScheme.serialize(obj, SerializationSchemeContextAdapter(context)).bytes) + return SerializedBytes(stream.toByteArray()) + } + + private class SerializationSchemeContextAdapter(context: SerializationContext) : SerializationSchemeContext { + override val deserializationClassLoader = context.deserializationClassLoader + override val whitelist = context.whitelist + override val properties = context.properties + } +} \ No newline at end of file diff --git a/node-api/src/test/java/net/corda/nodeapi/internal/serialization/DummyCustomSerializationSchemeInJava.java b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/DummyCustomSerializationSchemeInJava.java new file mode 100644 index 0000000000..3be21e0b86 --- /dev/null +++ b/node-api/src/test/java/net/corda/nodeapi/internal/serialization/DummyCustomSerializationSchemeInJava.java @@ -0,0 +1,30 @@ +package net.corda.nodeapi.internal.serialization; + +import net.corda.core.serialization.SerializationSchemeContext; +import net.corda.core.serialization.CustomSerializationScheme; +import net.corda.core.serialization.SerializedBytes; +import net.corda.core.utilities.ByteSequence; + +public class DummyCustomSerializationSchemeInJava implements CustomSerializationScheme { + + public class DummyOutput {} + + static final int testMagic = 7; + + @Override + public int getSchemeId() { + return testMagic; + } + + @Override + @SuppressWarnings("unchecked") + public T deserialize(ByteSequence bytes, Class clazz, SerializationSchemeContext context) { + return (T)new DummyOutput(); + } + + @Override + public SerializedBytes serialize(T obj, SerializationSchemeContext context) { + byte[] myBytes = {0xA, 0xA}; + return new SerializedBytes<>(myBytes); + } +} \ No newline at end of file diff --git a/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapterTests.kt b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapterTests.kt new file mode 100644 index 0000000000..2d4f751ddf --- /dev/null +++ b/node-api/src/test/kotlin/net/corda/nodeapi/internal/serialization/CustomSerializationSchemeAdapterTests.kt @@ -0,0 +1,93 @@ +package net.corda.nodeapi.internal.serialization + +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.utilities.ByteSequence +import net.corda.nodeapi.internal.serialization.testutils.serializationContext +import org.junit.Test +import org.junit.jupiter.api.Assertions.assertTrue +import java.io.NotSerializableException +import kotlin.test.assertFailsWith + +class CustomSerializationSchemeAdapterTests { + + companion object { + const val DEFAULT_SCHEME_ID = 7 + } + + class DummyInputClass + class DummyOutputClass + + class SingleInputAndOutputScheme(private val schemeId: Int = DEFAULT_SCHEME_ID): CustomSerializationScheme { + + override fun getSchemeId(): Int { + return schemeId + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + @Suppress("UNCHECKED_CAST") + return DummyOutputClass() as T + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + assertTrue(obj is DummyInputClass) + return ByteSequence.of(ByteArray(2) { 0x2 }) + } + } + + class SameBytesInputAndOutputsAndScheme: CustomSerializationScheme { + + private val expectedBytes = "123456789".toByteArray() + + override fun getSchemeId(): Int { + return DEFAULT_SCHEME_ID + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + bytes.open().use { + val data = ByteArray(expectedBytes.size) { 0 } + it.read(data) + assertTrue(data.contentEquals(expectedBytes)) + } + @Suppress("UNCHECKED_CAST") + return DummyOutputClass() as T + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + return ByteSequence.of(expectedBytes) + } + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter calls the correct methods in CustomSerializationScheme`() { + val scheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme()) + val serializedData = scheme.serialize(DummyInputClass(), serializationContext) + val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext) + assertTrue(roundTripped is DummyOutputClass) + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter can adapt a Java implementation`() { + val scheme = CustomSerializationSchemeAdapter(DummyCustomSerializationSchemeInJava()) + val serializedData = scheme.serialize(DummyInputClass(), serializationContext) + val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext) + assertTrue(roundTripped is DummyCustomSerializationSchemeInJava.DummyOutput) + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter validates the magic`() { + val inScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme()) + val serializedData = inScheme.serialize(DummyInputClass(), serializationContext) + val outScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme(8)) + assertFailsWith { + outScheme.deserialize(serializedData, DummyOutputClass::class.java, serializationContext) + } + } + + @Test(timeout=300_000) + fun `CustomSerializationSchemeAdapter preserves the serialized bytes between deserialize and serialize`() { + val scheme = CustomSerializationSchemeAdapter(SameBytesInputAndOutputsAndScheme()) + val serializedData = scheme.serialize(Any(), serializationContext) + scheme.deserialize(serializedData, Any::class.java, serializationContext) + } +} \ No newline at end of file diff --git a/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeDriverTest.kt b/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeDriverTest.kt new file mode 100644 index 0000000000..430b7dbac8 --- /dev/null +++ b/node/src/integration-test/kotlin/net/corda/node/CustomSerializationSchemeDriverTest.kt @@ -0,0 +1,342 @@ +package net.corda.node + +import co.paralleluniverse.fibers.Suspendable +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.io.Input +import com.esotericsoftware.kryo.io.Output +import de.javakaffee.kryoserializers.ArraysAsListSerializer +import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint +import net.corda.core.contracts.BelongsToContract +import net.corda.core.contracts.Contract +import net.corda.core.contracts.ContractState +import net.corda.core.contracts.TransactionState +import net.corda.core.contracts.TypeOnlyCommandData +import net.corda.core.crypto.Crypto +import net.corda.core.crypto.SecureHash +import net.corda.core.crypto.SignableData +import net.corda.core.crypto.SignatureMetadata +import net.corda.core.flows.CollectSignaturesFlow +import net.corda.core.flows.FinalityFlow +import net.corda.core.flows.FlowLogic +import net.corda.core.flows.FlowSession +import net.corda.core.flows.InitiatedBy +import net.corda.core.flows.InitiatingFlow +import net.corda.core.flows.ReceiveFinalityFlow +import net.corda.core.flows.SignTransactionFlow +import net.corda.core.flows.StartableByRPC +import net.corda.core.identity.AbstractParty +import net.corda.core.identity.Party +import net.corda.core.internal.concurrent.transpose +import net.corda.core.internal.copyBytes +import net.corda.core.messaging.startFlow +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic +import net.corda.core.transactions.LedgerTransaction +import net.corda.core.transactions.SignedTransaction +import net.corda.core.transactions.TransactionBuilder +import net.corda.core.transactions.WireTransaction +import net.corda.core.utilities.ByteSequence +import net.corda.core.utilities.getOrThrow +import net.corda.core.utilities.unwrap +import net.corda.serialization.internal.CordaSerializationMagic +import net.corda.serialization.internal.SerializationFactoryImpl +import net.corda.testing.core.ALICE_NAME +import net.corda.testing.core.BOB_NAME +import net.corda.testing.core.DUMMY_NOTARY_NAME +import net.corda.testing.core.TestIdentity +import net.corda.testing.driver.DriverParameters +import net.corda.testing.driver.NodeParameters +import net.corda.testing.driver.driver +import net.corda.testing.node.internal.enclosedCordapp +import org.junit.Test +import org.objenesis.instantiator.ObjectInstantiator +import org.objenesis.strategy.InstantiatorStrategy +import org.objenesis.strategy.StdInstantiatorStrategy +import java.io.ByteArrayOutputStream +import java.lang.reflect.Modifier +import java.util.* +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class CustomSerializationSchemeDriverTest { + + @Test(timeout = 300_000) + fun `flow can send wire transaction serialized with custom kryo serializer`() { + driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val (alice, bob) = listOf( + startNode(NodeParameters(providedName = ALICE_NAME)), + startNode(NodeParameters(providedName = BOB_NAME)) + ).transpose().getOrThrow() + + val flow = alice.rpc.startFlow(::SendFlow, bob.nodeInfo.legalIdentities.single()) + assertTrue { flow.returnValue.getOrThrow() } + } + } + + @Test(timeout = 300_000) + fun `flow can write a wire transaction serialized with custom kryo serializer to the ledger`() { + driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val (alice, bob) = listOf( + startNode(NodeParameters(providedName = ALICE_NAME)), + startNode(NodeParameters(providedName = BOB_NAME)) + ).transpose().getOrThrow() + + val flow = alice.rpc.startFlow(::WriteTxToLedgerFlow, bob.nodeInfo.legalIdentities.single(), defaultNotaryIdentity) + val txId = flow.returnValue.getOrThrow() + val transaction = alice.rpc.startFlow(::GetTxFromDBFlow, txId).returnValue.getOrThrow() + + for(group in transaction!!.tx.componentGroups) { + for (item in group.components) { + val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes()) + assertEquals( KryoScheme.SCHEME_ID, getSchemeIdIfCustomSerializationMagic(magic)) + } + } + } + } + + @Test(timeout = 300_000) + fun `Component groups are lazily serialized by the CustomSerializationScheme`() { + driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow() + //We don't need a real notary as we don't verify the transaction in this test. + val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20) + assertTrue { alice.rpc.startFlow(::CheckComponentGroupsFlow, dummyNotary.party).returnValue.getOrThrow() } + } + } + + @Test(timeout = 300_000) + fun `Map in the serialization context can be used by lazily component group serialization`() { + driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) { + val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow() + //We don't need a real notary as we don't verify the transaction in this test. + val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20) + assertTrue { alice.rpc.startFlow(::CheckComponentGroupsWithMapFlow, dummyNotary.party).returnValue.getOrThrow() } + } + } + + @StartableByRPC + @InitiatingFlow + class WriteTxToLedgerFlow(val counterparty: Party, val notary: Party) : FlowLogic() { + @Suspendable + override fun call(): SecureHash { + val outputState = TransactionState( + data = DummyContract.DummyState(), + contract = DummyContract::class.java.name, + notary = notary, + constraint = AlwaysAcceptAttachmentConstraint + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, counterparty.owningKey) + val wireTx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID) + val partSignedTx = signWireTx(wireTx) + val session = initiateFlow(counterparty) + val fullySignedTx = subFlow(CollectSignaturesFlow(partSignedTx, setOf(session))) + subFlow(FinalityFlow(fullySignedTx, setOf(session))) + return fullySignedTx.id + } + + fun signWireTx(wireTx: WireTransaction) : SignedTransaction { + val signatureMetadata = SignatureMetadata( + serviceHub.myInfo.platformVersion, + Crypto.findSignatureScheme(serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey).schemeNumberID + ) + val signableData = SignableData(wireTx.id, signatureMetadata) + val sig = serviceHub.keyManagementService.sign(signableData, serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey) + return SignedTransaction(wireTx, listOf(sig)) + } + } + + @InitiatedBy(WriteTxToLedgerFlow::class) + class SignWireTxFlow(private val session: FlowSession): FlowLogic() { + @Suspendable + override fun call(): SignedTransaction { + val signTransactionFlow = object : SignTransactionFlow(session) { + override fun checkTransaction(stx: SignedTransaction) { + return + } + } + val txId = subFlow(signTransactionFlow).id + return subFlow(ReceiveFinalityFlow(session, expectedTxId = txId)) + } + } + + @StartableByRPC + class GetTxFromDBFlow(private val txId: SecureHash): FlowLogic() { + override fun call(): SignedTransaction? { + return serviceHub.validatedTransactions.getTransaction(txId) + } + } + + @StartableByRPC + @InitiatingFlow + class CheckComponentGroupsFlow(val notary: Party) : FlowLogic() { + @Suspendable + override fun call(): Boolean { + val outputState = TransactionState( + data = DummyContract.DummyState(), + contract = DummyContract::class.java.name, + notary = notary, + constraint = AlwaysAcceptAttachmentConstraint + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, notary.owningKey) + + val wtx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID) + var success = true + for (group in wtx.componentGroups) { + //Component groups are lazily serialized as we iterate through. + for (item in group.components) { + val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes()) + success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoScheme.SCHEME_ID) + } + } + return success + } + } + + @StartableByRPC + @InitiatingFlow + class CheckComponentGroupsWithMapFlow(val notary: Party) : FlowLogic() { + @Suspendable + override fun call(): Boolean { + val outputState = TransactionState( + data = DummyContract.DummyState(), + contract = DummyContract::class.java.name, + notary = notary, + constraint = AlwaysAcceptAttachmentConstraint + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, notary.owningKey) + val mapToCheckWhenSerializing = mapOf(Pair(KryoSchemeWithMap.KEY, KryoSchemeWithMap.VALUE)) + val wtx = builder.toWireTransaction(serviceHub, KryoSchemeWithMap.SCHEME_ID, mapToCheckWhenSerializing) + var success = true + for (group in wtx.componentGroups) { + //Component groups are lazily serialized as we iterate through. + for (item in group.components) { + val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes()) + success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoSchemeWithMap.SCHEME_ID) + } + } + return success + } + } + + @StartableByRPC + @InitiatingFlow + class SendFlow(val counterparty: Party) : FlowLogic() { + @Suspendable + override fun call(): Boolean { + val outputState = TransactionState( + data = DummyContract.DummyState(), + contract = DummyContract::class.java.name, + notary = counterparty, + constraint = AlwaysAcceptAttachmentConstraint + ) + val builder = TransactionBuilder() + .addOutputState(outputState) + .addCommand(DummyCommandData, counterparty.owningKey) + + val wtx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID) + val session = initiateFlow(counterparty) + session.send(wtx) + return session.receive().unwrap {it} + } + } + + @InitiatedBy(SendFlow::class) + class ReceiveFlow(private val session: FlowSession): FlowLogic() { + @Suspendable + override fun call() { + val message = session.receive().unwrap {it} + message.toLedgerTransaction(serviceHub) + session.send(true) + } + } + + class DummyContract: Contract { + @BelongsToContract(DummyContract::class) + class DummyState(override val participants: List = listOf()) : ContractState + override fun verify(tx: LedgerTransaction) { + return + } + } + + object DummyCommandData : TypeOnlyCommandData() + + open class KryoScheme : CustomSerializationScheme { + + companion object { + const val SCHEME_ID = 7 + } + + override fun getSchemeId(): Int { + return SCHEME_ID + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + val kryo = Kryo() + customiseKryo(kryo, context.deserializationClassLoader) + + val obj = Input(bytes.open()).use { + kryo.readClassAndObject(it) + } + @Suppress("UNCHECKED_CAST") + return obj as T + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + val kryo = Kryo() + customiseKryo(kryo, context.deserializationClassLoader) + + val outputStream = ByteArrayOutputStream() + Output(outputStream).use { + kryo.writeClassAndObject(it, obj) + } + return ByteSequence.of(outputStream.toByteArray()) + } + + private fun customiseKryo(kryo: Kryo, classLoader: ClassLoader) { + kryo.instantiatorStrategy = CustomInstantiatorStrategy() + kryo.classLoader = classLoader + kryo.register(Arrays.asList("").javaClass, ArraysAsListSerializer()) + } + + //Stolen from DefaultKryoCustomizer.kt + private class CustomInstantiatorStrategy : InstantiatorStrategy { + private val fallbackStrategy = StdInstantiatorStrategy() + + // Use this to allow construction of objects using a JVM backdoor that skips invoking the constructors, if there + // is no no-arg constructor available. + private val defaultStrategy = Kryo.DefaultInstantiatorStrategy(fallbackStrategy) + + override fun newInstantiatorOf(type: Class): ObjectInstantiator { + // However this doesn't work for non-public classes in the java. namespace + val strat = if (type.name.startsWith("java.") && !Modifier.isPublic(type.modifiers)) fallbackStrategy else defaultStrategy + return strat.newInstantiatorOf(type) + } + } + } + + class KryoSchemeWithMap : KryoScheme() { + + companion object { + const val SCHEME_ID = 8 + const val KEY = "Key" + const val VALUE = "Value" + } + + override fun getSchemeId(): Int { + return SCHEME_ID + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + assertEquals(VALUE, context.properties[KEY]) + return super.serialize(obj, context) + } + + } +} diff --git a/node/src/main/kotlin/net/corda/node/internal/Node.kt b/node/src/main/kotlin/net/corda/node/internal/Node.kt index 504627dad2..4ee0da30ec 100644 --- a/node/src/main/kotlin/net/corda/node/internal/Node.kt +++ b/node/src/main/kotlin/net/corda/node/internal/Node.kt @@ -73,6 +73,7 @@ import net.corda.node.utilities.DemoClock import net.corda.node.utilities.errorAndTerminate import net.corda.nodeapi.internal.ArtemisMessagingClient import net.corda.common.logging.errorReporting.NodeDatabaseErrors +import net.corda.node.internal.classloading.scanForCustomSerializationScheme import net.corda.nodeapi.internal.ShutdownHook import net.corda.nodeapi.internal.addShutdownHook import net.corda.nodeapi.internal.bridging.BridgeControlListener @@ -647,10 +648,14 @@ open class Node(configuration: NodeConfiguration, private fun initialiseSerialization() { if (!initialiseSerialization) return val classloader = cordappLoader.appClassLoader + val customScheme = System.getProperty("experimental.corda.customSerializationScheme")?.let { + scanForCustomSerializationScheme(it, classloader) + } nodeSerializationEnv = SerializationEnvironment.with( SerializationFactoryImpl().apply { registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build().asMap())) registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build().asMap())) + customScheme?.let{ registerScheme(it) } }, p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader), diff --git a/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt b/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt index d89dc3a455..07d1631097 100644 --- a/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt +++ b/node/src/main/kotlin/net/corda/node/internal/classloading/Utils.kt @@ -2,6 +2,31 @@ package net.corda.node.internal.classloading +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.node.internal.ConfigurationException +import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter +import net.corda.serialization.internal.SerializationScheme +import java.lang.reflect.Constructor + inline fun Class<*>.requireAnnotation(): A { return requireNotNull(getDeclaredAnnotation(A::class.java)) { "$name needs to be annotated with ${A::class.java.name}" } +} + +fun scanForCustomSerializationScheme(className: String, classLoader: ClassLoader) : SerializationScheme { + val schemaClass = try { + classLoader.loadClass(className) + } catch (exception: ClassNotFoundException) { + throw ConfigurationException("$className was declared as a custom serialization scheme but could not be found.") + } + val constructor = validateScheme(schemaClass, className) + return CustomSerializationSchemeAdapter(constructor.newInstance() as CustomSerializationScheme) +} + +private fun validateScheme(clazz: Class<*>, className: String): Constructor<*> { + if (!clazz.interfaces.contains(CustomSerializationScheme::class.java)) { + throw ConfigurationException("$className was declared as a custom serialization scheme but does not implement" + + " ${CustomSerializationScheme::class.java.canonicalName}") + } + return clazz.constructors.singleOrNull { it.parameters.isEmpty() } ?: throw ConfigurationException("$className was declared as a " + + "custom serialization scheme but does not have a no argument constructor.") } \ No newline at end of file diff --git a/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt b/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt new file mode 100644 index 0000000000..79eb969b21 --- /dev/null +++ b/node/src/test/kotlin/net/corda/node/internal/CustomSerializationSchemeScanningTest.kt @@ -0,0 +1,78 @@ +package net.corda.node.internal + +import com.nhaarman.mockito_kotlin.whenever +import net.corda.core.serialization.CustomSerializationScheme +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializationSchemeContext +import net.corda.core.utilities.ByteSequence +import net.corda.node.internal.classloading.scanForCustomSerializationScheme +import org.junit.Test +import org.mockito.Mockito +import kotlin.test.assertFailsWith + +class CustomSerializationSchemeScanningTest { + + class NonSerializationScheme + + open class DummySerializationScheme : CustomSerializationScheme { + override fun getSchemeId(): Int { + return 7; + } + + override fun deserialize(bytes: ByteSequence, clazz: Class, context: SerializationSchemeContext): T { + throw DummySerializationSchemeException("We should never get here.") + } + + override fun serialize(obj: T, context: SerializationSchemeContext): ByteSequence { + throw DummySerializationSchemeException("Tried to serialize with DummySerializationScheme") + } + } + + class DummySerializationSchemeException(override val message: String) : RuntimeException(message) + + class DummySerializationSchemeWithoutNoArgConstructor(val myArgument: String) : DummySerializationScheme() + + @Test(timeout = 300_000) + fun `Can scan for custom serialization scheme and build a serialization scheme`() { + val classLoader = Mockito.mock(ClassLoader::class.java) + whenever(classLoader.loadClass(DummySerializationScheme::class.java.canonicalName)).thenAnswer { DummySerializationScheme::class.java } + val scheme = scanForCustomSerializationScheme(DummySerializationScheme::class.java.canonicalName, classLoader) + val mockContext = Mockito.mock(SerializationContext::class.java) + assertFailsWith("Tried to serialize with DummySerializationScheme") { + scheme.serialize(Any::class.java, mockContext) + } + } + + @Test(timeout = 300_000) + fun `verification fails with a helpful error if the class is not found in the classloader`() { + val classLoader = Mockito.mock(ClassLoader::class.java) + val missingClassName = DummySerializationScheme::class.java.canonicalName + whenever(classLoader.loadClass(missingClassName)).thenAnswer { throw ClassNotFoundException()} + assertFailsWith("$missingClassName was declared as a custom serialization scheme but could not " + + "be found.") { + scanForCustomSerializationScheme(missingClassName, classLoader) + } + } + + @Test(timeout = 300_000) + fun `verification fails with a helpful error if the class is not a custom serialization scheme`() { + val canonicalName = NonSerializationScheme::class.java.canonicalName + val classLoader = Mockito.mock(ClassLoader::class.java) + whenever(classLoader.loadClass(canonicalName)).thenAnswer { NonSerializationScheme::class.java } + assertFailsWith("$canonicalName was declared as a custom serialization scheme but does not " + + "implement CustomSerializationScheme.") { + scanForCustomSerializationScheme(canonicalName, classLoader) + } + } + + @Test(timeout = 300_000) + fun `verification fails with a helpful error if the class does not have a no arg constructor`() { + val classLoader = Mockito.mock(ClassLoader::class.java) + val canonicalName = DummySerializationSchemeWithoutNoArgConstructor::class.java.canonicalName + whenever(classLoader.loadClass(canonicalName)).thenAnswer { DummySerializationSchemeWithoutNoArgConstructor::class.java } + assertFailsWith("$canonicalName was declared as a custom serialization scheme but does not " + + "have a no argument constructor.") { + scanForCustomSerializationScheme(canonicalName, classLoader) + } + } +} \ No newline at end of file diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt index 2447ed9642..dbadd68339 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt @@ -6,6 +6,7 @@ import net.corda.core.crypto.SecureHash import net.corda.core.internal.VisibleForTesting import net.corda.core.internal.copyBytes import net.corda.core.serialization.* +import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic import net.corda.core.utilities.ByteSequence import net.corda.serialization.internal.amqp.amqpMagic import org.slf4j.LoggerFactory @@ -47,6 +48,10 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe return copy(properties = properties + (property to value)) } + override fun withProperties(extraProperties: Map): SerializationContext { + return copy(properties = properties + extraProperties) + } + override fun withoutReferences(): SerializationContext { return copy(objectReferencesEnabled = false) } @@ -106,7 +111,9 @@ open class SerializationFactoryImpl( registeredSchemes.filter { it.canDeserializeVersion(magic, target) }.forEach { return@computeIfAbsent it } // XXX: Not single? logger.warn("Cannot find serialization scheme for: [$lookupKey, " + "${if (magic == amqpMagic) "AMQP" else "UNKNOWN MAGIC"}] registeredSchemes are: $registeredSchemes") - throw UnsupportedOperationException("Serialization scheme $lookupKey not supported.") + val schemeId = getSchemeIdIfCustomSerializationMagic(magic) ?: throw UnsupportedOperationException("Serialization scheme" + + " $lookupKey not supported.") + throw UnsupportedOperationException("Could not find custom serialization scheme with SchemeId = $schemeId.") }) to magic } diff --git a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt index 116016b991..6345dd7549 100644 --- a/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt +++ b/testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal/InternalSerializationTestHelpers.kt @@ -3,9 +3,11 @@ package net.corda.coretesting.internal import net.corda.nodeapi.internal.rpc.client.AMQPClientSerializationScheme import net.corda.core.internal.createInstancesOfClassesImplementing import net.corda.core.serialization.CheckpointCustomSerializer +import net.corda.core.serialization.CustomSerializationScheme import net.corda.core.serialization.SerializationCustomSerializer import net.corda.core.serialization.SerializationWhitelist import net.corda.core.serialization.internal.SerializationEnvironment +import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme import net.corda.nodeapi.internal.serialization.kryo.KRYO_CHECKPOINT_CONTEXT import net.corda.nodeapi.internal.serialization.kryo.KryoCheckpointSerializer @@ -14,6 +16,7 @@ import net.corda.serialization.internal.AMQP_RPC_CLIENT_CONTEXT import net.corda.serialization.internal.AMQP_RPC_SERVER_CONTEXT import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT import net.corda.serialization.internal.SerializationFactoryImpl +import net.corda.serialization.internal.SerializationScheme import net.corda.testing.common.internal.asContextEnv import java.util.ServiceLoader import java.util.concurrent.ConcurrentHashMap @@ -27,20 +30,29 @@ fun createTestSerializationEnv(): SerializationEnvironment { fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment { var customCheckpointSerializers: Set> = emptySet() - val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) { + val serializationSchemes: MutableList = mutableListOf() + if (classLoader != null) { val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java) customCheckpointSerializers = createInstancesOfClassesImplementing(classLoader, CheckpointCustomSerializer::class.java) val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet() - Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists), - AMQPServerSerializationScheme(customSerializers, serializationWhitelists)) + serializationSchemes.add(AMQPClientSerializationScheme(customSerializers, serializationWhitelists)) + serializationSchemes.add(AMQPServerSerializationScheme(customSerializers, serializationWhitelists)) + + val customSchemes = createInstancesOfClassesImplementing(classLoader, CustomSerializationScheme::class.java) + for (customScheme in customSchemes) { + serializationSchemes.add(CustomSerializationSchemeAdapter(customScheme)) + } } else { - Pair(AMQPClientSerializationScheme(emptyList()), AMQPServerSerializationScheme(emptyList())) + serializationSchemes.add(AMQPClientSerializationScheme(emptyList())) + serializationSchemes.add(AMQPServerSerializationScheme(emptyList())) } + val factory = SerializationFactoryImpl().apply { - registerScheme(clientSerializationScheme) - registerScheme(serverSerializationScheme) + for (serializationScheme in serializationSchemes) { + registerScheme(serializationScheme) + } } return SerializationEnvironment.with( factory,