rename client_id in pool to object_id (#2673)

* rename client_id in pool to object_id

* fix tests

* print out the content body when receiving an error response in the agent

* fix test

* Apply suggestions from code review

* Update src/ApiService/ApiService/Functions/AgentRegistration.cs

* format

* cleanup

* format

* address pr comment
This commit is contained in:
Cheick Keita
2022-12-12 19:39:49 -08:00
committed by GitHub
parent 3cf09c6a40
commit 4c1adb6e96
15 changed files with 120 additions and 79 deletions

View File

@ -73,7 +73,6 @@ public class AgentRegistration {
var baseAddress = _context.Creds.GetInstanceUrl(); var baseAddress = _context.Creds.GetInstanceUrl();
var eventsUrl = new Uri(baseAddress, "/api/agents/events"); var eventsUrl = new Uri(baseAddress, "/api/agents/events");
var commandsUrl = new Uri(baseAddress, "/api/agents/commands"); var commandsUrl = new Uri(baseAddress, "/api/agents/commands");
var workQueue = await _context.Queue.GetQueueSas( var workQueue = await _context.Queue.GetQueueSas(
_context.PoolOperations.GetPoolQueue(pool.PoolId), _context.PoolOperations.GetPoolQueue(pool.PoolId),
StorageType.Corpus, StorageType.Corpus,

View File

@ -67,7 +67,7 @@ public class Pool {
Errors: new string[] { "pool with that name already exists" }), Errors: new string[] { "pool with that name already exists" }),
"PoolCreate"); "PoolCreate");
} }
var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, clientId: create.ClientId); var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, objectId: create.ObjectId);
return await RequestHandling.Ok(req, await Populate(PoolToPoolResponse(newPool), true)); return await RequestHandling.Ok(req, await Populate(PoolToPoolResponse(newPool), true));
} }
@ -106,7 +106,7 @@ public class Pool {
PoolId: p.PoolId, PoolId: p.PoolId,
Os: p.Os, Os: p.Os,
State: p.State, State: p.State,
ClientId: p.ClientId, ObjectId: p.ObjectId,
Managed: p.Managed, Managed: p.Managed,
Arch: p.Arch, Arch: p.Arch,
Nodes: p.Nodes, Nodes: p.Nodes,

View File

@ -645,7 +645,7 @@ public record Pool(
bool Managed, bool Managed,
Architecture Arch, Architecture Arch,
PoolState State, PoolState State,
Guid? ClientId = null Guid? ObjectId = null
) : StatefulEntityBase<PoolState>(State) { ) : StatefulEntityBase<PoolState>(State) {
public List<Node>? Nodes { get; set; } public List<Node>? Nodes { get; set; }
public AgentConfig? Config { get; set; } public AgentConfig? Config { get; set; }

View File

@ -263,7 +263,7 @@ public record PoolCreate(
[property: Required] Os Os, [property: Required] Os Os,
[property: Required] Architecture Arch, [property: Required] Architecture Arch,
[property: Required] bool Managed, [property: Required] bool Managed,
Guid? ClientId = null Guid? ObjectId = null
) : BaseRequest; ) : BaseRequest;
public record WebhookCreate( public record WebhookCreate(

View File

@ -114,7 +114,7 @@ public record PoolGetResult(
bool Managed, bool Managed,
Architecture Arch, Architecture Arch,
PoolState State, PoolState State,
Guid? ClientId, Guid? ObjectId,
List<Node>? Nodes, List<Node>? Nodes,
AgentConfig? Config, AgentConfig? Config,
List<WorkSetSummary>? WorkQueue, List<WorkSetSummary>? WorkQueue,

View File

@ -50,6 +50,8 @@ public interface IServiceConfig {
// multiple instances to run against the same storage account, which // multiple instances to run against the same storage account, which
// is useful for things like integration testing. // is useful for things like integration testing.
public string OneFuzzStoragePrefix { get; } public string OneFuzzStoragePrefix { get; }
public Uri OneFuzzBaseAddress { get; }
} }
public class ServiceConfiguration : IServiceConfig { public class ServiceConfiguration : IServiceConfig {
@ -134,4 +136,12 @@ public class ServiceConfiguration : IServiceConfig {
public string OneFuzzNodeDisposalStrategy { get => GetEnv("ONEFUZZ_NODE_DISPOSAL_STRATEGY") ?? "scale_in"; } public string OneFuzzNodeDisposalStrategy { get => GetEnv("ONEFUZZ_NODE_DISPOSAL_STRATEGY") ?? "scale_in"; }
public string OneFuzzStoragePrefix => ""; // in production we never prefix the tables public string OneFuzzStoragePrefix => ""; // in production we never prefix the tables
public Uri OneFuzzBaseAddress {
get {
var hostName = Environment.GetEnvironmentVariable("WEBSITE_HOSTNAME");
var scheme = Environment.GetEnvironmentVariable("HTTPS") != null ? "https" : "http";
return new Uri($"{scheme}://{hostName}");
}
}
} }

View File

@ -46,9 +46,12 @@ public class EndpointAuthorization : IEndpointAuthorization {
} }
var token = tokenResult.OkV.UserInfo; var token = tokenResult.OkV.UserInfo;
if (await IsUser(tokenResult.OkV)) {
var (isAgent, reason) = await IsAgent(tokenResult.OkV);
if (!isAgent) {
if (!allowUser) { if (!allowUser) {
return await Reject(req, token); return await Reject(req, token, "endpoint not allowed for users");
} }
var access = await CheckAccess(req); var access = await CheckAccess(req);
@ -57,26 +60,24 @@ public class EndpointAuthorization : IEndpointAuthorization {
} }
} }
if (await IsAgent(tokenResult.OkV) && !allowAgent) {
return await Reject(req, token); if (isAgent && !allowAgent) {
return await Reject(req, token, reason);
} }
return await method(req); return await method(req);
} }
public async Async.Task<bool> IsUser(UserAuthInfo tokenData) {
return !await IsAgent(tokenData);
}
public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token) { public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token, String? reason = null) {
var body = await req.ReadAsStringAsync(); var body = await req.ReadAsStringAsync();
_log.Error($"reject token. url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}"); _log.Error($"reject token. reason:{reason} url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}");
return await _context.RequestHandling.NotOk( return await _context.RequestHandling.NotOk(
req, req,
new Error( new Error(
ErrorCode.UNAUTHORIZED, ErrorCode.UNAUTHORIZED,
new string[] { "Unrecognized agent" } new string[] { reason ?? "Unrecognized agent" }
), ),
"token verification", "token verification",
HttpStatusCode.Unauthorized HttpStatusCode.Unauthorized
@ -186,9 +187,10 @@ public class EndpointAuthorization : IEndpointAuthorization {
return null; return null;
} }
public async Async.Task<bool> IsAgent(UserAuthInfo authInfo) {
public async Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo) {
if (!AgentRoles.Overlaps(authInfo.Roles)) { if (!AgentRoles.Overlaps(authInfo.Roles)) {
return false; return (false, "no agent role");
} }
var tokenData = authInfo.UserInfo; var tokenData = authInfo.UserInfo;
@ -196,24 +198,24 @@ public class EndpointAuthorization : IEndpointAuthorization {
if (tokenData.ObjectId != null) { if (tokenData.ObjectId != null) {
var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value); var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value);
if (await scalesets.AnyAsync()) { if (await scalesets.AnyAsync()) {
return true; return (true, string.Empty);
} }
var principalId = await _context.Creds.GetScalesetPrincipalId(); var principalId = await _context.Creds.GetScalesetPrincipalId();
if (principalId == tokenData.ObjectId) { if (principalId == tokenData.ObjectId) {
return true; return (true, string.Empty);
} }
} }
if (!tokenData.ApplicationId.HasValue) { if (!tokenData.ObjectId.HasValue) {
return false; return (false, "no object id in token");
} }
var pools = _context.PoolOperations.GetByClientId(tokenData.ApplicationId.Value); var pools = _context.PoolOperations.GetByObjectId(tokenData.ObjectId.Value);
if (await pools.AnyAsync()) { if (await pools.AnyAsync()) {
return true; return (true, string.Empty);
} }
return false; return (false, "no matching scaleset or pool");
} }
} }

