CORDA-4103 Feature Branch: Serialization injection for transaction building (#6867)

* CORDA-4105 Add public API to allow custom serialization schemes (#6848)

* CORDA-4105 Add public API to allow custom serialization schemes

* Fix Detekt

* Suppress warning

* Fix usused import

* Improve API to use generics

This does not break Java support (only Intelij gets confused).

* Add more detailed documentation to public interfaces

* Change internal variable name after rename public API

* Update Public API to use ByteSquence instead of SerializedBytes

* Remove unused import

* Fix whitespace.

* Add added public API to .ci/api-current.txt

* Improve public interfaces

Rename CustomSchemeContext to SerializationSchemeContext to improve
clarity and move to it's own file. Improve kdoc to make things less
confusing.

* Update API current with changed API

* CORDA-4104 Implement custom serialization scheme discovery mechanism (#6854)

* CORDA-4104 Implement CustomSerializationScheme Discovery Mechanism

Discovers a single CustomSerializationScheme implementation inside
the drivers dir using a system property.

* Started MockNetwork test

* Add driver test of Custom Serialization Scheme

* Fix detekt and minor style error

* Respond to review comments

Allow non-single arg constructors (there must be one no args
constructor), move code from SerializationEnviroment into its
own file, improve exceptions to be more user friendly.

* Fix minor bug in Scheme finding code  + improve error messages

* CORDA-4104 Improve test coverage of custom serialization scheme discovery (#6855)

* CORDA-4104 Add test of classloader scanning for CustomSerializationSchemes

* Fix Detekt

* NOTICK Clarify KDOC on SerializationSchemeContext (#6865)

* CORDA-4111 Change Component Group Serialization to use contex when the lazy map is constructed (#6856)

Currently the component group will recheck the thread local (global)
serialization context when component groups are serialized lazily.
Instead store the serialization context when the lazy map is constructed
and use that latter when doing serialization lazily.

* CORDA-4106 Test wire transaction can still be written to the ledger (#6860)

* Add test that writes transaction to the Database

* Improve test check serialization scheme in test body

* CORDA-4119 Minor changes to serialisation injection for transaction building (#6868)

* CORDA-4119 Minor changes to serialisation injection for transaction building

Scan the CorDapp classloader instead of the drivers classloader.
Add properties map to CustomSerialiaztionContext (copied from SerializationContext).
Change API to let a user pass in the serialization context in TransactionBuilder.toLedgerTransaction

* Improve KDOC + fix shawdowing issue in CordaUtils

* Pass only the properties map into theTransactionBuilder.toWireTransaction

Not the entire serializationContext

* Revert change to CordaUtils

* Improve KDOC explain pitfalls of setting properties
This commit is contained in:
William Vigor 2021-02-11 15:27:03 +00:00 committed by GitHub
parent 4f336a1a67
commit 20dbbf008d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 838 additions and 14 deletions

View File

@ -5937,6 +5937,13 @@ public @interface net.corda.core.serialization.CordaSerializationTransformRename
public @interface net.corda.core.serialization.CordaSerializationTransformRenames
public abstract net.corda.core.serialization.CordaSerializationTransformRename[] value()
##
public interface net.corda.core.serialization.CustomSerializationScheme
@NotNull
public abstract T deserialize(net.corda.core.utilities.ByteSequence, Class<T>, net.corda.core.serialization.SerializationSchemeContext)
public abstract int getSchemeId()
@NotNull
public abstract net.corda.core.utilities.ByteSequence serialize(T, net.corda.core.serialization.SerializationSchemeContext)
##
public @interface net.corda.core.serialization.DeprecatedConstructorForDeserialization
public abstract int version()
##
@ -6078,6 +6085,13 @@ public static final class net.corda.core.serialization.SerializationFactory$Comp
@NotNull
public final net.corda.core.serialization.SerializationFactory getDefaultFactory()
##
@DoNotImplement
public interface net.corda.core.serialization.SerializationSchemeContext
@NotNull
public abstract ClassLoader getDeserializationClassLoader()
@NotNull
public abstract net.corda.core.serialization.ClassWhitelist getWhitelist()
##
public interface net.corda.core.serialization.SerializationToken
@NotNull
public abstract Object fromToken(net.corda.core.serialization.SerializeAsTokenContext)

View File

@ -3,7 +3,16 @@ package net.corda.coretests.transactions
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.*
import net.corda.core.contracts.Command
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.HashAttachmentConstraint
import net.corda.core.contracts.PrivacySalt
import net.corda.core.contracts.SignatureAttachmentConstraint
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TimeWindow
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.TransactionVerificationException
import net.corda.core.cordapp.CordappProvider
import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.DigestService
@ -20,11 +29,16 @@ import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.NetworkParametersService
import net.corda.core.serialization.serialize
import net.corda.core.transactions.TransactionBuilder
import net.corda.coretesting.internal.rigorousMock
import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState
import net.corda.testing.core.*
import net.corda.coretesting.internal.rigorousMock
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.DummyCommandData
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity
import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Assert.assertFalse
@ -35,6 +49,7 @@ import org.junit.Rule
import org.junit.Test
import java.security.PublicKey
import java.time.Instant
import kotlin.test.assertFailsWith
class TransactionBuilderTest {
@Rule
@ -299,4 +314,22 @@ class TransactionBuilderTest {
HashAgility.init()
}
}
@Test(timeout=300_000)
fun `toWireTransaction fails if no scheme is registered with schemeId`() {
val outputState = TransactionState(
data = DummyState(),
contract = DummyContract.PROGRAM_ID,
notary = notary,
constraint = HashAttachmentConstraint(contractAttachmentId)
)
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, notary.owningKey)
val schemeId = 7
assertFailsWith<UnsupportedOperationException>("Could not find custom serialization scheme with SchemeId = $schemeId.") {
builder.toWireTransaction(services, schemeId)
}
}
}

View File

@ -154,7 +154,8 @@ fun createComponentGroups(inputs: List<StateRef>,
timeWindow: TimeWindow?,
references: List<StateRef>,
networkParametersHash: SecureHash?): List<ComponentGroup> {
val serialize = { value: Any, _: Int -> value.serialize() }
val serializationContext = SerializationFactory.defaultFactory.defaultContext
val serialize = { value: Any, _: Int -> value.serialize(context = serializationContext) }
val componentGroupMap: MutableList<ComponentGroup> = mutableListOf()
if (inputs.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.INPUTS_GROUP.ordinal, inputs.lazyMapped(serialize)))
if (references.isNotEmpty()) componentGroupMap.add(ComponentGroup(ComponentGroupEnum.REFERENCES_GROUP.ordinal, references.lazyMapped(serialize)))

View File

@ -0,0 +1,35 @@
package net.corda.core.serialization
import net.corda.core.utilities.ByteSequence
import java.io.NotSerializableException
/***
* Implement this interface to add your own Serialization Scheme. This is an experimental feature. All methods in this class MUST be
* thread safe i.e. methods from the same instance of this class can be called in different threads simultaneously.
*/
interface CustomSerializationScheme {
/**
* This method must return an id used to uniquely identify the Scheme. This should be unique within a network as serialized data might
* be sent over the wire.
*/
fun getSchemeId(): Int
/**
* This method must deserialize the data stored [bytes] into an instance of [T].
*
* @param bytes the serialized data.
* @param clazz the class to instantiate.
* @param context used to pass information about how the object should be deserialized.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T
/**
* This method must be able to serialize any object [T] into a ByteSequence.
*
* @param obj the object to be serialized.
* @param context used to pass information about how the object should be serialized.
*/
@Throws(NotSerializableException::class)
fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence
}

View File

@ -134,7 +134,7 @@ interface SerializationContext {
*/
val encodingWhitelist: EncodingWhitelist
/**
* A map of any addition properties specific to the particular use case.
* A map of any additional properties specific to the particular use case.
*/
val properties: Map<Any, Any>
/**
@ -178,6 +178,11 @@ interface SerializationContext {
*/
fun withProperty(property: Any, value: Any): SerializationContext
/**
* Helper method to return a new context based on this context with the extra properties added.
*/
fun withProperties(extraProperties: Map<Any, Any>): SerializationContext
/**
* Helper method to return a new context based on this context with object references disabled.
*/

View File

@ -0,0 +1,30 @@
package net.corda.core.serialization
import net.corda.core.DoNotImplement
/**
* This is used to pass information into [CustomSerializationScheme] about how the object should be (de)serialized.
* This context can change depending on the specific circumstances in the node when (de)serialization occurs.
*/
@DoNotImplement
interface SerializationSchemeContext {
/**
* The class loader to use for deserialization. This is guaranteed to be able to load all the required classes passed into
* [CustomSerializationScheme.deserialize].
*/
val deserializationClassLoader: ClassLoader
/**
* A whitelist that contains (mostly for security purposes) which classes are authorised to be deserialized.
* A secure implementation will not instantiate any object which is not either whitelisted or annotated with [CordaSerializable] when
* deserializing. To catch classes missing from the whitelist as early as possible it is HIGHLY recommended to also check this
* whitelist when serializing (as well as deserializing) objects.
*/
val whitelist: ClassWhitelist
/**
* A map of any additional properties specific to the particular use case. If these properties are set via
* [toWireTransaction][net.corda.core.transactions.TransactionBuilder.toWireTransaction] then they might not be available when
* deserializing. If the properties are required when deserializing, they can be added into the blob when serializing and read back
* when deserializing.
*/
val properties: Map<Any, Any>
}

View File

@ -0,0 +1,28 @@
package net.corda.core.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializationMagic
import net.corda.core.utilities.ByteSequence
import java.nio.ByteBuffer
class CustomSerializationSchemeUtils {
@KeepForDJVM
companion object {
private const val SERIALIZATION_SCHEME_ID_SIZE = 4
private val PREFIX = "CUS".toByteArray()
fun getCustomSerializationMagicFromSchemeId(schemeId: Int) : SerializationMagic {
return SerializationMagic.of(PREFIX + ByteBuffer.allocate(SERIALIZATION_SCHEME_ID_SIZE).putInt(schemeId).array())
}
fun getSchemeIdIfCustomSerializationMagic(magic: SerializationMagic): Int? {
return if (magic.take(PREFIX.size) != ByteSequence.of(PREFIX)) {
null
} else {
return magic.slice(start = PREFIX.size).int
}
}
}
}

View File

@ -16,14 +16,18 @@ import net.corda.core.node.ServicesForResolution
import net.corda.core.node.ZoneVersionTooLowException
import net.corda.core.node.services.AttachmentId
import net.corda.core.node.services.KeyManagementService
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializationMagic
import net.corda.core.serialization.SerializationSchemeContext
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId
import net.corda.core.utilities.contextLogger
import java.security.PublicKey
import java.time.Duration
import java.time.Instant
import java.util.ArrayDeque
import java.util.UUID
import java.util.*
import java.util.regex.Pattern
import kotlin.collections.ArrayList
import kotlin.collections.component1
@ -140,6 +144,41 @@ open class TransactionBuilder(
fun toWireTransaction(services: ServicesForResolution): WireTransaction = toWireTransactionWithContext(services, null)
.apply { checkSupportedHashType() }
/**
* Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction.
*
* @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction].
* This is an experimental feature.
*
* @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder].
*
* @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4.
*/
@Throws(MissingContractAttachments::class)
fun toWireTransaction(services: ServicesForResolution, schemeId: Int): WireTransaction {
return toWireTransaction(services, schemeId, emptyMap()).apply { checkSupportedHashType() }
}
/**
* Generates a [WireTransaction] from this builder, resolves any [AutomaticPlaceholderConstraint], and selects the attachments to use for this transaction.
*
* @param [schemeId] is used to specify the [CustomSerializationScheme] used to serialize each component of the componentGroups of the [WireTransaction].
* This is an experimental feature.
*
* @param [properties] a list of properties to add to the [SerializationSchemeContext] these properties can be accessed in [CustomSerializationScheme.serialize]
* when serializing the componentGroups of the wire transaction but might not be available when deserializing.
*
* @returns A new [WireTransaction] that will be unaffected by further changes to this [TransactionBuilder].
*
* @throws [ZoneVersionTooLowException] if there are reference states and the zone minimum platform version is less than 4.
*/
@Throws(MissingContractAttachments::class)
fun toWireTransaction(services: ServicesForResolution, schemeId: Int, properties: Map<Any, Any>): WireTransaction {
val magic: SerializationMagic = getCustomSerializationMagicFromSchemeId(schemeId)
val serializationContext = SerializationDefaults.P2P_CONTEXT.withPreferredSerializationVersion(magic).withProperties(properties)
return toWireTransactionWithContext(services, serializationContext).apply { checkSupportedHashType() }
}
@CordaInternal
internal fun toWireTransactionWithContext(
services: ServicesForResolution,

View File

@ -0,0 +1,47 @@
package net.corda.nodeapi.internal.serialization
import net.corda.core.serialization.SerializationSchemeContext
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getCustomSerializationMagicFromSchemeId
import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.CordaSerializationMagic
import net.corda.serialization.internal.SerializationScheme
import java.io.ByteArrayOutputStream
import java.io.NotSerializableException
class CustomSerializationSchemeAdapter(private val customScheme: CustomSerializationScheme): SerializationScheme {
val serializationSchemeMagic = getCustomSerializationMagicFromSchemeId(customScheme.getSchemeId())
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == serializationSchemeMagic
}
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
val readMagic = byteSequence.take(serializationSchemeMagic.size)
if (readMagic != serializationSchemeMagic) {
throw NotSerializableException("Scheme ${customScheme::class.java} is incompatible with blob." +
" Magic from blob = $readMagic (Expected = $serializationSchemeMagic)")
}
return customScheme.deserialize(
byteSequence.subSequence(serializationSchemeMagic.size, byteSequence.size - serializationSchemeMagic.size),
clazz,
SerializationSchemeContextAdapter(context)
)
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
val stream = ByteArrayOutputStream()
stream.write(serializationSchemeMagic.bytes)
stream.write(customScheme.serialize(obj, SerializationSchemeContextAdapter(context)).bytes)
return SerializedBytes(stream.toByteArray())
}
private class SerializationSchemeContextAdapter(context: SerializationContext) : SerializationSchemeContext {
override val deserializationClassLoader = context.deserializationClassLoader
override val whitelist = context.whitelist
override val properties = context.properties
}
}

View File

@ -0,0 +1,30 @@
package net.corda.nodeapi.internal.serialization;
import net.corda.core.serialization.SerializationSchemeContext;
import net.corda.core.serialization.CustomSerializationScheme;
import net.corda.core.serialization.SerializedBytes;
import net.corda.core.utilities.ByteSequence;
public class DummyCustomSerializationSchemeInJava implements CustomSerializationScheme {
public class DummyOutput {}
static final int testMagic = 7;
@Override
public int getSchemeId() {
return testMagic;
}
@Override
@SuppressWarnings("unchecked")
public <T> T deserialize(ByteSequence bytes, Class<T> clazz, SerializationSchemeContext context) {
return (T)new DummyOutput();
}
@Override
public <T> SerializedBytes<T> serialize(T obj, SerializationSchemeContext context) {
byte[] myBytes = {0xA, 0xA};
return new SerializedBytes<>(myBytes);
}
}

View File

@ -0,0 +1,93 @@
package net.corda.nodeapi.internal.serialization
import net.corda.core.serialization.SerializationSchemeContext
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.internal.serialization.testutils.serializationContext
import org.junit.Test
import org.junit.jupiter.api.Assertions.assertTrue
import java.io.NotSerializableException
import kotlin.test.assertFailsWith
class CustomSerializationSchemeAdapterTests {
companion object {
const val DEFAULT_SCHEME_ID = 7
}
class DummyInputClass
class DummyOutputClass
class SingleInputAndOutputScheme(private val schemeId: Int = DEFAULT_SCHEME_ID): CustomSerializationScheme {
override fun getSchemeId(): Int {
return schemeId
}
override fun <T: Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
@Suppress("UNCHECKED_CAST")
return DummyOutputClass() as T
}
override fun <T: Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
assertTrue(obj is DummyInputClass)
return ByteSequence.of(ByteArray(2) { 0x2 })
}
}
class SameBytesInputAndOutputsAndScheme: CustomSerializationScheme {
private val expectedBytes = "123456789".toByteArray()
override fun getSchemeId(): Int {
return DEFAULT_SCHEME_ID
}
override fun <T: Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
bytes.open().use {
val data = ByteArray(expectedBytes.size) { 0 }
it.read(data)
assertTrue(data.contentEquals(expectedBytes))
}
@Suppress("UNCHECKED_CAST")
return DummyOutputClass() as T
}
override fun <T: Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
return ByteSequence.of(expectedBytes)
}
}
@Test(timeout=300_000)
fun `CustomSerializationSchemeAdapter calls the correct methods in CustomSerializationScheme`() {
val scheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme())
val serializedData = scheme.serialize(DummyInputClass(), serializationContext)
val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext)
assertTrue(roundTripped is DummyOutputClass)
}
@Test(timeout=300_000)
fun `CustomSerializationSchemeAdapter can adapt a Java implementation`() {
val scheme = CustomSerializationSchemeAdapter(DummyCustomSerializationSchemeInJava())
val serializedData = scheme.serialize(DummyInputClass(), serializationContext)
val roundTripped = scheme.deserialize(serializedData, Any::class.java, serializationContext)
assertTrue(roundTripped is DummyCustomSerializationSchemeInJava.DummyOutput)
}
@Test(timeout=300_000)
fun `CustomSerializationSchemeAdapter validates the magic`() {
val inScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme())
val serializedData = inScheme.serialize(DummyInputClass(), serializationContext)
val outScheme = CustomSerializationSchemeAdapter(SingleInputAndOutputScheme(8))
assertFailsWith<NotSerializableException> {
outScheme.deserialize(serializedData, DummyOutputClass::class.java, serializationContext)
}
}
@Test(timeout=300_000)
fun `CustomSerializationSchemeAdapter preserves the serialized bytes between deserialize and serialize`() {
val scheme = CustomSerializationSchemeAdapter(SameBytesInputAndOutputsAndScheme())
val serializedData = scheme.serialize(Any(), serializationContext)
scheme.deserialize(serializedData, Any::class.java, serializationContext)
}
}

