[CORDA-1264}: Complete obfuscation of exceptions to client side. (#3155)

This commit is contained in:
Michele Sollecito
2018-05-21 13:34:37 +01:00
committed by GitHub
parent b0b36b5b7d
commit 5de2c2aa4b
30 changed files with 811 additions and 1007 deletions

View File

@ -1,10 +1,10 @@
package net.corda.client.rpc package net.corda.client.rpc
import net.corda.core.CordaRuntimeException import net.corda.core.CordaRuntimeException
import net.corda.nodeapi.exceptions.RpcSerializableError import net.corda.core.ClientRelevantError
/** /**
* Thrown to indicate that the calling user does not have permission for something they have requested (for example * Thrown to indicate that the calling user does not have permission for something they have requested (for example
* calling a method). * calling a method).
*/ */
class PermissionException(message: String) : CordaRuntimeException(message), RpcSerializableError class PermissionException(message: String) : CordaRuntimeException(message), ClientRelevantError

View File

@ -0,0 +1,9 @@
package net.corda.core
import net.corda.core.serialization.CordaSerializable
/**
* Allows an implementing [Throwable] to be propagated to clients.
*/
@CordaSerializable
interface ClientRelevantError

View File

@ -267,6 +267,12 @@ fun <T> Any.declaredField(name: String): DeclaredField<T> = DeclaredField(javaCl
*/ */
fun <T> Any.declaredField(clazz: KClass<*>, name: String): DeclaredField<T> = DeclaredField(clazz.java, name, this) fun <T> Any.declaredField(clazz: KClass<*>, name: String): DeclaredField<T> = DeclaredField(clazz.java, name, this)
/**
* Returns a [DeclaredField] wrapper around the (possibly non-public) instance field of the receiver object, but declared
* in its superclass [clazz].
*/
fun <T> Any.declaredField(clazz: Class<*>, name: String): DeclaredField<T> = DeclaredField(clazz, name, this)
/** creates a new instance if not a Kotlin object */ /** creates a new instance if not a Kotlin object */
fun <T : Any> KClass<T>.objectOrNewInstance(): T { fun <T : Any> KClass<T>.objectOrNewInstance(): T {
return this.objectInstance ?: this.createInstance() return this.objectInstance ?: this.createInstance()
@ -277,10 +283,43 @@ fun <T : Any> KClass<T>.objectOrNewInstance(): T {
* visibility. * visibility.
*/ */
class DeclaredField<T>(clazz: Class<*>, name: String, private val receiver: Any?) { class DeclaredField<T>(clazz: Class<*>, name: String, private val receiver: Any?) {
private val javaField = clazz.getDeclaredField(name).apply { isAccessible = true } private val javaField = findField(name, clazz)
var value: T var value: T
get() = uncheckedCast<Any?, T>(javaField.get(receiver)) get() {
set(value) = javaField.set(receiver, value) synchronized(this) {
return javaField.accessible { uncheckedCast<Any?, T>(get(receiver)) }
}
}
set(value) {
synchronized(this) {
javaField.accessible {
set(receiver, value)
}
}
}
val name: String = javaField.name
private fun <RESULT> Field.accessible(action: Field.() -> RESULT): RESULT {
val accessible = isAccessible
isAccessible = true
try {
return action(this)
} finally {
isAccessible = accessible
}
}
@Throws(NoSuchFieldException::class)
private fun findField(fieldName: String, clazz: Class<*>?): Field {
if (clazz == null) {
throw NoSuchFieldException(fieldName)
}
return try {
return clazz.getDeclaredField(fieldName)
} catch (e: NoSuchFieldException) {
findField(fieldName, clazz.superclass)
}
}
} }
/** The annotated object would have a more restricted visibility were it not needed in tests. */ /** The annotated object would have a more restricted visibility were it not needed in tests. */

View File

@ -290,9 +290,11 @@ interface CordaRPCOps : RPCOps {
fun openAttachment(id: SecureHash): InputStream fun openAttachment(id: SecureHash): InputStream
/** Uploads a jar to the node, returns it's hash. */ /** Uploads a jar to the node, returns it's hash. */
@Throws(java.nio.file.FileAlreadyExistsException::class)
fun uploadAttachment(jar: InputStream): SecureHash fun uploadAttachment(jar: InputStream): SecureHash
/** Uploads a jar including metadata to the node, returns it's hash. */ /** Uploads a jar including metadata to the node, returns it's hash. */
@Throws(java.nio.file.FileAlreadyExistsException::class)
fun uploadAttachmentWithMetadata(jar: InputStream, uploader: String, filename: String): SecureHash fun uploadAttachmentWithMetadata(jar: InputStream, uploader: String, filename: String): SecureHash
/** Queries attachments metadata */ /** Queries attachments metadata */

View File

@ -3,6 +3,7 @@
package net.corda.core.node.services.vault package net.corda.core.node.services.vault
import net.corda.core.DoNotImplement import net.corda.core.DoNotImplement
import net.corda.core.internal.declaredField
import net.corda.core.internal.uncheckedCast import net.corda.core.internal.uncheckedCast
import net.corda.core.schemas.PersistentState import net.corda.core.schemas.PersistentState
import net.corda.core.serialization.CordaSerializable import net.corda.core.serialization.CordaSerializable
@ -439,18 +440,7 @@ class FieldInfo internal constructor(val name: String, val entityClass: Class<*>
*/ */
@Throws(NoSuchFieldException::class) @Throws(NoSuchFieldException::class)
fun getField(fieldName: String, entityClass: Class<*>): FieldInfo { fun getField(fieldName: String, entityClass: Class<*>): FieldInfo {
return getField(fieldName, entityClass, entityClass)
}
@Throws(NoSuchFieldException::class) val field = entityClass.declaredField<Any>(entityClass, fieldName)
private fun getField(fieldName: String, clazz: Class<*>?, invokingClazz: Class<*>): FieldInfo { return FieldInfo(field.name, entityClass)
if (clazz == null) {
throw NoSuchFieldException(fieldName)
}
return try {
val field = clazz.getDeclaredField(fieldName)
return FieldInfo(field.name, invokingClazz)
} catch (e: NoSuchFieldException) {
getField(fieldName, clazz.superclass, invokingClazz)
}
} }

View File

@ -1,6 +1,7 @@
package net.corda.core.flows package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.* import net.corda.core.contracts.*
import net.corda.core.identity.AbstractParty import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party import net.corda.core.identity.Party
@ -17,7 +18,6 @@ import net.corda.finance.USD
import net.corda.finance.`issued by` import net.corda.finance.`issued by`
import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Cash
import net.corda.finance.flows.CashIssueFlow import net.corda.finance.flows.CashIssueFlow
import net.corda.node.internal.SecureCordaRPCOps
import net.corda.node.internal.StartedNode import net.corda.node.internal.StartedNode
import net.corda.node.services.Permissions.Companion.startFlow import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
@ -116,7 +116,7 @@ class ContractUpgradeFlowTest {
return startRpcClient<CordaRPCOps>( return startRpcClient<CordaRPCOps>(
rpcAddress = startRpcServer( rpcAddress = startRpcServer(
rpcUser = user, rpcUser = user,
ops = SecureCordaRPCOps(node.services, node.smm, node.database, node.services, { }) ops = node.rpcOps
).get().broker.hostAndPort!!, ).get().broker.hostAndPort!!,
username = user.username, username = user.username,
password = user.password password = user.password
@ -153,7 +153,7 @@ class ContractUpgradeFlowTest {
DummyContractV2::class.java).returnValue DummyContractV2::class.java).returnValue
mockNet.runNetwork() mockNet.runNetwork()
assertFailsWith(UnexpectedFlowEndException::class) { rejectedFuture.getOrThrow() } assertFailsWith(CordaRuntimeException::class) { rejectedFuture.getOrThrow() }
// Party B authorise the contract state upgrade, and immediately deauthorise the same. // Party B authorise the contract state upgrade, and immediately deauthorise the same.
rpcB.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow.Authorise(stateAndRef, upgrade) }, rpcB.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow.Authorise(stateAndRef, upgrade) },
@ -168,7 +168,7 @@ class ContractUpgradeFlowTest {
DummyContractV2::class.java).returnValue DummyContractV2::class.java).returnValue
mockNet.runNetwork() mockNet.runNetwork()
assertFailsWith(UnexpectedFlowEndException::class) { deauthorisedFuture.getOrThrow() } assertFailsWith(CordaRuntimeException::class) { deauthorisedFuture.getOrThrow() }
// Party B authorise the contract state upgrade. // Party B authorise the contract state upgrade.
rpcB.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow.Authorise(stateAndRef, upgrade) }, rpcB.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow.Authorise(stateAndRef, upgrade) },

View File

@ -1,37 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.TransactionVerificationException
import net.corda.core.flows.FlowException
import java.io.InvalidClassException
// could change to use package name matching but trying to avoid reflection for now
private val whitelisted = setOf(
FlowException::class,
InvalidClassException::class,
RpcSerializableError::class,
TransactionVerificationException::class
)
/**
* An [Exception] to signal RPC clients that something went wrong within a Corda node.
*/
class InternalNodeException(message: String) : CordaRuntimeException(message) {
companion object {
private const val DEFAULT_MESSAGE = "Something went wrong within the Corda node."
fun defaultMessage(): String = DEFAULT_MESSAGE
fun obfuscateIfInternal(wrapped: Throwable): Throwable {
(wrapped as? CordaRuntimeException)?.setCause(null)
return when {
whitelisted.any { it.isInstance(wrapped) } -> wrapped
else -> InternalNodeException(DEFAULT_MESSAGE).apply {
stackTrace = emptyArray()
}
}
}
}
}

View File

@ -1,11 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
import net.corda.core.crypto.SecureHash
class OutdatedNetworkParameterHashException(old: SecureHash, new: SecureHash) : CordaRuntimeException(TEMPLATE.format(old, new)), RpcSerializableError {
private companion object {
private const val TEMPLATE = "Refused to accept parameters with hash %s because network map advertises update with hash %s. Please check newest version"
}
}

View File

@ -1,8 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
/**
* Thrown to indicate that the command was rejected by the node, typically due to a special temporary mode.
*/
class RejectedCommandException(message: String) : CordaRuntimeException(message), RpcSerializableError

View File

@ -0,0 +1,49 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
import net.corda.core.crypto.SecureHash
import net.corda.core.ClientRelevantError
import net.corda.core.flows.IdentifiableException
/**
* Thrown to indicate that an attachment was already uploaded to a Corda node.
*/
class DuplicateAttachmentException(attachmentHash: String) : java.nio.file.FileAlreadyExistsException(attachmentHash), ClientRelevantError
/**
* Thrown to indicate that a flow was not designed for RPC and should be started from an RPC client.
*/
class NonRpcFlowException(logicType: Class<*>) : IllegalArgumentException("${logicType.name} was not designed for RPC"), ClientRelevantError
/**
* An [Exception] to signal RPC clients that something went wrong within a Corda node.
* The message is generic on purpose, as this prevents internal information from reaching RPC clients.
* Leaking internal information outside can compromise privacy e.g., party names and security e.g., passwords, stacktraces, etc.
*
* @param errorIdentifier an optional identifier for tracing problems across parties.
*/
class InternalNodeException(private val errorIdentifier: Long? = null) : CordaRuntimeException(message), ClientRelevantError, IdentifiableException {
companion object {
/**
* Message for the exception.
*/
const val message = "Something went wrong within the Corda node."
}
override fun getErrorId(): Long? {
return errorIdentifier
}
}
class OutdatedNetworkParameterHashException(old: SecureHash, new: SecureHash) : CordaRuntimeException(TEMPLATE.format(old, new)), ClientRelevantError {
private companion object {
private const val TEMPLATE = "Refused to accept parameters with hash %s because network map advertises update with hash %s. Please check newest version"
}
}
/**
* Thrown to indicate that the command was rejected by the node, typically due to a special temporary mode.
*/
class RejectedCommandException(message: String) : CordaRuntimeException(message), ClientRelevantError

View File

@ -1,9 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.serialization.CordaSerializable
/**
* Allows an implementing [Throwable] to be propagated to RPC clients.
*/
@CordaSerializable
interface RpcSerializableError

View File

@ -1,15 +0,0 @@
package net.corda.nodeapi.exceptions.adapters
import net.corda.core.internal.concurrent.mapError
import net.corda.core.messaging.FlowHandle
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.exceptions.InternalNodeException
/**
* Adapter able to mask errors within a Corda node for RPC clients.
*/
@CordaSerializable
data class InternalObfuscatingFlowHandle<RESULT>(val wrapped: FlowHandle<RESULT>) : FlowHandle<RESULT> by wrapped {
override val returnValue = wrapped.returnValue.mapError(InternalNodeException.Companion::obfuscateIfInternal)
}

View File

@ -1,22 +0,0 @@
package net.corda.nodeapi.exceptions.adapters
import net.corda.core.internal.concurrent.mapError
import net.corda.core.mapErrors
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.exceptions.InternalNodeException
/**
* Adapter able to mask errors within a Corda node for RPC clients.
*/
@CordaSerializable
class InternalObfuscatingFlowProgressHandle<RESULT>(val wrapped: FlowProgressHandle<RESULT>) : FlowProgressHandle<RESULT> by wrapped {
override val returnValue = wrapped.returnValue.mapError(InternalNodeException.Companion::obfuscateIfInternal)
override val progress = wrapped.progress.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
override val stepsTreeIndexFeed = wrapped.stepsTreeIndexFeed?.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
override val stepsTreeFeed = wrapped.stepsTreeFeed?.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
}

View File

@ -1,6 +1,6 @@
package net.corda package net.corda
import net.corda.core.CordaRuntimeException import net.corda.core.CordaRuntimeException
import net.corda.nodeapi.exceptions.RpcSerializableError import net.corda.core.ClientRelevantError
class ClientRelevantException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause), RpcSerializableError class ClientRelevantException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause), ClientRelevantError

View File

@ -12,7 +12,6 @@ import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.node.internal.NodeStartup import net.corda.node.internal.NodeStartup
import net.corda.node.services.Permissions.Companion.startFlow import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.testing.common.internal.ProjectStructure.projectRootDir import net.corda.testing.common.internal.ProjectStructure.projectRootDir
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.driver.DriverParameters import net.corda.testing.driver.DriverParameters
@ -34,7 +33,6 @@ class BootTests {
start(user.username, user.password).proxy.startFlow(::ObjectInputStreamFlow).returnValue start(user.username, user.password).proxy.startFlow(::ObjectInputStreamFlow).returnValue
assertThatThrownBy { future.getOrThrow() } assertThatThrownBy { future.getOrThrow() }
.isInstanceOf(CordaRuntimeException::class.java) .isInstanceOf(CordaRuntimeException::class.java)
.hasMessageContaining(InternalNodeException.defaultMessage())
} }
} }