View File

@ -6,14 +6,14 @@ public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName); Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName);
Async.Task<OneFuzzResult<Pool>> GetById(Guid poolId); Async.Task<OneFuzzResult<Pool>> GetById(Guid poolId);
Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet); Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet);
IAsyncEnumerable<Pool> GetByClientId(Guid clientId); IAsyncEnumerable<Pool> GetByObjectId(Guid objectId);
string GetPoolQueue(Guid poolId); string GetPoolQueue(Guid poolId);
Async.Task<List<ScalesetSummary>> GetScalesetSummary(PoolName name); Async.Task<List<ScalesetSummary>> GetScalesetSummary(PoolName name);
Async.Task<List<WorkSetSummary>> GetWorkQueue(Guid poolId, PoolState state); Async.Task<List<WorkSetSummary>> GetWorkQueue(Guid poolId, PoolState state);
IAsyncEnumerable<Pool> SearchStates(IEnumerable<PoolState> states); IAsyncEnumerable<Pool> SearchStates(IEnumerable<PoolState> states);
Async.Task<Pool> SetShutdown(Pool pool, bool Now); Async.Task<Pool> SetShutdown(Pool pool, bool Now);
Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null); Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null);
new Async.Task Delete(Pool pool); new Async.Task Delete(Pool pool);
// state transitions: // state transitions:
@ -32,7 +32,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
} }
public async Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null) { public async Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null) {
var newPool = new Service.Pool( var newPool = new Service.Pool(
PoolId: Guid.NewGuid(), PoolId: Guid.NewGuid(),
State: PoolState.Init, State: PoolState.Init,
@ -40,7 +40,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
Os: os, Os: os,
Managed: managed, Managed: managed,
Arch: architecture, Arch: architecture,
ClientId: clientId); ObjectId: objectId);
var r = await Insert(newPool); var r = await Insert(newPool);
if (!r.IsOk) { if (!r.IsOk) {
@ -87,8 +87,8 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
return await _context.Queue.QueueObject(GetPoolQueue(pool.PoolId), workSet, StorageType.Corpus); return await _context.Queue.QueueObject(GetPoolQueue(pool.PoolId), workSet, StorageType.Corpus);
} }
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) { public IAsyncEnumerable<Pool> GetByObjectId(Guid objectId) {
return QueryAsync(filter: $"client_id eq '{clientId}'"); return QueryAsync(filter: $"object_id eq '{objectId}'");
} }
public string GetPoolQueue(Guid poolId) public string GetPoolQueue(Guid poolId)

