diff --git a/src/agent/onefuzz-supervisor/src/buffer.rs b/src/agent/onefuzz-supervisor/src/buffer.rs new file mode 100644 index 000000000..d03b56ef2 --- /dev/null +++ b/src/agent/onefuzz-supervisor/src/buffer.rs @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct TailBuffer { + data: Vec, + capacity: usize, +} + +impl TailBuffer { + pub fn new(capacity: usize) -> Self { + let data = Vec::with_capacity(capacity); + Self { data, capacity } + } + + pub fn data(&self) -> &[u8] { + &self.data + } + + pub fn to_string_lossy(&self) -> String { + String::from_utf8_lossy(self.data()).to_string() + } +} + +impl std::io::Write for TailBuffer { + fn write(&mut self, new_data: &[u8]) -> std::io::Result { + // Write the new data to the internal buffer, allocating internally as needed. + self.data.extend(new_data); + + // Shift and truncate the buffer if it is too big. + if self.data.len() > self.capacity { + let lo = self.data.len() - self.capacity; + let range = lo..self.data.len(); + self.data.copy_within(range, 0); + self.data.truncate(self.capacity); + } + + Ok(new_data.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::io::Write; + + use super::*; + + #[test] + fn test_tail_buffer() { + let mut buf = TailBuffer::new(5); + + assert!(buf.data().is_empty()); + + buf.write(&[1, 2, 3]).unwrap(); + assert_eq!(buf.data(), &[1, 2, 3]); + + buf.write(&[]).unwrap(); + assert_eq!(buf.data(), &[1, 2, 3]); + + buf.write(&[4, 5]).unwrap(); + assert_eq!(buf.data(), &[1, 2, 3, 4, 5]); + + buf.write(&[6, 7, 8]).unwrap(); + assert_eq!(buf.data(), &[4, 5, 6, 7, 8]); + + buf.write(&[9, 10, 11, 12, 13]).unwrap(); + assert_eq!(buf.data(), &[9, 10, 11, 12, 13]); + + buf.write(&[14, 15, 16, 17, 18, 19, 20, 21, 22]).unwrap(); + assert_eq!(buf.data(), &[18, 19, 20, 21, 22]); + } +} diff --git a/src/agent/onefuzz-supervisor/src/main.rs b/src/agent/onefuzz-supervisor/src/main.rs index e9c64eb53..556d437e3 100644 --- a/src/agent/onefuzz-supervisor/src/main.rs +++ b/src/agent/onefuzz-supervisor/src/main.rs @@ -29,6 +29,7 @@ use structopt::StructOpt; pub mod agent; pub mod auth; +pub mod buffer; pub mod commands; pub mod config; pub mod coordinator; diff --git a/src/agent/onefuzz-supervisor/src/worker.rs b/src/agent/onefuzz-supervisor/src/worker.rs index 1d5bad9f9..3d673e673 100644 --- a/src/agent/onefuzz-supervisor/src/worker.rs +++ b/src/agent/onefuzz-supervisor/src/worker.rs @@ -1,18 +1,22 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. - use std::{ path::{Path, PathBuf}, - process::{Child, Command, Stdio}, + process::{Child, ChildStderr, ChildStdout, Command, Stdio}, + thread::{self, JoinHandle}, }; -use anyhow::{Context as AnyhowContext, Result}; +use anyhow::{format_err, Context as AnyhowContext, Result}; use downcast_rs::Downcast; use onefuzz::process::{ExitStatus, Output}; use tokio::fs; +use crate::buffer::TailBuffer; use crate::work::*; +// Max length of captured output streams from worker child processes. +const MAX_TAIL_LEN: usize = 4096; + #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)] #[serde(rename_all = "snake_case")] pub enum WorkerEvent { @@ -231,24 +235,118 @@ impl IWorkerRunner for WorkerRunner { cmd.stderr(Stdio::piped()); cmd.stdout(Stdio::piped()); - let child = cmd.spawn().context("onefuzz-agent failed to start")?; - let child = Box::new(child); - - Ok(child) + Ok(Box::new(RedirectedChild::spawn(cmd)?)) } } -impl IWorkerChild for Child { +/// Child process with redirected output streams, tailed by two worker threads. +struct RedirectedChild { + /// The child process. + child: Child, + + /// Worker threads which continuously read from the redirected streams. + streams: Option, +} + +impl RedirectedChild { + pub fn spawn(mut cmd: Command) -> Result { + // Make sure we capture the child's output streams. + cmd.stderr(Stdio::piped()); + cmd.stdout(Stdio::piped()); + + let mut child = cmd.spawn().context("onefuzz-agent failed to start")?; + + // Guaranteed by the above. + let stderr = child.stderr.take().unwrap(); + let stdout = child.stdout.take().unwrap(); + let streams = Some(StreamReaderThreads::new(stderr, stdout)); + + Ok(Self { child, streams }) + } +} + +/// Worker threads that tail the redirected output streams of a running child process. +struct StreamReaderThreads { + stderr: JoinHandle, + stdout: JoinHandle, +} + +struct CapturedStreams { + stderr: String, + stdout: String, +} + +impl StreamReaderThreads { + pub fn new(mut stderr: ChildStderr, mut stdout: ChildStdout) -> Self { + use std::io::Read; + + let stderr = thread::spawn(move || { + let mut buf = TailBuffer::new(MAX_TAIL_LEN); + let mut tmp = [0u8; MAX_TAIL_LEN]; + + while let Ok(count) = stderr.read(&mut tmp) { + if count == 0 { + break; + } + if let Err(err) = std::io::copy(&mut &tmp[..count], &mut buf) { + log::error!("error copying to circular buffer: {}", err); + break; + } + } + + buf + }); + + let stdout = thread::spawn(move || { + let mut buf = TailBuffer::new(MAX_TAIL_LEN); + let mut tmp = [0u8; MAX_TAIL_LEN]; + + while let Ok(count) = stdout.read(&mut tmp) { + if count == 0 { + break; + } + + if let Err(err) = std::io::copy(&mut &tmp[..count], &mut buf) { + log::error!("error copying to circular buffer: {}", err); + break; + } + } + + buf + }); + + Self { stderr, stdout } + } + + pub fn join(self) -> Result { + let stderr = self + .stderr + .join() + .map_err(|_| format_err!("stderr tail thread panicked"))? + .to_string_lossy(); + let stdout = self + .stdout + .join() + .map_err(|_| format_err!("stdout tail thread panicked"))? + .to_string_lossy(); + + Ok(CapturedStreams { stderr, stdout }) + } +} + +impl IWorkerChild for RedirectedChild { fn try_wait(&mut self) -> Result> { - let output = if let Some(exit_status) = self.try_wait()? { + let output = if let Some(exit_status) = self.child.try_wait()? { let exit_status = exit_status.into(); - let stderr = read_to_string(&mut self.stderr)?; - let stdout = read_to_string(&mut self.stdout)?; + let streams = self.streams.take(); + let streams = streams + .ok_or_else(|| format_err!("onefuzz-agent streams not captured"))? + .join()?; Some(Output { exit_status, - stderr, - stdout, + stderr: streams.stderr, + stdout: streams.stdout, }) } else { None @@ -260,7 +358,7 @@ impl IWorkerChild for Child { fn kill(&mut self) -> Result<()> { use std::io::ErrorKind; - let killed = self.kill(); + let killed = self.child.kill(); if let Err(err) = &killed { if let ErrorKind::InvalidInput = err.kind() { @@ -273,15 +371,6 @@ impl IWorkerChild for Child { } } -fn read_to_string(stream: &mut Option) -> Result { - let mut data = Vec::new(); - if let Some(stream) = stream { - stream.read_to_end(&mut data)?; - } - - Ok(String::from_utf8_lossy(&data).into_owned()) -} - #[cfg(test)] pub mod double; diff --git a/src/agent/onefuzz-supervisor/src/worker/tests.rs b/src/agent/onefuzz-supervisor/src/worker/tests.rs index f22767095..0e168ba20 100644 --- a/src/agent/onefuzz-supervisor/src/worker/tests.rs +++ b/src/agent/onefuzz-supervisor/src/worker/tests.rs @@ -226,3 +226,65 @@ async fn test_worker_done() { assert!(matches!(worker, Worker::Done(..))); assert_eq!(events, vec![]); } + +#[cfg(target_os = "linux")] +#[test] +fn test_redirected_child() { + use std::iter::repeat; + use std::process::Command; + + // Assume OS pipe capacity of 16 pages, each 4096 bytes. + // + // For each stream, + // + // 1. Write enough of one char to fill up the OS pipe. + // 2. Write a smaller count (< tail size) of another char to overflow it. + // + // Our tailing buffer has size 4096, so we will expect to see all of the + // bytes from the second char, and the remainder from the first char. + let script = "import sys;\ +sys.stdout.write('A' * 65536 + 'B' * 4000);\ +sys.stderr.write('C' * 65536 + 'D' * 4000)"; + + let mut cmd = Command::new("python3"); + cmd.args(&["-c", script]); + + let mut redirected = RedirectedChild::spawn(cmd).unwrap(); + redirected.child.wait().unwrap(); + let captured = redirected.streams.unwrap().join().unwrap(); + + let stdout: String = repeat('A').take(96).chain(repeat('B').take(4000)).collect(); + assert_eq!(captured.stdout, stdout); + + let stderr: String = repeat('C').take(96).chain(repeat('D').take(4000)).collect(); + assert_eq!(captured.stderr, stderr); +} + +#[cfg(target_os = "windows")] +#[test] +fn test_redirected_child() { + use std::iter::repeat; + use std::process::Command; + + // Only write to stdout. + let script = "Write-Output ('A' * 65536 + 'B' * 4000)"; + + let mut cmd = Command::new("powershell.exe"); + cmd.args(&[ + "-NonInteractive", + "-ExecutionPolicy", + "Unrestricted", + "-Command", + script, + ]); + + let mut redirected = RedirectedChild::spawn(cmd).unwrap(); + redirected.child.wait().unwrap(); + let captured = redirected.streams.unwrap().join().unwrap(); + + let mut stdout: String = repeat('A').take(94).chain(repeat('B').take(4000)).collect(); + stdout.push_str("\r\n"); + assert_eq!(captured.stdout, stdout); + + assert_eq!(captured.stderr, ""); +}