CORDA-1747 fix issue around RPC return of generic objects (#3625)

* fix issue around RPC return of generic objects

* address review comments
This commit is contained in:
Stefano Franz 2018-07-17 12:19:06 +01:00 committed by GitHub
parent a8fa232301
commit 829be5dfb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 253 additions and 16 deletions

View File

@ -0,0 +1,73 @@
package net.corda.serialization.reproduction;
import net.corda.client.rpc.CordaRPCClient;
import net.corda.core.concurrent.CordaFuture;
import net.corda.core.flows.FlowLogic;
import net.corda.core.flows.StartableByRPC;
import net.corda.core.serialization.CordaSerializable;
import net.corda.node.services.Permissions;
import net.corda.testing.driver.Driver;
import net.corda.testing.driver.DriverParameters;
import net.corda.testing.driver.NodeHandle;
import net.corda.testing.driver.NodeParameters;
import net.corda.testing.node.User;
import org.junit.Test;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class GenericReturnFailureReproductionIntegrationTest {
@Test()
public void flowShouldReturnGenericList() {
User user = new User("yes", "yes", Collections.singleton(Permissions.startFlow(SuperSimpleGenericFlow.class)));
DriverParameters defaultParameters = new DriverParameters();
Driver.<Void>driver(defaultParameters, (driver) -> {
NodeHandle startedNode = getOrThrow(driver.startNode(new NodeParameters().withRpcUsers(Collections.singletonList(user)).withStartInSameProcess(true)));
(new CordaRPCClient(startedNode.getRpcAddress())).<Void>use("yes", "yes", (cordaRPCConnection -> {
getOrThrow(cordaRPCConnection.getProxy().startFlowDynamic(SuperSimpleGenericFlow.class).getReturnValue());
return null;
}));
return null;
});
}
@StartableByRPC
public static class SuperSimpleGenericFlow extends FlowLogic<GenericHolder<String>> {
public SuperSimpleGenericFlow() {
}
@Override
public GenericHolder<String> call() {
return new GenericHolder<>(IntStream.of(100).mapToObj((i) -> "" + i).collect(Collectors.toList()));
}
}
@CordaSerializable
public static class GenericHolder<S> {
private final List<S> items;
public GenericHolder(List<S> items) {
this.items = items;
}
public List<S> getItems() {
return items;
}
}
private static <Y> Y getOrThrow(CordaFuture<Y> future) {
try {
return future.get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}

View File

@ -63,9 +63,13 @@ class CollectionSerializer(private val declaredType: ParameterizedType, factory:
private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "list", Descriptor(typeDescriptor), emptyList())
private val outboundType = resolveTypeVariables(declaredType.actualTypeArguments[0], null)
private val inboundType = declaredType.actualTypeArguments[0]
override fun writeClassInfo(output: SerializationOutput) = ifThrowsAppend({ declaredType.typeName }) {
if (output.writeTypeNotations(typeNotation)) {
output.requireSerializer(declaredType.actualTypeArguments[0])
output.requireSerializer(outboundType)
}
}
@ -80,12 +84,13 @@ class CollectionSerializer(private val declaredType: ParameterizedType, factory:
data.withDescribed(typeNotation.descriptor) {
withList {
for (entry in obj as Collection<*>) {
output.writeObjectOrNull(entry, this, declaredType.actualTypeArguments[0], context, debugIndent)
output.writeObjectOrNull(entry, this, outboundType, context, debugIndent)
}
}
}
}
override fun readObject(
obj: Any,
schemas: SerializationSchemas,
@ -93,7 +98,7 @@ class CollectionSerializer(private val declaredType: ParameterizedType, factory:
context: SerializationContext): Any = ifThrowsAppend({ declaredType.typeName }) {
// TODO: Can we verify the entries in the list?
concreteBuilder((obj as List<*>).map {
input.readObjectOrNull(it, schemas, declaredType.actualTypeArguments[0], context)
input.readObjectOrNull(it, schemas, inboundType, context)
})
}
}

View File

@ -70,10 +70,15 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial
private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "map", Descriptor(typeDescriptor), emptyList())
private val inboundKeyType = declaredType.actualTypeArguments[0]
private val outboundKeyType = resolveTypeVariables(inboundKeyType, null)
private val inboundValueType = declaredType.actualTypeArguments[1]
private val outboundValueType = resolveTypeVariables(inboundValueType, null)
override fun writeClassInfo(output: SerializationOutput) = ifThrowsAppend({ declaredType.typeName }) {
if (output.writeTypeNotations(typeNotation)) {
output.requireSerializer(declaredType.actualTypeArguments[0])
output.requireSerializer(declaredType.actualTypeArguments[1])
output.requireSerializer(outboundKeyType)
output.requireSerializer(outboundValueType)
}
}
@ -91,8 +96,8 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial
data.putMap()
data.enter()
for ((key, value) in obj as Map<*, *>) {
output.writeObjectOrNull(key, data, declaredType.actualTypeArguments[0], context, debugIndent)
output.writeObjectOrNull(value, data, declaredType.actualTypeArguments[1], context, debugIndent)
output.writeObjectOrNull(key, data, outboundKeyType, context, debugIndent)
output.writeObjectOrNull(value, data, outboundValueType, context, debugIndent)
}
data.exit() // exit map
}
@ -108,8 +113,8 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial
private fun readEntry(schemas: SerializationSchemas, input: DeserializationInput, entry: Map.Entry<Any?, Any?>,
context: SerializationContext
) = input.readObjectOrNull(entry.key, schemas, declaredType.actualTypeArguments[0], context) to
input.readObjectOrNull(entry.value, schemas, declaredType.actualTypeArguments[1], context)
) = input.readObjectOrNull(entry.key, schemas, inboundKeyType, context) to
input.readObjectOrNull(entry.value, schemas, inboundValueType, context)
// Cannot use * as a bound for EnumMap and EnumSet since * is not an enum. So, we use a sample enum instead.
// We don't actually care about the type, we just need to make the compiler happier.