View File

@ -0,0 +1,342 @@
package net.corda.node
import co.paralleluniverse.fibers.Suspendable
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import de.javakaffee.kryoserializers.ArraysAsListSerializer
import net.corda.core.contracts.AlwaysAcceptAttachmentConstraint
import net.corda.core.contracts.BelongsToContract
import net.corda.core.contracts.Contract
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.TransactionState
import net.corda.core.contracts.TypeOnlyCommandData
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignableData
import net.corda.core.crypto.SignatureMetadata
import net.corda.core.flows.CollectSignaturesFlow
import net.corda.core.flows.FinalityFlow
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.ReceiveFinalityFlow
import net.corda.core.flows.SignTransactionFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.transpose
import net.corda.core.internal.copyBytes
import net.corda.core.messaging.startFlow
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.core.serialization.SerializationSchemeContext
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.transactions.WireTransaction
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.serialization.internal.CordaSerializationMagic
import net.corda.serialization.internal.SerializationFactoryImpl
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.NodeParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.internal.enclosedCordapp
import org.junit.Test
import org.objenesis.instantiator.ObjectInstantiator
import org.objenesis.strategy.InstantiatorStrategy
import org.objenesis.strategy.StdInstantiatorStrategy
import java.io.ByteArrayOutputStream
import java.lang.reflect.Modifier
import java.util.*
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class CustomSerializationSchemeDriverTest {
@Test(timeout = 300_000)
fun `flow can send wire transaction serialized with custom kryo serializer`() {
driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
val (alice, bob) = listOf(
startNode(NodeParameters(providedName = ALICE_NAME)),
startNode(NodeParameters(providedName = BOB_NAME))
).transpose().getOrThrow()
val flow = alice.rpc.startFlow(::SendFlow, bob.nodeInfo.legalIdentities.single())
assertTrue { flow.returnValue.getOrThrow() }
}
}
@Test(timeout = 300_000)
fun `flow can write a wire transaction serialized with custom kryo serializer to the ledger`() {
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
val (alice, bob) = listOf(
startNode(NodeParameters(providedName = ALICE_NAME)),
startNode(NodeParameters(providedName = BOB_NAME))
).transpose().getOrThrow()
val flow = alice.rpc.startFlow(::WriteTxToLedgerFlow, bob.nodeInfo.legalIdentities.single(), defaultNotaryIdentity)
val txId = flow.returnValue.getOrThrow()
val transaction = alice.rpc.startFlow(::GetTxFromDBFlow, txId).returnValue.getOrThrow()
for(group in transaction!!.tx.componentGroups) {
for (item in group.components) {
val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes())
assertEquals( KryoScheme.SCHEME_ID, getSchemeIdIfCustomSerializationMagic(magic))
}
}
}
}
@Test(timeout = 300_000)
fun `Component groups are lazily serialized by the CustomSerializationScheme`() {
driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow()
//We don't need a real notary as we don't verify the transaction in this test.
val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20)
assertTrue { alice.rpc.startFlow(::CheckComponentGroupsFlow, dummyNotary.party).returnValue.getOrThrow() }
}
}
@Test(timeout = 300_000)
fun `Map in the serialization context can be used by lazily component group serialization`() {
driver(DriverParameters(notarySpecs = emptyList(), startNodesInProcess = true, cordappsForAllNodes = listOf(enclosedCordapp()))) {
val alice = startNode(NodeParameters(providedName = ALICE_NAME)).getOrThrow()
//We don't need a real notary as we don't verify the transaction in this test.
val dummyNotary = TestIdentity(DUMMY_NOTARY_NAME, 20)
assertTrue { alice.rpc.startFlow(::CheckComponentGroupsWithMapFlow, dummyNotary.party).returnValue.getOrThrow() }
}
}
@StartableByRPC
@InitiatingFlow
class WriteTxToLedgerFlow(val counterparty: Party, val notary: Party) : FlowLogic<SecureHash>() {
@Suspendable
override fun call(): SecureHash {
val outputState = TransactionState(
data = DummyContract.DummyState(),
contract = DummyContract::class.java.name,
notary = notary,
constraint = AlwaysAcceptAttachmentConstraint
)
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, counterparty.owningKey)
val wireTx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID)
val partSignedTx = signWireTx(wireTx)
val session = initiateFlow(counterparty)
val fullySignedTx = subFlow(CollectSignaturesFlow(partSignedTx, setOf(session)))
subFlow(FinalityFlow(fullySignedTx, setOf(session)))
return fullySignedTx.id
}
fun signWireTx(wireTx: WireTransaction) : SignedTransaction {
val signatureMetadata = SignatureMetadata(
serviceHub.myInfo.platformVersion,
Crypto.findSignatureScheme(serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey).schemeNumberID
)
val signableData = SignableData(wireTx.id, signatureMetadata)
val sig = serviceHub.keyManagementService.sign(signableData, serviceHub.myInfo.legalIdentitiesAndCerts.first().owningKey)
return SignedTransaction(wireTx, listOf(sig))
}
}
@InitiatedBy(WriteTxToLedgerFlow::class)
class SignWireTxFlow(private val session: FlowSession): FlowLogic<SignedTransaction>() {
@Suspendable
override fun call(): SignedTransaction {
val signTransactionFlow = object : SignTransactionFlow(session) {
override fun checkTransaction(stx: SignedTransaction) {
return
}
}
val txId = subFlow(signTransactionFlow).id
return subFlow(ReceiveFinalityFlow(session, expectedTxId = txId))
}
}
@StartableByRPC
class GetTxFromDBFlow(private val txId: SecureHash): FlowLogic<SignedTransaction?>() {
override fun call(): SignedTransaction? {
return serviceHub.validatedTransactions.getTransaction(txId)
}
}
@StartableByRPC
@InitiatingFlow
class CheckComponentGroupsFlow(val notary: Party) : FlowLogic<Boolean>() {
@Suspendable
override fun call(): Boolean {
val outputState = TransactionState(
data = DummyContract.DummyState(),
contract = DummyContract::class.java.name,
notary = notary,
constraint = AlwaysAcceptAttachmentConstraint
)
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, notary.owningKey)
val wtx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID)
var success = true
for (group in wtx.componentGroups) {
//Component groups are lazily serialized as we iterate through.
for (item in group.components) {
val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes())
success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoScheme.SCHEME_ID)
}
}
return success
}
}
@StartableByRPC
@InitiatingFlow
class CheckComponentGroupsWithMapFlow(val notary: Party) : FlowLogic<Boolean>() {
@Suspendable
override fun call(): Boolean {
val outputState = TransactionState(
data = DummyContract.DummyState(),
contract = DummyContract::class.java.name,
notary = notary,
constraint = AlwaysAcceptAttachmentConstraint
)
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, notary.owningKey)
val mapToCheckWhenSerializing = mapOf<Any, Any>(Pair(KryoSchemeWithMap.KEY, KryoSchemeWithMap.VALUE))
val wtx = builder.toWireTransaction(serviceHub, KryoSchemeWithMap.SCHEME_ID, mapToCheckWhenSerializing)
var success = true
for (group in wtx.componentGroups) {
//Component groups are lazily serialized as we iterate through.
for (item in group.components) {
val magic = CordaSerializationMagic(item.slice(end = SerializationFactoryImpl.magicSize).copyBytes())
success = success && (getSchemeIdIfCustomSerializationMagic(magic) == KryoSchemeWithMap.SCHEME_ID)
}
}
return success
}
}
@StartableByRPC
@InitiatingFlow
class SendFlow(val counterparty: Party) : FlowLogic<Boolean>() {
@Suspendable
override fun call(): Boolean {
val outputState = TransactionState(
data = DummyContract.DummyState(),
contract = DummyContract::class.java.name,
notary = counterparty,
constraint = AlwaysAcceptAttachmentConstraint
)
val builder = TransactionBuilder()
.addOutputState(outputState)
.addCommand(DummyCommandData, counterparty.owningKey)
val wtx = builder.toWireTransaction(serviceHub, KryoScheme.SCHEME_ID)
val session = initiateFlow(counterparty)
session.send(wtx)
return session.receive<Boolean>().unwrap {it}
}
}
@InitiatedBy(SendFlow::class)
class ReceiveFlow(private val session: FlowSession): FlowLogic<Unit>() {
@Suspendable
override fun call() {
val message = session.receive<WireTransaction>().unwrap {it}
message.toLedgerTransaction(serviceHub)
session.send(true)
}
}
class DummyContract: Contract {
@BelongsToContract(DummyContract::class)
class DummyState(override val participants: List<AbstractParty> = listOf()) : ContractState
override fun verify(tx: LedgerTransaction) {
return
}
}
object DummyCommandData : TypeOnlyCommandData()
open class KryoScheme : CustomSerializationScheme {
companion object {
const val SCHEME_ID = 7
}
override fun getSchemeId(): Int {
return SCHEME_ID
}
override fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
val kryo = Kryo()
customiseKryo(kryo, context.deserializationClassLoader)
val obj = Input(bytes.open()).use {
kryo.readClassAndObject(it)
}
@Suppress("UNCHECKED_CAST")
return obj as T
}
override fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
val kryo = Kryo()
customiseKryo(kryo, context.deserializationClassLoader)
val outputStream = ByteArrayOutputStream()
Output(outputStream).use {
kryo.writeClassAndObject(it, obj)
}
return ByteSequence.of(outputStream.toByteArray())
}
private fun customiseKryo(kryo: Kryo, classLoader: ClassLoader) {
kryo.instantiatorStrategy = CustomInstantiatorStrategy()
kryo.classLoader = classLoader
kryo.register(Arrays.asList("").javaClass, ArraysAsListSerializer())
}
//Stolen from DefaultKryoCustomizer.kt
private class CustomInstantiatorStrategy : InstantiatorStrategy {
private val fallbackStrategy = StdInstantiatorStrategy()
// Use this to allow construction of objects using a JVM backdoor that skips invoking the constructors, if there
// is no no-arg constructor available.
private val defaultStrategy = Kryo.DefaultInstantiatorStrategy(fallbackStrategy)
override fun <T> newInstantiatorOf(type: Class<T>): ObjectInstantiator<T> {
// However this doesn't work for non-public classes in the java. namespace
val strat = if (type.name.startsWith("java.") && !Modifier.isPublic(type.modifiers)) fallbackStrategy else defaultStrategy
return strat.newInstantiatorOf(type)
}
}
}
class KryoSchemeWithMap : KryoScheme() {
companion object {
const val SCHEME_ID = 8
const val KEY = "Key"
const val VALUE = "Value"
}
override fun getSchemeId(): Int {
return SCHEME_ID
}
override fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
assertEquals(VALUE, context.properties[KEY])
return super.serialize(obj, context)
}
}
}

