mirror of
https://github.com/corda/corda.git
synced 2025-03-11 15:04:14 +00:00
CompositeKey validation checks (#956)
This commit is contained in:
parent
4e355ba95e
commit
baaef30d5b
@ -5,6 +5,7 @@ import net.corda.core.serialization.CordaSerializable
|
||||
import org.bouncycastle.asn1.*
|
||||
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
|
||||
import java.security.PublicKey
|
||||
import java.util.*
|
||||
|
||||
/**
|
||||
* A tree data structure that enables the representation of composite public keys.
|
||||
@ -28,16 +29,86 @@ class CompositeKey private constructor (val threshold: Int,
|
||||
children: List<NodeAndWeight>) : PublicKey {
|
||||
val children = children.sorted()
|
||||
init {
|
||||
require (children.size == children.toSet().size) { "Trying to construct CompositeKey with duplicated child nodes." }
|
||||
// If we want PublicKey we only keep one key, otherwise it will lead to semantically equivalent trees but having different structures.
|
||||
require(children.size > 1) { "Cannot construct CompositeKey with only one child node." }
|
||||
// TODO: replace with the more extensive, but slower, checkValidity() test.
|
||||
checkConstraints()
|
||||
}
|
||||
|
||||
@Transient
|
||||
private var validated = false
|
||||
|
||||
// Check for key duplication, threshold and weight constraints and test for aggregated weight integer overflow.
|
||||
private fun checkConstraints() {
|
||||
require(children.size == children.toSet().size) { "CompositeKey with duplicated child nodes detected." }
|
||||
// If we want PublicKey we only keep one key, otherwise it will lead to semantically equivalent trees
|
||||
// but having different structures.
|
||||
require(children.size > 1) { "CompositeKey must consist of two or more child nodes." }
|
||||
// We should ensure threshold is positive, because smaller allowable weight for a node key is 1.
|
||||
require(threshold > 0) { "CompositeKey threshold is set to $threshold, but it should be a positive integer." }
|
||||
// If threshold is bigger than total weight, then it will never be satisfied.
|
||||
val totalWeight = totalWeight()
|
||||
require(threshold <= totalWeight) { "CompositeKey threshold: $threshold cannot be bigger than aggregated weight of " +
|
||||
"child nodes: $totalWeight"}
|
||||
}
|
||||
|
||||
// Graph cycle detection in the composite key structure to avoid infinite loops on CompositeKey graph traversal and
|
||||
// when recursion is used (i.e. in isFulfilledBy()).
|
||||
// An IdentityHashMap Vs HashMap is used, because a graph cycle causes infinite loop on the CompositeKey.hashCode().
|
||||
private fun cycleDetection(visitedMap: IdentityHashMap<CompositeKey, Boolean>) {
|
||||
for ((node) in children) {
|
||||
if (node is CompositeKey) {
|
||||
val curVisitedMap = IdentityHashMap<CompositeKey, Boolean>()
|
||||
curVisitedMap.putAll(visitedMap)
|
||||
require(!curVisitedMap.contains(node)) { "Cycle detected for CompositeKey: $node" }
|
||||
curVisitedMap.put(node, true)
|
||||
node.cycleDetection(curVisitedMap)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* This method will detect graph cycles in the full composite key structure to protect against infinite loops when
|
||||
* traversing the graph and key duplicates in the each layer. It also checks if the threshold and weight constraint
|
||||
* requirements are met, while it tests for aggregated-weight integer overflow.
|
||||
* In practice, this method should be always invoked on the root [CompositeKey], as it inherently
|
||||
* validates the child nodes (all the way till the leaves).
|
||||
* TODO: Always call this method when deserialising [CompositeKey]s.
|
||||
*/
|
||||
fun checkValidity() {
|
||||
val visitedMap = IdentityHashMap<CompositeKey,Boolean>()
|
||||
visitedMap.put(this, true)
|
||||
cycleDetection(visitedMap) // Graph cycle testing on the root node.
|
||||
checkConstraints()
|
||||
for ((node, _) in children) {
|
||||
if (node is CompositeKey) {
|
||||
// We don't need to check for cycles on the rest of the nodes (testing on the root node is enough).
|
||||
node.checkConstraints()
|
||||
}
|
||||
}
|
||||
validated = true
|
||||
}
|
||||
|
||||
// Method to check if the total (aggregated) weight of child nodes overflows.
|
||||
// Unlike similar solutions that use long conversion, this approach takes advantage of the minimum weight being 1.
|
||||
private fun totalWeight(): Int {
|
||||
var sum = 0
|
||||
for ((_, weight) in children) {
|
||||
require (weight > 0) { "Non-positive weight: $weight detected." }
|
||||
sum = Math.addExact(sum, weight) // Add and check for integer overflow.
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
/**
|
||||
* Holds node - weight pairs for a CompositeKey. Ordered first by weight, then by node's hashCode.
|
||||
* Each node should be assigned with a positive weight to avoid certain types of weight underflow attacks.
|
||||
*/
|
||||
@CordaSerializable
|
||||
data class NodeAndWeight(val node: PublicKey, val weight: Int): Comparable<NodeAndWeight>, ASN1Object() {
|
||||
|
||||
init {
|
||||
// We don't allow zero or negative weights. Minimum weight = 1.
|
||||
require (weight > 0) { "A non-positive weight was detected. Node info: $this" }
|
||||
}
|
||||
override fun compareTo(other: NodeAndWeight): Int {
|
||||
if (weight == other.weight) {
|
||||
return node.hashCode().compareTo(other.node.hashCode())
|
||||
@ -51,6 +122,10 @@ class CompositeKey private constructor (val threshold: Int,
|
||||
vector.add(ASN1Integer(weight.toLong()))
|
||||
return DERSequence(vector)
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "Public key: ${node.toStringShort()}, weight: $weight"
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
@ -75,21 +150,30 @@ class CompositeKey private constructor (val threshold: Int,
|
||||
}
|
||||
override fun getFormat() = ASN1Encoding.DER
|
||||
|
||||
// Extracted method from isFulfilledBy.
|
||||
private fun checkFulfilledBy(keysToCheck: Iterable<PublicKey>): Boolean {
|
||||
if (keysToCheck.any { it is CompositeKey } ) return false
|
||||
val totalWeight = children.map { (node, weight) ->
|
||||
if (node is CompositeKey) {
|
||||
if (node.checkFulfilledBy(keysToCheck)) weight else 0
|
||||
} else {
|
||||
if (keysToCheck.contains(node)) weight else 0
|
||||
}
|
||||
}.sum()
|
||||
return totalWeight >= threshold
|
||||
}
|
||||
|
||||
/**
|
||||
* Function checks if the public keys corresponding to the signatures are matched against the leaves of the composite
|
||||
* key tree in question, and the total combined weight of all children is calculated for every intermediary node.
|
||||
* If all thresholds are satisfied, the composite key requirement is considered to be met.
|
||||
*/
|
||||
fun isFulfilledBy(keysToCheck: Iterable<PublicKey>): Boolean {
|
||||
if (keysToCheck.any { it is CompositeKey } ) return false
|
||||
val totalWeight = children.map { (node, weight) ->
|
||||
if (node is CompositeKey) {
|
||||
if (node.isFulfilledBy(keysToCheck)) weight else 0
|
||||
} else {
|
||||
if (keysToCheck.contains(node)) weight else 0
|
||||
}
|
||||
}.sum()
|
||||
return totalWeight >= threshold
|
||||
// We validate keys only when checking if they're matched, as this checks subkeys as a result.
|
||||
// Doing these checks at deserialization/construction time would result in duplicate checks.
|
||||
if (!validated)
|
||||
checkValidity() // TODO: remove when checkValidity() will be eventually invoked during/after deserialization.
|
||||
return checkFulfilledBy(keysToCheck)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -134,14 +218,14 @@ class CompositeKey private constructor (val threshold: Int,
|
||||
|
||||
/**
|
||||
* Builds the [CompositeKey]. If [threshold] is not specified, it will default to
|
||||
* the size of the children, effectively generating an "N of N" requirement.
|
||||
* the total (aggregated) weight of the children, effectively generating an "N of N" requirement.
|
||||
* During process removes single keys wrapped in [CompositeKey] and enforces ordering on child nodes.
|
||||
*/
|
||||
@Throws(IllegalArgumentException::class)
|
||||
fun build(threshold: Int? = null): PublicKey {
|
||||
val n = children.size
|
||||
if (n > 1)
|
||||
return CompositeKey(threshold ?: n, children)
|
||||
return CompositeKey(threshold ?: children.map { (_, weight) -> weight }.sum(), children)
|
||||
else if (n == 1) {
|
||||
require(threshold == null || threshold == children.first().weight)
|
||||
{ "Trying to build invalid CompositeKey, threshold value different than weight of single child node." }
|
||||
|
@ -4,6 +4,7 @@ import net.corda.core.serialization.OpaqueBytes
|
||||
import net.corda.core.serialization.serialize
|
||||
import org.junit.Test
|
||||
import kotlin.test.assertEquals
|
||||
import kotlin.test.assertFailsWith
|
||||
import kotlin.test.assertFalse
|
||||
import kotlin.test.assertTrue
|
||||
|
||||
@ -21,7 +22,6 @@ class CompositeKeyTests {
|
||||
val aliceSignature = aliceKey.sign(message)
|
||||
val bobSignature = bobKey.sign(message)
|
||||
val charlieSignature = charlieKey.sign(message)
|
||||
val compositeAliceSignature = CompositeSignaturesWithKeys(listOf(aliceSignature))
|
||||
|
||||
@Test
|
||||
fun `(Alice) fulfilled by Alice signature`() {
|
||||
@ -124,4 +124,139 @@ class CompositeKeyTests {
|
||||
val brokenBobSignature = DigitalSignature.WithKey(bobSignature.by, aliceSignature.bytes)
|
||||
assertFalse { engine.verify(CompositeSignaturesWithKeys(listOf(aliceSignature, brokenBobSignature)).serialize().bytes) }
|
||||
}
|
||||
|
||||
@Test()
|
||||
fun `composite key constraints`() {
|
||||
// Zero weight.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKey(alicePublicKey, 0)
|
||||
}
|
||||
// Negative weight.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKey(alicePublicKey, -1)
|
||||
}
|
||||
// Zero threshold.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKey(alicePublicKey).build(0)
|
||||
}
|
||||
// Negative threshold.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKey(alicePublicKey).build(-1)
|
||||
}
|
||||
// Threshold > Total-weight.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKey(alicePublicKey, 2).addKey(bobPublicKey, 2).build(5)
|
||||
}
|
||||
// Threshold value different than weight of single child node.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKey(alicePublicKey, 3).build(2)
|
||||
}
|
||||
// Aggregated weight integer overflow.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKey(alicePublicKey, Int.MAX_VALUE).addKey(bobPublicKey, Int.MAX_VALUE).build()
|
||||
}
|
||||
// Duplicated children.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey, alicePublicKey).build()
|
||||
}
|
||||
// Duplicated composite key children.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
val compositeKey1 = CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey).build()
|
||||
val compositeKey2 = CompositeKey.Builder().addKeys(bobPublicKey, alicePublicKey).build()
|
||||
CompositeKey.Builder().addKeys(compositeKey1, compositeKey2).build()
|
||||
}
|
||||
}
|
||||
|
||||
@Test()
|
||||
fun `composite key validation with graph cycle detection`() {
|
||||
val key1 = CompositeKey.Builder().addKeys(alicePublicKey, bobPublicKey).build() as CompositeKey
|
||||
val key2 = CompositeKey.Builder().addKeys(alicePublicKey, key1).build() as CompositeKey
|
||||
val key3 = CompositeKey.Builder().addKeys(alicePublicKey, key2).build() as CompositeKey
|
||||
val key4 = CompositeKey.Builder().addKeys(alicePublicKey, key3).build() as CompositeKey
|
||||
val key5 = CompositeKey.Builder().addKeys(alicePublicKey, key4).build() as CompositeKey
|
||||
val key6 = CompositeKey.Builder().addKeys(alicePublicKey, key5, key2).build() as CompositeKey
|
||||
|
||||
// Initially, there is no any graph cycle.
|
||||
key1.checkValidity()
|
||||
key2.checkValidity()
|
||||
key3.checkValidity()
|
||||
key4.checkValidity()
|
||||
key5.checkValidity()
|
||||
// The fact that key6 has a direct reference to key2 and an indirect (via path key5->key4->key3->key2)
|
||||
// does not imply a cycle, as expected (independent paths).
|
||||
key6.checkValidity()
|
||||
|
||||
// We will create a graph cycle between key5 and key3. Key5 has already a reference to key3 (via key4).
|
||||
// To create a cycle, we add a reference (child) from key3 to key5.
|
||||
// Children list is immutable, so reflection is used to inject key5 as an extra NodeAndWeight child of key3.
|
||||
val field = key3.javaClass.getDeclaredField("children")
|
||||
field.isAccessible = true
|
||||
val fixedChildren = key3.children.plus(CompositeKey.NodeAndWeight(key5, 1))
|
||||
field.set(key3, fixedChildren)
|
||||
|
||||
/* A view of the example graph cycle.
|
||||
*
|
||||
* key6
|
||||
* / \
|
||||
* key5 key2
|
||||
* /
|
||||
* key4
|
||||
* /
|
||||
* key3
|
||||
* / \
|
||||
* key2 key5
|
||||
* /
|
||||
* key1
|
||||
*
|
||||
*/
|
||||
|
||||
// Detect the graph cycle starting from key3.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
key3.checkValidity()
|
||||
}
|
||||
|
||||
// Detect the graph cycle starting from key4.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
key4.checkValidity()
|
||||
}
|
||||
|
||||
// Detect the graph cycle starting from key5.
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
key5.checkValidity()
|
||||
}
|
||||
|
||||
// Detect the graph cycle starting from key6.
|
||||
// Typically, one needs to test on the root tree-node only (thus, a validity check on key6 would be enough).
|
||||
assertFailsWith(IllegalArgumentException::class) {
|
||||
key6.checkValidity()
|
||||
}
|
||||
|
||||
// Key2 (and all paths below it, i.e. key1) are outside the graph cycle and thus, there is no impact on them.
|
||||
key2.checkValidity()
|
||||
key1.checkValidity()
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `CompositeKey from multiple signature schemes and signature verification`() {
|
||||
val (privRSA, pubRSA) = Crypto.generateKeyPair(Crypto.RSA_SHA256)
|
||||
val (privK1, pubK1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256K1_SHA256)
|
||||
val (privR1, pubR1) = Crypto.generateKeyPair(Crypto.ECDSA_SECP256R1_SHA256)
|
||||
val (privEd, pubEd) = Crypto.generateKeyPair(Crypto.EDDSA_ED25519_SHA512)
|
||||
val (privSP, pubSP) = Crypto.generateKeyPair(Crypto.SPHINCS256_SHA256)
|
||||
|
||||
val RSASignature = privRSA.sign(message.bytes, pubRSA)
|
||||
val K1Signature = privK1.sign(message.bytes, pubK1)
|
||||
val R1Signature = privR1.sign(message.bytes, pubR1)
|
||||
val EdSignature = privEd.sign(message.bytes, pubEd)
|
||||
val SPSignature = privSP.sign(message.bytes, pubSP)
|
||||
|
||||
val compositeKey = CompositeKey.Builder().addKeys(pubRSA, pubK1, pubR1, pubEd, pubSP).build() as CompositeKey
|
||||
|
||||
val signatures = listOf(RSASignature, K1Signature, R1Signature, EdSignature, SPSignature)
|
||||
assertTrue { compositeKey.isFulfilledBy(signatures.byKeys()) }
|
||||
|
||||
// One signature is missing.
|
||||
val signaturesWithoutRSA = listOf(K1Signature, R1Signature, EdSignature, SPSignature)
|
||||
assertFalse { compositeKey.isFulfilledBy(signaturesWithoutRSA.byKeys()) }
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user