From 9362ad28e8fd9c099a4fcc1d2c36827d830763b5 Mon Sep 17 00:00:00 2001
From: Konstantinos Chalkias <konstantinos.chalkias@r3.com>
Date: Tue, 9 May 2017 14:08:34 +0100
Subject: [PATCH] Check that a public key (EC point) lies on its corresponding
 curve. (#634)

Check that a public key EC point lies on its corresponding curve and it's not point at infinity.
---
 .../kotlin/net/corda/core/crypto/Crypto.kt    | 42 +++++++++++++++---
 .../net/corda/core/crypto/CryptoUtilsTest.kt  | 44 +++++++++++++++++--
 2 files changed, 76 insertions(+), 10 deletions(-)

diff --git a/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt b/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt
index 9f9a874198..e102c2bd99 100644
--- a/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt
+++ b/core/src/main/kotlin/net/corda/core/crypto/Crypto.kt
@@ -2,6 +2,7 @@ package net.corda.core.crypto
 
 import net.corda.core.random63BitValue
 import net.i2p.crypto.eddsa.*
+import net.i2p.crypto.eddsa.math.GroupElement
 import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec
 import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable
 import net.i2p.crypto.eddsa.spec.EdDSAPrivateKeySpec
@@ -18,9 +19,10 @@ import org.bouncycastle.asn1.x9.X9ObjectIdentifiers
 import org.bouncycastle.cert.bc.BcX509ExtensionUtils
 import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter
 import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder
+import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey
+import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey
 import org.bouncycastle.jcajce.provider.util.AsymmetricKeyInfoConverter
 import org.bouncycastle.jce.ECNamedCurveTable
-import org.bouncycastle.jce.interfaces.ECKey
 import org.bouncycastle.jce.provider.BouncyCastleProvider
 import org.bouncycastle.pkcs.PKCS10CertificationRequest
 import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder
@@ -184,19 +186,18 @@ object Crypto {
             if (algorithm == "SPHINCS-256") algorithm = "SPHINCS256" // because encoding may change algorithm name from SPHINCS256 to SPHINCS-256.
             if (algorithm == sig.algorithmName) {
                 // If more than one ECDSA schemes are supported, we should distinguish between them by checking their curve parameters.
-                // TODO: change 'continue' to 'break' if only one EdDSA curve will be used.
                 if (algorithm == "EdDSA") {
-                    if ((key as EdDSAKey).params == sig.algSpec) {
+                    if ((key is EdDSAPublicKey && publicKeyOnCurve(sig, key)) || (key is EdDSAPrivateKey && key.params == sig.algSpec)) {
                         return sig
-                    } else continue
+                    } else break // use continue if in the future we support more than one Edwards curves.
                 } else if (algorithm == "ECDSA") {
-                    if ((key as ECKey).parameters == sig.algSpec) {
+                    if ((key is BCECPublicKey && publicKeyOnCurve(sig, key)) || (key is BCECPrivateKey && key.parameters == sig.algSpec)) {
                         return sig
                     } else continue
                 } else return sig // it's either RSA_SHA256 or SPHINCS-256.
             }
         }
-        throw IllegalArgumentException("Unsupported key/algorithm for the private key: ${key.encoded.toBase58()}")
+        throw IllegalArgumentException("Unsupported key/algorithm for the key: ${key.encoded.toBase58()}")
     }
 
     /**
@@ -592,4 +593,33 @@ object Crypto {
         override fun generatePublic(keyInfo: SubjectPublicKeyInfo?): PublicKey? = keyInfo?.let { decodePublicKey(signatureScheme, it.encoded) }
         override fun generatePrivate(keyInfo: PrivateKeyInfo?): PrivateKey? = keyInfo?.let { decodePrivateKey(signatureScheme, it.encoded) }
     }
+
+    /**
+     * Check if a point's coordinates are on the expected curve to avoid certain types of ECC attacks.
+     * Point-at-infinity is not permitted as well.
+     * @see <a href="https://safecurves.cr.yp.to/twist.html">Small subgroup and invalid-curve attacks</a> for a more descriptive explanation on such attacks.
+     * We use this function on [findSignatureScheme] for a [PublicKey]; currently used for signature verification only.
+     * Thus, as these attacks are mostly not relevant to signature verification, we should note that
+     * we're doing it out of an abundance of caution and specifically to proactively protect developers
+     * against using these points as part of a DH key agreement or for use cases as yet unimagined.
+     * This method currently applies to BouncyCastle's ECDSA (both R1 and K1 curves) and I2P's EdDSA (ed25519 curve).
+     * @param publicKey a [PublicKey], usually used to validate a signer's public key in on the Curve.
+     * @param signatureScheme a [SignatureScheme] object, retrieved from supported signature schemes, see [Crypto].
+     * @return true if the point lies on the curve or false if it doesn't.
+     * @throws IllegalArgumentException if the requested signature scheme or the key type is not supported.
+     */
+    @Throws(IllegalArgumentException::class)
+    fun publicKeyOnCurve(signatureScheme: SignatureScheme, publicKey: PublicKey): Boolean {
+        if (!isSupportedSignatureScheme(signatureScheme))
+            throw IllegalArgumentException("Unsupported signature scheme: $signatureScheme.schemeCodeName")
+        when (publicKey) {
+            is BCECPublicKey -> return (publicKey.parameters == signatureScheme.algSpec && !publicKey.q.isInfinity && publicKey.q.isValid)
+            is EdDSAPublicKey -> return (publicKey.params == signatureScheme.algSpec && !isEdDSAPointAtInfinity(publicKey) && publicKey.a.isOnCurve)
+            else -> throw IllegalArgumentException("Unsupported key type: ${publicKey::class}")
+        }
+    }
+
+    // return true if EdDSA publicKey is point at infinity.
+    // For EdDSA a custom function is required as it is not supported by the I2P implementation.
+    private fun isEdDSAPointAtInfinity(publicKey: EdDSAPublicKey) = publicKey.a.toP3() == (EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec).curve.getZero(GroupElement.Representation.P3)
 }