View File

@ -73,6 +73,7 @@ import net.corda.node.utilities.DemoClock
import net.corda.node.utilities.errorAndTerminate
import net.corda.nodeapi.internal.ArtemisMessagingClient
import net.corda.common.logging.errorReporting.NodeDatabaseErrors
import net.corda.node.internal.classloading.scanForCustomSerializationScheme
import net.corda.nodeapi.internal.ShutdownHook
import net.corda.nodeapi.internal.addShutdownHook
import net.corda.nodeapi.internal.bridging.BridgeControlListener
@ -647,10 +648,14 @@ open class Node(configuration: NodeConfiguration,
private fun initialiseSerialization() {
if (!initialiseSerialization) return
val classloader = cordappLoader.appClassLoader
val customScheme = System.getProperty("experimental.corda.customSerializationScheme")?.let {
scanForCustomSerializationScheme(it, classloader)
}
nodeSerializationEnv = SerializationEnvironment.with(
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build<SerializationFactoryCacheKey, SerializerFactory>().asMap()))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps, Caffeine.newBuilder().maximumSize(128).build<SerializationFactoryCacheKey, SerializerFactory>().asMap()))
customScheme?.let{ registerScheme(it) }
},
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),

View File

@ -2,6 +2,31 @@
package net.corda.node.internal.classloading
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.node.internal.ConfigurationException
import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter
import net.corda.serialization.internal.SerializationScheme
import java.lang.reflect.Constructor
inline fun <reified A : Annotation> Class<*>.requireAnnotation(): A {
return requireNotNull(getDeclaredAnnotation(A::class.java)) { "$name needs to be annotated with ${A::class.java.name}" }
}
fun scanForCustomSerializationScheme(className: String, classLoader: ClassLoader) : SerializationScheme {
val schemaClass = try {
classLoader.loadClass(className)
} catch (exception: ClassNotFoundException) {
throw ConfigurationException("$className was declared as a custom serialization scheme but could not be found.")
}
val constructor = validateScheme(schemaClass, className)
return CustomSerializationSchemeAdapter(constructor.newInstance() as CustomSerializationScheme)
}
private fun validateScheme(clazz: Class<*>, className: String): Constructor<*> {
if (!clazz.interfaces.contains(CustomSerializationScheme::class.java)) {
throw ConfigurationException("$className was declared as a custom serialization scheme but does not implement" +
" ${CustomSerializationScheme::class.java.canonicalName}")
}
return clazz.constructors.singleOrNull { it.parameters.isEmpty() } ?: throw ConfigurationException("$className was declared as a " +
"custom serialization scheme but does not have a no argument constructor.")
}

