[CORDA-1264}: Complete obfuscation of exceptions to client side. (#3155)

This commit is contained in:
Michele Sollecito 2018-05-21 13:34:37 +01:00 committed by GitHub
parent b0b36b5b7d
commit 5de2c2aa4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 811 additions and 1007 deletions

View File

@ -1,10 +1,10 @@
package net.corda.client.rpc
import net.corda.core.CordaRuntimeException
import net.corda.nodeapi.exceptions.RpcSerializableError
import net.corda.core.ClientRelevantError
/**
* Thrown to indicate that the calling user does not have permission for something they have requested (for example
* calling a method).
*/
class PermissionException(message: String) : CordaRuntimeException(message), RpcSerializableError
class PermissionException(message: String) : CordaRuntimeException(message), ClientRelevantError

View File

@ -0,0 +1,9 @@
package net.corda.core
import net.corda.core.serialization.CordaSerializable
/**
* Allows an implementing [Throwable] to be propagated to clients.
*/
@CordaSerializable
interface ClientRelevantError

View File

@ -267,6 +267,12 @@ fun <T> Any.declaredField(name: String): DeclaredField<T> = DeclaredField(javaCl
*/
fun <T> Any.declaredField(clazz: KClass<*>, name: String): DeclaredField<T> = DeclaredField(clazz.java, name, this)
/**
* Returns a [DeclaredField] wrapper around the (possibly non-public) instance field of the receiver object, but declared
* in its superclass [clazz].
*/
fun <T> Any.declaredField(clazz: Class<*>, name: String): DeclaredField<T> = DeclaredField(clazz, name, this)
/** creates a new instance if not a Kotlin object */
fun <T : Any> KClass<T>.objectOrNewInstance(): T {
return this.objectInstance ?: this.createInstance()
@ -277,10 +283,43 @@ fun <T : Any> KClass<T>.objectOrNewInstance(): T {
* visibility.
*/
class DeclaredField<T>(clazz: Class<*>, name: String, private val receiver: Any?) {
private val javaField = clazz.getDeclaredField(name).apply { isAccessible = true }
private val javaField = findField(name, clazz)
var value: T
get() = uncheckedCast<Any?, T>(javaField.get(receiver))
set(value) = javaField.set(receiver, value)
get() {
synchronized(this) {
return javaField.accessible { uncheckedCast<Any?, T>(get(receiver)) }
}
}
set(value) {
synchronized(this) {
javaField.accessible {
set(receiver, value)
}
}
}
val name: String = javaField.name
private fun <RESULT> Field.accessible(action: Field.() -> RESULT): RESULT {
val accessible = isAccessible
isAccessible = true
try {
return action(this)
} finally {
isAccessible = accessible
}
}
@Throws(NoSuchFieldException::class)
private fun findField(fieldName: String, clazz: Class<*>?): Field {
if (clazz == null) {
throw NoSuchFieldException(fieldName)
}
return try {
return clazz.getDeclaredField(fieldName)
} catch (e: NoSuchFieldException) {
findField(fieldName, clazz.superclass)
}
}
}
/** The annotated object would have a more restricted visibility were it not needed in tests. */

View File

@ -290,9 +290,11 @@ interface CordaRPCOps : RPCOps {
fun openAttachment(id: SecureHash): InputStream
/** Uploads a jar to the node, returns it's hash. */
@Throws(java.nio.file.FileAlreadyExistsException::class)
fun uploadAttachment(jar: InputStream): SecureHash
/** Uploads a jar including metadata to the node, returns it's hash. */
@Throws(java.nio.file.FileAlreadyExistsException::class)
fun uploadAttachmentWithMetadata(jar: InputStream, uploader: String, filename: String): SecureHash
/** Queries attachments metadata */

View File

@ -3,6 +3,7 @@
package net.corda.core.node.services.vault
import net.corda.core.DoNotImplement
import net.corda.core.internal.declaredField
import net.corda.core.internal.uncheckedCast
import net.corda.core.schemas.PersistentState
import net.corda.core.serialization.CordaSerializable
@ -439,18 +440,7 @@ class FieldInfo internal constructor(val name: String, val entityClass: Class<*>
*/
@Throws(NoSuchFieldException::class)
fun getField(fieldName: String, entityClass: Class<*>): FieldInfo {
return getField(fieldName, entityClass, entityClass)
}
@Throws(NoSuchFieldException::class)
private fun getField(fieldName: String, clazz: Class<*>?, invokingClazz: Class<*>): FieldInfo {
if (clazz == null) {
throw NoSuchFieldException(fieldName)
}
return try {
val field = clazz.getDeclaredField(fieldName)
return FieldInfo(field.name, invokingClazz)
} catch (e: NoSuchFieldException) {
getField(fieldName, clazz.superclass, invokingClazz)
}
val field = entityClass.declaredField<Any>(entityClass, fieldName)
return FieldInfo(field.name, entityClass)
}

View File

@ -1,6 +1,7 @@
package net.corda.core.flows
import co.paralleluniverse.fibers.Suspendable
import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.*
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.Party
@ -17,7 +18,6 @@ import net.corda.finance.USD
import net.corda.finance.`issued by`
import net.corda.finance.contracts.asset.Cash
import net.corda.finance.flows.CashIssueFlow
import net.corda.node.internal.SecureCordaRPCOps
import net.corda.node.internal.StartedNode
import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.testing.contracts.DummyContract
@ -116,7 +116,7 @@ class ContractUpgradeFlowTest {
return startRpcClient<CordaRPCOps>(
rpcAddress = startRpcServer(
rpcUser = user,
ops = SecureCordaRPCOps(node.services, node.smm, node.database, node.services, { })
ops = node.rpcOps
).get().broker.hostAndPort!!,
username = user.username,
password = user.password
@ -153,7 +153,7 @@ class ContractUpgradeFlowTest {
DummyContractV2::class.java).returnValue
mockNet.runNetwork()
assertFailsWith(UnexpectedFlowEndException::class) { rejectedFuture.getOrThrow() }
assertFailsWith(CordaRuntimeException::class) { rejectedFuture.getOrThrow() }
// Party B authorise the contract state upgrade, and immediately deauthorise the same.
rpcB.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow.Authorise(stateAndRef, upgrade) },
@ -168,7 +168,7 @@ class ContractUpgradeFlowTest {
DummyContractV2::class.java).returnValue
mockNet.runNetwork()
assertFailsWith(UnexpectedFlowEndException::class) { deauthorisedFuture.getOrThrow() }
assertFailsWith(CordaRuntimeException::class) { deauthorisedFuture.getOrThrow() }
// Party B authorise the contract state upgrade.
rpcB.startFlow({ stateAndRef, upgrade -> ContractUpgradeFlow.Authorise(stateAndRef, upgrade) },

View File

@ -1,37 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.TransactionVerificationException
import net.corda.core.flows.FlowException
import java.io.InvalidClassException
// could change to use package name matching but trying to avoid reflection for now
private val whitelisted = setOf(
FlowException::class,
InvalidClassException::class,
RpcSerializableError::class,
TransactionVerificationException::class
)
/**
* An [Exception] to signal RPC clients that something went wrong within a Corda node.
*/
class InternalNodeException(message: String) : CordaRuntimeException(message) {
companion object {
private const val DEFAULT_MESSAGE = "Something went wrong within the Corda node."
fun defaultMessage(): String = DEFAULT_MESSAGE
fun obfuscateIfInternal(wrapped: Throwable): Throwable {
(wrapped as? CordaRuntimeException)?.setCause(null)
return when {
whitelisted.any { it.isInstance(wrapped) } -> wrapped
else -> InternalNodeException(DEFAULT_MESSAGE).apply {
stackTrace = emptyArray()
}
}
}
}
}

View File

@ -1,11 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
import net.corda.core.crypto.SecureHash
class OutdatedNetworkParameterHashException(old: SecureHash, new: SecureHash) : CordaRuntimeException(TEMPLATE.format(old, new)), RpcSerializableError {
private companion object {
private const val TEMPLATE = "Refused to accept parameters with hash %s because network map advertises update with hash %s. Please check newest version"
}
}

View File

@ -1,8 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
/**
* Thrown to indicate that the command was rejected by the node, typically due to a special temporary mode.
*/
class RejectedCommandException(message: String) : CordaRuntimeException(message), RpcSerializableError

View File

@ -0,0 +1,49 @@
package net.corda.nodeapi.exceptions
import net.corda.core.CordaRuntimeException
import net.corda.core.crypto.SecureHash
import net.corda.core.ClientRelevantError
import net.corda.core.flows.IdentifiableException
/**
* Thrown to indicate that an attachment was already uploaded to a Corda node.
*/
class DuplicateAttachmentException(attachmentHash: String) : java.nio.file.FileAlreadyExistsException(attachmentHash), ClientRelevantError
/**
* Thrown to indicate that a flow was not designed for RPC and should be started from an RPC client.
*/
class NonRpcFlowException(logicType: Class<*>) : IllegalArgumentException("${logicType.name} was not designed for RPC"), ClientRelevantError
/**
* An [Exception] to signal RPC clients that something went wrong within a Corda node.
* The message is generic on purpose, as this prevents internal information from reaching RPC clients.
* Leaking internal information outside can compromise privacy e.g., party names and security e.g., passwords, stacktraces, etc.
*
* @param errorIdentifier an optional identifier for tracing problems across parties.
*/
class InternalNodeException(private val errorIdentifier: Long? = null) : CordaRuntimeException(message), ClientRelevantError, IdentifiableException {
companion object {
/**
* Message for the exception.
*/
const val message = "Something went wrong within the Corda node."
}
override fun getErrorId(): Long? {
return errorIdentifier
}
}
class OutdatedNetworkParameterHashException(old: SecureHash, new: SecureHash) : CordaRuntimeException(TEMPLATE.format(old, new)), ClientRelevantError {
private companion object {
private const val TEMPLATE = "Refused to accept parameters with hash %s because network map advertises update with hash %s. Please check newest version"
}
}
/**
* Thrown to indicate that the command was rejected by the node, typically due to a special temporary mode.
*/
class RejectedCommandException(message: String) : CordaRuntimeException(message), ClientRelevantError

View File

@ -1,9 +0,0 @@
package net.corda.nodeapi.exceptions
import net.corda.core.serialization.CordaSerializable
/**
* Allows an implementing [Throwable] to be propagated to RPC clients.
*/
@CordaSerializable
interface RpcSerializableError

View File

@ -1,15 +0,0 @@
package net.corda.nodeapi.exceptions.adapters
import net.corda.core.internal.concurrent.mapError
import net.corda.core.messaging.FlowHandle
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.exceptions.InternalNodeException
/**
* Adapter able to mask errors within a Corda node for RPC clients.
*/
@CordaSerializable
data class InternalObfuscatingFlowHandle<RESULT>(val wrapped: FlowHandle<RESULT>) : FlowHandle<RESULT> by wrapped {
override val returnValue = wrapped.returnValue.mapError(InternalNodeException.Companion::obfuscateIfInternal)
}

View File

@ -1,22 +0,0 @@
package net.corda.nodeapi.exceptions.adapters
import net.corda.core.internal.concurrent.mapError
import net.corda.core.mapErrors
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.serialization.CordaSerializable
import net.corda.nodeapi.exceptions.InternalNodeException
/**
* Adapter able to mask errors within a Corda node for RPC clients.
*/
@CordaSerializable
class InternalObfuscatingFlowProgressHandle<RESULT>(val wrapped: FlowProgressHandle<RESULT>) : FlowProgressHandle<RESULT> by wrapped {
override val returnValue = wrapped.returnValue.mapError(InternalNodeException.Companion::obfuscateIfInternal)
override val progress = wrapped.progress.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
override val stepsTreeIndexFeed = wrapped.stepsTreeIndexFeed?.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
override val stepsTreeFeed = wrapped.stepsTreeFeed?.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
}

View File

@ -1,6 +1,6 @@
package net.corda
import net.corda.core.CordaRuntimeException
import net.corda.nodeapi.exceptions.RpcSerializableError
import net.corda.core.ClientRelevantError
class ClientRelevantException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause), RpcSerializableError
class ClientRelevantException(message: String?, cause: Throwable?) : CordaRuntimeException(message, cause), ClientRelevantError

View File

@ -12,7 +12,6 @@ import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.node.internal.NodeStartup
import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.testing.common.internal.ProjectStructure.projectRootDir
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.driver.DriverParameters
@ -34,7 +33,6 @@ class BootTests {
start(user.username, user.password).proxy.startFlow(::ObjectInputStreamFlow).returnValue
assertThatThrownBy { future.getOrThrow() }
.isInstanceOf(CordaRuntimeException::class.java)
.hasMessageContaining(InternalNodeException.defaultMessage())
}
}

View File

@ -1,474 +1,474 @@
package net.corda.node.amqp
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.crypto.Crypto
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.div
import net.corda.core.toFuture
import net.corda.core.utilities.NetworkHostAndPort
import net.corda.core.utilities.days
import net.corda.core.utilities.minutes
import net.corda.core.utilities.seconds
import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.config.configureWithDevSSLCertificate
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER
import net.corda.nodeapi.internal.config.SSLConfiguration
import net.corda.nodeapi.internal.crypto.*
import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
import net.corda.testing.core.*
import net.corda.testing.internal.DEV_INTERMEDIATE_CA
import net.corda.testing.internal.DEV_ROOT_CA
import net.corda.testing.internal.rigorousMock
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x509.*
import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.ServerConnector
import org.eclipse.jetty.server.handler.HandlerCollection
import org.eclipse.jetty.servlet.ServletContextHandler
import org.eclipse.jetty.servlet.ServletHolder
import org.glassfish.jersey.server.ResourceConfig
import org.glassfish.jersey.servlet.ServletContainer
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.io.Closeable
import java.math.BigInteger
import java.net.InetSocketAddress
import java.security.KeyPair
import java.security.PrivateKey
import java.security.Security
import java.security.cert.X509CRL
import java.security.cert.X509Certificate
import java.util.*
import javax.ws.rs.GET
import javax.ws.rs.Path
import javax.ws.rs.Produces
import javax.ws.rs.core.Response
import kotlin.test.assertEquals
class CertificateRevocationListNodeTests {
@Rule
@JvmField
val temporaryFolder = TemporaryFolder()
private val ROOT_CA = DEV_ROOT_CA
private lateinit var INTERMEDIATE_CA: CertificateAndKeyPair
private val serverPort = freePort()
private lateinit var server: CrlServer
private val revokedNodeCerts: MutableList<BigInteger> = mutableListOf()
private val revokedIntermediateCerts: MutableList<BigInteger> = mutableListOf()
private abstract class AbstractNodeConfiguration : NodeConfiguration
@Before
fun setUp() {
Security.addProvider(BouncyCastleProvider())
revokedNodeCerts.clear()
server = CrlServer(NetworkHostAndPort("localhost", 0))
server.start()
INTERMEDIATE_CA = CertificateAndKeyPair(replaceCrlDistPointCaCertificate(
DEV_INTERMEDIATE_CA.certificate,
CertificateType.INTERMEDIATE_CA,
ROOT_CA.keyPair,
"http://${server.hostAndPort}/crl/intermediate.crl"), DEV_INTERMEDIATE_CA.keyPair)
}
@After
fun tearDown() {
server.close()
revokedNodeCerts.clear()
}
@Test
fun `Simple AMPQ Client to Server connection works`() {
val crlCheckSoftFail = true
val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
amqpServer.use {
amqpServer.start()
val receiveSubs = amqpServer.onReceive.subscribe {
assertEquals(BOB_NAME.toString(), it.sourceLegalName)
assertEquals(P2P_PREFIX + "Test", it.topic)
assertEquals("Test", String(it.payload))
it.complete(true)
}
val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture()
val clientConnected = amqpClient.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertEquals(true, serverConnect.connected)
val clientConnect = clientConnected.get()
assertEquals(true, clientConnect.connected)
val msg = amqpClient.createMessage("Test".toByteArray(),
P2P_PREFIX + "Test",
ALICE_NAME.toString(),
emptyMap())
amqpClient.write(msg)
assertEquals(MessageStatus.Acknowledged, msg.onComplete.get())
receiveSubs.unsubscribe()
}
}
}
@Test
fun `AMPQ Client to Server connection fails when client's certificate is revoked`() {
val crlCheckSoftFail = true
val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
amqpServer.use {
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (amqpClient, clientCert) = createClient(serverPort, crlCheckSoftFail)
revokedNodeCerts.add(clientCert.serialNumber)
amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected)
}
}
}
@Test
fun `AMPQ Client to Server connection fails when servers's certificate is revoked`() {
val crlCheckSoftFail = true
val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
revokedNodeCerts.add(serverCert.serialNumber)
amqpServer.use {
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected)
}
}
}
@Test
fun `AMPQ Client to Server connection fails when servers's certificate is revoked and soft fail is enabled`() {
val crlCheckSoftFail = true
val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
revokedNodeCerts.add(serverCert.serialNumber)
amqpServer.use {
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected)
}
}
}
@Test
fun `AMPQ Client to Server connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
val crlCheckSoftFail = true
val (amqpServer, serverCert) = createServer(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl")
amqpServer.use {
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (amqpClient, _) = createClient(
serverPort,
crlCheckSoftFail,
nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl")
amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertEquals(true, serverConnect.connected)
}
}
}
@Test
fun `Revocation status chceck fails when the CRL distribution point is not set and soft fail is disabled`() {
val crlCheckSoftFail = false
val (amqpServer, _) = createServer(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
tlsCrlDistPoint = null)
amqpServer.use {
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (amqpClient, _) = createClient(
serverPort,
crlCheckSoftFail,
tlsCrlDistPoint = null)
amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertEquals(false, serverConnect.connected)
}
}
}
@Test
fun `Revocation status chceck succeds when the CRL distribution point is not set and soft fail is enabled`() {
val crlCheckSoftFail = true
val (amqpServer, _) = createServer(
serverPort,
crlCheckSoftFail = crlCheckSoftFail,
tlsCrlDistPoint = null)
amqpServer.use {
amqpServer.start()
amqpServer.onReceive.subscribe {
it.complete(true)
}
val (amqpClient, _) = createClient(
serverPort,
crlCheckSoftFail,
tlsCrlDistPoint = null)
amqpClient.use {
val serverConnected = amqpServer.onConnection.toFuture()
amqpClient.onConnection.toFuture()
amqpClient.start()
val serverConnect = serverConnected.get()
assertEquals(true, serverConnect.connected)
}
}
}
private fun createClient(targetPort: Int,
crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl",
tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl",
maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPClient, X509Certificate> {
val clientConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(temporaryFolder.root.toPath() / "client").whenever(it).baseDirectory
doReturn(BOB_NAME).whenever(it).myLegalName
doReturn("trustpass").whenever(it).trustStorePassword
doReturn("cordacadevpass").whenever(it).keyStorePassword
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
}
clientConfig.configureWithDevSSLCertificate()
val nodeCert = clientConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint)
val clientTruststore = clientConfig.loadTrustStore().internal
val clientKeystore = clientConfig.loadSslKeyStore().internal
return Pair(AMQPClient(
listOf(NetworkHostAndPort("localhost", targetPort)),
setOf(ALICE_NAME, CHARLIE_NAME),
PEER_USER,
PEER_USER,
clientKeystore,
clientConfig.keyStorePassword,
clientTruststore,
crlCheckSoftFail,
maxMessageSize = maxMessageSize), nodeCert)
}
private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME,
crlCheckSoftFail: Boolean,
nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl",
tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl",
maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPServer, X509Certificate> {
val serverConfig = rigorousMock<AbstractNodeConfiguration>().also {
doReturn(temporaryFolder.root.toPath() / "server").whenever(it).baseDirectory
doReturn(name).whenever(it).myLegalName
doReturn("trustpass").whenever(it).trustStorePassword
doReturn("cordacadevpass").whenever(it).keyStorePassword
doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
}
serverConfig.configureWithDevSSLCertificate()
val nodeCert = serverConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint)
val serverTruststore = serverConfig.loadTrustStore().internal
val serverKeystore = serverConfig.loadSslKeyStore().internal
return Pair(AMQPServer(
"0.0.0.0",
port,
PEER_USER,
PEER_USER,
serverKeystore,
serverConfig.keyStorePassword,
serverTruststore,
crlCheckSoftFail,
maxMessageSize = maxMessageSize), nodeCert)
}
private fun SSLConfiguration.recreateNodeCaAndTlsCertificates(nodeCaCrlDistPoint: String, tlsCrlDistPoint: String?): X509Certificate {
val nodeKeyStore = loadNodeKeyStore()
val (nodeCert, nodeKeys) = nodeKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA)
val newNodeCert = replaceCrlDistPointCaCertificate(nodeCert, CertificateType.NODE_CA, INTERMEDIATE_CA.keyPair, nodeCaCrlDistPoint)
val nodeCertChain = listOf(newNodeCert, INTERMEDIATE_CA.certificate, *nodeKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA).drop(2).toTypedArray())
nodeKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA)
nodeKeyStore.save()
nodeKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain)
}
val sslKeyStore = loadSslKeyStore()
val (tlsCert, tlsKeys) = sslKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS)
val newTlsCert = replaceCrlDistPointCaCertificate(tlsCert, CertificateType.TLS, nodeKeys, tlsCrlDistPoint, X500Name.getInstance(ROOT_CA.certificate.subjectX500Principal.encoded))
val sslCertChain = listOf(newTlsCert, newNodeCert, INTERMEDIATE_CA.certificate, *sslKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_TLS).drop(3).toTypedArray())
sslKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS)
sslKeyStore.save()
sslKeyStore.update {
setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain)
}
return newNodeCert
}
private fun replaceCrlDistPointCaCertificate(currentCaCert: X509Certificate, certType: CertificateType, issuerKeyPair: KeyPair, crlDistPoint: String?, crlIssuer: X500Name? = null): X509Certificate {
val signatureScheme = Crypto.findSignatureScheme(issuerKeyPair.private)
val provider = Crypto.findProvider(signatureScheme.providerName)
val issuerSigner = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider)
val builder = X509Utilities.createPartialCertificate(
certType,
currentCaCert.issuerX500Principal,
issuerKeyPair.public,
currentCaCert.subjectX500Principal,
currentCaCert.publicKey,
Pair(Date(System.currentTimeMillis() - 5.minutes.toMillis()), Date(System.currentTimeMillis() + 10.days.toMillis())),
null
)
crlDistPoint?.let {
val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, it)))
val crlIssuerGeneralNames = crlIssuer?.let {
GeneralNames(GeneralName(crlIssuer))
}
val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames)
builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint)))
}
return builder.build(issuerSigner).toJca()
}
@Path("crl")
inner class CrlServlet(private val server: CrlServer) {
private val SIGNATURE_ALGORITHM = "SHA256withECDSA"
private val NODE_CRL = "node.crl"
private val INTEMEDIATE_CRL = "intermediate.crl"
private val EMPTY_CRL = "empty.crl"
@GET
@Path("node.crl")
@Produces("application/pkcs7-crl")
fun getNodeCRL(): Response {
return Response.ok(createRevocationList(
INTERMEDIATE_CA.certificate,
INTERMEDIATE_CA.keyPair.private,
NODE_CRL,
false,
*revokedNodeCerts.toTypedArray()).encoded).build()
}
@GET
@Path("intermediate.crl")
@Produces("application/pkcs7-crl")
fun getIntermediateCRL(): Response {
return Response.ok(createRevocationList(
ROOT_CA.certificate,
ROOT_CA.keyPair.private,
INTEMEDIATE_CRL,
false,
*revokedIntermediateCerts.toTypedArray()).encoded).build()
}
@GET
@Path("empty.crl")
@Produces("application/pkcs7-crl")
fun getEmptyCRL(): Response {
return Response.ok(createRevocationList(
ROOT_CA.certificate,
ROOT_CA.keyPair.private,
EMPTY_CRL, true).encoded).build()
}
private fun createRevocationList(caCertificate: X509Certificate,
caPrivateKey: PrivateKey,
endpoint: String,
indirect: Boolean,
vararg serialNumbers: BigInteger): X509CRL {
println("Generating CRL for $endpoint")
val builder = JcaX509v2CRLBuilder(caCertificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis()))
val extensionUtils = JcaX509ExtensionUtils()
builder.addExtension(Extension.authorityKeyIdentifier,
false, extensionUtils.createAuthorityKeyIdentifier(caCertificate))
val issuingDistPointName = GeneralName(
GeneralName.uniformResourceIdentifier,
"http://${server.hostAndPort.host}:${server.hostAndPort.port}/crl/$endpoint")
// This is required and needs to match the certificate settings with respect to being indirect
val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false)
builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint)
builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis()))
serialNumbers.forEach {
builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold)
}
val signer = JcaContentSignerBuilder(SIGNATURE_ALGORITHM).setProvider(BouncyCastleProvider.PROVIDER_NAME).build(caPrivateKey)
return JcaX509CRLConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME).getCRL(builder.build(signer))
}
}
inner class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
private val server: Server = Server(InetSocketAddress(hostAndPort.host, hostAndPort.port)).apply {
handler = HandlerCollection().apply {
addHandler(buildServletContextHandler())
}
}
val hostAndPort: NetworkHostAndPort
get() = server.connectors.mapNotNull { it as? ServerConnector }
.map { NetworkHostAndPort(it.host, it.localPort) }
.first()
override fun close() {
println("Shutting down network management web services...")
server.stop()
server.join()
}
fun start() {
server.start()
println("Network management web services started on $hostAndPort")
}
private fun buildServletContextHandler(): ServletContextHandler {
val crlServer = this
return ServletContextHandler().apply {
contextPath = "/"
val resourceConfig = ResourceConfig().apply {
register(CrlServlet(crlServer))
}
val jerseyServlet = ServletHolder(ServletContainer(resourceConfig)).apply { initOrder = 0 }
addServlet(jerseyServlet, "/*")
}
}
}
}
//package net.corda.node.amqp
//
//import com.nhaarman.mockito_kotlin.doReturn
//import com.nhaarman.mockito_kotlin.whenever
//import net.corda.core.crypto.Crypto
//import net.corda.core.identity.CordaX500Name
//import net.corda.core.internal.div
//import net.corda.core.toFuture
//import net.corda.core.utilities.NetworkHostAndPort
//import net.corda.core.utilities.days
//import net.corda.core.utilities.minutes
//import net.corda.core.utilities.seconds
//import net.corda.node.services.config.NodeConfiguration
//import net.corda.node.services.config.configureWithDevSSLCertificate
//import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.P2P_PREFIX
//import net.corda.nodeapi.internal.ArtemisMessagingComponent.Companion.PEER_USER
//import net.corda.nodeapi.internal.config.SSLConfiguration
//import net.corda.nodeapi.internal.crypto.*
//import net.corda.nodeapi.internal.protonwrapper.messages.MessageStatus
//import net.corda.nodeapi.internal.protonwrapper.netty.AMQPClient
//import net.corda.nodeapi.internal.protonwrapper.netty.AMQPServer
//import net.corda.testing.core.*
//import net.corda.testing.internal.DEV_INTERMEDIATE_CA
//import net.corda.testing.internal.DEV_ROOT_CA
//import net.corda.testing.internal.rigorousMock
//import org.bouncycastle.asn1.x500.X500Name
//import org.bouncycastle.asn1.x509.*
//import org.bouncycastle.cert.jcajce.JcaX509CRLConverter
//import org.bouncycastle.cert.jcajce.JcaX509ExtensionUtils
//import org.bouncycastle.cert.jcajce.JcaX509v2CRLBuilder
//import org.bouncycastle.jce.provider.BouncyCastleProvider
//import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
//import org.eclipse.jetty.server.Server
//import org.eclipse.jetty.server.ServerConnector
//import org.eclipse.jetty.server.handler.HandlerCollection
//import org.eclipse.jetty.servlet.ServletContextHandler
//import org.eclipse.jetty.servlet.ServletHolder
//import org.glassfish.jersey.server.ResourceConfig
//import org.glassfish.jersey.servlet.ServletContainer
//import org.junit.After
//import org.junit.Before
//import org.junit.Rule
//import org.junit.Test
//import org.junit.rules.TemporaryFolder
//import java.io.Closeable
//import java.math.BigInteger
//import java.net.InetSocketAddress
//import java.security.KeyPair
//import java.security.PrivateKey
//import java.security.Security
//import java.security.cert.X509CRL
//import java.security.cert.X509Certificate
//import java.util.*
//import javax.ws.rs.GET
//import javax.ws.rs.Path
//import javax.ws.rs.Produces
//import javax.ws.rs.core.Response
//import kotlin.test.assertEquals
//
//class CertificateRevocationListNodeTests {
// @Rule
// @JvmField
// val temporaryFolder = TemporaryFolder()
//
// private val ROOT_CA = DEV_ROOT_CA
// private lateinit var INTERMEDIATE_CA: CertificateAndKeyPair
//
// private val serverPort = freePort()
//
// private lateinit var server: CrlServer
//
// private val revokedNodeCerts: MutableList<BigInteger> = mutableListOf()
// private val revokedIntermediateCerts: MutableList<BigInteger> = mutableListOf()
//
// private abstract class AbstractNodeConfiguration : NodeConfiguration
//
// @Before
// fun setUp() {
// Security.addProvider(BouncyCastleProvider())
// revokedNodeCerts.clear()
// server = CrlServer(NetworkHostAndPort("localhost", 0))
// server.start()
// INTERMEDIATE_CA = CertificateAndKeyPair(replaceCrlDistPointCaCertificate(
// DEV_INTERMEDIATE_CA.certificate,
// CertificateType.INTERMEDIATE_CA,
// ROOT_CA.keyPair,
// "http://${server.hostAndPort}/crl/intermediate.crl"), DEV_INTERMEDIATE_CA.keyPair)
// }
//
// @After
// fun tearDown() {
// server.close()
// revokedNodeCerts.clear()
// }
//
// @Test
// fun `Simple AMPQ Client to Server connection works`() {
// val crlCheckSoftFail = true
// val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
// amqpServer.use {
// amqpServer.start()
// val receiveSubs = amqpServer.onReceive.subscribe {
// assertEquals(BOB_NAME.toString(), it.sourceLegalName)
// assertEquals(P2P_PREFIX + "Test", it.topic)
// assertEquals("Test", String(it.payload))
// it.complete(true)
// }
// val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
// amqpClient.use {
// val serverConnected = amqpServer.onConnection.toFuture()
// val clientConnected = amqpClient.onConnection.toFuture()
// amqpClient.start()
// val serverConnect = serverConnected.get()
// assertEquals(true, serverConnect.connected)
// val clientConnect = clientConnected.get()
// assertEquals(true, clientConnect.connected)
// val msg = amqpClient.createMessage("Test".toByteArray(),
// P2P_PREFIX + "Test",
// ALICE_NAME.toString(),
// emptyMap())
// amqpClient.write(msg)
// assertEquals(MessageStatus.Acknowledged, msg.onComplete.get())
// receiveSubs.unsubscribe()
// }
// }
// }
//
// @Test
// fun `AMPQ Client to Server connection fails when client's certificate is revoked`() {
// val crlCheckSoftFail = true
// val (amqpServer, _) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
// amqpServer.use {
// amqpServer.start()
// amqpServer.onReceive.subscribe {
// it.complete(true)
// }
// val (amqpClient, clientCert) = createClient(serverPort, crlCheckSoftFail)
// revokedNodeCerts.add(clientCert.serialNumber)
// amqpClient.use {
// val serverConnected = amqpServer.onConnection.toFuture()
// amqpClient.onConnection.toFuture()
// amqpClient.start()
// val serverConnect = serverConnected.get()
// assertEquals(false, serverConnect.connected)
// }
// }
// }
//
// @Test
// fun `AMPQ Client to Server connection fails when servers's certificate is revoked`() {
// val crlCheckSoftFail = true
// val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
// revokedNodeCerts.add(serverCert.serialNumber)
// amqpServer.use {
// amqpServer.start()
// amqpServer.onReceive.subscribe {
// it.complete(true)
// }
// val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
// amqpClient.use {
// val serverConnected = amqpServer.onConnection.toFuture()
// amqpClient.onConnection.toFuture()
// amqpClient.start()
// val serverConnect = serverConnected.get()
// assertEquals(false, serverConnect.connected)
// }
// }
// }
//
// @Test
// fun `AMPQ Client to Server connection fails when servers's certificate is revoked and soft fail is enabled`() {
// val crlCheckSoftFail = true
// val (amqpServer, serverCert) = createServer(serverPort, crlCheckSoftFail = crlCheckSoftFail)
// revokedNodeCerts.add(serverCert.serialNumber)
// amqpServer.use {
// amqpServer.start()
// amqpServer.onReceive.subscribe {
// it.complete(true)
// }
// val (amqpClient, _) = createClient(serverPort, crlCheckSoftFail)
// amqpClient.use {
// val serverConnected = amqpServer.onConnection.toFuture()
// amqpClient.onConnection.toFuture()
// amqpClient.start()
// val serverConnect = serverConnected.get()
// assertEquals(false, serverConnect.connected)
// }
// }
// }
//
// @Test
// fun `AMPQ Client to Server connection succeeds when CRL cannot be obtained and soft fail is enabled`() {
// val crlCheckSoftFail = true
// val (amqpServer, serverCert) = createServer(
// serverPort,
// crlCheckSoftFail = crlCheckSoftFail,
// nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl")
// amqpServer.use {
// amqpServer.start()
// amqpServer.onReceive.subscribe {
// it.complete(true)
// }
// val (amqpClient, _) = createClient(
// serverPort,
// crlCheckSoftFail,
// nodeCrlDistPoint = "http://${server.hostAndPort}/crl/invalid.crl")
// amqpClient.use {
// val serverConnected = amqpServer.onConnection.toFuture()
// amqpClient.onConnection.toFuture()
// amqpClient.start()
// val serverConnect = serverConnected.get()
// assertEquals(true, serverConnect.connected)
// }
// }
// }
//
// @Test
// fun `Revocation status chceck fails when the CRL distribution point is not set and soft fail is disabled`() {
// val crlCheckSoftFail = false
// val (amqpServer, _) = createServer(
// serverPort,
// crlCheckSoftFail = crlCheckSoftFail,
// tlsCrlDistPoint = null)
// amqpServer.use {
// amqpServer.start()
// amqpServer.onReceive.subscribe {
// it.complete(true)
// }
// val (amqpClient, _) = createClient(
// serverPort,
// crlCheckSoftFail,
// tlsCrlDistPoint = null)
// amqpClient.use {
// val serverConnected = amqpServer.onConnection.toFuture()
// amqpClient.onConnection.toFuture()
// amqpClient.start()
// val serverConnect = serverConnected.get()
// assertEquals(false, serverConnect.connected)
// }
// }
// }
//
// @Test
// fun `Revocation status chceck succeds when the CRL distribution point is not set and soft fail is enabled`() {
// val crlCheckSoftFail = true
// val (amqpServer, _) = createServer(
// serverPort,
// crlCheckSoftFail = crlCheckSoftFail,
// tlsCrlDistPoint = null)
// amqpServer.use {
// amqpServer.start()
// amqpServer.onReceive.subscribe {
// it.complete(true)
// }
// val (amqpClient, _) = createClient(
// serverPort,
// crlCheckSoftFail,
// tlsCrlDistPoint = null)
// amqpClient.use {
// val serverConnected = amqpServer.onConnection.toFuture()
// amqpClient.onConnection.toFuture()
// amqpClient.start()
// val serverConnect = serverConnected.get()
// assertEquals(true, serverConnect.connected)
// }
// }
// }
//
// private fun createClient(targetPort: Int,
// crlCheckSoftFail: Boolean,
// nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl",
// tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl",
// maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPClient, X509Certificate> {
// val clientConfig = rigorousMock<AbstractNodeConfiguration>().also {
// doReturn(temporaryFolder.root.toPath() / "client").whenever(it).baseDirectory
// doReturn(BOB_NAME).whenever(it).myLegalName
// doReturn("trustpass").whenever(it).trustStorePassword
// doReturn("cordacadevpass").whenever(it).keyStorePassword
// doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
// }
// clientConfig.configureWithDevSSLCertificate()
// val nodeCert = clientConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint)
// val clientTruststore = clientConfig.loadTrustStore().internal
// val clientKeystore = clientConfig.loadSslKeyStore().internal
// return Pair(AMQPClient(
// listOf(NetworkHostAndPort("localhost", targetPort)),
// setOf(ALICE_NAME, CHARLIE_NAME),
// PEER_USER,
// PEER_USER,
// clientKeystore,
// clientConfig.keyStorePassword,
// clientTruststore,
// crlCheckSoftFail,
// maxMessageSize = maxMessageSize), nodeCert)
// }
//
// private fun createServer(port: Int, name: CordaX500Name = ALICE_NAME,
// crlCheckSoftFail: Boolean,
// nodeCrlDistPoint: String = "http://${server.hostAndPort}/crl/node.crl",
// tlsCrlDistPoint: String? = "http://${server.hostAndPort}/crl/empty.crl",
// maxMessageSize: Int = MAX_MESSAGE_SIZE): Pair<AMQPServer, X509Certificate> {
// val serverConfig = rigorousMock<AbstractNodeConfiguration>().also {
// doReturn(temporaryFolder.root.toPath() / "server").whenever(it).baseDirectory
// doReturn(name).whenever(it).myLegalName
// doReturn("trustpass").whenever(it).trustStorePassword
// doReturn("cordacadevpass").whenever(it).keyStorePassword
// doReturn(crlCheckSoftFail).whenever(it).crlCheckSoftFail
// }
// serverConfig.configureWithDevSSLCertificate()
// val nodeCert = serverConfig.recreateNodeCaAndTlsCertificates(nodeCrlDistPoint, tlsCrlDistPoint)
// val serverTruststore = serverConfig.loadTrustStore().internal
// val serverKeystore = serverConfig.loadSslKeyStore().internal
// return Pair(AMQPServer(
// "0.0.0.0",
// port,
// PEER_USER,
// PEER_USER,
// serverKeystore,
// serverConfig.keyStorePassword,
// serverTruststore,
// crlCheckSoftFail,
// maxMessageSize = maxMessageSize), nodeCert)
// }
//
// private fun SSLConfiguration.recreateNodeCaAndTlsCertificates(nodeCaCrlDistPoint: String, tlsCrlDistPoint: String?): X509Certificate {
// val nodeKeyStore = loadNodeKeyStore()
// val (nodeCert, nodeKeys) = nodeKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_CA)
// val newNodeCert = replaceCrlDistPointCaCertificate(nodeCert, CertificateType.NODE_CA, INTERMEDIATE_CA.keyPair, nodeCaCrlDistPoint)
// val nodeCertChain = listOf(newNodeCert, INTERMEDIATE_CA.certificate, *nodeKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_CA).drop(2).toTypedArray())
// nodeKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_CA)
// nodeKeyStore.save()
// nodeKeyStore.update {
// setPrivateKey(X509Utilities.CORDA_CLIENT_CA, nodeKeys.private, nodeCertChain)
// }
// val sslKeyStore = loadSslKeyStore()
// val (tlsCert, tlsKeys) = sslKeyStore.getCertificateAndKeyPair(X509Utilities.CORDA_CLIENT_TLS)
// val newTlsCert = replaceCrlDistPointCaCertificate(tlsCert, CertificateType.TLS, nodeKeys, tlsCrlDistPoint, X500Name.getInstance(ROOT_CA.certificate.subjectX500Principal.encoded))
// val sslCertChain = listOf(newTlsCert, newNodeCert, INTERMEDIATE_CA.certificate, *sslKeyStore.getCertificateChain(X509Utilities.CORDA_CLIENT_TLS).drop(3).toTypedArray())
// sslKeyStore.internal.deleteEntry(X509Utilities.CORDA_CLIENT_TLS)
// sslKeyStore.save()
// sslKeyStore.update {
// setPrivateKey(X509Utilities.CORDA_CLIENT_TLS, tlsKeys.private, sslCertChain)
// }
// return newNodeCert
// }
//
// private fun replaceCrlDistPointCaCertificate(currentCaCert: X509Certificate, certType: CertificateType, issuerKeyPair: KeyPair, crlDistPoint: String?, crlIssuer: X500Name? = null): X509Certificate {
// val signatureScheme = Crypto.findSignatureScheme(issuerKeyPair.private)
// val provider = Crypto.findProvider(signatureScheme.providerName)
// val issuerSigner = ContentSignerBuilder.build(signatureScheme, issuerKeyPair.private, provider)
// val builder = X509Utilities.createPartialCertificate(
// certType,
// currentCaCert.issuerX500Principal,
// issuerKeyPair.public,
// currentCaCert.subjectX500Principal,
// currentCaCert.publicKey,
// Pair(Date(System.currentTimeMillis() - 5.minutes.toMillis()), Date(System.currentTimeMillis() + 10.days.toMillis())),
// null
// )
// crlDistPoint?.let {
// val distPointName = DistributionPointName(GeneralNames(GeneralName(GeneralName.uniformResourceIdentifier, it)))
// val crlIssuerGeneralNames = crlIssuer?.let {
// GeneralNames(GeneralName(crlIssuer))
// }
// val distPoint = DistributionPoint(distPointName, null, crlIssuerGeneralNames)
// builder.addExtension(Extension.cRLDistributionPoints, false, CRLDistPoint(arrayOf(distPoint)))
// }
// return builder.build(issuerSigner).toJca()
// }
//
// @Path("crl")
// inner class CrlServlet(private val server: CrlServer) {
//
// private val SIGNATURE_ALGORITHM = "SHA256withECDSA"
// private val NODE_CRL = "node.crl"
// private val INTEMEDIATE_CRL = "intermediate.crl"
// private val EMPTY_CRL = "empty.crl"
//
// @GET
// @Path("node.crl")
// @Produces("application/pkcs7-crl")
// fun getNodeCRL(): Response {
// return Response.ok(createRevocationList(
// INTERMEDIATE_CA.certificate,
// INTERMEDIATE_CA.keyPair.private,
// NODE_CRL,
// false,
// *revokedNodeCerts.toTypedArray()).encoded).build()
// }
//
// @GET
// @Path("intermediate.crl")
// @Produces("application/pkcs7-crl")
// fun getIntermediateCRL(): Response {
// return Response.ok(createRevocationList(
// ROOT_CA.certificate,
// ROOT_CA.keyPair.private,
// INTEMEDIATE_CRL,
// false,
// *revokedIntermediateCerts.toTypedArray()).encoded).build()
// }
//
// @GET
// @Path("empty.crl")
// @Produces("application/pkcs7-crl")
// fun getEmptyCRL(): Response {
// return Response.ok(createRevocationList(
// ROOT_CA.certificate,
// ROOT_CA.keyPair.private,
// EMPTY_CRL, true).encoded).build()
// }
//
// private fun createRevocationList(caCertificate: X509Certificate,
// caPrivateKey: PrivateKey,
// endpoint: String,
// indirect: Boolean,
// vararg serialNumbers: BigInteger): X509CRL {
// println("Generating CRL for $endpoint")
// val builder = JcaX509v2CRLBuilder(caCertificate.subjectX500Principal, Date(System.currentTimeMillis() - 1.minutes.toMillis()))
// val extensionUtils = JcaX509ExtensionUtils()
// builder.addExtension(Extension.authorityKeyIdentifier,
// false, extensionUtils.createAuthorityKeyIdentifier(caCertificate))
// val issuingDistPointName = GeneralName(
// GeneralName.uniformResourceIdentifier,
// "http://${server.hostAndPort.host}:${server.hostAndPort.port}/crl/$endpoint")
// // This is required and needs to match the certificate settings with respect to being indirect
// val issuingDistPoint = IssuingDistributionPoint(DistributionPointName(GeneralNames(issuingDistPointName)), indirect, false)
// builder.addExtension(Extension.issuingDistributionPoint, true, issuingDistPoint)
// builder.setNextUpdate(Date(System.currentTimeMillis() + 1.seconds.toMillis()))
// serialNumbers.forEach {
// builder.addCRLEntry(it, Date(System.currentTimeMillis() - 10.minutes.toMillis()), ReasonFlags.certificateHold)
// }
// val signer = JcaContentSignerBuilder(SIGNATURE_ALGORITHM).setProvider(BouncyCastleProvider.PROVIDER_NAME).build(caPrivateKey)
// return JcaX509CRLConverter().setProvider(BouncyCastleProvider.PROVIDER_NAME).getCRL(builder.build(signer))
// }
// }
//
// inner class CrlServer(hostAndPort: NetworkHostAndPort) : Closeable {
//
// private val server: Server = Server(InetSocketAddress(hostAndPort.host, hostAndPort.port)).apply {
// handler = HandlerCollection().apply {
// addHandler(buildServletContextHandler())
// }
// }
//
// val hostAndPort: NetworkHostAndPort
// get() = server.connectors.mapNotNull { it as? ServerConnector }
// .map { NetworkHostAndPort(it.host, it.localPort) }
// .first()
//
// override fun close() {
// println("Shutting down network management web services...")
// server.stop()
// server.join()
// }
//
// fun start() {
// server.start()
// println("Network management web services started on $hostAndPort")
// }
//
// private fun buildServletContextHandler(): ServletContextHandler {
// val crlServer = this
// return ServletContextHandler().apply {
// contextPath = "/"
// val resourceConfig = ResourceConfig().apply {
// register(CrlServlet(crlServer))
// }
// val jerseyServlet = ServletHolder(ServletContainer(resourceConfig)).apply { initOrder = 0 }
// addServlet(jerseyServlet, "/*")
// }
// }
// }
//}

