Fixed TransientProperty so that it actually works during (de)serialisation

This commit is contained in:
Shams Asari 2017-07-17 10:18:33 +01:00
parent 195189070a
commit c62387f3f6
6 changed files with 124 additions and 47 deletions

View File

@ -35,7 +35,6 @@ import java.util.zip.ZipInputStream
import java.util.zip.ZipOutputStream import java.util.zip.ZipOutputStream
import kotlin.concurrent.withLock import kotlin.concurrent.withLock
import kotlin.reflect.KClass import kotlin.reflect.KClass
import kotlin.reflect.KProperty
val Int.days: Duration get() = Duration.ofDays(this.toLong()) val Int.days: Duration get() = Duration.ofDays(this.toLong())
@Suppress("unused") // It's here for completeness @Suppress("unused") // It's here for completeness
@ -215,17 +214,7 @@ class ThreadBox<out T>(val content: T, val lock: ReentrantLock = ReentrantLock()
@CordaSerializable @CordaSerializable
abstract class RetryableException(message: String) : FlowException(message) abstract class RetryableException(message: String) : FlowException(message)
/**
* A simple wrapper that enables the use of Kotlin's "val x by TransientProperty { ... }" syntax. Such a property
* will not be serialized to disk, and if it's missing (or the first time it's accessed), the initializer will be
* used to set it up. Note that the initializer will be called with the TransientProperty object locked.
*/
class TransientProperty<out T>(private val initializer: () -> T) {
@Transient private var v: T? = null
@Synchronized
operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: initializer().also { v = it }
}
/** /**
* Given a path to a zip file, extracts it to the given directory. * Given a path to a zip file, extracts it to the given directory.

View File

@ -3,12 +3,13 @@ package net.corda.core.serialization
import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.*
import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.serializers.FieldSerializer
import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util import com.esotericsoftware.kryo.util.Util
import net.corda.core.node.AttachmentsClassLoader import net.corda.core.node.AttachmentsClassLoader
import net.corda.core.utilities.loggerFor import net.corda.core.utilities.loggerFor
import java.io.PrintWriter import java.io.PrintWriter
import java.lang.reflect.Modifier import java.lang.reflect.Modifier.isAbstract
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.Paths import java.nio.file.Paths
@ -52,18 +53,16 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean
} }
private fun checkClass(type: Class<*>): Registration? { private fun checkClass(type: Class<*>): Registration? {
/** If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking. */ // If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking.
if (!whitelistEnabled) return null if (!whitelistEnabled) return null
// Allow primitives, abstracts and interfaces // Allow primitives, abstracts and interfaces
if (type.isPrimitive || type == Any::class.java || Modifier.isAbstract(type.modifiers) || type == String::class.java) return null if (type.isPrimitive || type == Any::class.java || isAbstract(type.modifiers) || type == String::class.java) return null
// If array, recurse on element type // If array, recurse on element type
if (type.isArray) { if (type.isArray) return checkClass(type.componentType)
return checkClass(type.componentType)
}
if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) {
// Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry. // Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry.
return checkClass(type.superclass) if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass)
} // Kotlin lambdas require some special treatment
if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) return null
// It's safe to have the Class already, since Kryo loads it with initialisation off. // It's safe to have the Class already, since Kryo loads it with initialisation off.
// If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted. // If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted.
// Thus, blacklisting precedes annotation checking. // Thus, blacklisting precedes annotation checking.
@ -74,34 +73,40 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean
} }
override fun registerImplicit(type: Class<*>): Registration { override fun registerImplicit(type: Class<*>): Registration {
val hasAnnotation = checkForAnnotation(type)
// If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the // If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the
// case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures // case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures
// are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph. // are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph.
if (!hasAnnotation || !amqpEnabled) { if (checkForAnnotation(type) && amqpEnabled) {
// Build AMQP serializer
return register(Registration(type, KryoAMQPSerializer, NAME.toInt()))
}
val objectInstance = try { val objectInstance = try {
type.kotlin.objectInstance type.kotlin.objectInstance
} catch (t: Throwable) { } catch (t: Throwable) {
// objectInstance will throw if the type is something like a lambda null // objectInstance will throw if the type is something like a lambda
null
} }
// We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent. // We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent.
val references = kryo.references val references = kryo.references
try { try {
kryo.references = true kryo.references = true
val serializer = if (objectInstance != null) KotlinObjectSerializer(objectInstance) else kryo.getDefaultSerializer(type) val serializer = if (objectInstance != null) {
KotlinObjectSerializer(objectInstance)
} else if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) {
// Kotlin lambdas extend this class and any captured variables are stored in synthentic fields
FieldSerializer<Any>(kryo, type).apply { setIgnoreSyntheticFields(false) }
} else {
kryo.getDefaultSerializer(type)
}
return register(Registration(type, serializer, NAME.toInt())) return register(Registration(type, serializer, NAME.toInt()))
} finally { } finally {
kryo.references = references kryo.references = references
} }
} else {
// Build AMQP serializer
return register(Registration(type, KryoAMQPSerializer, NAME.toInt()))
}
} }
// Trivial Serializer which simply returns the given instance which we already know is a Kotlin object // Trivial Serializer which simply returns the given instance, which we already know is a Kotlin object
private class KotlinObjectSerializer(val objectInstance: Any) : Serializer<Any>() { private class KotlinObjectSerializer(private val objectInstance: Any) : Serializer<Any>() {
override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any = objectInstance override fun read(kryo: Kryo, input: Input, type: Class<Any>): Any = objectInstance
override fun write(kryo: Kryo, output: Output, obj: Any) = Unit override fun write(kryo: Kryo, output: Output, obj: Any) = Unit
} }