View File

@ -1,474 +1,474 @@
package net.corda.node.amqp //package net.corda.node.amqp
//
import com.nhaarman.mockito_kotlin.doReturn //import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever //import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.Crypto //import net.corda.core.crypto.Crypto
import net.corda.core.identity.CordaX500Name //import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div //import net.corda.core.internal.div
import net.corda.core.toFuture //import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort //import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days //import net.corda.core.utilities.days
import net.corda.core.utilities.minutes //import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds //import net.corda.core.utilities.seconds
import net.corda.node.services.config.NodeConfiguration //import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate //import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX //import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER //import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER
import net.corda.nodeapi.internal.config.SSLConfiguration //import net.corda.nodeapi.internal.config.SSLConfiguration
import net.corda.nodeapi.internal.crypto.* //import net.corda.nodeapi.internal.crypto.*
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus //import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient //import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer //import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.testing.core.* //import net.corda.testing.core.*
import net.corda.testing.internal.DEV_INTERMEDIATE_CA //import net.corda.testing.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.internal.DEV_ROOT_CA //import net.corda.testing.internal.DEV_ROOT_CA
import net.corda.testing.internal.rigorousMock //import net.corda.testing.internal.rigorousMock
import org.bouncycastle.asn1.x500.X500Name //import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.* //import org.bouncycastle.asn1.x509.*
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter //import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils //import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder //import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.jce.provider.BouncyCastleProvider //import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder //import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.eclipse.jetty.server.Server //import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.ServerConnector //import org.eclipse.jetty.server.ServerConnector
import org.eclipse.jetty.server.handler.HandlerCollection //import org.eclipse.jetty.server.handler.HandlerCollection
import org.eclipse.jetty.servlet.ServletContextHandler //import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.ServletHolder //import org.eclipse.jetty.servlet.ServletHolder
import org.glassfish.jersey.server.ResourceConfig //import org.glassfish.jersey.server.ResourceConfig
import org.glassfish.jersey.servlet.ServletContainer //import org.glassfish.jersey.servlet.ServletContainer
import org.junit.After //import org.junit.After
import org.junit.Before //import org.junit.Before
import org.junit.Rule //import org.junit.Rule
import org.junit.Test //import org.junit.Test
import org.junit.rules.TemporaryFolder //import org.junit.rules.TemporaryFolder
import java.io.Closeable //import java.io.Closeable
import java.math.BigInteger //import java.math.BigInteger
import java.net.InetSocketAddress //import java.net.InetSocketAddress
import java.security.KeyPair //import java.security.KeyPair
import java.security.PrivateKey //import java.security.PrivateKey
import java.security.Security //import java.security.Security
import java.security.cert.X509CRL //import java.security.cert.X509CRL
import java.security.cert.X509Certificate //import java.security.cert.X509Certificate
import java.util.* //import java.util.*
import javax.ws.rs.GET //import javax.ws.rs.GET
import javax.ws.rs.Path //import javax.ws.rs.Path
import javax.ws.rs.Produces //import javax.ws.rs.Produces
import javax.ws.rs.core.Response //import javax.ws.rs.core.Response
import kotlin.test.assertEquals //import kotlin.test.assertEquals
//
class CertificateRevocationListNodeTests { //class CertificateRevocationListNodeTests {
@Rule // @Rule
@JvmField // @JvmField
val temporaryFolder = TemporaryFolder() // val temporaryFolder = TemporaryFolder()
//
private val ROOT_CA = DEV_ROOT_CA // private val ROOT_CA = DEV_ROOT_CA
private lateinit var INTERMEDIATE_CA: CertificateAndKeyPair // private lateinit var INTERMEDIATE_CA: CertificateAndKeyPair
//
private val serverPort = freePort() // private val serverPort = freePort()
//
private lateinit var server: CrlServer // private lateinit var server: CrlServer
//
private val revokedNodeCerts: MutableList<BigInteger> = mutableListOf() // private val revokedNodeCerts: MutableList<BigInteger> = mutableListOf()
private val revokedIntermediateCerts: MutableList<BigInteger> = mutableListOf() // private val revokedIntermediateCerts: MutableList<BigInteger> = mutableListOf()
//
private abstract class AbstractNodeConfiguration : NodeConfiguration // private abstract class AbstractNodeConfiguration : NodeConfiguration
//
@Before // @Before
fun setUp() { // fun setUp() {
Security.addProvider(BouncyCastleProvider()) // Security.addProvider(BouncyCastleProvider())
revokedNodeCerts.clear() // revokedNodeCerts.clear()
server = CrlServer(NetworkHostAndPort("localhost", 0)) // server = CrlServer(NetworkHostAndPort("localhost", 0))
server.start() // server.start()
INTERMEDIATE_CA = CertificateAndKeyPair(replaceCrlDistPointCaCertificate( // INTERMEDIATE_CA = CertificateAndKeyPair(replaceCrlDistPointCaCertificate(
DEV_INTERMEDIATE_CA.certificate, // DEV_INTERMEDIATE_CA.certificate,
CertificateType.INTERMEDIATE_CA, // CertificateType.INTERMEDIATE_CA,
ROOT_CA.keyPair, // ROOT_CA.keyPair,
"http://${server.hostAndPort}/crl/intermediate.crl"), DEV_INTERMEDIATE_CA.keyPair) // "http://${server.hostAndPort}/crl/intermediate.crl"), DEV_INTERMEDIATE_CA.keyPair)
} // }
//
@After // @After
fun tearDown() { // fun tearDown() {
server.close() // server.close()
revokedNodeCerts.clear() // revokedNodeCerts.clear()
} // }
//
@Test // @Test
fun `Simple AMPQ Client to Server connection works`() { // fun `Simple AMPQ Client to Server connection works`() {
val crlCheckSoftFail = true // val crlCheckSoftFail = true
val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) // val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
amqpServer.use { // amqpServer.use {
amqpServer.start() // amqpServer.start()
val receiveSubs = amqpServer.onReceive.subscribe { // val receiveSubs = amqpServer.onReceive.subscribe {
assertEquals(BOB_NAME.toString(), it.sourceLegalName) // assertEquals(BOB_NAME.toString(), it.sourceLegalName)
assertEquals(P2P_PREFIX + "Test", it.topic) // assertEquals(P2P_PREFIX + "Test", it.topic)
assertEquals("Test", String(it.payload)) // assertEquals("Test", String(it.payload))
it.complete(true) // it.complete(true)
} // }
val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail) // val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
amqpClient.use { // amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture() // val serverConnected = amqpServer.onConnection.toFuture()
val clientConnected = amqpClient.onConnection.toFuture() // val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start() // amqpClient.start()
val serverConnect = serverConnected.get() // val serverConnect = serverConnected.get()
assertEquals(true, serverConnect.connected) // assertEquals(true, serverConnect.connected)
val clientConnect = clientConnected.get() // val clientConnect = clientConnected.get()
assertEquals(true, clientConnect.connected) // assertEquals(true, clientConnect.connected)
val msg = amqpClient.createMessage("Test".toByteArray(), // val msg = amqpClient.createMessage("Test".toByteArray(),
P2P_PREFIX + "Test", // P2P_PREFIX + "Test",
ALICE_NAME.toString(), // ALICE_NAME.toString(),
emptyMap()) // emptyMap())
amqpClient.write(msg) // amqpClient.write(msg)
assertEquals(MessageStatus.Acknowledged, msg.onComplete.get()) // assertEquals(MessageStatus.Acknowledged, msg.onComplete.get())
receiveSubs.unsubscribe() // receiveSubs.unsubscribe()
} // }
} // }
} // }
//
@Test // @Test
fun `AMPQ Client to Server connection fails when client's certificate is revoked`() { // fun `AMPQ Client to Server connection fails when client's certificate is revoked`() {
val crlCheckSoftFail = true // val crlCheckSoftFail = true
val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) // val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
amqpServer.use { // amqpServer.use {
amqpServer.start() // amqpServer.start()
amqpServer.onReceive.subscribe { // amqpServer.onReceive.subscribe {
it.complete(true) // it.complete(true)
} // }
val (amqpClient, clientCert) = createClient(serverPort, crlCheckSoftFail) // val (amqpClient, clientCert) = createClient(serverPort, crlCheckSoftFail)
revokedNodeCerts.add(clientCert.serialNumber) // revokedNodeCerts.add(clientCert.serialNumber)
amqpClient.use { // amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture() // val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture() // amqpClient.onConnection.toFuture()
amqpClient.start() // amqpClient.start()
val serverConnect = serverConnected.get() // val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected) // assertEquals(false, serverConnect.connected)
} // }
} // }
} // }
//
@Test // @Test
fun `AMPQ Client to Server connection fails when servers's certificate is revoked`() { // fun `AMPQ Client to Server connection fails when servers's certificate is revoked`() {
val crlCheckSoftFail = true // val crlCheckSoftFail = true
val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) // val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
revokedNodeCerts.add(serverCert.serialNumber) // revokedNodeCerts.add(serverCert.serialNumber)
amqpServer.use { // amqpServer.use {
amqpServer.start() // amqpServer.start()
amqpServer.onReceive.subscribe { // amqpServer.onReceive.subscribe {
it.complete(true) // it.complete(true)
} // }
val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail) // val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
amqpClient.use { // amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture() // val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture() // amqpClient.onConnection.toFuture()
amqpClient.start() // amqpClient.start()
val serverConnect = serverConnected.get() // val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected) // assertEquals(false, serverConnect.connected)
} // }
} // }
} // }
//
@Test // @Test
fun `AMPQ Client to Server connection fails when servers's certificate is revoked and soft fail is enabled`() { // fun `AMPQ Client to Server connection fails when servers's certificate is revoked and soft fail is enabled`() {
val crlCheckSoftFail = true // val crlCheckSoftFail = true
val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail) // val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
revokedNodeCerts.add(serverCert.serialNumber) // revokedNodeCerts.add(serverCert.serialNumber)
amqpServer.use { // amqpServer.use {
amqpServer.start() // amqpServer.start()
amqpServer.onReceive.subscribe { // amqpServer.onReceive.subscribe {
it.complete(true) // it.complete(true)
} // }
val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail) // val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
amqpClient.use { // amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture() // val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture() // amqpClient.onConnection.toFuture()
amqpClient.start() // amqpClient.start()
val serverConnect = serverConnected.get() // val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected) // assertEquals(false, serverConnect.connected)
} // }
} // }
} // }
//
@Test // @Test
fun `AMPQ Client to Server connection succeeds when CRL cannot be obtained and soft fail is enabled`() { // fun `AMPQ Client to Server connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
val crlCheckSoftFail = true // val crlCheckSoftFail = true
val (amqpServer, serverCert) = createServer( // val (amqpServer, serverCert) = createServer(
serverPort, // serverPort,
crlCheckSoftFail = crlCheckSoftFail, // crlCheckSoftFail = crlCheckSoftFail,
nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl") // nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl")
amqpServer.use { // amqpServer.use {
amqpServer.start() // amqpServer.start()
amqpServer.onReceive.subscribe { // amqpServer.onReceive.subscribe {
it.complete(true) // it.complete(true)
} // }
val (amqpClient, _) = createClient( // val (amqpClient, _) = createClient(
serverPort, // serverPort,
crlCheckSoftFail, // crlCheckSoftFail,
nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl") // nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl")
amqpClient.use { // amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture() // val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture() // amqpClient.onConnection.toFuture()
amqpClient.start() // amqpClient.start()
val serverConnect = serverConnected.get() // val serverConnect = serverConnected.get()
assertEquals(true, serverConnect.connected) // assertEquals(true, serverConnect.connected)
} // }
} // }
} // }
//
@Test // @Test
fun `Revocation status chceck fails when the CRL distribution point is not set and soft fail is disabled`() { // fun `Revocation status chceck fails when the CRL distribution point is not set and soft fail is disabled`() {
val crlCheckSoftFail = false // val crlCheckSoftFail = false
val (amqpServer, _) = createServer( // val (amqpServer, _) = createServer(
serverPort, // serverPort,
crlCheckSoftFail = crlCheckSoftFail, // crlCheckSoftFail = crlCheckSoftFail,
tlsCrlDistPoint = null) // tlsCrlDistPoint = null)
amqpServer.use { // amqpServer.use {
amqpServer.start() // amqpServer.start()
amqpServer.onReceive.subscribe { // amqpServer.onReceive.subscribe {
it.complete(true) // it.complete(true)
} // }
val (amqpClient, _) = createClient( // val (amqpClient, _) = createClient(
serverPort, // serverPort,
crlCheckSoftFail, // crlCheckSoftFail,
tlsCrlDistPoint = null) // tlsCrlDistPoint = null)
amqpClient.use { // amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture() // val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture() // amqpClient.onConnection.toFuture()
amqpClient.start() // amqpClient.start()
val serverConnect = serverConnected.get() // val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected) // assertEquals(false, serverConnect.connected)
} // }
} // }
} // }
//
@Test // @Test
fun `Revocation status chceck succeds when the CRL distribution point is not set and soft fail is enabled`() { // fun `Revocation status chceck succeds when the CRL distribution point is not set and soft fail is enabled`() {
val crlCheckSoftFail = true // val crlCheckSoftFail = true
val (amqpServer, _) = createServer( // val (amqpServer, _) = createServer(
serverPort, // serverPort,
crlCheckSoftFail = crlCheckSoftFail, // crlCheckSoftFail = crlCheckSoftFail,
tlsCrlDistPoint = null) // tlsCrlDistPoint = null)
amqpServer.use { // amqpServer.use {
amqpServer.start() // amqpServer.start()
amqpServer.onReceive.subscribe { // amqpServer.onReceive.subscribe {
it.complete(true) // it.complete(true)
} // }
val (amqpClient, _) = createClient( // val (amqpClient, _) = createClient(
serverPort, // serverPort,
crlCheckSoftFail, // crlCheckSoftFail,
tlsCrlDistPoint = null) // tlsCrlDistPoint = null)
amqpClient.use { // amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture() // val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture() // amqpClient.onConnection.toFuture()
amqpClient.start() // amqpClient.start()
val serverConnect = serverConnected.get() // val serverConnect = serverConnected.get()
assertEquals(true, serverConnect.connected) // assertEquals(true, serverConnect.connected)
} // }
} // }
} // }
//
private fun createClient(targetPort: Int, // private fun createClient(targetPort: Int,
crlCheckSoftFail: Boolean, // crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl", // nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl",
tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl", // tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl",
maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPClient, X509Certificate> { // maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPClient, X509Certificate> {
val clientConfig = rigorousMock<AbstractNodeConfiguration>().also { // val clientConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(temporaryFolder.root.toPath() / "client").whenever(it).baseDirectory // doReturn(temporaryFolder.root.toPath() / "client").whenever(it).baseDirectory
doReturn(BOB_NAME).whenever(it).myLegalName // doReturn(BOB_NAME).whenever(it).myLegalName
doReturn("trustpass").whenever(it).trustStorePassword // doReturn("trustpass").whenever(it).trustStorePassword
doReturn("cordacadevpass").whenever(it).keyStorePassword // doReturn("cordacadevpass").whenever(it).keyStorePassword
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail // doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
} // }
clientConfig.configureWithDevSSLCertificate() // clientConfig.configureWithDevSSLCertificate()
val nodeCert = clientConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint) // val nodeCert = clientConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint)
val clientTruststore = clientConfig.loadTrustStore().internal // val clientTruststore = clientConfig.loadTrustStore().internal
val clientKeystore = clientConfig.loadSslKeyStore().internal // val clientKeystore = clientConfig.loadSslKeyStore().internal
return Pair(AMQPClient( // return Pair(AMQPClient(
listOf(NetworkHostAndPort("localhost", targetPort)), // listOf(NetworkHostAndPort("localhost", targetPort)),
setOf(ALICE_NAME, CHARLIE_NAME), // setOf(ALICE_NAME, CHARLIE_NAME),
PEER_USER, // PEER_USER,
PEER_USER, // PEER_USER,
clientKeystore, // clientKeystore,
clientConfig.keyStorePassword, // clientConfig.keyStorePassword,
clientTruststore, // clientTruststore,
crlCheckSoftFail, // crlCheckSoftFail,
maxMessageSize = maxMessageSize), nodeCert) // maxMessageSize = maxMessageSize), nodeCert)
} // }
//
private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME, // private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME,
crlCheckSoftFail: Boolean, // crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl", // nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl",
tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl", // tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl",
maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPServer, X509Certificate> { // maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPServer, X509Certificate> {
val serverConfig = rigorousMock<AbstractNodeConfiguration>().also { // val serverConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(temporaryFolder.root.toPath() / "server").whenever(it).baseDirectory // doReturn(temporaryFolder.root.toPath() / "server").whenever(it).baseDirectory
doReturn(name).whenever(it).myLegalName // doReturn(name).whenever(it).myLegalName
doReturn("trustpass").whenever(it).trustStorePassword // doReturn("trustpass").whenever(it).trustStorePassword
doReturn("cordacadevpass").whenever(it).keyStorePassword // doReturn("cordacadevpass").whenever(it).keyStorePassword
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail // doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
} // }
serverConfig.configureWithDevSSLCertificate() // serverConfig.configureWithDevSSLCertificate()
val nodeCert = serverConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint) // val nodeCert = serverConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint)
val serverTruststore = serverConfig.loadTrustStore().internal // val serverTruststore = serverConfig.loadTrustStore().internal
val serverKeystore = serverConfig.loadSslKeyStore().internal // val serverKeystore = serverConfig.loadSslKeyStore().internal
return Pair(AMQPServer( // return Pair(AMQPServer(
"0.0.0.0", // "0.0.0.0",
port, // port,
PEER_USER, // PEER_USER,
PEER_USER, // PEER_USER,
serverKeystore, // serverKeystore,
serverConfig.keyStorePassword, // serverConfig.keyStorePassword,
serverTruststore, // serverTruststore,
crlCheckSoftFail, // crlCheckSoftFail,
maxMessageSize = maxMessageSize), nodeCert) // maxMessageSize = maxMessageSize), nodeCert)
} // }
//
private fun SSLConfiguration.recreateNodeCaAndTlsCertificates(nodeCaCrlDistPoint: String, tlsCrlDistPoint: String?): X509Certificate { // private fun SSLConfiguration.recreateNodeCaAndTlsCertificates(nodeCaCrlDistPoint: String, tlsCrlDistPoint: String?): X509Certificate {
val nodeKeyStore = loadNodeKeyStore() // val nodeKeyStore = loadNodeKeyStore()
val (nodeCert, nodeKeys) = nodeKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA) // val (nodeCert, nodeKeys) = nodeKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA)
val newNodeCert = replaceCrlDistPointCaCertificate(nodeCert, CertificateType.NODE_CA, INTERMEDIATE_CA.keyPair, nodeCaCrlDistPoint) // val newNodeCert = replaceCrlDistPointCaCertificate(nodeCert, CertificateType.NODE_CA, INTERMEDIATE_CA.keyPair, nodeCaCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, INTERMEDIATE_CA.certificate, *nodeKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA).drop(2).toTypedArray()) // val nodeCertChain = listOf(newNodeCert, INTERMEDIATE_CA.certificate, *nodeKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA).drop(2).toTypedArray())
nodeKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA) // nodeKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA)
nodeKeyStore.save() // nodeKeyStore.save()
nodeKeyStore.update { // nodeKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain) // setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain)
} // }
val sslKeyStore = loadSslKeyStore() // val sslKeyStore = loadSslKeyStore()
val (tlsCert, tlsKeys) = sslKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS) // val (tlsCert, tlsKeys) = sslKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS)
val newTlsCert = replaceCrlDistPointCaCertificate(tlsCert, CertificateType.TLS, nodeKeys, tlsCrlDistPoint, X500Name.getInstance(ROOT_CA.certificate.subjectX500Principal.encoded)) // val newTlsCert = replaceCrlDistPointCaCertificate(tlsCert, CertificateType.TLS, nodeKeys, tlsCrlDistPoint, X500Name.getInstance(ROOT_CA.certificate.subjectX500Principal.encoded))
val sslCertChain = listOf(newTlsCert, newNodeCert, INTERMEDIATE_CA.certificate, *sslKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_TLS).drop(3).toTypedArray()) // val sslCertChain = listOf(newTlsCert, newNodeCert, INTERMEDIATE_CA.certificate, *sslKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_TLS).drop(3).toTypedArray())
sslKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS) // sslKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS)
sslKeyStore.save() // sslKeyStore.save()
sslKeyStore.update { // sslKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain) // setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain)
} // }
return newNodeCert // return newNodeCert
} // }
//
private fun replaceCrlDistPointCaCertificate(currentCaCert: X509Certificate, certType: CertificateType, issuerKeyPair: KeyPair, crlDistPoint: String?, crlIssuer: X500Name? = null): X509Certificate { // private fun replaceCrlDistPointCaCertificate(currentCaCert: X509Certificate, certType: CertificateType, issuerKeyPair: KeyPair, crlDistPoint: String?, crlIssuer: X500Name? = null): X509Certificate {
val signatureScheme = Crypto.findSignatureScheme(issuerKeyPair.private) // val signatureScheme = Crypto.findSignatureScheme(issuerKeyPair.private)
val provider = Crypto.findProvider(signatureScheme.providerName) // val provider = Crypto.findProvider(signatureScheme.providerName)
val issuerSigner = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider) // val issuerSigner = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider)
val builder = X509Utilities.createPartialCertificate( // val builder = X509Utilities.createPartialCertificate(
certType, // certType,
currentCaCert.issuerX500Principal, // currentCaCert.issuerX500Principal,
issuerKeyPair.public, // issuerKeyPair.public,
currentCaCert.subjectX500Principal, // currentCaCert.subjectX500Principal,
currentCaCert.publicKey, // currentCaCert.publicKey,
Pair(Date(System.currentTimeMillis() - 5.minutes.toMillis()), Date(System.currentTimeMillis() + 10.days.toMillis())), // Pair(Date(System.currentTimeMillis() - 5.minutes.toMillis()), Date(System.currentTimeMillis() + 10.days.toMillis())),
null // null
) // )
crlDistPoint?.let { // crlDistPoint?.let {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, it))) // val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, it)))
val crlIssuerGeneralNames = crlIssuer?.let { // val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer)) // GeneralNames(GeneralName(crlIssuer))
} // }
val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames) // val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames)
builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint))) // builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint)))
} // }
return builder.build(issuerSigner).toJca() // return builder.build(issuerSigner).toJca()
} // }
//
@Path("crl") // @Path("crl")
inner class CrlServlet(private val server: CrlServer) { // inner class CrlServlet(private val server: CrlServer) {
//
private val SIGNATURE_ALGORITHM = "SHA256withECDSA" // private val SIGNATURE_ALGORITHM = "SHA256withECDSA"
private val NODE_CRL = "node.crl" // private val NODE_CRL = "node.crl"
private val INTEMEDIATE_CRL = "intermediate.crl" // private val INTEMEDIATE_CRL = "intermediate.crl"
private val EMPTY_CRL = "empty.crl" // private val EMPTY_CRL = "empty.crl"
//
@GET // @GET
@Path("node.crl") // @Path("node.crl")
@Produces("application/pkcs7-crl") // @Produces("application/pkcs7-crl")
fun getNodeCRL(): Response { // fun getNodeCRL(): Response {
return Response.ok(createRevocationList( // return Response.ok(createRevocationList(
INTERMEDIATE_CA.certificate, // INTERMEDIATE_CA.certificate,
INTERMEDIATE_CA.keyPair.private, // INTERMEDIATE_CA.keyPair.private,
NODE_CRL, // NODE_CRL,
false, // false,
*revokedNodeCerts.toTypedArray()).encoded).build() // *revokedNodeCerts.toTypedArray()).encoded).build()
} // }
//
@GET // @GET
@Path("intermediate.crl") // @Path("intermediate.crl")
@Produces("application/pkcs7-crl") // @Produces("application/pkcs7-crl")
fun getIntermediateCRL(): Response { // fun getIntermediateCRL(): Response {
return Response.ok(createRevocationList( // return Response.ok(createRevocationList(
ROOT_CA.certificate, // ROOT_CA.certificate,
ROOT_CA.keyPair.private, // ROOT_CA.keyPair.private,
INTEMEDIATE_CRL, // INTEMEDIATE_CRL,
false, // false,
*revokedIntermediateCerts.toTypedArray()).encoded).build() // *revokedIntermediateCerts.toTypedArray()).encoded).build()
} // }
//
@GET // @GET
@Path("empty.crl") // @Path("empty.crl")
@Produces("application/pkcs7-crl") // @Produces("application/pkcs7-crl")
fun getEmptyCRL(): Response { // fun getEmptyCRL(): Response {
return Response.ok(createRevocationList( // return Response.ok(createRevocationList(
ROOT_CA.certificate, // ROOT_CA.certificate,
ROOT_CA.keyPair.private, // ROOT_CA.keyPair.private,
EMPTY_CRL, true).encoded).build() // EMPTY_CRL, true).encoded).build()
} // }
//
private fun createRevocationList(caCertificate: X509Certificate, // private fun createRevocationList(caCertificate: X509Certificate,
caPrivateKey: PrivateKey, // caPrivateKey: PrivateKey,
endpoint: String, // endpoint: String,
indirect: Boolean, // indirect: Boolean,
vararg serialNumbers: BigInteger): X509CRL { // vararg serialNumbers: BigInteger): X509CRL {
println("Generating CRL for $endpoint") // println("Generating CRL for $endpoint")
val builder = JcaX509v2CRLBuilder(caCertificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis())) // val builder = JcaX509v2CRLBuilder(caCertificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis()))
val extensionUtils = JcaX509ExtensionUtils() // val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier, // builder.addExtension(Extension.authorityKeyIdentifier,
false, extensionUtils.createAuthorityKeyIdentifier(caCertificate)) // false, extensionUtils.createAuthorityKeyIdentifier(caCertificate))
val issuingDistPointName = GeneralName( // val issuingDistPointName = GeneralName(
GeneralName.uniformResourceIdentifier, // GeneralName.uniformResourceIdentifier,
"http://${server.hostAndPort.host}:${server.hostAndPort.port}/crl/$endpoint") // "http://${server.hostAndPort.host}:${server.hostAndPort.port}/crl/$endpoint")
// This is required and needs to match the certificate settings with respect to being indirect // // This is required and needs to match the certificate settings with respect to being indirect
val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false) // val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false)
builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint) // builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint)
builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis())) // builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis()))
serialNumbers.forEach { // serialNumbers.forEach {
builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold) // builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold)
} // }
val signer = JcaContentSignerBuilder(SIGNATURE_ALGORITHM).setProvider(BouncyCastleProvider.PROVIDER_NAME).build(caPrivateKey) // val signer = JcaContentSignerBuilder(SIGNATURE_ALGORITHM).setProvider(BouncyCastleProvider.PROVIDER_NAME).build(caPrivateKey)
return JcaX509CRLConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME).getCRL(builder.build(signer)) // return JcaX509CRLConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME).getCRL(builder.build(signer))
} // }
} // }
//
inner class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable { // inner class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
//
private val server: Server = Server(InetSocketAddress(hostAndPort.host, hostAndPort.port)).apply { // private val server: Server = Server(InetSocketAddress(hostAndPort.host, hostAndPort.port)).apply {
handler = HandlerCollection().apply { // handler = HandlerCollection().apply {
addHandler(buildServletContextHandler()) // addHandler(buildServletContextHandler())
} // }
} // }
//
val hostAndPort: NetworkHostAndPort // val hostAndPort: NetworkHostAndPort
get() = server.connectors.mapNotNull { it as? ServerConnector } // get() = server.connectors.mapNotNull { it as? ServerConnector }
.map { NetworkHostAndPort(it.host, it.localPort) } // .map { NetworkHostAndPort(it.host, it.localPort) }
.first() // .first()
//
override fun close() { // override fun close() {
println("Shutting down network management web services...") // println("Shutting down network management web services...")
server.stop() // server.stop()
server.join() // server.join()
} // }
//
fun start() { // fun start() {
server.start() // server.start()
println("Network management web services started on $hostAndPort") // println("Network management web services started on $hostAndPort")
} // }
//
private fun buildServletContextHandler(): ServletContextHandler { // private fun buildServletContextHandler(): ServletContextHandler {
val crlServer = this // val crlServer = this
return ServletContextHandler().apply { // return ServletContextHandler().apply {
contextPath = "/" // contextPath = "/"
val resourceConfig = ResourceConfig().apply { // val resourceConfig = ResourceConfig().apply {
register(CrlServlet(crlServer)) // register(CrlServlet(crlServer))
} // }
val jerseyServlet = ServletHolder(ServletContainer(resourceConfig)).apply { initOrder = 0 } // val jerseyServlet = ServletHolder(ServletContainer(resourceConfig)).apply { initOrder = 0 }
addServlet(jerseyServlet, "/*") // addServlet(jerseyServlet, "/*")
} // }
} // }
} // }
} //}

