diff --git a/rustybits/zeroidc/build.rs b/rustybits/zeroidc/build.rs index 324ecfedd..c4b4c9072 100644 --- a/rustybits/zeroidc/build.rs +++ b/rustybits/zeroidc/build.rs @@ -8,10 +8,7 @@ fn main() { let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); let package_name = env::var("CARGO_PKG_NAME").unwrap(); - let output_file = target_dir() - .join(format!("{}.h", package_name)) - .display() - .to_string(); + let output_file = target_dir().join(format!("{}.h", package_name)).display().to_string(); let config = Config { language: Language::C, diff --git a/rustybits/zeroidc/src/error.rs b/rustybits/zeroidc/src/error.rs index 039429e6e..a3907c2d6 100644 --- a/rustybits/zeroidc/src/error.rs +++ b/rustybits/zeroidc/src/error.rs @@ -15,9 +15,7 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum ZeroIDCError { #[error(transparent)] - DiscoveryError( - #[from] openidconnect::DiscoveryError>, - ), + DiscoveryError(#[from] openidconnect::DiscoveryError>), #[error(transparent)] ParseError(#[from] url::ParseError), diff --git a/rustybits/zeroidc/src/ext.rs b/rustybits/zeroidc/src/ext.rs index fd1e2e96b..28d88dc1f 100644 --- a/rustybits/zeroidc/src/ext.rs +++ b/rustybits/zeroidc/src/ext.rs @@ -160,11 +160,7 @@ pub extern "C" fn zeroidc_get_exp_time(ptr: *mut ZeroIDC) -> u64 { target_os = "macos", ))] #[no_mangle] -pub extern "C" fn zeroidc_set_nonce_and_csrf( - ptr: *mut ZeroIDC, - csrf_token: *const c_char, - nonce: *const c_char, -) { +pub extern "C" fn zeroidc_set_nonce_and_csrf(ptr: *mut ZeroIDC, csrf_token: *const c_char, nonce: *const c_char) { let idc = unsafe { assert!(!ptr.is_null()); &mut *ptr @@ -180,14 +176,8 @@ pub extern "C" fn zeroidc_set_nonce_and_csrf( return; } - let csrf_token = unsafe { CStr::from_ptr(csrf_token) } - .to_str() - .unwrap() - .to_string(); - let nonce = unsafe { CStr::from_ptr(nonce) } - .to_str() - .unwrap() - .to_string(); + let csrf_token = unsafe { CStr::from_ptr(csrf_token) }.to_str().unwrap().to_string(); + let nonce = unsafe { CStr::from_ptr(nonce) }.to_str().unwrap().to_string(); idc.set_nonce_and_csrf(csrf_token, nonce); } @@ -275,10 +265,7 @@ pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, code: *const c_char) } #[no_mangle] -pub extern "C" fn zeroidc_get_url_param_value( - param: *const c_char, - path: *const c_char, -) -> *mut c_char { +pub extern "C" fn zeroidc_get_url_param_value(param: *const c_char, path: *const c_char) -> *mut c_char { if param.is_null() { println!("param is null"); return std::ptr::null_mut(); diff --git a/rustybits/zeroidc/src/lib.rs b/rustybits/zeroidc/src/lib.rs index 0c63cb946..edf20481b 100644 --- a/rustybits/zeroidc/src/lib.rs +++ b/rustybits/zeroidc/src/lib.rs @@ -26,9 +26,8 @@ use jwt::Token; use openidconnect::core::{CoreClient, CoreProviderMetadata, CoreResponseType}; use openidconnect::reqwest::http_client; use openidconnect::{ - AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, CsrfToken, - IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, - RefreshToken, Scope, TokenResponse, + AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, CsrfToken, IssuerUrl, Nonce, + OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, TokenResponse, }; use std::error::Error; use std::str::from_utf8; @@ -153,13 +152,9 @@ impl ZeroIDC { let redirect = RedirectUrl::new(redir_url.to_string())?; idc.inner.lock().unwrap().oidc_client = Some( - CoreClient::from_provider_metadata( - provider_meta, - ClientId::new(client_id.to_string()), - None, - ) - .set_redirect_uri(redirect) - .set_auth_type(openidconnect::AuthType::RequestBody), + CoreClient::from_provider_metadata(provider_meta, ClientId::new(client_id.to_string()), None) + .set_redirect_uri(redirect) + .set_auth_type(openidconnect::AuthType::RequestBody), ); Ok(idc) @@ -184,22 +179,15 @@ impl ZeroIDC { let nonce = inner_local.lock().unwrap().nonce.clone(); while running { - let exp = - UNIX_EPOCH + Duration::from_secs(inner_local.lock().unwrap().exp_time); + let exp = UNIX_EPOCH + Duration::from_secs(inner_local.lock().unwrap().exp_time); let now = SystemTime::now(); #[cfg(debug_assertions)] { println!( "refresh token thread tick, now: {}, exp: {}", - systemtime_strftime( - now, - "[year]-[month]-[day] [hour]:[minute]:[second]" - ), - systemtime_strftime( - exp, - "[year]-[month]-[day] [hour]:[minute]:[second]" - ) + systemtime_strftime(now, "[year]-[month]-[day] [hour]:[minute]:[second]"), + systemtime_strftime(exp, "[year]-[month]-[day] [hour]:[minute]:[second]") ); } let refresh_token = inner_local.lock().unwrap().refresh_token.clone(); @@ -220,14 +208,11 @@ impl ZeroIDC { println!("Refresh Token: {}", refresh_token.secret()); } - let token_response = - inner_local.lock().unwrap().oidc_client.as_ref().map(|c| { - let res = c - .exchange_refresh_token(&refresh_token) - .request(http_client); + let token_response = inner_local.lock().unwrap().oidc_client.as_ref().map(|c| { + let res = c.exchange_refresh_token(&refresh_token).request(http_client); - res - }); + res + }); if let Some(res) = token_response { match res { @@ -246,20 +231,11 @@ impl ZeroIDC { ]; #[cfg(debug_assertions)] { - println!( - "New ID token: {}", - id_token.to_string() - ); + println!("New ID token: {}", id_token.to_string()); } let client = reqwest::blocking::Client::new(); let r = client - .post( - inner_local - .lock() - .unwrap() - .auth_endpoint - .clone(), - ) + .post(inner_local.lock().unwrap().auth_endpoint.clone()) .form(¶ms) .send(); @@ -268,10 +244,7 @@ impl ZeroIDC { if r.status().is_success() { #[cfg(debug_assertions)] { - println!( - "hit url: {}", - r.url().as_str() - ); + println!("hit url: {}", r.url().as_str()); println!("status: {}", r.status()); } @@ -279,24 +252,16 @@ impl ZeroIDC { let idt = &id_token.to_string(); let t: Result< - Token< - jwt::Header, - jwt::Claims, - jwt::Unverified<'_>, - >, + Token>, jwt::Error, > = Token::parse_unverified(idt); if let Ok(t) = t { - let claims = - t.claims().registered.clone(); + let claims = t.claims().registered.clone(); match claims.expiration { Some(exp) => { println!("exp: {}", exp); - inner_local - .lock() - .unwrap() - .exp_time = exp; + inner_local.lock().unwrap().exp_time = exp; } None => { panic!("expiration is None. This shouldn't happen") @@ -306,17 +271,11 @@ impl ZeroIDC { panic!("error parsing claims"); } - inner_local - .lock() - .unwrap() - .access_token = + inner_local.lock().unwrap().access_token = Some(access_token.clone()); if let Some(t) = res.refresh_token() { // println!("New Refresh Token: {}", t.secret()); - inner_local - .lock() - .unwrap() - .refresh_token = + inner_local.lock().unwrap().refresh_token = Some(t.clone()); } #[cfg(debug_assertions)] @@ -324,35 +283,22 @@ impl ZeroIDC { println!("Central post succeeded"); } } else { - println!( - "Central post failed: {}", - r.status() - ); - println!( - "hit url: {}", - r.url().as_str() - ); + println!("Central post failed: {}", r.status()); + println!("hit url: {}", r.url().as_str()); println!("Status: {}", r.status()); if let Ok(body) = r.bytes() { - if let Ok(body) = - std::str::from_utf8(&body) - { + if let Ok(body) = std::str::from_utf8(&body) { println!("Body: {}", body); } } - inner_local.lock().unwrap().exp_time = - 0; - inner_local.lock().unwrap().running = - false; + inner_local.lock().unwrap().exp_time = 0; + inner_local.lock().unwrap().running = false; } } Err(e) => { println!("Central post failed: {}", e); - println!( - "hit url: {}", - e.url().unwrap().as_str() - ); + println!("hit url: {}", e.url().unwrap().as_str()); println!("Status: {}", e.status().unwrap()); inner_local.lock().unwrap().exp_time = 0; inner_local.lock().unwrap().running = false; @@ -421,88 +367,86 @@ impl ZeroIDC { pub fn set_nonce_and_csrf(&mut self, csrf_token: String, nonce: String) { let local = Arc::clone(&self.inner); - (*local.lock().expect("can't lock inner")) - .as_opt() - .map(|i| { - if i.running { - println!("refresh thread running. not setting new nonce or csrf"); - return; - } + (*local.lock().expect("can't lock inner")).as_opt().map(|i| { + if i.running { + println!("refresh thread running. not setting new nonce or csrf"); + return; + } - let need_verifier = matches!(i.pkce_verifier, None); + let need_verifier = matches!(i.pkce_verifier, None); - let csrf_diff = if let Some(csrf) = i.csrf_token.clone() { - *csrf.secret() != csrf_token - } else { - false - }; + let csrf_diff = if let Some(csrf) = i.csrf_token.clone() { + *csrf.secret() != csrf_token + } else { + false + }; - let nonce_diff = if let Some(n) = i.nonce.clone() { - *n.secret() != nonce - } else { - false - }; + let nonce_diff = if let Some(n) = i.nonce.clone() { + *n.secret() != nonce + } else { + false + }; - if need_verifier || csrf_diff || nonce_diff { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - let r = i.oidc_client.as_ref().map(|c| { - let mut auth_builder = c - .authorize_url( - AuthenticationFlow::::AuthorizationCode, - csrf_func(csrf_token), - nonce_func(nonce), - ) - .set_pkce_challenge(pkce_challenge); - match i.provider.as_str() { - "auth0" => { - auth_builder = auth_builder - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("offline_access".to_string())); - } - "okta" => { - auth_builder = auth_builder - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("groups".to_string())) - .add_scope(Scope::new("offline_access".to_string())); - } - "keycloak" => { - auth_builder = auth_builder - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())); - } - "onelogin" => { - auth_builder = auth_builder - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("groups".to_string())) - } - "default" => { - auth_builder = auth_builder - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("offline_access".to_string())); - } - _ => { - auth_builder = auth_builder - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("offline_access".to_string())); - } + if need_verifier || csrf_diff || nonce_diff { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let r = i.oidc_client.as_ref().map(|c| { + let mut auth_builder = c + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + csrf_func(csrf_token), + nonce_func(nonce), + ) + .set_pkce_challenge(pkce_challenge); + match i.provider.as_str() { + "auth0" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + "okta" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("groups".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + "keycloak" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())); + } + "onelogin" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("groups".to_string())) + } + "default" => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); + } + _ => { + auth_builder = auth_builder + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())); } - - auth_builder.url() - }); - - if let Some(r) = r { - i.url = Some(r.0); - i.csrf_token = Some(r.1); - i.nonce = Some(r.2); - i.pkce_verifier = Some(pkce_verifier); } + + auth_builder.url() + }); + + if let Some(r) = r { + i.url = Some(r.0); + i.csrf_token = Some(r.1); + i.nonce = Some(r.2); + i.pkce_verifier = Some(pkce_verifier); } - }); + } + }); } pub fn auth_url(&self) -> String { @@ -572,10 +516,7 @@ impl ZeroIDC { }; if let Some(expected_hash) = claims.access_token_hash() { - let actual_hash = match AccessTokenHash::from_token( - res.access_token(), - &signing_algo, - ) { + let actual_hash = match AccessTokenHash::from_token(res.access_token(), &signing_algo) { Ok(h) => h, Err(e) => { println!("Error hashing access token: {}", e); @@ -616,10 +557,7 @@ impl ZeroIDC { let split = split.split('_').collect::>(); if split.len() == 2 { - let params = [ - ("id_token", id_token.to_string()), - ("state", split[0].to_string()), - ]; + let params = [("id_token", id_token.to_string()), ("state", split[0].to_string())]; let client = reqwest::blocking::Client::new(); let res = client.post(i.auth_endpoint.clone()).form(¶ms).send(); @@ -634,10 +572,8 @@ impl ZeroIDC { let idt = &id_token.to_string(); - let t: Result< - Token>, - jwt::Error, - > = Token::parse_unverified(idt); + let t: Result>, jwt::Error> = + Token::parse_unverified(idt); if let Ok(t) = t { let claims = t.claims().registered.clone(); @@ -682,13 +618,12 @@ impl ZeroIDC { } else if res.status() == 402 { i.running = false; Err(SSOExchangeError::new( - "additional license seats required. Please contact your network administrator.".to_string(), + "additional license seats required. Please contact your network administrator." + .to_string(), )) } else { i.running = false; - Err(SSOExchangeError::new( - "error from central endpoint".to_string(), - )) + Err(SSOExchangeError::new("error from central endpoint".to_string())) } } Err(res) => { @@ -697,16 +632,12 @@ impl ZeroIDC { println!("Post error: {}", res); i.exp_time = 0; i.running = false; - Err(SSOExchangeError::new( - "error from central endpoint".to_string(), - )) + Err(SSOExchangeError::new("error from central endpoint".to_string())) } } } else { i.running = false; - Err(SSOExchangeError::new( - "error splitting state token".to_string(), - )) + Err(SSOExchangeError::new("error splitting state token".to_string())) } } else { i.running = false;