diff --git a/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt b/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt
index 6f4e2acb01..5cf97d4a98 100644
--- a/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt
+++ b/core/src/test/kotlin/net/corda/core/crypto/CryptoUtilsTest.kt
@@ -2,20 +2,23 @@ package net.corda.core.crypto
 
 import com.google.common.collect.Sets
 import net.i2p.crypto.eddsa.EdDSAKey
+import net.i2p.crypto.eddsa.EdDSAPublicKey
+import net.i2p.crypto.eddsa.math.GroupElement
+import net.i2p.crypto.eddsa.spec.EdDSANamedCurveSpec
 import net.i2p.crypto.eddsa.spec.EdDSANamedCurveTable
+import net.i2p.crypto.eddsa.spec.EdDSAPublicKeySpec
 import org.bouncycastle.asn1.pkcs.PrivateKeyInfo
 import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
+import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey
 import org.bouncycastle.jce.ECNamedCurveTable
 import org.bouncycastle.jce.interfaces.ECKey
 import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PrivateKey
 import org.bouncycastle.pqc.jcajce.provider.sphincs.BCSphincs256PublicKey
 import org.junit.Assert.assertNotEquals
 import org.junit.Test
+import java.security.KeyPairGenerator
 import java.util.*
-import kotlin.test.assertEquals
-import kotlin.test.assertNotNull
-import kotlin.test.assertTrue
-import kotlin.test.fail
+import kotlin.test.*
 
 /**
  * Run tests for cryptographic algorithms
@@ -620,4 +623,37 @@ class CryptoUtilsTest {
             encodedPrivK1[i] = b.dec()
         }
     }
+
+    @Test
+    fun `Check ECDSA public key on curve`() {
+        val keyPairK1 = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256)
+        val pubK1 = keyPairK1.public as BCECPublicKey
+        assertTrue(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256K1_SHA256, pubK1))
+        // use R1 curve for check.
+        assertFalse(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256R1_SHA256, pubK1))
+        // use ed25519 curve for check.
+        assertFalse(Crypto.publicKeyOnCurve(Crypto.EDDSA_ED25519_SHA512, pubK1))
+    }
+
+    @Test
+    fun `Check EdDSA public key on curve`() {
+        val keyPairEdDSA = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512)
+        val pubEdDSA = keyPairEdDSA.public
+        assertTrue(Crypto.publicKeyOnCurve(Crypto.EDDSA_ED25519_SHA512, pubEdDSA))
+        // use R1 curve for check.
+        assertFalse(Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256R1_SHA256, pubEdDSA))
+        // check for point at infinity.
+        val pubKeySpec = EdDSAPublicKeySpec((Crypto.EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec).curve.getZero(GroupElement.Representation.P3), Crypto.EDDSA_ED25519_SHA512.algSpec as EdDSANamedCurveSpec)
+        assertFalse(Crypto.publicKeyOnCurve(Crypto.EDDSA_ED25519_SHA512, EdDSAPublicKey(pubKeySpec)))
+    }
+
+    @Test(expected = IllegalArgumentException::class)
+    fun `Unsupported EC public key type on curve`() {
+        val keyGen = KeyPairGenerator.getInstance("EC") // sun.security.ec.ECPublicKeyImpl
+        keyGen.initialize(256, newSecureRandom())
+        val pairSun = keyGen.generateKeyPair()
+        val pubSun = pairSun.getPublic()
+        // should fail as pubSun is not a BCECPublicKey.
+        Crypto.publicKeyOnCurve(Crypto.ECDSA_SECP256R1_SHA256, pubSun)
+    }
 }