Bug fixes related to the unmanaged nodes (#2632)

* bug fixes related to the unmanaged nodes
- fix the token request in the client
- adding a is_unmnaged field to the agentConfig
- fix typo in the appId claim type
- fix typo in the UnmanagedNode claim value
- fix query PoolOperation.GetByClientId

* remove unused import

* build fix

* change unmanaged  field to managed

* Apply suggestions from code review

Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com>

Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com>
This commit is contained in:
Cheick Keita
2022-11-17 12:51:52 -08:00
committed by GitHub
parent 5bc3dac1ae
commit 1f46388e6d
11 changed files with 37 additions and 29 deletions

View File

@ -133,7 +133,8 @@ public class Pool {
HeartbeatQueue: queueSas, HeartbeatQueue: queueSas,
InstanceId: instanceId, InstanceId: instanceId,
ClientCredentials: null, ClientCredentials: null,
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain) MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
Managed: p.Managed)
}; };
} }
} }

View File

@ -678,12 +678,11 @@ public record AgentConfig(
string? InstanceTelemetryKey, string? InstanceTelemetryKey,
string? MicrosoftTelemetryKey, string? MicrosoftTelemetryKey,
string? MultiTenantDomain, string? MultiTenantDomain,
Guid InstanceId Guid InstanceId,
bool? Managed = true
); );
public record Vm( public record Vm(
string Name, string Name,
Region Region, Region Region,

View File

@ -77,7 +77,7 @@ public class UserCredentials : IUserCredentials {
switch (claim.Type) { switch (claim.Type) {
case "oid": case "oid":
return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } }; return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } };
case "appId": case "appid":
return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } }; return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } };
case "upn": case "upn":
return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } }; return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } };
@ -88,7 +88,6 @@ public class UserCredentials : IUserCredentials {
return acc; return acc;
} }
}); });
return OneFuzzResult<UserAuthInfo>.Ok(userInfo); return OneFuzzResult<UserAuthInfo>.Ok(userInfo);
} else { } else {
var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!); var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!);

View File

@ -30,8 +30,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
private readonly IOnefuzzContext _context; private readonly IOnefuzzContext _context;
private readonly ILogTracer _log; private readonly ILogTracer _log;
private readonly GraphServiceClient _graphClient; private readonly GraphServiceClient _graphClient;
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmanagedNode", "ManagedNode" };
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmamagedNode", "ManagedNode" };
public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) { public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) {
_context = context; _context = context;
@ -46,10 +45,10 @@ public class EndpointAuthorization : IEndpointAuthorization {
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized); return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
} }
var token = tokenResult.OkV; var token = tokenResult.OkV.UserInfo;
if (await IsUser(token)) { if (await IsUser(tokenResult.OkV)) {
if (!allowUser) { if (!allowUser) {
return await Reject(req, tokenResult.OkV.UserInfo); return await Reject(req, token);
} }
var access = await CheckAccess(req); var access = await CheckAccess(req);
@ -58,8 +57,8 @@ public class EndpointAuthorization : IEndpointAuthorization {
} }
} }
if (await IsAgent(token) && !allowAgent) { if (await IsAgent(tokenResult.OkV) && !allowAgent) {
return await Reject(req, tokenResult.OkV.UserInfo); return await Reject(req, token);
} }
return await method(req); return await method(req);
@ -201,7 +200,9 @@ public class EndpointAuthorization : IEndpointAuthorization {
} }
var principalId = await _context.Creds.GetScalesetPrincipalId(); var principalId = await _context.Creds.GetScalesetPrincipalId();
return principalId == tokenData.ObjectId; if (principalId == tokenData.ObjectId) {
return true;
}
} }
if (!tokenData.ApplicationId.HasValue) { if (!tokenData.ApplicationId.HasValue) {

View File

@ -220,7 +220,8 @@ public class Extensions : IExtensions {
InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey, InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey,
MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry, MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry,
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain, MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
InstanceId: instanceId InstanceId: instanceId,
Managed: pool.Managed
); );
var fileName = $"{pool.Name}/config.json"; var fileName = $"{pool.Name}/config.json";

View File

@ -1,6 +1,5 @@
using System.Threading.Tasks; using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm; using ApiService.OneFuzzLib.Orm;
using Azure.Data.Tables;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
public interface IPoolOperations : IStatefulOrm<Pool, PoolState> { public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
@ -89,7 +88,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
} }
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) { public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
return QueryAsync(filter: TableClient.CreateQueryFilter($"client_id eq {clientId}")); return QueryAsync(filter: $"client_id eq '{clientId}'");
} }
public string GetPoolQueue(Guid poolId) public string GetPoolQueue(Guid poolId)

View File

