Refactor MQTT integration to use rumqttc.

This commit is contained in:
Orne Brocaar 2023-11-29 12:00:42 +00:00
parent 345d0d8462
commit 17f0d8c495
6 changed files with 278 additions and 163 deletions

22
Cargo.lock generated
View File

@ -798,6 +798,7 @@ dependencies = [
"regex",
"reqwest",
"rquickjs",
"rumqttc",
"rust-embed",
"rustls",
"rustls-native-certs",
@ -1593,6 +1594,8 @@ version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181"
dependencies = [
"futures-core",
"futures-sink",
"spin 0.9.8",
]
@ -3693,6 +3696,25 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rumqttc"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d8941c6791801b667d52bfe9ff4fc7c968d4f3f9ae8ae7abdaaa1c966feafc8"
dependencies = [
"bytes",
"flume 0.11.0",
"futures-util",
"log",
"rustls-native-certs",
"rustls-pemfile",
"rustls-webpki",
"thiserror",
"tokio",
"tokio-rustls",
"url",
]
[[package]]
name = "rust-embed"
version = "8.0.0"

View File

@ -110,6 +110,7 @@ openidconnect = { version = "3.3", features = ["accept-rfc3339-timestamps"] }
# MQTT
paho-mqtt = { version = "0.12", features = ["ssl"] }
rumqttc = { version = "0.23", features = ["url"] }
hex = "0.4"
# Codecs

View File

@ -1 +1,2 @@
pub mod errors;
pub mod tls;

View File

@ -0,0 +1,49 @@
use std::fs::File;
use std::io::BufReader;
use anyhow::{Context, Result};
// Return root certificates, optionally with the provided ca_file appended.
pub fn get_root_certs(ca_file: Option<String>) -> Result<rustls::RootCertStore> {
let mut roots = rustls::RootCertStore::empty();
let certs = rustls_native_certs::load_native_certs()?;
let certs: Vec<_> = certs.into_iter().map(|cert| cert.0).collect();
roots.add_parsable_certificates(&certs);
if let Some(ca_file) = &ca_file {
let f = File::open(ca_file).context("Open CA certificate")?;
let mut reader = BufReader::new(f);
let certs = rustls_pemfile::certs(&mut reader)?;
for cert in certs
.into_iter()
.map(rustls::Certificate)
.collect::<Vec<_>>()
{
roots.add(&cert)?;
}
}
Ok(roots)
}
pub fn load_cert(cert_file: &str) -> Result<Vec<rustls::Certificate>> {
let f = File::open(cert_file).context("Open TLS certificate")?;
let mut reader = BufReader::new(f);
let certs = rustls_pemfile::certs(&mut reader)?;
let certs = certs
.into_iter()
.map(rustls::Certificate)
.collect::<Vec<_>>();
Ok(certs)
}
pub fn load_key(key_file: &str) -> Result<rustls::PrivateKey> {
let f = File::open(key_file).context("Open private key")?;
let mut reader = BufReader::new(f);
let mut keys = rustls_pemfile::pkcs8_private_keys(&mut reader)?;
match keys.len() {
0 => Err(anyhow!("No private key found")),
1 => Ok(rustls::PrivateKey(keys.remove(0))),
_ => Err(anyhow!("More than one private key found")),
}
}

View File

@ -1,29 +1,32 @@
use std::collections::HashMap;
use std::env::temp_dir;
use std::io::Cursor;
use std::time::Duration;
use anyhow::{Context, Result};
use anyhow::Result;
use async_trait::async_trait;
use futures::stream::StreamExt;
use handlebars::Handlebars;
use paho_mqtt as mqtt;
use prost::Message;
use rand::Rng;
use regex::Regex;
use rumqttc::tokio_rustls::rustls;
use rumqttc::v5::mqttbytes::v5::{ConnectReturnCode, Publish};
use rumqttc::v5::{mqttbytes::QoS, AsyncClient, Event, Incoming, MqttOptions};
use rumqttc::Transport;
use serde::Serialize;
use tokio::sync::mpsc;
use tracing::{error, info};
use tokio::time::sleep;
use tracing::{error, info, trace, warn};
use super::Integration as IntegrationTrait;
use crate::config::MqttIntegration as Config;
use crate::helpers::tls::{get_root_certs, load_cert, load_key};
use chirpstack_api::integration;
pub struct Integration<'a> {
client: mqtt::AsyncClient,
client: AsyncClient,
templates: Handlebars<'a>,
json: bool,
qos: usize,
qos: QoS,
command_regex: Regex,
}
@ -70,76 +73,57 @@ impl<'a> Integration<'a> {
conf.client_id.clone()
};
// Create subscribe channel
// This is needed as we can't subscribe within the set_connected_callback as this would
// block the callback (we want to wait for success or error), which would create a
// deadlock. We need to re-subscribe on (re)connect to be sure we have a subscription. Even
// Get QoS
let qos = match conf.qos {
0 => QoS::AtMostOnce,
1 => QoS::AtLeastOnce,
2 => QoS::ExactlyOnce,
_ => return Err(anyhow!("Invalid QoS: {}", conf.qos)),
};
// Create connect channel
// We need to re-subscribe on (re)connect to be sure we have a subscription. Even
// in case of a persistent MQTT session, there is no guarantee that the MQTT persisted the
// session and that a re-connect would recover the subscription.
let (subscribe_tx, mut subscribe_rx) = mpsc::channel(10);
let (connect_tx, mut connect_rx) = mpsc::channel(1);
// create client
let create_opts = mqtt::CreateOptionsBuilder::new()
.server_uri(&conf.server)
.client_id(&client_id)
.persistence(mqtt::create_options::PersistenceType::FilePath(temp_dir()))
.finalize();
let mut client = mqtt::AsyncClient::new(create_opts).context("Create MQTT client")?;
client.set_connected_callback(move |_client| {
info!("Connected to MQTT broker");
if let Err(e) = subscribe_tx.try_send(()) {
error!(error = %e, "Send to subscribe channel error");
}
});
client.set_connection_lost_callback(|_client| {
error!("MQTT connection to broker lost");
});
// Create client
let mut mqtt_opts =
MqttOptions::parse_url(format!("{}?client_id={}", conf.server, client_id))?;
mqtt_opts.set_clean_start(conf.clean_session);
mqtt_opts.set_keep_alive(conf.keep_alive_interval);
if !conf.username.is_empty() || !conf.password.is_empty() {
mqtt_opts.set_credentials(&conf.username, &conf.password);
}
// connection options
let mut conn_opts_b = mqtt::ConnectOptionsBuilder::new();
conn_opts_b.automatic_reconnect(Duration::from_secs(1), Duration::from_secs(30));
conn_opts_b.clean_session(conf.clean_session);
conn_opts_b.keep_alive_interval(conf.keep_alive_interval);
if !conf.username.is_empty() {
conn_opts_b.user_name(&conf.username);
}
if !conf.password.is_empty() {
conn_opts_b.password(&conf.password);
}
if !conf.ca_cert.is_empty() || !conf.tls_cert.is_empty() || !conf.tls_key.is_empty() {
info!(
ca_cert = %conf.ca_cert,
tls_cert = %conf.tls_cert,
tls_key = %conf.tls_key,
"Configuring connection with TLS certificate"
"Configuring client with TLS certificate, ca_cert: {}, tls_cert: {}, tls_key: {}",
conf.ca_cert, conf.tls_cert, conf.tls_key
);
let mut ssl_opts_b = mqtt::SslOptionsBuilder::new();
let root_certs = get_root_certs(if conf.ca_cert.is_empty() {
None
} else {
Some(conf.ca_cert.clone())
})?;
if !conf.ca_cert.is_empty() {
ssl_opts_b
.trust_store(&conf.ca_cert)
.context("Failed to set gateway ca_cert")?;
}
let client_conf = if conf.tls_cert.is_empty() && conf.tls_key.is_empty() {
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_certs.clone())
.with_no_client_auth()
} else {
rustls::ClientConfig::builder()
.with_safe_defaults()
.with_root_certificates(root_certs.clone())
.with_client_auth_cert(load_cert(&conf.tls_cert)?, load_key(&conf.tls_key)?)?
};
if !conf.tls_cert.is_empty() {
ssl_opts_b
.key_store(&conf.tls_cert)
.context("Failed to set gateway tls_cert")?;
}
if !conf.tls_key.is_empty() {
ssl_opts_b
.private_key(&conf.tls_key)
.context("Failed to set gateway tls_key")?;
}
conn_opts_b.ssl_options(ssl_opts_b.finalize());
mqtt_opts.set_transport(Transport::tls_with_config(client_conf.into()));
}
let conn_opts = conn_opts_b.finalize();
// get message stream
let mut stream = client.get_stream(None);
let (client, mut eventloop) = AsyncClient::new(mqtt_opts, 100);
let i = Integration {
command_regex: Regex::new(&templates.render(
@ -150,7 +134,7 @@ impl<'a> Integration<'a> {
command: r"(?P<command>[\w]+)".to_string(),
},
)?)?,
qos: conf.qos,
qos: qos,
json: conf.json,
client,
templates,
@ -158,54 +142,77 @@ impl<'a> Integration<'a> {
// connect
info!(server_uri = %conf.server, client_id = %client_id, clean_session = conf.clean_session, "Connecting to MQTT broker");
i.client
.connect(conn_opts)
.await
.context("Connect to MQTT broker")?;
// Command consume loop.
// (Re)subscribe loop
tokio::spawn({
let command_regex = i.command_regex.clone();
let client = i.client.clone();
let qos = i.qos;
async move {
info!("Starting MQTT consumer loop");
while let Some(msg_opt) = stream.next().await {
if let Some(msg) = msg_opt {
let caps = match command_regex.captures(msg.topic()) {
Some(v) => v,
None => {
error!(topic = %msg.topic(), "Error parsing command topic (regex captures returned None)");
continue;
}
};
if caps.len() != 4 {
error!(topic = %msg.topic(), "Parsing command topic returned invalid match count");
continue;
}
message_callback(
caps.get(1).map_or("", |m| m.as_str()).to_string(),
caps.get(2).map_or("", |m| m.as_str()).to_string(),
caps.get(3).map_or("", |m| m.as_str()).to_string(),
i.json,
msg,
)
.await;
while connect_rx.recv().await.is_some() {
info!(command_topic = %command_topic, "Subscribing to command topic");
if let Err(e) = client.subscribe(&command_topic, qos).await {
error!(error = %e, "Subscribe to command topic error");
}
}
}
});
// (Re)subscribe loop.
// Eventloop
tokio::spawn({
let client = i.client.clone();
let qos = conf.qos as i32;
let command_regex = i.command_regex.clone();
let json = i.json;
async move {
while subscribe_rx.recv().await.is_some() {
info!(command_topic = %command_topic, "Subscribing to command topic");
if let Err(e) = client.subscribe(&command_topic, qos).await {
error!(error = %e, "MQTT subscribe error");
info!("Starting MQTT event loop");
loop {
match eventloop.poll().await {
Ok(v) => {
trace!(event = ?v, "MQTT event");
match v {
Event::Incoming(Incoming::Publish(p)) => {
let topic = String::from_utf8_lossy(&p.topic);
let caps = match command_regex.captures(&topic) {
Some(v) => v,
None => {
warn!(topic = %topic, "Error parsing command topic (regex captures returned None");
continue;
}
};
if caps.len() != 4 {
warn!(topic = %topic, "Parsing command topic returned invalid match count");
continue;
}
message_callback(
caps.get(1).map_or("", |m| m.as_str()).to_string(),
caps.get(2).map_or("", |m| m.as_str()).to_string(),
caps.get(3).map_or("", |m| m.as_str()).to_string(),
json,
p,
)
.await;
}
Event::Incoming(Incoming::ConnAck(v)) => {
if v.code == ConnectReturnCode::Success {
if let Err(e) = connect_tx.try_send(()) {
error!(error = %e, "Send to subscribe channel error");
}
} else {
error!(code = ?v.code, "Connection error");
sleep(Duration::from_secs(1)).await
}
}
_ => {}
}
}
Err(e) => {
error!(error = %e, "MQTT error");
sleep(Duration::from_secs(1)).await
}
}
}
}
@ -226,10 +233,9 @@ impl<'a> Integration<'a> {
)?)
}
async fn publish_event(&self, topic: &str, b: &[u8]) -> Result<()> {
async fn publish_event(&self, topic: &str, b: Vec<u8>) -> Result<()> {
info!(topic = %topic, "Publishing event");
let msg = mqtt::Message::new(topic, b, self.qos as i32);
self.client.publish(msg).await?;
self.client.publish(topic, self.qos, false, b).await?;
Ok(())
}
}
@ -252,7 +258,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
async fn join_event(
@ -271,7 +277,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
async fn ack_event(
@ -290,7 +296,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
async fn txack_event(
@ -309,7 +315,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
async fn log_event(
@ -328,7 +334,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
async fn status_event(
@ -347,7 +353,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
async fn location_event(
@ -367,7 +373,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
async fn integration_event(
@ -387,7 +393,7 @@ impl IntegrationTrait for Integration<'_> {
false => pl.encode_to_vec(),
};
self.publish_event(&topic, &b).await
self.publish_event(&topic, b).await
}
}
@ -396,20 +402,18 @@ async fn message_callback(
dev_eui: String,
command: String,
json: bool,
msg: mqtt::Message,
p: Publish,
) {
let topic = msg.topic();
let qos = msg.qos();
let b = msg.payload();
let topic = String::from_utf8_lossy(&p.topic);
info!(topic = topic, qos = qos, "Command received for device");
info!(topic = %topic, qos = ?p.qos, "Command received for device");
let err = || -> Result<()> {
match command.as_ref() {
"down" => {
let cmd: integration::DownlinkCommand = match json {
true => serde_json::from_slice(b)?,
false => integration::DownlinkCommand::decode(&mut Cursor::new(b))?,
true => serde_json::from_slice(&p.payload)?,
false => integration::DownlinkCommand::decode(&mut Cursor::new(&p.payload))?,
};
if dev_eui != cmd.dev_eui {
return Err(anyhow!(
@ -430,9 +434,9 @@ async fn message_callback(
.err();
if err.is_some() {
error!(
topic = topic,
qos = qos,
warn!(
topic = %topic,
qos = ?p.qos,
"Processing command error: {}",
err.as_ref().unwrap()
);
@ -447,9 +451,8 @@ pub mod test {
use crate::config::MqttIntegration;
use crate::storage::{application, device, device_profile, device_queue, tenant};
use crate::test;
use futures::stream::StreamExt;
use lrwn::EUI64;
use paho_mqtt as mqtt;
use tokio::sync::mpsc;
use tokio::time::{sleep, Duration};
use uuid::Uuid;
@ -498,24 +501,38 @@ pub mod test {
};
let i = Integration::new(&conf).await.unwrap();
let create_opts = mqtt::CreateOptionsBuilder::new()
.server_uri(&conf.server)
.finalize();
let mut client = mqtt::AsyncClient::new(create_opts).unwrap();
let conn_opts = mqtt::ConnectOptionsBuilder::new()
.clean_session(true)
.finalize();
let mut stream = client.get_stream(None);
client.connect(conn_opts).await.unwrap();
let mut mqtt_opts =
MqttOptions::parse_url(format!("{}?client_id=chirpstack_test", &conf.server)).unwrap();
mqtt_opts.set_clean_start(true);
let (client, mut eventloop) = AsyncClient::new(mqtt_opts, 100);
let (mqtt_tx, mut mqtt_rx) = mpsc::channel(100);
tokio::spawn({
async move {
loop {
match eventloop.poll().await {
Ok(v) => match v {
Event::Incoming(Incoming::Publish(p)) => mqtt_tx.send(p).await.unwrap(),
_ => {}
},
Err(_) => {
break;
}
}
}
}
});
client
.subscribe(
"application/00000000-0000-0000-0000-000000000000/device/+/event/+",
mqtt::QOS_0,
QoS::AtLeastOnce,
)
.await
.unwrap();
sleep(Duration::from_millis(100)).await;
// uplink event
let pl = integration::UplinkEvent {
device_info: Some(integration::DeviceInfo {
@ -526,12 +543,15 @@ pub mod test {
..Default::default()
};
i.uplink_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/up",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// join event
let pl = integration::JoinEvent {
@ -543,12 +563,15 @@ pub mod test {
..Default::default()
};
i.join_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/join",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// ack event
let pl = integration::AckEvent {
@ -560,12 +583,15 @@ pub mod test {
..Default::default()
};
i.ack_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/ack",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// txack event
let pl = integration::TxAckEvent {
@ -577,12 +603,15 @@ pub mod test {
..Default::default()
};
i.txack_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/txack",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// log event
let pl = integration::LogEvent {
@ -594,12 +623,15 @@ pub mod test {
..Default::default()
};
i.log_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/log",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// status event
let pl = integration::StatusEvent {
@ -611,12 +643,15 @@ pub mod test {
..Default::default()
};
i.status_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/status",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// location event
let pl = integration::LocationEvent {
@ -628,12 +663,15 @@ pub mod test {
..Default::default()
};
i.location_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/location",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// integration event
let pl = integration::IntegrationEvent {
@ -645,12 +683,15 @@ pub mod test {
..Default::default()
};
i.integration_event(&HashMap::new(), &pl).await.unwrap();
let msg = stream.next().await.unwrap().unwrap();
let msg = mqtt_rx.recv().await.unwrap();
assert_eq!(
"application/00000000-0000-0000-0000-000000000000/device/0102030405060708/event/integration",
msg.topic()
String::from_utf8(msg.topic.to_vec()).unwrap()
);
assert_eq!(
serde_json::to_string(&pl).unwrap(),
String::from_utf8(msg.payload.to_vec()).unwrap()
);
assert_eq!(serde_json::to_string(&pl).unwrap(), msg.payload_str());
// downlink command
let down_cmd = integration::DownlinkCommand {
@ -663,11 +704,12 @@ pub mod test {
};
let down_cmd_json = serde_json::to_string(&down_cmd).unwrap();
client
.publish(mqtt::Message::new(
.publish(
format!("application/{}/device/{}/command/down", app.id, dev.dev_eui),
QoS::AtLeastOnce,
false,
down_cmd_json,
mqtt::QOS_0,
))
)
.await
.unwrap();

View File

@ -466,7 +466,7 @@ async fn test_sns_roaming_not_allowed() {
.await
.unwrap();
let dev = device::create(device::Device {
let _dev = device::create(device::Device {
name: "device".into(),
application_id: app.id.clone(),
device_profile_id: dp.id.clone(),