CORDA-540: Ensure that covariance of type is handled correctly when serializing with AMQP (#1631)

This commit is contained in:
Viktor Kolomeyko 2017-09-27 09:19:25 +01:00 committed by GitHub
parent aff4d35ccb
commit 8a842d1d53
5 changed files with 169 additions and 3 deletions

View File

@ -11,6 +11,8 @@ import org.apache.qpid.proton.codec.Data
import java.io.NotSerializableException
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.lang.reflect.TypeVariable
import java.lang.reflect.WildcardType
import java.nio.ByteBuffer
data class ObjectAndEnvelope<out T>(val obj: T, val envelope: Envelope)
@ -142,10 +144,18 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory) {
}
/**
* TODO: Currently performs rather basic checks aimed in particular at [java.util.List<Command<?>>] and
* [java.lang.Class<? extends net.corda.core.contracts.Contract>]
* Currently performs checks aimed at:
* * [java.util.List<Command<?>>] and [java.lang.Class<? extends net.corda.core.contracts.Contract>]
* * [T : Parent] and [Parent]
* * [? extends Parent] and [Parent]
*
* In the future tighter control might be needed
*/
private fun Type.materiallyEquivalentTo(that: Type): Boolean =
asClass() == that.asClass() && that is ParameterizedType
when(that) {
is ParameterizedType -> asClass() == that.asClass()
is TypeVariable<*> -> isSubClassOf(that.bounds.first())
is WildcardType -> isSubClassOf(that.upperBounds.first())
else -> false
}
}

View File

@ -189,6 +189,8 @@ internal fun Type.asClass(): Class<*>? {
this is Class<*> -> this
this is ParameterizedType -> this.rawType.asClass()
this is GenericArrayType -> this.genericComponentType.asClass()?.arrayClass()
this is TypeVariable<*> -> this.bounds.first().asClass()
this is WildcardType -> this.upperBounds.first().asClass()
else -> null
}
}

View File

@ -105,6 +105,8 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
val declaredComponent = declaredType.genericComponentType
inferTypeVariables(actualClass?.componentType, declaredComponent.asClass()!!, declaredComponent)?.asArray()
}
is TypeVariable<*> -> actualClass
is WildcardType -> actualClass
else -> null
}

View File

@ -0,0 +1,134 @@
package net.corda.nodeapi.internal.serialization.amqp;
import net.corda.core.serialization.CordaSerializable;
import net.corda.core.serialization.SerializedBytes;
import net.corda.nodeapi.internal.serialization.AllWhitelist;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class ListsSerializationJavaTest {
@CordaSerializable
interface Parent {}
public static class Child implements Parent {
private final int value;
Child(int value) {
this.value = value;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Child child = (Child) o;
return value == child.value;
}
@Override
public int hashCode() {
return value;
}
// Needed to show that there is a property called "value"
@SuppressWarnings("unused")
public int getValue() {
return value;
}
}
@CordaSerializable
public static class CovariantContainer<T extends Parent> {
private final List<T> content;
CovariantContainer(List<T> content) {
this.content = content;
}
@Override
@SuppressWarnings("unchecked")
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CovariantContainer<T> that = (CovariantContainer<T>) o;
return content != null ? content.equals(that.content) : that.content == null;
}
@Override
public int hashCode() {
return content != null ? content.hashCode() : 0;
}
// Needed to show that there is a property called "content"
@SuppressWarnings("unused")
public List<T> getContent() {
return content;
}
}
@CordaSerializable
public static class CovariantContainer2 {
private final List<? extends Parent> content;
CovariantContainer2(List<? extends Parent> content) {
this.content = content;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CovariantContainer2 that = (CovariantContainer2) o;
return content != null ? content.equals(that.content) : that.content == null;
}
@Override
public int hashCode() {
return content != null ? content.hashCode() : 0;
}
// Needed to show that there is a property called "content"
@SuppressWarnings("unused")
public List<? extends Parent> getContent() {
return content;
}
}
@Test
public void checkCovariance() throws Exception {
List<Child> payload = new ArrayList<>();
payload.add(new Child(1));
payload.add(new Child(2));
CovariantContainer<Child> container = new CovariantContainer<>(payload);
assertEqualAfterRoundTripSerialization(container, CovariantContainer.class);
}
@Test
public void checkCovariance2() throws Exception {
List<Child> payload = new ArrayList<>();
payload.add(new Child(1));
payload.add(new Child(2));
CovariantContainer2 container = new CovariantContainer2(payload);
assertEqualAfterRoundTripSerialization(container, CovariantContainer2.class);
}
// Have to have own version as Kotlin inline functions cannot be easily called from Java
private static<T> void assertEqualAfterRoundTripSerialization(T container, Class<T> clazz) throws Exception {
SerializerFactory factory1 = new SerializerFactory(AllWhitelist.INSTANCE, ClassLoader.getSystemClassLoader());
SerializationOutput ser = new SerializationOutput(factory1);
SerializedBytes<Object> bytes = ser.serialize(container);
DeserializationInput des = new DeserializationInput(factory1);
T deserialized = des.deserialize(bytes, clazz);
Assert.assertEquals(container, deserialized);
}
}

View File

@ -68,6 +68,24 @@ class ListsSerializationTest : TestDependencyInjectionBase() {
Assertions.assertThatThrownBy { wrongPayloadType.serialize() }
.isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive collection type for declaredType")
}
@CordaSerializable
interface Parent
data class Child(val value: Int) : Parent
@CordaSerializable
data class CovariantContainer<out T: Parent>(val payload: List<T>)
@Test
fun `check covariance`() {
val payload = ArrayList<Child>()
payload.add(Child(1))
payload.add(Child(2))
val container = CovariantContainer(payload)
assertEqualAfterRoundTripSerialization(container)
}
}
internal inline fun<reified T : Any> assertEqualAfterRoundTripSerialization(obj: T) {