View File

@ -2,7 +2,13 @@ package net.corda.node.services
import com.nhaarman.mockito_kotlin.doReturn
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.*
import net.corda.core.CordaRuntimeException
import net.corda.core.contracts.Contract
import net.corda.core.contracts.ContractState
import net.corda.core.contracts.PartyAndReference
import net.corda.core.contracts.StateAndRef
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.TransactionState
import net.corda.core.cordapp.CordappProvider
import net.corda.core.flows.FlowLogic
import net.corda.core.identity.CordaX500Name
@ -22,13 +28,13 @@ import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.getOrThrow
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.testing.common.internal.testNetworkParameters
import net.corda.testing.core.DUMMY_BANK_A_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity
import net.corda.testing.driver.DriverDSL
import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.NodeHandle
import net.corda.testing.driver.driver
import net.corda.testing.internal.MockCordappConfigProvider
@ -107,10 +113,10 @@ class AttachmentLoadingTests {
@Test
fun `test that attachments retrieved over the network are not used for code`() = withoutTestSerialization {
driver {
driver(DriverParameters(startNodesInProcess = true)) {
installIsolatedCordappTo(bankAName)
val (bankA, bankB) = createTwoNodes()
assertFailsWith<InternalNodeException> {
assertFailsWith<CordaRuntimeException>("Party C=CH,L=Zurich,O=BankB rejected session request: Don't know net.corda.finance.contracts.isolated.IsolatedDummyFlow\$Initiator") {
bankA.rpc.startFlowDynamic(flowInitiatorClass, bankB.nodeInfo.legalIdentities.first()).returnValue.getOrThrow()
}
}

View File

@ -2,13 +2,13 @@ package net.corda.node.services.rpc
import co.paralleluniverse.fibers.Suspendable
import net.corda.ClientRelevantException
import net.corda.core.CordaRuntimeException
import net.corda.core.flows.*
import net.corda.core.identity.Party
import net.corda.core.messaging.startFlow
import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap
import net.corda.node.services.Permissions
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.singleIdentity
@ -16,7 +16,7 @@ import net.corda.testing.driver.DriverParameters
import net.corda.testing.driver.NodeParameters
import net.corda.testing.driver.driver
import net.corda.testing.node.User
import org.assertj.core.api.Assertions.assertThatCode
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.assertj.core.api.AssertionsForInterfaceTypes.assertThat
import org.hibernate.exception.GenericJDBCException
import org.junit.Test
@ -33,11 +33,10 @@ class RpcExceptionHandlingTest {
val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow()
assertThatCode { node.rpc.startFlow(::Flow).returnValue.getOrThrow() }.isInstanceOfSatisfying(InternalNodeException::class.java) { exception ->
assertThatThrownBy { node.rpc.startFlow(::Flow).returnValue.getOrThrow() }.isInstanceOfSatisfying(CordaRuntimeException::class.java) { exception ->
assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty()
assertThat(exception.message).isEqualTo(InternalNodeException.defaultMessage())
}
}
}
@ -49,7 +48,7 @@ class RpcExceptionHandlingTest {
val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow()
val clientRelevantMessage = "This is for the players!"
assertThatCode { node.rpc.startFlow(::ClientRelevantErrorFlow, clientRelevantMessage).returnValue.getOrThrow() }.isInstanceOfSatisfying(ClientRelevantException::class.java) { exception ->
assertThatThrownBy { node.rpc.startFlow(::ClientRelevantErrorFlow, clientRelevantMessage).returnValue.getOrThrow() }.isInstanceOfSatisfying(CordaRuntimeException::class.java) { exception ->
assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty()
@ -63,7 +62,7 @@ class RpcExceptionHandlingTest {
driver(DriverParameters(startNodesInProcess = true, notarySpecs = emptyList())) {
val node = startNode(NodeParameters(rpcUsers = users)).getOrThrow()
val exceptionMessage = "Flow error!"
assertThatCode { node.rpc.startFlow(::FlowExceptionFlow, exceptionMessage).returnValue.getOrThrow() }
assertThatThrownBy { node.rpc.startFlow(::FlowExceptionFlow, exceptionMessage).returnValue.getOrThrow() }
.isInstanceOfSatisfying(FlowException::class.java) { exception ->
assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty()
@ -79,12 +78,10 @@ class RpcExceptionHandlingTest {
val nodeA = startNode(NodeParameters(providedName = ALICE_NAME, rpcUsers = users)).getOrThrow()
val nodeB = startNode(NodeParameters(providedName = BOB_NAME, rpcUsers = users)).getOrThrow()
assertThatCode { nodeA.rpc.startFlow(::InitFlow, nodeB.nodeInfo.singleIdentity()).returnValue.getOrThrow() }
.isInstanceOfSatisfying(InternalNodeException::class.java) { exception ->
assertThatThrownBy { nodeA.rpc.startFlow(::InitFlow, nodeB.nodeInfo.singleIdentity()).returnValue.getOrThrow() }.isInstanceOfSatisfying(CordaRuntimeException::class.java) { exception ->
assertThat(exception).hasNoCause()
assertThat(exception.stackTrace).isEmpty()
assertThat(exception.message).isEqualTo(InternalNodeException.defaultMessage())
}
}
}

View File

@ -37,6 +37,8 @@ import net.corda.node.internal.cordapp.CordappConfigFileProvider
import net.corda.node.internal.cordapp.CordappLoader
import net.corda.node.internal.cordapp.CordappProviderImpl
import net.corda.node.internal.cordapp.CordappProviderInternal
import net.corda.node.internal.rpc.proxies.AuthenticatedRpcOpsProxy
import net.corda.node.internal.rpc.proxies.ExceptionSerialisingRpcOpsProxy
import net.corda.node.internal.security.RPCSecurityManager
import net.corda.node.services.ContractUpgradeHandler
import net.corda.node.services.FinalityHandler
@ -168,7 +170,10 @@ abstract class AbstractNode(val configuration: NodeConfiguration,
/** The implementation of the [CordaRPCOps] interface used by this node. */
open fun makeRPCOps(flowStarter: FlowStarter, database: CordaPersistence, smm: StateMachineManager): CordaRPCOps {
return SecureCordaRPCOps(services, smm, database, flowStarter, { shutdownExecutor.submit { stop() } })
val ops: CordaRPCOps = CordaRPCOpsImpl(services, smm, database, flowStarter, { shutdownExecutor.submit { stop() } })
// Mind that order is relevant here.
val proxies = listOf(::AuthenticatedRpcOpsProxy, ::ExceptionSerialisingRpcOpsProxy)
return proxies.fold(ops) { delegate, decorate -> decorate(delegate) }
}
private fun initCertificate() {

View File

@ -25,12 +25,12 @@ import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.*
import net.corda.core.serialization.serialize
import net.corda.core.transactions.SignedTransaction
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.getOrThrow
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.context
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.exceptions.NonRpcFlowException
import net.corda.nodeapi.exceptions.RejectedCommandException
import net.corda.nodeapi.internal.persistence.CordaPersistence
import rx.Observable
@ -173,7 +173,7 @@ internal class CordaRPCOpsImpl(
}
private fun <T> startFlow(logicType: Class<out FlowLogic<T>>, args: Array<out Any?>): FlowStateMachine<T> {
require(logicType.isAnnotationPresent(StartableByRPC::class.java)) { "${logicType.name} was not designed for RPC" }
if (!logicType.isAnnotationPresent(StartableByRPC::class.java)) throw NonRpcFlowException(logicType)
if (isFlowsDrainingModeEnabled()) {
throw RejectedCommandException("Node is draining before shutdown. Cannot start new flows through RPC.")
}
@ -325,8 +325,4 @@ internal class CordaRPCOpsImpl(
is InvocationOrigin.Scheduled -> FlowInitiator.Scheduled((origin as InvocationOrigin.Scheduled).scheduledState)
}
}
companion object {
private val log = contextLogger()
}
}

View File

@ -0,0 +1,21 @@
package net.corda.node.internal
import java.lang.reflect.InvocationHandler
import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
/**
* Helps writing correct [InvocationHandler]s.
*/
internal interface InvocationHandlerTemplate : InvocationHandler {
val delegate: Any
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
val args = arguments ?: emptyArray()
return try {
method.invoke(delegate, *args)
} catch (e: InvocationTargetException) {
throw e.targetException
}
}
}

View File

@ -82,7 +82,7 @@ open class Node(configuration: NodeConfiguration,
fun printWarning(message: String) {
Emoji.renderIfSupported {
println("${Emoji.warningSign} ATTENTION: ${message}")
println("${Emoji.warningSign} ATTENTION: $message")
}
staticLog.warn(message)
}
@ -183,7 +183,7 @@ open class Node(configuration: NodeConfiguration,
val serverAddress = configuration.messagingServerAddress
?: NetworkHostAndPort("localhost", configuration.p2pAddress.port)
val rpcServerAddresses = if (configuration.rpcOptions.standAloneBroker) {
BrokerAddresses(configuration.rpcOptions.address!!, configuration.rpcOptions.adminAddress)
BrokerAddresses(configuration.rpcOptions.address, configuration.rpcOptions.adminAddress)
} else {
startLocalRpcBroker()
}
@ -293,10 +293,7 @@ open class Node(configuration: NodeConfiguration,
// Start up the MQ clients.
internalRpcMessagingClient?.run {
runOnStop += this::close
when (rpcOps) {
is SecureCordaRPCOps -> init(RpcExceptionHandlingProxy(rpcOps), securityManager)
else -> init(rpcOps, securityManager)
}
init(rpcOps, securityManager)
}
verifierMessagingClient?.run {
runOnStop += this::stop
@ -385,7 +382,7 @@ open class Node(configuration: NodeConfiguration,
SerializationFactoryImpl().apply {
registerScheme(AMQPServerSerializationScheme(cordappLoader.cordapps))
registerScheme(AMQPClientSerializationScheme(cordappLoader.cordapps))
registerScheme(KryoServerSerializationScheme() )
registerScheme(KryoServerSerializationScheme())
},
p2pContext = AMQP_P2P_CONTEXT.withClassLoader(classloader),
rpcServerContext = AMQP_RPC_SERVER_CONTEXT.withClassLoader(classloader),

View File

@ -1,191 +0,0 @@
package net.corda.node.internal
import net.corda.client.rpc.PermissionException
import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.ParametersUpdateInfo
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.AttachmentId
import net.corda.core.node.services.NetworkMapCache
import net.corda.core.node.services.Vault
import net.corda.core.node.services.vault.*
import net.corda.node.services.messaging.RpcAuthContext
import java.io.InputStream
import java.security.PublicKey
// TODO change to KFunction reference after Kotlin fixes https://youtrack.jetbrains.com/issue/KT-12140
class RpcAuthorisationProxy(private val implementation: CordaRPCOps, private val context: () -> RpcAuthContext) : CordaRPCOps {
override fun networkParametersFeed(): DataFeed<ParametersUpdateInfo?, ParametersUpdateInfo> = guard("networkParametersFeed") {
implementation.networkParametersFeed()
}
override fun acceptNewNetworkParameters(parametersHash: SecureHash) = guard("acceptNewNetworkParameters") {
implementation.acceptNewNetworkParameters(parametersHash)
}
override fun uploadAttachmentWithMetadata(jar: InputStream, uploader: String, filename: String): SecureHash = guard("uploadAttachmentWithMetadata") {
implementation.uploadAttachmentWithMetadata(jar, uploader, filename)
}
override fun queryAttachments(query: AttachmentQueryCriteria, sorting: AttachmentSort?): List<AttachmentId> = guard("queryAttachments") {
implementation.queryAttachments(query, sorting)
}
override fun stateMachinesSnapshot() = guard("stateMachinesSnapshot") {
implementation.stateMachinesSnapshot()
}
override fun stateMachinesFeed() = guard("stateMachinesFeed") {
implementation.stateMachinesFeed()
}
override fun <T : ContractState> vaultQueryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = guard("vaultQueryBy") {
implementation.vaultQueryBy(criteria, paging, sorting, contractStateType)
}
override fun <T : ContractState> vaultTrackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = guard("vaultTrackBy") {
implementation.vaultTrackBy(criteria, paging, sorting, contractStateType)
}
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsSnapshot() = guard("internalVerifiedTransactionsSnapshot", implementation::internalVerifiedTransactionsSnapshot)
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsFeed() = guard("internalVerifiedTransactionsFeed", implementation::internalVerifiedTransactionsFeed)
override fun stateMachineRecordedTransactionMappingSnapshot() = guard("stateMachineRecordedTransactionMappingSnapshot", implementation::stateMachineRecordedTransactionMappingSnapshot)
override fun stateMachineRecordedTransactionMappingFeed() = guard("stateMachineRecordedTransactionMappingFeed", implementation::stateMachineRecordedTransactionMappingFeed)
override fun networkMapSnapshot(): List<NodeInfo> = guard("networkMapSnapshot", implementation::networkMapSnapshot)
override fun networkMapFeed(): DataFeed<List<NodeInfo>, NetworkMapCache.MapChange> = guard("networkMapFeed", implementation::networkMapFeed)
override fun <T> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startFlowDynamic", listOf(logicType)) {
implementation.startFlowDynamic(logicType, *args)
}
override fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startTrackedFlowDynamic", listOf(logicType)) {
implementation.startTrackedFlowDynamic(logicType, *args)
}
override fun killFlow(id: StateMachineRunId): Boolean = guard("killFlow") {
return implementation.killFlow(id)
}
override fun nodeInfo(): NodeInfo = guard("nodeInfo", implementation::nodeInfo)
override fun notaryIdentities(): List<Party> = guard("notaryIdentities", implementation::notaryIdentities)
override fun addVaultTransactionNote(txnId: SecureHash, txnNote: String) = guard("addVaultTransactionNote") {
implementation.addVaultTransactionNote(txnId, txnNote)
}
override fun getVaultTransactionNotes(txnId: SecureHash): Iterable<String> = guard("getVaultTransactionNotes") {
implementation.getVaultTransactionNotes(txnId)
}
override fun attachmentExists(id: SecureHash) = guard("attachmentExists") {
implementation.attachmentExists(id)
}
override fun openAttachment(id: SecureHash) = guard("openAttachment") {
implementation.openAttachment(id)
}
override fun uploadAttachment(jar: InputStream) = guard("uploadAttachment") {
implementation.uploadAttachment(jar)
}
override fun currentNodeTime() = guard("currentNodeTime", implementation::currentNodeTime)
override fun waitUntilNetworkReady() = guard("waitUntilNetworkReady", implementation::waitUntilNetworkReady)
override fun wellKnownPartyFromAnonymous(party: AbstractParty) = guard("wellKnownPartyFromAnonymous") {
implementation.wellKnownPartyFromAnonymous(party)
}
override fun partyFromKey(key: PublicKey) = guard("partyFromKey") {
implementation.partyFromKey(key)
}
override fun wellKnownPartyFromX500Name(x500Name: CordaX500Name) = guard("wellKnownPartyFromX500Name") {
implementation.wellKnownPartyFromX500Name(x500Name)
}
override fun notaryPartyFromX500Name(x500Name: CordaX500Name) = guard("notaryPartyFromX500Name") {
implementation.notaryPartyFromX500Name(x500Name)
}
override fun partiesFromName(query: String, exactMatch: Boolean) = guard("partiesFromName") {
implementation.partiesFromName(query, exactMatch)
}
override fun registeredFlows() = guard("registeredFlows", implementation::registeredFlows)
override fun nodeInfoFromParty(party: AbstractParty) = guard("nodeInfoFromParty") {
implementation.nodeInfoFromParty(party)
}
override fun clearNetworkMapCache() = guard("clearNetworkMapCache", implementation::clearNetworkMapCache)
override fun <T : ContractState> vaultQuery(contractStateType: Class<out T>): Vault.Page<T> = guard("vaultQuery") {
implementation.vaultQuery(contractStateType)
}
override fun <T : ContractState> vaultQueryByCriteria(criteria: QueryCriteria, contractStateType: Class<out T>): Vault.Page<T> = guard("vaultQueryByCriteria") {
implementation.vaultQueryByCriteria(criteria, contractStateType)
}
override fun <T : ContractState> vaultQueryByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification): Vault.Page<T> = guard("vaultQueryByWithPagingSpec") {
implementation.vaultQueryByWithPagingSpec(contractStateType, criteria, paging)
}
override fun <T : ContractState> vaultQueryByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort): Vault.Page<T> = guard("vaultQueryByWithSorting") {
implementation.vaultQueryByWithSorting(contractStateType, criteria, sorting)
}
override fun <T : ContractState> vaultTrack(contractStateType: Class<out T>): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrack") {
implementation.vaultTrack(contractStateType)
}
override fun <T : ContractState> vaultTrackByCriteria(contractStateType: Class<out T>, criteria: QueryCriteria): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrackByCriteria") {
implementation.vaultTrackByCriteria(contractStateType, criteria)
}
override fun <T : ContractState> vaultTrackByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrackByWithPagingSpec") {
implementation.vaultTrackByWithPagingSpec(contractStateType, criteria, paging)
}
override fun <T : ContractState> vaultTrackByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort): DataFeed<Vault.Page<T>, Vault.Update<T>> = guard("vaultTrackByWithSorting") {
implementation.vaultTrackByWithSorting(contractStateType, criteria, sorting)
}
override fun setFlowsDrainingModeEnabled(enabled: Boolean) = guard("setFlowsDrainingModeEnabled") {
implementation.setFlowsDrainingModeEnabled(enabled)
}
override fun isFlowsDrainingModeEnabled(): Boolean = guard("isFlowsDrainingModeEnabled", implementation::isFlowsDrainingModeEnabled)
override fun shutdown() = guard("shutdown", implementation::shutdown)
// TODO change to KFunction reference after Kotlin fixes https://youtrack.jetbrains.com/issue/KT-12140
private inline fun <RESULT> guard(methodName: String, action: () -> RESULT) = guard(methodName, emptyList(), action)
// TODO change to KFunction reference after Kotlin fixes https://youtrack.jetbrains.com/issue/KT-12140
private inline fun <RESULT> guard(methodName: String, args: List<Class<*>>, action: () -> RESULT) : RESULT {
if (!context().isPermitted(methodName, *(args.map { it.name }.toTypedArray()))) {
throw PermissionException("User not authorized to perform RPC call $methodName with target $args")
}
else {
return action()
}
}
}

View File

@ -1,154 +0,0 @@
package net.corda.node.internal
import net.corda.core.concurrent.CordaFuture
import net.corda.core.contracts.ContractState
import net.corda.core.crypto.SecureHash
import net.corda.core.doOnError
import net.corda.core.flows.FlowLogic
import net.corda.core.flows.StateMachineRunId
import net.corda.core.identity.AbstractParty
import net.corda.core.identity.CordaX500Name
import net.corda.core.internal.concurrent.doOnError
import net.corda.core.internal.concurrent.mapError
import net.corda.core.mapErrors
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.node.services.vault.*
import net.corda.core.utilities.loggerFor
import net.corda.nodeapi.exceptions.InternalNodeException
import net.corda.nodeapi.exceptions.adapters.InternalObfuscatingFlowHandle
import net.corda.nodeapi.exceptions.adapters.InternalObfuscatingFlowProgressHandle
import java.io.InputStream
import java.security.PublicKey
class RpcExceptionHandlingProxy(private val delegate: SecureCordaRPCOps) : CordaRPCOps {
private companion object {
private val logger = loggerFor<RpcExceptionHandlingProxy>()
}
override val protocolVersion: Int get() = delegate.protocolVersion
override fun <T> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowHandle<T> = wrap {
val handle = delegate.startFlowDynamic(logicType, *args)
val result = InternalObfuscatingFlowHandle(handle)
result.returnValue.doOnError { error -> logger.error(error.message, error) }
result
}
override fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?): FlowProgressHandle<T> = wrap {
val handle = delegate.startTrackedFlowDynamic(logicType, *args)
val result = InternalObfuscatingFlowProgressHandle(handle)
result.returnValue.doOnError { error -> logger.error(error.message, error) }
result
}
override fun waitUntilNetworkReady() = wrapFuture(delegate::waitUntilNetworkReady)
override fun stateMachinesFeed() = wrapFeed(delegate::stateMachinesFeed)
override fun <T : ContractState> vaultTrackBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = wrapFeed { delegate.vaultTrackBy(criteria, paging, sorting, contractStateType) }
override fun <T : ContractState> vaultTrack(contractStateType: Class<out T>) = wrapFeed { delegate.vaultTrack(contractStateType) }
override fun <T : ContractState> vaultTrackByCriteria(contractStateType: Class<out T>, criteria: QueryCriteria) = wrapFeed { delegate.vaultTrackByCriteria(contractStateType, criteria) }
override fun <T : ContractState> vaultTrackByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification) = wrapFeed { delegate.vaultTrackByWithPagingSpec(contractStateType, criteria, paging) }
override fun <T : ContractState> vaultTrackByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort) = wrapFeed { delegate.vaultTrackByWithSorting(contractStateType, criteria, sorting) }
override fun stateMachineRecordedTransactionMappingFeed() = wrapFeed(delegate::stateMachineRecordedTransactionMappingFeed)
override fun networkMapFeed() = wrapFeed(delegate::networkMapFeed)
override fun networkParametersFeed() = wrapFeed(delegate::networkParametersFeed)
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsFeed() = wrapFeed(delegate::internalVerifiedTransactionsFeed)
override fun stateMachinesSnapshot() = wrap(delegate::stateMachinesSnapshot)
override fun <T : ContractState> vaultQueryBy(criteria: QueryCriteria, paging: PageSpecification, sorting: Sort, contractStateType: Class<out T>) = wrap { delegate.vaultQueryBy(criteria, paging, sorting, contractStateType) }
override fun <T : ContractState> vaultQuery(contractStateType: Class<out T>) = wrap { delegate.vaultQuery(contractStateType) }
override fun <T : ContractState> vaultQueryByCriteria(criteria: QueryCriteria, contractStateType: Class<out T>) = wrap { delegate.vaultQueryByCriteria(criteria, contractStateType) }
override fun <T : ContractState> vaultQueryByWithPagingSpec(contractStateType: Class<out T>, criteria: QueryCriteria, paging: PageSpecification) = wrap { delegate.vaultQueryByWithPagingSpec(contractStateType, criteria, paging) }
override fun <T : ContractState> vaultQueryByWithSorting(contractStateType: Class<out T>, criteria: QueryCriteria, sorting: Sort) = wrap { delegate.vaultQueryByWithSorting(contractStateType, criteria, sorting) }
@Suppress("DEPRECATION", "OverridingDeprecatedMember")
override fun internalVerifiedTransactionsSnapshot() = wrap(delegate::internalVerifiedTransactionsSnapshot)
override fun stateMachineRecordedTransactionMappingSnapshot() = wrap(delegate::stateMachineRecordedTransactionMappingSnapshot)
override fun networkMapSnapshot() = wrap(delegate::networkMapSnapshot)
override fun acceptNewNetworkParameters(parametersHash: SecureHash) = wrap { delegate.acceptNewNetworkParameters(parametersHash) }
override fun killFlow(id: StateMachineRunId) = wrap { delegate.killFlow(id) }
override fun nodeInfo() = wrap(delegate::nodeInfo)
override fun notaryIdentities() = wrap(delegate::notaryIdentities)
override fun addVaultTransactionNote(txnId: SecureHash, txnNote: String) = wrap { delegate.addVaultTransactionNote(txnId, txnNote) }
override fun getVaultTransactionNotes(txnId: SecureHash) = wrap { delegate.getVaultTransactionNotes(txnId) }
override fun attachmentExists(id: SecureHash) = wrap { delegate.attachmentExists(id) }
override fun openAttachment(id: SecureHash) = wrap { delegate.openAttachment(id) }
override fun uploadAttachment(jar: InputStream) = wrap { delegate.uploadAttachment(jar) }
override fun uploadAttachmentWithMetadata(jar: InputStream, uploader: String, filename: String) = wrap { delegate.uploadAttachmentWithMetadata(jar, uploader, filename) }
override fun queryAttachments(query: AttachmentQueryCriteria, sorting: AttachmentSort?) = wrap { delegate.queryAttachments(query, sorting) }
override fun currentNodeTime() = wrap(delegate::currentNodeTime)
override fun wellKnownPartyFromAnonymous(party: AbstractParty) = wrap { delegate.wellKnownPartyFromAnonymous(party) }
override fun partyFromKey(key: PublicKey) = wrap { delegate.partyFromKey(key) }
override fun wellKnownPartyFromX500Name(x500Name: CordaX500Name) = wrap { delegate.wellKnownPartyFromX500Name(x500Name) }
override fun notaryPartyFromX500Name(x500Name: CordaX500Name) = wrap { delegate.notaryPartyFromX500Name(x500Name) }
override fun partiesFromName(query: String, exactMatch: Boolean) = wrap { delegate.partiesFromName(query, exactMatch) }
override fun registeredFlows() = wrap(delegate::registeredFlows)
override fun nodeInfoFromParty(party: AbstractParty) = wrap { delegate.nodeInfoFromParty(party) }
override fun clearNetworkMapCache() = wrap(delegate::clearNetworkMapCache)
override fun setFlowsDrainingModeEnabled(enabled: Boolean) = wrap { delegate.setFlowsDrainingModeEnabled(enabled) }
override fun isFlowsDrainingModeEnabled() = wrap(delegate::isFlowsDrainingModeEnabled)
override fun shutdown() = wrap(delegate::shutdown)
private fun <RESULT> wrap(call: () -> RESULT): RESULT {
return try {
call.invoke()
} catch (error: Throwable) {
logger.error(error.message, error)
throw InternalNodeException.obfuscateIfInternal(error)
}
}
private fun <SNAPSHOT, ELEMENT> wrapFeed(call: () -> DataFeed<SNAPSHOT, ELEMENT>) = wrap {
call.invoke().doOnError { error -> logger.error(error.message, error) }.mapErrors(InternalNodeException.Companion::obfuscateIfInternal)
}
private fun <RESULT> wrapFuture(call: () -> CordaFuture<RESULT>): CordaFuture<RESULT> = wrap { call.invoke().mapError(InternalNodeException.Companion::obfuscateIfInternal).doOnError { error -> logger.error(error.message, error) } }
}

View File

@ -1,25 +0,0 @@
package net.corda.node.internal
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.services.api.FlowStarter
import net.corda.node.services.api.ServiceHubInternal
import net.corda.node.services.messaging.rpcContext
import net.corda.node.services.statemachine.StateMachineManager
import net.corda.nodeapi.internal.persistence.CordaPersistence
/**
* Implementation of [CordaRPCOps] that checks authorisation.
*/
class SecureCordaRPCOps(services: ServiceHubInternal,
smm: StateMachineManager,
database: CordaPersistence,
flowStarter: FlowStarter,
shutdownNode: () -> Unit,
val unsafe: CordaRPCOps = CordaRPCOpsImpl(services, smm, database, flowStarter, shutdownNode)) : CordaRPCOps by RpcAuthorisationProxy(unsafe, ::rpcContext) {
/**
* Returns the RPC protocol version, which is the same the node's Platform Version. Exists since version 1 so guaranteed
* to be present.
*/
override val protocolVersion: Int get() = unsafe.nodeInfo().platformVersion
}

View File

@ -0,0 +1,53 @@
package net.corda.node.internal.rpc.proxies
import net.corda.client.rpc.PermissionException
import net.corda.core.flows.FlowLogic
import net.corda.core.messaging.CordaRPCOps
import net.corda.node.internal.InvocationHandlerTemplate
import net.corda.node.services.messaging.RpcAuthContext
import net.corda.node.services.messaging.rpcContext
import java.lang.reflect.Method
import java.lang.reflect.Proxy
/**
* Implementation of [CordaRPCOps] that checks authorisation.
*/
internal class AuthenticatedRpcOpsProxy(private val delegate: CordaRPCOps) : CordaRPCOps by proxy(delegate, ::rpcContext) {
/**
* Returns the RPC protocol version, which is the same the node's Platform Version. Exists since version 1 so guaranteed
* to be present.
*/
override val protocolVersion: Int get() = delegate.nodeInfo().platformVersion
// Need overriding to pass additional `listOf(logicType)` argument for polymorphic `startFlow` permissions.
override fun <T> startFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startFlowDynamic", listOf(logicType), ::rpcContext) {
delegate.startFlowDynamic(logicType, *args)
}
// Need overriding to pass additional `listOf(logicType)` argument for polymorphic `startFlow` permissions.
override fun <T> startTrackedFlowDynamic(logicType: Class<out FlowLogic<T>>, vararg args: Any?) = guard("startTrackedFlowDynamic", listOf(logicType), ::rpcContext) {
delegate.startTrackedFlowDynamic(logicType, *args)
}
private companion object {
private fun proxy(delegate: CordaRPCOps, context: () -> RpcAuthContext): CordaRPCOps {
val handler = PermissionsEnforcingInvocationHandler(delegate, context)
return Proxy.newProxyInstance(delegate::class.java.classLoader, arrayOf(CordaRPCOps::class.java), handler) as CordaRPCOps
}
}
private class PermissionsEnforcingInvocationHandler(override val delegate: CordaRPCOps, private val context: () -> RpcAuthContext) : InvocationHandlerTemplate {
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?) = guard(method.name, context, { super.invoke(proxy, method, arguments) })
}
}
private fun <RESULT> guard(methodName: String, context: () -> RpcAuthContext, action: () -> RESULT) = guard(methodName, emptyList(), context, action)
private fun <RESULT> guard(methodName: String, args: List<Class<*>>, context: () -> RpcAuthContext, action: () -> RESULT): RESULT {
if (!context().isPermitted(methodName, *(args.map { it.name }.toTypedArray()))) {
throw PermissionException("User not authorized to perform RPC call $methodName with target $args")
} else {
return action()
}
}

View File

@ -0,0 +1,100 @@
package net.corda.node.internal.rpc.proxies
import net.corda.core.CordaRuntimeException
import net.corda.core.CordaThrowable
import net.corda.core.concurrent.CordaFuture
import net.corda.core.internal.concurrent.mapError
import net.corda.core.mapErrors
import net.corda.core.messaging.CordaRPCOps
import net.corda.core.messaging.DataFeed
import net.corda.core.messaging.FlowHandle
import net.corda.core.messaging.FlowHandleImpl
import net.corda.core.messaging.FlowProgressHandle
import net.corda.core.messaging.FlowProgressHandleImpl
import net.corda.core.serialization.CordaSerializable
import net.corda.node.internal.InvocationHandlerTemplate
import rx.Observable
import java.lang.reflect.Method
import java.lang.reflect.Proxy.newProxyInstance
internal class ExceptionSerialisingRpcOpsProxy(private val delegate: CordaRPCOps) : CordaRPCOps by proxy(delegate) {
private companion object {
private fun proxy(delegate: CordaRPCOps): CordaRPCOps {
val handler = ErrorSerialisingInvocationHandler(delegate)
return newProxyInstance(delegate::class.java.classLoader, arrayOf(CordaRPCOps::class.java), handler) as CordaRPCOps
}
}
private class ErrorSerialisingInvocationHandler(override val delegate: CordaRPCOps) : InvocationHandlerTemplate {
override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
try {
val result = super.invoke(proxy, method, arguments)
return result?.let { ensureSerialisable(it) }
} catch (exception: Exception) {
throw ensureSerialisable(exception)
}
}
private fun <RESULT : Any> ensureSerialisable(result: RESULT): Any {
return when (result) {
is CordaFuture<*> -> wrapFuture(result)
is DataFeed<*, *> -> wrapFeed(result)
is FlowProgressHandle<*> -> wrapFlowProgressHandle(result)
is FlowHandle<*> -> wrapFlowHandle(result)
is Observable<*> -> wrapObservable(result)
else -> result
}
}
private fun wrapFlowProgressHandle(handle: FlowProgressHandle<*>): FlowProgressHandle<*> {
val returnValue = wrapFuture(handle.returnValue)
val progress = wrapObservable(handle.progress)
val stepsTreeIndexFeed = handle.stepsTreeIndexFeed?.let { wrapFeed(it) }
val stepsTreeFeed = handle.stepsTreeFeed?.let { wrapFeed(it) }
return FlowProgressHandleImpl(handle.id, returnValue, progress, stepsTreeIndexFeed, stepsTreeFeed)
}
private fun wrapFlowHandle(handle: FlowHandle<*>): FlowHandle<*> {
return FlowHandleImpl(handle.id, wrapFuture(handle.returnValue))
}
private fun <ELEMENT> wrapObservable(observable: Observable<ELEMENT>): Observable<ELEMENT> {
return observable.mapErrors(::ensureSerialisable)
}
private fun <SNAPSHOT, ELEMENT> wrapFeed(feed: DataFeed<SNAPSHOT, ELEMENT>): DataFeed<SNAPSHOT, ELEMENT> {
return feed.mapErrors(::ensureSerialisable)
}
private fun <RESULT> wrapFuture(future: CordaFuture<RESULT>): CordaFuture<RESULT> {
return future.mapError(::ensureSerialisable)
}
private fun ensureSerialisable(error: Throwable): Throwable {
val serialisable = (superclasses(error::class.java) + error::class.java).any { it.isAnnotationPresent(CordaSerializable::class.java) || it.interfaces.any { it.isAnnotationPresent(CordaSerializable::class.java) } }
val result = if (serialisable) error else CordaRuntimeException(error.message, error)
if (result is CordaThrowable) {
result.stackTrace = arrayOf<StackTraceElement>()
result.setCause(null)
}
return result
}
private fun superclasses(clazz: Class<*>): List<Class<*>> {
val superclasses = mutableListOf<Class<*>>()
var current: Class<*>?
var superclass = clazz.superclass
while (superclass != null) {
superclasses += superclass
current = superclass
superclass = current.superclass
}
return superclasses
}
}
override fun toString(): String {
return "ExceptionSerialisingRpcOpsProxy"
}
}

View File

@ -26,6 +26,7 @@ import net.corda.node.services.config.NodeConfiguration
import net.corda.node.services.vault.HibernateAttachmentQueryCriteriaParser
import net.corda.node.utilities.NonInvalidatingCache
import net.corda.node.utilities.NonInvalidatingWeightBasedCache
import net.corda.nodeapi.exceptions.DuplicateAttachmentException
import net.corda.nodeapi.internal.persistence.NODE_DATABASE_PREFIX
import net.corda.nodeapi.internal.persistence.currentDBSession
import net.corda.nodeapi.internal.withContractsInJar
@ -287,7 +288,7 @@ class NodeAttachmentService(
log.info("Stored new attachment $id")
id
} else {
throw java.nio.file.FileAlreadyExistsException(id.toString())
throw DuplicateAttachmentException(id.toString())
}
}
}