View File

@ -2,7 +2,13 @@ package net.corda.node.services
import com.nhaarman.mockito_kotlin.doReturn import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.* import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.Contract
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.cordapp.CordappProvider import net.corda.core.cordapp.CordappProvider
import net.corda.core.flows.FlowLogic import net.corda.core.flows.FlowLogic
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
@ -22,13 +28,13 @@ import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.testing.common.internal.testNetworkParameters import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.core.DUMMY_BANK_A_NAME import net.corda.testing.core.DUMMY_BANK_A_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.SerializationEnvironmentRule import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.DriverDSL import net.corda.testing.driver.DriverDSL
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.NodeHandle import net.corda.testing.driver.NodeHandle
import net.corda.testing.driver.driver import net.corda.testing.driver.driver
import net.corda.testing.internal.MockCordappConfigProvider import net.corda.testing.internal.MockCordappConfigProvider
@ -107,10 +113,10 @@ class AttachmentLoadingTests {
@Test @Test
fun `test that attachments retrieved over the network are not used for code`() = withoutTestSerialization { fun `test that attachments retrieved over the network are not used for code`() = withoutTestSerialization {
driver { driver(DriverParameters(startNodesInProcess = true)) {
installIsolatedCordappTo(bankAName) installIsolatedCordappTo(bankAName)
val (bankA, bankB) = createTwoNodes() val (bankA, bankB) = createTwoNodes()
assertFailsWith<InternalNodeException> { assertFailsWith<CordaRuntimeException>("Party C=CH,L=Zurich,O=BankB rejected session request: Don't know net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Initiator") {
bankA.rpc.startFlowDynamic(flowInitiatorClass, bankB.nodeInfo.legalIdentities.first()).returnValue.getOrThrow() bankA.rpc.startFlowDynamic(flowInitiatorClass, bankB.nodeInfo.legalIdentities.first()).returnValue.getOrThrow()
} }
} }

View File

@ -2,13 +2,13 @@ package net.corda.node.services.rpc
import co.paralleluniverse.fibers.Suspendable import co.paralleluniverse.fibers.Suspendable
import net.corda.ClientRelevantException import net.corda.ClientRelevantException
import net.corda.core.CordaRuntimeException
import net.corda.core.flows.* import net.corda.core.flows.*
import net.corda.core.identity.Party import net.corda.core.identity.Party
import net.corda.core.messaging.startFlow import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions import net.corda.node.services.Permissions
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.singleIdentity import net.corda.testing.core.singleIdentity
@ -16,7 +16,7 @@ import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.NodeParameters import net.corda.testing.driver.NodeParameters
import net.corda.testing.driver.driver import net.corda.testing.driver.driver
import net.corda.testing.node.User import net.corda.testing.node.User
import org.assertj.core.api.Assertions.assertThatCode import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForInterfaceTypes.assertThat import org.assertj.core.api.AssertionsForInterfaceTypes.assertThat
import org.hibernate.exception.GenericJDBCException import org.hibernate.exception.GenericJDBCException
import org.junit.Test import org.junit.Test
@ -33,11 +33,10 @@ class RpcExceptionHandlingTest {
val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow() val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow()
assertThatCode { node.rpc.startFlow(::Flow).returnValue.getOrThrow() }.isInstanceOfSatisfying(InternalNodeException::class.java) { exception -> assertThatThrownBy { node.rpc.startFlow(::Flow).returnValue.getOrThrow() }.isInstanceOfSatisfying(CordaRuntimeException::class.java) { exception ->
assertThat(exception).hasNoCause() assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty() assertThat(exception.stackTrace).isEmpty()
assertThat(exception.message).isEqualTo(InternalNodeException.defaultMessage())
} }
} }
} }
@ -49,7 +48,7 @@ class RpcExceptionHandlingTest {
val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow() val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow()
val clientRelevantMessage = "This is for the players!" val clientRelevantMessage = "This is for the players!"
assertThatCode { node.rpc.startFlow(::ClientRelevantErrorFlow, clientRelevantMessage).returnValue.getOrThrow() }.isInstanceOfSatisfying(ClientRelevantException::class.java) { exception -> assertThatThrownBy { node.rpc.startFlow(::ClientRelevantErrorFlow, clientRelevantMessage).returnValue.getOrThrow() }.isInstanceOfSatisfying(CordaRuntimeException::class.java) { exception ->
assertThat(exception).hasNoCause() assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty() assertThat(exception.stackTrace).isEmpty()
@ -63,7 +62,7 @@ class RpcExceptionHandlingTest {
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList())) { driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList())) {
val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow() val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow()
val exceptionMessage = "Flow error!" val exceptionMessage = "Flow error!"
assertThatCode { node.rpc.startFlow(::FlowExceptionFlow, exceptionMessage).returnValue.getOrThrow() } assertThatThrownBy { node.rpc.startFlow(::FlowExceptionFlow, exceptionMessage).returnValue.getOrThrow() }
.isInstanceOfSatisfying(FlowException::class.java) { exception -> .isInstanceOfSatisfying(FlowException::class.java) { exception ->
assertThat(exception).hasNoCause() assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty() assertThat(exception.stackTrace).isEmpty()
@ -79,12 +78,10 @@ class RpcExceptionHandlingTest {
val nodeA = startNode(NodeParameters(providedName = ALICE_NAME, rpcUsers = users)).getOrThrow() val nodeA = startNode(NodeParameters(providedName = ALICE_NAME, rpcUsers = users)).getOrThrow()
val nodeB = startNode(NodeParameters(providedName = BOB_NAME, rpcUsers = users)).getOrThrow() val nodeB = startNode(NodeParameters(providedName = BOB_NAME, rpcUsers = users)).getOrThrow()
assertThatCode { nodeA.rpc.startFlow(::InitFlow, nodeB.nodeInfo.singleIdentity()).returnValue.getOrThrow() } assertThatThrownBy { nodeA.rpc.startFlow(::InitFlow, nodeB.nodeInfo.singleIdentity()).returnValue.getOrThrow() }.isInstanceOfSatisfying(CordaRuntimeException::class.java) { exception ->
.isInstanceOfSatisfying(InternalNodeException::class.java) { exception ->
assertThat(exception).hasNoCause() assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty() assertThat(exception.stackTrace).isEmpty()
assertThat(exception.message).isEqualTo(InternalNodeException.defaultMessage())
} }
} }
} }

View File

@ -37,6 +37,8 @@ import net.corda.node.internal.cordapp.CordappConfigFileProvider
import net.corda.node.internal.cordapp.CordappLoader import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.internal.cordapp.CordappProviderInternal import net.corda.node.internal.cordapp.CordappProviderInternal
import net.corda.node.internal.rpc.proxies.AuthenticatedRpcOpsProxy
import net.corda.node.internal.rpc.proxies.ExceptionSerialisingRpcOpsProxy
import net.corda.node.internal.security.RPCSecurityManager import net.corda.node.internal.security.RPCSecurityManager
import net.corda.node.services.ContractUpgradeHandler import net.corda.node.services.ContractUpgradeHandler
import net.corda.node.services.FinalityHandler import net.corda.node.services.FinalityHandler
@ -168,7 +170,10 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
/** The implementation of the [CordaRPCOps] interface used by this node. */ /** The implementation of the [CordaRPCOps] interface used by this node. */
open fun makeRPCOps(flowStarter: FlowStarter, database: CordaPersistence, smm: StateMachineManager): CordaRPCOps { open fun makeRPCOps(flowStarter: FlowStarter, database: CordaPersistence, smm: StateMachineManager): CordaRPCOps {
return SecureCordaRPCOps(services, smm, database, flowStarter, { shutdownExecutor.submit { stop() } }) val ops: CordaRPCOps = CordaRPCOpsImpl(services, smm, database, flowStarter, { shutdownExecutor.submit { stop() } })
// Mind that order is relevant here.
val proxies = listOf(::AuthenticatedRpcOpsProxy, ::ExceptionSerialisingRpcOpsProxy)
return proxies.fold(ops) { delegate, decorate -> decorate(delegate) }
} }
private fun initCertificate() { private fun initCertificate() {

View File

@ -25,12 +25,12 @@ import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.* import net.corda.core.node.services.vault.*
import net.corda.core.serialization.serialize import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.node.services.api.FlowStarter import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.context import net.corda.node.services.messaging.context
import net.corda.node.services.statemachine.StateMachineManager import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.exceptions.NonRpcFlowException
import net.corda.nodeapi.exceptions.RejectedCommandException import net.corda.nodeapi.exceptions.RejectedCommandException
import net.corda.nodeapi.internal.persistence.CordaPersistence import net.corda.nodeapi.internal.persistence.CordaPersistence
import rx.Observable import rx.Observable
@ -173,7 +173,7 @@ internal class CordaRPCOpsImpl(
} }
private fun <T> startFlow(logicType: Class<out FlowLogic<T>>, args: Array<out Any?>): FlowStateMachine<T> { private fun <T> startFlow(logicType: Class<out FlowLogic<T>>, args: Array<out Any?>): FlowStateMachine<T> {
require(logicType.isAnnotationPresent(StartableByRPC::class.java)) { "${logicType.name} was not designed for RPC" } if (!logicType.isAnnotationPresent(StartableByRPC::class.java)) throw NonRpcFlowException(logicType)
if (isFlowsDrainingModeEnabled()) { if (isFlowsDrainingModeEnabled()) {
throw RejectedCommandException("Node is draining before shutdown. Cannot start new flows through RPC.") throw RejectedCommandException("Node is draining before shutdown. Cannot start new flows through RPC.")
} }
@ -325,8 +325,4 @@ internal class CordaRPCOpsImpl(
is InvocationOrigin.Scheduled -> FlowInitiator.Scheduled((origin as InvocationOrigin.Scheduled).scheduledState) is InvocationOrigin.Scheduled -> FlowInitiator.Scheduled((origin as InvocationOrigin.Scheduled).scheduledState)
} }
} }
companion object {
private val log = contextLogger()
}
} }

