diff --git a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt index 101329e05a..5ead84aef6 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/SerializationAPI.kt @@ -1,8 +1,6 @@ @file:KeepForDJVM package net.corda.core.serialization -import co.paralleluniverse.io.serialization.Serialization -import net.corda.core.CordaInternal import net.corda.core.DeleteForDJVM import net.corda.core.DoNotImplement import net.corda.core.KeepForDJVM @@ -12,6 +10,7 @@ import net.corda.core.serialization.internal.effectiveSerializationEnv import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.sequence +import java.io.NotSerializableException import java.sql.Blob data class ObjectWithCompatibleContext(val obj: T, val context: SerializationContext) @@ -152,7 +151,15 @@ interface SerializationContext { */ val lenientCarpenterEnabled: Boolean /** - * If true the serialization evolver will fail if the binary to be deserialized contains more fields then the current object from the classpath. + * If true, deserialization calls using this context will not fallback to using the Class Carpenter to attempt + * to construct classes present in the schema but not on the current classpath. + * + * The default is false. + */ + val carpenterDisabled: Boolean + /** + * If true the serialization evolver will fail if the binary to be deserialized contains more fields then the current object from + * the classpath. * * The default is false. */ @@ -182,6 +189,12 @@ interface SerializationContext { */ fun withLenientCarpenter(): SerializationContext + /** + * Returns a copy of the current context with carpentry of unknown classes disabled. On encountering + * such a class during deserialization the Serialization framework will throw a [NotSerializableException]. + */ + fun withoutCarpenter() : SerializationContext + /** * Return a new context based on this one but with a strict evolution. * @see preventDataLoss diff --git a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt index ea5f0d5f08..93a61c0f2f 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/internal/AttachmentsClassLoader.kt @@ -327,6 +327,7 @@ object AttachmentsClassLoaderBuilder { .withClassLoader(transactionClassLoader) .withWhitelist(whitelistedClasses) .withCustomSerializers(serializers) + .withoutCarpenter() } // Deserialize all relevant classes in the transaction classloader. diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt index 9c6162a874..2c03e2fa56 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/SerializationScheme.kt @@ -26,6 +26,7 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe override val encoding: SerializationEncoding?, override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist, override val lenientCarpenterEnabled: Boolean = false, + override val carpenterDisabled: Boolean = false, override val preventDataLoss: Boolean = false, override val customSerializers: Set> = emptySet()) : SerializationContext { /** @@ -45,6 +46,8 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe override fun withLenientCarpenter(): SerializationContext = copy(lenientCarpenterEnabled = true) + override fun withoutCarpenter(): SerializationContext = copy(carpenterDisabled = true) + override fun withPreventDataLoss(): SerializationContext = copy(preventDataLoss = true) override fun withClassLoader(classLoader: ClassLoader): SerializationContext { diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt index f5ed5dd924..503d370fd6 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/DeserializationInput.kt @@ -169,7 +169,7 @@ class DeserializationInput constructor( val objectRead = when (obj) { is DescribedType -> { // Look up serializer in factory by descriptor - val serializer = serializerFactory.get(obj.descriptor.toString(), schemas) + val serializer = serializerFactory.get(obj.descriptor.toString(), schemas, context) if (type != TypeIdentifier.UnknownType.getLocalType() && serializer.type != type && with(serializer.type) { !isSubClassOf(type) && !materiallyEquivalentTo(type) } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt index 03722bbd7f..af241c4efc 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/ObjectSerializer.kt @@ -17,7 +17,7 @@ interface ObjectSerializer : AMQPSerializer { if (typeInformation is LocalTypeInformation.NonComposable) throw NotSerializableException( "Trying to build an object serializer for ${typeInformation.typeIdentifier.prettyPrint(false)}, " + - "but it is not constructible from its public properties, and so requires a custom serialiser.") + "but it is not constructable from its public properties, and so requires a custom serialiser.") val typeDescriptor = factory.createDescriptor(typeInformation) val typeNotation = TypeNotationGenerator.getTypeNotation(typeInformation, typeDescriptor) diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/RemoteSerializerFactory.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/RemoteSerializerFactory.kt index c92947651b..e7b00f618a 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/RemoteSerializerFactory.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/RemoteSerializerFactory.kt @@ -1,5 +1,6 @@ package net.corda.serialization.internal.amqp +import net.corda.core.serialization.SerializationContext import net.corda.core.utilities.contextLogger import net.corda.serialization.internal.model.* import org.hibernate.type.descriptor.java.ByteTypeDescriptor @@ -16,8 +17,8 @@ interface RemoteSerializerFactory { * @param typeDescriptor The type descriptor for the type to obtain a serializer for. * @param schema The schemas sent along with the serialized data. */ - @Throws(NotSerializableException::class) - fun get(typeDescriptor: TypeDescriptor, schema: SerializationSchemas): AMQPSerializer + @Throws(NotSerializableException::class, ClassNotFoundException::class) + fun get(typeDescriptor: TypeDescriptor, schema: SerializationSchemas, context: SerializationContext): AMQPSerializer } /** @@ -57,14 +58,18 @@ class DefaultRemoteSerializerFactory( private val logger = contextLogger() } - override fun get(typeDescriptor: TypeDescriptor, schema: SerializationSchemas): AMQPSerializer = + override fun get( + typeDescriptor: TypeDescriptor, + schema: SerializationSchemas, + context: SerializationContext + ): AMQPSerializer = // If we have seen this descriptor before, we assume we have seen everything in this schema before. descriptorBasedSerializerRegistry.getOrBuild(typeDescriptor) { logger.trace("get Serializer descriptor=$typeDescriptor") // Interpret all of the types in the schema into RemoteTypeInformation, and reflect that into LocalTypeInformation. val remoteTypeInformationMap = remoteTypeModel.interpret(schema) - val reflected = reflect(remoteTypeInformationMap) + val reflected = reflect(remoteTypeInformationMap, context) // Get, and record in the registry, serializers for all of the types contained in the schema. // This will save us having to re-interpret the entire schema on re-entry when deserialising individual property values. @@ -79,7 +84,10 @@ class DefaultRemoteSerializerFactory( "Could not find type matching descriptor $typeDescriptor.") } - private fun getUncached(remoteTypeInformation: RemoteTypeInformation, localTypeInformation: LocalTypeInformation): AMQPSerializer { + private fun getUncached( + remoteTypeInformation: RemoteTypeInformation, + localTypeInformation: LocalTypeInformation + ): AMQPSerializer { val remoteDescriptor = remoteTypeInformation.typeDescriptor // Obtain a serializer and descriptor for the local type. @@ -117,9 +125,9 @@ ${localTypeInformation.prettyPrint(false)} } } - private fun reflect(remoteInformation: Map): + private fun reflect(remoteInformation: Map, context: SerializationContext): Map { - val localInformationByIdentifier = typeLoader.load(remoteInformation.values).mapValues { (_, type) -> + val localInformationByIdentifier = typeLoader.load(remoteInformation.values, context).mapValues { (_, type) -> localTypeModel.inspect(type) } diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt index 665eda5c56..7c5ec79c49 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/amqp/SerializerFactoryBuilder.kt @@ -17,10 +17,10 @@ object SerializerFactoryBuilder { whitelist, classCarpenter, DefaultDescriptorBasedSerializerRegistry(), - true, - null, - false, - false) + allowEvolution = true, + overrideFingerPrinter = null, + onlyCustomSerializers = false, + mustPreserveDataWhenEvolving = false) } @JvmStatic diff --git a/serialization/src/main/kotlin/net/corda/serialization/internal/model/TypeLoader.kt b/serialization/src/main/kotlin/net/corda/serialization/internal/model/TypeLoader.kt index ca541f6102..500a3555c4 100644 --- a/serialization/src/main/kotlin/net/corda/serialization/internal/model/TypeLoader.kt +++ b/serialization/src/main/kotlin/net/corda/serialization/internal/model/TypeLoader.kt @@ -1,5 +1,6 @@ package net.corda.serialization.internal.model +import net.corda.core.serialization.SerializationContext import net.corda.serialization.internal.carpenter.* import java.io.NotSerializableException import java.lang.ClassCastException @@ -14,7 +15,7 @@ interface TypeLoader { * * @param remoteTypeInformation The type information for the remote types. */ - fun load(remoteTypeInformation: Collection): Map + fun load(remoteTypeInformation: Collection, context: SerializationContext): Map } /** @@ -25,7 +26,10 @@ class ClassCarpentingTypeLoader(private val carpenter: RemoteTypeCarpenter, priv val cache = DefaultCacheProvider.createCache() - override fun load(remoteTypeInformation: Collection): Map { + override fun load( + remoteTypeInformation: Collection, + context: SerializationContext + ): Map { val remoteInformationByIdentifier = remoteTypeInformation.associateBy { it.typeIdentifier } // Grab all the types we can from the cache, or the classloader. @@ -33,6 +37,9 @@ class ClassCarpentingTypeLoader(private val carpenter: RemoteTypeCarpenter, priv try { identifier to cache.computeIfAbsent(identifier) { identifier.getLocalType(classLoader) } } catch (e: ClassNotFoundException) { + if (context.carpenterDisabled) { + throw e + } null } }.toMap() diff --git a/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java b/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java index d0e39bb214..105b017d08 100644 --- a/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java +++ b/serialization/src/test/java/net/corda/serialization/internal/carpenter/JavaCalculatedValuesToClassCarpenterTest.java @@ -7,6 +7,7 @@ import net.corda.core.serialization.SerializedBytes; import net.corda.serialization.internal.AllWhitelist; import net.corda.serialization.internal.amqp.*; import net.corda.serialization.internal.amqp.Schema; +import net.corda.serialization.internal.amqp.testutils.TestSerializationContext; import net.corda.serialization.internal.model.RemoteTypeInformation; import net.corda.serialization.internal.model.TypeIdentifier; import net.corda.testing.core.SerializationEnvironmentRule; @@ -77,7 +78,7 @@ public class JavaCalculatedValuesToClassCarpenterTest extends AmqpCarpenterBase RemoteTypeInformation renamed = rename(typeInformation, typeToMangle, mangle(typeToMangle)); - Class pinochio = load(renamed); + Class pinochio = load(renamed, TestSerializationContext.testSerializationContext); Object p = pinochio.getConstructors()[0].newInstance(4, 2, "4"); assertEquals(2, pinochio.getMethod("getI").invoke(p)); diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryTests.kt index 1d7fa23dd6..0d5fbec240 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeNeedingCarpentryTests.kt @@ -1,11 +1,15 @@ package net.corda.serialization.internal.amqp import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializationContext +import net.corda.core.serialization.SerializedBytes import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.amqp.testutils.* import net.corda.serialization.internal.carpenter.* import org.junit.Test +import java.io.NotSerializableException import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertNotEquals import kotlin.test.assertTrue @@ -35,6 +39,17 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { // Deserialize with whitelisting on to check that `CordaSerializable` annotation present. private val sf2 = testDefaultFactoryWithWhitelist() + private inline fun DeserializationInput.deserializeWithoutAndWithCarpenter( + bytes: SerializedBytes, + context: SerializationContext? = null + ) : T { + assertFailsWith(NotSerializableException::class) { + deserialize(bytes, T::class.java, (context ?: testSerializationContext).withoutCarpenter()) + } + + return deserialize(bytes, T::class.java, context ?: testSerializationContext) + } + @Test fun verySimpleType() { val testVal = 10 @@ -52,19 +67,20 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { assertEquals(deserializedObj1::class.java, deserializedObj2::class.java) assertEquals(testVal, deserializedObj2::class.java.getMethod("getA").invoke(deserializedObj2)) - val deserializedObj3 = DeserializationInput(sf2).deserialize(serialisedBytes) + val deserializedObj3 = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serialisedBytes) assertNotEquals(clazz, deserializedObj3::class.java) assertNotEquals(deserializedObj1::class.java, deserializedObj3::class.java) assertNotEquals(deserializedObj2::class.java, deserializedObj3::class.java) assertEquals(testVal, deserializedObj3::class.java.getMethod("getA").invoke(deserializedObj3)) + // NOTE: There is no point attempting this without the carepenter a second time as having carpented things up once + // it will, of course, just succeed even with the carpenter disabled val deserializedObj4 = DeserializationInput(sf2).deserialize(serialisedBytes) assertNotEquals(clazz, deserializedObj4::class.java) assertNotEquals(deserializedObj1::class.java, deserializedObj4::class.java) assertNotEquals(deserializedObj2::class.java, deserializedObj4::class.java) assertEquals(deserializedObj3::class.java, deserializedObj4::class.java) assertEquals(testVal, deserializedObj4::class.java.getMethod("getA").invoke(deserializedObj4)) - } @Test @@ -79,7 +95,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { val concreteB = clazz.constructors[0].newInstance(testValB) val concreteC = clazz.constructors[0].newInstance(testValC) - val deserialisedA = DeserializationInput(sf2).deserialize( + val deserialisedA = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter( TestSerializationOutput(VERBOSE, sf1).serialize(concreteA)) assertEquals(testValA, deserialisedA::class.java.getMethod("getA").invoke(deserialisedA)) @@ -114,7 +130,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { val classInstance = clazz.constructors[0].newInstance(testVal) val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) - val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + val deserializedObj = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serialisedBytes) assertTrue(deserializedObj is I) assertEquals(testVal, (deserializedObj as I).getName()) @@ -133,7 +149,8 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { clazz.constructors[0].newInstance(2), clazz.constructors[0].newInstance(3))) - val deserializedObj = DeserializationInput(sf2).deserialize(TestSerializationOutput(VERBOSE, sf1).serialize(outer)) + val deserializedObj = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter( + TestSerializationOutput(VERBOSE, sf1).serialize(outer)) assertNotEquals((deserializedObj.a[0])::class.java, (outer.a[0])::class.java) assertNotEquals((deserializedObj.a[1])::class.java, (outer.a[1])::class.java) @@ -164,9 +181,9 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { val outer = outerType.constructors[0].newInstance(innerType.constructors[0].newInstance(2)) val serializedI = TestSerializationOutput(VERBOSE, sf1).serialize(inner) - val deserialisedI = DeserializationInput(sf2).deserialize(serializedI) + val deserialisedI = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serializedI) val serialisedO = TestSerializationOutput(VERBOSE, sf1).serialize(outer) - val deserialisedO = DeserializationInput(sf2).deserialize(serialisedO) + val deserialisedO = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serialisedO) // ensure out carpented version of inner is reused assertEquals(deserialisedI::class.java, @@ -184,7 +201,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { val classInstance = outerClass.constructors.first().newInstance(nestedClass.constructors.first().newInstance("name")) val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) - val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + val deserializedObj = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serialisedBytes) val inner = deserializedObj::class.java.getMethod("getInner").invoke(deserializedObj) assertEquals("name", inner::class.java.getMethod("getName").invoke(inner)) @@ -204,7 +221,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { nestedClass.constructors.first().newInstance("bar")) val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance) - val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + val deserializedObj = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serialisedBytes) assertEquals("foo", deserializedObj.a::class.java.getMethod("getName").invoke(deserializedObj.a)) assertEquals("bar", deserializedObj.b::class.java.getMethod("getName").invoke(deserializedObj.b)) @@ -226,7 +243,8 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { unknownClass.constructors.first().newInstance(7, 8))) val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(toSerialise) - val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + + val deserializedObj = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serialisedBytes) var sentinel = 1 deserializedObj.l.forEach { assertEquals(sentinel++, it::class.java.getMethod("getV1").invoke(it)) @@ -249,8 +267,8 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) { val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize( concreteClass.constructors.first().newInstance(12, "timmy")) - val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes) + val deserializedObj = DeserializationInput(sf2).deserializeWithoutAndWithCarpenter(serialisedBytes) assertTrue(deserializedObj is I) assertEquals("timmy", (deserializedObj as I).getName()) assertEquals("timmy", deserializedObj::class.java.getMethod("getName").invoke(deserializedObj)) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt index 0baf197d83..df9bc8316b 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/amqp/DeserializeSimpleTypesTests.kt @@ -550,7 +550,7 @@ class DeserializeSimpleTypesTests { @Test fun classHasNoPublicConstructor() { assertFailsWithMessage("Trying to build an object serializer for ${Garbo::class.java.name}, " + - "but it is not constructible from its public properties, and so requires a custom serialiser.") { + "but it is not constructable from its public properties, and so requires a custom serialiser.") { TestSerializationOutput(VERBOSE, sf1).serializeAndReturnSchema(Garbo.make(1)) } } @@ -558,7 +558,7 @@ class DeserializeSimpleTypesTests { @Test fun propertyClassHasNoPublicConstructor() { assertFailsWithMessage("Trying to build an object serializer for ${Greta::class.java.name}, " + - "but it is not constructible from its public properties, and so requires a custom serialiser.") { + "but it is not constructable from its public properties, and so requires a custom serialiser.") { TestSerializationOutput(VERBOSE, sf1).serializeAndReturnSchema(Greta(Garbo.make(1))) } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt index 0ec2fed1ab..189d176d3c 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/ClassCarpenterTestUtils.kt @@ -2,6 +2,7 @@ package net.corda.serialization.internal.carpenter import com.google.common.reflect.TypeToken import net.corda.core.serialization.ClassWhitelist +import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializedBytes import net.corda.serialization.internal.amqp.* import net.corda.serialization.internal.amqp.testutils.deserializeAndReturnEnvelope @@ -83,11 +84,11 @@ open class AmqpCarpenterBase(whitelist: ClassWhitelist) { else -> this } - protected fun RemoteTypeInformation.load(): Class<*> = - typeLoader.load(listOf(this))[typeIdentifier]!!.asClass() + protected fun RemoteTypeInformation.load(context : SerializationContext): Class<*> = + typeLoader.load(listOf(this), context)[typeIdentifier]!!.asClass() - protected fun assertCanLoadAll(vararg types: RemoteTypeInformation) { - assertTrue(typeLoader.load(types.asList()).keys.containsAll(types.map { it.typeIdentifier })) + protected fun assertCanLoadAll(context: SerializationContext, vararg types: RemoteTypeInformation) { + assertTrue(typeLoader.load(types.asList(), context).keys.containsAll(types.map { it.typeIdentifier })) } protected fun Class<*>.new(vararg constructorParams: Any?) = diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt index b1700ae8c2..4b5bbe94b2 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/CompositeMemberCompositeSchemaToClassCarpenterTests.kt @@ -2,6 +2,7 @@ package net.corda.serialization.internal.carpenter import net.corda.core.serialization.CordaSerializable import net.corda.serialization.internal.AllWhitelist +import net.corda.serialization.internal.amqp.testutils.testSerializationContext import org.junit.Test import java.io.NotSerializableException import java.util.* @@ -26,7 +27,7 @@ class CompositeMembers : AmqpCarpenterBase(AllWhitelist) { val (_, envelope) = B(A(10), 20).roundTrip() // We load an unknown class, B_mangled, which includes a reference to a known class, A. - assertCanLoadAll(envelope.getMangled()) + assertCanLoadAll(testSerializationContext, envelope.getMangled()) } @Test @@ -41,7 +42,8 @@ class CompositeMembers : AmqpCarpenterBase(AllWhitelist) { // We load an unknown class, B_mangled, which includes a reference to an unknown class, A_mangled. // For this to work, we must include A_mangled in our set of classes to load. - assertCanLoadAll(envelope.getMangled().mangle(), envelope.getMangled()) + assertCanLoadAll(testSerializationContext, + envelope.getMangled().mangle(), envelope.getMangled()) } @Test @@ -56,7 +58,8 @@ class CompositeMembers : AmqpCarpenterBase(AllWhitelist) { // We load an unknown class, B_mangled, which includes a reference to an unknown class, A_mangled. // This will fail, because A_mangled is not included in our set of classes to load. - assertFailsWith { assertCanLoadAll(envelope.getMangled().mangle()) } + assertFailsWith { assertCanLoadAll(testSerializationContext, + envelope.getMangled().mangle()) } } // See https://github.com/corda/corda/issues/4107 @@ -71,7 +74,7 @@ class CompositeMembers : AmqpCarpenterBase(AllWhitelist) { val uuid = UUID.randomUUID() val(_, envelope) = IOUStateData(10, uuid, "new value").roundTrip() - val recarpented = envelope.getMangled().load() + val recarpented = envelope.getMangled().load(testSerializationContext) val instance = recarpented.new(null, uuid, 10) assertEquals(uuid, instance.get("ref")) } @@ -90,7 +93,7 @@ class CompositeMembers : AmqpCarpenterBase(AllWhitelist) { "java.util.Map", mangledMap.prettyPrint(false)) - assertCanLoadAll(infoForD, mangledMap, mangledC) + assertCanLoadAll(testSerializationContext, infoForD, mangledMap, mangledC) } @Test @@ -104,6 +107,6 @@ class CompositeMembers : AmqpCarpenterBase(AllWhitelist) { val mangledNotAMap = envelope.typeInformationFor>().mangle() val mangledC = envelope.getMangled() - assertCanLoadAll(infoForD, mangledNotAMap, mangledC) + assertCanLoadAll(testSerializationContext, infoForD, mangledNotAMap, mangledC) } } diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt index 336471c9ac..d066315340 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/InheritanceSchemaToClassCarpenterTests.kt @@ -2,6 +2,7 @@ package net.corda.serialization.internal.carpenter import net.corda.core.serialization.CordaSerializable import net.corda.serialization.internal.AllWhitelist +import net.corda.serialization.internal.amqp.testutils.testSerializationContext import org.junit.Test import kotlin.test.* import java.io.NotSerializableException @@ -41,7 +42,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { val (_, env) = A(20).roundTrip() val mangledA = env.getMangled() - val carpentedA = mangledA.load() + val carpentedA = mangledA.load(testSerializationContext) val carpentedInstance = carpentedA.new(20) assertEquals(20, carpentedInstance.get("j")) @@ -52,10 +53,11 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { @Test fun interfaceParent2() { + @Suppress("UNUSED") class A(override val j: Int, val jj: Int) : J val (_, env) = A(23, 42).roundTrip() - val carpentedA = env.getMangled().load() + val carpentedA = env.getMangled().load(testSerializationContext) val carpetedInstance = carpentedA.constructors[0].newInstance(23, 42) assertEquals(23, carpetedInstance.get("j")) @@ -70,7 +72,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { class A(override val i: Int, override val ii: Int) : I, II val (_, env) = A(23, 42).roundTrip() - val carpentedA = env.getMangled().load() + val carpentedA = env.getMangled().load(testSerializationContext) val carpetedInstance = carpentedA.constructors[0].newInstance(23, 42) assertEquals(23, carpetedInstance.get("i")) @@ -88,7 +90,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { class A(override val i: Int, override val iii: Int) : III val (_, env) = A(23, 42).roundTrip() - val carpentedA = env.getMangled().load() + val carpentedA = env.getMangled().load(testSerializationContext) val carpetedInstance = carpentedA.constructors[0].newInstance(23, 42) assertEquals(23, carpetedInstance.get("i")) @@ -108,8 +110,8 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { class B(override val i: I, override val iiii: Int) : IIII val (_, env) = B(A(23), 42).roundTrip() - val carpentedA = env.getMangled().load() - val carpentedB = env.getMangled().load() + val carpentedA = env.getMangled().load(testSerializationContext) + val carpentedB = env.getMangled().load(testSerializationContext) val carpentedAInstance = carpentedA.new(23) val carpentedBInstance = carpentedB.new(carpentedAInstance, 42) @@ -127,7 +129,9 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { // if we remove the nested interface we should get an error as it's impossible // to have a concrete class loaded without having access to all of it's elements - assertFailsWith { assertCanLoadAll(env.getMangled().mangle()) } + assertFailsWith { assertCanLoadAll( + testSerializationContext, + env.getMangled().mangle()) } } @Test @@ -137,7 +141,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { val (_, env) = A(23).roundTrip() // This time around we will succeed, because the mangled I is included in the type information to be loaded. - assertCanLoadAll(env.getMangled().mangle(), env.getMangled()) + assertCanLoadAll(testSerializationContext, env.getMangled().mangle(), env.getMangled()) } @Test @@ -146,6 +150,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { val (_, env) = A(23, 42).roundTrip() assertCanLoadAll( + testSerializationContext, env.getMangled().mangle().mangle(), env.getMangled(), env.getMangled() @@ -158,6 +163,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) { val (_, env) = A(23, 42).roundTrip() assertCanLoadAll( + testSerializationContext, env.getMangled().mangle().mangle(), env.getMangled(), env.getMangled().mangle() diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt index f5848a8f3b..7174c6a04a 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/carpenter/MultiMemberCompositeSchemaToClassCarpenterTests.kt @@ -3,6 +3,7 @@ package net.corda.serialization.internal.carpenter import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.SerializableCalculatedProperty import net.corda.serialization.internal.AllWhitelist +import net.corda.serialization.internal.amqp.testutils.testSerializationContext import org.junit.Test import kotlin.test.assertEquals import kotlin.test.assertNotEquals @@ -15,7 +16,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhi data class A(val a: Int, val b: Long) val (_, env) = A(23, 42).roundTrip() - val carpentedInstance = env.getMangled().load().new(23, 42) + val carpentedInstance = env.getMangled().load(testSerializationContext).new(23, 42) assertEquals(23, carpentedInstance.get("a")) assertEquals(42L, carpentedInstance.get("b")) @@ -27,7 +28,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhi data class A(val a: Int, val b: String) val (_, env) = A(23, "skidoo").roundTrip() - val carpentedInstance = env.getMangled().load().new(23, "skidoo") + val carpentedInstance = env.getMangled().load(testSerializationContext).new(23, "skidoo") assertEquals(23, carpentedInstance.get("a")) assertEquals("skidoo", carpentedInstance.get("b")) @@ -57,7 +58,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhi squared: String """.trimIndent(), remoteTypeInformation.prettyPrint()) - val pinochio = remoteTypeInformation.mangle().load() + val pinochio = remoteTypeInformation.mangle().load(testSerializationContext) assertNotEquals(pinochio.name, C::class.java.name) assertNotEquals(pinochio, C::class.java) @@ -78,7 +79,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhi val (_, env) = C(5).roundTrip() - val pinochio = env.getMangled().load() + val pinochio = env.getMangled().load(testSerializationContext) val p = pinochio.new(5) assertEquals(5, p.get("doubled")) diff --git a/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt b/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt index f7b942ccd1..683241c117 100644 --- a/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt +++ b/serialization/src/test/kotlin/net/corda/serialization/internal/model/ClassCarpentingTypeLoaderTests.kt @@ -4,6 +4,7 @@ import com.fasterxml.jackson.databind.ObjectMapper import com.google.common.reflect.TypeToken import net.corda.serialization.internal.AllWhitelist import net.corda.serialization.internal.amqp.asClass +import net.corda.serialization.internal.amqp.testutils.testSerializationContext import net.corda.serialization.internal.carpenter.ClassCarpenterImpl import org.junit.Test import java.lang.reflect.Type @@ -12,8 +13,8 @@ import kotlin.test.assertEquals class ClassCarpentingTypeLoaderTests { val carpenter = ClassCarpenterImpl(AllWhitelist) - val remoteTypeCarpenter = SchemaBuildingRemoteTypeCarpenter(carpenter) - val typeLoader = ClassCarpentingTypeLoader(remoteTypeCarpenter, carpenter.classloader) + private val remoteTypeCarpenter = SchemaBuildingRemoteTypeCarpenter(carpenter) + private val typeLoader = ClassCarpentingTypeLoader(remoteTypeCarpenter, carpenter.classloader) @Test fun `carpent some related classes`() { @@ -44,7 +45,9 @@ class ClassCarpentingTypeLoaderTests { "previousAddresses" to listOfAddresses.mandatory ), emptyList(), emptyList()) - val types = typeLoader.load(listOf(personInformation, addressInformation, listOfAddresses)) + val types = typeLoader.load(listOf(personInformation, addressInformation, listOfAddresses), + testSerializationContext) + val addressType = types[addressInformation.typeIdentifier]!! val personType = types[personInformation.typeIdentifier]!!