Merge remote-tracking branch 'remotes/open/master' into enterprise-merge-september-26

# Conflicts:
#	core/src/main/kotlin/net/corda/core/crypto/CryptoUtils.kt
#	node-api/src/main/kotlin/net/corda/nodeapi/internal/serialization/Kryo.kt
#	settings.gradle
This commit is contained in:
sollecitom
2017-09-26 18:08:47 +01:00
3608 changed files with 12107 additions and 334367 deletions

View File

@ -1,17 +1,15 @@
package net.corda.nodeapi
import net.corda.core.utilities.toBase58String
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.messaging.MessageRecipientGroup
import net.corda.core.messaging.MessageRecipients
import net.corda.core.messaging.SingleMessageRecipient
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.ServiceType
import net.corda.core.internal.read
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.toBase58String
import net.corda.nodeapi.config.SSLConfiguration
import java.security.KeyStore
import java.security.PublicKey
/**
@ -29,8 +27,7 @@ abstract class ArtemisMessagingComponent : SingletonSerializeAsToken() {
const val PEER_USER = "SystemUsers/Peer"
const val INTERNAL_PREFIX = "internal."
const val PEERS_PREFIX = "${INTERNAL_PREFIX}peers."
const val SERVICES_PREFIX = "${INTERNAL_PREFIX}services."
const val PEERS_PREFIX = "${INTERNAL_PREFIX}peers." //TODO Come up with better name for common peers/services queue
const val IP_REQUEST_PREFIX = "ip."
const val P2P_QUEUE = "p2p.inbound"
const val NOTIFICATIONS_ADDRESS = "${INTERNAL_PREFIX}activemq.notifications"
@ -64,13 +61,9 @@ abstract class ArtemisMessagingComponent : SingletonSerializeAsToken() {
@CordaSerializable
data class NodeAddress(override val queueName: String, override val hostAndPort: NetworkHostAndPort) : ArtemisPeerAddress {
companion object {
fun asPeer(peerIdentity: PublicKey, hostAndPort: NetworkHostAndPort): NodeAddress {
fun asSingleNode(peerIdentity: PublicKey, hostAndPort: NetworkHostAndPort): NodeAddress {
return NodeAddress("$PEERS_PREFIX${peerIdentity.toBase58String()}", hostAndPort)
}
fun asService(serviceIdentity: PublicKey, hostAndPort: NetworkHostAndPort): NodeAddress {
return NodeAddress("$SERVICES_PREFIX${serviceIdentity.toBase58String()}", hostAndPort)
}
}
}
@ -84,33 +77,18 @@ abstract class ArtemisMessagingComponent : SingletonSerializeAsToken() {
* @param identity The service identity's owning key.
*/
data class ServiceAddress(val identity: PublicKey) : ArtemisAddress, MessageRecipientGroup {
override val queueName: String = "$SERVICES_PREFIX${identity.toBase58String()}"
override val queueName: String = "$PEERS_PREFIX${identity.toBase58String()}"
}
/** The config object is used to pass in the passwords for the certificate KeyStore and TrustStore */
abstract val config: SSLConfiguration?
/**
* Returns nothing if the keystore was opened OK or throws if not. Useful to check the password, as
* unfortunately Artemis tends to bury the exception when the password is wrong.
*/
fun checkStorePasswords() {
val config = config ?: return
arrayOf(config.sslKeystore, config.nodeKeystore).forEach {
it.read {
KeyStore.getInstance("JKS").load(it, config.keyStorePassword.toCharArray())
}
}
config.trustStoreFile.read {
KeyStore.getInstance("JKS").load(it, config.trustStorePassword.toCharArray())
}
}
fun getArtemisPeerAddress(nodeInfo: NodeInfo): ArtemisPeerAddress {
return if (nodeInfo.advertisedServices.any { it.info.type == ServiceType.networkMap }) {
NetworkMapAddress(nodeInfo.addresses.first())
// Used for bridges creation.
fun getArtemisPeerAddress(party: Party, address: NetworkHostAndPort, netMapName: CordaX500Name? = null): ArtemisPeerAddress {
return if (party.name == netMapName) {
NetworkMapAddress(address)
} else {
NodeAddress.asPeer(nodeInfo.legalIdentity.owningKey, nodeInfo.addresses.first())
NodeAddress.asSingleNode(party.owningKey, address) // It also takes care of services nodes treated as peer nodes
}
}
}

View File

@ -1,18 +1,17 @@
package net.corda.nodeapi
import net.corda.core.identity.CordaX500Name
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.nodeapi.config.SSLConfiguration
import org.apache.activemq.artemis.api.core.TransportConfiguration
import org.apache.activemq.artemis.core.remoting.impl.netty.NettyConnectorFactory
import org.apache.activemq.artemis.core.remoting.impl.netty.TransportConstants
import org.bouncycastle.asn1.x500.X500Name
import java.nio.file.FileSystems
import java.nio.file.Path
sealed class ConnectionDirection {
data class Inbound(val acceptorFactoryClassName: String) : ConnectionDirection()
data class Outbound(
val expectedCommonName: X500Name? = null,
val expectedCommonNames: Set<CordaX500Name> = emptySet(), // TODO SNI? Or we need a notion of node's network identity?
val connectorFactoryClassName: String = NettyConnectorFactory::class.java.name
) : ConnectionDirection()
}
@ -53,8 +52,8 @@ class ArtemisTcpTransport {
)
if (config != null && enableSSL) {
config.sslKeystore.expectedOnDefaultFileSystem()
config.trustStoreFile.expectedOnDefaultFileSystem()
config.sslKeystore.requireOnDefaultFileSystem()
config.trustStoreFile.requireOnDefaultFileSystem()
val tlsOptions = mapOf(
// Enable TLS transport layer with client certs and restrict to at least SHA256 in handshake
// and AES encryption
@ -68,7 +67,7 @@ class ArtemisTcpTransport {
TransportConstants.ENABLED_CIPHER_SUITES_PROP_NAME to CIPHER_SUITES.joinToString(","),
TransportConstants.ENABLED_PROTOCOLS_PROP_NAME to "TLSv1.2",
TransportConstants.NEED_CLIENT_AUTH_PROP_NAME to true,
VERIFY_PEER_LEGAL_NAME to (direction as? ConnectionDirection.Outbound)?.expectedCommonName
VERIFY_PEER_LEGAL_NAME to (direction as? ConnectionDirection.Outbound)?.expectedCommonNames
)
options.putAll(tlsOptions)
}
@ -80,7 +79,3 @@ class ArtemisTcpTransport {
}
}
}
fun Path.expectedOnDefaultFileSystem() {
require(fileSystem == FileSystems.getDefault()) { "Artemis only uses the default file system" }
}

View File

@ -0,0 +1,14 @@
@file:JvmName("ArtemisUtils")
package net.corda.nodeapi
import java.nio.file.FileSystems
import java.nio.file.Path
/**
* Require that the [Path] is on a default file system, and therefore is one that Artemis is willing to use.
* @throws IllegalArgumentException if the path is not on a default file system.
*/
fun Path.requireOnDefaultFileSystem() {
require(fileSystem == FileSystems.getDefault()) { "Artemis only uses the default file system" }
}

View File

@ -170,7 +170,7 @@ object RPCApi {
override fun writeToClientMessage(context: SerializationContext, message: ClientMessage) {
message.putIntProperty(TAG_FIELD_NAME, Tag.RPC_REPLY.ordinal)
message.putLongProperty(RPC_ID_FIELD_NAME, id.toLong)
message.bodyBuffer.writeBytes(result.safeSerialize(context) { Try.Failure(it) }.bytes)
message.bodyBuffer.writeBytes(result.safeSerialize(context) { Try.Failure<Any>(it) }.bytes)
}
}

View File

@ -1,87 +0,0 @@
@file:JvmName("RPCStructures")
package net.corda.nodeapi
import com.esotericsoftware.kryo.Registration
import com.esotericsoftware.kryo.Serializer
import net.corda.core.concurrent.CordaFuture
import net.corda.core.CordaRuntimeException
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
import net.corda.core.toFuture
import net.corda.core.toObservable
import net.corda.nodeapi.config.OldConfig
import net.corda.nodeapi.internal.serialization.*
import rx.Observable
import java.io.InputStream
data class User(
@OldConfig("user")
val username: String,
val password: String,
val permissions: Set<String>) {
override fun toString(): String = "${javaClass.simpleName}($username, permissions=$permissions)"
fun toMap() = mapOf(
"username" to username,
"password" to password,
"permissions" to permissions
)
}
/** Records the protocol version in which this RPC was added. */
@Target(AnnotationTarget.FUNCTION)
@MustBeDocumented
annotation class RPCSinceVersion(val version: Int)
/**
* Thrown to indicate a fatal error in the RPC system itself, as opposed to an error generated by the invoked
* method.
*/
open class RPCException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) {
constructor(msg: String) : this(msg, null)
}
/**
* Thrown to indicate that the calling user does not have permission for something they have requested (for example
* calling a method).
*/
@CordaSerializable
class PermissionException(msg: String) : RuntimeException(msg)
/**
* The Kryo used for the RPC wire protocol.
*/
// Every type in the wire protocol is listed here explicitly.
// This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes,
// because we can see everything we're using in one place.
class RPCKryo(observableSerializer: Serializer<Observable<*>>, serializationContext: SerializationContext) : CordaKryo(CordaClassResolver(serializationContext)) {
init {
DefaultKryoCustomizer.customize(this)
// RPC specific classes
register(InputStream::class.java, InputStreamSerializer)
register(Observable::class.java, observableSerializer)
register(CordaFuture::class,
read = { kryo, input -> observableSerializer.read(kryo, input, Observable::class.java).toFuture() },
write = { kryo, output, obj -> observableSerializer.write(kryo, output, obj.toObservable()) }
)
}
override fun getRegistration(type: Class<*>): Registration {
if (Observable::class.java != type && Observable::class.java.isAssignableFrom(type)) {
return super.getRegistration(Observable::class.java)
}
if (InputStream::class.java != type && InputStream::class.java.isAssignableFrom(type)) {
return super.getRegistration(InputStream::class.java)
}
if (CordaFuture::class.java != type && CordaFuture::class.java.isAssignableFrom(type)) {
return super.getRegistration(CordaFuture::class.java)
}
type.requireExternal("RPC not allowed to deserialise internal classes")
return super.getRegistration(type)
}
private fun Class<*>.requireExternal(msg: String) {
require(!name.startsWith("net.corda.node.") && !name.contains(".internal.")) { "$msg: $name" }
}
}

View File

@ -0,0 +1,16 @@
package net.corda.nodeapi
import net.corda.nodeapi.config.OldConfig
data class User(
@OldConfig("user")
val username: String,
val password: String,
val permissions: Set<String>) {
override fun toString(): String = "${javaClass.simpleName}($username, permissions=$permissions)"
fun toMap() = mapOf(
"username" to username,
"password" to password,
"permissions" to permissions
)
}

View File

