Merge pull request #1402 from corda/corda/os-merge-20-09-2018

Corda OS merge 20.09.2018
This commit is contained in:
szymonsztuka 2018-09-21 10:10:35 +01:00 committed by GitHub
commit ce9fe71e35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
90 changed files with 1670 additions and 607 deletions

View File

@ -4448,8 +4448,6 @@ public interface net.corda.core.serialization.SerializationCustomSerializer
public abstract PROXY toProxy(OBJ)
##
public final class net.corda.core.serialization.SerializationDefaults extends java.lang.Object
@NotNull
public final net.corda.core.serialization.SerializationContext getCHECKPOINT_CONTEXT()
@NotNull
public final net.corda.core.serialization.SerializationContext getP2P_CONTEXT()
@NotNull
@ -6886,8 +6884,6 @@ public final class net.corda.testing.core.SerializationEnvironmentRule extends j
@NotNull
public org.junit.runners.model.Statement apply(org.junit.runners.model.Statement, org.junit.runner.Description)
@NotNull
public final net.corda.core.serialization.SerializationContext getCheckpointContext()
@NotNull
public final net.corda.core.serialization.SerializationFactory getSerializationFactory()
public static final net.corda.testing.core.SerializationEnvironmentRule$Companion Companion
##

View File

@ -91,6 +91,7 @@ see changes to this list.
* Ivan Schasny (R3)
* James Brown (R3)
* James Carlyle (R3)
* Janis Olekss (Accenture)
* Jared Harwayne-Gidansky (BNY Mellon)
* Jayavaradhan Sambedu (Société Générale)
* Joel Dudley (R3)
@ -137,6 +138,7 @@ see changes to this list.
* Mike Hearn (R3)
* Mike Ward (R3)
* Mike Reichelt (US Bank)
* Milen Dobrinov (Industria)
* Mohamed Amine LEGHERABA
* Mustafa Ozturk (Natixis)
* Nick Skinner (Northern Trust)
@ -145,6 +147,7 @@ see changes to this list.
* Nuam Athaweth (MUFG)
* Oscar Zibordi de Paiva (Scopus Soluções em TI)
* OP Financial
* Parnika Sharma (BCS Technology)
* Patrick Kuo (R3)
* Pekka Kaipio (OP Financial)
* Phillip Griffin
@ -176,6 +179,7 @@ see changes to this list.
* Scott James
* Sean Zhang (Wells Fargo)
* Shams Asari (R3)
* Shivan Sawant (Persistent Systems Limited)
* Siddhartha Sengupta (Tradewind Markets)
* Simon Taylor (Barclays)
* Sofus Mortensen (Digital Asset Holdings)

View File