View File

@ -64,5 +64,6 @@ public sealed class TestServiceConfiguration : IServiceConfig {
public string? OneFuzzAllowOutdatedAgent => throw new NotImplementedException(); public string? OneFuzzAllowOutdatedAgent => throw new NotImplementedException();
public string? AppConfigurationEndpoint => throw new NotImplementedException(); public string? AppConfigurationEndpoint => throw new NotImplementedException();
public Uri OneFuzzBaseAddress { get => new Uri("http://test"); }
public string? AppConfigurationConnectionString => throw new NotImplementedException(); public string? AppConfigurationConnectionString => throw new NotImplementedException();
} }

1
src/agent/Cargo.lock generated
View File

@ -2616,6 +2616,7 @@ dependencies = [
"log", "log",
"onefuzz-telemetry", "onefuzz-telemetry",
"reqwest", "reqwest",
"thiserror",
"tokio", "tokio",
"wiremock", "wiremock",
] ]

View File

@ -12,6 +12,7 @@ backoff = { version = "0.4", features = ["tokio"] }
log = "0.4" log = "0.4"
onefuzz-telemetry = { path = "../onefuzz-telemetry" } onefuzz-telemetry = { path = "../onefuzz-telemetry" }
reqwest = { version = "0.11", features = ["json", "stream", "native-tls-vendored"], default-features=false } reqwest = { version = "0.11", features = ["json", "stream", "native-tls-vendored"], default-features=false }
thiserror = "1.0"
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.16", features = ["macros"] } tokio = { version = "1.16", features = ["macros"] }

