mirror of
https://github.com/corda/corda.git
synced 2025-06-13 20:58:19 +00:00
Some preparation work for AMQP serialization integration & a refactor fix (#1214)
This commit is contained in:
@ -0,0 +1,94 @@
|
||||
package net.corda.nodeapi.internal.serialization
|
||||
|
||||
import net.corda.core.serialization.ClassWhitelist
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0
|
||||
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
|
||||
import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput
|
||||
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
|
||||
private const val AMQP_ENABLED = false
|
||||
|
||||
abstract class AbstractAMQPSerializationScheme : SerializationScheme {
|
||||
private val serializerFactoriesForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, SerializerFactory>()
|
||||
|
||||
protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory
|
||||
protected abstract fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory
|
||||
|
||||
private fun getSerializerFactory(context: SerializationContext): SerializerFactory {
|
||||
return serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
|
||||
when (context.useCase) {
|
||||
SerializationContext.UseCase.Checkpoint ->
|
||||
throw IllegalStateException("AMQP should not be used for checkpoint serialization.")
|
||||
SerializationContext.UseCase.RPCClient ->
|
||||
rpcClientSerializerFactory(context)
|
||||
SerializationContext.UseCase.RPCServer ->
|
||||
rpcServerSerializerFactory(context)
|
||||
else -> SerializerFactory(context.whitelist) // TODO pass class loader also
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
|
||||
val serializerFactory = getSerializerFactory(context)
|
||||
return DeserializationInput(serializerFactory).deserialize(byteSequence, clazz)
|
||||
}
|
||||
|
||||
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
|
||||
val serializerFactory = getSerializerFactory(context)
|
||||
return SerializationOutput(serializerFactory).serialize(obj)
|
||||
}
|
||||
|
||||
protected fun canDeserializeVersion(byteSequence: ByteSequence): Boolean = AMQP_ENABLED && byteSequence == AmqpHeaderV1_0
|
||||
}
|
||||
|
||||
// TODO: This will eventually cover server RPC as well and move to node module, but for now this is not implemented
|
||||
class AMQPServerSerializationScheme : AbstractAMQPSerializationScheme() {
|
||||
override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory {
|
||||
throw UnsupportedOperationException()
|
||||
}
|
||||
|
||||
override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory {
|
||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||
}
|
||||
|
||||
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
|
||||
return (canDeserializeVersion(byteSequence) &&
|
||||
(target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// TODO: This will eventually cover client RPC as well and move to client module, but for now this is not implemented
|
||||
class AMQPClientSerializationScheme : AbstractAMQPSerializationScheme() {
|
||||
override fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory {
|
||||
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
|
||||
}
|
||||
|
||||
override fun rpcServerSerializerFactory(context: SerializationContext): SerializerFactory {
|
||||
throw UnsupportedOperationException()
|
||||
}
|
||||
|
||||
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
|
||||
return (canDeserializeVersion(byteSequence) &&
|
||||
(target == SerializationContext.UseCase.P2P || target == SerializationContext.UseCase.Storage))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
val AMQP_P2P_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()),
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.P2P)
|
||||
val AMQP_STORAGE_CONTEXT = SerializationContextImpl(AmqpHeaderV1_0,
|
||||
SerializationDefaults.javaClass.classLoader,
|
||||
AllButBlacklisted,
|
||||
emptyMap(),
|
||||
true,
|
||||
SerializationContext.UseCase.Storage)
|
@ -54,6 +54,8 @@ data class SerializationContextImpl(override val preferedSerializationVersion: B
|
||||
}
|
||||
}
|
||||
|
||||
private const val HEADER_SIZE: Int = 8
|
||||
|
||||
open class SerializationFactoryImpl : SerializationFactory {
|
||||
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
|
||||
|
||||
@ -63,8 +65,8 @@ open class SerializationFactoryImpl : SerializationFactory {
|
||||
private val schemes: ConcurrentHashMap<Pair<ByteSequence, SerializationContext.UseCase>, SerializationScheme> = ConcurrentHashMap()
|
||||
|
||||
private fun schemeFor(byteSequence: ByteSequence, target: SerializationContext.UseCase): SerializationScheme {
|
||||
// truncate sequence to 8 bytes
|
||||
return schemes.computeIfAbsent(byteSequence.take(8).copy() to target) {
|
||||
// truncate sequence to 8 bytes, and make sure it's a copy to avoid holding onto large ByteArrays
|
||||
return schemes.computeIfAbsent(byteSequence.take(HEADER_SIZE).copy() to target) {
|
||||
for (scheme in registeredSchemes) {
|
||||
if (scheme.canDeserializeVersion(it.first, it.second)) {
|
||||
return@computeIfAbsent scheme
|
||||
@ -162,11 +164,12 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
|
||||
|
||||
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
|
||||
val pool = getPool(context)
|
||||
Input(byteSequence.bytes, byteSequence.offset, byteSequence.size).use { input ->
|
||||
val header = OpaqueBytes(input.readBytes(8))
|
||||
if (header != KryoHeaderV0_1) {
|
||||
throw KryoException("Serialized bytes header does not match expected format.")
|
||||
}
|
||||
val headerSize = KryoHeaderV0_1.size
|
||||
val header = byteSequence.take(headerSize)
|
||||
if (header != KryoHeaderV0_1) {
|
||||
throw KryoException("Serialized bytes header does not match expected format.")
|
||||
}
|
||||
Input(byteSequence.bytes, byteSequence.offset + headerSize, byteSequence.size - headerSize).use { input ->
|
||||
return pool.run { kryo ->
|
||||
withContext(kryo, context) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
|
@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.serialization.amqp
|
||||
|
||||
import net.corda.core.internal.getStackTraceAsString
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import org.apache.qpid.proton.amqp.Binary
|
||||
import org.apache.qpid.proton.amqp.DescribedType
|
||||
import org.apache.qpid.proton.amqp.UnsignedByte
|
||||
@ -26,17 +27,6 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S
|
||||
internal companion object {
|
||||
val BYTES_NEEDED_TO_PEEK: Int = 23
|
||||
|
||||
private fun subArraysEqual(a: ByteArray, aOffset: Int, length: Int, b: ByteArray, bOffset: Int): Boolean {
|
||||
if (aOffset + length > a.size || bOffset + length > b.size) throw IndexOutOfBoundsException()
|
||||
var bytesRemaining = length
|
||||
var aPos = aOffset
|
||||
var bPos = bOffset
|
||||
while (bytesRemaining-- > 0) {
|
||||
if (a[aPos++] != b[bPos++]) return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
fun peekSize(bytes: ByteArray): Int {
|
||||
// There's an 8 byte header, and then a 0 byte plus descriptor followed by constructor
|
||||
val eighth = bytes[8].toInt()
|
||||
@ -69,15 +59,16 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S
|
||||
|
||||
|
||||
@Throws(NotSerializableException::class)
|
||||
private fun <T : Any> getEnvelope(bytes: SerializedBytes<T>): Envelope {
|
||||
private fun getEnvelope(bytes: ByteSequence): Envelope {
|
||||
// Check that the lead bytes match expected header
|
||||
if (!subArraysEqual(bytes.bytes, 0, 8, AmqpHeaderV1_0.bytes, 0)) {
|
||||
val headerSize = AmqpHeaderV1_0.size
|
||||
if (bytes.take(headerSize) != AmqpHeaderV1_0) {
|
||||
throw NotSerializableException("Serialization header does not match.")
|
||||
}
|
||||
|
||||
val data = Data.Factory.create()
|
||||
val size = data.decode(ByteBuffer.wrap(bytes.bytes, 8, bytes.size - 8))
|
||||
if (size.toInt() != bytes.size - 8) {
|
||||
val size = data.decode(ByteBuffer.wrap(bytes.bytes, bytes.offset + headerSize, bytes.size - headerSize))
|
||||
if (size.toInt() != bytes.size - headerSize) {
|
||||
throw NotSerializableException("Unexpected size of data")
|
||||
}
|
||||
|
||||
@ -103,7 +94,7 @@ class DeserializationInput(internal val serializerFactory: SerializerFactory = S
|
||||
* be deserialized and a schema describing the types of the objects.
|
||||
*/
|
||||
@Throws(NotSerializableException::class)
|
||||
fun <T : Any> deserialize(bytes: SerializedBytes<T>, clazz: Class<T>): T {
|
||||
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>): T {
|
||||
return des {
|
||||
val envelope = getEnvelope(bytes)
|
||||
clazz.cast(readObjectOrNull(envelope.obj, envelope.schema, clazz))
|
||||
|
@ -0,0 +1,226 @@
|
||||
package net.corda.nodeapi.internal.serialization.amqp;
|
||||
|
||||
import net.corda.core.serialization.SerializedBytes;
|
||||
import org.apache.qpid.proton.codec.DecoderImpl;
|
||||
import org.apache.qpid.proton.codec.EncoderImpl;
|
||||
import org.junit.Test;
|
||||
|
||||
import javax.annotation.Nonnull;
|
||||
import java.io.NotSerializableException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class JavaSerializationOutputTests {
|
||||
|
||||
static class Foo {
|
||||
private final String bob;
|
||||
private final int count;
|
||||
|
||||
public Foo(String msg, long count) {
|
||||
this.bob = msg;
|
||||
this.count = (int) count;
|
||||
}
|
||||
|
||||
@ConstructorForDeserialization
|
||||
public Foo(String fred, int count) {
|
||||
this.bob = fred;
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
public String getFred() {
|
||||
return bob;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
Foo foo = (Foo) o;
|
||||
|
||||
if (count != foo.count) return false;
|
||||
return bob != null ? bob.equals(foo.bob) : foo.bob == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = bob != null ? bob.hashCode() : 0;
|
||||
result = 31 * result + count;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
static class UnAnnotatedFoo {
|
||||
private final String bob;
|
||||
private final int count;
|
||||
|
||||
public UnAnnotatedFoo(String fred, int count) {
|
||||
this.bob = fred;
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
public String getFred() {
|
||||
return bob;
|
||||
}
|
||||
|
||||
public int getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
UnAnnotatedFoo foo = (UnAnnotatedFoo) o;
|
||||
|
||||
if (count != foo.count) return false;
|
||||
return bob != null ? bob.equals(foo.bob) : foo.bob == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = bob != null ? bob.hashCode() : 0;
|
||||
result = 31 * result + count;
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
static class BoxedFoo {
|
||||
private final String fred;
|
||||
private final Integer count;
|
||||
|
||||
public BoxedFoo(String fred, Integer count) {
|
||||
this.fred = fred;
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
public String getFred() {
|
||||
return fred;
|
||||
}
|
||||
|
||||
public Integer getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
BoxedFoo boxedFoo = (BoxedFoo) o;
|
||||
|
||||
if (fred != null ? !fred.equals(boxedFoo.fred) : boxedFoo.fred != null) return false;
|
||||
return count != null ? count.equals(boxedFoo.count) : boxedFoo.count == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = fred != null ? fred.hashCode() : 0;
|
||||
result = 31 * result + (count != null ? count.hashCode() : 0);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static class BoxedFooNotNull {
|
||||
private final String fred;
|
||||
private final Integer count;
|
||||
|
||||
public BoxedFooNotNull(String fred, Integer count) {
|
||||
this.fred = fred;
|
||||
this.count = count;
|
||||
}
|
||||
|
||||
public String getFred() {
|
||||
return fred;
|
||||
}
|
||||
|
||||
@Nonnull
|
||||
public Integer getCount() {
|
||||
return count;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
|
||||
BoxedFooNotNull boxedFoo = (BoxedFooNotNull) o;
|
||||
|
||||
if (fred != null ? !fred.equals(boxedFoo.fred) : boxedFoo.fred != null) return false;
|
||||
return count != null ? count.equals(boxedFoo.count) : boxedFoo.count == null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = fred != null ? fred.hashCode() : 0;
|
||||
result = 31 * result + (count != null ? count.hashCode() : 0);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
private Object serdes(Object obj) throws NotSerializableException {
|
||||
SerializerFactory factory = new SerializerFactory();
|
||||
SerializationOutput ser = new SerializationOutput(factory);
|
||||
SerializedBytes<Object> bytes = ser.serialize(obj);
|
||||
|
||||
DecoderImpl decoder = new DecoderImpl();
|
||||
|
||||
decoder.register(Envelope.Companion.getDESCRIPTOR(), Envelope.Companion);
|
||||
decoder.register(Schema.Companion.getDESCRIPTOR(), Schema.Companion);
|
||||
decoder.register(Descriptor.Companion.getDESCRIPTOR(), Descriptor.Companion);
|
||||
decoder.register(Field.Companion.getDESCRIPTOR(), Field.Companion);
|
||||
decoder.register(CompositeType.Companion.getDESCRIPTOR(), CompositeType.Companion);
|
||||
decoder.register(Choice.Companion.getDESCRIPTOR(), Choice.Companion);
|
||||
decoder.register(RestrictedType.Companion.getDESCRIPTOR(), RestrictedType.Companion);
|
||||
|
||||
new EncoderImpl(decoder);
|
||||
decoder.setByteBuffer(ByteBuffer.wrap(bytes.getBytes(), 8, bytes.getSize() - 8));
|
||||
Envelope result = (Envelope) decoder.readObject();
|
||||
assertTrue(result != null);
|
||||
|
||||
DeserializationInput des = new DeserializationInput();
|
||||
Object desObj = des.deserialize(bytes, Object.class);
|
||||
assertTrue(Objects.deepEquals(obj, desObj));
|
||||
|
||||
// Now repeat with a re-used factory
|
||||
SerializationOutput ser2 = new SerializationOutput(factory);
|
||||
DeserializationInput des2 = new DeserializationInput(factory);
|
||||
Object desObj2 = des2.deserialize(ser2.serialize(obj), Object.class);
|
||||
assertTrue(Objects.deepEquals(obj, desObj2));
|
||||
// TODO: check schema is as expected
|
||||
return desObj2;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJavaConstructorAnnotations() throws NotSerializableException {
|
||||
Foo obj = new Foo("Hello World!", 123);
|
||||
serdes(obj);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testJavaConstructorWithoutAnnotations() throws NotSerializableException {
|
||||
UnAnnotatedFoo obj = new UnAnnotatedFoo("Hello World!", 123);
|
||||
serdes(obj);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testBoxedTypes() throws NotSerializableException {
|
||||
BoxedFoo obj = new BoxedFoo("Hello World!", 123);
|
||||
serdes(obj);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBoxedTypesNotNull() throws NotSerializableException {
|
||||
BoxedFooNotNull obj = new BoxedFooNotNull("Hello World!", 123);
|
||||
serdes(obj);
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user