ENT-11126: Use UNIX domain socket for communication with external verifier

These have the advantage of being more secure as only the current user has access to them and faster than local TCP as it avoids the entire TCP stack.
This commit is contained in:
Shams Asari 2024-03-26 10:29:19 +00:00
parent ea1aec1267
commit 62819f27f0
6 changed files with 148 additions and 67 deletions

View File

@ -47,6 +47,7 @@ import kotlin.io.path.copyTo
import kotlin.io.path.div
import kotlin.io.path.listDirectoryEntries
import kotlin.io.path.readText
import kotlin.io.path.useLines
class ExternalVerificationSignedCordappsTest {
private companion object {
@ -84,6 +85,16 @@ class ExternalVerificationSignedCordappsTest {
@JvmStatic
fun close() {
factory.close()
// Make sure all UNIX domain files are deleted
(notaries + currentNode).forEach { node ->
node.logFile("node")!!.useLines { lines ->
for (line in lines) {
if ("ExternalVerifierHandleImpl" in line && "Binding to UNIX domain file " in line) {
assertThat(Path(line.substringAfterLast("Binding to UNIX domain file "))).doesNotExist()
}
}
}
}
}
}
@ -262,7 +273,7 @@ private fun <T> Observable<T>.waitForFirst(predicate: (T) -> Boolean): Completab
}
private fun NodeProcess.assertTransactionsWereVerified(verificationType: VerificationType, vararg txIds: SecureHash) {
val nodeLogs = logs("node")!!
val nodeLogs = logContents("node")!!
val externalVerifierLogs = externalVerifierLogs()
for (txId in txIds) {
assertThat(nodeLogs).contains("WireTransaction(id=$txId) will be verified ${verificationType.logStatement}")
@ -273,15 +284,14 @@ private fun NodeProcess.assertTransactionsWereVerified(verificationType: Verific
}
}
private fun NodeProcess.externalVerifierLogs(): String? = logs("verifier")
private fun NodeProcess.externalVerifierLogs(): String? = logContents("verifier")
private fun NodeProcess.logs(name: String): String? {
return (nodeDir / "logs")
.listDirectoryEntries("$name-${InetAddress.getLocalHost().hostName}.log")
.singleOrNull()
?.readText()
private fun NodeProcess.logFile(name: String): Path? {
return (nodeDir / "logs").listDirectoryEntries("$name-${InetAddress.getLocalHost().hostName}.log").singleOrNull()
}
private fun NodeProcess.logContents(name: String): String? = logFile(name)?.readText()
private enum class VerificationType {
IN_PROCESS, EXTERNAL, BOTH;

View File

@ -1,6 +1,7 @@
package net.corda.node.verification
import net.corda.core.contracts.Attachment
import net.corda.core.crypto.random63BitValue
import net.corda.core.internal.AbstractAttachment
import net.corda.core.internal.copyTo
import net.corda.core.internal.level
@ -35,19 +36,27 @@ import net.corda.serialization.internal.verifier.ExternalVerifierOutbound.Verifi
import net.corda.serialization.internal.verifier.ExternalVerifierOutbound.VerifierRequest.GetTrustedClassAttachments
import net.corda.serialization.internal.verifier.readCordaSerializable
import net.corda.serialization.internal.verifier.writeCordaSerializable
import java.io.DataInputStream
import java.io.DataOutputStream
import java.io.IOException
import java.lang.Character.MAX_RADIX
import java.lang.ProcessBuilder.Redirect
import java.lang.management.ManagementFactory
import java.net.ServerSocket
import java.net.Socket
import java.net.StandardProtocolFamily
import java.net.UnixDomainSocketAddress
import java.nio.channels.ServerSocketChannel
import java.nio.channels.SocketChannel
import java.nio.file.Files
import java.nio.file.Path
import java.nio.file.StandardCopyOption.REPLACE_EXISTING
import java.nio.file.attribute.PosixFileAttributeView
import java.nio.file.attribute.PosixFilePermissions.fromString
import kotlin.io.path.Path
import kotlin.io.path.absolutePathString
import kotlin.io.path.createDirectories
import kotlin.io.path.deleteIfExists
import kotlin.io.path.div
import kotlin.io.path.fileAttributesViewOrNull
import kotlin.io.path.isExecutable
import kotlin.io.path.isWritable
/**
* Handle to the node's external verifier. The verifier process is started lazily on the first verification request.
@ -67,11 +76,13 @@ class ExternalVerifierHandleImpl(
Companion::class.java.getResourceAsStream("external-verifier.jar")!!.use {
it.copyTo(verifierJar, REPLACE_EXISTING)
}
log.debug { "Extracted external verifier jar to ${verifierJar.absolutePathString()}" }
verifierJar.toFile().deleteOnExit()
}
}
private lateinit var server: ServerSocket
private lateinit var socketFile: Path
private lateinit var serverChannel: ServerSocketChannel
@Volatile
private var connection: Connection? = null
@ -104,8 +115,16 @@ class ExternalVerifierHandleImpl(
}
private fun startServer() {
if (::server.isInitialized) return
server = ServerSocket(0)
if (::socketFile.isInitialized) return
// Try to create the UNIX domain file in /tmp to keep the full path under the 100 char limit. If we don't have access to it then
// fallback to the temp dir specified by the JVM and hope it's short enough.
val tempDir = Path("/tmp").takeIf { it.isWritable() && it.isExecutable() } ?: Path(System.getProperty("java.io.tmpdir"))
socketFile = tempDir / "corda-external-verifier-${random63BitValue().toString(MAX_RADIX)}.socket"
serverChannel = ServerSocketChannel.open(StandardProtocolFamily.UNIX)
log.debug { "Binding to UNIX domain file $socketFile" }
serverChannel.bind(UnixDomainSocketAddress.of(socketFile), 1)
// Lock down access to the file
socketFile.fileAttributesViewOrNull<PosixFileAttributeView>()?.setPermissions(fromString("rwx------"))
// Just in case...
Runtime.getRuntime().addShutdownHook(Thread(::close))
}
@ -126,11 +145,11 @@ class ExternalVerifierHandleImpl(
private fun tryVerification(request: VerificationRequest): Try<Unit> {
val connection = getConnection()
connection.toVerifier.writeCordaSerializable(request)
connection.channel.writeCordaSerializable(request)
// Send the verification request and then wait for any requests from verifier for more information. The last message will either
// be a verification success or failure message.
while (true) {
val message = connection.fromVerifier.readCordaSerializable<ExternalVerifierOutbound>()
val message = connection.channel.readCordaSerializable(ExternalVerifierOutbound::class)
log.debug { "Received from external verifier: $message" }
when (message) {
// Process the information the verifier needs and then loop back and wait for more messages
@ -153,7 +172,7 @@ class ExternalVerifierHandleImpl(
is GetTrustedClassAttachments -> TrustedClassAttachmentsResult(verificationSupport.getTrustedClassAttachments(request.className).map { it.id })
}
log.debug { "Sending response to external verifier: $result" }
connection.toVerifier.writeCordaSerializable(result)
connection.channel.writeCordaSerializable(result)
}
private fun Attachment.withTrust(): AttachmentWithTrust {
@ -168,21 +187,19 @@ class ExternalVerifierHandleImpl(
}
override fun close() {
connection?.let {
connection = null
try {
it.close()
} finally {
server.close()
}
connection?.close()
connection = null
if (::serverChannel.isInitialized) {
serverChannel.close()
}
if (::socketFile.isInitialized) {
socketFile.deleteIfExists()
}
}
private inner class Connection : AutoCloseable {
private val verifierProcess: Process
private val socket: Socket
val toVerifier: DataOutputStream
val fromVerifier: DataInputStream
val channel: SocketChannel
init {
val inheritedJvmArgs = ManagementFactory.getRuntimeMXBean().inputArguments.filter { "--add-opens" in it }
@ -192,7 +209,7 @@ class ExternalVerifierHandleImpl(
command += listOf(
"-jar",
"$verifierJar",
"${server.localPort}",
socketFile.absolutePathString(),
log.level.name.lowercase()
)
log.debug { "External verifier command: $command" }
@ -213,9 +230,7 @@ class ExternalVerifierHandleImpl(
connection = null
}
socket = server.accept()
toVerifier = DataOutputStream(socket.outputStream)
fromVerifier = DataInputStream(socket.inputStream)
channel = serverChannel.accept()
val cordapps = verificationSupport.cordappProvider.cordapps
val initialisation = Initialisation(
@ -224,12 +239,12 @@ class ExternalVerifierHandleImpl(
System.getProperty("experimental.corda.customSerializationScheme"), // See Node#initialiseSerialization
serializedCurrentNetworkParameters = verificationSupport.networkParameters.serialize()
)
toVerifier.writeCordaSerializable(initialisation)
channel.writeCordaSerializable(initialisation)
}
override fun close() {
try {
socket.close()
channel.close()
} finally {
verifierProcess.destroyForcibly()
}

View File

@ -0,0 +1,39 @@
package net.corda.serialization.internal.verifier
import net.corda.core.crypto.SecureHash
import net.corda.core.internal.concurrent.openFuture
import net.corda.serialization.internal.verifier.ExternalVerifierOutbound.VerifierRequest.GetAttachments
import net.corda.testing.core.SerializationEnvironmentRule
import org.assertj.core.api.Assertions.assertThat
import org.junit.Rule
import org.junit.Test
import java.net.InetSocketAddress
import java.nio.channels.ServerSocketChannel
import java.nio.channels.SocketChannel
import kotlin.concurrent.thread
class ExternalVerifierTypesTest {
@get:Rule
val testSerialization = SerializationEnvironmentRule()
@Test(timeout=300_000)
fun `socket channel read-write`() {
val payload = GetAttachments(setOf(SecureHash.randomSHA256(), SecureHash.randomSHA256()))
val serverChannel = ServerSocketChannel.open()
serverChannel.bind(null)
val future = openFuture<GetAttachments>()
thread {
SocketChannel.open().use {
it.connect(InetSocketAddress(serverChannel.socket().localPort))
val received = it.readCordaSerializable(GetAttachments::class)
future.set(received)
}
}
serverChannel.use { it.accept().writeCordaSerializable(payload) }
assertThat(future.get()).isEqualTo(payload)
}
}

View File

@ -8,15 +8,19 @@ import net.corda.core.identity.Party
import net.corda.core.internal.SerializedTransactionState
import net.corda.core.node.NetworkParameters
import net.corda.core.serialization.CordaSerializable
import net.corda.core.serialization.SerializationFactory
import net.corda.core.serialization.SerializedBytes
import net.corda.core.serialization.deserialize
import net.corda.core.serialization.serialize
import net.corda.core.transactions.CoreTransaction
import net.corda.core.utilities.Try
import java.io.DataInputStream
import java.io.DataOutputStream
import net.corda.core.utilities.sequence
import java.io.EOFException
import java.nio.ByteBuffer
import java.nio.channels.SocketChannel
import java.security.PublicKey
import kotlin.math.min
import kotlin.reflect.KClass
typealias SerializedNetworkParameters = SerializedBytes<NetworkParameters>
@ -71,18 +75,37 @@ sealed class ExternalVerifierOutbound {
data class VerificationResult(val result: Try<Unit>) : ExternalVerifierOutbound()
}
fun DataOutputStream.writeCordaSerializable(payload: Any) {
fun SocketChannel.writeCordaSerializable(payload: Any) {
val serialised = payload.serialize()
writeInt(serialised.size)
serialised.writeTo(this)
flush()
val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
buffer.putInt(serialised.size)
var writtenSoFar = 0
while (writtenSoFar < serialised.size) {
val length = min(buffer.remaining(), serialised.size - writtenSoFar)
serialised.subSequence(writtenSoFar, length).putTo(buffer)
buffer.flip()
write(buffer)
writtenSoFar += length
buffer.clear()
}
}
inline fun <reified T : Any> DataInputStream.readCordaSerializable(): T {
val length = readInt()
val bytes = readNBytes(length)
if (bytes.size != length) {
throw EOFException("Incomplete read of ${T::class.java.name}")
}
return bytes.deserialize<T>()
fun <T : Any> SocketChannel.readCordaSerializable(clazz: KClass<T>): T {
val length = ByteBuffer.wrap(read(clazz, Integer.BYTES)).getInt()
val bytes = read(clazz, length)
return SerializationFactory.defaultFactory.deserialize(bytes.sequence(), clazz.java, SerializationFactory.defaultFactory.defaultContext)
}
private fun SocketChannel.read(clazz: KClass<*>, length: Int): ByteArray {
val bytes = ByteArray(length)
var readSoFar = 0
while (readSoFar < bytes.size) {
// Wrap a ByteBuffer around the byte array to read directly into it
val n = read(ByteBuffer.wrap(bytes, readSoFar, bytes.size - readSoFar))
if (n == -1) {
throw EOFException("Incomplete read of ${clazz.java.name}")
}
readSoFar += n
}
return bytes
}

View File

@ -47,9 +47,8 @@ import net.corda.serialization.internal.verifier.ExternalVerifierOutbound.Verifi
import net.corda.serialization.internal.verifier.loadCustomSerializationScheme
import net.corda.serialization.internal.verifier.readCordaSerializable
import net.corda.serialization.internal.verifier.writeCordaSerializable
import java.io.DataInputStream
import java.io.DataOutputStream
import java.net.URLClassLoader
import java.nio.channels.SocketChannel
import java.nio.file.Path
import java.security.PublicKey
import java.util.Optional
@ -57,11 +56,7 @@ import kotlin.io.path.div
import kotlin.io.path.listDirectoryEntries
@Suppress("MagicNumber")
class ExternalVerifier(
private val baseDirectory: Path,
private val fromNode: DataInputStream,
private val toNode: DataOutputStream
) {
class ExternalVerifier(private val baseDirectory: Path, private val channel: SocketChannel) {
companion object {
private val log = contextLogger()
}
@ -88,7 +83,7 @@ class ExternalVerifier(
fun run() {
initialise()
while (true) {
val request = fromNode.readCordaSerializable<VerificationRequest>()
val request = channel.readCordaSerializable(VerificationRequest::class)
log.debug { "Received $request" }
verifyTransaction(request)
}
@ -102,7 +97,7 @@ class ExternalVerifier(
))
log.info("Waiting for initialisation message from node...")
val initialisation = fromNode.readCordaSerializable<Initialisation>()
val initialisation = channel.readCordaSerializable(Initialisation::class)
log.info("Received $initialisation")
appClassLoader = createAppClassLoader()
@ -151,7 +146,7 @@ class ExternalVerifier(
log.info("${request.ctx.toSimpleString()} failed to verify", t)
Try.Failure(t)
}
toNode.writeCordaSerializable(VerificationResult(result))
channel.writeCordaSerializable(VerificationResult(result))
}
fun getParties(keys: Collection<PublicKey>): List<Party?> {
@ -195,8 +190,8 @@ class ExternalVerifier(
private inline fun <reified T : Any> request(request: Any): T {
log.debug { "Sending request to node: $request" }
toNode.writeCordaSerializable(request)
val response = fromNode.readCordaSerializable<T>()
channel.writeCordaSerializable(request)
val response = channel.readCordaSerializable(T::class)
log.debug { "Received response from node: $response" }
return response
}

View File

@ -2,9 +2,9 @@ package net.corda.verifier
import net.corda.core.utilities.loggerFor
import org.slf4j.bridge.SLF4JBridgeHandler
import java.io.DataInputStream
import java.io.DataOutputStream
import java.net.Socket
import java.net.StandardProtocolFamily
import java.net.UnixDomainSocketAddress
import java.nio.channels.SocketChannel
import java.nio.file.Path
import kotlin.io.path.div
import kotlin.system.exitProcess
@ -12,7 +12,7 @@ import kotlin.system.exitProcess
object Main {
@JvmStatic
fun main(args: Array<String>) {
val port = args[0].toInt()
val socketFile = args[0]
val loggingLevel = args[1]
val baseDirectory = Path.of("").toAbsolutePath()
@ -23,11 +23,10 @@ object Main {
log.info("Node base directory: $baseDirectory")
try {
val socket = Socket("localhost", port)
log.info("Connected to node on port $port")
val fromNode = DataInputStream(socket.getInputStream())
val toNode = DataOutputStream(socket.getOutputStream())
ExternalVerifier(baseDirectory, fromNode, toNode).run()
val channel = SocketChannel.open(StandardProtocolFamily.UNIX)
channel.connect(UnixDomainSocketAddress.of(socketFile))
log.info("Connected to node on UNIX domain file $socketFile")
ExternalVerifier(baseDirectory, channel).run()
} catch (t: Throwable) {
log.error("Unexpected error which has terminated the verifier", t)
exitProcess(1)