mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-18 04:38:09 +00:00
Bug fixes related to the unmanaged nodes (#2632)
* bug fixes related to the unmanaged nodes - fix the token request in the client - adding a is_unmnaged field to the agentConfig - fix typo in the appId claim type - fix typo in the UnmanagedNode claim value - fix query PoolOperation.GetByClientId * remove unused import * build fix * change unmanaged field to managed * Apply suggestions from code review Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com> Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com>
This commit is contained in:
@ -133,7 +133,8 @@ public class Pool {
|
|||||||
HeartbeatQueue: queueSas,
|
HeartbeatQueue: queueSas,
|
||||||
InstanceId: instanceId,
|
InstanceId: instanceId,
|
||||||
ClientCredentials: null,
|
ClientCredentials: null,
|
||||||
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain)
|
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
|
||||||
|
Managed: p.Managed)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -678,12 +678,11 @@ public record AgentConfig(
|
|||||||
string? InstanceTelemetryKey,
|
string? InstanceTelemetryKey,
|
||||||
string? MicrosoftTelemetryKey,
|
string? MicrosoftTelemetryKey,
|
||||||
string? MultiTenantDomain,
|
string? MultiTenantDomain,
|
||||||
Guid InstanceId
|
Guid InstanceId,
|
||||||
|
bool? Managed = true
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public record Vm(
|
public record Vm(
|
||||||
string Name,
|
string Name,
|
||||||
Region Region,
|
Region Region,
|
||||||
|
@ -77,7 +77,7 @@ public class UserCredentials : IUserCredentials {
|
|||||||
switch (claim.Type) {
|
switch (claim.Type) {
|
||||||
case "oid":
|
case "oid":
|
||||||
return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } };
|
return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } };
|
||||||
case "appId":
|
case "appid":
|
||||||
return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } };
|
return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } };
|
||||||
case "upn":
|
case "upn":
|
||||||
return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } };
|
return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } };
|
||||||
@ -88,7 +88,6 @@ public class UserCredentials : IUserCredentials {
|
|||||||
return acc;
|
return acc;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
return OneFuzzResult<UserAuthInfo>.Ok(userInfo);
|
return OneFuzzResult<UserAuthInfo>.Ok(userInfo);
|
||||||
} else {
|
} else {
|
||||||
var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!);
|
var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!);
|
||||||
|
@ -30,8 +30,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
|||||||
private readonly IOnefuzzContext _context;
|
private readonly IOnefuzzContext _context;
|
||||||
private readonly ILogTracer _log;
|
private readonly ILogTracer _log;
|
||||||
private readonly GraphServiceClient _graphClient;
|
private readonly GraphServiceClient _graphClient;
|
||||||
|
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmanagedNode", "ManagedNode" };
|
||||||
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmamagedNode", "ManagedNode" };
|
|
||||||
|
|
||||||
public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) {
|
public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) {
|
||||||
_context = context;
|
_context = context;
|
||||||
@ -46,10 +45,10 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
|||||||
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
|
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
|
||||||
}
|
}
|
||||||
|
|
||||||
var token = tokenResult.OkV;
|
var token = tokenResult.OkV.UserInfo;
|
||||||
if (await IsUser(token)) {
|
if (await IsUser(tokenResult.OkV)) {
|
||||||
if (!allowUser) {
|
if (!allowUser) {
|
||||||
return await Reject(req, tokenResult.OkV.UserInfo);
|
return await Reject(req, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
var access = await CheckAccess(req);
|
var access = await CheckAccess(req);
|
||||||
@ -58,8 +57,8 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (await IsAgent(token) && !allowAgent) {
|
if (await IsAgent(tokenResult.OkV) && !allowAgent) {
|
||||||
return await Reject(req, tokenResult.OkV.UserInfo);
|
return await Reject(req, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
return await method(req);
|
return await method(req);
|
||||||
@ -201,7 +200,9 @@ public class EndpointAuthorization : IEndpointAuthorization {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var principalId = await _context.Creds.GetScalesetPrincipalId();
|
var principalId = await _context.Creds.GetScalesetPrincipalId();
|
||||||
return principalId == tokenData.ObjectId;
|
if (principalId == tokenData.ObjectId) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tokenData.ApplicationId.HasValue) {
|
if (!tokenData.ApplicationId.HasValue) {
|
||||||
|
@ -220,7 +220,8 @@ public class Extensions : IExtensions {
|
|||||||
InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey,
|
InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey,
|
||||||
MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry,
|
MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry,
|
||||||
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
|
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
|
||||||
InstanceId: instanceId
|
InstanceId: instanceId,
|
||||||
|
Managed: pool.Managed
|
||||||
);
|
);
|
||||||
|
|
||||||
var fileName = $"{pool.Name}/config.json";
|
var fileName = $"{pool.Name}/config.json";
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using ApiService.OneFuzzLib.Orm;
|
using ApiService.OneFuzzLib.Orm;
|
||||||
using Azure.Data.Tables;
|
|
||||||
namespace Microsoft.OneFuzz.Service;
|
namespace Microsoft.OneFuzz.Service;
|
||||||
|
|
||||||
public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
|
public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
|
||||||
@ -89,7 +88,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
|
|||||||
}
|
}
|
||||||
|
|
||||||
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
|
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
|
||||||
return QueryAsync(filter: TableClient.CreateQueryFilter($"client_id eq {clientId}"));
|
return QueryAsync(filter: $"client_id eq '{clientId}'");
|
||||||
}
|
}
|
||||||
|
|
||||||
public string GetPoolQueue(Guid poolId)
|
public string GetPoolQueue(Guid poolId)
|
||||||
|
@ -34,12 +34,19 @@ pub struct StaticConfig {
|
|||||||
pub heartbeat_queue: Option<Url>,
|
pub heartbeat_queue: Option<Url>,
|
||||||
|
|
||||||
pub instance_id: Uuid,
|
pub instance_id: Uuid,
|
||||||
|
|
||||||
|
#[serde(default = "default_as_true")]
|
||||||
|
pub managed: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_as_true() -> bool {
|
||||||
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Temporary shim type to bridge the current service-provided config.
|
// Temporary shim type to bridge the current service-provided config.
|
||||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
||||||
struct RawStaticConfig {
|
struct RawStaticConfig {
|
||||||
pub credentials: Option<ClientCredentials>,
|
pub client_credentials: Option<ClientCredentials>,
|
||||||
|
|
||||||
pub pool_name: String,
|
pub pool_name: String,
|
||||||
|
|
||||||
@ -54,13 +61,16 @@ struct RawStaticConfig {
|
|||||||
pub heartbeat_queue: Option<Url>,
|
pub heartbeat_queue: Option<Url>,
|
||||||
|
|
||||||
pub instance_id: Uuid,
|
pub instance_id: Uuid,
|
||||||
|
|
||||||
|
#[serde(default = "default_as_true")]
|
||||||
|
pub managed: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StaticConfig {
|
impl StaticConfig {
|
||||||
pub fn new(data: &[u8]) -> Result<Self> {
|
pub fn new(data: &[u8]) -> Result<Self> {
|
||||||
let config: RawStaticConfig = serde_json::from_slice(data)?;
|
let config: RawStaticConfig = serde_json::from_slice(data)?;
|
||||||
|
|
||||||
let credentials = match config.credentials {
|
let credentials = match config.client_credentials {
|
||||||
Some(client) => client.into(),
|
Some(client) => client.into(),
|
||||||
None => {
|
None => {
|
||||||
// Remove trailing `/`, which is treated as a distinct resource.
|
// Remove trailing `/`, which is treated as a distinct resource.
|
||||||
@ -83,6 +93,7 @@ impl StaticConfig {
|
|||||||
instance_telemetry_key: config.instance_telemetry_key,
|
instance_telemetry_key: config.instance_telemetry_key,
|
||||||
heartbeat_queue: config.heartbeat_queue,
|
heartbeat_queue: config.heartbeat_queue,
|
||||||
instance_id: config.instance_id,
|
instance_id: config.instance_id,
|
||||||
|
managed: config.managed,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(config)
|
Ok(config)
|
||||||
@ -103,6 +114,7 @@ impl StaticConfig {
|
|||||||
let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok();
|
let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok();
|
||||||
let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?;
|
let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?;
|
||||||
let pool_name = std::env::var("ONEFUZZ_POOL")?;
|
let pool_name = std::env::var("ONEFUZZ_POOL")?;
|
||||||
|
let is_unmanaged = std::env::var("ONEFUZZ_IS_UNMANAGED").is_ok();
|
||||||
|
|
||||||
let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") {
|
let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") {
|
||||||
Some(Url::parse(&key)?)
|
Some(Url::parse(&key)?)
|
||||||
@ -142,6 +154,7 @@ impl StaticConfig {
|
|||||||
microsoft_telemetry_key,
|
microsoft_telemetry_key,
|
||||||
heartbeat_queue,
|
heartbeat_queue,
|
||||||
instance_id,
|
instance_id,
|
||||||
|
managed: !is_unmanaged,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,7 +226,8 @@ impl Registration {
|
|||||||
.append_pair("machine_id", &machine_id.to_string())
|
.append_pair("machine_id", &machine_id.to_string())
|
||||||
.append_pair("machine_name", &machine_name)
|
.append_pair("machine_name", &machine_name)
|
||||||
.append_pair("pool_name", &config.pool_name)
|
.append_pair("pool_name", &config.pool_name)
|
||||||
.append_pair("version", env!("ONEFUZZ_VERSION"));
|
.append_pair("version", env!("ONEFUZZ_VERSION"))
|
||||||
|
.append_pair("os", std::env::consts::OS);
|
||||||
|
|
||||||
if managed {
|
if managed {
|
||||||
let scaleset = onefuzz::machine_id::get_scaleset_name().await?;
|
let scaleset = onefuzz::machine_id::get_scaleset_name().await?;
|
||||||
|
@ -277,7 +277,7 @@ async fn run_agent(config: StaticConfig) -> Result<()> {
|
|||||||
let registration = match config::Registration::load_existing(config.clone()).await {
|
let registration = match config::Registration::load_existing(config.clone()).await {
|
||||||
Ok(registration) => registration,
|
Ok(registration) => registration,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
if scaleset.is_some() {
|
if config.managed {
|
||||||
config::Registration::create_managed(config.clone()).await?
|
config::Registration::create_managed(config.clone()).await?
|
||||||
} else {
|
} else {
|
||||||
config::Registration::create_unmanaged(config.clone()).await?
|
config::Registration::create_unmanaged(config.clone()).await?
|
||||||
|
@ -134,7 +134,6 @@ impl ClientCredentials {
|
|||||||
|
|
||||||
let response = reqwest::Client::new()
|
let response = reqwest::Client::new()
|
||||||
.post(url)
|
.post(url)
|
||||||
.header("Content-Length", "0")
|
|
||||||
.form(&[
|
.form(&[
|
||||||
("client_id", self.client_id.to_hyphenated().to_string()),
|
("client_id", self.client_id.to_hyphenated().to_string()),
|
||||||
("client_secret", self.client_secret.expose_ref().to_string()),
|
("client_secret", self.client_secret.expose_ref().to_string()),
|
||||||
|
@ -1268,13 +1268,7 @@ class Pool(Endpoint):
|
|||||||
if pool.config is None:
|
if pool.config is None:
|
||||||
raise Exception("Missing AgentConfig in response")
|
raise Exception("Missing AgentConfig in response")
|
||||||
|
|
||||||
config = pool.config
|
return pool.config
|
||||||
config.client_credentials = models.ClientCredentials( # nosec - bandit consider this a hard coded password
|
|
||||||
client_id=pool.client_id,
|
|
||||||
client_secret="<client secret>",
|
|
||||||
)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult:
|
def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult:
|
||||||
expanded_name = self._disambiguate(
|
expanded_name = self._disambiguate(
|
||||||
|
@ -339,6 +339,7 @@ class AgentConfig(BaseModel):
|
|||||||
microsoft_telemetry_key: Optional[str]
|
microsoft_telemetry_key: Optional[str]
|
||||||
multi_tenant_domain: Optional[str]
|
multi_tenant_domain: Optional[str]
|
||||||
instance_id: UUID
|
instance_id: UUID
|
||||||
|
managed: Optional[bool] = Field(default=True)
|
||||||
|
|
||||||
|
|
||||||
class TaskUnitConfig(BaseModel):
|
class TaskUnitConfig(BaseModel):
|
||||||
|
Reference in New Issue
Block a user