View File

@ -0,0 +1,21 @@
package net.corda.node.internal
import java.lang.reflect.InvocationHandler
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
/**
* Helps writing correct [InvocationHandler]s.
*/
internal interface InvocationHandlerTemplate : InvocationHandler {
val delegate: Any
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
val args = arguments ?: emptyArray()
return try {
method.invoke(delegate, *args)
} catch (e: InvocationTargetException) {
throw e.targetException
}
}
}

View File

@ -82,7 +82,7 @@ open class Node(configuration: NodeConfiguration,
fun printWarning(message: String) { fun printWarning(message: String) {
Emoji.renderIfSupported { Emoji.renderIfSupported {
println("${Emoji.warningSign} ATTENTION: ${message}") println("${Emoji.warningSign} ATTENTION: $message")
} }
staticLog.warn(message) staticLog.warn(message)
} }
@ -183,7 +183,7 @@ open class Node(configuration: NodeConfiguration,
val serverAddress = configuration.messagingServerAddress val serverAddress = configuration.messagingServerAddress
?: NetworkHostAndPort("localhost", configuration.p2pAddress.port) ?: NetworkHostAndPort("localhost", configuration.p2pAddress.port)
val rpcServerAddresses = if (configuration.rpcOptions.standAloneBroker) { val rpcServerAddresses = if (configuration.rpcOptions.standAloneBroker) {
BrokerAddresses(configuration.rpcOptions.address!!, configuration.rpcOptions.adminAddress) BrokerAddresses(configuration.rpcOptions.address, configuration.rpcOptions.adminAddress)
} else { } else {
startLocalRpcBroker() startLocalRpcBroker()
} }
@ -293,10 +293,7 @@ open class Node(configuration: NodeConfiguration,
// Start up the MQ clients. // Start up the MQ clients.
internalRpcMessagingClient?.run { internalRpcMessagingClient?.run {
runOnStop += this::close runOnStop += this::close
when (rpcOps) { init(rpcOps, securityManager)
is SecureCordaRPCOps -> init(RpcExceptionHandlingProxy(rpcOps), securityManager)
else -> init(rpcOps, securityManager)
}
} }
verifierMessagingClient?.run { verifierMessagingClient?.run {
runOnStop += this::stop runOnStop += this::stop
@ -385,7 +382,7 @@ open class Node(configuration: NodeConfiguration,
SerializationFactoryImpl().apply { SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps)) registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme() ) registerScheme(KryoServerSerializationScheme())
}, },
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader), p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader), rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),

