diff --git a/core/src/main/kotlin/net/corda/core/Utils.kt b/core/src/main/kotlin/net/corda/core/Utils.kt index bf38a24d07..3ee55d50ba 100644 --- a/core/src/main/kotlin/net/corda/core/Utils.kt +++ b/core/src/main/kotlin/net/corda/core/Utils.kt @@ -35,7 +35,6 @@ import java.util.zip.ZipInputStream import java.util.zip.ZipOutputStream import kotlin.concurrent.withLock import kotlin.reflect.KClass -import kotlin.reflect.KProperty val Int.days: Duration get() = Duration.ofDays(this.toLong()) @Suppress("unused") // It's here for completeness @@ -215,17 +214,7 @@ class ThreadBox(val content: T, val lock: ReentrantLock = ReentrantLock() @CordaSerializable abstract class RetryableException(message: String) : FlowException(message) -/** - * A simple wrapper that enables the use of Kotlin's "val x by TransientProperty { ... }" syntax. Such a property - * will not be serialized to disk, and if it's missing (or the first time it's accessed), the initializer will be - * used to set it up. Note that the initializer will be called with the TransientProperty object locked. - */ -class TransientProperty(private val initializer: () -> T) { - @Transient private var v: T? = null - @Synchronized - operator fun getValue(thisRef: Any?, property: KProperty<*>) = v ?: initializer().also { v = it } -} /** * Given a path to a zip file, extracts it to the given directory. diff --git a/core/src/main/kotlin/net/corda/core/serialization/CordaClassResolver.kt b/core/src/main/kotlin/net/corda/core/serialization/CordaClassResolver.kt index 11f450cac9..864167b91e 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/CordaClassResolver.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/CordaClassResolver.kt @@ -3,12 +3,13 @@ package net.corda.core.serialization import com.esotericsoftware.kryo.* import com.esotericsoftware.kryo.io.Input import com.esotericsoftware.kryo.io.Output +import com.esotericsoftware.kryo.serializers.FieldSerializer import com.esotericsoftware.kryo.util.DefaultClassResolver import com.esotericsoftware.kryo.util.Util import net.corda.core.node.AttachmentsClassLoader import net.corda.core.utilities.loggerFor import java.io.PrintWriter -import java.lang.reflect.Modifier +import java.lang.reflect.Modifier.isAbstract import java.nio.charset.StandardCharsets import java.nio.file.Files import java.nio.file.Paths @@ -52,18 +53,16 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean } private fun checkClass(type: Class<*>): Registration? { - /** If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking. */ + // If call path has disabled whitelisting (see [CordaKryo.register]), just return without checking. if (!whitelistEnabled) return null // Allow primitives, abstracts and interfaces - if (type.isPrimitive || type == Any::class.java || Modifier.isAbstract(type.modifiers) || type == String::class.java) return null + if (type.isPrimitive || type == Any::class.java || isAbstract(type.modifiers) || type == String::class.java) return null // If array, recurse on element type - if (type.isArray) { - return checkClass(type.componentType) - } - if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) { - // Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry. - return checkClass(type.superclass) - } + if (type.isArray) return checkClass(type.componentType) + // Specialised enum entry, so just resolve the parent Enum type since cannot annotate the specialised entry. + if (!type.isEnum && Enum::class.java.isAssignableFrom(type)) return checkClass(type.superclass) + // Kotlin lambdas require some special treatment + if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) return null // It's safe to have the Class already, since Kryo loads it with initialisation off. // If we use a whitelist with blacklisting capabilities, whitelist.hasListed(type) may throw an IllegalStateException if input class is blacklisted. // Thus, blacklisting precedes annotation checking. @@ -74,34 +73,40 @@ class CordaClassResolver(val whitelist: ClassWhitelist, val amqpEnabled: Boolean } override fun registerImplicit(type: Class<*>): Registration { - val hasAnnotation = checkForAnnotation(type) // If something is not annotated, or AMQP is disabled, we stay serializing with Kryo. This will typically be the // case for flow checkpoints (ignoring all cases where AMQP is disabled) since our top level messaging data structures // are annotated and once we enter AMQP serialisation we stay with it for the entire object subgraph. - if (!hasAnnotation || !amqpEnabled) { - val objectInstance = try { - type.kotlin.objectInstance - } catch (t: Throwable) { - // objectInstance will throw if the type is something like a lambda - null - } - // We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent. - val references = kryo.references - try { - kryo.references = true - val serializer = if (objectInstance != null) KotlinObjectSerializer(objectInstance) else kryo.getDefaultSerializer(type) - return register(Registration(type, serializer, NAME.toInt())) - } finally { - kryo.references = references - } - } else { + if (checkForAnnotation(type) && amqpEnabled) { // Build AMQP serializer return register(Registration(type, KryoAMQPSerializer, NAME.toInt())) } + + val objectInstance = try { + type.kotlin.objectInstance + } catch (t: Throwable) { + null // objectInstance will throw if the type is something like a lambda + } + + // We have to set reference to true, since the flag influences how String fields are treated and we want it to be consistent. + val references = kryo.references + try { + kryo.references = true + val serializer = if (objectInstance != null) { + KotlinObjectSerializer(objectInstance) + } else if (kotlin.jvm.internal.Lambda::class.java.isAssignableFrom(type)) { + // Kotlin lambdas extend this class and any captured variables are stored in synthentic fields + FieldSerializer(kryo, type).apply { setIgnoreSyntheticFields(false) } + } else { + kryo.getDefaultSerializer(type) + } + return register(Registration(type, serializer, NAME.toInt())) + } finally { + kryo.references = references + } } - // Trivial Serializer which simply returns the given instance which we already know is a Kotlin object - private class KotlinObjectSerializer(val objectInstance: Any) : Serializer() { + // Trivial Serializer which simply returns the given instance, which we already know is a Kotlin object + private class KotlinObjectSerializer(private val objectInstance: Any) : Serializer() { override fun read(kryo: Kryo, input: Input, type: Class): Any = objectInstance override fun write(kryo: Kryo, output: Output, obj: Any) = Unit } diff --git a/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt b/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt index 19e1a90e06..6df50156d3 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/KotlinUtils.kt @@ -1,7 +1,9 @@ package net.corda.core.utilities +import net.corda.core.serialization.CordaSerializable import org.slf4j.Logger import org.slf4j.LoggerFactory +import kotlin.reflect.KProperty /** * Get the [Logger] for a class using the syntax @@ -20,5 +22,30 @@ inline fun Logger.debug(msg: () -> String) { if (isDebugEnabled) debug(msg()) } +/** + * A simple wrapper that enables the use of Kotlin's `val x by transient { ... }` syntax. Such a property + * will not be serialized, and if it's missing (or the first time it's accessed), the initializer will be + * used to set it up. + */ +@Suppress("DEPRECATION") +fun transient(initializer: () -> T) = TransientProperty(initializer) + +@Deprecated("Use transient") +@CordaSerializable +class TransientProperty(private val initialiser: () -> T) { + @Transient private var initialised = false + @Transient private var value: T? = null + + @Suppress("UNCHECKED_CAST") + @Synchronized + operator fun getValue(thisRef: Any?, property: KProperty<*>): T { + if (!initialised) { + value = initialiser() + initialised = true + } + return value as T + } +} + /** @see NonEmptySet.copyOf */ -fun Collection.toNonEmptySet(): NonEmptySet = NonEmptySet.copyOf(this) \ No newline at end of file +fun Collection.toNonEmptySet(): NonEmptySet = NonEmptySet.copyOf(this) diff --git a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt index fc788eb10c..6cd9526795 100644 --- a/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt +++ b/core/src/main/kotlin/net/corda/core/utilities/ProgressTracker.kt @@ -1,10 +1,8 @@ package net.corda.core.utilities -import net.corda.core.TransientProperty import net.corda.core.serialization.CordaSerializable import rx.Observable import rx.Subscription -import rx.subjects.BehaviorSubject import rx.subjects.PublishSubject import java.util.* @@ -76,7 +74,7 @@ class ProgressTracker(vararg steps: Step) { val steps = arrayOf(UNSTARTED, *steps, DONE) // This field won't be serialized. - private val _changes by TransientProperty { PublishSubject.create() } + private val _changes by transient { PublishSubject.create() } @CordaSerializable private data class Child(val tracker: ProgressTracker, @Transient val subscription: Subscription?) diff --git a/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt new file mode 100644 index 0000000000..b58c74a51f --- /dev/null +++ b/core/src/test/kotlin/net/corda/core/utilities/KotlinUtilsTest.kt @@ -0,0 +1,58 @@ +package net.corda.core.utilities + +import net.corda.core.crypto.random63BitValue +import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.deserialize +import net.corda.core.serialization.serialize +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test + +class KotlinUtilsTest { + @Test + fun `transient property which is null`() { + val test = NullTransientProperty() + test.transientValue + test.transientValue + assertThat(test.evalCount).isEqualTo(1) + } + + @Test + fun `transient property with non-capturing lamba`() { + val original = NonCapturingTransientProperty() + val originalVal = original.transientVal + val copy = original.serialize().deserialize() + val copyVal = copy.transientVal + assertThat(copyVal).isNotEqualTo(originalVal) + assertThat(copy.transientVal).isEqualTo(copyVal) + } + + @Test + fun `transient property with capturing lamba`() { + val original = CapturingTransientProperty("Hello") + val originalVal = original.transientVal + val copy = original.serialize().deserialize() + val copyVal = copy.transientVal + assertThat(copyVal).isNotEqualTo(originalVal) + assertThat(copy.transientVal).isEqualTo(copyVal) + assertThat(copy.transientVal).startsWith("Hello") + } + + private class NullTransientProperty { + var evalCount = 0 + val transientValue by transient { + evalCount++ + null + } + } + + @CordaSerializable + private class NonCapturingTransientProperty { + val transientVal by transient { random63BitValue() } + } + + @CordaSerializable + private class CapturingTransientProperty(prefix: String) { + private val seed = random63BitValue() + val transientVal by transient { prefix + seed + random63BitValue() } + } +} \ No newline at end of file diff --git a/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt b/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt index 57110870aa..9b8a11cd96 100644 --- a/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt +++ b/samples/irs-demo/src/main/kotlin/net/corda/irs/flows/FixingFlow.kt @@ -3,7 +3,6 @@ package net.corda.irs.flows import co.paralleluniverse.fibers.Suspendable import net.corda.contracts.Fix import net.corda.contracts.FixableDealState -import net.corda.core.TransientProperty import net.corda.core.contracts.* import net.corda.core.crypto.toBase58String import net.corda.core.flows.FlowLogic @@ -17,6 +16,7 @@ import net.corda.core.seconds import net.corda.core.serialization.CordaSerializable import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.TransactionBuilder +import net.corda.core.utilities.transient import net.corda.core.utilities.ProgressTracker import net.corda.core.utilities.trace import net.corda.flows.TwoPartyDealFlow @@ -103,7 +103,7 @@ object FixingFlow { override val progressTracker: ProgressTracker = TwoPartyDealFlow.Primary.tracker()) : TwoPartyDealFlow.Primary() { @Suppress("UNCHECKED_CAST") - internal val dealToFix: StateAndRef by TransientProperty { + internal val dealToFix: StateAndRef by transient { val state = serviceHub.loadState(payload.ref) as TransactionState StateAndRef(state, payload.ref) }