diff --git a/Cargo.lock b/Cargo.lock index cb99b7f8..662c9483 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -173,6 +173,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "ascii-canvas" version = "3.0.0" @@ -524,18 +530,19 @@ dependencies = [ [[package]] name = "axum" -version = "0.6.20" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf" +checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" dependencies = [ "async-trait", "axum-core", - "bitflags 1.3.2", "bytes", "futures-util", - "http 0.2.12", - "http-body 0.4.6", - "hyper 0.14.30", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-util", "itoa", "matchit", "memchr", @@ -544,27 +551,60 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", - "sync_wrapper 0.1.2", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] name = "axum-core" -version = "0.3.4" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c" +checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" dependencies = [ "async-trait", "bytes", "futures-util", - "http 0.2.12", - "http-body 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", "mime", + "pin-project-lite", "rustversion", + "sync_wrapper 0.1.2", "tower-layer", "tower-service", + "tracing", +] + +[[package]] +name = "axum-server" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56bac90848f6a9393ac03c63c640925c4b7c8ca21654de40d53f55964667c7d8" +dependencies = [ + "arc-swap", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "pin-project-lite", + "rustls 0.23.12", + "rustls-pemfile", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.0", + "tower", + "tower-service", ] [[package]] @@ -779,6 +819,8 @@ dependencies = [ "async-recursion", "async-trait", "aws-sign-v4", + "axum", + "axum-server", "backend", "base64 0.22.1", "bigdecimal", @@ -800,11 +842,10 @@ dependencies = [ "handlebars", "hex", "hmac", - "http 0.2.12", - "http-body 0.4.6", + "http 1.1.0", + "http-body 1.0.1", "httpmock", "humantime-serde", - "hyper 0.14.30", "jsonwebtoken", "lapin", "lazy_static", @@ -857,7 +898,6 @@ dependencies = [ "tracing-subscriber", "urlencoding", "uuid", - "warp", "x509-parser", ] @@ -960,7 +1000,7 @@ version = "4.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.72", @@ -1380,7 +1420,7 @@ checksum = "c5d9abe6314103864cc2d8901b7ae224e0ab1a103a0a416661b4097b0779b607" dependencies = [ "darling", "either", - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.72", @@ -1837,25 +1877,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "h2" -version = "0.3.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http 0.2.12", - "indexmap 2.2.6", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "h2" version = "0.4.5" @@ -1901,36 +1922,6 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -[[package]] -name = "headers" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" -dependencies = [ - "base64 0.21.7", - "bytes", - "headers-core", - "http 0.2.12", - "httpdate", - "mime", - "sha1", -] - -[[package]] -name = "headers-core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" -dependencies = [ - "http 0.2.12", -] - -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -2038,12 +2029,6 @@ dependencies = [ "pin-project-lite", ] -[[package]] -name = "http-range-header" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f" - [[package]] name = "httparse" version = "1.9.4" @@ -2110,7 +2095,6 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -2133,10 +2117,11 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.5", + "h2", "http 1.1.0", "http-body 1.0.1", "httparse", + "httpdate", "itoa", "pin-project-lite", "smallvec", @@ -2165,14 +2150,15 @@ dependencies = [ [[package]] name = "hyper-timeout" -version = "0.4.1" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1" +checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793" dependencies = [ - "hyper 0.14.30", + "hyper 1.4.1", + "hyper-util", "pin-project-lite", "tokio", - "tokio-io-timeout", + "tower-service", ] [[package]] @@ -2324,6 +2310,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -2898,9 +2893,9 @@ dependencies = [ [[package]] name = "pbjson" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1030c719b0ec2a2d25a5df729d6cff1acf3cc230bf766f4f97833591f7577b90" +checksum = "c7e6349fa080353f4a597daffd05cb81572a9c031a6d4fff7e504947496fcc68" dependencies = [ "base64 0.21.7", "serde", @@ -2908,21 +2903,21 @@ dependencies = [ [[package]] name = "pbjson-build" -version = "0.6.2" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2580e33f2292d34be285c5bc3dba5259542b083cfad6037b6d70345f24dcb735" +checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9" dependencies = [ - "heck 0.4.1", - "itertools 0.11.0", + "heck", + "itertools 0.13.0", "prost", "prost-types", ] [[package]] name = "pbjson-types" -version = "0.6.0" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18f596653ba4ac51bdecbb4ef6773bc7f56042dc13927910de1684ad3d32aa12" +checksum = "e54e5e7bfb1652f95bc361d76f3c780d8e526b134b85417e774166ee941f0887" dependencies = [ "bytes", "chrono", @@ -3305,9 +3300,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.12.6" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29" +checksum = "e13db3d3fde688c61e2446b4d843bc27a7e8af269a69440c0308021dc92333cc" dependencies = [ "bytes", "prost-derive", @@ -3315,13 +3310,13 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.12.6" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" +checksum = "5bb182580f71dd070f88d01ce3de9f4da5021db7115d2e1c3605a754153b77c1" dependencies = [ "bytes", - "heck 0.5.0", - "itertools 0.12.1", + "heck", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -3336,12 +3331,12 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.12.6" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" +checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca" dependencies = [ "anyhow", - "itertools 0.12.1", + "itertools 0.13.0", "proc-macro2", "quote", "syn 2.0.72", @@ -3349,9 +3344,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.12.6" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0" +checksum = "cee5168b05f49d4b0ca581206eb14a7b22fafd963efe729ac48eb03266e25cc2" dependencies = [ "prost", ] @@ -3946,12 +3941,6 @@ dependencies = [ "pin-utils", ] -[[package]] -name = "scoped-tls" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" - [[package]] name = "scopeguard" version = "1.2.0" @@ -4135,17 +4124,6 @@ dependencies = [ "unsafe-libyaml", ] -[[package]] -name = "sha1" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha1_smol" version = "1.0.1" @@ -4514,16 +4492,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-io-timeout" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf" -dependencies = [ - "pin-project-lite", - "tokio", -] - [[package]] name = "tokio-macros" version = "2.4.0" @@ -4682,23 +4650,26 @@ dependencies = [ [[package]] name = "tonic" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13" +checksum = "38659f4a91aba8598d27821589f5db7dddd94601e7a01b1e485a50e5484c7401" dependencies = [ "async-stream", "async-trait", "axum", - "base64 0.21.7", + "base64 0.22.1", "bytes", - "h2 0.3.26", - "http 0.2.12", - "http-body 0.4.6", - "hyper 0.14.30", + "h2", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", "hyper-timeout", + "hyper-util", "percent-encoding", "pin-project", "prost", + "socket2 0.5.7", "tokio", "tokio-stream", "tower", @@ -4709,9 +4680,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4ef6dd70a610078cb4e338a0f79d06bc759ff1b22d2120c2ff02ae264ba9c2" +checksum = "568392c5a2bd0020723e3f387891176aabafe36fd9fcd074ad309dfa0c8eb964" dependencies = [ "prettyplease", "proc-macro2", @@ -4722,9 +4693,9 @@ dependencies = [ [[package]] name = "tonic-reflection" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "548c227bd5c0fae5925812c4ec6c66ffcfced23ea370cb823f4d18f0fc1cb6a7" +checksum = "b742c83ad673e9ab5b4ce0981f7b9e8932be9d60e8682cbf9120494764dbc173" dependencies = [ "prost", "prost-types", @@ -4735,15 +4706,15 @@ dependencies = [ [[package]] name = "tonic-web" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc3b0e1cedbf19fdfb78ef3d672cb9928e0a91a9cb4629cc0c916e8cff8aaaa1" +checksum = "8dc0e36ac436560b9a8c9edad4521cf5dd5deb1af591936db9660191b6ecf619" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "bytes", - "http 0.2.12", - "http-body 0.4.6", - "hyper 0.14.30", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", "pin-project", "tokio-stream", "tonic", @@ -4775,18 +4746,16 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.4.4" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" dependencies = [ "base64 0.21.7", "bitflags 2.6.0", "bytes", - "futures-core", - "futures-util", - "http 0.2.12", - "http-body 0.4.6", - "http-range-header", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", "mime", "pin-project-lite", "tower-layer", @@ -5042,35 +5011,6 @@ dependencies = [ "try-lock", ] -[[package]] -name = "warp" -version = "0.3.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4378d202ff965b011c64817db11d5829506d3404edeadb61f190d111da3f231c" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "headers", - "http 0.2.12", - "hyper 0.14.30", - "log", - "mime", - "mime_guess", - "percent-encoding", - "pin-project", - "rustls-pemfile", - "scoped-tls", - "serde", - "serde_json", - "serde_urlencoded", - "tokio", - "tokio-rustls 0.25.0", - "tokio-util", - "tower-service", - "tracing", -] - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" diff --git a/api/rust/Cargo.toml b/api/rust/Cargo.toml index e16333a4..49634087 100644 --- a/api/rust/Cargo.toml +++ b/api/rust/Cargo.toml @@ -16,23 +16,23 @@ internal = [] [dependencies] - prost = "0.12" - prost-types = "0.12" + prost = "0.13" + prost-types = "0.13" hex = "0.4" rand = "0.8" - tonic = { version = "0.11", features = [ + tonic = { version = "0.12", features = [ "codegen", "prost", ], default-features = false, optional = true } tokio = { version = "1.38", features = ["macros"], optional = true } - pbjson = { version = "0.6", optional = true } - pbjson-types = { version = "0.6", optional = true } + pbjson = { version = "0.7", optional = true } + pbjson-types = { version = "0.7", optional = true } serde = { version = "1.0", optional = true } diesel = { version = "2.2", features = ["postgres_backend"], optional = true } [build-dependencies] - tonic-build = { version = "0.11", features = [ + tonic-build = { version = "0.12", features = [ "prost", ], default-features = false } - pbjson-build = "0.6" + pbjson-build = "0.7" diff --git a/chirpstack/Cargo.toml b/chirpstack/Cargo.toml index bf3ab6d8..5fae4b4b 100644 --- a/chirpstack/Cargo.toml +++ b/chirpstack/Cargo.toml @@ -88,26 +88,26 @@ ] } # gRPC and Protobuf - tonic = "0.11" - tonic-web = "0.11" - tonic-reflection = "0.11" + tonic = "0.12" + tonic-web = "0.12" + tonic-reflection = "0.12" tokio = { version = "1.38", features = ["macros", "rt-multi-thread"] } tokio-stream = "0.1" - prost-types = "0.12" - prost = "0.12" - pbjson-types = "0.6" + prost-types = "0.13" + prost = "0.13" + pbjson-types = "0.7" # gRPC and HTTP multiplexing - warp = { version = "0.3", features = ["tls"], default-features = false } - hyper = "0.14" - tower = "0.4" + axum = "0.7" + axum-server = { version = "0.7.1", features = ["tls-rustls-no-provider"] } + tower = { version = "0.4" } futures = "0.3" futures-util = "0.3" - http = "0.2" - http-body = "0.4" + http = "1.1" + http-body = "1.0" rust-embed = "8.5" mime_guess = "2.0" - tower-http = { version = "0.4", features = ["trace", "auth"] } + tower-http = { version = "0.5", features = ["trace", "auth"] } # Error handling thiserror = "1.0" diff --git a/chirpstack/src/api/auth/mod.rs b/chirpstack/src/api/auth/mod.rs index 427c6035..f3258f2a 100644 --- a/chirpstack/src/api/auth/mod.rs +++ b/chirpstack/src/api/auth/mod.rs @@ -6,7 +6,7 @@ pub mod claims; pub mod error; pub mod validator; -#[derive(PartialEq, Eq, Debug)] +#[derive(PartialEq, Eq, Debug, Clone)] pub enum AuthID { None, User(Uuid), diff --git a/chirpstack/src/api/backend/mod.rs b/chirpstack/src/api/backend/mod.rs index 28d46272..9748cb72 100644 --- a/chirpstack/src/api/backend/mod.rs +++ b/chirpstack/src/api/backend/mod.rs @@ -5,18 +5,28 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Result; +use axum::{ + body::Bytes, + response::{IntoResponse, Json, Response}, + Router, +}; use chrono::Utc; +use http::StatusCode; use redis::streams::StreamReadReply; +use rustls::{ + server::{NoClientAuth, WebPkiClientVerifier}, + ServerConfig, +}; use serde::Serialize; use tokio::sync::oneshot; use tokio::task; use tracing::{error, info, span, warn, Instrument, Level}; use uuid::Uuid; -use warp::{http::StatusCode, Filter, Reply}; use crate::backend::{joinserver, keywrap, roaming}; use crate::downlink::data_fns; use crate::helpers::errors::PrintFullError; +use crate::helpers::tls::{get_root_certs, load_cert, load_key}; use crate::storage::{ device, error::Error as StorageError, get_async_redis_conn, passive_roaming, redis_key, }; @@ -39,47 +49,47 @@ pub async fn setup() -> Result<()> { let addr: SocketAddr = conf.backend_interfaces.bind.parse()?; info!(bind = %conf.backend_interfaces.bind, "Setting up backend interfaces API"); - let routes = warp::post() - .and(warp::body::content_length_limit(1024 * 16)) - .and(warp::body::aggregate()) - .then(handle_request); + let app = Router::new().fallback(handle_request); if !conf.backend_interfaces.ca_cert.is_empty() || !conf.backend_interfaces.tls_cert.is_empty() || !conf.backend_interfaces.tls_key.is_empty() { - let mut w = warp::serve(routes).tls(); - if !conf.backend_interfaces.ca_cert.is_empty() { - w = w.client_auth_required_path(&conf.backend_interfaces.ca_cert); - } - if !conf.backend_interfaces.tls_cert.is_empty() { - w = w.cert_path(&conf.backend_interfaces.tls_cert); - } - if !conf.backend_interfaces.tls_key.is_empty() { - w = w.key_path(&conf.backend_interfaces.tls_key); - } - w.run(addr).await; + let mut server_config = ServerConfig::builder() + .with_client_cert_verifier(if conf.backend_interfaces.ca_cert.is_empty() { + Arc::new(NoClientAuth) + } else { + let root_certs = get_root_certs(Some(conf.backend_interfaces.ca_cert.clone()))?; + WebPkiClientVerifier::builder(root_certs.into()).build()? + }) + .with_single_cert( + load_cert(&conf.backend_interfaces.tls_cert).await?, + load_key(&conf.backend_interfaces.tls_key).await?, + )?; + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + axum_server::bind_rustls( + addr, + axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(server_config)), + ) + .serve(app.into_make_service()) + .await?; } else { - warp::serve(routes).run(addr).await; + axum_server::bind(addr) + .serve(app.into_make_service()) + .await?; } Ok(()) } -pub async fn handle_request(mut body: impl warp::Buf) -> http::Response<hyper::Body> { - let mut b: Vec<u8> = vec![]; - - while body.has_remaining() { - b.extend_from_slice(body.chunk()); - let cnt = body.chunk().len(); - body.advance(cnt); - } +pub async fn handle_request(b: Bytes) -> Response { + let b: Vec<u8> = b.into(); let bp: BasePayload = match serde_json::from_slice(&b) { Ok(v) => v, Err(e) => { - return warp::reply::with_status(e.to_string(), StatusCode::BAD_REQUEST) - .into_response(); + return (StatusCode::BAD_REQUEST, e.to_string()).into_response(); } }; @@ -87,7 +97,7 @@ pub async fn handle_request(mut body: impl warp::Buf) -> http::Response<hyper::B _handle_request(bp, b).instrument(span).await } -pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hyper::Body> { +pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> Response { info!("Request received"); let sender_client = { @@ -100,7 +110,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype let msg = format!("Error decoding SenderID: {}", e); let pl = bp.to_base_payload_result(backend::ResultCode::MalformedRequest, &msg); log_request_response(&bp, &b, &pl).await; - return warp::reply::json(&pl).into_response(); + return Json(&pl).into_response(); } }; @@ -111,7 +121,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype let msg = format!("Unknown SenderID: {}", sender_id); let pl = bp.to_base_payload_result(backend::ResultCode::UnknownSender, &msg); log_request_response(&bp, &b, &pl).await; - return warp::reply::json(&pl).into_response(); + return Json(&pl).into_response(); } } } else if bp.sender_id.len() == 3 { @@ -123,7 +133,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype let msg = format!("Error decoding SenderID: {}", e); let pl = bp.to_base_payload_result(backend::ResultCode::MalformedRequest, &msg); log_request_response(&bp, &b, &pl).await; - return warp::reply::json(&pl).into_response(); + return Json(&pl).into_response(); } }; @@ -134,7 +144,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype let msg = format!("Unknown SenderID: {}", sender_id); let pl = bp.to_base_payload_result(backend::ResultCode::UnknownSender, &msg); log_request_response(&bp, &b, &pl).await; - return warp::reply::json(&pl).into_response(); + return Json(&pl).into_response(); } } } else { @@ -145,7 +155,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype "Invalid SenderID length", ); log_request_response(&bp, &b, &pl).await; - return warp::reply::json(&pl).into_response(); + return Json(&pl).into_response(); } }; @@ -156,7 +166,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype error!(error = %e.full(), "Handle async answer error"); } }); - return warp::reply::with_status("", StatusCode::OK).into_response(); + return (StatusCode::OK, "").into_response(); } match bp.message_type { @@ -165,11 +175,11 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype MessageType::XmitDataReq => handle_xmit_data_req(sender_client, bp, &b).await, MessageType::HomeNSReq => handle_home_ns_req(sender_client, bp, &b).await, // Unknown message - _ => warp::reply::with_status( - "Handler for {:?} is not implemented", + _ => ( StatusCode::INTERNAL_SERVER_ERROR, + format!("Handler for {:?} is not implemented", bp.message_type), ) - .into_response(), + .into_response(), } } @@ -201,7 +211,7 @@ async fn handle_pr_start_req( sender_client: Arc<backend::Client>, bp: backend::BasePayload, b: &[u8], -) -> http::Response<hyper::Body> { +) -> Response { if sender_client.is_async() { let b = b.to_vec(); task::spawn(async move { @@ -222,18 +232,17 @@ async fn handle_pr_start_req( error!(error = %e.full(), transaction_id = bp.transaction_id, "Send async PRStartAns error"); } }); - - warp::reply::with_status("", StatusCode::OK).into_response() + (StatusCode::OK, "").into_response() } else { match _handle_pr_start_req(b).await { Ok(ans) => { log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } Err(e) => { let ans = err_to_response(e, &bp); log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } } } @@ -363,7 +372,7 @@ async fn handle_pr_stop_req( sender_client: Arc<backend::Client>, bp: backend::BasePayload, b: &[u8], -) -> http::Response<hyper::Body> { +) -> Response { if sender_client.is_async() { let b = b.to_vec(); task::spawn(async move { @@ -383,18 +392,17 @@ async fn handle_pr_stop_req( error!(error = %e.full(), "Send async PRStopAns error"); } }); - - warp::reply::with_status("", StatusCode::OK).into_response() + (StatusCode::OK, "").into_response() } else { match _handle_pr_stop_req(b).await { Ok(ans) => { log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } Err(e) => { let ans = err_to_response(e, &bp); log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } } } @@ -430,13 +438,13 @@ async fn handle_xmit_data_req( sender_client: Arc<backend::Client>, bp: backend::BasePayload, b: &[u8], -) -> http::Response<hyper::Body> { +) -> Response { let pl: backend::XmitDataReqPayload = match serde_json::from_slice(b) { Ok(v) => v, Err(e) => { let ans = err_to_response(anyhow::Error::new(e), &bp); log_request_response(&bp, b, &ans).await; - return warp::reply::json(&ans).into_response(); + return Json(&ans).into_response(); } }; @@ -465,18 +473,17 @@ async fn handle_xmit_data_req( error!(error = %e.full(), "Send async XmitDataAns error"); } }); - - warp::reply::with_status("", StatusCode::OK).into_response() + (StatusCode::OK, "").into_response() } else { match _handle_xmit_data_req(pl).await { Ok(ans) => { log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } Err(e) => { let ans = err_to_response(e, &bp); log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } } } @@ -529,13 +536,13 @@ async fn handle_home_ns_req( sender_client: Arc<backend::Client>, bp: backend::BasePayload, b: &[u8], -) -> http::Response<hyper::Body> { +) -> Response { let pl: backend::HomeNSReqPayload = match serde_json::from_slice(b) { Ok(v) => v, Err(e) => { let ans = err_to_response(anyhow::Error::new(e), &bp); log_request_response(&bp, b, &ans).await; - return warp::reply::json(&ans).into_response(); + return Json(&ans).into_response(); } }; @@ -560,17 +567,17 @@ async fn handle_home_ns_req( } }); - warp::reply::with_status("", StatusCode::OK).into_response() + (StatusCode::OK, "").into_response() } else { match _handle_home_ns_req(pl).await { Ok(ans) => { log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } Err(e) => { let ans = err_to_response(e, &bp); log_request_response(&bp, b, &ans).await; - warp::reply::json(&ans).into_response() + Json(&ans).into_response() } } } @@ -587,7 +594,7 @@ async fn _handle_home_ns_req(pl: backend::HomeNSReqPayload) -> Result<backend::H }) } -async fn handle_async_ans(bp: &BasePayload, b: &[u8]) -> Result<http::Response<hyper::Body>> { +async fn handle_async_ans(bp: &BasePayload, b: &[u8]) -> Result<Response> { let transaction_id = bp.transaction_id; let key = redis_key(format!("backend:async:{}", transaction_id)); @@ -609,7 +616,7 @@ async fn handle_async_ans(bp: &BasePayload, b: &[u8]) -> Result<http::Response<h .query_async(&mut get_async_redis_conn().await?) .await?; - Ok(warp::reply().into_response()) + Ok((StatusCode::OK, "").into_response()) } pub async fn get_async_receiver( diff --git a/chirpstack/src/api/grpc_multiplex.rs b/chirpstack/src/api/grpc_multiplex.rs new file mode 100644 index 00000000..1f64b773 --- /dev/null +++ b/chirpstack/src/api/grpc_multiplex.rs @@ -0,0 +1,141 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::ready; +use http::{header::CONTENT_TYPE, Request, Response}; +use http_body::Body; +use pin_project::pin_project; +use tower::{Layer, Service}; + +type BoxError = Box<dyn std::error::Error + Send + Sync>; + +#[pin_project(project = GrpcMultiplexFutureEnumProj)] +enum GrpcMultiplexFutureEnum<FS, FO> { + Grpc { + #[pin] + future: FS, + }, + Other { + #[pin] + future: FO, + }, +} + +#[pin_project] +pub struct GrpcMultiplexFuture<FS, FO> { + #[pin] + future: GrpcMultiplexFutureEnum<FS, FO>, +} + +impl<ResBody, FS, FO, ES, EO> Future for GrpcMultiplexFuture<FS, FO> +where + ResBody: Body, + FS: Future<Output = Result<Response<ResBody>, ES>>, + FO: Future<Output = Result<Response<ResBody>, EO>>, + ES: Into<BoxError> + Send, + EO: Into<BoxError> + Send, +{ + type Output = Result<Response<ResBody>, Box<dyn std::error::Error + Send + Sync + 'static>>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { + let this = self.project(); + match this.future.project() { + GrpcMultiplexFutureEnumProj::Grpc { future } => future.poll(cx).map_err(Into::into), + GrpcMultiplexFutureEnumProj::Other { future } => future.poll(cx).map_err(Into::into), + } + } +} + +#[derive(Debug, Clone)] +pub struct GrpcMultiplexService<S, O> { + grpc: S, + other: O, + grpc_ready: bool, + other_ready: bool, +} + +impl<ReqBody, ResBody, S, O> Service<Request<ReqBody>> for GrpcMultiplexService<S, O> +where + ResBody: Body, + S: Service<Request<ReqBody>, Response = Response<ResBody>>, + O: Service<Request<ReqBody>, Response = Response<ResBody>>, + S::Error: Into<BoxError> + Send, + O::Error: Into<BoxError> + Send, +{ + type Response = S::Response; + type Error = Box<dyn std::error::Error + Send + Sync + 'static>; + type Future = GrpcMultiplexFuture<S::Future, O::Future>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { + loop { + match (self.grpc_ready, self.other_ready) { + (true, true) => { + return Ok(()).into(); + } + (false, _) => { + ready!(self.grpc.poll_ready(cx)).map_err(Into::into)?; + self.grpc_ready = true; + } + (_, false) => { + ready!(self.other.poll_ready(cx)).map_err(Into::into)?; + self.other_ready = true; + } + } + } + } + + fn call(&mut self, request: Request<ReqBody>) -> Self::Future { + assert!(self.grpc_ready); + assert!(self.other_ready); + + if is_grpc_request(&request) { + GrpcMultiplexFuture { + future: GrpcMultiplexFutureEnum::Grpc { + future: self.grpc.call(request), + }, + } + } else { + GrpcMultiplexFuture { + future: GrpcMultiplexFutureEnum::Other { + future: self.other.call(request), + }, + } + } + } +} + +#[derive(Debug, Clone)] +pub struct GrpcMultiplexLayer<O> { + other: O, +} + +impl<O> GrpcMultiplexLayer<O> { + pub fn new(other: O) -> Self { + Self { other } + } +} + +impl<S, O> Layer<S> for GrpcMultiplexLayer<O> +where + O: Clone, +{ + type Service = GrpcMultiplexService<S, O>; + + fn layer(&self, grpc: S) -> Self::Service { + GrpcMultiplexService { + grpc, + other: self.other.clone(), + grpc_ready: false, + other_ready: false, + } + } +} + +fn is_grpc_request<B>(req: &Request<B>) -> bool { + req.headers() + .get(CONTENT_TYPE) + .map(|content_type| content_type.as_bytes()) + .filter(|content_type| content_type.starts_with(b"application/grpc")) + .is_some() +} diff --git a/chirpstack/src/api/mod.rs b/chirpstack/src/api/mod.rs index a4a903d8..773860e7 100644 --- a/chirpstack/src/api/mod.rs +++ b/chirpstack/src/api/mod.rs @@ -1,4 +1,3 @@ -use std::convert::Infallible; use std::time::{Duration, Instant}; use std::{ future::Future, @@ -7,23 +6,27 @@ use std::{ }; use anyhow::Result; -use futures::future::{self, Either, TryFutureExt}; -use hyper::{service::make_service_fn, Server}; +use axum::{response::IntoResponse, routing::get, Router}; +use http::{ + header::{self, HeaderMap, HeaderValue}, + Request, StatusCode, Uri, +}; use pin_project::pin_project; use prometheus_client::encoding::EncodeLabelSet; use prometheus_client::metrics::counter::Counter; use prometheus_client::metrics::family::Family; use prometheus_client::metrics::histogram::Histogram; use rust_embed::RustEmbed; -use tokio::{task, try_join}; +use tokio::task; +use tokio::try_join; use tonic::transport::Server as TonicServer; use tonic::Code; use tonic_reflection::server::Builder as TonicReflectionBuilder; use tonic_web::GrpcWebLayer; -use tower::{Service, ServiceBuilder}; +use tower::util::ServiceExt; +use tower::Service; use tower_http::trace::TraceLayer; use tracing::{error, info}; -use warp::{http::header::HeaderValue, path::Tail, reply::Response, Filter, Rejection, Reply}; use chirpstack_api::api::application_service_server::ApplicationServiceServer; use chirpstack_api::api::device_profile_service_server::DeviceProfileServiceServer; @@ -51,6 +54,7 @@ pub mod device_profile; pub mod device_profile_template; pub mod error; pub mod gateway; +mod grpc_multiplex; pub mod helpers; pub mod internal; pub mod monitoring; @@ -89,210 +93,120 @@ lazy_static! { }; } -type Error = Box<dyn std::error::Error + Send + Sync + 'static>; - #[derive(RustEmbed)] #[folder = "../ui/build"] struct Asset; +type BoxError = Box<dyn std::error::Error + Send + Sync>; + pub async fn setup() -> Result<()> { let conf = config::get(); - let addr = conf.api.bind.parse()?; + let bind = conf.api.bind.parse()?; - info!(bind = %conf.api.bind, "Setting up API interface"); + info!(bind = %bind, "Setting up API interface"); - // Taken from the tonic hyper_warp_multiplex example: - // https://github.com/hyperium/tonic/blob/master/examples/src/hyper_warp_multiplex/server.rs#L101 - let service = make_service_fn(move |_| { - // tonic gRPC service - let tonic_service = TonicServer::builder() - .accept_http1(true) - .layer(GrpcWebLayer::new()) - .add_service( - TonicReflectionBuilder::configure() - .register_encoded_file_descriptor_set(chirpstack_api::api::DESCRIPTOR) - .build() - .unwrap(), - ) - .add_service(InternalServiceServer::with_interceptor( - internal::Internal::new( - validator::RequestValidator::new(), - conf.api.secret.clone(), - ), - auth::auth_interceptor, - )) - .add_service(ApplicationServiceServer::with_interceptor( - application::Application::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .add_service(DeviceProfileServiceServer::with_interceptor( - device_profile::DeviceProfile::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .add_service(DeviceProfileTemplateServiceServer::with_interceptor( - device_profile_template::DeviceProfileTemplate::new( - validator::RequestValidator::new(), - ), - auth::auth_interceptor, - )) - .add_service(TenantServiceServer::with_interceptor( - tenant::Tenant::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .add_service(DeviceServiceServer::with_interceptor( - device::Device::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .add_service(UserServiceServer::with_interceptor( - user::User::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .add_service(GatewayServiceServer::with_interceptor( - gateway::Gateway::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .add_service(MulticastGroupServiceServer::with_interceptor( - multicast::MulticastGroup::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .add_service(RelayServiceServer::with_interceptor( - relay::Relay::new(validator::RequestValidator::new()), - auth::auth_interceptor, - )) - .into_service(); - let mut tonic_service = ServiceBuilder::new() - .layer( - TraceLayer::new_for_grpc() - .make_span_with(|req: &http::Request<hyper::Body>| { - tracing::info_span!( - "gRPC", - uri = %req.uri().path(), - ) - }) - .on_request(OnRequest {}) - .on_response(OnResponse {}), - ) - .layer(ApiLogger {}) - .service(tonic_service); + let web = Router::new() + .route("/auth/oidc/login", get(oidc::login_handler)) + .route("/auth/oidc/callback", get(oidc::callback_handler)) + .route("/auth/oauth2/login", get(oauth2::login_handler)) + .route("/auth/oauth2/callback", get(oauth2::callback_handler)) + .fallback(service_static_handler) + .into_service() + .map_response(|r| r.map(tonic::body::boxed)); - // HTTP service - let warp_service = warp::service( - warp::path!("auth" / "oidc" / "login") - .and_then(oidc::login_handler) - .or(warp::path!("auth" / "oidc" / "callback") - .and(warp::query::<oidc::CallbackArgs>()) - .and_then(oidc::callback_handler)) - .or(warp::path!("auth" / "oauth2" / "login").and_then(oauth2::login_handler)) - .or(warp::path!("auth" / "oauth2" / "callback") - .and(warp::query::<oauth2::CallbackArgs>()) - .and_then(oauth2::callback_handler)) - .or(warp::path::tail().and_then(http_serve)), - ); - let mut warp_service = ServiceBuilder::new() - .layer( - TraceLayer::new_for_http() - .make_span_with(|req: &http::Request<hyper::Body>| { - tracing::info_span!( - "http", - method = req.method().as_str(), - uri = %req.uri().path(), - version = ?req.version(), - ) - }) - .on_request(OnRequest {}) - .on_response(OnResponse {}), - ) - .service(warp_service); - - future::ok::<_, Infallible>(tower::service_fn( - move |req: hyper::Request<hyper::Body>| match req.method() { - &hyper::Method::GET => Either::Left( - warp_service - .call(req) - .map_ok(|res| res.map(EitherBody::Right)) - .map_err(Error::from), - ), - _ => Either::Right( - tonic_service - .call(req) - .map_ok(|res| res.map(EitherBody::Left)) - .map_err(Error::from), - ), - }, + let grpc = TonicServer::builder() + .accept_http1(true) + .layer( + TraceLayer::new_for_grpc() + .make_span_with(|req: &Request<_>| { + tracing::info_span!( + "gRPC", + uri = %req.uri().path(), + ) + }) + .on_request(OnRequest {}) + .on_response(OnResponse {}), + ) + .layer(grpc_multiplex::GrpcMultiplexLayer::new(web)) + .layer(ApiLoggerLayer {}) + .layer(GrpcWebLayer::new()) + .add_service( + TonicReflectionBuilder::configure() + .register_encoded_file_descriptor_set(chirpstack_api::api::DESCRIPTOR) + .build() + .unwrap(), + ) + .add_service(InternalServiceServer::with_interceptor( + internal::Internal::new(validator::RequestValidator::new(), conf.api.secret.clone()), + auth::auth_interceptor, )) - }); + .add_service(ApplicationServiceServer::with_interceptor( + application::Application::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(DeviceProfileServiceServer::with_interceptor( + device_profile::DeviceProfile::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(DeviceProfileTemplateServiceServer::with_interceptor( + device_profile_template::DeviceProfileTemplate::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(TenantServiceServer::with_interceptor( + tenant::Tenant::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(DeviceServiceServer::with_interceptor( + device::Device::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(UserServiceServer::with_interceptor( + user::User::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(GatewayServiceServer::with_interceptor( + gateway::Gateway::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(MulticastGroupServiceServer::with_interceptor( + multicast::MulticastGroup::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )) + .add_service(RelayServiceServer::with_interceptor( + relay::Relay::new(validator::RequestValidator::new()), + auth::auth_interceptor, + )); let backend_handle = tokio::spawn(backend::setup()); let monitoring_handle = tokio::spawn(monitoring::setup()); - let api_handle = tokio::spawn(Server::bind(&addr).serve(service)); + let grpc_handle = tokio::spawn(grpc.serve(bind)); - let _ = try_join!(api_handle, backend_handle, monitoring_handle)?; + let _ = try_join!(grpc_handle, backend_handle, monitoring_handle)?; Ok(()) } -enum EitherBody<A, B> { - Left(A), - Right(B), -} - -impl<A, B> http_body::Body for EitherBody<A, B> -where - A: http_body::Body + Send + Unpin, - B: http_body::Body<Data = A::Data> + Send + Unpin, - A::Error: Into<Error>, - B::Error: Into<Error>, -{ - type Data = A::Data; - type Error = Box<dyn std::error::Error + Send + Sync + 'static>; - - fn is_end_stream(&self) -> bool { - match self { - EitherBody::Left(b) => b.is_end_stream(), - EitherBody::Right(b) => b.is_end_stream(), - } - } - - fn poll_data( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Option<Result<Self::Data, Self::Error>>> { - match self.get_mut() { - EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err), - EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err), - } - } - - fn poll_trailers( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> { - match self.get_mut() { - EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into), - EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into), - } - } -} - -fn map_option_err<T, U: Into<Error>>(err: Option<Result<T, U>>) -> Option<Result<T, Error>> { - err.map(|e| e.map_err(Into::into)) -} - -async fn http_serve(path: Tail) -> Result<impl Reply, Rejection> { - let mut path = path.as_str(); +async fn service_static_handler(uri: Uri) -> impl IntoResponse { + let mut path = { + let mut chars = uri.path().chars(); + chars.next(); + chars.as_str() + }; if path.is_empty() { path = "index.html"; } - let asset = Asset::get(path).ok_or_else(warp::reject::not_found)?; - let mime = mime_guess::from_path(path).first_or_octet_stream(); - - let mut res = Response::new(asset.data.into()); - res.headers_mut().insert( - "content-type", - HeaderValue::from_str(mime.as_ref()).unwrap(), - ); - Ok(res) + if let Some(asset) = Asset::get(path) { + let mime = mime_guess::from_path(path).first_or_octet_stream(); + let mut headers = HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + HeaderValue::from_str(mime.as_ref()).unwrap(), + ); + (StatusCode::OK, headers, asset.data.into()) + } else { + (StatusCode::NOT_FOUND, HeaderMap::new(), vec![]) + } } #[derive(Debug, Clone)] @@ -320,13 +234,14 @@ struct GrpcLabels { status_code: String, } -struct ApiLogger {} +#[derive(Debug, Clone)] +struct ApiLoggerLayer {} -impl<S> tower::Layer<S> for ApiLogger { +impl<S> tower::Layer<S> for ApiLoggerLayer { type Service = ApiLoggerService<S>; - fn layer(&self, service: S) -> Self::Service { - ApiLoggerService { inner: service } + fn layer(&self, inner: S) -> Self::Service { + ApiLoggerService { inner } } } @@ -335,15 +250,15 @@ struct ApiLoggerService<S> { inner: S, } -impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for ApiLoggerService<S> +impl<ReqBody, ResBody, S> Service<http::Request<ReqBody>> for ApiLoggerService<S> where - S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>, - ReqBody: http_body::Body, ResBody: http_body::Body, + S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>, + S::Error: Into<BoxError> + Send, { type Response = S::Response; type Error = S::Error; - type Future = ApiLoggerResponseFuture<S::Future>; + type Future = ApiLoggerFuture<S::Future>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) @@ -352,10 +267,10 @@ where fn call(&mut self, request: http::Request<ReqBody>) -> Self::Future { let uri = request.uri().path().to_string(); let uri_parts: Vec<&str> = uri.split('/').collect(); - let response_future = self.inner.call(request); + let future = self.inner.call(request); let start = Instant::now(); - ApiLoggerResponseFuture { - response_future, + ApiLoggerFuture { + future, start, service: uri_parts.get(1).map(|v| v.to_string()).unwrap_or_default(), method: uri_parts.get(2).map(|v| v.to_string()).unwrap_or_default(), @@ -364,25 +279,26 @@ where } #[pin_project] -struct ApiLoggerResponseFuture<F> { +struct ApiLoggerFuture<F> { #[pin] - response_future: F, + future: F, start: Instant, service: String, method: String, } -impl<F, ResBody, Error> Future for ApiLoggerResponseFuture<F> +impl<ResBody, F, E> Future for ApiLoggerFuture<F> where - F: Future<Output = Result<http::Response<ResBody>, Error>>, ResBody: http_body::Body, + F: Future<Output = Result<http::Response<ResBody>, E>>, + E: Into<BoxError> + Send, { - type Output = Result<http::Response<ResBody>, Error>; + type Output = Result<http::Response<ResBody>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { let this = self.project(); - match this.response_future.poll(cx) { + match this.future.poll(cx) { Poll::Ready(result) => { if let Ok(response) = &result { let status_code: i32 = match response.headers().get("grpc-status") { diff --git a/chirpstack/src/api/monitoring/mod.rs b/chirpstack/src/api/monitoring/mod.rs index 94cf877d..98101cf2 100644 --- a/chirpstack/src/api/monitoring/mod.rs +++ b/chirpstack/src/api/monitoring/mod.rs @@ -1,51 +1,49 @@ -use std::convert::Infallible; use std::net::SocketAddr; use anyhow::{Context, Result}; +use axum::{ + response::{IntoResponse, Response}, + routing::get, + Router, +}; use diesel_async::RunQueryDsl; +use http::StatusCode; use tracing::info; -use warp::{http::Response, http::StatusCode, Filter}; use crate::config; use crate::monitoring::prometheus; use crate::storage::{get_async_db_conn, get_async_redis_conn}; -pub async fn setup() { +pub async fn setup() -> Result<()> { let conf = config::get(); if conf.monitoring.bind.is_empty() { - return; + return Ok(()); } let addr: SocketAddr = conf.monitoring.bind.parse().unwrap(); info!(bind = %conf.monitoring.bind, "Setting up monitoring endpoint"); - let prom_endpoint = warp::get() - .and(warp::path!("metrics")) - .and_then(prometheus_handler); + let app = Router::new() + .route("/metrics", get(prometheus_handler)) + .route("/health", get(health_handler)); - let health_endpoint = warp::get() - .and(warp::path!("health")) - .and_then(health_handler); - - let routes = prom_endpoint.or(health_endpoint); - - warp::serve(routes).run(addr).await; + axum_server::bind(addr) + .serve(app.into_make_service()) + .await?; + Ok(()) } -async fn prometheus_handler() -> Result<impl warp::Reply, Infallible> { +async fn prometheus_handler() -> Response { let body = prometheus::encode_to_string().unwrap_or_default(); - Ok(Response::builder().body(body)) + body.into_response() } -async fn health_handler() -> Result<impl warp::Reply, Infallible> { +async fn health_handler() -> Response { if let Err(e) = _health_handler().await { - return Ok(warp::reply::with_status( - e.to_string(), - StatusCode::SERVICE_UNAVAILABLE, - )); + (StatusCode::SERVICE_UNAVAILABLE, e.to_string()).into_response() + } else { + (StatusCode::OK, "".to_string()).into_response() } - - Ok(warp::reply::with_status("OK".to_string(), StatusCode::OK)) } async fn _health_handler() -> Result<()> { diff --git a/chirpstack/src/api/oauth2.rs b/chirpstack/src/api/oauth2.rs index 5b4a2f74..462a37f1 100644 --- a/chirpstack/src/api/oauth2.rs +++ b/chirpstack/src/api/oauth2.rs @@ -1,7 +1,10 @@ -use std::str::FromStr; - use anyhow::{Context, Result}; +use axum::{ + extract::Query, + response::{IntoResponse, Redirect, Response}, +}; use chrono::Duration; +use http::StatusCode; use oauth2::basic::BasicClient; use oauth2::reqwest; use oauth2::{ @@ -11,7 +14,6 @@ use oauth2::{ use reqwest::header::AUTHORIZATION; use serde::{Deserialize, Serialize}; use tracing::{error, trace}; -use warp::{Rejection, Reply}; use crate::config; use crate::helpers::errors::PrintFullError; @@ -26,7 +28,7 @@ struct ClerkUserinfo { pub user_id: String, } -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct CallbackArgs { pub code: String, pub state: String, @@ -39,16 +41,12 @@ pub struct User { pub external_id: String, } -pub async fn login_handler() -> Result<impl Reply, Rejection> { +pub async fn login_handler() -> Response { let client = match get_client() { Ok(v) => v, Err(e) => { error!(error = %e.full(), "Get OAuth2 client error"); - return Ok(warp::reply::with_status( - "Internal error", - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - ) - .into_response()); + return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response(); } }; @@ -64,26 +62,16 @@ pub async fn login_handler() -> Result<impl Reply, Rejection> { if let Err(e) = store_verifier(&csrf_token, &pkce_verifier).await { error!(error = %e.full(), "Store verifier error"); - return Ok(warp::reply::with_status( - "Internal error", - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - ) - .into_response()); + return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response(); } - Ok( - warp::redirect::found(warp::http::Uri::from_str(auth_url.as_str()).unwrap()) - .into_response(), - ) + Redirect::temporary(auth_url.as_str()).into_response() } -pub async fn callback_handler(args: CallbackArgs) -> Result<impl Reply, Rejection> { - // warp::redirect does not work with '#'. - Ok(warp::reply::with_header( - warp::http::StatusCode::PERMANENT_REDIRECT, - warp::http::header::LOCATION, - format!("/#/login?code={}&state={}", args.code, args.state), - )) +pub async fn callback_handler(args: Query<CallbackArgs>) -> Response { + let args: CallbackArgs = args.0; + Redirect::permanent(&format!("/#/login?code={}&state={}", args.code, args.state)) + .into_response() } fn get_client() -> Result<Client> { diff --git a/chirpstack/src/api/oidc.rs b/chirpstack/src/api/oidc.rs index b01f89f4..9424edfc 100644 --- a/chirpstack/src/api/oidc.rs +++ b/chirpstack/src/api/oidc.rs @@ -1,8 +1,12 @@ use std::collections::HashMap; -use std::str::FromStr; use anyhow::{Context, Result}; +use axum::{ + extract::Query, + response::{IntoResponse, Redirect, Response}, +}; use chrono::Duration; +use http::StatusCode; use openidconnect::core::{ CoreClient, CoreGenderClaim, CoreIdTokenClaims, CoreIdTokenVerifier, CoreProviderMetadata, CoreResponseType, @@ -15,7 +19,6 @@ use openidconnect::{ use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::{error, trace}; -use warp::{Rejection, Reply}; use crate::config; use crate::helpers::errors::PrintFullError; @@ -40,22 +43,18 @@ pub struct CustomClaims { impl AdditionalClaims for CustomClaims {} -#[derive(Serialize, Deserialize)] +#[derive(Deserialize)] pub struct CallbackArgs { pub code: String, pub state: String, } -pub async fn login_handler() -> Result<impl Reply, Rejection> { +pub async fn login_handler() -> Response { let client = match get_client().await { Ok(v) => v, Err(e) => { error!(error = %e.full(), "Get OIDC client error"); - return Ok(warp::reply::with_status( - "Internal error", - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - ) - .into_response()); + return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response(); } }; @@ -72,26 +71,16 @@ pub async fn login_handler() -> Result<impl Reply, Rejection> { if let Err(e) = store_nonce(&csrf_state, &nonce).await { error!(error = %e.full(), "Store nonce error"); - return Ok(warp::reply::with_status( - "Internal error", - warp::http::StatusCode::INTERNAL_SERVER_ERROR, - ) - .into_response()); + return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response(); } - Ok( - warp::redirect::found(warp::http::Uri::from_str(auth_url.as_str()).unwrap()) - .into_response(), - ) + Redirect::temporary(auth_url.as_str()).into_response() } -pub async fn callback_handler(args: CallbackArgs) -> Result<impl Reply, Rejection> { - // warp::redirect does not work with '#'. - Ok(warp::reply::with_header( - warp::http::StatusCode::PERMANENT_REDIRECT, - warp::http::header::LOCATION, - format!("/#/login?code={}&state={}", args.code, args.state), - )) +pub async fn callback_handler(args: Query<CallbackArgs>) -> Response { + let args: CallbackArgs = args.0; + Redirect::permanent(&format!("/#/login?code={}&state={}", args.code, args.state)) + .into_response() } pub async fn get_user(code: &str, state: &str) -> Result<User> { diff --git a/chirpstack/src/helpers/tls.rs b/chirpstack/src/helpers/tls.rs index e92a1415..e8e17c4a 100644 --- a/chirpstack/src/helpers/tls.rs +++ b/chirpstack/src/helpers/tls.rs @@ -2,6 +2,8 @@ use std::fs::File; use std::io::BufReader; use anyhow::{Context, Result}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; +use tokio::fs; // Return root certificates, optionally with the provided ca_file appended. pub fn get_root_certs(ca_file: Option<String>) -> Result<rustls::RootCertStore> { @@ -22,6 +24,38 @@ pub fn get_root_certs(ca_file: Option<String>) -> Result<rustls::RootCertStore> Ok(roots) } +pub async fn load_cert(cert_file: &str) -> Result<Vec<CertificateDer<'static>>> { + let cert_s = fs::read_to_string(cert_file) + .await + .context("Read TLS certificate")?; + let mut cert_b = cert_s.as_bytes(); + let certs = rustls_pemfile::certs(&mut cert_b); + let mut out = Vec::new(); + for cert in certs { + out.push(cert?.into_owned()); + } + Ok(out) +} + +pub async fn load_key(key_file: &str) -> Result<PrivateKeyDer<'static>> { + let key_s = fs::read_to_string(key_file) + .await + .context("Read private key")?; + let key_s = private_key_to_pkcs8(&key_s)?; + let mut key_b = key_s.as_bytes(); + let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key_b); + if let Some(key) = keys.next() { + match key { + Ok(v) => return Ok(PrivateKeyDer::Pkcs8(v.clone_key())), + Err(e) => { + return Err(anyhow!("Error parsing private key, error: {}", e)); + } + } + } + + Err(anyhow!("No private key found")) +} + pub fn private_key_to_pkcs8(pem: &str) -> Result<String> { if pem.contains("RSA PRIVATE KEY") { use rsa::{ diff --git a/chirpstack/src/test/class_a_pr_test.rs b/chirpstack/src/test/class_a_pr_test.rs index 3aecfdb8..3413ca07 100644 --- a/chirpstack/src/test/class_a_pr_test.rs +++ b/chirpstack/src/test/class_a_pr_test.rs @@ -383,7 +383,9 @@ async fn test_sns_uplink() { let resp = backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap())) .await; - let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap(); @@ -572,7 +574,9 @@ async fn test_sns_roaming_not_allowed() { let resp = backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap())) .await; - let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap(); assert_eq!( @@ -683,7 +687,9 @@ async fn test_sns_dev_not_found() { let resp = backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap())) .await; - let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap(); diff --git a/chirpstack/src/test/otaa_pr_test.rs b/chirpstack/src/test/otaa_pr_test.rs index 09e9af12..026f8002 100644 --- a/chirpstack/src/test/otaa_pr_test.rs +++ b/chirpstack/src/test/otaa_pr_test.rs @@ -379,7 +379,9 @@ async fn test_sns() { let resp = backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap())) .await; - let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap(); @@ -562,7 +564,9 @@ async fn test_sns_roaming_not_allowed() { let resp = backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap())) .await; - let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap(); + let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap();