diff --git a/src/ApiService/ApiService/Functions/AgentRegistration.cs b/src/ApiService/ApiService/Functions/AgentRegistration.cs index 00ccd2925..0236372d8 100644 --- a/src/ApiService/ApiService/Functions/AgentRegistration.cs +++ b/src/ApiService/ApiService/Functions/AgentRegistration.cs @@ -73,7 +73,6 @@ public class AgentRegistration { var baseAddress = _context.Creds.GetInstanceUrl(); var eventsUrl = new Uri(baseAddress, "/api/agents/events"); var commandsUrl = new Uri(baseAddress, "/api/agents/commands"); - var workQueue = await _context.Queue.GetQueueSas( _context.PoolOperations.GetPoolQueue(pool.PoolId), StorageType.Corpus, diff --git a/src/ApiService/ApiService/Functions/Pool.cs b/src/ApiService/ApiService/Functions/Pool.cs index 6ca19ec64..497a8fd2c 100644 --- a/src/ApiService/ApiService/Functions/Pool.cs +++ b/src/ApiService/ApiService/Functions/Pool.cs @@ -67,7 +67,7 @@ public class Pool { Errors: new string[] { "pool with that name already exists" }), "PoolCreate"); } - var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, clientId: create.ClientId); + var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, objectId: create.ObjectId); return await RequestHandling.Ok(req, await Populate(PoolToPoolResponse(newPool), true)); } @@ -106,7 +106,7 @@ public class Pool { PoolId: p.PoolId, Os: p.Os, State: p.State, - ClientId: p.ClientId, + ObjectId: p.ObjectId, Managed: p.Managed, Arch: p.Arch, Nodes: p.Nodes, diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index db425d281..ff274f0a3 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -645,7 +645,7 @@ public record Pool( bool Managed, Architecture Arch, PoolState State, - Guid? ClientId = null + Guid? ObjectId = null ) : StatefulEntityBase(State) { public List? Nodes { get; set; } public AgentConfig? Config { get; set; } diff --git a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs index f00ec5d33..062abecce 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs @@ -263,7 +263,7 @@ public record PoolCreate( [property: Required] Os Os, [property: Required] Architecture Arch, [property: Required] bool Managed, - Guid? ClientId = null + Guid? ObjectId = null ) : BaseRequest; public record WebhookCreate( diff --git a/src/ApiService/ApiService/OneFuzzTypes/Responses.cs b/src/ApiService/ApiService/OneFuzzTypes/Responses.cs index 2dbaa18fd..b5391f7ad 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Responses.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Responses.cs @@ -114,7 +114,7 @@ public record PoolGetResult( bool Managed, Architecture Arch, PoolState State, - Guid? ClientId, + Guid? ObjectId, List? Nodes, AgentConfig? Config, List? WorkQueue, diff --git a/src/ApiService/ApiService/ServiceConfiguration.cs b/src/ApiService/ApiService/ServiceConfiguration.cs index bc8eb7ba9..57f056122 100644 --- a/src/ApiService/ApiService/ServiceConfiguration.cs +++ b/src/ApiService/ApiService/ServiceConfiguration.cs @@ -50,6 +50,8 @@ public interface IServiceConfig { // multiple instances to run against the same storage account, which // is useful for things like integration testing. public string OneFuzzStoragePrefix { get; } + + public Uri OneFuzzBaseAddress { get; } } public class ServiceConfiguration : IServiceConfig { @@ -134,4 +136,12 @@ public class ServiceConfiguration : IServiceConfig { public string OneFuzzNodeDisposalStrategy { get => GetEnv("ONEFUZZ_NODE_DISPOSAL_STRATEGY") ?? "scale_in"; } public string OneFuzzStoragePrefix => ""; // in production we never prefix the tables + + public Uri OneFuzzBaseAddress { + get { + var hostName = Environment.GetEnvironmentVariable("WEBSITE_HOSTNAME"); + var scheme = Environment.GetEnvironmentVariable("HTTPS") != null ? "https" : "http"; + return new Uri($"{scheme}://{hostName}"); + } + } } diff --git a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs index 6eccef84a..2abceb30d 100644 --- a/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs +++ b/src/ApiService/ApiService/onefuzzlib/EndpointAuthorization.cs @@ -46,9 +46,12 @@ public class EndpointAuthorization : IEndpointAuthorization { } var token = tokenResult.OkV.UserInfo; - if (await IsUser(tokenResult.OkV)) { + + var (isAgent, reason) = await IsAgent(tokenResult.OkV); + + if (!isAgent) { if (!allowUser) { - return await Reject(req, token); + return await Reject(req, token, "endpoint not allowed for users"); } var access = await CheckAccess(req); @@ -57,26 +60,24 @@ public class EndpointAuthorization : IEndpointAuthorization { } } - if (await IsAgent(tokenResult.OkV) && !allowAgent) { - return await Reject(req, token); + + if (isAgent && !allowAgent) { + return await Reject(req, token, reason); } return await method(req); } - public async Async.Task IsUser(UserAuthInfo tokenData) { - return !await IsAgent(tokenData); - } - public async Async.Task Reject(HttpRequestData req, UserInfo token) { + public async Async.Task Reject(HttpRequestData req, UserInfo token, String? reason = null) { var body = await req.ReadAsStringAsync(); - _log.Error($"reject token. url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}"); + _log.Error($"reject token. reason:{reason} url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}"); return await _context.RequestHandling.NotOk( req, new Error( ErrorCode.UNAUTHORIZED, - new string[] { "Unrecognized agent" } + new string[] { reason ?? "Unrecognized agent" } ), "token verification", HttpStatusCode.Unauthorized @@ -186,9 +187,10 @@ public class EndpointAuthorization : IEndpointAuthorization { return null; } - public async Async.Task IsAgent(UserAuthInfo authInfo) { + + public async Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo) { if (!AgentRoles.Overlaps(authInfo.Roles)) { - return false; + return (false, "no agent role"); } var tokenData = authInfo.UserInfo; @@ -196,24 +198,24 @@ public class EndpointAuthorization : IEndpointAuthorization { if (tokenData.ObjectId != null) { var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value); if (await scalesets.AnyAsync()) { - return true; + return (true, string.Empty); } var principalId = await _context.Creds.GetScalesetPrincipalId(); if (principalId == tokenData.ObjectId) { - return true; + return (true, string.Empty); } } - if (!tokenData.ApplicationId.HasValue) { - return false; + if (!tokenData.ObjectId.HasValue) { + return (false, "no object id in token"); } - var pools = _context.PoolOperations.GetByClientId(tokenData.ApplicationId.Value); + var pools = _context.PoolOperations.GetByObjectId(tokenData.ObjectId.Value); if (await pools.AnyAsync()) { - return true; + return (true, string.Empty); } - return false; + return (false, "no matching scaleset or pool"); } } diff --git a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs index 09c0ec342..9a3662f3b 100644 --- a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs @@ -6,14 +6,14 @@ public interface IPoolOperations : IStatefulOrm { Async.Task> GetByName(PoolName poolName); Async.Task> GetById(Guid poolId); Task ScheduleWorkset(Pool pool, WorkSet workSet); - IAsyncEnumerable GetByClientId(Guid clientId); + IAsyncEnumerable GetByObjectId(Guid objectId); string GetPoolQueue(Guid poolId); Async.Task> GetScalesetSummary(PoolName name); Async.Task> GetWorkQueue(Guid poolId, PoolState state); IAsyncEnumerable SearchStates(IEnumerable states); Async.Task SetShutdown(Pool pool, bool Now); - Async.Task Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null); + Async.Task Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null); new Async.Task Delete(Pool pool); // state transitions: @@ -32,7 +32,7 @@ public class PoolOperations : StatefulOrm, IPoo } - public async Async.Task Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null) { + public async Async.Task Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null) { var newPool = new Service.Pool( PoolId: Guid.NewGuid(), State: PoolState.Init, @@ -40,7 +40,7 @@ public class PoolOperations : StatefulOrm, IPoo Os: os, Managed: managed, Arch: architecture, - ClientId: clientId); + ObjectId: objectId); var r = await Insert(newPool); if (!r.IsOk) { @@ -87,8 +87,8 @@ public class PoolOperations : StatefulOrm, IPoo return await _context.Queue.QueueObject(GetPoolQueue(pool.PoolId), workSet, StorageType.Corpus); } - public IAsyncEnumerable GetByClientId(Guid clientId) { - return QueryAsync(filter: $"client_id eq '{clientId}'"); + public IAsyncEnumerable GetByObjectId(Guid objectId) { + return QueryAsync(filter: $"object_id eq '{objectId}'"); } public string GetPoolQueue(Guid poolId) diff --git a/src/ApiService/IntegrationTests/Fakes/TestServiceConfiguration.cs b/src/ApiService/IntegrationTests/Fakes/TestServiceConfiguration.cs index 6be25ff65..72a1622d4 100644 --- a/src/ApiService/IntegrationTests/Fakes/TestServiceConfiguration.cs +++ b/src/ApiService/IntegrationTests/Fakes/TestServiceConfiguration.cs @@ -64,5 +64,6 @@ public sealed class TestServiceConfiguration : IServiceConfig { public string? OneFuzzAllowOutdatedAgent => throw new NotImplementedException(); public string? AppConfigurationEndpoint => throw new NotImplementedException(); + public Uri OneFuzzBaseAddress { get => new Uri("http://test"); } public string? AppConfigurationConnectionString => throw new NotImplementedException(); } diff --git a/src/agent/.gitignore b/src/agent/.gitignore index 41c40876d..8769e3d5f 100644 --- a/src/agent/.gitignore +++ b/src/agent/.gitignore @@ -1,2 +1,2 @@ target -.agent-run +.agent-run \ No newline at end of file diff --git a/src/agent/Cargo.lock b/src/agent/Cargo.lock index 65b6a90b6..efbf12dc9 100644 --- a/src/agent/Cargo.lock +++ b/src/agent/Cargo.lock @@ -2616,6 +2616,7 @@ dependencies = [ "log", "onefuzz-telemetry", "reqwest", + "thiserror", "tokio", "wiremock", ] diff --git a/src/agent/reqwest-retry/Cargo.toml b/src/agent/reqwest-retry/Cargo.toml index 25b24f324..8e73b3123 100644 --- a/src/agent/reqwest-retry/Cargo.toml +++ b/src/agent/reqwest-retry/Cargo.toml @@ -12,6 +12,7 @@ backoff = { version = "0.4", features = ["tokio"] } log = "0.4" onefuzz-telemetry = { path = "../onefuzz-telemetry" } reqwest = { version = "0.11", features = ["json", "stream", "native-tls-vendored"], default-features=false } +thiserror = "1.0" [dev-dependencies] tokio = { version = "1.16", features = ["macros"] } diff --git a/src/agent/reqwest-retry/src/lib.rs b/src/agent/reqwest-retry/src/lib.rs index 216681b3b..f2a80f17a 100644 --- a/src/agent/reqwest-retry/src/lib.rs +++ b/src/agent/reqwest-retry/src/lib.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; use backoff::{self, future::retry_notify, ExponentialBackoff}; use onefuzz_telemetry::debug; @@ -20,6 +20,32 @@ pub enum RetryCheck { Succeed, } +#[derive(Debug, thiserror::Error)] +pub enum ReqwestRetryError { + #[error("request failed with status code {status_code} and url {url}")] + Response { + status_code: StatusCode, + url: reqwest::Url, + source: anyhow::Error, + }, + #[error("request failed to be sent")] + SendError { source: anyhow::Error }, +} + +impl ReqwestRetryError { + fn response_error(status_code: StatusCode, url: reqwest::Url, source: anyhow::Error) -> Self { + Self::Response { + status_code, + url, + source, + } + } + + fn send_error(source: anyhow::Error) -> Self { + Self::SendError { source } + } +} + fn always_retry(_: StatusCode) -> RetryCheck { RetryCheck::Retry } @@ -51,7 +77,8 @@ where let counter = AtomicUsize::new(0); let op = || async { let attempt_count = counter.fetch_add(1, Ordering::SeqCst); - let request = build_request().map_err(|err| backoff::Error::Permanent(Err(err)))?; + let request = build_request() + .map_err(|err| backoff::Error::Permanent(ReqwestRetryError::send_error(err)))?; let result = request .send() .await @@ -59,9 +86,9 @@ where match result { Err(x) => { if attempt_count >= max_retry { - Err(backoff::Error::Permanent(Err(x))) + Err(backoff::Error::Permanent(ReqwestRetryError::send_error(x))) } else { - Err(backoff::Error::transient(Err(x))) + Err(backoff::Error::transient(ReqwestRetryError::send_error(x))) } } Ok(x) => { @@ -70,31 +97,40 @@ where } else { let status = x.status(); let result = check_status(status); + let url = x.url().clone(); match result { RetryCheck::Succeed => Ok(x), RetryCheck::Fail => { - match x.error_for_status().with_context(|| { - format!("request attempt {} failed", attempt_count + 1) - }) { - // the is_success check earlier should have taken care of this already. - Ok(x) => Ok(x), - Err(as_err) => Err(backoff::Error::Permanent(Err(as_err))), - } + let content = x.text().await.unwrap_or_else(|_| "".to_string()); + let e = anyhow!( + "request attempt {} failed with status code {} and content {}", + attempt_count + 1, + status, + content + ); + + Err(backoff::Error::Permanent( + ReqwestRetryError::response_error(status, url, e), + )) } RetryCheck::Retry => { - match x.error_for_status().with_context(|| { - format!("request attempt {} failed", attempt_count + 1) - }) { - // the is_success check earlier should have taken care of this already. - Ok(x) => Ok(x), - Err(as_err) => { - if attempt_count >= max_retry { - Err(backoff::Error::Permanent(Err(as_err))) - } else { - Err(backoff::Error::transient(Err(as_err))) - } - } + let content = x.text().await.unwrap_or_else(|_| "".to_string()); + let e = anyhow!( + "request attempt {} failed with status code {} and content {}", + attempt_count + 1, + status, + content + ); + + if attempt_count >= max_retry { + Err(backoff::Error::Permanent( + ReqwestRetryError::response_error(status, url, e), + )) + } else { + Err(backoff::Error::transient( + ReqwestRetryError::response_error(status, url, e), + )) } } } @@ -109,20 +145,13 @@ where ..ExponentialBackoff::default() }, op, - |err: Result, dur| match err { - Ok(response) => { - if let Err(err) = response.error_for_status() { - debug!("request attempt failed after {:?}: {:?}", dur, err) - } - } - err => debug!("request attempt failed after {:?}: {:?}", dur, err), - }, + |err: ReqwestRetryError, dur| debug!("request attempt failed after {:?}: {:?}", dur, err), ) .await; match result { - Ok(response) | Err(Ok(response)) => Ok(response), - Err(Err(err)) => Err(err), + Ok(response) => Ok(response), + Err(error) => Err(error.into()), } } @@ -174,18 +203,16 @@ impl SendRetry for reqwest::RequestBuilder { pub fn is_auth_failure(response: &Result) -> bool { // Check both cases to support `error_for_status()`. match response { - Ok(response) => { - return response.status() == StatusCode::UNAUTHORIZED; - } - Err(error) => { - if let Some(error) = error.downcast_ref::() { - if let Some(status) = error.status() { - return status == StatusCode::UNAUTHORIZED; - } - } - } + Ok(response) => response.status() == StatusCode::UNAUTHORIZED, + Err(error) => match error.downcast_ref::() { + Some(ReqwestRetryError::Response { + status_code, + url: _, + source: _, + }) => status_code == &StatusCode::UNAUTHORIZED, + _ => false, + }, } - false } #[cfg(test)] diff --git a/src/cli/onefuzz/api.py b/src/cli/onefuzz/api.py index 51aad410b..f7659356d 100644 --- a/src/cli/onefuzz/api.py +++ b/src/cli/onefuzz/api.py @@ -1239,7 +1239,7 @@ class Pool(Endpoint): self, name: str, os: enums.OS, - client_id: Optional[UUID] = None, + object_id: Optional[UUID] = None, *, unmanaged: bool = False, arch: enums.Architecture = enums.Architecture.x86_64, @@ -1256,7 +1256,7 @@ class Pool(Endpoint): "POST", models.Pool, data=requests.PoolCreate( - name=name, os=os, arch=arch, managed=managed, client_id=client_id + name=name, os=os, arch=arch, managed=managed, object_id=object_id ), ) diff --git a/src/pytypes/onefuzztypes/requests.py b/src/pytypes/onefuzztypes/requests.py index 54bad0ade..1bdf37dff 100644 --- a/src/pytypes/onefuzztypes/requests.py +++ b/src/pytypes/onefuzztypes/requests.py @@ -92,7 +92,7 @@ class PoolCreate(BaseRequest): os: OS arch: Architecture managed: bool - client_id: Optional[UUID] + object_id: Optional[UUID] autoscale: Optional[AutoScaleConfig]