View File

@ -0,0 +1,78 @@
package net.corda.node.internal
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationSchemeContext
import net.corda.core.utilities.ByteSequence
import net.corda.node.internal.classloading.scanForCustomSerializationScheme
import org.junit.Test
import org.mockito.Mockito
import kotlin.test.assertFailsWith
class CustomSerializationSchemeScanningTest {
class NonSerializationScheme
open class DummySerializationScheme : CustomSerializationScheme {
override fun getSchemeId(): Int {
return 7;
}
override fun <T : Any> deserialize(bytes: ByteSequence, clazz: Class<T>, context: SerializationSchemeContext): T {
throw DummySerializationSchemeException("We should never get here.")
}
override fun <T : Any> serialize(obj: T, context: SerializationSchemeContext): ByteSequence {
throw DummySerializationSchemeException("Tried to serialize with DummySerializationScheme")
}
}
class DummySerializationSchemeException(override val message: String) : RuntimeException(message)
class DummySerializationSchemeWithoutNoArgConstructor(val myArgument: String) : DummySerializationScheme()
@Test(timeout = 300_000)
fun `Can scan for custom serialization scheme and build a serialization scheme`() {
val classLoader = Mockito.mock(ClassLoader::class.java)
whenever(classLoader.loadClass(DummySerializationScheme::class.java.canonicalName)).thenAnswer { DummySerializationScheme::class.java }
val scheme = scanForCustomSerializationScheme(DummySerializationScheme::class.java.canonicalName, classLoader)
val mockContext = Mockito.mock(SerializationContext::class.java)
assertFailsWith<DummySerializationSchemeException>("Tried to serialize with DummySerializationScheme") {
scheme.serialize(Any::class.java, mockContext)
}
}
@Test(timeout = 300_000)
fun `verification fails with a helpful error if the class is not found in the classloader`() {
val classLoader = Mockito.mock(ClassLoader::class.java)
val missingClassName = DummySerializationScheme::class.java.canonicalName
whenever(classLoader.loadClass(missingClassName)).thenAnswer { throw ClassNotFoundException()}
assertFailsWith<ConfigurationException>("$missingClassName was declared as a custom serialization scheme but could not " +
"be found.") {
scanForCustomSerializationScheme(missingClassName, classLoader)
}
}
@Test(timeout = 300_000)
fun `verification fails with a helpful error if the class is not a custom serialization scheme`() {
val canonicalName = NonSerializationScheme::class.java.canonicalName
val classLoader = Mockito.mock(ClassLoader::class.java)
whenever(classLoader.loadClass(canonicalName)).thenAnswer { NonSerializationScheme::class.java }
assertFailsWith<ConfigurationException>("$canonicalName was declared as a custom serialization scheme but does not " +
"implement CustomSerializationScheme.") {
scanForCustomSerializationScheme(canonicalName, classLoader)
}
}
@Test(timeout = 300_000)
fun `verification fails with a helpful error if the class does not have a no arg constructor`() {
val classLoader = Mockito.mock(ClassLoader::class.java)
val canonicalName = DummySerializationSchemeWithoutNoArgConstructor::class.java.canonicalName
whenever(classLoader.loadClass(canonicalName)).thenAnswer { DummySerializationSchemeWithoutNoArgConstructor::class.java }
assertFailsWith<ConfigurationException>("$canonicalName was declared as a custom serialization scheme but does not " +
"have a no argument constructor.") {
scanForCustomSerializationScheme(canonicalName, classLoader)
}
}
}

