client certificate signing utility

This commit is contained in:
Patrick Kuo
2016-09-16 17:17:32 +01:00
parent c328655f68
commit 9b4bf32fdc
48 changed files with 1812 additions and 132 deletions

View File

@ -297,10 +297,11 @@ abstract class AbstractNode(val dir: Path, val configuration: NodeConfiguration,
}
services.networkMapCache.addNode(info)
// In the unit test environment, we may run without any network map service sometimes.
if (networkMapService == null && inNodeNetworkMapService == null)
if (networkMapService == null && inNodeNetworkMapService == null) {
services.networkMapCache.runWithoutMapService()
return noNetworkMapConfigured()
else
return registerWithNetworkMap(networkMapService ?: info.address)
}
return registerWithNetworkMap(networkMapService ?: info.address)
}
private fun registerWithNetworkMap(networkMapServiceAddress: SingleMessageRecipient): ListenableFuture<Unit> {

View File

@ -34,4 +34,10 @@ data class Checkpoint(
val serialisedFiber: SerializedBytes<ProtocolStateMachineImpl<*>>,
val request: ProtocolIORequest?,
val receivedPayload: Any?
)
) {
// This flag is always false when loaded from storage as it isn't serialised.
// It is used to track when the associated fiber has been created, but not necessarily started when
// messages for protocols arrive before the system has fully loaded at startup.
@Transient
var fiberCreated: Boolean = false
}

View File

@ -1,5 +1,6 @@
package com.r3corda.node.services.network
import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ListenableFuture
import com.google.common.util.concurrent.MoreExecutors
import com.google.common.util.concurrent.SettableFuture
@ -45,6 +46,9 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
get() = registeredNodes.map { it.value }
private val _changed = PublishSubject.create<MapChange>()
override val changed: Observable<MapChange> = _changed
private val _registrationFuture = SettableFuture.create<Unit>()
override val mapServiceRegistered: ListenableFuture<Unit>
get() = _registrationFuture
private var registeredForPush = false
protected var registeredNodes = Collections.synchronizedMap(HashMap<Party, NodeInfo>())
@ -82,6 +86,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
// Add a message handler for the response, and prepare a future to put the data into.
// Note that the message handler will run on the network thread (not this one).
val future = SettableFuture.create<Unit>()
_registrationFuture.setFuture(future)
net.runOnNextMessage(NetworkMapService.FETCH_PROTOCOL_TOPIC, sessionID, MoreExecutors.directExecutor()) { message ->
val resp = message.data.deserialize<NetworkMapService.FetchMapResponse>()
// We may not receive any nodes back, if the map hasn't changed since the version specified
@ -120,6 +125,7 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
// Add a message handler for the response, and prepare a future to put the data into.
// Note that the message handler will run on the network thread (not this one).
val future = SettableFuture.create<Unit>()
_registrationFuture.setFuture(future)
net.runOnNextMessage(NetworkMapService.SUBSCRIPTION_PROTOCOL_TOPIC, sessionID, MoreExecutors.directExecutor()) { message ->
val resp = message.data.deserialize<NetworkMapService.SubscribeResponse>()
if (resp.confirmed) {
@ -151,4 +157,9 @@ open class InMemoryNetworkMapCache : SingletonSerializeAsToken(), NetworkMapCach
AddOrRemove.REMOVE -> removeNode(reg.node)
}
}
@VisibleForTesting
override fun runWithoutMapService() {
_registrationFuture.set(Unit)
}
}

View File