@ -82,6 +82,7 @@ buildscript {
// Update 121 is required for ObjectInputFilter.
// Updates [131, 161] also have zip compression bugs on MacOS (High Sierra).
// when the java version in NodeStartup.hasMinimumJavaVersion() changes, so must this check
ext.java8_minUpdateVersion = '171'
repositories {

View File

@ -197,7 +197,7 @@ class JacksonSupportTest(@Suppress("unused") private val name: String, factory:
fun DigitalSignatureWithCert() {
val digitalSignature = DigitalSignatureWithCert(MINI_CORP.identity.certificate, secureRandomBytes(128))
val json = mapper.valueToTree<ObjectNode>(digitalSignature)
val (by, bytes) = json.assertHasOnlyFields("by", "bytes")
val (by, bytes) = json.assertHasOnlyFields("by", "bytes", "parentCertsChain")
assertThat(by.valueAs<X509Certificate>(mapper)).isEqualTo(MINI_CORP.identity.certificate)
assertThat(bytes.binaryValue()).isEqualTo(digitalSignature.bytes)
assertThat(mapper.convertValue<DigitalSignatureWithCert>(json)).isEqualTo(digitalSignature)

View File

@ -3,9 +3,13 @@ package net.corda.client.rpc
import net.corda.core.CordaRuntimeException
/**
* Thrown to indicate a fatal error in the RPC system itself, as opposed to an error generated by the invoked
* method.
* Thrown to indicate a fatal error in the RPC system itself, as opposed to an error generated by the invoked method.
*/
open class RPCException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause) {
constructor(msg: String) : this(msg, null)
}
/**
* Signals that the underlying [RPCConnection] dropped.
*/
open class ConnectionFailureException(cause: Throwable? = null) : RPCException("Connection failure detected.", cause)

View File

@ -2,11 +2,8 @@ package net.corda.client.rpc.internal
import net.corda.client.rpc.CordaRPCClient
import net.corda.client.rpc.CordaRPCClientConfiguration
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.pendingFlowsCount
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.messaging.ClientRpcSslOptions
import rx.Observable
/** Utility which exposes the internal Corda RPC constructor to other internal Corda components */
fun createCordaRPCClientWithSslAndClassLoader(
@ -14,14 +11,4 @@ fun createCordaRPCClientWithSslAndClassLoader(
configuration: CordaRPCClientConfiguration = CordaRPCClientConfiguration.DEFAULT,
sslConfiguration: ClientRpcSslOptions? = null,
classLoader: ClassLoader? = null
) = CordaRPCClient.createWithSslAndClassLoader(hostAndPort, configuration, sslConfiguration, classLoader)
fun CordaRPCOps.drainAndShutdown(): Observable<Unit> {
setFlowsDrainingModeEnabled(true)
return pendingFlowsCount().updates
.doOnError { error ->
throw error
}
.doOnCompleted { shutdown() }.map { }
}
) = CordaRPCClient.createWithSslAndClassLoader(hostAndPort, configuration, sslConfiguration, classLoader)

View File

@ -7,6 +7,7 @@ import com.github.benmanes.caffeine.cache.RemovalCause
import com.github.benmanes.caffeine.cache.RemovalListener
import com.google.common.util.concurrent.SettableFuture
import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.client.rpc.ConnectionFailureException
import net.corda.client.rpc.CordaRPCClientConfiguration
import net.corda.client.rpc.RPCException
import net.corda.client.rpc.RPCSinceVersion
@ -552,7 +553,7 @@ class RPCClientProxyHandler(
m.keys.forEach { k ->
observationExecutorPool.run(k) {
try {
m[k]?.onError(RPCException("Connection failure detected."))
m[k]?.onError(ConnectionFailureException())
} catch (th: Throwable) {
log.error("Unexpected exception when RPC connection failure handling", th)
}
@ -561,7 +562,7 @@ class RPCClientProxyHandler(
observableContext.observableMap.invalidateAll()
rpcReplyMap.forEach { _, replyFuture ->
replyFuture.setException(RPCException("Connection failure detected."))
replyFuture.setException(ConnectionFailureException())
}
rpcReplyMap.clear()

View File

@ -5,7 +5,7 @@ kotlinVersion=1.2.51
platformVersion=4
guavaVersion=25.1-jre
proguardVersion=6.0.3
bouncycastleVersion=1.57
bouncycastleVersion=1.60
typesafeConfigVersion=1.3.1
jsr305Version=3.0.2
artifactoryPluginVersion=4.7.3

View File

@ -52,6 +52,7 @@ task patchCore(type: Zip, dependsOn: coreJarTask) {
exclude 'net/corda/core/crypto/SHA256DigestSupplier.class'
exclude 'net/corda/core/internal/*ToggleField*.class'
exclude 'net/corda/core/serialization/*SerializationFactory*.class'
exclude 'net/corda/core/serialization/internal/CheckpointSerializationFactory*.class'
exclude 'net/corda/core/utilities/SgxSupport*.class'
}

View File

@ -0,0 +1,74 @@
package net.corda.core.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.serialization.SerializedBytes
import net.corda.core.utilities.ByteSequence
import java.io.NotSerializableException
/**
* A deterministic version of [CheckpointSerializationFactory] that does not use thread-locals to manage serialization
* context.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private var _currentContext: CheckpointSerializationContext? = null
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext
if (context != null) _currentContext = context
try {
return block()
} finally {
if (context != null) _currentContext = priorContext
}
}
companion object {
/**
* A default factory for serialization/deserialization.
*/
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}

View File

@ -106,7 +106,11 @@ fun PublicKey.isValid(content: ByteArray, signature: DigitalSignature): Boolean
/** Render a public key to its hash (in Base58) of its serialised form using the DL prefix. */
fun PublicKey.toStringShort(): String = "DL" + this.toSHA256Bytes().toBase58()
/** Return a [Set] of the contained keys if this is a [CompositeKey]; otherwise, return a [Set] with a single element (this [PublicKey]). */
/**
* Return a [Set] of the contained leaf keys if this is a [CompositeKey].
* Otherwise, return a [Set] with a single element (this [PublicKey]).
* <i>Note that leaf keys cannot be of type [CompositeKey].</i>
*/
val PublicKey.keys: Set<PublicKey> get() = (this as? CompositeKey)?.leafKeys ?: setOf(this)
/** Return true if [otherKey] fulfils the requirements of this [PublicKey]. */
@ -115,7 +119,12 @@ fun PublicKey.isFulfilledBy(otherKey: PublicKey): Boolean = isFulfilledBy(setOf(
/** Return true if [otherKeys] fulfil the requirements of this [PublicKey]. */
fun PublicKey.isFulfilledBy(otherKeys: Iterable<PublicKey>): Boolean = (this as? CompositeKey)?.isFulfilledBy(otherKeys) ?: (this in otherKeys)
/** Checks whether any of the given [keys] matches a leaf on the [CompositeKey] tree or a single [PublicKey]. */
/**
* Checks whether any of the given [keys] matches a leaf on the [CompositeKey] tree or a single [PublicKey].
*
* <i>Note that this function checks against leaves, which cannot be of type [CompositeKey]. Due to that, if any of the
* [otherKeys] is a [CompositeKey], this function will not find a match.</i>
*/
fun PublicKey.containsAny(otherKeys: Iterable<PublicKey>): Boolean {
return if (this is CompositeKey) keys.intersect(otherKeys).isNotEmpty()
else this in otherKeys

View File

@ -4,26 +4,42 @@ import net.corda.core.crypto.DigitalSignature
import net.corda.core.crypto.SignedData
import net.corda.core.crypto.verify
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.DeprecatedConstructorForDeserialization
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.utilities.OpaqueBytes
import java.security.cert.CertPath
import java.security.cert.X509Certificate
import java.security.cert.*
// TODO: Rename this to DigitalSignature.WithCert once we're happy for it to be public API. The methods will need documentation
// and the correct exceptions will be need to be annotated
/** A digital signature with attached certificate of the public key. */
open class DigitalSignatureWithCert(val by: X509Certificate, bytes: ByteArray) : DigitalSignature(bytes) {
/** A digital signature with attached certificate of the public key and (optionally) the remaining chain of the certificates from the certificate path. */
class DigitalSignatureWithCert(val by: X509Certificate, val parentCertsChain: List<X509Certificate>, bytes: ByteArray) : DigitalSignature(bytes) {
@DeprecatedConstructorForDeserialization(1)
constructor(by: X509Certificate, bytes: ByteArray) : this(by, emptyList(), bytes)
val fullCertChain: List<X509Certificate> get() = listOf(by) + parentCertsChain
val fullCertPath: CertPath get() = CertificateFactory.getInstance("X.509").generateCertPath(fullCertChain)
fun verify(content: ByteArray): Boolean = by.publicKey.verify(content, this)
fun verify(content: OpaqueBytes): Boolean = verify(content.bytes)
}
/**
* A digital signature with attached certificate path. The first certificate in the path corresponds to the data signer key.
* @param path certificate path associated with this signature
* @param bytes signature bytes
*/
class DigitalSignatureWithCertPath(val path: List<X509Certificate>, bytes: ByteArray): DigitalSignatureWithCert(path.first(), bytes)
init {
if (parentCertsChain.isNotEmpty()) {
val parameters = PKIXParameters(setOf(TrustAnchor(parentCertsChain.last(), null))).apply { isRevocationEnabled = false }
try {
CertPathValidator.getInstance("PKIX").validate(fullCertPath, parameters)
} catch (e: CertPathValidatorException) {
throw IllegalArgumentException(
"""Cert path failed to validate.
Reason: ${e.reason}
Offending cert index: ${e.index}
Cert path: $fullCertPath
""", e)
}
}
}
}
/** Similar to [SignedData] but instead of just attaching the public key, the certificate for the key is attached instead. */
@CordaSerializable

View File

@ -22,7 +22,6 @@ import net.corda.core.serialization.CordaSerializable
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.Try
import rx.Observable
import rx.subjects.PublishSubject
import java.io.IOException
import java.io.InputStream
import java.security.PublicKey
@ -405,38 +404,20 @@ interface CordaRPCOps : RPCOps {
* This does not wait for flows to be completed.
*/
fun shutdown()
}
/**
* Returns a [DataFeed] that keeps track on the count of pending flows.
*/
fun CordaRPCOps.pendingFlowsCount(): DataFeed<Int, Pair<Int, Int>> {
/**
* Shuts the node down. Returns immediately.
* @param drainPendingFlows whether the node will wait for pending flows to be completed before exiting. While draining, new flows from RPC will be rejected.
*/
fun terminate(drainPendingFlows: Boolean = false)
val stateMachineState = stateMachinesFeed()
var pendingFlowsCount = stateMachineState.snapshot.size
var completedFlowsCount = 0
val updates = PublishSubject.create<Pair<Int, Int>>()
stateMachineState
.updates
.doOnNext { update ->
when (update) {
is StateMachineUpdate.Added -> {
pendingFlowsCount++
updates.onNext(completedFlowsCount to pendingFlowsCount)
}
is StateMachineUpdate.Removed -> {
completedFlowsCount++
updates.onNext(completedFlowsCount to pendingFlowsCount)
if (completedFlowsCount == pendingFlowsCount) {
updates.onCompleted()
}
}
}
}.subscribe()
if (pendingFlowsCount == 0) {
updates.onCompleted()
}
return DataFeed(pendingFlowsCount, updates)
/**
* Returns whether the node is waiting for pending flows to complete before shutting down.
* Disabling draining mode cancels this state.
*
* @return whether the node will shutdown when the pending flows count reaches zero.
*/
fun isWaitingForShutdown(): Boolean
}
inline fun <reified T : ContractState> CordaRPCOps.vaultQueryBy(criteria: QueryCriteria = QueryCriteria.VaultQueryCriteria(),

View File

@ -160,7 +160,7 @@ class Vault<out T : ContractState>(val states: Iterable<StateAndRef<T>>) {
val notary: AbstractParty?,
val lockId: String?,
val lockUpdateTime: Instant?,
val isRelevant: Vault.RelevancyStatus?
val relevancyStatus: Vault.RelevancyStatus?
) {
constructor(ref: StateRef,
contractStateClassName: String,

View File

@ -73,7 +73,7 @@ sealed class QueryCriteria : GenericQueryCriteria<QueryCriteria, IQueryCriteriaP
abstract class CommonQueryCriteria : QueryCriteria() {
abstract val status: Vault.StateStatus
open val isRelevant: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
open val relevancyStatus: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
abstract val contractStateTypes: Set<Class<out ContractState>>?
override fun visit(parser: IQueryCriteriaParser): Collection<Predicate> {
return parser.parseCriteria(this)
@ -90,7 +90,7 @@ sealed class QueryCriteria : GenericQueryCriteria<QueryCriteria, IQueryCriteriaP
val notary: List<AbstractParty>? = null,
val softLockingCondition: SoftLockingCondition? = null,
val timeCondition: TimeCondition? = null,
override val isRelevant: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
override val relevancyStatus: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
) : CommonQueryCriteria() {
override fun visit(parser: IQueryCriteriaParser): Collection<Predicate> {
super.visit(parser)
@ -125,15 +125,15 @@ sealed class QueryCriteria : GenericQueryCriteria<QueryCriteria, IQueryCriteriaP
val externalId: List<String>? = null,
override val status: Vault.StateStatus = Vault.StateStatus.UNCONSUMED,
override val contractStateTypes: Set<Class<out ContractState>>? = null,
override val isRelevant: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
override val relevancyStatus: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
) : CommonQueryCriteria() {
constructor(
participants: List<AbstractParty>? = null,
linearId: List<UniqueIdentifier>? = null,
status: Vault.StateStatus = Vault.StateStatus.UNCONSUMED,
contractStateTypes: Set<Class<out ContractState>>? = null,
isRelevant: Vault.RelevancyStatus
) : this(participants, linearId?.map { it.id }, linearId?.mapNotNull { it.externalId }, status, contractStateTypes, isRelevant)
relevancyStatus: Vault.RelevancyStatus
) : this(participants, linearId?.map { it.id }, linearId?.mapNotNull { it.externalId }, status, contractStateTypes, relevancyStatus)
constructor(
participants: List<AbstractParty>? = null,
@ -175,7 +175,7 @@ sealed class QueryCriteria : GenericQueryCriteria<QueryCriteria, IQueryCriteriaP
val issuerRef: List<OpaqueBytes>? = null,
override val status: Vault.StateStatus = Vault.StateStatus.UNCONSUMED,
override val contractStateTypes: Set<Class<out ContractState>>? = null,
override val isRelevant: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
override val relevancyStatus: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
) : CommonQueryCriteria() {
override fun visit(parser: IQueryCriteriaParser): Collection<Predicate> {
super.visit(parser)
@ -215,7 +215,7 @@ sealed class QueryCriteria : GenericQueryCriteria<QueryCriteria, IQueryCriteriaP
val expression: CriteriaExpression<L, Boolean>,
override val status: Vault.StateStatus = Vault.StateStatus.UNCONSUMED,
override val contractStateTypes: Set<Class<out ContractState>>? = null,
override val isRelevant: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
override val relevancyStatus: Vault.RelevancyStatus = Vault.RelevancyStatus.ALL
) : CommonQueryCriteria() {
override fun visit(parser: IQueryCriteriaParser): Collection<Predicate> {
super.visit(parser)

View File

@ -207,7 +207,13 @@ interface SerializationContext {
* The use case that we are serializing for, since it influences the implementations chosen.
*/
@KeepForDJVM
enum class UseCase { P2P, RPCServer, RPCClient, Storage, Checkpoint, Testing }
enum class UseCase {
P2P,
RPCServer,
RPCClient,
Storage,
Testing
}
}
/**
@ -230,7 +236,6 @@ object SerializationDefaults {
@DeleteForDJVM val RPC_SERVER_CONTEXT get() = effectiveSerializationEnv.rpcServerContext
@DeleteForDJVM val RPC_CLIENT_CONTEXT get() = effectiveSerializationEnv.rpcClientContext
@DeleteForDJVM val STORAGE_CONTEXT get() = effectiveSerializationEnv.storageContext
@DeleteForDJVM val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
}
/**

View File

@ -0,0 +1,198 @@
package net.corda.core.serialization.internal
import net.corda.core.DeleteForDJVM
import net.corda.core.DoNotImplement
import net.corda.core.KeepForDJVM
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.sequence
import java.io.NotSerializableException
object CheckpointSerializationDefaults {
@DeleteForDJVM
val CHECKPOINT_CONTEXT get() = effectiveSerializationEnv.checkpointContext
val CHECKPOINT_SERIALIZATION_FACTORY get() = effectiveSerializationEnv.checkpointSerializationFactory
}
/**
* A class for serializing and deserializing objects at checkpoints, using Kryo serialization.
*/
@KeepForDJVM
class CheckpointSerializationFactory(
private val scheme: CheckpointSerializationScheme
) {
val defaultContext: CheckpointSerializationContext get() = _currentContext.get() ?: effectiveSerializationEnv.checkpointContext
private val creator: List<StackTraceElement> = Exception().stackTrace.asList()
/**
* Deserialize the bytes in to an object, using the prefixed bytes to determine the format.
*
* @param byteSequence The bytes to deserialize, including a format header prefix.
* @param clazz The class or superclass or the object to be deserialized, or [Any] or [Object] if unknown.
* @param context A context that configures various parameters to deserialization.
*/
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
return withCurrentContext(context) { scheme.deserialize(byteSequence, clazz, context) }
}
/**
* Serialize an object to bytes using the preferred serialization format version from the context.
*
* @param obj The object to be serialized.
* @param context A context that configures various parameters to serialization, including the serialization format version.
*/
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return withCurrentContext(context) { scheme.serialize(obj, context) }
}
override fun toString(): String {
return "${this.javaClass.name} scheme=$scheme ${creator.joinToString("\n")}"
}
override fun equals(other: Any?): Boolean {
return other is CheckpointSerializationFactory && other.scheme == this.scheme
}
override fun hashCode(): Int = scheme.hashCode()
private val _currentContext = ThreadLocal<CheckpointSerializationContext?>()
/**
* Change the current context inside the block to that supplied.
*/
fun <T> withCurrentContext(context: CheckpointSerializationContext?, block: () -> T): T {
val priorContext = _currentContext.get()
if (context != null) _currentContext.set(context)
try {
return block()
} finally {
if (context != null) _currentContext.set(priorContext)
}
}
companion object {
val defaultFactory: CheckpointSerializationFactory get() = effectiveSerializationEnv.checkpointSerializationFactory
}
}
@KeepForDJVM
@DoNotImplement
interface CheckpointSerializationScheme {
@Throws(NotSerializableException::class)
fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T
@Throws(NotSerializableException::class)
fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T>
}
/**
* Parameters to checkpoint serialization and deserialization.
*/
@KeepForDJVM
@DoNotImplement
interface CheckpointSerializationContext {
/**
* If non-null, apply this encoding (typically compression) when serializing.
*/
val encoding: SerializationEncoding?
/**
* The class loader to use for deserialization.
*/
val deserializationClassLoader: ClassLoader
/**
* A whitelist that contains (mostly for security purposes) which classes can be serialized and deserialized.
*/
val whitelist: ClassWhitelist
/**
* A whitelist that determines (mostly for security purposes) whether a particular encoding may be used when deserializing.
*/
val encodingWhitelist: EncodingWhitelist
/**
* A map of any addition properties specific to the particular use case.
*/
val properties: Map<Any, Any>
/**
* Duplicate references to the same object preserved in the wire format and when deserialized when this is true,
* otherwise they appear as new copies of the object.
*/
val objectReferencesEnabled: Boolean
/**
* Helper method to return a new context based on this context with the property added.
*/
fun withProperty(property: Any, value: Any): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with object references disabled.
*/
fun withoutReferences(): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the deserialization class loader changed.
*/
fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the appropriate class loader constructed from the passed attachment identifiers.
* (Requires the attachment storage to have been enabled).
*/
@Throws(MissingAttachmentsException::class)
fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): CheckpointSerializationContext
/**
* Helper method to return a new context based on this context with the given class specifically whitelisted.
*/
fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext
/**
* A shallow copy of this context but with the given (possibly null) encoding.
*/
fun withEncoding(encoding: SerializationEncoding?): CheckpointSerializationContext
/**
* A shallow copy of this context but with the given encoding whitelist.
*/
fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist): CheckpointSerializationContext
}
/*
* The following extension methods are disambiguated from the AMQP-serialization methods by requiring that an
* explicit [CheckpointSerializationContext] parameter be provided.
*/
/*
* Convenience extension method for deserializing a ByteSequence, utilising the default factory.
*/
inline fun <reified T : Any> ByteSequence.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing SerializedBytes with type matching, utilising the default factory.
*/
inline fun <reified T : Any> SerializedBytes<T>.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
return serializationFactory.deserialize(this, T::class.java, context)
}
/**
* Convenience extension method for deserializing a ByteArray, utilising the default factory.
*/
inline fun <reified T : Any> ByteArray.checkpointDeserialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): T {
require(isNotEmpty()) { "Empty bytes" }
return this.sequence().checkpointDeserialize(serializationFactory, context)
}
/**
* Convenience extension method for serializing an object of type T, utilising the default factory.
*/
fun <T : Any> T.checkpointSerialize(serializationFactory: CheckpointSerializationFactory = CheckpointSerializationFactory.defaultFactory,
context: CheckpointSerializationContext): SerializedBytes<T> {
return serializationFactory.serialize(this, context)
}

View File

@ -12,11 +12,12 @@ import net.corda.core.serialization.SerializationFactory
@KeepForDJVM
interface SerializationEnvironment {
val serializationFactory: SerializationFactory
val checkpointSerializationFactory: CheckpointSerializationFactory
val p2pContext: SerializationContext
val rpcServerContext: SerializationContext
val rpcClientContext: SerializationContext
val storageContext: SerializationContext
val checkpointContext: SerializationContext
val checkpointContext: CheckpointSerializationContext
}
@KeepForDJVM
@ -26,18 +27,21 @@ open class SerializationEnvironmentImpl(
rpcServerContext: SerializationContext? = null,
rpcClientContext: SerializationContext? = null,
storageContext: SerializationContext? = null,
checkpointContext: SerializationContext? = null) : SerializationEnvironment {
checkpointContext: CheckpointSerializationContext? = null,
checkpointSerializationFactory: CheckpointSerializationFactory? = null) : SerializationEnvironment {
// Those that are passed in as null are never inited:
override lateinit var rpcServerContext: SerializationContext
override lateinit var rpcClientContext: SerializationContext
override lateinit var storageContext: SerializationContext
override lateinit var checkpointContext: SerializationContext
override lateinit var checkpointContext: CheckpointSerializationContext
override lateinit var checkpointSerializationFactory: CheckpointSerializationFactory
init {
rpcServerContext?.let { this.rpcServerContext = it }
rpcClientContext?.let { this.rpcClientContext = it }
storageContext?.let { this.storageContext = it }
checkpointContext?.let { this.checkpointContext = it }
checkpointSerializationFactory?.let { this.checkpointSerializationFactory = it }
}
}

View File

@ -1,11 +1,14 @@
package net.corda.core.flows;
import net.corda.core.serialization.internal.CheckpointSerializationDefaults;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.core.serialization.SerializationDefaults;
import net.corda.core.serialization.SerializationFactory;
import net.corda.testing.core.SerializationEnvironmentRule;
import org.junit.Rule;
import org.junit.Test;
import static net.corda.core.serialization.internal.CheckpointSerializationAPIKt.checkpointSerialize;
import static net.corda.core.serialization.SerializationAPIKt.serialize;
import static org.junit.Assert.assertNull;
@ -28,10 +31,13 @@ public class SerializationApiInJavaTest {
public void enforceSerializationDefaultsApi() {
SerializationDefaults defaults = SerializationDefaults.INSTANCE;
SerializationFactory factory = defaults.getSERIALIZATION_FACTORY();
CheckpointSerializationDefaults checkpointDefaults = CheckpointSerializationDefaults.INSTANCE;
CheckpointSerializationFactory checkpointSerializationFactory = checkpointDefaults.getCHECKPOINT_SERIALIZATION_FACTORY();
serialize("hello", factory, defaults.getP2P_CONTEXT());
serialize("hello", factory, defaults.getRPC_SERVER_CONTEXT());
serialize("hello", factory, defaults.getRPC_CLIENT_CONTEXT());
serialize("hello", factory, defaults.getSTORAGE_CONTEXT());
serialize("hello", factory, defaults.getCHECKPOINT_CONTEXT());
checkpointSerialize("hello", checkpointSerializationFactory, checkpointDefaults.getCHECKPOINT_CONTEXT());
}
}

View File

@ -105,7 +105,7 @@ internal class UseRefState(val linearId: UniqueIdentifier) : FlowLogic<SignedTra
val notary = serviceHub.networkMapCache.notaryIdentities.first()
val query = QueryCriteria.LinearStateQueryCriteria(
linearId = listOf(linearId),
isRelevant = Vault.RelevancyStatus.ALL
relevancyStatus = Vault.RelevancyStatus.ALL
)
val referenceState = serviceHub.vaultService.queryBy<ContractState>(query).states.single()
return subFlow(FinalityFlow(

View File

@ -3,9 +3,10 @@ package net.corda.core.utilities
import com.esotericsoftware.kryo.KryoException
import net.corda.core.crypto.random63BitValue
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.serialization.internal.SerializationContextImpl
import net.corda.serialization.internal.CheckpointSerializationContextImpl
import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule
@ -24,12 +25,11 @@ class KotlinUtilsTest {
@Rule
val expectedEx: ExpectedException = ExpectedException.none()
private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = SerializationContextImpl(kryoMagic,
private val KRYO_CHECKPOINT_NOWHITELIST_CONTEXT = CheckpointSerializationContextImpl(
javaClass.classLoader,
EmptyWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Checkpoint,
null)
@Test
@ -44,7 +44,7 @@ class KotlinUtilsTest {
fun `checkpointing a transient property with non-capturing lambda`() {
val original = NonCapturingTransientProperty()
val originalVal = original.transientVal
val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
@ -55,15 +55,14 @@ class KotlinUtilsTest {
expectedEx.expect(KryoException::class.java)
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = NonCapturingTransientProperty()
original.serialize(context = KRYO_CHECKPOINT_CONTEXT.withEncoding(null))
.deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT.withEncoding(null)).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
}
@Test
fun `checkpointing a transient property with capturing lambda`() {
val original = CapturingTransientProperty("Hello")
val originalVal = original.transientVal
val copy = original.serialize(context = KRYO_CHECKPOINT_CONTEXT).deserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copy = original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT).checkpointDeserialize(context = KRYO_CHECKPOINT_CONTEXT)
val copyVal = copy.transientVal
assertThat(copyVal).isNotEqualTo(originalVal)
assertThat(copy.transientVal).isEqualTo(copyVal)
@ -76,8 +75,8 @@ class KotlinUtilsTest {
expectedEx.expectMessage("is not annotated or on the whitelist, so cannot be used in serialization")
val original = CapturingTransientProperty("Hello")
original.serialize(context = KRYO_CHECKPOINT_CONTEXT.withEncoding(null))
.deserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
original.checkpointSerialize(context = KRYO_CHECKPOINT_CONTEXT.withEncoding(null)).checkpointDeserialize(context = KRYO_CHECKPOINT_NOWHITELIST_CONTEXT)
}
private class NullTransientProperty {

View File

@ -22,6 +22,7 @@ import net.corda.core.node.NodeInfo
import net.corda.core.node.services.CordaService
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.core.serialization.internal.nodeSerializationEnv
@ -37,7 +38,7 @@ import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.internal.cordapp.JarScanningCordappLoader
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.services.ContractUpgradeHandler
import net.corda.node.services.FinalityHandler
import net.corda.node.services.NotaryChangeHandler
@ -358,8 +359,8 @@ class FlowWorkerServiceHub(override val configuration: NodeConfiguration, overri
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme())
},
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),

View File

@ -338,6 +338,14 @@ class CordaRpcWorkerOps(
shutdownNode.invoke()
}
override fun terminate(drainPendingFlows: Boolean) {
TODO("not implemented")
}
override fun isWaitingForShutdown(): Boolean {
TODO("not implemented")
}
private fun stateMachineInfoFromFlowLogic(flowLogic: FlowLogic<*>): StateMachineInfo {
return StateMachineInfo(flowLogic.runId, flowLogic.javaClass.name, flowLogic.stateMachine.context.toFlowInitiator(), flowLogic.track(), flowLogic.stateMachine.context)
}

View File

@ -14,6 +14,7 @@ import net.corda.core.node.services.ContractUpgradeService
import net.corda.core.node.services.TransactionVerifierService
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SingletonSerializeAsToken
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.core.serialization.internal.nodeSerializationEnv
@ -27,7 +28,7 @@ import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.internal.cordapp.JarScanningCordappLoader
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.services.api.AuditService
import net.corda.node.services.api.MonitoringService
import net.corda.node.services.api.ServiceHubInternal
@ -194,8 +195,8 @@ class RpcWorkerServiceHub(override val configuration: NodeConfiguration, overrid
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme())
},
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),

View File

@ -22,7 +22,7 @@ private fun generateCashSumCriteria(currency: Currency): QueryCriteria {
val ccyIndex = builder { CashSchemaV1.PersistentCashState::currency.equal(currency.currencyCode) }
// This query should only return cash states the calling node is a participant of (meaning they can be modified/spent).
val ccyCriteria = QueryCriteria.VaultCustomQueryCriteria(ccyIndex, isRelevant = Vault.RelevancyStatus.RELEVANT)
val ccyCriteria = QueryCriteria.VaultCustomQueryCriteria(ccyIndex, relevancyStatus = Vault.RelevancyStatus.RELEVANT)
return sumCriteria.and(ccyCriteria)
}
@ -32,7 +32,7 @@ private fun generateCashSumsCriteria(): QueryCriteria {
orderBy = Sort.Direction.DESC)
}
// This query should only return cash states the calling node is a participant of (meaning they can be modified/spent).
return QueryCriteria.VaultCustomQueryCriteria(sum, isRelevant = Vault.RelevancyStatus.RELEVANT)
return QueryCriteria.VaultCustomQueryCriteria(sum, relevancyStatus = Vault.RelevancyStatus.RELEVANT)
}
private fun rowsToAmount(currency: Currency, rows: Vault.Page<FungibleAsset<*>>): Amount<Currency> {

View File

@ -40,7 +40,7 @@ class CashSelectionH2Impl : AbstractCashSelection() {
FROM vault_states AS vs, contract_cash_states AS ccs
WHERE vs.transaction_id = ccs.transaction_id AND vs.output_index = ccs.output_index
AND vs.state_status = 0
AND vs.is_relevant = 0
AND vs.relevancy_status = 0
AND ccs.ccy_code = ? and @t < ?
AND (vs.lock_id = ? OR vs.lock_id is null)
""" +

View File

@ -34,7 +34,7 @@ class CashSelectionOracleImpl : AbstractCashSelection(maxRetries = 16, retrySlee
FROM contract_cash_states ccs, vault_states vs
WHERE vs.transaction_id = ccs.transaction_id AND vs.output_index = ccs.output_index
AND vs.state_status = 0
AND vs.is_modifiable = 0
AND vs.relevancy_status = 0
AND ccs.ccy_code = ?
AND (vs.lock_id = ? OR vs.lock_id is null)
"""+

View File

@ -42,7 +42,7 @@ class CashSelectionPostgreSQLImpl : AbstractCashSelection() {
FROM vault_states AS vs, contract_cash_states AS ccs
WHERE vs.transaction_id = ccs.transaction_id AND vs.output_index = ccs.output_index
AND vs.state_status = 0
AND vs.is_relevant = 0
AND vs.relevancy_status = 0
AND ccs.ccy_code = ?
AND (vs.lock_id = ? OR vs.lock_id is null)
""" +

View File

@ -64,7 +64,7 @@ class CashSelectionSQLServerImpl : AbstractCashSelection(maxRetries = 16, retryS
ON vs.transaction_id = ccs.transaction_id AND vs.output_index = ccs.output_index
WHERE
vs.state_status = 0
AND vs.is_relevant = 0
AND vs.relevancy_status = 0
AND ccs.ccy_code = ?
AND (vs.lock_id = ? OR vs.lock_id IS NULL)
"""

View File

@ -0,0 +1,52 @@
package net.corda.nodeapi.internal
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.StateMachineUpdate
import rx.Observable
import rx.schedulers.Schedulers
import rx.subjects.PublishSubject
import java.util.concurrent.TimeUnit
/**
* Returns a [DataFeed] of the number of pending flows. The [Observable] for the updates will complete the moment all pending flows will have terminated.
*/
fun CordaRPCOps.pendingFlowsCount(): DataFeed<Int, Pair<Int, Int>> {
val updates = PublishSubject.create<Pair<Int, Int>>()
val initialPendingFlowsCount = stateMachinesFeed().let {
var completedFlowsCount = 0
var pendingFlowsCount = it.snapshot.size
it.updates.observeOn(Schedulers.io()).subscribe({ update ->
when (update) {
is StateMachineUpdate.Added -> {
pendingFlowsCount++
updates.onNext(completedFlowsCount to pendingFlowsCount)
}
is StateMachineUpdate.Removed -> {
completedFlowsCount++
updates.onNext(completedFlowsCount to pendingFlowsCount)
if (completedFlowsCount == pendingFlowsCount) {
updates.onCompleted()
}
}
}
}, updates::onError)
if (pendingFlowsCount == 0) {
updates.onCompleted()
}
pendingFlowsCount
}
return DataFeed(initialPendingFlowsCount, updates)
}
/**
* Returns an [Observable] that will complete when the node will have cancelled the draining shutdown hook.
*
* @param interval the value of the polling interval, default is 5.
* @param unit the time unit of the polling interval, default is [TimeUnit.SECONDS].
*/
fun CordaRPCOps.hasCancelledDrainingShutdown(interval: Long = 5, unit: TimeUnit = TimeUnit.SECONDS): Observable<Unit> {
return Observable.interval(interval, unit).map { isWaitingForShutdown() }.takeFirst { waiting -> waiting == false }.map { Unit }
}

View File

@ -147,11 +147,19 @@ internal constructor(private val initSerEnv: Boolean,
}
}
/** Entry point for Cordform */
/** Old Entry point for Cordform
*
* TODO: Remove once the gradle plugins are updated to 4.0.30
*/
fun bootstrap(directory: Path, cordappJars: List<Path>) {
bootstrap(directory, cordappJars, copyCordapps = true, fromCordform = true)
}
/** Entry point for Cordform */
fun bootstrapCordform(directory: Path, cordappJars: List<Path>) {
bootstrap(directory, cordappJars, copyCordapps = false, fromCordform = true)
}
/** Entry point for the tool */
fun bootstrap(directory: Path, copyCordapps: Boolean) {
// Don't accidently include the bootstrapper jar as a CorDapp!

View File

@ -2,8 +2,6 @@ package net.corda.nodeapi.internal.network
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.CertRole
import net.corda.core.internal.DigitalSignatureWithCert
import net.corda.core.internal.DigitalSignatureWithCertPath
import net.corda.core.internal.SignedDataWithCert
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo
@ -59,9 +57,10 @@ data class ParametersUpdate(
/** Verify that a Network Map certificate path and its [CertRole] is correct. */
fun <T : Any> SignedDataWithCert<T>.verifiedNetworkMapCert(rootCert: X509Certificate): T {
require(CertRole.extract(sig.by) == CertRole.NETWORK_MAP) { "Incorrect cert role: ${CertRole.extract(sig.by)}" }
val path = when (this.sig) {
is DigitalSignatureWithCertPath -> (sig as DigitalSignatureWithCertPath).path
else -> listOf(sig.by, rootCert)
val path = if (sig.parentCertsChain.isEmpty()) {
listOf(sig.by, rootCert)
} else {
sig.fullCertChain
}
X509Utilities.validateCertificateChain(rootCert, path)
return verified()

View File

@ -1,9 +1,13 @@
package net.corda.nodeapi.internal.crypto
import net.corda.core.crypto.Crypto
import net.corda.core.crypto.*
import net.corda.core.crypto.Crypto.COMPOSITE_KEY
import net.corda.core.crypto.Crypto.ECDSA_SECP256K1_SHA256
import net.corda.core.crypto.Crypto.ECDSA_SECP256R1_SHA256
import net.corda.core.crypto.Crypto.EDDSA_ED25519_SHA512
import net.corda.core.crypto.Crypto.RSA_SHA256
import net.corda.core.crypto.Crypto.SPHINCS256_SHA256
import net.corda.core.crypto.Crypto.generateKeyPair
import net.corda.core.crypto.newSecureRandom
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.core.serialization.SerializationContext
@ -12,6 +16,8 @@ import net.corda.core.serialization.serialize
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.nodeapi.internal.config.MutualSslConfiguration
import net.corda.nodeapi.internal.createDevNodeCa
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_IDENTITY_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.crypto.X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME
import net.corda.nodeapi.internal.protonwrapper.netty.init
import net.corda.nodeapi.internal.registerDevP2pCertificates
import net.corda.nodeapi.internal.registerDevSigningCertificates
@ -24,17 +30,24 @@ import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.TestIdentity
import net.corda.testing.internal.stubs.CertificateStoreStubs
import net.corda.testing.internal.createDevIntermediateCaCertPath
import net.i2p.crypto.eddsa.EdDSAPrivateKey
import org.assertj.core.api.Assertions.assertThat
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.jcajce.provider.asymmetric.rsa.BCRSAPrivateCrtKey
import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import sun.security.rsa.RSAPrivateCrtKeyImpl
import java.io.DataInputStream
import java.io.DataOutputStream
import java.io.IOException
import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.file.Path
import java.security.Key
import java.security.KeyPair
import java.security.PrivateKey
import java.security.cert.CertPath
import java.security.cert.X509Certificate
import java.util.*
@ -53,6 +66,28 @@ class X509UtilitiesTest {
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_DHE_RSA_WITH_AES_128_GCM_SHA256"
)
// We ensure that all of the algorithms are both used (at least once) as first and second in the following [Pair]s.
// We also add [DEFAULT_TLS_SIGNATURE_SCHEME] and [DEFAULT_IDENTITY_SIGNATURE_SCHEME] combinations for consistency.
val certChainSchemeCombinations = listOf(
Pair(DEFAULT_TLS_SIGNATURE_SCHEME, DEFAULT_TLS_SIGNATURE_SCHEME),
Pair(DEFAULT_IDENTITY_SIGNATURE_SCHEME, DEFAULT_IDENTITY_SIGNATURE_SCHEME),
Pair(DEFAULT_TLS_SIGNATURE_SCHEME, DEFAULT_IDENTITY_SIGNATURE_SCHEME),
Pair(ECDSA_SECP256R1_SHA256, SPHINCS256_SHA256),
Pair(ECDSA_SECP256K1_SHA256, RSA_SHA256),
Pair(EDDSA_ED25519_SHA512, ECDSA_SECP256K1_SHA256),
Pair(RSA_SHA256, EDDSA_ED25519_SHA512),
Pair(SPHINCS256_SHA256, ECDSA_SECP256R1_SHA256)
)
val schemeToKeyTypes = listOf(
// By default, JKS returns SUN EC key.
Triple(ECDSA_SECP256R1_SHA256,java.security.interfaces.ECPrivateKey::class.java, org.bouncycastle.jce.interfaces.ECPrivateKey::class.java),
Triple(ECDSA_SECP256K1_SHA256,java.security.interfaces.ECPrivateKey::class.java, org.bouncycastle.jce.interfaces.ECPrivateKey::class.java),
Triple(EDDSA_ED25519_SHA512, EdDSAPrivateKey::class.java, EdDSAPrivateKey::class.java),
// By default, JKS returns SUN RSA key.
Triple(RSA_SHA256, RSAPrivateCrtKeyImpl::class.java, BCRSAPrivateCrtKey::class.java),
Triple(SPHINCS256_SHA256, BCSphincs256PrivateKey::class.java, BCSphincs256PrivateKey::class.java)
)
}
@Rule
@ -61,7 +96,11 @@ class X509UtilitiesTest {
@Test
fun `create valid self-signed CA certificate`() {
val caKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
Crypto.supportedSignatureSchemes().filter { it != COMPOSITE_KEY }.forEach { validSelfSignedCertificate(it) }
}
private fun validSelfSignedCertificate(signatureScheme: SignatureScheme) {
val caKey = generateKeyPair(signatureScheme)
val subject = X500Principal("CN=Test Cert,O=R3 Ltd,L=London,C=GB")
val caCert = X509Utilities.createSelfSignedCACertificate(subject, caKey)
assertEquals(subject, caCert.subjectX500Principal) // using our subject common name
@ -78,8 +117,12 @@ class X509UtilitiesTest {
@Test
fun `load and save a PEM file certificate`() {
Crypto.supportedSignatureSchemes().filter { it != COMPOSITE_KEY }.forEach { loadSavePEMCert(it) }
}
private fun loadSavePEMCert(signatureScheme: SignatureScheme) {
val tmpCertificateFile = tempFile("cacert.pem")
val caKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
val caKey = generateKeyPair(signatureScheme)
val caCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=Test Cert,O=R3 Ltd,L=London,C=GB"), caKey)
X509Utilities.saveCertificateAsPEMFile(caCert, tmpCertificateFile)
val readCertificate = X509Utilities.loadCertificateFromPEMFile(tmpCertificateFile)
@ -88,29 +131,52 @@ class X509UtilitiesTest {
@Test
fun `create valid server certificate chain`() {
val caKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
val caCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=Test CA Cert,O=R3 Ltd,L=London,C=GB"), caKey)
val subject = X500Principal("CN=Server Cert,O=R3 Ltd,L=London,C=GB")
val keyPair = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
val serverCert = X509Utilities.createCertificate(CertificateType.TLS, caCert, caKey, subject, keyPair.public)
assertEquals(subject, serverCert.subjectX500Principal) // using our subject common name
assertEquals(caCert.issuerX500Principal, serverCert.issuerX500Principal) // Issued by our CA cert
serverCert.checkValidity(Date()) // throws on verification problems
serverCert.verify(caKey.public) // throws on verification problems
serverCert.toBc().run {
certChainSchemeCombinations.forEach { createValidServerCertChain(it.first, it.second) }
}
private fun createValidServerCertChain(signatureSchemeRoot: SignatureScheme, signatureSchemeChild: SignatureScheme) {
val (caKeyPair, caCert, _, childCert, _, childSubject)
= genCaAndChildKeysCertsAndSubjects(signatureSchemeRoot, signatureSchemeChild)
assertEquals(childSubject, childCert.subjectX500Principal) // Using our subject common name.
assertEquals(caCert.issuerX500Principal, childCert.issuerX500Principal) // Issued by our CA cert.
childCert.checkValidity(Date()) // Throws on verification problems.
childCert.verify(caKeyPair.public) // Throws on verification problems.
childCert.toBc().run {
val basicConstraints = BasicConstraints.getInstance(getExtension(Extension.basicConstraints).parsedValue)
val keyUsage = KeyUsage.getInstance(getExtension(Extension.keyUsage).parsedValue)
assertFalse { keyUsage.hasUsages(5) } // Bit 5 == keyCertSign according to ASN.1 spec (see full comment on KeyUsage property)
assertNull(basicConstraints.pathLenConstraint) // Non-CA certificate
assertFalse { keyUsage.hasUsages(5) } // Bit 5 == keyCertSign according to ASN.1 spec (see full comment on KeyUsage property).
assertNull(basicConstraints.pathLenConstraint) // Non-CA certificate.
}
}
private data class CaAndChildKeysCertsAndSubjects(val caKeyPair: KeyPair,
val caCert: X509Certificate,
val childKeyPair: KeyPair,
val childCert: X509Certificate,
val caSubject: X500Principal,
val childSubject: X500Principal)
private fun genCaAndChildKeysCertsAndSubjects(signatureSchemeRoot: SignatureScheme,
signatureSchemeChild: SignatureScheme,
rootSubject: X500Principal = X500Principal("CN=Test CA Cert,O=R3 Ltd,L=London,C=GB"),
childSubject: X500Principal = X500Principal("CN=Test Child Cert,O=R3 Ltd,L=London,C=GB")): CaAndChildKeysCertsAndSubjects {
val caKeyPair = generateKeyPair(signatureSchemeRoot)
val caCert = X509Utilities.createSelfSignedCACertificate(rootSubject, caKeyPair)
val childKeyPair = generateKeyPair(signatureSchemeChild)
val childCert = X509Utilities.createCertificate(CertificateType.TLS, caCert, caKeyPair, childSubject, childKeyPair.public)
return CaAndChildKeysCertsAndSubjects(caKeyPair, caCert, childKeyPair, childCert, rootSubject, childSubject)
}
@Test
fun `create valid server certificate chain includes CRL info`() {
val caKey = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
certChainSchemeCombinations.forEach { createValidServerCertIncludeCRL(it.first, it.second) }
}
private fun createValidServerCertIncludeCRL(signatureSchemeRoot: SignatureScheme, signatureSchemeChild: SignatureScheme) {
val caKey = generateKeyPair(signatureSchemeRoot)
val caCert = X509Utilities.createSelfSignedCACertificate(X500Principal("CN=Test CA Cert,O=R3 Ltd,L=London,C=GB"), caKey)
val caSubjectKeyIdentifier = SubjectKeyIdentifier.getInstance(caCert.toBc().getExtension(Extension.subjectKeyIdentifier).parsedValue)
val keyPair = generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
val keyPair = generateKeyPair(signatureSchemeChild)
val crlDistPoint = "http://test.com"
val serverCert = X509Utilities.createCertificate(
CertificateType.TLS,
@ -128,57 +194,33 @@ class X509UtilitiesTest {
}
@Test
fun `storing EdDSA key in java keystore`() {
fun `storing all supported key types in java keystore`() {
Crypto.supportedSignatureSchemes().filter { it != COMPOSITE_KEY }.forEach { storeKeyToKeystore(it) }
}
private fun storeKeyToKeystore(signatureScheme: SignatureScheme) {
val tmpKeyStore = tempFile("keystore.jks")
val keyPair = generateKeyPair(EDDSA_ED25519_SHA512)
val keyPair = generateKeyPair(signatureScheme)
val testName = X500Principal("CN=Test,O=R3 Ltd,L=London,C=GB")
val selfSignCert = X509Utilities.createSelfSignedCACertificate(testName, keyPair)
assertTrue(Arrays.equals(selfSignCert.publicKey.encoded, keyPair.public.encoded))
// Save the EdDSA private key with self sign cert in the keystore.
// Save the private key with self sign cert in the keystore.
val keyStore = loadOrCreateKeyStore(tmpKeyStore, "keystorepass")
keyStore.setKeyEntry("Key", keyPair.private, "password".toCharArray(), arrayOf(selfSignCert))
keyStore.save(tmpKeyStore, "keystorepass")
// Load the keystore from file and make sure keys are intact.
val keyStore2 = loadOrCreateKeyStore(tmpKeyStore, "keystorepass")
val privateKey = keyStore2.getKey("Key", "password".toCharArray())
val pubKey = keyStore2.getCertificate("Key").publicKey
val reloadedKeystore = loadOrCreateKeyStore(tmpKeyStore, "keystorepass")
val reloadedPrivateKey = reloadedKeystore.getKey("Key", "password".toCharArray())
val reloadedPublicKey = reloadedKeystore.getCertificate("Key").publicKey
assertNotNull(pubKey)
assertNotNull(privateKey)
assertEquals(keyPair.public, pubKey)
assertEquals(keyPair.private, privateKey)
}
@Test
fun `signing EdDSA key with EcDSA certificate`() {
val tmpKeyStore = tempFile("keystore.jks")
val ecDSAKey = generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
val testName = X500Principal("CN=Test,O=R3 Ltd,L=London,C=GB")
val ecDSACert = X509Utilities.createSelfSignedCACertificate(testName, ecDSAKey)
val edDSAKeypair = generateKeyPair(EDDSA_ED25519_SHA512)
val edDSACert = X509Utilities.createCertificate(CertificateType.TLS, ecDSACert, ecDSAKey, BOB.name.x500Principal, edDSAKeypair.public)
// Save the EdDSA private key with cert chains.
val keyStore = loadOrCreateKeyStore(tmpKeyStore, "keystorepass")
keyStore.setKeyEntry("Key", edDSAKeypair.private, "password".toCharArray(), arrayOf(ecDSACert, edDSACert))
keyStore.save(tmpKeyStore, "keystorepass")
// Load the keystore from file and make sure keys are intact.
val keyStore2 = loadOrCreateKeyStore(tmpKeyStore, "keystorepass")
val privateKey = keyStore2.getKey("Key", "password".toCharArray())
val certs = keyStore2.getCertificateChain("Key")
val pubKey = certs.last().publicKey
assertEquals(2, certs.size)
assertNotNull(pubKey)
assertNotNull(privateKey)
assertEquals(edDSAKeypair.public, pubKey)
assertEquals(edDSAKeypair.private, privateKey)
assertNotNull(reloadedPublicKey)
assertNotNull(reloadedPrivateKey)
assertEquals(keyPair.public, reloadedPublicKey)
assertEquals(keyPair.private, reloadedPrivateKey)
}
@Test
@ -316,7 +358,17 @@ class X509UtilitiesTest {
@Test
fun `get correct private key type from Keystore`() {
val keyPair = generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
schemeToKeyTypes.forEach { getCorrectKeyFromKeystore(it.first, it.second, it.third) }
}
private fun <U, C> getCorrectKeyFromKeystore(signatureScheme: SignatureScheme, uncastedClass: Class<U>, castedClass: Class<C>) {
val keyPair = generateKeyPair(signatureScheme)
val (keyFromKeystore, keyFromKeystoreCasted) = storeAndGetKeysFromKeystore(keyPair)
assertThat(keyFromKeystore).isInstanceOf(uncastedClass)
assertThat(keyFromKeystoreCasted).isInstanceOf(castedClass)
}
private fun storeAndGetKeysFromKeystore(keyPair: KeyPair): Pair<Key, PrivateKey> {
val testName = X500Principal("CN=Test,O=R3 Ltd,L=London,C=GB")
val selfSignCert = X509Utilities.createSelfSignedCACertificate(testName, keyPair)
val keyStore = loadOrCreateKeyStore(tempFile("testKeystore.jks"), "keystorepassword")
@ -324,13 +376,15 @@ class X509UtilitiesTest {
val keyFromKeystore = keyStore.getKey("Key", "keypassword".toCharArray())
val keyFromKeystoreCasted = keyStore.getSupportedKey("Key", "keypassword")
assertTrue(keyFromKeystore is java.security.interfaces.ECPrivateKey) // by default JKS returns SUN EC key
assertTrue(keyFromKeystoreCasted is org.bouncycastle.jce.interfaces.ECPrivateKey)
return Pair(keyFromKeystore, keyFromKeystoreCasted)
}
@Test
fun `serialize - deserialize X509Certififcate`() {
fun `serialize - deserialize X509Certificate`() {
Crypto.supportedSignatureSchemes().filter { it != COMPOSITE_KEY }.forEach { serializeDeserializeX509Cert(it) }
}
private fun serializeDeserializeX509Cert(signatureScheme: SignatureScheme) {
val factory = SerializationFactoryImpl().apply { registerScheme(AMQPServerSerializationScheme()) }
val context = SerializationContextImpl(amqpMagic,
javaClass.classLoader,
@ -339,7 +393,7 @@ class X509UtilitiesTest {
true,
SerializationContext.UseCase.P2P,
null)
val expected = X509Utilities.createSelfSignedCACertificate(ALICE.name.x500Principal, Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME))
val expected = X509Utilities.createSelfSignedCACertificate(ALICE.name.x500Principal, generateKeyPair(signatureScheme))
val serialized = expected.serialize(factory, context).bytes
val actual = serialized.deserialize<X509Certificate>(factory, context)
assertEquals(expected, actual)
@ -347,6 +401,10 @@ class X509UtilitiesTest {
@Test
fun `serialize - deserialize X509CertPath`() {
Crypto.supportedSignatureSchemes().filter { it != COMPOSITE_KEY }.forEach { serializeDeserializeX509CertPath(it) }
}
private fun serializeDeserializeX509CertPath(signatureScheme: SignatureScheme) {
val factory = SerializationFactoryImpl().apply { registerScheme(AMQPServerSerializationScheme()) }
val context = SerializationContextImpl(
amqpMagic,
@ -357,7 +415,7 @@ class X509UtilitiesTest {
SerializationContext.UseCase.P2P,
null
)
val rootCAKey = Crypto.generateKeyPair(X509Utilities.DEFAULT_TLS_SIGNATURE_SCHEME)
val rootCAKey = generateKeyPair(signatureScheme)
val rootCACert = X509Utilities.createSelfSignedCACertificate(ALICE_NAME.x500Principal, rootCAKey)
val certificate = X509Utilities.createCertificate(CertificateType.TLS, rootCACert, rootCAKey, BOB_NAME.x500Principal, BOB.publicKey)
val expected = X509Utilities.buildCertPath(certificate, rootCACert)
@ -365,4 +423,33 @@ class X509UtilitiesTest {
val actual: CertPath = serialized.deserialize(factory, context)
assertEquals(expected, actual)
}
@Test
fun `signing a key type with another key type certificate then store and reload correctly from keystore`() {
certChainSchemeCombinations.forEach { signCertWithOtherKeyTypeAndTestKeystoreReload(it.first, it.second) }
}
private fun signCertWithOtherKeyTypeAndTestKeystoreReload(signatureSchemeRoot: SignatureScheme, signatureSchemeChild: SignatureScheme) {
val tmpKeyStore = tempFile("keystore.jks")
val (_, caCert, childKeyPair, childCert) = genCaAndChildKeysCertsAndSubjects(signatureSchemeRoot, signatureSchemeChild)
// Save the child private key with cert chains.
val keyStore = loadOrCreateKeyStore(tmpKeyStore, "keystorepass")
keyStore.setKeyEntry("Key", childKeyPair.private, "password".toCharArray(), arrayOf(caCert, childCert))
keyStore.save(tmpKeyStore, "keystorepass")
// Load the keystore from file and make sure keys are intact.
val reloadedKeystore = loadOrCreateKeyStore(tmpKeyStore, "keystorepass")
val reloadedPrivateKey = reloadedKeystore.getKey("Key", "password".toCharArray())
val reloadedCerts = reloadedKeystore.getCertificateChain("Key")
val reloadedPublicKey = reloadedCerts.last().publicKey
assertEquals(2, reloadedCerts.size)
assertNotNull(reloadedPublicKey)
assertNotNull(reloadedPrivateKey)
assertEquals(childKeyPair.public, reloadedPublicKey)
assertEquals(childKeyPair.private, reloadedPrivateKey)
}
}

View File

@ -43,6 +43,7 @@ class BootTests : IntegrationTest() {
@Test
fun `java deserialization is disabled`() {
val user = User("u", "p", setOf(startFlow<ObjectInputStreamFlow>()))
val devParams = NodeParameters(providedName = BOB_NAME, rpcUsers = listOf(user))
val params = NodeParameters(rpcUsers = listOf(user))
fun NodeHandle.attemptJavaDeserialization() {
@ -52,7 +53,7 @@ class BootTests : IntegrationTest() {
}
}
driver {
val devModeNode = startNode(params).getOrThrow()
val devModeNode = startNode(devParams).getOrThrow()
val node = startNode(ALICE_NAME, devMode = false, parameters = params).getOrThrow()
assertThatThrownBy { devModeNode.attemptJavaDeserialization() }.isInstanceOf(CordaRuntimeException::class.java)

View File

@ -9,6 +9,7 @@ import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.testing.core.*
import net.corda.testing.core.singleIdentity
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.TestCorDapp
import net.corda.testing.driver.driver
@ -58,8 +59,8 @@ class AsymmetricCorDappsTests : IntegrationTest() {
driver(DriverParameters(startNodesInProcess = false, cordappsForAllNodes = emptySet())) {
val nodeA = startNode(additionalCordapps = setOf(TestCorDapp.Factory.create("Szymon CorDapp", "1.0", classes = setOf(Ping::class.java)))).getOrThrow()
val nodeB = startNode(additionalCordapps = setOf(TestCorDapp.Factory.create("Szymon CorDapp", "1.0", classes = setOf(Ping::class.java, Pong::class.java)))).getOrThrow()
val nodeA = startNode(providedName = ALICE_NAME, additionalCordapps = setOf(TestCorDapp.Factory.create("Szymon CorDapp", "1.0", classes = setOf(Ping::class.java)))).getOrThrow()
val nodeB = startNode(providedName = BOB_NAME, additionalCordapps = setOf(TestCorDapp.Factory.create("Szymon CorDapp", "1.0", classes = setOf(Ping::class.java, Pong::class.java)))).getOrThrow()
nodeA.rpc.startFlow(::Ping, nodeB.nodeInfo.singleIdentity(), 1).returnValue.getOrThrow()
}
}
@ -73,7 +74,7 @@ class AsymmetricCorDappsTests : IntegrationTest() {
val cordappForNodeB = TestCorDapp.Factory.create("nodeB_only", "1.0", classes = setOf(Pong::class.java))
driver(DriverParameters(startNodesInProcess = false, cordappsForAllNodes = setOf(sharedCordapp))) {
val (nodeA, nodeB) = listOf(startNode(), startNode(additionalCordapps = setOf(cordappForNodeB))).transpose().getOrThrow()
val (nodeA, nodeB) = listOf(startNode(providedName = ALICE_NAME), startNode(providedName = BOB_NAME, additionalCordapps = setOf(cordappForNodeB))).transpose().getOrThrow()
nodeA.rpc.startFlow(::Ping, nodeB.nodeInfo.singleIdentity(), 1).returnValue.getOrThrow()
}
}
@ -87,7 +88,7 @@ class AsymmetricCorDappsTests : IntegrationTest() {
val cordappForNodeB = TestCorDapp.Factory.create("nodeB_only", "1.0", classes = setOf(Pong::class.java))
driver(DriverParameters(startNodesInProcess = true, cordappsForAllNodes = setOf(sharedCordapp))) {
val (nodeA, nodeB) = listOf(startNode(), startNode(additionalCordapps = setOf(cordappForNodeB))).transpose().getOrThrow()
val (nodeA, nodeB) = listOf(startNode(providedName = ALICE_NAME), startNode(providedName = BOB_NAME, additionalCordapps = setOf(cordappForNodeB))).transpose().getOrThrow()
nodeA.rpc.startFlow(::Ping, nodeB.nodeInfo.singleIdentity(), 1).returnValue.getOrThrow()
}
}

View File

@ -1,8 +1,11 @@
package net.corda.node.modes.draining
import co.paralleluniverse.fibers.Suspendable
import net.corda.client.rpc.internal.drainAndShutdown
import net.corda.core.flows.*
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.FlowSession
import net.corda.core.flows.InitiatedBy
import net.corda.core.flows.InitiatingFlow
import net.corda.core.flows.StartableByRPC
import net.corda.core.identity.Party
import net.corda.core.internal.concurrent.map
import net.corda.core.messaging.startFlow
@ -10,6 +13,7 @@ import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions
import net.corda.nodeapi.internal.hasCancelledDrainingShutdown
import net.corda.testing.core.*
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.PortAllocation
@ -19,6 +23,7 @@ import net.corda.testing.internal.IntegrationTestSchemas
import net.corda.testing.internal.toDatabaseSchemaName
import net.corda.testing.internal.chooseIdentity
import net.corda.testing.node.User
import net.corda.testing.node.internal.waitForShutdown
import org.assertj.core.api.AssertionsForInterfaceTypes.assertThat
import org.junit.*
import org.junit.After
@ -87,24 +92,74 @@ class P2PFlowsDrainingModeTest : IntegrationTest() {
}
@Test
fun `clean shutdown by draining`() {
driver(DriverParameters(startNodesInProcess = true, portAllocation = portAllocation, notarySpecs = emptyList())) {
fun `terminate node waiting for pending flows`() {
driver(DriverParameters(portAllocation = portAllocation, notarySpecs = emptyList())) {
val nodeA = startNode(providedName = ALICE_NAME, rpcUsers = users).getOrThrow()
val nodeB = startNode(providedName = BOB_NAME, rpcUsers = users).getOrThrow()
var successful = false
val latch = CountDownLatch(1)
nodeB.rpc.setFlowsDrainingModeEnabled(true)
IntRange(1, 10).forEach { nodeA.rpc.startFlow(::InitiateSessionFlow, nodeB.nodeInfo.chooseIdentity()) }
nodeA.rpc.drainAndShutdown()
.doOnError { error ->
error.printStackTrace()
successful = false
}
.doOnCompleted { successful = true }
.doAfterTerminate { latch.countDown() }
.subscribe()
nodeA.waitForShutdown().doOnError(Throwable::printStackTrace).doOnError { successful = false }.doOnCompleted { successful = true }.doAfterTerminate(latch::countDown).subscribe()
nodeA.rpc.terminate(true)
nodeB.rpc.setFlowsDrainingModeEnabled(false)
latch.await()
assertThat(successful).isTrue()
}
}
@Test
fun `terminate resets persistent draining mode property when waiting for pending flows`() {
driver(DriverParameters(portAllocation = portAllocation, notarySpecs = emptyList())) {
val nodeA = startNode(providedName = ALICE_NAME, rpcUsers = users).getOrThrow()
var successful = false
val latch = CountDownLatch(1)
// This would not be needed, as `terminate(true)` sets draining mode anyway, but it's here to ensure that it removes the persistent value anyway.
nodeA.rpc.setFlowsDrainingModeEnabled(true)
nodeA.rpc.waitForShutdown().doOnError(Throwable::printStackTrace).doOnError { successful = false }.doOnCompleted(nodeA::stop).doOnCompleted {
val nodeARestarted = startNode(providedName = ALICE_NAME, rpcUsers = users).getOrThrow()
successful = !nodeARestarted.rpc.isFlowsDrainingModeEnabled()
}.doAfterTerminate(latch::countDown).subscribe()
nodeA.rpc.terminate(true)
latch.await()
assertThat(successful).isTrue()
}
}
@Test
fun `disabling draining mode cancels draining shutdown`() {
driver(DriverParameters(portAllocation = portAllocation, notarySpecs = emptyList())) {
val nodeA = startNode(providedName = ALICE_NAME, rpcUsers = users).getOrThrow()
val nodeB = startNode(providedName = BOB_NAME, rpcUsers = users).getOrThrow()
var successful = false
val latch = CountDownLatch(1)
nodeB.rpc.setFlowsDrainingModeEnabled(true)
IntRange(1, 10).forEach { nodeA.rpc.startFlow(::InitiateSessionFlow, nodeB.nodeInfo.chooseIdentity()) }
nodeA.waitForShutdown().doOnError(Throwable::printStackTrace).doAfterTerminate { successful = false }.doAfterTerminate(latch::countDown).subscribe()
nodeA.rpc.terminate(true)
nodeA.rpc.hasCancelledDrainingShutdown().doOnError(Throwable::printStackTrace).doOnError { successful = false }.doOnCompleted { successful = true }.doAfterTerminate(latch::countDown).subscribe()
nodeA.rpc.setFlowsDrainingModeEnabled(false)
nodeB.rpc.setFlowsDrainingModeEnabled(false)
latch.await()
assertThat(successful).isTrue()

View File

@ -24,10 +24,8 @@ class NodeCmdLineOptions {
names = ["-f", "--config-file"],
description = ["The path to the config file. By default this is node.conf in the base directory."]
)
var configFileArgument: Path? = null
val configFile : Path
get() = configFileArgument ?: (baseDirectory / "node.conf")
private var _configFile: Path? = null
val configFile: Path get() = _configFile ?: (baseDirectory / "node.conf")
@Option(
names = ["--sshd"],
@ -57,7 +55,8 @@ class NodeCmdLineOptions {
names = ["-t", "--network-root-truststore"],
description = ["Network root trust store obtained from network operator."]
)
var networkRootTrustStorePath: Path = baseDirectory / "certificates" / "network-root-truststore.jks"
private var _networkRootTrustStorePath: Path? = null
val networkRootTrustStorePath: Path get() = _networkRootTrustStorePath ?: baseDirectory / "certificates" / "network-root-truststore.jks"
@Option(
names = ["-p", "--network-root-truststore-password"],
@ -101,7 +100,7 @@ class NodeCmdLineOptions {
)
var clearNetworkMapCache: Boolean = false
val nodeRegistrationOption : NodeRegistrationOption? by lazy {
val nodeRegistrationOption: NodeRegistrationOption? by lazy {
if (isRegistration) {
requireNotNull(networkRootTrustStorePassword) { "Network root trust store password must be provided in registration mode using --network-root-truststore-password." }
require(networkRootTrustStorePath.exists()) { "Network root trust store path: '$networkRootTrustStorePath' doesn't exist" }

View File

@ -240,7 +240,7 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
/** The implementation of the [CordaRPCOps] interface used by this node. */
open fun makeRPCOps(): CordaRPCOps {
val ops: CordaRPCOps = CordaRPCOpsImpl(services, smm, flowStarter) { shutdownExecutor.submit { stop() } }
val ops: CordaRPCOps = CordaRPCOpsImpl(services, smm, flowStarter) { shutdownExecutor.submit { stop() } }.also { it.closeOnStop() }
val proxies = mutableListOf<(CordaRPCOps) -> CordaRPCOps>()
// Mind that order is relevant here.
proxies += ::AuthenticatedRpcOpsProxy

View File

@ -5,11 +5,12 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.node.services.api.CheckpointStorage
import net.corda.node.services.statemachine.SubFlow
import net.corda.node.services.statemachine.SubFlowVersion
import net.corda.serialization.internal.SerializeAsTokenContextImpl
import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext
object CheckpointVerifier {
@ -19,13 +20,13 @@ object CheckpointVerifier {
* @throws CheckpointIncompatibleException if any offending checkpoint is found.
*/
fun verifyCheckpointsCompatible(checkpointStorage: CheckpointStorage, currentCordapps: List<Cordapp>, platformVersion: Int, serviceHub: ServiceHub, tokenizableServices: List<Any>) {
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
checkpointStorage.getAllCheckpoints().forEach { (_, serializedCheckpoint) ->
val checkpoint = try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext)
serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext)
} catch (e: Exception) {
throw CheckpointIncompatibleException.CannotBeDeserialisedException(e)
}

View File

@ -28,6 +28,7 @@ import net.corda.core.node.services.vault.*
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.loggerFor
import net.corda.node.internal.exceptions.StateMachineStoppedException
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
@ -35,11 +36,14 @@ import net.corda.node.services.messaging.context
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.exceptions.NonRpcFlowException
import net.corda.nodeapi.exceptions.RejectedCommandException
import net.corda.nodeapi.internal.pendingFlowsCount
import rx.Observable
import rx.Subscription
import java.io.InputStream
import java.net.ConnectException
import java.security.PublicKey
import java.time.Instant
import java.util.concurrent.atomic.AtomicReference
/**
* Server side implementations of RPCs available to MQ based client tools. Execution takes place on the server
@ -50,7 +54,24 @@ internal class CordaRPCOpsImpl(
private val smm: StateMachineManager,
private val flowStarter: FlowStarter,
private val shutdownNode: () -> Unit
) : CordaRPCOps {
) : CordaRPCOps, AutoCloseable {
private companion object {
private val logger = loggerFor<CordaRPCOpsImpl>()
}
private val drainingShutdownHook = AtomicReference<Subscription?>()
init {
services.nodeProperties.flowsDrainingMode.values.filter { it.isDisabled() }.subscribe({
cancelDrainingShutdownHook()
}, {
// Nothing to do in case of errors here.
})
}
private fun Pair<Boolean, Boolean>.isDisabled(): Boolean = first && !second
/**
* Returns the RPC protocol version, which is the same the node's platform Version. Exists since version 1 so guaranteed
* to be present.
@ -227,7 +248,7 @@ internal class CordaRPCOpsImpl(
return services.networkMapCache.getNodeByLegalIdentity(party)
}
override fun registeredFlows(): List<String> = services.rpcFlows.map { it.name }.sorted()
override fun registeredFlows(): List<String> = services.rpcFlows.asSequence().map(Class<*>::getName).sorted().toList()
override fun clearNetworkMapCache() {
services.networkMapCache.clearNetworkMapCache()
@ -276,18 +297,46 @@ internal class CordaRPCOpsImpl(
return vaultTrackBy(criteria, PageSpecification(), sorting, contractStateType)
}
override fun setFlowsDrainingModeEnabled(enabled: Boolean) {
services.nodeProperties.flowsDrainingMode.setEnabled(enabled)
override fun setFlowsDrainingModeEnabled(enabled: Boolean) = setPersistentDrainingModeProperty(enabled, propagateChange = true)
override fun isFlowsDrainingModeEnabled() = services.nodeProperties.flowsDrainingMode.isEnabled()
override fun shutdown() = terminate(false)
override fun terminate(drainPendingFlows: Boolean) {
if (drainPendingFlows) {
logger.info("Waiting for pending flows to complete before shutting down.")
setFlowsDrainingModeEnabled(true)
drainingShutdownHook.set(pendingFlowsCount().updates.doOnNext {(completed, total) ->
logger.info("Pending flows progress before shutdown: $completed / $total.")
}.doOnCompleted { setPersistentDrainingModeProperty(false, false) }.doOnCompleted(::cancelDrainingShutdownHook).doOnCompleted { logger.info("No more pending flows to drain. Shutting down.") }.doOnCompleted(shutdownNode::invoke).subscribe({
// Nothing to do on each update here, only completion matters.
}, { error ->
logger.error("Error while waiting for pending flows to drain in preparation for shutdown. Cause was: ${error.message}", error)
}))
} else {
shutdownNode.invoke()
}
}
override fun isFlowsDrainingModeEnabled(): Boolean {
return services.nodeProperties.flowsDrainingMode.isEnabled()
override fun isWaitingForShutdown() = drainingShutdownHook.get() != null
override fun close() {
cancelDrainingShutdownHook()
}
override fun shutdown() {
shutdownNode.invoke()
private fun cancelDrainingShutdownHook() {
drainingShutdownHook.getAndSet(null)?.let {
it.unsubscribe()
logger.info("Cancelled draining shutdown hook.")
}
}
private fun setPersistentDrainingModeProperty(enabled: Boolean, propagateChange: Boolean) = services.nodeProperties.flowsDrainingMode.setEnabled(enabled, propagateChange)
private fun stateMachineInfoFromFlowLogic(flowLogic: FlowLogic<*>): StateMachineInfo {
return StateMachineInfo(flowLogic.runId, flowLogic.javaClass.name, flowLogic.stateMachine.context.toFlowInitiator(), flowLogic.track(), flowLogic.stateMachine.context)
}

View File

@ -21,6 +21,7 @@ import net.corda.core.messaging.RPCOps
import net.corda.core.node.NetworkParameters
import net.corda.core.node.NodeInfo
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.core.utilities.NetworkHostAndPort
@ -37,7 +38,7 @@ import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.internal.security.RPCSecurityManagerWithAdditionalUser
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.node.services.Permissions
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
@ -475,8 +476,8 @@ open class Node(configuration: NodeConfiguration,
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme())
},
checkpointSerializationFactory = CheckpointSerializationFactory(KryoSerializationScheme),
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),

View File

@ -36,6 +36,7 @@ import net.corda.nodeapi.internal.persistence.DatabaseIncompatibleException
import net.corda.nodeapi.internal.persistence.DatabaseMigrationException
import net.corda.nodeapi.internal.persistence.oracleJdbcDriverSerialFilter
import net.corda.tools.shell.InteractiveShell
import org.apache.commons.lang.SystemUtils
import org.fusesource.jansi.Ansi
import org.slf4j.bridge.SLF4JBridgeHandler
import picocli.CommandLine.Mixin
@ -54,7 +55,7 @@ import java.util.*
import kotlin.system.exitProcess
/** This class is responsible for starting a Node from command line arguments. */
open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
open class NodeStartup : CordaCliWrapper("corda", "Runs a Corda Node") {
companion object {
private val logger by lazy { loggerFor<Node>() } // I guess this is lazy to allow for logging init, but why Node?
const val LOGS_DIRECTORY_NAME = "logs"
@ -63,7 +64,7 @@ open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
}
@Mixin
var cmdLineOptions = NodeCmdLineOptions()
val cmdLineOptions = NodeCmdLineOptions()
/**
* @return exit code based on the success of the node startup. This value is intended to be the exit code of the process.
@ -133,14 +134,21 @@ open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
}
private fun isValidJavaVersion(): Boolean {
if (!canNormalizeEmptyPath()) {
println("You are using a version of Java that is not supported (${System.getProperty("java.version")}). Please upgrade to the latest supported version.")
if (!hasMinimumJavaVersion()) {
println("You are using a version of Java that is not supported (${SystemUtils.JAVA_VERSION}). Please upgrade to the latest version of Java 8.")
println("Corda will now exit...")
return false
}
return true
}
private fun hasMinimumJavaVersion(): Boolean {
// when the ext.java8_minUpdateVersion gradle constant changes, so must this check
val major = SystemUtils.JAVA_VERSION_FLOAT
val update = SystemUtils.JAVA_VERSION.substringAfter("_").toLong()
return major == 1.8F && update >= 171
}
// TODO: Reconsider if automatic re-registration should be applied when something failed during initial registration.
// There might be cases where the node user should investigate what went wrong before registering again.
private fun checkUnfinishedRegistration() {
@ -157,25 +165,12 @@ open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
private val startNodeExpectedErrors = setOf(DatabaseMigrationException::class, MultipleCordappsForFlowException::class, CheckpointIncompatibleException::class, AddressBindingException::class, NetworkParametersReader::class, DatabaseIncompatibleException::class)
private fun Exception.logAsExpected(message: String? = this.message, print: (String?) -> Unit = logger::error) = print("$message [errorCode=${errorCode()}]")
private fun Exception.logAsExpected(message: String? = this.message, print: (String?) -> Unit = logger::error) = print(message)
private fun Exception.logAsUnexpected(message: String? = this.message, error: Exception = this, print: (String?, Throwable) -> Unit = logger::error) = print("$message${this.message?.let { ": $it" } ?: ""} [errorCode=${errorCode()}]", error)
private fun Exception.logAsUnexpected(message: String? = this.message, error: Exception = this, print: (String?, Throwable) -> Unit = logger::error) = print("$message${this.message?.let { ": $it" } ?: ""}", error)
private fun Exception.isOpenJdkKnownIssue() = message?.startsWith("Unknown named curve:") == true
private fun Exception.errorCode(): String {
val hash = staticLocationBasedHash()
return Integer.toOctalString(hash)
}
private fun Throwable.staticLocationBasedHash(visited: Set<Throwable> = setOf(this)): Int {
val cause = this.cause
return when {
cause != null && !visited.contains(cause) -> Objects.hash(this::class.java.name, stackTrace.customHashCode(), cause.staticLocationBasedHash(visited + cause))
else -> Objects.hash(this::class.java.name, stackTrace.customHashCode())
}
}
private val handleRegistrationError = { error: Exception ->
when (error) {
is NodeRegistrationException -> error.logAsExpected("Node registration service is unavailable. Perhaps try to perform the initial registration again after a while.")
@ -201,19 +196,6 @@ open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
}
}
private fun Array<StackTraceElement?>?.customHashCode(): Int {
if (this == null) {
return 0
}
return Arrays.hashCode(map { it?.customHashCode() ?: 0 }.toIntArray())
}
private fun StackTraceElement.customHashCode(): Int {
return Objects.hash(StackTraceElement::class.java.name, methodName, lineNumber)
}
private fun configFileNotFoundMessage(configFile: Path): String {
return """
Unable to load the node config file from '$configFile'.
@ -239,15 +221,14 @@ open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
}
private fun checkRegistrationMode(): Boolean {
val baseDirectory = cmdLineOptions.baseDirectory.normalize().toAbsolutePath()
// If the node was started with `--initial-registration`, create marker file.
// We do this here to ensure the marker is created even if parsing the args with NodeArgsParser fails.
val marker = File((baseDirectory / INITIAL_REGISTRATION_MARKER).toUri())
val marker = cmdLineOptions.baseDirectory / INITIAL_REGISTRATION_MARKER
if (!cmdLineOptions.isRegistration && !marker.exists()) {
return false
}
try {
marker.createNewFile()
marker.createFile()
} catch (e: Exception) {
logger.warn("Could not create marker file for `--initial-registration`.", e)
}
@ -548,16 +529,6 @@ open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
return hostName
}
private fun canNormalizeEmptyPath(): Boolean {
// Check we're not running a version of Java with a known bug: https://github.com/corda/corda/issues/83
return try {
Paths.get("").normalize()
true
} catch (e: ArrayIndexOutOfBoundsException) {
false
}
}
open fun drawBanner(versionInfo: VersionInfo) {
Emoji.renderIfSupported {
val messages = arrayListOf(
@ -636,4 +607,3 @@ open class NodeStartup: CordaCliWrapper("corda", "Runs a Corda Node") {
}
}

View File

@ -8,8 +8,8 @@ import com.esotericsoftware.kryo.util.DefaultClassResolver
import com.esotericsoftware.kryo.util.Util
import net.corda.core.internal.kotlinObjectInstance
import net.corda.core.internal.writer
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.utilities.contextLogger
import net.corda.serialization.internal.AttachmentsClassLoader
import net.corda.serialization.internal.MutableClassWhitelist
@ -25,7 +25,7 @@ import java.util.*
/**
* Corda specific class resolver which enables extra customisation for the purposes of serialization using Kryo
*/
class CordaClassResolver(serializationContext: SerializationContext) : DefaultClassResolver() {
class CordaClassResolver(serializationContext: CheckpointSerializationContext) : DefaultClassResolver() {
val whitelist: ClassWhitelist = TransientClassWhiteList(serializationContext.whitelist)
// These classes are assignment-compatible Java equivalents of Kotlin classes.

View File

@ -14,12 +14,11 @@ import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.TransactionSignature
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint
import net.corda.core.serialization.SerializationContext.UseCase.Storage
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.transactions.*
import net.corda.core.utilities.OpaqueBytes
import net.corda.serialization.internal.checkUseCase
import net.corda.core.utilities.SgxSupport
import net.corda.serialization.internal.serializationContextKey
import org.slf4j.Logger
@ -279,16 +278,9 @@ object SignedTransactionSerializer : Serializer<SignedTransaction>() {
}
}
sealed class UseCaseSerializer<T>(private val allowedUseCases: EnumSet<SerializationContext.UseCase>) : Serializer<T>() {
protected fun checkUseCase() {
net.corda.serialization.internal.checkUseCase(allowedUseCases)
}
}
@ThreadSafe
object PrivateKeySerializer : UseCaseSerializer<PrivateKey>(EnumSet.of(Storage, Checkpoint)) {
object PrivateKeySerializer : Serializer<PrivateKey>() {
override fun write(kryo: Kryo, output: Output, obj: PrivateKey) {
checkUseCase()
output.writeBytesWithLength(obj.encoded)
}

View File

@ -10,10 +10,9 @@ import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.esotericsoftware.kryo.serializers.ClosureSerializer
import net.corda.core.internal.uncheckedCast
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationScheme
import net.corda.core.utilities.ByteSequence
import net.corda.serialization.internal.*
import net.corda.serialization.internal.CordaSerializationEncoding.SNAPPY
@ -33,46 +32,30 @@ private object AutoCloseableSerialisationDetector : Serializer<AutoCloseable>()
override fun read(kryo: Kryo, input: Input, type: Class<AutoCloseable>) = throw IllegalStateException("Should not reach here!")
}
abstract class AbstractKryoSerializationScheme : SerializationScheme {
object KryoSerializationScheme : CheckpointSerializationScheme {
private val kryoPoolsForContexts = ConcurrentHashMap<Pair<ClassWhitelist, ClassLoader>, KryoPool>()
protected abstract fun rpcClientKryoPool(context: SerializationContext): KryoPool
protected abstract fun rpcServerKryoPool(context: SerializationContext): KryoPool
// this can be overridden in derived serialization schemes
protected open val publicKeySerializer: Serializer<PublicKey> = PublicKeySerializer
private fun getPool(context: SerializationContext): KryoPool {
private fun getPool(context: CheckpointSerializationContext): KryoPool {
return kryoPoolsForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply {
field.set(this, classResolver)
// don't allow overriding the public key serializer for checkpointing
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second
}
}.build()
SerializationContext.UseCase.RPCClient ->
rpcClientKryoPool(context)
SerializationContext.UseCase.RPCServer ->
rpcServerKryoPool(context)
else ->
KryoPool.Builder {
DefaultKryoCustomizer.customize(CordaKryo(CordaClassResolver(context)), publicKeySerializer).apply { classLoader = it.second }
}.build()
}
KryoPool.Builder {
val serializer = Fiber.getFiberSerializer(false) as KryoSerializer
val classResolver = CordaClassResolver(context).apply { setKryo(serializer.kryo) }
// TODO The ClassResolver can only be set in the Kryo constructor and Quasar doesn't provide us with a way of doing that
val field = Kryo::class.java.getDeclaredField("classResolver").apply { isAccessible = true }
serializer.kryo.apply {
field.set(this, classResolver)
// don't allow overriding the public key serializer for checkpointing
DefaultKryoCustomizer.customize(this)
addDefaultSerializer(AutoCloseable::class.java, AutoCloseableSerialisationDetector)
register(ClosureSerializer.Closure::class.java, CordaClosureSerializer)
classLoader = it.second
}
}.build()
}
}
private fun <T : Any> SerializationContext.kryo(task: Kryo.() -> T): T {
private fun <T : Any> CheckpointSerializationContext.kryo(task: Kryo.() -> T): T {
return getPool(this).run { kryo ->
kryo.context.ensureCapacity(properties.size)
properties.forEach { kryo.context.put(it.key, it.value) }
@ -84,7 +67,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: SerializationContext): T {
override fun <T : Any> deserialize(byteSequence: ByteSequence, clazz: Class<T>, context: CheckpointSerializationContext): T {
val dataBytes = kryoMagic.consume(byteSequence)
?: throw KryoException("Serialized bytes header does not match expected format.")
return context.kryo {
@ -112,7 +95,7 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
override fun <T : Any> serialize(obj: T, context: SerializationContext): SerializedBytes<T> {
override fun <T : Any> serialize(obj: T, context: CheckpointSerializationContext): SerializedBytes<T> {
return context.kryo {
SerializedBytes(kryoOutput {
kryoMagic.writeTo(this)
@ -132,13 +115,11 @@ abstract class AbstractKryoSerializationScheme : SerializationScheme {
}
}
val KRYO_CHECKPOINT_CONTEXT = SerializationContextImpl(
kryoMagic,
val KRYO_CHECKPOINT_CONTEXT = CheckpointSerializationContextImpl(
SerializationDefaults.javaClass.classLoader,
QuasarWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Checkpoint,
SNAPPY,
AlwaysAcceptEncodingWhitelist
)

View File

@ -1,14 +0,0 @@
package net.corda.node.serialization.kryo
import com.esotericsoftware.kryo.pool.KryoPool
import net.corda.core.serialization.SerializationContext
import net.corda.serialization.internal.CordaSerializationMagic
class KryoServerSerializationScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target == SerializationContext.UseCase.Checkpoint
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}

View File

@ -8,7 +8,7 @@ interface NodePropertiesStore {
interface FlowsDrainingModeOperations {
fun setEnabled(enabled: Boolean)
fun setEnabled(enabled: Boolean, propagateChange: Boolean = true)
fun isEnabled(): Boolean

View File

@ -57,12 +57,13 @@ class FlowsDrainingModeOperationsImpl(readPhysicalNodeId: () -> String, private
override val values = PublishSubject.create<Pair<Boolean, Boolean>>()!!
override fun setEnabled(enabled: Boolean) {
var oldValue: Boolean? = null
persistence.transaction {
oldValue = map.put(nodeSpecificFlowsExecutionModeKey, enabled.toString())?.toBoolean() ?: false
override fun setEnabled(enabled: Boolean, propagateChange: Boolean) {
val oldValue = persistence.transaction {
map.put(nodeSpecificFlowsExecutionModeKey, enabled.toString())?.toBoolean() ?: false
}
if (propagateChange) {
values.onNext(oldValue to enabled)
}
values.onNext(oldValue!! to enabled)
}
override fun isEnabled(): Boolean {

View File

@ -4,9 +4,9 @@ import co.paralleluniverse.fibers.Fiber
import co.paralleluniverse.fibers.Suspendable
import com.codahale.metrics.*
import net.corda.core.internal.concurrent.thenMatch
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.trace
import net.corda.node.services.api.CheckpointStorage
@ -27,7 +27,7 @@ class ActionExecutorImpl(
private val checkpointStorage: CheckpointStorage,
private val flowMessaging: FlowMessaging,
private val stateMachineManager: StateMachineManagerInternal,
private val checkpointSerializationContext: SerializationContext,
private val checkpointSerializationContext: CheckpointSerializationContext,
metrics: MetricRegistry
) : ActionExecutor {
@ -237,7 +237,7 @@ class ActionExecutorImpl(
}
private fun serializeCheckpoint(checkpoint: Checkpoint): SerializedBytes<Checkpoint> {
return checkpoint.serialize(context = checkpointSerializationContext)
return checkpoint.checkpointSerialize(context = checkpointSerializationContext)
}
private fun cancelFlowTimeout(action: Action.CancelFlowTimeout) {

View File

@ -12,8 +12,8 @@ import net.corda.core.cordapp.Cordapp
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.internal.*
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.serialize
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.Try
import net.corda.core.utilities.debug
import net.corda.core.utilities.trace
@ -70,7 +70,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
val actionExecutor: ActionExecutor,
val stateMachine: StateMachine,
val serviceHub: ServiceHubInternal,
val checkpointSerializationContext: SerializationContext,
val checkpointSerializationContext: CheckpointSerializationContext,
val unfinishedFibers: ReusableLatch
)
@ -373,7 +373,7 @@ class FlowStateMachineImpl<R>(override val id: StateMachineRunId,
Event.Suspend(
ioRequest = ioRequest,
maySkipCheckpoint = skipPersistingCheckpoint,
fiber = this.serialize(context = serializationContext.value)
fiber = this.checkpointSerialize(context = serializationContext.value)
)
} catch (throwable: Throwable) {
Event.Error(throwable)

View File

@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
@ -35,7 +39,7 @@ import net.corda.node.services.statemachine.transitions.StateMachine
import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.injectOldProgressTracker
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.serialization.internal.SerializeAsTokenContextImpl
import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch
import rx.Observable
@ -110,7 +114,7 @@ class MultiThreadedStateMachineManager(
private val transitionExecutor = makeTransitionExecutor()
private val ourSenderUUID get() = serviceHub.networkService.ourSenderUUID
private var checkpointSerializationContext: SerializationContext? = null
private var checkpointSerializationContext: CheckpointSerializationContext? = null
private var tokenizableServices: List<Any>? = null
private var actionExecutor: ActionExecutor? = null
@ -134,8 +138,8 @@ class MultiThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
this.tokenizableServices = tokenizableServices
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
@ -535,7 +539,7 @@ class MultiThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!)
val flowCorDappVersion = FlowStateMachineImpl.createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
val initialCheckpoint = Checkpoint.create(invocationContext, flowStart, flowLogic.javaClass, frozenFlowLogic, ourIdentity, flowCorDappVersion).getOrThrow()
@ -616,7 +620,7 @@ class MultiThreadedStateMachineManager(
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext!!)
} catch (exception: Throwable) {
logger.error("Encountered unrestorable checkpoint!", exception)
null
@ -661,7 +665,7 @@ class MultiThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
val fiber = when (flowState) {
is FlowState.Unstarted -> {
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val logic = flowState.frozenFlowLogic.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -680,7 +684,7 @@ class MultiThreadedStateMachineManager(
fiber
}
is FlowState.Started -> {
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -740,7 +744,7 @@ class MultiThreadedStateMachineManager(
}
}
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor {
private fun makeActionExecutor(checkpointSerializationContext: CheckpointSerializationContext): ActionExecutor {
return ActionExecutorImpl(
serviceHub,
checkpointStorage,

View File

@ -19,6 +19,10 @@ import net.corda.core.internal.concurrent.map
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.DataFeed
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
@ -36,7 +40,7 @@ import net.corda.node.utilities.AffinityExecutor
import net.corda.node.utilities.injectOldProgressTracker
import net.corda.nodeapi.internal.persistence.CordaPersistence
import net.corda.nodeapi.internal.persistence.wrapWithDatabaseTransaction
import net.corda.serialization.internal.SerializeAsTokenContextImpl
import net.corda.serialization.internal.CheckpointSerializeAsTokenContextImpl
import net.corda.serialization.internal.withTokenContext
import org.apache.activemq.artemis.utils.ReusableLatch
import rx.Observable
@ -103,7 +107,7 @@ class SingleThreadedStateMachineManager(
private val transitionExecutor = makeTransitionExecutor()
private val ourSenderUUID = serviceHub.networkService.ourSenderUUID
private var checkpointSerializationContext: SerializationContext? = null
private var checkpointSerializationContext: CheckpointSerializationContext? = null
private var actionExecutor: ActionExecutor? = null
override val allStateMachines: List<FlowLogic<*>>
@ -122,8 +126,8 @@ class SingleThreadedStateMachineManager(
override fun start(tokenizableServices: List<Any>) {
checkQuasarJavaAgentPresence()
val checkpointSerializationContext = SerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
SerializeAsTokenContextImpl(tokenizableServices, SerializationDefaults.SERIALIZATION_FACTORY, SerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
val checkpointSerializationContext = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withTokenContext(
CheckpointSerializeAsTokenContextImpl(tokenizableServices, CheckpointSerializationDefaults.CHECKPOINT_SERIALIZATION_FACTORY, CheckpointSerializationDefaults.CHECKPOINT_CONTEXT, serviceHub)
)
this.checkpointSerializationContext = checkpointSerializationContext
this.actionExecutor = makeActionExecutor(checkpointSerializationContext)
@ -531,7 +535,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
flowStateMachineImpl.transientValues = TransientReference(createTransientValues(flowId, resultFuture))
flowLogic.stateMachine = flowStateMachineImpl
val frozenFlowLogic = (flowLogic as FlowLogic<*>).serialize(context = checkpointSerializationContext!!)
val frozenFlowLogic = (flowLogic as FlowLogic<*>).checkpointSerialize(context = checkpointSerializationContext!!)
val flowCorDappVersion = createSubFlowVersion(serviceHub.cordappProvider.getCordappForFlow(flowLogic), serviceHub.myInfo.platformVersion)
@ -613,7 +617,7 @@ class SingleThreadedStateMachineManager(
private fun deserializeCheckpoint(serializedCheckpoint: SerializedBytes<Checkpoint>): Checkpoint? {
return try {
serializedCheckpoint.deserialize(context = checkpointSerializationContext!!)
serializedCheckpoint.checkpointDeserialize(context = checkpointSerializationContext!!)
} catch (exception: Throwable) {
logger.error("Encountered unrestorable checkpoint!", exception)
null
@ -658,7 +662,7 @@ class SingleThreadedStateMachineManager(
val resultFuture = openFuture<Any?>()
val fiber = when (flowState) {
is FlowState.Unstarted -> {
val logic = flowState.frozenFlowLogic.deserialize(context = checkpointSerializationContext!!)
val logic = flowState.frozenFlowLogic.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -677,7 +681,7 @@ class SingleThreadedStateMachineManager(
fiber
}
is FlowState.Started -> {
val fiber = flowState.frozenFiber.deserialize(context = checkpointSerializationContext!!)
val fiber = flowState.frozenFiber.checkpointDeserialize(context = checkpointSerializationContext!!)
val state = StateMachineState(
checkpoint = checkpoint,
pendingDeduplicationHandlers = initialDeduplicationHandler?.let { listOf(it) } ?: emptyList(),
@ -742,7 +746,7 @@ class SingleThreadedStateMachineManager(
}
}
private fun makeActionExecutor(checkpointSerializationContext: SerializationContext): ActionExecutor {
private fun makeActionExecutor(checkpointSerializationContext: CheckpointSerializationContext): ActionExecutor {
return ActionExecutorImpl(
serviceHub,
checkpointStorage,

View File

@ -2,9 +2,9 @@ package net.corda.node.services.statemachine.interceptors
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.utilities.contextLogger
import net.corda.node.services.statemachine.ActionExecutor
import net.corda.node.services.statemachine.Event
@ -68,7 +68,7 @@ class FiberDeserializationChecker {
private val jobQueue = LinkedBlockingQueue<Job>()
private var foundUnrestorableFibers: Boolean = false
fun start(checkpointSerializationContext: SerializationContext) {
fun start(checkpointSerializationContext: CheckpointSerializationContext) {
require(checkerThread == null)
checkerThread = thread(name = "FiberDeserializationChecker") {
while (true) {
@ -76,7 +76,7 @@ class FiberDeserializationChecker {
when (job) {
is Job.Check -> {
try {
job.serializedFiber.deserialize(context = checkpointSerializationContext)
job.serializedFiber.checkpointDeserialize(context = checkpointSerializationContext)
} catch (throwable: Throwable) {
log.error("Encountered unrestorable checkpoint!", throwable)
foundUnrestorableFibers = true

View File

@ -20,10 +20,10 @@ import net.corda.core.flows.StateConsumptionDetails
import net.corda.core.internal.VisibleForTesting
import net.corda.core.internal.notary.isConsumedByTheSameTx
import net.corda.core.internal.notary.validateTimeWindow
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
@ -200,11 +200,11 @@ class RaftTransactionCommitLog<E, EK>(
}
class CordaKryoSerializer<T : Any> : TypeSerializer<T> {
private val context = SerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = SerializationFactory.defaultFactory
private val context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT.withEncoding(CordaSerializationEncoding.SNAPPY)
private val factory = CheckpointSerializationFactory.defaultFactory
override fun write(obj: T, buffer: BufferOutput<*>, serializer: Serializer) {
val serialized = obj.serialize(context = context)
val serialized = obj.checkpointSerialize(context = context)
buffer.writeInt(serialized.size)
buffer.write(serialized.bytes)
}

View File

@ -533,16 +533,16 @@ class HibernateQueryCriteriaParser(val contractStateType: Class<out ContractStat
}
// state relevance.
if (criteria.isRelevant != Vault.RelevancyStatus.ALL) {
val predicateID = Pair(VaultSchemaV1.VaultStates::isRelevant.name, EQUAL)
if (criteria.relevancyStatus != Vault.RelevancyStatus.ALL) {
val predicateID = Pair(VaultSchemaV1.VaultStates::relevancyStatus.name, EQUAL)
if (commonPredicates.containsKey(predicateID)) {
val existingStatus = ((commonPredicates[predicateID] as ComparisonPredicate).rightHandOperand as LiteralExpression).literal
if (existingStatus != criteria.isRelevant) {
log.warn("Overriding previous attribute [${VaultSchemaV1.VaultStates::isRelevant.name}] value $existingStatus with ${criteria.status}")
commonPredicates.replace(predicateID, criteriaBuilder.equal(vaultStates.get<Vault.RelevancyStatus>(VaultSchemaV1.VaultStates::isRelevant.name), criteria.isRelevant))
if (existingStatus != criteria.relevancyStatus) {
log.warn("Overriding previous attribute [${VaultSchemaV1.VaultStates::relevancyStatus.name}] value $existingStatus with ${criteria.status}")
commonPredicates.replace(predicateID, criteriaBuilder.equal(vaultStates.get<Vault.RelevancyStatus>(VaultSchemaV1.VaultStates::relevancyStatus.name), criteria.relevancyStatus))
}
} else {
commonPredicates[predicateID] = criteriaBuilder.equal(vaultStates.get<Vault.RelevancyStatus>(VaultSchemaV1.VaultStates::isRelevant.name), criteria.isRelevant)
commonPredicates[predicateID] = criteriaBuilder.equal(vaultStates.get<Vault.RelevancyStatus>(VaultSchemaV1.VaultStates::relevancyStatus.name), criteria.relevancyStatus)
}
}

View File

@ -5,6 +5,7 @@ import co.paralleluniverse.strands.Strand
import com.github.benmanes.caffeine.cache.Caffeine
import net.corda.core.contracts.*
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.containsAny
import net.corda.core.internal.*
import net.corda.core.messaging.DataFeed
import net.corda.core.node.ServicesForResolution
@ -133,10 +134,12 @@ class NodeVaultService(
// For EVERY state to be committed to the vault, this checks whether it is spendable by the recording
// node. The behaviour is as follows:
//
// 1) All vault updates marked as RELEVANT will, of, course all have isRelevant = true.
// 2) For ALL_VISIBLE updates, those which are not relevant according to the relevancy rules will have isRelevant = false.
// 1) All vault updates marked as RELEVANT will, of course, all have relevancy_status = 1 in the
// "vault_states" table.
// 2) For ALL_VISIBLE updates, those which are not relevant according to the relevancy rules will have
// relevancy_status = 0 in the "vault_states" table.
//
// This is useful when it comes to querying for fungible states, when we do not want non-relevant states
// This is useful when it comes to querying for fungible states, when we do not want irrelevant states
// included in the result.
//
// The same functionality could be obtained by passing in a list of participants to the vault query,
@ -156,7 +159,7 @@ class NodeVaultService(
lockId = uuid,
lockUpdateTime = if (uuid == null) null else now,
recordedTime = clock.instant(),
isRelevant = if (isRelevant) Vault.RelevancyStatus.RELEVANT else Vault.RelevancyStatus.NOT_RELEVANT
relevancyStatus = if (isRelevant) Vault.RelevancyStatus.RELEVANT else Vault.RelevancyStatus.NOT_RELEVANT
)
stateToAdd.stateRef = PersistentStateRef(stateAndRef.key)
session.save(stateToAdd)
@ -454,7 +457,7 @@ class NodeVaultService(
val enrichedCriteria = QueryCriteria.VaultQueryCriteria(
contractStateTypes = setOf(contractStateType),
softLockingCondition = QueryCriteria.SoftLockingCondition(QueryCriteria.SoftLockingType.UNLOCKED_AND_SPECIFIED, listOf(lockId)),
isRelevant = Vault.RelevancyStatus.RELEVANT
relevancyStatus = Vault.RelevancyStatus.RELEVANT
)
val results = queryBy(contractStateType, enrichedCriteria.and(eligibleStatesQuery), sorter)
@ -485,7 +488,7 @@ class NodeVaultService(
is OwnableState -> (state.participants.map { it.owningKey } + state.owner.owningKey).toSet()
else -> state.participants.map { it.owningKey }
}
return keysToCheck.any { it in myKeys }
return keysToCheck.any { it.containsAny(myKeys) }
}
@Throws(VaultQueryException::class)
@ -561,7 +564,7 @@ class NodeVaultService(
vaultState.notary,
vaultState.lockId,
vaultState.lockUpdateTime,
vaultState.isRelevant))
vaultState.relevancyStatus))
} else {
// TODO: improve typing of returned other results
log.debug { "OtherResults: ${Arrays.toString(result.toArray())}" }

View File

@ -61,8 +61,8 @@ object VaultSchemaV1 : MappedSchema(schemaFamily = VaultSchema.javaClass, versio
var lockId: String? = null,
/** Used to determine whether a state abides by the relevancy rules of the recording node */
@Column(name = "is_relevant", nullable = false)
var isRelevant: Vault.RelevancyStatus,
@Column(name = "relevancy_status", nullable = false)
var relevancyStatus: Vault.RelevancyStatus,
/** refers to the last time a lock was taken (reserved) or updated (released, re-reserved) */
@Column(name = "lock_timestamp", nullable = true)

View File

@ -2,14 +2,18 @@
<databaseChangeLog xmlns="http://www.liquibase.org/xml/ns/dbchangelog"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://www.liquibase.org/xml/ns/dbchangelog http://www.liquibase.org/xml/ns/dbchangelog/dbchangelog-3.5.xsd">
<changeSet author="R3.Corda" id="add_is_relevant_column">
<preConditions onFail="MARK_RAN"><not><columnExists tableName="vault_states" columnName="is_relevant"/></not></preConditions>
<changeSet author="R3.Corda" id="add_relevancy_status_column">
<preConditions onFail="MARK_RAN">
<not>
<columnExists tableName="vault_states" columnName="relevancy_status"/>
</not>
</preConditions>
<addColumn tableName="vault_states">
<column name="is_relevant" type="INT"/>
<column name="relevancy_status" type="INT"/>
</addColumn>
<update tableName="vault_states">
<column name="is_relevant" valueNumeric="0"/>
<column name="relevancy_status" valueNumeric="0"/>
</update>
<addNotNullConstraint tableName="vault_states" columnName="is_relevant" columnDataType="INT" />
<addNotNullConstraint tableName="vault_states" columnName="relevancy_status" columnDataType="INT"/>
</changeSet>
</databaseChangeLog>

View File

@ -1,43 +0,0 @@
package net.corda.node
import net.corda.core.internal.div
import net.corda.node.internal.NodeStartup
import net.corda.nodeapi.internal.config.UnknownConfigKeysPolicy
import org.assertj.core.api.Assertions.assertThat
import org.junit.BeforeClass
import org.junit.Test
import org.slf4j.event.Level
import java.nio.file.Path
import java.nio.file.Paths
class NodeCmdLineOptionsTest {
private val parser = NodeStartup()
companion object {
private lateinit var workingDirectory: Path
@BeforeClass
@JvmStatic
fun initDirectories() {
workingDirectory = Paths.get(".").normalize().toAbsolutePath()
}
}
@Test
fun `no command line arguments`() {
assertThat(parser.cmdLineOptions.baseDirectory).isEqualTo(workingDirectory)
assertThat(parser.cmdLineOptions.configFile).isEqualTo(workingDirectory / "node.conf")
assertThat(parser.verbose).isEqualTo(false)
assertThat(parser.loggingLevel).isEqualTo(Level.INFO)
assertThat(parser.cmdLineOptions.nodeRegistrationOption).isEqualTo(null)
assertThat(parser.cmdLineOptions.noLocalShell).isEqualTo(false)
assertThat(parser.cmdLineOptions.sshdServer).isEqualTo(false)
assertThat(parser.cmdLineOptions.justGenerateNodeInfo).isEqualTo(false)
assertThat(parser.cmdLineOptions.justGenerateRpcSslCerts).isEqualTo(false)
assertThat(parser.cmdLineOptions.bootstrapRaftCluster).isEqualTo(false)
assertThat(parser.cmdLineOptions.unknownConfigKeysPolicy).isEqualTo(UnknownConfigKeysPolicy.FAIL)
assertThat(parser.cmdLineOptions.devMode).isEqualTo(null)
assertThat(parser.cmdLineOptions.clearNetworkMapCache).isEqualTo(false)
assertThat(parser.cmdLineOptions.networkRootTrustStorePath).isEqualTo(workingDirectory / "certificates" / "network-root-truststore.jks")
}
}

View File

@ -0,0 +1,52 @@
package net.corda.node.internal
import net.corda.core.internal.div
import net.corda.nodeapi.internal.config.UnknownConfigKeysPolicy
import org.assertj.core.api.Assertions.assertThat
import org.junit.BeforeClass
import org.junit.Test
import org.slf4j.event.Level
import picocli.CommandLine
import java.nio.file.Path
import java.nio.file.Paths
class NodeStartupTest {
private val startup = NodeStartup()
companion object {
private lateinit var workingDirectory: Path
@BeforeClass
@JvmStatic
fun initDirectories() {
workingDirectory = Paths.get(".").normalize().toAbsolutePath()
}
}
@Test
fun `no command line arguments`() {
CommandLine.populateCommand(startup)
assertThat(startup.cmdLineOptions.baseDirectory).isEqualTo(workingDirectory)
assertThat(startup.cmdLineOptions.configFile).isEqualTo(workingDirectory / "node.conf")
assertThat(startup.verbose).isEqualTo(false)
assertThat(startup.loggingLevel).isEqualTo(Level.INFO)
assertThat(startup.cmdLineOptions.nodeRegistrationOption).isEqualTo(null)
assertThat(startup.cmdLineOptions.noLocalShell).isEqualTo(false)
assertThat(startup.cmdLineOptions.sshdServer).isEqualTo(false)
assertThat(startup.cmdLineOptions.justGenerateNodeInfo).isEqualTo(false)
assertThat(startup.cmdLineOptions.justGenerateRpcSslCerts).isEqualTo(false)
assertThat(startup.cmdLineOptions.bootstrapRaftCluster).isEqualTo(false)
assertThat(startup.cmdLineOptions.unknownConfigKeysPolicy).isEqualTo(UnknownConfigKeysPolicy.FAIL)
assertThat(startup.cmdLineOptions.devMode).isEqualTo(null)
assertThat(startup.cmdLineOptions.clearNetworkMapCache).isEqualTo(false)
assertThat(startup.cmdLineOptions.networkRootTrustStorePath).isEqualTo(workingDirectory / "certificates" / "network-root-truststore.jks")
}
@Test
fun `--base-directory`() {
CommandLine.populateCommand(startup, "--base-directory", (workingDirectory / "another-base-dir").toString())
assertThat(startup.cmdLineOptions.baseDirectory).isEqualTo(workingDirectory / "another-base-dir")
assertThat(startup.cmdLineOptions.configFile).isEqualTo(workingDirectory / "another-base-dir" / "node.conf")
assertThat(startup.cmdLineOptions.networkRootTrustStorePath).isEqualTo(workingDirectory / "another-base-dir" / "certificates" / "network-root-truststore.jks")
}
}

View File

@ -5,7 +5,6 @@ import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.KryoSerializable
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.esotericsoftware.kryo.pool.KryoPool
import com.google.common.primitives.Ints
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
@ -13,6 +12,10 @@ import net.corda.core.contracts.PrivacySalt
import net.corda.core.crypto.*
import net.corda.core.internal.FetchDataFlow
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.ByteSequence
import net.corda.core.utilities.ProgressTracker
import net.corda.core.utilities.sequence
@ -36,16 +39,6 @@ import java.util.*
import kotlin.collections.ArrayList
import kotlin.test.*
class TestScheme : AbstractKryoSerializationScheme() {
override fun canDeserializeVersion(magic: CordaSerializationMagic, target: SerializationContext.UseCase): Boolean {
return magic == kryoMagic && target != SerializationContext.UseCase.RPCClient
}
override fun rpcClientKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
override fun rpcServerKryoPool(context: SerializationContext): KryoPool = throw UnsupportedOperationException()
}
@RunWith(Parameterized::class)
class KryoTests(private val compression: CordaSerializationEncoding?) {
companion object {
@ -55,18 +48,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun compression() = arrayOf<CordaSerializationEncoding?>(null) + CordaSerializationEncoding.values()
}
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) }
context = SerializationContextImpl(kryoMagic,
factory = CheckpointSerializationFactory(KryoSerializationScheme)
context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.Storage,
compression,
rigorousMock<EncodingWhitelist>().also {
if (compression != null) doReturn(true).whenever(it).acceptEncoding(compression)
@ -77,15 +69,15 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `simple data class`() {
val birthday = Instant.parse("1984-04-17T00:30:00.00Z")
val mike = Person("mike", birthday)
val bits = mike.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("mike", birthday))
val bits = mike.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("mike", birthday))
}
@Test
fun `null values`() {
val bob = Person("bob", null)
val bits = bob.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(Person("bob", null))
val bits = bob.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(Person("bob", null))
}
@Test
@ -93,10 +85,10 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val noReferencesContext = context.withoutReferences()
val obj : ByteSequence = Ints.toByteArray(0x01234567).sequence()
val originalList : ArrayList<ByteSequence> = ArrayList<ByteSequence>().apply { this += obj }
val deserialisedList = originalList.serialize(factory, noReferencesContext).deserialize(factory, noReferencesContext)
val deserialisedList = originalList.checkpointSerialize(factory, noReferencesContext).checkpointDeserialize(factory, noReferencesContext)
originalList += obj
deserialisedList += obj
assertThat(deserialisedList.serialize(factory, noReferencesContext)).isEqualTo(originalList.serialize(factory, noReferencesContext))
assertThat(deserialisedList.checkpointSerialize(factory, noReferencesContext)).isEqualTo(originalList.checkpointSerialize(factory, noReferencesContext))
}
@Test
@ -113,14 +105,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
this += instant
this += instant
}
assertThat(listWithSameInstances.serialize(factory, noReferencesContext)).isEqualTo(listWithCopies.serialize(factory, noReferencesContext))
assertThat(listWithSameInstances.checkpointSerialize(factory, noReferencesContext)).isEqualTo(listWithCopies.checkpointSerialize(factory, noReferencesContext))
}
@Test
fun `cyclic object graph`() {
val cyclic = Cyclic(3)
val bits = cyclic.serialize(factory, context)
assertThat(bits.deserialize(factory, context)).isEqualTo(cyclic)
val bits = cyclic.checkpointSerialize(factory, context)
assertThat(bits.checkpointDeserialize(factory, context)).isEqualTo(cyclic)
}
@Test
@ -132,7 +124,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
signature.verify(bitsToSign)
assertThatThrownBy { signature.verify(wrongBits) }
val deserialisedKeyPair = keyPair.serialize(factory, context).deserialize(factory, context)
val deserialisedKeyPair = keyPair.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
val deserialisedSignature = deserialisedKeyPair.sign(bitsToSign)
deserialisedSignature.verify(bitsToSign)
assertThatThrownBy { deserialisedSignature.verify(wrongBits) }
@ -140,28 +132,28 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `write and read Kotlin object singleton`() {
val serialised = TestSingleton.serialize(factory, context)
val deserialised = serialised.deserialize(factory, context)
val serialised = TestSingleton.checkpointSerialize(factory, context)
val deserialised = serialised.checkpointDeserialize(factory, context)
assertThat(deserialised).isSameAs(TestSingleton)
}
@Test
fun `check Kotlin EmptyList can be serialised`() {
val deserialisedList: List<Int> = emptyList<Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedList: List<Int> = emptyList<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedList.size)
assertEquals<Any>(Collections.emptyList<Int>().javaClass, deserialisedList.javaClass)
}
@Test
fun `check Kotlin EmptySet can be serialised`() {
val deserialisedSet: Set<Int> = emptySet<Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedSet: Set<Int> = emptySet<Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedSet.size)
assertEquals<Any>(Collections.emptySet<Int>().javaClass, deserialisedSet.javaClass)
}
@Test
fun `check Kotlin EmptyMap can be serialised`() {
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().serialize(factory, context).deserialize(factory, context)
val deserialisedMap: Map<Int, Int> = emptyMap<Int, Int>().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(0, deserialisedMap.size)
assertEquals<Any>(Collections.emptyMap<Int, Int>().javaClass, deserialisedMap.javaClass)
}
@ -169,7 +161,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation`() {
val rubbish = ByteArray(12345) { (it * it * 0.12345).toByte() }
val readRubbishStream: InputStream = rubbish.inputStream().serialize(factory, context).deserialize(factory, context)
val readRubbishStream: InputStream = rubbish.inputStream().checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -179,7 +171,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `InputStream serialisation does not write trailing garbage`() {
val byteArrays = listOf("123", "456").map { it.toByteArray() }
val streams = byteArrays.map { it.inputStream() }.serialize(factory, context).deserialize(factory, context).iterator()
val streams = byteArrays.map { it.inputStream() }.checkpointSerialize(factory, context).checkpointDeserialize(factory, context).iterator()
byteArrays.forEach { assertArrayEquals(it, streams.next().readBytes()) }
assertFalse(streams.hasNext())
}
@ -190,16 +182,16 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val testBytes = testString.toByteArray()
val meta = SignableData(testBytes.sha256(), SignatureMetadata(1, Crypto.findSignatureScheme(ALICE_PUBKEY).schemeNumberID))
val serializedMetaData = meta.serialize(factory, context).bytes
val meta2 = serializedMetaData.deserialize<SignableData>(factory, context)
val serializedMetaData = meta.checkpointSerialize(factory, context).bytes
val meta2 = serializedMetaData.checkpointDeserialize<SignableData>(factory, context)
assertEquals(meta2, meta)
}
@Test
fun `serialize - deserialize Logger`() {
val storageContext: SerializationContext = context // TODO: make it storage context
val storageContext: CheckpointSerializationContext = context
val logger = LoggerFactory.getLogger("aName")
val logger2 = logger.serialize(factory, storageContext).deserialize(factory, storageContext)
val logger2 = logger.checkpointSerialize(factory, storageContext).checkpointDeserialize(factory, storageContext)
assertEquals(logger.name, logger2.name)
assertTrue(logger === logger2)
}
@ -211,7 +203,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
SecureHash.sha256(rubbish),
rubbish.size,
rubbish.inputStream()
).serialize(factory, context).deserialize(factory, context)
).checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
for (i in 0..12344) {
assertEquals(rubbish[i], readRubbishStream.read().toByte())
}
@ -238,8 +230,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
31, 32
))
val serializedBytes = expected.serialize(factory, context)
val actual = serializedBytes.deserialize(factory, context)
val serializedBytes = expected.checkpointSerialize(factory, context)
val actual = serializedBytes.checkpointDeserialize(factory, context)
assertEquals(expected, actual)
}
@ -286,15 +278,14 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
}
}
Tmp()
val factory = SerializationFactoryImpl().apply { registerScheme(TestScheme()) }
val context = SerializationContextImpl(kryoMagic,
val factory = CheckpointSerializationFactory(KryoSerializationScheme)
val context = CheckpointSerializationContextImpl(
javaClass.classLoader,
AllWhitelist,
emptyMap(),
true,
SerializationContext.UseCase.P2P,
null)
pt.serialize(factory, context)
pt.checkpointSerialize(factory, context)
}
@Test
@ -302,7 +293,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
val exception = IllegalArgumentException("fooBar")
val toBeSuppressedOnSenderSide = IllegalStateException("bazz1")
exception.addSuppressed(toBeSuppressedOnSenderSide)
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message)
assertEquals(1, exception2.suppressed.size)
@ -317,7 +308,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
@Test
fun `serialize - deserialize Exception no suppressed`() {
val exception = IllegalArgumentException("fooBar")
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(exception.message, exception2.message)
assertEquals(0, exception2.suppressed.size)
@ -331,7 +322,7 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `serialize - deserialize HashNotFound`() {
val randomHash = SecureHash.randomSHA256()
val exception = FetchDataFlow.HashNotFound(randomHash)
val exception2 = exception.serialize(factory, context).deserialize(factory, context)
val exception2 = exception.checkpointSerialize(factory, context).checkpointDeserialize(factory, context)
assertEquals(randomHash, exception2.requested)
}
@ -339,17 +330,17 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
fun `compression has the desired effect`() {
compression ?: return
val data = ByteArray(12345).also { Random(0).nextBytes(it) }.let { it + it }
val compressed = data.serialize(factory, context)
val compressed = data.checkpointSerialize(factory, context)
assertEquals(.5, compressed.size.toDouble() / data.size, .03)
assertArrayEquals(data, compressed.deserialize(factory, context))
assertArrayEquals(data, compressed.checkpointDeserialize(factory, context))
}
@Test
fun `a particular encoding can be banned for deserialization`() {
compression ?: return
doReturn(false).whenever(context.encodingWhitelist).acceptEncoding(compression)
val compressed = "whatever".serialize(factory, context)
catchThrowable { compressed.deserialize(factory, context) }.run {
val compressed = "whatever".checkpointSerialize(factory, context)
catchThrowable { compressed.checkpointDeserialize(factory, context) }.run {
assertSame<Any>(KryoException::class.java, javaClass)
assertEquals(encodingNotPermittedFormat.format(compression), message)
}
@ -360,8 +351,8 @@ class KryoTests(private val compression: CordaSerializationEncoding?) {
class Holder(val holder: ByteArray)
val obj = Holder(ByteArray(20000))
val uncompressedSize = obj.serialize(factory, context.withEncoding(null)).size
val compressedSize = obj.serialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
val uncompressedSize = obj.checkpointSerialize(factory, context.withEncoding(null)).size
val compressedSize = obj.checkpointSerialize(factory, context.withEncoding(CordaSerializationEncoding.SNAPPY)).size
// If these need fixing, sounds like Kryo wire format changed and checkpoints might not surive an upgrade.
assertEquals(20222, uncompressedSize)
assertEquals(1111, compressedSize)

View File

@ -3,9 +3,9 @@ package net.corda.node.services.persistence
import net.corda.core.context.InvocationContext
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.serialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.node.internal.CheckpointIncompatibleException
import net.corda.node.internal.CheckpointVerifier
import net.corda.node.internal.configureDatabase
@ -189,9 +189,9 @@ class DBCheckpointStorageTests {
val logic: FlowLogic<*> = object : FlowLogic<Unit>() {
override fun call() {}
}
val frozenLogic = logic.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
val frozenLogic = logic.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
val checkpoint = Checkpoint.create(InvocationContext.shell(), FlowStart.Explicit, logic.javaClass, frozenLogic, ALICE, SubFlowVersion.CoreFlow(version)).getOrThrow()
return id to checkpoint.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
return id to checkpoint.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
}
}

View File

@ -14,9 +14,9 @@ import net.corda.core.identity.Party
import net.corda.core.internal.FlowIORequest
import net.corda.core.internal.FlowStateMachine
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.node.TestClock
import org.junit.Assert
@ -70,7 +70,7 @@ class FlowStateMachineComparatorTest {
val sm1 = FlowStateMachineImpl<Unit>(StateMachineRunId(UUID.randomUUID()),
scheduler = scheduler,
logic = EmptyFlow, creationTime = clock.millis())
val sm2 = sm1.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).deserialize(context = SerializationDefaults.CHECKPOINT_CONTEXT)
val sm2 = sm1.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT).checkpointDeserialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT)
Fiber.unparkDeserialized(sm2, scheduler)
val comparator = FlowStateMachineComparator()

View File

@ -48,7 +48,7 @@ import java.math.BigDecimal
import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import javax.persistence.*
import javax.persistence.PersistenceException
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
@ -759,13 +759,13 @@ class NodeVaultServiceTest {
// Test two.
// RelevancyStatus set to NOT_RELEVANT.
val criteriaTwo = VaultQueryCriteria(isRelevant = Vault.RelevancyStatus.NOT_RELEVANT)
val criteriaTwo = VaultQueryCriteria(relevancyStatus = Vault.RelevancyStatus.NOT_RELEVANT)
val resultTwo = vaultService.queryBy<DummyState>(criteriaTwo).states.getNumbers()
assertEquals(setOf(4, 5), resultTwo)
// Test three.
// RelevancyStatus set to ALL.
val criteriaThree = VaultQueryCriteria(isRelevant = Vault.RelevancyStatus.RELEVANT)
val criteriaThree = VaultQueryCriteria(relevancyStatus = Vault.RelevancyStatus.RELEVANT)
val resultThree = vaultService.queryBy<DummyState>(criteriaThree).states.getNumbers()
assertEquals(setOf(1, 3, 6), resultThree)

View File

@ -0,0 +1,246 @@
#!/usr/bin/python
#-------------------------------------------------------------------------------
#
# Usage
# =======
#
# ./jiraReleaseChecker.py <oldTag> <jiraTag> <jiraUser> [-m mode]
# ./jiraReleaseChecker.py release-V3.1 "Corda 3.3" some.user@r3.com [-m not-in-jira]
#
# <oldTag> is the point prior to the current branches head in history from
# which to inspect commits. Normally this will be the tag of the previous
# release. e.g.
#
# master ----------------------------------------------
# \
# release/4 -----------+--------------+------------
# / /
# release/4.0 release/4.1
#
# The current release in the above example will be 4.2 and those commits
# extend from 4.1 having been backported from master. Thus <oldTag> is
# release/4.1
#
# <jiraTag> should refer to the version string used within
# the R3 Corda Jira to track the release. For example, for 3.3 this would be
# "Corda 3.3"
#
# <jiraUser> should be a registered user able to authenticate with the
# R3 Jira system. Authentication and password management is handled through
# the native OS keyring implementation.
#
# The script should be run on the relevant release branch within the git
# repository.
#
# Modes
# -------
#
# The tool can operate in 3 modes
#
# * rst - The default when omitted. Will take the combined lists
# of issues fixed from both Jira and commit summaries and
# format that list in such a way it can be included within
# the release notes for the next release. Will include hyper
# links to the R3 Jira for each ticket.
# * not-in-jira - Print a list of tickets that are included in commit
# summaries but are not tagged in Jira as fixed in the release
# * not-in-commit - Print a list of tickets that are tagged in Jira but that
# are not mentioned in any commit summary,
#
# Pre Requisites
# ================
#
# pip
# pyjira
# gitpython
# keyring (optional)
#
# Installation
# --------------
# Should be a simple matter of ``pip install <package>``
#
# Issues
# ========
#
# Doesn't really handle many errors all that well, also gives no mechanism
# to enter a correct password into the keyring if a wrong one is added which
# isn't great but for now this should do
#
#-------------------------------------------------------------------------------
import re
import sys
import getpass
import argparse
try :
import keyring
except ImportError :
disableKeyring = True
else :
disableKeyring = False
from jira import JIRA
from git import Repo
#-------------------------------------------------------------------------------
R3_JIRA_ADDR = "https://r3-cev.atlassian.net"
JIRA_MAX_RESULTS = 50
#-------------------------------------------------------------------------------
#
# For a given user (provide via the command line) authenticate with Jira and
# return an interface object instance
#
def jiraLogin(user) :
password = keyring.get_password ('jira', user) if not disableKeyring else None
if not password:
password = getpass.getpass("Please enter your JIRA password, " +
"it will be stored in your OS Keyring: ")
if not disableKeyring :
keyring.set_password ('jira', user, password)
return JIRA(R3_JIRA_ADDR, auth=(user, password))
#-------------------------------------------------------------------------------
#
# Cope with Jira REST API paginating query results
#
def jiraQuery (jira, query) :
offset = 0
results = JIRA_MAX_RESULTS
rtn = []
while (results == JIRA_MAX_RESULTS) :
issues = jira.search_issues(query, maxResults=JIRA_MAX_RESULTS, startAt=offset)
results = len(issues)
if results > 0 :
offset += JIRA_MAX_RESULTS
rtn += issues
return rtn
#-------------------------------------------------------------------------------
#
# Take a Jira issue and format it in such a way we can include it as a line
# item in the release notes formatted with a hyperlink to the issue in Jira
#
def issueToRST(issue) :
return "* %s [`%s <%s/browse/%s>`_]" % (
issue.fields.summary,
issue.key,
R3_JIRA_ADDR,
issue.key)
#-------------------------------------------------------------------------------
#
# Get a list of jiras from Jira where those jiras are marked as fixed
# in some specific version (set on the command line).
#
# Optionally, an already authenticated Jira connection instance can be
# provided to avoid re-authenticating. The authenticated object
# is returned for reuse.
#
def getJirasFromJira(args_, jira_ = None) :
jira = jiraLogin(args_.jiraUser) if jira_ == None else jira_
return jiraQuery(jira, \
'project in (Corda, Ent) And fixVersion in ("%s") and status in (Done)' % (args_.jiraTag)) \
, jira
#-------------------------------------------------------------------------------
def getJiraIdsFromJira(args_, jira_ = None) :
jira = jiraLogin(args_.jiraUser) if jira_ == None else jira_
jirasFromJira, _ = jiraQuery(jira, \
'project in (Corda, Ent) And fixVersion in ("%s") and status in (Done)' % (args_.jiraTag)) \
, jira
return [ j.key for j in jirasFromJira ], jira
#-------------------------------------------------------------------------------
def getJiraIdsFromCommits(args_) :
jiraMatch = re.compile("(CORDA-\d+|ENT-\d+)")
repo = Repo(".", search_parent_directories = True)
jirasFromCommits = []
for commit in list (repo.iter_commits ("%s.." % (args_.oldTag))) :
jirasFromCommits += jiraMatch.findall(commit.summary)
return jirasFromCommits
#-------------------------------------------------------------------------------
#
# Take the set of all tickets completed in a release (the union of those
# tagged in Jira and those marked in commit summaries) and format them
# for inclusion in the release notes (rst format).
#
def rst (args_) :
jiraIdsFromCommits = getJiraIdsFromCommits(args_)
jirasFromJira, jiraObj = getJirasFromJira(args_)
jiraIdsFromJira = [ jira.key for jira in jirasFromJira ]
#
# Grab the set of JIRA's that aren't tagged as fixed in the release but are
# mentioned in a commit and pull down the JIRA information for those so as
# to get access to their summary
#
extraJiras = set(jiraIdsFromCommits).difference(jiraIdsFromJira)
jirasFromJira += jiraQuery(jiraObj, "key in (%s)" % (", ".join(extraJiras)))
for jira in jirasFromJira :
print issueToRST(jira)
#-------------------------------------------------------------------------------
def notInJira(args_) :
jiraIdsFromCommits = getJiraIdsFromCommits(args_)
jiraIdsFromJira, _ = getJiraIdsFromJira(args_)
print 'Issues mentioned in commits but not set as "fixed in" in Jira'
for jiraId in set(jiraIdsFromJira).difference(jiraIdsFromCommits) :
print jiraId
#-------------------------------------------------------------------------------
def notInCommit(args_) :
jiraIdsFromCommits = getJiraIdsFromCommits(args_)
jiraIdsFromJira, _ = getJiraIdsFromJira(args_)
print 'Issues tagged in Jira as fixed but not mentioned in any commit summary'
for jiraId in set(jiraIdsFromCommits).difference(jiraIdsFromJira) :
print jiraId
#-------------------------------------------------------------------------------
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--mode", help="display a square of a given number",
choices = [ "rst", "not-in-jira", "not-in-commit"])
parser.add_argument("oldTag", help="The previous release tag")
parser.add_argument("jiraTag", help="The current Jira release")
parser.add_argument("jiraUser", help="Who to authenticate with Jira as")
args = parser.parse_args()
if not args.mode : args.mode = "rst"
if args.mode == "rst" : rst(args)
elif args.mode == "not-in-jira" : notInJira(args)
elif args.mode == "not-in-commit" : notInCommit(args)
#-------------------------------------------------------------------------------

View File

@ -0,0 +1,49 @@
package net.corda.serialization.internal
import net.corda.core.KeepForDJVM
import net.corda.core.crypto.SecureHash
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
@KeepForDJVM
data class CheckpointSerializationContextImpl @JvmOverloads constructor(
override val deserializationClassLoader: ClassLoader,
override val whitelist: ClassWhitelist,
override val properties: Map<Any, Any>,
override val objectReferencesEnabled: Boolean,
override val encoding: SerializationEncoding?,
override val encodingWhitelist: EncodingWhitelist = NullEncodingWhitelist) : CheckpointSerializationContext {
private val builder = AttachmentsClassLoaderBuilder(properties, deserializationClassLoader)
/**
* {@inheritDoc}
*
* We need to cache the AttachmentClassLoaders to avoid too many contexts, since the class loader is part of cache key for the context.
*/
override fun withAttachmentsClassLoader(attachmentHashes: List<SecureHash>): CheckpointSerializationContext {
properties[attachmentsClassLoaderEnabledPropertyName] as? Boolean == true || return this
val classLoader = builder.build(attachmentHashes) ?: return this
return withClassLoader(classLoader)
}
override fun withProperty(property: Any, value: Any): CheckpointSerializationContext {
return copy(properties = properties + (property to value))
}
override fun withoutReferences(): CheckpointSerializationContext {
return copy(objectReferencesEnabled = false)
}
override fun withClassLoader(classLoader: ClassLoader): CheckpointSerializationContext {
return copy(deserializationClassLoader = classLoader)
}
override fun withWhitelisted(clazz: Class<*>): CheckpointSerializationContext {
return copy(whitelist = object : ClassWhitelist {
override fun hasListed(type: Class<*>): Boolean = whitelist.hasListed(type) || type.name == clazz.name
})
}
override fun withEncoding(encoding: SerializationEncoding?) = copy(encoding = encoding)
override fun withEncodingWhitelist(encodingWhitelist: EncodingWhitelist) = copy(encodingWhitelist = encodingWhitelist)
}

View File

@ -3,14 +3,14 @@ package net.corda.serialization.internal
import net.corda.core.DeleteForDJVM
import net.corda.core.node.ServiceHub
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializeAsToken
import net.corda.core.serialization.SerializeAsTokenContext
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
val serializationContextKey = SerializeAsTokenContext::class.java
fun SerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): SerializationContext = this.withProperty(serializationContextKey, serializationContext)
fun CheckpointSerializationContext.withTokenContext(serializationContext: SerializeAsTokenContext): CheckpointSerializationContext = this.withProperty(serializationContextKey, serializationContext)
/**
* A context for mapping SerializationTokens to/from SerializeAsTokens.
@ -55,6 +55,53 @@ class SerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: Ser
}
}
override fun getSingleton(className: String) = classNameToSingleton[className]
?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this")
}
/**
* A context for mapping SerializationTokens to/from SerializeAsTokens.
*
* A context is initialised with an object containing all the instances of [SerializeAsToken] to eagerly register all the tokens.
* In our case this can be the [ServiceHub].
*
* Then it is a case of using the companion object methods on [SerializeAsTokenSerializer] to set and clear context as necessary
* when serializing to enable/disable tokenization.
*/
@DeleteForDJVM
class CheckpointSerializeAsTokenContextImpl(override val serviceHub: ServiceHub, init: SerializeAsTokenContext.() -> Unit) : SerializeAsTokenContext {
constructor(toBeTokenized: Any, serializationFactory: CheckpointSerializationFactory, context: CheckpointSerializationContext, serviceHub: ServiceHub) : this(serviceHub, {
serializationFactory.serialize(toBeTokenized, context.withTokenContext(this))
})
private val classNameToSingleton = mutableMapOf<String, SerializeAsToken>()
private var readOnly = false
init {
/**
* Go ahead and eagerly serialize the object to register all of the tokens in the context.
*
* This results in the toToken() method getting called for any [SingletonSerializeAsToken] instances which
* are encountered in the object graph as they are serialized and will therefore register the token to
* object mapping for those instances. We then immediately set the readOnly flag to stop further adhoc or
* accidental registrations from occuring as these could not be deserialized in a deserialization-first
* scenario if they are not part of this iniital context construction serialization.
*/
init(this)
readOnly = true
}
override fun putSingleton(toBeTokenized: SerializeAsToken) {
val className = toBeTokenized.javaClass.name
if (className !in classNameToSingleton) {
// Only allowable if we are in SerializeAsTokenContext init (readOnly == false)
if (readOnly) {
throw UnsupportedOperationException("Attempt to write token for lazy registered $className. All tokens should be registered during context construction.")
}
classNameToSingleton[className] = toBeTokenized
}
}
override fun getSingleton(className: String) = classNameToSingleton[className]
?: throw IllegalStateException("Unable to find tokenized instance of $className in context $this")
}

View File

@ -13,3 +13,11 @@ fun checkUseCase(allowedUseCases: EnumSet<SerializationContext.UseCase>) {
throw IllegalStateException("UseCase '${currentContext.useCase}' is not within '$allowedUseCases'")
}
}
fun checkUseCase(allowedUseCase: SerializationContext.UseCase) {
val currentContext: SerializationContext = SerializationFactory.currentFactory?.currentContext
?: throw IllegalStateException("Current context is not set")
if (allowedUseCase != currentContext.useCase) {
throw IllegalStateException("UseCase '${currentContext.useCase}' is not '$allowedUseCase'")
}
}

View File

@ -163,8 +163,6 @@ abstract class AbstractAMQPSerializationScheme(
return synchronized(serializerFactoriesForContexts) {
serializerFactoriesForContexts.computeIfAbsent(Pair(context.whitelist, context.deserializationClassLoader)) {
when (context.useCase) {
SerializationContext.UseCase.Checkpoint ->
throw IllegalStateException("AMQP should not be used for checkpoint serialization.")
SerializationContext.UseCase.RPCClient ->
rpcClientSerializerFactory(context)
SerializationContext.UseCase.RPCServer ->

View File

@ -2,7 +2,6 @@ package net.corda.serialization.internal.amqp.custom
import net.corda.core.crypto.Crypto
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.SerializationContext.UseCase.Checkpoint
import net.corda.core.serialization.SerializationContext.UseCase.Storage
import net.corda.serialization.internal.amqp.*
import net.corda.serialization.internal.checkUseCase
@ -13,14 +12,12 @@ import java.util.*
object PrivateKeySerializer : CustomSerializer.Implements<PrivateKey>(PrivateKey::class.java) {
private val allowedUseCases = EnumSet.of(Storage, Checkpoint)
override val schemaForDocumentation = Schema(listOf(RestrictedType(type.toString(), "", listOf(type.toString()), SerializerFactory.primitiveTypeName(ByteArray::class.java)!!, descriptor, emptyList())))
override fun writeDescribedObject(obj: PrivateKey, data: Data, type: Type, output: SerializationOutput,
context: SerializationContext
) {
checkUseCase(allowedUseCases)
checkUseCase(Storage)
output.writeObject(obj.encoded, data, clazz, context)
}

View File

@ -4,7 +4,6 @@ import com.google.common.collect.Maps;
import net.corda.core.serialization.SerializationContext;
import net.corda.core.serialization.SerializationFactory;
import net.corda.core.serialization.SerializedBytes;
import net.corda.serialization.internal.amqp.AMQPNotSerializableException;
import net.corda.serialization.internal.amqp.SchemaKt;
import net.corda.testing.core.SerializationEnvironmentRule;
import org.junit.Before;
@ -20,8 +19,10 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class ForbiddenLambdaSerializationTests {
private EnumSet<SerializationContext.UseCase> contexts = EnumSet.complementOf(
EnumSet.of(SerializationContext.UseCase.Checkpoint, SerializationContext.UseCase.Testing));
EnumSet.of(SerializationContext.UseCase.Testing));
@Rule
public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule();
private SerializationFactory factory;

View File

@ -1,11 +1,11 @@
package net.corda.serialization.internal;
import net.corda.core.serialization.SerializationContext;
import net.corda.core.serialization.SerializationFactory;
import net.corda.core.serialization.SerializedBytes;
import net.corda.core.serialization.*;
import net.corda.core.serialization.internal.CheckpointSerializationContext;
import net.corda.core.serialization.internal.CheckpointSerializationFactory;
import net.corda.node.serialization.kryo.CordaClosureSerializer;
import net.corda.node.serialization.kryo.KryoSerializationSchemeKt;
import net.corda.testing.core.SerializationEnvironmentRule;
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@ -18,21 +18,22 @@ import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.ThrowableAssert.catchThrowable;
public final class LambdaCheckpointSerializationTest {
@Rule
public final SerializationEnvironmentRule testSerialization = new SerializationEnvironmentRule();
private SerializationFactory factory;
private SerializationContext context;
public final CheckpointSerializationEnvironmentRule testCheckpointSerialization =
new CheckpointSerializationEnvironmentRule();
private CheckpointSerializationFactory factory;
private CheckpointSerializationContext context;
@Before
public void setup() {
factory = testSerialization.getSerializationFactory();
context = new SerializationContextImpl(
KryoSerializationSchemeKt.getKryoMagic(),
factory = testCheckpointSerialization.getCheckpointSerializationFactory();
context = new CheckpointSerializationContextImpl(
getClass().getClassLoader(),
AllWhitelist.INSTANCE,
Collections.emptyMap(),
true,
SerializationContext.UseCase.Checkpoint,
null
);
}

View File

@ -3,8 +3,13 @@ package net.corda.serialization.internal
import net.corda.core.contracts.ContractAttachment
import net.corda.core.identity.CordaX500Name
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.testing.contracts.DummyContract
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import net.corda.testing.internal.rigorousMock
import net.corda.testing.node.MockServices
import org.apache.commons.lang.ArrayUtils.EMPTY_BYTE_ARRAY
@ -17,28 +22,29 @@ import org.junit.Test
import kotlin.test.assertEquals
class ContractAttachmentSerializerTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
private lateinit var contextWithToken: SerializationContext
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
private lateinit var contextWithToken: CheckpointSerializationContext
private val mockServices = MockServices(emptyList(), CordaX500Name("MegaCorp", "London", "GB"), rigorousMock())
@Before
fun setup() {
factory = testSerialization.serializationFactory
context = testSerialization.checkpointContext
contextWithToken = context.withTokenContext(SerializeAsTokenContextImpl(Any(), factory, context, mockServices))
factory = testCheckpointSerialization.checkpointSerializationFactory
context = testCheckpointSerialization.checkpointSerializationContext
contextWithToken = context.withTokenContext(CheckpointSerializeAsTokenContextImpl(Any(), factory, context, mockServices))
}
@Test
fun `write contract attachment and read it back`() {
val contractAttachment = ContractAttachment(GeneratedAttachment(EMPTY_BYTE_ARRAY), DummyContract.PROGRAM_ID)
// no token context so will serialize the whole attachment
val serialized = contractAttachment.serialize(factory, context)
val deserialized = serialized.deserialize(factory, context)
val serialized = contractAttachment.checkpointSerialize(factory, context)
val deserialized = serialized.checkpointDeserialize(factory, context)
assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract)
@ -53,8 +59,8 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
val deserialized = serialized.deserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
assertEquals(contractAttachment.id, deserialized.attachment.id)
assertEquals(contractAttachment.contract, deserialized.contract)
@ -70,7 +76,7 @@ class ContractAttachmentSerializerTest {
mockServices.attachments.importAttachment(attachment.open(), "test", null)
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
assertThat(serialized.size).isLessThan(largeAttachmentSize)
}
@ -82,8 +88,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
val deserialized = serialized.deserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
val deserialized = serialized.checkpointDeserialize(factory, contextWithToken)
assertThatThrownBy { deserialized.attachment.open() }.isInstanceOf(MissingAttachmentsException::class.java)
}
@ -94,8 +100,8 @@ class ContractAttachmentSerializerTest {
// don't importAttachment in mockService
val contractAttachment = ContractAttachment(attachment, DummyContract.PROGRAM_ID)
val serialized = contractAttachment.serialize(factory, contextWithToken)
serialized.deserialize(factory, contextWithToken)
val serialized = contractAttachment.checkpointSerialize(factory, contextWithToken)
serialized.checkpointDeserialize(factory, contextWithToken)
// MissingAttachmentsException thrown if we try to open attachment
}

View File

@ -11,12 +11,11 @@ import com.nhaarman.mockito_kotlin.verify
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.DEPLOYED_CORDAPP_UPLOADER
import net.corda.core.node.services.AttachmentStorage
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.ClassWhitelist
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationContext
import net.corda.node.serialization.kryo.CordaClassResolver
import net.corda.node.serialization.kryo.CordaKryo
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock
import net.corda.testing.services.MockAttachmentStorage
import org.junit.Rule
@ -115,8 +114,8 @@ class CordaClassResolverTests {
val emptyMapClass = mapOf<Any, Any>().javaClass
}
private val emptyWhitelistContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, SerializationContext.UseCase.P2P, null)
private val allButBlacklistedContext: SerializationContext = SerializationContextImpl(kryoMagic, this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, SerializationContext.UseCase.P2P, null)
private val emptyWhitelistContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, EmptyWhitelist, emptyMap(), true, null)
private val allButBlacklistedContext: CheckpointSerializationContext = CheckpointSerializationContextImpl(this.javaClass.classLoader, AllButBlacklisted, emptyMap(), true, null)
@Test
fun `Annotation on enum works for specialised entries`() {
CordaClassResolver(emptyWhitelistContext).getRegistration(Foo.Bar::class.java)

View File

@ -3,6 +3,8 @@ package net.corda.serialization.internal
import net.corda.core.crypto.Crypto
import net.corda.core.serialization.SerializationContext.UseCase.*
import net.corda.core.serialization.SerializationDefaults
import net.corda.core.serialization.internal.CheckpointSerializationDefaults
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.serialization.serialize
import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThatThrownBy
@ -33,13 +35,13 @@ class PrivateKeySerializationTest(private val privateKey: PrivateKey, private va
@Test
fun `passed with expected UseCases`() {
assertTrue { privateKey.serialize(context = SerializationDefaults.STORAGE_CONTEXT).bytes.isNotEmpty() }
assertTrue { privateKey.serialize(context = SerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() }
assertTrue { privateKey.checkpointSerialize(context = CheckpointSerializationDefaults.CHECKPOINT_CONTEXT).bytes.isNotEmpty() }
}
@Test
fun `failed with wrong UseCase`() {
assertThatThrownBy { privateKey.serialize(context = SerializationDefaults.P2P_CONTEXT) }
.isInstanceOf(IllegalStateException::class.java)
.hasMessageContaining("UseCase '$P2P' is not within")
.hasMessageContaining("UseCase '$P2P' is not 'Storage")
}
}

View File

@ -4,6 +4,10 @@ import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.KryoException
import com.esotericsoftware.kryo.io.Output
import net.corda.core.serialization.*
import net.corda.core.serialization.internal.CheckpointSerializationContext
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.checkpointDeserialize
import net.corda.core.serialization.internal.checkpointSerialize
import net.corda.core.utilities.OpaqueBytes
import net.corda.node.serialization.kryo.CordaClassResolver
import net.corda.node.serialization.kryo.CordaKryo
@ -11,6 +15,7 @@ import net.corda.node.serialization.kryo.DefaultKryoCustomizer
import net.corda.node.serialization.kryo.kryoMagic
import net.corda.testing.internal.rigorousMock
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.internal.CheckpointSerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat
import org.junit.Before
import org.junit.Rule
@ -18,16 +23,18 @@ import org.junit.Test
import java.io.ByteArrayOutputStream
class SerializationTokenTest {
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
private lateinit var factory: SerializationFactory
private lateinit var context: SerializationContext
val testCheckpointSerialization = CheckpointSerializationEnvironmentRule()
private lateinit var factory: CheckpointSerializationFactory
private lateinit var context: CheckpointSerializationContext
@Before
fun setup() {
factory = testSerialization.serializationFactory
context = testSerialization.checkpointContext.withWhitelisted(SingletonSerializationToken::class.java)
factory = testCheckpointSerialization.checkpointSerializationFactory
context = testCheckpointSerialization.checkpointSerializationContext.withWhitelisted(SingletonSerializationToken::class.java)
}
// Large tokenizable object so we can tell from the smaller number of serialized bytes it was actually tokenized
@ -42,16 +49,16 @@ class SerializationTokenTest {
override fun equals(other: Any?) = other is LargeTokenizable && other.bytes.size == this.bytes.size
}
private fun serializeAsTokenContext(toBeTokenized: Any) = SerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock())
private fun serializeAsTokenContext(toBeTokenized: Any) = CheckpointSerializeAsTokenContextImpl(toBeTokenized, factory, context, rigorousMock())
@Test
fun `write token and read tokenizable`() {
val tokenizableBefore = LargeTokenizable()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
assertThat(serializedBytes.size).isLessThan(tokenizableBefore.numBytes)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
@ -62,8 +69,8 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
val tokenizableAfter = serializedBytes.deserialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
val tokenizableAfter = serializedBytes.checkpointDeserialize(factory, testContext)
assertThat(tokenizableAfter).isSameAs(tokenizableBefore)
}
@ -72,7 +79,7 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
tokenizableBefore.serialize(factory, testContext)
tokenizableBefore.checkpointSerialize(factory, testContext)
}
@Test(expected = UnsupportedOperationException::class)
@ -80,14 +87,14 @@ class SerializationTokenTest {
val tokenizableBefore = UnitSerializeAsToken()
val context = serializeAsTokenContext(emptyList<Any>())
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).serialize(factory, testContext)
serializedBytes.deserialize(factory, testContext)
val serializedBytes = tokenizableBefore.toToken(serializeAsTokenContext(emptyList<Any>())).checkpointSerialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
}
@Test(expected = KryoException::class)
fun `no context set`() {
val tokenizableBefore = UnitSerializeAsToken()
tokenizableBefore.serialize(factory, context)
tokenizableBefore.checkpointSerialize(factory, context)
}
@Test(expected = KryoException::class)
@ -105,7 +112,7 @@ class SerializationTokenTest {
kryo.writeObject(it, emptyList<Any>())
}
val serializedBytes = SerializedBytes<Any>(stream.toByteArray())
serializedBytes.deserialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
}
private class WrongTypeSerializeAsToken : SerializeAsToken {
@ -121,7 +128,7 @@ class SerializationTokenTest {
val tokenizableBefore = WrongTypeSerializeAsToken()
val context = serializeAsTokenContext(tokenizableBefore)
val testContext = this.context.withTokenContext(context)
val serializedBytes = tokenizableBefore.serialize(factory, testContext)
serializedBytes.deserialize(factory, testContext)
val serializedBytes = tokenizableBefore.checkpointSerialize(factory, testContext)
serializedBytes.checkpointDeserialize(factory, testContext)
}
}

View File

@ -51,8 +51,10 @@ import net.corda.testing.node.User
import net.corda.testing.node.internal.DriverDSLImpl.Companion.cordappsInCurrentAndAdditionalPackages
import okhttp3.OkHttpClient
import okhttp3.Request
import rx.Observable
import rx.Subscription
import rx.schedulers.Schedulers
import rx.subjects.AsyncSubject
import java.lang.management.ManagementFactory
import java.net.ConnectException
import java.net.URL
@ -763,9 +765,11 @@ class DriverDSLImpl(
val systemProperties = mutableMapOf(
"name" to config.corda.myLegalName,
"visualvm.display.name" to "corda-${config.corda.myLegalName}",
"log4j2.debug" to if (debugPort != null) "true" else "false"
"visualvm.display.name" to "corda-${config.corda.myLegalName}"
)
debugPort?.let {
systemProperties += "log4j2.debug" to "true"
}
systemProperties += inheritFromParentProcess()
systemProperties += overriddenSystemProperties

View File

@ -3,6 +3,7 @@ package net.corda.testing.node.internal
import com.typesafe.config.Config
import com.typesafe.config.ConfigFactory
import com.typesafe.config.ConfigParseOptions
import net.corda.client.rpc.ConnectionFailureException
import net.corda.client.rpc.CordaRPCClient
import net.corda.core.CordaException
import net.corda.core.concurrent.CordaFuture
@ -12,6 +13,7 @@ import net.corda.core.flows.FlowLogic
import net.corda.core.internal.FlowStateMachine
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.internal.times
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.millis
@ -24,11 +26,14 @@ import net.corda.node.services.messaging.Message
import net.corda.nodeapi.internal.persistence.DatabaseConfig
import net.corda.nodeapi.internal.persistence.TransactionIsolationLevel
import net.corda.testing.database.DatabaseConstants
import net.corda.testing.driver.NodeHandle
import net.corda.testing.internal.chooseIdentity
import net.corda.testing.node.InMemoryMessagingNetwork
import net.corda.testing.node.User
import net.corda.testing.node.testContext
import org.slf4j.LoggerFactory
import rx.Observable
import rx.subjects.AsyncSubject
import java.net.Socket
import java.net.SocketException
import java.time.Duration
@ -214,3 +219,21 @@ fun inMemoryH2DataSourceConfig(providedNodeName: String? = null, postfix: String
}
fun CordaRPCClient.start(user: User) = start(user.username, user.password)
fun NodeHandle.waitForShutdown(): Observable<Unit> {
return rpc.waitForShutdown().doAfterTerminate(::stop)
}
fun CordaRPCOps.waitForShutdown(): Observable<Unit> {
val completable = AsyncSubject.create<Unit>()
stateMachinesFeed().updates.subscribe({ _ -> }, { error ->
if (error is ConnectionFailureException) {
completable.onCompleted()
} else {
completable.onError(error)
}
})
return completable
}

View File

@ -241,6 +241,14 @@ class CordaRPCProxyClient(private val targetHostAndPort: NetworkHostAndPort) : C
TODO("not implemented")
}
override fun isWaitingForShutdown(): Boolean {
TODO("not implemented")
}
override fun terminate(drainPendingFlows: Boolean) {
TODO("not implemented")
}
private inline fun <reified T : Any> doPost(hostAndPort: NetworkHostAndPort, path: String, payload: ByteArray) : T {
val url = URL("http://$hostAndPort/rpc/$path")
val connection = url.openHttpConnection().apply {

View File

@ -5,6 +5,7 @@ import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.DoNotImplement
import net.corda.core.internal.staticField
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv
@ -45,7 +46,6 @@ class SerializationEnvironmentRule(private val inheritable: Boolean = false) : T
private lateinit var env: SerializationEnvironment
val serializationFactory get() = env.serializationFactory
val checkpointContext get() = env.checkpointContext
override fun apply(base: Statement, description: Description): Statement {
init(description.toString())

View File

@ -0,0 +1,71 @@
package net.corda.testing.core.internal
import com.nhaarman.mockito_kotlin.any
import com.nhaarman.mockito_kotlin.doAnswer
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.internal.staticField
import net.corda.core.serialization.internal.SerializationEnvironment
import net.corda.core.serialization.internal.effectiveSerializationEnv
import net.corda.testing.common.internal.asContextEnv
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.internal.createTestSerializationEnv
import net.corda.testing.internal.inVMExecutors
import net.corda.testing.internal.rigorousMock
import net.corda.testing.internal.testThreadFactory
import org.apache.activemq.artemis.core.remoting.impl.invm.InVMConnector
import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
/**
* A test checkpoint serialization rule implementation for use in tests.
*
* @param inheritable whether new threads inherit the environment, use sparingly.
*/
class CheckpointSerializationEnvironmentRule(private val inheritable: Boolean = false) : TestRule {
companion object {
init {
// Can't turn it off, and it creates threads that do serialization, so hack it:
InVMConnector::class.staticField<ExecutorService>("threadPoolExecutor").value = rigorousMock<ExecutorService>().also {
doAnswer {
inVMExecutors.computeIfAbsent(effectiveSerializationEnv) {
Executors.newCachedThreadPool(testThreadFactory(true)) // Close enough to what InVMConnector makes normally.
}.execute(it.arguments[0] as Runnable)
}.whenever(it).execute(any())
}
}
/** Do not call, instead use [SerializationEnvironmentRule] as a [org.junit.Rule]. */
fun <T> run(taskLabel: String, task: (SerializationEnvironment) -> T): T {
return CheckpointSerializationEnvironmentRule().apply { init(taskLabel) }.runTask(task)
}
}
private lateinit var env: SerializationEnvironment
override fun apply(base: Statement, description: Description): Statement {
init(description.toString())
return object : Statement() {
override fun evaluate() = runTask { base.evaluate() }
}
}
private fun init(envLabel: String) {
env = createTestSerializationEnv(envLabel)
}
private fun <T> runTask(task: (SerializationEnvironment) -> T): T {
try {
return env.asContextEnv(inheritable, task)
} finally {
inVMExecutors.remove(env)
}
}
val checkpointSerializationFactory get() = env.checkpointSerializationFactory
val checkpointSerializationContext get() = env.checkpointContext
}

View File

@ -4,10 +4,11 @@ import com.nhaarman.mockito_kotlin.doNothing
import com.nhaarman.mockito_kotlin.whenever
import net.corda.client.rpc.internal.serialization.amqp.AMQPClientSerializationScheme
import net.corda.core.DoNotImplement
import net.corda.core.serialization.internal.CheckpointSerializationFactory
import net.corda.core.serialization.internal.*
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.node.serialization.kryo.KryoServerSerializationScheme
import net.corda.node.serialization.kryo.KryoSerializationScheme
import net.corda.serialization.internal.*
import net.corda.testing.core.SerializationEnvironmentRule
import java.util.concurrent.ConcurrentHashMap
@ -33,8 +34,6 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
val factory = SerializationFactoryImpl().apply {
registerScheme(AMQPClientSerializationScheme(emptyList()))
registerScheme(AMQPServerSerializationScheme(emptyList()))
// needed for checkpointing
registerScheme(KryoServerSerializationScheme())
}
return object : SerializationEnvironmentImpl(
factory,
@ -42,7 +41,8 @@ internal fun createTestSerializationEnv(label: String): SerializationEnvironment
AMQP_RPC_SERVER_CONTEXT,
AMQP_RPC_CLIENT_CONTEXT,
AMQP_STORAGE_CONTEXT,
KRYO_CHECKPOINT_CONTEXT
KRYO_CHECKPOINT_CONTEXT,
CheckpointSerializationFactory(KryoSerializationScheme)
) {
override fun toString() = "testSerializationEnv($label)"
}

View File

@ -17,10 +17,7 @@ import net.corda.core.serialization.internal._contextSerializationEnv
import net.corda.core.utilities.base64ToByteArray
import net.corda.core.utilities.hexToByteArray
import net.corda.core.utilities.sequence
import net.corda.serialization.internal.AMQP_P2P_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.CordaSerializationMagic
import net.corda.serialization.internal.SerializationFactoryImpl
import net.corda.serialization.internal.*
import net.corda.serialization.internal.amqp.AbstractAMQPSerializationScheme
import net.corda.serialization.internal.amqp.DeserializationInput
import net.corda.serialization.internal.amqp.amqpMagic

View File

@ -11,6 +11,7 @@ dependencies {
compile "info.picocli:picocli:$picocli_version"
compile "com.jcabi:jcabi-manifests:$jcabi_manifests_version"
compile "org.slf4j:slf4j-api:$slf4j_version"
compile "org.apache.logging.log4j:log4j-slf4j-impl:$log4j_version"
// JAnsi: for drawing things to the terminal in nicely coloured ways.
compile "org.fusesource.jansi:jansi:$jansi_version"

View File

@ -0,0 +1,35 @@
package net.corda.cliutils
import org.apache.logging.log4j.core.LoggerContext
import org.apache.logging.log4j.core.config.Configuration
import org.apache.logging.log4j.core.config.ConfigurationFactory
import org.apache.logging.log4j.core.config.ConfigurationSource
import org.apache.logging.log4j.core.config.Order
import org.apache.logging.log4j.core.config.plugins.Plugin
import org.apache.logging.log4j.core.config.xml.XmlConfiguration
import org.apache.logging.log4j.core.impl.LogEventFactory
@Plugin(name = "CordaLog4j2ConfigFactory", category = "ConfigurationFactory")
@Order(10)
class CordaLog4j2ConfigFactory : ConfigurationFactory() {
private companion object {
private val SUPPORTED_TYPES = arrayOf(".xml", "*")
}
override fun getConfiguration(loggerContext: LoggerContext, source: ConfigurationSource): Configuration = ErrorCodeAppendingConfiguration(loggerContext, source)
override fun getSupportedTypes() = SUPPORTED_TYPES
private class ErrorCodeAppendingConfiguration(loggerContext: LoggerContext, source: ConfigurationSource) : XmlConfiguration(loggerContext, source) {
override fun doConfigure() {
super.doConfigure()
loggers.values.forEach {
val existingFactory = it.logEventFactory
it.logEventFactory = LogEventFactory { loggerName, marker, fqcn, level, message, properties, error -> existingFactory.createEvent(loggerName, marker, fqcn, level, message?.withErrorCodeFor(error, level), properties, error) }
}
}
}
}

View File

@ -0,0 +1,61 @@
package net.corda.cliutils
import org.apache.logging.log4j.Level
import org.apache.logging.log4j.message.Message
import org.apache.logging.log4j.message.SimpleMessage
import java.util.*
internal fun Message.withErrorCodeFor(error: Throwable?, level: Level): Message {
return when {
error != null && level.isInRange(Level.FATAL, Level.WARN) -> CompositeMessage("$formattedMessage [errorCode=${error.errorCode()}]", format, parameters, throwable)
else -> this
}
}
private fun Throwable.errorCode(hashedFields: (Throwable) -> Array<out Any?> = Throwable::defaultHashedFields): String {
val hash = staticLocationBasedHash(hashedFields)
return hash.toBase(36)
}
private fun Throwable.staticLocationBasedHash(hashedFields: (Throwable) -> Array<out Any?>, visited: Set<Throwable> = setOf(this)): Int {
val cause = this.cause
val fields = hashedFields.invoke(this)
return when {
cause != null && !visited.contains(cause) -> Objects.hash(*fields, cause.staticLocationBasedHash(hashedFields, visited + cause))
else -> Objects.hash(*fields)
}
}
private fun Int.toBase(base: Int): String = Integer.toUnsignedString(this, base)
private fun Array<StackTraceElement?>.customHashCode(maxElementsToConsider: Int = this.size): Int {
return Arrays.hashCode(take(maxElementsToConsider).map { it?.customHashCode() ?: 0 }.toIntArray())
}
private fun StackTraceElement.customHashCode(hashedFields: (StackTraceElement) -> Array<out Any?> = StackTraceElement::defaultHashedFields): Int {
return Objects.hash(*hashedFields.invoke(this))
}
private fun Throwable.defaultHashedFields(): Array<out Any?> {
return arrayOf(this::class.java.name, stackTrace?.customHashCode(3) ?: 0)
}
private fun StackTraceElement.defaultHashedFields(): Array<out Any?> {
return arrayOf(className, methodName)
}
private class CompositeMessage(message: String?, private val formatArg: String?, private val parameters: Array<out Any?>?, private val error: Throwable?) : SimpleMessage(message) {
override fun getThrowable(): Throwable? = error
override fun getParameters(): Array<out Any?>? = parameters
override fun getFormat(): String? = formatArg
}

View File

@ -0,0 +1 @@
log4j.configurationFactory=net.corda.cliutils.CordaLog4j2ConfigFactory

View File

@ -3,6 +3,7 @@ package net.corda.bootstrapper.serialization
import net.corda.core.serialization.internal.SerializationEnvironmentImpl
import net.corda.core.serialization.internal.nodeSerializationEnv
import net.corda.node.serialization.amqp.AMQPServerSerializationScheme
import net.corda.node.serialization.kryo.KRYO_CHECKPOINT_CONTEXT
import net.corda.serialization.internal.AMQP_P2P_CONTEXT
import net.corda.serialization.internal.AMQP_STORAGE_CONTEXT
import net.corda.serialization.internal.SerializationFactoryImpl
@ -20,7 +21,7 @@ class SerializationEngine {
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
storageContext = AMQP_STORAGE_CONTEXT.withClassLoader(classloader),
checkpointContext = AMQP_P2P_CONTEXT.withClassLoader(classloader)
checkpointContext = KRYO_CHECKPOINT_CONTEXT.withClassLoader(classloader)
)
}
}

View File

@ -44,7 +44,7 @@ public class RunShellCommand extends InteractiveShellCommand {
emitHelp(context, parser);
return null;
}
return InteractiveShell.runRPCFromString(command, out, context, ops(), objectMapper(), isSsh());
return InteractiveShell.runRPCFromString(command, out, context, ops(), objectMapper());
}
private void emitHelp(InvocationContext<Map> context, StringToMethodCallParser<CordaRPCOps> parser) {

View File

@ -18,6 +18,7 @@ import net.corda.core.internal.*
import net.corda.core.internal.concurrent.doneFuture
import net.corda.core.internal.concurrent.openFuture
import net.corda.core.messaging.*
import net.corda.nodeapi.internal.pendingFlowsCount
import net.corda.tools.shell.utlities.ANSIProgressRenderer
import net.corda.tools.shell.utlities.StdoutANSIProgressRenderer
import org.crsh.command.InvocationContext
@ -408,7 +409,7 @@ object InteractiveShell {
}
@JvmStatic
fun runRPCFromString(input: List<String>, out: RenderPrintWriter, context: InvocationContext<out Any>, cordaRPCOps: CordaRPCOps, om: ObjectMapper, isSsh: Boolean = false): Any? {
fun runRPCFromString(input: List<String>, out: RenderPrintWriter, context: InvocationContext<out Any>, cordaRPCOps: CordaRPCOps, om: ObjectMapper): Any? {
val cmd = input.joinToString(" ").trim { it <= ' ' }
if (cmd.startsWith("startflow", ignoreCase = true)) {
// The flow command provides better support and startFlow requires special handling anyway due to
@ -417,7 +418,7 @@ object InteractiveShell {
out.println("Please use the 'flow' command to interact with flows rather than the 'run' command.", Color.yellow)
return null
} else if (cmd.substringAfter(" ").trim().equals("gracefulShutdown", ignoreCase = true)) {
return InteractiveShell.gracefulShutdown(out, cordaRPCOps, isSsh)
return InteractiveShell.gracefulShutdown(out, cordaRPCOps)
}
var result: Any? = null
@ -456,9 +457,8 @@ object InteractiveShell {
return result
}
@JvmStatic
fun gracefulShutdown(userSessionOut: RenderPrintWriter, cordaRPCOps: CordaRPCOps, isSsh: Boolean = false) {
fun gracefulShutdown(userSessionOut: RenderPrintWriter, cordaRPCOps: CordaRPCOps) {
fun display(statements: RenderPrintWriter.() -> Unit) {
statements.invoke(userSessionOut)
@ -467,40 +467,48 @@ object InteractiveShell {
var isShuttingDown = false
try {
display { println("Orchestrating a clean shutdown, press CTRL+C to cancel...") }
isShuttingDown = true
display {
println("Orchestrating a clean shutdown...")
println("...enabling draining mode")
}
cordaRPCOps.setFlowsDrainingModeEnabled(true)
display {
println("...waiting for in-flight flows to be completed")
}
cordaRPCOps.pendingFlowsCount().updates
.doOnError { error ->
log.error(error.message)
throw error
}
.doOnNext { (first, second) ->
display {
println("...remaining: ${first}/${second}")
cordaRPCOps.terminate(true)
val latch = CountDownLatch(1)
cordaRPCOps.pendingFlowsCount().updates.doOnError { error ->
log.error(error.message)
throw error
}.doAfterTerminate(latch::countDown).subscribe(
// For each update.
{ (first, second) -> display { println("...remaining: $first / $second") } },
// On error.
{ error ->
if (!isShuttingDown) {
display { println("RPC failed: ${error.rootCause}", Color.red) }
}
}
.doOnCompleted {
if (isSsh) {
// print in the original Shell process
System.out.println("Shutting down the node via remote SSH session (it may take a while)")
}
display {
println("Shutting down the node (it may take a while)")
}
cordaRPCOps.shutdown()
isShuttingDown = true
},
// When completed.
{
connection.forceClose()
display {
println("...done, quitting standalone shell now.")
}
// This will only show up in the standalone Shell, because the embedded one is killed as part of a node's shutdown.
display { println("...done, quitting the shell now.") }
onExit.invoke()
}.toBlocking().single()
})
while (!Thread.currentThread().isInterrupted) {
try {
latch.await()
break
} catch (e: InterruptedException) {
try {
cordaRPCOps.setFlowsDrainingModeEnabled(false)
display { println("...cancelled clean shutdown.") }
} finally {
Thread.currentThread().interrupt()
break
}
}
}
} catch (e: StringToMethodCallParser.UnparseableCallException) {
display {
println(e.message, Color.red)
@ -508,9 +516,7 @@ object InteractiveShell {
}
} catch (e: Exception) {
if (!isShuttingDown) {
display {
println("RPC failed: ${e.rootCause}", Color.red)
}
display { println("RPC failed: ${e.rootCause}", Color.red) }
}
} finally {
InputStreamSerializer.invokeContext = null