Refactor agent commands (#1922)

* Checkpoint

* Disable the function for now

* snapshot

* Tested locally

* fmt
This commit is contained in:
Teo Voinea
2022-05-11 10:06:09 -04:00
committed by GitHub
parent c9b46e983d
commit 5f4a02503a
14 changed files with 252 additions and 39 deletions

View File

@ -6,34 +6,25 @@ namespace Microsoft.OneFuzz.Service;
public class AgentCanSchedule { public class AgentCanSchedule {
private readonly ILogTracer _log; private readonly ILogTracer _log;
private readonly IStorage _storage; private readonly IOnefuzzContext _context;
private readonly INodeOperations _nodeOperations; public AgentCanSchedule(ILogTracer log, IOnefuzzContext context) {
private readonly ITaskOperations _taskOperations;
private readonly IScalesetOperations _scalesetOperations;
public AgentCanSchedule(ILogTracer log, IStorage storage, INodeOperations nodeOperations, ITaskOperations taskOperations, IScalesetOperations scalesetOperations) {
_log = log; _log = log;
_storage = storage; _context = context;
_nodeOperations = nodeOperations;
_taskOperations = taskOperations;
_scalesetOperations = scalesetOperations;
} }
// [Function("AgentCanSchedule")] // [Function("AgentCanSchedule")]
public async Async.Task<HttpResponseData> Run([HttpTrigger] HttpRequestData req) { public async Async.Task<HttpResponseData> Run([HttpTrigger] HttpRequestData req) {
var request = await RequestHandling.ParseRequest<CanScheduleRequest>(req); var request = await RequestHandling.ParseRequest<CanScheduleRequest>(req);
if (!request.IsOk || request.OkV == null) { if (!request.IsOk || request.OkV == null) {
return await RequestHandling.NotOk(req, request.ErrorV, typeof(CanScheduleRequest).ToString(), _log); return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(CanScheduleRequest).ToString());
} }
var canScheduleRequest = request.OkV; var canScheduleRequest = request.OkV;
var node = await _nodeOperations.GetByMachineId(canScheduleRequest.MachineId); var node = await _context.NodeOperations.GetByMachineId(canScheduleRequest.MachineId);
if (node == null) { if (node == null) {
return await RequestHandling.NotOk( return await _context.RequestHandling.NotOk(
req, req,
new Error( new Error(
ErrorCode.UNABLE_TO_FIND, ErrorCode.UNABLE_TO_FIND,
@ -41,29 +32,24 @@ public class AgentCanSchedule {
"unable to find node" "unable to find node"
} }
), ),
canScheduleRequest.MachineId.ToString(), canScheduleRequest.MachineId.ToString()
_log
); );
} }
var allowed = true; var allowed = true;
var workStopped = false; var workStopped = false;
if (!await _nodeOperations.CanProcessNewWork(node)) { if (!await _context.NodeOperations.CanProcessNewWork(node)) {
allowed = false; allowed = false;
} }
var task = await _taskOperations.GetByTaskId(canScheduleRequest.TaskId); var task = await _context.TaskOperations.GetByTaskId(canScheduleRequest.TaskId);
workStopped = task == null || TaskStateHelper.ShuttingDown.Contains(task.State); workStopped = task == null || TaskStateHelper.ShuttingDown.Contains(task.State);
if (allowed) { if (allowed) {
allowed = (await _nodeOperations.AcquireScaleInProtection(node)).IsOk; allowed = (await _context.NodeOperations.AcquireScaleInProtection(node)).IsOk;
} }
return await RequestHandling.Ok( return await RequestHandling.Ok(req, new CanSchedule(allowed, workStopped));
req,
new BaseResponse[] {
new CanSchedule(allowed, workStopped)
});
} }
} }

View File