@ -14,6 +14,7 @@ import com.r3corda.core.messaging.send
import com.r3corda.core.protocols.ProtocolLogic
import com.r3corda.core.protocols.ProtocolStateMachine
import com.r3corda.core.serialization.*
import com.r3corda.core.then
import com.r3corda.core.utilities.ProgressTracker
import com.r3corda.core.utilities.trace
import com.r3corda.node.services.api.Checkpoint
@ -73,6 +74,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
private val checkpointingMeter = metrics.meter("Protocols.Checkpointing Rate")
private val totalStartedProtocols = metrics.counter("Protocols.Started")
private val totalFinishedProtocols = metrics.counter("Protocols.Finished")
private var started = false
// Context for tokenized services in checkpoints
private val serializationContext = SerializeAsTokenContext(tokenizableServices, quasarKryo())
@ -118,13 +120,23 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
fun start() {
checkpointStorage.checkpoints.forEach { restoreFromCheckpoint(it) }
checkpointStorage.checkpoints.forEach { createFiberForCheckpoint(it) }
serviceHub.networkMapCache.mapServiceRegistered.then(executor) {
synchronized(started) {
started = true
stateMachines.forEach { restartFiber(it.key, it.value) }
}
}
}
private fun restoreFromCheckpoint(checkpoint: Checkpoint) {
val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint })
private fun createFiberForCheckpoint(checkpoint: Checkpoint) {
if (!checkpoint.fiberCreated) {
val fiber = deserializeFiber(checkpoint.serialisedFiber)
initFiber(fiber, { checkpoint })
}
}
private fun restartFiber(fiber: ProtocolStateMachineImpl<*>, checkpoint: Checkpoint) {
if (checkpoint.request is ReceiveRequest<*>) {
val topicSession = checkpoint.request.receiveTopicSession
fiber.logger.info("Restored ${fiber.logic} - it was previously waiting for message of type ${checkpoint.request.receiveType.name} on $topicSession")
@ -179,7 +191,7 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
}
}
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint) {
private fun initFiber(psm: ProtocolStateMachineImpl<*>, startingCheckpoint: () -> Checkpoint): Checkpoint {
psm.serviceHub = serviceHub
psm.suspendAction = { request ->
psm.logger.trace { "Suspended fiber ${psm.id} ${psm.logic}" }
@ -194,8 +206,12 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
totalFinishedProtocols.inc()
notifyChangeObservers(psm, AddOrRemove.REMOVE)
}
stateMachines[psm] = startingCheckpoint()
val checkpoint = startingCheckpoint()
checkpoint.fiberCreated = true
totalStartedProtocols.inc()
stateMachines[psm] = checkpoint
notifyChangeObservers(psm, AddOrRemove.ADD)
return checkpoint
}
/**
@ -206,17 +222,22 @@ class StateMachineManager(val serviceHub: ServiceHubInternal, tokenizableService
fun <T> add(loggerName: String, logic: ProtocolLogic<T>): ProtocolStateMachine<T> {
val fiber = ProtocolStateMachineImpl(logic, scheduler, loggerName)
// Need to add before iterating in case of immediate completion
initFiber(fiber) {
val checkpoint = initFiber(fiber) {
val checkpoint = Checkpoint(serializeFiber(fiber), null, null)
checkpointStorage.addCheckpoint(checkpoint)
checkpoint
}
checkpointStorage.addCheckpoint(checkpoint)
synchronized(started) { // If we are not started then our checkpoint will be picked up during start
if (!started) {
return fiber
}
}
try {
executor.executeASAP {
iterateStateMachine(fiber, null) {
fiber.start()
}
totalStartedProtocols.inc()
}
} catch (e: ExecutionException) {
// There are two ways we can take exceptions in this method:

View File

@ -0,0 +1,133 @@
package com.r3corda.node.utilities.certsigning
import com.r3corda.core.crypto.X509Utilities
import com.r3corda.core.crypto.X509Utilities.CORDA_CLIENT_CA
import com.r3corda.core.crypto.X509Utilities.CORDA_CLIENT_CA_PRIVATE_KEY
import com.r3corda.core.crypto.X509Utilities.CORDA_ROOT_CA
import com.r3corda.core.crypto.X509Utilities.addOrReplaceCertificate
import com.r3corda.core.crypto.X509Utilities.addOrReplaceKey
import com.r3corda.core.div
import com.r3corda.core.minutes
import com.r3corda.core.utilities.loggerFor
import com.r3corda.node.services.config.FullNodeConfiguration
import com.r3corda.node.services.config.NodeConfiguration
import com.r3corda.node.services.messaging.ArtemisMessagingComponent
import joptsimple.OptionParser
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.Paths
import java.security.KeyPair
import java.security.cert.Certificate
import kotlin.system.exitProcess
/**
* This check the [certificatePath] for certificates required to connect to the Corda network.
* If the certificates are not found, a [PKCS10CertificationRequest] will be submitted to Corda network permissioning server using [CertificateSigningService].
* This process will enter a slow polling loop until the request has been approved, and then
* the certificate chain will be downloaded and stored in [KeyStore] reside in [certificatePath].
*/
class CertificateSigner(certificatePath: Path, val nodeConfig: NodeConfiguration, val certService: CertificateSigningService) : ArtemisMessagingComponent(certificatePath, nodeConfig) {
companion object {
val pollInterval = 1.minutes
val log = loggerFor<CertificateSigner>()
}
fun buildKeyStore() {
val caKeyStore = X509Utilities.loadOrCreateKeyStore(keyStorePath, config.keyStorePassword)
if (!caKeyStore.containsAlias(CORDA_CLIENT_CA)) {
// No certificate found in key store, create certificate signing request and post request to signing server.
log.info("No certificate found in key store, creating certificate signing request...")
// Create or load key pair from the key store.
val keyPair = X509Utilities.loadOrCreateKeyPairFromKeyStore(keyStorePath, config.keyStorePassword,
config.keyStorePassword, CORDA_CLIENT_CA_PRIVATE_KEY) {
X509Utilities.createSelfSignedCACert(nodeConfig.myLegalName)
}
log.info("Submitting certificate signing request to Corda certificate signing server.")
val requestId = submitCertificateSigningRequest(keyPair)
log.info("Successfully submitted request to Corda certificate signing server, request ID : $requestId")
log.info("Start polling server for certificate signing approval.")
val certificates = pollServerForCertificates(requestId)
log.info("Certificate signing request approved, installing new certificates.")
// Save private key and certificate chain to the key store.
caKeyStore.addOrReplaceKey(CORDA_CLIENT_CA_PRIVATE_KEY, keyPair.private,
config.keyStorePassword.toCharArray(), certificates)
// Assumes certificate chain always starts with client certificate and end with root certificate.
caKeyStore.addOrReplaceCertificate(CORDA_CLIENT_CA, certificates.first())
X509Utilities.saveKeyStore(caKeyStore, keyStorePath, config.keyStorePassword)
// Save certificates to trust store.
val trustStore = X509Utilities.loadOrCreateKeyStore(trustStorePath, config.trustStorePassword)
// Assumes certificate chain always starts with client certificate and end with root certificate.
trustStore.addOrReplaceCertificate(CORDA_ROOT_CA, certificates.last())
X509Utilities.saveKeyStore(trustStore, trustStorePath, config.trustStorePassword)
} else {
log.trace("Certificate already exists, exiting certificate signer...")
}
}
/**
* Poll Certificate Signing Server for approved certificate,
* enter a slow polling loop if server return null.
* @param requestId Certificate signing request ID.
* @return Map of certificate chain.
*/
private fun pollServerForCertificates(requestId: String): Array<Certificate> {
// Poll server to download the signed certificate once request has been approved.
var certificates = certService.retrieveCertificates(requestId)
while (certificates == null) {
Thread.sleep(pollInterval.toMillis())
certificates = certService.retrieveCertificates(requestId)
}
return certificates
}
/**
* Submit Certificate Signing Request to Certificate signing service if request ID not found in file system
* New request ID will be stored in requestId.txt
* @param keyPair Public Private key pair generated for SSL certification.
* @return Request ID return from the server.
*/
private fun submitCertificateSigningRequest(keyPair: KeyPair): String {
val requestIdStore = certificatePath / "certificate-request-id.txt"
// Retrieve request id from file if exists, else post a request to server.
return if (!Files.exists(requestIdStore)) {
val request = X509Utilities.createCertificateSigningRequest(nodeConfig.myLegalName, nodeConfig.nearestCity, nodeConfig.emailAddress, keyPair)
// Post request to signing server via http.
val requestId = certService.submitRequest(request)
// Persists request ID to file in case of node shutdown.
Files.write(requestIdStore, listOf(requestId), Charsets.UTF_8)
requestId
} else {
Files.readAllLines(requestIdStore).first()
}
}
}
object ParamsSpec {
val parser = OptionParser()
val baseDirectoryArg = parser.accepts("base-dir", "The directory to put all key stores under").withRequiredArg()
val configFileArg = parser.accepts("config-file", "The path to the config file").withRequiredArg()
}
fun main(args: Array<String>) {
val cmdlineOptions = try {
ParamsSpec.parser.parse(*args)
} catch (ex: Exception) {
CertificateSigner.log.error("Unable to parse args", ex)
exitProcess(1)
}
val baseDirectoryPath = Paths.get(cmdlineOptions.valueOf(ParamsSpec.baseDirectoryArg) ?: throw IllegalArgumentException("Please provide Corda node base directory path"))
val configFile = if (cmdlineOptions.has(ParamsSpec.configFileArg)) Paths.get(cmdlineOptions.valueOf(ParamsSpec.configFileArg)) else null
val conf = FullNodeConfiguration(NodeConfiguration.loadConfig(baseDirectoryPath, configFile, allowMissingConfig = true))
// TODO: Use HTTPS instead
CertificateSigner(baseDirectoryPath / "certificate", conf, HTTPCertificateSigningService(conf.certificateSigningService)).buildKeyStore()
}

View File

@ -0,0 +1,11 @@
package com.r3corda.node.utilities.certsigning
import org.bouncycastle.pkcs.PKCS10CertificationRequest
import java.security.cert.Certificate
interface CertificateSigningService {
/** Submits a CSR to the signing service and returns an opaque request ID. */
fun submitRequest(request: PKCS10CertificationRequest): String
/** Poll Certificate Signing Server for the request and returns a chain of certificates if request has been approved, null otherwise. */
fun retrieveCertificates(requestId: String): Array<Certificate>?
}

View File

@ -0,0 +1,59 @@
package com.r3corda.node.utilities.certsigning
import com.google.common.net.HostAndPort
import org.apache.commons.io.IOUtils
import org.bouncycastle.pkcs.PKCS10CertificationRequest
import java.io.IOException
import java.net.HttpURLConnection
import java.net.URL
import java.security.cert.Certificate
import java.security.cert.CertificateFactory
import java.util.*
import java.util.zip.ZipInputStream
class HTTPCertificateSigningService(val server: HostAndPort) : CertificateSigningService {
companion object {
// TODO: Propagate version information from gradle
val clientVersion = "1.0"
}
override fun retrieveCertificates(requestId: String): Array<Certificate>? {
// Poll server to download the signed certificate once request has been approved.
val url = URL("http://$server/api/certificate/$requestId")
val conn = url.openConnection() as HttpURLConnection
conn.requestMethod = "GET"
return when (conn.responseCode) {
HttpURLConnection.HTTP_OK -> conn.inputStream.use {
ZipInputStream(it).use {
val certificates = ArrayList<Certificate>()
while (it.nextEntry != null) {
certificates.add(CertificateFactory.getInstance("X.509").generateCertificate(it))
}
certificates.toTypedArray()
}
}
HttpURLConnection.HTTP_NO_CONTENT -> null
HttpURLConnection.HTTP_UNAUTHORIZED -> throw IOException("Certificate signing request has been rejected, please contact Corda network administrator for more information.")
else -> throw IOException("Unexpected response code ${conn.responseCode} - ${IOUtils.toString(conn.errorStream)}")
}
}
override fun submitRequest(request: PKCS10CertificationRequest): String {
// Post request to certificate signing server via http.
val conn = URL("http://$server/api/certificate").openConnection() as HttpURLConnection
conn.doOutput = true
conn.requestMethod = "POST"
conn.setRequestProperty("Content-Type", "application/octet-stream")
conn.setRequestProperty("Client-Version", clientVersion)
conn.outputStream.write(request.encoded)
return when (conn.responseCode) {
HttpURLConnection.HTTP_OK -> IOUtils.toString(conn.inputStream)
HttpURLConnection.HTTP_FORBIDDEN -> throw IOException("Client version $clientVersion is forbidden from accessing permissioning server, please upgrade to newer version.")
else -> throw IOException("Unexpected response code ${conn.responseCode} - ${IOUtils.toString(conn.errorStream)}")
}
}
}

View File

@ -91,6 +91,7 @@ class NodeSchedulerServiceTest : SingletonSerializeAsToken() {
smmHasRemovedAllProtocols.countDown()
}
}
mockSMM.start()
services.smm = mockSMM
}