View File

@ -1,191 +0,0 @@
package net.corda.node.internal
import net.corda.client.rpc.PermissionException
import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.ParametersUpdateInfo
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.AttachmentId
import net.corda.core.node.services.NetworkMapCache
import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.*
import net.corda.node.services.messaging.RpcAuthContext
import java.io.InputStream
import java.security.PublicKey
// TODO change to KFunction reference after Kotlin fixes https://youtrack.jetbrains.com/issue/KT-12140
class RpcAuthorisationProxy(private val implementation: CordaRPCOps, private val context: () -> RpcAuthContext) : CordaRPCOps {
override fun networkParametersFeed(): DataFeed<ParametersUpdateInfo?, ParametersUpdateInfo> = guard("networkParametersFeed") {
implementation.networkParametersFeed()
}
override fun acceptNewNetworkParameters(parametersHash: SecureHash) = guard("acceptNewNetworkParameters") {
implementation.acceptNewNetworkParameters(parametersHash)
}
override fun uploadAttachmentWithMetadata(jar: InputStream, uploader: String, filename: String): SecureHash = guard("uploadAttachmentWithMetadata") {
implementation.uploadAttachmentWithMetadata(jar, uploader, filename)
}
override fun queryAttachments(query: AttachmentQueryCriteria, sorting: AttachmentSort?): List<AttachmentId> = guard("queryAttachments") {
implementation.queryAttachments(query, sorting)
}
override fun stateMachinesSnapshot() = guard("stateMachinesSnapshot") {
implementation.stateMachinesSnapshot()
}
override fun stateMachinesFeed() = guard("stateMachinesFeed") {
implementation.stateMachinesFeed()
}
override fun <T : ContractState> vaultQueryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = guard("vaultQueryBy") {
implementation.vaultQueryBy(criteria, paging, sorting, contractStateType)
}
override fun <T : ContractState> vaultTrackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = guard("vaultTrackBy") {
implementation.vaultTrackBy(criteria, paging, sorting, contractStateType)
}
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsSnapshot() = guard("internalVerifiedTransactionsSnapshot", implementation::internalVerifiedTransactionsSnapshot)
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsFeed() = guard("internalVerifiedTransactionsFeed", implementation::internalVerifiedTransactionsFeed)
override fun stateMachineRecordedTransactionMappingSnapshot() = guard("stateMachineRecordedTransactionMappingSnapshot", implementation::stateMachineRecordedTransactionMappingSnapshot)
override fun stateMachineRecordedTransactionMappingFeed() = guard("stateMachineRecordedTransactionMappingFeed", implementation::stateMachineRecordedTransactionMappingFeed)
override fun networkMapSnapshot(): List<NodeInfo> = guard("networkMapSnapshot", implementation::networkMapSnapshot)
override fun networkMapFeed(): DataFeed<List<NodeInfo>, NetworkMapCache.MapChange> = guard("networkMapFeed", implementation::networkMapFeed)
override fun <T> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startFlowDynamic", listOf(logicType)) {
implementation.startFlowDynamic(logicType, *args)
}
override fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startTrackedFlowDynamic", listOf(logicType)) {
implementation.startTrackedFlowDynamic(logicType, *args)
}
override fun killFlow(id: StateMachineRunId): Boolean = guard("killFlow") {
return implementation.killFlow(id)
}
override fun nodeInfo(): NodeInfo = guard("nodeInfo", implementation::nodeInfo)
override fun notaryIdentities(): List<Party> = guard("notaryIdentities", implementation::notaryIdentities)
override fun addVaultTransactionNote(txnId: SecureHash, txnNote: String) = guard("addVaultTransactionNote") {
implementation.addVaultTransactionNote(txnId, txnNote)
}
override fun getVaultTransactionNotes(txnId: SecureHash): Iterable<String> = guard("getVaultTransactionNotes") {
implementation.getVaultTransactionNotes(txnId)
}
override fun attachmentExists(id: SecureHash) = guard("attachmentExists") {
implementation.attachmentExists(id)
}
override fun openAttachment(id: SecureHash) = guard("openAttachment") {
implementation.openAttachment(id)
}
override fun uploadAttachment(jar: InputStream) = guard("uploadAttachment") {
implementation.uploadAttachment(jar)
}
override fun currentNodeTime() = guard("currentNodeTime", implementation::currentNodeTime)
override fun waitUntilNetworkReady() = guard("waitUntilNetworkReady", implementation::waitUntilNetworkReady)
override fun wellKnownPartyFromAnonymous(party: AbstractParty) = guard("wellKnownPartyFromAnonymous") {
implementation.wellKnownPartyFromAnonymous(party)
}
override fun partyFromKey(key: PublicKey) = guard("partyFromKey") {
implementation.partyFromKey(key)
}
override fun wellKnownPartyFromX500Name(x500Name: CordaX500Name) = guard("wellKnownPartyFromX500Name") {
implementation.wellKnownPartyFromX500Name(x500Name)
}
override fun notaryPartyFromX500Name(x500Name: CordaX500Name) = guard("notaryPartyFromX500Name") {
implementation.notaryPartyFromX500Name(x500Name)
}
override fun partiesFromName(query: String, exactMatch: Boolean) = guard("partiesFromName") {
implementation.partiesFromName(query, exactMatch)
}
override fun registeredFlows() = guard("registeredFlows", implementation::registeredFlows)
override fun nodeInfoFromParty(party: AbstractParty) = guard("nodeInfoFromParty") {
implementation.nodeInfoFromParty(party)
}
override fun clearNetworkMapCache() = guard("clearNetworkMapCache", implementation::clearNetworkMapCache)
override fun <T : ContractState> vaultQuery(contractStateType: Class<out T>): Vault.Page<T> = guard("vaultQuery") {
implementation.vaultQuery(contractStateType)
}
override fun <T : ContractState> vaultQueryByCriteria(criteria: QueryCriteria, contractStateType: Class<out T>): Vault.Page<T> = guard("vaultQueryByCriteria") {
implementation.vaultQueryByCriteria(criteria, contractStateType)
}
override fun <T : ContractState> vaultQueryByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification): Vault.Page<T> = guard("vaultQueryByWithPagingSpec") {
implementation.vaultQueryByWithPagingSpec(contractStateType, criteria, paging)
}
override fun <T : ContractState> vaultQueryByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort): Vault.Page<T> = guard("vaultQueryByWithSorting") {
implementation.vaultQueryByWithSorting(contractStateType, criteria, sorting)
}
override fun <T : ContractState> vaultTrack(contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrack") {
implementation.vaultTrack(contractStateType)
}
override fun <T : ContractState> vaultTrackByCriteria(contractStateType: Class<out T>, criteria: QueryCriteria): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrackByCriteria") {
implementation.vaultTrackByCriteria(contractStateType, criteria)
}
override fun <T : ContractState> vaultTrackByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrackByWithPagingSpec") {
implementation.vaultTrackByWithPagingSpec(contractStateType, criteria, paging)
}
override fun <T : ContractState> vaultTrackByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrackByWithSorting") {
implementation.vaultTrackByWithSorting(contractStateType, criteria, sorting)
}
override fun setFlowsDrainingModeEnabled(enabled: Boolean) = guard("setFlowsDrainingModeEnabled") {
implementation.setFlowsDrainingModeEnabled(enabled)
}
override fun isFlowsDrainingModeEnabled(): Boolean = guard("isFlowsDrainingModeEnabled", implementation::isFlowsDrainingModeEnabled)
override fun shutdown() = guard("shutdown", implementation::shutdown)
// TODO change to KFunction reference after Kotlin fixes https://youtrack.jetbrains.com/issue/KT-12140
private inline fun <RESULT> guard(methodName: String, action: () -> RESULT) = guard(methodName, emptyList(), action)
// TODO change to KFunction reference after Kotlin fixes https://youtrack.jetbrains.com/issue/KT-12140
private inline fun <RESULT> guard(methodName: String, args: List<Class<*>>, action: () -> RESULT) : RESULT {
if (!context().isPermitted(methodName, *(args.map { it.name }.toTypedArray()))) {
throw PermissionException("User not authorized to perform RPC call $methodName with target $args")
}
else {
return action()
}
}
}

