From baaef30d5b32a1fb4cf5f0bf1aee954a70288839 Mon Sep 17 00:00:00 2001 From: Konstantinos Chalkias Date: Wed, 5 Jul 2017 16:14:18 +0100 Subject: [PATCH] CompositeKey validation checks (#956) --- .../net/corda/core/crypto/CompositeKey.kt | 112 ++++++++++++-- .../corda/core/crypto/CompositeKeyTests.kt | 137 +++++++++++++++++- 2 files changed, 234 insertions(+), 15 deletions(-) diff --git a/core/src/main/kotlin/net/corda/core/crypto/CompositeKey.kt b/core/src/main/kotlin/net/corda/core/crypto/CompositeKey.kt index 1e0ae94678..d7d3e8f205 100644 --- a/core/src/main/kotlin/net/corda/core/crypto/CompositeKey.kt +++ b/core/src/main/kotlin/net/corda/core/crypto/CompositeKey.kt @@ -5,6 +5,7 @@ import net.corda.core.serialization.CordaSerializable import org.bouncycastle.asn1.* import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo import java.security.PublicKey +import java.util.* /** * A tree data structure that enables the representation of composite public keys. @@ -28,16 +29,86 @@ class CompositeKey private constructor (val threshold: Int, children: List) : PublicKey { val children = children.sorted() init { - require (children.size == children.toSet().size) { "Trying to construct CompositeKey with duplicated child nodes." } - // If we want PublicKey we only keep one key, otherwise it will lead to semantically equivalent trees but having different structures. - require(children.size > 1) { "Cannot construct CompositeKey with only one child node." } + // TODO: replace with the more extensive, but slower, checkValidity() test. + checkConstraints() + } + + @Transient + private var validated = false + + // Check for key duplication, threshold and weight constraints and test for aggregated weight integer overflow. + private fun checkConstraints() { + require(children.size == children.toSet().size) { "CompositeKey with duplicated child nodes detected." } + // If we want PublicKey we only keep one key, otherwise it will lead to semantically equivalent trees + // but having different structures. + require(children.size > 1) { "CompositeKey must consist of two or more child nodes." } + // We should ensure threshold is positive, because smaller allowable weight for a node key is 1. + require(threshold > 0) { "CompositeKey threshold is set to $threshold, but it should be a positive integer." } + // If threshold is bigger than total weight, then it will never be satisfied. + val totalWeight = totalWeight() + require(threshold <= totalWeight) { "CompositeKey threshold: $threshold cannot be bigger than aggregated weight of " + + "child nodes: $totalWeight"} + } + + // Graph cycle detection in the composite key structure to avoid infinite loops on CompositeKey graph traversal and + // when recursion is used (i.e. in isFulfilledBy()). + // An IdentityHashMap Vs HashMap is used, because a graph cycle causes infinite loop on the CompositeKey.hashCode(). + private fun cycleDetection(visitedMap: IdentityHashMap) { + for ((node) in children) { + if (node is CompositeKey) { + val curVisitedMap = IdentityHashMap() + curVisitedMap.putAll(visitedMap) + require(!curVisitedMap.contains(node)) { "Cycle detected for CompositeKey: $node" } + curVisitedMap.put(node, true) + node.cycleDetection(curVisitedMap) + } + } + } + + /** + * This method will detect graph cycles in the full composite key structure to protect against infinite loops when + * traversing the graph and key duplicates in the each layer. It also checks if the threshold and weight constraint + * requirements are met, while it tests for aggregated-weight integer overflow. + * In practice, this method should be always invoked on the root [CompositeKey], as it inherently + * validates the child nodes (all the way till the leaves). + * TODO: Always call this method when deserialising [CompositeKey]s. + */ + fun checkValidity() { + val visitedMap = IdentityHashMap() + visitedMap.put(this, true) + cycleDetection(visitedMap) // Graph cycle testing on the root node. + checkConstraints() + for ((node, _) in children) { + if (node is CompositeKey) { + // We don't need to check for cycles on the rest of the nodes (testing on the root node is enough). + node.checkConstraints() + } + } + validated = true + } + + // Method to check if the total (aggregated) weight of child nodes overflows. + // Unlike similar solutions that use long conversion, this approach takes advantage of the minimum weight being 1. + private fun totalWeight(): Int { + var sum = 0 + for ((_, weight) in children) { + require (weight > 0) { "Non-positive weight: $weight detected." } + sum = Math.addExact(sum, weight) // Add and check for integer overflow. + } + return sum } /** * Holds node - weight pairs for a CompositeKey. Ordered first by weight, then by node's hashCode. + * Each node should be assigned with a positive weight to avoid certain types of weight underflow attacks. */ @CordaSerializable data class NodeAndWeight(val node: PublicKey, val weight: Int): Comparable, ASN1Object() { + + init { + // We don't allow zero or negative weights. Minimum weight = 1. + require (weight > 0) { "A non-positive weight was detected. Node info: $this" } + } override fun compareTo(other: NodeAndWeight): Int { if (weight == other.weight) { return node.hashCode().compareTo(other.node.hashCode()) @@ -51,6 +122,10 @@ class CompositeKey private constructor (val threshold: Int, vector.add(ASN1Integer(weight.toLong())) return DERSequence(vector) } + + override fun toString(): String { + return "Public key: ${node.toStringShort()}, weight: $weight" + } } companion object { @@ -75,21 +150,30 @@ class CompositeKey private constructor (val threshold: Int, } override fun getFormat() = ASN1Encoding.DER + // Extracted method from isFulfilledBy. + private fun checkFulfilledBy(keysToCheck: Iterable): Boolean { + if (keysToCheck.any { it is CompositeKey } ) return false + val totalWeight = children.map { (node, weight) -> + if (node is CompositeKey) { + if (node.checkFulfilledBy(keysToCheck)) weight else 0 + } else { + if (keysToCheck.contains(node)) weight else 0 + } + }.sum() + return totalWeight >= threshold + } + /** * Function checks if the public keys corresponding to the signatures are matched against the leaves of the composite * key tree in question, and the total combined weight of all children is calculated for every intermediary node. * If all thresholds are satisfied, the composite key requirement is considered to be met. */ fun isFulfilledBy(keysToCheck: Iterable): Boolean { - if (keysToCheck.any { it is CompositeKey } ) return false - val totalWeight = children.map { (node, weight) -> - if (node is CompositeKey) { - if (node.isFulfilledBy(keysToCheck)) weight else 0 - } else { - if (keysToCheck.contains(node)) weight else 0 - } - }.sum() - return totalWeight >= threshold + // We validate keys only when checking if they're matched, as this checks subkeys as a result. + // Doing these checks at deserialization/construction time would result in duplicate checks. + if (!validated) + checkValidity() // TODO: remove when checkValidity() will be eventually invoked during/after deserialization. + return checkFulfilledBy(keysToCheck) } /** @@ -134,14 +218,14 @@ class CompositeKey private constructor (val threshold: Int, /** * Builds the [CompositeKey]. If [threshold] is not specified, it will default to - * the size of the children, effectively generating an "N of N" requirement. + * the total (aggregated) weight of the children, effectively generating an "N of N" requirement. * During process removes single keys wrapped in [CompositeKey] and enforces ordering on child nodes. */ @Throws(IllegalArgumentException::class) fun build(threshold: Int? = null): PublicKey { val n = children.size if (n > 1) - return CompositeKey(threshold ?: n, children) + return CompositeKey(threshold ?: children.map { (_, weight) -> weight }.sum(), children) else if (n == 1) { require(threshold == null || threshold == children.first().weight) { "Trying to build invalid CompositeKey, threshold value different than weight of single child node." } diff --git a/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt b/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt index 8d2bed5a73..9029da4448 100644 --- a/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt +++ b/core/src/test/kotlin/net/corda/core/crypto/CompositeKeyTests.kt @@ -4,6 +4,7 @@ import net.corda.core.serialization.OpaqueBytes import net.corda.core.serialization.serialize import org.junit.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertFalse import kotlin.test.assertTrue @@ -21,7 +22,6 @@ class CompositeKeyTests { val aliceSignature = aliceKey.sign(message) val bobSignature = bobKey.sign(message) val charlieSignature = charlieKey.sign(message) - val compositeAliceSignature = CompositeSignaturesWithKeys(listOf(aliceSignature)) @Test fun `(Alice) fulfilled by Alice signature`() { @@ -124,4 +124,139 @@ class CompositeKeyTests { val brokenBobSignature = DigitalSignature.WithKey(bobSignature.by, aliceSignature.bytes) assertFalse { engine.verify(CompositeSignaturesWithKeys(listOf(aliceSignature, brokenBobSignature)).serialize().bytes) } } + + @Test() + fun `composite key constraints`() { + // Zero weight. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKey(alicePublicKey, 0) + } + // Negative weight. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKey(alicePublicKey, -1) + } + // Zero threshold. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKey(alicePublicKey).build(0) + } + // Negative threshold. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKey(alicePublicKey).build(-1) + } + // Threshold > Total-weight. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKey(alicePublicKey, 2).addKey(bobPublicKey, 2).build(5) + } + // Threshold value different than weight of single child node. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKey(alicePublicKey, 3).build(2) + } + // Aggregated weight integer overflow. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKey(alicePublicKey, Int.MAX_VALUE).addKey(bobPublicKey, Int.MAX_VALUE).build() + } + // Duplicated children. + assertFailsWith(IllegalArgumentException::class) { + CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey, alicePublicKey).build() + } + // Duplicated composite key children. + assertFailsWith(IllegalArgumentException::class) { + val compositeKey1 = CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey).build() + val compositeKey2 = CompositeKey.Builder().addKeys(bobPublicKey, alicePublicKey).build() + CompositeKey.Builder().addKeys(compositeKey1, compositeKey2).build() + } + } + + @Test() + fun `composite key validation with graph cycle detection`() { + val key1 = CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey).build() as CompositeKey + val key2 = CompositeKey.Builder().addKeys(alicePublicKey, key1).build() as CompositeKey + val key3 = CompositeKey.Builder().addKeys(alicePublicKey, key2).build() as CompositeKey + val key4 = CompositeKey.Builder().addKeys(alicePublicKey, key3).build() as CompositeKey + val key5 = CompositeKey.Builder().addKeys(alicePublicKey, key4).build() as CompositeKey + val key6 = CompositeKey.Builder().addKeys(alicePublicKey, key5, key2).build() as CompositeKey + + // Initially, there is no any graph cycle. + key1.checkValidity() + key2.checkValidity() + key3.checkValidity() + key4.checkValidity() + key5.checkValidity() + // The fact that key6 has a direct reference to key2 and an indirect (via path key5->key4->key3->key2) + // does not imply a cycle, as expected (independent paths). + key6.checkValidity() + + // We will create a graph cycle between key5 and key3. Key5 has already a reference to key3 (via key4). + // To create a cycle, we add a reference (child) from key3 to key5. + // Children list is immutable, so reflection is used to inject key5 as an extra NodeAndWeight child of key3. + val field = key3.javaClass.getDeclaredField("children") + field.isAccessible = true + val fixedChildren = key3.children.plus(CompositeKey.NodeAndWeight(key5, 1)) + field.set(key3, fixedChildren) + + /* A view of the example graph cycle. + * + * key6 + * / \ + * key5 key2 + * / + * key4 + * / + * key3 + * / \ + * key2 key5 + * / + * key1 + * + */ + + // Detect the graph cycle starting from key3. + assertFailsWith(IllegalArgumentException::class) { + key3.checkValidity() + } + + // Detect the graph cycle starting from key4. + assertFailsWith(IllegalArgumentException::class) { + key4.checkValidity() + } + + // Detect the graph cycle starting from key5. + assertFailsWith(IllegalArgumentException::class) { + key5.checkValidity() + } + + // Detect the graph cycle starting from key6. + // Typically, one needs to test on the root tree-node only (thus, a validity check on key6 would be enough). + assertFailsWith(IllegalArgumentException::class) { + key6.checkValidity() + } + + // Key2 (and all paths below it, i.e. key1) are outside the graph cycle and thus, there is no impact on them. + key2.checkValidity() + key1.checkValidity() + } + + @Test + fun `CompositeKey from multiple signature schemes and signature verification`() { + val (privRSA, pubRSA) = Crypto.generateKeyPair(Crypto.RSA_SHA256) + val (privK1, pubK1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256) + val (privR1, pubR1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256) + val (privEd, pubEd) = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512) + val (privSP, pubSP) = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256) + + val RSASignature = privRSA.sign(message.bytes, pubRSA) + val K1Signature = privK1.sign(message.bytes, pubK1) + val R1Signature = privR1.sign(message.bytes, pubR1) + val EdSignature = privEd.sign(message.bytes, pubEd) + val SPSignature = privSP.sign(message.bytes, pubSP) + + val compositeKey = CompositeKey.Builder().addKeys(pubRSA, pubK1, pubR1, pubEd, pubSP).build() as CompositeKey + + val signatures = listOf(RSASignature, K1Signature, R1Signature, EdSignature, SPSignature) + assertTrue { compositeKey.isFulfilledBy(signatures.byKeys()) } + + // One signature is missing. + val signaturesWithoutRSA = listOf(K1Signature, R1Signature, EdSignature, SPSignature) + assertFalse { compositeKey.isFulfilledBy(signaturesWithoutRSA.byKeys()) } + } }