diff --git a/zeroidc/src/error.rs b/zeroidc/src/error.rs index 3feab7696..039429e6e 100644 --- a/zeroidc/src/error.rs +++ b/zeroidc/src/error.rs @@ -13,10 +13,11 @@ use thiserror::Error; #[derive(Error, Debug)] -pub enum ZeroIDCError -{ +pub enum ZeroIDCError { #[error(transparent)] - DiscoveryError(#[from] openidconnect::DiscoveryError>), + DiscoveryError( + #[from] openidconnect::DiscoveryError>, + ), #[error(transparent)] ParseError(#[from] url::ParseError), @@ -30,8 +31,6 @@ pub struct SSOExchangeError { impl SSOExchangeError { pub fn new(message: String) -> Self { - SSOExchangeError{ - message - } + SSOExchangeError { message } } } diff --git a/zeroidc/src/ext.rs b/zeroidc/src/ext.rs index 0831cb193..dc951dbb9 100644 --- a/zeroidc/src/ext.rs +++ b/zeroidc/src/ext.rs @@ -12,19 +12,17 @@ use std::ffi::{CStr, CString}; use std::os::raw::c_char; -use url::{Url}; +use url::Url; use crate::ZeroIDC; -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_new( issuer: *const c_char, @@ -56,25 +54,21 @@ pub extern "C" fn zeroidc_new( auth_endpoint.to_str().unwrap(), web_listen_port, ) { - Ok(idc) => { - return Box::into_raw(Box::new(idc)); - } + Ok(idc) => Box::into_raw(Box::new(idc)), Err(s) => { println!("Error creating ZeroIDC instance: {}", s); - return std::ptr::null_mut(); + std::ptr::null_mut() } } } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_delete(ptr: *mut ZeroIDC) { if ptr.is_null() { @@ -85,21 +79,19 @@ pub extern "C" fn zeroidc_delete(ptr: *mut ZeroIDC) { &mut *ptr }; idc.stop(); - + unsafe { Box::from_raw(ptr); } } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_start(ptr: *mut ZeroIDC) { let idc = unsafe { @@ -109,15 +101,13 @@ pub extern "C" fn zeroidc_start(ptr: *mut ZeroIDC) { idc.start(); } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_stop(ptr: *mut ZeroIDC) { let idc = unsafe { @@ -127,15 +117,13 @@ pub extern "C" fn zeroidc_stop(ptr: *mut ZeroIDC) { idc.stop(); } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_is_running(ptr: *mut ZeroIDC) -> bool { let idc = unsafe { @@ -156,20 +144,19 @@ pub extern "C" fn zeroidc_get_exp_time(ptr: *mut ZeroIDC) -> u64 { id.get_exp_time() } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_set_nonce_and_csrf( ptr: *mut ZeroIDC, csrf_token: *const c_char, - nonce: *const c_char) { + nonce: *const c_char, +) { let idc = unsafe { assert!(!ptr.is_null()); &mut *ptr @@ -193,19 +180,17 @@ pub extern "C" fn zeroidc_set_nonce_and_csrf( .to_str() .unwrap() .to_string(); - + idc.set_nonce_and_csrf(csrf_token, nonce); } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn free_cstr(s: *mut c_char) { if s.is_null() { @@ -218,40 +203,34 @@ pub extern "C" fn free_cstr(s: *mut c_char) { } } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_get_auth_url(ptr: *mut ZeroIDC) -> *mut c_char { if ptr.is_null() { println!("passed a null object"); return std::ptr::null_mut(); } - let idc = unsafe { - &mut *ptr - }; - + let idc = unsafe { &mut *ptr }; + let s = CString::new(idc.auth_url()).unwrap(); return s.into_raw(); } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] -pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, code: *const c_char ) -> *mut c_char { +pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, code: *const c_char) -> *mut c_char { if idc.is_null() { println!("idc is null"); return std::ptr::null_mut(); @@ -261,29 +240,29 @@ pub extern "C" fn zeroidc_token_exchange(idc: *mut ZeroIDC, code: *const c_char println!("code is null"); return std::ptr::null_mut(); } - let idc = unsafe { - &mut *idc - }; + let idc = unsafe { &mut *idc }; - let code = unsafe{CStr::from_ptr(code)}.to_str().unwrap(); + let code = unsafe { CStr::from_ptr(code) }.to_str().unwrap(); let ret = idc.do_token_exchange(code); match ret { Ok(ret) => { let ret = CString::new(ret).unwrap(); - return ret.into_raw(); - - }, + ret.into_raw() + } Err(e) => { - let errstr = format!("{{\"errorMessage\":\"{}\"\"}}", e).to_string(); + let errstr = format!("{{\"errorMessage\":\"{}\"\"}}", e); let ret = CString::new(errstr).unwrap(); - return ret.into_raw(); + ret.into_raw() } } } #[no_mangle] -pub extern "C" fn zeroidc_get_url_param_value(param: *const c_char, path: *const c_char) -> *mut c_char { +pub extern "C" fn zeroidc_get_url_param_value( + param: *const c_char, + path: *const c_char, +) -> *mut c_char { if param.is_null() { println!("param is null"); return std::ptr::null_mut(); @@ -292,21 +271,21 @@ pub extern "C" fn zeroidc_get_url_param_value(param: *const c_char, path: *const println!("path is null"); return std::ptr::null_mut(); } - let param = unsafe {CStr::from_ptr(param)}.to_str().unwrap(); - let path = unsafe {CStr::from_ptr(path)}.to_str().unwrap(); + let param = unsafe { CStr::from_ptr(param) }.to_str().unwrap(); + let path = unsafe { CStr::from_ptr(path) }.to_str().unwrap(); let url = "http://localhost:9993".to_string() + path; let url = Url::parse(&url).unwrap(); - let pairs = url.query_pairs(); + let pairs = url.query_pairs(); for p in pairs { if p.0 == param { let s = CString::new(p.1.into_owned()).unwrap(); - return s.into_raw() + return s.into_raw(); } } - return std::ptr::null_mut(); + std::ptr::null_mut() } #[no_mangle] @@ -316,36 +295,32 @@ pub extern "C" fn zeroidc_network_id_from_state(state: *const c_char) -> *mut c_ return std::ptr::null_mut(); } - let state = unsafe{CStr::from_ptr(state)}.to_str().unwrap(); + let state = unsafe { CStr::from_ptr(state) }.to_str().unwrap(); - let split = state.split("_"); + let split = state.split('_'); let split = split.collect::>(); if split.len() != 2 { return std::ptr::null_mut(); } let s = CString::new(split[1]).unwrap(); - return s.into_raw(); + s.into_raw() } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] #[no_mangle] pub extern "C" fn zeroidc_kick_refresh_thread(idc: *mut ZeroIDC) { if idc.is_null() { println!("idc is null"); return; } - let idc = unsafe { - &mut *idc - }; + let idc = unsafe { &mut *idc }; idc.kick_refresh_thread(); } diff --git a/zeroidc/src/lib.rs b/zeroidc/src/lib.rs index 8b870be42..003d1d74d 100644 --- a/zeroidc/src/lib.rs +++ b/zeroidc/src/lib.rs @@ -22,42 +22,41 @@ extern crate url; use crate::error::*; use bytes::Bytes; -use jwt::{Token}; +use jwt::Token; use openidconnect::core::{CoreClient, CoreProviderMetadata, CoreResponseType}; use openidconnect::reqwest::http_client; -use openidconnect::{AccessToken, AccessTokenHash, AuthorizationCode, AuthenticationFlow, ClientId, CsrfToken, IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, TokenResponse}; +use openidconnect::{ + AccessToken, AccessTokenHash, AuthenticationFlow, AuthorizationCode, ClientId, CsrfToken, + IssuerUrl, Nonce, OAuth2TokenResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, + RefreshToken, Scope, TokenResponse, +}; use std::error::Error; use std::str::from_utf8; use std::sync::{Arc, Mutex}; use std::thread::{sleep, spawn, JoinHandle}; -use std::time::{SystemTime, UNIX_EPOCH, Duration}; -use time::{OffsetDateTime, format_description}; - +use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use time::{format_description, OffsetDateTime}; use url::Url; -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] pub struct ZeroIDC { inner: Arc>, } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] struct Inner { running: bool, auth_endpoint: String, @@ -82,40 +81,35 @@ impl Inner { } fn csrf_func(csrf_token: String) -> Box CsrfToken> { - return Box::new(move || CsrfToken::new(csrf_token.to_string())); + Box::new(move || CsrfToken::new(csrf_token.to_string())) } fn nonce_func(nonce: String) -> Box Nonce> { - return Box::new(move || Nonce::new(nonce.to_string())); + Box::new(move || Nonce::new(nonce.to_string())) } #[cfg(debug_assertions)] fn systemtime_strftime(dt: T, format: &str) -> String - where T: Into +where + T: Into, { let f = format_description::parse(format); match f { - Ok(f) => { - match dt.into().format(&f) { - Ok(s) => s, - Err(_e) => "".to_string(), - } - }, - Err(_e) => { - "".to_string() + Ok(f) => match dt.into().format(&f) { + Ok(s) => s, + Err(_e) => "".to_string(), }, + Err(_e) => "".to_string(), } } -#[cfg( - any( - all(target_os = "linux", target_arch = "x86"), - all(target_os = "linux", target_arch = "x86_64"), - all(target_os = "linux", target_arch = "aarch64"), - target_os = "windows", - target_os = "macos", - ) -)] +#[cfg(any( + all(target_os = "linux", target_arch = "x86"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + target_os = "windows", + target_os = "macos", +))] impl ZeroIDC { pub fn new( issuer: &str, @@ -137,12 +131,14 @@ impl ZeroIDC { url: None, csrf_token: None, nonce: None, - pkce_verifier: None, + pkce_verifier: None, })), }; - println!("issuer: {}, client_id: {}, auth_endopint: {}, local_web_port: {}", - issuer, client_id, auth_ep, local_web_port); + println!( + "issuer: {}, client_id: {}, auth_endopint: {}, local_web_port: {}", + issuer, client_id, auth_ep, local_web_port + ); let iss = IssuerUrl::new(issuer.to_string())?; let provider_meta = CoreProviderMetadata::discover(&iss, http_client)?; @@ -184,35 +180,53 @@ impl ZeroIDC { let nonce = (*inner_local.lock().unwrap()).nonce.clone(); while running { - let exp = UNIX_EPOCH + Duration::from_secs((*inner_local.lock().unwrap()).exp_time); + let exp = + UNIX_EPOCH + Duration::from_secs((*inner_local.lock().unwrap()).exp_time); let now = SystemTime::now(); - #[cfg(debug_assertions)] { - println!("refresh token thread tick, now: {}, exp: {}", systemtime_strftime(now, "[year]-[month]-[day] [hour]:[minute]:[second]"), systemtime_strftime(exp, "[year]-[month]-[day] [hour]:[minute]:[second]")); + #[cfg(debug_assertions)] + { + println!( + "refresh token thread tick, now: {}, exp: {}", + systemtime_strftime( + now, + "[year]-[month]-[day] [hour]:[minute]:[second]" + ), + systemtime_strftime( + exp, + "[year]-[month]-[day] [hour]:[minute]:[second]" + ) + ); } let refresh_token = (*inner_local.lock().unwrap()).refresh_token.clone(); - - if let Some(refresh_token) = refresh_token { + + if let Some(refresh_token) = refresh_token { let should_kick = (*inner_local.lock().unwrap()).kick; if now >= (exp - Duration::from_secs(30)) || should_kick { if should_kick { - #[cfg(debug_assertions)] { + #[cfg(debug_assertions)] + { println!("refresh thread kicked"); } (*inner_local.lock().unwrap()).kick = false; } - #[cfg(debug_assertions)] { + #[cfg(debug_assertions)] + { println!("Refresh Token: {}", refresh_token.secret()); } - let token_response = (*inner_local.lock().unwrap()).oidc_client.as_ref().map(|c| { - let res = c.exchange_refresh_token(&refresh_token) - .request(http_client); - - res - }); - + let token_response = (*inner_local.lock().unwrap()) + .oidc_client + .as_ref() + .map(|c| { + let res = c + .exchange_refresh_token(&refresh_token) + .request(http_client); + + res + }); + if let Some(res) = token_response { match res { Ok(res) => { @@ -223,78 +237,126 @@ impl ZeroIDC { None => "".to_string(), }; - let params = [("id_token", id_token.to_string()),("state", "refresh".to_string()),("extra_nonce", n)]; - #[cfg(debug_assertions)] { - println!("New ID token: {}", id_token.to_string()); + let params = [ + ("id_token", id_token.to_string()), + ("state", "refresh".to_string()), + ("extra_nonce", n), + ]; + #[cfg(debug_assertions)] + { + println!( + "New ID token: {}", + id_token.to_string() + ); } let client = reqwest::blocking::Client::new(); - let r = client.post((*inner_local.lock().unwrap()).auth_endpoint.clone()) + let r = client + .post( + (*inner_local.lock().unwrap()) + .auth_endpoint + .clone(), + ) .form(¶ms) .send(); match r { Ok(r) => { if r.status().is_success() { - #[cfg(debug_assertions)] { - println!("hit url: {}", r.url().as_str()); + #[cfg(debug_assertions)] + { + println!( + "hit url: {}", + r.url().as_str() + ); println!("status: {}", r.status()); } let access_token = res.access_token(); let idt = &id_token.to_string(); - let t: Result>, jwt::Error> = - Token::parse_unverified(idt); - + let t: Result< + Token< + jwt::Header, + jwt::Claims, + jwt::Unverified<'_>, + >, + jwt::Error, + > = Token::parse_unverified(idt); + if let Ok(t) = t { - let claims = t.claims().registered.clone(); + let claims = + t.claims().registered.clone(); match claims.expiration { Some(exp) => { - (*inner_local.lock().unwrap()).exp_time = exp; - }, + (*inner_local + .lock() + .unwrap()) + .exp_time = exp; + } None => { panic!("expiration is None. This shouldn't happen") } } - } + } - (*inner_local.lock().unwrap()).access_token = Some(access_token.clone()); + (*inner_local.lock().unwrap()) + .access_token = + Some(access_token.clone()); if let Some(t) = res.refresh_token() { // println!("New Refresh Token: {}", t.secret()); - (*inner_local.lock().unwrap()).refresh_token = Some(t.clone()); + (*inner_local.lock().unwrap()) + .refresh_token = + Some(t.clone()); } - #[cfg(debug_assertions)] { + #[cfg(debug_assertions)] + { println!("Central post succeeded"); } } else { - println!("Central post failed: {}", r.status().to_string()); - println!("hit url: {}", r.url().as_str()); + println!( + "Central post failed: {}", + r.status() + ); + println!( + "hit url: {}", + r.url().as_str() + ); println!("Status: {}", r.status()); if let Ok(body) = r.bytes() { - if let Ok(body) = std::str::from_utf8(&body) { + if let Ok(body) = + std::str::from_utf8(&body) + { println!("Body: {}", body); } - } - - (*inner_local.lock().unwrap()).exp_time = 0; - (*inner_local.lock().unwrap()).running = false; + + (*inner_local.lock().unwrap()) + .exp_time = 0; + (*inner_local.lock().unwrap()) + .running = false; } - }, + } Err(e) => { - println!("Central post failed: {}", e.to_string()); - println!("hit url: {}", e.url().unwrap().as_str()); + println!( + "Central post failed: {}", + e.to_string() + ); + println!( + "hit url: {}", + e.url().unwrap().as_str() + ); println!("Status: {}", e.status().unwrap()); (*inner_local.lock().unwrap()).exp_time = 0; - (*inner_local.lock().unwrap()).running = false; + (*inner_local.lock().unwrap()).running = + false; } } - }, + } None => { println!("no id token?!?"); } } - }, + } Err(e) => { println!("token error: {}", e); } @@ -323,7 +385,7 @@ impl ZeroIDC { pub fn stop(&mut self) { let local = self.inner.clone(); - if self.is_running(){ + if self.is_running() { (*local.lock().unwrap()).running = false; } } @@ -341,73 +403,75 @@ impl ZeroIDC { pub fn set_nonce_and_csrf(&mut self, csrf_token: String, nonce: String) { let local = Arc::clone(&self.inner); - (*local.lock().expect("can't lock inner")).as_opt().map(|i| { - if i.running { - println!("refresh thread running. not setting new nonce or csrf"); - return - } + (*local.lock().expect("can't lock inner")) + .as_opt() + .map(|i| { + if i.running { + println!("refresh thread running. not setting new nonce or csrf"); + return; + } - let need_verifier = match i.pkce_verifier { - None => true, - _ => false, - }; + let need_verifier = match i.pkce_verifier { + None => true, + _ => false, + }; - let csrf_diff = if let Some(csrf) = i.csrf_token.clone() { - if *csrf.secret() != csrf_token { - true + let csrf_diff = if let Some(csrf) = i.csrf_token.clone() { + if *csrf.secret() != csrf_token { + true + } else { + false + } } else { false - } - } else { - false - }; + }; - let nonce_diff = if let Some(n) = i.nonce.clone() { - if *n.secret() != nonce { - true + let nonce_diff = if let Some(n) = i.nonce.clone() { + if *n.secret() != nonce { + true + } else { + false + } } else { false + }; + + if need_verifier || csrf_diff || nonce_diff { + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let r = i.oidc_client.as_ref().map(|c| { + let (auth_url, csrf_token, nonce) = c + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + csrf_func(csrf_token), + nonce_func(nonce), + ) + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("offline_access".to_string())) + .add_scope(Scope::new("openid".to_string())) + .set_pkce_challenge(pkce_challenge) + .url(); + + (auth_url, csrf_token, nonce) + }); + + if let Some(r) = r { + i.url = Some(r.0); + i.csrf_token = Some(r.1); + i.nonce = Some(r.2); + i.pkce_verifier = Some(pkce_verifier); + } } - } else { - false - }; - - if need_verifier || csrf_diff || nonce_diff { - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - let r = i.oidc_client.as_ref().map(|c| { - let (auth_url, csrf_token, nonce) = c - .authorize_url( - AuthenticationFlow::::AuthorizationCode, - csrf_func(csrf_token), - nonce_func(nonce), - ) - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("offline_access".to_string())) - .add_scope(Scope::new("openid".to_string())) - .set_pkce_challenge(pkce_challenge) - .url(); - - (auth_url, csrf_token, nonce) - }); - - if let Some(r) = r { - i.url = Some(r.0); - i.csrf_token = Some(r.1); - i.nonce = Some(r.2); - i.pkce_verifier = Some(pkce_verifier); - } - } - }); + }); } pub fn auth_url(&self) -> String { - let url = (*self.inner.lock().expect("can't lock inner")).as_opt().map(|i| { - match i.url.clone() { + let url = (*self.inner.lock().expect("can't lock inner")) + .as_opt() + .map(|i| match i.url.clone() { Some(u) => u.to_string(), _ => "".to_string(), - } - }); + }); match url { Some(url) => url.to_string(), @@ -423,13 +487,14 @@ impl ZeroIDC { let token_response = i.oidc_client.as_ref().map(|c| { println!("auth code: {}", code); - let r = c.exchange_code(AuthorizationCode::new(code.to_string())) + let r = c + .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(verifier) .request(http_client); // validate the token hashes match r { - Ok(res) =>{ + Ok(res) => { let n = match i.nonce.clone() { Some(n) => n, None => { @@ -437,7 +502,7 @@ impl ZeroIDC { return None; } }; - + let id = match res.id_token() { Some(t) => t, None => { @@ -463,7 +528,10 @@ impl ZeroIDC { }; if let Some(expected_hash) = claims.access_token_hash() { - let actual_hash = match AccessTokenHash::from_token(res.access_token(), &signing_algo) { + let actual_hash = match AccessTokenHash::from_token( + res.access_token(), + &signing_algo, + ) { Ok(h) => h, Err(e) => { println!("Error hashing access token: {}", e); @@ -477,77 +545,80 @@ impl ZeroIDC { } } Some(res) - }, + } Err(e) => { println!("token response error: {:?}", e.to_string()); println!("\t {:?}", e.source()); - - return None; - }, + + None + } } }); - + if let Some(Some(tok)) = token_response { let id_token = tok.id_token().unwrap(); - #[cfg(debug_assertions)] { + #[cfg(debug_assertions)] + { println!("ID token: {}", id_token.to_string()); } let mut split = "".to_string(); - match i.csrf_token.clone() { - Some(csrf_token) => { - split = csrf_token.secret().to_owned(); - }, - _ => (), + if let Some(tok) = i.csrf_token.clone() { + split = tok.secret().to_owned(); } - let split = split.split("_").collect::>(); - + let split = split.split('_').collect::>(); + if split.len() == 2 { - let params = [("id_token", id_token.to_string()),("state", split[0].to_string())]; + let params = [ + ("id_token", id_token.to_string()), + ("state", split[0].to_string()), + ]; let client = reqwest::blocking::Client::new(); - let res = client.post(i.auth_endpoint.clone()) - .form(¶ms) - .send(); + let res = client.post(i.auth_endpoint.clone()).form(¶ms).send(); match res { Ok(res) => { - #[cfg(debug_assertions)] { + #[cfg(debug_assertions)] + { println!("hit url: {}", res.url().as_str()); println!("Status: {}", res.status()); } let idt = &id_token.to_string(); - let t: Result>, jwt::Error>= - Token::parse_unverified(idt); - + let t: Result< + Token>, + jwt::Error, + > = Token::parse_unverified(idt); + if let Ok(t) = t { let claims = t.claims().registered.clone(); match claims.expiration { Some(exp) => { i.exp_time = exp; println!("Set exp time to: {:?}", i.exp_time); - }, + } None => { panic!("expiration is None. This shouldn't happen"); } } - } + } i.access_token = Some(tok.access_token().clone()); if let Some(t) = tok.refresh_token() { i.refresh_token = Some(t.clone()); should_start = true; } - #[cfg(debug_assertions)] { + #[cfg(debug_assertions)] + { let access_token = tok.access_token(); println!("Access Token: {}", access_token.secret()); let refresh_token = tok.refresh_token(); println!("Refresh Token: {}", refresh_token.unwrap().secret()); } - + let bytes = match res.bytes() { Ok(bytes) => bytes, Err(_) => Bytes::from(""), @@ -558,38 +629,37 @@ impl ZeroIDC { Err(_) => "".to_string(), }; - return Ok(bytes); - }, + Ok(bytes) + } Err(res) => { println!("error result: {}", res); println!("hit url: {}", res.url().unwrap().as_str()); println!("Status: {}", res.status().unwrap()); - println!("Post error: {}", res.to_string()); + println!("Post error: {}", res); i.exp_time = 0; - return Err(SSOExchangeError::new("error from central endpoint".to_string())); + Err(SSOExchangeError::new( + "error from central endpoint".to_string(), + )) } } } else { - return Err(SSOExchangeError::new("error splitting state token".to_string())); + Err(SSOExchangeError::new( + "error splitting state token".to_string(), + )) } } else { - return Err(SSOExchangeError::new("invalid token response".to_string())); + Err(SSOExchangeError::new("invalid token response".to_string())) } } else { - return Err(SSOExchangeError::new("invalid pkce verifier".to_string())); + Err(SSOExchangeError::new("invalid pkce verifier".to_string())) } - }); if should_start { self.start(); } match res { - Some(res) => { - return res; - }, - _ => { - return Err(SSOExchangeError::new("invalid result".to_string())); - }, - }; + Some(res) => res, + _ => Err(SSOExchangeError::new("invalid result".to_string())), + } } }