Fixed TransientProperty so that it actually works during (de)serialisation

This commit is contained in:
Shams Asari 2017-07-17 10:18:33 +01:00
parent 195189070a
commit c62387f3f6
6 changed files with 124 additions and 47 deletions
core/src
main/kotlin/net/corda/core
test/kotlin/net/corda/core/utilities
samples/irs-demo/src/main/kotlin/net/corda/irs/flows

View File

@ -35,7 +35,6 @@ import java.util.zip.ZipInputStream
import java.util.zip.ZipOutputStream
import kotlin.concurrent.withLock
import kotlin.reflect.KClass
import kotlin.reflect.KProperty
val Int.days: Duration get() = Duration.ofDays(this.toLong())
@Suppress("unused") // It's here for completeness
@ -215,17 +214,7 @@ class ThreadBox<out T>(val content: T, val lock: ReentrantLock = ReentrantLock()
@CordaSerializable
abstract class RetryableException(message: String) : FlowException(message)
/**
* A simple wrapper that enables the use of Kotlin's "val x by TransientProperty { ... }" syntax. Such a property
* will not be serialized to disk, and if it's missing (or the first time it's accessed), the initializer will be
* used to set it up. Note that the initializer will be called with the TransientProperty object locked.
*/
class TransientProperty<out T>(private val initializer: () -> T) {
@Transient private var v: T? = null
@Synchronized
operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: initializer().also { v = it }
}
/**
* Given a path to a zip file, extracts it to the given directory.

View File

@ -3,12 +3,13 @@ package net.corda.core.serialization
import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util
import net.corda.core.node.AttachmentsClassLoader
import net.corda.core.utilities.loggerFor
import java.io.PrintWriter
import java.lang.reflect.Modifier
import java.lang.reflect.Modifier.isAbstract
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.Paths
@ -52,18 +53,16 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean
}
private fun checkClass(type: Class<*>): Registration? {
/** If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking. */
// If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking.
if (!whitelistEnabled) return null
// Allow primitives, abstracts and interfaces
if (type.isPrimitive || type == Any::class.java || Modifier.isAbstract(type.modifiers) || type == String::class.java) return null
if (type.isPrimitive || type == Any::class.java || isAbstract(type.modifiers) || type == String::class.java) return null
// If array, recurse on element type
if (type.isArray) {
return checkClass(type.componentType)
}
if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) {
// Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry.
return checkClass(type.superclass)
}
if (type.isArray) return checkClass(type.componentType)
// Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry.
if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass)
// Kotlin lambdas require some special treatment
if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) return null
// It's safe to have the Class already, since Kryo loads it with initialisation off.
// If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted.
// Thus, blacklisting precedes annotation checking.
@ -74,34 +73,40 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean
}
override fun registerImplicit(type: Class<*>): Registration {
val hasAnnotation = checkForAnnotation(type)
// If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the
// case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures
// are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph.
if (!hasAnnotation || !amqpEnabled) {
val objectInstance = try {
type.kotlin.objectInstance
} catch (t: Throwable) {
// objectInstance will throw if the type is something like a lambda
null
}
// We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent.
val references = kryo.references
try {
kryo.references = true
val serializer = if (objectInstance != null) KotlinObjectSerializer(objectInstance) else kryo.getDefaultSerializer(type)
return register(Registration(type, serializer, NAME.toInt()))
} finally {
kryo.references = references
}
} else {
if (checkForAnnotation(type) && amqpEnabled) {
// Build AMQP serializer
return register(Registration(type, KryoAMQPSerializer, NAME.toInt()))
}
val objectInstance = try {
type.kotlin.objectInstance
} catch (t: Throwable) {
null // objectInstance will throw if the type is something like a lambda
}
// We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent.
val references = kryo.references
try {
kryo.references = true
val serializer = if (objectInstance != null) {
KotlinObjectSerializer(objectInstance)
} else if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) {
// Kotlin lambdas extend this class and any captured variables are stored in synthentic fields
FieldSerializer<Any>(kryo, type).apply { setIgnoreSyntheticFields(false) }
} else {
kryo.getDefaultSerializer(type)
}
return register(Registration(type, serializer, NAME.toInt()))
} finally {
kryo.references = references
}
}
// Trivial Serializer which simply returns the given instance which we already know is a Kotlin object
private class KotlinObjectSerializer(val objectInstance: Any) : Serializer<Any>() {
// Trivial Serializer which simply returns the given instance, which we already know is a Kotlin object
private class KotlinObjectSerializer(private val objectInstance: Any) : Serializer<Any>() {
override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any = objectInstance
override fun write(kryo: Kryo, output: Output, obj: Any) = Unit
}

View File

@ -1,7 +1,9 @@
package net.corda.core.utilities
import net.corda.core.serialization.CordaSerializable
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import kotlin.reflect.KProperty
/**
* Get the [Logger] for a class using the syntax
@ -20,5 +22,30 @@ inline fun Logger.debug(msg: () -> String) {
if (isDebugEnabled) debug(msg())
}
/**
* A simple wrapper that enables the use of Kotlin's `val x by transient { ... }` syntax. Such a property
* will not be serialized, and if it's missing (or the first time it's accessed), the initializer will be
* used to set it up.
*/
@Suppress("DEPRECATION")
fun <T> transient(initializer: () -> T) = TransientProperty(initializer)
@Deprecated("Use transient")
@CordaSerializable
class TransientProperty<out T>(private val initialiser: () -> T) {
@Transient private var initialised = false
@Transient private var value: T? = null
@Suppress("UNCHECKED_CAST")
@Synchronized
operator fun getValue(thisRef: Any?, property: KProperty<*>): T {
if (!initialised) {
value = initialiser()
initialised = true
}
return value as T
}
}
/** @see NonEmptySet.copyOf */
fun <T> Collection<T>.toNonEmptySet(): NonEmptySet<T> = NonEmptySet.copyOf(this)
fun <T> Collection<T>.toNonEmptySet(): NonEmptySet<T> = NonEmptySet.copyOf(this)

View File

@ -1,10 +1,8 @@
package net.corda.core.utilities
import net.corda.core.TransientProperty
import net.corda.core.serialization.CordaSerializable
import rx.Observable
import rx.Subscription
import rx.subjects.BehaviorSubject
import rx.subjects.PublishSubject
import java.util.*
@ -76,7 +74,7 @@ class ProgressTracker(vararg steps: Step) {
val steps = arrayOf(UNSTARTED, *steps, DONE)
// This field won't be serialized.
private val _changes by TransientProperty { PublishSubject.create<Change>() }
private val _changes by transient { PublishSubject.create<Change>() }
@CordaSerializable
private data class Child(val tracker: ProgressTracker, @Transient val subscription: Subscription?)

View File

@ -0,0 +1,58 @@
package net.corda.core.utilities
import net.corda.core.crypto.random63BitValue
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class KotlinUtilsTest {
@Test
fun `transient property which is null`() {
val test = NullTransientProperty()
test.transientValue
test.transientValue
assertThat(test.evalCount).isEqualTo(1)
}
@Test
fun `transient property with non-capturing lamba`() {
val original = NonCapturingTransientProperty()
val originalVal = original.transientVal
val copy = original.serialize().deserialize()
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
}
@Test
fun `transient property with capturing lamba`() {
val original = CapturingTransientProperty("Hello")
val originalVal = original.transientVal
val copy = original.serialize().deserialize()
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
assertThat(copy.transientVal).startsWith("Hello")
}
private class NullTransientProperty {
var evalCount = 0
val transientValue by transient {
evalCount++
null
}
}
@CordaSerializable
private class NonCapturingTransientProperty {
val transientVal by transient { random63BitValue() }
}
@CordaSerializable
private class CapturingTransientProperty(prefix: String) {
private val seed = random63BitValue()
val transientVal by transient { prefix + seed + random63BitValue() }
}
}

View File

@ -3,7 +3,6 @@ package net.corda.irs.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.contracts.Fix
import net.corda.contracts.FixableDealState
import net.corda.core.TransientProperty
import net.corda.core.contracts.*
import net.corda.core.crypto.toBase58String
import net.corda.core.flows.FlowLogic
@ -17,6 +16,7 @@ import net.corda.core.seconds
import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.transient
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.trace
import net.corda.flows.TwoPartyDealFlow
@ -103,7 +103,7 @@ object FixingFlow {
override val progressTracker: ProgressTracker = TwoPartyDealFlow.Primary.tracker()) : TwoPartyDealFlow.Primary() {
@Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty {
internal val dealToFix: StateAndRef<FixableDealState> by transient {
val state = serviceHub.loadState(payload.ref) as TransactionState<FixableDealState>
StateAndRef(state, payload.ref)
}