View File

@ -6,6 +6,7 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.copyBytes
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CustomSerializationSchemeUtils.Companion.getSchemeIdIfCustomSerializationMagic
import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.amqp.amqpMagic
import org.slf4j.LoggerFactory
@ -47,6 +48,10 @@ data class SerializationContextImpl @JvmOverloads constructor(override val prefe
return copy(properties = properties + (property to value))
}
override fun withProperties(extraProperties: Map<Any, Any>): SerializationContext {
return copy(properties = properties + extraProperties)
}
override fun withoutReferences(): SerializationContext {
return copy(objectReferencesEnabled = false)
}
@ -106,7 +111,9 @@ open class SerializationFactoryImpl(
registeredSchemes.filter { it.canDeserializeVersion(magic, target) }.forEach { return@computeIfAbsent it } // XXX: Not single?
logger.warn("Cannot find serialization scheme for: [$lookupKey, " +
"${if (magic == amqpMagic) "AMQP" else "UNKNOWN MAGIC"}] registeredSchemes are: $registeredSchemes")
throw UnsupportedOperationException("Serialization scheme $lookupKey not supported.")
val schemeId = getSchemeIdIfCustomSerializationMagic(magic) ?: throw UnsupportedOperationException("Serialization scheme" +
" $lookupKey not supported.")
throw UnsupportedOperationException("Could not find custom serialization scheme with SchemeId = $schemeId.")
}) to magic
}

