diff --git a/src/ApiService/ApiService/Functions/Tasks.cs b/src/ApiService/ApiService/Functions/Tasks.cs new file mode 100644 index 000000000..2b5f89de8 --- /dev/null +++ b/src/ApiService/ApiService/Functions/Tasks.cs @@ -0,0 +1,150 @@ +using System.Net; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; + +namespace Microsoft.OneFuzz.Service.Functions; + +public class Tasks { + private readonly ILogTracer _log; + private readonly IEndpointAuthorization _auth; + private readonly IOnefuzzContext _context; + + public Tasks(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) { + _log = log; + _auth = auth; + _context = context; + } + + [Function("Tasks")] + public Async.Task Run([HttpTrigger(AuthorizationLevel.Anonymous, "GET", "POST", "DELETE")] HttpRequestData req) { + return _auth.CallIfUser(req, r => r.Method switch { + "GET" => Get(r), + "POST" => Post(r), + "DELETE" => Delete(r), + _ => throw new InvalidOperationException("Unsupported HTTP method"), + }); + } + + private async Async.Task Get(HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk) { + return await _context.RequestHandling.NotOk(req, request.ErrorV, "task get"); + } + + if (request.OkV.TaskId != null) { + var task = await _context.TaskOperations.GetByTaskId(request.OkV.TaskId.Value); + if (task == null) { + return await _context.RequestHandling.NotOk(req, new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find task" + }), "task get"); + + } + task.Nodes = await _context.NodeTasksOperations.GetNodeAssignments(request.OkV.TaskId.Value).ToListAsync(); + task.Events = await _context.TaskEventOperations.GetSummary(request.OkV.TaskId.Value).ToListAsync(); + + var response = req.CreateResponse(HttpStatusCode.OK); + await response.WriteAsJsonAsync(task); + return response; + + } + + var tasks = await _context.TaskOperations.SearchAll().ToListAsync(); + var response2 = req.CreateResponse(HttpStatusCode.OK); + await response2.WriteAsJsonAsync(tasks); + return response2; + } + + + private async Async.Task Post(HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk) { + return await _context.RequestHandling.NotOk( + req, + request.ErrorV, + "task create"); + } + + var userInfo = await _context.UserCredentials.ParseJwtToken(req); + if (!userInfo.IsOk) { + return await _context.RequestHandling.NotOk(req, userInfo.ErrorV, "task create"); + } + + var checkConfig = await _context.Config.CheckConfig(request.OkV); + if (!checkConfig.IsOk) { + return await _context.RequestHandling.NotOk( + req, + new Error(ErrorCode.INVALID_REQUEST, new[] { checkConfig.ErrorV.Error }), + "task create"); + } + + if (System.Web.HttpUtility.ParseQueryString(req.Url.Query)["dryrun"] != null) { + var response = req.CreateResponse(HttpStatusCode.OK); + await response.WriteAsJsonAsync(new BoolResult(true)); + return response; + } + + var job = await _context.JobOperations.Get(request.OkV.JobId); + if (job == null) { + return await _context.RequestHandling.NotOk( + req, + new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find job" }), + request.OkV.JobId.ToString()); + } + + if (job.State != JobState.Enabled && job.State != JobState.Init) { + return await _context.RequestHandling.NotOk( + req, + new Error(ErrorCode.UNABLE_TO_ADD_TASK_TO_JOB, new[] { $"unable to add a job in state {job.State}" }), + request.OkV.JobId.ToString()); + } + + if (request.OkV.PrereqTasks != null) { + foreach (var taskId in request.OkV.PrereqTasks) { + var prereq = await _context.TaskOperations.GetByTaskId(taskId); + + if (prereq == null) { + return await _context.RequestHandling.NotOk( + req, + new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find task " }), + "task create prerequisite"); + } + } + } + + var task = await _context.TaskOperations.Create(request.OkV, request.OkV.JobId, userInfo.OkV); + + if (!task.IsOk) { + return await _context.RequestHandling.NotOk( + req, + task.ErrorV, + "task create invalid pool"); + } + + var taskResponse = req.CreateResponse(HttpStatusCode.OK); + await taskResponse.WriteAsJsonAsync(task.OkV); + return taskResponse; + } + + private async Async.Task Delete(HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk) { + return await _context.RequestHandling.NotOk( + req, + request.ErrorV, + context: "task delete"); + } + + + var task = await _context.TaskOperations.GetByTaskId(request.OkV.TaskId); + if (task == null) { + return await _context.RequestHandling.NotOk(req, new Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find task" + }), "task delete"); + + } + + await _context.TaskOperations.MarkStopping(task); + + var response = req.CreateResponse(HttpStatusCode.OK); + await response.WriteAsJsonAsync(task); + return response; + } +} diff --git a/src/ApiService/ApiService/OneFuzzTypes/Events.cs b/src/ApiService/ApiService/OneFuzzTypes/Events.cs index 3f67591f7..184431c03 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Events.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Events.cs @@ -7,7 +7,6 @@ namespace Microsoft.OneFuzz.Service; - public enum EventType { JobCreated, JobStopped, @@ -63,6 +62,7 @@ public abstract record BaseEvent() { EventNodeDeleted _ => EventType.NodeDeleted, EventNodeCreated _ => EventType.NodeCreated, EventJobStopped _ => EventType.JobStopped, + EventTaskCreated _ => EventType.TaskCreated, var x => throw new NotSupportedException($"Unknown event type: {x.GetType()}"), }; @@ -91,6 +91,7 @@ public abstract record BaseEvent() { EventType.NodeDeleted => typeof(EventNodeDeleted), EventType.NodeCreated => typeof(EventNodeCreated), EventType.JobStopped => typeof(EventJobStopped), + EventType.TaskCreated => typeof(EventTaskCreated), _ => throw new ArgumentException($"Unknown event type: {eventType}"), }; } @@ -141,12 +142,12 @@ public record EventJobStopped( ) : BaseEvent(); -//record EventTaskCreated( -// Guid JobId, -// Guid TaskId, -// TaskConfig Config, -// UserInfo? UserInfo -// ) : BaseEvent(); +record EventTaskCreated( + Guid JobId, + Guid TaskId, + TaskConfig Config, + UserInfo? UserInfo + ) : BaseEvent(); public record EventTaskStateUpdated( diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index 02dfcc586..c53b1a60a 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -256,8 +256,8 @@ public record Task( DateTimeOffset? Heartbeat = null, DateTimeOffset? EndTime = null, UserInfo? UserInfo = null) : StatefulEntityBase(State) { - List Events { get; set; } = new List(); - List Nodes { get; set; } = new List(); + public List Events { get; set; } = new List(); + public List Nodes { get; set; } = new List(); } public record TaskEvent( diff --git a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs index baa2ed180..824c5abe6 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs @@ -164,6 +164,13 @@ public record ProxyReset( string Region ); +public record TaskGet(Guid TaskId); + +public record TaskSearch( + Guid? JobId, + Guid? TaskId, + List State); + public record PoolSearch( Guid? PoolId = null, PoolName? Name = null, diff --git a/src/ApiService/ApiService/onefuzzlib/Config.cs b/src/ApiService/ApiService/onefuzzlib/Config.cs index 551844b1d..bf89bd693 100644 --- a/src/ApiService/ApiService/onefuzzlib/Config.cs +++ b/src/ApiService/ApiService/onefuzzlib/Config.cs @@ -1,24 +1,31 @@ -using Azure.Storage.Sas; +using System.IO; +using System.Threading.Tasks; +using Azure.Storage.Sas; namespace Microsoft.OneFuzz.Service; - public interface IConfig { Async.Task BuildTaskConfig(Job job, Task task); + Task> CheckConfig(TaskConfig config); } +public record TaskConfigError(string Error); + public class Config : IConfig { + private readonly IOnefuzzContext _context; private readonly IContainers _containers; private readonly IServiceConfig _serviceConfig; - + private readonly ILogTracer _logTracer; private readonly IQueue _queue; - public Config(IContainers containers, IServiceConfig serviceConfig, IQueue queue) { - _containers = containers; - _serviceConfig = serviceConfig; - _queue = queue; + public Config(ILogTracer logTracer, IOnefuzzContext context) { + _context = context; + _logTracer = logTracer; + _containers = _context.Containers; + _serviceConfig = _context.ServiceConfiguration; + _queue = _context.Queue; } private static BlobContainerSasPermissions ConvertPermissions(ContainerPermission permission) { @@ -257,4 +264,257 @@ public class Config : IConfig { return config; } + + public async Async.Task> CheckConfig(TaskConfig config) { + if (!Defs.TASK_DEFINITIONS.ContainsKey(config.Task.Type)) { + return ResultVoid.Error(new TaskConfigError($"unsupported task type: {config.Task.Type}")); + } + + if (config.Vm != null && config.Pool != null) { + return ResultVoid.Error(new TaskConfigError($"either the vm or pool must be specified, but not both")); + } + + var definition = Defs.TASK_DEFINITIONS[config.Task.Type]; + var r = await CheckContainers(definition, config); + if (!r.IsOk) { + return r; + } + + if (definition.Features.Contains(TaskFeature.SupervisorExe) && config.Task.SupervisorExe == null) { + var err = "missing supervisor_exe"; + _logTracer.Error(err); + return ResultVoid.Error(new TaskConfigError(err)); + } + + if (definition.Features.Contains(TaskFeature.TargetMustUseInput) && !TargetUsesInput(config)) { + return ResultVoid.Error(new TaskConfigError("{input} must be used in target_env or target_options")); + } + + if (config.Vm != null) { + return ResultVoid.Error(new TaskConfigError("specifying task config vm is no longer supported")); + } + + if (config.Pool == null) { + return ResultVoid.Error(new TaskConfigError("pool must be specified")); + } + + if (!CheckVal(definition.Vm.Compare, definition.Vm.Value, config.Pool!.Count)) { + var err = + $"invalid vm count: expected {definition.Vm.Compare} {definition.Vm.Value}, got {config.Pool.Count}"; + _logTracer.Error(err); + return ResultVoid.Error(new TaskConfigError(err)); + } + + var pool = await _context.PoolOperations.GetByName(config.Pool.PoolName); + if (!pool.IsOk) { + return ResultVoid.Error(new TaskConfigError($"invalid pool: {config.Pool.PoolName}")); + } + + var checkTarget = await CheckTargetExe(config, definition); + if (!checkTarget.IsOk) { + return checkTarget; + } + + if (definition.Features.Contains(TaskFeature.GeneratorExe)) { + var container = config.Containers!.First(x => x.Type == ContainerType.Tools); + + if (config.Task.GeneratorExe == null) { + return ResultVoid.Error(new TaskConfigError($"generator_exe is not defined")); + } + + var tool_paths = new[] { "{tools_dir}/", "{tools_dir}\\" }; + + foreach (var toolPath in tool_paths) { + if (config.Task.GeneratorExe.StartsWith(toolPath)) { + var generator = config.Task.GeneratorExe.Replace(toolPath, ""); + if (!await _containers.BlobExists(container.Name, generator, StorageType.Corpus)) { + var err = + $"generator_exe `{config.Task.GeneratorExe}` does not exist in the tools container `{container.Name}`"; + _logTracer.Error(err); + return ResultVoid.Error(new TaskConfigError(err)); + } + } + } + } + + if (definition.Features.Contains(TaskFeature.StatsFile)) { + if (config.Task.StatsFile != null && config.Task.StatsFormat == null) { + var err2 = "using a stats_file requires a stats_format"; + _logTracer.Error(err2); + return ResultVoid.Error(new TaskConfigError(err2)); + } + } + + return ResultVoid.Ok(); + + } + + private async Task> CheckTargetExe(TaskConfig config, TaskDefinition definition) { + if (config.Task.TargetExe == null) { + if (definition.Features.Contains(TaskFeature.TargetExe)) { + return ResultVoid.Error(new TaskConfigError("missing target_exe")); + } + + if (definition.Features.Contains(TaskFeature.TargetExeOptional)) { + return ResultVoid.Ok(); + } + return ResultVoid.Ok(); + } + + // User-submitted paths must be relative to the setup directory that contains them. + // They also must be normalized, and exclude special filesystem path elements. + // + // For example, accessing the blob store path "./foo" generates an exception, but + // "foo" and "foo/bar" do not. + + if (!IsValidBlobName(config.Task.TargetExe)) { + return ResultVoid.Error(new TaskConfigError("target_exe must be a canonicalized relative path")); + } + + + var container = config.Containers!.FirstOrDefault(x => x.Type == ContainerType.Setup); + if (container != null) { + if (!await _containers.BlobExists(container.Name, config.Task.TargetExe, StorageType.Corpus)) { + var err = + $"target_exe `{config.Task.TargetExe}` does not exist in the setup container `{container.Name}`"; + + _logTracer.Warning(err); + } + } + + return ResultVoid.Ok(); + } + + + + // Azure Blob Storage uses a flat scheme, and has no true directory hierarchy. Forward + // slashes are used to delimit a _virtual_ directory structure. + private static bool IsValidBlobName(string blobName) { + // https://docs.microsoft.com/en-us/rest/api/storageservices/naming-and-referencing-containers--blobs--and-metadata#blob-names + const int MIN_LENGTH = 1; + const int MAX_LENGTH = 1024; // inclusive + const int MAX_PATH_SEGMENTS = 254; + + var length = blobName.Length; + + // No leading/trailing whitespace. + if (blobName != blobName.Trim()) { + return false; + } + + if (length < MIN_LENGTH) { + return false; + } + + if (length > MAX_LENGTH) { + return false; + } + + var segments = blobName.Split(new[] { Path.AltDirectorySeparatorChar, Path.DirectorySeparatorChar }); + + if (segments.Length > MAX_PATH_SEGMENTS) { + return false; + } + + // No path segment should end with a dot (`.`). + if (segments.Any(s => s.EndsWith('.'))) { + return false; + } + + // Reject absolute paths to avoid confusion. + if (Path.IsPathRooted(blobName)) { + return false; + } + + // Reject paths with special relative filesystem entries. + if (segments.Contains(".")) { + return false; + } + + if (segments.Contains("..")) { + return false; + } + + return true; + } + + private static bool TargetUsesInput(TaskConfig config) { + if (config.Task.TargetOptions != null) { + if (config.Task.TargetOptions.Any(x => x.Contains("{input}"))) + return true; + } + + if (config.Task.TargetEnv != null) { + if (config.Task.TargetEnv.Values.Any(x => x.Contains("{input}"))) + return true; + } + return false; + } + + private async Task> CheckContainers(TaskDefinition definition, TaskConfig config) { + + if (config.Containers == null) { + return ResultVoid.Ok(); + } + + var exist = new HashSet(); + var containers = new Dictionary>(); + + foreach (var container in config.Containers) { + if (exist.Contains(container.Name.ContainerName)) { + continue; + } + if (await _containers.FindContainer(container.Name, StorageType.Corpus) == null) { + return ResultVoid.Error(new TaskConfigError($"missing container: {container.Name}")); + } + exist.Add(container.Name.ContainerName); + + if (!containers.ContainsKey(container.Type)) { + containers.Add(container.Type, new List()); + } + containers[container.Type].Add(container.Name); + } + + foreach (var containerDef in definition.Containers) { + var r = CheckContainer(containerDef.Compare, containerDef.Value, containerDef.Type, containers); + if (!r.IsOk) { + return r; + } + } + + var containerTypes = definition.Containers.Select(x => x.Type).ToHashSet(); + var missing = containers.Keys.Where(x => !containerTypes.Contains(x)).ToList(); + if (missing.Any()) { + var types = string.Join(", ", missing); + return ResultVoid.Error(new TaskConfigError($"unsupported container types for this task: {types}")); + } + + if (definition.MonitorQueue != null) { + if (!containerTypes.Contains(definition.MonitorQueue.Value)) { + return ResultVoid.Error(new TaskConfigError($"unable to monitor container type as it is not used by this task: {definition.MonitorQueue}")); + } + } + + return ResultVoid.Ok(); + } + + private static ResultVoid CheckContainer(Compare compare, long expected, ContainerType containerType, Dictionary> containers) { + var actual = containers.ContainsKey(containerType) ? containers[containerType].Count : 0; + + if (!CheckVal(compare, expected, actual)) { + return ResultVoid.Error( + new TaskConfigError($"container type {containerType}: expected {compare} {expected}, got {actual}")); + } + + return ResultVoid.Ok(); + } + + private static bool CheckVal(Compare compare, long expected, long actual) { + return compare switch { + Compare.Equal => expected == actual, + Compare.AtLeast => expected <= actual, + Compare.AtMost => expected >= actual, + _ => throw new NotSupportedException() + }; + } } diff --git a/src/ApiService/ApiService/onefuzzlib/Creds.cs b/src/ApiService/ApiService/onefuzzlib/Creds.cs index 0fbd7977f..550f2fc55 100644 --- a/src/ApiService/ApiService/onefuzzlib/Creds.cs +++ b/src/ApiService/ApiService/onefuzzlib/Creds.cs @@ -36,7 +36,7 @@ public interface ICreds { Async.Task> GetRegions(); } -public class Creds : ICreds { +public sealed class Creds : ICreds, IDisposable { private readonly ArmClient _armClient; private readonly DefaultAzureCredential _azureCredential; private readonly IServiceConfig _config; @@ -51,6 +51,7 @@ public class Creds : ICreds { _cache = cache; _azureCredential = new DefaultAzureCredential(); _armClient = new ArmClient(this.GetIdentity(), this.GetSubscription()); + } public DefaultAzureCredential GetIdentity() { @@ -89,12 +90,14 @@ public class Creds : ICreds { return ArmClient.GetResourceGroupResource(resourceId); } - public async Async.Task GetBaseRegion() { - var rg = await ArmClient.GetResourceGroupResource(GetResourceGroupResourceIdentifier()).GetAsync(); - if (rg.GetRawResponse().IsError) { - throw new Exception($"Failed to get base region due to [{rg.GetRawResponse().Status}] {rg.GetRawResponse().ReasonPhrase}"); - } - return rg.Value.Data.Location.Name; + public Async.Task GetBaseRegion() { + return _cache.GetOrCreateAsync(nameof(GetBaseRegion), async _ => { + var rg = await ArmClient.GetResourceGroupResource(GetResourceGroupResourceIdentifier()).GetAsync(); + if (rg.GetRawResponse().IsError) { + throw new Exception($"Failed to get base region due to [{rg.GetRawResponse().Status}] {rg.GetRawResponse().ReasonPhrase}"); + } + return rg.Value.Data.Location.Name; + }); } public Uri GetInstanceUrl() @@ -103,14 +106,15 @@ public class Creds : ICreds { public record ScaleSetIdentity(string principalId); - public async Async.Task GetScalesetPrincipalId() { - var path = GetScalesetIdentityResourcePath(); - var uid = ArmClient.GetGenericResource(new ResourceIdentifier(path)); + public Async.Task GetScalesetPrincipalId() { + return _cache.GetOrCreateAsync(nameof(GetScalesetPrincipalId), async entry => { + var path = GetScalesetIdentityResourcePath(); + var uid = ArmClient.GetGenericResource(new ResourceIdentifier(path)); - - var resource = await uid.GetAsync(); - var principalId = resource.Value.Data.Properties.ToObjectFromJson().principalId; - return new Guid(principalId); + var resource = await uid.GetAsync(); + var principalId = resource.Value.Data.Properties.ToObjectFromJson().principalId; + return new Guid(principalId); + }); } public string GetScalesetIdentityResourcePath() { @@ -125,6 +129,7 @@ public class Creds : ICreds { private static readonly Uri _graphResource = new("https://graph.microsoft.com"); private static readonly Uri _graphResourceEndpoint = new("https://graph.microsoft.com/v1.0"); + public async Task QueryMicrosoftGraph(HttpMethod method, string resource) { var cred = GetIdentity(); @@ -166,6 +171,10 @@ public class Creds : ICreds { return resource; } + public void Dispose() { + throw new NotImplementedException(); + } + public Task> GetRegions() => _cache.GetOrCreateAsync>( nameof(Creds) + "." + nameof(GetRegions), diff --git a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs index b6bdcadfc..838bd63ed 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs @@ -459,7 +459,7 @@ public class NodeOperations : StatefulOrm, INod public interface INodeTasksOperations : IStatefulOrm { IAsyncEnumerable GetNodesByTaskId(Guid taskId); - IAsyncEnumerable GetNodeAssignments(Guid taskId, INodeOperations nodeOps); + IAsyncEnumerable GetNodeAssignments(Guid taskId); IAsyncEnumerable GetByMachineId(Guid machineId); IAsyncEnumerable GetByTaskId(Guid taskId); Async.Task ClearByMachineId(Guid machineId); @@ -485,7 +485,7 @@ public class NodeTasksOperations : StatefulOrm GetNodeAssignments(Guid taskId, INodeOperations nodeOps) { + public async IAsyncEnumerable GetNodeAssignments(Guid taskId) { await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) { var node = await _context.NodeOperations.GetByMachineId(entry.MachineId); diff --git a/src/ApiService/ApiService/onefuzzlib/Queue.cs b/src/ApiService/ApiService/onefuzzlib/Queue.cs index be3b922de..c21e13c47 100644 --- a/src/ApiService/ApiService/onefuzzlib/Queue.cs +++ b/src/ApiService/ApiService/onefuzzlib/Queue.cs @@ -118,7 +118,7 @@ public class Queue : IQueue { return true; } } - return false; ; + return false; } public async Task> PeekQueue(string name, StorageType storageType) { diff --git a/src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs b/src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs index 539d13d1e..b001dc242 100644 --- a/src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs @@ -3,9 +3,26 @@ namespace Microsoft.OneFuzz.Service; public interface ITaskEventOperations : IOrm { + IAsyncEnumerable GetSummary(Guid taskId); } public sealed class TaskEventOperations : Orm, ITaskEventOperations { public TaskEventOperations(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) { } + + public IAsyncEnumerable GetSummary(Guid taskId) { + return + SearchByPartitionKeys(new[] { $"{taskId}" }) + .OrderBy(x => x.TimeStamp ?? DateTimeOffset.MaxValue) + .Select(x => new TaskEventSummary(x.TimeStamp, GetEventData(x.EventData), GetEventType(x.EventData))); + } + + private static string GetEventData(WorkerEvent ev) { + return ev.Done != null ? $"exit status: {ev.Done.ExitStatus}" : + ev.Running != null ? string.Empty : "Unrecognized event: {ev}"; + } + + private static string GetEventType(WorkerEvent ev) { + return ev.GetType().Name; + } } diff --git a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs index 00710649f..db1a4dddc 100644 --- a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs @@ -1,4 +1,5 @@ -using ApiService.OneFuzzLib.Orm; +using System.Threading.Tasks; +using ApiService.OneFuzzLib.Orm; namespace Microsoft.OneFuzz.Service; @@ -24,6 +25,7 @@ public interface ITaskOperations : IStatefulOrm { Async.Task CheckPrereqTasks(Task task); Async.Task GetPool(Task task); Async.Task SetState(Task task, TaskState state); + Async.Task> Create(TaskConfig config, Guid jobId, UserInfo userInfo); } public class TaskOperations : StatefulOrm, ITaskOperations { @@ -164,6 +166,35 @@ public class TaskOperations : StatefulOrm, ITas return task; } + public async Task> Create(TaskConfig config, Guid jobId, UserInfo userInfo) { + + Os os; + if (config.Vm != null) { + var osResult = await _context.ImageOperations.GetOs(config.Vm.Region, config.Vm.Image); + if (!osResult.IsOk) { + return OneFuzzResult.Error(osResult.ErrorV); + } + os = osResult.OkV; + } else if (config.Pool != null) { + var pool = await _context.PoolOperations.GetByName(config.Pool.PoolName); + + if (!pool.IsOk) { + return OneFuzzResult.Error(pool.ErrorV); + } + os = pool.OkV.Os; + } else { + return OneFuzzResult.Error(new Error(ErrorCode.INVALID_CONFIGURATION, new[] { "task must have vm or pool" })); + } + + var task = new Task(jobId, Guid.NewGuid(), TaskState.Init, os, config, UserInfo: userInfo); + + await _context.TaskOperations.Insert(task); + await _context.Events.SendEvent(new EventTaskCreated(jobId, task.TaskId, config, userInfo)); + + _logTracer.Info($"created task. job_id:{jobId} task_id:{task.TaskId} type:{task.Config.Task.Type}"); + return OneFuzzResult.Ok(task); + } + private async Async.Task OnStart(Task task) { if (task.EndTime == null) { task = task with { EndTime = DateTimeOffset.UtcNow + TimeSpan.FromHours(task.Config.Task.Duration) }; diff --git a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs index 353f10ea4..8caa9725d 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs @@ -17,7 +17,7 @@ public abstract record EntityBase { public static string NewSortedKey => $"{DateTimeOffset.MaxValue.Ticks - DateTimeOffset.UtcNow.Ticks}"; } -public abstract record StatefulEntityBase([property: JsonIgnore] T State) : EntityBase() where T : Enum; +public abstract record StatefulEntityBase([property: JsonIgnore] T BaseState) : EntityBase() where T : Enum; diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs index c006e9d12..28268f998 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs @@ -205,7 +205,7 @@ namespace ApiService.OneFuzzLib.Orm { /// /// public async Async.Task ProcessStateUpdate(T entity) { - TState state = entity.State; + TState state = entity.BaseState; var func = GetType().GetMethod(state.ToString()) switch { null => null, MethodInfo info => info.CreateDelegate(this) @@ -227,13 +227,13 @@ namespace ApiService.OneFuzzLib.Orm { /// public async Async.Task ProcessStateUpdates(T entity, int MaxUpdates = 5) { for (int i = 0; i < MaxUpdates; i++) { - var state = entity.State; + var state = entity.BaseState; var newEntity = await ProcessStateUpdate(entity); if (newEntity == null) return null; - if (newEntity.State.Equals(state)) { + if (newEntity.BaseState.Equals(state)) { return newEntity; } } diff --git a/src/ApiService/Tests/OrmModelsTest.cs b/src/ApiService/Tests/OrmModelsTest.cs index d713c82cb..f686b4424 100644 --- a/src/ApiService/Tests/OrmModelsTest.cs +++ b/src/ApiService/Tests/OrmModelsTest.cs @@ -268,7 +268,7 @@ namespace Tests { InstanceName: arg.Item4, WebhookId: arg.Item5 ) - ); ; + ); } public static Gen WebhookMessageEventGrid() {