diff --git a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt index 73e6d7819d..97ef208514 100644 --- a/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt +++ b/core/src/main/kotlin/net/corda/core/internal/InternalUtils.kt @@ -378,10 +378,16 @@ val CordaX500Name.x500Name: X500Name val CordaX500Name.Companion.unspecifiedCountry get() = "ZZ" -fun T.signWithCert(privateKey: PrivateKey, certificate: X509Certificate): SignedDataWithCert { +inline fun T.signWithCert(signer: (SerializedBytes) -> DigitalSignatureWithCert): SignedDataWithCert { val serialised = serialize() - val signature = Crypto.doSign(privateKey, serialised.bytes) - return SignedDataWithCert(serialised, DigitalSignatureWithCert(certificate, signature)) + return SignedDataWithCert(serialised, signer(serialised)) +} + +fun T.signWithCert(privateKey: PrivateKey, certificate: X509Certificate): SignedDataWithCert { + return signWithCert { + val signature = Crypto.doSign(privateKey, it.bytes) + DigitalSignatureWithCert(certificate, signature) + } } inline fun SerializedBytes.sign(signer: (SerializedBytes) -> DigitalSignature.WithKey): SignedData { diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt index b23a8fc861..4d93816756 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/crypto/X509Utilities.kt @@ -4,10 +4,7 @@ import net.corda.core.CordaOID import net.corda.core.crypto.Crypto import net.corda.core.crypto.SignatureScheme import net.corda.core.crypto.random63BitValue -import net.corda.core.internal.CertRole -import net.corda.core.internal.reader -import net.corda.core.internal.uncheckedCast -import net.corda.core.internal.writer +import net.corda.core.internal.* import net.corda.core.utilities.days import net.corda.core.utilities.millis import org.bouncycastle.asn1.* @@ -415,4 +412,6 @@ enum class CertificateType(val keyUsage: KeyUsage, vararg val purposes: KeyPurpo ) } -data class CertificateAndKeyPair(val certificate: X509Certificate, val keyPair: KeyPair) +data class CertificateAndKeyPair(val certificate: X509Certificate, val keyPair: KeyPair) { + fun sign(obj: T): SignedDataWithCert = obj.signWithCert(keyPair.private, certificate) +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkMap.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkMap.kt index 2e652d9279..0433aed3f2 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkMap.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkMap.kt @@ -2,10 +2,13 @@ package net.corda.nodeapi.internal.network import net.corda.core.crypto.SecureHash import net.corda.core.internal.CertRole +import net.corda.core.internal.DigitalSignatureWithCert import net.corda.core.internal.SignedDataWithCert +import net.corda.core.internal.signWithCert import net.corda.core.node.NetworkParameters import net.corda.core.node.NodeInfo import net.corda.core.serialization.CordaSerializable +import net.corda.core.serialization.SerializedBytes import net.corda.nodeapi.internal.crypto.X509Utilities import java.security.cert.X509Certificate import java.time.Instant @@ -53,3 +56,10 @@ fun SignedDataWithCert.verifiedNetworkMapCert(rootCert: X509Certifi X509Utilities.validateCertificateChain(rootCert, sig.by, rootCert) return verified() } + +class NetworkMapAndSigned private constructor(val networkMap: NetworkMap, val signed: SignedNetworkMap) { + constructor(networkMap: NetworkMap, signer: (SerializedBytes) -> DigitalSignatureWithCert) : this(networkMap, networkMap.signWithCert(signer)) + constructor(signed: SignedNetworkMap) : this(signed.verified(), signed) + operator fun component1(): NetworkMap = networkMap + operator fun component2(): SignedNetworkMap = signed +} diff --git a/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkParametersCopier.kt b/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkParametersCopier.kt index 18376251a7..df2e325605 100644 --- a/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkParametersCopier.kt +++ b/node-api/src/main/kotlin/net/corda/nodeapi/internal/network/NetworkParametersCopier.kt @@ -1,6 +1,8 @@ package net.corda.nodeapi.internal.network -import net.corda.core.internal.* +import net.corda.core.internal.VisibleForTesting +import net.corda.core.internal.copyTo +import net.corda.core.internal.div import net.corda.core.node.NetworkParameters import net.corda.core.serialization.serialize import net.corda.nodeapi.internal.createDevNetworkMapCa @@ -11,16 +13,13 @@ import java.nio.file.StandardCopyOption class NetworkParametersCopier( networkParameters: NetworkParameters, - networkMapCa: CertificateAndKeyPair = createDevNetworkMapCa(), + signingCertAndKeyPair: CertificateAndKeyPair = createDevNetworkMapCa(), overwriteFile: Boolean = false, @VisibleForTesting val update: Boolean = false ) { private val copyOptions = if (overwriteFile) arrayOf(StandardCopyOption.REPLACE_EXISTING) else emptyArray() - private val serialisedSignedNetParams = networkParameters.signWithCert( - networkMapCa.keyPair.private, - networkMapCa.certificate - ).serialize() + private val serialisedSignedNetParams = signingCertAndKeyPair.sign(networkParameters).serialize() fun install(nodeDir: Path) { val fileName = if (update) NETWORK_PARAMS_UPDATE_FILE_NAME else NETWORK_PARAMS_FILE_NAME diff --git a/node/src/test/kotlin/net/corda/node/services/network/NetworkMapUpdaterTest.kt b/node/src/test/kotlin/net/corda/node/services/network/NetworkMapUpdaterTest.kt index 2e77979c24..2f6535a74f 100644 --- a/node/src/test/kotlin/net/corda/node/services/network/NetworkMapUpdaterTest.kt +++ b/node/src/test/kotlin/net/corda/node/services/network/NetworkMapUpdaterTest.kt @@ -49,7 +49,7 @@ class NetworkMapUpdaterTest { private val networkMapCache = createMockNetworkMapCache() private val nodeInfoMap = ConcurrentHashMap() private val networkParamsMap = HashMap() - private val networkMapCa: CertificateAndKeyPair = createDevNetworkMapCa() + private val networkMapCertAndKeyPair: CertificateAndKeyPair = createDevNetworkMapCa() private val cacheExpiryMs = 100 private val networkMapClient = createMockNetworkMapClient() private val scheduler = TestScheduler() @@ -254,7 +254,7 @@ class NetworkMapUpdaterTest { } on { getNetworkParameters(any()) }.then { val paramsHash: SecureHash = uncheckedCast(it.arguments[0]) - networkParamsMap[paramsHash]?.signWithCert(networkMapCa.keyPair.private, networkMapCa.certificate) + networkParamsMap[paramsHash]?.let { networkMapCertAndKeyPair.sign(it) } } on { ackNetworkParametersUpdate(any()) }.then { Unit diff --git a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/NetworkMapServer.kt b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/NetworkMapServer.kt index 4d0e952ce2..763cce2bf3 100644 --- a/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/NetworkMapServer.kt +++ b/testing/node-driver/src/main/kotlin/net/corda/testing/node/internal/network/NetworkMapServer.kt @@ -35,7 +35,7 @@ import javax.ws.rs.core.Response.status class NetworkMapServer(private val cacheTimeout: Duration, hostAndPort: NetworkHostAndPort, - private val networkMapCa: CertificateAndKeyPair = createDevNetworkMapCa(), + private val networkMapCertAndKeyPair: CertificateAndKeyPair = createDevNetworkMapCa(), private val myHostNameValue: String = "test.host.name", vararg additionalServices: Any) : Closeable { companion object { @@ -108,9 +108,7 @@ class NetworkMapServer(private val cacheTimeout: Duration, inner class InMemoryNetworkMapService { private val nodeInfoMap = mutableMapOf() val latestAcceptedParametersMap = mutableMapOf() - private val signedNetParams by lazy { - networkParameters.signWithCert(networkMapCa.keyPair.private, networkMapCa.certificate) - } + private val signedNetParams by lazy { networkMapCertAndKeyPair.sign(networkParameters) } @POST @Path("publish") @@ -143,7 +141,7 @@ class NetworkMapServer(private val cacheTimeout: Duration, @Produces(MediaType.APPLICATION_OCTET_STREAM) fun getNetworkMap(): Response { val networkMap = NetworkMap(nodeInfoMap.keys.toList(), signedNetParams.raw.hash, parametersUpdate) - val signedNetworkMap = networkMap.signWithCert(networkMapCa.keyPair.private, networkMapCa.certificate) + val signedNetworkMap = networkMapCertAndKeyPair.sign(networkMap) return Response.ok(signedNetworkMap.serialize().bytes).header("Cache-Control", "max-age=${cacheTimeout.seconds}").build() } @@ -172,8 +170,10 @@ class NetworkMapServer(private val cacheTimeout: Duration, val requestedParameters = if (requestedHash == signedNetParams.raw.hash) { signedNetParams } else if (requestedHash == nextNetworkParameters?.serialize()?.hash) { - nextNetworkParameters?.signWithCert(networkMapCa.keyPair.private, networkMapCa.certificate) - } else null + nextNetworkParameters?.let { networkMapCertAndKeyPair.sign(it) } + } else { + null + } requireNotNull(requestedParameters) return Response.ok(requestedParameters!!.serialize().bytes).build() }