View File

@ -1,7 +1,9 @@
package net.corda.core.utilities package net.corda.core.utilities
import net.corda.core.serialization.CordaSerializable
import org.slf4j.Logger import org.slf4j.Logger
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import kotlin.reflect.KProperty
/** /**
* Get the [Logger] for a class using the syntax * Get the [Logger] for a class using the syntax
@ -20,5 +22,30 @@ inline fun Logger.debug(msg: () -> String) {
if (isDebugEnabled) debug(msg()) if (isDebugEnabled) debug(msg())
} }
/**
* A simple wrapper that enables the use of Kotlin's `val x by transient { ... }` syntax. Such a property
* will not be serialized, and if it's missing (or the first time it's accessed), the initializer will be
* used to set it up.
*/
@Suppress("DEPRECATION")
fun <T> transient(initializer: () -> T) = TransientProperty(initializer)
@Deprecated("Use transient")
@CordaSerializable
class TransientProperty<out T>(private val initialiser: () -> T) {
@Transient private var initialised = false
@Transient private var value: T? = null
@Suppress("UNCHECKED_CAST")
@Synchronized
operator fun getValue(thisRef: Any?, property: KProperty<*>): T {
if (!initialised) {
value = initialiser()
initialised = true
}
return value as T
}
}
/** @see NonEmptySet.copyOf */ /** @see NonEmptySet.copyOf */
fun <T> Collection<T>.toNonEmptySet(): NonEmptySet<T> = NonEmptySet.copyOf(this) fun <T> Collection<T>.toNonEmptySet(): NonEmptySet<T> = NonEmptySet.copyOf(this)

View File

@ -1,10 +1,8 @@
package net.corda.core.utilities package net.corda.core.utilities
import net.corda.core.TransientProperty
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import rx.Observable import rx.Observable
import rx.Subscription import rx.Subscription
import rx.subjects.BehaviorSubject
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.util.* import java.util.*
@ -76,7 +74,7 @@ class ProgressTracker(vararg steps: Step) {
val steps = arrayOf(UNSTARTED, *steps, DONE) val steps = arrayOf(UNSTARTED, *steps, DONE)
// This field won't be serialized. // This field won't be serialized.
private val _changes by TransientProperty { PublishSubject.create<Change>() } private val _changes by transient { PublishSubject.create<Change>() }
@CordaSerializable @CordaSerializable
private data class Child(val tracker: ProgressTracker, @Transient val subscription: Subscription?) private data class Child(val tracker: ProgressTracker, @Transient val subscription: Subscription?)

View File

@ -0,0 +1,58 @@
package net.corda.core.utilities
import net.corda.core.crypto.random63BitValue
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
class KotlinUtilsTest {
@Test
fun `transient property which is null`() {
val test = NullTransientProperty()
test.transientValue
test.transientValue
assertThat(test.evalCount).isEqualTo(1)
}
@Test
fun `transient property with non-capturing lamba`() {
val original = NonCapturingTransientProperty()
val originalVal = original.transientVal
val copy = original.serialize().deserialize()
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
}
@Test
fun `transient property with capturing lamba`() {
val original = CapturingTransientProperty("Hello")
val originalVal = original.transientVal
val copy = original.serialize().deserialize()
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
assertThat(copy.transientVal).startsWith("Hello")
}
private class NullTransientProperty {
var evalCount = 0
val transientValue by transient {
evalCount++
null
}
}
@CordaSerializable
private class NonCapturingTransientProperty {
val transientVal by transient { random63BitValue() }
}
@CordaSerializable
private class CapturingTransientProperty(prefix: String) {
private val seed = random63BitValue()
val transientVal by transient { prefix + seed + random63BitValue() }
}
}

View File

@ -3,7 +3,6 @@ package net.corda.irs.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.contracts.Fix import net.corda.contracts.Fix
import net.corda.contracts.FixableDealState import net.corda.contracts.FixableDealState
import net.corda.core.TransientProperty
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.crypto.toBase58String import net.corda.core.crypto.toBase58String
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
@ -17,6 +16,7 @@ import net.corda.core.seconds
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.transactions.TransactionBuilder import net.corda.core.transactions.TransactionBuilder
import net.corda.core.utilities.transient
import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.trace import net.corda.core.utilities.trace
import net.corda.flows.TwoPartyDealFlow import net.corda.flows.TwoPartyDealFlow
@ -103,7 +103,7 @@ object FixingFlow {
override val progressTracker: ProgressTracker = TwoPartyDealFlow.Primary.tracker()) : TwoPartyDealFlow.Primary() { override val progressTracker: ProgressTracker = TwoPartyDealFlow.Primary.tracker()) : TwoPartyDealFlow.Primary() {
@Suppress("UNCHECKED_CAST") @Suppress("UNCHECKED_CAST")
internal val dealToFix: StateAndRef<FixableDealState> by TransientProperty { internal val dealToFix: StateAndRef<FixableDealState> by transient {
val state = serviceHub.loadState(payload.ref) as TransactionState<FixableDealState> val state = serviceHub.loadState(payload.ref) as TransactionState<FixableDealState>
StateAndRef(state, payload.ref) StateAndRef(state, payload.ref)
} }