Bug fixes related to the unmanaged nodes (#2632)

* bug fixes related to the unmanaged nodes
- fix the token request in the client
- adding a is_unmnaged field to the agentConfig
- fix typo in the appId claim type
- fix typo in the UnmanagedNode claim value
- fix query PoolOperation.GetByClientId

* remove unused import

* build fix

* change unmanaged  field to managed

* Apply suggestions from code review

Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com>

Co-authored-by: Teo Voinea <58236992+tevoinea@users.noreply.github.com>
This commit is contained in:
Cheick Keita
2022-11-17 12:51:52 -08:00
committed by GitHub
parent 5bc3dac1ae
commit 1f46388e6d
11 changed files with 37 additions and 29 deletions

View File

@ -133,7 +133,8 @@ public class Pool {
HeartbeatQueue: queueSas,
InstanceId: instanceId,
ClientCredentials: null,
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain)
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
Managed: p.Managed)
};
}
}

View File

@ -678,12 +678,11 @@ public record AgentConfig(
string? InstanceTelemetryKey,
string? MicrosoftTelemetryKey,
string? MultiTenantDomain,
Guid InstanceId
Guid InstanceId,
bool? Managed = true
);
public record Vm(
string Name,
Region Region,

View File

@ -77,7 +77,7 @@ public class UserCredentials : IUserCredentials {
switch (claim.Type) {
case "oid":
return acc with { UserInfo = acc.UserInfo with { ObjectId = Guid.Parse(claim.Value) } };
case "appId":
case "appid":
return acc with { UserInfo = acc.UserInfo with { ApplicationId = Guid.Parse(claim.Value) } };
case "upn":
return acc with { UserInfo = acc.UserInfo with { Upn = claim.Value } };
@ -88,7 +88,6 @@ public class UserCredentials : IUserCredentials {
return acc;
}
});
return OneFuzzResult<UserAuthInfo>.Ok(userInfo);
} else {
var tenantsStr = allowedTenants.OkV is null ? "null" : String.Join(';', allowedTenants.OkV!);

View File

@ -30,8 +30,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
private readonly IOnefuzzContext _context;
private readonly ILogTracer _log;
private readonly GraphServiceClient _graphClient;
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmamagedNode", "ManagedNode" };
private static readonly HashSet<string> AgentRoles = new HashSet<string> { "UnmanagedNode", "ManagedNode" };
public EndpointAuthorization(IOnefuzzContext context, ILogTracer log, GraphServiceClient graphClient) {
_context = context;
@ -46,10 +45,10 @@ public class EndpointAuthorization : IEndpointAuthorization {
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
}
var token = tokenResult.OkV;
if (await IsUser(token)) {
var token = tokenResult.OkV.UserInfo;
if (await IsUser(tokenResult.OkV)) {
if (!allowUser) {
return await Reject(req, tokenResult.OkV.UserInfo);
return await Reject(req, token);
}
var access = await CheckAccess(req);
@ -58,8 +57,8 @@ public class EndpointAuthorization : IEndpointAuthorization {
}
}
if (await IsAgent(token) && !allowAgent) {
return await Reject(req, tokenResult.OkV.UserInfo);
if (await IsAgent(tokenResult.OkV) && !allowAgent) {
return await Reject(req, token);
}
return await method(req);
@ -201,7 +200,9 @@ public class EndpointAuthorization : IEndpointAuthorization {
}
var principalId = await _context.Creds.GetScalesetPrincipalId();
return principalId == tokenData.ObjectId;
if (principalId == tokenData.ObjectId) {
return true;
}
}
if (!tokenData.ApplicationId.HasValue) {

View File

@ -220,7 +220,8 @@ public class Extensions : IExtensions {
InstanceTelemetryKey: _context.ServiceConfiguration.ApplicationInsightsInstrumentationKey,
MicrosoftTelemetryKey: _context.ServiceConfiguration.OneFuzzTelemetry,
MultiTenantDomain: _context.ServiceConfiguration.MultiTenantDomain,
InstanceId: instanceId
InstanceId: instanceId,
Managed: pool.Managed
);
var fileName = $"{pool.Name}/config.json";

View File

@ -1,6 +1,5 @@
using System.Threading.Tasks;
using ApiService.OneFuzzLib.Orm;
using Azure.Data.Tables;
namespace Microsoft.OneFuzz.Service;
public interface IPoolOperations : IStatefulOrm<Pool, PoolState> {
@ -89,7 +88,7 @@ public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoo
}
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
return QueryAsync(filter: TableClient.CreateQueryFilter($"client_id eq {clientId}"));
return QueryAsync(filter: $"client_id eq '{clientId}'");
}
public string GetPoolQueue(Guid poolId)

View File

@ -34,12 +34,19 @@ pub struct StaticConfig {
pub heartbeat_queue: Option<Url>,
pub instance_id: Uuid,
#[serde(default = "default_as_true")]
pub managed: bool,
}
fn default_as_true() -> bool {
true
}
// Temporary shim type to bridge the current service-provided config.
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
struct RawStaticConfig {
pub credentials: Option<ClientCredentials>,
pub client_credentials: Option<ClientCredentials>,
pub pool_name: String,
@ -54,13 +61,16 @@ struct RawStaticConfig {
pub heartbeat_queue: Option<Url>,
pub instance_id: Uuid,
#[serde(default = "default_as_true")]
pub managed: bool,
}
impl StaticConfig {
pub fn new(data: &[u8]) -> Result<Self> {
let config: RawStaticConfig = serde_json::from_slice(data)?;
let credentials = match config.credentials {
let credentials = match config.client_credentials {
Some(client) => client.into(),
None => {
// Remove trailing `/`, which is treated as a distinct resource.
@ -83,6 +93,7 @@ impl StaticConfig {
instance_telemetry_key: config.instance_telemetry_key,
heartbeat_queue: config.heartbeat_queue,
instance_id: config.instance_id,
managed: config.managed,
};
Ok(config)
@ -103,6 +114,7 @@ impl StaticConfig {
let multi_tenant_domain = std::env::var("ONEFUZZ_MULTI_TENANT_DOMAIN").ok();
let onefuzz_url = Url::parse(&std::env::var("ONEFUZZ_URL")?)?;
let pool_name = std::env::var("ONEFUZZ_POOL")?;
let is_unmanaged = std::env::var("ONEFUZZ_IS_UNMANAGED").is_ok();
let heartbeat_queue = if let Ok(key) = std::env::var("ONEFUZZ_HEARTBEAT") {
Some(Url::parse(&key)?)
@ -142,6 +154,7 @@ impl StaticConfig {
microsoft_telemetry_key,
heartbeat_queue,
instance_id,
managed: !is_unmanaged,
})
}
@ -213,7 +226,8 @@ impl Registration {
.append_pair("machine_id", &machine_id.to_string())
.append_pair("machine_name", &machine_name)
.append_pair("pool_name", &config.pool_name)
.append_pair("version", env!("ONEFUZZ_VERSION"));
.append_pair("version", env!("ONEFUZZ_VERSION"))
.append_pair("os", std::env::consts::OS);
if managed {
let scaleset = onefuzz::machine_id::get_scaleset_name().await?;

View File

@ -277,7 +277,7 @@ async fn run_agent(config: StaticConfig) -> Result<()> {
let registration = match config::Registration::load_existing(config.clone()).await {
Ok(registration) => registration,
Err(_) => {
if scaleset.is_some() {
if config.managed {
config::Registration::create_managed(config.clone()).await?
} else {
config::Registration::create_unmanaged(config.clone()).await?

View File

@ -134,7 +134,6 @@ impl ClientCredentials {
let response = reqwest::Client::new()
.post(url)
.header("Content-Length", "0")
.form(&[
("client_id", self.client_id.to_hyphenated().to_string()),
("client_secret", self.client_secret.expose_ref().to_string()),

View File

@ -1268,13 +1268,7 @@ class Pool(Endpoint):
if pool.config is None:
raise Exception("Missing AgentConfig in response")
config = pool.config
config.client_credentials = models.ClientCredentials( # nosec - bandit consider this a hard coded password
client_id=pool.client_id,
client_secret="<client secret>",
)
return config
return pool.config
def shutdown(self, name: str, *, now: bool = False) -> responses.BoolResult:
expanded_name = self._disambiguate(

View File

@ -339,6 +339,7 @@ class AgentConfig(BaseModel):
microsoft_telemetry_key: Optional[str]
multi_tenant_domain: Optional[str]
instance_id: UUID
managed: Optional[bool] = Field(default=True)
class TaskUnitConfig(BaseModel):