mirror of
https://github.com/corda/corda.git
synced 2024-12-23 06:42:33 +00:00
CORDA-4103 Feature Branch: Serialization injection for transaction building (#6867)
* CORDA-4105 Add public API to allow custom serialization schemes (#6848) * CORDA-4105 Add public API to allow custom serialization schemes * Fix Detekt * Suppress warning * Fix usused import * Improve API to use generics This does not break Java support (only Intelij gets confused). * Add more detailed documentation to public interfaces * Change internal variable name after rename public API * Update Public API to use ByteSquence instead of SerializedBytes * Remove unused import * Fix whitespace. * Add added public API to .ci/api-current.txt * Improve public interfaces Rename CustomSchemeContext to SerializationSchemeContext to improve clarity and move to it's own file. Improve kdoc to make things less confusing. * Update API current with changed API * CORDA-4104 Implement custom serialization scheme discovery mechanism (#6854) * CORDA-4104 Implement CustomSerializationScheme Discovery Mechanism Discovers a single CustomSerializationScheme implementation inside the drivers dir using a system property. * Started MockNetwork test * Add driver test of Custom Serialization Scheme * Fix detekt and minor style error * Respond to review comments Allow non-single arg constructors (there must be one no args constructor), move code from SerializationEnviroment into its own file, improve exceptions to be more user friendly. * Fix minor bug in Scheme finding code + improve error messages * CORDA-4104 Improve test coverage of custom serialization scheme discovery (#6855) * CORDA-4104 Add test of classloader scanning for CustomSerializationSchemes * Fix Detekt * NOTICK Clarify KDOC on SerializationSchemeContext (#6865) * CORDA-4111 Change Component Group Serialization to use contex when the lazy map is constructed (#6856) Currently the component group will recheck the thread local (global) serialization context when component groups are serialized lazily. Instead store the serialization context when the lazy map is constructed and use that latter when doing serialization lazily. * CORDA-4106 Test wire transaction can still be written to the ledger (#6860) * Add test that writes transaction to the Database * Improve test check serialization scheme in test body * CORDA-4119 Minor changes to serialisation injection for transaction building (#6868) * CORDA-4119 Minor changes to serialisation injection for transaction building Scan the CorDapp classloader instead of the drivers classloader. Add properties map to CustomSerialiaztionContext (copied from SerializationContext). Change API to let a user pass in the serialization context in TransactionBuilder.toLedgerTransaction * Improve KDOC + fix shawdowing issue in CordaUtils * Pass only the properties map into theTransactionBuilder.toWireTransaction Not the entire serializationContext * Revert change to CordaUtils * Improve KDOC explain pitfalls of setting properties
This commit is contained in:
parent
4f336a1a67
commit
20dbbf008d
.ci
core-tests/src/test/kotlin/net/corda/coretests/transactions
core/src/main/kotlin/net/corda/core
internal
serialization
transactions
node-api/src
main/kotlin/net/corda/nodeapi/internal/serialization
test
java/net/corda/nodeapi/internal/serialization
kotlin/net/corda/nodeapi/internal/serialization
node/src
integration-test/kotlin/net/corda/node
main/kotlin/net/corda/node/internal
test/kotlin/net/corda/node/internal
serialization/src/main/kotlin/net/corda/serialization/internal
testing/core-test-utils/src/main/kotlin/net/corda/coretesting/internal
@ -5937,6 +5937,13 @@ public @interface net.corda.core.serialization.CordaSerializationTransformRename
|
||||
public @interface net.corda.core.serialization.CordaSerializationTransformRenames
|
||||
public abstract net.corda.core.serialization.CordaSerializationTransformRename[] value()
|
||||
##
|
||||
public interface net.corda.core.serialization.CustomSerializationScheme
|
||||
@NotNull
|
||||
public abstract T deserialize(net.corda.core.utilities.ByteSequence, Class<T>, net.corda.core.serialization.SerializationSchemeContext)
|
||||
public abstract int getSchemeId()
|
||||
@NotNull
|
||||
public abstract net.corda.core.utilities.ByteSequence serialize(T, net.corda.core.serialization.SerializationSchemeContext)
|
||||
##
|
||||
public @interface net.corda.core.serialization.DeprecatedConstructorForDeserialization
|
||||
public abstract int version()
|
||||
##
|
||||
@ -6078,6 +6085,13 @@ public static final class net.corda.core.serialization.SerializationFactory$Comp
|
||||
@NotNull
|
||||
public final net.corda.core.serialization.SerializationFactory getDefaultFactory()
|
||||
##
|
||||
@DoNotImplement
|
||||
public interface net.corda.core.serialization.SerializationSchemeContext
|
||||
@NotNull
|
||||
public abstract ClassLoader getDeserializationClassLoader()
|
||||
@NotNull
|
||||
public abstract net.corda.core.serialization.ClassWhitelist getWhitelist()
|
||||
##
|
||||
public interface net.corda.core.serialization.SerializationToken
|
||||
@NotNull
|
||||
public abstract Object fromToken(net.corda.core.serialization.SerializeAsTokenContext)
|
||||
|
@ -3,7 +3,16 @@ package net.corda.coretests.transactions
|
||||
import com.nhaarman.mockito_kotlin.doReturn
|
||||
import com.nhaarman.mockito_kotlin.mock
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import net.corda.core.contracts.*
|
||||
import net.corda.core.contracts.Command
|
||||
import net.corda.core.contracts.ContractAttachment
|
||||
import net.corda.core.contracts.HashAttachmentConstraint
|
||||
import net.corda.core.contracts.PrivacySalt
|
||||
import net.corda.core.contracts.SignatureAttachmentConstraint
|
||||
import net.corda.core.contracts.StateAndRef
|
||||
import net.corda.core.contracts.StateRef
|
||||
import net.corda.core.contracts.TimeWindow
|
||||
import net.corda.core.contracts.TransactionState
|
||||
import net.corda.core.contracts.TransactionVerificationException
|
||||
import net.corda.core.cordapp.CordappProvider
|
||||
import net.corda.core.crypto.CompositeKey
|
||||
import net.corda.core.crypto.DigestService
|
||||
@ -20,11 +29,16 @@ import net.corda.core.node.services.IdentityService
|
||||
import net.corda.core.node.services.NetworkParametersService
|
||||
import net.corda.core.serialization.serialize
|
||||
import net.corda.core.transactions.TransactionBuilder
|
||||
import net.corda.coretesting.internal.rigorousMock
|
||||
import net.corda.testing.common.internal.testNetworkParameters
|
||||
import net.corda.testing.contracts.DummyContract
|
||||
import net.corda.testing.contracts.DummyState
|
||||
import net.corda.testing.core.*
|
||||
import net.corda.coretesting.internal.rigorousMock
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
import net.corda.testing.core.DUMMY_NOTARY_NAME
|
||||
import net.corda.testing.core.DummyCommandData
|
||||
import net.corda.testing.core.SerializationEnvironmentRule
|
||||
import net.corda.testing.core.TestIdentity
|
||||
import org.assertj.core.api.Assertions.assertThat
|
||||
import org.assertj.core.api.Assertions.assertThatThrownBy
|
||||
import org.junit.Assert.assertFalse
|
||||
@ -35,6 +49,7 @@ import org.junit.Rule
|
||||
import org.junit.Test
|
||||
import java.security.PublicKey
|
||||
import java.time.Instant
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
class TransactionBuilderTest {
|
||||
@Rule
|
||||
@ -299,4 +314,22 @@ class TransactionBuilderTest {
|
||||
HashAgility.init()
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `toWireTransaction fails if no scheme is registered with schemeId`() {
|
||||
val outputState = TransactionState(
|
||||
data = DummyState(),
|
||||
contract = DummyContract.PROGRAM_ID,
|
||||
notary = notary,
|
||||
constraint = HashAttachmentConstraint(contractAttachmentId)
|
||||
)
|
||||
val builder = TransactionBuilder()
|
||||
.addOutputState(outputState)
|
||||
.addCommand(DummyCommandData, notary.owningKey)
|
||||
|
||||
val schemeId = 7
|
||||
assertFailsWith<UnsupportedOperationException>("Could not find custom serialization scheme with SchemeId = $schemeId.") {
|
||||
builder.toWireTransaction(services, schemeId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,7 +154,8 @@ fun createComponentGroups(inputs: List<StateRef>,
|
||||
timeWindow: TimeWindow?,
|
||||
references: List<StateRef>,
|
||||
networkParametersHash: SecureHash?): List<ComponentGroup> {
|
||||
val serialize = { value: Any, _: Int -> value.serialize() }
|
||||
val serializationContext = SerializationFactory.defaultFactory.defaultContext
|
||||
val serialize = { value: Any, _: Int -> value.serialize(context = serializationContext) }
|
||||
val componentGroupMap: MutableList<ComponentGroup> = mutableListOf()
|
||||
if (inputs.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.INPUTS_GROUP.ordinal, inputs.lazyMapped(serialize)))
|
||||
if (references.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.REFERENCES_GROUP.ordinal, references.lazyMapped(serialize)))
|
||||
|
@ -0,0 +1,35 @@
|
||||
package net.corda.core.serialization
|
||||
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import java.io.NotSerializableException
|
||||
|
||||
/***
|
||||
* Implement this interface to add your own Serialization Scheme. This is an experimental feature. All methods in this class MUST be
|
||||
* thread safe i.e. methods from the same instance of this class can be called in different threads simultaneously.
|
||||
*/
|
||||
interface CustomSerializationScheme {
|
||||
/**
|
||||
* This method must return an id used to uniquely identify the Scheme. This should be unique within a network as serialized data might
|
||||
* be sent over the wire.
|
||||
*/
|
||||
fun getSchemeId(): Int
|
||||
|
||||
/**
|
||||
* This method must deserialize the data stored [bytes] into an instance of [T].
|
||||
*
|
||||
* @param bytes the serialized data.
|
||||
* @param clazz the class to instantiate.
|
||||
* @param context used to pass information about how the object should be deserialized.
|
||||
*/
|
||||
@Throws(NotSerializableException::class)
|
||||
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T
|
||||
|
||||
/**
|
||||
* This method must be able to serialize any object [T] into a ByteSequence.
|
||||
*
|
||||
* @param obj the object to be serialized.
|
||||
* @param context used to pass information about how the object should be serialized.
|
||||
*/
|
||||
@Throws(NotSerializableException::class)
|
||||
fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence
|
||||
}
|
@ -134,7 +134,7 @@ interface SerializationContext {
|
||||
*/
|
||||
val encodingWhitelist: EncodingWhitelist
|
||||
/**
|
||||
* A map of any addition properties specific to the particular use case.
|
||||
* A map of any additional properties specific to the particular use case.
|
||||
*/
|
||||
val properties: Map<Any, Any>
|
||||
/**
|
||||
@ -178,6 +178,11 @@ interface SerializationContext {
|
||||
*/
|
||||
fun withProperty(property: Any, value: Any): SerializationContext
|
||||
|
||||
/**
|
||||
* Helper method to return a new context based on this context with the extra properties added.
|
||||
*/
|
||||
fun withProperties(extraProperties: Map<Any, Any>): SerializationContext
|
||||
|
||||
/**
|
||||
* Helper method to return a new context based on this context with object references disabled.
|
||||
*/
|
||||
|
@ -0,0 +1,30 @@
|
||||
package net.corda.core.serialization
|
||||
|
||||
import net.corda.core.DoNotImplement
|
||||
|
||||
/**
|
||||
* This is used to pass information into [CustomSerializationScheme] about how the object should be (de)serialized.
|
||||
* This context can change depending on the specific circumstances in the node when (de)serialization occurs.
|
||||
*/
|
||||
@DoNotImplement
|
||||
interface SerializationSchemeContext {
|
||||
/**
|
||||
* The class loader to use for deserialization. This is guaranteed to be able to load all the required classes passed into
|
||||
* [CustomSerializationScheme.deserialize].
|
||||
*/
|
||||
val deserializationClassLoader: ClassLoader
|
||||
/**
|
||||
* A whitelist that contains (mostly for security purposes) which classes are authorised to be deserialized.
|
||||
* A secure implementation will not instantiate any object which is not either whitelisted or annotated with [CordaSerializable] when
|
||||
* deserializing. To catch classes missing from the whitelist as early as possible it is HIGHLY recommended to also check this
|
||||
* whitelist when serializing (as well as deserializing) objects.
|
||||
*/
|
||||
val whitelist: ClassWhitelist
|
||||
/**
|
||||
* A map of any additional properties specific to the particular use case. If these properties are set via
|
||||
* [toWireTransaction][net.corda.core.transactions.TransactionBuilder.toWireTransaction] then they might not be available when
|
||||
* deserializing. If the properties are required when deserializing, they can be added into the blob when serializing and read back
|
||||
* when deserializing.
|
||||
*/
|
||||
val properties: Map<Any, Any>
|
||||
}
|
@ -0,0 +1,28 @@
|
||||
package net.corda.core.serialization.internal
|
||||
|
||||
import net.corda.core.KeepForDJVM
|
||||
import net.corda.core.serialization.SerializationMagic
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
class CustomSerializationSchemeUtils {
|
||||
|
||||
@KeepForDJVM
|
||||
companion object {
|
||||
|
||||
private const val SERIALIZATION_SCHEME_ID_SIZE = 4
|
||||
private val PREFIX = "CUS".toByteArray()
|
||||
|
||||
fun getCustomSerializationMagicFromSchemeId(schemeId: Int) : SerializationMagic {
|
||||
return SerializationMagic.of(PREFIX + ByteBuffer.allocate(SERIALIZATION_SCHEME_ID_SIZE).putInt(schemeId).array())
|
||||
}
|
||||
|
||||
fun getSchemeIdIfCustomSerializationMagic(magic: SerializationMagic): Int? {
|
||||
return if (magic.take(PREFIX.size) != ByteSequence.of(PREFIX)) {
|
||||
null
|
||||
} else {
|
||||
return magic.slice(start = PREFIX.size).int
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -16,14 +16,18 @@ import net.corda.core.node.ServicesForResolution
|
||||
import net.corda.core.node.ZoneVersionTooLowException
|
||||
import net.corda.core.node.services.AttachmentId
|
||||
import net.corda.core.node.services.KeyManagementService
|
||||
import net.corda.core.serialization.CustomSerializationScheme
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializationDefaults
|
||||
import net.corda.core.serialization.SerializationFactory
|
||||
import net.corda.core.serialization.SerializationMagic
|
||||
import net.corda.core.serialization.SerializationSchemeContext
|
||||
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId
|
||||
import net.corda.core.utilities.contextLogger
|
||||
import java.security.PublicKey
|
||||
import java.time.Duration
|
||||
import java.time.Instant
|
||||
import java.util.ArrayDeque
|
||||
import java.util.UUID
|
||||
import java.util.*
|
||||
import java.util.regex.Pattern
|
||||
import kotlin.collections.ArrayList
|
||||
import kotlin.collections.component1
|
||||
@ -140,6 +144,41 @@ open class TransactionBuilder(
|
||||
fun toWireTransaction(services: ServicesForResolution): WireTransaction = toWireTransactionWithContext(services, null)
|
||||
.apply { checkSupportedHashType() }
|
||||
|
||||
/**
|
||||
* Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction.
|
||||
*
|
||||
* @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction].
|
||||
* This is an experimental feature.
|
||||
*
|
||||
* @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder].
|
||||
*
|
||||
* @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4.
|
||||
*/
|
||||
@Throws(MissingContractAttachments::class)
|
||||
fun toWireTransaction(services: ServicesForResolution, schemeId: Int): WireTransaction {
|
||||
return toWireTransaction(services, schemeId, emptyMap()).apply { checkSupportedHashType() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction.
|
||||
*
|
||||
* @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction].
|
||||
* This is an experimental feature.
|
||||
*
|
||||
* @param [properties] a list of properties to add to the [SerializationSchemeContext] these properties can be accessed in [CustomSerializationScheme.serialize]
|
||||
* when serializing the componentGroups of the wire transaction but might not be available when deserializing.
|
||||
*
|
||||
* @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder].
|
||||
*
|
||||
* @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4.
|
||||
*/
|
||||
@Throws(MissingContractAttachments::class)
|
||||
fun toWireTransaction(services: ServicesForResolution, schemeId: Int, properties: Map<Any, Any>): WireTransaction {
|
||||
val magic: SerializationMagic = getCustomSerializationMagicFromSchemeId(schemeId)
|
||||
val serializationContext = SerializationDefaults.P2P_CONTEXT.withPreferredSerializationVersion(magic).withProperties(properties)
|
||||
return toWireTransactionWithContext(services, serializationContext).apply { checkSupportedHashType() }
|
||||
}
|
||||
|
||||
@CordaInternal
|
||||
internal fun toWireTransactionWithContext(
|
||||
services: ServicesForResolution,
|
||||
|
@ -0,0 +1,47 @@
|
||||
package net.corda.nodeapi.internal.serialization
|
||||
|
||||
import net.corda.core.serialization.SerializationSchemeContext
|
||||
import net.corda.core.serialization.CustomSerializationScheme
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializedBytes
|
||||
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.serialization.internal.CordaSerializationMagic
|
||||
import net.corda.serialization.internal.SerializationScheme
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.io.NotSerializableException
|
||||
|
||||
class CustomSerializationSchemeAdapter(private val customScheme: CustomSerializationScheme): SerializationScheme {
|
||||
|
||||
val serializationSchemeMagic = getCustomSerializationMagicFromSchemeId(customScheme.getSchemeId())
|
||||
|
||||
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
|
||||
return magic == serializationSchemeMagic
|
||||
}
|
||||
|
||||
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
|
||||
val readMagic = byteSequence.take(serializationSchemeMagic.size)
|
||||
if (readMagic != serializationSchemeMagic) {
|
||||
throw NotSerializableException("Scheme ${customScheme::class.java} is incompatible with blob." +
|
||||
" Magic from blob = $readMagic (Expected = $serializationSchemeMagic)")
|
||||
}
|
||||
return customScheme.deserialize(
|
||||
byteSequence.subSequence(serializationSchemeMagic.size, byteSequence.size - serializationSchemeMagic.size),
|
||||
clazz,
|
||||
SerializationSchemeContextAdapter(context)
|
||||
)
|
||||
}
|
||||
|
||||
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
|
||||
val stream = ByteArrayOutputStream()
|
||||
stream.write(serializationSchemeMagic.bytes)
|
||||
stream.write(customScheme.serialize(obj, SerializationSchemeContextAdapter(context)).bytes)
|
||||
return SerializedBytes(stream.toByteArray())
|
||||
}
|
||||
|
||||
private class SerializationSchemeContextAdapter(context: SerializationContext) : SerializationSchemeContext {
|
||||
override val deserializationClassLoader = context.deserializationClassLoader
|
||||
override val whitelist = context.whitelist
|
||||
override val properties = context.properties
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
package net.corda.nodeapi.internal.serialization;
|
||||
|
||||
import net.corda.core.serialization.SerializationSchemeContext;
|
||||
import net.corda.core.serialization.CustomSerializationScheme;
|
||||
import net.corda.core.serialization.SerializedBytes;
|
||||
import net.corda.core.utilities.ByteSequence;
|
||||
|
||||
public class DummyCustomSerializationSchemeInJava implements CustomSerializationScheme {
|
||||
|
||||
public class DummyOutput {}
|
||||
|
||||
static final int testMagic = 7;
|
||||
|
||||
@Override
|
||||
public int getSchemeId() {
|
||||
return testMagic;
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> T deserialize(ByteSequence bytes, Class<T> clazz, SerializationSchemeContext context) {
|
||||
return (T)new DummyOutput();
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> SerializedBytes<T> serialize(T obj, SerializationSchemeContext context) {
|
||||
byte[] myBytes = {0xA, 0xA};
|
||||
return new SerializedBytes<>(myBytes);
|
||||
}
|
||||
}
|
@ -0,0 +1,93 @@
|
||||
package net.corda.nodeapi.internal.serialization
|
||||
|
||||
import net.corda.core.serialization.SerializationSchemeContext
|
||||
import net.corda.core.serialization.CustomSerializationScheme
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.nodeapi.internal.serialization.testutils.serializationContext
|
||||
import org.junit.Test
|
||||
import org.junit.jupiter.api.Assertions.assertTrue
|
||||
import java.io.NotSerializableException
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
class CustomSerializationSchemeAdapterTests {
|
||||
|
||||
companion object {
|
||||
const val DEFAULT_SCHEME_ID = 7
|
||||
}
|
||||
|
||||
class DummyInputClass
|
||||
class DummyOutputClass
|
||||
|
||||
class SingleInputAndOutputScheme(private val schemeId: Int = DEFAULT_SCHEME_ID): CustomSerializationScheme {
|
||||
|
||||
override fun getSchemeId(): Int {
|
||||
return schemeId
|
||||
}
|
||||
|
||||
override fun <T: Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return DummyOutputClass() as T
|
||||
}
|
||||
|
||||
override fun <T: Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
|
||||
assertTrue(obj is DummyInputClass)
|
||||
return ByteSequence.of(ByteArray(2) { 0x2 })
|
||||
}
|
||||
}
|
||||
|
||||
class SameBytesInputAndOutputsAndScheme: CustomSerializationScheme {
|
||||
|
||||
private val expectedBytes = "123456789".toByteArray()
|
||||
|
||||
override fun getSchemeId(): Int {
|
||||
return DEFAULT_SCHEME_ID
|
||||
}
|
||||
|
||||
override fun <T: Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
|
||||
bytes.open().use {
|
||||
val data = ByteArray(expectedBytes.size) { 0 }
|
||||
it.read(data)
|
||||
assertTrue(data.contentEquals(expectedBytes))
|
||||
}
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return DummyOutputClass() as T
|
||||
}
|
||||
|
||||
override fun <T: Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
|
||||
return ByteSequence.of(expectedBytes)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `CustomSerializationSchemeAdapter calls the correct methods in CustomSerializationScheme`() {
|
||||
val scheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme())
|
||||
val serializedData = scheme.serialize(DummyInputClass(), serializationContext)
|
||||
val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext)
|
||||
assertTrue(roundTripped is DummyOutputClass)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `CustomSerializationSchemeAdapter can adapt a Java implementation`() {
|
||||
val scheme = CustomSerializationSchemeAdapter(DummyCustomSerializationSchemeInJava())
|
||||
val serializedData = scheme.serialize(DummyInputClass(), serializationContext)
|
||||
val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext)
|
||||
assertTrue(roundTripped is DummyCustomSerializationSchemeInJava.DummyOutput)
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `CustomSerializationSchemeAdapter validates the magic`() {
|
||||
val inScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme())
|
||||
val serializedData = inScheme.serialize(DummyInputClass(), serializationContext)
|
||||
val outScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme(8))
|
||||
assertFailsWith<NotSerializableException> {
|
||||
outScheme.deserialize(serializedData, DummyOutputClass::class.java, serializationContext)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout=300_000)
|
||||
fun `CustomSerializationSchemeAdapter preserves the serialized bytes between deserialize and serialize`() {
|
||||
val scheme = CustomSerializationSchemeAdapter(SameBytesInputAndOutputsAndScheme())
|
||||
val serializedData = scheme.serialize(Any(), serializationContext)
|
||||
scheme.deserialize(serializedData, Any::class.java, serializationContext)
|
||||
}
|
||||
}
|
@ -0,0 +1,342 @@
|
||||
package net.corda.node
|
||||
|
||||
import co.paralleluniverse.fibers.Suspendable
|
||||
import com.esotericsoftware.kryo.Kryo
|
||||
import com.esotericsoftware.kryo.io.Input
|
||||
import com.esotericsoftware.kryo.io.Output
|
||||
import de.javakaffee.kryoserializers.ArraysAsListSerializer
|
||||
import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint
|
||||
import net.corda.core.contracts.BelongsToContract
|
||||
import net.corda.core.contracts.Contract
|
||||
import net.corda.core.contracts.ContractState
|
||||
import net.corda.core.contracts.TransactionState
|
||||
import net.corda.core.contracts.TypeOnlyCommandData
|
||||
import net.corda.core.crypto.Crypto
|
||||
import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.crypto.SignableData
|
||||
import net.corda.core.crypto.SignatureMetadata
|
||||
import net.corda.core.flows.CollectSignaturesFlow
|
||||
import net.corda.core.flows.FinalityFlow
|
||||
import net.corda.core.flows.FlowLogic
|
||||
import net.corda.core.flows.FlowSession
|
||||
import net.corda.core.flows.InitiatedBy
|
||||
import net.corda.core.flows.InitiatingFlow
|
||||
import net.corda.core.flows.ReceiveFinalityFlow
|
||||
import net.corda.core.flows.SignTransactionFlow
|
||||
import net.corda.core.flows.StartableByRPC
|
||||
import net.corda.core.identity.AbstractParty
|
||||
import net.corda.core.identity.Party
|
||||
import net.corda.core.internal.concurrent.transpose
|
||||
import net.corda.core.internal.copyBytes
|
||||
import net.corda.core.messaging.startFlow
|
||||
import net.corda.core.serialization.CustomSerializationScheme
|
||||
import net.corda.core.serialization.SerializationSchemeContext
|
||||
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic
|
||||
import net.corda.core.transactions.LedgerTransaction
|
||||
import net.corda.core.transactions.SignedTransaction
|
||||
import net.corda.core.transactions.TransactionBuilder
|
||||
import net.corda.core.transactions.WireTransaction
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.core.utilities.getOrThrow
|
||||
import net.corda.core.utilities.unwrap
|
||||
import net.corda.serialization.internal.CordaSerializationMagic
|
||||
import net.corda.serialization.internal.SerializationFactoryImpl
|
||||
import net.corda.testing.core.ALICE_NAME
|
||||
import net.corda.testing.core.BOB_NAME
|
||||
import net.corda.testing.core.DUMMY_NOTARY_NAME
|
||||
import net.corda.testing.core.TestIdentity
|
||||
import net.corda.testing.driver.DriverParameters
|
||||
import net.corda.testing.driver.NodeParameters
|
||||
import net.corda.testing.driver.driver
|
||||
import net.corda.testing.node.internal.enclosedCordapp
|
||||
import org.junit.Test
|
||||
import org.objenesis.instantiator.ObjectInstantiator
|
||||
import org.objenesis.strategy.InstantiatorStrategy
|
||||
import org.objenesis.strategy.StdInstantiatorStrategy
|
||||
import java.io.ByteArrayOutputStream
|
||||
import java.lang.reflect.Modifier
|
||||
import java.util.*
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
class CustomSerializationSchemeDriverTest {
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `flow can send wire transaction serialized with custom kryo serializer`() {
|
||||
driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
|
||||
val (alice, bob) = listOf(
|
||||
startNode(NodeParameters(providedName = ALICE_NAME)),
|
||||
startNode(NodeParameters(providedName = BOB_NAME))
|
||||
).transpose().getOrThrow()
|
||||
|
||||
val flow = alice.rpc.startFlow(::SendFlow, bob.nodeInfo.legalIdentities.single())
|
||||
assertTrue { flow.returnValue.getOrThrow() }
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `flow can write a wire transaction serialized with custom kryo serializer to the ledger`() {
|
||||
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
|
||||
val (alice, bob) = listOf(
|
||||
startNode(NodeParameters(providedName = ALICE_NAME)),
|
||||
startNode(NodeParameters(providedName = BOB_NAME))
|
||||
).transpose().getOrThrow()
|
||||
|
||||
val flow = alice.rpc.startFlow(::WriteTxToLedgerFlow, bob.nodeInfo.legalIdentities.single(), defaultNotaryIdentity)
|
||||
val txId = flow.returnValue.getOrThrow()
|
||||
val transaction = alice.rpc.startFlow(::GetTxFromDBFlow, txId).returnValue.getOrThrow()
|
||||
|
||||
for(group in transaction!!.tx.componentGroups) {
|
||||
for (item in group.components) {
|
||||
val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes())
|
||||
assertEquals( KryoScheme.SCHEME_ID, getSchemeIdIfCustomSerializationMagic(magic))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `Component groups are lazily serialized by the CustomSerializationScheme`() {
|
||||
driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
|
||||
val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow()
|
||||
//We don't need a real notary as we don't verify the transaction in this test.
|
||||
val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20)
|
||||
assertTrue { alice.rpc.startFlow(::CheckComponentGroupsFlow, dummyNotary.party).returnValue.getOrThrow() }
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `Map in the serialization context can be used by lazily component group serialization`() {
|
||||
driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
|
||||
val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow()
|
||||
//We don't need a real notary as we don't verify the transaction in this test.
|
||||
val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20)
|
||||
assertTrue { alice.rpc.startFlow(::CheckComponentGroupsWithMapFlow, dummyNotary.party).returnValue.getOrThrow() }
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class WriteTxToLedgerFlow(val counterparty: Party, val notary: Party) : FlowLogic<SecureHash>() {
|
||||
@Suspendable
|
||||
override fun call(): SecureHash {
|
||||
val outputState = TransactionState(
|
||||
data = DummyContract.DummyState(),
|
||||
contract = DummyContract::class.java.name,
|
||||
notary = notary,
|
||||
constraint = AlwaysAcceptAttachmentConstraint
|
||||
)
|
||||
val builder = TransactionBuilder()
|
||||
.addOutputState(outputState)
|
||||
.addCommand(DummyCommandData, counterparty.owningKey)
|
||||
val wireTx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID)
|
||||
val partSignedTx = signWireTx(wireTx)
|
||||
val session = initiateFlow(counterparty)
|
||||
val fullySignedTx = subFlow(CollectSignaturesFlow(partSignedTx, setOf(session)))
|
||||
subFlow(FinalityFlow(fullySignedTx, setOf(session)))
|
||||
return fullySignedTx.id
|
||||
}
|
||||
|
||||
fun signWireTx(wireTx: WireTransaction) : SignedTransaction {
|
||||
val signatureMetadata = SignatureMetadata(
|
||||
serviceHub.myInfo.platformVersion,
|
||||
Crypto.findSignatureScheme(serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey).schemeNumberID
|
||||
)
|
||||
val signableData = SignableData(wireTx.id, signatureMetadata)
|
||||
val sig = serviceHub.keyManagementService.sign(signableData, serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey)
|
||||
return SignedTransaction(wireTx, listOf(sig))
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatedBy(WriteTxToLedgerFlow::class)
|
||||
class SignWireTxFlow(private val session: FlowSession): FlowLogic<SignedTransaction>() {
|
||||
@Suspendable
|
||||
override fun call(): SignedTransaction {
|
||||
val signTransactionFlow = object : SignTransactionFlow(session) {
|
||||
override fun checkTransaction(stx: SignedTransaction) {
|
||||
return
|
||||
}
|
||||
}
|
||||
val txId = subFlow(signTransactionFlow).id
|
||||
return subFlow(ReceiveFinalityFlow(session, expectedTxId = txId))
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
class GetTxFromDBFlow(private val txId: SecureHash): FlowLogic<SignedTransaction?>() {
|
||||
override fun call(): SignedTransaction? {
|
||||
return serviceHub.validatedTransactions.getTransaction(txId)
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class CheckComponentGroupsFlow(val notary: Party) : FlowLogic<Boolean>() {
|
||||
@Suspendable
|
||||
override fun call(): Boolean {
|
||||
val outputState = TransactionState(
|
||||
data = DummyContract.DummyState(),
|
||||
contract = DummyContract::class.java.name,
|
||||
notary = notary,
|
||||
constraint = AlwaysAcceptAttachmentConstraint
|
||||
)
|
||||
val builder = TransactionBuilder()
|
||||
.addOutputState(outputState)
|
||||
.addCommand(DummyCommandData, notary.owningKey)
|
||||
|
||||
val wtx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID)
|
||||
var success = true
|
||||
for (group in wtx.componentGroups) {
|
||||
//Component groups are lazily serialized as we iterate through.
|
||||
for (item in group.components) {
|
||||
val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes())
|
||||
success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoScheme.SCHEME_ID)
|
||||
}
|
||||
}
|
||||
return success
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class CheckComponentGroupsWithMapFlow(val notary: Party) : FlowLogic<Boolean>() {
|
||||
@Suspendable
|
||||
override fun call(): Boolean {
|
||||
val outputState = TransactionState(
|
||||
data = DummyContract.DummyState(),
|
||||
contract = DummyContract::class.java.name,
|
||||
notary = notary,
|
||||
constraint = AlwaysAcceptAttachmentConstraint
|
||||
)
|
||||
val builder = TransactionBuilder()
|
||||
.addOutputState(outputState)
|
||||
.addCommand(DummyCommandData, notary.owningKey)
|
||||
val mapToCheckWhenSerializing = mapOf<Any, Any>(Pair(KryoSchemeWithMap.KEY, KryoSchemeWithMap.VALUE))
|
||||
val wtx = builder.toWireTransaction(serviceHub, KryoSchemeWithMap.SCHEME_ID, mapToCheckWhenSerializing)
|
||||
var success = true
|
||||
for (group in wtx.componentGroups) {
|
||||
//Component groups are lazily serialized as we iterate through.
|
||||
for (item in group.components) {
|
||||
val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes())
|
||||
success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoSchemeWithMap.SCHEME_ID)
|
||||
}
|
||||
}
|
||||
return success
|
||||
}
|
||||
}
|
||||
|
||||
@StartableByRPC
|
||||
@InitiatingFlow
|
||||
class SendFlow(val counterparty: Party) : FlowLogic<Boolean>() {
|
||||
@Suspendable
|
||||
override fun call(): Boolean {
|
||||
val outputState = TransactionState(
|
||||
data = DummyContract.DummyState(),
|
||||
contract = DummyContract::class.java.name,
|
||||
notary = counterparty,
|
||||
constraint = AlwaysAcceptAttachmentConstraint
|
||||
)
|
||||
val builder = TransactionBuilder()
|
||||
.addOutputState(outputState)
|
||||
.addCommand(DummyCommandData, counterparty.owningKey)
|
||||
|
||||
val wtx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID)
|
||||
val session = initiateFlow(counterparty)
|
||||
session.send(wtx)
|
||||
return session.receive<Boolean>().unwrap {it}
|
||||
}
|
||||
}
|
||||
|
||||
@InitiatedBy(SendFlow::class)
|
||||
class ReceiveFlow(private val session: FlowSession): FlowLogic<Unit>() {
|
||||
@Suspendable
|
||||
override fun call() {
|
||||
val message = session.receive<WireTransaction>().unwrap {it}
|
||||
message.toLedgerTransaction(serviceHub)
|
||||
session.send(true)
|
||||
}
|
||||
}
|
||||
|
||||
class DummyContract: Contract {
|
||||
@BelongsToContract(DummyContract::class)
|
||||
class DummyState(override val participants: List<AbstractParty> = listOf()) : ContractState
|
||||
override fun verify(tx: LedgerTransaction) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
object DummyCommandData : TypeOnlyCommandData()
|
||||
|
||||
open class KryoScheme : CustomSerializationScheme {
|
||||
|
||||
companion object {
|
||||
const val SCHEME_ID = 7
|
||||
}
|
||||
|
||||
override fun getSchemeId(): Int {
|
||||
return SCHEME_ID
|
||||
}
|
||||
|
||||
override fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
|
||||
val kryo = Kryo()
|
||||
customiseKryo(kryo, context.deserializationClassLoader)
|
||||
|
||||
val obj = Input(bytes.open()).use {
|
||||
kryo.readClassAndObject(it)
|
||||
}
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return obj as T
|
||||
}
|
||||
|
||||
override fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
|
||||
val kryo = Kryo()
|
||||
customiseKryo(kryo, context.deserializationClassLoader)
|
||||
|
||||
val outputStream = ByteArrayOutputStream()
|
||||
Output(outputStream).use {
|
||||
kryo.writeClassAndObject(it, obj)
|
||||
}
|
||||
return ByteSequence.of(outputStream.toByteArray())
|
||||
}
|
||||
|
||||
private fun customiseKryo(kryo: Kryo, classLoader: ClassLoader) {
|
||||
kryo.instantiatorStrategy = CustomInstantiatorStrategy()
|
||||
kryo.classLoader = classLoader
|
||||
kryo.register(Arrays.asList("").javaClass, ArraysAsListSerializer())
|
||||
}
|
||||
|
||||
//Stolen from DefaultKryoCustomizer.kt
|
||||
private class CustomInstantiatorStrategy : InstantiatorStrategy {
|
||||
private val fallbackStrategy = StdInstantiatorStrategy()
|
||||
|
||||
// Use this to allow construction of objects using a JVM backdoor that skips invoking the constructors, if there
|
||||
// is no no-arg constructor available.
|
||||
private val defaultStrategy = Kryo.DefaultInstantiatorStrategy(fallbackStrategy)
|
||||
|
||||
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
|
||||
// However this doesn't work for non-public classes in the java. namespace
|
||||
val strat = if (type.name.startsWith("java.") && !Modifier.isPublic(type.modifiers)) fallbackStrategy else defaultStrategy
|
||||
return strat.newInstantiatorOf(type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class KryoSchemeWithMap : KryoScheme() {
|
||||
|
||||
companion object {
|
||||
const val SCHEME_ID = 8
|
||||
const val KEY = "Key"
|
||||
const val VALUE = "Value"
|
||||
}
|
||||
|
||||
override fun getSchemeId(): Int {
|
||||
return SCHEME_ID
|
||||
}
|
||||
|
||||
override fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
|
||||
assertEquals(VALUE, context.properties[KEY])
|
||||
return super.serialize(obj, context)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
@ -73,6 +73,7 @@ import net.corda.node.utilities.DemoClock
|
||||
import net.corda.node.utilities.errorAndTerminate
|
||||
import net.corda.nodeapi.internal.ArtemisMessagingClient
|
||||
import net.corda.common.logging.errorReporting.NodeDatabaseErrors
|
||||
import net.corda.node.internal.classloading.scanForCustomSerializationScheme
|
||||
import net.corda.nodeapi.internal.ShutdownHook
|
||||
import net.corda.nodeapi.internal.addShutdownHook
|
||||
import net.corda.nodeapi.internal.bridging.BridgeControlListener
|
||||
@ -647,10 +648,14 @@ open class Node(configuration: NodeConfiguration,
|
||||
private fun initialiseSerialization() {
|
||||
if (!initialiseSerialization) return
|
||||
val classloader = cordappLoader.appClassLoader
|
||||
val customScheme = System.getProperty("experimental.corda.customSerializationScheme")?.let {
|
||||
scanForCustomSerializationScheme(it, classloader)
|
||||
}
|
||||
nodeSerializationEnv = SerializationEnvironment.with(
|
||||
SerializationFactoryImpl().apply {
|
||||
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build<SerializationFactoryCacheKey, SerializerFactory>().asMap()))
|
||||
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build<SerializationFactoryCacheKey, SerializerFactory>().asMap()))
|
||||
customScheme?.let{ registerScheme(it) }
|
||||
},
|
||||
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
|
||||
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
|
||||
|
@ -2,6 +2,31 @@
|
||||
|
||||
package net.corda.node.internal.classloading
|
||||
|
||||
import net.corda.core.serialization.CustomSerializationScheme
|
||||
import net.corda.node.internal.ConfigurationException
|
||||
import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter
|
||||
import net.corda.serialization.internal.SerializationScheme
|
||||
import java.lang.reflect.Constructor
|
||||
|
||||
inline fun <reified A : Annotation> Class<*>.requireAnnotation(): A {
|
||||
return requireNotNull(getDeclaredAnnotation(A::class.java)) { "$name needs to be annotated with ${A::class.java.name}" }
|
||||
}
|
||||
|
||||
fun scanForCustomSerializationScheme(className: String, classLoader: ClassLoader) : SerializationScheme {
|
||||
val schemaClass = try {
|
||||
classLoader.loadClass(className)
|
||||
} catch (exception: ClassNotFoundException) {
|
||||
throw ConfigurationException("$className was declared as a custom serialization scheme but could not be found.")
|
||||
}
|
||||
val constructor = validateScheme(schemaClass, className)
|
||||
return CustomSerializationSchemeAdapter(constructor.newInstance() as CustomSerializationScheme)
|
||||
}
|
||||
|
||||
private fun validateScheme(clazz: Class<*>, className: String): Constructor<*> {
|
||||
if (!clazz.interfaces.contains(CustomSerializationScheme::class.java)) {
|
||||
throw ConfigurationException("$className was declared as a custom serialization scheme but does not implement" +
|
||||
" ${CustomSerializationScheme::class.java.canonicalName}")
|
||||
}
|
||||
return clazz.constructors.singleOrNull { it.parameters.isEmpty() } ?: throw ConfigurationException("$className was declared as a " +
|
||||
"custom serialization scheme but does not have a no argument constructor.")
|
||||
}
|
@ -0,0 +1,78 @@
|
||||
package net.corda.node.internal
|
||||
|
||||
import com.nhaarman.mockito_kotlin.whenever
|
||||
import net.corda.core.serialization.CustomSerializationScheme
|
||||
import net.corda.core.serialization.SerializationContext
|
||||
import net.corda.core.serialization.SerializationSchemeContext
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.node.internal.classloading.scanForCustomSerializationScheme
|
||||
import org.junit.Test
|
||||
import org.mockito.Mockito
|
||||
import kotlin.test.assertFailsWith
|
||||
|
||||
class CustomSerializationSchemeScanningTest {
|
||||
|
||||
class NonSerializationScheme
|
||||
|
||||
open class DummySerializationScheme : CustomSerializationScheme {
|
||||
override fun getSchemeId(): Int {
|
||||
return 7;
|
||||
}
|
||||
|
||||
override fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
|
||||
throw DummySerializationSchemeException("We should never get here.")
|
||||
}
|
||||
|
||||
override fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
|
||||
throw DummySerializationSchemeException("Tried to serialize with DummySerializationScheme")
|
||||
}
|
||||
}
|
||||
|
||||
class DummySerializationSchemeException(override val message: String) : RuntimeException(message)
|
||||
|
||||
class DummySerializationSchemeWithoutNoArgConstructor(val myArgument: String) : DummySerializationScheme()
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `Can scan for custom serialization scheme and build a serialization scheme`() {
|
||||
val classLoader = Mockito.mock(ClassLoader::class.java)
|
||||
whenever(classLoader.loadClass(DummySerializationScheme::class.java.canonicalName)).thenAnswer { DummySerializationScheme::class.java }
|
||||
val scheme = scanForCustomSerializationScheme(DummySerializationScheme::class.java.canonicalName, classLoader)
|
||||
val mockContext = Mockito.mock(SerializationContext::class.java)
|
||||
assertFailsWith<DummySerializationSchemeException>("Tried to serialize with DummySerializationScheme") {
|
||||
scheme.serialize(Any::class.java, mockContext)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `verification fails with a helpful error if the class is not found in the classloader`() {
|
||||
val classLoader = Mockito.mock(ClassLoader::class.java)
|
||||
val missingClassName = DummySerializationScheme::class.java.canonicalName
|
||||
whenever(classLoader.loadClass(missingClassName)).thenAnswer { throw ClassNotFoundException()}
|
||||
assertFailsWith<ConfigurationException>("$missingClassName was declared as a custom serialization scheme but could not " +
|
||||
"be found.") {
|
||||
scanForCustomSerializationScheme(missingClassName, classLoader)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `verification fails with a helpful error if the class is not a custom serialization scheme`() {
|
||||
val canonicalName = NonSerializationScheme::class.java.canonicalName
|
||||
val classLoader = Mockito.mock(ClassLoader::class.java)
|
||||
whenever(classLoader.loadClass(canonicalName)).thenAnswer { NonSerializationScheme::class.java }
|
||||
assertFailsWith<ConfigurationException>("$canonicalName was declared as a custom serialization scheme but does not " +
|
||||
"implement CustomSerializationScheme.") {
|
||||
scanForCustomSerializationScheme(canonicalName, classLoader)
|
||||
}
|
||||
}
|
||||
|
||||
@Test(timeout = 300_000)
|
||||
fun `verification fails with a helpful error if the class does not have a no arg constructor`() {
|
||||
val classLoader = Mockito.mock(ClassLoader::class.java)
|
||||
val canonicalName = DummySerializationSchemeWithoutNoArgConstructor::class.java.canonicalName
|
||||
whenever(classLoader.loadClass(canonicalName)).thenAnswer { DummySerializationSchemeWithoutNoArgConstructor::class.java }
|
||||
assertFailsWith<ConfigurationException>("$canonicalName was declared as a custom serialization scheme but does not " +
|
||||
"have a no argument constructor.") {
|
||||
scanForCustomSerializationScheme(canonicalName, classLoader)
|
||||
}
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ import net.corda.core.crypto.SecureHash
|
||||
import net.corda.core.internal.VisibleForTesting
|
||||
import net.corda.core.internal.copyBytes
|
||||
import net.corda.core.serialization.*
|
||||
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic
|
||||
import net.corda.core.utilities.ByteSequence
|
||||
import net.corda.serialization.internal.amqp.amqpMagic
|
||||
import org.slf4j.LoggerFactory
|
||||
@ -47,6 +48,10 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe
|
||||
return copy(properties = properties + (property to value))
|
||||
}
|
||||
|
||||
override fun withProperties(extraProperties: Map<Any, Any>): SerializationContext {
|
||||
return copy(properties = properties + extraProperties)
|
||||
}
|
||||
|
||||
override fun withoutReferences(): SerializationContext {
|
||||
return copy(objectReferencesEnabled = false)
|
||||
}
|
||||
@ -106,7 +111,9 @@ open class SerializationFactoryImpl(
|
||||
registeredSchemes.filter { it.canDeserializeVersion(magic, target) }.forEach { return@computeIfAbsent it } // XXX: Not single?
|
||||
logger.warn("Cannot find serialization scheme for: [$lookupKey, " +
|
||||
"${if (magic == amqpMagic) "AMQP" else "UNKNOWN MAGIC"}] registeredSchemes are: $registeredSchemes")
|
||||
throw UnsupportedOperationException("Serialization scheme $lookupKey not supported.")
|
||||
val schemeId = getSchemeIdIfCustomSerializationMagic(magic) ?: throw UnsupportedOperationException("Serialization scheme" +
|
||||
" $lookupKey not supported.")
|
||||
throw UnsupportedOperationException("Could not find custom serialization scheme with SchemeId = $schemeId.")
|
||||
}) to magic
|
||||
}
|
||||
|
||||
|
@ -3,9 +3,11 @@ package net.corda.coretesting.internal
|
||||
import net.corda.nodeapi.internal.rpc.client.AMQPClientSerializationScheme
|
||||
import net.corda.core.internal.createInstancesOfClassesImplementing
|
||||
import net.corda.core.serialization.CheckpointCustomSerializer
|
||||
import net.corda.core.serialization.CustomSerializationScheme
|
||||
import net.corda.core.serialization.SerializationCustomSerializer
|
||||
import net.corda.core.serialization.SerializationWhitelist
|
||||
import net.corda.core.serialization.internal.SerializationEnvironment
|
||||
import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter
|
||||
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
|
||||
import net.corda.nodeapi.internal.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
|
||||
import net.corda.nodeapi.internal.serialization.kryo.KryoCheckpointSerializer
|
||||
@ -14,6 +16,7 @@ import net.corda.serialization.internal.AMQP_RPC_CLIENT_CONTEXT
|
||||
import net.corda.serialization.internal.AMQP_RPC_SERVER_CONTEXT
|
||||
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
|
||||
import net.corda.serialization.internal.SerializationFactoryImpl
|
||||
import net.corda.serialization.internal.SerializationScheme
|
||||
import net.corda.testing.common.internal.asContextEnv
|
||||
import java.util.ServiceLoader
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
@ -27,20 +30,29 @@ fun createTestSerializationEnv(): SerializationEnvironment {
|
||||
|
||||
fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment {
|
||||
var customCheckpointSerializers: Set<CheckpointCustomSerializer<*, *>> = emptySet()
|
||||
val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) {
|
||||
val serializationSchemes: MutableList<SerializationScheme> = mutableListOf()
|
||||
if (classLoader != null) {
|
||||
val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java)
|
||||
customCheckpointSerializers = createInstancesOfClassesImplementing(classLoader, CheckpointCustomSerializer::class.java)
|
||||
|
||||
val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet()
|
||||
|
||||
Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists),
|
||||
AMQPServerSerializationScheme(customSerializers, serializationWhitelists))
|
||||
serializationSchemes.add(AMQPClientSerializationScheme(customSerializers, serializationWhitelists))
|
||||
serializationSchemes.add(AMQPServerSerializationScheme(customSerializers, serializationWhitelists))
|
||||
|
||||
val customSchemes = createInstancesOfClassesImplementing(classLoader, CustomSerializationScheme::class.java)
|
||||
for (customScheme in customSchemes) {
|
||||
serializationSchemes.add(CustomSerializationSchemeAdapter(customScheme))
|
||||
}
|
||||
} else {
|
||||
Pair(AMQPClientSerializationScheme(emptyList()), AMQPServerSerializationScheme(emptyList()))
|
||||
serializationSchemes.add(AMQPClientSerializationScheme(emptyList()))
|
||||
serializationSchemes.add(AMQPServerSerializationScheme(emptyList()))
|
||||
}
|
||||
|
||||
val factory = SerializationFactoryImpl().apply {
|
||||
registerScheme(clientSerializationScheme)
|
||||
registerScheme(serverSerializationScheme)
|
||||
for (serializationScheme in serializationSchemes) {
|
||||
registerScheme(serializationScheme)
|
||||
}
|
||||
}
|
||||
return SerializationEnvironment.with(
|
||||
factory,
|
||||
|
Loading…
Reference in New Issue
Block a user