@ -8,11 +8,11 @@ import org.apache.activemq.artemis.api.core.client.ClientMessage
import org.apache.activemq.artemis.reader.MessageUtil
object VerifierApi {
val VERIFIER_USERNAME = "SystemUsers/Verifier"
val VERIFICATION_REQUESTS_QUEUE_NAME = "verifier.requests"
val VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX = "verifier.responses"
private val VERIFICATION_ID_FIELD_NAME = "id"
private val RESULT_EXCEPTION_FIELD_NAME = "result-exception"
const val VERIFIER_USERNAME = "SystemUsers/Verifier"
const val VERIFICATION_REQUESTS_QUEUE_NAME = "verifier.requests"
const val VERIFICATION_RESPONSES_QUEUE_NAME_PREFIX = "verifier.responses"
private const val VERIFICATION_ID_FIELD_NAME = "id"
private const val RESULT_EXCEPTION_FIELD_NAME = "result-exception"
data class VerificationRequest(
val verificationId: Long,

View File

@ -1,11 +1,11 @@
@file:JvmName("ConfigUtilities")
package net.corda.nodeapi.config
import com.typesafe.config.Config
import com.typesafe.config.ConfigUtil
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.noneOrSingle
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.parseNetworkHostAndPort
import org.bouncycastle.asn1.x500.X500Name
import org.slf4j.LoggerFactory
import java.net.Proxy
import java.net.URL
@ -68,11 +68,11 @@ private fun Config.getSingleValue(path: String, type: KType): Any? {
Boolean::class -> getBoolean(path)
LocalDate::class -> LocalDate.parse(getString(path))
Instant::class -> Instant.parse(getString(path))
NetworkHostAndPort::class -> getString(path).parseNetworkHostAndPort()
NetworkHostAndPort::class -> NetworkHostAndPort.parse(getString(path))
Path::class -> Paths.get(getString(path))
URL::class -> URL(getString(path))
Properties::class -> getConfig(path).toProperties()
X500Name::class -> X500Name(getString(path))
CordaX500Name::class -> CordaX500Name.parse(getString(path))
else -> if (typeClass.java.isEnum) {
parseEnum(typeClass.java, getString(path))
} else {
@ -96,10 +96,10 @@ private fun Config.getCollectionValue(path: String, type: KType): Collection<Any
Boolean::class -> getBooleanList(path)
LocalDate::class -> getStringList(path).map(LocalDate::parse)
Instant::class -> getStringList(path).map(Instant::parse)
NetworkHostAndPort::class -> getStringList(path).map { it.parseNetworkHostAndPort() }
NetworkHostAndPort::class -> getStringList(path).map(NetworkHostAndPort.Companion::parse)
Path::class -> getStringList(path).map { Paths.get(it) }
URL::class -> getStringList(path).map(::URL)
X500Name::class -> getStringList(path).map(::X500Name)
CordaX500Name::class -> getStringList(path).map(CordaX500Name.Companion::parse)
Properties::class -> getConfigList(path).map(Config::toProperties)
else -> if (elementClass.java.isEnum) {
getStringList(path).map { parseEnum(elementClass.java, it) }

View File

@ -0,0 +1,108 @@
package net.corda.nodeapi.internal
import net.corda.core.contracts.Attachment
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.CordaSerializable
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
import java.io.FileNotFoundException
import java.io.InputStream
import java.net.URL
import java.net.URLConnection
import java.net.URLStreamHandler
import java.security.CodeSigner
import java.security.CodeSource
import java.security.SecureClassLoader
import java.util.*
/**
* A custom ClassLoader that knows how to load classes from a set of attachments. The attachments themselves only
* need to provide JAR streams, and so could be fetched from a database, local disk, etc. Constructing an
* AttachmentsClassLoader is somewhat expensive, as every attachment is scanned to ensure that there are no overlapping
* file paths.
*/
class AttachmentsClassLoader(attachments: List<Attachment>, parent: ClassLoader = ClassLoader.getSystemClassLoader()) : SecureClassLoader(parent) {
private val pathsToAttachments = HashMap<String, Attachment>()
private val idsToAttachments = HashMap<SecureHash, Attachment>()
@CordaSerializable
class OverlappingAttachments(val path: String) : Exception() {
override fun toString() = "Multiple attachments define a file at path $path"
}
init {
for (attachment in attachments) {
attachment.openAsJAR().use { jar ->
while (true) {
val entry = jar.nextJarEntry ?: break
// We already verified that paths are not strange/game playing when we inserted the attachment
// into the storage service. So we don't need to repeat it here.
//
// We forbid files that differ only in case, or path separator to avoid issues for Windows/Mac developers where the
// filesystem tries to be case insensitive. This may break developers who attempt to use ProGuard.
//
// Also convert to Unix path separators as all resource/class lookups will expect this.
val path = entry.name.toLowerCase().replace('\\', '/')
if (path in pathsToAttachments)
throw OverlappingAttachments(path)
pathsToAttachments[path] = attachment
}
}
idsToAttachments[attachment.id] = attachment
}
}
// Example: attachment://0b4fc1327f3bbebf1bfe98330ea402ae035936c3cb6da9bd3e26eeaa9584e74d/some/file.txt
//
// We have to provide a fake stream handler to satisfy the URL class that the scheme is known. But it's not
// a real scheme and we don't register it. It's just here to ensure that there aren't codepaths that could
// lead to data loading that we don't control right here in this class (URLs can have evil security properties!)
private val fakeStreamHandler = object : URLStreamHandler() {
override fun openConnection(u: URL?): URLConnection? {
throw UnsupportedOperationException()
}
}
private fun Attachment.toURL(path: String?) = URL(null, "attachment://$id/" + (path ?: ""), fakeStreamHandler)
override fun findClass(name: String): Class<*> {
val path = name.replace('.', '/').toLowerCase() + ".class"
val attachment = pathsToAttachments[path] ?: throw ClassNotFoundException(name)
val stream = ByteArrayOutputStream()
try {
attachment.extractFile(path, stream)
} catch(e: FileNotFoundException) {
throw ClassNotFoundException(name)
}
val bytes = stream.toByteArray()
// We don't attempt to propagate signatures from the JAR into the codesource, because our sandbox does not
// depend on external policy files to specify what it can do, so the data wouldn't be useful.
val codesource = CodeSource(attachment.toURL(null), emptyArray<CodeSigner>())
// TODO: Define an empty ProtectionDomain to start enforcing the standard Java sandbox.
// The standard Java sandbox is insufficient for our needs and a much more sophisticated sandboxing
// ClassLoader will appear here in future, but it can't hurt to use the default one too: defence in depth!
return defineClass(name, bytes, 0, bytes.size, codesource)
}
override fun findResource(name: String): URL? {
val attachment = pathsToAttachments[name.toLowerCase()] ?: return null
return attachment.toURL(name)
}
override fun getResourceAsStream(name: String): InputStream? {
val url = getResource(name) ?: return null // May check parent classloaders, for example.
if (url.protocol != "attachment") return null
val attachment = idsToAttachments[SecureHash.parse(url.host)] ?: return null
val path = url.path?.substring(1) ?: return null // Chop off the leading slash.
try {
val stream = ByteArrayOutputStream()
attachment.extractFile(path, stream)
return ByteArrayInputStream(stream.toByteArray())
} catch(e: FileNotFoundException) {
return null
}
}
}

View File

@ -0,0 +1,27 @@
package net.corda.nodeapi.internal
import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.CordaSerializable
/**
* A container for additional information for an advertised service.
*
* @param type the ServiceType identifier
* @param name the service name, used for differentiating multiple services of the same type. Can also be used as a
* grouping identifier for nodes collectively running a distributed service.
*/
@CordaSerializable
data class ServiceInfo(val type: ServiceType, val name: CordaX500Name? = null) {
companion object {
fun parse(encoded: String): ServiceInfo {
val parts = encoded.split("|")
require(parts.size in 1..2) { "Invalid number of elements found" }
val type = ServiceType.parse(parts[0])
val name = parts.getOrNull(1)
val principal = name?.let { CordaX500Name.parse(it) }
return ServiceInfo(type, principal)
}
}
override fun toString() = if (name != null) "$type|$name" else type.toString()
}

View File

@ -0,0 +1,45 @@
package net.corda.nodeapi.internal
import net.corda.core.serialization.CordaSerializable
/**
* Identifier for service types a node can expose over the network to other peers. These types are placed into network
* map advertisements. Services that are purely local and are not providing functionality to other parts of the network
* don't need a declared service type.
*/
@CordaSerializable
class ServiceType private constructor(val id: String) {
init {
// Enforce:
//
// * IDs must start with a lower case letter
// * IDs can only contain alphanumeric, full stop and underscore ASCII characters
require(id.matches(Regex("[a-z][a-zA-Z0-9._]+"))) { id }
}
companion object {
val corda: ServiceType
get() {
val stack = Throwable().stackTrace
val caller = stack.first().className
require(caller.startsWith("net.corda.")) { "Corda ServiceType namespace is reserved for Corda core components" }
return ServiceType("corda")
}
val notary: ServiceType = corda.getSubType("notary")
val networkMap: ServiceType = corda.getSubType("network_map")
fun parse(id: String): ServiceType = ServiceType(id)
private fun baseWithSubType(baseId: String, subTypeId: String) = ServiceType("$baseId.$subTypeId")
}
fun getSubType(subTypeId: String): ServiceType = baseWithSubType(id, subTypeId)
fun isSubTypeOf(superType: ServiceType) = (id == superType.id) || id.startsWith(superType.id + ".")
fun isNotary() = isSubTypeOf(notary)
override fun equals(other: Any?): Boolean = other === this || other is ServiceType && other.id == this.id
override fun hashCode(): Int = id.hashCode()
override fun toString(): String = id
}

View File

@ -1,3 +1,5 @@
@file:JvmName("AMQPSerializationScheme")
package net.corda.nodeapi.internal.serialization
import net.corda.core.node.CordaPluginRegistry
@ -19,6 +21,11 @@ class AMQPSerializationCustomization(val factory: SerializerFactory) : Serializa
}
fun SerializerFactory.addToWhitelist(vararg types: Class<*>) {
require(types.toSet().size == types.size) {
val duplicates = types.toMutableList()
types.toSet().forEach { duplicates -= it }
"Cannot add duplicate classes to the whitelist ($duplicates)."
}
for (type in types) {
(this.whitelist as? MutableClassWhitelist)?.add(type)
}
@ -54,6 +61,11 @@ abstract class AbstractAMQPSerializationScheme : SerializationScheme {
register(net.corda.nodeapi.internal.serialization.amqp.custom.ClassSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.X509CertificateHolderSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.PartyAndCertificateSerializer(factory))
register(net.corda.nodeapi.internal.serialization.amqp.custom.StringBufferSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.SimpleStringSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.InputStreamSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.BitSetSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.EnumSetSerializer(this))
}
val customizer = AMQPSerializationCustomization(factory)
pluginRegistries.forEach { it.customizeSerialization(customizer) }

View File

@ -6,7 +6,7 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util
import net.corda.core.serialization.AttachmentsClassLoader
import net.corda.nodeapi.internal.AttachmentsClassLoader
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
@ -19,21 +19,28 @@ import java.nio.file.Paths
import java.nio.file.StandardOpenOption
import java.util.*
fun Kryo.addToWhitelist(vararg types: Class<*>) {
for (type in types) {
((classResolver as? CordaClassResolver)?.whitelist as? MutableClassWhitelist)?.add(type)
}
}
/**
* Corda specific class resolver which enables extra customisation for the purposes of serialization using Kryo
*/
class CordaClassResolver(serializationContext: SerializationContext) : DefaultClassResolver() {
val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist)
/*
* These classes are assignment-compatible Java equivalents of Kotlin classes.
* The point is that we do not want to send Kotlin types "over the wire" via RPC.
*/
private val javaAliases: Map<Class<*>, Class<*>> = mapOf(
listOf<Any>().javaClass to Collections.emptyList<Any>().javaClass,
setOf<Any>().javaClass to Collections.emptySet<Any>().javaClass,
mapOf<Any, Any>().javaClass to Collections.emptyMap<Any, Any>().javaClass
)
private fun typeForSerializationOf(type: Class<*>): Class<*> = javaAliases[type] ?: type
/** Returns the registration for the specified class, or null if the class is not registered. */
override fun getRegistration(type: Class<*>): Registration? {
return super.getRegistration(type) ?: checkClass(type)
val targetType = typeForSerializationOf(type)
return super.getRegistration(targetType) ?: checkClass(targetType)
}
private var whitelistEnabled = true
@ -67,9 +74,9 @@ class CordaClassResolver(serializationContext: SerializationContext) : DefaultCl
}
override fun registerImplicit(type: Class<*>): Registration {
val targetType = typeForSerializationOf(type)
val objectInstance = try {
type.kotlin.objectInstance
targetType.kotlin.objectInstance
} catch (t: Throwable) {
null // objectInstance will throw if the type is something like a lambda
}
@ -80,17 +87,21 @@ class CordaClassResolver(serializationContext: SerializationContext) : DefaultCl
kryo.references = true
val serializer = when {
objectInstance != null -> KotlinObjectSerializer(objectInstance)
kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type) -> // Kotlin lambdas extend this class and any captured variables are stored in synthetic fields
FieldSerializer<Any>(kryo, type).apply { setIgnoreSyntheticFields(false) }
Throwable::class.java.isAssignableFrom(type) -> ThrowableSerializer(kryo, type)
else -> kryo.getDefaultSerializer(type)
kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(targetType) -> // Kotlin lambdas extend this class and any captured variables are stored in synthetic fields
FieldSerializer<Any>(kryo, targetType).apply { setIgnoreSyntheticFields(false) }
Throwable::class.java.isAssignableFrom(targetType) -> ThrowableSerializer(kryo, targetType)
else -> kryo.getDefaultSerializer(targetType)
}
return register(Registration(type, serializer, NAME.toInt()))
return register(Registration(targetType, serializer, NAME.toInt()))
} finally {
kryo.references = references
}
}
override fun writeName(output: Output, type: Class<*>, registration: Registration) {
super.writeName(output, registration.type ?: type, registration)
}
// Trivial Serializer which simply returns the given instance, which we already know is a Kotlin object
private class KotlinObjectSerializer(private val objectInstance: Any) : Serializer<Any>() {
override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any = objectInstance
@ -134,10 +145,6 @@ interface MutableClassWhitelist : ClassWhitelist {
fun add(entry: Class<*>)
}
object EmptyWhitelist : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = false
}
class BuiltInExceptionsWhitelist : ClassWhitelist {
companion object {
private val packageName = "^(?:java|kotlin)(?:[.]|$)".toRegex()

View File

@ -0,0 +1,29 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.ClosureSerializer
import java.io.Serializable
object CordaClosureSerializer : ClosureSerializer() {
val ERROR_MESSAGE = "Unable to serialize Java Lambda expression, unless explicitly declared e.g., Runnable r = (Runnable & Serializable) () -> System.out.println(\"Hello world!\");"
override fun write(kryo: Kryo, output: Output, target: Any) {
if (!isSerializable(target)) {
throw IllegalArgumentException(ERROR_MESSAGE)
}
super.write(kryo, output, target)
}
private fun isSerializable(target: Any): Boolean {
return target is Serializable
}
}
object CordaClosureBlacklistSerializer : ClosureSerializer() {
val ERROR_MESSAGE = "Java 8 Lambda expressions are not supported for serialization."
override fun write(kryo: Kryo, output: Output, target: Any) {
throw IllegalArgumentException(ERROR_MESSAGE)
}
}

View File

@ -4,12 +4,14 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.ClosureSerializer
import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer
import com.esotericsoftware.kryo.serializers.FieldSerializer
import de.javakaffee.kryoserializers.ArraysAsListSerializer
import de.javakaffee.kryoserializers.BitSetSerializer
import de.javakaffee.kryoserializers.UnmodifiableCollectionsSerializer
import de.javakaffee.kryoserializers.guava.*
import net.corda.core.contracts.ContractAttachment
import net.corda.core.contracts.PrivacySalt
import net.corda.core.crypto.CompositeKey
import net.corda.core.identity.PartyAndCertificate
@ -38,6 +40,7 @@ import org.slf4j.Logger
import sun.security.ec.ECPublicKeyImpl
import sun.security.provider.certpath.X509CertPath
import java.io.BufferedInputStream
import java.io.ByteArrayOutputStream
import java.io.FileInputStream
import java.io.InputStream
import java.lang.reflect.Modifier.isPublic
@ -113,6 +116,12 @@ object DefaultKryoCustomizer {
// Don't deserialize PrivacySalt via its default constructor.
register(PrivacySalt::class.java, PrivacySaltSerializer)
// Used by the remote verifier, and will possibly be removed in future.
register(ContractAttachment::class.java, ContractAttachmentSerializer)
register(java.lang.invoke.SerializedLambda::class.java)
register(ClosureSerializer.Closure::class.java, CordaClosureBlacklistSerializer)
val customization = KryoSerializationCustomization(this)
pluginRegistries.forEach { it.customizeSerialization(customization) }
}
@ -170,4 +179,18 @@ object DefaultKryoCustomizer {
return PrivacySalt(input.readBytesWithLength())
}
}
private object ContractAttachmentSerializer : Serializer<ContractAttachment>() {
override fun write(kryo: Kryo, output: Output, obj: ContractAttachment) {
val buffer = ByteArrayOutputStream()
obj.attachment.open().use { it.copyTo(buffer) }
output.writeBytesWithLength(buffer.toByteArray())
output.writeString(obj.contract)
}
override fun read(kryo: Kryo, input: Input, type: Class<ContractAttachment>): ContractAttachment {
val attachment = GeneratedAttachment(input.readBytesWithLength())
return ContractAttachment(attachment, input.readString())
}
}
}

View File

@ -7,9 +7,6 @@ import net.corda.core.utilities.NetworkHostAndPort
import org.apache.activemq.artemis.api.core.SimpleString
import rx.Notification
import rx.exceptions.OnErrorNotImplementedException
import java.math.BigDecimal
import java.time.LocalDate
import java.time.Period
import java.util.*
/**
@ -22,31 +19,29 @@ class DefaultWhitelist : CordaPluginRegistry() {
Notification::class.java,
Notification.Kind::class.java,
ArrayList::class.java,
listOf<Any>().javaClass, // EmptyList
Pair::class.java,
ByteArray::class.java,
UUID::class.java,
LinkedHashSet::class.java,
setOf<Unit>().javaClass, // EmptySet
Currency::class.java,
listOf(Unit).javaClass, // SingletonList
setOf(Unit).javaClass, // SingletonSet
mapOf(Unit to Unit).javaClass, // SingletonSet
mapOf(Unit to Unit).javaClass, // SingletonMap
NetworkHostAndPort::class.java,
SimpleString::class.java,
KryoException::class.java,
KryoException::class.java, // TODO: Will be removed when we migrate away from Kryo
StringBuffer::class.java,
Unit::class.java,
java.io.ByteArrayInputStream::class.java,
java.lang.Class::class.java,
java.math.BigDecimal::class.java,
java.security.KeyPair::class.java,
// Matches the list in TimeSerializers.addDefaultSerializers:
java.time.Duration::class.java,
java.time.Instant::class.java,
java.time.LocalDate::class.java,
java.time.LocalDateTime::class.java,
java.time.LocalTime::class.java,
java.time.ZoneOffset::class.java,
java.time.ZoneId::class.java,
java.time.OffsetTime::class.java,
@ -57,12 +52,12 @@ class DefaultWhitelist : CordaPluginRegistry() {
java.time.MonthDay::class.java,
java.time.Period::class.java,
java.time.DayOfWeek::class.java, // No custom serialiser but it's an enum.
java.time.Month::class.java, // No custom serialiser but it's an enum.
java.util.Collections.singletonMap("A", "B").javaClass,
java.util.Collections.emptyMap<Any, Any>().javaClass,
java.util.Collections.emptySet<Any>().javaClass,
java.util.Collections.emptyList<Any>().javaClass,
java.util.LinkedHashMap::class.java,
BigDecimal::class.java,
LocalDate::class.java,
Period::class.java,
BitSet::class.java,
OnErrorNotImplementedException::class.java,
StackTraceElement::class.java)

View File

@ -0,0 +1,8 @@
package net.corda.nodeapi.internal.serialization
import net.corda.core.crypto.sha256
import net.corda.core.internal.AbstractAttachment
class GeneratedAttachment(bytes: ByteArray) : AbstractAttachment({ bytes }) {
override val id = bytes.sha256()
}

View File

@ -7,16 +7,19 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.CompatibleFieldSerializer
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.MapReferenceResolver
import net.corda.core.contracts.*
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.PrivacySalt
import net.corda.core.contracts.StateRef
import net.corda.core.crypto.CompositeKey
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.identity.Party
import net.corda.core.serialization.AttachmentsClassLoader
import net.corda.core.serialization.MissingAttachmentsException
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.toFuture
import net.corda.core.toObservable
import net.corda.core.transactions.*
import net.corda.core.transactions.CoreTransaction
import net.corda.core.transactions.NotaryChangeWireTransaction
import net.corda.core.transactions.SignedTransaction
@ -32,6 +35,7 @@ import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.cert.X509CertificateHolder
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import rx.Observable
import sun.security.ec.ECPublicKeyImpl
import sun.security.util.DerValue
import java.io.ByteArrayInputStream
@ -245,42 +249,15 @@ fun Input.readBytesWithLength(): ByteArray {
@ThreadSafe
object WireTransactionSerializer : Serializer<WireTransaction>() {
override fun write(kryo: Kryo, output: Output, obj: WireTransaction) {
kryo.writeClassAndObject(output, obj.inputs)
kryo.writeClassAndObject(output, obj.attachments)
kryo.writeClassAndObject(output, obj.outputs)
kryo.writeClassAndObject(output, obj.commands)
kryo.writeClassAndObject(output, obj.notary)
kryo.writeClassAndObject(output, obj.timeWindow)
kryo.writeClassAndObject(output, obj.componentGroups)
kryo.writeClassAndObject(output, obj.privacySalt)
}
private fun attachmentsClassLoader(kryo: Kryo, attachmentHashes: List<SecureHash>): ClassLoader? {
kryo.context[attachmentsClassLoaderEnabledPropertyName] as? Boolean ?: false || return null
val serializationContext = kryo.serializationContext() ?: return null // Some tests don't set one.
val missing = ArrayList<SecureHash>()
val attachments = ArrayList<Attachment>()
attachmentHashes.forEach { id ->
serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } ?: run { missing += id }
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
return AttachmentsClassLoader(attachments)
}
@Suppress("UNCHECKED_CAST")
override fun read(kryo: Kryo, input: Input, type: Class<WireTransaction>): WireTransaction {
val inputs = kryo.readClassAndObject(input) as List<StateRef>
val attachmentHashes = kryo.readClassAndObject(input) as List<SecureHash>
// If we're deserialising in the sandbox context, we use our special attachments classloader.
// Otherwise we just assume the code we need is on the classpath already.
kryo.useClassLoader(attachmentsClassLoader(kryo, attachmentHashes) ?: javaClass.classLoader) {
val outputs = kryo.readClassAndObject(input) as List<TransactionState<ContractState>>
val commands = kryo.readClassAndObject(input) as List<Command<*>>
val notary = kryo.readClassAndObject(input) as Party?
val timeWindow = kryo.readClassAndObject(input) as TimeWindow?
val privacySalt = kryo.readClassAndObject(input) as PrivacySalt
return WireTransaction(inputs, attachmentHashes, outputs, commands, notary, timeWindow, privacySalt)
}
val componentGroups = kryo.readClassAndObject(input) as List<ComponentGroup>
val privacySalt = kryo.readClassAndObject(input) as PrivacySalt
return WireTransaction(componentGroups, privacySalt)
}
}
@ -414,8 +391,7 @@ inline fun <reified T> readListOfLength(kryo: Kryo, input: Input, minLen: Int =
if (elemCount < minLen) throw KryoException("Cannot deserialize list, too little elements. Minimum required: $minLen, got: $elemCount")
if (expectedLen != null && elemCount != expectedLen)
throw KryoException("Cannot deserialize list, expected length: $expectedLen, got: $elemCount.")
val list = (1..elemCount).map { kryo.readClassAndObject(input) as T }
return list
return (1..elemCount).map { kryo.readClassAndObject(input) as T }
}
/**
@ -460,6 +436,44 @@ open class CordaKryo(classResolver: ClassResolver) : Kryo(classResolver, MapRefe
}
}
/**
* The Kryo used for the RPC wire protocol.
*/
// Every type in the wire protocol is listed here explicitly.
// This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes,
// because we can see everything we're using in one place.
class RPCKryo(observableSerializer: Serializer<Observable<*>>, serializationContext: SerializationContext) : CordaKryo(CordaClassResolver(serializationContext)) {
init {
DefaultKryoCustomizer.customize(this)
// RPC specific classes
register(InputStream::class.java, InputStreamSerializer)
register(Observable::class.java, observableSerializer)
register(CordaFuture::class,
read = { kryo, input -> observableSerializer.read(kryo, input, Observable::class.java).toFuture() },
write = { kryo, output, obj -> observableSerializer.write(kryo, output, obj.toObservable()) }
)
}
override fun getRegistration(type: Class<*>): Registration {
if (Observable::class.java != type && Observable::class.java.isAssignableFrom(type)) {
return super.getRegistration(Observable::class.java)
}
if (InputStream::class.java != type && InputStream::class.java.isAssignableFrom(type)) {
return super.getRegistration(InputStream::class.java)
}
if (CordaFuture::class.java != type && CordaFuture::class.java.isAssignableFrom(type)) {
return super.getRegistration(CordaFuture::class.java)
}
type.requireExternal("RPC not allowed to deserialise internal classes")
return super.getRegistration(type)
}
private fun Class<*>.requireExternal(msg: String) {
require(!name.startsWith("net.corda.node.") && ".internal" !in name) { "$msg: $name" }
}
}
inline fun <T : Any> Kryo.register(
type: KClass<T>,
crossinline read: (Kryo, Input) -> T,
@ -482,7 +496,7 @@ inline fun <reified T : Any> Kryo.noReferencesWithin() {
register(T::class.java, NoReferencesSerializer(getSerializer(T::class.java)))
}
class NoReferencesSerializer<T>(val baseSerializer: Serializer<T>) : Serializer<T>() {
class NoReferencesSerializer<T>(private val baseSerializer: Serializer<T>) : Serializer<T>() {
override fun read(kryo: Kryo, input: Input, type: Class<T>): T {
return kryo.withoutReferences { baseSerializer.read(kryo, input, type) }
@ -610,4 +624,4 @@ class ThrowableSerializer<T>(kryo: Kryo, type: Class<T>) : Serializer<Throwable>
}
private fun Throwable.setSuppressedToSentinel() = suppressedField.set(this, sentinelValue)
}
}

View File

@ -5,6 +5,13 @@ import net.corda.core.serialization.SerializationCustomization
class KryoSerializationCustomization(val kryo: Kryo) : SerializationCustomization {
override fun addToWhitelist(vararg types: Class<*>) {
kryo.addToWhitelist(*types)
require(types.toSet().size == types.size) {
val duplicates = types.toMutableList()
types.toSet().forEach { duplicates -= it }
"Cannot add duplicate classes to the whitelist ($duplicates)."
}
for (type in types) {
((kryo.classResolver as? CordaClassResolver)?.whitelist as? MutableClassWhitelist)?.add(type)
}
}
}

View File

@ -8,6 +8,7 @@ import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.serializers.ClosureSerializer
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import net.corda.core.contracts.Attachment
@ -16,6 +17,7 @@ import net.corda.core.internal.LazyPool
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.nodeapi.internal.AttachmentsClassLoader
import java.io.ByteArrayOutputStream
import java.io.NotSerializableException
import java.util.*
@ -59,7 +61,7 @@ data class SerializationContextImpl(override val preferredSerializationVersion:
serializationContext.serviceHub.attachments.openAttachment(id)?.let { attachments += it } ?: run { missing += id }
}
missing.isNotEmpty() && throw MissingAttachmentsException(missing)
AttachmentsClassLoader(attachments)
AttachmentsClassLoader(attachments, parent = deserializationClassLoader)
})
} catch (e: ExecutionException) {
// Caught from within the cache get, so unwrap.
@ -166,6 +168,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
field.set(this, classResolver)
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second
}
}.build()

View File