View File

@ -31,13 +31,13 @@ import net.corda.finance.USD
import net.corda.finance.contracts.asset.Cash
import net.corda.finance.flows.CashIssueFlow
import net.corda.finance.flows.CashPaymentFlow
import net.corda.node.internal.SecureCordaRPCOps
import net.corda.node.internal.StartedNode
import net.corda.node.internal.security.RPCSecurityManagerImpl
import net.corda.node.services.Permissions.Companion.invokeRpc
import net.corda.node.services.Permissions.Companion.startFlow
import net.corda.node.services.messaging.CURRENT_RPC_CONTEXT
import net.corda.node.services.messaging.RpcAuthContext
import net.corda.nodeapi.exceptions.NonRpcFlowException
import net.corda.nodeapi.internal.config.User
import net.corda.testing.core.ALICE_NAME
import net.corda.testing.core.expect
@ -88,7 +88,7 @@ class CordaRPCOpsImplTest {
fun setup() {
mockNet = InternalMockNetwork(cordappPackages = listOf("net.corda.finance.contracts.asset"))
aliceNode = mockNet.createNode(InternalMockNodeParameters(legalName = ALICE_NAME))
rpc = SecureCordaRPCOps(aliceNode.services, aliceNode.smm, aliceNode.database, aliceNode.services, { })
rpc = aliceNode.rpcOps
CURRENT_RPC_CONTEXT.set(RpcAuthContext(InvocationContext.rpc(testActor()), buildSubject("TEST_USER", emptySet())))
mockNet.runNetwork()
@ -266,10 +266,10 @@ class CordaRPCOpsImplTest {
@Test
fun `can't upload the same attachment`() {
withPermissions(invokeRpc(CordaRPCOps::uploadAttachment), invokeRpc(CordaRPCOps::attachmentExists)) {
val inputJar1 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val inputJar2 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val secureHash1 = rpc.uploadAttachment(inputJar1)
assertThatExceptionOfType(java.nio.file.FileAlreadyExistsException::class.java).isThrownBy {
val inputJar1 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val inputJar2 = Thread.currentThread().contextClassLoader.getResourceAsStream(testJar)
val secureHash1 = rpc.uploadAttachment(inputJar1)
val secureHash2 = rpc.uploadAttachment(inputJar2)
}
}
@ -293,7 +293,7 @@ class CordaRPCOpsImplTest {
@Test
fun `attempt to start non-RPC flow`() {
withPermissions(startFlow<NonRPCFlow>()) {
assertThatExceptionOfType(IllegalArgumentException::class.java).isThrownBy {
assertThatExceptionOfType(NonRpcFlowException::class.java).isThrownBy {
rpc.startFlow(::NonRPCFlow)
}
}

View File

@ -106,6 +106,19 @@ class FlowFrameworkTests {
assertThat(flow.lazyTime).isNotNull()
}
class SuspendThrowingActionExecutor(private val exception: Exception, val delegate: ActionExecutor) : ActionExecutor {
var thrown = false
@Suspendable
override fun executeAction(fiber: FlowFiber, action: Action) {
if (action is Action.CommitTransaction && !thrown) {
thrown = true
throw exception
} else {
delegate.executeAction(fiber, action)
}
}
}
@Test
fun `exception while fiber suspended`() {
bobNode.registerFlowFactory(ReceiveFlow::class) { InitiatedSendFlow("Hello", it) }
@ -113,7 +126,7 @@ class FlowFrameworkTests {
val fiber = aliceNode.services.startFlow(flow) as FlowStateMachineImpl
// Before the flow runs change the suspend action to throw an exception
val exceptionDuringSuspend = Exception("Thrown during suspend")
val throwingActionExecutor = ThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor)
val throwingActionExecutor = SuspendThrowingActionExecutor(exceptionDuringSuspend, fiber.transientValues!!.value.actionExecutor)
fiber.transientValues = TransientReference(fiber.transientValues!!.value.copy(actionExecutor = throwingActionExecutor))
mockNet.runNetwork()
assertThatThrownBy {