diff --git a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt index a7aa8624c4..3632d5a9c0 100644 --- a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt @@ -33,6 +33,7 @@ import java.io.* import java.lang.reflect.Field import java.math.BigDecimal import java.net.HttpURLConnection +import java.net.HttpURLConnection.HTTP_OK import java.net.URL import java.nio.ByteBuffer import java.nio.charset.Charset @@ -352,10 +353,11 @@ val KClass<*>.packageName: String get() = java.`package`.name fun URL.openHttpConnection(): HttpURLConnection = openConnection() as HttpURLConnection -fun URL.post(serializedData: OpaqueBytes): ByteArray { +fun URL.post(serializedData: OpaqueBytes, vararg properties: Pair): ByteArray { return openHttpConnection().run { doOutput = true requestMethod = "POST" + properties.forEach { (key, value) -> setRequestProperty(key, value) } setRequestProperty("Content-Type", "application/octet-stream") outputStream.use { serializedData.open().copyTo(it) } checkOkResponse() @@ -364,12 +366,13 @@ fun URL.post(serializedData: OpaqueBytes): ByteArray { } fun HttpURLConnection.checkOkResponse() { - if (responseCode != 200) { - val message = errorStream.use { it.reader().readText() } - throw IOException("Response Code $responseCode: $message") + if (responseCode != HTTP_OK) { + throw IOException("Response Code $responseCode: $errorMessage") } } +val HttpURLConnection.errorMessage: String? get() = errorStream?.let { it.use { it.reader().readText() } } + inline fun HttpURLConnection.responseAs(): T { checkOkResponse() return inputStream.readObject() diff --git a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapClient.kt b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapClient.kt index 4508b7753b..7bd6b7aec9 100644 --- a/node/src/main/kotlin/net/corda/node/services/network/NetworkMapClient.kt +++ b/node/src/main/kotlin/net/corda/node/services/network/NetworkMapClient.kt @@ -48,7 +48,7 @@ class NetworkMapClient(compatibilityZoneURL: URL, val trustedRoot: X509Certifica val connection = networkMapUrl.openHttpConnection() val signedNetworkMap = connection.responseAs() val networkMap = signedNetworkMap.verifiedNetworkMapCert(trustedRoot) - val timeout = connection.cacheControl().maxAgeSeconds().seconds + val timeout = connection.cacheControl.maxAgeSeconds().seconds logger.trace { "Fetched network map update from $networkMapUrl successfully: $networkMap" } return NetworkMapResponse(networkMap, timeout) } diff --git a/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt b/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt index 7f3d0ddf24..6cf4b0e369 100644 --- a/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt +++ b/node/src/main/kotlin/net/corda/node/utilities/registration/HTTPNetworkRegistrationService.kt @@ -1,12 +1,13 @@ package net.corda.node.utilities.registration -import com.google.common.net.MediaType +import net.corda.core.internal.errorMessage import net.corda.core.internal.openHttpConnection +import net.corda.core.internal.post +import net.corda.core.utilities.OpaqueBytes import net.corda.core.utilities.seconds import net.corda.nodeapi.internal.crypto.X509CertificateFactory import okhttp3.CacheControl import okhttp3.Headers -import org.apache.commons.io.IOUtils import org.bouncycastle.pkcs.PKCS10CertificationRequest import java.io.IOException import java.net.HttpURLConnection @@ -21,7 +22,7 @@ class HTTPNetworkRegistrationService(compatibilityZoneURL: URL) : NetworkRegistr companion object { // TODO: Propagate version information from gradle - val clientVersion = "1.0" + const val CLIENT_VERSION = "1.0" } @Throws(CertificateRequestException::class) @@ -29,7 +30,7 @@ class HTTPNetworkRegistrationService(compatibilityZoneURL: URL) : NetworkRegistr // Poll server to download the signed certificate once request has been approved. val conn = URL("$registrationURL/$requestId").openHttpConnection() conn.requestMethod = "GET" - val maxAge = conn.cacheControl().maxAgeSeconds() + val maxAge = conn.cacheControl.maxAgeSeconds() // Default poll interval to 10 seconds if not specified by the server, for backward compatibility. val pollInterval = if (maxAge == -1) 10.seconds else maxAge.seconds @@ -44,33 +45,15 @@ class HTTPNetworkRegistrationService(compatibilityZoneURL: URL) : NetworkRegistr } HTTP_NO_CONTENT -> CertificateResponse(pollInterval, null) HTTP_UNAUTHORIZED -> throw CertificateRequestException("Certificate signing request has been rejected: ${conn.errorMessage}") - else -> throwUnexpectedResponseCode(conn) + else -> throw IOException("Response Code ${conn.responseCode}: ${conn.errorMessage}") } } override fun submitRequest(request: PKCS10CertificationRequest): String { - // Post request to certificate signing server via http. - val conn = URL("$registrationURL").openHttpConnection() - conn.doOutput = true - conn.requestMethod = "POST" - conn.setRequestProperty("Content-Type", "application/octet-stream") - conn.setRequestProperty("Client-Version", clientVersion) - conn.outputStream.write(request.encoded) - - return when (conn.responseCode) { - HTTP_OK -> IOUtils.toString(conn.inputStream, conn.charset) - HTTP_FORBIDDEN -> throw IOException("Client version $clientVersion is forbidden from accessing permissioning server, please upgrade to newer version.") - else -> throwUnexpectedResponseCode(conn) - } + return String(registrationURL.post(OpaqueBytes(request.encoded), "Client-Version" to CLIENT_VERSION)) } - - private fun throwUnexpectedResponseCode(connection: HttpURLConnection): Nothing { - throw IOException("Unexpected response code ${connection.responseCode} - ${connection.errorMessage}") - } - - private val HttpURLConnection.charset: String get() = MediaType.parse(contentType).charset().or(Charsets.UTF_8).name() - - private val HttpURLConnection.errorMessage: String get() = IOUtils.toString(errorStream, charset) } -fun HttpURLConnection.cacheControl(): CacheControl = CacheControl.parse(Headers.of(headerFields.filterKeys { it != null }.mapValues { it.value[0] })) +val HttpURLConnection.cacheControl: CacheControl get() { + return CacheControl.parse(Headers.of(headerFields.filterKeys { it != null }.mapValues { it.value[0] })) +}