Merge branch 'master' into minorTestChanges

This commit is contained in:
Viktor Kolomeyko 2017-08-14 17:26:04 +01:00 committed by GitHub
commit 3f9270e38d
48 changed files with 701 additions and 427 deletions

View File

@ -71,7 +71,7 @@ class CordaRPCClient(
fun initialiseSerialization() { fun initialiseSerialization() {
try { try {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme()) registerScheme(KryoClientSerializationScheme(this))
registerScheme(AMQPClientSerializationScheme()) registerScheme(AMQPClientSerializationScheme())
} }
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT

View File

@ -3,20 +3,21 @@ package net.corda.client.rpc.serialization
import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.client.rpc.internal.RpcClientObservableSerializer import net.corda.client.rpc.internal.RpcClientObservableSerializer
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.RPCKryo import net.corda.nodeapi.RPCKryo
import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme
import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer
import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1
class KryoClientSerializationScheme : AbstractKryoSerializationScheme() { class KryoClientSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) {
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
return byteSequence == KryoHeaderV0_1 && (target == SerializationContext.UseCase.RPCClient || target == SerializationContext.UseCase.P2P) return byteSequence == KryoHeaderV0_1 && (target == SerializationContext.UseCase.RPCClient || target == SerializationContext.UseCase.P2P)
} }
override fun rpcClientKryoPool(context: SerializationContext): KryoPool { override fun rpcClientKryoPool(context: SerializationContext): KryoPool {
return KryoPool.Builder { return KryoPool.Builder {
DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, context.whitelist)).apply { classLoader = context.deserializationClassLoader } DefaultKryoCustomizer.customize(RPCKryo(RpcClientObservableSerializer, serializationFactory, context)).apply { classLoader = context.deserializationClassLoader }
}.build() }.build()
} }

View File

@ -11,18 +11,22 @@ import javax.persistence.Converter
* Completely anonymous parties are stored as null (to preserve privacy) * Completely anonymous parties are stored as null (to preserve privacy)
*/ */
@Converter(autoApply = true) @Converter(autoApply = true)
class AbstractPartyToX500NameAsStringConverter(val identitySvc: IdentityService) : AttributeConverter<AbstractParty, String> { class AbstractPartyToX500NameAsStringConverter(identitySvc: () -> IdentityService) : AttributeConverter<AbstractParty, String> {
private val identityService: IdentityService by lazy {
identitySvc()
}
override fun convertToDatabaseColumn(party: AbstractParty?): String? { override fun convertToDatabaseColumn(party: AbstractParty?): String? {
party?.let { party?.let {
return identitySvc.partyFromAnonymous(party)?.toString() return identityService.partyFromAnonymous(party)?.toString()
} }
return null // non resolvable anonymous parties return null // non resolvable anonymous parties
} }
override fun convertToEntityAttribute(dbData: String?): AbstractParty? { override fun convertToEntityAttribute(dbData: String?): AbstractParty? {
dbData?.let { dbData?.let {
val party = identitySvc.partyFromX500Name(X500Name(dbData)) val party = identityService.partyFromX500Name(X500Name(dbData))
return party as AbstractParty return party as AbstractParty
} }
return null // non resolvable anonymous parties are stored as nulls return null // non resolvable anonymous parties are stored as nulls

View File

@ -81,6 +81,11 @@ interface SerializationContext {
*/ */
fun withWhitelisted(clazz: Class<*>): SerializationContext fun withWhitelisted(clazz: Class<*>): SerializationContext
/**
* Helper method to return a new context based on this context but with serialization using the format this header sequence represents.
*/
fun withPreferredSerializationVersion(versionHeader: ByteSequence): SerializationContext
/** /**
* The use case that we are serializing for, since it influences the implementations chosen. * The use case that we are serializing for, since it influences the implementations chosen.
*/ */

View File

@ -8,6 +8,8 @@ import net.corda.core.concurrent.CordaFuture
import net.corda.core.CordaRuntimeException import net.corda.core.CordaRuntimeException
import net.corda.core.serialization.ClassWhitelist import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.toObservable import net.corda.core.toObservable
import net.corda.nodeapi.config.OldConfig import net.corda.nodeapi.config.OldConfig
@ -47,7 +49,7 @@ class PermissionException(msg: String) : RuntimeException(msg)
// The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly. // The Kryo used for the RPC wire protocol. Every type in the wire protocol is listed here explicitly.
// This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes, // This is annoying to write out, but will make it easier to formalise the wire protocol when the time comes,
// because we can see everything we're using in one place. // because we can see everything we're using in one place.
class RPCKryo(observableSerializer: Serializer<Observable<*>>, whitelist: ClassWhitelist) : CordaKryo(CordaClassResolver(whitelist)) { class RPCKryo(observableSerializer: Serializer<Observable<*>>, val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : CordaKryo(CordaClassResolver(serializationFactory, serializationContext)) {
init { init {
DefaultKryoCustomizer.customize(this) DefaultKryoCustomizer.customize(this)

View File

@ -11,9 +11,22 @@ import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
private const val AMQP_ENABLED = false internal val AMQP_ENABLED get() = SerializationDefaults.P2P_CONTEXT.preferedSerializationVersion == AmqpHeaderV1_0
abstract class AbstractAMQPSerializationScheme : SerializationScheme { abstract class AbstractAMQPSerializationScheme : SerializationScheme {
internal companion object {
fun registerCustomSerializers(factory: SerializerFactory) {
factory.apply {
register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this))
}
}
}
private val serializerFactoriesForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, SerializerFactory>() private val serializerFactoriesForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, SerializerFactory>()
protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory protected abstract fun rpcClientSerializerFactory(context: SerializationContext): SerializerFactory
@ -30,7 +43,7 @@ abstract class AbstractAMQPSerializationScheme : SerializationScheme {
rpcServerSerializerFactory(context) rpcServerSerializerFactory(context)
else -> SerializerFactory(context.whitelist) // TODO pass class loader also else -> SerializerFactory(context.whitelist) // TODO pass class loader also
} }
} }.also { registerCustomSerializers(it) }
} }
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T { override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {

View File

@ -6,10 +6,9 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.FieldSerializer import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util import com.esotericsoftware.kryo.util.Util
import net.corda.core.serialization.AttachmentsClassLoader import net.corda.core.serialization.*
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0
import java.io.PrintWriter import java.io.PrintWriter
import java.lang.reflect.Modifier.isAbstract import java.lang.reflect.Modifier.isAbstract
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
@ -22,23 +21,13 @@ fun Kryo.addToWhitelist(type: Class<*>) {
((classResolver as? CordaClassResolver)?.whitelist as? MutableClassWhitelist)?.add(type) ((classResolver as? CordaClassResolver)?.whitelist as? MutableClassWhitelist)?.add(type)
} }
fun makeStandardClassResolver(): ClassResolver {
return CordaClassResolver(GlobalTransientClassWhiteList(BuiltInExceptionsWhitelist()))
}
fun makeNoWhitelistClassResolver(): ClassResolver {
return CordaClassResolver(AllWhitelist)
}
fun makeAllButBlacklistedClassResolver(): ClassResolver {
return CordaClassResolver(AllButBlacklisted)
}
/** /**
* @param amqpEnabled Setting this to true turns on experimental AMQP serialization for any class annotated with * @param amqpEnabled Setting this to true turns on experimental AMQP serialization for any class annotated with
* [CordaSerializable]. * [CordaSerializable].
*/ */
class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean = false) : DefaultClassResolver() { class CordaClassResolver(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : DefaultClassResolver() {
val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist)
/** Returns the registration for the specified class, or null if the class is not registered. */ /** Returns the registration for the specified class, or null if the class is not registered. */
override fun getRegistration(type: Class<*>): Registration? { override fun getRegistration(type: Class<*>): Registration? {
return super.getRegistration(type) ?: checkClass(type) return super.getRegistration(type) ?: checkClass(type)
@ -78,9 +67,9 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean
// If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the // If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the
// case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures // case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures
// are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph. // are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph.
if (checkForAnnotation(type) && amqpEnabled) { if (checkForAnnotation(type) && AMQP_ENABLED) {
// Build AMQP serializer // Build AMQP serializer
return register(Registration(type, KryoAMQPSerializer, NAME.toInt())) return register(Registration(type, KryoAMQPSerializer(serializationFactory, serializationContext), NAME.toInt()))
} }
val objectInstance = try { val objectInstance = try {
@ -179,6 +168,21 @@ class GlobalTransientClassWhiteList(val delegate: ClassWhitelist) : MutableClass
} }
} }
/**
* A whitelist that can be customised via the [CordaPluginRegistry], since implements [MutableClassWhitelist].
*/
class TransientClassWhiteList(val delegate: ClassWhitelist) : MutableClassWhitelist, ClassWhitelist by delegate {
val whitelist: MutableSet<String> = Collections.synchronizedSet(mutableSetOf())
override fun hasListed(type: Class<*>): Boolean {
return (type.name in whitelist) || delegate.hasListed(type)
}
override fun add(entry: Class<*>) {
whitelist += entry.name
}
}
/** /**
* This class is not currently used, but can be installed to log a large number of missing entries from the whitelist * This class is not currently used, but can be installed to log a large number of missing entries from the whitelist

View File

@ -46,10 +46,7 @@ import kotlin.collections.ArrayList
object DefaultKryoCustomizer { object DefaultKryoCustomizer {
private val pluginRegistries: List<CordaPluginRegistry> by lazy { private val pluginRegistries: List<CordaPluginRegistry> by lazy {
// No ClassResolver only constructor. MapReferenceResolver is the default as used by Kryo in other constructors. ServiceLoader.load(CordaPluginRegistry::class.java, this.javaClass.classLoader).toList()
val unusedKryo = Kryo(makeStandardClassResolver(), MapReferenceResolver())
val customization = KryoSerializationCustomization(unusedKryo)
ServiceLoader.load(CordaPluginRegistry::class.java, this.javaClass.classLoader).toList().filter { it.customizeSerialization(customization) }
} }
fun customize(kryo: Kryo): Kryo { fun customize(kryo: Kryo): Kryo {

View File

@ -4,7 +4,11 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializedBytes import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.sequence
import net.corda.nodeapi.internal.serialization.amqp.AmqpHeaderV1_0
import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput import net.corda.nodeapi.internal.serialization.amqp.DeserializationInput
import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput import net.corda.nodeapi.internal.serialization.amqp.SerializationOutput
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
@ -15,38 +19,19 @@ import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory
* *
* There is no need to write out the length, since this can be peeked out of the first few bytes of the stream. * There is no need to write out the length, since this can be peeked out of the first few bytes of the stream.
*/ */
object KryoAMQPSerializer : Serializer<Any>() { class KryoAMQPSerializer(val serializationFactory: SerializationFactory, val serializationContext: SerializationContext) : Serializer<Any>() {
internal fun registerCustomSerializers(factory: SerializerFactory) {
factory.apply {
register(net.corda.nodeapi.internal.serialization.amqp.custom.PublicKeySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.ThrowableSerializer(this))
register(net.corda.nodeapi.internal.serialization.amqp.custom.X500NameSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.BigDecimalSerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.CurrencySerializer)
register(net.corda.nodeapi.internal.serialization.amqp.custom.InstantSerializer(this))
}
}
// TODO: need to sort out the whitelist... we currently do not apply the whitelist attached to the [Kryo]
// instance to the factory. We need to do this before turning on AMQP serialization.
private val serializerFactory = SerializerFactory().apply {
registerCustomSerializers(this)
}
override fun write(kryo: Kryo, output: Output, obj: Any) { override fun write(kryo: Kryo, output: Output, obj: Any) {
val amqpOutput = SerializationOutput(serializerFactory) val bytes = serializationFactory.serialize(obj, serializationContext.withPreferredSerializationVersion(AmqpHeaderV1_0)).bytes
val bytes = amqpOutput.serialize(obj).bytes
// No need to write out the size since it's encoded within the AMQP. // No need to write out the size since it's encoded within the AMQP.
output.write(bytes) output.write(bytes)
} }
override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any { override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any {
val amqpInput = DeserializationInput(serializerFactory)
// Use our helper functions to peek the size of the serialized object out of the AMQP byte stream. // Use our helper functions to peek the size of the serialized object out of the AMQP byte stream.
val peekedBytes = input.readBytes(DeserializationInput.BYTES_NEEDED_TO_PEEK) val peekedBytes = input.readBytes(DeserializationInput.BYTES_NEEDED_TO_PEEK)
val size = DeserializationInput.peekSize(peekedBytes) val size = DeserializationInput.peekSize(peekedBytes)
val allBytes = peekedBytes.copyOf(size) val allBytes = peekedBytes.copyOf(size)
input.readBytes(allBytes, peekedBytes.size, size - peekedBytes.size) input.readBytes(allBytes, peekedBytes.size, size - peekedBytes.size)
return amqpInput.deserialize(SerializedBytes<Any>(allBytes), type) return serializationFactory.deserialize(allBytes.sequence(), type, serializationContext)
} }
} }

View File

@ -34,7 +34,6 @@ data class SerializationContextImpl(override val preferedSerializationVersion: B
override val properties: Map<Any, Any>, override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean, override val objectReferencesEnabled: Boolean,
override val useCase: SerializationContext.UseCase) : SerializationContext { override val useCase: SerializationContext.UseCase) : SerializationContext {
override fun withProperty(property: Any, value: Any): SerializationContext { override fun withProperty(property: Any, value: Any): SerializationContext {
return copy(properties = properties + (property to value)) return copy(properties = properties + (property to value))
} }
@ -52,6 +51,8 @@ data class SerializationContextImpl(override val preferedSerializationVersion: B
override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name
}) })
} }
override fun withPreferredSerializationVersion(versionHeader: ByteSequence) = copy(preferedSerializationVersion = versionHeader)
} }
private const val HEADER_SIZE: Int = 8 private const val HEADER_SIZE: Int = 8
@ -118,7 +119,7 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!") override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
} }
abstract class AbstractKryoSerializationScheme : SerializationScheme { abstract class AbstractKryoSerializationScheme(val serializationFactory: SerializationFactory) : SerializationScheme {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>() private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool
@ -130,7 +131,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
SerializationContext.UseCase.Checkpoint -> SerializationContext.UseCase.Checkpoint ->
KryoPool.Builder { KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = makeNoWhitelistClassResolver().apply { setKryo(serializer.kryo) } val classResolver = CordaClassResolver(serializationFactory, context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that // TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true } val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply { serializer.kryo.apply {
@ -146,7 +147,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
rpcServerKryoPool(context) rpcServerKryoPool(context)
else -> else ->
KryoPool.Builder { KryoPool.Builder {
DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context.whitelist))).apply { classLoader = it.second } DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(serializationFactory, context))).apply { classLoader = it.second }
}.build() }.build()
} }
} }

