rename client_id in pool to object_id (#2673)

* rename client_id in pool to object_id

* fix tests

* print out the content body when receiving an error response in the agent

* fix test

* Apply suggestions from code review

* Update src/ApiService/ApiService/Functions/AgentRegistration.cs

* format

* cleanup

* format

* address pr comment
This commit is contained in:
Cheick Keita
2022-12-12 19:39:49 -08:00
committed by GitHub
parent 3cf09c6a40
commit 4c1adb6e96
15 changed files with 120 additions and 79 deletions

View File

@ -73,7 +73,6 @@ public class AgentRegistration {
var baseAddress = _context.Creds.GetInstanceUrl();
var eventsUrl = new Uri(baseAddress, "/api/agents/events");
var commandsUrl = new Uri(baseAddress, "/api/agents/commands");
var workQueue = await _context.Queue.GetQueueSas(
_context.PoolOperations.GetPoolQueue(pool.PoolId),
StorageType.Corpus,

View File

@ -67,7 +67,7 @@ public class Pool {
Errors: new string[] { "pool with that name already exists" }),
"PoolCreate");
}
var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, clientId: create.ClientId);
var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, objectId: create.ObjectId);
return await RequestHandling.Ok(req, await Populate(PoolToPoolResponse(newPool), true));
}
@ -106,7 +106,7 @@ public class Pool {
PoolId: p.PoolId,
Os: p.Os,
State: p.State,
ClientId: p.ClientId,
ObjectId: p.ObjectId,
Managed: p.Managed,
Arch: p.Arch,
Nodes: p.Nodes,

View File

@ -645,7 +645,7 @@ public record Pool(
bool Managed,
Architecture Arch,
PoolState State,
Guid? ClientId = null
Guid? ObjectId = null
) : StatefulEntityBase<PoolState>(State) {
public List<Node>? Nodes { get; set; }
public AgentConfig? Config { get; set; }

View File

@ -263,7 +263,7 @@ public record PoolCreate(
[property: Required] Os Os,
[property: Required] Architecture Arch,
[property: Required] bool Managed,
Guid? ClientId = null
Guid? ObjectId = null
) : BaseRequest;
public record WebhookCreate(

View File

@ -114,7 +114,7 @@ public record PoolGetResult(
bool Managed,
Architecture Arch,
PoolState State,
Guid? ClientId,
Guid? ObjectId,
List<Node>? Nodes,
AgentConfig? Config,
List<WorkSetSummary>? WorkQueue,

View File

@ -50,6 +50,8 @@ public interface IServiceConfig {
// multiple instances to run against the same storage account, which
// is useful for things like integration testing.
public string OneFuzzStoragePrefix { get; }
public Uri OneFuzzBaseAddress { get; }
}
public class ServiceConfiguration : IServiceConfig {
@ -134,4 +136,12 @@ public class ServiceConfiguration : IServiceConfig {
public string OneFuzzNodeDisposalStrategy { get => GetEnv("ONEFUZZ_NODE_DISPOSAL_STRATEGY") ?? "scale_in"; }
public string OneFuzzStoragePrefix => ""; // in production we never prefix the tables
public Uri OneFuzzBaseAddress {
get {
var hostName = Environment.GetEnvironmentVariable("WEBSITE_HOSTNAME");
var scheme = Environment.GetEnvironmentVariable("HTTPS") != null ? "https" : "http";
return new Uri($"{scheme}://{hostName}");
}
}
}

View File

@ -46,9 +46,12 @@ public class EndpointAuthorization : IEndpointAuthorization {
}
var token = tokenResult.OkV.UserInfo;
if (await IsUser(tokenResult.OkV)) {
var (isAgent, reason) = await IsAgent(tokenResult.OkV);
if (!isAgent) {
if (!allowUser) {
return await Reject(req, token);
return await Reject(req, token, "endpoint not allowed for users");
}
var access = await CheckAccess(req);
@ -57,26 +60,24 @@ public class EndpointAuthorization : IEndpointAuthorization {
}
}
if (await IsAgent(tokenResult.OkV) && !allowAgent) {
return await Reject(req, token);
if (isAgent && !allowAgent) {
return await Reject(req, token, reason);
}
return await method(req);
}
public async Async.Task<bool> IsUser(UserAuthInfo tokenData) {
return !await IsAgent(tokenData);
}
public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token) {
public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token, String? reason = null) {
var body = await req.ReadAsStringAsync();
_log.Error($"reject token. url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}");
_log.Error($"reject token. reason:{reason} url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}");
return await _context.RequestHandling.NotOk(
req,
new Error(
ErrorCode.UNAUTHORIZED,
new string[] { "Unrecognized agent" }
new string[] { reason ?? "Unrecognized agent" }
),
"token verification",
HttpStatusCode.Unauthorized
@ -186,9 +187,10 @@ public class EndpointAuthorization : IEndpointAuthorization {
return null;
}
public async Async.Task<bool> IsAgent(UserAuthInfo authInfo) {
public async Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo) {
if (!AgentRoles.Overlaps(authInfo.Roles)) {
return false;
return (false, "no agent role");
}
var tokenData = authInfo.UserInfo;
@ -196,24 +198,24 @@ public class EndpointAuthorization : IEndpointAuthorization {
if (tokenData.ObjectId != null) {
var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value);
if (await scalesets.AnyAsync()) {
return true;
return (true, string.Empty);
}
var principalId = await _context.Creds.GetScalesetPrincipalId();
if (principalId == tokenData.ObjectId) {
return true;
return (true, string.Empty);
}
}
if (!tokenData.ApplicationId.HasValue) {
return false;
if (!tokenData.ObjectId.HasValue) {
return (false, "no object id in token");
}
var pools = _context.PoolOperations.GetByClientId(tokenData.ApplicationId.Value);
var pools = _context.PoolOperations.GetByObjectId(tokenData.ObjectId.Value);
if (await pools.AnyAsync()) {
return true;
return (true, string.Empty);
}
return false;
return (false, "no matching scaleset or pool");
}
}

View File

@ -6,14 +6,14 @@ public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName);
Async.Task<OneFuzzResult<Pool>> GetById(Guid poolId);
Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet);
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
IAsyncEnumerable<Pool> GetByObjectId(Guid objectId);
string GetPoolQueue(Guid poolId);
Async.Task<List<ScalesetSummary>> GetScalesetSummary(PoolName name);
Async.Task<List<WorkSetSummary>> GetWorkQueue(Guid poolId, PoolState state);
IAsyncEnumerable<Pool> SearchStates(IEnumerable<PoolState> states);
Async.Task<Pool> SetShutdown(Pool pool, bool Now);
Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null);
Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null);
new Async.Task Delete(Pool pool);
// state transitions:
@ -32,7 +32,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
}
public async Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null) {
public async Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null) {
var newPool = new Service.Pool(
PoolId: Guid.NewGuid(),
State: PoolState.Init,
@ -40,7 +40,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
Os: os,
Managed: managed,
Arch: architecture,
ClientId: clientId);
ObjectId: objectId);
var r = await Insert(newPool);
if (!r.IsOk) {
@ -87,8 +87,8 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
return await _context.Queue.QueueObject(GetPoolQueue(pool.PoolId), workSet, StorageType.Corpus);
}
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
return QueryAsync(filter: $"client_id eq '{clientId}'");
public IAsyncEnumerable<Pool> GetByObjectId(Guid objectId) {
return QueryAsync(filter: $"object_id eq '{objectId}'");
}
public string GetPoolQueue(Guid poolId)

View File

@ -64,5 +64,6 @@ public sealed class TestServiceConfiguration : IServiceConfig {
public string? OneFuzzAllowOutdatedAgent => throw new NotImplementedException();
public string? AppConfigurationEndpoint => throw new NotImplementedException();
public Uri OneFuzzBaseAddress { get => new Uri("http://test"); }
public string? AppConfigurationConnectionString => throw new NotImplementedException();
}

View File

@ -1,2 +1,2 @@
target
.agent-run
.agent-run

1
src/agent/Cargo.lock generated
View File

@ -2616,6 +2616,7 @@ dependencies = [
"log",
"onefuzz-telemetry",
"reqwest",
"thiserror",
"tokio",
"wiremock",
]

View File

@ -12,6 +12,7 @@ backoff = { version = "0.4", features = ["tokio"] }
log = "0.4"
onefuzz-telemetry = { path = "../onefuzz-telemetry" }
reqwest = { version = "0.11", features = ["json", "stream", "native-tls-vendored"], default-features=false }
thiserror = "1.0"
[dev-dependencies]
tokio = { version = "1.16", features = ["macros"] }

View File

@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
use anyhow::{Context, Result};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
use backoff::{self, future::retry_notify, ExponentialBackoff};
use onefuzz_telemetry::debug;
@ -20,6 +20,32 @@ pub enum RetryCheck {
Succeed,
}
#[derive(Debug, thiserror::Error)]
pub enum ReqwestRetryError {
#[error("request failed with status code {status_code} and url {url}")]
Response {
status_code: StatusCode,
url: reqwest::Url,
source: anyhow::Error,
},
#[error("request failed to be sent")]
SendError { source: anyhow::Error },
}
impl ReqwestRetryError {
fn response_error(status_code: StatusCode, url: reqwest::Url, source: anyhow::Error) -> Self {
Self::Response {
status_code,
url,
source,
}
}
fn send_error(source: anyhow::Error) -> Self {
Self::SendError { source }
}
}
fn always_retry(_: StatusCode) -> RetryCheck {
RetryCheck::Retry
}
@ -51,7 +77,8 @@ where
let counter = AtomicUsize::new(0);
let op = || async {
let attempt_count = counter.fetch_add(1, Ordering::SeqCst);
let request = build_request().map_err(|err| backoff::Error::Permanent(Err(err)))?;
let request = build_request()
.map_err(|err| backoff::Error::Permanent(ReqwestRetryError::send_error(err)))?;
let result = request
.send()
.await
@ -59,9 +86,9 @@ where
match result {
Err(x) => {
if attempt_count >= max_retry {
Err(backoff::Error::Permanent(Err(x)))
Err(backoff::Error::Permanent(ReqwestRetryError::send_error(x)))
} else {
Err(backoff::Error::transient(Err(x)))
Err(backoff::Error::transient(ReqwestRetryError::send_error(x)))
}
}
Ok(x) => {
@ -70,31 +97,40 @@ where
} else {
let status = x.status();
let result = check_status(status);
let url = x.url().clone();
match result {
RetryCheck::Succeed => Ok(x),
RetryCheck::Fail => {
match x.error_for_status().with_context(|| {
format!("request attempt {} failed", attempt_count + 1)
}) {
// the is_success check earlier should have taken care of this already.
Ok(x) => Ok(x),
Err(as_err) => Err(backoff::Error::Permanent(Err(as_err))),
}
let content = x.text().await.unwrap_or_else(|_| "".to_string());
let e = anyhow!(
"request attempt {} failed with status code {} and content {}",
attempt_count + 1,
status,
content
);
Err(backoff::Error::Permanent(
ReqwestRetryError::response_error(status, url, e),
))
}
RetryCheck::Retry => {
match x.error_for_status().with_context(|| {
format!("request attempt {} failed", attempt_count + 1)
}) {
// the is_success check earlier should have taken care of this already.
Ok(x) => Ok(x),
Err(as_err) => {
if attempt_count >= max_retry {
Err(backoff::Error::Permanent(Err(as_err)))
} else {
Err(backoff::Error::transient(Err(as_err)))
}
}
let content = x.text().await.unwrap_or_else(|_| "".to_string());
let e = anyhow!(
"request attempt {} failed with status code {} and content {}",
attempt_count + 1,
status,
content
);
if attempt_count >= max_retry {
Err(backoff::Error::Permanent(
ReqwestRetryError::response_error(status, url, e),
))
} else {
Err(backoff::Error::transient(
ReqwestRetryError::response_error(status, url, e),
))
}
}
}
@ -109,20 +145,13 @@ where
..ExponentialBackoff::default()
},
op,
|err: Result<Response, anyhow::Error>, dur| match err {
Ok(response) => {
if let Err(err) = response.error_for_status() {
debug!("request attempt failed after {:?}: {:?}", dur, err)
}
}
err => debug!("request attempt failed after {:?}: {:?}", dur, err),
},
|err: ReqwestRetryError, dur| debug!("request attempt failed after {:?}: {:?}", dur, err),
)
.await;
match result {
Ok(response) | Err(Ok(response)) => Ok(response),
Err(Err(err)) => Err(err),
Ok(response) => Ok(response),
Err(error) => Err(error.into()),
}
}
@ -174,18 +203,16 @@ impl SendRetry for reqwest::RequestBuilder {
pub fn is_auth_failure(response: &Result<Response>) -> bool {
// Check both cases to support `error_for_status()`.
match response {
Ok(response) => {
return response.status() == StatusCode::UNAUTHORIZED;
}
Err(error) => {
if let Some(error) = error.downcast_ref::<reqwest::Error>() {
if let Some(status) = error.status() {
return status == StatusCode::UNAUTHORIZED;
}
}
}
Ok(response) => response.status() == StatusCode::UNAUTHORIZED,
Err(error) => match error.downcast_ref::<ReqwestRetryError>() {
Some(ReqwestRetryError::Response {
status_code,
url: _,
source: _,
}) => status_code == &StatusCode::UNAUTHORIZED,
_ => false,
},
}
false
}
#[cfg(test)]

View File

@ -1239,7 +1239,7 @@ class Pool(Endpoint):
self,
name: str,
os: enums.OS,
client_id: Optional[UUID] = None,
object_id: Optional[UUID] = None,
*,
unmanaged: bool = False,
arch: enums.Architecture = enums.Architecture.x86_64,
@ -1256,7 +1256,7 @@ class Pool(Endpoint):
"POST",
models.Pool,
data=requests.PoolCreate(
name=name, os=os, arch=arch, managed=managed, client_id=client_id
name=name, os=os, arch=arch, managed=managed, object_id=object_id
),
)

View File

@ -92,7 +92,7 @@ class PoolCreate(BaseRequest):
os: OS
arch: Architecture
managed: bool
client_id: Optional[UUID]
object_id: Optional[UUID]
autoscale: Optional[AutoScaleConfig]