Replace warp with axum.

The warp dependency was causing some issues with upgrading dependencies
as it depends on http v0.2, where other dependencies (e.g. tonic) have
already upgraded to http v1+.
This commit is contained in:
Orne Brocaar 2024-08-01 11:33:57 +01:00
parent 98978135c4
commit 4e0106a4e8
13 changed files with 564 additions and 541 deletions

294
Cargo.lock generated
View File

@ -173,6 +173,12 @@ dependencies = [
"num-traits",
]
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]]
name = "ascii-canvas"
version = "3.0.0"
@ -524,18 +530,19 @@ dependencies = [
[[package]]
name = "axum"
version = "0.6.20"
version = "0.7.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b829e4e32b91e643de6eafe82b1d90675f5874230191a4ffbc1b336dec4d6bf"
checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
dependencies = [
"async-trait",
"axum-core",
"bitflags 1.3.2",
"bytes",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.30",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.4.1",
"hyper-util",
"itoa",
"matchit",
"memchr",
@ -544,27 +551,60 @@ dependencies = [
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper 0.1.2",
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sync_wrapper 1.0.1",
"tokio",
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.3.4"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper 0.1.2",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-server"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56bac90848f6a9393ac03c63c640925c4b7c8ca21654de40d53f55964667c7d8"
dependencies = [
"arc-swap",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.4.1",
"hyper-util",
"pin-project-lite",
"rustls 0.23.12",
"rustls-pemfile",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.0",
"tower",
"tower-service",
]
[[package]]
@ -779,6 +819,8 @@ dependencies = [
"async-recursion",
"async-trait",
"aws-sign-v4",
"axum",
"axum-server",
"backend",
"base64 0.22.1",
"bigdecimal",
@ -800,11 +842,10 @@ dependencies = [
"handlebars",
"hex",
"hmac",
"http 0.2.12",
"http-body 0.4.6",
"http 1.1.0",
"http-body 1.0.1",
"httpmock",
"humantime-serde",
"hyper 0.14.30",
"jsonwebtoken",
"lapin",
"lazy_static",
@ -857,7 +898,6 @@ dependencies = [
"tracing-subscriber",
"urlencoding",
"uuid",
"warp",
"x509-parser",
]
@ -960,7 +1000,7 @@ version = "4.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2bac35c6dafb060fd4d275d9a4ffae97917c13a6327903a8be2153cd964f7085"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.72",
@ -1380,7 +1420,7 @@ checksum = "c5d9abe6314103864cc2d8901b7ae224e0ab1a103a0a416661b4097b0779b607"
dependencies = [
"darling",
"either",
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.72",
@ -1837,25 +1877,6 @@ dependencies = [
"subtle",
]
[[package]]
name = "h2"
version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http 0.2.12",
"indexmap 2.2.6",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "h2"
version = "0.4.5"
@ -1901,36 +1922,6 @@ version = "0.14.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1"
[[package]]
name = "headers"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270"
dependencies = [
"base64 0.21.7",
"bytes",
"headers-core",
"http 0.2.12",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429"
dependencies = [
"http 0.2.12",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@ -2038,12 +2029,6 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "add0ab9360ddbd88cfeb3bd9574a1d85cfdfa14db10b3e21d3700dbc4328758f"
[[package]]
name = "httparse"
version = "1.9.4"
@ -2110,7 +2095,6 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
"httparse",
@ -2133,10 +2117,11 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.5",
"h2",
"http 1.1.0",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"smallvec",
@ -2165,14 +2150,15 @@ dependencies = [
[[package]]
name = "hyper-timeout"
version = "0.4.1"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb958482e8c7be4bc3cf272a766a2b0bf1a6755e7a6ae777f017a31d11b13b1"
checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793"
dependencies = [
"hyper 0.14.30",
"hyper 1.4.1",
"hyper-util",
"pin-project-lite",
"tokio",
"tokio-io-timeout",
"tower-service",
]
[[package]]
@ -2324,6 +2310,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.11"
@ -2898,9 +2893,9 @@ dependencies = [
[[package]]
name = "pbjson"
version = "0.6.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1030c719b0ec2a2d25a5df729d6cff1acf3cc230bf766f4f97833591f7577b90"
checksum = "c7e6349fa080353f4a597daffd05cb81572a9c031a6d4fff7e504947496fcc68"
dependencies = [
"base64 0.21.7",
"serde",
@ -2908,21 +2903,21 @@ dependencies = [
[[package]]
name = "pbjson-build"
version = "0.6.2"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2580e33f2292d34be285c5bc3dba5259542b083cfad6037b6d70345f24dcb735"
checksum = "6eea3058763d6e656105d1403cb04e0a41b7bbac6362d413e7c33be0c32279c9"
dependencies = [
"heck 0.4.1",
"itertools 0.11.0",
"heck",
"itertools 0.13.0",
"prost",
"prost-types",
]
[[package]]
name = "pbjson-types"
version = "0.6.0"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "18f596653ba4ac51bdecbb4ef6773bc7f56042dc13927910de1684ad3d32aa12"
checksum = "e54e5e7bfb1652f95bc361d76f3c780d8e526b134b85417e774166ee941f0887"
dependencies = [
"bytes",
"chrono",
@ -3305,9 +3300,9 @@ dependencies = [
[[package]]
name = "prost"
version = "0.12.6"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "deb1435c188b76130da55f17a466d252ff7b1418b2ad3e037d127b94e3411f29"
checksum = "e13db3d3fde688c61e2446b4d843bc27a7e8af269a69440c0308021dc92333cc"
dependencies = [
"bytes",
"prost-derive",
@ -3315,13 +3310,13 @@ dependencies = [
[[package]]
name = "prost-build"
version = "0.12.6"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
checksum = "5bb182580f71dd070f88d01ce3de9f4da5021db7115d2e1c3605a754153b77c1"
dependencies = [
"bytes",
"heck 0.5.0",
"itertools 0.12.1",
"heck",
"itertools 0.13.0",
"log",
"multimap",
"once_cell",
@ -3336,12 +3331,12 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.12.6"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1"
checksum = "18bec9b0adc4eba778b33684b7ba3e7137789434769ee3ce3930463ef904cfca"
dependencies = [
"anyhow",
"itertools 0.12.1",
"itertools 0.13.0",
"proc-macro2",
"quote",
"syn 2.0.72",
@ -3349,9 +3344,9 @@ dependencies = [
[[package]]
name = "prost-types"
version = "0.12.6"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9091c90b0a32608e984ff2fa4091273cbdd755d54935c51d520887f4a1dbd5b0"
checksum = "cee5168b05f49d4b0ca581206eb14a7b22fafd963efe729ac48eb03266e25cc2"
dependencies = [
"prost",
]
@ -3946,12 +3941,6 @@ dependencies = [
"pin-utils",
]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -4135,17 +4124,6 @@ dependencies = [
"unsafe-libyaml",
]
[[package]]
name = "sha1"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha1_smol"
version = "1.0.1"
@ -4514,16 +4492,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-io-timeout"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30b74022ada614a1b4834de765f9bb43877f910cc8ce4be40e89042c9223a8bf"
dependencies = [
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-macros"
version = "2.4.0"
@ -4682,23 +4650,26 @@ dependencies = [
[[package]]
name = "tonic"
version = "0.11.0"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "76c4eb7a4e9ef9d4763600161f12f5070b92a578e1b634db88a6887844c91a13"
checksum = "38659f4a91aba8598d27821589f5db7dddd94601e7a01b1e485a50e5484c7401"
dependencies = [
"async-stream",
"async-trait",
"axum",
"base64 0.21.7",
"base64 0.22.1",
"bytes",
"h2 0.3.26",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.30",
"h2",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"prost",
"socket2 0.5.7",
"tokio",
"tokio-stream",
"tower",
@ -4709,9 +4680,9 @@ dependencies = [
[[package]]
name = "tonic-build"
version = "0.11.0"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4ef6dd70a610078cb4e338a0f79d06bc759ff1b22d2120c2ff02ae264ba9c2"
checksum = "568392c5a2bd0020723e3f387891176aabafe36fd9fcd074ad309dfa0c8eb964"
dependencies = [
"prettyplease",
"proc-macro2",
@ -4722,9 +4693,9 @@ dependencies = [
[[package]]
name = "tonic-reflection"
version = "0.11.0"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "548c227bd5c0fae5925812c4ec6c66ffcfced23ea370cb823f4d18f0fc1cb6a7"
checksum = "b742c83ad673e9ab5b4ce0981f7b9e8932be9d60e8682cbf9120494764dbc173"
dependencies = [
"prost",
"prost-types",
@ -4735,15 +4706,15 @@ dependencies = [
[[package]]
name = "tonic-web"
version = "0.11.0"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc3b0e1cedbf19fdfb78ef3d672cb9928e0a91a9cb4629cc0c916e8cff8aaaa1"
checksum = "8dc0e36ac436560b9a8c9edad4521cf5dd5deb1af591936db9660191b6ecf619"
dependencies = [
"base64 0.21.7",
"base64 0.22.1",
"bytes",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.30",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"pin-project",
"tokio-stream",
"tonic",
@ -4775,18 +4746,16 @@ dependencies = [
[[package]]
name = "tower-http"
version = "0.4.4"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61c5bb1d698276a2443e5ecfabc1008bf15a36c12e6a7176e7bf089ea9131140"
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
"base64 0.21.7",
"bitflags 2.6.0",
"bytes",
"futures-core",
"futures-util",
"http 0.2.12",
"http-body 0.4.6",
"http-range-header",
"http 1.1.0",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"tower-layer",
@ -5042,35 +5011,6 @@ dependencies = [
"try-lock",
]
[[package]]
name = "warp"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4378d202ff965b011c64817db11d5829506d3404edeadb61f190d111da3f231c"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"headers",
"http 0.2.12",
"hyper 0.14.30",
"log",
"mime",
"mime_guess",
"percent-encoding",
"pin-project",
"rustls-pemfile",
"scoped-tls",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-rustls 0.25.0",
"tokio-util",
"tower-service",
"tracing",
]
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"

14
api/rust/Cargo.toml vendored
View File

@ -16,23 +16,23 @@
internal = []
[dependencies]
prost = "0.12"
prost-types = "0.12"
prost = "0.13"
prost-types = "0.13"
hex = "0.4"
rand = "0.8"
tonic = { version = "0.11", features = [
tonic = { version = "0.12", features = [
"codegen",
"prost",
], default-features = false, optional = true }
tokio = { version = "1.38", features = ["macros"], optional = true }
pbjson = { version = "0.6", optional = true }
pbjson-types = { version = "0.6", optional = true }
pbjson = { version = "0.7", optional = true }
pbjson-types = { version = "0.7", optional = true }
serde = { version = "1.0", optional = true }
diesel = { version = "2.2", features = ["postgres_backend"], optional = true }
[build-dependencies]
tonic-build = { version = "0.11", features = [
tonic-build = { version = "0.12", features = [
"prost",
], default-features = false }
pbjson-build = "0.6"
pbjson-build = "0.7"

View File

@ -88,26 +88,26 @@
] }
# gRPC and Protobuf
tonic = "0.11"
tonic-web = "0.11"
tonic-reflection = "0.11"
tonic = "0.12"
tonic-web = "0.12"
tonic-reflection = "0.12"
tokio = { version = "1.38", features = ["macros", "rt-multi-thread"] }
tokio-stream = "0.1"
prost-types = "0.12"
prost = "0.12"
pbjson-types = "0.6"
prost-types = "0.13"
prost = "0.13"
pbjson-types = "0.7"
# gRPC and HTTP multiplexing
warp = { version = "0.3", features = ["tls"], default-features = false }
hyper = "0.14"
tower = "0.4"
axum = "0.7"
axum-server = { version = "0.7.1", features = ["tls-rustls-no-provider"] }
tower = { version = "0.4" }
futures = "0.3"
futures-util = "0.3"
http = "0.2"
http-body = "0.4"
http = "1.1"
http-body = "1.0"
rust-embed = "8.5"
mime_guess = "2.0"
tower-http = { version = "0.4", features = ["trace", "auth"] }
tower-http = { version = "0.5", features = ["trace", "auth"] }
# Error handling
thiserror = "1.0"

View File

@ -6,7 +6,7 @@ pub mod claims;
pub mod error;
pub mod validator;
#[derive(PartialEq, Eq, Debug)]
#[derive(PartialEq, Eq, Debug, Clone)]
pub enum AuthID {
None,
User(Uuid),

View File

@ -5,18 +5,28 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use axum::{
body::Bytes,
response::{IntoResponse, Json, Response},
Router,
};
use chrono::Utc;
use http::StatusCode;
use redis::streams::StreamReadReply;
use rustls::{
server::{NoClientAuth, WebPkiClientVerifier},
ServerConfig,
};
use serde::Serialize;
use tokio::sync::oneshot;
use tokio::task;
use tracing::{error, info, span, warn, Instrument, Level};
use uuid::Uuid;
use warp::{http::StatusCode, Filter, Reply};
use crate::backend::{joinserver, keywrap, roaming};
use crate::downlink::data_fns;
use crate::helpers::errors::PrintFullError;
use crate::helpers::tls::{get_root_certs, load_cert, load_key};
use crate::storage::{
device, error::Error as StorageError, get_async_redis_conn, passive_roaming, redis_key,
};
@ -39,47 +49,47 @@ pub async fn setup() -> Result<()> {
let addr: SocketAddr = conf.backend_interfaces.bind.parse()?;
info!(bind = %conf.backend_interfaces.bind, "Setting up backend interfaces API");
let routes = warp::post()
.and(warp::body::content_length_limit(1024 * 16))
.and(warp::body::aggregate())
.then(handle_request);
let app = Router::new().fallback(handle_request);
if !conf.backend_interfaces.ca_cert.is_empty()
|| !conf.backend_interfaces.tls_cert.is_empty()
|| !conf.backend_interfaces.tls_key.is_empty()
{
let mut w = warp::serve(routes).tls();
if !conf.backend_interfaces.ca_cert.is_empty() {
w = w.client_auth_required_path(&conf.backend_interfaces.ca_cert);
}
if !conf.backend_interfaces.tls_cert.is_empty() {
w = w.cert_path(&conf.backend_interfaces.tls_cert);
}
if !conf.backend_interfaces.tls_key.is_empty() {
w = w.key_path(&conf.backend_interfaces.tls_key);
}
w.run(addr).await;
let mut server_config = ServerConfig::builder()
.with_client_cert_verifier(if conf.backend_interfaces.ca_cert.is_empty() {
Arc::new(NoClientAuth)
} else {
let root_certs = get_root_certs(Some(conf.backend_interfaces.ca_cert.clone()))?;
WebPkiClientVerifier::builder(root_certs.into()).build()?
})
.with_single_cert(
load_cert(&conf.backend_interfaces.tls_cert).await?,
load_key(&conf.backend_interfaces.tls_key).await?,
)?;
server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
axum_server::bind_rustls(
addr,
axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(server_config)),
)
.serve(app.into_make_service())
.await?;
} else {
warp::serve(routes).run(addr).await;
axum_server::bind(addr)
.serve(app.into_make_service())
.await?;
}
Ok(())
}
pub async fn handle_request(mut body: impl warp::Buf) -> http::Response<hyper::Body> {
let mut b: Vec<u8> = vec![];
while body.has_remaining() {
b.extend_from_slice(body.chunk());
let cnt = body.chunk().len();
body.advance(cnt);
}
pub async fn handle_request(b: Bytes) -> Response {
let b: Vec<u8> = b.into();
let bp: BasePayload = match serde_json::from_slice(&b) {
Ok(v) => v,
Err(e) => {
return warp::reply::with_status(e.to_string(), StatusCode::BAD_REQUEST)
.into_response();
return (StatusCode::BAD_REQUEST, e.to_string()).into_response();
}
};
@ -87,7 +97,7 @@ pub async fn handle_request(mut body: impl warp::Buf) -> http::Response<hyper::B
_handle_request(bp, b).instrument(span).await
}
pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hyper::Body> {
pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> Response {
info!("Request received");
let sender_client = {
@ -100,7 +110,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
let msg = format!("Error decoding SenderID: {}", e);
let pl = bp.to_base_payload_result(backend::ResultCode::MalformedRequest, &msg);
log_request_response(&bp, &b, &pl).await;
return warp::reply::json(&pl).into_response();
return Json(&pl).into_response();
}
};
@ -111,7 +121,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
let msg = format!("Unknown SenderID: {}", sender_id);
let pl = bp.to_base_payload_result(backend::ResultCode::UnknownSender, &msg);
log_request_response(&bp, &b, &pl).await;
return warp::reply::json(&pl).into_response();
return Json(&pl).into_response();
}
}
} else if bp.sender_id.len() == 3 {
@ -123,7 +133,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
let msg = format!("Error decoding SenderID: {}", e);
let pl = bp.to_base_payload_result(backend::ResultCode::MalformedRequest, &msg);
log_request_response(&bp, &b, &pl).await;
return warp::reply::json(&pl).into_response();
return Json(&pl).into_response();
}
};
@ -134,7 +144,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
let msg = format!("Unknown SenderID: {}", sender_id);
let pl = bp.to_base_payload_result(backend::ResultCode::UnknownSender, &msg);
log_request_response(&bp, &b, &pl).await;
return warp::reply::json(&pl).into_response();
return Json(&pl).into_response();
}
}
} else {
@ -145,7 +155,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
"Invalid SenderID length",
);
log_request_response(&bp, &b, &pl).await;
return warp::reply::json(&pl).into_response();
return Json(&pl).into_response();
}
};
@ -156,7 +166,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
error!(error = %e.full(), "Handle async answer error");
}
});
return warp::reply::with_status("", StatusCode::OK).into_response();
return (StatusCode::OK, "").into_response();
}
match bp.message_type {
@ -165,11 +175,11 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
MessageType::XmitDataReq => handle_xmit_data_req(sender_client, bp, &b).await,
MessageType::HomeNSReq => handle_home_ns_req(sender_client, bp, &b).await,
// Unknown message
_ => warp::reply::with_status(
"Handler for {:?} is not implemented",
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Handler for {:?} is not implemented", bp.message_type),
)
.into_response(),
.into_response(),
}
}
@ -201,7 +211,7 @@ async fn handle_pr_start_req(
sender_client: Arc<backend::Client>,
bp: backend::BasePayload,
b: &[u8],
) -> http::Response<hyper::Body> {
) -> Response {
if sender_client.is_async() {
let b = b.to_vec();
task::spawn(async move {
@ -222,18 +232,17 @@ async fn handle_pr_start_req(
error!(error = %e.full(), transaction_id = bp.transaction_id, "Send async PRStartAns error");
}
});
warp::reply::with_status("", StatusCode::OK).into_response()
(StatusCode::OK, "").into_response()
} else {
match _handle_pr_start_req(b).await {
Ok(ans) => {
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
Err(e) => {
let ans = err_to_response(e, &bp);
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
}
}
@ -363,7 +372,7 @@ async fn handle_pr_stop_req(
sender_client: Arc<backend::Client>,
bp: backend::BasePayload,
b: &[u8],
) -> http::Response<hyper::Body> {
) -> Response {
if sender_client.is_async() {
let b = b.to_vec();
task::spawn(async move {
@ -383,18 +392,17 @@ async fn handle_pr_stop_req(
error!(error = %e.full(), "Send async PRStopAns error");
}
});
warp::reply::with_status("", StatusCode::OK).into_response()
(StatusCode::OK, "").into_response()
} else {
match _handle_pr_stop_req(b).await {
Ok(ans) => {
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
Err(e) => {
let ans = err_to_response(e, &bp);
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
}
}
@ -430,13 +438,13 @@ async fn handle_xmit_data_req(
sender_client: Arc<backend::Client>,
bp: backend::BasePayload,
b: &[u8],
) -> http::Response<hyper::Body> {
) -> Response {
let pl: backend::XmitDataReqPayload = match serde_json::from_slice(b) {
Ok(v) => v,
Err(e) => {
let ans = err_to_response(anyhow::Error::new(e), &bp);
log_request_response(&bp, b, &ans).await;
return warp::reply::json(&ans).into_response();
return Json(&ans).into_response();
}
};
@ -465,18 +473,17 @@ async fn handle_xmit_data_req(
error!(error = %e.full(), "Send async XmitDataAns error");
}
});
warp::reply::with_status("", StatusCode::OK).into_response()
(StatusCode::OK, "").into_response()
} else {
match _handle_xmit_data_req(pl).await {
Ok(ans) => {
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
Err(e) => {
let ans = err_to_response(e, &bp);
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
}
}
@ -529,13 +536,13 @@ async fn handle_home_ns_req(
sender_client: Arc<backend::Client>,
bp: backend::BasePayload,
b: &[u8],
) -> http::Response<hyper::Body> {
) -> Response {
let pl: backend::HomeNSReqPayload = match serde_json::from_slice(b) {
Ok(v) => v,
Err(e) => {
let ans = err_to_response(anyhow::Error::new(e), &bp);
log_request_response(&bp, b, &ans).await;
return warp::reply::json(&ans).into_response();
return Json(&ans).into_response();
}
};
@ -560,17 +567,17 @@ async fn handle_home_ns_req(
}
});
warp::reply::with_status("", StatusCode::OK).into_response()
(StatusCode::OK, "").into_response()
} else {
match _handle_home_ns_req(pl).await {
Ok(ans) => {
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
Err(e) => {
let ans = err_to_response(e, &bp);
log_request_response(&bp, b, &ans).await;
warp::reply::json(&ans).into_response()
Json(&ans).into_response()
}
}
}
@ -587,7 +594,7 @@ async fn _handle_home_ns_req(pl: backend::HomeNSReqPayload) -> Result<backend::H
})
}
async fn handle_async_ans(bp: &BasePayload, b: &[u8]) -> Result<http::Response<hyper::Body>> {
async fn handle_async_ans(bp: &BasePayload, b: &[u8]) -> Result<Response> {
let transaction_id = bp.transaction_id;
let key = redis_key(format!("backend:async:{}", transaction_id));
@ -609,7 +616,7 @@ async fn handle_async_ans(bp: &BasePayload, b: &[u8]) -> Result<http::Response<h
.query_async(&mut get_async_redis_conn().await?)
.await?;
Ok(warp::reply().into_response())
Ok((StatusCode::OK, "").into_response())
}
pub async fn get_async_receiver(

View File

@ -0,0 +1,141 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures::ready;
use http::{header::CONTENT_TYPE, Request, Response};
use http_body::Body;
use pin_project::pin_project;
use tower::{Layer, Service};
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[pin_project(project = GrpcMultiplexFutureEnumProj)]
enum GrpcMultiplexFutureEnum<FS, FO> {
Grpc {
#[pin]
future: FS,
},
Other {
#[pin]
future: FO,
},
}
#[pin_project]
pub struct GrpcMultiplexFuture<FS, FO> {
#[pin]
future: GrpcMultiplexFutureEnum<FS, FO>,
}
impl<ResBody, FS, FO, ES, EO> Future for GrpcMultiplexFuture<FS, FO>
where
ResBody: Body,
FS: Future<Output = Result<Response<ResBody>, ES>>,
FO: Future<Output = Result<Response<ResBody>, EO>>,
ES: Into<BoxError> + Send,
EO: Into<BoxError> + Send,
{
type Output = Result<Response<ResBody>, Box<dyn std::error::Error + Send + Sync + 'static>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.future.project() {
GrpcMultiplexFutureEnumProj::Grpc { future } => future.poll(cx).map_err(Into::into),
GrpcMultiplexFutureEnumProj::Other { future } => future.poll(cx).map_err(Into::into),
}
}
}
#[derive(Debug, Clone)]
pub struct GrpcMultiplexService<S, O> {
grpc: S,
other: O,
grpc_ready: bool,
other_ready: bool,
}
impl<ReqBody, ResBody, S, O> Service<Request<ReqBody>> for GrpcMultiplexService<S, O>
where
ResBody: Body,
S: Service<Request<ReqBody>, Response = Response<ResBody>>,
O: Service<Request<ReqBody>, Response = Response<ResBody>>,
S::Error: Into<BoxError> + Send,
O::Error: Into<BoxError> + Send,
{
type Response = S::Response;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
type Future = GrpcMultiplexFuture<S::Future, O::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
match (self.grpc_ready, self.other_ready) {
(true, true) => {
return Ok(()).into();
}
(false, _) => {
ready!(self.grpc.poll_ready(cx)).map_err(Into::into)?;
self.grpc_ready = true;
}
(_, false) => {
ready!(self.other.poll_ready(cx)).map_err(Into::into)?;
self.other_ready = true;
}
}
}
}
fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
assert!(self.grpc_ready);
assert!(self.other_ready);
if is_grpc_request(&request) {
GrpcMultiplexFuture {
future: GrpcMultiplexFutureEnum::Grpc {
future: self.grpc.call(request),
},
}
} else {
GrpcMultiplexFuture {
future: GrpcMultiplexFutureEnum::Other {
future: self.other.call(request),
},
}
}
}
}
#[derive(Debug, Clone)]
pub struct GrpcMultiplexLayer<O> {
other: O,
}
impl<O> GrpcMultiplexLayer<O> {
pub fn new(other: O) -> Self {
Self { other }
}
}
impl<S, O> Layer<S> for GrpcMultiplexLayer<O>
where
O: Clone,
{
type Service = GrpcMultiplexService<S, O>;
fn layer(&self, grpc: S) -> Self::Service {
GrpcMultiplexService {
grpc,
other: self.other.clone(),
grpc_ready: false,
other_ready: false,
}
}
}
fn is_grpc_request<B>(req: &Request<B>) -> bool {
req.headers()
.get(CONTENT_TYPE)
.map(|content_type| content_type.as_bytes())
.filter(|content_type| content_type.starts_with(b"application/grpc"))
.is_some()
}

View File

@ -1,4 +1,3 @@
use std::convert::Infallible;
use std::time::{Duration, Instant};
use std::{
future::Future,
@ -7,23 +6,27 @@ use std::{
};
use anyhow::Result;
use futures::future::{self, Either, TryFutureExt};
use hyper::{service::make_service_fn, Server};
use axum::{response::IntoResponse, routing::get, Router};
use http::{
header::{self, HeaderMap, HeaderValue},
Request, StatusCode, Uri,
};
use pin_project::pin_project;
use prometheus_client::encoding::EncodeLabelSet;
use prometheus_client::metrics::counter::Counter;
use prometheus_client::metrics::family::Family;
use prometheus_client::metrics::histogram::Histogram;
use rust_embed::RustEmbed;
use tokio::{task, try_join};
use tokio::task;
use tokio::try_join;
use tonic::transport::Server as TonicServer;
use tonic::Code;
use tonic_reflection::server::Builder as TonicReflectionBuilder;
use tonic_web::GrpcWebLayer;
use tower::{Service, ServiceBuilder};
use tower::util::ServiceExt;
use tower::Service;
use tower_http::trace::TraceLayer;
use tracing::{error, info};
use warp::{http::header::HeaderValue, path::Tail, reply::Response, Filter, Rejection, Reply};
use chirpstack_api::api::application_service_server::ApplicationServiceServer;
use chirpstack_api::api::device_profile_service_server::DeviceProfileServiceServer;
@ -51,6 +54,7 @@ pub mod device_profile;
pub mod device_profile_template;
pub mod error;
pub mod gateway;
mod grpc_multiplex;
pub mod helpers;
pub mod internal;
pub mod monitoring;
@ -89,210 +93,120 @@ lazy_static! {
};
}
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
#[derive(RustEmbed)]
#[folder = "../ui/build"]
struct Asset;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
pub async fn setup() -> Result<()> {
let conf = config::get();
let addr = conf.api.bind.parse()?;
let bind = conf.api.bind.parse()?;
info!(bind = %conf.api.bind, "Setting up API interface");
info!(bind = %bind, "Setting up API interface");
// Taken from the tonic hyper_warp_multiplex example:
// https://github.com/hyperium/tonic/blob/master/examples/src/hyper_warp_multiplex/server.rs#L101
let service = make_service_fn(move |_| {
// tonic gRPC service
let tonic_service = TonicServer::builder()
.accept_http1(true)
.layer(GrpcWebLayer::new())
.add_service(
TonicReflectionBuilder::configure()
.register_encoded_file_descriptor_set(chirpstack_api::api::DESCRIPTOR)
.build()
.unwrap(),
)
.add_service(InternalServiceServer::with_interceptor(
internal::Internal::new(
validator::RequestValidator::new(),
conf.api.secret.clone(),
),
auth::auth_interceptor,
))
.add_service(ApplicationServiceServer::with_interceptor(
application::Application::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(DeviceProfileServiceServer::with_interceptor(
device_profile::DeviceProfile::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(DeviceProfileTemplateServiceServer::with_interceptor(
device_profile_template::DeviceProfileTemplate::new(
validator::RequestValidator::new(),
),
auth::auth_interceptor,
))
.add_service(TenantServiceServer::with_interceptor(
tenant::Tenant::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(DeviceServiceServer::with_interceptor(
device::Device::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(UserServiceServer::with_interceptor(
user::User::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(GatewayServiceServer::with_interceptor(
gateway::Gateway::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(MulticastGroupServiceServer::with_interceptor(
multicast::MulticastGroup::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(RelayServiceServer::with_interceptor(
relay::Relay::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.into_service();
let mut tonic_service = ServiceBuilder::new()
.layer(
TraceLayer::new_for_grpc()
.make_span_with(|req: &http::Request<hyper::Body>| {
tracing::info_span!(
"gRPC",
uri = %req.uri().path(),
)
})
.on_request(OnRequest {})
.on_response(OnResponse {}),
)
.layer(ApiLogger {})
.service(tonic_service);
let web = Router::new()
.route("/auth/oidc/login", get(oidc::login_handler))
.route("/auth/oidc/callback", get(oidc::callback_handler))
.route("/auth/oauth2/login", get(oauth2::login_handler))
.route("/auth/oauth2/callback", get(oauth2::callback_handler))
.fallback(service_static_handler)
.into_service()
.map_response(|r| r.map(tonic::body::boxed));
// HTTP service
let warp_service = warp::service(
warp::path!("auth" / "oidc" / "login")
.and_then(oidc::login_handler)
.or(warp::path!("auth" / "oidc" / "callback")
.and(warp::query::<oidc::CallbackArgs>())
.and_then(oidc::callback_handler))
.or(warp::path!("auth" / "oauth2" / "login").and_then(oauth2::login_handler))
.or(warp::path!("auth" / "oauth2" / "callback")
.and(warp::query::<oauth2::CallbackArgs>())
.and_then(oauth2::callback_handler))
.or(warp::path::tail().and_then(http_serve)),
);
let mut warp_service = ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.make_span_with(|req: &http::Request<hyper::Body>| {
tracing::info_span!(
"http",
method = req.method().as_str(),
uri = %req.uri().path(),
version = ?req.version(),
)
})
.on_request(OnRequest {})
.on_response(OnResponse {}),
)
.service(warp_service);
future::ok::<_, Infallible>(tower::service_fn(
move |req: hyper::Request<hyper::Body>| match req.method() {
&hyper::Method::GET => Either::Left(
warp_service
.call(req)
.map_ok(|res| res.map(EitherBody::Right))
.map_err(Error::from),
),
_ => Either::Right(
tonic_service
.call(req)
.map_ok(|res| res.map(EitherBody::Left))
.map_err(Error::from),
),
},
let grpc = TonicServer::builder()
.accept_http1(true)
.layer(
TraceLayer::new_for_grpc()
.make_span_with(|req: &Request<_>| {
tracing::info_span!(
"gRPC",
uri = %req.uri().path(),
)
})
.on_request(OnRequest {})
.on_response(OnResponse {}),
)
.layer(grpc_multiplex::GrpcMultiplexLayer::new(web))
.layer(ApiLoggerLayer {})
.layer(GrpcWebLayer::new())
.add_service(
TonicReflectionBuilder::configure()
.register_encoded_file_descriptor_set(chirpstack_api::api::DESCRIPTOR)
.build()
.unwrap(),
)
.add_service(InternalServiceServer::with_interceptor(
internal::Internal::new(validator::RequestValidator::new(), conf.api.secret.clone()),
auth::auth_interceptor,
))
});
.add_service(ApplicationServiceServer::with_interceptor(
application::Application::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(DeviceProfileServiceServer::with_interceptor(
device_profile::DeviceProfile::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(DeviceProfileTemplateServiceServer::with_interceptor(
device_profile_template::DeviceProfileTemplate::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(TenantServiceServer::with_interceptor(
tenant::Tenant::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(DeviceServiceServer::with_interceptor(
device::Device::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(UserServiceServer::with_interceptor(
user::User::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(GatewayServiceServer::with_interceptor(
gateway::Gateway::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(MulticastGroupServiceServer::with_interceptor(
multicast::MulticastGroup::new(validator::RequestValidator::new()),
auth::auth_interceptor,
))
.add_service(RelayServiceServer::with_interceptor(
relay::Relay::new(validator::RequestValidator::new()),
auth::auth_interceptor,
));
let backend_handle = tokio::spawn(backend::setup());
let monitoring_handle = tokio::spawn(monitoring::setup());
let api_handle = tokio::spawn(Server::bind(&addr).serve(service));
let grpc_handle = tokio::spawn(grpc.serve(bind));
let _ = try_join!(api_handle, backend_handle, monitoring_handle)?;
let _ = try_join!(grpc_handle, backend_handle, monitoring_handle)?;
Ok(())
}
enum EitherBody<A, B> {
Left(A),
Right(B),
}
impl<A, B> http_body::Body for EitherBody<A, B>
where
A: http_body::Body + Send + Unpin,
B: http_body::Body<Data = A::Data> + Send + Unpin,
A::Error: Into<Error>,
B::Error: Into<Error>,
{
type Data = A::Data;
type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
fn is_end_stream(&self) -> bool {
match self {
EitherBody::Left(b) => b.is_end_stream(),
EitherBody::Right(b) => b.is_end_stream(),
}
}
fn poll_data(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
match self.get_mut() {
EitherBody::Left(b) => Pin::new(b).poll_data(cx).map(map_option_err),
EitherBody::Right(b) => Pin::new(b).poll_data(cx).map(map_option_err),
}
}
fn poll_trailers(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
match self.get_mut() {
EitherBody::Left(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
EitherBody::Right(b) => Pin::new(b).poll_trailers(cx).map_err(Into::into),
}
}
}
fn map_option_err<T, U: Into<Error>>(err: Option<Result<T, U>>) -> Option<Result<T, Error>> {
err.map(|e| e.map_err(Into::into))
}
async fn http_serve(path: Tail) -> Result<impl Reply, Rejection> {
let mut path = path.as_str();
async fn service_static_handler(uri: Uri) -> impl IntoResponse {
let mut path = {
let mut chars = uri.path().chars();
chars.next();
chars.as_str()
};
if path.is_empty() {
path = "index.html";
}
let asset = Asset::get(path).ok_or_else(warp::reject::not_found)?;
let mime = mime_guess::from_path(path).first_or_octet_stream();
let mut res = Response::new(asset.data.into());
res.headers_mut().insert(
"content-type",
HeaderValue::from_str(mime.as_ref()).unwrap(),
);
Ok(res)
if let Some(asset) = Asset::get(path) {
let mime = mime_guess::from_path(path).first_or_octet_stream();
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_str(mime.as_ref()).unwrap(),
);
(StatusCode::OK, headers, asset.data.into())
} else {
(StatusCode::NOT_FOUND, HeaderMap::new(), vec![])
}
}
#[derive(Debug, Clone)]
@ -320,13 +234,14 @@ struct GrpcLabels {
status_code: String,
}
struct ApiLogger {}
#[derive(Debug, Clone)]
struct ApiLoggerLayer {}
impl<S> tower::Layer<S> for ApiLogger {
impl<S> tower::Layer<S> for ApiLoggerLayer {
type Service = ApiLoggerService<S>;
fn layer(&self, service: S) -> Self::Service {
ApiLoggerService { inner: service }
fn layer(&self, inner: S) -> Self::Service {
ApiLoggerService { inner }
}
}
@ -335,15 +250,15 @@ struct ApiLoggerService<S> {
inner: S,
}
impl<S, ReqBody, ResBody> Service<http::Request<ReqBody>> for ApiLoggerService<S>
impl<ReqBody, ResBody, S> Service<http::Request<ReqBody>> for ApiLoggerService<S>
where
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
ReqBody: http_body::Body,
ResBody: http_body::Body,
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
S::Error: Into<BoxError> + Send,
{
type Response = S::Response;
type Error = S::Error;
type Future = ApiLoggerResponseFuture<S::Future>;
type Future = ApiLoggerFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
@ -352,10 +267,10 @@ where
fn call(&mut self, request: http::Request<ReqBody>) -> Self::Future {
let uri = request.uri().path().to_string();
let uri_parts: Vec<&str> = uri.split('/').collect();
let response_future = self.inner.call(request);
let future = self.inner.call(request);
let start = Instant::now();
ApiLoggerResponseFuture {
response_future,
ApiLoggerFuture {
future,
start,
service: uri_parts.get(1).map(|v| v.to_string()).unwrap_or_default(),
method: uri_parts.get(2).map(|v| v.to_string()).unwrap_or_default(),
@ -364,25 +279,26 @@ where
}
#[pin_project]
struct ApiLoggerResponseFuture<F> {
struct ApiLoggerFuture<F> {
#[pin]
response_future: F,
future: F,
start: Instant,
service: String,
method: String,
}
impl<F, ResBody, Error> Future for ApiLoggerResponseFuture<F>
impl<ResBody, F, E> Future for ApiLoggerFuture<F>
where
F: Future<Output = Result<http::Response<ResBody>, Error>>,
ResBody: http_body::Body,
F: Future<Output = Result<http::Response<ResBody>, E>>,
E: Into<BoxError> + Send,
{
type Output = Result<http::Response<ResBody>, Error>;
type Output = Result<http::Response<ResBody>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.response_future.poll(cx) {
match this.future.poll(cx) {
Poll::Ready(result) => {
if let Ok(response) = &result {
let status_code: i32 = match response.headers().get("grpc-status") {

View File

@ -1,51 +1,49 @@
use std::convert::Infallible;
use std::net::SocketAddr;
use anyhow::{Context, Result};
use axum::{
response::{IntoResponse, Response},
routing::get,
Router,
};
use diesel_async::RunQueryDsl;
use http::StatusCode;
use tracing::info;
use warp::{http::Response, http::StatusCode, Filter};
use crate::config;
use crate::monitoring::prometheus;
use crate::storage::{get_async_db_conn, get_async_redis_conn};
pub async fn setup() {
pub async fn setup() -> Result<()> {
let conf = config::get();
if conf.monitoring.bind.is_empty() {
return;
return Ok(());
}
let addr: SocketAddr = conf.monitoring.bind.parse().unwrap();
info!(bind = %conf.monitoring.bind, "Setting up monitoring endpoint");
let prom_endpoint = warp::get()
.and(warp::path!("metrics"))
.and_then(prometheus_handler);
let app = Router::new()
.route("/metrics", get(prometheus_handler))
.route("/health", get(health_handler));
let health_endpoint = warp::get()
.and(warp::path!("health"))
.and_then(health_handler);
let routes = prom_endpoint.or(health_endpoint);
warp::serve(routes).run(addr).await;
axum_server::bind(addr)
.serve(app.into_make_service())
.await?;
Ok(())
}
async fn prometheus_handler() -> Result<impl warp::Reply, Infallible> {
async fn prometheus_handler() -> Response {
let body = prometheus::encode_to_string().unwrap_or_default();
Ok(Response::builder().body(body))
body.into_response()
}
async fn health_handler() -> Result<impl warp::Reply, Infallible> {
async fn health_handler() -> Response {
if let Err(e) = _health_handler().await {
return Ok(warp::reply::with_status(
e.to_string(),
StatusCode::SERVICE_UNAVAILABLE,
));
(StatusCode::SERVICE_UNAVAILABLE, e.to_string()).into_response()
} else {
(StatusCode::OK, "".to_string()).into_response()
}
Ok(warp::reply::with_status("OK".to_string(), StatusCode::OK))
}
async fn _health_handler() -> Result<()> {

View File

@ -1,7 +1,10 @@
use std::str::FromStr;
use anyhow::{Context, Result};
use axum::{
extract::Query,
response::{IntoResponse, Redirect, Response},
};
use chrono::Duration;
use http::StatusCode;
use oauth2::basic::BasicClient;
use oauth2::reqwest;
use oauth2::{
@ -11,7 +14,6 @@ use oauth2::{
use reqwest::header::AUTHORIZATION;
use serde::{Deserialize, Serialize};
use tracing::{error, trace};
use warp::{Rejection, Reply};
use crate::config;
use crate::helpers::errors::PrintFullError;
@ -26,7 +28,7 @@ struct ClerkUserinfo {
pub user_id: String,
}
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct CallbackArgs {
pub code: String,
pub state: String,
@ -39,16 +41,12 @@ pub struct User {
pub external_id: String,
}
pub async fn login_handler() -> Result<impl Reply, Rejection> {
pub async fn login_handler() -> Response {
let client = match get_client() {
Ok(v) => v,
Err(e) => {
error!(error = %e.full(), "Get OAuth2 client error");
return Ok(warp::reply::with_status(
"Internal error",
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response());
return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response();
}
};
@ -64,26 +62,16 @@ pub async fn login_handler() -> Result<impl Reply, Rejection> {
if let Err(e) = store_verifier(&csrf_token, &pkce_verifier).await {
error!(error = %e.full(), "Store verifier error");
return Ok(warp::reply::with_status(
"Internal error",
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response());
return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response();
}
Ok(
warp::redirect::found(warp::http::Uri::from_str(auth_url.as_str()).unwrap())
.into_response(),
)
Redirect::temporary(auth_url.as_str()).into_response()
}
pub async fn callback_handler(args: CallbackArgs) -> Result<impl Reply, Rejection> {
// warp::redirect does not work with '#'.
Ok(warp::reply::with_header(
warp::http::StatusCode::PERMANENT_REDIRECT,
warp::http::header::LOCATION,
format!("/#/login?code={}&state={}", args.code, args.state),
))
pub async fn callback_handler(args: Query<CallbackArgs>) -> Response {
let args: CallbackArgs = args.0;
Redirect::permanent(&format!("/#/login?code={}&state={}", args.code, args.state))
.into_response()
}
fn get_client() -> Result<Client> {

View File

@ -1,8 +1,12 @@
use std::collections::HashMap;
use std::str::FromStr;
use anyhow::{Context, Result};
use axum::{
extract::Query,
response::{IntoResponse, Redirect, Response},
};
use chrono::Duration;
use http::StatusCode;
use openidconnect::core::{
CoreClient, CoreGenderClaim, CoreIdTokenClaims, CoreIdTokenVerifier, CoreProviderMetadata,
CoreResponseType,
@ -15,7 +19,6 @@ use openidconnect::{
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tracing::{error, trace};
use warp::{Rejection, Reply};
use crate::config;
use crate::helpers::errors::PrintFullError;
@ -40,22 +43,18 @@ pub struct CustomClaims {
impl AdditionalClaims for CustomClaims {}
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct CallbackArgs {
pub code: String,
pub state: String,
}
pub async fn login_handler() -> Result<impl Reply, Rejection> {
pub async fn login_handler() -> Response {
let client = match get_client().await {
Ok(v) => v,
Err(e) => {
error!(error = %e.full(), "Get OIDC client error");
return Ok(warp::reply::with_status(
"Internal error",
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response());
return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response();
}
};
@ -72,26 +71,16 @@ pub async fn login_handler() -> Result<impl Reply, Rejection> {
if let Err(e) = store_nonce(&csrf_state, &nonce).await {
error!(error = %e.full(), "Store nonce error");
return Ok(warp::reply::with_status(
"Internal error",
warp::http::StatusCode::INTERNAL_SERVER_ERROR,
)
.into_response());
return (StatusCode::INTERNAL_SERVER_ERROR, "Internal error").into_response();
}
Ok(
warp::redirect::found(warp::http::Uri::from_str(auth_url.as_str()).unwrap())
.into_response(),
)
Redirect::temporary(auth_url.as_str()).into_response()
}
pub async fn callback_handler(args: CallbackArgs) -> Result<impl Reply, Rejection> {
// warp::redirect does not work with '#'.
Ok(warp::reply::with_header(
warp::http::StatusCode::PERMANENT_REDIRECT,
warp::http::header::LOCATION,
format!("/#/login?code={}&state={}", args.code, args.state),
))
pub async fn callback_handler(args: Query<CallbackArgs>) -> Response {
let args: CallbackArgs = args.0;
Redirect::permanent(&format!("/#/login?code={}&state={}", args.code, args.state))
.into_response()
}
pub async fn get_user(code: &str, state: &str) -> Result<User> {

View File

@ -2,6 +2,8 @@ use std::fs::File;
use std::io::BufReader;
use anyhow::{Context, Result};
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio::fs;
// Return root certificates, optionally with the provided ca_file appended.
pub fn get_root_certs(ca_file: Option<String>) -> Result<rustls::RootCertStore> {
@ -22,6 +24,38 @@ pub fn get_root_certs(ca_file: Option<String>) -> Result<rustls::RootCertStore>
Ok(roots)
}
pub async fn load_cert(cert_file: &str) -> Result<Vec<CertificateDer<'static>>> {
let cert_s = fs::read_to_string(cert_file)
.await
.context("Read TLS certificate")?;
let mut cert_b = cert_s.as_bytes();
let certs = rustls_pemfile::certs(&mut cert_b);
let mut out = Vec::new();
for cert in certs {
out.push(cert?.into_owned());
}
Ok(out)
}
pub async fn load_key(key_file: &str) -> Result<PrivateKeyDer<'static>> {
let key_s = fs::read_to_string(key_file)
.await
.context("Read private key")?;
let key_s = private_key_to_pkcs8(&key_s)?;
let mut key_b = key_s.as_bytes();
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut key_b);
if let Some(key) = keys.next() {
match key {
Ok(v) => return Ok(PrivateKeyDer::Pkcs8(v.clone_key())),
Err(e) => {
return Err(anyhow!("Error parsing private key, error: {}", e));
}
}
}
Err(anyhow!("No private key found"))
}
pub fn private_key_to_pkcs8(pem: &str) -> Result<String> {
if pem.contains("RSA PRIVATE KEY") {
use rsa::{

View File

@ -383,7 +383,9 @@ async fn test_sns_uplink() {
let resp =
backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap()))
.await;
let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap();
@ -572,7 +574,9 @@ async fn test_sns_roaming_not_allowed() {
let resp =
backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap()))
.await;
let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap();
assert_eq!(
@ -683,7 +687,9 @@ async fn test_sns_dev_not_found() {
let resp =
backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap()))
.await;
let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap();

View File

@ -379,7 +379,9 @@ async fn test_sns() {
let resp =
backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap()))
.await;
let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap();
@ -562,7 +564,9 @@ async fn test_sns_roaming_not_allowed() {
let resp =
backend_api::handle_request(Bytes::from(serde_json::to_string(&pr_start_req).unwrap()))
.await;
let resp_b = hyper::body::to_bytes(resp.into_body()).await.unwrap();
let resp_b = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let pr_start_ans: backend::PRStartAnsPayload = serde_json::from_slice(&resp_b).unwrap();