ENT-3658, ENT-3660: Add timeouts and hospital flow handling to CryptoServices (#5226)

This commit is contained in:
fowlerrr 2019-07-03 12:39:32 +01:00 committed by Shams Asari
parent 2bfd2c8cb5
commit 6df142bf7a
7 changed files with 185 additions and 14 deletions

View File

@ -2,10 +2,14 @@ package net.corda.nodeapi.internal.cryptoservice
import net.corda.core.DoNotImplement import net.corda.core.DoNotImplement
import net.corda.core.crypto.SignatureScheme import net.corda.core.crypto.SignatureScheme
import net.corda.core.utilities.getOrThrow
import net.corda.nodeapi.internal.crypto.X509Utilities import net.corda.nodeapi.internal.crypto.X509Utilities
import org.bouncycastle.operator.ContentSigner import org.bouncycastle.operator.ContentSigner
import java.security.KeyPair import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.time.Duration
import java.util.concurrent.Executors
import java.util.concurrent.TimeoutException
/** /**
* Unlike [CryptoService] can only perform "read-only" operations but never create new key pairs. * Unlike [CryptoService] can only perform "read-only" operations but never create new key pairs.
@ -48,7 +52,27 @@ interface SignOnlyCryptoService {
* Fully-powered crypto service which can sign as well as create new key pairs. * Fully-powered crypto service which can sign as well as create new key pairs.
*/ */
@DoNotImplement @DoNotImplement
interface CryptoService : SignOnlyCryptoService { abstract class CryptoService(private val timeout: Duration? = null) : AutoCloseable, SignOnlyCryptoService {
private val executor = Executors.newCachedThreadPool()
override fun close() {
executor.shutdown()
}
/**
* Adds a timeout for the given [func].
* @param timeout The time to wait on the function completing (in milliseconds)
* @param func The call that we're waiting on
* @return the return value of the function call
* @throws TimedCryptoServiceException if we reach the timeout
*/
private fun <A> withTimeout(timeout: Duration?, func: () -> A) : A {
try {
return executor.submit(func).getOrThrow(timeout)
} catch (e: TimeoutException) {
throw TimedCryptoServiceException("Timed-out while waiting for ${timeout?.toMillis()} milliseconds")
}
}
/** /**
* Generate and store a new [KeyPair]. * Generate and store a new [KeyPair].
@ -58,7 +82,39 @@ interface CryptoService : SignOnlyCryptoService {
* *
* Returns the [PublicKey] of the generated [KeyPair]. * Returns the [PublicKey] of the generated [KeyPair].
*/ */
fun generateKeyPair(alias: String, scheme: SignatureScheme): PublicKey fun generateKeyPair(alias: String, scheme: SignatureScheme): PublicKey =
withTimeout(timeout) { _generateKeyPair(alias, scheme) }
protected abstract fun _generateKeyPair(alias: String, scheme: SignatureScheme): PublicKey
/** Check if this [CryptoService] has a private key entry for the input alias. */
override fun containsKey(alias: String): Boolean =
withTimeout(timeout) { _containsKey(alias) }
protected abstract fun _containsKey(alias: String): Boolean
/**
* Returns the [PublicKey] of the input alias or null if it doesn't exist.
*/
override fun getPublicKey(alias: String): PublicKey =
withTimeout(timeout) { _getPublicKey(alias) }
protected abstract fun _getPublicKey(alias: String): PublicKey
/**
* Sign a [ByteArray] using the private key identified by the input alias.
* Returns the signature bytes formatted according to the signature scheme.
* The signAlgorithm if specified determines the signature scheme used for signing, if
* not specified then the signature scheme is based on the private key scheme.
*/
override fun sign(alias: String, data: ByteArray, signAlgorithm: String?): ByteArray =
withTimeout(timeout) { _sign(alias, data, signAlgorithm) }
protected abstract fun _sign(alias: String, data: ByteArray, signAlgorithm: String?): ByteArray
/**
* Returns [ContentSigner] for the key identified by the input alias.
*/
override fun getSigner(alias: String): ContentSigner =
withTimeout(timeout) { _getSigner(alias) }
protected abstract fun _getSigner(alias: String): ContentSigner
} }
open class CryptoServiceException(message: String?, cause: Throwable? = null) : Exception(message, cause) open class CryptoServiceException(message: String?, cause: Throwable? = null) : Exception(message, cause)
class TimedCryptoServiceException(message: String?, cause: Throwable? = null) : CryptoServiceException(message, cause)

View File

@ -3,11 +3,16 @@ package net.corda.nodeapi.internal.cryptoservice
import net.corda.core.identity.CordaX500Name import net.corda.core.identity.CordaX500Name
import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier import net.corda.nodeapi.internal.config.FileBasedCertificateStoreSupplier
import net.corda.nodeapi.internal.cryptoservice.bouncycastle.BCCryptoService import net.corda.nodeapi.internal.cryptoservice.bouncycastle.BCCryptoService
import java.nio.file.Path import java.time.Duration
class CryptoServiceFactory { class CryptoServiceFactory {
companion object { companion object {
fun makeCryptoService(cryptoServiceName: SupportedCryptoServices, legalName: CordaX500Name, signingCertificateStore: FileBasedCertificateStoreSupplier? = null, cryptoServiceConf: Path? = null): CryptoService { fun makeCryptoService(
cryptoServiceName: SupportedCryptoServices,
legalName: CordaX500Name,
signingCertificateStore: FileBasedCertificateStoreSupplier? = null,
timeout: Duration? = null
): CryptoService {
// The signing certificate store can be null for other services as only BCC requires is at the moment. // The signing certificate store can be null for other services as only BCC requires is at the moment.
if (cryptoServiceName != SupportedCryptoServices.BC_SIMPLE || signingCertificateStore == null) { if (cryptoServiceName != SupportedCryptoServices.BC_SIMPLE || signingCertificateStore == null) {
throw IllegalArgumentException("Currently only BouncyCastle is used as a crypto service. A valid signing certificate store is required.") throw IllegalArgumentException("Currently only BouncyCastle is used as a crypto service. A valid signing certificate store is required.")

View File

@ -23,13 +23,14 @@ import javax.security.auth.x500.X500Principal
* and a Java KeyStore in the form of [CertificateStore] to store private keys. * and a Java KeyStore in the form of [CertificateStore] to store private keys.
* This service reuses the [NodeConfiguration.signingCertificateStore] to store keys. * This service reuses the [NodeConfiguration.signingCertificateStore] to store keys.
*/ */
class BCCryptoService(private val legalName: X500Principal, private val certificateStoreSupplier: CertificateStoreSupplier) : CryptoService { class BCCryptoService(private val legalName: X500Principal,
private val certificateStoreSupplier: CertificateStoreSupplier) : CryptoService() {
// TODO check if keyStore exists. // TODO check if keyStore exists.
// TODO make it private when E2ETestKeyManagementService does not require direct access to the private key. // TODO make it private when E2ETestKeyManagementService does not require direct access to the private key.
var certificateStore: CertificateStore = certificateStoreSupplier.get(true) var certificateStore: CertificateStore = certificateStoreSupplier.get(true)
override fun generateKeyPair(alias: String, scheme: SignatureScheme): PublicKey { override fun _generateKeyPair(alias: String, scheme: SignatureScheme): PublicKey {
try { try {
val keyPair = Crypto.generateKeyPair(scheme) val keyPair = Crypto.generateKeyPair(scheme)
importKey(alias, keyPair) importKey(alias, keyPair)
@ -39,11 +40,11 @@ class BCCryptoService(private val legalName: X500Principal, private val certific
} }
} }
override fun containsKey(alias: String): Boolean { override fun _containsKey(alias: String): Boolean {
return certificateStore.contains(alias) return certificateStore.contains(alias)
} }
override fun getPublicKey(alias: String): PublicKey { override fun _getPublicKey(alias: String): PublicKey {
try { try {
return certificateStore.query { getPublicKey(alias) } return certificateStore.query { getPublicKey(alias) }
} catch (e: Exception) { } catch (e: Exception) {
@ -51,8 +52,7 @@ class BCCryptoService(private val legalName: X500Principal, private val certific
} }
} }
@JvmOverloads override fun _sign(alias: String, data: ByteArray, signAlgorithm: String?): ByteArray {
override fun sign(alias: String, data: ByteArray, signAlgorithm: String?): ByteArray {
try { try {
return when(signAlgorithm) { return when(signAlgorithm) {
null -> Crypto.doSign(certificateStore.query { getPrivateKey(alias, certificateStore.entryPassword) }, data) null -> Crypto.doSign(certificateStore.query { getPrivateKey(alias, certificateStore.entryPassword) }, data)
@ -71,7 +71,7 @@ class BCCryptoService(private val legalName: X500Principal, private val certific
return signature.sign() return signature.sign()
} }
override fun getSigner(alias: String): ContentSigner { override fun _getSigner(alias: String): ContentSigner {
try { try {
val privateKey = certificateStore.query { getPrivateKey(alias, certificateStore.entryPassword) } val privateKey = certificateStore.query { getPrivateKey(alias, certificateStore.entryPassword) }
val signatureScheme = Crypto.findSignatureScheme(privateKey) val signatureScheme = Crypto.findSignatureScheme(privateKey)

View File

@ -0,0 +1,68 @@
package net.corda.nodeapi.internal.cryptoservice
import net.corda.core.crypto.SignatureScheme
import net.corda.core.internal.times
import org.bouncycastle.operator.ContentSigner
import org.junit.After
import org.junit.Before
import org.junit.Test
import java.security.PublicKey
import java.time.Duration
import kotlin.test.assertFailsWith
import kotlin.test.expect
class CryptoServiceTest {
private val TEST_TIMEOUT = Duration.ofMillis(500)
private var sleepTime = TEST_TIMEOUT
private lateinit var stub: CryptoService
@Before
fun setUp() {
stub = CryptoServiceStub()
}
@After
fun tearDown() {
stub.close()
}
inner class CryptoServiceStub : CryptoService(TEST_TIMEOUT) {
private fun sleeper() {
Thread.sleep(sleepTime.toMillis())
}
override fun _generateKeyPair(alias: String, scheme: SignatureScheme): PublicKey {
throw NotImplementedError("Not needed for this test")
}
override fun _containsKey(alias: String): Boolean {
sleeper()
return true
}
override fun _getPublicKey(alias: String): PublicKey {
throw NotImplementedError("Not needed for this test")
}
override fun _sign(alias: String, data: ByteArray, signAlgorithm: String?): ByteArray {
throw NotImplementedError("Not needed for this test")
}
override fun _getSigner(alias: String): ContentSigner {
throw NotImplementedError("Not needed for this test")
}
}
@Test
fun `if no timeout is reached then correct value is returned`() {
sleepTime = Duration.ZERO
expect(true) { stub.containsKey("Test") }
}
@Test
fun `when timeout is reached the correct exception is thrown`() {
sleepTime = TEST_TIMEOUT.times(2)
assertFailsWith(TimedCryptoServiceException::class) { stub.containsKey("Test") }
}
}

View File

@ -111,7 +111,6 @@ import java.util.concurrent.TimeUnit.MINUTES
import java.util.concurrent.TimeUnit.SECONDS import java.util.concurrent.TimeUnit.SECONDS
import java.util.function.Consumer import java.util.function.Consumer
import javax.persistence.EntityManager import javax.persistence.EntityManager
import net.corda.core.crypto.generateKeyPair as cryptoGenerateKeyPair
/** /**
* A base node implementation that can be customised either for production (with real implementations that do real * A base node implementation that can be customised either for production (with real implementations that do real
@ -176,7 +175,11 @@ abstract class AbstractNode<S>(val configuration: NodeConfiguration,
val transactionStorage = makeTransactionStorage(configuration.transactionCacheSizeBytes).tokenize() val transactionStorage = makeTransactionStorage(configuration.transactionCacheSizeBytes).tokenize()
val networkMapClient: NetworkMapClient? = configuration.networkServices?.let { NetworkMapClient(it.networkMapURL, versionInfo) } val networkMapClient: NetworkMapClient? = configuration.networkServices?.let { NetworkMapClient(it.networkMapURL, versionInfo) }
val attachments = NodeAttachmentService(metricRegistry, cacheFactory, database, configuration.devMode).tokenize() val attachments = NodeAttachmentService(metricRegistry, cacheFactory, database, configuration.devMode).tokenize()
val cryptoService = CryptoServiceFactory.makeCryptoService(SupportedCryptoServices.BC_SIMPLE, configuration.myLegalName, configuration.signingCertificateStore) val cryptoService = CryptoServiceFactory.makeCryptoService(
SupportedCryptoServices.BC_SIMPLE,
configuration.myLegalName,
configuration.signingCertificateStore
).closeOnStop()
@Suppress("LeakingThis") @Suppress("LeakingThis")
val networkParametersStorage = makeNetworkParametersStorage() val networkParametersStorage = makeNetworkParametersStorage()
val cordappProvider = CordappProviderImpl(cordappLoader, CordappConfigFileProvider(configuration.cordappDirectories), attachments).tokenize() val cordappProvider = CordappProviderImpl(cordappLoader, CordappConfigFileProvider(configuration.cordappDirectories), attachments).tokenize()

View File

@ -11,6 +11,7 @@ import net.corda.core.messaging.DataFeed
import net.corda.core.utilities.contextLogger import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.seconds import net.corda.core.utilities.seconds
import net.corda.node.services.FinalityHandler import net.corda.node.services.FinalityHandler
import net.corda.nodeapi.internal.cryptoservice.TimedCryptoServiceException
import org.hibernate.exception.ConstraintViolationException import org.hibernate.exception.ConstraintViolationException
import rx.subjects.PublishSubject import rx.subjects.PublishSubject
import java.sql.SQLException import java.sql.SQLException
@ -25,7 +26,7 @@ import kotlin.math.pow
class StaffedFlowHospital(private val flowMessaging: FlowMessaging, private val ourSenderUUID: String) { class StaffedFlowHospital(private val flowMessaging: FlowMessaging, private val ourSenderUUID: String) {
private companion object { private companion object {
private val log = contextLogger() private val log = contextLogger()
private val staff = listOf(DeadlockNurse, DuplicateInsertSpecialist, DoctorTimeout, FinalityDoctor) private val staff = listOf(DeadlockNurse, DuplicateInsertSpecialist, DoctorTimeout, CryptoServiceTimeout, FinalityDoctor)
} }
private val mutex = ThreadBox(object { private val mutex = ThreadBox(object {
@ -303,6 +304,24 @@ class StaffedFlowHospital(private val flowMessaging: FlowMessaging, private val
} }
} }
/**
* Restarts [TimedFlow], keeping track of the number of retries and making sure it does not
* exceed the limit specified by the [FlowTimeoutException].
*/
object CryptoServiceTimeout : Staff {
override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: FlowMedicalHistory): Diagnosis {
return if (newError is TimedCryptoServiceException) {
if (history.notDischargedForTheSameThingMoreThan(2, this, currentState)) {
Diagnosis.DISCHARGE
} else {
Diagnosis.OVERNIGHT_OBSERVATION
}
} else {
Diagnosis.NOT_MY_SPECIALTY
}
}
}
object FinalityDoctor : Staff { object FinalityDoctor : Staff {
override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: FlowMedicalHistory): Diagnosis { override fun consult(flowFiber: FlowFiber, currentState: StateMachineState, newError: Throwable, history: FlowMedicalHistory): Diagnosis {
return if (currentState.flowLogic is FinalityHandler || isFromReceiveFinalityFlow(newError)) { return if (currentState.flowLogic is FinalityHandler || isFromReceiveFinalityFlow(newError)) {

View File

@ -26,6 +26,7 @@ import net.corda.core.utilities.ProgressTracker.Change
import net.corda.core.utilities.getOrThrow import net.corda.core.utilities.getOrThrow
import net.corda.core.utilities.unwrap import net.corda.core.utilities.unwrap
import net.corda.node.services.persistence.checkpoints import net.corda.node.services.persistence.checkpoints
import net.corda.nodeapi.internal.cryptoservice.TimedCryptoServiceException
import net.corda.testing.contracts.DummyContract import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState import net.corda.testing.contracts.DummyState
import net.corda.testing.core.ALICE_NAME import net.corda.testing.core.ALICE_NAME
@ -424,6 +425,25 @@ class FlowFrameworkTests {
assertThat(result.getOrThrow()).isEqualTo("HelloHello") assertThat(result.getOrThrow()).isEqualTo("HelloHello")
} }
@Test
fun `timed out CryptoService is sent to the flow hospital`() {
bobNode.registerCordappFlowFactory(ReceiveFlow::class) {
ExceptionFlow { TimedCryptoServiceException("We timed out!") }
}
aliceNode.services.startFlow(ReceiveFlow(bob)).resultFuture
mockNet.runNetwork()
assertThat(receivedSessionMessages.filter { it.message is ExistingSessionMessage && it.message.payload is ErrorSessionMessage }).hasSize(1)
val medicalRecords = bobNode.smm.flowHospital.track().apply { updates.notUsed() }.snapshot
// We expect three discharges and then overnight observation (in that order)
assertThat(medicalRecords).hasSize(4)
assertThat(medicalRecords.filter { it.outcome == StaffedFlowHospital.Outcome.DISCHARGE }).hasSize(3)
assertThat(medicalRecords.last().outcome == StaffedFlowHospital.Outcome.OVERNIGHT_OBSERVATION)
}
@Test @Test
fun `non-FlowException thrown on other side`() { fun `non-FlowException thrown on other side`() {
val erroringFlowFuture = bobNode.registerCordappFlowFactory(ReceiveFlow::class) { val erroringFlowFuture = bobNode.registerCordappFlowFactory(ReceiveFlow::class) {