@ -0,0 +1,57 @@
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
namespace Microsoft.OneFuzz.Service;
public class AgentCommands {
private readonly ILogTracer _log;
private readonly IOnefuzzContext _context;
public AgentCommands(ILogTracer log, IOnefuzzContext context) {
_log = log;
_context = context;
}
// [Function("AgentCommands")]
public async Async.Task<HttpResponseData> Run([HttpTrigger("get", "delete")] HttpRequestData req) {
return req.Method switch {
"GET" => await Get(req),
"DELETE" => await Delete(req),
_ => throw new NotImplementedException($"HTTP Method {req.Method} is not supported for this method")
};
}
private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeCommandGet>(req);
if (!request.IsOk || request.OkV == null) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(NodeCommandGet).ToString());
}
var nodeCommand = request.OkV;
var message = await _context.NodeMessageOperations.GetMessage(nodeCommand.MachineId).FirstOrDefaultAsync();
if (message != null) {
var command = message.Message;
var messageId = message.MessageId;
var envelope = new NodeCommandEnvelope(command, messageId);
return await RequestHandling.Ok(req, new PendingNodeCommand(envelope));
} else {
return await RequestHandling.Ok(req, new PendingNodeCommand(null));
}
}
private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeCommandDelete>(req);
if (!request.IsOk || request.OkV == null) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, typeof(NodeCommandDelete).ToString());
}
var nodeCommand = request.OkV;
var message = await _context.NodeMessageOperations.GetEntityAsync(nodeCommand.MachineId.ToString(), nodeCommand.MessageId);
if (message != null) {
await _context.NodeMessageOperations.Delete(message);
}
return await RequestHandling.Ok(req, new BoolResult(true));
}
}

View File

@ -695,3 +695,8 @@ public record TaskUnitConfig(
public IContainerDef? RegressionReport { get; set; } public IContainerDef? RegressionReport { get; set; }
} }
public record NodeCommandEnvelope(
NodeCommand Command,
string MessageId
);

View File

@ -6,3 +6,12 @@ public record CanScheduleRequest(
Guid MachineId, Guid MachineId,
Guid TaskId Guid TaskId
) : BaseRequest; ) : BaseRequest;
public record NodeCommandGet(
Guid MachineId
) : BaseRequest;
public record NodeCommandDelete(
Guid MachineId,
string MessageId
) : BaseRequest;

View File

@ -1,8 +1,32 @@
namespace Microsoft.OneFuzz.Service; using System.Text.Json;
using System.Text.Json.Serialization;
public record BaseResponse(); namespace Microsoft.OneFuzz.Service;
[JsonConverter(typeof(BaseResponseConverter))]
public abstract record BaseResponse();
public record CanSchedule( public record CanSchedule(
bool Allowed, bool Allowed,
bool WorkStopped bool WorkStopped
) : BaseResponse; ) : BaseResponse();
public record PendingNodeCommand(
NodeCommandEnvelope? Envelope
) : BaseResponse();
public record BoolResult(
bool Result
) : BaseResponse();
public class BaseResponseConverter : JsonConverter<BaseResponse> {
public override BaseResponse? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
return null;
}
public override void Write(Utf8JsonWriter writer, BaseResponse value, JsonSerializerOptions options) {
var eventType = value.GetType();
JsonSerializer.Serialize(writer, value, eventType, options);
}
}

View File

