Change to a ConcurrentHashSet whitelist so that we don't hold any locks across the DNS lookup.

This commit is contained in:
Matthew Nesbit 2016-07-26 14:40:30 +01:00
parent 4c4484b820
commit fea452d9ac

View File

@ -9,6 +9,7 @@ import java.security.Provider
import java.security.Security import java.security.Security
import java.security.cert.CertificateException import java.security.cert.CertificateException
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.ConcurrentHashMap
import javax.net.ssl.* import javax.net.ssl.*
/** /**
@ -37,11 +38,11 @@ object WhitelistTrustManagerProvider : Provider("WhitelistTrustManager",
val originalTrustProviderAlgorithm = Security.getProperty("ssl.TrustManagerFactory.algorithm") val originalTrustProviderAlgorithm = Security.getProperty("ssl.TrustManagerFactory.algorithm")
private val _whitelist = mutableSetOf<String>() private val _whitelist = ConcurrentHashMap.newKeySet<String>()
val whitelist: Set<String> get() = _whitelist.toSet() // The acceptable IP and DNS names for clients and servers. val whitelist: Set<String> get() = _whitelist.toSet() // The acceptable IP and DNS names for clients and servers.
init { init {
// Add ourselves to whitelist // Add ourselves to whitelist as currently we have to connect to a local ArtemisMQ broker
val host = InetAddress.getLocalHost() val host = InetAddress.getLocalHost()
addWhitelistEntry(host.hostName) addWhitelistEntry(host.hostName)
@ -52,7 +53,7 @@ object WhitelistTrustManagerProvider : Provider("WhitelistTrustManager",
// This will apply to all code using TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) // This will apply to all code using TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
// Which includes the standard HTTPS implementation and most other SSL code // Which includes the standard HTTPS implementation and most other SSL code
// TrustManagerFactory.getInstance(WhitelistTrustManagerProvider.originalTrustProviderAlgorithm)) will // TrustManagerFactory.getInstance(WhitelistTrustManagerProvider.originalTrustProviderAlgorithm)) will
// Allow access to the original implementation which is normally "PKIX" // allow access to the original implementation which is normally "PKIX"
Security.setProperty("ssl.TrustManagerFactory.algorithm", "whitelistTrustManager") Security.setProperty("ssl.TrustManagerFactory.algorithm", "whitelistTrustManager")
} }
@ -61,29 +62,25 @@ object WhitelistTrustManagerProvider : Provider("WhitelistTrustManager",
* If this is a new entry it will internally request a DNS lookup which may block the calling thread. * If this is a new entry it will internally request a DNS lookup which may block the calling thread.
*/ */
fun addWhitelistEntry(serverName: String) { fun addWhitelistEntry(serverName: String) {
synchronized(WhitelistTrustManagerProvider) { if (!_whitelist.contains(serverName)) { // Double check locking to avoid DNS cost. Safe as we never delete from the set
if (!_whitelist.contains(serverName)) { addWhitelistEntries(listOf(serverName))
addWhitelistEntries(listOf(serverName))
}
} }
} }
/** /**
* Adds a list of servers to the whitelist and also adds their fully resolved name after DNS lookup * Adds a list of servers to the whitelist and also adds their fully resolved name/ip address after DNS lookup
* If the server name is not an actual DNS name this is silently ignored * If the server name is not an actual DNS name this is silently ignored.
* The DNS request may block the calling thread. * The DNS request may block the calling thread.
*/ */
fun addWhitelistEntries(serverNames: List<String>) { fun addWhitelistEntries(serverNames: List<String>) {
synchronized(WhitelistTrustManagerProvider) { _whitelist.addAll(serverNames)
_whitelist.addAll(serverNames) for (name in serverNames) {
for (name in serverNames) { try {
try { val addresses = InetAddress.getAllByName(name).toList()
val addresses = InetAddress.getAllByName(name).toList() _whitelist.addAll(addresses.map { y -> y.canonicalHostName })
_whitelist.addAll(addresses.map { y -> y.canonicalHostName }) _whitelist.addAll(addresses.map { y -> y.hostAddress })
_whitelist.addAll(addresses.map { y -> y.hostAddress }) } catch (ex: UnknownHostException) {
} catch (ex: UnknownHostException) { // Ignore if the server name is not resolvable e.g. for wildcard addresses, or addresses that can only be resolved externally
// Ignore if the server name is not resolvable e.g. for wildcard addresses, or addresses that can only be resolved externally
}
} }
} }
} }