@ -61,13 +61,13 @@ class CollectionSerializer(val declaredType: ParameterizedType, factory: Seriali
private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "list", Descriptor(typeDescriptor), emptyList())
override fun writeClassInfo(output: SerializationOutput) {
override fun writeClassInfo(output: SerializationOutput) = ifThrowsAppend({declaredType.typeName}) {
if (output.writeTypeNotations(typeNotation)) {
output.requireSerializer(declaredType.actualTypeArguments[0])
}
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) = ifThrowsAppend({declaredType.typeName}) {
// Write described
data.withDescribed(typeNotation.descriptor) {
withList {
@ -78,8 +78,8 @@ class CollectionSerializer(val declaredType: ParameterizedType, factory: Seriali
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any = ifThrowsAppend({declaredType.typeName}) {
// TODO: Can we verify the entries in the list?
return concreteBuilder((obj as List<*>).map { input.readObjectOrNull(it, schema, declaredType.actualTypeArguments[0]) })
concreteBuilder((obj as List<*>).map { input.readObjectOrNull(it, schema, declaredType.actualTypeArguments[0]) })
}
}

View File

@ -29,6 +29,11 @@ abstract class CustomSerializer<T> : AMQPSerializer<T> {
*/
abstract val schemaForDocumentation: Schema
/**
* Whether subclasses using this serializer via inheritance should have a mapping in the schema.
*/
open val revealSubclassesInSchema: Boolean = false
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
data.withDescribed(descriptor) {
@Suppress("UNCHECKED_CAST")

View File

@ -86,7 +86,7 @@ class DeserializedParameterizedType(private val rawType: Class<*>, private val p
if (pos == typeStart) {
skippingWhitespace = false
if (params[pos].isWhitespace()) {
typeStart = pos++
typeStart = ++pos
} else if (!needAType) {
throw NotSerializableException("Not expecting a type")
} else if (params[pos] == '?') {

View File

@ -11,7 +11,7 @@ import java.lang.reflect.Type
*/
class EnumSerializer(declaredType: Type, declaredClass: Class<*>, factory: SerializerFactory) : AMQPSerializer<Any> {
override val type: Type = declaredType
override val typeDescriptor = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}")
override val typeDescriptor = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}")!!
private val typeNotation: TypeNotation
init {

View File

@ -18,7 +18,7 @@ private typealias MapCreationFunction = (Map<*, *>) -> Map<*, *>
*/
class MapSerializer(private val declaredType: ParameterizedType, factory: SerializerFactory) : AMQPSerializer<Any> {
override val type: Type = declaredType as? DeserializedParameterizedType ?: DeserializedParameterizedType.make(SerializerFactory.nameForType(declaredType))
override val typeDescriptor = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}")
override val typeDescriptor: Symbol = Symbol.valueOf("$DESCRIPTOR_DOMAIN:${fingerprintForType(type, factory)}")
companion object {
// NB: Order matters in this map, the most specific classes should be listed at the end
@ -29,7 +29,11 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial
NavigableMap::class.java to { map -> Collections.unmodifiableNavigableMap(TreeMap(map)) },
// concrete classes for user convenience
LinkedHashMap::class.java to { map -> LinkedHashMap(map) },
TreeMap::class.java to { map -> TreeMap(map) }
TreeMap::class.java to { map -> TreeMap(map) },
EnumMap::class.java to { map ->
@Suppress("UNCHECKED_CAST")
EnumMap(map as Map<EnumJustUsedForCasting, Any>)
}
))
private fun findConcreteType(clazz: Class<*>): MapCreationFunction {
@ -37,6 +41,7 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial
}
fun deriveParameterizedType(declaredType: Type, declaredClass: Class<*>, actualClass: Class<*>?): ParameterizedType {
declaredClass.checkSupportedMapType()
if(supportedTypes.containsKey(declaredClass)) {
// Simple case - it is already known to be a map.
@Suppress("UNCHECKED_CAST")
@ -63,15 +68,15 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial
private val typeNotation: TypeNotation = RestrictedType(SerializerFactory.nameForType(declaredType), null, emptyList(), "map", Descriptor(typeDescriptor), emptyList())
override fun writeClassInfo(output: SerializationOutput) {
override fun writeClassInfo(output: SerializationOutput) = ifThrowsAppend({declaredType.typeName}) {
if (output.writeTypeNotations(typeNotation)) {
output.requireSerializer(declaredType.actualTypeArguments[0])
output.requireSerializer(declaredType.actualTypeArguments[1])
}
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
obj.javaClass.checkNotUnsupportedHashMap()
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) = ifThrowsAppend({declaredType.typeName}) {
obj.javaClass.checkSupportedMapType()
// Write described
data.withDescribed(typeNotation.descriptor) {
// Write map
@ -85,18 +90,22 @@ class MapSerializer(private val declaredType: ParameterizedType, factory: Serial
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any = ifThrowsAppend({declaredType.typeName}) {
// TODO: General generics question. Do we need to validate that entries in Maps and Collections match the generic type? Is it a security hole?
val entries: Iterable<Pair<Any?, Any?>> = (obj as Map<*, *>).map { readEntry(schema, input, it) }
return concreteBuilder(entries.toMap())
concreteBuilder(entries.toMap())
}
private fun readEntry(schema: Schema, input: DeserializationInput, entry: Map.Entry<Any?, Any?>) =
input.readObjectOrNull(entry.key, schema, declaredType.actualTypeArguments[0]) to
input.readObjectOrNull(entry.value, schema, declaredType.actualTypeArguments[1])
// Cannot use * as a bound for EnumMap and EnumSet since * is not an enum. So, we use a sample enum instead.
// We don't actually care about the type, we just need to make the compiler happier.
internal enum class EnumJustUsedForCasting { NOT_USED }
}
internal fun Class<*>.checkNotUnsupportedHashMap() {
internal fun Class<*>.checkSupportedMapType() {
if (HashMap::class.java.isAssignableFrom(this) && !LinkedHashMap::class.java.isAssignableFrom(this)) {
throw IllegalArgumentException(
"Map type $this is unstable under iteration. Suggested fix: use java.util.LinkedHashMap instead.")

View File

@ -41,7 +41,7 @@ open class ObjectSerializer(val clazz: Type, factory: SerializerFactory) : AMQPS
}
}
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) {
override fun writeObject(obj: Any, data: Data, type: Type, output: SerializationOutput) = ifThrowsAppend({clazz.typeName}) {
// Write described
data.withDescribed(typeNotation.descriptor) {
// Write list
@ -53,11 +53,11 @@ open class ObjectSerializer(val clazz: Type, factory: SerializerFactory) : AMQPS
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any {
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): Any = ifThrowsAppend({clazz.typeName}) {
if (obj is List<*>) {
if (obj.size > propertySerializers.size) throw NotSerializableException("Too many properties in described type $typeName")
val params = obj.zip(propertySerializers).map { it.second.readProperty(it.first, schema, input) }
return construct(params)
construct(params)
} else throw NotSerializableException("Body of described type is unexpected $obj")
}
@ -74,5 +74,4 @@ open class ObjectSerializer(val clazz: Type, factory: SerializerFactory) : AMQPS
return javaConstructor?.newInstance(*properties.toTypedArray()) ?:
throw NotSerializableException("Attempt to deserialize an interface: $clazz. Serialized form is invalid.")
}
}

View File

@ -76,19 +76,21 @@ sealed class PropertySerializer(val name: String, val readMethod: Method?, val r
// This is lazy so we don't get an infinite loop when a method returns an instance of the class.
private val typeSerializer: AMQPSerializer<*> by lazy { lazyTypeSerializer() }
override fun writeClassInfo(output: SerializationOutput) {
override fun writeClassInfo(output: SerializationOutput) = ifThrowsAppend({nameForDebug}) {
if (resolvedType != Any::class.java) {
typeSerializer.writeClassInfo(output)
}
}
override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? {
return input.readObjectOrNull(obj, schema, resolvedType)
override fun readProperty(obj: Any?, schema: Schema, input: DeserializationInput): Any? = ifThrowsAppend({nameForDebug}) {
input.readObjectOrNull(obj, schema, resolvedType)
}
override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) {
override fun writeProperty(obj: Any?, data: Data, output: SerializationOutput) = ifThrowsAppend({nameForDebug}) {
output.writeObjectOrNull(readMethod!!.invoke(obj), data, resolvedType)
}
private val nameForDebug = "$name(${resolvedType.typeName})"
}
/**

View File

@ -18,9 +18,9 @@ import net.corda.nodeapi.internal.serialization.carpenter.Field as CarpenterFiel
import net.corda.nodeapi.internal.serialization.carpenter.Schema as CarpenterSchema
// TODO: get an assigned number as per AMQP spec
val DESCRIPTOR_TOP_32BITS: Long = 0xc0da0000
const val DESCRIPTOR_TOP_32BITS: Long = 0xc0da0000
val DESCRIPTOR_DOMAIN: String = "net.corda"
const val DESCRIPTOR_DOMAIN: String = "net.corda"
// "corda" + majorVersionByte + minorVersionMSB + minorVersionLSB
val AmqpHeaderV1_0: OpaqueBytes = OpaqueBytes("corda\u0001\u0000\u0000".toByteArray())
@ -468,7 +468,8 @@ private fun fingerprintForType(type: Type, contextType: Type?, alreadySeen: Muta
}
}
private fun isCollectionOrMap(type: Class<*>) = Collection::class.java.isAssignableFrom(type) || Map::class.java.isAssignableFrom(type)
private fun isCollectionOrMap(type: Class<*>) = (Collection::class.java.isAssignableFrom(type) || Map::class.java.isAssignableFrom(type)) &&
!EnumSet::class.java.isAssignableFrom(type)
private fun fingerprintForObject(type: Type, contextType: Type?, alreadySeen: MutableSet<Type>, hasher: Hasher, factory: SerializerFactory): Hasher {
// Hash the class + properties + interfaces

View File

@ -2,6 +2,7 @@ package net.corda.nodeapi.internal.serialization.amqp
import com.google.common.primitives.Primitives
import com.google.common.reflect.TypeToken
import net.corda.core.serialization.SerializationContext
import org.apache.qpid.proton.codec.Data
import java.beans.IndexedPropertyDescriptor
import java.beans.Introspector
@ -229,4 +230,34 @@ internal fun Type.isSubClassOf(type: Type): Boolean {
internal fun suitableForObjectReference(type: Type): Boolean {
val clazz = type.asClass()
return type != ByteArray::class.java && (clazz != null && !clazz.isPrimitive && !Primitives.unwrap(clazz).isPrimitive)
}
}
/**
* Common properties that are to be used in the [SerializationContext.properties] to alter serialization behavior/content
*/
internal enum class CommonPropertyNames {
IncludeInternalInfo,
}
/**
* Utility function which helps tracking the path in the object graph when exceptions are thrown.
* Since there might be a chain of nested calls it is useful to record which part of the graph caused an issue.
* Path information is added to the message of the exception being thrown.
*/
internal inline fun <T> ifThrowsAppend(strToAppendFn: () -> String, block: () -> T): T {
try {
return block()
} catch (th: Throwable) {
th.setMessage("${strToAppendFn()} -> ${th.message}")
throw th
}
}
/**
* Not a public property so will have to use reflection
*/
private fun Throwable.setMessage(newMsg: String) {
val detailMessageField = Throwable::class.java.getDeclaredField("detailMessage")
detailMessageField.isAccessible = true
detailMessageField.set(this, newMsg)
}

View File

@ -13,7 +13,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList
import javax.annotation.concurrent.ThreadSafe
data class schemaAndDescriptor(val schema: Schema, val typeDescriptor: Any)
data class FactorySchemaAndDescriptor(val schema: Schema, val typeDescriptor: Any)
/**
* Factory of serializers designed to be shared across threads and invocations.
@ -22,8 +22,8 @@ data class schemaAndDescriptor(val schema: Schema, val typeDescriptor: Any)
// TODO: maybe support for caching of serialized form of some core types for performance
// TODO: profile for performance in general
// TODO: use guava caches etc so not unbounded
// TODO: do we need to support a transient annotation to exclude certain properties?
// TODO: allow definition of well known types that are left out of the schema.
// TODO: migrate some core types to unsigned integer descriptor
// TODO: document and alert to the fact that classes cannot default superclass/interface properties otherwise they are "erased" due to matching with constructor.
// TODO: type name prefixes for interfaces and abstract classes? Or use label?
// TODO: generic types should define restricted type alias with source of the wildcarded version, I think, if we're to generate classes from schema
@ -34,7 +34,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
private val serializersByType = ConcurrentHashMap<Type, AMQPSerializer<Any>>()
private val serializersByDescriptor = ConcurrentHashMap<Any, AMQPSerializer<Any>>()
private val customSerializers = CopyOnWriteArrayList<CustomSerializer<out Any>>()
private val classCarpenter = ClassCarpenter(cl)
val classCarpenter = ClassCarpenter(cl, whitelist)
val classloader: ClassLoader
get() = classCarpenter.classloader
@ -61,25 +61,26 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
val actualType: Type = inferTypeVariables(actualClass, declaredClass, declaredType) ?: declaredType
val serializer = when {
// Declared class may not be set to Collection, but actual class could be a collection.
// In this case use of CollectionSerializer is perfectly appropriate.
// Declared class may not be set to Collection, but actual class could be a collection.
// In this case use of CollectionSerializer is perfectly appropriate.
(Collection::class.java.isAssignableFrom(declaredClass) ||
(actualClass != null && Collection::class.java.isAssignableFrom(actualClass))) -> {
val declaredTypeAmended= CollectionSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass)
serializersByType.computeIfAbsent(declaredTypeAmended) {
CollectionSerializer(declaredTypeAmended, this)
}
(actualClass != null && Collection::class.java.isAssignableFrom(actualClass))) &&
!EnumSet::class.java.isAssignableFrom(actualClass ?: declaredClass) -> {
val declaredTypeAmended = CollectionSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass)
serializersByType.computeIfAbsent(declaredTypeAmended) {
CollectionSerializer(declaredTypeAmended, this)
}
}
// Declared class may not be set to Map, but actual class could be a map.
// In this case use of MapSerializer is perfectly appropriate.
// Declared class may not be set to Map, but actual class could be a map.
// In this case use of MapSerializer is perfectly appropriate.
(Map::class.java.isAssignableFrom(declaredClass) ||
(actualClass != null && Map::class.java.isAssignableFrom(actualClass))) -> {
val declaredTypeAmended= MapSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass)
serializersByType.computeIfAbsent(declaredClass) {
makeMapSerializer(declaredTypeAmended)
}
(actualClass != null && Map::class.java.isAssignableFrom(actualClass))) -> {
val declaredTypeAmended = MapSerializer.deriveParameterizedType(declaredType, declaredClass, actualClass)
serializersByType.computeIfAbsent(declaredTypeAmended) {
makeMapSerializer(declaredTypeAmended)
}
}
Enum::class.java.isAssignableFrom(declaredClass) -> serializersByType.computeIfAbsent(declaredClass) {
Enum::class.java.isAssignableFrom(actualClass ?: declaredClass) -> serializersByType.computeIfAbsent(actualClass ?: declaredClass) {
EnumSerializer(actualType, actualClass ?: declaredClass, this)
}
else -> makeClassSerializer(actualClass ?: declaredClass, actualType, declaredType)
@ -98,7 +99,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
private fun inferTypeVariables(actualClass: Class<*>?, declaredClass: Class<*>, declaredType: Type): Type? =
when (declaredType) {
is ParameterizedType -> inferTypeVariables(actualClass, declaredClass, declaredType)
// Nothing to infer, otherwise we'd have ParameterizedType
// Nothing to infer, otherwise we'd have ParameterizedType
is Class<*> -> actualClass
is GenericArrayType -> {
val declaredComponent = declaredType.genericComponentType
@ -164,7 +165,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
@Throws(NotSerializableException::class)
fun get(typeDescriptor: Any, schema: Schema): AMQPSerializer<Any> {
return serializersByDescriptor[typeDescriptor] ?: {
processSchema(schemaAndDescriptor(schema, typeDescriptor))
processSchema(FactorySchemaAndDescriptor(schema, typeDescriptor))
serializersByDescriptor[typeDescriptor] ?: throw NotSerializableException(
"Could not find type matching descriptor $typeDescriptor.")
}()
@ -188,8 +189,8 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
* Iterate over an AMQP schema, for each type ascertain weather it's on ClassPath of [classloader] amd
* if not use the [ClassCarpenter] to generate a class to use in it's place
*/
private fun processSchema(schema: schemaAndDescriptor, sentinel: Boolean = false) {
val carpenterSchemas = CarpenterSchemas.newInstance()
private fun processSchema(schema: FactorySchemaAndDescriptor, sentinel: Boolean = false) {
val metaSchema = CarpenterMetaSchema.newInstance()
for (typeNotation in schema.schema.types) {
try {
val serialiser = processSchemaEntry(typeNotation)
@ -201,13 +202,13 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
getEvolutionSerializer(typeNotation, serialiser)
}
} catch (e: ClassNotFoundException) {
if (sentinel || (typeNotation !is CompositeType)) throw e
typeNotation.carpenterSchema(classloader, carpenterSchemas = carpenterSchemas)
if (sentinel) throw e
metaSchema.buildFor(typeNotation, classloader)
}
}
if (carpenterSchemas.isNotEmpty()) {
val mc = MetaCarpenter(carpenterSchemas, classCarpenter)
if (metaSchema.isNotEmpty()) {
val mc = MetaCarpenter(metaSchema, classCarpenter)
mc.build()
processSchema(schema, true)
}
@ -234,8 +235,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
} else {
findCustomSerializer(clazz, declaredType) ?: run {
if (type.isArray()) {
// Allow Object[] since this can be quite common (i.e. an untyped array)
if (type.componentType() != Object::class.java) whitelisted(type.componentType())
// Don't need to check the whitelist since each element will come back through the whitelisting process.
if (clazz.componentType.isPrimitive) PrimArraySerializer.make(type, this)
else ArraySerializer.make(type, this)
} else if (clazz.kotlin.objectInstance != null) {
@ -256,7 +256,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
for (customSerializer in customSerializers) {
if (customSerializer.isSerializerFor(clazz)) {
val declaredSuperClass = declaredType.asClass()?.superclass
if (declaredSuperClass == null || !customSerializer.isSerializerFor(declaredSuperClass)) {
if (declaredSuperClass == null || !customSerializer.isSerializerFor(declaredSuperClass) || !customSerializer.revealSubclassesInSchema) {
return customSerializer
} else {
// Make a subclass serializer for the subclass and return that...
@ -288,7 +288,7 @@ class SerializerFactory(val whitelist: ClassWhitelist, cl: ClassLoader) {
private fun makeMapSerializer(declaredType: ParameterizedType): AMQPSerializer<Any> {
val rawType = declaredType.rawType as Class<*>
rawType.checkNotUnsupportedHashMap()
rawType.checkSupportedMapType()
return MapSerializer(declaredType, this)
}

View File

@ -0,0 +1,17 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import java.util.*
/**
* A serializer that writes out a [BitSet] as an integer number of bits, plus the necessary number of bytes to encode that
* many bits.
*/
class BitSetSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<BitSet, BitSetSerializer.BitSetProxy>(BitSet::class.java, BitSetProxy::class.java, factory) {
override fun toProxy(obj: BitSet): BitSetProxy = BitSetProxy(obj.toByteArray())
override fun fromProxy(proxy: BitSetProxy): BitSet = BitSet.valueOf(proxy.bytes)
data class BitSetProxy(val bytes: ByteArray)
}

View File

@ -0,0 +1,34 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import net.corda.nodeapi.internal.serialization.amqp.MapSerializer
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import java.util.*
@Suppress("UNCHECKED_CAST")
/**
* A serializer that writes out an [EnumSet] as a type, plus list of instances in the set.
*/
class EnumSetSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<EnumSet<*>, EnumSetSerializer.EnumSetProxy>(EnumSet::class.java, EnumSetProxy::class.java, factory) {
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = listOf(ClassSerializer(factory))
override fun toProxy(obj: EnumSet<*>): EnumSetProxy = EnumSetProxy(elementType(obj), obj.toList())
private fun elementType(set: EnumSet<*>): Class<*> {
return if (set.isEmpty()) {
EnumSet.complementOf(set as EnumSet<MapSerializer.EnumJustUsedForCasting>).first().javaClass
} else {
set.first().javaClass
}
}
override fun fromProxy(proxy: EnumSetProxy): EnumSet<*> {
return if (proxy.elements.isEmpty()) {
EnumSet.noneOf(proxy.clazz as Class<MapSerializer.EnumJustUsedForCasting>)
} else {
EnumSet.copyOf(proxy.elements as List<MapSerializer.EnumJustUsedForCasting>)
}
}
data class EnumSetProxy(val clazz: Class<*>, val elements: List<Any>)
}

View File

@ -0,0 +1,41 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.*
import org.apache.qpid.proton.amqp.Binary
import org.apache.qpid.proton.codec.Data
import java.io.ByteArrayInputStream
import java.io.InputStream
import java.lang.reflect.Type
/**
* A serializer that writes out the content of an input stream as bytes and deserializes into a [ByteArrayInputStream].
*/
object InputStreamSerializer : CustomSerializer.Implements<InputStream>(InputStream::class.java) {
override val revealSubclassesInSchema: Boolean = true
override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList())))
override fun writeDescribedObject(obj: InputStream, data: Data, type: Type, output: SerializationOutput) {
val startingSize = maxOf(4096, obj.available() + 1)
var buffer = ByteArray(startingSize)
var pos = 0
while (true) {
val numberOfBytesRead = obj.read(buffer, pos, buffer.size - pos)
if (numberOfBytesRead != -1) {
pos += numberOfBytesRead
// If the buffer is now full, resize it.
if (pos == buffer.size) {
buffer = buffer.copyOf(buffer.size + maxOf(4096, obj.available() + 1))
}
} else {
data.putBinary(Binary(buffer, 0, pos))
break
}
}
}
override fun readObject(obj: Any, schema: Schema, input: DeserializationInput): InputStream {
val bits = input.readObject(obj, schema, ByteArray::class.java) as ByteArray
return ByteArrayInputStream(bits)
}
}

View File

@ -0,0 +1,9 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import org.apache.activemq.artemis.api.core.SimpleString
/**
* A serializer for [SimpleString].
*/
object SimpleStringSerializer : CustomSerializer.ToString<SimpleString>(SimpleString::class.java)

View File

@ -0,0 +1,8 @@
package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
/**
* A serializer for [StringBuffer].
*/
object StringBufferSerializer : CustomSerializer.ToString<StringBuffer>(StringBuffer::class.java)

View File

@ -2,11 +2,9 @@ package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.core.CordaRuntimeException
import net.corda.core.CordaThrowable
import net.corda.core.serialization.SerializationFactory
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import net.corda.nodeapi.internal.serialization.amqp.constructorForDeserialization
import net.corda.nodeapi.internal.serialization.amqp.propertiesForSerialization
import net.corda.nodeapi.internal.serialization.amqp.*
import java.io.NotSerializableException
class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<Throwable, ThrowableSerializer.ThrowableProxy>(Throwable::class.java, ThrowableProxy::class.java, factory) {
@ -15,6 +13,8 @@ class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<T
private val logger = loggerFor<ThrowableSerializer>()
}
override val revealSubclassesInSchema: Boolean = true
override val additionalSerializers: Iterable<CustomSerializer<out Any>> = listOf(StackTraceElementSerializer(factory))
override fun toProxy(obj: Throwable): ThrowableProxy {
@ -28,12 +28,20 @@ class ThrowableSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<T
extraProperties[prop.name] = prop.readMethod!!.invoke(obj)
}
} catch(e: NotSerializableException) {
logger.warn("Unexpected exception", e)
}
obj.originalMessage
} else {
obj.message
}
return ThrowableProxy(obj.javaClass.name, message, obj.stackTrace, obj.cause, obj.suppressed, extraProperties)
val stackTraceToInclude = if (shouldIncludeInternalInfo()) obj.stackTrace else emptyArray()
return ThrowableProxy(obj.javaClass.name, message, stackTraceToInclude, obj.cause, obj.suppressed, extraProperties)
}
private fun shouldIncludeInternalInfo(): Boolean {
val currentContext = SerializationFactory.currentFactory?.currentContext
val includeInternalInfo = currentContext?.properties?.get(CommonPropertyNames.IncludeInternalInfo)
return true == includeInternalInfo
}
override fun fromProxy(proxy: ThrowableProxy): Throwable {

View File

@ -2,12 +2,14 @@ package net.corda.nodeapi.internal.serialization.amqp.custom
import net.corda.nodeapi.internal.serialization.amqp.CustomSerializer
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import java.time.*
import java.time.ZoneId
/**
* A serializer for [ZoneId] that uses a proxy object to write out the string form.
*/
class ZoneIdSerializer(factory: SerializerFactory) : CustomSerializer.Proxy<ZoneId, ZoneIdSerializer.ZoneIdProxy>(ZoneId::class.java, ZoneIdProxy::class.java, factory) {
override val revealSubclassesInSchema: Boolean = true
override fun toProxy(obj: ZoneId): ZoneIdProxy = ZoneIdProxy(obj.id)
override fun fromProxy(proxy: ZoneIdProxy): ZoneId = ZoneId.of(proxy.id)

View File

@ -1,11 +1,14 @@
@file:JvmName("AMQPSchemaExtensions")
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.RestrictedType
import net.corda.nodeapi.internal.serialization.amqp.Field as AMQPField
import net.corda.nodeapi.internal.serialization.amqp.Schema as AMQPSchema
fun AMQPSchema.carpenterSchema(classloader: ClassLoader) : CarpenterSchemas {
val rtn = CarpenterSchemas.newInstance()
fun AMQPSchema.carpenterSchema(classloader: ClassLoader): CarpenterMetaSchema {
val rtn = CarpenterMetaSchema.newInstance()
types.filterIsInstance<CompositeType>().forEach {
it.carpenterSchema(classloader, carpenterSchemas = rtn)
@ -38,7 +41,7 @@ fun AMQPField.typeAsString() = if (type == "*") requires[0] else type
* on the class path. For testing purposes schema generation can be forced
*/
fun CompositeType.carpenterSchema(classloader: ClassLoader,
carpenterSchemas: CarpenterSchemas,
carpenterSchemas: CarpenterMetaSchema,
force: Boolean = false) {
if (classloader.exists(name)) {
validatePropertyTypes(classloader)
@ -83,6 +86,18 @@ fun CompositeType.carpenterSchema(classloader: ClassLoader,
}
}
// This is potentially problematic as we're assuming the only type of restriction we will be
// carpenting for, an enum, but actually trying to split out RestrictedType into something
// more polymorphic is hard. Additionally, to conform to AMQP we're really serialising
// this as a list so...
fun RestrictedType.carpenterSchema(carpenterSchemas: CarpenterMetaSchema) {
val m: MutableMap<String, Field> = mutableMapOf()
choices.forEach { m[it.name] = EnumField() }
carpenterSchemas.carpenterSchemas.add(EnumSchema(name = name, fields = m))
}
// map a pair of (typename, mandatory) to the corresponding class type
// where the mandatory AMQP flag maps to the types nullability
val typeStrToType: Map<Pair<String, Boolean>, Class<out Any?>> = mapOf(

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
@ -20,11 +21,21 @@ interface SimpleFieldAccess {
operator fun get(name: String): Any?
}
class CarpenterClassLoader (parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) :
class CarpenterClassLoader(parentClassLoader: ClassLoader = Thread.currentThread().contextClassLoader) :
ClassLoader(parentClassLoader) {
fun load(name: String, bytes: ByteArray) = defineClass(name, bytes, 0, bytes.size)
}
/**
* Which version of the java runtime are we constructing objects against
*/
private const val TARGET_VERSION = V1_8
private val jlEnum get() = Type.getInternalName(Enum::class.java)
private val jlString get() = Type.getInternalName(String::class.java)
private val jlObject get() = Type.getInternalName(Object::class.java)
private val jlClass get() = Type.getInternalName(Class::class.java)
/**
* A class carpenter generates JVM bytecodes for a class given a schema and then loads it into a sub-classloader.
* The generated classes have getters, a toString method and implement a simple property access interface. The
@ -69,7 +80,8 @@ class CarpenterClassLoader (parentClassLoader: ClassLoader = Thread.currentThrea
*
* Equals/hashCode methods are not yet supported.
*/
class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader) {
class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader,
val whitelist: ClassWhitelist) {
// TODO: Generics.
// TODO: Sandbox the generated code when a security manager is in use.
// TODO: Generate equals/hashCode.
@ -107,49 +119,66 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
when (it) {
is InterfaceSchema -> generateInterface(it)
is ClassSchema -> generateClass(it)
is EnumSchema -> generateEnum(it)
}
}
assert (schema.name in _loaded)
assert(schema.name in _loaded)
return _loaded[schema.name]!!
}
private fun generateEnum(enumSchema: Schema): Class<*> {
return generate(enumSchema) { cw, schema ->
cw.apply {
visit(TARGET_VERSION, ACC_PUBLIC + ACC_FINAL + ACC_SUPER + ACC_ENUM, schema.jvmName,
"L$jlEnum<L${schema.jvmName};>;", jlEnum, null)
visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd()
generateFields(schema)
generateStaticEnumConstructor(schema)
generateEnumConstructor()
generateEnumValues(schema)
generateEnumValueOf(schema)
}.visitEnd()
}
}
private fun generateInterface(interfaceSchema: Schema): Class<*> {
return generate(interfaceSchema) { cw, schema ->
val interfaces = schema.interfaces.map { it.name.jvm }.toTypedArray()
with(cw) {
visit(V1_8, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null, "java/lang/Object", interfaces)
cw.apply {
visit(TARGET_VERSION, ACC_PUBLIC + ACC_ABSTRACT + ACC_INTERFACE, schema.jvmName, null,
jlObject, interfaces)
visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd()
generateAbstractGetters(schema)
visitEnd()
}
}.visitEnd()
}
}
private fun generateClass(classSchema: Schema): Class<*> {
return generate(classSchema) { cw, schema ->
val superName = schema.superclass?.jvmName ?: "java/lang/Object"
val superName = schema.superclass?.jvmName ?: jlObject
val interfaces = schema.interfaces.map { it.name.jvm }.toMutableList()
if (SimpleFieldAccess::class.java !in schema.interfaces) interfaces.add(SimpleFieldAccess::class.java.name.jvm)
if (SimpleFieldAccess::class.java !in schema.interfaces) {
interfaces.add(SimpleFieldAccess::class.java.name.jvm)
}
with(cw) {
visit(V1_8, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName, interfaces.toTypedArray())
cw.apply {
visit(TARGET_VERSION, ACC_PUBLIC + ACC_SUPER, schema.jvmName, null, superName,
interfaces.toTypedArray())
visitAnnotation(Type.getDescriptor(CordaSerializable::class.java), true).visitEnd()
generateFields(schema)
generateConstructor(schema)
generateClassConstructor(schema)
generateGetters(schema)
if (schema.superclass == null)
generateGetMethod() // From SimplePropertyAccess
generateToString(schema)
visitEnd()
}
}.visitEnd()
}
}
@ -165,25 +194,26 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
}
private fun ClassWriter.generateFields(schema: Schema) {
schema.fields.forEach { it.value.generateField(this) }
schema.generateFields(this)
}
private fun ClassWriter.generateToString(schema: Schema) {
val toStringHelper = "com/google/common/base/MoreObjects\$ToStringHelper"
with(visitMethod(ACC_PUBLIC, "toString", "()Ljava/lang/String;", null, null)) {
with(visitMethod(ACC_PUBLIC, "toString", "()L$jlString;", null, null)) {
visitCode()
// com.google.common.base.MoreObjects.toStringHelper("TypeName")
visitLdcInsn(schema.name.split('.').last())
visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper", "(Ljava/lang/String;)L$toStringHelper;", false)
visitMethodInsn(INVOKESTATIC, "com/google/common/base/MoreObjects", "toStringHelper",
"(L$jlString;)L$toStringHelper;", false)
// Call the add() methods.
for ((name, field) in schema.fieldsIncludingSuperclasses().entries) {
visitLdcInsn(name)
visitVarInsn(ALOAD, 0) // this
visitFieldInsn(GETFIELD, schema.jvmName, name, schema.descriptorsIncludingSuperclasses()[name])
visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(Ljava/lang/String;${field.type})L$toStringHelper;", false)
visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "add", "(L$jlString;${field.type})L$toStringHelper;", false)
}
// call toString() on the builder and return.
visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()Ljava/lang/String;", false)
visitMethodInsn(INVOKEVIRTUAL, toStringHelper, "toString", "()L$jlString;", false)
visitInsn(ARETURN)
visitMaxs(0, 0)
visitEnd()
@ -192,14 +222,14 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
private fun ClassWriter.generateGetMethod() {
val ourJvmName = ClassCarpenter::class.java.name.jvm
with(visitMethod(ACC_PUBLIC, "get", "(Ljava/lang/String;)Ljava/lang/Object;", null, null)) {
with(visitMethod(ACC_PUBLIC, "get", "(L$jlString;)L$jlObject;", null, null)) {
visitCode()
visitVarInsn(ALOAD, 0) // Load 'this'
visitVarInsn(ALOAD, 1) // Load the name argument
// Using this generic helper method is slow, as it relies on reflection. A faster way would be
// to use a tableswitch opcode, or just push back on the user and ask them to use actual reflection
// or MethodHandles (super fast reflection) to access the object instead.
visitMethodInsn(INVOKESTATIC, ourJvmName, "getField", "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false)
visitMethodInsn(INVOKESTATIC, ourJvmName, "getField", "(L$jlObject;L$jlString;)L$jlObject;", false)
visitInsn(ARETURN)
visitMaxs(0, 0)
visitEnd()
@ -207,45 +237,113 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
}
private fun ClassWriter.generateGetters(schema: Schema) {
for ((name, type) in schema.fields) {
with(visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null)) {
@Suppress("UNCHECKED_CAST")
for ((name, type) in (schema.fields as Map<String, ClassField>)) {
visitMethod(ACC_PUBLIC, "get" + name.capitalize(), "()" + type.descriptor, null, null).apply {
type.addNullabilityAnnotation(this)
visitCode()
visitVarInsn(ALOAD, 0) // Load 'this'
visitFieldInsn(GETFIELD, schema.jvmName, name, type.descriptor)
when (type.field) {
java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE,
java.lang.Character.TYPE -> visitInsn(IRETURN)
java.lang.Character.TYPE -> visitInsn(IRETURN)
java.lang.Long.TYPE -> visitInsn(LRETURN)
java.lang.Double.TYPE -> visitInsn(DRETURN)
java.lang.Float.TYPE -> visitInsn(FRETURN)
else -> visitInsn(ARETURN)
}
visitMaxs(0, 0)
visitEnd()
}
}.visitEnd()
}
}
private fun ClassWriter.generateAbstractGetters(schema: Schema) {
for ((name, field) in schema.fields) {
@Suppress("UNCHECKED_CAST")
for ((name, field) in (schema.fields as Map<String, ClassField>)) {
val opcodes = ACC_ABSTRACT + ACC_PUBLIC
with(visitMethod(opcodes, "get" + name.capitalize(), "()${field.descriptor}", null, null)) {
// abstract method doesn't have any implementation so just end
visitEnd()
}
// abstract method doesn't have any implementation so just end
visitMethod(opcodes, "get" + name.capitalize(), "()${field.descriptor}", null, null).visitEnd()
}
}
private fun ClassWriter.generateConstructor(schema: Schema) {
with(visitMethod(
private fun ClassWriter.generateStaticEnumConstructor(schema: Schema) {
visitMethod(ACC_STATIC, "<clinit>", "()V", null, null).apply {
visitCode()
visitIntInsn(BIPUSH, schema.fields.size)
visitTypeInsn(ANEWARRAY, schema.jvmName)
visitInsn(DUP)
var idx = 0
schema.fields.forEach {
visitInsn(DUP)
visitIntInsn(BIPUSH, idx)
visitTypeInsn(NEW, schema.jvmName)
visitInsn(DUP)
visitLdcInsn(it.key)
visitIntInsn(BIPUSH, idx++)
visitMethodInsn(INVOKESPECIAL, schema.jvmName, "<init>", "(L$jlString;I)V", false)
visitInsn(DUP)
visitFieldInsn(PUTSTATIC, schema.jvmName, it.key, "L${schema.jvmName};")
visitInsn(AASTORE)
}
visitFieldInsn(PUTSTATIC, schema.jvmName, "\$VALUES", schema.asArray)
visitInsn(RETURN)
visitMaxs(0, 0)
}.visitEnd()
}
private fun ClassWriter.generateEnumValues(schema: Schema) {
visitMethod(ACC_PUBLIC + ACC_STATIC, "values", "()${schema.asArray}", null, null).apply {
visitCode()
visitFieldInsn(GETSTATIC, schema.jvmName, "\$VALUES", schema.asArray)
visitMethodInsn(INVOKEVIRTUAL, schema.asArray, "clone", "()L$jlObject;", false)
visitTypeInsn(CHECKCAST, schema.asArray)
visitInsn(ARETURN)
visitMaxs(0, 0)
}.visitEnd()
}
private fun ClassWriter.generateEnumValueOf(schema: Schema) {
visitMethod(ACC_PUBLIC + ACC_STATIC, "valueOf", "(L$jlString;)L${schema.jvmName};", null, null).apply {
visitCode()
visitLdcInsn(Type.getType("L${schema.jvmName};"))
visitVarInsn(ALOAD, 0)
visitMethodInsn(INVOKESTATIC, jlEnum, "valueOf", "(L$jlClass;L$jlString;)L$jlEnum;", true)
visitTypeInsn(CHECKCAST, schema.jvmName)
visitInsn(ARETURN)
visitMaxs(0, 0)
}.visitEnd()
}
private fun ClassWriter.generateEnumConstructor() {
visitMethod(ACC_PROTECTED, "<init>", "(L$jlString;I)V", "()V", null).apply {
visitParameter("\$enum\$name", ACC_SYNTHETIC)
visitParameter("\$enum\$ordinal", ACC_SYNTHETIC)
visitCode()
visitVarInsn(ALOAD, 0) // this
visitVarInsn(ALOAD, 1)
visitVarInsn(ILOAD, 2)
visitMethodInsn(INVOKESPECIAL, jlEnum, "<init>", "(L$jlString;I)V", false)
visitInsn(RETURN)
visitMaxs(0, 0)
}.visitEnd()
}
private fun ClassWriter.generateClassConstructor(schema: Schema) {
visitMethod(
ACC_PUBLIC,
"<init>",
"(" + schema.descriptorsIncludingSuperclasses().values.joinToString("") + ")V",
null,
null))
{
null).apply {
var idx = 0
schema.fields.values.forEach { it.visitParameter(this, idx++) }
visitCode()
@ -255,7 +353,7 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
visitVarInsn(ALOAD, 0)
val sc = schema.superclass
if (sc == null) {
visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V", false)
visitMethodInsn(INVOKESPECIAL, jlObject, "<init>", "()V", false)
} else {
var slot = 1
superclassFields.values.forEach { slot += load(slot, it) }
@ -265,7 +363,8 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
// Assign the fields from parameters.
var slot = 1 + superclassFields.size
for ((name, field) in schema.fields.entries) {
@Suppress("UNCHECKED_CAST")
for ((name, field) in (schema.fields as Map<String, ClassField>)) {
field.nullTest(this, slot)
visitVarInsn(ALOAD, 0) // Load 'this' onto the stack
@ -274,14 +373,13 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
}
visitInsn(RETURN)
visitMaxs(0, 0)
visitEnd()
}
}.visitEnd()
}
private fun MethodVisitor.load(slot: Int, type: Field): Int {
when (type.field) {
java.lang.Boolean.TYPE, Integer.TYPE, java.lang.Short.TYPE, java.lang.Byte.TYPE,
java.lang.Character.TYPE -> visitVarInsn(ILOAD, slot)
java.lang.Character.TYPE -> visitVarInsn(ILOAD, slot)
java.lang.Long.TYPE -> visitVarInsn(LLOAD, slot)
java.lang.Double.TYPE -> visitVarInsn(DLOAD, slot)
java.lang.Float.TYPE -> visitVarInsn(FLOAD, slot)
@ -325,7 +423,8 @@ class ClassCarpenter(cl: ClassLoader = Thread.currentThread().contextClassLoader
}
companion object {
@JvmStatic @Suppress("UNUSED")
@JvmStatic
@Suppress("UNUSED")
fun getField(obj: Any, name: String): Any? = obj.javaClass.getMethod("get" + name.capitalize()).invoke(obj)
}
}

View File

@ -1,11 +1,11 @@
package net.corda.nodeapi.internal.serialization.carpenter
class DuplicateNameException : RuntimeException (
class DuplicateNameException : RuntimeException(
"An attempt was made to register two classes with the same name within the same ClassCarpenter namespace.")
class InterfaceMismatchException(msg: String) : RuntimeException(msg)
class NullablePrimitiveException(msg: String) : RuntimeException(msg)
class UncarpentableException (name: String, field: String, type: String) :
Exception ("Class $name is loadable yet contains field $field of unknown type $type")
class UncarpentableException(name: String, field: String, type: String) :
Exception("Class $name is loadable yet contains field $field of unknown type $type")

View File

@ -1,6 +1,7 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.RestrictedType
import net.corda.nodeapi.internal.serialization.amqp.TypeNotation
/**
@ -22,22 +23,19 @@ import net.corda.nodeapi.internal.serialization.amqp.TypeNotation
* in turn look up all of those classes in the [dependsOn] list, remove their dependency on the newly created class,
* and if that list is reduced to zero know we can now generate a [Schema] for them and carpent them up
*/
data class CarpenterSchemas (
data class CarpenterMetaSchema(
val carpenterSchemas: MutableList<Schema>,
val dependencies: MutableMap<String, Pair<TypeNotation, MutableList<String>>>,
val dependsOn: MutableMap<String, MutableList<String>>) {
companion object CarpenterSchemaConstructor {
fun newInstance(): CarpenterSchemas {
return CarpenterSchemas(
mutableListOf<Schema>(),
mutableMapOf<String, Pair<TypeNotation, MutableList<String>>>(),
mutableMapOf<String, MutableList<String>>())
fun newInstance(): CarpenterMetaSchema {
return CarpenterMetaSchema(mutableListOf(), mutableMapOf(), mutableMapOf())
}
}
fun addDepPair(type: TypeNotation, dependant: String, dependee: String) {
dependsOn.computeIfAbsent(dependee, { mutableListOf<String>() }).add(dependant)
dependencies.computeIfAbsent(dependant, { Pair(type, mutableListOf<String>()) }).second.add(dependee)
dependsOn.computeIfAbsent(dependee, { mutableListOf() }).add(dependant)
dependencies.computeIfAbsent(dependant, { Pair(type, mutableListOf()) }).second.add(dependee)
}
val size
@ -45,34 +43,42 @@ data class CarpenterSchemas (
fun isEmpty() = carpenterSchemas.isEmpty()
fun isNotEmpty() = carpenterSchemas.isNotEmpty()
// We could make this an abstract method on TypeNotation but that
// would mean the amqp package being "more" infected with carpenter
// specific bits.
fun buildFor(target: TypeNotation, cl: ClassLoader) = when (target) {
is RestrictedType -> target.carpenterSchema(this)
is CompositeType -> target.carpenterSchema(cl, this, false)
}
}
/**
* Take a dependency tree of [CarpenterSchemas] and reduce it to zero by carpenting those classes that
* Take a dependency tree of [CarpenterMetaSchema] and reduce it to zero by carpenting those classes that
* require it. As classes are carpented check for depdency resolution, if now free generate a [Schema] for
* that class and add it to the list of classes ([CarpenterSchemas.carpenterSchemas]) that require
* that class and add it to the list of classes ([CarpenterMetaSchema.carpenterSchemas]) that require
* carpenting
*
* @property cc a reference to the actual class carpenter we're using to constuct classes
* @property objects a list of carpented classes loaded into the carpenters class loader
*/
abstract class MetaCarpenterBase (val schemas : CarpenterSchemas, val cc : ClassCarpenter = ClassCarpenter()) {
abstract class MetaCarpenterBase(val schemas: CarpenterMetaSchema, val cc: ClassCarpenter) {
val objects = mutableMapOf<String, Class<*>>()
fun step(newObject: Schema) {
objects[newObject.name] = cc.build (newObject)
objects[newObject.name] = cc.build(newObject)
// go over the list of everything that had a dependency on the newly
// carpented class existing and remove it from their dependency list, If that
// list is now empty we have no impediment to carpenting that class up
schemas.dependsOn.remove(newObject.name)?.forEach { dependent ->
assert (newObject.name in schemas.dependencies[dependent]!!.second)
assert(newObject.name in schemas.dependencies[dependent]!!.second)
schemas.dependencies[dependent]?.second?.remove(newObject.name)
// we're out of blockers so we can now create the type
if (schemas.dependencies[dependent]?.second?.isEmpty() ?: false) {
(schemas.dependencies.remove (dependent)?.first as CompositeType).carpenterSchema (
if (schemas.dependencies[dependent]?.second?.isEmpty() == true) {
(schemas.dependencies.remove(dependent)?.first as CompositeType).carpenterSchema(
classloader = cc.classloader,
carpenterSchemas = schemas)
}
@ -81,25 +87,23 @@ abstract class MetaCarpenterBase (val schemas : CarpenterSchemas, val cc : Class
abstract fun build()
val classloader : ClassLoader
get() = cc.classloader
val classloader: ClassLoader
get() = cc.classloader
}
class MetaCarpenter(schemas : CarpenterSchemas,
cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) {
class MetaCarpenter(schemas: CarpenterMetaSchema, cc: ClassCarpenter) : MetaCarpenterBase(schemas, cc) {
override fun build() {
while (schemas.carpenterSchemas.isNotEmpty()) {
val newObject = schemas.carpenterSchemas.removeAt(0)
step (newObject)
step(newObject)
}
}
}
class TestMetaCarpenter(schemas : CarpenterSchemas,
cc : ClassCarpenter = ClassCarpenter()) : MetaCarpenterBase(schemas, cc) {
class TestMetaCarpenter(schemas: CarpenterMetaSchema, cc: ClassCarpenter) : MetaCarpenterBase(schemas, cc) {
override fun build() {
if (schemas.carpenterSchemas.isEmpty()) return
step (schemas.carpenterSchemas.removeAt(0))
step(schemas.carpenterSchemas.removeAt(0))
}
}

View File

@ -1,148 +1,110 @@
package net.corda.nodeapi.internal.serialization.carpenter
import jdk.internal.org.objectweb.asm.Opcodes.*
import kotlin.collections.LinkedHashMap
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Type
import java.util.*
import org.objectweb.asm.Opcodes.*
/**
* A Schema represents a desired class.
* A Schema is the representation of an object the Carpenter can contsruct
*
* Known Sub Classes
* - [ClassSchema]
* - [InterfaceSchema]
* - [EnumSchema]
*/
abstract class Schema(
val name: String,
fields: Map<String, Field>,
var fields: Map<String, Field>,
val superclass: Schema? = null,
val interfaces: List<Class<*>> = emptyList())
{
private fun Map<String, Field>.descriptors() =
LinkedHashMap(this.mapValues { it.value.descriptor })
val interfaces: List<Class<*>> = emptyList(),
updater: (String, Field) -> Unit) {
private fun Map<String, Field>.descriptors() = LinkedHashMap(this.mapValues { it.value.descriptor })
/* Fix the order up front if the user didn't, inject the name into the field as it's
neater when iterating */
val fields = LinkedHashMap(fields.mapValues { it.value.copy(it.key, it.value.field) })
init {
fields.forEach { updater(it.key, it.value) }
// Fix the order up front if the user didn't, inject the name into the field as it's
// neater when iterating
fields = LinkedHashMap(fields)
}
fun fieldsIncludingSuperclasses(): Map<String, Field> =
(superclass?.fieldsIncludingSuperclasses() ?: emptyMap()) + LinkedHashMap(fields)
fun descriptorsIncludingSuperclasses(): Map<String, String> =
fun descriptorsIncludingSuperclasses(): Map<String, String?> =
(superclass?.descriptorsIncludingSuperclasses() ?: emptyMap()) + fields.descriptors()
abstract fun generateFields(cw: ClassWriter)
val jvmName: String
get() = name.replace(".", "/")
val asArray: String
get() = "[L$jvmName;"
}
/**
* Represents a concrete object
*/
class ClassSchema(
name: String,
fields: Map<String, Field>,
superclass: Schema? = null,
interfaces: List<Class<*>> = emptyList()
) : Schema(name, fields, superclass, interfaces)
) : Schema(name, fields, superclass, interfaces, { newName, field -> field.name = newName }) {
override fun generateFields(cw: ClassWriter) {
cw.apply { fields.forEach { it.value.generateField(this) } }
}
}
/**
* Represents an interface. Carpented interfaces can be used within [ClassSchema]s
* if that class should be implementing that interface
*/
class InterfaceSchema(
name: String,
fields: Map<String, Field>,
superclass: Schema? = null,
interfaces: List<Class<*>> = emptyList()
) : Schema(name, fields, superclass, interfaces)
) : Schema(name, fields, superclass, interfaces, { newName, field -> field.name = newName }) {
override fun generateFields(cw: ClassWriter) {
cw.apply { fields.forEach { it.value.generateField(this) } }
}
}
/**
* Represents an enumerated type
*/
class EnumSchema(
name: String,
fields: Map<String, Field>
) : Schema(name, fields, null, emptyList(), { fieldName, field ->
(field as EnumField).name = fieldName
field.descriptor = "L${name.replace(".", "/")};"
}) {
override fun generateFields(cw: ClassWriter) {
with(cw) {
fields.forEach { it.value.generateField(this) }
visitField(ACC_PRIVATE + ACC_FINAL + ACC_STATIC + ACC_SYNTHETIC,
"\$VALUES", asArray, null, null)
}
}
}
/**
* Factory object used by the serialiser when building [Schema]s based
* on an AMQP schema
*/
object CarpenterSchemaFactory {
fun newInstance (
fun newInstance(
name: String,
fields: Map<String, Field>,
superclass: Schema? = null,
interfaces: List<Class<*>> = emptyList(),
isInterface: Boolean = false
) : Schema =
if (isInterface) InterfaceSchema (name, fields, superclass, interfaces)
else ClassSchema (name, fields, superclass, interfaces)
): Schema =
if (isInterface) InterfaceSchema(name, fields, superclass, interfaces)
else ClassSchema(name, fields, superclass, interfaces)
}
abstract class Field(val field: Class<out Any?>) {
companion object {
const val unsetName = "Unset"
}
var name: String = unsetName
abstract val nullabilityAnnotation: String
val descriptor: String
get() = Type.getDescriptor(this.field)
val type: String
get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;"
fun generateField(cw: ClassWriter) {
val fieldVisitor = cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null)
fieldVisitor.visitAnnotation(nullabilityAnnotation, true).visitEnd()
fieldVisitor.visitEnd()
}
fun addNullabilityAnnotation(mv: MethodVisitor) {
mv.visitAnnotation(nullabilityAnnotation, true).visitEnd()
}
fun visitParameter(mv: MethodVisitor, idx: Int) {
with(mv) {
visitParameter(name, 0)
if (!field.isPrimitive) {
visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd()
}
}
}
abstract fun copy(name: String, field: Class<out Any?>): Field
abstract fun nullTest(mv: MethodVisitor, slot: Int)
}
class NonNullableField(field: Class<out Any?>) : Field(field) {
override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;"
constructor(name: String, field: Class<out Any?>) : this(field) {
this.name = name
}
override fun copy(name: String, field: Class<out Any?>) = NonNullableField(name, field)
override fun nullTest(mv: MethodVisitor, slot: Int) {
assert(name != unsetName)
if (!field.isPrimitive) {
with(mv) {
visitVarInsn(ALOAD, 0) // load this
visitVarInsn(ALOAD, slot) // load parameter
visitLdcInsn("param \"$name\" cannot be null")
visitMethodInsn(INVOKESTATIC,
"java/util/Objects",
"requireNonNull",
"(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false)
visitInsn(POP)
}
}
}
}
class NullableField(field: Class<out Any?>) : Field(field) {
override val nullabilityAnnotation = "Ljavax/annotation/Nullable;"
constructor(name: String, field: Class<out Any?>) : this(field) {
if (field.isPrimitive) {
throw NullablePrimitiveException (
"Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable")
}
this.name = name
}
override fun copy(name: String, field: Class<out Any?>) = NullableField(name, field)
override fun nullTest(mv: MethodVisitor, slot: Int) {
assert(name != unsetName)
}
}
object FieldFactory {
fun newInstance (mandatory: Boolean, name: String, field: Class<out Any?>) =
if (mandatory) NonNullableField (name, field) else NullableField (name, field)
}

View File

@ -0,0 +1,137 @@
package net.corda.nodeapi.internal.serialization.carpenter
import jdk.internal.org.objectweb.asm.Opcodes.*
import org.objectweb.asm.ClassWriter
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Type
abstract class Field(val field: Class<out Any?>) {
abstract var descriptor: String?
companion object {
const val unsetName = "Unset"
}
var name: String = unsetName
abstract val type: String
abstract fun generateField(cw: ClassWriter)
abstract fun visitParameter(mv: MethodVisitor, idx: Int)
}
/**
* Any field that can be a member of an object
*
* Known
* - [NullableField]
* - [NonNullableField]
*/
abstract class ClassField(field: Class<out Any?>) : Field(field) {
abstract val nullabilityAnnotation: String
abstract fun nullTest(mv: MethodVisitor, slot: Int)
override var descriptor = Type.getDescriptor(this.field)
override val type: String get() = if (this.field.isPrimitive) this.descriptor else "Ljava/lang/Object;"
fun addNullabilityAnnotation(mv: MethodVisitor) {
mv.visitAnnotation(nullabilityAnnotation, true).visitEnd()
}
override fun generateField(cw: ClassWriter) {
cw.visitField(ACC_PROTECTED + ACC_FINAL, name, descriptor, null, null).visitAnnotation(
nullabilityAnnotation, true).visitEnd()
}
override fun visitParameter(mv: MethodVisitor, idx: Int) {
with(mv) {
visitParameter(name, 0)
if (!field.isPrimitive) {
visitParameterAnnotation(idx, nullabilityAnnotation, true).visitEnd()
}
}
}
}
/**
* A member of a constructed class that can be assigned to null, the
* mandatory type for primitives, but also any member that cannot be
* null
*
* maps to AMQP mandatory = true fields
*/
open class NonNullableField(field: Class<out Any?>) : ClassField(field) {
override val nullabilityAnnotation = "Ljavax/annotation/Nonnull;"
constructor(name: String, field: Class<out Any?>) : this(field) {
this.name = name
}
override fun nullTest(mv: MethodVisitor, slot: Int) {
assert(name != unsetName)
if (!field.isPrimitive) {
with(mv) {
visitVarInsn(ALOAD, 0) // load this
visitVarInsn(ALOAD, slot) // load parameter
visitLdcInsn("param \"$name\" cannot be null")
visitMethodInsn(INVOKESTATIC,
"java/util/Objects",
"requireNonNull",
"(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", false)
visitInsn(POP)
}
}
}
}
/**
* A member of a constructed class that can be assigned to null,
*
* maps to AMQP mandatory = false fields
*/
class NullableField(field: Class<out Any?>) : ClassField(field) {
override val nullabilityAnnotation = "Ljavax/annotation/Nullable;"
constructor(name: String, field: Class<out Any?>) : this(field) {
this.name = name
}
init {
if (field.isPrimitive) {
throw NullablePrimitiveException(
"Field $name is primitive type ${Type.getDescriptor(field)} and thus cannot be nullable")
}
}
override fun nullTest(mv: MethodVisitor, slot: Int) {
assert(name != unsetName)
}
}
/**
* Represents enum constants within an enum
*/
class EnumField : Field(Enum::class.java) {
override var descriptor: String? = null
override val type: String
get() = "Ljava/lang/Enum;"
override fun generateField(cw: ClassWriter) {
cw.visitField(ACC_PUBLIC + ACC_FINAL + ACC_STATIC + ACC_ENUM, name,
descriptor, null, null).visitEnd()
}
override fun visitParameter(mv: MethodVisitor, idx: Int) {
mv.visitParameter(name, 0)
}
}
/**
* Constructs a Field Schema object of the correct type depending weather
* the AMQP schema indicates it's mandatory (non nullable) or not (nullable)
*/
object FieldFactory {
fun newInstance(mandatory: Boolean, name: String, field: Class<out Any?>) =
if (mandatory) NonNullableField(name, field) else NullableField(name, field)
}

View File

@ -0,0 +1,77 @@
package net.corda.nodeapi.internal.serialization;
import com.google.common.collect.Maps;
import net.corda.core.serialization.SerializationContext;
import net.corda.core.serialization.SerializationDefaults;
import net.corda.core.serialization.SerializationFactory;
import net.corda.core.serialization.SerializedBytes;
import net.corda.testing.TestDependencyInjectionBase;
import org.junit.Before;
import org.junit.Test;
import java.io.Serializable;
import java.util.EnumSet;
import java.util.concurrent.Callable;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class ForbiddenLambdaSerializationTests extends TestDependencyInjectionBase {
private SerializationFactory factory;
@Before
public void setup() {
factory = SerializationDefaults.INSTANCE.getSERIALIZATION_FACTORY();
}
@Test
public final void serialization_fails_for_serializable_java_lambdas() throws Exception {
EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint));
contexts.forEach(ctx -> {
SerializationContext context = new SerializationContextImpl(SerializationSchemeKt.getKryoHeaderV0_1(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx);
String value = "Hey";
Callable<String> target = (Callable<String> & Serializable) () -> value;
Throwable throwable = catchThrowable(() -> serialize(target, context));
assertThat(throwable).isNotNull();
assertThat(throwable).isInstanceOf(IllegalArgumentException.class);
if (ctx != SerializationContext.UseCase.RPCServer && ctx != SerializationContext.UseCase.Storage) {
assertThat(throwable).hasMessage(CordaClosureBlacklistSerializer.INSTANCE.getERROR_MESSAGE());
} else {
assertThat(throwable).hasMessageContaining("RPC not allowed to deserialise internal classes");
}
});
}
@Test
@SuppressWarnings("unchecked")
public final void serialization_fails_for_not_serializable_java_lambdas() throws Exception {
EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(EnumSet.of(SerializationContext.UseCase.Checkpoint));
contexts.forEach(ctx -> {
SerializationContext context = new SerializationContextImpl(SerializationSchemeKt.getKryoHeaderV0_1(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, ctx);
String value = "Hey";
Callable<String> target = () -> value;
Throwable throwable = catchThrowable(() -> serialize(target, context));
assertThat(throwable).isNotNull();
assertThat(throwable).isInstanceOf(IllegalArgumentException.class);
assertThat(throwable).isInstanceOf(IllegalArgumentException.class);
if (ctx != SerializationContext.UseCase.RPCServer && ctx != SerializationContext.UseCase.Storage) {
assertThat(throwable).hasMessage(CordaClosureBlacklistSerializer.INSTANCE.getERROR_MESSAGE());
} else {
assertThat(throwable).hasMessageContaining("RPC not allowed to deserialise internal classes");
}
});
}
private <T> SerializedBytes<T> serialize(final T target, final SerializationContext context) {
return factory.serialize(target, context);
}
}

View File

@ -0,0 +1,61 @@
package net.corda.nodeapi.internal.serialization;
import com.google.common.collect.Maps;
import net.corda.core.serialization.SerializationContext;
import net.corda.core.serialization.SerializationDefaults;
import net.corda.core.serialization.SerializationFactory;
import net.corda.core.serialization.SerializedBytes;
import net.corda.testing.TestDependencyInjectionBase;
import org.junit.Before;
import org.junit.Test;
import java.io.Serializable;
import java.util.concurrent.Callable;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class LambdaCheckpointSerializationTest extends TestDependencyInjectionBase {
private SerializationFactory factory;
private SerializationContext context;
@Before
public void setup() {
factory = SerializationDefaults.INSTANCE.getSERIALIZATION_FACTORY();
context = new SerializationContextImpl(SerializationSchemeKt.getKryoHeaderV0_1(), this.getClass().getClassLoader(), AllWhitelist.INSTANCE, Maps.newHashMap(), true, SerializationContext.UseCase.Checkpoint);
}
@Test
@SuppressWarnings("unchecked")
public final void serialization_works_for_serializable_java_lambdas() throws Exception {
String value = "Hey";
Callable<String> target = (Callable<String> & Serializable) () -> value;
SerializedBytes<Callable<String>> serialized = serialize(target);
Callable<String> deserialized = deserialize(serialized, Callable.class);
assertThat(deserialized.call()).isEqualTo(value);
}
@Test
@SuppressWarnings("unchecked")
public final void serialization_fails_for_not_serializable_java_lambdas() throws Exception {
String value = "Hey";
Callable<String> target = () -> value;
Throwable throwable = catchThrowable(() -> serialize(target));
assertThat(throwable).isNotNull();
assertThat(throwable).isInstanceOf(IllegalArgumentException.class);
assertThat(throwable).hasMessage(CordaClosureSerializer.INSTANCE.getERROR_MESSAGE());
}
private <T> SerializedBytes<T> serialize(final T target) {
return factory.serialize(target, context);
}
private <T> T deserialize(final SerializedBytes<? extends T> bytes, final Class<T> type) {
return factory.deserialize(bytes, type, context);
}
}

View File

@ -0,0 +1,17 @@
package net.corda.nodeapi
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.PartyAndReference
import net.corda.core.identity.Party
import net.corda.core.transactions.TransactionBuilder
/**
* This interface deliberately mirrors the one in the finance:isolated module.
* We will actually link [AnotherDummyContract] against this interface rather
* than the one inside isolated.jar, which means we won't need to use reflection
* to execute the contract's generateInitial() method.
*/
interface DummyContractBackdoor {
fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder
fun inspectState(state: ContractState): Int
}

View File

@ -4,11 +4,10 @@ import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory.empty
import com.typesafe.config.ConfigRenderOptions.defaults
import com.typesafe.config.ConfigValueFactory
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.getX500Name
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x500.X500Name
import org.junit.Test
import java.net.URL
import java.nio.file.Path
@ -112,7 +111,7 @@ class ConfigParsingTest {
@Test
fun x500Name() {
testPropertyType<X500NameData, X500NameListData, X500Name>(getX500Name(O = "Mock Party", L = "London", C = "GB"), getX500Name(O = "Mock Party 2", L = "London", C = "GB"), valuesToString = true)
testPropertyType<X500NameData, X500NameListData, CordaX500Name>(CordaX500Name(organisation = "Mock Party", locality = "London", country = "GB"), CordaX500Name(organisation = "Mock Party 2", locality = "London", country = "GB"), valuesToString = true)
}
@Test
@ -229,8 +228,8 @@ class ConfigParsingTest {
data class PathListData(override val values: List<Path>) : ListData<Path>
data class URLData(override val value: URL) : SingleData<URL>
data class URLListData(override val values: List<URL>) : ListData<URL>
data class X500NameData(override val value: X500Name) : SingleData<X500Name>
data class X500NameListData(override val values: List<X500Name>) : ListData<X500Name>
data class X500NameData(override val value: CordaX500Name) : SingleData<CordaX500Name>
data class X500NameListData(override val values: List<CordaX500Name>) : ListData<CordaX500Name>
data class PropertiesData(override val value: Properties) : SingleData<Properties>
data class PropertiesListData(override val values: List<Properties>) : ListData<Properties>
data class MultiPropertyData(val i: Int, val b: Boolean, val l: List<String>)

View File

@ -0,0 +1,74 @@
package net.corda.nodeapi.internal
import net.corda.core.contracts.*
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.testing.*
import net.corda.testing.node.MockServices
import org.junit.After
import org.junit.Assert.*
import org.junit.Before
import org.junit.Test
class AttachmentsClassLoaderStaticContractTests : TestDependencyInjectionBase() {
class AttachmentDummyContract : Contract {
companion object {
private val ATTACHMENT_PROGRAM_ID = "net.corda.nodeapi.internal.AttachmentsClassLoaderStaticContractTests\$AttachmentDummyContract"
}
data class State(val magicNumber: Int = 0) : ContractState {
override val participants: List<AbstractParty>
get() = listOf()
}
interface Commands : CommandData {
class Create : TypeOnlyCommandData(), Commands
}
override fun verify(tx: LedgerTransaction) {
// Always accepts.
}
fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder {
val state = State(magicNumber)
return TransactionBuilder(notary)
.withItems(StateAndContract(state, ATTACHMENT_PROGRAM_ID), Command(Commands.Create(), owner.party.owningKey))
}
}
private lateinit var serviceHub: MockServices
@Before
fun `create service hub`() {
serviceHub = MockServices(cordappPackages=listOf("net.corda.nodeapi.internal"))
}
@After
fun `clear packages`() {
unsetCordappPackages()
}
@Test
fun `test serialization of WireTransaction with statically loaded contract`() {
val tx = AttachmentDummyContract().generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val wireTransaction = tx.toWireTransaction(serviceHub)
val bytes = wireTransaction.serialize()
val copiedWireTransaction = bytes.deserialize()
assertEquals(1, copiedWireTransaction.outputs.size)
assertEquals(42, (copiedWireTransaction.outputs[0].data as AttachmentDummyContract.State).magicNumber)
}
@Test
fun `verify that contract DummyContract is in classPath`() {
val contractClass = Class.forName("net.corda.nodeapi.internal.AttachmentsClassLoaderStaticContractTests\$AttachmentDummyContract")
val contract = contractClass.newInstance() as Contract
assertNotNull(contract)
}
}

View File

@ -1,19 +1,18 @@
package net.corda.nodeapi
package net.corda.nodeapi.internal
import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
import net.corda.core.internal.declaredField
import net.corda.core.node.ServiceHub
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.*
import net.corda.core.transactions.LedgerTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.OpaqueBytes
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.nodeapi.DummyContractBackdoor
import net.corda.nodeapi.internal.serialization.SerializeAsTokenContextImpl
import net.corda.nodeapi.internal.serialization.attachmentsClassLoaderEnabledPropertyName
import net.corda.nodeapi.internal.serialization.withTokenContext
@ -22,8 +21,10 @@ import net.corda.testing.MEGA_CORP
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.kryoSpecific
import net.corda.testing.node.MockAttachmentStorage
import net.corda.testing.node.MockServices
import org.apache.commons.io.IOUtils
import org.junit.Assert
import org.junit.Assert.*
import org.junit.Before
import org.junit.Test
import java.io.ByteArrayInputStream
import java.io.ByteArrayOutputStream
@ -31,99 +32,86 @@ import java.net.URL
import java.net.URLClassLoader
import java.util.jar.JarOutputStream
import java.util.zip.ZipEntry
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
interface DummyContractBackdoor {
fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder
fun inspectState(state: ContractState): Int
}
val ATTACHMENT_TEST_PROGRAM_ID = AttachmentClassLoaderTests.AttachmentDummyContract()
class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
class AttachmentsClassLoaderTests : TestDependencyInjectionBase() {
companion object {
val ISOLATED_CONTRACTS_JAR_PATH: URL = AttachmentClassLoaderTests::class.java.getResource("isolated.jar")
val ISOLATED_CONTRACTS_JAR_PATH: URL = AttachmentsClassLoaderTests::class.java.getResource("isolated.jar")
private const val ISOLATED_CONTRACT_CLASS_NAME = "net.corda.finance.contracts.isolated.AnotherDummyContract"
private fun SerializationContext.withAttachmentStorage(attachmentStorage: AttachmentStorage): SerializationContext {
val serviceHub = mock<ServiceHub>()
whenever(serviceHub.attachments).thenReturn(attachmentStorage)
return this.withServiceHub(serviceHub)
}
private fun SerializationContext.withServiceHub(serviceHub: ServiceHub): SerializationContext {
return this.withTokenContext(SerializeAsTokenContextImpl(serviceHub) {}).withProperty(attachmentsClassLoaderEnabledPropertyName, true)
}
}
class AttachmentDummyContract : Contract {
data class State(val magicNumber: Int = 0) : ContractState {
override val contract = ATTACHMENT_TEST_PROGRAM_ID
override val participants: List<AbstractParty>
get() = listOf()
}
private lateinit var serviceHub: DummyServiceHub
interface Commands : CommandData {
class Create : TypeOnlyCommandData(), Commands
}
class DummyServiceHub : MockServices() {
override val cordappProvider: CordappProviderImpl
= CordappProviderImpl(CordappLoader.createDevMode(listOf(ISOLATED_CONTRACTS_JAR_PATH))).start(attachments)
override fun verify(tx: LedgerTransaction) {
// Always accepts.
}
fun generateInitial(owner: PartyAndReference, magicNumber: Int, notary: Party): TransactionBuilder {
val state = State(magicNumber)
return TransactionBuilder(notary).withItems(state, Command(Commands.Create(), owner.party.owningKey))
}
private val cordapp get() = cordappProvider.cordapps.first()
val attachmentId get() = cordappProvider.getCordappAttachmentId(cordapp)!!
val appContext get() = cordappProvider.getAppContext(cordapp)
}
fun importJar(storage: AttachmentStorage) = ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) }
// These ClassLoaders work together to load 'AnotherDummyContract' in a disposable way, such that even though
// the class may be on the unit test class path (due to default IDE settings, etc), it won't be loaded into the
// regular app classloader but rather than ClassLoaderForTests. This helps keep our environment clean and
// ensures we have precise control over where it's loaded.
object FilteringClassLoader : ClassLoader() {
override fun loadClass(name: String, resolve: Boolean): Class<*>? {
@Throws(ClassNotFoundException::class)
override fun loadClass(name: String, resolve: Boolean): Class<*> {
if ("AnotherDummyContract" in name) {
return null
} else
return super.loadClass(name, resolve)
throw ClassNotFoundException(name)
}
return super.loadClass(name, resolve)
}
}
class ClassLoaderForTests : URLClassLoader(arrayOf(ISOLATED_CONTRACTS_JAR_PATH), FilteringClassLoader)
@Test
fun `dynamically load AnotherDummyContract from isolated contracts jar`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as Contract
assertEquals(SecureHash.sha256("https://anotherdummy.org"), contract.declaredField<Any?>("legalContractReference").value)
@Before
fun `create service hub`() {
serviceHub = DummyServiceHub()
}
fun fakeAttachment(filepath: String, content: String): ByteArray {
@Test
fun `dynamically load AnotherDummyContract from isolated contracts jar`() {
ClassLoaderForTests().use { child ->
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, child)
val contract = contractClass.newInstance() as Contract
assertEquals("helloworld", contract.declaredField<Any?>("magicString").value)
}
}
private fun fakeAttachment(filepath: String, content: String): ByteArray {
val bs = ByteArrayOutputStream()
val js = JarOutputStream(bs)
js.putNextEntry(ZipEntry(filepath))
js.writer().apply { append(content); flush() }
js.closeEntry()
js.close()
JarOutputStream(bs).use { js ->
js.putNextEntry(ZipEntry(filepath))
js.writer().apply { append(content); flush() }
js.closeEntry()
}
return bs.toByteArray()
}
fun readAttachment(attachment: Attachment, filepath: String): ByteArray {
private fun readAttachment(attachment: Attachment, filepath: String): ByteArray {
ByteArrayOutputStream().use {
attachment.extractFile(filepath, it)
return it.toByteArray()
}
}
@Test
fun `test MockAttachmentStorage open as jar`() {
val storage = MockAttachmentStorage()
val key = importJar(storage)
val storage = serviceHub.attachments
val key = serviceHub.attachmentId
val attachment = storage.openAttachment(key)!!
val jar = attachment.openAsJAR()
@ -133,9 +121,9 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
@Test
fun `test overlapping file exception`() {
val storage = MockAttachmentStorage()
val storage = serviceHub.attachments
val att0 = importJar(storage)
val att0 = serviceHub.attachmentId
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file.txt", "some other data")))
@ -146,9 +134,9 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
@Test
fun `basic`() {
val storage = MockAttachmentStorage()
val storage = serviceHub.attachments
val att0 = importJar(storage)
val att0 = serviceHub.attachmentId
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
@ -165,47 +153,38 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("\\folder1\\folderb\\file2.txt", "some other data")))
val data1a = readAttachment(storage.openAttachment(att1)!!, "/folder1/foldera/file1.txt")
Assert.assertArrayEquals("some data".toByteArray(), data1a)
assertArrayEquals("some data".toByteArray(), data1a)
val data1b = readAttachment(storage.openAttachment(att1)!!, "\\folder1\\foldera\\file1.txt")
Assert.assertArrayEquals("some data".toByteArray(), data1b)
assertArrayEquals("some data".toByteArray(), data1b)
val data2a = readAttachment(storage.openAttachment(att2)!!, "\\folder1\\folderb\\file2.txt")
Assert.assertArrayEquals("some other data".toByteArray(), data2a)
assertArrayEquals("some other data".toByteArray(), data2a)
val data2b = readAttachment(storage.openAttachment(att2)!!, "/folder1/folderb/file2.txt")
Assert.assertArrayEquals("some other data".toByteArray(), data2b)
assertArrayEquals("some other data".toByteArray(), data2b)
}
@Test
fun `loading class AnotherDummyContract`() {
val storage = MockAttachmentStorage()
val storage = serviceHub.attachments
val att0 = importJar(storage)
val att0 = serviceHub.attachmentId
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, cl)
val contract = contractClass.newInstance() as Contract
assertEquals(cl, contract.javaClass.classLoader)
assertEquals(SecureHash.sha256("https://anotherdummy.org"), contract.declaredField<Any?>("legalContractReference").value)
assertEquals("helloworld", contract.declaredField<Any?>("magicString").value)
}
@Test
fun `verify that contract DummyContract is in classPath`() {
val contractClass = Class.forName("net.corda.nodeapi.AttachmentClassLoaderTests\$AttachmentDummyContract")
val contract = contractClass.newInstance() as Contract
assertNotNull(contract)
}
fun createContract2Cash(): Contract {
val cl = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl)
return contractClass.newInstance() as Contract
private fun createContract2Cash(): Contract {
ClassLoaderForTests().use { cl ->
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, cl)
return contractClass.newInstance() as Contract
}
}
@Test
@ -214,9 +193,9 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
val bytes = contract.serialize()
val storage = MockAttachmentStorage()
val storage = serviceHub.attachments
val att0 = importJar(storage)
val att0 = serviceHub.attachmentId
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
@ -242,15 +221,15 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
val bytes = data.serialize(context = context2)
val storage = MockAttachmentStorage()
val storage = serviceHub.attachments
val att0 = importJar(storage)
val att0 = serviceHub.attachmentId
val att1 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file1.txt", "some data")))
val att2 = storage.importAttachment(ByteArrayInputStream(fakeAttachment("file2.txt", "some other data")))
val cl = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context = SerializationFactory.defaultFactory.defaultContext.withClassLoader(cl).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl))
val context = SerializationFactory.defaultFactory.defaultContext.withClassLoader(cl).withWhitelisted(Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, cl))
val state2 = bytes.deserialize(context = context)
assertEquals(cl, state2.contract.javaClass.classLoader)
@ -259,7 +238,7 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
// We should be able to load same class from a different class loader and have them be distinct.
val cl2 = AttachmentsClassLoader(arrayOf(att0, att1, att2).map { storage.openAttachment(it)!! }, FilteringClassLoader)
val context3 = SerializationFactory.defaultFactory.defaultContext.withClassLoader(cl2).withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, cl2))
val context3 = SerializationFactory.defaultFactory.defaultContext.withClassLoader(cl2).withWhitelisted(Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, cl2))
val state3 = bytes.deserialize(context = context3)
assertEquals(cl2, state3.contract.javaClass.classLoader)
@ -293,107 +272,110 @@ class AttachmentClassLoaderTests : TestDependencyInjectionBase() {
assertEquals(bytesSequence, copiedBytesSequence)
}
@Test
fun `test serialization of WireTransaction with statically loaded contract`() {
val tx = ATTACHMENT_TEST_PROGRAM_ID.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val wireTransaction = tx.toWireTransaction()
val bytes = wireTransaction.serialize()
val copiedWireTransaction = bytes.deserialize()
assertEquals(1, copiedWireTransaction.outputs.size)
assertEquals(42, (copiedWireTransaction.outputs[0].data as AttachmentDummyContract.State).magicNumber)
}
@Test
fun `test serialization of WireTransaction with dynamically loaded contract`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val child = serviceHub.appContext.classLoader
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val storage = MockAttachmentStorage()
val context = SerializationFactory.defaultFactory.defaultContext.withWhitelisted(contract.javaClass)
.withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$State", true, child))
.withWhitelisted(Class.forName("net.corda.contracts.isolated.AnotherDummyContract\$Commands\$Create", true, child))
.withAttachmentStorage(storage)
val context = SerializationFactory.defaultFactory.defaultContext
.withWhitelisted(contract.javaClass)
.withWhitelisted(Class.forName("$ISOLATED_CONTRACT_CLASS_NAME\$State", true, child))
.withWhitelisted(Class.forName("$ISOLATED_CONTRACT_CLASS_NAME\$Commands\$Create", true, child))
.withServiceHub(serviceHub)
.withClassLoader(child)
// todo - think about better way to push attachmentStorage down to serializer
val bytes = run {
val attachmentRef = importJar(storage)
tx.addAttachment(storage.openAttachment(attachmentRef)!!.id)
val wireTransaction = tx.toWireTransaction()
val wireTransaction = tx.toWireTransaction(serviceHub, context)
wireTransaction.serialize(context = context)
}
val copiedWireTransaction = bytes.deserialize(context = context)
assertEquals(1, copiedWireTransaction.outputs.size)
val contract2 = copiedWireTransaction.getOutput(0).contract as DummyContractBackdoor
// Contracts need to be loaded by the same classloader as the ContractState itself
val contractClassloader = copiedWireTransaction.getOutput(0).javaClass.classLoader
val contract2 = contractClassloader.loadClass(copiedWireTransaction.outputs.first().contract).newInstance() as DummyContractBackdoor
assertEquals(contract2.javaClass.classLoader, copiedWireTransaction.outputs[0].data.javaClass.classLoader)
assertEquals(42, contract2.inspectState(copiedWireTransaction.outputs[0].data))
}
@Test
fun `test deserialize of WireTransaction where contract cannot be found`() = kryoSpecific<AttachmentClassLoaderTests>("Kryo verifies/loads attachments on deserialization, whereas AMQP currently does not") {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
val storage = MockAttachmentStorage()
fun `test deserialize of WireTransaction where contract cannot be found`() {
kryoSpecific<AttachmentsClassLoaderTests>("Kryo verifies/loads attachments on deserialization, whereas AMQP currently does not") {
ClassLoaderForTests().use { child ->
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val tx = contract.generateInitial(MEGA_CORP.ref(0), 42, DUMMY_NOTARY)
// todo - think about better way to push attachmentStorage down to serializer
val attachmentRef = importJar(storage)
val bytes = run {
val attachmentRef = serviceHub.attachmentId
val bytes = run {
val outboundContext = SerializationFactory.defaultFactory.defaultContext
.withServiceHub(serviceHub)
.withClassLoader(child)
val wireTransaction = tx.toWireTransaction(serviceHub, outboundContext)
wireTransaction.serialize(context = outboundContext)
}
// use empty attachmentStorage
tx.addAttachment(storage.openAttachment(attachmentRef)!!.id)
val e = assertFailsWith(MissingAttachmentsException::class) {
val mockAttStorage = MockAttachmentStorage()
val inboundContext = SerializationFactory.defaultFactory.defaultContext
.withAttachmentStorage(mockAttStorage)
.withAttachmentsClassLoader(listOf(attachmentRef))
bytes.deserialize(context = inboundContext)
val wireTransaction = tx.toWireTransaction()
wireTransaction.serialize(context = SerializationFactory.defaultFactory.defaultContext.withAttachmentStorage(storage))
}
// use empty attachmentStorage
val e = assertFailsWith(MissingAttachmentsException::class) {
val mockAttStorage = MockAttachmentStorage()
bytes.deserialize(context = SerializationFactory.defaultFactory.defaultContext.withAttachmentStorage(mockAttStorage))
if(mockAttStorage.openAttachment(attachmentRef) == null) {
throw MissingAttachmentsException(listOf(attachmentRef))
if (mockAttStorage.openAttachment(attachmentRef) == null) {
throw MissingAttachmentsException(listOf(attachmentRef))
}
}
assertEquals(attachmentRef, e.ids.single())
}
}
assertEquals(attachmentRef, e.ids.single())
}
@Test
fun `test loading a class from attachment during deserialization`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val storage = MockAttachmentStorage()
val attachmentRef = importJar(storage)
val outboundContext = SerializationFactory.defaultFactory.defaultContext.withClassLoader(child)
// We currently ignore annotations in attachments, so manually whitelist.
val inboundContext = SerializationFactory.defaultFactory.defaultContext.withWhitelisted(contract.javaClass).withAttachmentStorage(storage).withAttachmentsClassLoader(listOf(attachmentRef))
ClassLoaderForTests().use { child ->
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val outboundContext = SerializationFactory.defaultFactory.defaultContext.withClassLoader(child)
val attachmentRef = serviceHub.attachmentId
// We currently ignore annotations in attachments, so manually whitelist.
val inboundContext = SerializationFactory
.defaultFactory
.defaultContext
.withWhitelisted(contract.javaClass)
.withServiceHub(serviceHub)
.withAttachmentsClassLoader(listOf(attachmentRef))
// Serialize with custom context to avoid populating the default context with the specially loaded class
val serialized = contract.serialize(context = outboundContext)
// Then deserialize with the attachment class loader associated with the attachment
serialized.deserialize(context = inboundContext)
// Serialize with custom context to avoid populating the default context with the specially loaded class
val serialized = contract.serialize(context = outboundContext)
// Then deserialize with the attachment class loader associated with the attachment
serialized.deserialize(context = inboundContext)
}
}
@Test
fun `test loading a class with attachment missing during deserialization`() {
val child = ClassLoaderForTests()
val contractClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val storage = MockAttachmentStorage()
val attachmentRef = SecureHash.randomSHA256()
val outboundContext = SerializationFactory.defaultFactory.defaultContext.withClassLoader(child)
// Serialize with custom context to avoid populating the default context with the specially loaded class
val serialized = contract.serialize(context = outboundContext)
ClassLoaderForTests().use { child ->
val contractClass = Class.forName(ISOLATED_CONTRACT_CLASS_NAME, true, child)
val contract = contractClass.newInstance() as DummyContractBackdoor
val attachmentRef = SecureHash.randomSHA256()
val outboundContext = SerializationFactory.defaultFactory.defaultContext.withClassLoader(child)
// Serialize with custom context to avoid populating the default context with the specially loaded class
val serialized = contract.serialize(context = outboundContext)
// Then deserialize with the attachment class loader associated with the attachment
val e = assertFailsWith(MissingAttachmentsException::class) {
// We currently ignore annotations in attachments, so manually whitelist.
val inboundContext = SerializationFactory.defaultFactory.defaultContext.withWhitelisted(contract.javaClass).withAttachmentStorage(storage).withAttachmentsClassLoader(listOf(attachmentRef))
serialized.deserialize(context = inboundContext)
// Then deserialize with the attachment class loader associated with the attachment
val e = assertFailsWith(MissingAttachmentsException::class) {
// We currently ignore annotations in attachments, so manually whitelist.
val inboundContext = SerializationFactory
.defaultFactory
.defaultContext
.withWhitelisted(contract.javaClass)
.withServiceHub(serviceHub)
.withAttachmentsClassLoader(listOf(attachmentRef))
serialized.deserialize(context = inboundContext)
}
assertEquals(attachmentRef, e.ids.single())
}
assertEquals(attachmentRef, e.ids.single())
}
}
}

View File

@ -3,11 +3,16 @@ package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.MapReferenceResolver
import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.verify
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.AttachmentClassLoaderTests
import net.corda.nodeapi.internal.AttachmentsClassLoader
import net.corda.nodeapi.internal.AttachmentsClassLoaderTests
import net.corda.testing.node.MockAttachmentStorage
import org.junit.Rule
import org.junit.Test
@ -15,6 +20,7 @@ import org.junit.rules.ExpectedException
import java.lang.IllegalStateException
import java.sql.Connection
import java.util.*
import kotlin.test.*
@CordaSerializable
enum class Foo {
@ -52,7 +58,6 @@ open class SerializableViaSubInterface : SerializableSubInterface
class SerializableViaSuperSubInterface : SerializableViaSubInterface()
@CordaSerializable
class CustomSerializable : KryoSerializable {
override fun read(kryo: Kryo?, input: Input?) {
@ -75,16 +80,25 @@ class DefaultSerializableSerializer : Serializer<DefaultSerializable>() {
}
}
object EmptyWhitelist : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = false
}
class CordaClassResolverTests {
private companion object {
val emptyListClass = listOf<Any>().javaClass
val emptySetClass = setOf<Any>().javaClass
val emptyMapClass = mapOf<Any, Any>().javaClass
}
val factory: SerializationFactory = object : SerializationFactory() {
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
TODO("not implemented")
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
TODO("not implemented")
}
}
private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P)
@ -156,14 +170,14 @@ class CordaClassResolverTests {
CordaClassResolver(emptyWhitelistContext).getRegistration(DefaultSerializable::class.java)
}
private fun importJar(storage: AttachmentStorage) = AttachmentClassLoaderTests.ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) }
private fun importJar(storage: AttachmentStorage) = AttachmentsClassLoaderTests.ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) }
@Test(expected = KryoException::class)
fun `Annotation does not work in conjunction with AttachmentClassLoader annotation`() {
val storage = MockAttachmentStorage()
val attachmentHash = importJar(storage)
val classLoader = AttachmentsClassLoader(arrayOf(attachmentHash).map { storage.openAttachment(it)!! })
val attachedClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, classLoader)
val attachedClass = Class.forName("net.corda.finance.contracts.isolated.AnotherDummyContract", true, classLoader)
CordaClassResolver(emptyWhitelistContext).getRegistration(attachedClass)
}
@ -193,6 +207,69 @@ class CordaClassResolverTests {
resolver.getRegistration(HashSet::class.java)
}
@Test
fun `Kotlin EmptyList not registered`() {
val resolver = CordaClassResolver(allButBlacklistedContext)
assertNull(resolver.getRegistration(emptyListClass))
}
@Test
fun `Kotlin EmptyList registers as Java emptyList`() {
val javaEmptyListClass = Collections.emptyList<Any>().javaClass
val kryo = mock<Kryo>()
val resolver = CordaClassResolver(allButBlacklistedContext).apply { setKryo(kryo) }
whenever(kryo.getDefaultSerializer(javaEmptyListClass)).thenReturn(DefaultSerializableSerializer())
val registration = resolver.registerImplicit(emptyListClass)
assertNotNull(registration)
assertEquals(javaEmptyListClass, registration.type)
assertEquals(DefaultClassResolver.NAME.toInt(), registration.id)
verify(kryo).getDefaultSerializer(javaEmptyListClass)
assertEquals(registration, resolver.getRegistration(emptyListClass))
}
@Test
fun `Kotlin EmptySet not registered`() {
val resolver = CordaClassResolver(allButBlacklistedContext)
assertNull(resolver.getRegistration(emptySetClass))
}
@Test
fun `Kotlin EmptySet registers as Java emptySet`() {
val javaEmptySetClass = Collections.emptySet<Any>().javaClass
val kryo = mock<Kryo>()
val resolver = CordaClassResolver(allButBlacklistedContext).apply { setKryo(kryo) }
whenever(kryo.getDefaultSerializer(javaEmptySetClass)).thenReturn(DefaultSerializableSerializer())
val registration = resolver.registerImplicit(emptySetClass)
assertNotNull(registration)
assertEquals(javaEmptySetClass, registration.type)
assertEquals(DefaultClassResolver.NAME.toInt(), registration.id)
verify(kryo).getDefaultSerializer(javaEmptySetClass)
assertEquals(registration, resolver.getRegistration(emptySetClass))
}
@Test
fun `Kotlin EmptyMap not registered`() {
val resolver = CordaClassResolver(allButBlacklistedContext)
assertNull(resolver.getRegistration(emptyMapClass))
}
@Test
fun `Kotlin EmptyMap registers as Java emptyMap`() {
val javaEmptyMapClass = Collections.emptyMap<Any, Any>().javaClass
val kryo = mock<Kryo>()
val resolver = CordaClassResolver(allButBlacklistedContext).apply { setKryo(kryo) }
whenever(kryo.getDefaultSerializer(javaEmptyMapClass)).thenReturn(DefaultSerializableSerializer())
val registration = resolver.registerImplicit(emptyMapClass)
assertNotNull(registration)
assertEquals(javaEmptyMapClass, registration.type)
assertEquals(DefaultClassResolver.NAME.toInt(), registration.id)
verify(kryo).getDefaultSerializer(javaEmptyMapClass)
assertEquals(registration, resolver.getRegistration(emptyMapClass))
}
open class SubHashSet<E> : HashSet<E>()
@Test
fun `Check blacklisted subclass`() {

View File

@ -25,9 +25,8 @@ import org.slf4j.LoggerFactory
import java.io.ByteArrayInputStream
import java.io.InputStream
import java.time.Instant
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import java.util.Collections
import kotlin.test.*
class KryoTests : TestDependencyInjectionBase() {
private lateinit var factory: SerializationFactory
@ -113,6 +112,27 @@ class KryoTests : TestDependencyInjectionBase() {
assertThat(deserialised).isSameAs(TestSingleton)
}
@Test
fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().serialize(factory, context).deserialize(factory, context)
assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
}
@Test
fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().serialize(factory, context).deserialize(factory, context)
assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
}
@Test
fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().serialize(factory, context).deserialize(factory, context)
assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
}
@Test
fun `InputStream serialisation`() {
val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() })

View File

@ -1,18 +1,23 @@
package net.corda.nodeapi.internal.serialization
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.*
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific
import org.assertj.core.api.Assertions
import org.junit.Assert.*
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.io.NotSerializableException
import kotlin.test.assertEquals
import java.nio.charset.StandardCharsets.*
import java.util.*
class ListsSerializationTest : TestDependencyInjectionBase() {
private companion object {
val javaEmptyListClass = Collections.emptyList<Any>().javaClass
}
@Test
fun `check list can be serialized as root of serialization graph`() {
@ -23,16 +28,32 @@ class ListsSerializationTest : TestDependencyInjectionBase() {
@Test
fun `check list can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, listOf(1))
assertEqualAfterRoundTripSerialization(sessionData)
}
run {
val sessionData = SessionData(123, listOf(1, 2))
assertEqualAfterRoundTripSerialization(sessionData)
}
run {
val sessionData = SessionData(123, emptyList<Int>())
assertEqualAfterRoundTripSerialization(sessionData)
}
}
@Test
fun `check empty list serialises as Java emptyList`() {
val nameID = 0
val serializedForm = emptyList<Int>().serialize()
val output = ByteArrayOutputStream().apply {
write(KryoHeaderV0_1.bytes)
write(DefaultClassResolver.NAME + 2)
write(nameID)
write(javaEmptyListClass.name.toAscii())
write(Kryo.NOT_NULL.toInt())
}
assertArrayEquals(output.toByteArray(), serializedForm.bytes)
}
@CordaSerializable
@ -55,4 +76,8 @@ internal inline fun<reified T : Any> assertEqualAfterRoundTripSerialization(obj:
val deserializedInstance = serializedForm.deserialize()
assertEquals(obj, deserializedInstance)
}
}
internal fun String.toAscii(): ByteArray = toByteArray(US_ASCII).apply {
this[lastIndex] = (this[lastIndex] + 0x80).toByte()
}

View File

@ -1,18 +1,25 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.amqpSpecific
import net.corda.testing.kryoSpecific
import org.assertj.core.api.Assertions
import org.junit.Assert.assertArrayEquals
import org.junit.Test
import org.bouncycastle.asn1.x500.X500Name
import java.io.NotSerializableException
import java.io.ByteArrayOutputStream
import java.util.*
class MapsSerializationTest : TestDependencyInjectionBase() {
private val smallMap = mapOf("foo" to "bar", "buzz" to "bull")
private companion object {
val javaEmptyMapClass = Collections.emptyMap<Any, Any>().javaClass
val smallMap = mapOf("foo" to "bar", "buzz" to "bull")
}
@Test
fun `check EmptyMap serialization`() = amqpSpecific<MapsSerializationTest>("kotlin.collections.EmptyMap is not enabled for Kryo serialization") {
@ -38,7 +45,8 @@ class MapsSerializationTest : TestDependencyInjectionBase() {
val payload = HashMap<String, String>(smallMap)
val wrongPayloadType = WrongPayloadType(payload)
Assertions.assertThatThrownBy { wrongPayloadType.serialize() }
.isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType")
.isInstanceOf(IllegalArgumentException::class.java).hasMessageContaining(
"Map type class java.util.HashMap is unstable under iteration. Suggested fix: use java.util.LinkedHashMap instead.")
}
@CordaSerializable
@ -54,4 +62,18 @@ class MapsSerializationTest : TestDependencyInjectionBase() {
MyKey(10.0) to MyValue(X500Name("CN=ten")))
assertEqualAfterRoundTripSerialization(myMap)
}
@Test
fun `check empty map serialises as Java emptyMap`() = kryoSpecific<MapsSerializationTest>("Specifically checks Kryo serialization") {
val nameID = 0
val serializedForm = emptyMap<Int, Int>().serialize()
val output = ByteArrayOutputStream().apply {
write(KryoHeaderV0_1.bytes)
write(DefaultClassResolver.NAME + 2)
write(nameID)
write(javaEmptyMapClass.name.toAscii())
write(Kryo.NOT_NULL.toInt())
}
assertArrayEquals(output.toByteArray(), serializedForm.bytes)
}
}

View File

@ -0,0 +1,54 @@
package net.corda.nodeapi.internal.serialization
import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.util.DefaultClassResolver
import net.corda.core.serialization.serialize
import net.corda.node.services.statemachine.SessionData
import net.corda.testing.TestDependencyInjectionBase
import org.junit.Assert.*
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.util.*
class SetsSerializationTest : TestDependencyInjectionBase() {
private companion object {
val javaEmptySetClass = Collections.emptySet<Any>().javaClass
}
@Test
fun `check set can be serialized as root of serialization graph`() {
assertEqualAfterRoundTripSerialization(emptySet<Int>())
assertEqualAfterRoundTripSerialization(setOf(1))
assertEqualAfterRoundTripSerialization(setOf(1, 2))
}
@Test
fun `check set can be serialized as part of SessionData`() {
run {
val sessionData = SessionData(123, setOf(1))
assertEqualAfterRoundTripSerialization(sessionData)
}
run {
val sessionData = SessionData(123, setOf(1, 2))
assertEqualAfterRoundTripSerialization(sessionData)
}
run {
val sessionData = SessionData(123, emptySet<Int>())
assertEqualAfterRoundTripSerialization(sessionData)
}
}
@Test
fun `check empty set serialises as Java emptySet`() {
val nameID = 0
val serializedForm = emptySet<Int>().serialize()
val output = ByteArrayOutputStream().apply {
write(KryoHeaderV0_1.bytes)
write(DefaultClassResolver.NAME + 2)
write(nameID)
write(javaEmptySetClass.name.toAscii())
write(Kryo.NOT_NULL.toInt())
}
assertArrayEquals(output.toByteArray(), serializedForm.bytes)
}
}

View File

@ -2,10 +2,9 @@ package net.corda.nodeapi.internal.serialization.amqp
import org.assertj.core.api.Assertions
import org.junit.Test
import java.io.NotSerializableException
import java.util.*
class DeserializeCollectionTests {
class DeserializeMapTests {
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
@ -82,7 +81,7 @@ class DeserializeCollectionTests {
// expected to throw
Assertions.assertThatThrownBy { TestSerializationOutput(VERBOSE, sf).serialize(c) }
.isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType")
.isInstanceOf(IllegalArgumentException::class.java).hasMessageContaining("Unable to serialise deprecated type class java.util.Hashtable. Suggested fix: prefer java.util.map implementations")
}
@Test
@ -92,7 +91,7 @@ class DeserializeCollectionTests {
// expect this to throw
Assertions.assertThatThrownBy { TestSerializationOutput(VERBOSE, sf).serialize(c) }
.isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType")
.isInstanceOf(IllegalArgumentException::class.java).hasMessageContaining("Map type class java.util.HashMap is unstable under iteration. Suggested fix: use java.util.LinkedHashMap instead.")
}
@Test
@ -101,7 +100,7 @@ class DeserializeCollectionTests {
val c = C (WeakHashMap (mapOf("A" to 1, "B" to 2)))
Assertions.assertThatThrownBy { TestSerializationOutput(VERBOSE, sf).serialize(c) }
.isInstanceOf(NotSerializableException::class.java).hasMessageContaining("Cannot derive map type for declaredType")
.isInstanceOf(IllegalArgumentException::class.java).hasMessageContaining("Weak references with map types not supported. Suggested fix: use java.util.LinkedHashMap instead.")
}
@Test

View File

@ -0,0 +1,110 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.nodeapi.internal.serialization.AllWhitelist
import org.junit.Test
import kotlin.test.*
import net.corda.nodeapi.internal.serialization.carpenter.*
class DeserializeNeedingCarpentryOfEnumsTest : AmqpCarpenterBase(AllWhitelist) {
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
*/
private const val VERBOSE = false
}
@Test
fun singleEnum() {
//
// Setup the test
//
val setupFactory = testDefaultFactory()
val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF",
"GGG", "HHH", "III", "JJJ").associateBy({ it }, { EnumField() })
// create the enum
val testEnumType = setupFactory.classCarpenter.build(EnumSchema("test.testEnumType", enumConstants))
// create the class that has that enum as an element
val testClassType = setupFactory.classCarpenter.build(ClassSchema("test.testClassType",
mapOf("a" to NonNullableField(testEnumType))))
// create an instance of the class we can then serialise
val testInstance = testClassType.constructors[0].newInstance(testEnumType.getMethod(
"valueOf", String::class.java).invoke(null, "BBB"))
// serialise the object
val serialisedBytes = TestSerializationOutput(VERBOSE, setupFactory).serialize(testInstance)
//
// Test setup done, now on with the actual test
//
// need a second factory to ensure a second carpenter is used and thus the class we're attempting
// to de-serialise isn't in the factories class loader
val testFactory = testDefaultFactoryWithWhitelist()
val deserializedObj = DeserializationInput(testFactory).deserialize(serialisedBytes)
assertTrue(deserializedObj::class.java.getMethod("getA").invoke(deserializedObj)::class.java.isEnum)
assertEquals("BBB",
(deserializedObj::class.java.getMethod("getA").invoke(deserializedObj) as Enum<*>).name)
}
@Test
fun compositeIncludingEnums() {
//
// Setup the test
//
val setupFactory = testDefaultFactory()
val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF",
"GGG", "HHH", "III", "JJJ").associateBy({ it }, { EnumField() })
// create the enum
val testEnumType1 = setupFactory.classCarpenter.build(EnumSchema("test.testEnumType1", enumConstants))
val testEnumType2 = setupFactory.classCarpenter.build(EnumSchema("test.testEnumType2", enumConstants))
// create the class that has that enum as an element
val testClassType = setupFactory.classCarpenter.build(ClassSchema("test.testClassType",
mapOf(
"a" to NonNullableField(testEnumType1),
"b" to NonNullableField(testEnumType2),
"c" to NullableField(testEnumType1),
"d" to NullableField(String::class.java))))
val vOf1 = testEnumType1.getMethod("valueOf", String::class.java)
val vOf2 = testEnumType2.getMethod("valueOf", String::class.java)
val testStr = "so many things [Ø Þ]"
// create an instance of the class we can then serialise
val testInstance = testClassType.constructors[0].newInstance(
vOf1.invoke(null, "CCC"),
vOf2.invoke(null, "EEE"),
null,
testStr)
// serialise the object
val serialisedBytes = TestSerializationOutput(VERBOSE, setupFactory).serialize(testInstance)
//
// Test setup done, now on with the actual test
//
// need a second factory to ensure a second carpenter is used and thus the class we're attempting
// to de-serialise isn't in the factories class loader
val testFactory = testDefaultFactoryWithWhitelist()
val deserializedObj = DeserializationInput(testFactory).deserialize(serialisedBytes)
assertTrue(deserializedObj::class.java.getMethod("getA").invoke(deserializedObj)::class.java.isEnum)
assertEquals("CCC",
(deserializedObj::class.java.getMethod("getA").invoke(deserializedObj) as Enum<*>).name)
assertTrue(deserializedObj::class.java.getMethod("getB").invoke(deserializedObj)::class.java.isEnum)
assertEquals("EEE",
(deserializedObj::class.java.getMethod("getB").invoke(deserializedObj) as Enum<*>).name)
assertNull(deserializedObj::class.java.getMethod("getC").invoke(deserializedObj))
assertEquals(testStr, deserializedObj::class.java.getMethod("getD").invoke(deserializedObj))
}
}

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.serialization.amqp
import net.corda.nodeapi.internal.serialization.AllWhitelist
import org.junit.Test
import kotlin.test.*
import net.corda.nodeapi.internal.serialization.carpenter.*
@ -8,7 +9,7 @@ import net.corda.nodeapi.internal.serialization.carpenter.*
// those classes don't exist within the system's Class Loader the deserialiser will be forced to carpent
// versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This
// replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath
class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase(AllWhitelist) {
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
@ -16,12 +17,12 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
private const val VERBOSE = false
}
val sf = testDefaultFactory()
val sf2 = testDefaultFactory()
private val sf = testDefaultFactory()
private val sf2 = testDefaultFactory()
@Test
fun singleInt() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"int" to NonNullableField(Integer::class.javaPrimitiveType!!)
)))
@ -31,9 +32,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
// despite being carpented, and thus not on the class path, we should've cached clazz
// inside the serialiser object and thus we should have created the same type
assertEquals (db::class.java, clazz)
assertNotEquals (db2::class.java, clazz)
assertNotEquals (db::class.java, db2::class.java)
assertEquals(db::class.java, clazz)
assertNotEquals(db2::class.java, clazz)
assertNotEquals(db::class.java, db2::class.java)
assertEquals(1, db::class.java.getMethod("getInt").invoke(db))
assertEquals(1, db2::class.java.getMethod("getInt").invoke(db2))
@ -41,7 +42,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleIntNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"int" to NullableField(Integer::class.java)
)))
@ -57,7 +58,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleIntNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"int" to NullableField(Integer::class.java)
)))
@ -73,7 +74,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleChar() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"char" to NonNullableField(Character::class.javaPrimitiveType!!)
)))
@ -86,7 +87,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleCharNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"char" to NullableField(Character::class.javaObjectType)
)))
@ -99,7 +100,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleCharNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"char" to NullableField(java.lang.Character::class.java)
)))
@ -112,11 +113,11 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleLong() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"long" to NonNullableField(Long::class.javaPrimitiveType!!)
)))
val l : Long = 1
val l: Long = 1
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l))
val db = DeserializationInput(sf2).deserialize(sb)
@ -126,11 +127,11 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleLongNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"long" to NullableField(Long::class.javaObjectType)
)))
val l : Long = 1
val l: Long = 1
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(l))
val db = DeserializationInput(sf2).deserialize(sb)
@ -140,7 +141,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleLongNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"long" to NullableField(Long::class.javaObjectType)
)))
@ -153,7 +154,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleBoolean() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"boolean" to NonNullableField(Boolean::class.javaPrimitiveType!!)
)))
@ -166,7 +167,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleBooleanNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"boolean" to NullableField(Boolean::class.javaObjectType)
)))
@ -179,7 +180,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleBooleanNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"boolean" to NullableField(Boolean::class.javaObjectType)
)))
@ -192,7 +193,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleDouble() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"double" to NonNullableField(Double::class.javaPrimitiveType!!)
)))
@ -205,7 +206,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleDoubleNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"double" to NullableField(Double::class.javaObjectType)
)))
@ -218,7 +219,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleDoubleNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"double" to NullableField(Double::class.javaObjectType)
)))
@ -231,7 +232,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleShort() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"short" to NonNullableField(Short::class.javaPrimitiveType!!)
)))
@ -244,7 +245,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleShortNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"short" to NullableField(Short::class.javaObjectType)
)))
@ -257,7 +258,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleShortNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"short" to NullableField(Short::class.javaObjectType)
)))
@ -270,7 +271,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleFloat() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"float" to NonNullableField(Float::class.javaPrimitiveType!!)
)))
@ -283,7 +284,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleFloatNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"float" to NullableField(Float::class.javaObjectType)
)))
@ -296,7 +297,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleFloatNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"float" to NullableField(Float::class.javaObjectType)
)))
@ -309,11 +310,11 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleByte() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"byte" to NonNullableField(Byte::class.javaPrimitiveType!!)
)))
val b : Byte = 0b0101
val b: Byte = 0b0101
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b))
val db = DeserializationInput(sf2).deserialize(sb)
@ -324,11 +325,11 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleByteNullable() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"byte" to NullableField(Byte::class.javaObjectType)
)))
val b : Byte = 0b0101
val b: Byte = 0b0101
val sb = TestSerializationOutput(VERBOSE, sf).serialize(clazz.constructors.first().newInstance(b))
val db = DeserializationInput(sf2).deserialize(sb)
@ -339,7 +340,7 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun singleByteNullableNull() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"byte" to NullableField(Byte::class.javaObjectType)
)))
@ -352,9 +353,9 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun simpleTypeKnownInterface() {
val clazz = ClassCarpenter().build (ClassSchema(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(
testName(), mapOf("name" to NonNullableField(String::class.java)),
interfaces = listOf (I::class.java)))
interfaces = listOf(I::class.java)))
val testVal = "Some Person"
val classInstance = clazz.constructors[0].newInstance(testVal)
val serialisedBytes = TestSerializationOutput(VERBOSE, sf).serialize(classInstance)
@ -367,34 +368,34 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
@Test
fun manyTypes() {
val manyClass = ClassCarpenter().build (ClassSchema(testName(), mapOf(
"intA" to NonNullableField (Int::class.java),
"intB" to NullableField (Integer::class.java),
"intC" to NullableField (Integer::class.java),
"strA" to NonNullableField (String::class.java),
"strB" to NullableField (String::class.java),
"strC" to NullableField (String::class.java),
"charA" to NonNullableField (Char::class.java),
"charB" to NullableField (Character::class.javaObjectType),
"charC" to NullableField (Character::class.javaObjectType),
"shortA" to NonNullableField (Short::class.javaPrimitiveType!!),
"shortB" to NullableField (Short::class.javaObjectType),
"shortC" to NullableField (Short::class.javaObjectType),
"longA" to NonNullableField (Long::class.javaPrimitiveType!!),
val manyClass = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"intA" to NonNullableField(Int::class.java),
"intB" to NullableField(Integer::class.java),
"intC" to NullableField(Integer::class.java),
"strA" to NonNullableField(String::class.java),
"strB" to NullableField(String::class.java),
"strC" to NullableField(String::class.java),
"charA" to NonNullableField(Char::class.java),
"charB" to NullableField(Character::class.javaObjectType),
"charC" to NullableField(Character::class.javaObjectType),
"shortA" to NonNullableField(Short::class.javaPrimitiveType!!),
"shortB" to NullableField(Short::class.javaObjectType),
"shortC" to NullableField(Short::class.javaObjectType),
"longA" to NonNullableField(Long::class.javaPrimitiveType!!),
"longB" to NullableField(Long::class.javaObjectType),
"longC" to NullableField(Long::class.javaObjectType),
"booleanA" to NonNullableField (Boolean::class.javaPrimitiveType!!),
"booleanB" to NullableField (Boolean::class.javaObjectType),
"booleanC" to NullableField (Boolean::class.javaObjectType),
"doubleA" to NonNullableField (Double::class.javaPrimitiveType!!),
"doubleB" to NullableField (Double::class.javaObjectType),
"doubleC" to NullableField (Double::class.javaObjectType),
"floatA" to NonNullableField (Float::class.javaPrimitiveType!!),
"floatB" to NullableField (Float::class.javaObjectType),
"floatC" to NullableField (Float::class.javaObjectType),
"byteA" to NonNullableField (Byte::class.javaPrimitiveType!!),
"byteB" to NullableField (Byte::class.javaObjectType),
"byteC" to NullableField (Byte::class.javaObjectType))))
"booleanA" to NonNullableField(Boolean::class.javaPrimitiveType!!),
"booleanB" to NullableField(Boolean::class.javaObjectType),
"booleanC" to NullableField(Boolean::class.javaObjectType),
"doubleA" to NonNullableField(Double::class.javaPrimitiveType!!),
"doubleB" to NullableField(Double::class.javaObjectType),
"doubleC" to NullableField(Double::class.javaObjectType),
"floatA" to NonNullableField(Float::class.javaPrimitiveType!!),
"floatB" to NullableField(Float::class.javaObjectType),
"floatC" to NullableField(Float::class.javaObjectType),
"byteA" to NonNullableField(Byte::class.javaPrimitiveType!!),
"byteB" to NullableField(Byte::class.javaObjectType),
"byteC" to NullableField(Byte::class.javaObjectType))))
val serialisedBytes = TestSerializationOutput(VERBOSE, factory).serialize(
manyClass.constructors.first().newInstance(
@ -411,33 +412,33 @@ class DeserializeNeedingCarpentrySimpleTypesTest : AmqpCarpenterBase() {
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertNotEquals(manyClass, deserializedObj::class.java)
assertEquals(1, deserializedObj::class.java.getMethod("getIntA").invoke(deserializedObj))
assertEquals(2, deserializedObj::class.java.getMethod("getIntB").invoke(deserializedObj))
assertEquals(1, deserializedObj::class.java.getMethod("getIntA").invoke(deserializedObj))
assertEquals(2, deserializedObj::class.java.getMethod("getIntB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getIntC").invoke(deserializedObj))
assertEquals("a", deserializedObj::class.java.getMethod("getStrA").invoke(deserializedObj))
assertEquals("b", deserializedObj::class.java.getMethod("getStrB").invoke(deserializedObj))
assertEquals("a", deserializedObj::class.java.getMethod("getStrA").invoke(deserializedObj))
assertEquals("b", deserializedObj::class.java.getMethod("getStrB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getStrC").invoke(deserializedObj))
assertEquals('c', deserializedObj::class.java.getMethod("getCharA").invoke(deserializedObj))
assertEquals('d', deserializedObj::class.java.getMethod("getCharB").invoke(deserializedObj))
assertEquals('c', deserializedObj::class.java.getMethod("getCharA").invoke(deserializedObj))
assertEquals('d', deserializedObj::class.java.getMethod("getCharB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getCharC").invoke(deserializedObj))
assertEquals(3.toShort(), deserializedObj::class.java.getMethod("getShortA").invoke(deserializedObj))
assertEquals(4.toShort(), deserializedObj::class.java.getMethod("getShortB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getShortC").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getShortC").invoke(deserializedObj))
assertEquals(100.toLong(), deserializedObj::class.java.getMethod("getLongA").invoke(deserializedObj))
assertEquals(200.toLong(), deserializedObj::class.java.getMethod("getLongB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getLongC").invoke(deserializedObj))
assertEquals(true, deserializedObj::class.java.getMethod("getBooleanA").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getLongC").invoke(deserializedObj))
assertEquals(true, deserializedObj::class.java.getMethod("getBooleanA").invoke(deserializedObj))
assertEquals(false, deserializedObj::class.java.getMethod("getBooleanB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getBooleanC").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getBooleanC").invoke(deserializedObj))
assertEquals(10.0, deserializedObj::class.java.getMethod("getDoubleA").invoke(deserializedObj))
assertEquals(20.0, deserializedObj::class.java.getMethod("getDoubleB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getDoubleC").invoke(deserializedObj))
assertEquals(10.0F, deserializedObj::class.java.getMethod("getFloatA").invoke(deserializedObj))
assertEquals(20.0F, deserializedObj::class.java.getMethod("getFloatB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getFloatC").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getFloatC").invoke(deserializedObj))
assertEquals(0b0101.toByte(), deserializedObj::class.java.getMethod("getByteA").invoke(deserializedObj))
assertEquals(0b1010.toByte(), deserializedObj::class.java.getMethod("getByteB").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getByteC").invoke(deserializedObj))
assertEquals(null, deserializedObj::class.java.getMethod("getByteC").invoke(deserializedObj))
}
}

View File

@ -8,7 +8,7 @@ import net.corda.nodeapi.internal.serialization.AllWhitelist
@CordaSerializable
interface I {
fun getName() : String
fun getName(): String
}
// These tests work by having the class carpenter build the classes we serialise and then deserialise them
@ -19,7 +19,7 @@ interface I {
// However, those classes don't exist within the system's Class Loader and thus the deserialiser will be forced
// to carpent versions of them up using its own internal class carpenter (each carpenter houses it's own loader). This
// replicates the situation where a receiver doesn't have some or all elements of a schema present on it's classpath
class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
class DeserializeNeedingCarpentryTests : AmqpCarpenterBase(AllWhitelist) {
companion object {
/**
* If you want to see the schema encoded into the envelope after serialisation change this to true
@ -27,38 +27,40 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
private const val VERBOSE = false
}
val sf1 = testDefaultFactory()
val sf2 = testDefaultFactoryWithWhitelist() // Deserialize with whitelisting on to check that `CordaSerializable` annotation present.
private val sf1 = testDefaultFactory()
// Deserialize with whitelisting on to check that `CordaSerializable` annotation present.
private val sf2 = testDefaultFactoryWithWhitelist()
@Test
fun verySimpleType() {
val testVal = 10
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf("a" to NonNullableField(Int::class.java))))
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(),
mapOf("a" to NonNullableField(Int::class.java))))
val classInstance = clazz.constructors[0].newInstance(testVal)
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance)
val deserializedObj1 = DeserializationInput(sf1).deserialize(serialisedBytes)
assertEquals(clazz, deserializedObj1::class.java)
assertEquals (testVal, deserializedObj1::class.java.getMethod("getA").invoke(deserializedObj1))
assertEquals(testVal, deserializedObj1::class.java.getMethod("getA").invoke(deserializedObj1))
val deserializedObj2 = DeserializationInput(sf1).deserialize(serialisedBytes)
assertEquals(clazz, deserializedObj2::class.java)
assertEquals(deserializedObj1::class.java, deserializedObj2::class.java)
assertEquals (testVal, deserializedObj2::class.java.getMethod("getA").invoke(deserializedObj2))
assertEquals(testVal, deserializedObj2::class.java.getMethod("getA").invoke(deserializedObj2))
val deserializedObj3 = DeserializationInput(sf2).deserialize(serialisedBytes)
assertNotEquals(clazz, deserializedObj3::class.java)
assertNotEquals(deserializedObj1::class.java, deserializedObj3::class.java)
assertNotEquals(deserializedObj2::class.java, deserializedObj3::class.java)
assertEquals (testVal, deserializedObj3::class.java.getMethod("getA").invoke(deserializedObj3))
assertEquals(testVal, deserializedObj3::class.java.getMethod("getA").invoke(deserializedObj3))
val deserializedObj4 = DeserializationInput(sf2).deserialize(serialisedBytes)
assertNotEquals(clazz, deserializedObj4::class.java)
assertNotEquals(deserializedObj1::class.java, deserializedObj4::class.java)
assertNotEquals(deserializedObj2::class.java, deserializedObj4::class.java)
assertEquals(deserializedObj3::class.java, deserializedObj4::class.java)
assertEquals (testVal, deserializedObj4::class.java.getMethod("getA").invoke(deserializedObj4))
assertEquals(testVal, deserializedObj4::class.java.getMethod("getA").invoke(deserializedObj4))
}
@ -67,7 +69,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
val testValA = 10
val testValB = 20
val testValC = 20
val clazz = ClassCarpenter().build(ClassSchema("${testName()}_clazz",
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema("${testName()}_clazz",
mapOf("a" to NonNullableField(Int::class.java))))
val concreteA = clazz.constructors[0].newInstance(testValA)
@ -77,13 +79,13 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
val deserialisedA = DeserializationInput(sf2).deserialize(
TestSerializationOutput(VERBOSE, sf1).serialize(concreteA))
assertEquals (testValA, deserialisedA::class.java.getMethod("getA").invoke(deserialisedA))
assertEquals(testValA, deserialisedA::class.java.getMethod("getA").invoke(deserialisedA))
val deserialisedB = DeserializationInput(sf2).deserialize(
TestSerializationOutput(VERBOSE, sf1).serialize(concreteB))
assertEquals (testValB, deserialisedA::class.java.getMethod("getA").invoke(deserialisedB))
assertEquals (deserialisedA::class.java, deserialisedB::class.java)
assertEquals(testValB, deserialisedA::class.java.getMethod("getA").invoke(deserialisedB))
assertEquals(deserialisedA::class.java, deserialisedB::class.java)
// C is deseriliased with a different factory, meaning a different class carpenter, so the type
// won't already exist and it will be carpented a second time showing that when A and B are the
@ -93,16 +95,16 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
val deserialisedC = DeserializationInput(lfactory).deserialize(
TestSerializationOutput(VERBOSE, lfactory).serialize(concreteC))
assertEquals (testValC, deserialisedC::class.java.getMethod("getA").invoke(deserialisedC))
assertNotEquals (deserialisedA::class.java, deserialisedC::class.java)
assertNotEquals (deserialisedB::class.java, deserialisedC::class.java)
assertEquals(testValC, deserialisedC::class.java.getMethod("getA").invoke(deserialisedC))
assertNotEquals(deserialisedA::class.java, deserialisedC::class.java)
assertNotEquals(deserialisedB::class.java, deserialisedC::class.java)
}
@Test
fun simpleTypeKnownInterface() {
val clazz = ClassCarpenter().build (ClassSchema(
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(
testName(), mapOf("name" to NonNullableField(String::class.java)),
interfaces = listOf (I::class.java)))
interfaces = listOf(I::class.java)))
val testVal = "Some Person"
val classInstance = clazz.constructors[0].newInstance(testVal)
@ -115,12 +117,13 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
@Test
fun arrayOfTypes() {
val clazz = ClassCarpenter().build(ClassSchema(testName(), mapOf("a" to NonNullableField(Int::class.java))))
val clazz = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(),
mapOf("a" to NonNullableField(Int::class.java))))
@CordaSerializable
data class Outer (val a : Array<Any>)
data class Outer(val a: Array<Any>)
val outer = Outer (arrayOf (
val outer = Outer(arrayOf(
clazz.constructors[0].newInstance(1),
clazz.constructors[0].newInstance(2),
clazz.constructors[0].newInstance(3)))
@ -148,7 +151,7 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
@Test
fun reusedClasses() {
val cc = ClassCarpenter()
val cc = ClassCarpenter(whitelist = AllWhitelist)
val innerType = cc.build(ClassSchema("${testName()}.inner", mapOf("a" to NonNullableField(Int::class.java))))
val outerType = cc.build(ClassSchema("${testName()}.outer", mapOf("a" to NonNullableField(innerType))))
@ -157,21 +160,21 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
val serializedI = TestSerializationOutput(VERBOSE, sf1).serialize(inner)
val deserialisedI = DeserializationInput(sf2).deserialize(serializedI)
val serialisedO = TestSerializationOutput(VERBOSE, sf1).serialize(outer)
val serialisedO = TestSerializationOutput(VERBOSE, sf1).serialize(outer)
val deserialisedO = DeserializationInput(sf2).deserialize(serialisedO)
// ensure out carpented version of inner is reused
assertEquals (deserialisedI::class.java,
assertEquals(deserialisedI::class.java,
(deserialisedO::class.java.getMethod("getA").invoke(deserialisedO))::class.java)
}
@Test
fun nestedTypes() {
val cc = ClassCarpenter()
val nestedClass = cc.build (ClassSchema("nestedType",
val cc = ClassCarpenter(whitelist = AllWhitelist)
val nestedClass = cc.build(ClassSchema("nestedType",
mapOf("name" to NonNullableField(String::class.java))))
val outerClass = cc.build (ClassSchema("outerType",
val outerClass = cc.build(ClassSchema("outerType",
mapOf("inner" to NonNullableField(nestedClass))))
val classInstance = outerClass.constructors.first().newInstance(nestedClass.constructors.first().newInstance("name"))
@ -184,33 +187,34 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
@Test
fun repeatedNestedTypes() {
val cc = ClassCarpenter()
val nestedClass = cc.build (ClassSchema("nestedType",
val cc = ClassCarpenter(whitelist = AllWhitelist)
val nestedClass = cc.build(ClassSchema("nestedType",
mapOf("name" to NonNullableField(String::class.java))))
@CordaSerializable
data class outer(val a: Any, val b: Any)
val classInstance = outer (
val classInstance = outer(
nestedClass.constructors.first().newInstance("foo"),
nestedClass.constructors.first().newInstance("bar"))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(classInstance)
val deserializedObj = DeserializationInput(sf2).deserialize(serialisedBytes)
assertEquals ("foo", deserializedObj.a::class.java.getMethod("getName").invoke(deserializedObj.a))
assertEquals ("bar", deserializedObj.b::class.java.getMethod("getName").invoke(deserializedObj.b))
assertEquals("foo", deserializedObj.a::class.java.getMethod("getName").invoke(deserializedObj.a))
assertEquals("bar", deserializedObj.b::class.java.getMethod("getName").invoke(deserializedObj.b))
}
@Test
fun listOfType() {
val unknownClass = ClassCarpenter().build (ClassSchema(testName(), mapOf(
val unknownClass = ClassCarpenter(whitelist = AllWhitelist).build(ClassSchema(testName(), mapOf(
"v1" to NonNullableField(Int::class.java),
"v2" to NonNullableField(Int::class.java))))
@CordaSerializable
data class outer (val l : List<Any>)
val toSerialise = outer (listOf (
data class outer(val l: List<Any>)
val toSerialise = outer(listOf(
unknownClass.constructors.first().newInstance(1, 2),
unknownClass.constructors.first().newInstance(3, 4),
unknownClass.constructors.first().newInstance(5, 6),
@ -227,16 +231,16 @@ class DeserializeNeedingCarpentryTests : AmqpCarpenterBase() {
@Test
fun unknownInterface() {
val cc = ClassCarpenter()
val cc = ClassCarpenter(whitelist = AllWhitelist)
val interfaceClass = cc.build (InterfaceSchema(
val interfaceClass = cc.build(InterfaceSchema(
"gen.Interface",
mapOf("age" to NonNullableField (Int::class.java))))
mapOf("age" to NonNullableField(Int::class.java))))
val concreteClass = cc.build (ClassSchema (testName(), mapOf(
"age" to NonNullableField (Int::class.java),
val concreteClass = cc.build(ClassSchema(testName(), mapOf(
"age" to NonNullableField(Int::class.java),
"name" to NonNullableField(String::class.java)),
interfaces = listOf (I::class.java, interfaceClass)))
interfaces = listOf(I::class.java, interfaceClass)))
val serialisedBytes = TestSerializationOutput(VERBOSE, sf1).serialize(
concreteClass.constructors.first().newInstance(12, "timmy"))

View File

@ -8,9 +8,11 @@ import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowException
import net.corda.core.identity.AbstractParty
import net.corda.core.internal.toX509CertHolder
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationFactory
import net.corda.core.transactions.LedgerTransaction
import net.corda.nodeapi.RPCException
import net.corda.client.rpc.RPCException
import net.corda.nodeapi.internal.serialization.AbstractAMQPSerializationScheme
import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.EmptyWhitelist
@ -18,6 +20,8 @@ import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion
import net.corda.testing.BOB_IDENTITY
import net.corda.testing.MEGA_CORP
import net.corda.testing.MEGA_CORP_PUBKEY
import net.corda.testing.withTestSerialization
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.qpid.proton.amqp.*
import org.apache.qpid.proton.codec.DecoderImpl
import org.apache.qpid.proton.codec.EncoderImpl
@ -25,6 +29,7 @@ import org.junit.Assert.assertNotSame
import org.junit.Assert.assertSame
import org.junit.Ignore
import org.junit.Test
import java.io.ByteArrayInputStream
import java.io.IOException
import java.io.NotSerializableException
import java.math.BigDecimal
@ -178,7 +183,7 @@ class SerializationOutputTests {
assertTrue(Objects.deepEquals(desObj, desObj2) == expectDeserializedEqual)
// TODO: add some schema assertions to check correctly formed.
return desObj2
return desObj
}
@Test
@ -353,7 +358,7 @@ class SerializationOutputTests {
serdes(obj)
}
@Test(expected = NotSerializableException::class)
@Test
fun `test TreeMap`() {
val obj = TreeMap<Int, Foo>()
obj[456] = Foo("Fred", 123)
@ -418,10 +423,18 @@ class SerializationOutputTests {
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(factory2))
val t = IllegalAccessException("message").fillInStackTrace()
val desThrowable = serdes(t, factory, factory2, false) as Throwable
val desThrowable = serdesThrowableWithInternalInfo(t, factory, factory2, false)
assertSerializedThrowableEquivalent(t, desThrowable)
}
private fun serdesThrowableWithInternalInfo(t: Throwable, factory: SerializerFactory, factory2: SerializerFactory, expectedEqual: Boolean = true): Throwable = withTestSerialization {
val newContext = SerializationFactory.defaultFactory.defaultContext.withProperty(CommonPropertyNames.IncludeInternalInfo, true)
val deserializedObj = SerializationFactory.defaultFactory.asCurrent { withCurrentContext(newContext) { serdes(t, factory, factory2, expectedEqual) } }
return deserializedObj
}
@Test
fun `test complex throwables serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
@ -437,7 +450,7 @@ class SerializationOutputTests {
throw IllegalStateException("Layer 2", t)
}
} catch(t: Throwable) {
val desThrowable = serdes(t, factory, factory2, false)
val desThrowable = serdesThrowableWithInternalInfo(t, factory, factory2, false)
assertSerializedThrowableEquivalent(t, desThrowable)
}
}
@ -469,7 +482,7 @@ class SerializationOutputTests {
throw e
}
} catch(t: Throwable) {
val desThrowable = serdes(t, factory, factory2, false)
val desThrowable = serdesThrowableWithInternalInfo(t, factory, factory2, false)
assertSerializedThrowableEquivalent(t, desThrowable)
}
}
@ -483,7 +496,7 @@ class SerializationOutputTests {
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(factory2))
val obj = FlowException("message").fillInStackTrace()
serdes(obj, factory, factory2)
serdesThrowableWithInternalInfo(obj, factory, factory2)
}
@Test
@ -495,7 +508,7 @@ class SerializationOutputTests {
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(factory2))
val obj = RPCException("message").fillInStackTrace()
serdes(obj, factory, factory2)
serdesThrowableWithInternalInfo(obj, factory, factory2)
}
@Test
@ -521,16 +534,14 @@ class SerializationOutputTests {
}
}
val FOO_PROGRAM_ID = "net.corda.nodeapi.internal.serialization.amqp.SerializationOutputTests.FooContract"
class FooState : ContractState {
override val contract: Contract
get() = FooContract
override val participants: List<AbstractParty>
get() = emptyList()
override val participants: List<AbstractParty> = emptyList()
}
@Test
fun `test transaction state`() {
val state = TransactionState(FooState(), MEGA_CORP)
val state = TransactionState(FooState(), FOO_PROGRAM_ID, MEGA_CORP)
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
AbstractAMQPSerializationScheme.registerCustomSerializers(factory)
@ -712,8 +723,18 @@ class SerializationOutputTests {
serdes(obj, factory, factory2)
}
// TODO: ignored due to Proton-J bug https://issues.apache.org/jira/browse/PROTON-1551
@Ignore
@Test
fun `test month serialize`() {
val obj = Month.APRIL
serdes(obj)
}
@Test
fun `test day of week serialize`() {
val obj = DayOfWeek.FRIDAY
serdes(obj)
}
@Test
fun `test certificate holder serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
@ -722,12 +743,10 @@ class SerializationOutputTests {
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.X509CertificateHolderSerializer)
val obj = BOB_IDENTITY.certificate
val obj = BOB_IDENTITY.certificate.toX509CertHolder()
serdes(obj, factory, factory2)
}
// TODO: ignored due to Proton-J bug https://issues.apache.org/jira/browse/PROTON-1551
@Ignore
@Test
fun `test party and certificate serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
@ -798,7 +817,6 @@ class SerializationOutputTests {
data class Bob(val byteArrays: List<ByteArray>)
@Ignore("Causes DeserializedParameterizedType.make() to fail")
@Test
fun `test list of byte arrays`() {
val a = ByteArray(1)
@ -807,7 +825,9 @@ class SerializationOutputTests {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
serdes(obj, factory, factory2)
val obj2 = serdes(obj, factory, factory2, false, false)
assertNotSame(obj2.byteArrays[0], obj2.byteArrays[2])
}
data class Vic(val a: List<String>, val b: List<String>)
@ -866,4 +886,88 @@ class SerializationOutputTests {
val objCopy = serdes(obj, factory, factory2, false, false)
assertNotSame(objCopy.a, objCopy.b)
}
@Test
fun `test StringBuffer serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory.register(net.corda.nodeapi.internal.serialization.amqp.custom.StringBufferSerializer)
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.StringBufferSerializer)
val obj = StringBuffer("Bob")
val obj2 = serdes(obj, factory, factory2, false, false)
assertEquals(obj.toString(), obj2.toString())
}
@Test
fun `test SimpleString serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory.register(net.corda.nodeapi.internal.serialization.amqp.custom.SimpleStringSerializer)
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.SimpleStringSerializer)
val obj = SimpleString("Bob")
serdes(obj, factory, factory2)
}
@Test
fun `test kotlin Unit serialize`() {
val obj = Unit
serdes(obj)
}
@Test
fun `test kotlin Pair serialize`() {
val obj = Pair("a", 3)
serdes(obj)
}
@Test
fun `test InputStream serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory.register(net.corda.nodeapi.internal.serialization.amqp.custom.InputStreamSerializer)
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.InputStreamSerializer)
val bytes = ByteArray(10) { it.toByte() }
val obj = ByteArrayInputStream(bytes)
val obj2 = serdes(obj, factory, factory2, expectedEqual = false, expectDeserializedEqual = false)
val obj3 = ByteArrayInputStream(bytes) // Can't use original since the stream pointer has moved.
assertEquals(obj3.available(), obj2.available())
assertEquals(obj3.read(), obj2.read())
}
@Test
fun `test EnumSet serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory.register(net.corda.nodeapi.internal.serialization.amqp.custom.EnumSetSerializer(factory))
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.EnumSetSerializer(factory2))
val obj = EnumSet.of(Month.APRIL, Month.AUGUST)
serdes(obj, factory, factory2)
}
@Test
fun `test BitSet serialize`() {
val factory = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory.register(net.corda.nodeapi.internal.serialization.amqp.custom.BitSetSerializer(factory))
val factory2 = SerializerFactory(AllWhitelist, ClassLoader.getSystemClassLoader())
factory2.register(net.corda.nodeapi.internal.serialization.amqp.custom.BitSetSerializer(factory2))
val obj = BitSet.valueOf(kotlin.ByteArray(16) { it.toByte() }).get(0, 123)
serdes(obj, factory, factory2)
}
@Test
fun `test EnumMap serialize`() {
val obj = EnumMap<Month, Int>(Month::class.java)
obj[Month.APRIL] = Month.APRIL.value
obj[Month.AUGUST] = Month.AUGUST.value
serdes(obj)
}
}

View File

@ -1,5 +1,6 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.nodeapi.internal.serialization.AllWhitelist
import org.junit.Test
import java.beans.Introspector
import java.lang.reflect.Field
@ -15,12 +16,12 @@ class ClassCarpenterTest {
val b: Int
}
val cc = ClassCarpenter()
private val cc = ClassCarpenter(whitelist = AllWhitelist)
// We have to ignore synthetic fields even though ClassCarpenter doesn't create any because the JaCoCo
// coverage framework auto-magically injects one method and one field into every class loaded into the JVM.
val Class<*>.nonSyntheticFields: List<Field> get() = declaredFields.filterNot { it.isSynthetic }
val Class<*>.nonSyntheticMethods: List<Method> get() = declaredMethods.filterNot { it.isSynthetic }
private val Class<*>.nonSyntheticFields: List<Field> get() = declaredFields.filterNot { it.isSynthetic }
private val Class<*>.nonSyntheticMethods: List<Method> get() = declaredMethods.filterNot { it.isSynthetic }
@Test
fun empty() {
@ -266,7 +267,7 @@ class ClassCarpenterTest {
mapOf("a" to NonNullableField(Int::class.java)))
val clazz = cc.build(schema)
val a : Int? = null
val a: Int? = null
clazz.constructors[0].newInstance(a)
}
@ -288,10 +289,10 @@ class ClassCarpenterTest {
mapOf("a" to NullableField(Integer::class.java)))
val clazz = cc.build(schema)
val a1 : Int? = null
val a1: Int? = null
clazz.constructors[0].newInstance(a1)
val a2 : Int? = 10
val a2: Int? = 10
clazz.constructors[0].newInstance(a2)
}
@ -304,7 +305,7 @@ class ClassCarpenterTest {
val clazz = cc.build(schema)
val a : Int? = 10
val a: Int? = 10
clazz.constructors[0].newInstance(a)
}
@ -317,7 +318,7 @@ class ClassCarpenterTest {
val clazz = cc.build(schema)
val a : Int? = null
val a: Int? = null
clazz.constructors[0].newInstance(a)
}
@ -350,7 +351,7 @@ class ClassCarpenterTest {
val clazz = cc.build(schema)
val a : IntArray? = null
val a: IntArray? = null
clazz.constructors[0].newInstance(a)
}
@ -472,14 +473,14 @@ class ClassCarpenterTest {
val clazz = cc.build(schema)
assertEquals (2, clazz.declaredFields.size)
assertEquals (1, clazz.getDeclaredField("a").annotations.size)
assertEquals(2, clazz.declaredFields.size)
assertEquals(1, clazz.getDeclaredField("a").annotations.size)
assertEquals(Nullable::class.java, clazz.getDeclaredField("a").annotations[0].annotationClass.java)
assertEquals (1, clazz.getDeclaredField("b").annotations.size)
assertEquals(1, clazz.getDeclaredField("b").annotations.size)
assertEquals(Nonnull::class.java, clazz.getDeclaredField("b").annotations[0].annotationClass.java)
assertEquals (1, clazz.getMethod("getA").annotations.size)
assertEquals(1, clazz.getMethod("getA").annotations.size)
assertEquals(Nullable::class.java, clazz.getMethod("getA").annotations[0].annotationClass.java)
assertEquals (1, clazz.getMethod("getB").annotations.size)
assertEquals(1, clazz.getMethod("getB").annotations.size)
assertEquals(Nonnull::class.java, clazz.getMethod("getB").annotations[0].annotationClass.java)
}

View File

@ -1,8 +1,10 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.ClassWhitelist
import net.corda.nodeapi.internal.serialization.amqp.*
import net.corda.nodeapi.internal.serialization.amqp.Field
import net.corda.nodeapi.internal.serialization.amqp.Schema
import net.corda.nodeapi.internal.serialization.AllWhitelist
fun mangleName(name: String) = "${name}__carpenter"
@ -33,8 +35,9 @@ fun Schema.mangleNames(names: List<String>): Schema {
return Schema(types = newTypes)
}
open class AmqpCarpenterBase {
var factory = testDefaultFactory()
open class AmqpCarpenterBase(whitelist: ClassWhitelist) {
var cc = ClassCarpenter(whitelist = whitelist)
var factory = SerializerFactory(AllWhitelist, cc.classloader)
fun serialise(clazz: Any) = SerializationOutput(factory).serialize(clazz)
fun testName(): String = Thread.currentThread().stackTrace[2].methodName

View File

@ -1,6 +1,7 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
@ -13,7 +14,7 @@ interface I_ {
val a: Int
}
class CompositeMembers : AmqpCarpenterBase() {
class CompositeMembers : AmqpCarpenterBase(AllWhitelist) {
@Test
fun bothKnown() {
val testA = 10
@ -42,7 +43,7 @@ class CompositeMembers : AmqpCarpenterBase() {
var amqpSchemaB: CompositeType? = null
for (type in obj.envelope.schema.types) {
when (type.name.split ("$").last()) {
when (type.name.split("$").last()) {
"A" -> amqpSchemaA = type as CompositeType
"B" -> amqpSchemaB = type as CompositeType
}
@ -88,7 +89,7 @@ class CompositeMembers : AmqpCarpenterBase() {
val b = B(A(testA), testB)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(b))
val amqpSchema = obj.envelope.schema.mangleNames(listOf (classTestName ("A")))
val amqpSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A")))
assert(obj.obj is B)
@ -116,7 +117,7 @@ class CompositeMembers : AmqpCarpenterBase() {
assertEquals(1, carpenterSchema.size)
val metaCarpenter = MetaCarpenter(carpenterSchema)
val metaCarpenter = MetaCarpenter(carpenterSchema, ClassCarpenter(whitelist = AllWhitelist))
metaCarpenter.build()
@ -151,7 +152,7 @@ class CompositeMembers : AmqpCarpenterBase() {
assertEquals(1, carpenterSchema.dependsOn.size)
assert(mangleName(classTestName("A")) in carpenterSchema.dependsOn)
val metaCarpenter = TestMetaCarpenter(carpenterSchema)
val metaCarpenter = TestMetaCarpenter(carpenterSchema, ClassCarpenter(whitelist = AllWhitelist))
assertEquals(0, metaCarpenter.objects.size)
@ -251,7 +252,8 @@ class CompositeMembers : AmqpCarpenterBase() {
assert(obj.obj is C)
val carpenterSchema = obj.envelope.schema.mangleNames(listOf(classTestName("A"), classTestName("B")))
TestMetaCarpenter(carpenterSchema.carpenterSchema(ClassLoader.getSystemClassLoader()))
TestMetaCarpenter(carpenterSchema.carpenterSchema(
ClassLoader.getSystemClassLoader()), ClassCarpenter(whitelist = AllWhitelist))
}
/*

View File

@ -0,0 +1,106 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.nodeapi.internal.serialization.AllWhitelist
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class EnumClassTests : AmqpCarpenterBase(AllWhitelist) {
@Test
fun oneValue() {
val enumConstants = mapOf("A" to EnumField())
val schema = EnumSchema("gen.enum", enumConstants)
assertTrue(cc.build(schema).isEnum)
}
@Test
fun oneValueInstantiate() {
val enumConstants = mapOf("A" to EnumField())
val schema = EnumSchema("gen.enum", enumConstants)
val clazz = cc.build(schema)
assertTrue(clazz.isEnum)
assertEquals(enumConstants.size, clazz.enumConstants.size)
assertEquals("A", clazz.enumConstants.first().toString())
assertEquals(0, (clazz.enumConstants.first() as Enum<*>).ordinal)
assertEquals("A", (clazz.enumConstants.first() as Enum<*>).name)
}
@Test
fun twoValuesInstantiate() {
val enumConstants = mapOf("left" to EnumField(), "right" to EnumField())
val schema = EnumSchema("gen.enum", enumConstants)
val clazz = cc.build(schema)
assertTrue(clazz.isEnum)
assertEquals(enumConstants.size, clazz.enumConstants.size)
val left = clazz.enumConstants[0] as Enum<*>
val right = clazz.enumConstants[1] as Enum<*>
assertEquals(0, left.ordinal)
assertEquals("left", left.name)
assertEquals(1, right.ordinal)
assertEquals("right", right.name)
}
@Test
fun manyValues() {
val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF",
"GGG", "HHH", "III", "JJJ").associateBy({ it }, { EnumField() })
val schema = EnumSchema("gen.enum", enumConstants)
val clazz = cc.build(schema)
assertTrue(clazz.isEnum)
assertEquals(enumConstants.size, clazz.enumConstants.size)
var idx = 0
enumConstants.forEach {
val constant = clazz.enumConstants[idx] as Enum<*>
assertEquals(idx++, constant.ordinal)
assertEquals(it.key, constant.name)
}
}
@Test
fun assignment() {
val enumConstants = listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF").associateBy({ it }, { EnumField() })
val schema = EnumSchema("gen.enum", enumConstants)
val clazz = cc.build(schema)
assertEquals("CCC", clazz.getMethod("valueOf", String::class.java).invoke(null, "CCC").toString())
assertEquals("CCC", (clazz.getMethod("valueOf", String::class.java).invoke(null, "CCC") as Enum<*>).name)
val ddd = clazz.getMethod("valueOf", String::class.java).invoke(null, "DDD") as Enum<*>
assertTrue(ddd::class.java.isEnum)
assertEquals("DDD", ddd.name)
assertEquals(3, ddd.ordinal)
}
// if anything goes wrong with this test it's going to end up throwing *some*
// exception, hence the lack of asserts
@Test
fun assignAndTest() {
val cc2 = ClassCarpenter(whitelist = AllWhitelist)
val schema1 = EnumSchema("gen.enum",
listOf("AAA", "BBB", "CCC", "DDD", "EEE", "FFF").associateBy({ it }, { EnumField() }))
val enumClazz = cc2.build(schema1)
val schema2 = ClassSchema("gen.class",
mapOf(
"a" to NonNullableField(Int::class.java),
"b" to NonNullableField(enumClazz)))
val classClazz = cc2.build(schema2)
// make sure we can construct a class that has an enum we've constructed as a member
classClazz.constructors[0].newInstance(1, enumClazz.getMethod(
"valueOf", String::class.java).invoke(null, "BBB"))
}
}

View File

@ -1,6 +1,7 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
import kotlin.test.*
@ -32,7 +33,7 @@ interface IIII {
val i: I
}
class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) {
@Test
fun interfaceParent1() {
class A(override val j: Int) : J
@ -61,7 +62,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals(1, aSchema.interfaces.size)
assertEquals(J::class.java, aSchema.interfaces[0])
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val objJ = aBuilder.constructors[0].newInstance(testJ)
val j = objJ as J
@ -106,7 +107,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals(1, aSchema.interfaces.size)
assertEquals(J::class.java, aSchema.interfaces[0])
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val objJ = aBuilder.constructors[0].newInstance(testJ, testJJ)
val j = objJ as J
@ -154,7 +155,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertTrue(I::class.java in aSchema.interfaces)
assertTrue(II::class.java in aSchema.interfaces)
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val objA = aBuilder.constructors[0].newInstance(testI, testII)
val i = objA as I
val ii = objA as II
@ -200,7 +201,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertTrue(I::class.java in aSchema.interfaces)
assertTrue(III::class.java in aSchema.interfaces)
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val objA = aBuilder.constructors[0].newInstance(testI, testIII)
val i = objA as I
val iii = objA as III
@ -247,8 +248,8 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertNotEquals(null, aCarpenterSchema)
assertNotEquals(null, bCarpenterSchema)
val cc = ClassCarpenter()
val cc2 = ClassCarpenter()
val cc = ClassCarpenter(whitelist = AllWhitelist)
val cc2 = ClassCarpenter(whitelist = AllWhitelist)
val bBuilder = cc.build(bCarpenterSchema!!)
bBuilder.constructors[0].newInstance(a, testIIII)
@ -332,7 +333,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals(1, carpenterSchema.dependsOn[iName]!!.size)
assertEquals(aName, carpenterSchema.dependsOn[iName]!![0])
val mc = MetaCarpenter(carpenterSchema)
val mc = MetaCarpenter(carpenterSchema, ClassCarpenter(whitelist = AllWhitelist))
mc.build()
assertEquals(0, mc.schemas.carpenterSchemas.size)
@ -385,7 +386,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName })
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiName })
val mc = MetaCarpenter(carpenterSchema)
val mc = MetaCarpenter(carpenterSchema, ClassCarpenter(whitelist = AllWhitelist))
mc.build()
assertEquals(0, mc.schemas.carpenterSchemas.size)
@ -445,7 +446,7 @@ class InheritanceSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iiiName })
assertNotNull(carpenterSchema.dependencies[aName]!!.second.find { it == iName })
val mc = MetaCarpenter(carpenterSchema)
val mc = MetaCarpenter(carpenterSchema, ClassCarpenter(whitelist = AllWhitelist))
mc.build()
assertEquals(0, mc.schemas.carpenterSchemas.size)

View File

@ -1,13 +1,14 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertNotEquals
class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) {
@Test
fun twoInts() {
@ -35,7 +36,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals("b", amqpSchema.fields[1].name)
assertEquals("int", amqpSchema.fields[1].type)
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
@ -46,7 +47,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertNotEquals(null, aSchema)
val pinochio = ClassCarpenter().build(aSchema!!)
val pinochio = ClassCarpenter(whitelist = AllWhitelist).build(aSchema!!)
val p = pinochio.constructors[0].newInstance(testA, testB)
assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a)
@ -79,7 +80,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals("b", amqpSchema.fields[1].name)
assertEquals("string", amqpSchema.fields[1].type)
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
@ -90,7 +91,7 @@ class MultiMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertNotEquals(null, aSchema)
val pinochio = ClassCarpenter().build(aSchema!!)
val pinochio = ClassCarpenter(whitelist = AllWhitelist).build(aSchema!!)
val p = pinochio.constructors[0].newInstance(testA, testB)
assertEquals(pinochio.getMethod("getA").invoke(p), amqpObj.a)

View File

@ -1,12 +1,13 @@
package net.corda.nodeapi.internal.serialization.carpenter
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.internal.serialization.AllWhitelist
import net.corda.nodeapi.internal.serialization.amqp.CompositeType
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import org.junit.Test
import kotlin.test.assertEquals
class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase(AllWhitelist) {
@Test
fun singleInteger() {
@CordaSerializable
@ -29,14 +30,14 @@ class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("int", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
@ -60,14 +61,14 @@ class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assert(obj.envelope.schema.types[0] is CompositeType)
val amqpSchema = obj.envelope.schema.types[0] as CompositeType
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
@ -95,14 +96,14 @@ class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("long", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
@ -130,14 +131,14 @@ class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("short", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
@ -165,14 +166,14 @@ class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("double", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)
@ -183,7 +184,7 @@ class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
@CordaSerializable
data class A(val a: Float)
val test: Float = 10.0F
val test = 10.0F
val a = A(test)
val obj = DeserializationInput(factory).deserializeAndReturnEnvelope(serialise(a))
@ -200,14 +201,14 @@ class SingleMemberCompositeSchemaToClassCarpenterTests : AmqpCarpenterBase() {
assertEquals("a", amqpSchema.fields[0].name)
assertEquals("float", amqpSchema.fields[0].type)
val carpenterSchema = CarpenterSchemas.newInstance()
val carpenterSchema = CarpenterMetaSchema.newInstance()
amqpSchema.carpenterSchema(
classloader = ClassLoader.getSystemClassLoader(),
carpenterSchemas = carpenterSchema,
force = true)
val aSchema = carpenterSchema.carpenterSchemas.find { it.name == classTestName("A") }!!
val aBuilder = ClassCarpenter().build(aSchema)
val aBuilder = ClassCarpenter(whitelist = AllWhitelist).build(aSchema)
val p = aBuilder.constructors[0].newInstance(test)
assertEquals(aBuilder.getMethod("getA").invoke(p), amqpObj.a)