diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index 38876528c..ae866b878 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -184,7 +184,8 @@ public record TaskDetails( bool? PreserveExistingOutputs = null, List? ReportList = null, int? MinimizedStackDepth = null, - string? CoverageFilter = null); + string? CoverageFilter = null +); public record TaskVm( Region Region, @@ -214,7 +215,8 @@ public record TaskConfig( List? Containers = null, Dictionary? Tags = null, List? Debug = null, - bool? Colocate = null); + bool? Colocate = null + ); public record TaskEventSummary( DateTimeOffset? Timestamp, @@ -590,9 +592,29 @@ public record WorkUnit( Guid JobId, Guid TaskId, TaskType TaskType, - TaskUnitConfig Config + + // JSON-serialized `TaskUnitConfig`. + [property: JsonConverter(typeof(TaskUnitConfigConverter))] TaskUnitConfig Config ); +public class TaskUnitConfigConverter : JsonConverter { + public override TaskUnitConfig? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { + var taskUnitString = reader.GetString(); + if (taskUnitString == null) { + return null; + } + return JsonSerializer.Deserialize(taskUnitString, options); + } + + public override void Write(Utf8JsonWriter writer, TaskUnitConfig value, JsonSerializerOptions options) { + var v = JsonSerializer.Serialize(value, new JsonSerializerOptions(options) { + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }); + + writer.WriteStringValue(v); + } +} + public record VmDefinition( Compare Compare, int Value @@ -625,14 +647,36 @@ public record ContainerDefinition( // TODO: service shouldn't pass SyncedDir, but just the url and let the agent // come up with paths -public record SyncedDir(string Path, Uri url); +public record SyncedDir(string Path, Uri Url); +[JsonConverter(typeof(ContainerDefConverter))] public interface IContainerDef { } public record SingleContainer(SyncedDir SyncedDir) : IContainerDef; public record MultipleContainer(List SyncedDirs) : IContainerDef; +public class ContainerDefConverter : JsonConverter { + public override IContainerDef? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { + throw new NotImplementedException(); + } + + public override void Write(Utf8JsonWriter writer, IContainerDef value, JsonSerializerOptions options) { + switch (value) { + case SingleContainer container: + JsonSerializer.Serialize(writer, container.SyncedDir, options); + break; + case MultipleContainer { SyncedDirs: var syncedDirs }: + JsonSerializer.Serialize(writer, syncedDirs, options); + break; + default: + throw new NotImplementedException(); + } + } +} + + + public record TaskUnitConfig( Guid InstanceId, Guid JobId, @@ -686,7 +730,7 @@ public record TaskUnitConfig( public IContainerDef? Tools { get; set; } public IContainerDef? UniqueInputs { get; set; } public IContainerDef? UniqueReports { get; set; } - public IContainerDef? RegressionReport { get; set; } + public IContainerDef? RegressionReports { get; set; } } diff --git a/src/ApiService/ApiService/OneFuzzTypes/Validated.cs b/src/ApiService/ApiService/OneFuzzTypes/Validated.cs index 965321d45..d6d0991a2 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Validated.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Validated.cs @@ -1,6 +1,4 @@ - -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.CodeAnalysis; using System.Text.Json; using System.Text.Json.Serialization; using System.Text.RegularExpressions; @@ -51,7 +49,7 @@ public abstract class ValidatedStringConverter : JsonConverter where T : V [JsonConverter(typeof(Converter))] public record PoolName : ValidatedString { private PoolName(string value) : base(value) { - Debug.Assert(Check.IsAlnumDash(value)); + // Debug.Assert(Check.IsAlnumDash(value)); } public static PoolName Parse(string input) { @@ -63,10 +61,14 @@ public record PoolName : ValidatedString { } public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) { - if (!Check.IsAlnumDash(input)) { - result = default; - return false; - } + + // bypassing the validation because this code has a stricter validation than the python equivalent + // see (issue #2080) + + // if (!Check.IsAlnumDash(input)) { + // result = default; + // return false; + // } result = new PoolName(input); return true; diff --git a/src/ApiService/ApiService/TimerTasks.cs b/src/ApiService/ApiService/TimerTasks.cs index 6f066de26..f6960ee1a 100644 --- a/src/ApiService/ApiService/TimerTasks.cs +++ b/src/ApiService/ApiService/TimerTasks.cs @@ -21,8 +21,8 @@ public class TimerTasks { _scheduler = scheduler; } - //[Function("TimerTasks")] - public async Async.Task Run([TimerTrigger("1.00:00:00")] TimerInfo myTimer) { + [Function("TimerTasks")] + public async Async.Task Run([TimerTrigger("00:00:15")] TimerInfo myTimer) { var expriredTasks = _taskOperations.SearchExpired(); await foreach (var task in expriredTasks) { _logger.Info($"stopping expired task. job_id:{task.JobId} task_id:{task.TaskId}"); diff --git a/src/ApiService/ApiService/onefuzzlib/Config.cs b/src/ApiService/ApiService/onefuzzlib/Config.cs index 70f22f320..a86c2a8a4 100644 --- a/src/ApiService/ApiService/onefuzzlib/Config.cs +++ b/src/ApiService/ApiService/onefuzzlib/Config.cs @@ -5,7 +5,6 @@ namespace Microsoft.OneFuzz.Service; public interface IConfig { - string GetSetupContainer(TaskConfig config); Async.Task BuildTaskConfig(Job job, Task task); } @@ -89,9 +88,13 @@ public class Config : IConfig { await foreach (var data in containersByType) { + if (!data.containers.Any()) { + continue; + } + IContainerDef def = data.countainerDef switch { ContainerDefinition { Compare: Compare.Equal, Value: 1 } or - ContainerDefinition { Compare: Compare.AtMost, Value: 1 } => new SingleContainer(data.containers[0]), + ContainerDefinition { Compare: Compare.AtMost, Value: 1 } when data.containers.Count == 1 => new SingleContainer(data.containers[0]), _ => new MultipleContainer(data.containers) }; @@ -126,6 +129,9 @@ public class Config : IConfig { case ContainerType.UniqueReports: config.UniqueReports = def; break; + case ContainerType.RegressionReports: + config.RegressionReports = def; + break; } } @@ -249,16 +255,4 @@ public class Config : IConfig { return config; } - - - public string GetSetupContainer(TaskConfig config) { - - foreach (var container in config.Containers ?? throw new Exception("Missing containers")) { - if (container.Type == ContainerType.Setup) { - return container.Name.ContainerName; - } - } - - throw new Exception($"task missing setup container: task_type = {config.Task.Type}"); - } } diff --git a/src/ApiService/ApiService/onefuzzlib/Containers.cs b/src/ApiService/ApiService/onefuzzlib/Containers.cs index f28286e8f..30e8a6e46 100644 --- a/src/ApiService/ApiService/onefuzzlib/Containers.cs +++ b/src/ApiService/ApiService/onefuzzlib/Containers.cs @@ -164,14 +164,15 @@ public class Containers : IContainers { if (uri.Query.Contains("sig")) { return uri; } - - var accountName = uri.Host.Split('.')[0]; - var (_, accountKey) = await _storage.GetStorageAccountNameAndKey(accountName); + var blobUriBuilder = new BlobUriBuilder(uri); + var accountKey = await _storage.GetStorageAccountNameKeyByName(blobUriBuilder.AccountName); var sasBuilder = new BlobSasBuilder( BlobContainerSasPermissions.Read | BlobContainerSasPermissions.Write | BlobContainerSasPermissions.Delete | BlobContainerSasPermissions.List, - DateTimeOffset.UtcNow + TimeSpan.FromHours(1)); + DateTimeOffset.UtcNow + TimeSpan.FromHours(1)) { + BlobContainerName = blobUriBuilder.BlobContainerName, + }; - var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(accountName, accountKey)).ToString(); + var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(blobUriBuilder.AccountName, accountKey)).ToString(); return new UriBuilder(uri) { Query = sas }.Uri; @@ -179,13 +180,10 @@ public class Containers : IContainers { public async Async.Task GetContainerSasUrl(Container container, StorageType storageType, BlobContainerSasPermissions permissions, TimeSpan? duration = null) { var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}"); - var (accountName, accountKey) = await _storage.GetStorageAccountNameAndKey(client.AccountName); - var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30)); - var sasBuilder = new BlobSasBuilder(permissions, endTime) { StartsOn = startTime, - BlobContainerName = container.ContainerName, + BlobContainerName = container.ContainerName }; var sasUrl = client.GenerateSasUri(sasBuilder); diff --git a/src/ApiService/ApiService/onefuzzlib/JobOperations.cs b/src/ApiService/ApiService/onefuzzlib/JobOperations.cs index 779cf1d2e..772607f5c 100644 --- a/src/ApiService/ApiService/onefuzzlib/JobOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/JobOperations.cs @@ -6,12 +6,13 @@ public interface IJobOperations : IStatefulOrm { Async.Task Get(Guid jobId); Async.Task OnStart(Job job); IAsyncEnumerable SearchExpired(); - Async.Task Stopping(Job job); + Async.Task Stopping(Job job); IAsyncEnumerable SearchState(IEnumerable states); Async.Task StopNeverStartedJobs(); + Async.Task StopIfAllDone(Job job); } -public class JobOperations : StatefulOrm, IJobOperations { +public class JobOperations : StatefulOrm, IJobOperations { private static TimeSpan JOB_NEVER_STARTED_DURATION = TimeSpan.FromDays(30); public JobOperations(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) { @@ -36,6 +37,17 @@ public class JobOperations : StatefulOrm, IJobOperations { return QueryAsync(filter: query); } + public async Async.Task StopIfAllDone(Job job) { + var anyNotStoppedJobs = await _context.TaskOperations.GetByJobId(job.JobId).AnyAsync(task => task.State != TaskState.Stopped); + + if (anyNotStoppedJobs) { + return; + } + + _logTracer.Info($"stopping job as all tasks are stopped: {job.JobId}"); + await Stopping(job); + } + public async Async.Task StopNeverStartedJobs() { // # Note, the "not(end_time...)" with end_time set long before the use of // # OneFuzz enables identifying those without end_time being set. @@ -58,7 +70,7 @@ public class JobOperations : StatefulOrm, IJobOperations { } } - public async Async.Task Stopping(Job job) { + public async Async.Task Stopping(Job job) { job = job with { State = JobState.Stopping }; var tasks = await _context.TaskOperations.QueryAsync(filter: $"job_id eq '{job.JobId}'").ToListAsync(); var taskNotStopped = tasks.ToLookup(task => task.State != TaskState.Stopped); @@ -76,7 +88,12 @@ public class JobOperations : StatefulOrm, IJobOperations { await _context.Events.SendEvent(new EventJobStopped(job.JobId, job.Config, job.UserInfo, taskInfo)); } - await Replace(job); + var result = await Replace(job); + if (result.IsOk) { + return job; + } else { + throw new Exception($"Failed to save job {job.JobId} : {result.ErrorV}"); + } } } diff --git a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs index 1f12f83db..633494c0c 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs @@ -43,6 +43,8 @@ public interface INodeOperations : IStatefulOrm { Async.Task MarkTasksStoppedEarly(Node node, Error? error = null); static TimeSpan NODE_EXPIRATION_TIME = TimeSpan.FromHours(1.0); static TimeSpan NODE_REIMAGE_TIME = TimeSpan.FromDays(6.0); + + Async.Task StopTask(Guid task_id); } @@ -51,7 +53,7 @@ public interface INodeOperations : IStatefulOrm { /// Enabling autoscaling for the scalesets based on the pool work queues. /// https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics -public class NodeOperations : StatefulOrm, INodeOperations { +public class NodeOperations : StatefulOrm, INodeOperations { public NodeOperations( @@ -269,7 +271,7 @@ public class NodeOperations : StatefulOrm, INodeOperations { public async Async.Task GetByMachineId(Guid machineId) { - var data = QueryAsync(filter: $"RowKey eq '{machineId}'"); + var data = QueryAsync(filter: Query.RowKey(machineId.ToString())); return await data.FirstOrDefaultAsync(); } @@ -387,18 +389,49 @@ public class NodeOperations : StatefulOrm, INodeOperations { await _context.Events.SendEvent(new EventNodeDeleted(node.MachineId, node.ScalesetId, node.PoolName, node.State)); } + public async Async.Task StopTask(Guid task_id) { + // For now, this just re-images the node. Eventually, this + // should send a message to the node to let the agent shut down + // gracefully + + var nodes = _context.NodeTasksOperations.GetNodesByTaskId(task_id); + + await foreach (var node in nodes) { + await _context.NodeMessageOperations.SendMessage(node.MachineId, new NodeCommand(StopTask: new StopTaskNodeCommand(task_id))); + + if (!(await StopIfComplete(node))) { + _logTracer.Info($"nodes: stopped task on node, but not reimaging due to other tasks: task_id:{task_id} machine_id:{node.MachineId}"); + } + } + + } + + /// returns True on stopping the node and False if this doesn't stop the node + private async Task StopIfComplete(Node node, bool done = false) { + var nodeTaskIds = await _context.NodeTasksOperations.GetByMachineId(node.MachineId).Select(nt => nt.TaskId).ToArrayAsync(); + var tasks = _context.TaskOperations.GetByTaskIds(nodeTaskIds); + await foreach (var task in tasks) { + if (!TaskStateHelper.ShuttingDown(task.State)) { + return false; + } + } + _logTracer.Info($"node: stopping busy node with all tasks complete: {node.MachineId}"); + + await Stop(node, done: done); + return true; + } } public interface INodeTasksOperations : IStatefulOrm { - IAsyncEnumerable GetNodesByTaskId(Guid taskId, INodeOperations nodeOps); + IAsyncEnumerable GetNodesByTaskId(Guid taskId); IAsyncEnumerable GetNodeAssignments(Guid taskId, INodeOperations nodeOps); IAsyncEnumerable GetByMachineId(Guid machineId); IAsyncEnumerable GetByTaskId(Guid taskId); Async.Task ClearByMachineId(Guid machineId); } -public class NodeTasksOperations : StatefulOrm, INodeTasksOperations { +public class NodeTasksOperations : StatefulOrm, INodeTasksOperations { ILogTracer _log; @@ -408,18 +441,20 @@ public class NodeTasksOperations : StatefulOrm, INodeT } //TODO: suggest by Cheick: this can probably be optimize by query all NodesTasks then query the all machine in single request - public async IAsyncEnumerable GetNodesByTaskId(Guid taskId, INodeOperations nodeOps) { - await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) { - var node = await nodeOps.GetByMachineId(entry.MachineId); + + public async IAsyncEnumerable GetNodesByTaskId(Guid taskId) { + await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) { + var node = await _context.NodeOperations.GetByMachineId(entry.MachineId); if (node is not null) { yield return node; } } } + public async IAsyncEnumerable GetNodeAssignments(Guid taskId, INodeOperations nodeOps) { - await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) { - var node = await nodeOps.GetByMachineId(entry.MachineId); + await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) { + var node = await _context.NodeOperations.GetByMachineId(entry.MachineId); if (node is not null) { var nodeAssignment = new NodeAssignment(node.MachineId, node.ScalesetId, entry.State); yield return nodeAssignment; @@ -428,11 +463,11 @@ public class NodeTasksOperations : StatefulOrm, INodeT } public IAsyncEnumerable GetByMachineId(Guid machineId) { - return QueryAsync($"machine_id eq '{machineId}'"); + return QueryAsync(Query.PartitionKey(machineId.ToString())); } public IAsyncEnumerable GetByTaskId(Guid taskId) { - return QueryAsync($"task_id eq '{taskId}'"); + return QueryAsync(Query.RowKey(taskId.ToString())); } public async Async.Task ClearByMachineId(Guid machineId) { @@ -472,7 +507,7 @@ public class NodeMessageOperations : Orm, INodeMessageOperations { } public IAsyncEnumerable GetMessage(Guid machineId) - => QueryAsync(Query.PartitionKey(machineId)); + => QueryAsync(Query.PartitionKey(machineId.ToString())); public async Async.Task ClearMessages(Guid machineId) { _logTracer.Info($"clearing messages for node {machineId}"); diff --git a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs index 67153a1f0..9f291285e 100644 --- a/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/PoolOperations.cs @@ -9,7 +9,7 @@ public interface IPoolOperations { IAsyncEnumerable GetByClientId(Guid clientId); } -public class PoolOperations : StatefulOrm, IPoolOperations { +public class PoolOperations : StatefulOrm, IPoolOperations { public PoolOperations(ILogTracer log, IOnefuzzContext context) : base(log, context) { diff --git a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs index dcc8c7308..3c3d7716f 100644 --- a/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ProxyOperations.cs @@ -13,7 +13,7 @@ public interface IProxyOperations : IStatefulOrm { bool IsOutdated(Proxy proxy); Async.Task GetOrCreate(string region); } -public class ProxyOperations : StatefulOrm, IProxyOperations { +public class ProxyOperations : StatefulOrm, IProxyOperations { static TimeSpan PROXY_LIFESPAN = TimeSpan.FromDays(7); diff --git a/src/ApiService/ApiService/onefuzzlib/Queue.cs b/src/ApiService/ApiService/onefuzzlib/Queue.cs index 99a476ac0..8ddaa1165 100644 --- a/src/ApiService/ApiService/onefuzzlib/Queue.cs +++ b/src/ApiService/ApiService/onefuzzlib/Queue.cs @@ -51,7 +51,7 @@ public class Queue : IQueue { var accountId = _storage.GetPrimaryAccount(storageType); _log.Verbose($"getting blob container (account_id: {accountId})"); var (name, key) = await _storage.GetStorageAccountNameAndKey(accountId); - var endpoint = _storage.GetQueueEndpoint(accountId); + var endpoint = _storage.GetQueueEndpoint(name); var options = new QueueClientOptions { MessageEncoding = QueueMessageEncoding.Base64 }; return new QueueServiceClient(endpoint, new StorageSharedKeyCredential(name, key), options); } diff --git a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs index c45c07c45..b54408c0f 100644 --- a/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ReproOperations.cs @@ -10,7 +10,7 @@ public interface IReproOperations : IStatefulOrm { public IAsyncEnumerable SearchStates(IEnumerable? States); } -public class ReproOperations : StatefulOrm, IReproOperations { +public class ReproOperations : StatefulOrm, IReproOperations { private static readonly Dictionary DEFAULT_OS = new Dictionary { {Os.Linux, "Canonical:UbuntuServer:18.04-LTS:latest"}, diff --git a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs index 0faee3795..59a546bae 100644 --- a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs @@ -14,7 +14,7 @@ public interface IScalesetOperations : IOrm { IAsyncEnumerable GetByObjectId(Guid objectId); } -public class ScalesetOperations : StatefulOrm, IScalesetOperations { +public class ScalesetOperations : StatefulOrm, IScalesetOperations { const string SCALESET_LOG_PREFIX = "scalesets: "; ILogTracer _log; diff --git a/src/ApiService/ApiService/onefuzzlib/Scheduler.cs b/src/ApiService/ApiService/onefuzzlib/Scheduler.cs index 1d4bc7114..55322078f 100644 --- a/src/ApiService/ApiService/onefuzzlib/Scheduler.cs +++ b/src/ApiService/ApiService/onefuzzlib/Scheduler.cs @@ -74,7 +74,7 @@ public class Scheduler : IScheduler { private async Async.Task<(BucketConfig, WorkSet)?> BuildWorkSet(Task[] tasks) { var taskIds = tasks.Select(x => x.TaskId).ToHashSet(); - var work_units = new List(); + var workUnits = new List(); BucketConfig? bucketConfig = null; foreach (var task in tasks) { @@ -99,7 +99,7 @@ public class Scheduler : IScheduler { throw new Exception($"bucket configs differ: {bucketConfig} VS {result.Value.Item1}"); } - work_units.Add(result.Value.Item2); + workUnits.Add(result.Value.Item2); } if (bucketConfig != null) { @@ -108,7 +108,7 @@ public class Scheduler : IScheduler { Reboot: bucketConfig.reboot, Script: bucketConfig.setupScript != null, SetupUrl: setupUrl, - WorkUnits: work_units + WorkUnits: workUnits ); return (bucketConfig, workSet); @@ -182,9 +182,9 @@ public class Scheduler : IScheduler { return (bucketConfig, workUnit); } - record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique); + public record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique); - private ILookup BucketTasks(IEnumerable tasks) { + public static ILookup BucketTasks(IEnumerable tasks) { // buckets are hashed by: // OS, JOB ID, vm sku & image (if available), pool name (if available), @@ -214,8 +214,19 @@ public class Scheduler : IScheduler { unique = Guid.NewGuid(); } - return new BucketId(task.Os, task.JobId, vm, pool, _config.GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique); + return new BucketId(task.Os, task.JobId, vm, pool, GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique); }); } + + static string GetSetupContainer(TaskConfig config) { + + foreach (var container in config.Containers ?? throw new Exception("Missing containers")) { + if (container.Type == ContainerType.Setup) { + return container.Name.ContainerName; + } + } + + throw new Exception($"task missing setup container: task_type = {config.Task.Type}"); + } } diff --git a/src/ApiService/ApiService/onefuzzlib/Storage.cs b/src/ApiService/ApiService/onefuzzlib/Storage.cs index 0db103bd2..44db128c9 100644 --- a/src/ApiService/ApiService/onefuzzlib/Storage.cs +++ b/src/ApiService/ApiService/onefuzzlib/Storage.cs @@ -20,9 +20,9 @@ public interface IStorage { public Uri GetBlobEndpoint(string accountId); - public Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId); + public Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId); - public Async.Task GetStorageAccountNameAndKeyByName(string accountName); + public Async.Task GetStorageAccountNameKeyByName(string accountName); public IEnumerable GetAccounts(StorageType storageType); } @@ -99,16 +99,16 @@ public class Storage : IStorage { }; } - public async Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId) { + public async Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId) { var resourceId = new ResourceIdentifier(accountId); var armClient = GetMgmtClient(); var storageAccount = armClient.GetStorageAccountResource(resourceId); var keys = await storageAccount.GetKeysAsync(); - var key = keys.Value.Keys.FirstOrDefault(); - return (resourceId.Name, key?.Value); + var key = keys.Value.Keys.FirstOrDefault() ?? throw new Exception("no keys found"); + return (resourceId.Name, key.Value); } - public async Async.Task GetStorageAccountNameAndKeyByName(string accountName) { + public async Async.Task GetStorageAccountNameKeyByName(string accountName) { var armClient = GetMgmtClient(); var resourceGroup = _creds.GetResourceGroupResourceIdentifier(); var storageAccount = await armClient.GetResourceGroupResource(resourceGroup).GetStorageAccountAsync(accountName); diff --git a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs index b364f1633..01e32d8e4 100644 --- a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs @@ -5,6 +5,10 @@ namespace Microsoft.OneFuzz.Service; public interface ITaskOperations : IStatefulOrm { Async.Task GetByTaskId(Guid taskId); + IAsyncEnumerable GetByTaskIds(IEnumerable taskId); + + IAsyncEnumerable GetByJobId(Guid jobId); + Async.Task GetByJobIdAndTaskId(Guid jobId, Guid taskId); @@ -22,7 +26,7 @@ public interface ITaskOperations : IStatefulOrm { Async.Task SetState(Task task, TaskState state); } -public class TaskOperations : StatefulOrm, ITaskOperations { +public class TaskOperations : StatefulOrm, ITaskOperations { public TaskOperations(ILogTracer log, IOnefuzzContext context) @@ -31,9 +35,15 @@ public class TaskOperations : StatefulOrm, ITaskOperations { } public async Async.Task GetByTaskId(Guid taskId) { - var data = QueryAsync(filter: $"RowKey eq '{taskId}'"); + return await GetByTaskIds(new[] { taskId }).FirstOrDefaultAsync(); + } - return await data.FirstOrDefaultAsync(); + public IAsyncEnumerable GetByTaskIds(IEnumerable taskId) { + return QueryAsync(filter: Query.RowKeys(taskId.Select(t => t.ToString()))); + } + + public IAsyncEnumerable GetByJobId(Guid jobId) { + return QueryAsync(filter: $"PartitionKey eq '{jobId}'"); } public async Async.Task GetByJobIdAndTaskId(Guid jobId, Guid taskId) { @@ -42,21 +52,13 @@ public class TaskOperations : StatefulOrm, ITaskOperations { return await data.FirstOrDefaultAsync(); } public IAsyncEnumerable SearchStates(Guid? jobId = null, IEnumerable? states = null) { - var queryString = String.Empty; - if (jobId != null) { - queryString += $"PartitionKey eq '{jobId}'"; - } - - if (states != null) { - if (jobId != null) { - queryString += " and "; - } - - queryString += "(" + string.Join( - " or ", - states.Select(s => $"state eq '{s}'") - ) + ")"; - } + var queryString = + (jobId, states) switch { + (null, null) => "", + (Guid id, null) => Query.PartitionKey($"{id}"), + (null, IEnumerable s) => Query.EqualAnyEnum("state", s), + (Guid id, IEnumerable s) => Query.And(Query.PartitionKey($"{id}"), Query.EqualAnyEnum("state", s)), + }; return QueryAsync(filter: queryString); } @@ -209,7 +211,7 @@ public class TaskOperations : StatefulOrm, ITaskOperations { // if a prereq task fails, then mark this task as failed if (t == null) { - await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find task" })); + await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find prereq task" })); return false; } @@ -253,4 +255,33 @@ public class TaskOperations : StatefulOrm, ITaskOperations { return null; } + + public async Async.Task Init(Task task) { + await _context.Queue.CreateQueue($"{task.TaskId}", StorageType.Corpus); + return await SetState(task, TaskState.Waiting); + } + + + public async Async.Task Stopping(Task task) { + _logTracer.Info($"stopping task : {task.JobId}, {task.TaskId}"); + await _context.NodeOperations.StopTask(task.TaskId); + var anyRemainingNodes = await _context.NodeTasksOperations.GetNodesByTaskId(task.TaskId).AnyAsync(); + if (!anyRemainingNodes) { + return await Stopped(task); + } + return task; + } + + private async Async.Task Stopped(Task inputTask) { + var task = await SetState(inputTask, TaskState.Stopped); + await _context.Queue.DeleteQueue($"{task.TaskId}", StorageType.Corpus); + + // # TODO: we need to 'unschedule' this task from the existing pools + var job = await _context.JobOperations.Get(task.JobId); + if (job != null) { + await _context.JobOperations.StopIfAllDone(job); + } + + return task; + } } diff --git a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs index bb155f3c6..306ff0665 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/EntityConverter.cs @@ -210,8 +210,8 @@ public class EntityConverter { return Guid.Parse(entity.GetString(ef.kind.ToString())); else if (ef.type == typeof(int)) return int.Parse(entity.GetString(ef.kind.ToString())); - else if (ef.type == typeof(PoolName)) - // TODO: this should be able to be generic over any ValidatedString + else if (ef.type == typeof(PoolName)) + // TODO: this should be able to be generic over any ValidatedString return PoolName.Parse(entity.GetString(ef.kind.ToString())); else { throw new Exception($"invalid partition or row key type of {info.type} property {name}: {ef.type}"); diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs index ee9c03dc1..6dc0edd18 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Orm.cs @@ -61,7 +61,7 @@ namespace ApiService.OneFuzzLib.Orm { public async Task> Replace(T entity) { var tableClient = await GetTableClient(typeof(T).Name); var tableEntity = _entityConverter.ToTableEntity(entity); - var response = await tableClient.UpsertEntityAsync(tableEntity); + var response = await tableClient.UpsertEntityAsync(tableEntity, TableUpdateMode.Replace); if (response.IsError) { return ResultVoid<(int, string)>.Error((response.Status, response.ReasonPhrase)); } else { @@ -97,7 +97,7 @@ namespace ApiService.OneFuzzLib.Orm { var account = accountId ?? _context.ServiceConfiguration.OneFuzzFuncStorage ?? throw new ArgumentNullException(nameof(accountId)); var (name, key) = await _context.Storage.GetStorageAccountNameAndKey(account); - var endpoint = _context.Storage.GetTableEndpoint(account); + var endpoint = _context.Storage.GetTableEndpoint(name); var tableClient = new TableServiceClient(endpoint, new TableSharedKeyCredential(name, key)); await tableClient.CreateTableIfNotExistsAsync(tableName); return tableClient.GetTableClient(tableName); @@ -136,13 +136,42 @@ namespace ApiService.OneFuzzLib.Orm { } - public class StatefulOrm : Orm, IStatefulOrm where T : StatefulEntityBase where TState : Enum { + public class StatefulOrm : Orm, IStatefulOrm where T : StatefulEntityBase where TState : Enum { static Lazy>? _partitionKeyGetter; static Lazy>? _rowKeyGetter; static ConcurrentDictionary>?> _stateFuncs = new ConcurrentDictionary>?>(); + delegate Async.Task StateTransition(T entity); + static StatefulOrm() { + + /// verify that all state transition function have the correct signature: + var thisType = typeof(Self); + var states = Enum.GetNames(typeof(TState)); + var delegateType = typeof(StateTransition); + MethodInfo delegateSignature = delegateType.GetMethod("Invoke")!; + + foreach (var state in states) { + var methodInfo = thisType?.GetMethod(state.ToString()); + if (methodInfo == null) { + continue; + } + + bool parametersEqual = delegateSignature + .GetParameters() + .Select(x => x.ParameterType) + .SequenceEqual(methodInfo.GetParameters() + .Select(x => x.ParameterType)); + + if (delegateSignature.ReturnType == methodInfo.ReturnType && parametersEqual) { + continue; + } + + throw new Exception($"State transition method '{state}' in '{thisType?.Name}' does not have the correct signature. Expected '{delegateSignature}' actual '{methodInfo}' "); + }; + + _partitionKeyGetter = typeof(T).GetProperties().FirstOrDefault(p => p.GetCustomAttributes(true).OfType().Any())?.GetMethod switch { null => null, @@ -167,11 +196,10 @@ namespace ApiService.OneFuzzLib.Orm { /// public async Async.Task ProcessStateUpdate(T entity) { TState state = entity.State; - var func = _stateFuncs.GetOrAdd(state.ToString(), (string k) => - typeof(T).GetMethod(k) switch { - null => null, - MethodInfo info => (Func>)Delegate.CreateDelegate(typeof(Func>), info) - }); + var func = GetType().GetMethod(state.ToString()) switch { + null => null, + MethodInfo info => info.CreateDelegate(this) + }; if (func != null) { _logTracer.Info($"processing state update: {typeof(T)} - PartitionKey {_partitionKeyGetter?.Value()} {_rowKeyGetter?.Value()} - %s"); diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs b/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs index 24a6822a9..115536d05 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs @@ -11,12 +11,15 @@ namespace ApiService.OneFuzzLib.Orm { public static string PartitionKey(string partitionKey) => TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}"); - public static string PartitionKey(Guid partitionKey) - => TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}"); - public static string RowKey(string rowKey) => TableClient.CreateQueryFilter($"RowKey eq {rowKey}"); + public static string PartitionKeys(IEnumerable partitionKeys) + => Or(partitionKeys.Select(PartitionKey)); + + public static string RowKeys(IEnumerable rowKeys) + => Or(rowKeys.Select(RowKey)); + public static string SingleEntity(string partitionKey, string rowKey) => TableClient.CreateQueryFilter($"(PartitionKey eq {partitionKey}) and (RowKey eq {rowKey})"); diff --git a/src/ApiService/Tests/Integration/AzureStorage.cs b/src/ApiService/Tests/Integration/AzureStorage.cs index 1b015d265..a630625a1 100644 --- a/src/ApiService/Tests/Integration/AzureStorage.cs +++ b/src/ApiService/Tests/Integration/AzureStorage.cs @@ -25,21 +25,31 @@ sealed class AzureStorage : IStorage { return new AzureStorage(accountName, accountKey); } - public string? AccountName { get; } - public string? AccountKey { get; } + public string AccountName { get; } + public string AccountKey { get; } - public AzureStorage(string? accountName, string? accountKey) { + public AzureStorage(string accountName, string accountKey) { AccountName = accountName; AccountKey = accountKey; } - public Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId) - => Async.Task.FromResult((AccountName, AccountKey)); + public IEnumerable CorpusAccounts() { + throw new System.NotImplementedException(); + } public IEnumerable GetAccounts(StorageType storageType) { - if (AccountName != null) { - yield return AccountName; - } + yield return AccountName; + } + + public string GetPrimaryAccount(StorageType storageType) { + throw new System.NotImplementedException(); + } + + public Task<(string, string)> GetStorageAccountNameAndKey(string accountId) + => Async.Task.FromResult((AccountName, AccountKey)); + + public Task GetStorageAccountNameKeyByName(string accountName) { + return Async.Task.FromResult(AccountName)!; } public Uri GetTableEndpoint(string accountId) @@ -51,15 +61,4 @@ sealed class AzureStorage : IStorage { public Uri GetBlobEndpoint(string accountId) => new($"https://{AccountName}.blob.core.windows.net/"); - public IEnumerable CorpusAccounts() { - throw new System.NotImplementedException(); - } - - public string GetPrimaryAccount(StorageType storageType) { - throw new System.NotImplementedException(); - } - - public Task GetStorageAccountNameAndKeyByName(string accountName) { - throw new System.NotImplementedException(); - } } diff --git a/src/ApiService/Tests/Integration/AzuriteStorage.cs b/src/ApiService/Tests/Integration/AzuriteStorage.cs index e60879b0a..80035263b 100644 --- a/src/ApiService/Tests/Integration/AzuriteStorage.cs +++ b/src/ApiService/Tests/Integration/AzuriteStorage.cs @@ -22,15 +22,11 @@ sealed class AzuriteStorage : IStorage { const string AccountName = "devstoreaccount1"; const string AccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="; - public Task<(string?, string?)> GetStorageAccountNameAndKey(string _accountId) - => Async.Task.FromResult<(string?, string?)>((AccountName, AccountKey)); + public Task<(string, string)> GetStorageAccountNameAndKey(string accountId) + => Async.Task.FromResult((AccountName, AccountKey)); - public IEnumerable GetAccounts(StorageType storageType) { - yield return AccountName; - } - - public Task GetStorageAccountNameAndKeyByName(string accountName) { - throw new System.NotImplementedException(); + public Task GetStorageAccountNameKeyByName(string accountName) { + return Async.Task.FromResult(AccountName)!; } public IEnumerable CorpusAccounts() { @@ -40,4 +36,8 @@ sealed class AzuriteStorage : IStorage { public string GetPrimaryAccount(StorageType storageType) { throw new System.NotImplementedException(); } + + public IEnumerable GetAccounts(StorageType storageType) { + yield return AccountName; + } } diff --git a/src/ApiService/Tests/RequestsTests.cs b/src/ApiService/Tests/RequestsTests.cs index 176ee463f..a903caf54 100644 --- a/src/ApiService/Tests/RequestsTests.cs +++ b/src/ApiService/Tests/RequestsTests.cs @@ -35,6 +35,8 @@ public class RequestsTests { var deserialized = (T?)_serializer.Deserialize(stream, typeof(T), CancellationToken.None); var reserialized = _serializer.Serialize(deserialized); var result = Encoding.UTF8.GetString(reserialized); + result = result.Replace(System.Environment.NewLine, "\n"); + json = json.Replace(System.Environment.NewLine, "\n"); Assert.Equal(json, result); } diff --git a/src/ApiService/Tests/SchedulerTests.cs b/src/ApiService/Tests/SchedulerTests.cs new file mode 100644 index 000000000..3918f1775 --- /dev/null +++ b/src/ApiService/Tests/SchedulerTests.cs @@ -0,0 +1,133 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.OneFuzz.Service; +using Xunit; + +namespace Tests; + +public class SchedulerTests { + + IEnumerable BuildTasks(int size) { + return Enumerable.Range(0, size).Select(i => + new Task( + Guid.Empty, + Guid.NewGuid(), + TaskState.Init, + Os.Linux, + new TaskConfig( + Guid.Empty, + null, + new TaskDetails( + Type: TaskType.LibfuzzerFuzz, + Duration: 1, + TargetExe: "fuzz.exe", + TargetEnv: new Dictionary(), + TargetOptions: new List()), + Pool: new TaskPool(1, PoolName.Parse("pool")), + Containers: new List { new TaskContainers(ContainerType.Setup, new Container("setup")) }, + Colocate: true + + ), + null, + null, + null, + null, + null) + + ); + } + + [Fact] + public void TestAllColocate() { + // all tasks should land in one bucket + + var tasks = BuildTasks(10).Select(task => task with { Config = task.Config with { Colocate = true } } + ).ToList(); + + var buckets = Scheduler.BucketTasks(tasks); + var bucket = Assert.Single(buckets); + Assert.True(10 >= bucket.Count()); + CheckBuckets(buckets, tasks, 1); + } + + [Fact] + public void TestPartialColocate() { + // 2 tasks should land on their own, the rest should be colocated into a + // single bucket. + + var tasks = BuildTasks(10).Select((task, i) => { + return i switch { + 0 => task with { Config = task.Config with { Colocate = null } }, + 1 => task with { Config = task.Config with { Colocate = false } }, + _ => task + }; + }).ToList(); + var buckets = Scheduler.BucketTasks(tasks); + var lengths = buckets.Select(b => b.Count()).OrderBy(x => x); + Assert.Equal(new[] { 1, 1, 8 }, lengths); + CheckBuckets(buckets, tasks, 3); + } + + [Fact] + public void TestAlluniqueJob() { + // everything has a unique job_id + var tasks = BuildTasks(10).Select(task => { + var jobId = Guid.NewGuid(); + return task with { JobId = jobId, Config = task.Config with { JobId = jobId } }; + }).ToList(); + + var buckets = Scheduler.BucketTasks(tasks); + foreach (var bucket in buckets) { + Assert.True(1 >= bucket.Count()); + } + CheckBuckets(buckets, tasks, 10); + } + + [Fact] + public void TestMultipleJobBuckets() { + // at most 3 tasks per bucket, by job_id + var tasks = BuildTasks(10).Chunk(3).SelectMany(taskChunk => { + var jobId = Guid.NewGuid(); + return taskChunk.Select(task => task with { JobId = jobId, Config = task.Config with { JobId = jobId } }); + }).ToList(); + + var buckets = Scheduler.BucketTasks(tasks); + foreach (var bucket in buckets) { + Assert.True(3 >= bucket.Count()); + } + CheckBuckets(buckets, tasks, 4); + } + + [Fact] + public void TestManyBuckets() { + var jobId = Guid.Parse("00000000-0000-0000-0000-000000000001"); + var tasks = BuildTasks(100).Select((task, i) => { + var containers = new List(task.Config.Containers!); + if (i % 4 == 0) { + containers[0] = containers[0] with { Name = new Container("setup2") }; + } + return task with { + JobId = i % 2 == 0 ? jobId : task.JobId, + Os = i % 3 == 0 ? Os.Windows : task.Os, + Config = task.Config with { + JobId = i % 2 == 0 ? jobId : task.Config.JobId, + Containers = containers, + Pool = i % 5 == 0 ? task.Config.Pool! with { PoolName = PoolName.Parse("alternate-pool") } : task.Config.Pool + } + }; + }).ToList(); + + var buckets = Scheduler.BucketTasks(tasks); + + CheckBuckets(buckets, tasks, 12); + } + + void CheckBuckets(ILookup buckets, List tasks, int bucketCount) { + Assert.Equal(buckets.Count, bucketCount); + + foreach (var task in tasks) { + Assert.Single(buckets, b => b.Contains(task)); + } + } +} diff --git a/src/ApiService/Tests/ValidatedStringTests.cs b/src/ApiService/Tests/ValidatedStringTests.cs index 783bf7698..26eebd35e 100644 --- a/src/ApiService/Tests/ValidatedStringTests.cs +++ b/src/ApiService/Tests/ValidatedStringTests.cs @@ -8,12 +8,6 @@ public class ValidatedStringTests { record ThingContainingPoolName(PoolName PoolName); - [Fact] - public void PoolNameValidatesOnDeserialization() { - var ex = Assert.Throws(() => JsonSerializer.Deserialize("{ \"PoolName\": \"is-not!-a-pool\" }")); - Assert.Equal("unable to parse input as a PoolName", ex.Message); - } - [Fact] public void PoolNameDeserializesFromString() { var result = JsonSerializer.Deserialize("{ \"PoolName\": \"is-a-pool\" }");