@ -103,6 +103,7 @@ public class Program {
.AddScoped<IVmssOperations, VmssOperations>() .AddScoped<IVmssOperations, VmssOperations>()
.AddScoped<INodeTasksOperations, NodeTasksOperations>() .AddScoped<INodeTasksOperations, NodeTasksOperations>()
.AddScoped<INodeMessageOperations, NodeMessageOperations>() .AddScoped<INodeMessageOperations, NodeMessageOperations>()
.AddScoped<IRequestHandling, RequestHandling>()
.AddScoped<IOnefuzzContext, OnefuzzContext>() .AddScoped<IOnefuzzContext, OnefuzzContext>()
.AddSingleton<ICreds, Creds>() .AddSingleton<ICreds, Creds>()

View File

@ -9,7 +9,7 @@ namespace Microsoft.OneFuzz.Service;
public interface IUserCredentials { public interface IUserCredentials {
public string? GetBearerToken(HttpRequestData req); public string? GetBearerToken(HttpRequestData req);
public string? GetAuthToken(HttpRequestData req); public string? GetAuthToken(HttpRequestData req);
public Task<OneFuzzResult<UserInfo>> ParseJwtToken(LogTracer log, HttpRequestData req); public Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req);
} }
public class UserCredentials : IUserCredentials { public class UserCredentials : IUserCredentials {
@ -58,7 +58,7 @@ public class UserCredentials : IUserCredentials {
return OneFuzzResult<string[]>.Ok(allowedAddTenantsQuery.ToArray()); return OneFuzzResult<string[]>.Ok(allowedAddTenantsQuery.ToArray());
} }
public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(LogTracer log, HttpRequestData req) { public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) {
var authToken = GetAuthToken(req); var authToken = GetAuthToken(req);
if (authToken is null) { if (authToken is null) {
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" }); return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" });
@ -84,11 +84,11 @@ public class UserCredentials : IUserCredentials {
return OneFuzzResult<UserInfo>.Ok(new(applicationId, objectId, upn)); return OneFuzzResult<UserInfo>.Ok(new(applicationId, objectId, upn));
} else { } else {
log.Error($"issuer not from allowed tenant: {token.Issuer} - {allowedTenants}"); _log.Error($"issuer not from allowed tenant: {token.Issuer} - {allowedTenants}");
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unauthorized AAD issuer" }); return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unauthorized AAD issuer" });
} }
} else { } else {
log.Error("Failed to get allowed tenants"); _log.Error("Failed to get allowed tenants");
return OneFuzzResult<UserInfo>.Error(allowedTenants.ErrorV); return OneFuzzResult<UserInfo>.Error(allowedTenants.ErrorV);
} }
} }

View File

@ -1,4 +1,5 @@
using Azure.Core; using System.Text.Json;
using Azure.Core;
using Azure.Identity; using Azure.Identity;
using Azure.ResourceManager; using Azure.ResourceManager;
using Azure.ResourceManager.Resources; using Azure.ResourceManager.Resources;
@ -23,6 +24,7 @@ public interface ICreds {
public Async.Task<string> GetBaseRegion(); public Async.Task<string> GetBaseRegion();
public Uri GetInstanceUrl(); public Uri GetInstanceUrl();
Guid GetScalesetPrincipalId();
} }
public class Creds : ICreds { public class Creds : ICreds {
@ -85,4 +87,19 @@ public class Creds : ICreds {
public Uri GetInstanceUrl() { public Uri GetInstanceUrl() {
return new Uri($"https://{GetInstanceName()}.azurewebsites.net"); return new Uri($"https://{GetInstanceName()}.azurewebsites.net");
} }
public Guid GetScalesetPrincipalId() {
var uid = ArmClient.GetGenericResource(
new ResourceIdentifier(GetScalesetIdentityResourcePath())
);
var principalId = JsonSerializer.Deserialize<JsonDocument>(uid.Data.Properties.ToString())?.RootElement.GetProperty("principalId").GetString()!;
return new Guid(principalId);
}
public string GetScalesetIdentityResourcePath() {
var scalesetIdName = $"{GetInstanceName()}-scalesetid";
var resourceGroupPath = $"/subscriptions/{GetSubscription()}/resourceGroups/{GetBaseResourceGroup()}/providers";
return $"{resourceGroupPath}/Microsoft.ManagedIdentity/userAssignedIdentities/{scalesetIdName}";
}
} }

View File

@ -0,0 +1,91 @@
using System.Net;
using Microsoft.Azure.Functions.Worker.Http;
namespace Microsoft.OneFuzz.Service;
public class EndpointAuthorization {
private readonly IOnefuzzContext _context;
private readonly ILogTracer _log;
public EndpointAuthorization(IOnefuzzContext context, ILogTracer log) {
_context = context;
_log = log;
}
public async Async.Task<HttpResponseData> CallIfAgent(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method) {
return await CallIf(req, method, allowAgent: true);
}
public async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method, bool allowUser = false, bool allowAgent = false) {
var tokenResult = await _context.UserCredentials.ParseJwtToken(req);
if (!tokenResult.IsOk) {
return await _context.RequestHandling.NotOk(req, tokenResult.ErrorV, "token verification", HttpStatusCode.Unauthorized);
}
var token = tokenResult.OkV!;
if (await IsUser(token)) {
if (!allowUser) {
return await Reject(req, token);
}
var access = CheckAccess(req);
if (!access.IsOk) {
return await _context.RequestHandling.NotOk(req, access.ErrorV, "access control", HttpStatusCode.Unauthorized);
}
}
if (await IsAgent(token) && !allowAgent) {
return await Reject(req, token);
}
return await method(req);
}
public async Async.Task<bool> IsUser(UserInfo tokenData) {
return !await IsAgent(tokenData);
}
public async Async.Task<HttpResponseData> Reject(HttpRequestData req, UserInfo token) {
_log.Error(
$"reject token. url:{req.Url} token:{token} body:{await req.ReadAsStringAsync()}"
);
return await _context.RequestHandling.NotOk(
req,
new Error(
ErrorCode.UNAUTHORIZED,
new string[] { "Unrecognized agent" }
),
"token verification",
HttpStatusCode.Unauthorized
);
}
public OneFuzzResultVoid CheckAccess(HttpRequestData req) {
throw new NotImplementedException();
}
public async Async.Task<bool> IsAgent(UserInfo tokenData) {
if (tokenData.ObjectId != null) {
var scalesets = _context.ScalesetOperations.GetByObjectId(tokenData.ObjectId.Value);
if (await scalesets.AnyAsync()) {
return true;
}
var principalId = _context.Creds.GetScalesetPrincipalId();
return principalId == tokenData.ObjectId;
}
if (!tokenData.ApplicationId.HasValue) {
return false;
}
var pools = _context.PoolOperations.GetByClientId(tokenData.ApplicationId.Value);
if (await pools.AnyAsync()) {
return true;
}
return false;
}
}