View File

@ -1,154 +0,0 @@
package net.corda.node.internal
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash
import net.corda.core.doOnError
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.doOnError
import net.corda.core.internal.concurrent.mapError
import net.corda.core.mapErrors
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.node.services.vault.*
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.nodeapi.exceptions.adapters.InternalObfuscatingFlowHandle
import net.corda.nodeapi.exceptions.adapters.InternalObfuscatingFlowProgressHandle
import java.io.InputStream
import java.security.PublicKey
class RpcExceptionHandlingProxy(private val delegate: SecureCordaRPCOps) : CordaRPCOps {
private companion object {
private val logger = loggerFor<RpcExceptionHandlingProxy>()
}
override val protocolVersion: Int get() = delegate.protocolVersion
override fun <T> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowHandle<T> = wrap {
val handle = delegate.startFlowDynamic(logicType, *args)
val result = InternalObfuscatingFlowHandle(handle)
result.returnValue.doOnError { error -> logger.error(error.message, error) }
result
}
override fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T> = wrap {
val handle = delegate.startTrackedFlowDynamic(logicType, *args)
val result = InternalObfuscatingFlowProgressHandle(handle)
result.returnValue.doOnError { error -> logger.error(error.message, error) }
result
}
override fun waitUntilNetworkReady() = wrapFuture(delegate::waitUntilNetworkReady)
override fun stateMachinesFeed() = wrapFeed(delegate::stateMachinesFeed)
override fun <T : ContractState> vaultTrackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = wrapFeed { delegate.vaultTrackBy(criteria, paging, sorting, contractStateType) }
override fun <T : ContractState> vaultTrack(contractStateType: Class<out T>) = wrapFeed { delegate.vaultTrack(contractStateType) }
override fun <T : ContractState> vaultTrackByCriteria(contractStateType: Class<out T>, criteria: QueryCriteria) = wrapFeed { delegate.vaultTrackByCriteria(contractStateType, criteria) }
override fun <T : ContractState> vaultTrackByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification) = wrapFeed { delegate.vaultTrackByWithPagingSpec(contractStateType, criteria, paging) }
override fun <T : ContractState> vaultTrackByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort) = wrapFeed { delegate.vaultTrackByWithSorting(contractStateType, criteria, sorting) }
override fun stateMachineRecordedTransactionMappingFeed() = wrapFeed(delegate::stateMachineRecordedTransactionMappingFeed)
override fun networkMapFeed() = wrapFeed(delegate::networkMapFeed)
override fun networkParametersFeed() = wrapFeed(delegate::networkParametersFeed)
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsFeed() = wrapFeed(delegate::internalVerifiedTransactionsFeed)
override fun stateMachinesSnapshot() = wrap(delegate::stateMachinesSnapshot)
override fun <T : ContractState> vaultQueryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = wrap { delegate.vaultQueryBy(criteria, paging, sorting, contractStateType) }
override fun <T : ContractState> vaultQuery(contractStateType: Class<out T>) = wrap { delegate.vaultQuery(contractStateType) }
override fun <T : ContractState> vaultQueryByCriteria(criteria: QueryCriteria, contractStateType: Class<out T>) = wrap { delegate.vaultQueryByCriteria(criteria, contractStateType) }
override fun <T : ContractState> vaultQueryByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification) = wrap { delegate.vaultQueryByWithPagingSpec(contractStateType, criteria, paging) }
override fun <T : ContractState> vaultQueryByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort) = wrap { delegate.vaultQueryByWithSorting(contractStateType, criteria, sorting) }
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsSnapshot() = wrap(delegate::internalVerifiedTransactionsSnapshot)
override fun stateMachineRecordedTransactionMappingSnapshot() = wrap(delegate::stateMachineRecordedTransactionMappingSnapshot)
override fun networkMapSnapshot() = wrap(delegate::networkMapSnapshot)
override fun acceptNewNetworkParameters(parametersHash: SecureHash) = wrap { delegate.acceptNewNetworkParameters(parametersHash) }
override fun killFlow(id: StateMachineRunId) = wrap { delegate.killFlow(id) }
override fun nodeInfo() = wrap(delegate::nodeInfo)
override fun notaryIdentities() = wrap(delegate::notaryIdentities)
override fun addVaultTransactionNote(txnId: SecureHash, txnNote: String) = wrap { delegate.addVaultTransactionNote(txnId, txnNote) }
override fun getVaultTransactionNotes(txnId: SecureHash) = wrap { delegate.getVaultTransactionNotes(txnId) }
override fun attachmentExists(id: SecureHash) = wrap { delegate.attachmentExists(id) }
override fun openAttachment(id: SecureHash) = wrap { delegate.openAttachment(id) }
override fun uploadAttachment(jar: InputStream) = wrap { delegate.uploadAttachment(jar) }
override fun uploadAttachmentWithMetadata(jar: InputStream, uploader: String, filename: String) = wrap { delegate.uploadAttachmentWithMetadata(jar, uploader, filename) }
override fun queryAttachments(query: AttachmentQueryCriteria, sorting: AttachmentSort?) = wrap { delegate.queryAttachments(query, sorting) }
override fun currentNodeTime() = wrap(delegate::currentNodeTime)
override fun wellKnownPartyFromAnonymous(party: AbstractParty) = wrap { delegate.wellKnownPartyFromAnonymous(party) }
override fun partyFromKey(key: PublicKey) = wrap { delegate.partyFromKey(key) }
override fun wellKnownPartyFromX500Name(x500Name: CordaX500Name) = wrap { delegate.wellKnownPartyFromX500Name(x500Name) }
override fun notaryPartyFromX500Name(x500Name: CordaX500Name) = wrap { delegate.notaryPartyFromX500Name(x500Name) }
override fun partiesFromName(query: String, exactMatch: Boolean) = wrap { delegate.partiesFromName(query, exactMatch) }
override fun registeredFlows() = wrap(delegate::registeredFlows)
override fun nodeInfoFromParty(party: AbstractParty) = wrap { delegate.nodeInfoFromParty(party) }
override fun clearNetworkMapCache() = wrap(delegate::clearNetworkMapCache)
override fun setFlowsDrainingModeEnabled(enabled: Boolean) = wrap { delegate.setFlowsDrainingModeEnabled(enabled) }
override fun isFlowsDrainingModeEnabled() = wrap(delegate::isFlowsDrainingModeEnabled)
override fun shutdown() = wrap(delegate::shutdown)
private fun <RESULT> wrap(call: () -> RESULT): RESULT {
return try {
call.invoke()
} catch (error: Throwable) {
logger.error(error.message, error)
throw InternalNodeException.obfuscateIfInternal(error)
}
}
private fun <SNAPSHOT, ELEMENT> wrapFeed(call: () -> DataFeed<SNAPSHOT, ELEMENT>) = wrap {
call.invoke().doOnError { error -> logger.error(error.message, error) }.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
}
private fun <RESULT> wrapFuture(call: () -> CordaFuture<RESULT>): CordaFuture<RESULT> = wrap { call.invoke().mapError(InternalNodeException.Companion::obfuscateIfInternal).doOnError { error -> logger.error(error.message, error) } }
}

