diff --git a/core/src/main/kotlin/net/corda/core/serialization/DefaultKryoCustomizer.kt b/core/src/main/kotlin/net/corda/core/serialization/DefaultKryoCustomizer.kt index a165c5acc2..4dad6f4037 100644 --- a/core/src/main/kotlin/net/corda/core/serialization/DefaultKryoCustomizer.kt +++ b/core/src/main/kotlin/net/corda/core/serialization/DefaultKryoCustomizer.kt @@ -19,6 +19,8 @@ import net.i2p.crypto.eddsa.EdDSAPublicKey import org.objenesis.strategy.StdInstantiatorStrategy import org.slf4j.Logger import java.io.BufferedInputStream +import java.io.FileInputStream +import java.io.InputStream import java.util.* object DefaultKryoCustomizer { @@ -53,6 +55,7 @@ object DefaultKryoCustomizer { ImmutableMapSerializer.registerSerializers(this) ImmutableMultimapSerializer.registerSerializers(this) + // InputStream subclasses whitelisting, required for attachments. register(BufferedInputStream::class.java, InputStreamSerializer) register(Class.forName("sun.net.www.protocol.jar.JarURLConnection\$JarURLInputStream"), InputStreamSerializer) @@ -81,6 +84,11 @@ object DefaultKryoCustomizer { addDefaultSerializer(Logger::class.java, LoggerSerializer) + register(FileInputStream::class.java, InputStreamSerializer) + // Required for HashCheckingStream (de)serialization. + // Note that return type should be specifically set to InputStream, otherwise it may not work, i.e. val aStream : InputStream = HashCheckingStream(...). + addDefaultSerializer(InputStream::class.java, InputStreamSerializer) + val customization = KryoSerializationCustomization(this) pluginRegistries.forEach { it.customizeSerialization(customization) } } diff --git a/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt b/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt index d27c0cbb95..1df24fe81c 100644 --- a/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt +++ b/core/src/test/kotlin/net/corda/core/serialization/KryoTests.kt @@ -8,9 +8,7 @@ import org.assertj.core.api.Assertions.assertThat import org.assertj.core.api.Assertions.assertThatThrownBy import org.bouncycastle.jce.provider.BouncyCastleProvider import org.bouncycastle.pqc.jcajce.provider.BouncyCastlePQCProvider -import org.junit.After import org.junit.Before -import org.junit.Ignore import org.junit.Test import org.slf4j.LoggerFactory import java.io.InputStream @@ -19,6 +17,8 @@ import java.time.Instant import java.util.* import kotlin.test.assertEquals import kotlin.test.assertTrue +import net.corda.node.services.persistence.NodeAttachmentService +import java.io.ByteArrayInputStream class KryoTests { @@ -132,6 +132,16 @@ class KryoTests { assertTrue(logger === logger2) } + @Test + fun `HashCheckingStream (de)serialize`() { + val rubbish = ByteArray(12345, { (it * it * 0.12345).toByte() }) + val readRubbishStream : InputStream = NodeAttachmentService.HashCheckingStream(SecureHash.sha256(rubbish), rubbish.size, ByteArrayInputStream(rubbish)).serialize(kryo).deserialize(kryo) + for (i in 0 .. 12344) { + assertEquals(rubbish[i], readRubbishStream.read().toByte()) + } + assertEquals(-1, readRubbishStream.read()) + } + @CordaSerializable private data class Person(val name: String, val birthday: Instant?) diff --git a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt index eb79ae2cf3..4eddc22224 100644 --- a/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt +++ b/node/src/main/kotlin/net/corda/node/services/persistence/NodeAttachmentService.kt @@ -65,11 +65,12 @@ class NodeAttachmentService(override var storePath: Path, dataSourceProperties: * inside it, we haven't read the whole file, so we can't check the hash. But when copying it over the network * this will provide an additional safety check against user error. */ - private class HashCheckingStream(val expected: SecureHash.SHA256, - val expectedSize: Int, - input: InputStream, - private val counter: CountingInputStream = CountingInputStream(input), - private val stream: HashingInputStream = HashingInputStream(Hashing.sha256(), counter)) : FilterInputStream(stream) { + @VisibleForTesting @CordaSerializable + class HashCheckingStream(val expected: SecureHash.SHA256, + val expectedSize: Int, + input: InputStream, + private val counter: CountingInputStream = CountingInputStream(input), + private val stream: HashingInputStream = HashingInputStream(Hashing.sha256(), counter)) : FilterInputStream(stream) { override fun close() { super.close() @@ -86,7 +87,7 @@ class NodeAttachmentService(override var storePath: Path, dataSourceProperties: private val checkOnLoad: Boolean) : Attachment { override fun open(): InputStream { - var stream = ByteArrayInputStream(attachment) + val stream = ByteArrayInputStream(attachment) // This is just an optional safety check. If it slows things down too much it can be disabled. if (id is SecureHash.SHA256 && checkOnLoad)