View File

@ -468,7 +468,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
} }
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId) { public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId) {
return QueryAsync($"machine_id eq '{machineId}'"); return QueryAsync($"PartitionKey eq '{machineId}'");
} }
public async Async.Task ClearMessages(Guid machineId) { public async Async.Task ClearMessages(Guid machineId) {

View File

@ -34,6 +34,8 @@ public interface IOnefuzzContext {
IVmssOperations VmssOperations { get; } IVmssOperations VmssOperations { get; }
IWebhookMessageLogOperations WebhookMessageLogOperations { get; } IWebhookMessageLogOperations WebhookMessageLogOperations { get; }
IWebhookOperations WebhookOperations { get; } IWebhookOperations WebhookOperations { get; }
IRequestHandling RequestHandling { get; }
} }
public class OnefuzzContext : IOnefuzzContext { public class OnefuzzContext : IOnefuzzContext {
@ -71,6 +73,8 @@ public class OnefuzzContext : IOnefuzzContext {
public ICreds Creds { get => _serviceProvider.GetService<ICreds>() ?? throw new Exception("No ICreds service"); } public ICreds Creds { get => _serviceProvider.GetService<ICreds>() ?? throw new Exception("No ICreds service"); }
public IServiceConfig ServiceConfiguration { get => _serviceProvider.GetService<IServiceConfig>() ?? throw new Exception("No IServiceConfiguration service"); } public IServiceConfig ServiceConfiguration { get => _serviceProvider.GetService<IServiceConfig>() ?? throw new Exception("No IServiceConfiguration service"); }
public IRequestHandling RequestHandling { get => _serviceProvider.GetService<IRequestHandling>() ?? throw new Exception("No IRequestHandling service"); }
public OnefuzzContext(IServiceProvider serviceProvider) { public OnefuzzContext(IServiceProvider serviceProvider) {
_serviceProvider = serviceProvider; _serviceProvider = serviceProvider;
} }

View File

@ -6,6 +6,7 @@ namespace Microsoft.OneFuzz.Service;
public interface IPoolOperations { public interface IPoolOperations {
public Async.Task<OneFuzzResult<Pool>> GetByName(string poolName); public Async.Task<OneFuzzResult<Pool>> GetByName(string poolName);
Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet); Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet);
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
} }
public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations { public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations {
@ -37,6 +38,10 @@ public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations {
return await _context.Queue.QueueObject(GetPoolQueue(pool), workSet, StorageType.Corpus); return await _context.Queue.QueueObject(GetPoolQueue(pool), workSet, StorageType.Corpus);
} }
public IAsyncEnumerable<Pool> GetByClientId(Guid clientId) {
return QueryAsync(filter: $"client_id eq '{clientId.ToString()}'");
}
private string GetPoolQueue(Pool pool) { private string GetPoolQueue(Pool pool) {
return $"pool-{pool.PoolId.ToString("N")}"; return $"pool-{pool.PoolId.ToString("N")}";
} }

View File

@ -1,14 +1,21 @@
using System.Net; using System.Net;
using Microsoft.Azure.Functions.Worker.Http; using Microsoft.Azure.Functions.Worker.Http;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
public class RequestHandling { public interface IRequestHandling {
public static async Async.Task<HttpResponseData> NotOk(HttpRequestData request, Error error, string context, ILogTracer log, HttpStatusCode statusCode = HttpStatusCode.BadRequest) { Async.Task<HttpResponseData> NotOk(HttpRequestData request, Error error, string context, HttpStatusCode statusCode = HttpStatusCode.BadRequest);
}
public class RequestHandling : IRequestHandling {
private readonly ILogTracer _log;
public RequestHandling(ILogTracer log) {
_log = log;
}
public async Async.Task<HttpResponseData> NotOk(HttpRequestData request, Error error, string context, HttpStatusCode statusCode = HttpStatusCode.BadRequest) {
var statusNum = (int)statusCode; var statusNum = (int)statusCode;
if (statusNum >= 400 && statusNum <= 599) { if (statusNum >= 400 && statusNum <= 599) {
log.Error($"request error - {context}: {error}"); _log.Error($"request error - {context}: {error}");
var response = HttpResponseData.CreateResponse(request); var response = HttpResponseData.CreateResponse(request);
await response.WriteAsJsonAsync(error); await response.WriteAsJsonAsync(error);
@ -59,11 +66,13 @@ public class RequestHandling {
} else if (response.Any()) { } else if (response.Any()) {
await resp.WriteAsJsonAsync(response.Single()); await resp.WriteAsJsonAsync(response.Single());
} }
// TODO: ModelMixin stuff // TODO: ModelMixin stuff
return resp; return resp;
} }
public async static Async.Task<HttpResponseData> Ok(HttpRequestData req, BaseResponse response) {
return await Ok(req, new BaseResponse[] { response });
}
} }

View File

@ -11,6 +11,7 @@ public interface IScalesetOperations : IOrm<Scaleset> {
public Async.Task UpdateConfigs(Scaleset scaleSet); public Async.Task UpdateConfigs(Scaleset scaleSet);
public Async.Task<OneFuzzResult<Scaleset>> GetById(Guid scalesetId); public Async.Task<OneFuzzResult<Scaleset>> GetById(Guid scalesetId);
IAsyncEnumerable<Scaleset> GetByObjectId(Guid objectId);
} }
public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalesetOperations { public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalesetOperations {
@ -327,4 +328,8 @@ public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalese
return OneFuzzResult<Scaleset>.Ok(await data.SingleAsync()); return OneFuzzResult<Scaleset>.Ok(await data.SingleAsync());
} }
public IAsyncEnumerable<Scaleset> GetByObjectId(Guid objectId) {
return QueryAsync(filter: $"client_object_id eq '{objectId}'");
}
} }