View File

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT License. // Licensed under the MIT License.
use anyhow::{Context, Result}; use anyhow::{anyhow, Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use backoff::{self, future::retry_notify, ExponentialBackoff}; use backoff::{self, future::retry_notify, ExponentialBackoff};
use onefuzz_telemetry::debug; use onefuzz_telemetry::debug;
@ -20,6 +20,32 @@ pub enum RetryCheck {
Succeed, Succeed,
} }
#[derive(Debug, thiserror::Error)]
pub enum ReqwestRetryError {
#[error("request failed with status code {status_code} and url {url}")]
Response {
status_code: StatusCode,
url: reqwest::Url,
source: anyhow::Error,
},
#[error("request failed to be sent")]
SendError { source: anyhow::Error },
}
impl ReqwestRetryError {
fn response_error(status_code: StatusCode, url: reqwest::Url, source: anyhow::Error) -> Self {
Self::Response {
status_code,
url,
source,
}
}
fn send_error(source: anyhow::Error) -> Self {
Self::SendError { source }
}
}
fn always_retry(_: StatusCode) -> RetryCheck { fn always_retry(_: StatusCode) -> RetryCheck {
RetryCheck::Retry RetryCheck::Retry
} }
@ -51,7 +77,8 @@ where
let counter = AtomicUsize::new(0); let counter = AtomicUsize::new(0);
let op = || async { let op = || async {
let attempt_count = counter.fetch_add(1, Ordering::SeqCst); let attempt_count = counter.fetch_add(1, Ordering::SeqCst);
let request = build_request().map_err(|err| backoff::Error::Permanent(Err(err)))?; let request = build_request()
.map_err(|err| backoff::Error::Permanent(ReqwestRetryError::send_error(err)))?;
let result = request let result = request
.send() .send()
.await .await
@ -59,9 +86,9 @@ where
match result { match result {
Err(x) => { Err(x) => {
if attempt_count >= max_retry { if attempt_count >= max_retry {
Err(backoff::Error::Permanent(Err(x))) Err(backoff::Error::Permanent(ReqwestRetryError::send_error(x)))
} else { } else {
Err(backoff::Error::transient(Err(x))) Err(backoff::Error::transient(ReqwestRetryError::send_error(x)))
} }
} }
Ok(x) => { Ok(x) => {
@ -70,31 +97,40 @@ where
} else { } else {
let status = x.status(); let status = x.status();
let result = check_status(status); let result = check_status(status);
let url = x.url().clone();
match result { match result {
RetryCheck::Succeed => Ok(x), RetryCheck::Succeed => Ok(x),
RetryCheck::Fail => { RetryCheck::Fail => {
match x.error_for_status().with_context(|| { let content = x.text().await.unwrap_or_else(|_| "".to_string());
format!("request attempt {} failed", attempt_count + 1) let e = anyhow!(
}) { "request attempt {} failed with status code {} and content {}",
// the is_success check earlier should have taken care of this already. attempt_count + 1,
Ok(x) => Ok(x), status,
Err(as_err) => Err(backoff::Error::Permanent(Err(as_err))), content
} );
Err(backoff::Error::Permanent(
ReqwestRetryError::response_error(status, url, e),
))
} }
RetryCheck::Retry => { RetryCheck::Retry => {
match x.error_for_status().with_context(|| { let content = x.text().await.unwrap_or_else(|_| "".to_string());
format!("request attempt {} failed", attempt_count + 1) let e = anyhow!(
}) { "request attempt {} failed with status code {} and content {}",
// the is_success check earlier should have taken care of this already. attempt_count + 1,
Ok(x) => Ok(x), status,
Err(as_err) => { content
if attempt_count >= max_retry { );
Err(backoff::Error::Permanent(Err(as_err)))
} else { if attempt_count >= max_retry {
Err(backoff::Error::transient(Err(as_err))) Err(backoff::Error::Permanent(
} ReqwestRetryError::response_error(status, url, e),
} ))
} else {
Err(backoff::Error::transient(
ReqwestRetryError::response_error(status, url, e),
))
} }
} }
} }
@ -109,20 +145,13 @@ where
..ExponentialBackoff::default() ..ExponentialBackoff::default()
}, },
op, op,
|err: Result<Response, anyhow::Error>, dur| match err { |err: ReqwestRetryError, dur| debug!("request attempt failed after {:?}: {:?}", dur, err),
Ok(response) => {
if let Err(err) = response.error_for_status() {
debug!("request attempt failed after {:?}: {:?}", dur, err)
}
}
err => debug!("request attempt failed after {:?}: {:?}", dur, err),
},
) )
.await; .await;
match result { match result {
Ok(response) | Err(Ok(response)) => Ok(response), Ok(response) => Ok(response),
Err(Err(err)) => Err(err), Err(error) => Err(error.into()),
} }
} }
@ -174,18 +203,16 @@ impl SendRetry for reqwest::RequestBuilder {
pub fn is_auth_failure(response: &Result<Response>) -> bool { pub fn is_auth_failure(response: &Result<Response>) -> bool {
// Check both cases to support `error_for_status()`. // Check both cases to support `error_for_status()`.
match response { match response {
Ok(response) => { Ok(response) => response.status() == StatusCode::UNAUTHORIZED,
return response.status() == StatusCode::UNAUTHORIZED; Err(error) => match error.downcast_ref::<ReqwestRetryError>() {
} Some(ReqwestRetryError::Response {
Err(error) => { status_code,
if let Some(error) = error.downcast_ref::<reqwest::Error>() { url: _,
if let Some(status) = error.status() { source: _,
return status == StatusCode::UNAUTHORIZED; }) => status_code == &StatusCode::UNAUTHORIZED,
} _ => false,
} },
}
} }
false
} }
#[cfg(test)] #[cfg(test)]

View File

@ -1239,7 +1239,7 @@ class Pool(Endpoint):
self, self,
name: str, name: str,
os: enums.OS, os: enums.OS,
client_id: Optional[UUID] = None, object_id: Optional[UUID] = None,
*, *,
unmanaged: bool = False, unmanaged: bool = False,
arch: enums.Architecture = enums.Architecture.x86_64, arch: enums.Architecture = enums.Architecture.x86_64,
@ -1256,7 +1256,7 @@ class Pool(Endpoint):
"POST", "POST",
models.Pool, models.Pool,
data=requests.PoolCreate( data=requests.PoolCreate(
name=name, os=os, arch=arch, managed=managed, client_id=client_id name=name, os=os, arch=arch, managed=managed, object_id=object_id
), ),
) )

View File

@ -92,7 +92,7 @@ class PoolCreate(BaseRequest):
os: OS os: OS
arch: Architecture arch: Architecture
managed: bool managed: bool
client_id: Optional[UUID] object_id: Optional[UUID]
autoscale: Optional[AutoScaleConfig] autoscale: Optional[AutoScaleConfig]