Refactor JS join_eui to join_eui_prefix.

This makes it possible to use a range of JoinEUIs per Join Server.
Use-cases are either Join Servers using a JoinEUI range or the
configuration of a "catch-all" Join Server prefix ("0000000000000000/0").
This commit is contained in:
Orne Brocaar
2023-10-30 15:44:34 +00:00
parent 091909e8ea
commit 09e1ae0263
13 changed files with 275 additions and 77 deletions

View File

@ -149,11 +149,12 @@ impl Client {
pub async fn join_req( pub async fn join_req(
&self, &self,
receiver_id: Vec<u8>,
pl: &mut JoinReqPayload, pl: &mut JoinReqPayload,
async_resp: Option<Receiver<Vec<u8>>>, async_resp: Option<Receiver<Vec<u8>>>,
) -> Result<JoinAnsPayload> { ) -> Result<JoinAnsPayload> {
pl.base.sender_id = self.config.sender_id.clone(); pl.base.sender_id = self.config.sender_id.clone();
pl.base.receiver_id = self.config.receiver_id.clone(); pl.base.receiver_id = receiver_id;
pl.base.message_type = MessageType::JoinReq; pl.base.message_type = MessageType::JoinReq;
let mut ans: JoinAnsPayload = Default::default(); let mut ans: JoinAnsPayload = Default::default();
@ -231,11 +232,12 @@ impl Client {
pub async fn home_ns_req( pub async fn home_ns_req(
&self, &self,
receiver_id: Vec<u8>,
pl: &mut HomeNSReqPayload, pl: &mut HomeNSReqPayload,
async_resp: Option<Receiver<Vec<u8>>>, async_resp: Option<Receiver<Vec<u8>>>,
) -> Result<HomeNSAnsPayload> { ) -> Result<HomeNSAnsPayload> {
pl.base.sender_id = self.config.sender_id.clone(); pl.base.sender_id = self.config.sender_id.clone();
pl.base.receiver_id = self.config.receiver_id.clone(); pl.base.receiver_id = receiver_id;
pl.base.message_type = MessageType::HomeNSReq; pl.base.message_type = MessageType::HomeNSReq;
let mut ans: HomeNSAnsPayload = Default::default(); let mut ans: HomeNSAnsPayload = Default::default();
@ -1176,8 +1178,7 @@ pub mod test {
let server = MockServer::start(); let server = MockServer::start();
let c = Client::new(ClientConfig { let c = Client::new(ClientConfig {
sender_id: "010203".into(), sender_id: vec![1, 2, 3],
receiver_id: "0102030405060708".into(),
server: server.url("/"), server: server.url("/"),
async_timeout: Duration::from_secs(1), async_timeout: Duration::from_secs(1),
..Default::default() ..Default::default()
@ -1186,8 +1187,8 @@ pub mod test {
let mut req = HomeNSReqPayload { let mut req = HomeNSReqPayload {
base: BasePayload { base: BasePayload {
sender_id: "010203".into(), sender_id: vec![1, 2, 3],
receiver_id: "0102030405060708".into(), receiver_id: vec![1, 2, 3, 4, 5, 6, 7, 8],
message_type: MessageType::HomeNSReq, message_type: MessageType::HomeNSReq,
transaction_id: 1234, transaction_id: 1234,
..Default::default() ..Default::default()
@ -1198,8 +1199,8 @@ pub mod test {
let ans = HomeNSAnsPayload { let ans = HomeNSAnsPayload {
base: BasePayloadResult { base: BasePayloadResult {
base: BasePayload { base: BasePayload {
sender_id: "0102030405060708".into(), sender_id: vec![1, 2, 3, 4, 5, 6, 7, 8],
receiver_id: "010203".into(), receiver_id: vec![1, 2, 3],
message_type: MessageType::HomeNSAns, message_type: MessageType::HomeNSAns,
transaction_id: 1234, transaction_id: 1234,
..Default::default() ..Default::default()
@ -1222,14 +1223,19 @@ pub mod test {
// OK // OK
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
tx.send(serde_json::to_vec(&ans).unwrap()).unwrap(); tx.send(serde_json::to_vec(&ans).unwrap()).unwrap();
let resp = c.home_ns_req(&mut req, Some(rx)).await.unwrap(); let resp = c
.home_ns_req(vec![1, 2, 3, 4, 5, 6, 7, 8], &mut req, Some(rx))
.await
.unwrap();
mock.assert(); mock.assert();
mock.delete(); mock.delete();
assert_eq!(resp, ans); assert_eq!(resp, ans);
// Timeout // Timeout
let (_tx, rx) = oneshot::channel(); let (_tx, rx) = oneshot::channel();
let resp = c.home_ns_req(&mut req, Some(rx)).await; let resp = c
.home_ns_req(vec![1, 2, 3, 4, 5, 6, 7, 8], &mut req, Some(rx))
.await;
assert!(resp.is_err()); assert!(resp.is_err());
} }
@ -1238,8 +1244,7 @@ pub mod test {
let server = MockServer::start(); let server = MockServer::start();
let c = Client::new(ClientConfig { let c = Client::new(ClientConfig {
sender_id: "010203".into(), sender_id: vec![1, 2, 3],
receiver_id: "0102030405060708".into(),
server: server.url("/"), server: server.url("/"),
async_timeout: Duration::from_secs(1), async_timeout: Duration::from_secs(1),
..Default::default() ..Default::default()
@ -1248,8 +1253,8 @@ pub mod test {
let mut req = HomeNSReqPayload { let mut req = HomeNSReqPayload {
base: BasePayload { base: BasePayload {
sender_id: "010203".into(), sender_id: vec![1, 2, 3],
receiver_id: "0102030405060708".into(), receiver_id: vec![1, 2, 3, 4, 5, 6, 7, 8],
message_type: MessageType::HomeNSReq, message_type: MessageType::HomeNSReq,
transaction_id: 1234, transaction_id: 1234,
..Default::default() ..Default::default()
@ -1260,8 +1265,8 @@ pub mod test {
let ans = HomeNSAnsPayload { let ans = HomeNSAnsPayload {
base: BasePayloadResult { base: BasePayloadResult {
base: BasePayload { base: BasePayload {
sender_id: "0102030405060708".into(), sender_id: vec![1, 2, 3, 4, 5, 6, 7, 8],
receiver_id: "010203".into(), receiver_id: vec![1, 2, 3],
message_type: MessageType::HomeNSAns, message_type: MessageType::HomeNSAns,
transaction_id: 1234, transaction_id: 1234,
..Default::default() ..Default::default()
@ -1284,14 +1289,19 @@ pub mod test {
// OK // OK
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
tx.send(serde_json::to_vec(&ans).unwrap()).unwrap(); tx.send(serde_json::to_vec(&ans).unwrap()).unwrap();
let resp = c.home_ns_req(&mut req, Some(rx)).await.unwrap(); let resp = c
.home_ns_req(vec![1, 2, 3, 4, 5, 6, 7, 8], &mut req, Some(rx))
.await
.unwrap();
mock.assert(); mock.assert();
mock.delete(); mock.delete();
assert_eq!(resp, ans); assert_eq!(resp, ans);
// Timeout // Timeout
let (_tx, rx) = oneshot::channel(); let (_tx, rx) = oneshot::channel();
let resp = c.home_ns_req(&mut req, Some(rx)).await; let resp = c
.home_ns_req(vec![1, 2, 3, 4, 5, 6, 7, 8], &mut req, Some(rx))
.await;
assert!(resp.is_err()); assert!(resp.is_err());
} }
@ -1300,8 +1310,7 @@ pub mod test {
let server = MockServer::start(); let server = MockServer::start();
let c = Client::new(ClientConfig { let c = Client::new(ClientConfig {
sender_id: "010203".into(), sender_id: vec![1, 2, 3],
receiver_id: "0102030405060708".into(),
server: server.url("/"), server: server.url("/"),
..Default::default() ..Default::default()
}) })
@ -1309,8 +1318,8 @@ pub mod test {
let mut req = HomeNSReqPayload { let mut req = HomeNSReqPayload {
base: BasePayload { base: BasePayload {
sender_id: "010203".into(), sender_id: vec![1, 2, 3],
receiver_id: "0102030405060708".into(), receiver_id: vec![1, 2, 3, 4, 5, 6, 7, 8],
message_type: MessageType::HomeNSReq, message_type: MessageType::HomeNSReq,
transaction_id: 1234, transaction_id: 1234,
..Default::default() ..Default::default()
@ -1321,8 +1330,8 @@ pub mod test {
let ans = HomeNSAnsPayload { let ans = HomeNSAnsPayload {
base: BasePayloadResult { base: BasePayloadResult {
base: BasePayload { base: BasePayload {
sender_id: "0102030405060708".into(), sender_id: vec![1, 2, 3, 4, 5, 6, 7, 8],
receiver_id: "010203".into(), receiver_id: vec![1, 2, 3],
message_type: MessageType::HomeNSAns, message_type: MessageType::HomeNSAns,
transaction_id: 1234, transaction_id: 1234,
..Default::default() ..Default::default()
@ -1342,7 +1351,10 @@ pub mod test {
.body(serde_json::to_string(&req).unwrap()); .body(serde_json::to_string(&req).unwrap());
then.body(serde_json::to_vec(&ans).unwrap()).status(200); then.body(serde_json::to_vec(&ans).unwrap()).status(200);
}); });
let resp = c.home_ns_req(&mut req, None).await.unwrap(); let resp = c
.home_ns_req(vec![1, 2, 3, 4, 5, 6, 7, 8], &mut req, None)
.await
.unwrap();
mock.assert(); mock.assert();
mock.delete(); mock.delete();
assert_eq!(resp, ans); assert_eq!(resp, ans);
@ -1354,7 +1366,9 @@ pub mod test {
.body(serde_json::to_string(&req).unwrap()); .body(serde_json::to_string(&req).unwrap());
then.status(500); then.status(500);
}); });
let resp = c.home_ns_req(&mut req, None).await; let resp = c
.home_ns_req(vec![1, 2, 3, 4, 5, 6, 7, 8], &mut req, None)
.await;
mock.assert(); mock.assert();
mock.delete(); mock.delete();
assert!(resp.is_err()); assert!(resp.is_err());

View File

@ -101,7 +101,7 @@ pub async fn _handle_request(bp: BasePayload, b: Vec<u8>) -> http::Response<hype
} }
}; };
match joinserver::get(&sender_id) { match joinserver::get(sender_id) {
Ok(v) => v, Ok(v) => v,
Err(_) => { Err(_) => {
warn!("Unknown SenderID"); warn!("Unknown SenderID");

View File

@ -1,4 +1,3 @@
use std::collections::HashMap;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use anyhow::Result; use anyhow::Result;
@ -6,22 +5,24 @@ use tracing::info;
use crate::config; use crate::config;
use backend::{Client, ClientConfig}; use backend::{Client, ClientConfig};
use lrwn::EUI64; use lrwn::{EUI64Prefix, EUI64};
lazy_static! { lazy_static! {
static ref CLIENTS: RwLock<HashMap<EUI64, Arc<Client>>> = RwLock::new(HashMap::new()); static ref CLIENTS: RwLock<Vec<(EUI64Prefix, Arc<Client>)>> = RwLock::new(vec![]);
} }
pub fn setup() -> Result<()> { pub fn setup() -> Result<()> {
info!("Setting up Join Server clients"); info!("Setting up Join Server clients");
let conf = config::get(); let conf = config::get();
let mut clients_w = CLIENTS.write().unwrap();
*clients_w = vec![];
for js in &conf.join_server.servers { for js in &conf.join_server.servers {
info!(join_eui = %js.join_eui, "Configuring Join Server"); info!(join_eui_prefix = %js.join_eui_prefix, "Configuring Join Server");
let c = Client::new(ClientConfig { let c = Client::new(ClientConfig {
sender_id: conf.network.net_id.to_vec(), sender_id: conf.network.net_id.to_vec(),
receiver_id: js.join_eui.to_vec(),
server: js.server.clone(), server: js.server.clone(),
ca_cert: js.ca_cert.clone(), ca_cert: js.ca_cert.clone(),
tls_cert: js.tls_cert.clone(), tls_cert: js.tls_cert.clone(),
@ -30,32 +31,28 @@ pub fn setup() -> Result<()> {
..Default::default() ..Default::default()
})?; })?;
set(&js.join_eui, c); clients_w.push((js.join_eui_prefix, Arc::new(c)));
} }
Ok(()) Ok(())
} }
pub fn set(join_eui: &EUI64, c: Client) { pub fn get(join_eui: EUI64) -> Result<Arc<Client>> {
let mut clients_w = CLIENTS.write().unwrap();
clients_w.insert(*join_eui, Arc::new(c));
}
pub fn get(join_eui: &EUI64) -> Result<Arc<Client>> {
let clients_r = CLIENTS.read().unwrap(); let clients_r = CLIENTS.read().unwrap();
Ok(clients_r for client in clients_r.iter() {
.get(join_eui) if client.0.matches(join_eui) {
.ok_or_else(|| { return Ok(client.1.clone());
anyhow!( }
"Join Server client for join_eui {} does not exist", }
join_eui
) Err(anyhow!(
})? "Join Server client for join_eui {} does not exist",
.clone()) join_eui
))
} }
#[cfg(test)] #[cfg(test)]
pub fn reset() { pub fn reset() {
let mut clients_w = CLIENTS.write().unwrap(); let mut clients_w = CLIENTS.write().unwrap();
*clients_w = HashMap::new(); *clients_w = vec![];
} }

View File

@ -600,11 +600,24 @@ pub fn run() {
[join_server] [join_server]
# Per Join Server configuration (this can be repeated). # Per Join Server configuration (this can be repeated).
#
# ChirpStack will try to match the Join-Request JoinEUI against each
# join_eui_prefix in the same order as they appear in the configuration.
#
# If you configure a 'catch-all' Join Server, then this entry must appear
# as the last item in the list.
#
# Example: # Example:
# [[join_server.servers]] # [[join_server.servers]]
# #
# # JoinEUI of the Join Server. # # JoinEUI prefix that must be routed to the Join Server.
# join_eui="0102030405060708" # #
# # Example '0102030405060700/56` means that the 56MSB of the
# # join_eui_prefix will be used to match against the JoinEUI.
# # Thus the following JoinEUI range will be forwarded to the
# # configured Join Server:
# # 0102030405060700 - 01020304050607ff
# join_eui_prefix="0102030405060708/64"
# #
# # Server endpoint. # # Server endpoint.
# server="https://example.com:1234/join/endpoint" # server="https://example.com:1234/join/endpoint"
@ -633,7 +646,7 @@ pub fn run() {
{{#each join_server.servers}} {{#each join_server.servers}}
[[join_server.servers]] [[join_server.servers]]
join_eui="{{ this.join_eui }}" join_eui_prefix="{{ this.join_eui_prefix }}"
server="{{ this.server }}" server="{{ this.server }}"
async_interface={{ this.async_interface }} async_interface={{ this.async_interface }}
async_interface_timeout="{{ this.async_interface_timeout }}" async_interface_timeout="{{ this.async_interface_timeout }}"

View File

@ -7,7 +7,7 @@ use anyhow::{Context, Result};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use lrwn::region::CommonName; use lrwn::region::CommonName;
use lrwn::{AES128Key, DevAddrPrefix, NetID, EUI64}; use lrwn::{AES128Key, DevAddrPrefix, EUI64Prefix, NetID};
lazy_static! { lazy_static! {
static ref CONFIG: Mutex<Arc<Configuration>> = Mutex::new(Arc::new(Default::default())); static ref CONFIG: Mutex<Arc<Configuration>> = Mutex::new(Arc::new(Default::default()));
@ -416,7 +416,8 @@ pub struct JoinServer {
#[derive(Serialize, Deserialize, Default, Clone)] #[derive(Serialize, Deserialize, Default, Clone)]
#[serde(default)] #[serde(default)]
pub struct JoinServerServer { pub struct JoinServerServer {
pub join_eui: EUI64, #[serde(alias = "join_eui")]
pub join_eui_prefix: EUI64Prefix,
pub server: String, pub server: String,
#[serde(with = "humantime_serde")] #[serde(with = "humantime_serde")]
pub async_timeout: Duration, pub async_timeout: Duration,

View File

@ -8,7 +8,7 @@ use crate::{
uplink, uplink,
}; };
use chirpstack_api::{common, gw, integration as integration_pb, internal}; use chirpstack_api::{common, gw, integration as integration_pb, internal};
use lrwn::{DevAddr, EUI64}; use lrwn::{DevAddr, EUI64Prefix, EUI64};
struct Test { struct Test {
name: String, name: String,
@ -361,11 +361,11 @@ async fn run_test(t: &Test) {
}); });
let mut conf: config::Configuration = (*config::get()).clone(); let mut conf: config::Configuration = (*config::get()).clone();
conf.join_server.servers.push(config::JoinServerServer { conf.join_server.servers = vec![config::JoinServerServer {
join_eui: EUI64::from_be_bytes([1, 2, 3, 4, 5, 6, 7, 8]), join_eui_prefix: EUI64Prefix::new([1, 2, 3, 4, 5, 6, 7, 8], 64),
server: server.url("/"), server: server.url("/"),
..Default::default() ..Default::default()
}); }];
config::set(conf); config::set(conf);
region::setup().unwrap(); region::setup().unwrap();
joinserver::setup().unwrap(); joinserver::setup().unwrap();

View File

@ -17,7 +17,7 @@ use crate::storage::{
}; };
use crate::{config, test, uplink}; use crate::{config, test, uplink};
use chirpstack_api::gw; use chirpstack_api::gw;
use lrwn::{AES128Key, NetID, EUI64}; use lrwn::{AES128Key, EUI64Prefix, NetID, EUI64};
#[tokio::test] #[tokio::test]
async fn test_fns() { async fn test_fns() {
@ -32,18 +32,18 @@ async fn test_fns() {
conf.network.net_id = NetID::from_str("010203").unwrap(); conf.network.net_id = NetID::from_str("010203").unwrap();
// Set Join Server. // Set Join Server.
conf.join_server.servers.push(config::JoinServerServer { conf.join_server.servers = vec![config::JoinServerServer {
join_eui: EUI64::from_str("0102030405060708").unwrap(), join_eui_prefix: EUI64Prefix::new([1, 2, 3, 4, 5, 6, 7, 8], 64),
server: js_mock.url("/"), server: js_mock.url("/"),
..Default::default() ..Default::default()
}); }];
// Set roaming agreement. // Set roaming agreement.
conf.roaming.servers.push(config::RoamingServer { conf.roaming.servers = vec![config::RoamingServer {
net_id: NetID::from_str("030201").unwrap(), net_id: NetID::from_str("030201").unwrap(),
server: sns_mock.url("/"), server: sns_mock.url("/"),
..Default::default() ..Default::default()
}); }];
config::set(conf); config::set(conf);
joinserver::setup().unwrap(); joinserver::setup().unwrap();

View File

@ -264,7 +264,7 @@ impl JoinRequest {
if self.device_keys.is_none() { if self.device_keys.is_none() {
trace!(join_eui = %jr.join_eui, "Getting Join Server client"); trace!(join_eui = %jr.join_eui, "Getting Join Server client");
self.js_client = Some(joinserver::get(&jr.join_eui)?); self.js_client = Some(joinserver::get(jr.join_eui)?);
} }
Ok(()) Ok(())
@ -537,6 +537,7 @@ impl JoinRequest {
trace!("Getting join-accept from Join Server"); trace!("Getting join-accept from Join Server");
let js_client = self.js_client.as_ref().unwrap(); let js_client = self.js_client.as_ref().unwrap();
let jr = self.join_request.as_ref().unwrap();
let region_network = config::get_region_network(&self.uplink_frame_set.region_config_id)?; let region_network = config::get_region_network(&self.uplink_frame_set.region_config_id)?;
let region_conf = region::get(&self.uplink_frame_set.region_config_id)?; let region_conf = region::get(&self.uplink_frame_set.region_config_id)?;
@ -584,7 +585,9 @@ impl JoinRequest {
), ),
}; };
let join_ans_pl = js_client.join_req(&mut join_req_pl, async_receiver).await?; let join_ans_pl = js_client
.join_req(jr.join_eui.to_vec(), &mut join_req_pl, async_receiver)
.await?;
if let Some(v) = &join_ans_pl.app_s_key { if let Some(v) = &join_ans_pl.app_s_key {
self.app_s_key = Some(common::KeyEnvelope { self.app_s_key = Some(common::KeyEnvelope {

View File

@ -58,7 +58,7 @@ impl JoinRequest {
trace!("Getting home netid"); trace!("Getting home netid");
trace!(join_eui = %self.join_request.join_eui, "Trying to get join-server client"); trace!(join_eui = %self.join_request.join_eui, "Trying to get join-server client");
let js_client = joinserver::get(&self.join_request.join_eui)?; let js_client = joinserver::get(self.join_request.join_eui)?;
let mut home_ns_req = backend::HomeNSReqPayload { let mut home_ns_req = backend::HomeNSReqPayload {
dev_eui: self.join_request.dev_eui.to_vec(), dev_eui: self.join_request.dev_eui.to_vec(),
@ -83,7 +83,11 @@ impl JoinRequest {
trace!("Requesting home netid"); trace!("Requesting home netid");
let home_ns_ans = js_client let home_ns_ans = js_client
.home_ns_req(&mut home_ns_req, async_receiver) .home_ns_req(
self.join_request.join_eui.to_vec(),
&mut home_ns_req,
async_receiver,
)
.await?; .await?;
self.home_net_id = Some(NetID::from_slice(&home_ns_ans.h_net_id)?); self.home_net_id = Some(NetID::from_slice(&home_ns_ans.h_net_id)?);

View File

@ -147,7 +147,7 @@ impl JoinRequest {
if self.device_keys.is_none() { if self.device_keys.is_none() {
trace!(join_eui = %jr.join_eui, "Getting Join Server client"); trace!(join_eui = %jr.join_eui, "Getting Join Server client");
self.js_client = Some(joinserver::get(&jr.join_eui)?); self.js_client = Some(joinserver::get(jr.join_eui)?);
} }
Ok(()) Ok(())
@ -226,6 +226,7 @@ impl JoinRequest {
trace!("Getting join-accept from Join Server"); trace!("Getting join-accept from Join Server");
let js_client = self.js_client.as_ref().unwrap(); let js_client = self.js_client.as_ref().unwrap();
let jr = self.join_request.as_ref().unwrap();
let region_network = config::get_region_network(&self.uplink_frame_set.region_config_id)?; let region_network = config::get_region_network(&self.uplink_frame_set.region_config_id)?;
let region_conf = region::get(&self.uplink_frame_set.region_config_id)?; let region_conf = region::get(&self.uplink_frame_set.region_config_id)?;
@ -262,7 +263,9 @@ impl JoinRequest {
..Default::default() ..Default::default()
}; };
let join_ans_pl = js_client.join_req(&mut join_req_pl, None).await?; let join_ans_pl = js_client
.join_req(jr.join_eui.to_vec(), &mut join_req_pl, None)
.await?;
if let Some(v) = &join_ans_pl.app_s_key { if let Some(v) = &join_ans_pl.app_s_key {
self.app_s_key = Some(common::KeyEnvelope { self.app_s_key = Some(common::KeyEnvelope {

View File

@ -46,8 +46,12 @@ impl FromStr for DevAddrPrefix {
type Err = Error; type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.to_string(); let s = s.to_string();
let mut size: u32 = 32;
let parts: Vec<&str> = s.split('/').collect(); let parts: Vec<&str> = s.split('/').collect();
if parts.len() != 2 { if parts.len() == 2 {
size = parts[1].parse().map_err(|_| Error::DevAddrPrefixFormat)?;
}
if parts.len() > 2 {
return Err(Error::DevAddrPrefixFormat); return Err(Error::DevAddrPrefixFormat);
} }
@ -57,7 +61,6 @@ impl FromStr for DevAddrPrefix {
let mut mask: [u8; 4] = [0; 4]; let mut mask: [u8; 4] = [0; 4];
hex::decode_to_slice(parts[0], &mut mask)?; hex::decode_to_slice(parts[0], &mut mask)?;
let size: u32 = parts[1].parse().map_err(|_| Error::DevAddrPrefixFormat)?;
Ok(DevAddrPrefix(mask, size)) Ok(DevAddrPrefix(mask, size))
} }
@ -311,6 +314,10 @@ mod tests {
let p = DevAddrPrefix::from_str("01000000/8").unwrap(); let p = DevAddrPrefix::from_str("01000000/8").unwrap();
assert_eq!(DevAddrPrefix::new([1, 0, 0, 0], 8), p); assert_eq!(DevAddrPrefix::new([1, 0, 0, 0], 8), p);
assert_eq!("01000000/8", p.to_string()); assert_eq!("01000000/8", p.to_string());
let p = DevAddrPrefix::from_str("01020304").unwrap();
assert_eq!(DevAddrPrefix::new([1, 2, 3, 4], 32), p);
assert_eq!("01020304/32", p.to_string());
} }
#[test] #[test]

View File

@ -17,6 +17,12 @@ pub enum Error {
#[error("DevAddrPrefix must be in the form 00000000/0")] #[error("DevAddrPrefix must be in the form 00000000/0")]
DevAddrPrefixFormat, DevAddrPrefixFormat,
#[error("EUI64Prefix must be in the form 0000000000000000/0")]
EUI64PrefixFormat,
#[error(transparent)] #[error(transparent)]
FromHexError(#[from] hex::FromHexError), FromHexError(#[from] hex::FromHexError),
#[error(transparent)]
Anyhow(#[from] anyhow::Error),
} }

View File

@ -1,7 +1,7 @@
use std::fmt; use std::fmt;
use std::str::FromStr; use std::str::FromStr;
use anyhow::Result; use anyhow::{Context, Result};
#[cfg(feature = "diesel")] #[cfg(feature = "diesel")]
use diesel::{backend::Backend, deserialize, serialize, sql_types::Binary}; use diesel::{backend::Backend, deserialize, serialize, sql_types::Binary};
#[cfg(feature = "serde")] #[cfg(feature = "serde")]
@ -151,12 +151,115 @@ impl diesel::sql_types::SqlType for EUI64 {
type IsNull = diesel::sql_types::is_nullable::NotNull; type IsNull = diesel::sql_types::is_nullable::NotNull;
} }
#[derive(PartialEq, Eq, Copy, Clone, Default)]
pub struct EUI64Prefix([u8; 8], u64);
impl EUI64Prefix {
pub fn new(prefix: [u8; 8], size: u64) -> Self {
EUI64Prefix(prefix, size)
}
pub fn matches(&self, eui: EUI64) -> bool {
if self.size() == 0 {
return true;
}
let eui = u64::from_be_bytes(eui.to_be_bytes());
let prefix = u64::from_be_bytes(self.prefix());
let shift = 64 - self.size();
(prefix >> shift) == (eui >> shift)
}
fn prefix(&self) -> [u8; 8] {
self.0
}
fn size(&self) -> u64 {
self.1
}
}
impl fmt::Display for EUI64Prefix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", hex::encode(self.prefix()), self.size())
}
}
impl fmt::Debug for EUI64Prefix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", hex::encode(self.prefix()), self.size())
}
}
impl FromStr for EUI64Prefix {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let s = s.to_string();
let mut size: u64 = 64;
let parts: Vec<&str> = s.split("/").collect();
if parts.len() == 2 {
size = parts[1].parse().map_err(|_| Error::EUI64PrefixFormat)?;
}
if parts.len() > 2 {
return Err(Error::EUI64PrefixFormat);
}
if parts[0].len() != 16 {
return Err(Error::EUI64PrefixFormat);
}
let mut mask: [u8; 8] = [0; 8];
hex::decode_to_slice(parts[0], &mut mask).context("Decode EUI64Prefix")?;
Ok(EUI64Prefix(mask, size))
}
}
#[cfg(feature = "serde")]
impl Serialize for EUI64Prefix {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for EUI64Prefix {
fn deserialize<D>(deserialize: D) -> Result<EUI64Prefix, D::Error>
where
D: Deserializer<'de>,
{
deserialize.deserialize_str(EUI64PrefixVisitor)
}
}
#[cfg(feature = "serde")]
struct EUI64PrefixVisitor;
#[cfg(feature = "serde")]
impl<'de> Visitor<'de> for EUI64PrefixVisitor {
type Value = EUI64Prefix;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("A EUI64Prefix in the format 0000000000000000/0 is expected")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
EUI64Prefix::from_str(value).map_err(|e| E::custom(format!("{}", e)))
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test] #[test]
fn test_to_le_bytes() { fn test_eui64_to_le_bytes() {
let eui = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); let eui = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
assert_eq!( assert_eq!(
@ -166,7 +269,7 @@ mod tests {
} }
#[test] #[test]
fn test_from_le_bytes() { fn test_eui64_from_le_bytes() {
let eui64_from_le = EUI64::from_le_bytes([0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]); let eui64_from_le = EUI64::from_le_bytes([0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01]);
let eui64_from_be = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); let eui64_from_be = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
@ -174,14 +277,61 @@ mod tests {
} }
#[test] #[test]
fn test_to_string() { fn test_eui64_to_string() {
let eui = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); let eui = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
assert_eq!(eui.to_string(), "0102030405060708"); assert_eq!(eui.to_string(), "0102030405060708");
} }
#[test] #[test]
fn test_from_str() { fn test_eui64_from_str() {
let eui = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); let eui = EUI64::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]);
assert_eq!(eui, EUI64::from_str(&"0102030405060708").unwrap()); assert_eq!(eui, EUI64::from_str(&"0102030405060708").unwrap());
} }
#[test]
fn test_eui64_prefix() {
let p = EUI64Prefix::from_str("0102030405060708").unwrap();
assert_eq!(EUI64Prefix::new([1, 2, 3, 4, 5, 6, 7, 8], 64), p);
assert_eq!("0102030405060708/64", p.to_string());
let p = EUI64Prefix::from_str("0100000000000000/8").unwrap();
assert_eq!(EUI64Prefix::new([1, 0, 0, 0, 0, 0, 0, 0], 8), p);
assert_eq!("0100000000000000/8", p.to_string());
}
#[test]
fn test_eui64_prefix_is_eui64() {
struct Test {
prefix: EUI64Prefix,
eui: EUI64,
matches: bool,
}
let tests = vec![
Test {
prefix: EUI64Prefix::from_str("0000000000000000/0").unwrap(),
eui: EUI64::from_str("0000000000000000").unwrap(),
matches: true,
},
Test {
prefix: EUI64Prefix::from_str("0000000000000000/0").unwrap(),
eui: EUI64::from_str("ffffffffffffffff").unwrap(),
matches: true,
},
Test {
eui: EUI64::from_str("ffffffff00000000").unwrap(),
prefix: EUI64Prefix::from_str("ff00000000000000/8").unwrap(),
matches: true,
},
Test {
eui: EUI64::from_str("ffffffff00000000").unwrap(),
prefix: EUI64Prefix::from_str("ff00000000000000/9").unwrap(),
matches: false,
},
];
for tst in &tests {
assert_eq!(tst.matches, tst.prefix.matches(tst.eui));
}
}
} }