Prohibit Java deserialisation in the Corda process (#566)

This commit is contained in:
Andrzej Cichocki
2017-04-21 16:26:35 +01:00
committed by GitHub
parent 4cb21257e6
commit 160d13b6f7
20 changed files with 215 additions and 101 deletions

View File

@ -0,0 +1,43 @@
package net.corda.node
import net.corda.core.flows.FlowLogic
import net.corda.core.getOrThrow
import net.corda.core.messaging.startFlow
import net.corda.core.node.CordaPluginRegistry
import net.corda.node.driver.driver
import net.corda.node.services.startFlowPermission
import net.corda.nodeapi.User
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.Test
import java.io.*
class BootTests {
@Test
fun `java deserialization is disabled`() {
driver {
val user = User("u", "p", setOf(startFlowPermission<ObjectInputStreamFlow>()))
val future = startNode(rpcUsers = listOf(user)).getOrThrow().rpcClientToNode().apply {
start(user.username, user.password)
}.proxy().startFlow(::ObjectInputStreamFlow).returnValue
assertThatThrownBy { future.getOrThrow() }.isInstanceOf(InvalidClassException::class.java).hasMessage("filter status: REJECTED")
}
}
}
class ObjectInputStreamFlow : FlowLogic<Unit>() {
override fun call() {
System.clearProperty("jdk.serialFilter") // This checks that the node has already consumed the property.
val data = ByteArrayOutputStream().apply { ObjectOutputStream(this).use { it.writeObject(object : Serializable {}) } }.toByteArray()
ObjectInputStream(data.inputStream()).use { it.readObject() }
}
}
class BootTestsPlugin : CordaPluginRegistry() {
override val requiredFlows: Map<String, Set<String>> = mapOf(ObjectInputStreamFlow::class.java.name to setOf())
}

View File

@ -1,9 +1,9 @@
package net.corda.node.driver
import com.google.common.util.concurrent.ListenableFuture
import net.corda.core.div
import net.corda.core.getOrThrow
import net.corda.core.list
import net.corda.core.node.NodeInfo
import net.corda.core.node.services.ServiceInfo
import net.corda.core.readLines
import net.corda.core.utilities.DUMMY_BANK_A
@ -19,54 +19,51 @@ import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
class DriverTests {
companion object {
val executorService: ScheduledExecutorService = Executors.newScheduledThreadPool(2)
fun nodeMustBeUp(nodeInfo: NodeInfo) {
companion object {
private val executorService: ScheduledExecutorService = Executors.newScheduledThreadPool(2)
private fun nodeMustBeUp(handleFuture: ListenableFuture<NodeHandle>) = handleFuture.getOrThrow().apply {
val hostAndPort = ArtemisMessagingComponent.toHostAndPort(nodeInfo.address)
// Check that the port is bound
addressMustBeBound(executorService, hostAndPort)
addressMustBeBound(executorService, hostAndPort, process)
}
fun nodeMustBeDown(nodeInfo: NodeInfo) {
val hostAndPort = ArtemisMessagingComponent.toHostAndPort(nodeInfo.address)
private fun nodeMustBeDown(handle: NodeHandle) {
val hostAndPort = ArtemisMessagingComponent.toHostAndPort(handle.nodeInfo.address)
// Check that the port is bound
addressMustNotBeBound(executorService, hostAndPort)
}
}
@Test
fun `simple node startup and shutdown`() {
val (notary, regulator) = driver {
val handles = driver {
val notary = startNode(DUMMY_NOTARY.name, setOf(ServiceInfo(SimpleNotaryService.type)))
val regulator = startNode("Regulator", setOf(ServiceInfo(RegulatorService.type)))
nodeMustBeUp(notary.getOrThrow().nodeInfo)
nodeMustBeUp(regulator.getOrThrow().nodeInfo)
Pair(notary.getOrThrow(), regulator.getOrThrow())
listOf(nodeMustBeUp(notary), nodeMustBeUp(regulator))
}
nodeMustBeDown(notary.nodeInfo)
nodeMustBeDown(regulator.nodeInfo)
handles.map { nodeMustBeDown(it) }
}
@Test
fun `starting node with no services`() {
val noService = driver {
val noService = startNode(DUMMY_BANK_A.name)
nodeMustBeUp(noService.getOrThrow().nodeInfo)
noService.getOrThrow()
nodeMustBeUp(noService)
}
nodeMustBeDown(noService.nodeInfo)
nodeMustBeDown(noService)
}
@Test
fun `random free port allocation`() {
val nodeHandle = driver(portAllocation = PortAllocation.RandomFree) {
val nodeInfo = startNode(DUMMY_BANK_A.name)
nodeMustBeUp(nodeInfo.getOrThrow().nodeInfo)
nodeInfo.getOrThrow()
nodeMustBeUp(nodeInfo)
}
nodeMustBeDown(nodeHandle.nodeInfo)
nodeMustBeDown(nodeHandle)
}
@Test
@ -81,4 +78,5 @@ class DriverTests {
assertThat(debugLinesPresent).isTrue()
}
}
}

View File

@ -0,0 +1 @@
net.corda.node.BootTestsPlugin

View File

@ -9,6 +9,7 @@ import net.corda.core.*
import net.corda.core.node.NodeVersionInfo
import net.corda.core.node.Version
import net.corda.core.utilities.Emoji
import net.corda.core.utilities.LogHelper.withLevel
import net.corda.node.internal.Node
import net.corda.node.services.config.FullNodeConfiguration
import net.corda.node.shell.InteractiveShell
@ -17,10 +18,12 @@ import net.corda.node.utilities.registration.NetworkRegistrationHelper
import org.fusesource.jansi.Ansi
import org.fusesource.jansi.AnsiConsole
import org.slf4j.LoggerFactory
import org.slf4j.bridge.SLF4JBridgeHandler
import java.io.*
import java.lang.management.ManagementFactory
import java.net.InetAddress
import java.nio.file.Path
import java.nio.file.Paths
import java.util.Locale
import kotlin.system.exitProcess
private var renderBasicInfoToConsole = true
@ -34,9 +37,21 @@ fun printBasicNodeInfo(description: String, info: String? = null) {
val LOGS_DIRECTORY_NAME = "logs"
private fun initLogging(cmdlineOptions: CmdLineOptions) {
val loggingLevel = cmdlineOptions.loggingLevel.name.toLowerCase(Locale.ENGLISH)
System.setProperty("defaultLogLevel", loggingLevel) // These properties are referenced from the XML config file.
if (cmdlineOptions.logToConsole) {
System.setProperty("consoleLogLevel", loggingLevel)
renderBasicInfoToConsole = false
}
System.setProperty("log-path", (cmdlineOptions.baseDirectory / LOGS_DIRECTORY_NAME).toString())
SLF4JBridgeHandler.removeHandlersForRootLogger() // The default j.u.l config adds a ConsoleHandler.
SLF4JBridgeHandler.install()
}
fun main(args: Array<String>) {
val startTime = System.currentTimeMillis()
checkJavaVersion()
assertCanNormalizeEmptyPath()
val argsParser = ArgsParser()
@ -48,13 +63,8 @@ fun main(args: Array<String>) {
exitProcess(1)
}
// Set up logging. These properties are referenced from the XML config file.
val loggingLevel = cmdlineOptions.loggingLevel.name.toLowerCase()
System.setProperty("defaultLogLevel", loggingLevel)
if (cmdlineOptions.logToConsole) {
System.setProperty("consoleLogLevel", loggingLevel)
renderBasicInfoToConsole = false
}
initLogging(cmdlineOptions)
disableJavaDeserialization() // Should be after initLogging to avoid TMI.
// Manifest properties are only available if running from the corda jar
fun manifestValue(name: String): String? = if (Manifests.exists(name)) Manifests.read(name) else null
@ -79,9 +89,6 @@ fun main(args: Array<String>) {
drawBanner(nodeVersionInfo)
val dir: Path = cmdlineOptions.baseDirectory
System.setProperty("log-path", (dir / "logs").toString())
val log = LoggerFactory.getLogger("Main")
printBasicNodeInfo("Logs can be found in", System.getProperty("log-path"))
@ -137,7 +144,7 @@ fun main(args: Array<String>) {
val runShell = !cmdlineOptions.noLocalShell && System.console() != null
node.startupComplete then {
try {
InteractiveShell.startShell(dir, runShell, cmdlineOptions.sshdServer, node)
InteractiveShell.startShell(cmdlineOptions.baseDirectory, runShell, cmdlineOptions.sshdServer, node)
} catch(e: Throwable) {
log.error("Shell failed to start", e)
}
@ -155,15 +162,34 @@ fun main(args: Array<String>) {
exitProcess(0)
}
private fun checkJavaVersion() {
private fun assertCanNormalizeEmptyPath() {
// Check we're not running a version of Java with a known bug: https://github.com/corda/corda/issues/83
try {
Paths.get("").normalize()
} catch (e: ArrayIndexOutOfBoundsException) {
println("""
javaIsTooOld()
}
}
private fun javaIsTooOld(): Nothing {
println("""
You are using a version of Java that is not supported (${System.getProperty("java.version")}). Please upgrade to the latest version.
Corda will now exit...""")
exitProcess(1)
exitProcess(1)
}
private fun disableJavaDeserialization() {
// ObjectInputFilter and friends are in java.io in Java 9 but sun.misc in backports, so we use the system property interface for portability:
System.setProperty("jdk.serialFilter", "maxbytes=0")
// Attempt a deserialization so that ObjectInputFilter (permanently) inits itself:
val data = ByteArrayOutputStream().apply { ObjectOutputStream(this).use { it.writeObject(object : Serializable {}) } }.toByteArray()
try {
withLevel("java.io.serialization", "WARN") {
ObjectInputStream(data.inputStream()).use { it.readObject() } // Logs REJECTED at INFO, which we don't want users to see.
}
javaIsTooOld()
} catch (e: InvalidClassException) {
// Good, our system property is honoured (assuming ObjectInputFilter wasn't inited earlier).
}
}

View File

@ -96,7 +96,7 @@ interface DriverDSLExposedInterface {
*
* @param handle The handle for the node that this webserver connects to via RPC.
*/
fun startWebserver(handle: NodeHandle): ListenableFuture<HostAndPort>
fun startWebserver(handle: NodeHandle): ListenableFuture<WebserverHandle>
/**
* Starts a network map service node. Note that only a single one should ever be running, so you will probably want
@ -122,6 +122,11 @@ data class NodeHandle(
fun rpcClientToNode(): CordaRPCClient = CordaRPCClient(configuration.rpcAddress!!)
}
data class WebserverHandle(
val listenAddress: HostAndPort,
val process: Process
)
sealed class PortAllocation {
abstract fun nextPort(): Int
fun nextHostAndPort(): HostAndPort = HostAndPort.fromParts("localhost", nextPort())
@ -228,8 +233,16 @@ fun getTimestampAsDirectoryName(): String {
return DateTimeFormatter.ofPattern("yyyyMMddHHmmss").withZone(UTC).format(Instant.now())
}
fun addressMustBeBound(executorService: ScheduledExecutorService, hostAndPort: HostAndPort): ListenableFuture<Unit> {
class ListenProcessDeathException(message: String) : Exception(message)
/**
* @throws ListenProcessDeathException if [listenProcess] dies before the check succeeds, i.e. the check can't succeed as intended.
*/
fun addressMustBeBound(executorService: ScheduledExecutorService, hostAndPort: HostAndPort, listenProcess: Process): ListenableFuture<Unit> {
return poll(executorService, "address $hostAndPort to bind") {
if (!listenProcess.isAlive) {
throw ListenProcessDeathException("The process that was expected to listen on $hostAndPort has died with status: ${listenProcess.exitValue()}")
}
try {
Socket(hostAndPort.host, hostAndPort.port).close()
Unit
@ -265,12 +278,17 @@ fun <A> poll(
}
var counter = 0
fun schedulePoll() {
executorService.schedule({
executorService.schedule(task@ {
counter++
if (counter == warnCount) {
log.warn("Been polling $pollName for ${pollIntervalMs * warnCount / 1000.0} seconds...")
}
val result = check()
val result = try {
check()
} catch (t: Throwable) {
resultFuture.setException(t)
return@task
}
if (result == null) {
schedulePoll()
} else {
@ -482,7 +500,7 @@ class DriverDSL(
}
}
private fun queryWebserver(handle: NodeHandle, process: Process): HostAndPort {
private fun queryWebserver(handle: NodeHandle, process: Process): WebserverHandle {
val protocol = if (handle.configuration.useHTTPS) "https://" else "http://"
val url = URL("$protocol${handle.webAddress}/api/status")
val client = OkHttpClient.Builder().connectTimeout(5, SECONDS).readTimeout(60, SECONDS).build()
@ -490,7 +508,7 @@ class DriverDSL(
while (process.isAlive) try {
val response = client.newCall(Request.Builder().url(url).build()).execute()
if (response.isSuccessful && (response.body().string() == "started")) {
return handle.webAddress
return WebserverHandle(handle.webAddress, process)
}
} catch(e: ConnectException) {
log.debug("Retrying webserver info at ${handle.webAddress}")
@ -499,13 +517,11 @@ class DriverDSL(
throw IllegalStateException("Webserver at ${handle.webAddress} has died")
}
override fun startWebserver(handle: NodeHandle): ListenableFuture<HostAndPort> {
override fun startWebserver(handle: NodeHandle): ListenableFuture<WebserverHandle> {
val debugPort = if (isDebug) debugPortAllocation.nextPort() else null
val process = DriverDSL.startWebserver(executorService, handle, debugPort)
registerProcess(process)
return process.map {
queryWebserver(handle, it)
}
val processFuture = DriverDSL.startWebserver(executorService, handle, debugPort)
registerProcess(processFuture)
return processFuture.map { queryWebserver(handle, it) }
}
override fun start() {
@ -577,7 +593,7 @@ class DriverDSL(
errorLogPath = nodeConf.baseDirectory / LOGS_DIRECTORY_NAME / "error.log",
workingDirectory = nodeConf.baseDirectory
)
}.flatMap { process -> addressMustBeBound(executorService, nodeConf.p2pAddress).map { process } }
}.flatMap { process -> addressMustBeBound(executorService, nodeConf.p2pAddress, process).map { process } }
}
private fun startWebserver(
@ -594,7 +610,7 @@ class DriverDSL(
extraJvmArguments = listOf("-Dname=node-${handle.configuration.p2pAddress}-webserver"),
errorLogPath = Paths.get("error.$className.log")
)
}.flatMap { process -> addressMustBeBound(executorService, handle.webAddress).map { process } }
}.flatMap { process -> addressMustBeBound(executorService, handle.webAddress, process).map { process } }
}
}
}

View File

@ -1,7 +1,10 @@
package net.corda.node.services.transactions
import com.google.common.net.HostAndPort
import io.atomix.catalyst.buffer.BufferInput
import io.atomix.catalyst.buffer.BufferOutput
import io.atomix.catalyst.serializer.Serializer
import io.atomix.catalyst.serializer.TypeSerializer
import io.atomix.catalyst.transport.Address
import io.atomix.catalyst.transport.Transport
import io.atomix.catalyst.transport.netty.NettyTransport
@ -69,7 +72,21 @@ class RaftUniquenessProvider(
val address = Address(myAddress.host, myAddress.port)
val storage = buildStorage(storagePath)
val transport = buildTransport(config)
val serializer = Serializer()
val serializer = Serializer().apply {
// Add serializers so Catalyst doesn't attempt to fall back on Java serialization for these types, which is disabled process-wide:
register(DistributedImmutableMap.Commands.PutAll::class.java) {
object : TypeSerializer<DistributedImmutableMap.Commands.PutAll<*, *>> {
override fun write(obj: DistributedImmutableMap.Commands.PutAll<*, *>, buffer: BufferOutput<out BufferOutput<*>>, serializer: Serializer) = writeMap(obj.entries, buffer, serializer)
override fun read(type: Class<DistributedImmutableMap.Commands.PutAll<*, *>>, buffer: BufferInput<out BufferInput<*>>, serializer: Serializer) = DistributedImmutableMap.Commands.PutAll(readMap(buffer, serializer))
}
}
register(LinkedHashMap::class.java) {
object : TypeSerializer<LinkedHashMap<*, *>> {
override fun write(obj: LinkedHashMap<*, *>, buffer: BufferOutput<out BufferOutput<*>>, serializer: Serializer) = writeMap(obj, buffer, serializer)
override fun read(type: Class<LinkedHashMap<*, *>>, buffer: BufferInput<out BufferInput<*>>, serializer: Serializer) = readMap(buffer, serializer)
}
}
}
server = CopycatServer.builder(address)
.withStateMachine(stateMachineFactory)
@ -141,4 +158,16 @@ class RaftUniquenessProvider(
fun String.toStateRef() = split(":").let { StateRef(SecureHash.parse(it[0]), it[1].toInt()) }
return items.map { it.key.toStateRef() to it.value.deserialize<UniquenessProvider.ConsumingTx>() }.toMap()
}
}
}
private fun writeMap(map: Map<*, *>, buffer: BufferOutput<out BufferOutput<*>>, serializer: Serializer) = with(map) {
buffer.writeInt(size)
forEach {
with(serializer) {
writeObject(it.key, buffer)
writeObject(it.value, buffer)
}
}
}
private fun readMap(buffer: BufferInput<out BufferInput<*>>, serializer: Serializer) = LinkedHashMap<Any, Any>().apply { repeat(buffer.readInt()) { put(serializer.readObject(buffer), serializer.readObject(buffer)) } }

View File

@ -54,8 +54,6 @@ import java.util.*
import java.util.concurrent.CountDownLatch
import java.util.concurrent.ExecutionException
import java.util.concurrent.Future
import java.util.logging.Level
import java.util.logging.Logger
import kotlin.concurrent.thread
// TODO: Add command history.
@ -80,8 +78,6 @@ object InteractiveShell {
this.node = node
var runSSH = runSSHServer
Logger.getLogger("").level = Level.OFF // TODO: Is this really needed?
val config = Properties()
if (runSSH) {
// TODO: Finish and enable SSH access.