View File

@ -1,25 +0,0 @@
package net.corda.node.internal
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.rpcContext
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.internal.persistence.CordaPersistence
/**
* Implementation of [CordaRPCOps] that checks authorisation.
*/
class SecureCordaRPCOps(services: ServiceHubInternal,
smm: StateMachineManager,
database: CordaPersistence,
flowStarter: FlowStarter,
shutdownNode: () -> Unit,
val unsafe: CordaRPCOps = CordaRPCOpsImpl(services, smm, database, flowStarter, shutdownNode)) : CordaRPCOps by RpcAuthorisationProxy(unsafe, ::rpcContext) {
/**
* Returns the RPC protocol version, which is the same the node's Platform Version. Exists since version 1 so guaranteed
* to be present.
*/
override val protocolVersion: Int get() = unsafe.nodeInfo().platformVersion
}

View File

@ -0,0 +1,53 @@
package net.corda.node.internal.rpc.proxies
import net.corda.client.rpc.PermissionException
import net.corda.core.flows.FlowLogic
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.internal.InvocationHandlerTemplate
import net.corda.node.services.messaging.RpcAuthContext
import net.corda.node.services.messaging.rpcContext
import java.lang.reflect.Method
import java.lang.reflect.Proxy
/**
* Implementation of [CordaRPCOps] that checks authorisation.
*/
internal class AuthenticatedRpcOpsProxy(private val delegate: CordaRPCOps) : CordaRPCOps by proxy(delegate, ::rpcContext) {
/**
* Returns the RPC protocol version, which is the same the node's Platform Version. Exists since version 1 so guaranteed
* to be present.
*/
override val protocolVersion: Int get() = delegate.nodeInfo().platformVersion
// Need overriding to pass additional `listOf(logicType)` argument for polymorphic `startFlow` permissions.
override fun <T> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startFlowDynamic", listOf(logicType), ::rpcContext) {
delegate.startFlowDynamic(logicType, *args)
}
// Need overriding to pass additional `listOf(logicType)` argument for polymorphic `startFlow` permissions.
override fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startTrackedFlowDynamic", listOf(logicType), ::rpcContext) {
delegate.startTrackedFlowDynamic(logicType, *args)
}
private companion object {
private fun proxy(delegate: CordaRPCOps, context: () -> RpcAuthContext): CordaRPCOps {
val handler = PermissionsEnforcingInvocationHandler(delegate, context)
return Proxy.newProxyInstance(delegate::class.java.classLoader, arrayOf(CordaRPCOps::class.java), handler) as CordaRPCOps
}
}
private class PermissionsEnforcingInvocationHandler(override val delegate: CordaRPCOps, private val context: () -> RpcAuthContext) : InvocationHandlerTemplate {
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?) = guard(method.name, context, { super.invoke(proxy, method, arguments) })
}
}
private fun <RESULT> guard(methodName: String, context: () -> RpcAuthContext, action: () -> RESULT) = guard(methodName, emptyList(), context, action)
private fun <RESULT> guard(methodName: String, args: List<Class<*>>, context: () -> RpcAuthContext, action: () -> RESULT): RESULT {
if (!context().isPermitted(methodName, *(args.map { it.name }.toTypedArray()))) {
throw PermissionException("User not authorized to perform RPC call $methodName with target $args")
} else {
return action()
}
}

