mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-17 20:38:06 +00:00
Bug fixes related to the unmanaged nodes (#2632)
* bug fixes related to the unmanaged nodes - fix the token request in the client - adding a is_unmnaged field to the agentConfig - fix typo in the appId claim type - fix typo in the UnmanagedNode claim value - fix query PoolOperation.GetByClientId * remove unused import * build fix * change unmanaged field to managed * Apply suggestions from code review Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com> Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com>
This commit is contained in:
@ -133,7 +133,8 @@ public class Pool {
|
||||
HeartbeatQueue: queueSas,
|
||||
InstanceId: instanceId,
|
||||
ClientCredentials: null,
|
||||
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain)
|
||||
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
|
||||
Managed: p.Managed)
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -678,12 +678,11 @@ public record AgentConfig(
|
||||
string? InstanceTelemetryKey,
|
||||
string? MicrosoftTelemetryKey,
|
||||
string? MultiTenantDomain,
|
||||
Guid InstanceId
|
||||
Guid InstanceId,
|
||||
bool? Managed = true
|
||||
);
|
||||
|
||||
|
||||
|
||||
|
||||
public record Vm(
|
||||
string Name,
|
||||
Region Region,
|
||||
|
@ -77,7 +77,7 @@ public class UserCredentials : IUserCredentials {
|
||||
switch (claim.Type) {
|
||||
case "oid":
|
||||
return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } };
|
||||
case "appId":
|
||||
case "appid":
|
||||
return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } };
|
||||
case "upn":
|
||||
return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } };
|
||||
@ -88,7 +88,6 @@ public class UserCredentials : IUserCredentials {
|
||||
return acc;
|
||||
}
|
||||
});
|
||||
|
||||
return OneFuzzResult<UserAuthInfo>.Ok(userInfo);
|
||||
} else {
|
||||
var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!);
|
||||
|
@ -30,8 +30,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
private readonly IOnefuzzContext _context;
|
||||
private readonly ILogTracer _log;
|
||||
private readonly GraphServiceClient _graphClient;
|
||||
|
||||
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmamagedNode", "ManagedNode" };
|
||||
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmanagedNode", "ManagedNode" };
|
||||
|
||||
public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) {
|
||||
_context = context;
|
||||
@ -46,10 +45,10 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
|
||||
}
|
||||
|
||||
var token = tokenResult.OkV;
|
||||
if (await IsUser(token)) {
|
||||
var token = tokenResult.OkV.UserInfo;
|
||||
if (await IsUser(tokenResult.OkV)) {
|
||||
if (!allowUser) {
|
||||
return await Reject(req, tokenResult.OkV.UserInfo);
|
||||
return await Reject(req, token);
|
||||
}
|
||||
|
||||
var access = await CheckAccess(req);
|
||||
@ -58,8 +57,8 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
}
|
||||
}
|
||||
|
||||
if (await IsAgent(token) && !allowAgent) {
|
||||
return await Reject(req, tokenResult.OkV.UserInfo);
|
||||
if (await IsAgent(tokenResult.OkV) && !allowAgent) {
|
||||
return await Reject(req, token);
|
||||
}
|
||||
|
||||
return await method(req);
|
||||
@ -201,7 +200,9 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
||||
}
|
||||
|
||||
var principalId = await _context.Creds.GetScalesetPrincipalId();
|
||||
return principalId == tokenData.ObjectId;
|
||||
if (principalId == tokenData.ObjectId) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!tokenData.ApplicationId.HasValue) {
|
||||
|
@ -220,7 +220,8 @@ public class Extensions : IExtensions {
|
||||
InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey,
|
||||
MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry,
|
||||
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
|
||||
InstanceId: instanceId
|
||||
InstanceId: instanceId,
|
||||
Managed: pool.Managed
|
||||
);
|
||||
|
||||
var fileName = $"{pool.Name}/config.json";
|
||||
|
@ -1,6 +1,5 @@
|
||||
using System.Threading.Tasks;
|
||||
using ApiService.OneFuzzLib.Orm;
|
||||
using Azure.Data.Tables;
|
||||
namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
|
||||
@ -89,7 +88,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
|
||||
return QueryAsync(filter: TableClient.CreateQueryFilter($"client_id eq {clientId}"));
|
||||
return QueryAsync(filter: $"client_id eq '{clientId}'");
|
||||
}
|
||||
|
||||
public string GetPoolQueue(Guid poolId)
|
||||
|
@ -34,12 +34,19 @@ pub struct StaticConfig {
|
||||
pub heartbeat_queue: Option<Url>,
|
||||
|
||||
pub instance_id: Uuid,
|
||||
|
||||
#[serde(default = "default_as_true")]
|
||||
pub managed: bool,
|
||||
}
|
||||
|
||||
fn default_as_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// Temporary shim type to bridge the current service-provided config.
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||
struct RawStaticConfig {
|
||||
pub credentials: Option<ClientCredentials>,
|
||||
pub client_credentials: Option<ClientCredentials>,
|
||||
|
||||
pub pool_name: String,
|
||||
|
||||
@ -54,13 +61,16 @@ struct RawStaticConfig {
|
||||
pub heartbeat_queue: Option<Url>,
|
||||
|
||||
pub instance_id: Uuid,
|
||||
|
||||
#[serde(default = "default_as_true")]
|
||||
pub managed: bool,
|
||||
}
|
||||
|
||||
impl StaticConfig {
|
||||
pub fn new(data: &[u8]) -> Result<Self> {
|
||||
let config: RawStaticConfig = serde_json::from_slice(data)?;
|
||||
|
||||
let credentials = match config.credentials {
|
||||
let credentials = match config.client_credentials {
|
||||
Some(client) => client.into(),
|
||||
None => {
|
||||
// Remove trailing `/`, which is treated as a distinct resource.
|
||||
@ -83,6 +93,7 @@ impl StaticConfig {
|
||||
instance_telemetry_key: config.instance_telemetry_key,
|
||||
heartbeat_queue: config.heartbeat_queue,
|
||||
instance_id: config.instance_id,
|
||||
managed: config.managed,
|
||||
};
|
||||
|
||||
Ok(config)
|
||||
@ -103,6 +114,7 @@ impl StaticConfig {
|
||||
let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok();
|
||||
let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?;
|
||||
let pool_name = std::env::var("ONEFUZZ_POOL")?;
|
||||
let is_unmanaged = std::env::var("ONEFUZZ_IS_UNMANAGED").is_ok();
|
||||
|
||||
let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") {
|
||||
Some(Url::parse(&key)?)
|
||||
@ -142,6 +154,7 @@ impl StaticConfig {
|
||||
microsoft_telemetry_key,
|
||||
heartbeat_queue,
|
||||
instance_id,
|
||||
managed: !is_unmanaged,
|
||||
})
|
||||
}
|
||||
|
||||
@ -213,7 +226,8 @@ impl Registration {
|
||||
.append_pair("machine_id", &machine_id.to_string())
|
||||
.append_pair("machine_name", &machine_name)
|
||||
.append_pair("pool_name", &config.pool_name)
|
||||
.append_pair("version", env!("ONEFUZZ_VERSION"));
|
||||
.append_pair("version", env!("ONEFUZZ_VERSION"))
|
||||
.append_pair("os", std::env::consts::OS);
|
||||
|
||||
if managed {
|
||||
let scaleset = onefuzz::machine_id::get_scaleset_name().await?;
|
||||
|
@ -277,7 +277,7 @@ async fn run_agent(config: StaticConfig) -> Result<()> {
|
||||
let registration = match config::Registration::load_existing(config.clone()).await {
|
||||
Ok(registration) => registration,
|
||||
Err(_) => {
|
||||
if scaleset.is_some() {
|
||||
if config.managed {
|
||||
config::Registration::create_managed(config.clone()).await?
|
||||
} else {
|
||||
config::Registration::create_unmanaged(config.clone()).await?
|
||||
|
@ -134,7 +134,6 @@ impl ClientCredentials {
|
||||
|
||||
let response = reqwest::Client::new()
|
||||
.post(url)
|
||||
.header("Content-Length", "0")
|
||||
.form(&[
|
||||
("client_id", self.client_id.to_hyphenated().to_string()),
|
||||
("client_secret", self.client_secret.expose_ref().to_string()),
|
||||
|
@ -1268,13 +1268,7 @@ class Pool(Endpoint):
|
||||
if pool.config is None:
|
||||
raise Exception("Missing AgentConfig in response")
|
||||
|
||||
config = pool.config
|
||||
config.client_credentials = models.ClientCredentials( # nosec - bandit consider this a hard coded password
|
||||
client_id=pool.client_id,
|
||||
client_secret="<client secret>",
|
||||
)
|
||||
|
||||
return config
|
||||
return pool.config
|
||||
|
||||
def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult:
|
||||
expanded_name = self._disambiguate(
|
||||
|
@ -339,6 +339,7 @@ class AgentConfig(BaseModel):
|
||||
microsoft_telemetry_key: Optional[str]
|
||||
multi_tenant_domain: Optional[str]
|
||||
instance_id: UUID
|
||||
managed: Optional[bool] = Field(default=True)
|
||||
|
||||
|
||||
class TaskUnitConfig(BaseModel):
|
||||
|
Reference in New Issue
Block a user