mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-14 11:08:06 +00:00
migrate tasks (#2233)
* migrate task * cleanup * NotSupported instead of NotImplemented * address pr comments fix function name * fix default value * rename base state * add caching * return the task when deleting * build fix * format
This commit is contained in:
150
src/ApiService/ApiService/Functions/Tasks.cs
Normal file
150
src/ApiService/ApiService/Functions/Tasks.cs
Normal file
@ -0,0 +1,150 @@
|
||||
using System.Net;
|
||||
using Microsoft.Azure.Functions.Worker;
|
||||
using Microsoft.Azure.Functions.Worker.Http;
|
||||
|
||||
namespace Microsoft.OneFuzz.Service.Functions;
|
||||
|
||||
public class Tasks {
|
||||
private readonly ILogTracer _log;
|
||||
private readonly IEndpointAuthorization _auth;
|
||||
private readonly IOnefuzzContext _context;
|
||||
|
||||
public Tasks(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) {
|
||||
_log = log;
|
||||
_auth = auth;
|
||||
_context = context;
|
||||
}
|
||||
|
||||
[Function("Tasks")]
|
||||
public Async.Task<HttpResponseData> Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) {
|
||||
return _auth.CallIfUser(req, r => r.Method switch {
|
||||
"GET" => Get(r),
|
||||
"POST" => Post(r),
|
||||
"DELETE" => Delete(r),
|
||||
_ => throw new InvalidOperationException("Unsupported HTTP method"),
|
||||
});
|
||||
}
|
||||
|
||||
private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
|
||||
var request = await RequestHandling.ParseRequest<TaskSearch>(req);
|
||||
if (!request.IsOk) {
|
||||
return await _context.RequestHandling.NotOk(req, request.ErrorV, "task get");
|
||||
}
|
||||
|
||||
if (request.OkV.TaskId != null) {
|
||||
var task = await _context.TaskOperations.GetByTaskId(request.OkV.TaskId.Value);
|
||||
if (task == null) {
|
||||
return await _context.RequestHandling.NotOk(req, new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find task"
|
||||
}), "task get");
|
||||
|
||||
}
|
||||
task.Nodes = await _context.NodeTasksOperations.GetNodeAssignments(request.OkV.TaskId.Value).ToListAsync();
|
||||
task.Events = await _context.TaskEventOperations.GetSummary(request.OkV.TaskId.Value).ToListAsync();
|
||||
|
||||
var response = req.CreateResponse(HttpStatusCode.OK);
|
||||
await response.WriteAsJsonAsync(task);
|
||||
return response;
|
||||
|
||||
}
|
||||
|
||||
var tasks = await _context.TaskOperations.SearchAll().ToListAsync();
|
||||
var response2 = req.CreateResponse(HttpStatusCode.OK);
|
||||
await response2.WriteAsJsonAsync(tasks);
|
||||
return response2;
|
||||
}
|
||||
|
||||
|
||||
private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
|
||||
var request = await RequestHandling.ParseRequest<TaskConfig>(req);
|
||||
if (!request.IsOk) {
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
request.ErrorV,
|
||||
"task create");
|
||||
}
|
||||
|
||||
var userInfo = await _context.UserCredentials.ParseJwtToken(req);
|
||||
if (!userInfo.IsOk) {
|
||||
return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "task create");
|
||||
}
|
||||
|
||||
var checkConfig = await _context.Config.CheckConfig(request.OkV);
|
||||
if (!checkConfig.IsOk) {
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
new Error(ErrorCode.INVALID_REQUEST, new[] { checkConfig.ErrorV.Error }),
|
||||
"task create");
|
||||
}
|
||||
|
||||
if (System.Web.HttpUtility.ParseQueryString(req.Url.Query)["dryrun"] != null) {
|
||||
var response = req.CreateResponse(HttpStatusCode.OK);
|
||||
await response.WriteAsJsonAsync(new BoolResult(true));
|
||||
return response;
|
||||
}
|
||||
|
||||
var job = await _context.JobOperations.Get(request.OkV.JobId);
|
||||
if (job == null) {
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find job" }),
|
||||
request.OkV.JobId.ToString());
|
||||
}
|
||||
|
||||
if (job.State != JobState.Enabled && job.State != JobState.Init) {
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
new Error(ErrorCode.UNABLE_TO_ADD_TASK_TO_JOB, new[] { $"unable to add a job in state {job.State}" }),
|
||||
request.OkV.JobId.ToString());
|
||||
}
|
||||
|
||||
if (request.OkV.PrereqTasks != null) {
|
||||
foreach (var taskId in request.OkV.PrereqTasks) {
|
||||
var prereq = await _context.TaskOperations.GetByTaskId(taskId);
|
||||
|
||||
if (prereq == null) {
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find task " }),
|
||||
"task create prerequisite");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var task = await _context.TaskOperations.Create(request.OkV, request.OkV.JobId, userInfo.OkV);
|
||||
|
||||
if (!task.IsOk) {
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
task.ErrorV,
|
||||
"task create invalid pool");
|
||||
}
|
||||
|
||||
var taskResponse = req.CreateResponse(HttpStatusCode.OK);
|
||||
await taskResponse.WriteAsJsonAsync(task.OkV);
|
||||
return taskResponse;
|
||||
}
|
||||
|
||||
private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
|
||||
var request = await RequestHandling.ParseRequest<TaskGet>(req);
|
||||
if (!request.IsOk) {
|
||||
return await _context.RequestHandling.NotOk(
|
||||
req,
|
||||
request.ErrorV,
|
||||
context: "task delete");
|
||||
}
|
||||
|
||||
|
||||
var task = await _context.TaskOperations.GetByTaskId(request.OkV.TaskId);
|
||||
if (task == null) {
|
||||
return await _context.RequestHandling.NotOk(req, new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find task"
|
||||
}), "task delete");
|
||||
|
||||
}
|
||||
|
||||
await _context.TaskOperations.MarkStopping(task);
|
||||
|
||||
var response = req.CreateResponse(HttpStatusCode.OK);
|
||||
await response.WriteAsJsonAsync(task);
|
||||
return response;
|
||||
}
|
||||
}
|
@ -7,7 +7,6 @@ namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
|
||||
|
||||
|
||||
public enum EventType {
|
||||
JobCreated,
|
||||
JobStopped,
|
||||
@ -63,6 +62,7 @@ public abstract record BaseEvent() {
|
||||
EventNodeDeleted _ => EventType.NodeDeleted,
|
||||
EventNodeCreated _ => EventType.NodeCreated,
|
||||
EventJobStopped _ => EventType.JobStopped,
|
||||
EventTaskCreated _ => EventType.TaskCreated,
|
||||
var x => throw new NotSupportedException($"Unknown event type: {x.GetType()}"),
|
||||
};
|
||||
|
||||
@ -91,6 +91,7 @@ public abstract record BaseEvent() {
|
||||
EventType.NodeDeleted => typeof(EventNodeDeleted),
|
||||
EventType.NodeCreated => typeof(EventNodeCreated),
|
||||
EventType.JobStopped => typeof(EventJobStopped),
|
||||
EventType.TaskCreated => typeof(EventTaskCreated),
|
||||
_ => throw new ArgumentException($"Unknown event type: {eventType}"),
|
||||
};
|
||||
}
|
||||
@ -141,12 +142,12 @@ public record EventJobStopped(
|
||||
) : BaseEvent();
|
||||
|
||||
|
||||
//record EventTaskCreated(
|
||||
// Guid JobId,
|
||||
// Guid TaskId,
|
||||
// TaskConfig Config,
|
||||
// UserInfo? UserInfo
|
||||
// ) : BaseEvent();
|
||||
record EventTaskCreated(
|
||||
Guid JobId,
|
||||
Guid TaskId,
|
||||
TaskConfig Config,
|
||||
UserInfo? UserInfo
|
||||
) : BaseEvent();
|
||||
|
||||
|
||||
public record EventTaskStateUpdated(
|
||||
|
@ -256,8 +256,8 @@ public record Task(
|
||||
DateTimeOffset? Heartbeat = null,
|
||||
DateTimeOffset? EndTime = null,
|
||||
UserInfo? UserInfo = null) : StatefulEntityBase<TaskState>(State) {
|
||||
List<TaskEventSummary> Events { get; set; } = new List<TaskEventSummary>();
|
||||
List<NodeAssignment> Nodes { get; set; } = new List<NodeAssignment>();
|
||||
public List<TaskEventSummary> Events { get; set; } = new List<TaskEventSummary>();
|
||||
public List<NodeAssignment> Nodes { get; set; } = new List<NodeAssignment>();
|
||||
}
|
||||
|
||||
public record TaskEvent(
|
||||
|
@ -164,6 +164,13 @@ public record ProxyReset(
|
||||
string Region
|
||||
);
|
||||
|
||||
public record TaskGet(Guid TaskId);
|
||||
|
||||
public record TaskSearch(
|
||||
Guid? JobId,
|
||||
Guid? TaskId,
|
||||
List<TaskState> State);
|
||||
|
||||
public record PoolSearch(
|
||||
Guid? PoolId = null,
|
||||
PoolName? Name = null,
|
||||
|
@ -1,24 +1,31 @@
|
||||
using Azure.Storage.Sas;
|
||||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
using Azure.Storage.Sas;
|
||||
|
||||
namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
|
||||
|
||||
public interface IConfig {
|
||||
Async.Task<TaskUnitConfig> BuildTaskConfig(Job job, Task task);
|
||||
Task<ResultVoid<TaskConfigError>> CheckConfig(TaskConfig config);
|
||||
}
|
||||
|
||||
public record TaskConfigError(string Error);
|
||||
|
||||
public class Config : IConfig {
|
||||
|
||||
private readonly IOnefuzzContext _context;
|
||||
private readonly IContainers _containers;
|
||||
private readonly IServiceConfig _serviceConfig;
|
||||
|
||||
private readonly ILogTracer _logTracer;
|
||||
private readonly IQueue _queue;
|
||||
|
||||
public Config(IContainers containers, IServiceConfig serviceConfig, IQueue queue) {
|
||||
_containers = containers;
|
||||
_serviceConfig = serviceConfig;
|
||||
_queue = queue;
|
||||
public Config(ILogTracer logTracer, IOnefuzzContext context) {
|
||||
_context = context;
|
||||
_logTracer = logTracer;
|
||||
_containers = _context.Containers;
|
||||
_serviceConfig = _context.ServiceConfiguration;
|
||||
_queue = _context.Queue;
|
||||
}
|
||||
|
||||
private static BlobContainerSasPermissions ConvertPermissions(ContainerPermission permission) {
|
||||
@ -257,4 +264,257 @@ public class Config : IConfig {
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
public async Async.Task<ResultVoid<TaskConfigError>> CheckConfig(TaskConfig config) {
|
||||
if (!Defs.TASK_DEFINITIONS.ContainsKey(config.Task.Type)) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError($"unsupported task type: {config.Task.Type}"));
|
||||
}
|
||||
|
||||
if (config.Vm != null && config.Pool != null) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError($"either the vm or pool must be specified, but not both"));
|
||||
}
|
||||
|
||||
var definition = Defs.TASK_DEFINITIONS[config.Task.Type];
|
||||
var r = await CheckContainers(definition, config);
|
||||
if (!r.IsOk) {
|
||||
return r;
|
||||
}
|
||||
|
||||
if (definition.Features.Contains(TaskFeature.SupervisorExe) && config.Task.SupervisorExe == null) {
|
||||
var err = "missing supervisor_exe";
|
||||
_logTracer.Error(err);
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError(err));
|
||||
}
|
||||
|
||||
if (definition.Features.Contains(TaskFeature.TargetMustUseInput) && !TargetUsesInput(config)) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError("{input} must be used in target_env or target_options"));
|
||||
}
|
||||
|
||||
if (config.Vm != null) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError("specifying task config vm is no longer supported"));
|
||||
}
|
||||
|
||||
if (config.Pool == null) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError("pool must be specified"));
|
||||
}
|
||||
|
||||
if (!CheckVal(definition.Vm.Compare, definition.Vm.Value, config.Pool!.Count)) {
|
||||
var err =
|
||||
$"invalid vm count: expected {definition.Vm.Compare} {definition.Vm.Value}, got {config.Pool.Count}";
|
||||
_logTracer.Error(err);
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError(err));
|
||||
}
|
||||
|
||||
var pool = await _context.PoolOperations.GetByName(config.Pool.PoolName);
|
||||
if (!pool.IsOk) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError($"invalid pool: {config.Pool.PoolName}"));
|
||||
}
|
||||
|
||||
var checkTarget = await CheckTargetExe(config, definition);
|
||||
if (!checkTarget.IsOk) {
|
||||
return checkTarget;
|
||||
}
|
||||
|
||||
if (definition.Features.Contains(TaskFeature.GeneratorExe)) {
|
||||
var container = config.Containers!.First(x => x.Type == ContainerType.Tools);
|
||||
|
||||
if (config.Task.GeneratorExe == null) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError($"generator_exe is not defined"));
|
||||
}
|
||||
|
||||
var tool_paths = new[] { "{tools_dir}/", "{tools_dir}\\" };
|
||||
|
||||
foreach (var toolPath in tool_paths) {
|
||||
if (config.Task.GeneratorExe.StartsWith(toolPath)) {
|
||||
var generator = config.Task.GeneratorExe.Replace(toolPath, "");
|
||||
if (!await _containers.BlobExists(container.Name, generator, StorageType.Corpus)) {
|
||||
var err =
|
||||
$"generator_exe `{config.Task.GeneratorExe}` does not exist in the tools container `{container.Name}`";
|
||||
_logTracer.Error(err);
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (definition.Features.Contains(TaskFeature.StatsFile)) {
|
||||
if (config.Task.StatsFile != null && config.Task.StatsFormat == null) {
|
||||
var err2 = "using a stats_file requires a stats_format";
|
||||
_logTracer.Error(err2);
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError(err2));
|
||||
}
|
||||
}
|
||||
|
||||
return ResultVoid<TaskConfigError>.Ok();
|
||||
|
||||
}
|
||||
|
||||
private async Task<ResultVoid<TaskConfigError>> CheckTargetExe(TaskConfig config, TaskDefinition definition) {
|
||||
if (config.Task.TargetExe == null) {
|
||||
if (definition.Features.Contains(TaskFeature.TargetExe)) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError("missing target_exe"));
|
||||
}
|
||||
|
||||
if (definition.Features.Contains(TaskFeature.TargetExeOptional)) {
|
||||
return ResultVoid<TaskConfigError>.Ok();
|
||||
}
|
||||
return ResultVoid<TaskConfigError>.Ok();
|
||||
}
|
||||
|
||||
// User-submitted paths must be relative to the setup directory that contains them.
|
||||
// They also must be normalized, and exclude special filesystem path elements.
|
||||
//
|
||||
// For example, accessing the blob store path "./foo" generates an exception, but
|
||||
// "foo" and "foo/bar" do not.
|
||||
|
||||
if (!IsValidBlobName(config.Task.TargetExe)) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError("target_exe must be a canonicalized relative path"));
|
||||
}
|
||||
|
||||
|
||||
var container = config.Containers!.FirstOrDefault(x => x.Type == ContainerType.Setup);
|
||||
if (container != null) {
|
||||
if (!await _containers.BlobExists(container.Name, config.Task.TargetExe, StorageType.Corpus)) {
|
||||
var err =
|
||||
$"target_exe `{config.Task.TargetExe}` does not exist in the setup container `{container.Name}`";
|
||||
|
||||
_logTracer.Warning(err);
|
||||
}
|
||||
}
|
||||
|
||||
return ResultVoid<TaskConfigError>.Ok();
|
||||
}
|
||||
|
||||
|
||||
|
||||
// Azure Blob Storage uses a flat scheme, and has no true directory hierarchy. Forward
|
||||
// slashes are used to delimit a _virtual_ directory structure.
|
||||
private static bool IsValidBlobName(string blobName) {
|
||||
// https://docs.microsoft.com/en-us/rest/api/storageservices/naming-and-referencing-containers--blobs--and-metadata#blob-names
|
||||
const int MIN_LENGTH = 1;
|
||||
const int MAX_LENGTH = 1024; // inclusive
|
||||
const int MAX_PATH_SEGMENTS = 254;
|
||||
|
||||
var length = blobName.Length;
|
||||
|
||||
// No leading/trailing whitespace.
|
||||
if (blobName != blobName.Trim()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (length < MIN_LENGTH) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (length > MAX_LENGTH) {
|
||||
return false;
|
||||
}
|
||||
|
||||
var segments = blobName.Split(new[] { Path.AltDirectorySeparatorChar, Path.DirectorySeparatorChar });
|
||||
|
||||
if (segments.Length > MAX_PATH_SEGMENTS) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// No path segment should end with a dot (`.`).
|
||||
if (segments.Any(s => s.EndsWith('.'))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reject absolute paths to avoid confusion.
|
||||
if (Path.IsPathRooted(blobName)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Reject paths with special relative filesystem entries.
|
||||
if (segments.Contains(".")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (segments.Contains("..")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool TargetUsesInput(TaskConfig config) {
|
||||
if (config.Task.TargetOptions != null) {
|
||||
if (config.Task.TargetOptions.Any(x => x.Contains("{input}")))
|
||||
return true;
|
||||
}
|
||||
|
||||
if (config.Task.TargetEnv != null) {
|
||||
if (config.Task.TargetEnv.Values.Any(x => x.Contains("{input}")))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private async Task<ResultVoid<TaskConfigError>> CheckContainers(TaskDefinition definition, TaskConfig config) {
|
||||
|
||||
if (config.Containers == null) {
|
||||
return ResultVoid<TaskConfigError>.Ok();
|
||||
}
|
||||
|
||||
var exist = new HashSet<string>();
|
||||
var containers = new Dictionary<ContainerType, List<Container>>();
|
||||
|
||||
foreach (var container in config.Containers) {
|
||||
if (exist.Contains(container.Name.ContainerName)) {
|
||||
continue;
|
||||
}
|
||||
if (await _containers.FindContainer(container.Name, StorageType.Corpus) == null) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError($"missing container: {container.Name}"));
|
||||
}
|
||||
exist.Add(container.Name.ContainerName);
|
||||
|
||||
if (!containers.ContainsKey(container.Type)) {
|
||||
containers.Add(container.Type, new List<Container>());
|
||||
}
|
||||
containers[container.Type].Add(container.Name);
|
||||
}
|
||||
|
||||
foreach (var containerDef in definition.Containers) {
|
||||
var r = CheckContainer(containerDef.Compare, containerDef.Value, containerDef.Type, containers);
|
||||
if (!r.IsOk) {
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
var containerTypes = definition.Containers.Select(x => x.Type).ToHashSet();
|
||||
var missing = containers.Keys.Where(x => !containerTypes.Contains(x)).ToList();
|
||||
if (missing.Any()) {
|
||||
var types = string.Join(", ", missing);
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError($"unsupported container types for this task: {types}"));
|
||||
}
|
||||
|
||||
if (definition.MonitorQueue != null) {
|
||||
if (!containerTypes.Contains(definition.MonitorQueue.Value)) {
|
||||
return ResultVoid<TaskConfigError>.Error(new TaskConfigError($"unable to monitor container type as it is not used by this task: {definition.MonitorQueue}"));
|
||||
}
|
||||
}
|
||||
|
||||
return ResultVoid<TaskConfigError>.Ok();
|
||||
}
|
||||
|
||||
private static ResultVoid<TaskConfigError> CheckContainer(Compare compare, long expected, ContainerType containerType, Dictionary<ContainerType, List<Container>> containers) {
|
||||
var actual = containers.ContainsKey(containerType) ? containers[containerType].Count : 0;
|
||||
|
||||
if (!CheckVal(compare, expected, actual)) {
|
||||
return ResultVoid<TaskConfigError>.Error(
|
||||
new TaskConfigError($"container type {containerType}: expected {compare} {expected}, got {actual}"));
|
||||
}
|
||||
|
||||
return ResultVoid<TaskConfigError>.Ok();
|
||||
}
|
||||
|
||||
private static bool CheckVal(Compare compare, long expected, long actual) {
|
||||
return compare switch {
|
||||
Compare.Equal => expected == actual,
|
||||
Compare.AtLeast => expected <= actual,
|
||||
Compare.AtMost => expected >= actual,
|
||||
_ => throw new NotSupportedException()
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ public interface ICreds {
|
||||
Async.Task<IReadOnlyList<string>> GetRegions();
|
||||
}
|
||||
|
||||
public class Creds : ICreds {
|
||||
public sealed class Creds : ICreds, IDisposable {
|
||||
private readonly ArmClient _armClient;
|
||||
private readonly DefaultAzureCredential _azureCredential;
|
||||
private readonly IServiceConfig _config;
|
||||
@ -51,6 +51,7 @@ public class Creds : ICreds {
|
||||
_cache = cache;
|
||||
_azureCredential = new DefaultAzureCredential();
|
||||
_armClient = new ArmClient(this.GetIdentity(), this.GetSubscription());
|
||||
|
||||
}
|
||||
|
||||
public DefaultAzureCredential GetIdentity() {
|
||||
@ -89,12 +90,14 @@ public class Creds : ICreds {
|
||||
return ArmClient.GetResourceGroupResource(resourceId);
|
||||
}
|
||||
|
||||
public async Async.Task<string> GetBaseRegion() {
|
||||
var rg = await ArmClient.GetResourceGroupResource(GetResourceGroupResourceIdentifier()).GetAsync();
|
||||
if (rg.GetRawResponse().IsError) {
|
||||
throw new Exception($"Failed to get base region due to [{rg.GetRawResponse().Status}] {rg.GetRawResponse().ReasonPhrase}");
|
||||
}
|
||||
return rg.Value.Data.Location.Name;
|
||||
public Async.Task<string> GetBaseRegion() {
|
||||
return _cache.GetOrCreateAsync(nameof(GetBaseRegion), async _ => {
|
||||
var rg = await ArmClient.GetResourceGroupResource(GetResourceGroupResourceIdentifier()).GetAsync();
|
||||
if (rg.GetRawResponse().IsError) {
|
||||
throw new Exception($"Failed to get base region due to [{rg.GetRawResponse().Status}] {rg.GetRawResponse().ReasonPhrase}");
|
||||
}
|
||||
return rg.Value.Data.Location.Name;
|
||||
});
|
||||
}
|
||||
|
||||
public Uri GetInstanceUrl()
|
||||
@ -103,14 +106,15 @@ public class Creds : ICreds {
|
||||
|
||||
public record ScaleSetIdentity(string principalId);
|
||||
|
||||
public async Async.Task<Guid> GetScalesetPrincipalId() {
|
||||
var path = GetScalesetIdentityResourcePath();
|
||||
var uid = ArmClient.GetGenericResource(new ResourceIdentifier(path));
|
||||
public Async.Task<Guid> GetScalesetPrincipalId() {
|
||||
return _cache.GetOrCreateAsync(nameof(GetScalesetPrincipalId), async entry => {
|
||||
var path = GetScalesetIdentityResourcePath();
|
||||
var uid = ArmClient.GetGenericResource(new ResourceIdentifier(path));
|
||||
|
||||
|
||||
var resource = await uid.GetAsync();
|
||||
var principalId = resource.Value.Data.Properties.ToObjectFromJson<ScaleSetIdentity>().principalId;
|
||||
return new Guid(principalId);
|
||||
var resource = await uid.GetAsync();
|
||||
var principalId = resource.Value.Data.Properties.ToObjectFromJson<ScaleSetIdentity>().principalId;
|
||||
return new Guid(principalId);
|
||||
});
|
||||
}
|
||||
|
||||
public string GetScalesetIdentityResourcePath() {
|
||||
@ -125,6 +129,7 @@ public class Creds : ICreds {
|
||||
private static readonly Uri _graphResource = new("https://graph.microsoft.com");
|
||||
private static readonly Uri _graphResourceEndpoint = new("https://graph.microsoft.com/v1.0");
|
||||
|
||||
|
||||
public async Task<T> QueryMicrosoftGraph<T>(HttpMethod method, string resource) {
|
||||
var cred = GetIdentity();
|
||||
|
||||
@ -166,6 +171,10 @@ public class Creds : ICreds {
|
||||
return resource;
|
||||
}
|
||||
|
||||
public void Dispose() {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public Task<IReadOnlyList<string>> GetRegions()
|
||||
=> _cache.GetOrCreateAsync<IReadOnlyList<string>>(
|
||||
nameof(Creds) + "." + nameof(GetRegions),
|
||||
|
@ -459,7 +459,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState, NodeOperations>, INod
|
||||
|
||||
public interface INodeTasksOperations : IStatefulOrm<NodeTasks, NodeTaskState> {
|
||||
IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId);
|
||||
IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps);
|
||||
IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId);
|
||||
IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId);
|
||||
IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId);
|
||||
Async.Task ClearByMachineId(Guid machineId);
|
||||
@ -485,7 +485,7 @@ public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState, NodeTas
|
||||
}
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps) {
|
||||
public async IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId) {
|
||||
|
||||
await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) {
|
||||
var node = await _context.NodeOperations.GetByMachineId(entry.MachineId);
|
||||
|
@ -118,7 +118,7 @@ public class Queue : IQueue {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false; ;
|
||||
return false;
|
||||
}
|
||||
|
||||
public async Task<IList<T>> PeekQueue<T>(string name, StorageType storageType) {
|
||||
|
@ -3,9 +3,26 @@
|
||||
namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
public interface ITaskEventOperations : IOrm<TaskEvent> {
|
||||
IAsyncEnumerable<TaskEventSummary> GetSummary(Guid taskId);
|
||||
}
|
||||
|
||||
public sealed class TaskEventOperations : Orm<TaskEvent>, ITaskEventOperations {
|
||||
public TaskEventOperations(ILogTracer logTracer, IOnefuzzContext context)
|
||||
: base(logTracer, context) { }
|
||||
|
||||
public IAsyncEnumerable<TaskEventSummary> GetSummary(Guid taskId) {
|
||||
return
|
||||
SearchByPartitionKeys(new[] { $"{taskId}" })
|
||||
.OrderBy(x => x.TimeStamp ?? DateTimeOffset.MaxValue)
|
||||
.Select(x => new TaskEventSummary(x.TimeStamp, GetEventData(x.EventData), GetEventType(x.EventData)));
|
||||
}
|
||||
|
||||
private static string GetEventData(WorkerEvent ev) {
|
||||
return ev.Done != null ? $"exit status: {ev.Done.ExitStatus}" :
|
||||
ev.Running != null ? string.Empty : "Unrecognized event: {ev}";
|
||||
}
|
||||
|
||||
private static string GetEventType(WorkerEvent ev) {
|
||||
return ev.GetType().Name;
|
||||
}
|
||||
}
|
||||
|
@ -1,4 +1,5 @@
|
||||
using ApiService.OneFuzzLib.Orm;
|
||||
using System.Threading.Tasks;
|
||||
using ApiService.OneFuzzLib.Orm;
|
||||
|
||||
namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
@ -24,6 +25,7 @@ public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
|
||||
Async.Task<bool> CheckPrereqTasks(Task task);
|
||||
Async.Task<Pool?> GetPool(Task task);
|
||||
Async.Task<Task> SetState(Task task, TaskState state);
|
||||
Async.Task<OneFuzzResult<Task>> Create(TaskConfig config, Guid jobId, UserInfo userInfo);
|
||||
}
|
||||
|
||||
public class TaskOperations : StatefulOrm<Task, TaskState, TaskOperations>, ITaskOperations {
|
||||
@ -164,6 +166,35 @@ public class TaskOperations : StatefulOrm<Task, TaskState, TaskOperations>, ITas
|
||||
return task;
|
||||
}
|
||||
|
||||
public async Task<OneFuzzResult<Task>> Create(TaskConfig config, Guid jobId, UserInfo userInfo) {
|
||||
|
||||
Os os;
|
||||
if (config.Vm != null) {
|
||||
var osResult = await _context.ImageOperations.GetOs(config.Vm.Region, config.Vm.Image);
|
||||
if (!osResult.IsOk) {
|
||||
return OneFuzzResult<Task>.Error(osResult.ErrorV);
|
||||
}
|
||||
os = osResult.OkV;
|
||||
} else if (config.Pool != null) {
|
||||
var pool = await _context.PoolOperations.GetByName(config.Pool.PoolName);
|
||||
|
||||
if (!pool.IsOk) {
|
||||
return OneFuzzResult<Task>.Error(pool.ErrorV);
|
||||
}
|
||||
os = pool.OkV.Os;
|
||||
} else {
|
||||
return OneFuzzResult<Task>.Error(new Error(ErrorCode.INVALID_CONFIGURATION, new[] { "task must have vm or pool" }));
|
||||
}
|
||||
|
||||
var task = new Task(jobId, Guid.NewGuid(), TaskState.Init, os, config, UserInfo: userInfo);
|
||||
|
||||
await _context.TaskOperations.Insert(task);
|
||||
await _context.Events.SendEvent(new EventTaskCreated(jobId, task.TaskId, config, userInfo));
|
||||
|
||||
_logTracer.Info($"created task. job_id:{jobId} task_id:{task.TaskId} type:{task.Config.Task.Type}");
|
||||
return OneFuzzResult<Task>.Ok(task);
|
||||
}
|
||||
|
||||
private async Async.Task<Task> OnStart(Task task) {
|
||||
if (task.EndTime == null) {
|
||||
task = task with { EndTime = DateTimeOffset.UtcNow + TimeSpan.FromHours(task.Config.Task.Duration) };
|
||||
|
@ -17,7 +17,7 @@ public abstract record EntityBase {
|
||||
public static string NewSortedKey => $"{DateTimeOffset.MaxValue.Ticks - DateTimeOffset.UtcNow.Ticks}";
|
||||
}
|
||||
|
||||
public abstract record StatefulEntityBase<T>([property: JsonIgnore] T State) : EntityBase() where T : Enum;
|
||||
public abstract record StatefulEntityBase<T>([property: JsonIgnore] T BaseState) : EntityBase() where T : Enum;
|
||||
|
||||
|
||||
|
||||
|
@ -205,7 +205,7 @@ namespace ApiService.OneFuzzLib.Orm {
|
||||
/// <param name="entity"></param>
|
||||
/// <returns></returns>
|
||||
public async Async.Task<T?> ProcessStateUpdate(T entity) {
|
||||
TState state = entity.State;
|
||||
TState state = entity.BaseState;
|
||||
var func = GetType().GetMethod(state.ToString()) switch {
|
||||
null => null,
|
||||
MethodInfo info => info.CreateDelegate<StateTransition>(this)
|
||||
@ -227,13 +227,13 @@ namespace ApiService.OneFuzzLib.Orm {
|
||||
/// <param name="MaxUpdates"></param>
|
||||
public async Async.Task<T?> ProcessStateUpdates(T entity, int MaxUpdates = 5) {
|
||||
for (int i = 0; i < MaxUpdates; i++) {
|
||||
var state = entity.State;
|
||||
var state = entity.BaseState;
|
||||
var newEntity = await ProcessStateUpdate(entity);
|
||||
|
||||
if (newEntity == null)
|
||||
return null;
|
||||
|
||||
if (newEntity.State.Equals(state)) {
|
||||
if (newEntity.BaseState.Equals(state)) {
|
||||
return newEntity;
|
||||
}
|
||||
}
|
||||
|
@ -268,7 +268,7 @@ namespace Tests {
|
||||
InstanceName: arg.Item4,
|
||||
WebhookId: arg.Item5
|
||||
)
|
||||
); ;
|
||||
);
|
||||
}
|
||||
|
||||
public static Gen<WebhookMessageEventGrid> WebhookMessageEventGrid() {
|
||||
|
Reference in New Issue
Block a user