View File

@ -13,6 +13,8 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.After
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertTrue
class StateMachineManagerTests {
@ -62,13 +64,72 @@ class StateMachineManagerTests {
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
}
@Test
fun `protocol added before network map does run after init`() {
val node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking()
node3.smm.add("test", protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
assertEquals(true, protocol.protocolStarted) // Now we should have run the protocol
}
@Test
fun `protocol added before network map will be init checkpointed`() {
var node3 = net.createNode(node1.info.address) //create vanilla node
val protocol = ProtocolNoBlocking()
node3.smm.add("test", protocol)
assertEquals(false, protocol.protocolStarted) // Not started yet as no network activity has been allowed yet
node3.stop()
node3 = net.createNode(node1.info.address, forcedID = node3.id)
val restoredProtocol = node3.smm.findStateMachines(ProtocolNoBlocking::class.java).single().first
assertEquals(false, restoredProtocol.protocolStarted) // Not started yet as no network activity has been allowed yet
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
assertEquals(true, restoredProtocol.protocolStarted) // Now we should have run the protocol and hopefully cleared the init checkpoint
node3.stop()
// Now it is completed the protocol should leave no Checkpoint.
node3 = net.createNode(node1.info.address, forcedID = node3.id)
net.runNetwork() // Allow network map messages to flow
node3.smm.executor.flush()
assertTrue(node3.smm.findStateMachines(ProtocolNoBlocking::class.java).isEmpty())
}
@Test
fun `protocol loaded from checkpoint will respond to messages from before start`() {
val topic = "send-and-receive"
val payload = random63BitValue()
val sendProtocol = SendProtocol(topic, node2.info.identity, payload)
val receiveProtocol = ReceiveProtocol(topic, node1.info.identity)
connectProtocols(sendProtocol, receiveProtocol)
node2.smm.add("test", receiveProtocol) // Prepare checkpointed receive protocol
node2.stop() // kill receiver
node1.smm.add("test", sendProtocol) // now generate message to spool up and thus come in ahead of messages for NetworkMapService
val restoredProtocol = node2.restartAndGetRestoredProtocol<ReceiveProtocol>(node1.info.address)
assertThat(restoredProtocol.receivedPayload).isEqualTo(payload)
}
private inline fun <reified P : NonTerminatingProtocol> MockNode.restartAndGetRestoredProtocol(networkMapAddress: SingleMessageRecipient? = null): P {
val servicesArray = advertisedServices.toTypedArray()
val node = mockNet.createNode(networkMapAddress, id, advertisedServices = *servicesArray)
mockNet.runNetwork() // allow NetworkMapService messages to stabilise and thus start the state machine
return node.smm.findStateMachines(P::class.java).single().first
}
private class ProtocolNoBlocking : ProtocolLogic<Unit>() {
@Transient var protocolStarted = false
@Suspendable
override fun call() {
protocolStarted = true
}
override val topic: String get() = throw UnsupportedOperationException()
}
private class ProtocolWithoutCheckpoints : NonTerminatingProtocol() {
@Transient var protocolStarted = false

View File

@ -0,0 +1,81 @@
package com.r3corda.node.utilities.certsigning
import com.google.common.net.HostAndPort
import com.nhaarman.mockito_kotlin.any
import com.nhaarman.mockito_kotlin.eq
import com.nhaarman.mockito_kotlin.mock
import com.r3corda.core.crypto.SecureHash
import com.r3corda.core.crypto.X509Utilities
import com.r3corda.core.div
import com.r3corda.node.services.config.NodeConfiguration
import org.junit.Rule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import java.nio.file.Files
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
class CertificateSignerTest {
@Rule
@JvmField
val tempFolder: TemporaryFolder = TemporaryFolder()
@Test
fun buildKeyStore() {
val id = SecureHash.randomSHA256().toString()
val certs = arrayOf(X509Utilities.createSelfSignedCACert("CORDA_CLIENT_CA").certificate,
X509Utilities.createSelfSignedCACert("CORDA_INTERMEDIATE_CA").certificate,
X509Utilities.createSelfSignedCACert("CORDA_ROOT_CA").certificate)
val certService: CertificateSigningService = mock {
on { submitRequest(any()) }.then { id }
on { retrieveCertificates(eq(id)) }.then { certs }
}
val keyStore = tempFolder.root.toPath().resolve("sslkeystore.jks")
val tmpTrustStore = tempFolder.root.toPath().resolve("truststore.jks")
assertFalse(Files.exists(keyStore))
assertFalse(Files.exists(tmpTrustStore))
val config = object : NodeConfiguration {
override val myLegalName: String = "me"
override val nearestCity: String = "London"
override val emailAddress: String = ""
override val devMode: Boolean = true
override val exportJMXto: String = ""
override val keyStorePassword: String = "testpass"
override val trustStorePassword: String = "trustpass"
override val certificateSigningService: HostAndPort = HostAndPort.fromParts("localhost", 0)
}
CertificateSigner(tempFolder.root.toPath(), config, certService).buildKeyStore()
assertTrue(Files.exists(keyStore))
assertTrue(Files.exists(tmpTrustStore))
X509Utilities.loadKeyStore(keyStore, config.keyStorePassword).run {
assertTrue(containsAlias(X509Utilities.CORDA_CLIENT_CA_PRIVATE_KEY))
assertTrue(containsAlias(X509Utilities.CORDA_CLIENT_CA))
assertFalse(containsAlias(X509Utilities.CORDA_INTERMEDIATE_CA))
assertFalse(containsAlias(X509Utilities.CORDA_INTERMEDIATE_CA_PRIVATE_KEY))
assertFalse(containsAlias(X509Utilities.CORDA_ROOT_CA))
assertFalse(containsAlias(X509Utilities.CORDA_ROOT_CA_PRIVATE_KEY))
}
X509Utilities.loadKeyStore(tmpTrustStore, config.trustStorePassword).run {
assertFalse(containsAlias(X509Utilities.CORDA_CLIENT_CA_PRIVATE_KEY))
assertFalse(containsAlias(X509Utilities.CORDA_CLIENT_CA))
assertFalse(containsAlias(X509Utilities.CORDA_INTERMEDIATE_CA))
assertFalse(containsAlias(X509Utilities.CORDA_INTERMEDIATE_CA_PRIVATE_KEY))
assertTrue(containsAlias(X509Utilities.CORDA_ROOT_CA))
assertFalse(containsAlias(X509Utilities.CORDA_ROOT_CA_PRIVATE_KEY))
}
assertEquals(id, Files.readAllLines(tempFolder.root.toPath() / "certificate-request-id.txt").first())
}
}