@ -34,12 +34,19 @@ pub struct StaticConfig {
pub heartbeat_queue: Option<Url>, pub heartbeat_queue: Option<Url>,
pub instance_id: Uuid, pub instance_id: Uuid,
#[serde(default = "default_as_true")]
pub managed: bool,
}
fn default_as_true() -> bool {
true
} }
// Temporary shim type to bridge the current service-provided config. // Temporary shim type to bridge the current service-provided config.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] #[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct RawStaticConfig { struct RawStaticConfig {
pub credentials: Option<ClientCredentials>, pub client_credentials: Option<ClientCredentials>,
pub pool_name: String, pub pool_name: String,
@ -54,13 +61,16 @@ struct RawStaticConfig {
pub heartbeat_queue: Option<Url>, pub heartbeat_queue: Option<Url>,
pub instance_id: Uuid, pub instance_id: Uuid,
#[serde(default = "default_as_true")]
pub managed: bool,
} }
impl StaticConfig { impl StaticConfig {
pub fn new(data: &[u8]) -> Result<Self> { pub fn new(data: &[u8]) -> Result<Self> {
let config: RawStaticConfig = serde_json::from_slice(data)?; let config: RawStaticConfig = serde_json::from_slice(data)?;
let credentials = match config.credentials { let credentials = match config.client_credentials {
Some(client) => client.into(), Some(client) => client.into(),
None => { None => {
// Remove trailing `/`, which is treated as a distinct resource. // Remove trailing `/`, which is treated as a distinct resource.
@ -83,6 +93,7 @@ impl StaticConfig {
instance_telemetry_key: config.instance_telemetry_key, instance_telemetry_key: config.instance_telemetry_key,
heartbeat_queue: config.heartbeat_queue, heartbeat_queue: config.heartbeat_queue,
instance_id: config.instance_id, instance_id: config.instance_id,
managed: config.managed,
}; };
Ok(config) Ok(config)
@ -103,6 +114,7 @@ impl StaticConfig {
let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok(); let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok();
let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?; let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?;
let pool_name = std::env::var("ONEFUZZ_POOL")?; let pool_name = std::env::var("ONEFUZZ_POOL")?;
let is_unmanaged = std::env::var("ONEFUZZ_IS_UNMANAGED").is_ok();
let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") { let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") {
Some(Url::parse(&key)?) Some(Url::parse(&key)?)
@ -142,6 +154,7 @@ impl StaticConfig {
microsoft_telemetry_key, microsoft_telemetry_key,
heartbeat_queue, heartbeat_queue,
instance_id, instance_id,
managed: !is_unmanaged,
}) })
} }
@ -213,7 +226,8 @@ impl Registration {
.append_pair("machine_id", &machine_id.to_string()) .append_pair("machine_id", &machine_id.to_string())
.append_pair("machine_name", &machine_name) .append_pair("machine_name", &machine_name)
.append_pair("pool_name", &config.pool_name) .append_pair("pool_name", &config.pool_name)
.append_pair("version", env!("ONEFUZZ_VERSION")); .append_pair("version", env!("ONEFUZZ_VERSION"))
.append_pair("os", std::env::consts::OS);
if managed { if managed {
let scaleset = onefuzz::machine_id::get_scaleset_name().await?; let scaleset = onefuzz::machine_id::get_scaleset_name().await?;

View File

@ -277,7 +277,7 @@ async fn run_agent(config: StaticConfig) -> Result<()> {
let registration = match config::Registration::load_existing(config.clone()).await { let registration = match config::Registration::load_existing(config.clone()).await {
Ok(registration) => registration, Ok(registration) => registration,
Err(_) => { Err(_) => {
if scaleset.is_some() { if config.managed {
config::Registration::create_managed(config.clone()).await? config::Registration::create_managed(config.clone()).await?
} else { } else {
config::Registration::create_unmanaged(config.clone()).await? config::Registration::create_unmanaged(config.clone()).await?

View File

@ -134,7 +134,6 @@ impl ClientCredentials {
let response = reqwest::Client::new() let response = reqwest::Client::new()
.post(url) .post(url)
.header("Content-Length", "0")
.form(&[ .form(&[
("client_id", self.client_id.to_hyphenated().to_string()), ("client_id", self.client_id.to_hyphenated().to_string()),
("client_secret", self.client_secret.expose_ref().to_string()), ("client_secret", self.client_secret.expose_ref().to_string()),

View File

@ -1268,13 +1268,7 @@ class Pool(Endpoint):
if pool.config is None: if pool.config is None:
raise Exception("Missing AgentConfig in response") raise Exception("Missing AgentConfig in response")
config = pool.config return pool.config
config.client_credentials = models.ClientCredentials( # nosec - bandit consider this a hard coded password
client_id=pool.client_id,
client_secret="<client secret>",
)
return config
def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult: def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult:
expanded_name = self._disambiguate( expanded_name = self._disambiguate(

View File

@ -339,6 +339,7 @@ class AgentConfig(BaseModel):
microsoft_telemetry_key: Optional[str] microsoft_telemetry_key: Optional[str]
multi_tenant_domain: Optional[str] multi_tenant_domain: Optional[str]
instance_id: UUID instance_id: UUID
managed: Optional[bool] = Field(default=True)
class TaskUnitConfig(BaseModel): class TaskUnitConfig(BaseModel):