mirror of
https://github.com/zerotier/ZeroTierOne.git
synced 2024-12-23 06:42:23 +00:00
173 lines
5.4 KiB
Rust
173 lines
5.4 KiB
Rust
|
use futures::join;
|
||
|
use lazy_static::lazy_static;
|
||
|
use native_tls::{Certificate, Identity};
|
||
|
use std::{fs, io::Error, path::PathBuf, process::Command};
|
||
|
use tokio::{
|
||
|
io::{AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||
|
net::{TcpListener, TcpStream},
|
||
|
};
|
||
|
use tokio_native_tls::{TlsAcceptor, TlsConnector};
|
||
|
|
||
|
lazy_static! {
|
||
|
static ref CERT_DIR: PathBuf = {
|
||
|
if cfg!(unix) {
|
||
|
let dir = tempfile::TempDir::new().unwrap();
|
||
|
let path = dir.path().to_str().unwrap();
|
||
|
|
||
|
Command::new("sh")
|
||
|
.arg("-c")
|
||
|
.arg(format!("./scripts/generate-certificate.sh {}", path))
|
||
|
.output()
|
||
|
.expect("failed to execute process");
|
||
|
|
||
|
dir.into_path()
|
||
|
} else {
|
||
|
PathBuf::from("tests")
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn client_to_server() {
|
||
|
let srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
|
let addr = srv.local_addr().unwrap();
|
||
|
|
||
|
let (server_tls, client_tls) = context();
|
||
|
|
||
|
// Create a future to accept one socket, connect the ssl stream, and then
|
||
|
// read all the data from it.
|
||
|
let server = async move {
|
||
|
let (socket, _) = srv.accept().await.unwrap();
|
||
|
let mut socket = server_tls.accept(socket).await.unwrap();
|
||
|
|
||
|
// Verify access to all of the nested inner streams (e.g. so that peer
|
||
|
// certificates can be accessed). This is just a compile check.
|
||
|
let native_tls_stream: &native_tls::TlsStream<_> = socket.get_ref();
|
||
|
let _peer_cert = native_tls_stream.peer_certificate().unwrap();
|
||
|
let allow_std_stream: &tokio_native_tls::AllowStd<_> = native_tls_stream.get_ref();
|
||
|
let _tokio_tcp_stream: &tokio::net::TcpStream = allow_std_stream.get_ref();
|
||
|
|
||
|
let mut data = Vec::new();
|
||
|
socket.read_to_end(&mut data).await.unwrap();
|
||
|
data
|
||
|
};
|
||
|
|
||
|
// Create a future to connect to our server, connect the ssl stream, and
|
||
|
// then write a bunch of data to it.
|
||
|
let client = async move {
|
||
|
let socket = TcpStream::connect(&addr).await.unwrap();
|
||
|
let socket = client_tls.connect("foobar.com", socket).await.unwrap();
|
||
|
copy_data(socket).await
|
||
|
};
|
||
|
|
||
|
// Finally, run everything!
|
||
|
let (data, _) = join!(server, client);
|
||
|
// assert_eq!(amt, AMT);
|
||
|
assert!(data == vec![9; AMT]);
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn server_to_client() {
|
||
|
// Create a server listening on a port, then figure out what that port is
|
||
|
let srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
|
let addr = srv.local_addr().unwrap();
|
||
|
|
||
|
let (server_tls, client_tls) = context();
|
||
|
|
||
|
let server = async move {
|
||
|
let (socket, _) = srv.accept().await.unwrap();
|
||
|
let socket = server_tls.accept(socket).await.unwrap();
|
||
|
copy_data(socket).await
|
||
|
};
|
||
|
|
||
|
let client = async move {
|
||
|
let socket = TcpStream::connect(&addr).await.unwrap();
|
||
|
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
|
||
|
let mut data = Vec::new();
|
||
|
socket.read_to_end(&mut data).await.unwrap();
|
||
|
data
|
||
|
};
|
||
|
|
||
|
// Finally, run everything!
|
||
|
let (_, data) = join!(server, client);
|
||
|
assert!(data == vec![9; AMT]);
|
||
|
}
|
||
|
|
||
|
#[tokio::test]
|
||
|
async fn one_byte_at_a_time() {
|
||
|
const AMT: usize = 1024;
|
||
|
|
||
|
let srv = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||
|
let addr = srv.local_addr().unwrap();
|
||
|
|
||
|
let (server_tls, client_tls) = context();
|
||
|
|
||
|
let server = async move {
|
||
|
let (socket, _) = srv.accept().await.unwrap();
|
||
|
let mut socket = server_tls.accept(socket).await.unwrap();
|
||
|
let mut amt = 0;
|
||
|
for b in std::iter::repeat(9).take(AMT) {
|
||
|
let data = [b as u8];
|
||
|
socket.write_all(&data).await.unwrap();
|
||
|
amt += 1;
|
||
|
}
|
||
|
amt
|
||
|
};
|
||
|
|
||
|
let client = async move {
|
||
|
let socket = TcpStream::connect(&addr).await.unwrap();
|
||
|
let mut socket = client_tls.connect("foobar.com", socket).await.unwrap();
|
||
|
let mut data = Vec::new();
|
||
|
loop {
|
||
|
let mut buf = [0; 1];
|
||
|
match socket.read_exact(&mut buf).await {
|
||
|
Ok(_) => data.extend_from_slice(&buf),
|
||
|
Err(ref err) if err.kind() == std::io::ErrorKind::UnexpectedEof => break,
|
||
|
Err(err) => panic!(err),
|
||
|
}
|
||
|
}
|
||
|
data
|
||
|
};
|
||
|
|
||
|
let (amt, data) = join!(server, client);
|
||
|
assert_eq!(amt, AMT);
|
||
|
assert!(data == vec![9; AMT as usize]);
|
||
|
}
|
||
|
|
||
|
fn context() -> (TlsAcceptor, TlsConnector) {
|
||
|
let pkcs12 = fs::read(CERT_DIR.join("identity.p12")).unwrap();
|
||
|
let der = fs::read(CERT_DIR.join("root-ca.der")).unwrap();
|
||
|
|
||
|
let identity = Identity::from_pkcs12(&pkcs12, "mypass").unwrap();
|
||
|
let acceptor = native_tls::TlsAcceptor::builder(identity).build().unwrap();
|
||
|
|
||
|
let cert = Certificate::from_der(&der).unwrap();
|
||
|
let connector = native_tls::TlsConnector::builder()
|
||
|
.add_root_certificate(cert)
|
||
|
.build()
|
||
|
.unwrap();
|
||
|
|
||
|
(acceptor.into(), connector.into())
|
||
|
}
|
||
|
|
||
|
const AMT: usize = 128 * 1024;
|
||
|
|
||
|
async fn copy_data<W: AsyncWrite + Unpin>(mut w: W) -> Result<usize, Error> {
|
||
|
let mut data = vec![9; AMT as usize];
|
||
|
let mut amt = 0;
|
||
|
while !data.is_empty() {
|
||
|
let written = w.write(&data).await?;
|
||
|
if written <= data.len() {
|
||
|
amt += written;
|
||
|
data.resize(data.len() - written, 0);
|
||
|
} else {
|
||
|
w.write_all(&data).await?;
|
||
|
amt += data.len();
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
println!("remaining: {}", data.len());
|
||
|
}
|
||
|
Ok(amt)
|
||
|
}
|