mirror of
https://github.com/microsoft/onefuzz.git
synced 2025-06-12 01:58:18 +00:00
Migrate timer_task part 3 (#2066)
* Adding unit tests for the scheduler * formatting * fix unit test * bug fixes * proting a couple more missong functions * more bug fixes * fixing queries and serilization * fix sas url generation * fix condition * [testing] enabling timer_tasks for testing * Update src/ApiService/Tests/SchedulerTests.cs Co-authored-by: George Pollard <porges@porg.es> * address PR comments * build fix * address PR comment * removing renamed function * resolve merge * Update src/ApiService/Tests/Integration/AzuriteStorage.cs Co-authored-by: George Pollard <porges@porg.es> * - Added verification of the state transition functions - disabled validation on PoolName * Update src/deployment/deploy.py Co-authored-by: George Pollard <porges@porg.es> Co-authored-by: George Pollard <gpollard@microsoft.com>
This commit is contained in:
@ -184,7 +184,8 @@ public record TaskDetails(
|
||||
bool? PreserveExistingOutputs = null,
|
||||
List<string>? ReportList = null,
|
||||
int? MinimizedStackDepth = null,
|
||||
string? CoverageFilter = null);
|
||||
string? CoverageFilter = null
|
||||
);
|
||||
|
||||
public record TaskVm(
|
||||
Region Region,
|
||||
@ -214,7 +215,8 @@ public record TaskConfig(
|
||||
List<TaskContainers>? Containers = null,
|
||||
Dictionary<string, string>? Tags = null,
|
||||
List<TaskDebugFlag>? Debug = null,
|
||||
bool? Colocate = null);
|
||||
bool? Colocate = null
|
||||
);
|
||||
|
||||
public record TaskEventSummary(
|
||||
DateTimeOffset? Timestamp,
|
||||
@ -590,9 +592,29 @@ public record WorkUnit(
|
||||
Guid JobId,
|
||||
Guid TaskId,
|
||||
TaskType TaskType,
|
||||
TaskUnitConfig Config
|
||||
|
||||
// JSON-serialized `TaskUnitConfig`.
|
||||
[property: JsonConverter(typeof(TaskUnitConfigConverter))] TaskUnitConfig Config
|
||||
);
|
||||
|
||||
public class TaskUnitConfigConverter : JsonConverter<TaskUnitConfig> {
|
||||
public override TaskUnitConfig? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
|
||||
var taskUnitString = reader.GetString();
|
||||
if (taskUnitString == null) {
|
||||
return null;
|
||||
}
|
||||
return JsonSerializer.Deserialize<TaskUnitConfig>(taskUnitString, options);
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, TaskUnitConfig value, JsonSerializerOptions options) {
|
||||
var v = JsonSerializer.Serialize(value, new JsonSerializerOptions(options) {
|
||||
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
|
||||
});
|
||||
|
||||
writer.WriteStringValue(v);
|
||||
}
|
||||
}
|
||||
|
||||
public record VmDefinition(
|
||||
Compare Compare,
|
||||
int Value
|
||||
@ -625,14 +647,36 @@ public record ContainerDefinition(
|
||||
|
||||
// TODO: service shouldn't pass SyncedDir, but just the url and let the agent
|
||||
// come up with paths
|
||||
public record SyncedDir(string Path, Uri url);
|
||||
public record SyncedDir(string Path, Uri Url);
|
||||
|
||||
|
||||
[JsonConverter(typeof(ContainerDefConverter))]
|
||||
public interface IContainerDef { }
|
||||
public record SingleContainer(SyncedDir SyncedDir) : IContainerDef;
|
||||
public record MultipleContainer(List<SyncedDir> SyncedDirs) : IContainerDef;
|
||||
|
||||
|
||||
public class ContainerDefConverter : JsonConverter<IContainerDef> {
|
||||
public override IContainerDef? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override void Write(Utf8JsonWriter writer, IContainerDef value, JsonSerializerOptions options) {
|
||||
switch (value) {
|
||||
case SingleContainer container:
|
||||
JsonSerializer.Serialize(writer, container.SyncedDir, options);
|
||||
break;
|
||||
case MultipleContainer { SyncedDirs: var syncedDirs }:
|
||||
JsonSerializer.Serialize(writer, syncedDirs, options);
|
||||
break;
|
||||
default:
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
public record TaskUnitConfig(
|
||||
Guid InstanceId,
|
||||
Guid JobId,
|
||||
@ -686,7 +730,7 @@ public record TaskUnitConfig(
|
||||
public IContainerDef? Tools { get; set; }
|
||||
public IContainerDef? UniqueInputs { get; set; }
|
||||
public IContainerDef? UniqueReports { get; set; }
|
||||
public IContainerDef? RegressionReport { get; set; }
|
||||
public IContainerDef? RegressionReports { get; set; }
|
||||
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,4 @@
|
||||
|
||||
using System.Diagnostics;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Text.Json;
|
||||
using System.Text.Json.Serialization;
|
||||
using System.Text.RegularExpressions;
|
||||
@ -51,7 +49,7 @@ public abstract class ValidatedStringConverter<T> : JsonConverter<T> where T : V
|
||||
[JsonConverter(typeof(Converter))]
|
||||
public record PoolName : ValidatedString {
|
||||
private PoolName(string value) : base(value) {
|
||||
Debug.Assert(Check.IsAlnumDash(value));
|
||||
// Debug.Assert(Check.IsAlnumDash(value));
|
||||
}
|
||||
|
||||
public static PoolName Parse(string input) {
|
||||
@ -63,10 +61,14 @@ public record PoolName : ValidatedString {
|
||||
}
|
||||
|
||||
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) {
|
||||
if (!Check.IsAlnumDash(input)) {
|
||||
result = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
// bypassing the validation because this code has a stricter validation than the python equivalent
|
||||
// see (issue #2080)
|
||||
|
||||
// if (!Check.IsAlnumDash(input)) {
|
||||
// result = default;
|
||||
// return false;
|
||||
// }
|
||||
|
||||
result = new PoolName(input);
|
||||
return true;
|
||||
|
@ -21,8 +21,8 @@ public class TimerTasks {
|
||||
_scheduler = scheduler;
|
||||
}
|
||||
|
||||
//[Function("TimerTasks")]
|
||||
public async Async.Task Run([TimerTrigger("1.00:00:00")] TimerInfo myTimer) {
|
||||
[Function("TimerTasks")]
|
||||
public async Async.Task Run([TimerTrigger("00:00:15")] TimerInfo myTimer) {
|
||||
var expriredTasks = _taskOperations.SearchExpired();
|
||||
await foreach (var task in expriredTasks) {
|
||||
_logger.Info($"stopping expired task. job_id:{task.JobId} task_id:{task.TaskId}");
|
||||
|
@ -5,7 +5,6 @@ namespace Microsoft.OneFuzz.Service;
|
||||
|
||||
|
||||
public interface IConfig {
|
||||
string GetSetupContainer(TaskConfig config);
|
||||
Async.Task<TaskUnitConfig> BuildTaskConfig(Job job, Task task);
|
||||
}
|
||||
|
||||
@ -89,9 +88,13 @@ public class Config : IConfig {
|
||||
|
||||
await foreach (var data in containersByType) {
|
||||
|
||||
if (!data.containers.Any()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
IContainerDef def = data.countainerDef switch {
|
||||
ContainerDefinition { Compare: Compare.Equal, Value: 1 } or
|
||||
ContainerDefinition { Compare: Compare.AtMost, Value: 1 } => new SingleContainer(data.containers[0]),
|
||||
ContainerDefinition { Compare: Compare.AtMost, Value: 1 } when data.containers.Count == 1 => new SingleContainer(data.containers[0]),
|
||||
_ => new MultipleContainer(data.containers)
|
||||
};
|
||||
|
||||
@ -126,6 +129,9 @@ public class Config : IConfig {
|
||||
case ContainerType.UniqueReports:
|
||||
config.UniqueReports = def;
|
||||
break;
|
||||
case ContainerType.RegressionReports:
|
||||
config.RegressionReports = def;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,16 +255,4 @@ public class Config : IConfig {
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
|
||||
public string GetSetupContainer(TaskConfig config) {
|
||||
|
||||
foreach (var container in config.Containers ?? throw new Exception("Missing containers")) {
|
||||
if (container.Type == ContainerType.Setup) {
|
||||
return container.Name.ContainerName;
|
||||
}
|
||||
}
|
||||
|
||||
throw new Exception($"task missing setup container: task_type = {config.Task.Type}");
|
||||
}
|
||||
}
|
||||
|
@ -164,14 +164,15 @@ public class Containers : IContainers {
|
||||
if (uri.Query.Contains("sig")) {
|
||||
return uri;
|
||||
}
|
||||
|
||||
var accountName = uri.Host.Split('.')[0];
|
||||
var (_, accountKey) = await _storage.GetStorageAccountNameAndKey(accountName);
|
||||
var blobUriBuilder = new BlobUriBuilder(uri);
|
||||
var accountKey = await _storage.GetStorageAccountNameKeyByName(blobUriBuilder.AccountName);
|
||||
var sasBuilder = new BlobSasBuilder(
|
||||
BlobContainerSasPermissions.Read | BlobContainerSasPermissions.Write | BlobContainerSasPermissions.Delete | BlobContainerSasPermissions.List,
|
||||
DateTimeOffset.UtcNow + TimeSpan.FromHours(1));
|
||||
DateTimeOffset.UtcNow + TimeSpan.FromHours(1)) {
|
||||
BlobContainerName = blobUriBuilder.BlobContainerName,
|
||||
};
|
||||
|
||||
var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(accountName, accountKey)).ToString();
|
||||
var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(blobUriBuilder.AccountName, accountKey)).ToString();
|
||||
return new UriBuilder(uri) {
|
||||
Query = sas
|
||||
}.Uri;
|
||||
@ -179,13 +180,10 @@ public class Containers : IContainers {
|
||||
|
||||
public async Async.Task<Uri> GetContainerSasUrl(Container container, StorageType storageType, BlobContainerSasPermissions permissions, TimeSpan? duration = null) {
|
||||
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
|
||||
var (accountName, accountKey) = await _storage.GetStorageAccountNameAndKey(client.AccountName);
|
||||
|
||||
var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30));
|
||||
|
||||
var sasBuilder = new BlobSasBuilder(permissions, endTime) {
|
||||
StartsOn = startTime,
|
||||
BlobContainerName = container.ContainerName,
|
||||
BlobContainerName = container.ContainerName
|
||||
};
|
||||
|
||||
var sasUrl = client.GenerateSasUri(sasBuilder);
|
||||
|
@ -6,12 +6,13 @@ public interface IJobOperations : IStatefulOrm<Job, JobState> {
|
||||
Async.Task<Job?> Get(Guid jobId);
|
||||
Async.Task OnStart(Job job);
|
||||
IAsyncEnumerable<Job> SearchExpired();
|
||||
Async.Task Stopping(Job job);
|
||||
Async.Task<Job> Stopping(Job job);
|
||||
IAsyncEnumerable<Job> SearchState(IEnumerable<JobState> states);
|
||||
Async.Task StopNeverStartedJobs();
|
||||
Async.Task StopIfAllDone(Job job);
|
||||
}
|
||||
|
||||
public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
||||
public class JobOperations : StatefulOrm<Job, JobState, JobOperations>, IJobOperations {
|
||||
private static TimeSpan JOB_NEVER_STARTED_DURATION = TimeSpan.FromDays(30);
|
||||
|
||||
public JobOperations(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) {
|
||||
@ -36,6 +37,17 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
||||
return QueryAsync(filter: query);
|
||||
}
|
||||
|
||||
public async Async.Task StopIfAllDone(Job job) {
|
||||
var anyNotStoppedJobs = await _context.TaskOperations.GetByJobId(job.JobId).AnyAsync(task => task.State != TaskState.Stopped);
|
||||
|
||||
if (anyNotStoppedJobs) {
|
||||
return;
|
||||
}
|
||||
|
||||
_logTracer.Info($"stopping job as all tasks are stopped: {job.JobId}");
|
||||
await Stopping(job);
|
||||
}
|
||||
|
||||
public async Async.Task StopNeverStartedJobs() {
|
||||
// # Note, the "not(end_time...)" with end_time set long before the use of
|
||||
// # OneFuzz enables identifying those without end_time being set.
|
||||
@ -58,7 +70,7 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
||||
}
|
||||
}
|
||||
|
||||
public async Async.Task Stopping(Job job) {
|
||||
public async Async.Task<Job> Stopping(Job job) {
|
||||
job = job with { State = JobState.Stopping };
|
||||
var tasks = await _context.TaskOperations.QueryAsync(filter: $"job_id eq '{job.JobId}'").ToListAsync();
|
||||
var taskNotStopped = tasks.ToLookup(task => task.State != TaskState.Stopped);
|
||||
@ -76,7 +88,12 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
|
||||
await _context.Events.SendEvent(new EventJobStopped(job.JobId, job.Config, job.UserInfo, taskInfo));
|
||||
}
|
||||
|
||||
await Replace(job);
|
||||
var result = await Replace(job);
|
||||
|
||||
if (result.IsOk) {
|
||||
return job;
|
||||
} else {
|
||||
throw new Exception($"Failed to save job {job.JobId} : {result.ErrorV}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -43,6 +43,8 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
|
||||
Async.Task MarkTasksStoppedEarly(Node node, Error? error = null);
|
||||
static TimeSpan NODE_EXPIRATION_TIME = TimeSpan.FromHours(1.0);
|
||||
static TimeSpan NODE_REIMAGE_TIME = TimeSpan.FromDays(6.0);
|
||||
|
||||
Async.Task StopTask(Guid task_id);
|
||||
}
|
||||
|
||||
|
||||
@ -51,7 +53,7 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
|
||||
/// Enabling autoscaling for the scalesets based on the pool work queues.
|
||||
/// https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
|
||||
|
||||
public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
|
||||
public class NodeOperations : StatefulOrm<Node, NodeState, NodeOperations>, INodeOperations {
|
||||
|
||||
|
||||
public NodeOperations(
|
||||
@ -269,7 +271,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
|
||||
|
||||
|
||||
public async Async.Task<Node?> GetByMachineId(Guid machineId) {
|
||||
var data = QueryAsync(filter: $"RowKey eq '{machineId}'");
|
||||
var data = QueryAsync(filter: Query.RowKey(machineId.ToString()));
|
||||
|
||||
return await data.FirstOrDefaultAsync();
|
||||
}
|
||||
@ -387,18 +389,49 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
|
||||
await _context.Events.SendEvent(new EventNodeDeleted(node.MachineId, node.ScalesetId, node.PoolName, node.State));
|
||||
}
|
||||
|
||||
public async Async.Task StopTask(Guid task_id) {
|
||||
// For now, this just re-images the node. Eventually, this
|
||||
// should send a message to the node to let the agent shut down
|
||||
// gracefully
|
||||
|
||||
var nodes = _context.NodeTasksOperations.GetNodesByTaskId(task_id);
|
||||
|
||||
await foreach (var node in nodes) {
|
||||
await _context.NodeMessageOperations.SendMessage(node.MachineId, new NodeCommand(StopTask: new StopTaskNodeCommand(task_id)));
|
||||
|
||||
if (!(await StopIfComplete(node))) {
|
||||
_logTracer.Info($"nodes: stopped task on node, but not reimaging due to other tasks: task_id:{task_id} machine_id:{node.MachineId}");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// returns True on stopping the node and False if this doesn't stop the node
|
||||
private async Task<bool> StopIfComplete(Node node, bool done = false) {
|
||||
var nodeTaskIds = await _context.NodeTasksOperations.GetByMachineId(node.MachineId).Select(nt => nt.TaskId).ToArrayAsync();
|
||||
var tasks = _context.TaskOperations.GetByTaskIds(nodeTaskIds);
|
||||
await foreach (var task in tasks) {
|
||||
if (!TaskStateHelper.ShuttingDown(task.State)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
_logTracer.Info($"node: stopping busy node with all tasks complete: {node.MachineId}");
|
||||
|
||||
await Stop(node, done: done);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public interface INodeTasksOperations : IStatefulOrm<NodeTasks, NodeTaskState> {
|
||||
IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId, INodeOperations nodeOps);
|
||||
IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId);
|
||||
IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps);
|
||||
IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId);
|
||||
IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId);
|
||||
Async.Task ClearByMachineId(Guid machineId);
|
||||
}
|
||||
|
||||
public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeTasksOperations {
|
||||
public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState, NodeTasksOperations>, INodeTasksOperations {
|
||||
|
||||
ILogTracer _log;
|
||||
|
||||
@ -408,18 +441,20 @@ public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeT
|
||||
}
|
||||
|
||||
//TODO: suggest by Cheick: this can probably be optimize by query all NodesTasks then query the all machine in single request
|
||||
public async IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId, INodeOperations nodeOps) {
|
||||
await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) {
|
||||
var node = await nodeOps.GetByMachineId(entry.MachineId);
|
||||
|
||||
public async IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId) {
|
||||
await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) {
|
||||
var node = await _context.NodeOperations.GetByMachineId(entry.MachineId);
|
||||
if (node is not null) {
|
||||
yield return node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps) {
|
||||
|
||||
await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) {
|
||||
var node = await nodeOps.GetByMachineId(entry.MachineId);
|
||||
await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) {
|
||||
var node = await _context.NodeOperations.GetByMachineId(entry.MachineId);
|
||||
if (node is not null) {
|
||||
var nodeAssignment = new NodeAssignment(node.MachineId, node.ScalesetId, entry.State);
|
||||
yield return nodeAssignment;
|
||||
@ -428,11 +463,11 @@ public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeT
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId) {
|
||||
return QueryAsync($"machine_id eq '{machineId}'");
|
||||
return QueryAsync(Query.PartitionKey(machineId.ToString()));
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId) {
|
||||
return QueryAsync($"task_id eq '{taskId}'");
|
||||
return QueryAsync(Query.RowKey(taskId.ToString()));
|
||||
}
|
||||
|
||||
public async Async.Task ClearByMachineId(Guid machineId) {
|
||||
@ -472,7 +507,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
|
||||
=> QueryAsync(Query.PartitionKey(machineId));
|
||||
=> QueryAsync(Query.PartitionKey(machineId.ToString()));
|
||||
|
||||
public async Async.Task ClearMessages(Guid machineId) {
|
||||
_logTracer.Info($"clearing messages for node {machineId}");
|
||||
|
@ -9,7 +9,7 @@ public interface IPoolOperations {
|
||||
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
|
||||
}
|
||||
|
||||
public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations {
|
||||
public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoolOperations {
|
||||
|
||||
public PoolOperations(ILogTracer log, IOnefuzzContext context)
|
||||
: base(log, context) {
|
||||
|
@ -13,7 +13,7 @@ public interface IProxyOperations : IStatefulOrm<Proxy, VmState> {
|
||||
bool IsOutdated(Proxy proxy);
|
||||
Async.Task<Proxy?> GetOrCreate(string region);
|
||||
}
|
||||
public class ProxyOperations : StatefulOrm<Proxy, VmState>, IProxyOperations {
|
||||
public class ProxyOperations : StatefulOrm<Proxy, VmState, ProxyOperations>, IProxyOperations {
|
||||
|
||||
|
||||
static TimeSpan PROXY_LIFESPAN = TimeSpan.FromDays(7);
|
||||
|
@ -51,7 +51,7 @@ public class Queue : IQueue {
|
||||
var accountId = _storage.GetPrimaryAccount(storageType);
|
||||
_log.Verbose($"getting blob container (account_id: {accountId})");
|
||||
var (name, key) = await _storage.GetStorageAccountNameAndKey(accountId);
|
||||
var endpoint = _storage.GetQueueEndpoint(accountId);
|
||||
var endpoint = _storage.GetQueueEndpoint(name);
|
||||
var options = new QueueClientOptions { MessageEncoding = QueueMessageEncoding.Base64 };
|
||||
return new QueueServiceClient(endpoint, new StorageSharedKeyCredential(name, key), options);
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ public interface IReproOperations : IStatefulOrm<Repro, VmState> {
|
||||
public IAsyncEnumerable<Repro> SearchStates(IEnumerable<VmState>? States);
|
||||
}
|
||||
|
||||
public class ReproOperations : StatefulOrm<Repro, VmState>, IReproOperations {
|
||||
public class ReproOperations : StatefulOrm<Repro, VmState, ReproOperations>, IReproOperations {
|
||||
private static readonly Dictionary<Os, string> DEFAULT_OS = new Dictionary<Os, string>
|
||||
{
|
||||
{Os.Linux, "Canonical:UbuntuServer:18.04-LTS:latest"},
|
||||
|
@ -14,7 +14,7 @@ public interface IScalesetOperations : IOrm<Scaleset> {
|
||||
IAsyncEnumerable<Scaleset> GetByObjectId(Guid objectId);
|
||||
}
|
||||
|
||||
public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalesetOperations {
|
||||
public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState, ScalesetOperations>, IScalesetOperations {
|
||||
const string SCALESET_LOG_PREFIX = "scalesets: ";
|
||||
|
||||
ILogTracer _log;
|
||||
|
@ -74,7 +74,7 @@ public class Scheduler : IScheduler {
|
||||
|
||||
private async Async.Task<(BucketConfig, WorkSet)?> BuildWorkSet(Task[] tasks) {
|
||||
var taskIds = tasks.Select(x => x.TaskId).ToHashSet();
|
||||
var work_units = new List<WorkUnit>();
|
||||
var workUnits = new List<WorkUnit>();
|
||||
|
||||
BucketConfig? bucketConfig = null;
|
||||
foreach (var task in tasks) {
|
||||
@ -99,7 +99,7 @@ public class Scheduler : IScheduler {
|
||||
throw new Exception($"bucket configs differ: {bucketConfig} VS {result.Value.Item1}");
|
||||
}
|
||||
|
||||
work_units.Add(result.Value.Item2);
|
||||
workUnits.Add(result.Value.Item2);
|
||||
}
|
||||
|
||||
if (bucketConfig != null) {
|
||||
@ -108,7 +108,7 @@ public class Scheduler : IScheduler {
|
||||
Reboot: bucketConfig.reboot,
|
||||
Script: bucketConfig.setupScript != null,
|
||||
SetupUrl: setupUrl,
|
||||
WorkUnits: work_units
|
||||
WorkUnits: workUnits
|
||||
);
|
||||
|
||||
return (bucketConfig, workSet);
|
||||
@ -182,9 +182,9 @@ public class Scheduler : IScheduler {
|
||||
return (bucketConfig, workUnit);
|
||||
}
|
||||
|
||||
record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
|
||||
public record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
|
||||
|
||||
private ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
|
||||
public static ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
|
||||
|
||||
// buckets are hashed by:
|
||||
// OS, JOB ID, vm sku & image (if available), pool name (if available),
|
||||
@ -214,8 +214,19 @@ public class Scheduler : IScheduler {
|
||||
unique = Guid.NewGuid();
|
||||
}
|
||||
|
||||
return new BucketId(task.Os, task.JobId, vm, pool, _config.GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique);
|
||||
return new BucketId(task.Os, task.JobId, vm, pool, GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique);
|
||||
|
||||
});
|
||||
}
|
||||
|
||||
static string GetSetupContainer(TaskConfig config) {
|
||||
|
||||
foreach (var container in config.Containers ?? throw new Exception("Missing containers")) {
|
||||
if (container.Type == ContainerType.Setup) {
|
||||
return container.Name.ContainerName;
|
||||
}
|
||||
}
|
||||
|
||||
throw new Exception($"task missing setup container: task_type = {config.Task.Type}");
|
||||
}
|
||||
}
|
||||
|
@ -20,9 +20,9 @@ public interface IStorage {
|
||||
|
||||
public Uri GetBlobEndpoint(string accountId);
|
||||
|
||||
public Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId);
|
||||
public Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId);
|
||||
|
||||
public Async.Task<string?> GetStorageAccountNameAndKeyByName(string accountName);
|
||||
public Async.Task<string?> GetStorageAccountNameKeyByName(string accountName);
|
||||
|
||||
public IEnumerable<string> GetAccounts(StorageType storageType);
|
||||
}
|
||||
@ -99,16 +99,16 @@ public class Storage : IStorage {
|
||||
};
|
||||
}
|
||||
|
||||
public async Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId) {
|
||||
public async Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId) {
|
||||
var resourceId = new ResourceIdentifier(accountId);
|
||||
var armClient = GetMgmtClient();
|
||||
var storageAccount = armClient.GetStorageAccountResource(resourceId);
|
||||
var keys = await storageAccount.GetKeysAsync();
|
||||
var key = keys.Value.Keys.FirstOrDefault();
|
||||
return (resourceId.Name, key?.Value);
|
||||
var key = keys.Value.Keys.FirstOrDefault() ?? throw new Exception("no keys found");
|
||||
return (resourceId.Name, key.Value);
|
||||
}
|
||||
|
||||
public async Async.Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
|
||||
public async Async.Task<string?> GetStorageAccountNameKeyByName(string accountName) {
|
||||
var armClient = GetMgmtClient();
|
||||
var resourceGroup = _creds.GetResourceGroupResourceIdentifier();
|
||||
var storageAccount = await armClient.GetResourceGroupResource(resourceGroup).GetStorageAccountAsync(accountName);
|
||||
|
@ -5,6 +5,10 @@ namespace Microsoft.OneFuzz.Service;
|
||||
public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
|
||||
Async.Task<Task?> GetByTaskId(Guid taskId);
|
||||
|
||||
IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId);
|
||||
|
||||
IAsyncEnumerable<Task> GetByJobId(Guid jobId);
|
||||
|
||||
Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId);
|
||||
|
||||
|
||||
@ -22,7 +26,7 @@ public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
|
||||
Async.Task<Task> SetState(Task task, TaskState state);
|
||||
}
|
||||
|
||||
public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
||||
public class TaskOperations : StatefulOrm<Task, TaskState, TaskOperations>, ITaskOperations {
|
||||
|
||||
|
||||
public TaskOperations(ILogTracer log, IOnefuzzContext context)
|
||||
@ -31,9 +35,15 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
||||
}
|
||||
|
||||
public async Async.Task<Task?> GetByTaskId(Guid taskId) {
|
||||
var data = QueryAsync(filter: $"RowKey eq '{taskId}'");
|
||||
return await GetByTaskIds(new[] { taskId }).FirstOrDefaultAsync();
|
||||
}
|
||||
|
||||
return await data.FirstOrDefaultAsync();
|
||||
public IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId) {
|
||||
return QueryAsync(filter: Query.RowKeys(taskId.Select(t => t.ToString())));
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<Task> GetByJobId(Guid jobId) {
|
||||
return QueryAsync(filter: $"PartitionKey eq '{jobId}'");
|
||||
}
|
||||
|
||||
public async Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId) {
|
||||
@ -42,21 +52,13 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
||||
return await data.FirstOrDefaultAsync();
|
||||
}
|
||||
public IAsyncEnumerable<Task> SearchStates(Guid? jobId = null, IEnumerable<TaskState>? states = null) {
|
||||
var queryString = String.Empty;
|
||||
if (jobId != null) {
|
||||
queryString += $"PartitionKey eq '{jobId}'";
|
||||
}
|
||||
|
||||
if (states != null) {
|
||||
if (jobId != null) {
|
||||
queryString += " and ";
|
||||
}
|
||||
|
||||
queryString += "(" + string.Join(
|
||||
" or ",
|
||||
states.Select(s => $"state eq '{s}'")
|
||||
) + ")";
|
||||
}
|
||||
var queryString =
|
||||
(jobId, states) switch {
|
||||
(null, null) => "",
|
||||
(Guid id, null) => Query.PartitionKey($"{id}"),
|
||||
(null, IEnumerable<TaskState> s) => Query.EqualAnyEnum("state", s),
|
||||
(Guid id, IEnumerable<TaskState> s) => Query.And(Query.PartitionKey($"{id}"), Query.EqualAnyEnum("state", s)),
|
||||
};
|
||||
|
||||
return QueryAsync(filter: queryString);
|
||||
}
|
||||
@ -209,7 +211,7 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
||||
|
||||
// if a prereq task fails, then mark this task as failed
|
||||
if (t == null) {
|
||||
await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find task" }));
|
||||
await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find prereq task" }));
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -253,4 +255,33 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
|
||||
return null;
|
||||
|
||||
}
|
||||
|
||||
public async Async.Task<Task> Init(Task task) {
|
||||
await _context.Queue.CreateQueue($"{task.TaskId}", StorageType.Corpus);
|
||||
return await SetState(task, TaskState.Waiting);
|
||||
}
|
||||
|
||||
|
||||
public async Async.Task<Task> Stopping(Task task) {
|
||||
_logTracer.Info($"stopping task : {task.JobId}, {task.TaskId}");
|
||||
await _context.NodeOperations.StopTask(task.TaskId);
|
||||
var anyRemainingNodes = await _context.NodeTasksOperations.GetNodesByTaskId(task.TaskId).AnyAsync();
|
||||
if (!anyRemainingNodes) {
|
||||
return await Stopped(task);
|
||||
}
|
||||
return task;
|
||||
}
|
||||
|
||||
private async Async.Task<Task> Stopped(Task inputTask) {
|
||||
var task = await SetState(inputTask, TaskState.Stopped);
|
||||
await _context.Queue.DeleteQueue($"{task.TaskId}", StorageType.Corpus);
|
||||
|
||||
// # TODO: we need to 'unschedule' this task from the existing pools
|
||||
var job = await _context.JobOperations.Get(task.JobId);
|
||||
if (job != null) {
|
||||
await _context.JobOperations.StopIfAllDone(job);
|
||||
}
|
||||
|
||||
return task;
|
||||
}
|
||||
}
|
||||
|
@ -210,8 +210,8 @@ public class EntityConverter {
|
||||
return Guid.Parse(entity.GetString(ef.kind.ToString()));
|
||||
else if (ef.type == typeof(int))
|
||||
return int.Parse(entity.GetString(ef.kind.ToString()));
|
||||
else if (ef.type == typeof(PoolName))
|
||||
// TODO: this should be able to be generic over any ValidatedString
|
||||
else if (ef.type == typeof(PoolName))
|
||||
// TODO: this should be able to be generic over any ValidatedString
|
||||
return PoolName.Parse(entity.GetString(ef.kind.ToString()));
|
||||
else {
|
||||
throw new Exception($"invalid partition or row key type of {info.type} property {name}: {ef.type}");
|
||||
|
@ -61,7 +61,7 @@ namespace ApiService.OneFuzzLib.Orm {
|
||||
public async Task<ResultVoid<(int, string)>> Replace(T entity) {
|
||||
var tableClient = await GetTableClient(typeof(T).Name);
|
||||
var tableEntity = _entityConverter.ToTableEntity(entity);
|
||||
var response = await tableClient.UpsertEntityAsync(tableEntity);
|
||||
var response = await tableClient.UpsertEntityAsync(tableEntity, TableUpdateMode.Replace);
|
||||
if (response.IsError) {
|
||||
return ResultVoid<(int, string)>.Error((response.Status, response.ReasonPhrase));
|
||||
} else {
|
||||
@ -97,7 +97,7 @@ namespace ApiService.OneFuzzLib.Orm {
|
||||
|
||||
var account = accountId ?? _context.ServiceConfiguration.OneFuzzFuncStorage ?? throw new ArgumentNullException(nameof(accountId));
|
||||
var (name, key) = await _context.Storage.GetStorageAccountNameAndKey(account);
|
||||
var endpoint = _context.Storage.GetTableEndpoint(account);
|
||||
var endpoint = _context.Storage.GetTableEndpoint(name);
|
||||
var tableClient = new TableServiceClient(endpoint, new TableSharedKeyCredential(name, key));
|
||||
await tableClient.CreateTableIfNotExistsAsync(tableName);
|
||||
return tableClient.GetTableClient(tableName);
|
||||
@ -136,13 +136,42 @@ namespace ApiService.OneFuzzLib.Orm {
|
||||
}
|
||||
|
||||
|
||||
public class StatefulOrm<T, TState> : Orm<T>, IStatefulOrm<T, TState> where T : StatefulEntityBase<TState> where TState : Enum {
|
||||
public class StatefulOrm<T, TState, Self> : Orm<T>, IStatefulOrm<T, TState> where T : StatefulEntityBase<TState> where TState : Enum {
|
||||
static Lazy<Func<object>>? _partitionKeyGetter;
|
||||
static Lazy<Func<object>>? _rowKeyGetter;
|
||||
static ConcurrentDictionary<string, Func<T, Async.Task<T>>?> _stateFuncs = new ConcurrentDictionary<string, Func<T, Async.Task<T>>?>();
|
||||
|
||||
delegate Async.Task<T> StateTransition(T entity);
|
||||
|
||||
|
||||
static StatefulOrm() {
|
||||
|
||||
/// verify that all state transition function have the correct signature:
|
||||
var thisType = typeof(Self);
|
||||
var states = Enum.GetNames(typeof(TState));
|
||||
var delegateType = typeof(StateTransition);
|
||||
MethodInfo delegateSignature = delegateType.GetMethod("Invoke")!;
|
||||
|
||||
foreach (var state in states) {
|
||||
var methodInfo = thisType?.GetMethod(state.ToString());
|
||||
if (methodInfo == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool parametersEqual = delegateSignature
|
||||
.GetParameters()
|
||||
.Select(x => x.ParameterType)
|
||||
.SequenceEqual(methodInfo.GetParameters()
|
||||
.Select(x => x.ParameterType));
|
||||
|
||||
if (delegateSignature.ReturnType == methodInfo.ReturnType && parametersEqual) {
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Exception($"State transition method '{state}' in '{thisType?.Name}' does not have the correct signature. Expected '{delegateSignature}' actual '{methodInfo}' ");
|
||||
};
|
||||
|
||||
|
||||
_partitionKeyGetter =
|
||||
typeof(T).GetProperties().FirstOrDefault(p => p.GetCustomAttributes(true).OfType<PartitionKeyAttribute>().Any())?.GetMethod switch {
|
||||
null => null,
|
||||
@ -167,11 +196,10 @@ namespace ApiService.OneFuzzLib.Orm {
|
||||
/// <returns></returns>
|
||||
public async Async.Task<T?> ProcessStateUpdate(T entity) {
|
||||
TState state = entity.State;
|
||||
var func = _stateFuncs.GetOrAdd(state.ToString(), (string k) =>
|
||||
typeof(T).GetMethod(k) switch {
|
||||
null => null,
|
||||
MethodInfo info => (Func<T, Async.Task<T>>)Delegate.CreateDelegate(typeof(Func<T, Async.Task<T>>), info)
|
||||
});
|
||||
var func = GetType().GetMethod(state.ToString()) switch {
|
||||
null => null,
|
||||
MethodInfo info => info.CreateDelegate<StateTransition>(this)
|
||||
};
|
||||
|
||||
if (func != null) {
|
||||
_logTracer.Info($"processing state update: {typeof(T)} - PartitionKey {_partitionKeyGetter?.Value()} {_rowKeyGetter?.Value()} - %s");
|
||||
|
@ -11,12 +11,15 @@ namespace ApiService.OneFuzzLib.Orm {
|
||||
public static string PartitionKey(string partitionKey)
|
||||
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
|
||||
|
||||
public static string PartitionKey(Guid partitionKey)
|
||||
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
|
||||
|
||||
public static string RowKey(string rowKey)
|
||||
=> TableClient.CreateQueryFilter($"RowKey eq {rowKey}");
|
||||
|
||||
public static string PartitionKeys(IEnumerable<string> partitionKeys)
|
||||
=> Or(partitionKeys.Select(PartitionKey));
|
||||
|
||||
public static string RowKeys(IEnumerable<string> rowKeys)
|
||||
=> Or(rowKeys.Select(RowKey));
|
||||
|
||||
public static string SingleEntity(string partitionKey, string rowKey)
|
||||
=> TableClient.CreateQueryFilter($"(PartitionKey eq {partitionKey}) and (RowKey eq {rowKey})");
|
||||
|
||||
|
@ -25,21 +25,31 @@ sealed class AzureStorage : IStorage {
|
||||
return new AzureStorage(accountName, accountKey);
|
||||
}
|
||||
|
||||
public string? AccountName { get; }
|
||||
public string? AccountKey { get; }
|
||||
public string AccountName { get; }
|
||||
public string AccountKey { get; }
|
||||
|
||||
public AzureStorage(string? accountName, string? accountKey) {
|
||||
public AzureStorage(string accountName, string accountKey) {
|
||||
AccountName = accountName;
|
||||
AccountKey = accountKey;
|
||||
}
|
||||
|
||||
public Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId)
|
||||
=> Async.Task.FromResult((AccountName, AccountKey));
|
||||
public IEnumerable<string> CorpusAccounts() {
|
||||
throw new System.NotImplementedException();
|
||||
}
|
||||
|
||||
public IEnumerable<string> GetAccounts(StorageType storageType) {
|
||||
if (AccountName != null) {
|
||||
yield return AccountName;
|
||||
}
|
||||
yield return AccountName;
|
||||
}
|
||||
|
||||
public string GetPrimaryAccount(StorageType storageType) {
|
||||
throw new System.NotImplementedException();
|
||||
}
|
||||
|
||||
public Task<(string, string)> GetStorageAccountNameAndKey(string accountId)
|
||||
=> Async.Task.FromResult((AccountName, AccountKey));
|
||||
|
||||
public Task<string?> GetStorageAccountNameKeyByName(string accountName) {
|
||||
return Async.Task.FromResult(AccountName)!;
|
||||
}
|
||||
|
||||
public Uri GetTableEndpoint(string accountId)
|
||||
@ -51,15 +61,4 @@ sealed class AzureStorage : IStorage {
|
||||
public Uri GetBlobEndpoint(string accountId)
|
||||
=> new($"https://{AccountName}.blob.core.windows.net/");
|
||||
|
||||
public IEnumerable<string> CorpusAccounts() {
|
||||
throw new System.NotImplementedException();
|
||||
}
|
||||
|
||||
public string GetPrimaryAccount(StorageType storageType) {
|
||||
throw new System.NotImplementedException();
|
||||
}
|
||||
|
||||
public Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
|
||||
throw new System.NotImplementedException();
|
||||
}
|
||||
}
|
||||
|
@ -22,15 +22,11 @@ sealed class AzuriteStorage : IStorage {
|
||||
const string AccountName = "devstoreaccount1";
|
||||
const string AccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
|
||||
|
||||
public Task<(string?, string?)> GetStorageAccountNameAndKey(string _accountId)
|
||||
=> Async.Task.FromResult<(string?, string?)>((AccountName, AccountKey));
|
||||
public Task<(string, string)> GetStorageAccountNameAndKey(string accountId)
|
||||
=> Async.Task.FromResult((AccountName, AccountKey));
|
||||
|
||||
public IEnumerable<string> GetAccounts(StorageType storageType) {
|
||||
yield return AccountName;
|
||||
}
|
||||
|
||||
public Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
|
||||
throw new System.NotImplementedException();
|
||||
public Task<string?> GetStorageAccountNameKeyByName(string accountName) {
|
||||
return Async.Task.FromResult(AccountName)!;
|
||||
}
|
||||
|
||||
public IEnumerable<string> CorpusAccounts() {
|
||||
@ -40,4 +36,8 @@ sealed class AzuriteStorage : IStorage {
|
||||
public string GetPrimaryAccount(StorageType storageType) {
|
||||
throw new System.NotImplementedException();
|
||||
}
|
||||
|
||||
public IEnumerable<string> GetAccounts(StorageType storageType) {
|
||||
yield return AccountName;
|
||||
}
|
||||
}
|
||||
|
@ -35,6 +35,8 @@ public class RequestsTests {
|
||||
var deserialized = (T?)_serializer.Deserialize(stream, typeof(T), CancellationToken.None);
|
||||
var reserialized = _serializer.Serialize(deserialized);
|
||||
var result = Encoding.UTF8.GetString(reserialized);
|
||||
result = result.Replace(System.Environment.NewLine, "\n");
|
||||
json = json.Replace(System.Environment.NewLine, "\n");
|
||||
Assert.Equal(json, result);
|
||||
}
|
||||
|
||||
|
133
src/ApiService/Tests/SchedulerTests.cs
Normal file
133
src/ApiService/Tests/SchedulerTests.cs
Normal file
@ -0,0 +1,133 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using Microsoft.OneFuzz.Service;
|
||||
using Xunit;
|
||||
|
||||
namespace Tests;
|
||||
|
||||
public class SchedulerTests {
|
||||
|
||||
IEnumerable<Task> BuildTasks(int size) {
|
||||
return Enumerable.Range(0, size).Select(i =>
|
||||
new Task(
|
||||
Guid.Empty,
|
||||
Guid.NewGuid(),
|
||||
TaskState.Init,
|
||||
Os.Linux,
|
||||
new TaskConfig(
|
||||
Guid.Empty,
|
||||
null,
|
||||
new TaskDetails(
|
||||
Type: TaskType.LibfuzzerFuzz,
|
||||
Duration: 1,
|
||||
TargetExe: "fuzz.exe",
|
||||
TargetEnv: new Dictionary<string, string>(),
|
||||
TargetOptions: new List<string>()),
|
||||
Pool: new TaskPool(1, PoolName.Parse("pool")),
|
||||
Containers: new List<TaskContainers> { new TaskContainers(ContainerType.Setup, new Container("setup")) },
|
||||
Colocate: true
|
||||
|
||||
),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null)
|
||||
|
||||
);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestAllColocate() {
|
||||
// all tasks should land in one bucket
|
||||
|
||||
var tasks = BuildTasks(10).Select(task => task with { Config = task.Config with { Colocate = true } }
|
||||
).ToList();
|
||||
|
||||
var buckets = Scheduler.BucketTasks(tasks);
|
||||
var bucket = Assert.Single(buckets);
|
||||
Assert.True(10 >= bucket.Count());
|
||||
CheckBuckets(buckets, tasks, 1);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestPartialColocate() {
|
||||
// 2 tasks should land on their own, the rest should be colocated into a
|
||||
// single bucket.
|
||||
|
||||
var tasks = BuildTasks(10).Select((task, i) => {
|
||||
return i switch {
|
||||
0 => task with { Config = task.Config with { Colocate = null } },
|
||||
1 => task with { Config = task.Config with { Colocate = false } },
|
||||
_ => task
|
||||
};
|
||||
}).ToList();
|
||||
var buckets = Scheduler.BucketTasks(tasks);
|
||||
var lengths = buckets.Select(b => b.Count()).OrderBy(x => x);
|
||||
Assert.Equal(new[] { 1, 1, 8 }, lengths);
|
||||
CheckBuckets(buckets, tasks, 3);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestAlluniqueJob() {
|
||||
// everything has a unique job_id
|
||||
var tasks = BuildTasks(10).Select(task => {
|
||||
var jobId = Guid.NewGuid();
|
||||
return task with { JobId = jobId, Config = task.Config with { JobId = jobId } };
|
||||
}).ToList();
|
||||
|
||||
var buckets = Scheduler.BucketTasks(tasks);
|
||||
foreach (var bucket in buckets) {
|
||||
Assert.True(1 >= bucket.Count());
|
||||
}
|
||||
CheckBuckets(buckets, tasks, 10);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestMultipleJobBuckets() {
|
||||
// at most 3 tasks per bucket, by job_id
|
||||
var tasks = BuildTasks(10).Chunk(3).SelectMany(taskChunk => {
|
||||
var jobId = Guid.NewGuid();
|
||||
return taskChunk.Select(task => task with { JobId = jobId, Config = task.Config with { JobId = jobId } });
|
||||
}).ToList();
|
||||
|
||||
var buckets = Scheduler.BucketTasks(tasks);
|
||||
foreach (var bucket in buckets) {
|
||||
Assert.True(3 >= bucket.Count());
|
||||
}
|
||||
CheckBuckets(buckets, tasks, 4);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void TestManyBuckets() {
|
||||
var jobId = Guid.Parse("00000000-0000-0000-0000-000000000001");
|
||||
var tasks = BuildTasks(100).Select((task, i) => {
|
||||
var containers = new List<TaskContainers>(task.Config.Containers!);
|
||||
if (i % 4 == 0) {
|
||||
containers[0] = containers[0] with { Name = new Container("setup2") };
|
||||
}
|
||||
return task with {
|
||||
JobId = i % 2 == 0 ? jobId : task.JobId,
|
||||
Os = i % 3 == 0 ? Os.Windows : task.Os,
|
||||
Config = task.Config with {
|
||||
JobId = i % 2 == 0 ? jobId : task.Config.JobId,
|
||||
Containers = containers,
|
||||
Pool = i % 5 == 0 ? task.Config.Pool! with { PoolName = PoolName.Parse("alternate-pool") } : task.Config.Pool
|
||||
}
|
||||
};
|
||||
}).ToList();
|
||||
|
||||
var buckets = Scheduler.BucketTasks(tasks);
|
||||
|
||||
CheckBuckets(buckets, tasks, 12);
|
||||
}
|
||||
|
||||
void CheckBuckets(ILookup<Scheduler.BucketId, Task> buckets, List<Task> tasks, int bucketCount) {
|
||||
Assert.Equal(buckets.Count, bucketCount);
|
||||
|
||||
foreach (var task in tasks) {
|
||||
Assert.Single(buckets, b => b.Contains(task));
|
||||
}
|
||||
}
|
||||
}
|
@ -8,12 +8,6 @@ public class ValidatedStringTests {
|
||||
|
||||
record ThingContainingPoolName(PoolName PoolName);
|
||||
|
||||
[Fact]
|
||||
public void PoolNameValidatesOnDeserialization() {
|
||||
var ex = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-not!-a-pool\" }"));
|
||||
Assert.Equal("unable to parse input as a PoolName", ex.Message);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PoolNameDeserializesFromString() {
|
||||
var result = JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-a-pool\" }");
|
||||
|
Reference in New Issue
Block a user