mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 20:38:06 +00:00
rename client_id in pool to object_id (#2673)
* rename client_id in pool to object_id * fix tests * print out the content body when receiving an error response in the agent * fix test * Apply suggestions from code review * Update src/ApiService/ApiService/Functions/AgentRegistration.cs * format * cleanup * format * address pr comment
This commit is contained in:
@ -73,7 +73,6 @@ public class AgentRegistration {
|
||||
var baseAddress = _context.Creds.GetInstanceUrl();
|
||||
var eventsUrl = new Uri(baseAddress, "/api/agents/events");
|
||||
var commandsUrl = new Uri(baseAddress, "/api/agents/commands");
|
||||
|
||||
var workQueue = await _context.Queue.GetQueueSas(
|
||||
_context.PoolOperations.GetPoolQueue(pool.PoolId),
|
||||
StorageType.Corpus,
|
||||
|
@ -67,7 +67,7 @@ public class Pool {
|
||||
Errors: new string[] { "pool with that name already exists" }),
|
||||
"PoolCreate");
|
||||
}
|
||||
var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, clientId: create.ClientId);
|
||||
var newPool = await _context.PoolOperations.Create(name: create.Name, os: create.Os, architecture: create.Arch, managed: create.Managed, objectId: create.ObjectId);
|
||||
return await RequestHandling.Ok(req, await Populate(PoolToPoolResponse(newPool), true));
|
||||
}
|
||||
|
||||
@ -106,7 +106,7 @@ public class Pool {
|
||||
PoolId: p.PoolId,
|
||||
Os: p.Os,
|
||||
State: p.State,
|
||||
ClientId: p.ClientId,
|
||||
ObjectId: p.ObjectId,
|
||||
Managed: p.Managed,
|
||||
Arch: p.Arch,
|
||||
Nodes: p.Nodes,
|
||||
|
@ -645,7 +645,7 @@ public record Pool(
|
||||
bool Managed,
|
||||
Architecture Arch,
|
||||
PoolState State,
|
||||
Guid? ClientId = null
|
||||
Guid? ObjectId = null
|
||||
) : StatefulEntityBase<PoolState>(State) {
|
||||
public List<Node>? Nodes { get; set; }
|
||||
public AgentConfig? Config { get; set; }
|
||||
|
@ -263,7 +263,7 @@ public record PoolCreate(
|
||||
[property: Required] Os Os,
|
||||
[property: Required] Architecture Arch,
|
||||
[property: Required] bool Managed,
|
||||
Guid? ClientId = null
|
||||
Guid? ObjectId = null
|
||||
) : BaseRequest;
|
||||
|
||||
public record WebhookCreate(
|
||||
|
@ -114,7 +114,7 @@ public record PoolGetResult(
|
||||
bool Managed,
|
||||
Architecture Arch,
|
||||
PoolState State,
|
||||
Guid? ClientId,
|
||||
Guid? ObjectId,
|
||||
List<Node>? Nodes,
|
||||
AgentConfig? Config,
|
||||
List<WorkSetSummary>? WorkQueue,
|
||||
|
@ -50,6 +50,8 @@ public interface IServiceConfig {
|
||||
// multiple instances to run against the same storage account, which
|
||||
// is useful for things like integration testing.
|
||||
public string OneFuzzStoragePrefix { get; }
|
||||
|
||||
public Uri OneFuzzBaseAddress { get; }
|
||||
}
|
||||
|
||||
public class ServiceConfiguration : IServiceConfig {
|
||||
@ -134,4 +136,12 @@ public class ServiceConfiguration : IServiceConfig {
|
||||
|
||||
public string OneFuzzNodeDisposalStrategy { get => GetEnv("ONEFUZZ_NODE_DISPOSAL_STRATEGY") ?? "scale_in"; }
|
||||
public string OneFuzzStoragePrefix => ""; // in production we never prefix the tables
|
||||
|
||||
public Uri OneFuzzBaseAddress {
|
||||
get {
|
||||
var hostName = Environment.GetEnvironmentVariable("WEBSITE_HOSTNAME");
|
||||
var scheme = Environment.GetEnvironmentVariable("HTTPS") != null ? "https" : "http";
|
||||
return new Uri($"{scheme}://{hostName}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -46,9 +46,12 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
}
|
||||
|
||||
var token = tokenResult.OkV.UserInfo;
|
||||
if (await IsUser(tokenResult.OkV)) {
|
||||
|
||||
var (isAgent, reason) = await IsAgent(tokenResult.OkV);
|
||||
|
||||
if (!isAgent) {
|
||||
if (!allowUser) {
|
||||
return await Reject(req, token);
|
||||
return await Reject(req, token, "endpoint not allowed for users");
|
||||
}
|
||||
|
||||
var access = await CheckAccess(req);
|
||||
@ -57,26 +60,24 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
}
|
||||
}
|
||||
|
||||
if (await IsAgent(tokenResult.OkV) && !allowAgent) {
|
||||
return await Reject(req, token);
|
||||
|
||||
if (isAgent && !allowAgent) {
|
||||
return await Reject(req, token, reason);
|
||||
}
|
||||
|
||||
return await method(req);
|
||||
}
|
||||
|
||||
public async Async.Task<bool> IsUser(UserAuthInfo tokenData) {
|
||||
return !await IsAgent(tokenData);
|
||||
}
|
||||
|
||||
public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token) {
|
||||
public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token, String? reason = null) {
|
||||
var body = await req.ReadAsStringAsync();
|
||||
_log.Error($"reject token. url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}");
|
||||
_log.Error($"reject token. reason:{reason} url:{req.Url:Tag:Url} token:{token:Tag:Token} body:{body:Tag:Body}");
|
||||
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
new Error(
|
||||
ErrorCode.UNAUTHORIZED,
|
||||
new string[] { "Unrecognized agent" }
|
||||
new string[] { reason ?? "Unrecognized agent" }
|
||||
),
|
||||
"token verification",
|
||||
HttpStatusCode.Unauthorized
|
||||
@ -186,9 +187,10 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
return null;
|
||||
}
|
||||
|
||||
public async Async.Task<bool> IsAgent(UserAuthInfo authInfo) {
|
||||
|
||||
public async Async.Task<(bool, string)> IsAgent(UserAuthInfo authInfo) {
|
||||
if (!AgentRoles.Overlaps(authInfo.Roles)) {
|
||||
return false;
|
||||
return (false, "no agent role");
|
||||
}
|
||||
|
||||
var tokenData = authInfo.UserInfo;
|
||||
@ -196,24 +198,24 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
if (tokenData.ObjectId != null) {
|
||||
var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value);
|
||||
if (await scalesets.AnyAsync()) {
|
||||
return true;
|
||||
return (true, string.Empty);
|
||||
}
|
||||
|
||||
var principalId = await _context.Creds.GetScalesetPrincipalId();
|
||||
if (principalId == tokenData.ObjectId) {
|
||||
return true;
|
||||
return (true, string.Empty);
|
||||
}
|
||||
}
|
||||
|
||||
if (!tokenData.ApplicationId.HasValue) {
|
||||
return false;
|
||||
if (!tokenData.ObjectId.HasValue) {
|
||||
return (false, "no object id in token");
|
||||
}
|
||||
|
||||
var pools = _context.PoolOperations.GetByClientId(tokenData.ApplicationId.Value);
|
||||
var pools = _context.PoolOperations.GetByObjectId(tokenData.ObjectId.Value);
|
||||
if (await pools.AnyAsync()) {
|
||||
return true;
|
||||
return (true, string.Empty);
|
||||
}
|
||||
|
||||
return false;
|
||||
return (false, "no matching scaleset or pool");
|
||||
}
|
||||
}
|
||||
|
@ -6,14 +6,14 @@ public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
|
||||
Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName);
|
||||
Async.Task<OneFuzzResult<Pool>> GetById(Guid poolId);
|
||||
Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet);
|
||||
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
|
||||
IAsyncEnumerable<Pool> GetByObjectId(Guid objectId);
|
||||
string GetPoolQueue(Guid poolId);
|
||||
Async.Task<List<ScalesetSummary>> GetScalesetSummary(PoolName name);
|
||||
Async.Task<List<WorkSetSummary>> GetWorkQueue(Guid poolId, PoolState state);
|
||||
IAsyncEnumerable<Pool> SearchStates(IEnumerable<PoolState> states);
|
||||
Async.Task<Pool> SetShutdown(Pool pool, bool Now);
|
||||
|
||||
Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null);
|
||||
Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null);
|
||||
new Async.Task Delete(Pool pool);
|
||||
|
||||
// state transitions:
|
||||
@ -32,7 +32,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
|
||||
|
||||
}
|
||||
|
||||
public async Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? clientId = null) {
|
||||
public async Async.Task<Pool> Create(PoolName name, Os os, Architecture architecture, bool managed, Guid? objectId = null) {
|
||||
var newPool = new Service.Pool(
|
||||
PoolId: Guid.NewGuid(),
|
||||
State: PoolState.Init,
|
||||
@ -40,7 +40,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
|
||||
Os: os,
|
||||
Managed: managed,
|
||||
Arch: architecture,
|
||||
ClientId: clientId);
|
||||
ObjectId: objectId);
|
||||
|
||||
var r = await Insert(newPool);
|
||||
if (!r.IsOk) {
|
||||
@ -87,8 +87,8 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
|
||||
return await _context.Queue.QueueObject(GetPoolQueue(pool.PoolId), workSet, StorageType.Corpus);
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
|
||||
return QueryAsync(filter: $"client_id eq '{clientId}'");
|
||||
public IAsyncEnumerable<Pool> GetByObjectId(Guid objectId) {
|
||||
return QueryAsync(filter: $"object_id eq '{objectId}'");
|
||||
}
|
||||
|
||||
public string GetPoolQueue(Guid poolId)
|
||||
|
@ -64,5 +64,6 @@ public sealed class TestServiceConfiguration : IServiceConfig {
|
||||
|
||||
public string? OneFuzzAllowOutdatedAgent => throw new NotImplementedException();
|
||||
public string? AppConfigurationEndpoint => throw new NotImplementedException();
|
||||
public Uri OneFuzzBaseAddress { get => new Uri("http://test"); }
|
||||
public string? AppConfigurationConnectionString => throw new NotImplementedException();
|
||||
}
|
||||
|
1
src/agent/Cargo.lock
generated
1
src/agent/Cargo.lock
generated
@ -2616,6 +2616,7 @@ dependencies = [
|
||||
"log",
|
||||
"onefuzz-telemetry",
|
||||
"reqwest",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"wiremock",
|
||||
]
|
||||
|
@ -12,6 +12,7 @@ backoff = { version = "0.4", features = ["tokio"] }
|
||||
log = "0.4"
|
||||
onefuzz-telemetry = { path = "../onefuzz-telemetry" }
|
||||
reqwest = { version = "0.11", features = ["json", "stream", "native-tls-vendored"], default-features=false }
|
||||
thiserror = "1.0"
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.16", features = ["macros"] }
|
||||
|
@ -1,7 +1,7 @@
|
||||
// Copyright (c) Microsoft Corporation.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use async_trait::async_trait;
|
||||
use backoff::{self, future::retry_notify, ExponentialBackoff};
|
||||
use onefuzz_telemetry::debug;
|
||||
@ -20,6 +20,32 @@ pub enum RetryCheck {
|
||||
Succeed,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ReqwestRetryError {
|
||||
#[error("request failed with status code {status_code} and url {url}")]
|
||||
Response {
|
||||
status_code: StatusCode,
|
||||
url: reqwest::Url,
|
||||
source: anyhow::Error,
|
||||
},
|
||||
#[error("request failed to be sent")]
|
||||
SendError { source: anyhow::Error },
|
||||
}
|
||||
|
||||
impl ReqwestRetryError {
|
||||
fn response_error(status_code: StatusCode, url: reqwest::Url, source: anyhow::Error) -> Self {
|
||||
Self::Response {
|
||||
status_code,
|
||||
url,
|
||||
source,
|
||||
}
|
||||
}
|
||||
|
||||
fn send_error(source: anyhow::Error) -> Self {
|
||||
Self::SendError { source }
|
||||
}
|
||||
}
|
||||
|
||||
fn always_retry(_: StatusCode) -> RetryCheck {
|
||||
RetryCheck::Retry
|
||||
}
|
||||
@ -51,7 +77,8 @@ where
|
||||
let counter = AtomicUsize::new(0);
|
||||
let op = || async {
|
||||
let attempt_count = counter.fetch_add(1, Ordering::SeqCst);
|
||||
let request = build_request().map_err(|err| backoff::Error::Permanent(Err(err)))?;
|
||||
let request = build_request()
|
||||
.map_err(|err| backoff::Error::Permanent(ReqwestRetryError::send_error(err)))?;
|
||||
let result = request
|
||||
.send()
|
||||
.await
|
||||
@ -59,9 +86,9 @@ where
|
||||
match result {
|
||||
Err(x) => {
|
||||
if attempt_count >= max_retry {
|
||||
Err(backoff::Error::Permanent(Err(x)))
|
||||
Err(backoff::Error::Permanent(ReqwestRetryError::send_error(x)))
|
||||
} else {
|
||||
Err(backoff::Error::transient(Err(x)))
|
||||
Err(backoff::Error::transient(ReqwestRetryError::send_error(x)))
|
||||
}
|
||||
}
|
||||
Ok(x) => {
|
||||
@ -70,31 +97,40 @@ where
|
||||
} else {
|
||||
let status = x.status();
|
||||
let result = check_status(status);
|
||||
let url = x.url().clone();
|
||||
|
||||
match result {
|
||||
RetryCheck::Succeed => Ok(x),
|
||||
RetryCheck::Fail => {
|
||||
match x.error_for_status().with_context(|| {
|
||||
format!("request attempt {} failed", attempt_count + 1)
|
||||
}) {
|
||||
// the is_success check earlier should have taken care of this already.
|
||||
Ok(x) => Ok(x),
|
||||
Err(as_err) => Err(backoff::Error::Permanent(Err(as_err))),
|
||||
}
|
||||
let content = x.text().await.unwrap_or_else(|_| "".to_string());
|
||||
let e = anyhow!(
|
||||
"request attempt {} failed with status code {} and content {}",
|
||||
attempt_count + 1,
|
||||
status,
|
||||
content
|
||||
);
|
||||
|
||||
Err(backoff::Error::Permanent(
|
||||
ReqwestRetryError::response_error(status, url, e),
|
||||
))
|
||||
}
|
||||
RetryCheck::Retry => {
|
||||
match x.error_for_status().with_context(|| {
|
||||
format!("request attempt {} failed", attempt_count + 1)
|
||||
}) {
|
||||
// the is_success check earlier should have taken care of this already.
|
||||
Ok(x) => Ok(x),
|
||||
Err(as_err) => {
|
||||
let content = x.text().await.unwrap_or_else(|_| "".to_string());
|
||||
let e = anyhow!(
|
||||
"request attempt {} failed with status code {} and content {}",
|
||||
attempt_count + 1,
|
||||
status,
|
||||
content
|
||||
);
|
||||
|
||||
if attempt_count >= max_retry {
|
||||
Err(backoff::Error::Permanent(Err(as_err)))
|
||||
Err(backoff::Error::Permanent(
|
||||
ReqwestRetryError::response_error(status, url, e),
|
||||
))
|
||||
} else {
|
||||
Err(backoff::Error::transient(Err(as_err)))
|
||||
}
|
||||
}
|
||||
Err(backoff::Error::transient(
|
||||
ReqwestRetryError::response_error(status, url, e),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -109,20 +145,13 @@ where
|
||||
..ExponentialBackoff::default()
|
||||
},
|
||||
op,
|
||||
|err: Result<Response, anyhow::Error>, dur| match err {
|
||||
Ok(response) => {
|
||||
if let Err(err) = response.error_for_status() {
|
||||
debug!("request attempt failed after {:?}: {:?}", dur, err)
|
||||
}
|
||||
}
|
||||
err => debug!("request attempt failed after {:?}: {:?}", dur, err),
|
||||
},
|
||||
|err: ReqwestRetryError, dur| debug!("request attempt failed after {:?}: {:?}", dur, err),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(response) | Err(Ok(response)) => Ok(response),
|
||||
Err(Err(err)) => Err(err),
|
||||
Ok(response) => Ok(response),
|
||||
Err(error) => Err(error.into()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -174,18 +203,16 @@ impl SendRetry for reqwest::RequestBuilder {
|
||||
pub fn is_auth_failure(response: &Result<Response>) -> bool {
|
||||
// Check both cases to support `error_for_status()`.
|
||||
match response {
|
||||
Ok(response) => {
|
||||
return response.status() == StatusCode::UNAUTHORIZED;
|
||||
Ok(response) => response.status() == StatusCode::UNAUTHORIZED,
|
||||
Err(error) => match error.downcast_ref::<ReqwestRetryError>() {
|
||||
Some(ReqwestRetryError::Response {
|
||||
status_code,
|
||||
url: _,
|
||||
source: _,
|
||||
}) => status_code == &StatusCode::UNAUTHORIZED,
|
||||
_ => false,
|
||||
},
|
||||
}
|
||||
Err(error) => {
|
||||
if let Some(error) = error.downcast_ref::<reqwest::Error>() {
|
||||
if let Some(status) = error.status() {
|
||||
return status == StatusCode::UNAUTHORIZED;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -1239,7 +1239,7 @@ class Pool(Endpoint):
|
||||
self,
|
||||
name: str,
|
||||
os: enums.OS,
|
||||
client_id: Optional[UUID] = None,
|
||||
object_id: Optional[UUID] = None,
|
||||
*,
|
||||
unmanaged: bool = False,
|
||||
arch: enums.Architecture = enums.Architecture.x86_64,
|
||||
@ -1256,7 +1256,7 @@ class Pool(Endpoint):
|
||||
"POST",
|
||||
models.Pool,
|
||||
data=requests.PoolCreate(
|
||||
name=name, os=os, arch=arch, managed=managed, client_id=client_id
|
||||
name=name, os=os, arch=arch, managed=managed, object_id=object_id
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -92,7 +92,7 @@ class PoolCreate(BaseRequest):
|
||||
os: OS
|
||||
arch: Architecture
|
||||
managed: bool
|
||||
client_id: Optional[UUID]
|
||||
object_id: Optional[UUID]
|
||||
autoscale: Optional[AutoScaleConfig]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user