View File

@ -5,8 +5,8 @@ import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.util.MapReferenceResolver import com.esotericsoftware.kryo.util.MapReferenceResolver
import net.corda.core.node.services.AttachmentStorage import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.AttachmentsClassLoader import net.corda.core.serialization.*
import net.corda.core.serialization.CordaSerializable import net.corda.core.utilities.ByteSequence
import net.corda.nodeapi.AttachmentClassLoaderTests import net.corda.nodeapi.AttachmentClassLoaderTests
import net.corda.testing.node.MockAttachmentStorage import net.corda.testing.node.MockAttachmentStorage
import org.junit.Rule import org.junit.Rule
@ -76,71 +76,84 @@ class DefaultSerializableSerializer : Serializer<DefaultSerializable>() {
} }
class CordaClassResolverTests { class CordaClassResolverTests {
val factory: SerializationFactory = object : SerializationFactory {
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
TODO("not implemented") //To change body of created functions use File | Settings | File Templates.
}
}
val emptyWhitelistContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P)
val allButBlacklistedContext: SerializationContext = SerializationContextImpl(KryoHeaderV0_1, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P)
@Test @Test
fun `Annotation on enum works for specialised entries`() { fun `Annotation on enum works for specialised entries`() {
// TODO: Remove this suppress when we upgrade to kotlin 1.1 or when JetBrain fixes the bug. // TODO: Remove this suppress when we upgrade to kotlin 1.1 or when JetBrain fixes the bug.
@Suppress("UNSUPPORTED_FEATURE") @Suppress("UNSUPPORTED_FEATURE")
CordaClassResolver(EmptyWhitelist).getRegistration(Foo.Bar::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Foo.Bar::class.java)
} }
@Test @Test
fun `Annotation on array element works`() { fun `Annotation on array element works`() {
val values = arrayOf(Element()) val values = arrayOf(Element())
CordaClassResolver(EmptyWhitelist).getRegistration(values.javaClass) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(values.javaClass)
} }
@Test @Test
fun `Annotation not needed on abstract class`() { fun `Annotation not needed on abstract class`() {
CordaClassResolver(EmptyWhitelist).getRegistration(AbstractClass::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(AbstractClass::class.java)
} }
@Test @Test
fun `Annotation not needed on interface`() { fun `Annotation not needed on interface`() {
CordaClassResolver(EmptyWhitelist).getRegistration(Interface::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Interface::class.java)
} }
@Test @Test
fun `Calling register method on modified Kryo does not consult the whitelist`() { fun `Calling register method on modified Kryo does not consult the whitelist`() {
val kryo = CordaKryo(CordaClassResolver(EmptyWhitelist)) val kryo = CordaKryo(CordaClassResolver(factory, emptyWhitelistContext))
kryo.register(NotSerializable::class.java) kryo.register(NotSerializable::class.java)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `Calling register method on unmodified Kryo does consult the whitelist`() { fun `Calling register method on unmodified Kryo does consult the whitelist`() {
val kryo = Kryo(CordaClassResolver(EmptyWhitelist), MapReferenceResolver()) val kryo = Kryo(CordaClassResolver(factory, emptyWhitelistContext), MapReferenceResolver())
kryo.register(NotSerializable::class.java) kryo.register(NotSerializable::class.java)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `Annotation is needed without whitelisting`() { fun `Annotation is needed without whitelisting`() {
CordaClassResolver(EmptyWhitelist).getRegistration(NotSerializable::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(NotSerializable::class.java)
} }
@Test @Test
fun `Annotation is not needed with whitelisting`() { fun `Annotation is not needed with whitelisting`() {
val resolver = CordaClassResolver(GlobalTransientClassWhiteList(EmptyWhitelist)) val resolver = CordaClassResolver(factory, emptyWhitelistContext.withWhitelisted(NotSerializable::class.java))
(resolver.whitelist as MutableClassWhitelist).add(NotSerializable::class.java)
resolver.getRegistration(NotSerializable::class.java) resolver.getRegistration(NotSerializable::class.java)
} }
@Test @Test
fun `Annotation not needed on Object`() { fun `Annotation not needed on Object`() {
CordaClassResolver(EmptyWhitelist).getRegistration(Object::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Object::class.java)
} }
@Test @Test
fun `Annotation not needed on primitive`() { fun `Annotation not needed on primitive`() {
CordaClassResolver(EmptyWhitelist).getRegistration(Integer.TYPE) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(Integer.TYPE)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `Annotation does not work for custom serializable`() { fun `Annotation does not work for custom serializable`() {
CordaClassResolver(EmptyWhitelist).getRegistration(CustomSerializable::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(CustomSerializable::class.java)
} }
@Test(expected = KryoException::class) @Test(expected = KryoException::class)
fun `Annotation does not work in conjunction with Kryo annotation`() { fun `Annotation does not work in conjunction with Kryo annotation`() {
CordaClassResolver(EmptyWhitelist).getRegistration(DefaultSerializable::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(DefaultSerializable::class.java)
} }
private fun importJar(storage: AttachmentStorage) = AttachmentClassLoaderTests.ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) } private fun importJar(storage: AttachmentStorage) = AttachmentClassLoaderTests.ISOLATED_CONTRACTS_JAR_PATH.openStream().use { storage.importAttachment(it) }
@ -151,20 +164,20 @@ class CordaClassResolverTests {
val attachmentHash = importJar(storage) val attachmentHash = importJar(storage)
val classLoader = AttachmentsClassLoader(arrayOf(attachmentHash).map { storage.openAttachment(it)!! }) val classLoader = AttachmentsClassLoader(arrayOf(attachmentHash).map { storage.openAttachment(it)!! })
val attachedClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, classLoader) val attachedClass = Class.forName("net.corda.contracts.isolated.AnotherDummyContract", true, classLoader)
CordaClassResolver(EmptyWhitelist).getRegistration(attachedClass) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(attachedClass)
} }
@Test @Test
fun `Annotation is inherited from interfaces`() { fun `Annotation is inherited from interfaces`() {
CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaInterface::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaInterface::class.java)
CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaSubInterface::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSubInterface::class.java)
} }
@Test @Test
fun `Annotation is inherited from superclass`() { fun `Annotation is inherited from superclass`() {
CordaClassResolver(EmptyWhitelist).getRegistration(SubElement::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubElement::class.java)
CordaClassResolver(EmptyWhitelist).getRegistration(SubSubElement::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SubSubElement::class.java)
CordaClassResolver(EmptyWhitelist).getRegistration(SerializableViaSuperSubInterface::class.java) CordaClassResolver(factory, emptyWhitelistContext).getRegistration(SerializableViaSuperSubInterface::class.java)
} }
// Blacklist tests. // Blacklist tests.
@ -175,7 +188,7 @@ class CordaClassResolverTests {
fun `Check blacklisted class`() { fun `Check blacklisted class`() {
expectedEx.expect(IllegalStateException::class.java) expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("Class java.util.HashSet is blacklisted, so it cannot be used in serialization.") expectedEx.expectMessage("Class java.util.HashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(AllButBlacklisted) val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// HashSet is blacklisted. // HashSet is blacklisted.
resolver.getRegistration(HashSet::class.java) resolver.getRegistration(HashSet::class.java)
} }
@ -185,7 +198,7 @@ class CordaClassResolverTests {
fun `Check blacklisted subclass`() { fun `Check blacklisted subclass`() {
expectedEx.expect(IllegalStateException::class.java) expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubHashSet is blacklisted, so it cannot be used in serialization.") expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubHashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(AllButBlacklisted) val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// SubHashSet extends the blacklisted HashSet. // SubHashSet extends the blacklisted HashSet.
resolver.getRegistration(SubHashSet::class.java) resolver.getRegistration(SubHashSet::class.java)
} }
@ -195,7 +208,7 @@ class CordaClassResolverTests {
fun `Check blacklisted subsubclass`() { fun `Check blacklisted subsubclass`() {
expectedEx.expect(IllegalStateException::class.java) expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubSubHashSet is blacklisted, so it cannot be used in serialization.") expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubSubHashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(AllButBlacklisted) val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// SubSubHashSet extends SubHashSet, which extends the blacklisted HashSet. // SubSubHashSet extends SubHashSet, which extends the blacklisted HashSet.
resolver.getRegistration(SubSubHashSet::class.java) resolver.getRegistration(SubSubHashSet::class.java)
} }
@ -205,7 +218,7 @@ class CordaClassResolverTests {
fun `Check blacklisted interface impl`() { fun `Check blacklisted interface impl`() {
expectedEx.expect(IllegalStateException::class.java) expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$ConnectionImpl is blacklisted, so it cannot be used in serialization.") expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$ConnectionImpl is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(AllButBlacklisted) val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// ConnectionImpl implements blacklisted Connection. // ConnectionImpl implements blacklisted Connection.
resolver.getRegistration(ConnectionImpl::class.java) resolver.getRegistration(ConnectionImpl::class.java)
} }
@ -216,14 +229,14 @@ class CordaClassResolverTests {
fun `Check blacklisted super-interface impl`() { fun `Check blacklisted super-interface impl`() {
expectedEx.expect(IllegalStateException::class.java) expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubConnectionImpl is blacklisted, so it cannot be used in serialization.") expectedEx.expectMessage("The superinterface java.sql.Connection of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$SubConnectionImpl is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(AllButBlacklisted) val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// SubConnectionImpl implements SubConnection, which extends the blacklisted Connection. // SubConnectionImpl implements SubConnection, which extends the blacklisted Connection.
resolver.getRegistration(SubConnectionImpl::class.java) resolver.getRegistration(SubConnectionImpl::class.java)
} }
@Test @Test
fun `Check forcibly allowed`() { fun `Check forcibly allowed`() {
val resolver = CordaClassResolver(AllButBlacklisted) val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// LinkedHashSet is allowed for serialization. // LinkedHashSet is allowed for serialization.
resolver.getRegistration(LinkedHashSet::class.java) resolver.getRegistration(LinkedHashSet::class.java)
} }
@ -234,7 +247,7 @@ class CordaClassResolverTests {
fun `Check blacklist precedes CordaSerializable`() { fun `Check blacklist precedes CordaSerializable`() {
expectedEx.expect(IllegalStateException::class.java) expectedEx.expect(IllegalStateException::class.java)
expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$CordaSerializableHashSet is blacklisted, so it cannot be used in serialization.") expectedEx.expectMessage("The superclass java.util.HashSet of net.corda.nodeapi.internal.serialization.CordaClassResolverTests\$CordaSerializableHashSet is blacklisted, so it cannot be used in serialization.")
val resolver = CordaClassResolver(AllButBlacklisted) val resolver = CordaClassResolver(factory, allButBlacklistedContext)
// CordaSerializableHashSet is @CordaSerializable, but extends the blacklisted HashSet. // CordaSerializableHashSet is @CordaSerializable, but extends the blacklisted HashSet.
resolver.getRegistration(CordaSerializableHashSet::class.java) resolver.getRegistration(CordaSerializableHashSet::class.java)
} }

View File

@ -12,6 +12,7 @@ import net.corda.core.utilities.sequence
import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.node.services.persistence.NodeAttachmentService import net.corda.node.services.persistence.NodeAttachmentService
import net.corda.testing.ALICE_PUBKEY import net.corda.testing.ALICE_PUBKEY
import net.corda.testing.TestDependencyInjectionBase
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Before import org.junit.Before
@ -23,13 +24,13 @@ import java.time.Instant
import kotlin.test.assertEquals import kotlin.test.assertEquals
import kotlin.test.assertTrue import kotlin.test.assertTrue
class KryoTests { class KryoTests : TestDependencyInjectionBase() {
private lateinit var factory: SerializationFactory private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext private lateinit var context: SerializationContext
@Before @Before
fun setup() { fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
context = SerializationContextImpl(KryoHeaderV0_1, context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,
@ -199,7 +200,7 @@ class KryoTests {
} }
} }
Tmp() Tmp()
val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
val context = SerializationContextImpl(KryoHeaderV0_1, val context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,

View File

@ -8,19 +8,20 @@ import net.corda.core.node.ServiceHub
import net.corda.core.serialization.* import net.corda.core.serialization.*
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.testing.TestDependencyInjectionBase
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
import java.io.ByteArrayOutputStream import java.io.ByteArrayOutputStream
class SerializationTokenTest { class SerializationTokenTest : TestDependencyInjectionBase() {
lateinit var factory: SerializationFactory lateinit var factory: SerializationFactory
lateinit var context: SerializationContext lateinit var context: SerializationContext
@Before @Before
fun setup() { fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
context = SerializationContextImpl(KryoHeaderV0_1, context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,
@ -96,7 +97,7 @@ class SerializationTokenTest {
val context = serializeAsTokenContext(tokenizableBefore) val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context) val testContext = this.context.withTokenContext(context)
val kryo: Kryo = DefaultKryoCustomizer.customize(CordaKryo(makeNoWhitelistClassResolver())) val kryo: Kryo = DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(factory, this.context)))
val stream = ByteArrayOutputStream() val stream = ByteArrayOutputStream()
Output(stream).use { Output(stream).use {
it.write(KryoHeaderV0_1.bytes) it.write(KryoHeaderV0_1.bytes)

View File

@ -11,8 +11,8 @@ import net.corda.core.identity.AbstractParty
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.LedgerTransaction import net.corda.core.transactions.LedgerTransaction
import net.corda.nodeapi.RPCException import net.corda.nodeapi.RPCException
import net.corda.nodeapi.internal.serialization.AbstractAMQPSerializationScheme
import net.corda.nodeapi.internal.serialization.EmptyWhitelist import net.corda.nodeapi.internal.serialization.EmptyWhitelist
import net.corda.nodeapi.internal.serialization.KryoAMQPSerializer
import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive import net.corda.nodeapi.internal.serialization.amqp.SerializerFactory.Companion.isPrimitive
import net.corda.nodeapi.internal.serialization.amqp.custom.* import net.corda.nodeapi.internal.serialization.amqp.custom.*
import net.corda.testing.MEGA_CORP import net.corda.testing.MEGA_CORP
@ -528,10 +528,10 @@ class SerializationOutputTests {
val state = TransactionState<FooState>(FooState(), MEGA_CORP) val state = TransactionState<FooState>(FooState(), MEGA_CORP)
val factory = SerializerFactory() val factory = SerializerFactory()
KryoAMQPSerializer.registerCustomSerializers(factory) AbstractAMQPSerializationScheme.registerCustomSerializers(factory)
val factory2 = SerializerFactory() val factory2 = SerializerFactory()
KryoAMQPSerializer.registerCustomSerializers(factory2) AbstractAMQPSerializationScheme.registerCustomSerializers(factory2)
val desState = serdes(state, factory, factory2, expectedEqual = false, expectDeserializedEqual = false) val desState = serdes(state, factory, factory2, expectedEqual = false, expectDeserializedEqual = false)
assertTrue(desState is TransactionState<*>) assertTrue(desState is TransactionState<*>)

View File

@ -10,18 +10,13 @@ import com.google.common.collect.testing.features.MapFeature
import com.google.common.collect.testing.features.SetFeature import com.google.common.collect.testing.features.SetFeature
import com.google.common.collect.testing.testers.* import com.google.common.collect.testing.testers.*
import junit.framework.TestSuite import junit.framework.TestSuite
import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.*
import net.corda.testing.initialiseTestSerialization
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.resetTestSerialization
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.transactions.TransactionManager
import org.junit.* import org.junit.*
import org.junit.runner.RunWith import org.junit.runner.RunWith
import org.junit.runners.Suite import org.junit.runners.Suite
import java.sql.Connection
import java.util.* import java.util.*
@RunWith(Suite::class) @RunWith(Suite::class)
@ -47,7 +42,7 @@ class JDBCHashMapTestSuite {
@BeforeClass @BeforeClass
fun before() { fun before() {
initialiseTestSerialization() initialiseTestSerialization()
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = { throw UnsupportedOperationException("Identity Service should not be in use") })
setUpDatabaseTx() setUpDatabaseTx()
loadOnInitFalseMap = JDBCHashMap<String, String>("test_map_false", loadOnInit = false) loadOnInitFalseMap = JDBCHashMap<String, String>("test_map_false", loadOnInit = false)
memoryConstrainedMap = JDBCHashMap<String, String>("test_map_constrained", loadOnInit = false, maxBuckets = 1) memoryConstrainedMap = JDBCHashMap<String, String>("test_map_constrained", loadOnInit = false, maxBuckets = 1)
@ -233,7 +228,7 @@ class JDBCHashMapTestSuite {
@Before @Before
fun before() { fun before() {
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = { throw UnsupportedOperationException("Identity Service should not be in use") })
} }
@After @After

View File

@ -486,7 +486,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
private fun makeVaultObservers() { private fun makeVaultObservers() {
VaultSoftLockManager(services.vaultService, smm) VaultSoftLockManager(services.vaultService, smm)
ScheduledActivityObserver(services) ScheduledActivityObserver(services)
HibernateObserver(services.vaultService.rawUpdates, HibernateConfiguration(services.schemaService, configuration.database ?: Properties(), services.identityService)) HibernateObserver(services.vaultService.rawUpdates, HibernateConfiguration(services.schemaService, configuration.database ?: Properties(), {services.identityService}))
} }
private fun makeInfo(): NodeInfo { private fun makeInfo(): NodeInfo {
@ -545,7 +545,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
protected open fun initialiseDatabasePersistence(insideTransaction: () -> Unit) { protected open fun initialiseDatabasePersistence(insideTransaction: () -> Unit) {
val props = configuration.dataSourceProperties val props = configuration.dataSourceProperties
if (props.isNotEmpty()) { if (props.isNotEmpty()) {
this.database = configureDatabase(props, configuration.database) this.database = configureDatabase(props, configuration.database, identitySvc = { _services.identityService })
// Now log the vendor string as this will also cause a connection to be tested eagerly. // Now log the vendor string as this will also cause a connection to be tested eagerly.
database.transaction { database.transaction {
log.info("Connected to ${database.database.vendor} database.") log.info("Connected to ${database.database.vendor} database.")
@ -773,7 +773,7 @@ abstract class AbstractNode(open val configuration: NodeConfiguration,
override val networkMapCache by lazy { InMemoryNetworkMapCache(this) } override val networkMapCache by lazy { InMemoryNetworkMapCache(this) }
override val vaultService by lazy { NodeVaultService(this, configuration.dataSourceProperties, configuration.database) } override val vaultService by lazy { NodeVaultService(this, configuration.dataSourceProperties, configuration.database) }
override val vaultQueryService by lazy { override val vaultQueryService by lazy {
HibernateVaultQueryImpl(HibernateConfiguration(schemaService, configuration.database ?: Properties(), identityService), vaultService.updatesPublisher) HibernateVaultQueryImpl(HibernateConfiguration(schemaService, configuration.database ?: Properties(), { identityService }), vaultService.updatesPublisher)
} }
// Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because // Place the long term identity key in the KMS. Eventually, this is likely going to be separated again because
// the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with // the KMS is meant for derived temporary keys used in transactions, and we're not supposed to sign things with

View File

@ -331,7 +331,7 @@ open class Node(override val configuration: FullNodeConfiguration,
private fun initialiseSerialization() { private fun initialiseSerialization() {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoServerSerializationScheme()) registerScheme(KryoServerSerializationScheme(this))
registerScheme(AMQPServerSerializationScheme()) registerScheme(AMQPServerSerializationScheme())
} }
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT

View File

@ -2,6 +2,7 @@ package net.corda.node.serialization
import com.esotericsoftware.kryo.pool.KryoPool import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.node.services.messaging.RpcServerObservableSerializer import net.corda.node.services.messaging.RpcServerObservableSerializer
import net.corda.nodeapi.RPCKryo import net.corda.nodeapi.RPCKryo
@ -9,7 +10,7 @@ import net.corda.nodeapi.internal.serialization.AbstractKryoSerializationScheme
import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer import net.corda.nodeapi.internal.serialization.DefaultKryoCustomizer
import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1 import net.corda.nodeapi.internal.serialization.KryoHeaderV0_1
class KryoServerSerializationScheme : AbstractKryoSerializationScheme() { class KryoServerSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) {
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
return byteSequence == KryoHeaderV0_1 && target != SerializationContext.UseCase.RPCClient return byteSequence == KryoHeaderV0_1 && target != SerializationContext.UseCase.RPCClient
} }
@ -20,7 +21,7 @@ class KryoServerSerializationScheme : AbstractKryoSerializationScheme() {
override fun rpcServerKryoPool(context: SerializationContext): KryoPool { override fun rpcServerKryoPool(context: SerializationContext): KryoPool {
return KryoPool.Builder { return KryoPool.Builder {
DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, context.whitelist)).apply { classLoader = context.deserializationClassLoader } DefaultKryoCustomizer.customize(RPCKryo(RpcServerObservableSerializer, serializationFactory, context)).apply { classLoader = context.deserializationClassLoader }
}.build() }.build()
} }
} }

View File

@ -20,7 +20,7 @@ import java.sql.Connection
import java.util.* import java.util.*
import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.ConcurrentHashMap
class HibernateConfiguration(val schemaService: SchemaService, val databaseProperties: Properties, val identitySvc: IdentityService) { class HibernateConfiguration(val schemaService: SchemaService, val databaseProperties: Properties, private val identitySvc: () -> IdentityService) {
companion object { companion object {
val logger = loggerFor<HibernateConfiguration>() val logger = loggerFor<HibernateConfiguration>()
} }

View File

@ -6,14 +6,13 @@ import net.corda.core.identity.PartyAndCertificate
import net.corda.core.internal.ThreadBox import net.corda.core.internal.ThreadBox
import net.corda.core.node.services.IdentityService import net.corda.core.node.services.IdentityService
import net.corda.core.node.services.KeyManagementService import net.corda.core.node.services.KeyManagementService
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.*
import net.corda.node.utilities.* import net.corda.node.utilities.*
import org.bouncycastle.operator.ContentSigner import org.bouncycastle.operator.ContentSigner
import org.jetbrains.exposed.sql.ResultRow
import org.jetbrains.exposed.sql.statements.InsertStatement
import java.security.KeyPair import java.security.KeyPair
import java.security.PrivateKey import java.security.PrivateKey
import java.security.PublicKey import java.security.PublicKey
import javax.persistence.*
/** /**
* A persistent re-implementation of [E2ETestKeyManagementService] to support node re-start. * A persistent re-implementation of [E2ETestKeyManagementService] to support node re-start.
@ -25,60 +24,62 @@ import java.security.PublicKey
class PersistentKeyManagementService(val identityService: IdentityService, class PersistentKeyManagementService(val identityService: IdentityService,
initialKeys: Set<KeyPair>) : SingletonSerializeAsToken(), KeyManagementService { initialKeys: Set<KeyPair>) : SingletonSerializeAsToken(), KeyManagementService {
private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}our_key_pairs") { @Entity
val publicKey = publicKey("public_key") @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}our_key_pairs")
val privateKey = blob("private_key") class PersistentKey(
}
@Id
private class InnerState { @Column(name = "public_key")
val keys = object : AbstractJDBCHashMap<PublicKey, PrivateKey, Table>(Table, loadOnInit = false) { var publicKey: String = "",
override fun keyFromRow(row: ResultRow): PublicKey = row[table.publicKey]
@Lob
override fun valueFromRow(row: ResultRow): PrivateKey = deserializeFromBlob(row[table.privateKey]) @Column(name = "private_key")
var privateKey: ByteArray = ByteArray(0)
override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry<PublicKey, PrivateKey>, finalizables: MutableList<() -> Unit>) { )
insert[table.publicKey] = entry.key
} private companion object {
fun createKeyMap(): AppendOnlyPersistentMap<PublicKey, PrivateKey, PersistentKey, String> {
override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry<PublicKey, PrivateKey>, finalizables: MutableList<() -> Unit>) { return AppendOnlyPersistentMap(
insert[table.privateKey] = serializeToBlob(entry.value, finalizables) toPersistentEntityKey = { it.toBase58String() },
fromPersistentEntity = { Pair(parsePublicKeyBase58(it.publicKey),
it.privateKey.deserialize<PrivateKey>(context = SerializationDefaults.STORAGE_CONTEXT)) },
toPersistentEntity = { key: PublicKey, value: PrivateKey ->
PersistentKey().apply {
publicKey = key.toBase58String()
privateKey = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes
} }
},
persistentEntityClass = PersistentKey::class.java
)
} }
} }
private val mutex = ThreadBox(InnerState()) val keysMap = createKeyMap()
init { init {
mutex.locked { initialKeys.forEach({ it -> keysMap.addWithDuplicatesAllowed(it.public, it.private) })
keys.putAll(initialKeys.associate { Pair(it.public, it.private) })
}
} }
override val keys: Set<PublicKey> get() = mutex.locked { keys.keys } override val keys: Set<PublicKey> get() = keysMap.allPersisted().map { it.first }.toSet()
override fun filterMyKeys(candidateKeys: Iterable<PublicKey>): Iterable<PublicKey> { override fun filterMyKeys(candidateKeys: Iterable<PublicKey>): Iterable<PublicKey> =
return mutex.locked { candidateKeys.filter { it in this.keys } } candidateKeys.filter { keysMap[it] != null }
}
override fun freshKey(): PublicKey { override fun freshKey(): PublicKey {
val keyPair = generateKeyPair() val keyPair = generateKeyPair()
mutex.locked { keysMap[keyPair.public] = keyPair.private
keys[keyPair.public] = keyPair.private
}
return keyPair.public return keyPair.public
} }
override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymousPartyAndPath { override fun freshKeyAndCert(identity: PartyAndCertificate, revocationEnabled: Boolean): AnonymousPartyAndPath =
return freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled) freshCertificate(identityService, freshKey(), identity, getSigner(identity.owningKey), revocationEnabled)
}
private fun getSigner(publicKey: PublicKey): ContentSigner = getSigner(getSigningKeyPair(publicKey)) private fun getSigner(publicKey: PublicKey): ContentSigner = getSigner(getSigningKeyPair(publicKey))
//It looks for the PublicKey in the (potentially) CompositeKey that is ours, and then returns the associated PrivateKey to use in signing
private fun getSigningKeyPair(publicKey: PublicKey): KeyPair { private fun getSigningKeyPair(publicKey: PublicKey): KeyPair {
return mutex.locked { val pk = publicKey.keys.first { keysMap[it] != null } //TODO here for us to re-write this using an actual query if publicKey.keys.size > 1
val pk = publicKey.keys.first { keys.containsKey(it) } return KeyPair(pk, keysMap[pk]!!)
KeyPair(pk, keys[pk]!!)
}
} }
override fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey { override fun sign(bytes: ByteArray, publicKey: PublicKey): DigitalSignature.WithKey {

View File

@ -1,58 +1,61 @@
package net.corda.node.services.persistence package net.corda.node.services.persistence
import net.corda.core.crypto.SecureHash import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationDefaults.CHECKPOINT_CONTEXT
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.node.services.api.Checkpoint import net.corda.node.services.api.Checkpoint
import net.corda.node.services.api.CheckpointStorage import net.corda.node.services.api.CheckpointStorage
import net.corda.node.utilities.* import net.corda.node.utilities.*
import org.jetbrains.exposed.sql.ResultRow import javax.persistence.Column
import org.jetbrains.exposed.sql.statements.InsertStatement import javax.persistence.Entity
import java.util.Collections.synchronizedMap import javax.persistence.Id
import javax.persistence.Lob
/** /**
* Simple checkpoint key value storage in DB using the underlying JDBCHashMap and transactional context of the call sites. * Simple checkpoint key value storage in DB.
*/ */
class DBCheckpointStorage : CheckpointStorage { class DBCheckpointStorage : CheckpointStorage {
private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}checkpoints") { @Entity
val checkpointId = secureHash("checkpoint_id") @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}checkpoints")
val checkpoint = blob("checkpoint") class DBCheckpoint(
} @Id
@Column(name = "checkpoint_id", length = 64)
var checkpointId: String = "",
private class CheckpointMap : AbstractJDBCHashMap<SecureHash, SerializedBytes<Checkpoint>, Table>(Table, loadOnInit = false) { @Lob
override fun keyFromRow(row: ResultRow): SecureHash = row[table.checkpointId] @Column(name = "checkpoint")
var checkpoint: ByteArray = ByteArray(0)
)
override fun valueFromRow(row: ResultRow): SerializedBytes<Checkpoint> = bytesFromBlob(row[table.checkpoint]) override fun addCheckpoint(value: Checkpoint) {
val session = DatabaseTransactionManager.current().session
override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry<SecureHash, SerializedBytes<Checkpoint>>, finalizables: MutableList<() -> Unit>) { session.save(DBCheckpoint().apply {
insert[table.checkpointId] = entry.key checkpointId = value.id.toString()
} checkpoint = value.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).bytes
})
override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry<SecureHash, SerializedBytes<Checkpoint>>, finalizables: MutableList<() -> Unit>) {
insert[table.checkpoint] = bytesToBlob(entry.value, finalizables)
}
}
private val checkpointStorage = synchronizedMap(CheckpointMap())
override fun addCheckpoint(checkpoint: Checkpoint) {
checkpointStorage.put(checkpoint.id, checkpoint.serialize(context = CHECKPOINT_CONTEXT))
} }
override fun removeCheckpoint(checkpoint: Checkpoint) { override fun removeCheckpoint(checkpoint: Checkpoint) {
checkpointStorage.remove(checkpoint.id) ?: throw IllegalArgumentException("Checkpoint not found") val session = DatabaseTransactionManager.current().session
val criteriaBuilder = session.criteriaBuilder
val delete = criteriaBuilder.createCriteriaDelete(DBCheckpoint::class.java)
val root = delete.from(DBCheckpoint::class.java)
delete.where(criteriaBuilder.equal(root.get<String>(DBCheckpoint::checkpointId.name), checkpoint.id.toString()))
session.createQuery(delete).executeUpdate()
} }
override fun forEach(block: (Checkpoint) -> Boolean) { override fun forEach(block: (Checkpoint) -> Boolean) {
synchronized(checkpointStorage) { val session = DatabaseTransactionManager.current().session
for (checkpoint in checkpointStorage.values) { val criteriaQuery = session.criteriaBuilder.createQuery(DBCheckpoint::class.java)
if (!block(checkpoint.deserialize(context = CHECKPOINT_CONTEXT))) { val root = criteriaQuery.from(DBCheckpoint::class.java)
criteriaQuery.select(root)
val query = session.createQuery(criteriaQuery)
val checkpoints = query.resultList.map { e -> e.checkpoint.deserialize<Checkpoint>(context = SerializationDefaults.CHECKPOINT_CONTEXT) }.asSequence()
for (e in checkpoints) {
if (!block(e)) {
break break
} }
} }
} }
} }
}

View File

@ -1,6 +1,5 @@
package net.corda.node.services.persistence package net.corda.node.services.persistence
import net.corda.core.internal.ThreadBox
import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.flows.StateMachineRunId import net.corda.core.flows.StateMachineRunId
@ -8,59 +7,57 @@ import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.StateMachineTransactionMapping import net.corda.core.messaging.StateMachineTransactionMapping
import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage import net.corda.node.services.api.StateMachineRecordedTransactionMappingStorage
import net.corda.node.utilities.* import net.corda.node.utilities.*
import org.jetbrains.exposed.sql.ResultRow
import org.jetbrains.exposed.sql.statements.InsertStatement
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.util.*
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import javax.persistence.*
/** /**
* Database storage of a txhash -> state machine id mapping. * Database storage of a txhash -> state machine id mapping.
* *
* Mappings are added as transactions are persisted by [ServiceHub.recordTransaction], and never deleted. Used in the * Mappings are added as transactions are persisted by [ServiceHub.recordTransaction], and never deleted. Used in the
* RPC API to correlate transaction creation with flows. * RPC API to correlate transaction creation with flows.
*
*/ */
@ThreadSafe @ThreadSafe
class DBTransactionMappingStorage : StateMachineRecordedTransactionMappingStorage { class DBTransactionMappingStorage : StateMachineRecordedTransactionMappingStorage {
private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}transaction_mappings") { @Entity
val txId = secureHash("tx_id") @javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}transaction_mappings")
val stateMachineRunId = uuidString("state_machine_run_id") class DBTransactionMapping(
} @Id
@Column(name = "tx_id", length = 64)
var txId: String = "",
private class TransactionMappingsMap : AbstractJDBCHashMap<SecureHash, StateMachineRunId, Table>(Table, loadOnInit = false) { @Column(name = "state_machine_run_id", length = 36)
override fun keyFromRow(row: ResultRow): SecureHash = row[table.txId] var stateMachineRunId: String = ""
)
override fun valueFromRow(row: ResultRow): StateMachineRunId = StateMachineRunId(row[table.stateMachineRunId]) private companion object {
fun createMap(): AppendOnlyPersistentMap<SecureHash, StateMachineRunId, DBTransactionMapping, String> {
override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry<SecureHash, StateMachineRunId>, finalizables: MutableList<() -> Unit>) { return AppendOnlyPersistentMap(
insert[table.txId] = entry.key toPersistentEntityKey = { it.toString() },
fromPersistentEntity = { Pair(SecureHash.parse(it.txId), StateMachineRunId(UUID.fromString(it.stateMachineRunId))) },
toPersistentEntity = { key: SecureHash, value: StateMachineRunId ->
DBTransactionMapping().apply {
txId = key.toString()
stateMachineRunId = value.uuid.toString()
} }
},
override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry<SecureHash, StateMachineRunId>, finalizables: MutableList<() -> Unit>) { persistentEntityClass = DBTransactionMapping::class.java
insert[table.stateMachineRunId] = entry.value.uuid
}
}
private class InnerState {
val stateMachineTransactionMap = TransactionMappingsMap()
val updates: PublishSubject<StateMachineTransactionMapping> = PublishSubject.create()
}
private val mutex = ThreadBox(InnerState())
override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) {
mutex.locked {
stateMachineTransactionMap[transactionId] = stateMachineRunId
updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId))
}
}
override fun track(): DataFeed<List<StateMachineTransactionMapping>, StateMachineTransactionMapping> {
mutex.locked {
return DataFeed(
stateMachineTransactionMap.map { StateMachineTransactionMapping(it.value, it.key) },
updates.bufferUntilSubscribed().wrapWithDatabaseTransaction()
) )
} }
} }
val stateMachineTransactionMap = createMap()
val updates: PublishSubject<StateMachineTransactionMapping> = PublishSubject.create()
override fun addMapping(stateMachineRunId: StateMachineRunId, transactionId: SecureHash) {
stateMachineTransactionMap[transactionId] = stateMachineRunId
updates.bufferUntilDatabaseCommit().onNext(StateMachineTransactionMapping(stateMachineRunId, transactionId))
}
override fun track(): DataFeed<List<StateMachineTransactionMapping>, StateMachineTransactionMapping> =
DataFeed(stateMachineTransactionMap.allPersisted().map { StateMachineTransactionMapping(it.second, it.first) }.toList(),
updates.bufferUntilSubscribed().wrapWithDatabaseTransaction())
} }

View File

@ -4,73 +4,60 @@ import com.google.common.annotations.VisibleForTesting
import net.corda.core.internal.bufferUntilSubscribed import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.messaging.DataFeed import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.*
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.api.WritableTransactionStorage import net.corda.node.services.api.WritableTransactionStorage
import net.corda.node.utilities.* import net.corda.node.utilities.*
import org.jetbrains.exposed.sql.ResultRow
import org.jetbrains.exposed.sql.exposedLogger
import org.jetbrains.exposed.sql.statements.InsertStatement
import rx.Observable import rx.Observable
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.util.Collections.synchronizedMap import javax.persistence.*
class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() { class DBTransactionStorage : WritableTransactionStorage, SingletonSerializeAsToken() {
private object Table : JDBCHashedTable("${NODE_DATABASE_PREFIX}transactions") {
val txId = secureHash("tx_id") @Entity
val transaction = blob("transaction") @Table(name = "${NODE_DATABASE_PREFIX}transactions")
class DBTransaction(
@Id
@Column(name = "tx_id", length = 64)
var txId: String = "",
@Lob
@Column
var transaction: ByteArray = ByteArray(0)
)
private companion object {
fun createTransactionsMap(): AppendOnlyPersistentMap<SecureHash, SignedTransaction, DBTransaction, String> {
return AppendOnlyPersistentMap(
toPersistentEntityKey = { it.toString() },
fromPersistentEntity = { Pair(SecureHash.parse(it.txId),
it.transaction.deserialize<SignedTransaction>( context = SerializationDefaults.STORAGE_CONTEXT)) },
toPersistentEntity = { key: SecureHash, value: SignedTransaction ->
DBTransaction().apply {
txId = key.toString()
transaction = value.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes
} }
},
private class TransactionsMap : AbstractJDBCHashMap<SecureHash, SignedTransaction, Table>(Table, loadOnInit = false) { persistentEntityClass = DBTransaction::class.java
override fun keyFromRow(row: ResultRow): SecureHash = row[table.txId] )
override fun valueFromRow(row: ResultRow): SignedTransaction = deserializeFromBlob(row[table.transaction])
override fun addKeyToInsert(insert: InsertStatement, entry: Map.Entry<SecureHash, SignedTransaction>, finalizables: MutableList<() -> Unit>) {
insert[table.txId] = entry.key
}
override fun addValueToInsert(insert: InsertStatement, entry: Map.Entry<SecureHash, SignedTransaction>, finalizables: MutableList<() -> Unit>) {
insert[table.transaction] = serializeToBlob(entry.value, finalizables)
} }
} }
private val txStorage = synchronizedMap(TransactionsMap()) private val txStorage = createTransactionsMap()
override fun addTransaction(transaction: SignedTransaction): Boolean { override fun addTransaction(transaction: SignedTransaction): Boolean =
val recorded = synchronized(txStorage) { txStorage.addWithDuplicatesAllowed(transaction.id, transaction).apply {
val old = txStorage[transaction.id]
if (old == null) {
txStorage.put(transaction.id, transaction)
updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction) updatesPublisher.bufferUntilDatabaseCommit().onNext(transaction)
true
} else {
false
}
}
if (!recorded) {
exposedLogger.warn("Duplicate recording of transaction ${transaction.id}")
}
return recorded
} }
override fun getTransaction(id: SecureHash): SignedTransaction? { override fun getTransaction(id: SecureHash): SignedTransaction? = txStorage[id]
synchronized(txStorage) {
return txStorage[id]
}
}
private val updatesPublisher = PublishSubject.create<SignedTransaction>().toSerialized() private val updatesPublisher = PublishSubject.create<SignedTransaction>().toSerialized()
override val updates: Observable<SignedTransaction> = updatesPublisher.wrapWithDatabaseTransaction() override val updates: Observable<SignedTransaction> = updatesPublisher.wrapWithDatabaseTransaction()
override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> { override fun track(): DataFeed<List<SignedTransaction>, SignedTransaction> =
synchronized(txStorage) { DataFeed(txStorage.allPersisted().map { it.second }.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
return DataFeed(txStorage.values.toList(), updatesPublisher.bufferUntilSubscribed().wrapWithDatabaseTransaction())
}
}
@VisibleForTesting @VisibleForTesting
val transactions: Iterable<SignedTransaction> get() = synchronized(txStorage) { val transactions: Iterable<SignedTransaction> get() = txStorage.allPersisted().map { it.second }.toList()
txStorage.values.toList()
}
} }

View File

@ -10,6 +10,11 @@ import net.corda.core.schemas.QueryableState
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.node.services.api.SchemaService import net.corda.node.services.api.SchemaService
import net.corda.core.schemas.CommonSchemaV1 import net.corda.core.schemas.CommonSchemaV1
import net.corda.node.services.keys.PersistentKeyManagementService
import net.corda.node.services.persistence.DBCheckpointStorage
import net.corda.node.services.persistence.DBTransactionMappingStorage
import net.corda.node.services.persistence.DBTransactionStorage
import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.services.vault.VaultSchemaV1
import net.corda.schemas.CashSchemaV1 import net.corda.schemas.CashSchemaV1
@ -23,14 +28,25 @@ import net.corda.schemas.CashSchemaV1
*/ */
class NodeSchemaService(customSchemas: Set<MappedSchema> = emptySet()) : SchemaService, SingletonSerializeAsToken() { class NodeSchemaService(customSchemas: Set<MappedSchema> = emptySet()) : SchemaService, SingletonSerializeAsToken() {
// Currently does not support configuring schema options. // Entities for compulsory services
object NodeServices
object NodeServicesV1 : MappedSchema(schemaFamily = NodeServices.javaClass, version = 1,
mappedTypes = listOf(DBCheckpointStorage.DBCheckpoint::class.java,
DBTransactionStorage.DBTransaction::class.java,
DBTransactionMappingStorage.DBTransactionMapping::class.java,
PersistentKeyManagementService.PersistentKey::class.java,
PersistentUniquenessProvider.PersistentUniqueness::class.java
))
// Required schemas are those used by internal Corda services // Required schemas are those used by internal Corda services
// For example, cash is used by the vault for coin selection (but will be extracted as a standalone CorDapp in future) // For example, cash is used by the vault for coin selection (but will be extracted as a standalone CorDapp in future)
val requiredSchemas: Map<MappedSchema, SchemaService.SchemaOptions> = val requiredSchemas: Map<MappedSchema, SchemaService.SchemaOptions> =
mapOf(Pair(CashSchemaV1, SchemaService.SchemaOptions()), mapOf(Pair(CashSchemaV1, SchemaService.SchemaOptions()),
Pair(CommonSchemaV1, SchemaService.SchemaOptions()), Pair(CommonSchemaV1, SchemaService.SchemaOptions()),
Pair(VaultSchemaV1, SchemaService.SchemaOptions())) Pair(VaultSchemaV1, SchemaService.SchemaOptions()),
Pair(NodeServicesV1, SchemaService.SchemaOptions()))
override val schemaOptions: Map<MappedSchema, SchemaService.SchemaOptions> = requiredSchemas.plus(customSchemas.map { override val schemaOptions: Map<MappedSchema, SchemaService.SchemaOptions> = requiredSchemas.plus(customSchemas.map {
mappedSchema -> Pair(mappedSchema, SchemaService.SchemaOptions()) mappedSchema -> Pair(mappedSchema, SchemaService.SchemaOptions())

View File

@ -1,69 +1,97 @@
package net.corda.node.services.transactions package net.corda.node.services.transactions
import net.corda.core.internal.ThreadBox
import net.corda.core.contracts.StateRef import net.corda.core.contracts.StateRef
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.parsePublicKeyBase58
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.internal.ThreadBox
import net.corda.core.node.services.UniquenessException import net.corda.core.node.services.UniquenessException
import net.corda.core.node.services.UniquenessProvider import net.corda.core.node.services.UniquenessProvider
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import net.corda.node.utilities.* import net.corda.node.utilities.*
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.jetbrains.exposed.sql.ResultRow import java.io.Serializable
import org.jetbrains.exposed.sql.statements.InsertStatement
import java.util.* import java.util.*
import javax.annotation.concurrent.ThreadSafe import javax.annotation.concurrent.ThreadSafe
import javax.persistence.*
/** A RDBMS backed Uniqueness provider */ /** A RDBMS backed Uniqueness provider */
@ThreadSafe @ThreadSafe
class PersistentUniquenessProvider : UniquenessProvider, SingletonSerializeAsToken() { class PersistentUniquenessProvider : UniquenessProvider, SingletonSerializeAsToken() {
@Entity
@javax.persistence.Table(name = "${NODE_DATABASE_PREFIX}notary_commit_log")
class PersistentUniqueness (
@EmbeddedId
var id: StateRef = StateRef(),
@Column(name = "consuming_transaction_id")
var consumingTxHash: String = "",
@Column(name = "consuming_input_index", length = 36)
var consumingIndex: Int = 0,
@Embedded
var party: Party = Party()
) {
@Embeddable
data class StateRef (
@Column(name = "transaction_id")
var txId: String = "",
@Column(name = "output_index", length = 36)
var index: Int = 0
) : Serializable
@Embeddable
data class Party (
@Column(name = "requesting_party_name")
var name: String = "",
@Column(name = "requesting_party_key", length = 255)
var owningKey: String = ""
) : Serializable
}
private class InnerState {
val committedStates = createMap()
}
private val mutex = ThreadBox(InnerState())
companion object { companion object {
private val TABLE_NAME = "${NODE_DATABASE_PREFIX}notary_commit_log"
private val log = loggerFor<PersistentUniquenessProvider>() private val log = loggerFor<PersistentUniquenessProvider>()
fun createMap(): AppendOnlyPersistentMap<StateRef, UniquenessProvider.ConsumingTx, PersistentUniqueness, PersistentUniqueness.StateRef> {
return AppendOnlyPersistentMap(
toPersistentEntityKey = { PersistentUniqueness.StateRef(it.txhash.toString(), it.index) },
fromPersistentEntity = {
Pair(StateRef(SecureHash.parse(it.id.txId), it.id.index),
UniquenessProvider.ConsumingTx(SecureHash.parse(it.consumingTxHash), it.consumingIndex,
Party(X500Name(it.party.name), parsePublicKeyBase58(it.party.owningKey))))
},
toPersistentEntity = { key: StateRef, value: UniquenessProvider.ConsumingTx ->
PersistentUniqueness().apply {
id = PersistentUniqueness.StateRef(key.txhash.toString(), key.index)
consumingTxHash = value.id.toString()
consumingIndex = value.inputIndex
party = PersistentUniqueness.Party(value.requestingParty.name.toString())
} }
},
/** persistentEntityClass = PersistentUniqueness::class.java
* For each input state store the consuming transaction information.
*/
private object Table : JDBCHashedTable(TABLE_NAME) {
val output = stateRef("transaction_id", "output_index")
val consumingTxHash = secureHash("consuming_transaction_id")
val consumingIndex = integer("consuming_input_index")
val requestingParty = party("requesting_party_name", "requesting_party_key")
}
private val committedStates = ThreadBox(object : AbstractJDBCHashMap<StateRef, UniquenessProvider.ConsumingTx, Table>(Table, loadOnInit = false) {
override fun keyFromRow(row: ResultRow): StateRef = StateRef(row[table.output.txId], row[table.output.index])
override fun valueFromRow(row: ResultRow): UniquenessProvider.ConsumingTx = UniquenessProvider.ConsumingTx(
row[table.consumingTxHash],
row[table.consumingIndex],
Party(X500Name(row[table.requestingParty.name]), row[table.requestingParty.owningKey])
) )
override fun addKeyToInsert(insert: InsertStatement,
entry: Map.Entry<StateRef, UniquenessProvider.ConsumingTx>,
finalizables: MutableList<() -> Unit>) {
insert[table.output.txId] = entry.key.txhash
insert[table.output.index] = entry.key.index
} }
override fun addValueToInsert(insert: InsertStatement,
entry: Map.Entry<StateRef, UniquenessProvider.ConsumingTx>,
finalizables: MutableList<() -> Unit>) {
insert[table.consumingTxHash] = entry.value.id
insert[table.consumingIndex] = entry.value.inputIndex
insert[table.requestingParty.name] = entry.value.requestingParty.name.toString()
insert[table.requestingParty.owningKey] = entry.value.requestingParty.owningKey
} }
})
override fun commit(states: List<StateRef>, txId: SecureHash, callerIdentity: Party) { override fun commit(states: List<StateRef>, txId: SecureHash, callerIdentity: Party) {
val conflict = committedStates.locked {
val conflict = mutex.locked {
val conflictingStates = LinkedHashMap<StateRef, UniquenessProvider.ConsumingTx>() val conflictingStates = LinkedHashMap<StateRef, UniquenessProvider.ConsumingTx>()
for (inputState in states) { for (inputState in states) {
val consumingTx = get(inputState) val consumingTx = committedStates.get(inputState)
if (consumingTx != null) conflictingStates[inputState] = consumingTx if (consumingTx != null) conflictingStates[inputState] = consumingTx
} }
if (conflictingStates.isNotEmpty()) { if (conflictingStates.isNotEmpty()) {
@ -71,7 +99,7 @@ class PersistentUniquenessProvider : UniquenessProvider, SingletonSerializeAsTok
UniquenessProvider.Conflict(conflictingStates) UniquenessProvider.Conflict(conflictingStates)
} else { } else {
states.forEachIndexed { i, stateRef -> states.forEachIndexed { i, stateRef ->
put(stateRef, UniquenessProvider.ConsumingTx(txId, i, callerIdentity)) committedStates[stateRef] = UniquenessProvider.ConsumingTx(txId, i, callerIdentity)
} }
log.debug("Successfully committed all input states: $states") log.debug("Successfully committed all input states: $states")
null null

View File

@ -0,0 +1,113 @@
package net.corda.node.utilities
import net.corda.core.utilities.loggerFor
import java.util.*
/**
* Implements a caching layer on top of an *append-only* table accessed via Hibernate mapping. Note that if the same key is [put] twice the
* behaviour is unpredictable! There is a best-effort check for double inserts, but this should *not* be relied on, so
* ONLY USE THIS IF YOUR TABLE IS APPEND-ONLY
*/
class AppendOnlyPersistentMap<K, V, E, EK> (
val toPersistentEntityKey: (K) -> EK,
val fromPersistentEntity: (E) -> Pair<K,V>,
val toPersistentEntity: (key: K, value: V) -> E,
val persistentEntityClass: Class<E>,
cacheBound: Long = 1024
) { //TODO determine cacheBound based on entity class later or with node config allowing tuning, or using some heuristic based on heap size
private companion object {
val log = loggerFor<AppendOnlyPersistentMap<*, *, *, *>>()
}
private val cache = NonInvalidatingCache<K, Optional<V>>(
bound = cacheBound,
concurrencyLevel = 8,
loadFunction = { key -> Optional.ofNullable(loadValue(key)) }
)
/**
* Returns the value associated with the key, first loading that value from the storage if necessary.
*/
operator fun get(key: K): V? {
return cache.get(key).orElse(null)
}
/**
* Returns all key/value pairs from the underlying storage.
*/
fun allPersisted(): Sequence<Pair<K, V>> {
val criteriaQuery = DatabaseTransactionManager.current().session.criteriaBuilder.createQuery(persistentEntityClass)
val root = criteriaQuery.from(persistentEntityClass)
criteriaQuery.select(root)
val query = DatabaseTransactionManager.current().session.createQuery(criteriaQuery)
val result = query.resultList
return result.map { x -> fromPersistentEntity(x) }.asSequence()
}
private tailrec fun set(key: K, value: V, logWarning: Boolean = true, store: (K,V) -> V?): Boolean {
var insertionAttempt = false
var isUnique = true
val existingInCache = cache.get(key) { // Thread safe, if multiple threads may wait until the first one has loaded.
insertionAttempt = true
// Key wasn't in the cache and might be in the underlying storage.
// Depending on 'store' method, this may insert without checking key duplication or it may avoid inserting a duplicated key.
val existingInDb = store(key, value)
if (existingInDb != null) { // Always reuse an existing value from the storage of a duplicated key.
Optional.of(existingInDb)
} else {
Optional.of(value)
}
}
if (!insertionAttempt) {
if (existingInCache.isPresent) {
// Key already exists in cache, do nothing.
isUnique = false
} else {
// This happens when the key was queried before with no value associated. We invalidate the cached null
// value and recursively call set again. This is to avoid race conditions where another thread queries after
// the invalidate but before the set.
cache.invalidate(key)
return set(key, value, logWarning, store)
}
}
if (logWarning && !isUnique) {
log.warn("Double insert in ${this.javaClass.name} for entity class $persistentEntityClass key $key, not inserting the second time")
}
return isUnique
}
/**
* Puts the value into the map and the underlying storage.
* Inserting the duplicated key may be unpredictable.
*/
operator fun set(key: K, value: V) =
set(key, value, logWarning = false) {
key,value -> DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value))
null
}
/**
* Puts the value into the map and underlying storage.
* Duplicated key is not added into the map and underlying storage.
* @return true if added key was unique, otherwise false
*/
fun addWithDuplicatesAllowed(key: K, value: V): Boolean =
set(key, value) {
key, value ->
val existingEntry = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key))
if (existingEntry == null) {
DatabaseTransactionManager.current().session.save(toPersistentEntity(key,value))
null
} else {
fromPersistentEntity(existingEntry).second
}
}
private fun loadValue(key: K): V? {
val result = DatabaseTransactionManager.current().session.find(persistentEntityClass, toPersistentEntityKey(key))
return result?.let(fromPersistentEntity)?.second
}
}

View File

@ -2,6 +2,11 @@ package net.corda.node.utilities
import com.zaxxer.hikari.HikariConfig import com.zaxxer.hikari.HikariConfig
import com.zaxxer.hikari.HikariDataSource import com.zaxxer.hikari.HikariDataSource
import net.corda.core.node.services.IdentityService
import net.corda.core.schemas.MappedSchema
import net.corda.node.services.database.HibernateConfiguration
import net.corda.node.services.schema.NodeSchemaService
import org.hibernate.SessionFactory
import org.jetbrains.exposed.sql.Database import org.jetbrains.exposed.sql.Database
import rx.Observable import rx.Observable
@ -15,15 +20,21 @@ import java.util.concurrent.CopyOnWriteArrayList
//HikariDataSource implements Closeable which allows CordaPersistence to be Closeable //HikariDataSource implements Closeable which allows CordaPersistence to be Closeable
class CordaPersistence(var dataSource: HikariDataSource, databaseProperties: Properties): Closeable { class CordaPersistence(var dataSource: HikariDataSource, var nodeSchemaService: NodeSchemaService, val identitySvc: ()-> IdentityService, databaseProperties: Properties): Closeable {
/** Holds Exposed database, the field will be removed once Exposed library is removed */ /** Holds Exposed database, the field will be removed once Exposed library is removed */
lateinit var database: Database lateinit var database: Database
var transactionIsolationLevel = parserTransactionIsolationLevel(databaseProperties.getProperty("transactionIsolationLevel")) var transactionIsolationLevel = parserTransactionIsolationLevel(databaseProperties.getProperty("transactionIsolationLevel"))
val entityManagerFactory: SessionFactory by lazy(LazyThreadSafetyMode.NONE) {
transaction {
HibernateConfiguration(nodeSchemaService, databaseProperties, identitySvc).sessionFactoryForRegisteredSchemas()
}
}
companion object { companion object {
fun connect(dataSource: HikariDataSource, databaseProperties: Properties): CordaPersistence { fun connect(dataSource: HikariDataSource, nodeSchemaService: NodeSchemaService, identitySvc: () -> IdentityService, databaseProperties: Properties): CordaPersistence {
return CordaPersistence(dataSource, databaseProperties).apply { return CordaPersistence(dataSource, nodeSchemaService, identitySvc, databaseProperties).apply {
DatabaseTransactionManager(this) DatabaseTransactionManager(this)
} }
} }
@ -89,10 +100,10 @@ class CordaPersistence(var dataSource: HikariDataSource, databaseProperties: Pro
} }
} }
fun configureDatabase(dataSourceProperties: Properties, databaseProperties: Properties?): CordaPersistence { fun configureDatabase(dataSourceProperties: Properties, databaseProperties: Properties?, entitySchemas: Set<MappedSchema> = emptySet<MappedSchema>(), identitySvc: ()-> IdentityService): CordaPersistence {
val config = HikariConfig(dataSourceProperties) val config = HikariConfig(dataSourceProperties)
val dataSource = HikariDataSource(config) val dataSource = HikariDataSource(config)
val persistence = CordaPersistence.connect(dataSource, databaseProperties ?: Properties()) val persistence = CordaPersistence.connect(dataSource, NodeSchemaService(entitySchemas), identitySvc, databaseProperties ?: Properties())
//org.jetbrains.exposed.sql.Database will be removed once Exposed library is removed //org.jetbrains.exposed.sql.Database will be removed once Exposed library is removed
val database = Database.connect(dataSource) { _ -> ExposedTransactionManager() } val database = Database.connect(dataSource) { _ -> ExposedTransactionManager() }
@ -191,7 +202,6 @@ fun <T : Any> rx.Observable<T>.wrapWithDatabaseTransaction(db: CordaPersistence?
} }
} }
fun parserTransactionIsolationLevel(property: String?): Int = fun parserTransactionIsolationLevel(property: String?): Int =
when (property) { when (property) {
"none" -> Connection.TRANSACTION_NONE "none" -> Connection.TRANSACTION_NONE

View File

@ -1,6 +1,8 @@
package net.corda.node.utilities package net.corda.node.utilities
import co.paralleluniverse.strands.Strand import co.paralleluniverse.strands.Strand
import org.hibernate.Session
import org.hibernate.Transaction
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import rx.subjects.Subject import rx.subjects.Subject
import java.sql.Connection import java.sql.Connection
@ -21,13 +23,28 @@ class DatabaseTransaction(isolation: Int, val threadLocal: ThreadLocal<DatabaseT
} }
} }
private val sessionDelegate = lazy {
val session = cordaPersistence.entityManagerFactory.withOptions().connection(connection).openSession()
hibernateTransaction = session.beginTransaction()
session
}
val session: Session by sessionDelegate
private lateinit var hibernateTransaction : Transaction
val outerTransaction: DatabaseTransaction? = threadLocal.get() val outerTransaction: DatabaseTransaction? = threadLocal.get()
fun commit() { fun commit() {
if (sessionDelegate.isInitialized()) {
hibernateTransaction.commit()
}
connection.commit() connection.commit()
} }
fun rollback() { fun rollback() {
if (sessionDelegate.isInitialized() && session.isOpen) {
session.clear()
}
if (!connection.isClosed) { if (!connection.isClosed) {
connection.rollback() connection.rollback()
} }

View File

@ -0,0 +1,33 @@
package net.corda.node.utilities
import com.google.common.cache.CacheBuilder
import com.google.common.cache.CacheLoader
import com.google.common.cache.LoadingCache
import com.google.common.util.concurrent.ListenableFuture
class NonInvalidatingCache<K, V> private constructor(
val cache: LoadingCache<K, V>
): LoadingCache<K, V> by cache {
constructor(bound: Long, concurrencyLevel: Int, loadFunction: (K) -> V) :
this(buildCache(bound, concurrencyLevel, loadFunction))
private companion object {
private fun <K, V> buildCache(bound: Long, concurrencyLevel: Int, loadFunction: (K) -> V): LoadingCache<K, V> {
val builder = CacheBuilder.newBuilder().maximumSize(bound).concurrencyLevel(concurrencyLevel)
return builder.build(NonInvalidatingCacheLoader(loadFunction))
}
}
// TODO look into overriding loadAll() if we ever use it
private class NonInvalidatingCacheLoader<K, V>(val loadFunction: (K) -> V) : CacheLoader<K, V>() {
override fun reload(key: K, oldValue: V): ListenableFuture<V> {
throw IllegalStateException("Non invalidating cache refreshed")
}
override fun load(key: K) = loadFunction(key)
override fun loadAll(keys: Iterable<K>): MutableMap<K, V> {
return super.loadAll(keys)
}
}
}

View File

@ -10,7 +10,6 @@ import net.corda.core.schemas.CommonSchemaV1
import net.corda.core.schemas.PersistentStateRef import net.corda.core.schemas.PersistentStateRef
import net.corda.core.serialization.deserialize import net.corda.core.serialization.deserialize
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.schema.HibernateObserver import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.vault.VaultSchemaV1 import net.corda.node.services.vault.VaultSchemaV1
@ -27,6 +26,7 @@ import net.corda.testing.contracts.fillWithSomeTestLinearStates
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import net.corda.testing.schemas.DummyLinearStateSchemaV1 import net.corda.testing.schemas.DummyLinearStateSchemaV1
import net.corda.testing.schemas.DummyLinearStateSchemaV2 import net.corda.testing.schemas.DummyLinearStateSchemaV2
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
@ -65,11 +65,10 @@ class HibernateConfigurationTest : TestDependencyInjectionBase() {
issuerServices = MockServices(DUMMY_CASH_ISSUER_KEY, BOB_KEY, BOC_KEY) issuerServices = MockServices(DUMMY_CASH_ISSUER_KEY, BOB_KEY, BOC_KEY)
val dataSourceProps = makeTestDataSourceProperties() val dataSourceProps = makeTestDataSourceProperties()
val defaultDatabaseProperties = makeTestDatabaseProperties() val defaultDatabaseProperties = makeTestDatabaseProperties()
database = configureDatabase(dataSourceProps, defaultDatabaseProperties) database = configureDatabase(dataSourceProps, defaultDatabaseProperties, identitySvc = ::makeTestIdentityService)
val customSchemas = setOf(VaultSchemaV1, CashSchemaV1, SampleCashSchemaV2, SampleCashSchemaV3) val customSchemas = setOf(VaultSchemaV1, CashSchemaV1, SampleCashSchemaV2, SampleCashSchemaV3)
database.transaction { database.transaction {
val identityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), ::makeTestIdentityService)
hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), identityService)
services = object : MockServices(BOB_KEY, BOC_KEY, DUMMY_NOTARY_KEY) { services = object : MockServices(BOB_KEY, BOC_KEY, DUMMY_NOTARY_KEY) {
override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig) override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig)

View File

@ -17,13 +17,11 @@ import net.corda.node.services.vault.schemas.requery.VaultSchema
import net.corda.node.services.vault.schemas.requery.VaultStatesEntity import net.corda.node.services.vault.schemas.requery.VaultStatesEntity
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.ALICE_PUBKEY import net.corda.testing.*
import net.corda.testing.DUMMY_NOTARY
import net.corda.testing.DUMMY_PUBKEY_1
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.junit.After import org.junit.After
import org.junit.Assert.assertEquals import org.junit.Assert.assertEquals
@ -42,7 +40,7 @@ class RequeryConfigurationTest : TestDependencyInjectionBase() {
@Before @Before
fun setUp() { fun setUp() {
val dataSourceProperties = makeTestDataSourceProperties() val dataSourceProperties = makeTestDataSourceProperties()
database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties()) database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
newTransactionStorage() newTransactionStorage()
newRequeryStorage(dataSourceProperties) newRequeryStorage(dataSourceProperties)
} }

View File

@ -10,9 +10,6 @@ import net.corda.core.node.ServiceHub
import net.corda.core.node.services.VaultService import net.corda.core.node.services.VaultService
import net.corda.core.serialization.SingletonSerializeAsToken import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.testing.ALICE_KEY
import net.corda.testing.DUMMY_CA
import net.corda.testing.DUMMY_NOTARY
import net.corda.node.services.MockServiceHubInternal import net.corda.node.services.MockServiceHubInternal
import net.corda.node.services.identity.InMemoryIdentityService import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.persistence.DBCheckpointStorage import net.corda.node.services.persistence.DBCheckpointStorage
@ -22,14 +19,11 @@ import net.corda.node.services.vault.NodeVaultService
import net.corda.node.utilities.AffinityExecutor import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.*
import net.corda.testing.node.InMemoryMessagingNetwork import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.MockKeyManagementService import net.corda.testing.node.MockKeyManagementService
import net.corda.testing.getTestX509Name
import net.corda.testing.testNodeConfiguration
import net.corda.testing.initialiseTestSerialization
import net.corda.testing.node.* import net.corda.testing.node.*
import net.corda.testing.node.TestClock import net.corda.testing.node.TestClock
import net.corda.testing.resetTestSerialization
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.junit.After import org.junit.After
@ -77,7 +71,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
smmHasRemovedAllFlows = CountDownLatch(1) smmHasRemovedAllFlows = CountDownLatch(1)
calls = 0 calls = 0
val dataSourceProps = makeTestDataSourceProperties() val dataSourceProps = makeTestDataSourceProperties()
database = configureDatabase(dataSourceProps, makeTestDatabaseProperties()) database = configureDatabase(dataSourceProps, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
val identityService = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate) val identityService = InMemoryIdentityService(trustRoot = DUMMY_CA.certificate)
val kms = MockKeyManagementService(identityService, ALICE_KEY) val kms = MockKeyManagementService(identityService, ALICE_KEY)

View File

@ -13,6 +13,7 @@ import net.corda.node.services.api.DEFAULT_SESSION_ID
import net.corda.node.services.api.MonitoringService import net.corda.node.services.api.MonitoringService
import net.corda.node.services.config.NodeConfiguration import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.network.InMemoryNetworkMapCache import net.corda.node.services.network.InMemoryNetworkMapCache
import net.corda.node.services.network.NetworkMapService import net.corda.node.services.network.NetworkMapService
import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.transactions.PersistentUniquenessProvider
@ -23,6 +24,7 @@ import net.corda.testing.*
import net.corda.testing.node.MOCK_VERSION_INFO import net.corda.testing.node.MOCK_VERSION_INFO
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatThrownBy import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.After import org.junit.After
@ -69,7 +71,7 @@ class ArtemisMessagingTests : TestDependencyInjectionBase() {
baseDirectory = baseDirectory, baseDirectory = baseDirectory,
myLegalName = ALICE.name) myLegalName = ALICE.name)
LogHelper.setLevel(PersistentUniquenessProvider::class) LogHelper.setLevel(PersistentUniquenessProvider::class)
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
networkMapRegistrationFuture = doneFuture(Unit) networkMapRegistrationFuture = doneFuture(Unit)
} }

View File

@ -115,12 +115,12 @@ class InMemoryIdentityServiceTests {
service.verifyAndRegisterAnonymousIdentity(aliceTxIdentity, alice.party) service.verifyAndRegisterAnonymousIdentity(aliceTxIdentity, alice.party)
var actual = service.anonymousFromKey(aliceTxIdentity.party.owningKey) var actual = service.anonymousFromKey(aliceTxIdentity.party.owningKey)
assertEquals<AnonymousPartyAndPath>(aliceTxIdentity, actual!!) assertEquals(aliceTxIdentity, actual!!)
assertNull(service.anonymousFromKey(bobTxIdentity.party.owningKey)) assertNull(service.anonymousFromKey(bobTxIdentity.party.owningKey))
service.verifyAndRegisterAnonymousIdentity(bobTxIdentity, bob.party) service.verifyAndRegisterAnonymousIdentity(bobTxIdentity, bob.party)
actual = service.anonymousFromKey(bobTxIdentity.party.owningKey) actual = service.anonymousFromKey(bobTxIdentity.party.owningKey)
assertEquals<AnonymousPartyAndPath>(bobTxIdentity, actual!!) assertEquals(bobTxIdentity, actual!!)
} }
/** /**
@ -131,37 +131,28 @@ class InMemoryIdentityServiceTests {
fun `assert ownership`() { fun `assert ownership`() {
withTestSerialization { withTestSerialization {
val trustRoot = DUMMY_CA val trustRoot = DUMMY_CA
val (alice, aliceTxIdentity) = createParty(ALICE.name, trustRoot) val (alice, anonymousAlice) = createParty(ALICE.name, trustRoot)
val (bob, anonymousBob) = createParty(BOB.name, trustRoot)
val certFactory = CertificateFactory.getInstance("X509")
val bobRootKey = Crypto.generateKeyPair()
val bobRoot = getTestPartyAndCertificate(BOB.name, bobRootKey.public)
val bobRootCert = bobRoot.certificate
val bobTxKey = Crypto.generateKeyPair()
val bobTxCert = X509Utilities.createCertificate(CertificateType.IDENTITY, bobRootCert, bobRootKey, BOB.name, bobTxKey.public)
val bobCertPath = certFactory.generateCertPath(listOf(bobTxCert.cert, bobRootCert.cert))
val bob = PartyAndCertificate(BOB.name, bobRootKey.public, bobRootCert, bobCertPath)
// Now we have identities, construct the service and let it know about both // Now we have identities, construct the service and let it know about both
val service = InMemoryIdentityService(setOf(alice, bob), emptyMap(), trustRoot.certificate.cert) val service = InMemoryIdentityService(setOf(alice, bob), emptyMap(), trustRoot.certificate.cert)
service.verifyAndRegisterAnonymousIdentity(aliceTxIdentity, alice.party)
val anonymousBob = AnonymousPartyAndPath(AnonymousParty(bobTxKey.public),bobCertPath) service.verifyAndRegisterAnonymousIdentity(anonymousAlice, alice.party)
service.verifyAndRegisterAnonymousIdentity(anonymousBob, bob.party) service.verifyAndRegisterAnonymousIdentity(anonymousBob, bob.party)
// Verify that paths are verified // Verify that paths are verified
service.assertOwnership(alice.party, aliceTxIdentity.party) service.assertOwnership(alice.party, anonymousAlice.party)
service.assertOwnership(bob.party, anonymousBob.party) service.assertOwnership(bob.party, anonymousBob.party)
assertFailsWith<IllegalArgumentException> { assertFailsWith<IllegalArgumentException> {
service.assertOwnership(alice.party, anonymousBob.party) service.assertOwnership(alice.party, anonymousBob.party)
} }
assertFailsWith<IllegalArgumentException> { assertFailsWith<IllegalArgumentException> {
service.assertOwnership(bob.party, aliceTxIdentity.party) service.assertOwnership(bob.party, anonymousAlice.party)
} }
assertFailsWith<IllegalArgumentException> { assertFailsWith<IllegalArgumentException> {
val owningKey = Crypto.decodePublicKey(trustRoot.certificate.subjectPublicKeyInfo.encoded) val owningKey = Crypto.decodePublicKey(trustRoot.certificate.subjectPublicKeyInfo.encoded)
service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), aliceTxIdentity.party) service.assertOwnership(Party(trustRoot.certificate.subject, owningKey), anonymousAlice.party)
} }
} }
} }

View File

@ -11,8 +11,8 @@ import net.corda.testing.LogHelper
import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.assertj.core.api.Assertions.assertThatExceptionOfType
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -33,7 +33,7 @@ class DBCheckpointStorageTests : TestDependencyInjectionBase() {
@Before @Before
fun setUp() { fun setUp() {
LogHelper.setLevel(PersistentUniquenessProvider::class) LogHelper.setLevel(PersistentUniquenessProvider::class)
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
newCheckpointStorage() newCheckpointStorage()
} }
@ -94,16 +94,6 @@ class DBCheckpointStorageTests : TestDependencyInjectionBase() {
} }
} }
@Test
fun `remove unknown checkpoint`() {
val checkpoint = newCheckpoint()
database.transaction {
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
checkpointStorage.removeCheckpoint(checkpoint)
}
}
}
@Test @Test
fun `add two checkpoints then remove first one`() { fun `add two checkpoints then remove first one`() {
val firstCheckpoint = newCheckpoint() val firstCheckpoint = newCheckpoint()

View File

@ -4,19 +4,28 @@ import net.corda.core.contracts.StateRef
import net.corda.core.crypto.Crypto import net.corda.core.crypto.Crypto
import net.corda.core.crypto.SecureHash import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.SignatureMetadata import net.corda.core.crypto.SignatureMetadata
import net.corda.core.node.services.VaultService
import net.corda.core.crypto.TransactionSignature import net.corda.core.crypto.TransactionSignature
import net.corda.core.schemas.MappedSchema
import net.corda.core.toFuture import net.corda.core.toFuture
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.WireTransaction import net.corda.core.transactions.WireTransaction
import net.corda.node.services.database.HibernateConfiguration
import net.corda.node.services.schema.HibernateObserver
import net.corda.node.services.schema.NodeSchemaService
import net.corda.node.services.transactions.PersistentUniquenessProvider import net.corda.node.services.transactions.PersistentUniquenessProvider
import net.corda.node.services.vault.NodeVaultService
import net.corda.node.services.vault.VaultSchemaV1
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.ALICE_PUBKEY import net.corda.schemas.CashSchemaV1
import net.corda.testing.DUMMY_NOTARY import net.corda.schemas.SampleCashSchemaV2
import net.corda.testing.LogHelper import net.corda.schemas.SampleCashSchemaV3
import net.corda.testing.TestDependencyInjectionBase import net.corda.testing.*
import net.corda.testing.node.MockServices
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
@ -27,11 +36,43 @@ import kotlin.test.assertEquals
class DBTransactionStorageTests : TestDependencyInjectionBase() { class DBTransactionStorageTests : TestDependencyInjectionBase() {
lateinit var database: CordaPersistence lateinit var database: CordaPersistence
lateinit var transactionStorage: DBTransactionStorage lateinit var transactionStorage: DBTransactionStorage
lateinit var services: MockServices
val vault: VaultService get() = services.vaultService
// Hibernate configuration objects
lateinit var hibernateConfig: HibernateConfiguration
@Before @Before
fun setUp() { fun setUp() {
LogHelper.setLevel(PersistentUniquenessProvider::class) LogHelper.setLevel(PersistentUniquenessProvider::class)
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) val dataSourceProps = makeTestDataSourceProperties()
val transactionSchema = MappedSchema(schemaFamily = javaClass, version = 1,
mappedTypes = listOf(DBTransactionStorage.DBTransaction::class.java))
val customSchemas = setOf(VaultSchemaV1, CashSchemaV1, SampleCashSchemaV2, SampleCashSchemaV3, transactionSchema)
database = configureDatabase(dataSourceProps, makeTestDatabaseProperties(), customSchemas, identitySvc = ::makeTestIdentityService)
database.transaction {
hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
services = object : MockServices(BOB_KEY) {
override val vaultService: VaultService get() {
val vaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties())
hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig)
return vaultService
}
override fun recordTransactions(txs: Iterable<SignedTransaction>) {
for (stx in txs) {
validatedTransactions.addTransaction(stx)
}
// Refactored to use notifyAll() as we have no other unit test for that method with multiple transactions.
vaultService.notifyAll(txs.map { it.tx })
}
}
}
newTransactionStorage() newTransactionStorage()
} }
@ -120,6 +161,37 @@ class DBTransactionStorageTests : TestDependencyInjectionBase() {
} }
} }
@Test
fun `transaction saved twice in same DB transaction scope`() {
val firstTransaction = newTransaction()
database.transaction {
transactionStorage.addTransaction(firstTransaction)
transactionStorage.addTransaction(firstTransaction)
}
assertTransactionIsRetrievable(firstTransaction)
database.transaction {
assertThat(transactionStorage.transactions).containsOnly(firstTransaction)
}
}
@Test
fun `transaction saved twice in two DB transaction scopes`() {
val firstTransaction = newTransaction()
val secondTransaction = newTransaction()
database.transaction {
transactionStorage.addTransaction(firstTransaction)
}
database.transaction {
transactionStorage.addTransaction(secondTransaction)
transactionStorage.addTransaction(firstTransaction)
}
assertTransactionIsRetrievable(firstTransaction)
database.transaction {
assertThat(transactionStorage.transactions).containsOnly(firstTransaction, secondTransaction)
}
}
@Test @Test
fun `updates are fired`() { fun `updates are fired`() {
val future = transactionStorage.updates.toFuture() val future = transactionStorage.updates.toFuture()

View File

@ -17,6 +17,7 @@ import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -43,7 +44,7 @@ class NodeAttachmentStorageTest {
LogHelper.setLevel(PersistentUniquenessProvider::class) LogHelper.setLevel(PersistentUniquenessProvider::class)
dataSourceProperties = makeTestDataSourceProperties() dataSourceProperties = makeTestDataSourceProperties()
database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties()) database = configureDatabase(dataSourceProperties, makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
configuration = RequeryConfiguration(dataSourceProperties, databaseProperties = makeTestDatabaseProperties()) configuration = RequeryConfiguration(dataSourceProperties, databaseProperties = makeTestDatabaseProperties())
fs = Jimfs.newFileSystem(Configuration.unix()) fs = Jimfs.newFileSystem(Configuration.unix())

View File

@ -18,6 +18,7 @@ import net.corda.testing.MEGA_CORP
import net.corda.testing.MOCK_IDENTITIES import net.corda.testing.MOCK_IDENTITIES
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.hibernate.annotations.Cascade import org.hibernate.annotations.Cascade
import org.hibernate.annotations.CascadeType import org.hibernate.annotations.CascadeType
import org.jetbrains.exposed.sql.transactions.TransactionManager import org.jetbrains.exposed.sql.transactions.TransactionManager
@ -35,7 +36,7 @@ class HibernateObserverTests {
@Before @Before
fun setUp() { fun setUp() {
LogHelper.setLevel(HibernateObserver::class) LogHelper.setLevel(HibernateObserver::class)
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
} }
@After @After
@ -105,8 +106,7 @@ class HibernateObserverTests {
} }
@Suppress("UNUSED_VARIABLE") @Suppress("UNUSED_VARIABLE")
val identityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) val observer = HibernateObserver(rawUpdatesPublisher, HibernateConfiguration(schemaService, makeTestDatabaseProperties(), ::makeTestIdentityService))
val observer = HibernateObserver(rawUpdatesPublisher, HibernateConfiguration(schemaService, makeTestDatabaseProperties(), identityService))
database.transaction { database.transaction {
rawUpdatesPublisher.onNext(Vault.Update(emptySet(), setOf(StateAndRef(TransactionState(TestState(), MEGA_CORP), StateRef(SecureHash.sha256("dummy"), 0))))) rawUpdatesPublisher.onNext(Vault.Update(emptySet(), setOf(StateAndRef(TransactionState(TestState(), MEGA_CORP), StateRef(SecureHash.sha256("dummy"), 0)))))
val parentRowCountResult = TransactionManager.current().connection.prepareStatement("select count(*) from Parents").executeQuery() val parentRowCountResult = TransactionManager.current().connection.prepareStatement("select count(*) from Parents").executeQuery()

View File

@ -11,11 +11,10 @@ import net.corda.core.utilities.getOrThrow
import net.corda.node.services.network.NetworkMapService import net.corda.node.services.network.NetworkMapService
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.LogHelper import net.corda.testing.*
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.freeLocalHostAndPort
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.jetbrains.exposed.sql.Transaction import org.jetbrains.exposed.sql.Transaction
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
@ -35,7 +34,7 @@ class DistributedImmutableMapTests : TestDependencyInjectionBase() {
fun setup() { fun setup() {
LogHelper.setLevel("-org.apache.activemq") LogHelper.setLevel("-org.apache.activemq")
LogHelper.setLevel(NetworkMapService::class) LogHelper.setLevel(NetworkMapService::class)
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
cluster = setUpCluster() cluster = setUpCluster()
} }

View File

@ -4,12 +4,10 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.node.services.UniquenessException import net.corda.core.node.services.UniquenessException
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.LogHelper import net.corda.testing.*
import net.corda.testing.MEGA_CORP
import net.corda.testing.TestDependencyInjectionBase
import net.corda.testing.generateStateRef
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.junit.After import org.junit.After
import org.junit.Before import org.junit.Before
import org.junit.Test import org.junit.Test
@ -25,7 +23,7 @@ class PersistentUniquenessProviderTests : TestDependencyInjectionBase() {
@Before @Before
fun setUp() { fun setUp() {
LogHelper.setLevel(PersistentUniquenessProvider::class) LogHelper.setLevel(PersistentUniquenessProvider::class)
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
} }
@After @After

View File

@ -15,13 +15,9 @@ import net.corda.core.node.services.*
import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.*
import net.corda.core.node.services.vault.QueryCriteria.* import net.corda.core.node.services.vault.QueryCriteria.*
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.NonEmptySet import net.corda.core.utilities.NonEmptySet
import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.OpaqueBytes
import net.corda.core.utilities.toHexString import net.corda.core.utilities.toHexString
import net.corda.node.services.database.HibernateConfiguration
import net.corda.node.services.identity.InMemoryIdentityService
import net.corda.node.services.schema.NodeSchemaService
import net.corda.core.utilities.* import net.corda.core.utilities.*
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
@ -34,6 +30,7 @@ import net.corda.testing.contracts.*
import net.corda.testing.node.MockServices import net.corda.testing.node.MockServices
import net.corda.testing.node.makeTestDatabaseAndMockServices import net.corda.testing.node.makeTestDatabaseAndMockServices
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import net.corda.testing.schemas.DummyLinearStateSchemaV1 import net.corda.testing.schemas.DummyLinearStateSchemaV1
import org.assertj.core.api.Assertions import org.assertj.core.api.Assertions
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
@ -77,7 +74,7 @@ class VaultQueryTests : TestDependencyInjectionBase() {
@Ignore @Ignore
@Test @Test
fun createPersistentTestDb() { fun createPersistentTestDb() {
val database = configureDatabase(makePersistentDataSourceProperties(), makeTestDatabaseProperties()) val database = configureDatabase(makePersistentDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
setUpDb(database, 5000) setUpDb(database, 5000)

View File

@ -5,6 +5,7 @@ import net.corda.core.internal.bufferUntilSubscribed
import net.corda.core.internal.tee import net.corda.core.internal.tee
import net.corda.testing.node.makeTestDataSourceProperties import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties import net.corda.testing.node.makeTestDatabaseProperties
import net.corda.testing.node.makeTestIdentityService
import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThat
import org.junit.After import org.junit.After
import org.junit.Test import org.junit.Test
@ -20,7 +21,7 @@ class ObservablesTests {
val toBeClosed = mutableListOf<Closeable>() val toBeClosed = mutableListOf<Closeable>()
fun createDatabase(): CordaPersistence { fun createDatabase(): CordaPersistence {
val database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) val database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
toBeClosed += database toBeClosed += database
return database return database
} }

View File

@ -398,7 +398,7 @@ class X509UtilitiesTest {
@Test @Test
fun `serialize - deserialize X509CertififcateHolder`() { fun `serialize - deserialize X509CertififcateHolder`() {
val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
val context = SerializationContextImpl(KryoHeaderV0_1, val context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,
@ -413,7 +413,7 @@ class X509UtilitiesTest {
@Test @Test
fun `serialize - deserialize X509CertPath`() { fun `serialize - deserialize X509CertPath`() {
val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme()) } val factory = SerializationFactoryImpl().apply { registerScheme(KryoServerSerializationScheme(this)) }
val context = SerializationContextImpl(KryoHeaderV0_1, val context = SerializationContextImpl(KryoHeaderV0_1,
javaClass.classLoader, javaClass.classLoader,
AllWhitelist, AllWhitelist,

View File

@ -18,10 +18,7 @@ import net.corda.irs.flows.RatesFixFlow
import net.corda.node.utilities.CordaPersistence import net.corda.node.utilities.CordaPersistence
import net.corda.node.utilities.configureDatabase import net.corda.node.utilities.configureDatabase
import net.corda.testing.* import net.corda.testing.*
import net.corda.testing.node.MockNetwork import net.corda.testing.node.*
import net.corda.testing.node.MockServices
import net.corda.testing.node.makeTestDataSourceProperties
import net.corda.testing.node.makeTestDatabaseProperties
import org.bouncycastle.asn1.x500.X500Name import org.bouncycastle.asn1.x500.X500Name
import org.junit.After import org.junit.After
import org.junit.Assert import org.junit.Assert
@ -60,7 +57,7 @@ class NodeInterestRatesTest : TestDependencyInjectionBase() {
@Before @Before
fun setUp() { fun setUp() {
database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties()) database = configureDatabase(makeTestDataSourceProperties(), makeTestDatabaseProperties(), identitySvc = ::makeTestIdentityService)
database.transaction { database.transaction {
oracle = NodeInterestRates.Oracle( oracle = NodeInterestRates.Oracle(
MEGA_CORP, MEGA_CORP,

View File

@ -6,7 +6,7 @@ import net.corda.core.utilities.ByteSequence
import net.corda.node.serialization.KryoServerSerializationScheme import net.corda.node.serialization.KryoServerSerializationScheme
import net.corda.nodeapi.internal.serialization.* import net.corda.nodeapi.internal.serialization.*
fun <T> withTestSerialization(block: () -> T): T { inline fun <T> withTestSerialization(block: () -> T): T {
initialiseTestSerialization() initialiseTestSerialization()
try { try {
return block() return block()
@ -61,8 +61,8 @@ fun initialiseTestSerialization() {
// Now configure all the testing related delegates. // Now configure all the testing related delegates.
(SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply { (SerializationDefaults.SERIALIZATION_FACTORY as TestSerializationFactory).delegate = SerializationFactoryImpl().apply {
registerScheme(KryoClientSerializationScheme()) registerScheme(KryoClientSerializationScheme(this))
registerScheme(KryoServerSerializationScheme()) registerScheme(KryoServerSerializationScheme(this))
registerScheme(AMQPClientSerializationScheme()) registerScheme(AMQPClientSerializationScheme())
registerScheme(AMQPServerSerializationScheme()) registerScheme(AMQPServerSerializationScheme())
} }
@ -139,4 +139,8 @@ class TestSerializationContext : SerializationContext {
override fun withWhitelisted(clazz: Class<*>): SerializationContext { override fun withWhitelisted(clazz: Class<*>): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) } return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withWhitelisted(clazz) }
} }
override fun withPreferredSerializationVersion(versionHeader: ByteSequence): SerializationContext {
return TestSerializationContext().apply { delegate = this@TestSerializationContext.delegate!!.withPreferredSerializationVersion(versionHeader) }
}
} }

View File

@ -88,7 +88,7 @@ open class MockServices(vararg val keys: KeyPair) : ServiceHub {
lateinit var hibernatePersister: HibernateObserver lateinit var hibernatePersister: HibernateObserver
fun makeVaultService(dataSourceProps: Properties, hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), identityService)): VaultService { fun makeVaultService(dataSourceProps: Properties, hibernateConfig: HibernateConfiguration = HibernateConfiguration(NodeSchemaService(), makeTestDatabaseProperties(), { identityService })): VaultService {
val vaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties()) val vaultService = NodeVaultService(this, dataSourceProps, makeTestDatabaseProperties())
hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig) hibernatePersister = HibernateObserver(vaultService.rawUpdates, hibernateConfig)
return vaultService return vaultService
@ -216,13 +216,15 @@ fun makeTestDatabaseProperties(): Properties {
return props return props
} }
fun makeTestIdentityService() = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate)
fun makeTestDatabaseAndMockServices(customSchemas: Set<MappedSchema> = setOf(CommercialPaperSchemaV1, DummyLinearStateSchemaV1, CashSchemaV1), keys: List<KeyPair> = listOf(MEGA_CORP_KEY)): Pair<CordaPersistence, MockServices> { fun makeTestDatabaseAndMockServices(customSchemas: Set<MappedSchema> = setOf(CommercialPaperSchemaV1, DummyLinearStateSchemaV1, CashSchemaV1), keys: List<KeyPair> = listOf(MEGA_CORP_KEY)): Pair<CordaPersistence, MockServices> {
val dataSourceProps = makeTestDataSourceProperties() val dataSourceProps = makeTestDataSourceProperties()
val databaseProperties = makeTestDatabaseProperties() val databaseProperties = makeTestDatabaseProperties()
val database = configureDatabase(dataSourceProps, databaseProperties)
val database = configureDatabase(dataSourceProps, databaseProperties, identitySvc = ::makeTestIdentityService)
val mockService = database.transaction { val mockService = database.transaction {
val identityService = InMemoryIdentityService(MOCK_IDENTITIES, trustRoot = DUMMY_CA.certificate) val hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), databaseProperties, identitySvc = ::makeTestIdentityService)
val hibernateConfig = HibernateConfiguration(NodeSchemaService(customSchemas), databaseProperties, identityService)
object : MockServices(*(keys.toTypedArray())) { object : MockServices(*(keys.toTypedArray())) {
override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig) override val vaultService: VaultService = makeVaultService(dataSourceProps, hibernateConfig)

View File

@ -32,11 +32,11 @@ class SimpleNode(val config: NodeConfiguration, val address: NetworkHostAndPort
rpcAddress: NetworkHostAndPort = freeLocalHostAndPort(), rpcAddress: NetworkHostAndPort = freeLocalHostAndPort(),
trustRoot: X509Certificate) : AutoCloseable { trustRoot: X509Certificate) : AutoCloseable {
val database: CordaPersistence = configureDatabase(config.dataSourceProperties, config.database)
val userService = RPCUserServiceImpl(config.rpcUsers) val userService = RPCUserServiceImpl(config.rpcUsers)
val monitoringService = MonitoringService(MetricRegistry()) val monitoringService = MonitoringService(MetricRegistry())
val identity: KeyPair = generateKeyPair() val identity: KeyPair = generateKeyPair()
val identityService: IdentityService = InMemoryIdentityService(trustRoot = trustRoot) val identityService: IdentityService = InMemoryIdentityService(trustRoot = trustRoot)
val database: CordaPersistence = configureDatabase(config.dataSourceProperties, config.database, identitySvc = {InMemoryIdentityService(trustRoot = trustRoot)})
val keyService: KeyManagementService = E2ETestKeyManagementService(identityService, setOf(identity)) val keyService: KeyManagementService = E2ETestKeyManagementService(identityService, setOf(identity))
val executor = ServiceAffinityExecutor(config.myLegalName.commonName, 1) val executor = ServiceAffinityExecutor(config.myLegalName.commonName, 1)
// TODO: We should have a dummy service hub rather than change behaviour in tests // TODO: We should have a dummy service hub rather than change behaviour in tests

View File

@ -7,6 +7,7 @@ import com.typesafe.config.ConfigParseOptions
import net.corda.core.internal.div import net.corda.core.internal.div
import net.corda.core.serialization.SerializationContext import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory
import net.corda.core.utilities.ByteSequence import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.NetworkHostAndPort import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.debug import net.corda.core.utilities.debug
@ -89,13 +90,13 @@ class Verifier {
private fun initialiseSerialization() { private fun initialiseSerialization() {
SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply { SerializationDefaults.SERIALIZATION_FACTORY = SerializationFactoryImpl().apply {
registerScheme(KryoVerifierSerializationScheme) registerScheme(KryoVerifierSerializationScheme(this))
} }
SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT SerializationDefaults.P2P_CONTEXT = KRYO_P2P_CONTEXT
} }
} }
object KryoVerifierSerializationScheme : AbstractKryoSerializationScheme() { class KryoVerifierSerializationScheme(serializationFactory: SerializationFactory) : AbstractKryoSerializationScheme(serializationFactory) {
override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean { override fun canDeserializeVersion(byteSequence: ByteSequence, target: SerializationContext.UseCase): Boolean {
return byteSequence.equals(KryoHeaderV0_1) && target == SerializationContext.UseCase.P2P return byteSequence.equals(KryoHeaderV0_1) && target == SerializationContext.UseCase.P2P
} }