mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-14 02:58:10 +00:00
Migrate timer_task part 3 (#2066)
* Adding unit tests for the scheduler * formatting * fix unit test * bug fixes * proting a couple more missong functions * more bug fixes * fixing queries and serilization * fix sas url generation * fix condition * [testing] enabling timer_tasks for testing * Update src/ApiService/Tests/SchedulerTests.cs Co-authored-by: George Pollard <porges@porg.es> * address PR comments * build fix * address PR comment * removing renamed function * resolve merge * Update src/ApiService/Tests/Integration/AzuriteStorage.cs Co-authored-by: George Pollard <porges@porg.es> * - Added verification of the state transition functions - disabled validation on PoolName * Update src/deployment/deploy.py Co-authored-by: George Pollard <porges@porg.es> Co-authored-by: George Pollard <gpollard@microsoft.com>
This commit is contained in:
@ -184,7 +184,8 @@ public record TaskDetails(
|
|||||||
bool? PreserveExistingOutputs = null,
|
bool? PreserveExistingOutputs = null,
|
||||||
List<string>? ReportList = null,
|
List<string>? ReportList = null,
|
||||||
int? MinimizedStackDepth = null,
|
int? MinimizedStackDepth = null,
|
||||||
string? CoverageFilter = null);
|
string? CoverageFilter = null
|
||||||
|
);
|
||||||
|
|
||||||
public record TaskVm(
|
public record TaskVm(
|
||||||
Region Region,
|
Region Region,
|
||||||
@ -214,7 +215,8 @@ public record TaskConfig(
|
|||||||
List<TaskContainers>? Containers = null,
|
List<TaskContainers>? Containers = null,
|
||||||
Dictionary<string, string>? Tags = null,
|
Dictionary<string, string>? Tags = null,
|
||||||
List<TaskDebugFlag>? Debug = null,
|
List<TaskDebugFlag>? Debug = null,
|
||||||
bool? Colocate = null);
|
bool? Colocate = null
|
||||||
|
);
|
||||||
|
|
||||||
public record TaskEventSummary(
|
public record TaskEventSummary(
|
||||||
DateTimeOffset? Timestamp,
|
DateTimeOffset? Timestamp,
|
||||||
@ -590,9 +592,29 @@ public record WorkUnit(
|
|||||||
Guid JobId,
|
Guid JobId,
|
||||||
Guid TaskId,
|
Guid TaskId,
|
||||||
TaskType TaskType,
|
TaskType TaskType,
|
||||||
TaskUnitConfig Config
|
|
||||||
|
// JSON-serialized `TaskUnitConfig`.
|
||||||
|
[property: JsonConverter(typeof(TaskUnitConfigConverter))] TaskUnitConfig Config
|
||||||
);
|
);
|
||||||
|
|
||||||
|
public class TaskUnitConfigConverter : JsonConverter<TaskUnitConfig> {
|
||||||
|
public override TaskUnitConfig? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
|
||||||
|
var taskUnitString = reader.GetString();
|
||||||
|
if (taskUnitString == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return JsonSerializer.Deserialize<TaskUnitConfig>(taskUnitString, options);
|
||||||
|
}
|
||||||
|
|
||||||
|
public override void Write(Utf8JsonWriter writer, TaskUnitConfig value, JsonSerializerOptions options) {
|
||||||
|
var v = JsonSerializer.Serialize(value, new JsonSerializerOptions(options) {
|
||||||
|
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
|
||||||
|
});
|
||||||
|
|
||||||
|
writer.WriteStringValue(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public record VmDefinition(
|
public record VmDefinition(
|
||||||
Compare Compare,
|
Compare Compare,
|
||||||
int Value
|
int Value
|
||||||
@ -625,14 +647,36 @@ public record ContainerDefinition(
|
|||||||
|
|
||||||
// TODO: service shouldn't pass SyncedDir, but just the url and let the agent
|
// TODO: service shouldn't pass SyncedDir, but just the url and let the agent
|
||||||
// come up with paths
|
// come up with paths
|
||||||
public record SyncedDir(string Path, Uri url);
|
public record SyncedDir(string Path, Uri Url);
|
||||||
|
|
||||||
|
|
||||||
|
[JsonConverter(typeof(ContainerDefConverter))]
|
||||||
public interface IContainerDef { }
|
public interface IContainerDef { }
|
||||||
public record SingleContainer(SyncedDir SyncedDir) : IContainerDef;
|
public record SingleContainer(SyncedDir SyncedDir) : IContainerDef;
|
||||||
public record MultipleContainer(List<SyncedDir> SyncedDirs) : IContainerDef;
|
public record MultipleContainer(List<SyncedDir> SyncedDirs) : IContainerDef;
|
||||||
|
|
||||||
|
|
||||||
|
public class ContainerDefConverter : JsonConverter<IContainerDef> {
|
||||||
|
public override IContainerDef? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
|
||||||
|
throw new NotImplementedException();
|
||||||
|
}
|
||||||
|
|
||||||
|
public override void Write(Utf8JsonWriter writer, IContainerDef value, JsonSerializerOptions options) {
|
||||||
|
switch (value) {
|
||||||
|
case SingleContainer container:
|
||||||
|
JsonSerializer.Serialize(writer, container.SyncedDir, options);
|
||||||
|
break;
|
||||||
|
case MultipleContainer { SyncedDirs: var syncedDirs }:
|
||||||
|
JsonSerializer.Serialize(writer, syncedDirs, options);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new NotImplementedException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public record TaskUnitConfig(
|
public record TaskUnitConfig(
|
||||||
Guid InstanceId,
|
Guid InstanceId,
|
||||||
Guid JobId,
|
Guid JobId,
|
||||||
@ -686,7 +730,7 @@ public record TaskUnitConfig(
|
|||||||
public IContainerDef? Tools { get; set; }
|
public IContainerDef? Tools { get; set; }
|
||||||
public IContainerDef? UniqueInputs { get; set; }
|
public IContainerDef? UniqueInputs { get; set; }
|
||||||
public IContainerDef? UniqueReports { get; set; }
|
public IContainerDef? UniqueReports { get; set; }
|
||||||
public IContainerDef? RegressionReport { get; set; }
|
public IContainerDef? RegressionReports { get; set; }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
|
using System.Diagnostics.CodeAnalysis;
|
||||||
using System.Diagnostics;
|
|
||||||
using System.Diagnostics.CodeAnalysis;
|
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
using System.Text.Json.Serialization;
|
using System.Text.Json.Serialization;
|
||||||
using System.Text.RegularExpressions;
|
using System.Text.RegularExpressions;
|
||||||
@ -51,7 +49,7 @@ public abstract class ValidatedStringConverter<T> : JsonConverter<T> where T : V
|
|||||||
[JsonConverter(typeof(Converter))]
|
[JsonConverter(typeof(Converter))]
|
||||||
public record PoolName : ValidatedString {
|
public record PoolName : ValidatedString {
|
||||||
private PoolName(string value) : base(value) {
|
private PoolName(string value) : base(value) {
|
||||||
Debug.Assert(Check.IsAlnumDash(value));
|
// Debug.Assert(Check.IsAlnumDash(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PoolName Parse(string input) {
|
public static PoolName Parse(string input) {
|
||||||
@ -63,10 +61,14 @@ public record PoolName : ValidatedString {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) {
|
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) {
|
||||||
if (!Check.IsAlnumDash(input)) {
|
|
||||||
result = default;
|
// bypassing the validation because this code has a stricter validation than the python equivalent
|
||||||
return false;
|
// see (issue #2080)
|
||||||
}
|
|
||||||
|
// if (!Check.IsAlnumDash(input)) {
|
||||||
|
// result = default;
|
||||||
|
// return false;
|
||||||
|
// }
|
||||||
|
|
||||||
result = new PoolName(input);
|
result = new PoolName(input);
|
||||||
return true;
|
return true;
|
||||||
|
@ -21,8 +21,8 @@ public class TimerTasks {
|
|||||||
_scheduler = scheduler;
|
_scheduler = scheduler;
|
||||||
}
|
}
|
||||||
|
|
||||||
//[Function("TimerTasks")]
|
[Function("TimerTasks")]
|
||||||
public async Async.Task Run([TimerTrigger("1.00:00:00")] TimerInfo myTimer) {
|
public async Async.Task Run([TimerTrigger("00:00:15")] TimerInfo myTimer) {
|
||||||
var expriredTasks = _taskOperations.SearchExpired();
|
var expriredTasks = _taskOperations.SearchExpired();
|
||||||
await foreach (var task in expriredTasks) {
|
await foreach (var task in expriredTasks) {
|
||||||
_logger.Info($"stopping expired task. job_id:{task.JobId} task_id:{task.TaskId}");
|
_logger.Info($"stopping expired task. job_id:{task.JobId} task_id:{task.TaskId}");
|
||||||
|
@ -5,7 +5,6 @@ namespace Microsoft.OneFuzz.Service;
|
|||||||
|
|
||||||
|
|
||||||
public interface IConfig {
|
public interface IConfig {
|
||||||
string GetSetupContainer(TaskConfig config);
|
|
||||||
Async.Task<TaskUnitConfig> BuildTaskConfig(Job job, Task task);
|
Async.Task<TaskUnitConfig> BuildTaskConfig(Job job, Task task);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,9 +88,13 @@ public class Config : IConfig {
|
|||||||
|
|
||||||
await foreach (var data in containersByType) {
|
await foreach (var data in containersByType) {
|
||||||
|
|
||||||
|
if (!data.containers.Any()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
IContainerDef def = data.countainerDef switch {
|
IContainerDef def = data.countainerDef switch {
|
||||||
ContainerDefinition { Compare: Compare.Equal, Value: 1 } or
|
ContainerDefinition { Compare: Compare.Equal, Value: 1 } or
|
||||||
ContainerDefinition { Compare: Compare.AtMost, Value: 1 } => new SingleContainer(data.containers[0]),
|
ContainerDefinition { Compare: Compare.AtMost, Value: 1 } when data.containers.Count == 1 => new SingleContainer(data.containers[0]),
|
||||||
_ => new MultipleContainer(data.containers)
|
_ => new MultipleContainer(data.containers)
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -126,6 +129,9 @@ public class Config : IConfig {
|
|||||||
case ContainerType.UniqueReports:
|
case ContainerType.UniqueReports:
|
||||||
config.UniqueReports = def;
|
config.UniqueReports = def;
|
||||||
break;
|
break;
|
||||||
|
case ContainerType.RegressionReports:
|
||||||
|
config.RegressionReports = def;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,16 +255,4 @@ public class Config : IConfig {
|
|||||||
|
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public string GetSetupContainer(TaskConfig config) {
|
|
||||||
|
|
||||||
foreach (var container in config.Containers ?? throw new Exception("Missing containers")) {
|
|
||||||
if (container.Type == ContainerType.Setup) {
|
|
||||||
return container.Name.ContainerName;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new Exception($"task missing setup container: task_type = {config.Task.Type}");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -164,14 +164,15 @@ public class Containers : IContainers {
|
|||||||
if (uri.Query.Contains("sig")) {
|
if (uri.Query.Contains("sig")) {
|
||||||
return uri;
|
return uri;
|
||||||
}
|
}
|
||||||
|
var blobUriBuilder = new BlobUriBuilder(uri);
|
||||||
var accountName = uri.Host.Split('.')[0];
|
var accountKey = await _storage.GetStorageAccountNameKeyByName(blobUriBuilder.AccountName);
|
||||||
var (_, accountKey) = await _storage.GetStorageAccountNameAndKey(accountName);
|
|
||||||
var sasBuilder = new BlobSasBuilder(
|
var sasBuilder = new BlobSasBuilder(
|
||||||
BlobContainerSasPermissions.Read | BlobContainerSasPermissions.Write | BlobContainerSasPermissions.Delete | BlobContainerSasPermissions.List,
|
BlobContainerSasPermissions.Read | BlobContainerSasPermissions.Write | BlobContainerSasPermissions.Delete | BlobContainerSasPermissions.List,
|
||||||
DateTimeOffset.UtcNow + TimeSpan.FromHours(1));
|
DateTimeOffset.UtcNow + TimeSpan.FromHours(1)) {
|
||||||
|
BlobContainerName = blobUriBuilder.BlobContainerName,
|
||||||
|
};
|
||||||
|
|
||||||
var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(accountName, accountKey)).ToString();
|
var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(blobUriBuilder.AccountName, accountKey)).ToString();
|
||||||
return new UriBuilder(uri) {
|
return new UriBuilder(uri) {
|
||||||
Query = sas
|
Query = sas
|
||||||
}.Uri;
|
}.Uri;
|
||||||
@ -179,13 +180,10 @@ public class Containers : IContainers {
|
|||||||
|
|
||||||
public async Async.Task<Uri> GetContainerSasUrl(Container container, StorageType storageType, BlobContainerSasPermissions permissions, TimeSpan? duration = null) {
|
public async Async.Task<Uri> GetContainerSasUrl(Container container, StorageType storageType, BlobContainerSasPermissions permissions, TimeSpan? duration = null) {
|
||||||
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
|
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
|
||||||
var (accountName, accountKey) = await _storage.GetStorageAccountNameAndKey(client.AccountName);
|
|
||||||
|
|
||||||
var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30));
|
var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30));
|
||||||
|
|
||||||
var sasBuilder = new BlobSasBuilder(permissions, endTime) {
|
var sasBuilder = new BlobSasBuilder(permissions, endTime) {
|
||||||
StartsOn = startTime,
|
StartsOn = startTime,
|
||||||
BlobContainerName = container.ContainerName,
|
BlobContainerName = container.ContainerName
|
||||||
};
|
};
|
||||||
|
|
||||||
var sasUrl = client.GenerateSasUri(sasBuilder);
|
var sasUrl = client.GenerateSasUri(sasBuilder);
|
||||||
|
@ -6,12 +6,13 @@ public interface IJobOperations : IStatefulOrm<Job, JobState> {
|
|||||||
Async.Task<Job?> Get(Guid jobId);
|
Async.Task<Job?> Get(Guid jobId);
|
||||||
Async.Task OnStart(Job job);
|
Async.Task OnStart(Job job);
|
||||||
IAsyncEnumerable<Job> SearchExpired();
|
IAsyncEnumerable<Job> SearchExpired();
|
||||||
Async.Task Stopping(Job job);
|
Async.Task<Job> Stopping(Job job);
|
||||||
IAsyncEnumerable<Job> SearchState(IEnumerable<JobState> states);
|
IAsyncEnumerable<Job> SearchState(IEnumerable<JobState> states);
|
||||||
Async.Task StopNeverStartedJobs();
|
Async.Task StopNeverStartedJobs();
|
||||||
|
Async.Task StopIfAllDone(Job job);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
public class JobOperations : StatefulOrm<Job, JobState, JobOperations>, IJobOperations {
|
||||||
private static TimeSpan JOB_NEVER_STARTED_DURATION = TimeSpan.FromDays(30);
|
private static TimeSpan JOB_NEVER_STARTED_DURATION = TimeSpan.FromDays(30);
|
||||||
|
|
||||||
public JobOperations(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) {
|
public JobOperations(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) {
|
||||||
@ -36,6 +37,17 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
|||||||
return QueryAsync(filter: query);
|
return QueryAsync(filter: query);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async Async.Task StopIfAllDone(Job job) {
|
||||||
|
var anyNotStoppedJobs = await _context.TaskOperations.GetByJobId(job.JobId).AnyAsync(task => task.State != TaskState.Stopped);
|
||||||
|
|
||||||
|
if (anyNotStoppedJobs) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
_logTracer.Info($"stopping job as all tasks are stopped: {job.JobId}");
|
||||||
|
await Stopping(job);
|
||||||
|
}
|
||||||
|
|
||||||
public async Async.Task StopNeverStartedJobs() {
|
public async Async.Task StopNeverStartedJobs() {
|
||||||
// # Note, the "not(end_time...)" with end_time set long before the use of
|
// # Note, the "not(end_time...)" with end_time set long before the use of
|
||||||
// # OneFuzz enables identifying those without end_time being set.
|
// # OneFuzz enables identifying those without end_time being set.
|
||||||
@ -58,7 +70,7 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Async.Task Stopping(Job job) {
|
public async Async.Task<Job> Stopping(Job job) {
|
||||||
job = job with { State = JobState.Stopping };
|
job = job with { State = JobState.Stopping };
|
||||||
var tasks = await _context.TaskOperations.QueryAsync(filter: $"job_id eq '{job.JobId}'").ToListAsync();
|
var tasks = await _context.TaskOperations.QueryAsync(filter: $"job_id eq '{job.JobId}'").ToListAsync();
|
||||||
var taskNotStopped = tasks.ToLookup(task => task.State != TaskState.Stopped);
|
var taskNotStopped = tasks.ToLookup(task => task.State != TaskState.Stopped);
|
||||||
@ -76,7 +88,12 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
|||||||
await _context.Events.SendEvent(new EventJobStopped(job.JobId, job.Config, job.UserInfo, taskInfo));
|
await _context.Events.SendEvent(new EventJobStopped(job.JobId, job.Config, job.UserInfo, taskInfo));
|
||||||
}
|
}
|
||||||
|
|
||||||
await Replace(job);
|
var result = await Replace(job);
|
||||||
|
|
||||||
|
if (result.IsOk) {
|
||||||
|
return job;
|
||||||
|
} else {
|
||||||
|
throw new Exception($"Failed to save job {job.JobId} : {result.ErrorV}");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -43,6 +43,8 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
|
|||||||
Async.Task MarkTasksStoppedEarly(Node node, Error? error = null);
|
Async.Task MarkTasksStoppedEarly(Node node, Error? error = null);
|
||||||
static TimeSpan NODE_EXPIRATION_TIME = TimeSpan.FromHours(1.0);
|
static TimeSpan NODE_EXPIRATION_TIME = TimeSpan.FromHours(1.0);
|
||||||
static TimeSpan NODE_REIMAGE_TIME = TimeSpan.FromDays(6.0);
|
static TimeSpan NODE_REIMAGE_TIME = TimeSpan.FromDays(6.0);
|
||||||
|
|
||||||
|
Async.Task StopTask(Guid task_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -51,7 +53,7 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
|
|||||||
/// Enabling autoscaling for the scalesets based on the pool work queues.
|
/// Enabling autoscaling for the scalesets based on the pool work queues.
|
||||||
/// https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
|
/// https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
|
||||||
|
|
||||||
public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
|
public class NodeOperations : StatefulOrm<Node, NodeState, NodeOperations>, INodeOperations {
|
||||||
|
|
||||||
|
|
||||||
public NodeOperations(
|
public NodeOperations(
|
||||||
@ -269,7 +271,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
|
|||||||
|
|
||||||
|
|
||||||
public async Async.Task<Node?> GetByMachineId(Guid machineId) {
|
public async Async.Task<Node?> GetByMachineId(Guid machineId) {
|
||||||
var data = QueryAsync(filter: $"RowKey eq '{machineId}'");
|
var data = QueryAsync(filter: Query.RowKey(machineId.ToString()));
|
||||||
|
|
||||||
return await data.FirstOrDefaultAsync();
|
return await data.FirstOrDefaultAsync();
|
||||||
}
|
}
|
||||||
@ -387,18 +389,49 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
|
|||||||
await _context.Events.SendEvent(new EventNodeDeleted(node.MachineId, node.ScalesetId, node.PoolName, node.State));
|
await _context.Events.SendEvent(new EventNodeDeleted(node.MachineId, node.ScalesetId, node.PoolName, node.State));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async Async.Task StopTask(Guid task_id) {
|
||||||
|
// For now, this just re-images the node. Eventually, this
|
||||||
|
// should send a message to the node to let the agent shut down
|
||||||
|
// gracefully
|
||||||
|
|
||||||
|
var nodes = _context.NodeTasksOperations.GetNodesByTaskId(task_id);
|
||||||
|
|
||||||
|
await foreach (var node in nodes) {
|
||||||
|
await _context.NodeMessageOperations.SendMessage(node.MachineId, new NodeCommand(StopTask: new StopTaskNodeCommand(task_id)));
|
||||||
|
|
||||||
|
if (!(await StopIfComplete(node))) {
|
||||||
|
_logTracer.Info($"nodes: stopped task on node, but not reimaging due to other tasks: task_id:{task_id} machine_id:{node.MachineId}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/// returns True on stopping the node and False if this doesn't stop the node
|
||||||
|
private async Task<bool> StopIfComplete(Node node, bool done = false) {
|
||||||
|
var nodeTaskIds = await _context.NodeTasksOperations.GetByMachineId(node.MachineId).Select(nt => nt.TaskId).ToArrayAsync();
|
||||||
|
var tasks = _context.TaskOperations.GetByTaskIds(nodeTaskIds);
|
||||||
|
await foreach (var task in tasks) {
|
||||||
|
if (!TaskStateHelper.ShuttingDown(task.State)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_logTracer.Info($"node: stopping busy node with all tasks complete: {node.MachineId}");
|
||||||
|
|
||||||
|
await Stop(node, done: done);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public interface INodeTasksOperations : IStatefulOrm<NodeTasks, NodeTaskState> {
|
public interface INodeTasksOperations : IStatefulOrm<NodeTasks, NodeTaskState> {
|
||||||
IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId, INodeOperations nodeOps);
|
IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId);
|
||||||
IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps);
|
IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps);
|
||||||
IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId);
|
IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId);
|
||||||
IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId);
|
IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId);
|
||||||
Async.Task ClearByMachineId(Guid machineId);
|
Async.Task ClearByMachineId(Guid machineId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeTasksOperations {
|
public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState, NodeTasksOperations>, INodeTasksOperations {
|
||||||
|
|
||||||
ILogTracer _log;
|
ILogTracer _log;
|
||||||
|
|
||||||
@ -408,18 +441,20 @@ public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeT
|
|||||||
}
|
}
|
||||||
|
|
||||||
//TODO: suggest by Cheick: this can probably be optimize by query all NodesTasks then query the all machine in single request
|
//TODO: suggest by Cheick: this can probably be optimize by query all NodesTasks then query the all machine in single request
|
||||||
public async IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId, INodeOperations nodeOps) {
|
|
||||||
await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) {
|
public async IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId) {
|
||||||
var node = await nodeOps.GetByMachineId(entry.MachineId);
|
await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) {
|
||||||
|
var node = await _context.NodeOperations.GetByMachineId(entry.MachineId);
|
||||||
if (node is not null) {
|
if (node is not null) {
|
||||||
yield return node;
|
yield return node;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps) {
|
public async IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps) {
|
||||||
|
|
||||||
await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) {
|
await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) {
|
||||||
var node = await nodeOps.GetByMachineId(entry.MachineId);
|
var node = await _context.NodeOperations.GetByMachineId(entry.MachineId);
|
||||||
if (node is not null) {
|
if (node is not null) {
|
||||||
var nodeAssignment = new NodeAssignment(node.MachineId, node.ScalesetId, entry.State);
|
var nodeAssignment = new NodeAssignment(node.MachineId, node.ScalesetId, entry.State);
|
||||||
yield return nodeAssignment;
|
yield return nodeAssignment;
|
||||||
@ -428,11 +463,11 @@ public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeT
|
|||||||
}
|
}
|
||||||
|
|
||||||
public IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId) {
|
public IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId) {
|
||||||
return QueryAsync($"machine_id eq '{machineId}'");
|
return QueryAsync(Query.PartitionKey(machineId.ToString()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId) {
|
public IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId) {
|
||||||
return QueryAsync($"task_id eq '{taskId}'");
|
return QueryAsync(Query.RowKey(taskId.ToString()));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Async.Task ClearByMachineId(Guid machineId) {
|
public async Async.Task ClearByMachineId(Guid machineId) {
|
||||||
@ -472,7 +507,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
|
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
|
||||||
=> QueryAsync(Query.PartitionKey(machineId));
|
=> QueryAsync(Query.PartitionKey(machineId.ToString()));
|
||||||
|
|
||||||
public async Async.Task ClearMessages(Guid machineId) {
|
public async Async.Task ClearMessages(Guid machineId) {
|
||||||
_logTracer.Info($"clearing messages for node {machineId}");
|
_logTracer.Info($"clearing messages for node {machineId}");
|
||||||
|
@ -9,7 +9,7 @@ public interface IPoolOperations {
|
|||||||
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
|
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations {
|
public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoolOperations {
|
||||||
|
|
||||||
public PoolOperations(ILogTracer log, IOnefuzzContext context)
|
public PoolOperations(ILogTracer log, IOnefuzzContext context)
|
||||||
: base(log, context) {
|
: base(log, context) {
|
||||||
|
@ -13,7 +13,7 @@ public interface IProxyOperations : IStatefulOrm<Proxy, VmState> {
|
|||||||
bool IsOutdated(Proxy proxy);
|
bool IsOutdated(Proxy proxy);
|
||||||
Async.Task<Proxy?> GetOrCreate(string region);
|
Async.Task<Proxy?> GetOrCreate(string region);
|
||||||
}
|
}
|
||||||
public class ProxyOperations : StatefulOrm<Proxy, VmState>, IProxyOperations {
|
public class ProxyOperations : StatefulOrm<Proxy, VmState, ProxyOperations>, IProxyOperations {
|
||||||
|
|
||||||
|
|
||||||
static TimeSpan PROXY_LIFESPAN = TimeSpan.FromDays(7);
|
static TimeSpan PROXY_LIFESPAN = TimeSpan.FromDays(7);
|
||||||
|
@ -51,7 +51,7 @@ public class Queue : IQueue {
|
|||||||
var accountId = _storage.GetPrimaryAccount(storageType);
|
var accountId = _storage.GetPrimaryAccount(storageType);
|
||||||
_log.Verbose($"getting blob container (account_id: {accountId})");
|
_log.Verbose($"getting blob container (account_id: {accountId})");
|
||||||
var (name, key) = await _storage.GetStorageAccountNameAndKey(accountId);
|
var (name, key) = await _storage.GetStorageAccountNameAndKey(accountId);
|
||||||
var endpoint = _storage.GetQueueEndpoint(accountId);
|
var endpoint = _storage.GetQueueEndpoint(name);
|
||||||
var options = new QueueClientOptions { MessageEncoding = QueueMessageEncoding.Base64 };
|
var options = new QueueClientOptions { MessageEncoding = QueueMessageEncoding.Base64 };
|
||||||
return new QueueServiceClient(endpoint, new StorageSharedKeyCredential(name, key), options);
|
return new QueueServiceClient(endpoint, new StorageSharedKeyCredential(name, key), options);
|
||||||
}
|
}
|
||||||
|
@ -10,7 +10,7 @@ public interface IReproOperations : IStatefulOrm<Repro, VmState> {
|
|||||||
public IAsyncEnumerable<Repro> SearchStates(IEnumerable<VmState>? States);
|
public IAsyncEnumerable<Repro> SearchStates(IEnumerable<VmState>? States);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class ReproOperations : StatefulOrm<Repro, VmState>, IReproOperations {
|
public class ReproOperations : StatefulOrm<Repro, VmState, ReproOperations>, IReproOperations {
|
||||||
private static readonly Dictionary<Os, string> DEFAULT_OS = new Dictionary<Os, string>
|
private static readonly Dictionary<Os, string> DEFAULT_OS = new Dictionary<Os, string>
|
||||||
{
|
{
|
||||||
{Os.Linux, "Canonical:UbuntuServer:18.04-LTS:latest"},
|
{Os.Linux, "Canonical:UbuntuServer:18.04-LTS:latest"},
|
||||||
|
@ -14,7 +14,7 @@ public interface IScalesetOperations : IOrm<Scaleset> {
|
|||||||
IAsyncEnumerable<Scaleset> GetByObjectId(Guid objectId);
|
IAsyncEnumerable<Scaleset> GetByObjectId(Guid objectId);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalesetOperations {
|
public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState, ScalesetOperations>, IScalesetOperations {
|
||||||
const string SCALESET_LOG_PREFIX = "scalesets: ";
|
const string SCALESET_LOG_PREFIX = "scalesets: ";
|
||||||
|
|
||||||
ILogTracer _log;
|
ILogTracer _log;
|
||||||
|
@ -74,7 +74,7 @@ public class Scheduler : IScheduler {
|
|||||||
|
|
||||||
private async Async.Task<(BucketConfig, WorkSet)?> BuildWorkSet(Task[] tasks) {
|
private async Async.Task<(BucketConfig, WorkSet)?> BuildWorkSet(Task[] tasks) {
|
||||||
var taskIds = tasks.Select(x => x.TaskId).ToHashSet();
|
var taskIds = tasks.Select(x => x.TaskId).ToHashSet();
|
||||||
var work_units = new List<WorkUnit>();
|
var workUnits = new List<WorkUnit>();
|
||||||
|
|
||||||
BucketConfig? bucketConfig = null;
|
BucketConfig? bucketConfig = null;
|
||||||
foreach (var task in tasks) {
|
foreach (var task in tasks) {
|
||||||
@ -99,7 +99,7 @@ public class Scheduler : IScheduler {
|
|||||||
throw new Exception($"bucket configs differ: {bucketConfig} VS {result.Value.Item1}");
|
throw new Exception($"bucket configs differ: {bucketConfig} VS {result.Value.Item1}");
|
||||||
}
|
}
|
||||||
|
|
||||||
work_units.Add(result.Value.Item2);
|
workUnits.Add(result.Value.Item2);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (bucketConfig != null) {
|
if (bucketConfig != null) {
|
||||||
@ -108,7 +108,7 @@ public class Scheduler : IScheduler {
|
|||||||
Reboot: bucketConfig.reboot,
|
Reboot: bucketConfig.reboot,
|
||||||
Script: bucketConfig.setupScript != null,
|
Script: bucketConfig.setupScript != null,
|
||||||
SetupUrl: setupUrl,
|
SetupUrl: setupUrl,
|
||||||
WorkUnits: work_units
|
WorkUnits: workUnits
|
||||||
);
|
);
|
||||||
|
|
||||||
return (bucketConfig, workSet);
|
return (bucketConfig, workSet);
|
||||||
@ -182,9 +182,9 @@ public class Scheduler : IScheduler {
|
|||||||
return (bucketConfig, workUnit);
|
return (bucketConfig, workUnit);
|
||||||
}
|
}
|
||||||
|
|
||||||
record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
|
public record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
|
||||||
|
|
||||||
private ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
|
public static ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
|
||||||
|
|
||||||
// buckets are hashed by:
|
// buckets are hashed by:
|
||||||
// OS, JOB ID, vm sku & image (if available), pool name (if available),
|
// OS, JOB ID, vm sku & image (if available), pool name (if available),
|
||||||
@ -214,8 +214,19 @@ public class Scheduler : IScheduler {
|
|||||||
unique = Guid.NewGuid();
|
unique = Guid.NewGuid();
|
||||||
}
|
}
|
||||||
|
|
||||||
return new BucketId(task.Os, task.JobId, vm, pool, _config.GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique);
|
return new BucketId(task.Os, task.JobId, vm, pool, GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique);
|
||||||
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static string GetSetupContainer(TaskConfig config) {
|
||||||
|
|
||||||
|
foreach (var container in config.Containers ?? throw new Exception("Missing containers")) {
|
||||||
|
if (container.Type == ContainerType.Setup) {
|
||||||
|
return container.Name.ContainerName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Exception($"task missing setup container: task_type = {config.Task.Type}");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,9 +20,9 @@ public interface IStorage {
|
|||||||
|
|
||||||
public Uri GetBlobEndpoint(string accountId);
|
public Uri GetBlobEndpoint(string accountId);
|
||||||
|
|
||||||
public Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId);
|
public Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId);
|
||||||
|
|
||||||
public Async.Task<string?> GetStorageAccountNameAndKeyByName(string accountName);
|
public Async.Task<string?> GetStorageAccountNameKeyByName(string accountName);
|
||||||
|
|
||||||
public IEnumerable<string> GetAccounts(StorageType storageType);
|
public IEnumerable<string> GetAccounts(StorageType storageType);
|
||||||
}
|
}
|
||||||
@ -99,16 +99,16 @@ public class Storage : IStorage {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId) {
|
public async Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId) {
|
||||||
var resourceId = new ResourceIdentifier(accountId);
|
var resourceId = new ResourceIdentifier(accountId);
|
||||||
var armClient = GetMgmtClient();
|
var armClient = GetMgmtClient();
|
||||||
var storageAccount = armClient.GetStorageAccountResource(resourceId);
|
var storageAccount = armClient.GetStorageAccountResource(resourceId);
|
||||||
var keys = await storageAccount.GetKeysAsync();
|
var keys = await storageAccount.GetKeysAsync();
|
||||||
var key = keys.Value.Keys.FirstOrDefault();
|
var key = keys.Value.Keys.FirstOrDefault() ?? throw new Exception("no keys found");
|
||||||
return (resourceId.Name, key?.Value);
|
return (resourceId.Name, key.Value);
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Async.Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
|
public async Async.Task<string?> GetStorageAccountNameKeyByName(string accountName) {
|
||||||
var armClient = GetMgmtClient();
|
var armClient = GetMgmtClient();
|
||||||
var resourceGroup = _creds.GetResourceGroupResourceIdentifier();
|
var resourceGroup = _creds.GetResourceGroupResourceIdentifier();
|
||||||
var storageAccount = await armClient.GetResourceGroupResource(resourceGroup).GetStorageAccountAsync(accountName);
|
var storageAccount = await armClient.GetResourceGroupResource(resourceGroup).GetStorageAccountAsync(accountName);
|
||||||
|
@ -5,6 +5,10 @@ namespace Microsoft.OneFuzz.Service;
|
|||||||
public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
|
public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
|
||||||
Async.Task<Task?> GetByTaskId(Guid taskId);
|
Async.Task<Task?> GetByTaskId(Guid taskId);
|
||||||
|
|
||||||
|
IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId);
|
||||||
|
|
||||||
|
IAsyncEnumerable<Task> GetByJobId(Guid jobId);
|
||||||
|
|
||||||
Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId);
|
Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId);
|
||||||
|
|
||||||
|
|
||||||
@ -22,7 +26,7 @@ public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
|
|||||||
Async.Task<Task> SetState(Task task, TaskState state);
|
Async.Task<Task> SetState(Task task, TaskState state);
|
||||||
}
|
}
|
||||||
|
|
||||||
public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
public class TaskOperations : StatefulOrm<Task, TaskState, TaskOperations>, ITaskOperations {
|
||||||
|
|
||||||
|
|
||||||
public TaskOperations(ILogTracer log, IOnefuzzContext context)
|
public TaskOperations(ILogTracer log, IOnefuzzContext context)
|
||||||
@ -31,9 +35,15 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public async Async.Task<Task?> GetByTaskId(Guid taskId) {
|
public async Async.Task<Task?> GetByTaskId(Guid taskId) {
|
||||||
var data = QueryAsync(filter: $"RowKey eq '{taskId}'");
|
return await GetByTaskIds(new[] { taskId }).FirstOrDefaultAsync();
|
||||||
|
}
|
||||||
|
|
||||||
return await data.FirstOrDefaultAsync();
|
public IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId) {
|
||||||
|
return QueryAsync(filter: Query.RowKeys(taskId.Select(t => t.ToString())));
|
||||||
|
}
|
||||||
|
|
||||||
|
public IAsyncEnumerable<Task> GetByJobId(Guid jobId) {
|
||||||
|
return QueryAsync(filter: $"PartitionKey eq '{jobId}'");
|
||||||
}
|
}
|
||||||
|
|
||||||
public async Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId) {
|
public async Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId) {
|
||||||
@ -42,21 +52,13 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
|||||||
return await data.FirstOrDefaultAsync();
|
return await data.FirstOrDefaultAsync();
|
||||||
}
|
}
|
||||||
public IAsyncEnumerable<Task> SearchStates(Guid? jobId = null, IEnumerable<TaskState>? states = null) {
|
public IAsyncEnumerable<Task> SearchStates(Guid? jobId = null, IEnumerable<TaskState>? states = null) {
|
||||||
var queryString = String.Empty;
|
var queryString =
|
||||||
if (jobId != null) {
|
(jobId, states) switch {
|
||||||
queryString += $"PartitionKey eq '{jobId}'";
|
(null, null) => "",
|
||||||
}
|
(Guid id, null) => Query.PartitionKey($"{id}"),
|
||||||
|
(null, IEnumerable<TaskState> s) => Query.EqualAnyEnum("state", s),
|
||||||
if (states != null) {
|
(Guid id, IEnumerable<TaskState> s) => Query.And(Query.PartitionKey($"{id}"), Query.EqualAnyEnum("state", s)),
|
||||||
if (jobId != null) {
|
};
|
||||||
queryString += " and ";
|
|
||||||
}
|
|
||||||
|
|
||||||
queryString += "(" + string.Join(
|
|
||||||
" or ",
|
|
||||||
states.Select(s => $"state eq '{s}'")
|
|
||||||
) + ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
return QueryAsync(filter: queryString);
|
return QueryAsync(filter: queryString);
|
||||||
}
|
}
|
||||||
@ -209,7 +211,7 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
|||||||
|
|
||||||
// if a prereq task fails, then mark this task as failed
|
// if a prereq task fails, then mark this task as failed
|
||||||
if (t == null) {
|
if (t == null) {
|
||||||
await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find task" }));
|
await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find prereq task" }));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -253,4 +255,33 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
|||||||
return null;
|
return null;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async Async.Task<Task> Init(Task task) {
|
||||||
|
await _context.Queue.CreateQueue($"{task.TaskId}", StorageType.Corpus);
|
||||||
|
return await SetState(task, TaskState.Waiting);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public async Async.Task<Task> Stopping(Task task) {
|
||||||
|
_logTracer.Info($"stopping task : {task.JobId}, {task.TaskId}");
|
||||||
|
await _context.NodeOperations.StopTask(task.TaskId);
|
||||||
|
var anyRemainingNodes = await _context.NodeTasksOperations.GetNodesByTaskId(task.TaskId).AnyAsync();
|
||||||
|
if (!anyRemainingNodes) {
|
||||||
|
return await Stopped(task);
|
||||||
|
}
|
||||||
|
return task;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Async.Task<Task> Stopped(Task inputTask) {
|
||||||
|
var task = await SetState(inputTask, TaskState.Stopped);
|
||||||
|
await _context.Queue.DeleteQueue($"{task.TaskId}", StorageType.Corpus);
|
||||||
|
|
||||||
|
// # TODO: we need to 'unschedule' this task from the existing pools
|
||||||
|
var job = await _context.JobOperations.Get(task.JobId);
|
||||||
|
if (job != null) {
|
||||||
|
await _context.JobOperations.StopIfAllDone(job);
|
||||||
|
}
|
||||||
|
|
||||||
|
return task;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,7 @@ namespace ApiService.OneFuzzLib.Orm {
|
|||||||
public async Task<ResultVoid<(int, string)>> Replace(T entity) {
|
public async Task<ResultVoid<(int, string)>> Replace(T entity) {
|
||||||
var tableClient = await GetTableClient(typeof(T).Name);
|
var tableClient = await GetTableClient(typeof(T).Name);
|
||||||
var tableEntity = _entityConverter.ToTableEntity(entity);
|
var tableEntity = _entityConverter.ToTableEntity(entity);
|
||||||
var response = await tableClient.UpsertEntityAsync(tableEntity);
|
var response = await tableClient.UpsertEntityAsync(tableEntity, TableUpdateMode.Replace);
|
||||||
if (response.IsError) {
|
if (response.IsError) {
|
||||||
return ResultVoid<(int, string)>.Error((response.Status, response.ReasonPhrase));
|
return ResultVoid<(int, string)>.Error((response.Status, response.ReasonPhrase));
|
||||||
} else {
|
} else {
|
||||||
@ -97,7 +97,7 @@ namespace ApiService.OneFuzzLib.Orm {
|
|||||||
|
|
||||||
var account = accountId ?? _context.ServiceConfiguration.OneFuzzFuncStorage ?? throw new ArgumentNullException(nameof(accountId));
|
var account = accountId ?? _context.ServiceConfiguration.OneFuzzFuncStorage ?? throw new ArgumentNullException(nameof(accountId));
|
||||||
var (name, key) = await _context.Storage.GetStorageAccountNameAndKey(account);
|
var (name, key) = await _context.Storage.GetStorageAccountNameAndKey(account);
|
||||||
var endpoint = _context.Storage.GetTableEndpoint(account);
|
var endpoint = _context.Storage.GetTableEndpoint(name);
|
||||||
var tableClient = new TableServiceClient(endpoint, new TableSharedKeyCredential(name, key));
|
var tableClient = new TableServiceClient(endpoint, new TableSharedKeyCredential(name, key));
|
||||||
await tableClient.CreateTableIfNotExistsAsync(tableName);
|
await tableClient.CreateTableIfNotExistsAsync(tableName);
|
||||||
return tableClient.GetTableClient(tableName);
|
return tableClient.GetTableClient(tableName);
|
||||||
@ -136,13 +136,42 @@ namespace ApiService.OneFuzzLib.Orm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public class StatefulOrm<T, TState> : Orm<T>, IStatefulOrm<T, TState> where T : StatefulEntityBase<TState> where TState : Enum {
|
public class StatefulOrm<T, TState, Self> : Orm<T>, IStatefulOrm<T, TState> where T : StatefulEntityBase<TState> where TState : Enum {
|
||||||
static Lazy<Func<object>>? _partitionKeyGetter;
|
static Lazy<Func<object>>? _partitionKeyGetter;
|
||||||
static Lazy<Func<object>>? _rowKeyGetter;
|
static Lazy<Func<object>>? _rowKeyGetter;
|
||||||
static ConcurrentDictionary<string, Func<T, Async.Task<T>>?> _stateFuncs = new ConcurrentDictionary<string, Func<T, Async.Task<T>>?>();
|
static ConcurrentDictionary<string, Func<T, Async.Task<T>>?> _stateFuncs = new ConcurrentDictionary<string, Func<T, Async.Task<T>>?>();
|
||||||
|
|
||||||
|
delegate Async.Task<T> StateTransition(T entity);
|
||||||
|
|
||||||
|
|
||||||
static StatefulOrm() {
|
static StatefulOrm() {
|
||||||
|
|
||||||
|
/// verify that all state transition function have the correct signature:
|
||||||
|
var thisType = typeof(Self);
|
||||||
|
var states = Enum.GetNames(typeof(TState));
|
||||||
|
var delegateType = typeof(StateTransition);
|
||||||
|
MethodInfo delegateSignature = delegateType.GetMethod("Invoke")!;
|
||||||
|
|
||||||
|
foreach (var state in states) {
|
||||||
|
var methodInfo = thisType?.GetMethod(state.ToString());
|
||||||
|
if (methodInfo == null) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool parametersEqual = delegateSignature
|
||||||
|
.GetParameters()
|
||||||
|
.Select(x => x.ParameterType)
|
||||||
|
.SequenceEqual(methodInfo.GetParameters()
|
||||||
|
.Select(x => x.ParameterType));
|
||||||
|
|
||||||
|
if (delegateSignature.ReturnType == methodInfo.ReturnType && parametersEqual) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Exception($"State transition method '{state}' in '{thisType?.Name}' does not have the correct signature. Expected '{delegateSignature}' actual '{methodInfo}' ");
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
_partitionKeyGetter =
|
_partitionKeyGetter =
|
||||||
typeof(T).GetProperties().FirstOrDefault(p => p.GetCustomAttributes(true).OfType<PartitionKeyAttribute>().Any())?.GetMethod switch {
|
typeof(T).GetProperties().FirstOrDefault(p => p.GetCustomAttributes(true).OfType<PartitionKeyAttribute>().Any())?.GetMethod switch {
|
||||||
null => null,
|
null => null,
|
||||||
@ -167,11 +196,10 @@ namespace ApiService.OneFuzzLib.Orm {
|
|||||||
/// <returns></returns>
|
/// <returns></returns>
|
||||||
public async Async.Task<T?> ProcessStateUpdate(T entity) {
|
public async Async.Task<T?> ProcessStateUpdate(T entity) {
|
||||||
TState state = entity.State;
|
TState state = entity.State;
|
||||||
var func = _stateFuncs.GetOrAdd(state.ToString(), (string k) =>
|
var func = GetType().GetMethod(state.ToString()) switch {
|
||||||
typeof(T).GetMethod(k) switch {
|
null => null,
|
||||||
null => null,
|
MethodInfo info => info.CreateDelegate<StateTransition>(this)
|
||||||
MethodInfo info => (Func<T, Async.Task<T>>)Delegate.CreateDelegate(typeof(Func<T, Async.Task<T>>), info)
|
};
|
||||||
});
|
|
||||||
|
|
||||||
if (func != null) {
|
if (func != null) {
|
||||||
_logTracer.Info($"processing state update: {typeof(T)} - PartitionKey {_partitionKeyGetter?.Value()} {_rowKeyGetter?.Value()} - %s");
|
_logTracer.Info($"processing state update: {typeof(T)} - PartitionKey {_partitionKeyGetter?.Value()} {_rowKeyGetter?.Value()} - %s");
|
||||||
|
@ -11,12 +11,15 @@ namespace ApiService.OneFuzzLib.Orm {
|
|||||||
public static string PartitionKey(string partitionKey)
|
public static string PartitionKey(string partitionKey)
|
||||||
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
|
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
|
||||||
|
|
||||||
public static string PartitionKey(Guid partitionKey)
|
|
||||||
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
|
|
||||||
|
|
||||||
public static string RowKey(string rowKey)
|
public static string RowKey(string rowKey)
|
||||||
=> TableClient.CreateQueryFilter($"RowKey eq {rowKey}");
|
=> TableClient.CreateQueryFilter($"RowKey eq {rowKey}");
|
||||||
|
|
||||||
|
public static string PartitionKeys(IEnumerable<string> partitionKeys)
|
||||||
|
=> Or(partitionKeys.Select(PartitionKey));
|
||||||
|
|
||||||
|
public static string RowKeys(IEnumerable<string> rowKeys)
|
||||||
|
=> Or(rowKeys.Select(RowKey));
|
||||||
|
|
||||||
public static string SingleEntity(string partitionKey, string rowKey)
|
public static string SingleEntity(string partitionKey, string rowKey)
|
||||||
=> TableClient.CreateQueryFilter($"(PartitionKey eq {partitionKey}) and (RowKey eq {rowKey})");
|
=> TableClient.CreateQueryFilter($"(PartitionKey eq {partitionKey}) and (RowKey eq {rowKey})");
|
||||||
|
|
||||||
|
@ -25,21 +25,31 @@ sealed class AzureStorage : IStorage {
|
|||||||
return new AzureStorage(accountName, accountKey);
|
return new AzureStorage(accountName, accountKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
public string? AccountName { get; }
|
public string AccountName { get; }
|
||||||
public string? AccountKey { get; }
|
public string AccountKey { get; }
|
||||||
|
|
||||||
public AzureStorage(string? accountName, string? accountKey) {
|
public AzureStorage(string accountName, string accountKey) {
|
||||||
AccountName = accountName;
|
AccountName = accountName;
|
||||||
AccountKey = accountKey;
|
AccountKey = accountKey;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId)
|
public IEnumerable<string> CorpusAccounts() {
|
||||||
=> Async.Task.FromResult((AccountName, AccountKey));
|
throw new System.NotImplementedException();
|
||||||
|
}
|
||||||
|
|
||||||
public IEnumerable<string> GetAccounts(StorageType storageType) {
|
public IEnumerable<string> GetAccounts(StorageType storageType) {
|
||||||
if (AccountName != null) {
|
yield return AccountName;
|
||||||
yield return AccountName;
|
}
|
||||||
}
|
|
||||||
|
public string GetPrimaryAccount(StorageType storageType) {
|
||||||
|
throw new System.NotImplementedException();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Task<(string, string)> GetStorageAccountNameAndKey(string accountId)
|
||||||
|
=> Async.Task.FromResult((AccountName, AccountKey));
|
||||||
|
|
||||||
|
public Task<string?> GetStorageAccountNameKeyByName(string accountName) {
|
||||||
|
return Async.Task.FromResult(AccountName)!;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Uri GetTableEndpoint(string accountId)
|
public Uri GetTableEndpoint(string accountId)
|
||||||
@ -51,15 +61,4 @@ sealed class AzureStorage : IStorage {
|
|||||||
public Uri GetBlobEndpoint(string accountId)
|
public Uri GetBlobEndpoint(string accountId)
|
||||||
=> new($"https://{AccountName}.blob.core.windows.net/");
|
=> new($"https://{AccountName}.blob.core.windows.net/");
|
||||||
|
|
||||||
public IEnumerable<string> CorpusAccounts() {
|
|
||||||
throw new System.NotImplementedException();
|
|
||||||
}
|
|
||||||
|
|
||||||
public string GetPrimaryAccount(StorageType storageType) {
|
|
||||||
throw new System.NotImplementedException();
|
|
||||||
}
|
|
||||||
|
|
||||||
public Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
|
|
||||||
throw new System.NotImplementedException();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -22,15 +22,11 @@ sealed class AzuriteStorage : IStorage {
|
|||||||
const string AccountName = "devstoreaccount1";
|
const string AccountName = "devstoreaccount1";
|
||||||
const string AccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
|
const string AccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
|
||||||
|
|
||||||
public Task<(string?, string?)> GetStorageAccountNameAndKey(string _accountId)
|
public Task<(string, string)> GetStorageAccountNameAndKey(string accountId)
|
||||||
=> Async.Task.FromResult<(string?, string?)>((AccountName, AccountKey));
|
=> Async.Task.FromResult((AccountName, AccountKey));
|
||||||
|
|
||||||
public IEnumerable<string> GetAccounts(StorageType storageType) {
|
public Task<string?> GetStorageAccountNameKeyByName(string accountName) {
|
||||||
yield return AccountName;
|
return Async.Task.FromResult(AccountName)!;
|
||||||
}
|
|
||||||
|
|
||||||
public Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
|
|
||||||
throw new System.NotImplementedException();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public IEnumerable<string> CorpusAccounts() {
|
public IEnumerable<string> CorpusAccounts() {
|
||||||
@ -40,4 +36,8 @@ sealed class AzuriteStorage : IStorage {
|
|||||||
public string GetPrimaryAccount(StorageType storageType) {
|
public string GetPrimaryAccount(StorageType storageType) {
|
||||||
throw new System.NotImplementedException();
|
throw new System.NotImplementedException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public IEnumerable<string> GetAccounts(StorageType storageType) {
|
||||||
|
yield return AccountName;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -35,6 +35,8 @@ public class RequestsTests {
|
|||||||
var deserialized = (T?)_serializer.Deserialize(stream, typeof(T), CancellationToken.None);
|
var deserialized = (T?)_serializer.Deserialize(stream, typeof(T), CancellationToken.None);
|
||||||
var reserialized = _serializer.Serialize(deserialized);
|
var reserialized = _serializer.Serialize(deserialized);
|
||||||
var result = Encoding.UTF8.GetString(reserialized);
|
var result = Encoding.UTF8.GetString(reserialized);
|
||||||
|
result = result.Replace(System.Environment.NewLine, "\n");
|
||||||
|
json = json.Replace(System.Environment.NewLine, "\n");
|
||||||
Assert.Equal(json, result);
|
Assert.Equal(json, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
133
src/ApiService/Tests/SchedulerTests.cs
Normal file
133
src/ApiService/Tests/SchedulerTests.cs
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Linq;
|
||||||
|
using Microsoft.OneFuzz.Service;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Tests;
|
||||||
|
|
||||||
|
public class SchedulerTests {
|
||||||
|
|
||||||
|
IEnumerable<Task> BuildTasks(int size) {
|
||||||
|
return Enumerable.Range(0, size).Select(i =>
|
||||||
|
new Task(
|
||||||
|
Guid.Empty,
|
||||||
|
Guid.NewGuid(),
|
||||||
|
TaskState.Init,
|
||||||
|
Os.Linux,
|
||||||
|
new TaskConfig(
|
||||||
|
Guid.Empty,
|
||||||
|
null,
|
||||||
|
new TaskDetails(
|
||||||
|
Type: TaskType.LibfuzzerFuzz,
|
||||||
|
Duration: 1,
|
||||||
|
TargetExe: "fuzz.exe",
|
||||||
|
TargetEnv: new Dictionary<string, string>(),
|
||||||
|
TargetOptions: new List<string>()),
|
||||||
|
Pool: new TaskPool(1, PoolName.Parse("pool")),
|
||||||
|
Containers: new List<TaskContainers> { new TaskContainers(ContainerType.Setup, new Container("setup")) },
|
||||||
|
Colocate: true
|
||||||
|
|
||||||
|
),
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null,
|
||||||
|
null)
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void TestAllColocate() {
|
||||||
|
// all tasks should land in one bucket
|
||||||
|
|
||||||
|
var tasks = BuildTasks(10).Select(task => task with { Config = task.Config with { Colocate = true } }
|
||||||
|
).ToList();
|
||||||
|
|
||||||
|
var buckets = Scheduler.BucketTasks(tasks);
|
||||||
|
var bucket = Assert.Single(buckets);
|
||||||
|
Assert.True(10 >= bucket.Count());
|
||||||
|
CheckBuckets(buckets, tasks, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void TestPartialColocate() {
|
||||||
|
// 2 tasks should land on their own, the rest should be colocated into a
|
||||||
|
// single bucket.
|
||||||
|
|
||||||
|
var tasks = BuildTasks(10).Select((task, i) => {
|
||||||
|
return i switch {
|
||||||
|
0 => task with { Config = task.Config with { Colocate = null } },
|
||||||
|
1 => task with { Config = task.Config with { Colocate = false } },
|
||||||
|
_ => task
|
||||||
|
};
|
||||||
|
}).ToList();
|
||||||
|
var buckets = Scheduler.BucketTasks(tasks);
|
||||||
|
var lengths = buckets.Select(b => b.Count()).OrderBy(x => x);
|
||||||
|
Assert.Equal(new[] { 1, 1, 8 }, lengths);
|
||||||
|
CheckBuckets(buckets, tasks, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void TestAlluniqueJob() {
|
||||||
|
// everything has a unique job_id
|
||||||
|
var tasks = BuildTasks(10).Select(task => {
|
||||||
|
var jobId = Guid.NewGuid();
|
||||||
|
return task with { JobId = jobId, Config = task.Config with { JobId = jobId } };
|
||||||
|
}).ToList();
|
||||||
|
|
||||||
|
var buckets = Scheduler.BucketTasks(tasks);
|
||||||
|
foreach (var bucket in buckets) {
|
||||||
|
Assert.True(1 >= bucket.Count());
|
||||||
|
}
|
||||||
|
CheckBuckets(buckets, tasks, 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void TestMultipleJobBuckets() {
|
||||||
|
// at most 3 tasks per bucket, by job_id
|
||||||
|
var tasks = BuildTasks(10).Chunk(3).SelectMany(taskChunk => {
|
||||||
|
var jobId = Guid.NewGuid();
|
||||||
|
return taskChunk.Select(task => task with { JobId = jobId, Config = task.Config with { JobId = jobId } });
|
||||||
|
}).ToList();
|
||||||
|
|
||||||
|
var buckets = Scheduler.BucketTasks(tasks);
|
||||||
|
foreach (var bucket in buckets) {
|
||||||
|
Assert.True(3 >= bucket.Count());
|
||||||
|
}
|
||||||
|
CheckBuckets(buckets, tasks, 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void TestManyBuckets() {
|
||||||
|
var jobId = Guid.Parse("00000000-0000-0000-0000-000000000001");
|
||||||
|
var tasks = BuildTasks(100).Select((task, i) => {
|
||||||
|
var containers = new List<TaskContainers>(task.Config.Containers!);
|
||||||
|
if (i % 4 == 0) {
|
||||||
|
containers[0] = containers[0] with { Name = new Container("setup2") };
|
||||||
|
}
|
||||||
|
return task with {
|
||||||
|
JobId = i % 2 == 0 ? jobId : task.JobId,
|
||||||
|
Os = i % 3 == 0 ? Os.Windows : task.Os,
|
||||||
|
Config = task.Config with {
|
||||||
|
JobId = i % 2 == 0 ? jobId : task.Config.JobId,
|
||||||
|
Containers = containers,
|
||||||
|
Pool = i % 5 == 0 ? task.Config.Pool! with { PoolName = PoolName.Parse("alternate-pool") } : task.Config.Pool
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}).ToList();
|
||||||
|
|
||||||
|
var buckets = Scheduler.BucketTasks(tasks);
|
||||||
|
|
||||||
|
CheckBuckets(buckets, tasks, 12);
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckBuckets(ILookup<Scheduler.BucketId, Task> buckets, List<Task> tasks, int bucketCount) {
|
||||||
|
Assert.Equal(buckets.Count, bucketCount);
|
||||||
|
|
||||||
|
foreach (var task in tasks) {
|
||||||
|
Assert.Single(buckets, b => b.Contains(task));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -8,12 +8,6 @@ public class ValidatedStringTests {
|
|||||||
|
|
||||||
record ThingContainingPoolName(PoolName PoolName);
|
record ThingContainingPoolName(PoolName PoolName);
|
||||||
|
|
||||||
[Fact]
|
|
||||||
public void PoolNameValidatesOnDeserialization() {
|
|
||||||
var ex = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-not!-a-pool\" }"));
|
|
||||||
Assert.Equal("unable to parse input as a PoolName", ex.Message);
|
|
||||||
}
|
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public void PoolNameDeserializesFromString() {
|
public void PoolNameDeserializesFromString() {
|
||||||
var result = JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-a-pool\" }");
|
var result = JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-a-pool\" }");
|
||||||
|
Reference in New Issue
Block a user