Backward compatible fix for Merkle tree case where a single leaf becomes the tree root (#6895)

This commit is contained in:
Edoardo Ierina 2021-04-12 11:09:39 +01:00 committed by GitHub
parent bde2d2a8c7
commit 241170ffa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 242 additions and 23 deletions

View File

@ -5,8 +5,11 @@ import com.nhaarman.mockito_kotlin.mock
import com.nhaarman.mockito_kotlin.whenever
import net.corda.core.contracts.*
import net.corda.core.crypto.*
import net.corda.core.crypto.internal.DigestAlgorithmFactory
import net.corda.core.identity.CordaX500Name
import net.corda.core.identity.Party
import net.corda.core.internal.BLAKE2s256DigestAlgorithm
import net.corda.core.internal.SHA256BLAKE2s256DigestAlgorithm
import net.corda.core.node.NotaryInfo
import net.corda.core.node.services.IdentityService
import net.corda.core.serialization.deserialize
@ -49,12 +52,19 @@ class PartialMerkleTreeTest(private var digestService: DigestService) {
val MINI_CORP get() = miniCorp.party
val MINI_CORP_PUBKEY get() = miniCorp.publicKey
init {
DigestAlgorithmFactory.registerClass(BLAKE2s256DigestAlgorithm::class.java.name)
DigestAlgorithmFactory.registerClass(SHA256BLAKE2s256DigestAlgorithm::class.java.name)
}
@JvmStatic
@Parameterized.Parameters
fun data(): Collection<DigestService> = listOf(
DigestService.sha2_256,
DigestService.sha2_384,
DigestService.sha2_512
DigestService.sha2_512,
DigestService("BLAKE_TEST"),
DigestService("SHA256-BLAKE2S256-TEST")
)
}

View File

@ -140,7 +140,7 @@ class PartialMerkleTreeWithNamedHashTest {
fun `building Merkle tree one node`() {
val node = 'a'.serialize().sha2_384()
val mt = MerkleTree.getMerkleTree(listOf(node), DigestService.sha2_384)
assertEquals(node, mt.hash)
assertNotEquals(node, mt.hash)
}
@Test(timeout=300_000)

View File

@ -0,0 +1,186 @@
package net.corda.coretests.transactions
import net.corda.core.contracts.ComponentGroupEnum.COMMANDS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.INPUTS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.NOTARY_GROUP
import net.corda.core.contracts.ComponentGroupEnum.OUTPUTS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.SIGNERS_GROUP
import net.corda.core.contracts.ComponentGroupEnum.TIMEWINDOW_GROUP
import net.corda.core.contracts.StateRef
import net.corda.core.contracts.PrivacySalt
import net.corda.core.contracts.TimeWindow
import net.corda.core.contracts.TransactionState
import net.corda.core.crypto.DigestService
import net.corda.core.crypto.SecureHash
import net.corda.core.crypto.generateKeyPair
import net.corda.core.crypto.internal.DigestAlgorithmFactory
import net.corda.core.internal.SHA256BLAKE2s256DigestAlgorithm
import net.corda.core.internal.accessAvailableComponentHashes
import net.corda.core.internal.accessAvailableComponentNonces
import net.corda.core.internal.accessGroupMerkleRoots
import net.corda.core.serialization.serialize
import net.corda.core.transactions.ComponentGroup
import net.corda.core.transactions.WireTransaction
import net.corda.testing.contracts.DummyContract
import net.corda.testing.contracts.DummyState
import net.corda.testing.core.BOB_NAME
import net.corda.testing.core.DUMMY_NOTARY_NAME
import net.corda.testing.core.SerializationEnvironmentRule
import net.corda.testing.core.TestIdentity
import net.corda.testing.core.dummyCommand
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import java.time.Instant
import kotlin.test.assertEquals
class MerkleTreeAgilityTest {
private companion object {
val DUMMY_KEY_1 = generateKeyPair()
val DUMMY_KEY_2 = generateKeyPair()
val BOB = TestIdentity(BOB_NAME, 80).party
val DUMMY_NOTARY = TestIdentity(DUMMY_NOTARY_NAME, 20).party
}
@Rule
@JvmField
val testSerialization = SerializationEnvironmentRule()
private val dummyOutState = TransactionState(DummyState(0), DummyContract.PROGRAM_ID, DUMMY_NOTARY)
private val stateRef1 = StateRef(SecureHash.randomSHA256(), 0)
private val stateRef2 = StateRef(SecureHash.randomSHA256(), 1)
private val stateRef3 = StateRef(SecureHash.randomSHA256(), 2)
private val stateRef4 = StateRef(SecureHash.randomSHA256(), 3)
private val singleInput = listOf(stateRef1) // 1 elements.
private val threeInputs = listOf(stateRef1, stateRef2, stateRef3) // 3 elements.
private val fourInputs = listOf(stateRef1, stateRef2, stateRef3, stateRef4) // 4 elements.
private val outputs = listOf(dummyOutState, dummyOutState.copy(notary = BOB)) // 2 elements.
private val commands = listOf(dummyCommand(DUMMY_KEY_1.public, DUMMY_KEY_2.public)) // 1 element.
private val notary = DUMMY_NOTARY
private val timeWindow = TimeWindow.fromOnly(Instant.now())
private val privacySalt: PrivacySalt = PrivacySalt()
private val singleInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, singleInput.map { it.serialize() }) }
private val threeInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, threeInputs.map { it.serialize() }) }
private val fourInputsGroup by lazy { ComponentGroup(INPUTS_GROUP.ordinal, fourInputs.map { it.serialize() }) }
private val outputGroup by lazy { ComponentGroup(OUTPUTS_GROUP.ordinal, outputs.map { it.serialize() }) }
private val commandGroup by lazy { ComponentGroup(COMMANDS_GROUP.ordinal, commands.map { it.value.serialize() }) }
private val notaryGroup by lazy { ComponentGroup(NOTARY_GROUP.ordinal, listOf(notary.serialize())) }
private val timeWindowGroup by lazy { ComponentGroup(TIMEWINDOW_GROUP.ordinal, listOf(timeWindow.serialize())) }
private val signersGroup by lazy { ComponentGroup(SIGNERS_GROUP.ordinal, commands.map { it.signers.serialize() }) }
private val componentGroupsSingle by lazy {
listOf(singleInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup)
}
private val componentGroupsFourInputs by lazy {
listOf(fourInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup)
}
private val componentGroupsThreeInputs by lazy {
listOf(threeInputsGroup, outputGroup, commandGroup, notaryGroup, timeWindowGroup, signersGroup)
}
private val defaultDigestService = DigestService.sha2_256
private val customDigestService = DigestService("SHA256-BLAKE2S256-TEST")
@Before
fun before() {
DigestAlgorithmFactory.registerClass(SHA256BLAKE2s256DigestAlgorithm::class.java.name)
}
@Test(timeout = 300_000)
fun `component nonces are correct for custom preimage resistant hash algo`() {
val wireTransaction = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = customDigestService)
val expected = componentGroupsFourInputs.associate {
it.groupIndex to it.components.mapIndexed { componentIndexInGroup, _ ->
customDigestService.computeNonce(privacySalt, it.groupIndex, componentIndexInGroup)
}
}
assertEquals(expected, wireTransaction.accessAvailableComponentNonces())
}
@Test(timeout = 300_000)
fun `component nonces are correct for default SHA256 hash algo`() {
val wireTransaction = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt)
val expected = componentGroupsFourInputs.associate {
it.groupIndex to it.components.mapIndexed { componentIndexInGroup, componentBytes ->
defaultDigestService.componentHash(componentBytes, privacySalt, it.groupIndex, componentIndexInGroup)
}
}
assertEquals(expected, wireTransaction.accessAvailableComponentNonces())
}
@Test(timeout = 300_000)
fun `custom algorithm transaction pads leaf in single component component group`() {
val wtx = WireTransaction(componentGroups = componentGroupsSingle, privacySalt = privacySalt, digestService = customDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val expected = customDigestService.hash(inputsTreeLeaves[0].bytes + customDigestService.zeroHash.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `default algorithm transaction does not pad leaf in single component component group`() {
val wtx = WireTransaction(componentGroups = componentGroupsSingle, privacySalt = privacySalt, digestService = defaultDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val expected = inputsTreeLeaves[0]
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `custom algorithm transaction has expected root for four components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = customDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = customDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = customDigestService.hash(inputsTreeLeaves[2].bytes + inputsTreeLeaves[3].bytes)
val expected = customDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `default algorithm transaction has expected root for four components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsFourInputs, privacySalt = privacySalt, digestService = defaultDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = defaultDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = defaultDigestService.hash(inputsTreeLeaves[2].bytes + inputsTreeLeaves[3].bytes)
val expected = defaultDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `custom algorithm transaction has expected root for three components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsThreeInputs, privacySalt = privacySalt, digestService = customDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = customDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = customDigestService.hash(inputsTreeLeaves[2].bytes + customDigestService.zeroHash.bytes)
val expected = customDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
@Test(timeout = 300_000)
fun `default algorithm transaction has expected root for three components component group tree`() {
val wtx = WireTransaction(componentGroups = componentGroupsThreeInputs, privacySalt = privacySalt, digestService = defaultDigestService)
val inputsTreeLeaves: List<SecureHash> = wtx.accessAvailableComponentHashes()[INPUTS_GROUP.ordinal]!!
val h1 = defaultDigestService.hash(inputsTreeLeaves[0].bytes + inputsTreeLeaves[1].bytes)
val h2 = defaultDigestService.hash(inputsTreeLeaves[2].bytes + defaultDigestService.zeroHash.bytes)
val expected = defaultDigestService.hash(h1.bytes + h2.bytes)
assertEquals(expected, wtx.accessGroupMerkleRoots()[INPUTS_GROUP.ordinal]!!)
}
}

View File

@ -37,19 +37,19 @@ sealed class MerkleTree {
require(algorithms.size == 1) {
"Cannot build Merkle tree with multiple hash algorithms: $algorithms"
}
val leaves = padWithZeros(allLeavesHashes).map { Leaf(it) }
val leaves = padWithZeros(allLeavesHashes, nodeDigestService.hashAlgorithm == SecureHash.SHA2_256).map { Leaf(it) }
return buildMerkleTree(leaves, nodeDigestService)
}
// If number of leaves in the tree is not a power of 2, we need to pad it with zero hashes.
private fun padWithZeros(allLeavesHashes: List<SecureHash>): List<SecureHash> {
private fun padWithZeros(allLeavesHashes: List<SecureHash>, singleLeafWithoutPadding: Boolean): List<SecureHash> {
var n = allLeavesHashes.size
if (isPow2(n)) return allLeavesHashes
if (isPow2(n) && (n > 1 || singleLeafWithoutPadding)) return allLeavesHashes
val paddedHashes = ArrayList(allLeavesHashes)
val zeroHash = SecureHash.zeroHashFor(paddedHashes[0].algorithm)
while (!isPow2(n++)) {
do {
paddedHashes.add(zeroHash)
}
} while (!isPow2(++n))
return paddedHashes
}

View File

@ -1,33 +1,19 @@
package net.corda.core.crypto
import net.corda.core.crypto.internal.DigestAlgorithmFactory
import org.bouncycastle.crypto.digests.Blake2sDigest
import net.corda.core.internal.BLAKE2s256DigestAlgorithm
import org.junit.Assert.assertArrayEquals
import org.junit.Before
import org.junit.Test
import kotlin.test.assertEquals
class Blake2s256DigestServiceTest {
class BLAKE2s256DigestService : DigestAlgorithm {
override val algorithm = "BLAKE_TEST"
override val digestLength = 32
override fun digest(bytes: ByteArray): ByteArray {
val blake2s256 = Blake2sDigest(null, digestLength, null, "12345678".toByteArray())
blake2s256.reset()
blake2s256.update(bytes, 0, bytes.size)
val hash = ByteArray(digestLength)
blake2s256.doFinal(hash, 0)
return hash
}
}
private val service = DigestService("BLAKE_TEST")
@Before
fun before() {
DigestAlgorithmFactory.registerClass(BLAKE2s256DigestService::class.java.name)
DigestAlgorithmFactory.registerClass(BLAKE2s256DigestAlgorithm::class.java.name)
}
@Test(timeout = 300_000)

View File

@ -0,0 +1,36 @@
package net.corda.core.internal
import net.corda.core.crypto.DigestAlgorithm
import net.corda.core.crypto.SecureHash
import org.bouncycastle.crypto.digests.Blake2sDigest
/**
* A set of custom hash algorithms
*/
open class BLAKE2s256DigestAlgorithm : DigestAlgorithm {
override val algorithm = "BLAKE_TEST"
override val digestLength = 32
protected fun blake2sHash(bytes: ByteArray): ByteArray {
val blake2s256 = Blake2sDigest(null, digestLength, null, "12345678".toByteArray())
blake2s256.reset()
blake2s256.update(bytes, 0, bytes.size)
val hash = ByteArray(digestLength)
blake2s256.doFinal(hash, 0)
return hash
}
override fun digest(bytes: ByteArray): ByteArray = blake2sHash(bytes)
}
class SHA256BLAKE2s256DigestAlgorithm : BLAKE2s256DigestAlgorithm() {
override val algorithm = "SHA256-BLAKE2S256-TEST"
override fun digest(bytes: ByteArray): ByteArray = SecureHash.hashAs(SecureHash.SHA2_256, bytes).bytes
override fun componentDigest(bytes: ByteArray): ByteArray = blake2sHash(bytes)
override fun nonceDigest(bytes: ByteArray): ByteArray = blake2sHash(bytes)
}

View File

@ -18,6 +18,7 @@ fun WireTransaction.accessGroupHashes() = this.groupHashes
fun WireTransaction.accessGroupMerkleRoots() = this.groupsMerkleRoots
fun WireTransaction.accessAvailableComponentHashes() = this.availableComponentHashes
fun WireTransaction.accessAvailableComponentNonces() = this.availableComponentNonces
@Suppress("LongParameterList")
fun createLedgerTransaction(