View File

@ -0,0 +1,100 @@
package net.corda.node.internal.rpc.proxies
import net.corda.core.CordaRuntimeException
import net.corda.core.CordaThrowable
import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.concurrent.mapError
import net.corda.core.mapErrors
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowHandleImpl
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.FlowProgressHandleImpl
import net.corda.core.serialization.CordaSerializable
import net.corda.node.internal.InvocationHandlerTemplate
import rx.Observable
import java.lang.reflect.Method
import java.lang.reflect.Proxy.newProxyInstance
internal class ExceptionSerialisingRpcOpsProxy(private val delegate: CordaRPCOps) : CordaRPCOps by proxy(delegate) {
private companion object {
private fun proxy(delegate: CordaRPCOps): CordaRPCOps {
val handler = ErrorSerialisingInvocationHandler(delegate)
return newProxyInstance(delegate::class.java.classLoader, arrayOf(CordaRPCOps::class.java), handler) as CordaRPCOps
}
}
private class ErrorSerialisingInvocationHandler(override val delegate: CordaRPCOps) : InvocationHandlerTemplate {
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
try {
val result = super.invoke(proxy, method, arguments)
return result?.let { ensureSerialisable(it) }
} catch (exception: Exception) {
throw ensureSerialisable(exception)
}
}
private fun <RESULT : Any> ensureSerialisable(result: RESULT): Any {
return when (result) {
is CordaFuture<*> -> wrapFuture(result)
is DataFeed<*, *> -> wrapFeed(result)
is FlowProgressHandle<*> -> wrapFlowProgressHandle(result)
is FlowHandle<*> -> wrapFlowHandle(result)
is Observable<*> -> wrapObservable(result)
else -> result
}
}
private fun wrapFlowProgressHandle(handle: FlowProgressHandle<*>): FlowProgressHandle<*> {
val returnValue = wrapFuture(handle.returnValue)
val progress = wrapObservable(handle.progress)
val stepsTreeIndexFeed = handle.stepsTreeIndexFeed?.let { wrapFeed(it) }
val stepsTreeFeed = handle.stepsTreeFeed?.let { wrapFeed(it) }
return FlowProgressHandleImpl(handle.id, returnValue, progress, stepsTreeIndexFeed, stepsTreeFeed)
}
private fun wrapFlowHandle(handle: FlowHandle<*>): FlowHandle<*> {
return FlowHandleImpl(handle.id, wrapFuture(handle.returnValue))
}
private fun <ELEMENT> wrapObservable(observable: Observable<ELEMENT>): Observable<ELEMENT> {
return observable.mapErrors(::ensureSerialisable)
}
private fun <SNAPSHOT, ELEMENT> wrapFeed(feed: DataFeed<SNAPSHOT, ELEMENT>): DataFeed<SNAPSHOT, ELEMENT> {
return feed.mapErrors(::ensureSerialisable)
}
private fun <RESULT> wrapFuture(future: CordaFuture<RESULT>): CordaFuture<RESULT> {
return future.mapError(::ensureSerialisable)
}
private fun ensureSerialisable(error: Throwable): Throwable {
val serialisable = (superclasses(error::class.java) + error::class.java).any { it.isAnnotationPresent(CordaSerializable::class.java) || it.interfaces.any { it.isAnnotationPresent(CordaSerializable::class.java) } }
val result = if (serialisable) error else CordaRuntimeException(error.message, error)
if (result is CordaThrowable) {
result.stackTrace = arrayOf<StackTraceElement>()
result.setCause(null)
}
return result
}
private fun superclasses(clazz: Class<*>): List<Class<*>> {
val superclasses = mutableListOf<Class<*>>()
var current: Class<*>?
var superclass = clazz.superclass
while (superclass != null) {
superclasses += superclass
current = superclass
superclass = current.superclass
}
return superclasses
}
}
override fun toString(): String {
return "ExceptionSerialisingRpcOpsProxy"
}
}

View File

@ -26,6 +26,7 @@ import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.vault.HibernateAttachmentQueryCriteriaParser import net.corda.node.services.vault.HibernateAttachmentQueryCriteriaParser
import net.corda.node.utilities.NonInvalidatingCache import net.corda.node.utilities.NonInvalidatingCache
import net.corda.node.utilities.NonInvalidatingWeightBasedCache import net.corda.node.utilities.NonInvalidatingWeightBasedCache
import net.corda.nodeapi.exceptions.DuplicateAttachmentException
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.currentDBSession import net.corda.nodeapi.internal.persistence.currentDBSession
import net.corda.nodeapi.internal.withContractsInJar import net.corda.nodeapi.internal.withContractsInJar
@ -287,7 +288,7 @@ class NodeAttachmentService(
log.info("Stored new attachment $id") log.info("Stored new attachment $id")
id id
} else { } else {
throw java.nio.file.FileAlreadyExistsException(id.toString()) throw DuplicateAttachmentException(id.toString())
} }
} }
} }

View File

@ -31,13 +31,13 @@ import net.corda.finance.USD
import net.corda.finance.contracts.asset.Cash import net.corda.finance.contracts.asset.Cash
import net.corda.finance.flows.CashIssueFlow import net.corda.finance.flows.CashIssueFlow
import net.corda.finance.flows.CashPaymentFlow import net.corda.finance.flows.CashPaymentFlow
import net.corda.node.internal.SecureCordaRPCOps
import net.corda.node.internal.StartedNode import net.corda.node.internal.StartedNode
import net.corda.node.internal.security.RPCSecurityManagerImpl import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.services.Permissions.Companion.invokeRpc import net.corda.node.services.Permissions.Companion.invokeRpc
import net.corda.node.services.Permissions.Companion.startFlow import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT
import net.corda.node.services.messaging.RpcAuthContext import net.corda.node.services.messaging.RpcAuthContext
import net.corda.nodeapi.exceptions.NonRpcFlowException
import net.corda.nodeapi.internal.config.User import net.corda.nodeapi.internal.config.User
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.expect import net.corda.testing.core.expect
@ -88,7 +88,7 @@ class CordaRPCOpsImplTest {
fun setup() { fun setup() {
mockNet = InternalMockNetwork(cordappPackages = listOf("net.corda.finance.contracts.asset")) mockNet = InternalMockNetwork(cordappPackages = listOf("net.corda.finance.contracts.asset"))
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME)) aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
rpc = SecureCordaRPCOps(aliceNode.services, aliceNode.smm, aliceNode.database, aliceNode.services, { }) rpc = aliceNode.rpcOps
CURRENT_RPC_CONTEXT.set(RpcAuthContext(InvocationContext.rpc(testActor()), buildSubject("TEST_USER", emptySet()))) CURRENT_RPC_CONTEXT.set(RpcAuthContext(InvocationContext.rpc(testActor()), buildSubject("TEST_USER", emptySet())))
mockNet.runNetwork() mockNet.runNetwork()
@ -266,10 +266,10 @@ class CordaRPCOpsImplTest {
@Test @Test
fun `can't upload the same attachment`() { fun `can't upload the same attachment`() {
withPermissions(invokeRpc(CordaRPCOps::uploadAttachment), invokeRpc(CordaRPCOps::attachmentExists)) { withPermissions(invokeRpc(CordaRPCOps::uploadAttachment), invokeRpc(CordaRPCOps::attachmentExists)) {
val inputJar1 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val inputJar2 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val secureHash1 = rpc.uploadAttachment(inputJar1)
assertThatExceptionOfType(java.nio.file.FileAlreadyExistsException::class.java).isThrownBy { assertThatExceptionOfType(java.nio.file.FileAlreadyExistsException::class.java).isThrownBy {
val inputJar1 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val inputJar2 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val secureHash1 = rpc.uploadAttachment(inputJar1)
val secureHash2 = rpc.uploadAttachment(inputJar2) val secureHash2 = rpc.uploadAttachment(inputJar2)
} }
} }
@ -293,7 +293,7 @@ class CordaRPCOpsImplTest {
@Test @Test
fun `attempt to start non-RPC flow`() { fun `attempt to start non-RPC flow`() {
withPermissions(startFlow<NonRPCFlow>()) { withPermissions(startFlow<NonRPCFlow>()) {
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy { assertThatExceptionOfType(NonRpcFlowException::class.java).isThrownBy {
rpc.startFlow(::NonRPCFlow) rpc.startFlow(::NonRPCFlow)
} }
} }

View File

@ -106,6 +106,19 @@ class FlowFrameworkTests {
assertThat(flow.lazyTime).isNotNull() assertThat(flow.lazyTime).isNotNull()
} }
class SuspendThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor {
var thrown = false
@Suspendable
override fun executeAction(fiber: FlowFiber, action: Action) {
if (action is Action.CommitTransaction && !thrown) {
thrown = true
throw exception
} else {
delegate.executeAction(fiber, action)
}
}
}
@Test @Test
fun `exception while fiber suspended`() { fun `exception while fiber suspended`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) } bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
@ -113,7 +126,7 @@ class FlowFrameworkTests {
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception // Before the flow runs change the suspend action to throw an exception
val exceptionDuringSuspend = Exception("Thrown during suspend") val exceptionDuringSuspend = Exception("Thrown during suspend")
val throwingActionExecutor = ThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor) val throwingActionExecutor = SuspendThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor)
fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor)) fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor))
mockNet.runNetwork() mockNet.runNetwork()
assertThatThrownBy { assertThatThrownBy {