View File

@ -3,9 +3,11 @@ package net.corda.coretesting.internal
import net.corda.nodeapi.internal.rpc.client.AMQPClientSerializationScheme
import net.corda.core.internal.createInstancesOfClassesImplementing
import net.corda.core.serialization.CheckpointCustomSerializer
import net.corda.core.serialization.CustomSerializationScheme
import net.corda.core.serialization.SerializationCustomSerializer
import net.corda.core.serialization.SerializationWhitelist
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.nodeapi.internal.serialization.CustomSerializationSchemeAdapter
import net.corda.nodeapi.internal.serialization.amqp.AMQPServerSerializationScheme
import net.corda.nodeapi.internal.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.nodeapi.internal.serialization.kryo.KryoCheckpointSerializer
@ -14,6 +16,7 @@ import net.corda.serialization.internal.AMQP_RPC_CLIENT_CONTEXT
import net.corda.serialization.internal.AMQP_RPC_SERVER_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.SerializationFactoryImpl
import net.corda.serialization.internal.SerializationScheme
import net.corda.testing.common.internal.asContextEnv
import java.util.ServiceLoader
import java.util.concurrent.ConcurrentHashMap
@ -27,20 +30,29 @@ fun createTestSerializationEnv(): SerializationEnvironment {
fun createTestSerializationEnv(classLoader: ClassLoader?): SerializationEnvironment {
var customCheckpointSerializers: Set<CheckpointCustomSerializer<*, *>> = emptySet()
val (clientSerializationScheme, serverSerializationScheme) = if (classLoader != null) {
val serializationSchemes: MutableList<SerializationScheme> = mutableListOf()
if (classLoader != null) {
val customSerializers = createInstancesOfClassesImplementing(classLoader, SerializationCustomSerializer::class.java)
customCheckpointSerializers = createInstancesOfClassesImplementing(classLoader, CheckpointCustomSerializer::class.java)
val serializationWhitelists = ServiceLoader.load(SerializationWhitelist::class.java, classLoader).toSet()
Pair(AMQPClientSerializationScheme(customSerializers, serializationWhitelists),
AMQPServerSerializationScheme(customSerializers, serializationWhitelists))
serializationSchemes.add(AMQPClientSerializationScheme(customSerializers, serializationWhitelists))
serializationSchemes.add(AMQPServerSerializationScheme(customSerializers, serializationWhitelists))
val customSchemes = createInstancesOfClassesImplementing(classLoader, CustomSerializationScheme::class.java)
for (customScheme in customSchemes) {
serializationSchemes.add(CustomSerializationSchemeAdapter(customScheme))
}
} else {
Pair(AMQPClientSerializationScheme(emptyList()), AMQPServerSerializationScheme(emptyList()))
serializationSchemes.add(AMQPClientSerializationScheme(emptyList()))
serializationSchemes.add(AMQPServerSerializationScheme(emptyList()))
}
val factory = SerializationFactoryImpl().apply {
registerScheme(clientSerializationScheme)
registerScheme(serverSerializationScheme)
for (serializationScheme in serializationSchemes) {
registerScheme(serializationScheme)
}
}
return SerializationEnvironment.with(
factory,