View File

@ -427,7 +427,7 @@ fun Data.writeReferencedObject(refObject: ReferencedObject) {
exit() // exit described
}
private fun resolveTypeVariables(actualType: Type, contextType: Type?): Type {
fun resolveTypeVariables(actualType: Type, contextType: Type?): Type {
val resolvedType = if (contextType != null) TypeToken.of(contextType).resolveType(actualType).type else actualType
// TODO: surely we check it is concrete at this point with no TypeVariables
return if (resolvedType is TypeVariable<*>) {

View File

@ -1,26 +1,180 @@
package net.corda.serialization.internal.amqp;
import net.corda.core.serialization.CordaSerializable;
import net.corda.core.serialization.SerializedBytes;
import net.corda.serialization.internal.amqp.custom.BigIntegerSerializer;
import net.corda.serialization.internal.amqp.testutils.AMQPTestUtilsKt;
import net.corda.serialization.internal.amqp.testutils.TestSerializationContext;
import org.hamcrest.CoreMatchers;
import org.junit.Assert;
import org.junit.Test;
import java.io.NotSerializableException;
import java.math.BigInteger;
import java.util.*;
import static net.corda.serialization.internal.amqp.testutils.AMQPTestUtilsKt.testDefaultFactory;
import static org.jgroups.util.Util.assertEquals;
@SuppressWarnings("unchecked")
public class JavaGenericsTest {
private static class Inner {
private final Integer v;
private Inner(Integer v) { this.v = v; }
Integer getV() { return v; }
private Inner(Integer v) {
this.v = v;
}
Integer getV() {
return v;
}
}
private static class A<T> {
private final T t;
private A(T t) { this.t = t; }
public T getT() { return t; }
private A(T t) {
this.t = t;
}
public T getT() {
return t;
}
}
@CordaSerializable
private static class ConcreteClass {
private final String theItem;
private ConcreteClass(String theItem) {
this.theItem = theItem;
}
public String getTheItem() {
return theItem;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ConcreteClass that = (ConcreteClass) o;
return Objects.equals(theItem, that.theItem);
}
@Override
public int hashCode() {
return Objects.hash(theItem);
}
}
@CordaSerializable
private static class GenericClassWithList<CC> {
private final List<CC> items;
private GenericClassWithList(List<CC> items) {
this.items = items;
}
public List<CC> getItems() {
return items;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
GenericClassWithList<?> that = (GenericClassWithList<?>) o;
return Objects.equals(items, that.items);
}
@Override
public int hashCode() {
return Objects.hash(items);
}
}
@CordaSerializable
private static class GenericClassWithMap<CC, GG> {
private final Map<CC, GG> theMap;
private GenericClassWithMap(Map<CC, GG> theMap) {
this.theMap = new LinkedHashMap<>(theMap);
}
public Map<CC, GG> getTheMap() {
return theMap;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
GenericClassWithMap<?, ?> that = (GenericClassWithMap<?, ?>) o;
return Objects.equals(theMap, that.theMap);
}
@Override
public int hashCode() {
return Objects.hash(theMap);
}
}
@CordaSerializable
private static class HolderOfGeneric<G> {
private final G theGeneric;
private HolderOfGeneric(G theGeneric) {
this.theGeneric = theGeneric;
}
public G getTheGeneric() {
return theGeneric;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
HolderOfGeneric<?> that = (HolderOfGeneric<?>) o;
return Objects.equals(theGeneric, that.theGeneric);
}
@Override
public int hashCode() {
return Objects.hash(theGeneric);
}
}
@Test
public void shouldSupportNestedGenericsFromJavaWithCollections() throws NotSerializableException {
ConcreteClass concreteClass = new ConcreteClass("How to make concrete, $99/class");
HolderOfGeneric<GenericClassWithList<ConcreteClass>> genericList = new HolderOfGeneric<>(new GenericClassWithList<>(Collections.singletonList(concreteClass)));
SerializerFactory factory = AMQPTestUtilsKt.testDefaultFactoryWithWhitelist();
SerializationOutput ser = new SerializationOutput(factory);
SerializedBytes<?> bytes = ser.serialize(genericList, TestSerializationContext.testSerializationContext);
DeserializationInput des = new DeserializationInput(factory);
HolderOfGeneric<GenericClassWithList<ConcreteClass>> genericList2 = des.deserialize(bytes, HolderOfGeneric.class, TestSerializationContext.testSerializationContext);
Assert.assertThat(genericList, CoreMatchers.is(CoreMatchers.equalTo(genericList2)));
}
@Test
public void shouldSupportNestedGenericsFromJavaWithMaps() throws NotSerializableException {
ConcreteClass concreteClass = new ConcreteClass("How to make concrete, $99/class");
GenericClassWithMap<ConcreteClass, BigInteger> genericMap = new GenericClassWithMap<>(Collections.singletonMap(concreteClass, BigInteger.ONE));
SerializerFactory factory = AMQPTestUtilsKt.testDefaultFactoryWithWhitelist();
factory.register(BigIntegerSerializer.INSTANCE);
SerializationOutput ser = new SerializationOutput(factory);
SerializedBytes<?> bytes = ser.serialize(genericMap, TestSerializationContext.testSerializationContext);
DeserializationInput des = new DeserializationInput(factory);
GenericClassWithMap<ConcreteClass, BigInteger> genericMap2 = des.deserialize(bytes, GenericClassWithMap.class, TestSerializationContext.testSerializationContext);
Assert.assertThat(genericMap2, CoreMatchers.is(CoreMatchers.equalTo(genericMap2)));
}
@Test
@ -67,7 +221,7 @@ public class JavaGenericsTest {
@Test
public void forceWildcard() throws NotSerializableException {
SerializedBytes<?> bytes = forceWildcardSerialize(new A<>(new Inner(29)));
Inner i = (Inner)forceWildcardDeserialize(bytes).getT();
Inner i = (Inner) forceWildcardDeserialize(bytes).getT();
assertEquals(29, i.getV());
}
@ -76,7 +230,7 @@ public class JavaGenericsTest {
SerializerFactory factory = testDefaultFactory();
SerializedBytes<?> bytes = forceWildcardSerializeFactory(new A<>(new Inner(29)), factory);
Inner i = (Inner)forceWildcardDeserializeFactory(bytes, factory).getT();
Inner i = (Inner) forceWildcardDeserializeFactory(bytes, factory).getT();
assertEquals(29, i.getV());
}