diff --git a/core/src/main/kotlin/com/r3corda/core/crypto/WhitelistTrustManager.kt b/core/src/main/kotlin/com/r3corda/core/crypto/WhitelistTrustManager.kt new file mode 100644 index 0000000000..068ba28284 --- /dev/null +++ b/core/src/main/kotlin/com/r3corda/core/crypto/WhitelistTrustManager.kt @@ -0,0 +1,168 @@ +package com.r3corda.core.crypto + +import sun.security.util.HostnameChecker +import java.net.InetAddress +import java.net.Socket +import java.net.UnknownHostException +import java.security.KeyStore +import java.security.Provider +import java.security.Security +import java.security.cert.CertificateException +import java.security.cert.X509Certificate +import javax.net.ssl.* + +/** + * Call this to change the default verification algorithm and this use the WhitelistTrustManager + * implementation. This is a work around to the fact that ArtemisMQ and probably many other libraries + * don't correctly configure the SSLParameters with setEndpointIdentificationAlgorithm and thus don't check + * that the certificate matches with the DNS entry requested. This exposes us to man in the middle attacks. + */ +fun registerWhitelistTrustManager() { + if (Security.getProvider("WhitelistTrustManager") == null) { + Security.addProvider(WhitelistTrustManagerProvider) + } +} + +/** + * Custom Securtity Provider that forces the TrustManagerFactory to be our custom one. + * Also holds the identity of the original TrustManager algorithm so + * that we can delegate most of the checking to the proper Java code. We simply add some more checks. + * + * The whitelist automatically includes the local server DNS name and IP address + * + */ +object WhitelistTrustManagerProvider : Provider("WhitelistTrustManager", + 1.0, + "Provider for custom trust manager that always validates certificate names") { + + val originalTrustProviderAlgorithm = Security.getProperty("ssl.TrustManagerFactory.algorithm") + + private val _whitelist = mutableSetOf() + val whitelist: Set get() = _whitelist.toSet() // The acceptable IP and DNS names for clients and servers. + + init { + // Add ourselves to whitelist + val host = InetAddress.getLocalHost() + addWhitelistEntry(host.hostName) + + // Register our custom TrustManagerFactorySpi + put("TrustManagerFactory.whitelistTrustManager", "com.r3corda.core.crypto.WhitelistTrustManagerSpi") + + // Forcibly change the TrustManagerFactory defaultAlgorithm to be us + Security.setProperty("ssl.TrustManagerFactory.algorithm", "whitelistTrustManager") + } + + /** + * Adds an extra name to the whitelist if not already present + */ + fun addWhitelistEntry(serverName: String) { + if(!_whitelist.contains(serverName)) { + addWhitelistEntries(listOf(serverName)) + } + } + + /** + * Adds a list of servers to the whitelist and also adds their fully resolved name after DNS lookup + */ + fun addWhitelistEntries(serverNames: List) { + _whitelist.addAll(serverNames) + for(serveName in serverNames) { + try { + val addresses = InetAddress.getAllByName(serveName).toList() + _whitelist.addAll(addresses.map { y -> y.canonicalHostName }) + _whitelist.addAll(addresses.map { y -> y.hostAddress }) + } catch (ex: UnknownHostException) { + // Ignore if the server name is not resolvable e.g. for wildcard addresses, or addresses that can only be resolved externally + } + } + } +} + +/** + * Registered TrustManagerFactorySpi + */ +class WhitelistTrustManagerSpi : TrustManagerFactorySpi() { + //Get the original implementation to delegate to (can't use Kotlin delegation on abstract classes unfortunately). + val originalProvider = TrustManagerFactory.getInstance(WhitelistTrustManagerProvider.originalTrustProviderAlgorithm) + + override fun engineInit(keyStore: KeyStore?) { + originalProvider.init(keyStore) + } + + override fun engineInit(managerFactoryParameters: ManagerFactoryParameters?) { + originalProvider.init(managerFactoryParameters) + } + + override fun engineGetTrustManagers(): Array { + val parent = originalProvider.trustManagers.first() as X509ExtendedTrustManager + //Wrap original provider in ours and return + return arrayOf(WhitelistTrustManager(parent)) + } +} + +/** + * Our TrustManager extension takes the standard certificate checker and first delegates all the + * chain checking to that. If everything is well formed we then simply add a check against our whitelist + */ +class WhitelistTrustManager(val originalProvider: X509ExtendedTrustManager) : X509ExtendedTrustManager() { + // Use same Helper class as standard HTTPS library validator + val checker = HostnameChecker.getInstance(HostnameChecker.TYPE_TLS) + + private fun checkIdentity(hostname: String?, cert: X509Certificate) { + // Based on standard code in sun.security.ssl.X509TrustManagerImpl.checkIdentity + if ((hostname != null) && hostname.startsWith("[") && hostname.endsWith("]")) { + checker.match(hostname.substring(1, hostname.length - 1), cert) + } else { + checker.match(hostname, cert) + } + } + + /** + * scan whitelist and confirm the certificate matches at least one entry + */ + private fun checkWhitelist(cert: X509Certificate) { + for (whiteListEntry in WhitelistTrustManagerProvider.whitelist) { + try { + checkIdentity(whiteListEntry, cert) + return // if we get here without throwing we had a match + } catch(ex: CertificateException) { + // + } + } + throw CertificateException("Certificate not on whitelist ${cert.subjectDN}") + } + + override fun checkClientTrusted(chain: Array, authType: String, socket: Socket?) { + originalProvider.checkClientTrusted(chain, authType, socket) + checkWhitelist(chain[0]) + } + + override fun checkClientTrusted(chain: Array, authType: String, engine: SSLEngine?) { + originalProvider.checkClientTrusted(chain, authType, engine) + checkWhitelist(chain[0]) + } + + override fun checkClientTrusted(chain: Array, authType: String) { + originalProvider.checkClientTrusted(chain, authType) + checkWhitelist(chain[0]) + } + + override fun checkServerTrusted(chain: Array, authType: String, socket: Socket?) { + originalProvider.checkServerTrusted(chain, authType, socket) + checkWhitelist(chain[0]) + } + + override fun checkServerTrusted(chain: Array, authType: String, engine: SSLEngine?) { + originalProvider.checkServerTrusted(chain, authType, engine) + checkWhitelist(chain[0]) + } + + override fun checkServerTrusted(chain: Array, authType: String) { + originalProvider.checkServerTrusted(chain, authType) + checkWhitelist(chain[0]) + } + + override fun getAcceptedIssuers(): Array { + return originalProvider.acceptedIssuers + } +} diff --git a/core/src/test/kotlin/com/r3corda/core/crypto/WhitelistTrustManagerTest.kt b/core/src/test/kotlin/com/r3corda/core/crypto/WhitelistTrustManagerTest.kt new file mode 100644 index 0000000000..a32e55e4ff --- /dev/null +++ b/core/src/test/kotlin/com/r3corda/core/crypto/WhitelistTrustManagerTest.kt @@ -0,0 +1,203 @@ +package com.r3corda.core.crypto + +import org.junit.BeforeClass +import org.junit.Test +import java.net.Socket +import java.security.KeyStore +import java.security.cert.CertificateException +import java.security.cert.X509Certificate +import javax.net.ssl.SSLEngine +import javax.net.ssl.TrustManagerFactory +import javax.net.ssl.X509ExtendedTrustManager +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class WhitelistTrustManagerTest { + companion object { + @BeforeClass + @JvmStatic + fun registerTrustManager() { + // Validate original factory + assertEquals("PKIX", TrustManagerFactory.getDefaultAlgorithm()) + + //register for all tests + registerWhitelistTrustManager() + } + } + + private fun getTrustmanagerAndCert(whitelist: String, certificateName: String): Pair { + WhitelistTrustManagerProvider.addWhitelistEntry(whitelist) + + val caCertAndKey = X509Utilities.createSelfSignedCACert(certificateName) + + val keyStore = KeyStore.getInstance(KeyStore.getDefaultType()) + keyStore.load(null, null) + keyStore.setCertificateEntry("cacert", caCertAndKey.certificate) + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(keyStore) + + return Pair(trustManagerFactory.trustManagers.first() as X509ExtendedTrustManager, caCertAndKey.certificate) + } + + private fun getTrustmanagerAndUntrustedChainCert(): Pair { + WhitelistTrustManagerProvider.addWhitelistEntry("test.r3corda.com") + + val otherCaCertAndKey = X509Utilities.createSelfSignedCACert("bad root") + + val caCertAndKey = X509Utilities.createSelfSignedCACert("good root") + + val subject = X509Utilities.getDevX509Name("test.r3corda.com") + val serverKey = X509Utilities.generateECDSAKeyPairForSSL() + val serverCert = X509Utilities.createServerCert(subject, + serverKey.public, + otherCaCertAndKey, + listOf(), + listOf()) + + val keyStore = KeyStore.getInstance(KeyStore.getDefaultType()) + keyStore.load(null, null) + keyStore.setCertificateEntry("cacert", caCertAndKey.certificate) + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(keyStore) + + return Pair(trustManagerFactory.trustManagers.first() as X509ExtendedTrustManager, serverCert) + } + + + @Test + fun `getDefaultAlgorithm TrustManager is WhitelistTrustManager`() { + registerWhitelistTrustManager() // Check double register is safe + + assertEquals("whitelistTrustManager", TrustManagerFactory.getDefaultAlgorithm()) + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + + trustManagerFactory.init(null as KeyStore?) + + val trustManagers = trustManagerFactory.trustManagers + + assertTrue { trustManagers.all { it is WhitelistTrustManager } } + } + + @Test + fun `check certificate works for whitelisted certificate and specific domain`() { + val (trustManager, cert) = getTrustmanagerAndCert("test.r3corda.com", "test.r3corda.com") + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) + } + + @Test + fun `check certificate works for specific certificate and wildcard permitted domain`() { + val (trustManager, cert) = getTrustmanagerAndCert("*.r3corda.com", "test.r3corda.com") + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) + } + + @Test + fun `check certificate works for wildcard certificate and non wildcard domain`() { + val (trustManager, cert) = getTrustmanagerAndCert("*.r3corda.com", "test.r3corda.com") + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) + + trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) + + trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) + } + + @Test + fun `check unknown certificate rejected`() { + val (trustManager, cert) = getTrustmanagerAndCert("test.r3corda.com", "test.notr3.com") + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + } + + @Test + fun `check unknown wildcard certificate rejected`() { + val (trustManager, cert) = getTrustmanagerAndCert("test.r3corda.com", "*.notr3.com") + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + } + + @Test + fun `check unknown certificate rejected against mismatched wildcard`() { + val (trustManager, cert) = getTrustmanagerAndCert("*.r3corda.com", "test.notr3.com") + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + } + + @Test + fun `check certificate signed by untrusted root is still rejected, despite matched name`() { + val (trustManager, cert) = getTrustmanagerAndUntrustedChainCert() + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkServerTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as Socket?) } + + assertFailsWith { trustManager.checkClientTrusted(arrayOf(cert), X509Utilities.SIGNATURE_ALGORITHM, null as SSLEngine?) } + } +} \ No newline at end of file