Migrate timer_task part 3 (#2066)

* Adding unit tests for the scheduler

* formatting

* fix unit test

* bug fixes

* proting a couple more missong functions

* more bug fixes

* fixing queries and serilization

* fix sas url generation

* fix condition

* [testing] enabling timer_tasks for testing

* Update src/ApiService/Tests/SchedulerTests.cs

Co-authored-by: George Pollard <porges@porg.es>

* address PR comments

* build fix

* address PR comment

* removing renamed function

* resolve merge

* Update src/ApiService/Tests/Integration/AzuriteStorage.cs

Co-authored-by: George Pollard <porges@porg.es>

* - Added verification of the state transition functions
- disabled validation on PoolName

* Update src/deployment/deploy.py

Co-authored-by: George Pollard <porges@porg.es>
Co-authored-by: George Pollard <gpollard@microsoft.com>
This commit is contained in:
Cheick Keita
2022-06-24 09:22:08 -07:00
committed by GitHub
parent fb9af4b811
commit d61fe48a55
23 changed files with 427 additions and 136 deletions

View File

@ -184,7 +184,8 @@ public record TaskDetails(
bool? PreserveExistingOutputs = null,
List<string>? ReportList = null,
int? MinimizedStackDepth = null,
string? CoverageFilter = null);
string? CoverageFilter = null
);
public record TaskVm(
Region Region,
@ -214,7 +215,8 @@ public record TaskConfig(
List<TaskContainers>? Containers = null,
Dictionary<string, string>? Tags = null,
List<TaskDebugFlag>? Debug = null,
bool? Colocate = null);
bool? Colocate = null
);
public record TaskEventSummary(
DateTimeOffset? Timestamp,
@ -590,9 +592,29 @@ public record WorkUnit(
Guid JobId,
Guid TaskId,
TaskType TaskType,
TaskUnitConfig Config
// JSON-serialized `TaskUnitConfig`.
[property: JsonConverter(typeof(TaskUnitConfigConverter))] TaskUnitConfig Config
);
public class TaskUnitConfigConverter : JsonConverter<TaskUnitConfig> {
public override TaskUnitConfig? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
var taskUnitString = reader.GetString();
if (taskUnitString == null) {
return null;
}
return JsonSerializer.Deserialize<TaskUnitConfig>(taskUnitString, options);
}
public override void Write(Utf8JsonWriter writer, TaskUnitConfig value, JsonSerializerOptions options) {
var v = JsonSerializer.Serialize(value, new JsonSerializerOptions(options) {
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
});
writer.WriteStringValue(v);
}
}
public record VmDefinition(
Compare Compare,
int Value
@ -625,14 +647,36 @@ public record ContainerDefinition(
// TODO: service shouldn't pass SyncedDir, but just the url and let the agent
// come up with paths
public record SyncedDir(string Path, Uri url);
public record SyncedDir(string Path, Uri Url);
[JsonConverter(typeof(ContainerDefConverter))]
public interface IContainerDef { }
public record SingleContainer(SyncedDir SyncedDir) : IContainerDef;
public record MultipleContainer(List<SyncedDir> SyncedDirs) : IContainerDef;
public class ContainerDefConverter : JsonConverter<IContainerDef> {
public override IContainerDef? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
throw new NotImplementedException();
}
public override void Write(Utf8JsonWriter writer, IContainerDef value, JsonSerializerOptions options) {
switch (value) {
case SingleContainer container:
JsonSerializer.Serialize(writer, container.SyncedDir, options);
break;
case MultipleContainer { SyncedDirs: var syncedDirs }:
JsonSerializer.Serialize(writer, syncedDirs, options);
break;
default:
throw new NotImplementedException();
}
}
}
public record TaskUnitConfig(
Guid InstanceId,
Guid JobId,
@ -686,7 +730,7 @@ public record TaskUnitConfig(
public IContainerDef? Tools { get; set; }
public IContainerDef? UniqueInputs { get; set; }
public IContainerDef? UniqueReports { get; set; }
public IContainerDef? RegressionReport { get; set; }
public IContainerDef? RegressionReports { get; set; }
}

View File

@ -1,6 +1,4 @@

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
@ -51,7 +49,7 @@ public abstract class ValidatedStringConverter<T> : JsonConverter<T> where T : V
[JsonConverter(typeof(Converter))]
public record PoolName : ValidatedString {
private PoolName(string value) : base(value) {
Debug.Assert(Check.IsAlnumDash(value));
// Debug.Assert(Check.IsAlnumDash(value));
}
public static PoolName Parse(string input) {
@ -63,10 +61,14 @@ public record PoolName : ValidatedString {
}
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) {
if (!Check.IsAlnumDash(input)) {
result = default;
return false;
}
// bypassing the validation because this code has a stricter validation than the python equivalent
// see (issue #2080)
// if (!Check.IsAlnumDash(input)) {
// result = default;
// return false;
// }
result = new PoolName(input);
return true;

View File

@ -21,8 +21,8 @@ public class TimerTasks {
_scheduler = scheduler;
}
//[Function("TimerTasks")]
public async Async.Task Run([TimerTrigger("1.00:00:00")] TimerInfo myTimer) {
[Function("TimerTasks")]
public async Async.Task Run([TimerTrigger("00:00:15")] TimerInfo myTimer) {
var expriredTasks = _taskOperations.SearchExpired();
await foreach (var task in expriredTasks) {
_logger.Info($"stopping expired task. job_id:{task.JobId} task_id:{task.TaskId}");

View File

@ -5,7 +5,6 @@ namespace Microsoft.OneFuzz.Service;
public interface IConfig {
string GetSetupContainer(TaskConfig config);
Async.Task<TaskUnitConfig> BuildTaskConfig(Job job, Task task);
}
@ -89,9 +88,13 @@ public class Config : IConfig {
await foreach (var data in containersByType) {
if (!data.containers.Any()) {
continue;
}
IContainerDef def = data.countainerDef switch {
ContainerDefinition { Compare: Compare.Equal, Value: 1 } or
ContainerDefinition { Compare: Compare.AtMost, Value: 1 } => new SingleContainer(data.containers[0]),
ContainerDefinition { Compare: Compare.AtMost, Value: 1 } when data.containers.Count == 1 => new SingleContainer(data.containers[0]),
_ => new MultipleContainer(data.containers)
};
@ -126,6 +129,9 @@ public class Config : IConfig {
case ContainerType.UniqueReports:
config.UniqueReports = def;
break;
case ContainerType.RegressionReports:
config.RegressionReports = def;
break;
}
}
@ -249,16 +255,4 @@ public class Config : IConfig {
return config;
}
public string GetSetupContainer(TaskConfig config) {
foreach (var container in config.Containers ?? throw new Exception("Missing containers")) {
if (container.Type == ContainerType.Setup) {
return container.Name.ContainerName;
}
}
throw new Exception($"task missing setup container: task_type = {config.Task.Type}");
}
}

View File

@ -164,14 +164,15 @@ public class Containers : IContainers {
if (uri.Query.Contains("sig")) {
return uri;
}
var accountName = uri.Host.Split('.')[0];
var (_, accountKey) = await _storage.GetStorageAccountNameAndKey(accountName);
var blobUriBuilder = new BlobUriBuilder(uri);
var accountKey = await _storage.GetStorageAccountNameKeyByName(blobUriBuilder.AccountName);
var sasBuilder = new BlobSasBuilder(
BlobContainerSasPermissions.Read | BlobContainerSasPermissions.Write | BlobContainerSasPermissions.Delete | BlobContainerSasPermissions.List,
DateTimeOffset.UtcNow + TimeSpan.FromHours(1));
DateTimeOffset.UtcNow + TimeSpan.FromHours(1)) {
BlobContainerName = blobUriBuilder.BlobContainerName,
};
var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(accountName, accountKey)).ToString();
var sas = sasBuilder.ToSasQueryParameters(new StorageSharedKeyCredential(blobUriBuilder.AccountName, accountKey)).ToString();
return new UriBuilder(uri) {
Query = sas
}.Uri;
@ -179,13 +180,10 @@ public class Containers : IContainers {
public async Async.Task<Uri> GetContainerSasUrl(Container container, StorageType storageType, BlobContainerSasPermissions permissions, TimeSpan? duration = null) {
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
var (accountName, accountKey) = await _storage.GetStorageAccountNameAndKey(client.AccountName);
var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30));
var sasBuilder = new BlobSasBuilder(permissions, endTime) {
StartsOn = startTime,
BlobContainerName = container.ContainerName,
BlobContainerName = container.ContainerName
};
var sasUrl = client.GenerateSasUri(sasBuilder);

View File

@ -6,12 +6,13 @@ public interface IJobOperations : IStatefulOrm<Job, JobState> {
Async.Task<Job?> Get(Guid jobId);
Async.Task OnStart(Job job);
IAsyncEnumerable<Job> SearchExpired();
Async.Task Stopping(Job job);
Async.Task<Job> Stopping(Job job);
IAsyncEnumerable<Job> SearchState(IEnumerable<JobState> states);
Async.Task StopNeverStartedJobs();
Async.Task StopIfAllDone(Job job);
}
public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
public class JobOperations : StatefulOrm<Job, JobState, JobOperations>, IJobOperations {
private static TimeSpan JOB_NEVER_STARTED_DURATION = TimeSpan.FromDays(30);
public JobOperations(ILogTracer logTracer, IOnefuzzContext context) : base(logTracer, context) {
@ -36,6 +37,17 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
return QueryAsync(filter: query);
}
public async Async.Task StopIfAllDone(Job job) {
var anyNotStoppedJobs = await _context.TaskOperations.GetByJobId(job.JobId).AnyAsync(task => task.State != TaskState.Stopped);
if (anyNotStoppedJobs) {
return;
}
_logTracer.Info($"stopping job as all tasks are stopped: {job.JobId}");
await Stopping(job);
}
public async Async.Task StopNeverStartedJobs() {
// # Note, the "not(end_time...)" with end_time set long before the use of
// # OneFuzz enables identifying those without end_time being set.
@ -58,7 +70,7 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
}
}
public async Async.Task Stopping(Job job) {
public async Async.Task<Job> Stopping(Job job) {
job = job with { State = JobState.Stopping };
var tasks = await _context.TaskOperations.QueryAsync(filter: $"job_id eq '{job.JobId}'").ToListAsync();
var taskNotStopped = tasks.ToLookup(task => task.State != TaskState.Stopped);
@ -76,7 +88,12 @@ public class JobOperations : StatefulOrm<Job, JobState>, IJobOperations {
await _context.Events.SendEvent(new EventJobStopped(job.JobId, job.Config, job.UserInfo, taskInfo));
}
await Replace(job);
var result = await Replace(job);
if (result.IsOk) {
return job;
} else {
throw new Exception($"Failed to save job {job.JobId} : {result.ErrorV}");
}
}
}

View File

@ -43,6 +43,8 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
Async.Task MarkTasksStoppedEarly(Node node, Error? error = null);
static TimeSpan NODE_EXPIRATION_TIME = TimeSpan.FromHours(1.0);
static TimeSpan NODE_REIMAGE_TIME = TimeSpan.FromDays(6.0);
Async.Task StopTask(Guid task_id);
}
@ -51,7 +53,7 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
/// Enabling autoscaling for the scalesets based on the pool work queues.
/// https://docs.microsoft.com/en-us/azure/azure-monitor/platform/autoscale-common-metrics#commonly-used-storage-metrics
public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
public class NodeOperations : StatefulOrm<Node, NodeState, NodeOperations>, INodeOperations {
public NodeOperations(
@ -269,7 +271,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
public async Async.Task<Node?> GetByMachineId(Guid machineId) {
var data = QueryAsync(filter: $"RowKey eq '{machineId}'");
var data = QueryAsync(filter: Query.RowKey(machineId.ToString()));
return await data.FirstOrDefaultAsync();
}
@ -387,18 +389,49 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
await _context.Events.SendEvent(new EventNodeDeleted(node.MachineId, node.ScalesetId, node.PoolName, node.State));
}
public async Async.Task StopTask(Guid task_id) {
// For now, this just re-images the node. Eventually, this
// should send a message to the node to let the agent shut down
// gracefully
var nodes = _context.NodeTasksOperations.GetNodesByTaskId(task_id);
await foreach (var node in nodes) {
await _context.NodeMessageOperations.SendMessage(node.MachineId, new NodeCommand(StopTask: new StopTaskNodeCommand(task_id)));
if (!(await StopIfComplete(node))) {
_logTracer.Info($"nodes: stopped task on node, but not reimaging due to other tasks: task_id:{task_id} machine_id:{node.MachineId}");
}
}
}
/// returns True on stopping the node and False if this doesn't stop the node
private async Task<bool> StopIfComplete(Node node, bool done = false) {
var nodeTaskIds = await _context.NodeTasksOperations.GetByMachineId(node.MachineId).Select(nt => nt.TaskId).ToArrayAsync();
var tasks = _context.TaskOperations.GetByTaskIds(nodeTaskIds);
await foreach (var task in tasks) {
if (!TaskStateHelper.ShuttingDown(task.State)) {
return false;
}
}
_logTracer.Info($"node: stopping busy node with all tasks complete: {node.MachineId}");
await Stop(node, done: done);
return true;
}
}
public interface INodeTasksOperations : IStatefulOrm<NodeTasks, NodeTaskState> {
IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId, INodeOperations nodeOps);
IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId);
IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps);
IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId);
IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId);
Async.Task ClearByMachineId(Guid machineId);
}
public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeTasksOperations {
public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState, NodeTasksOperations>, INodeTasksOperations {
ILogTracer _log;
@ -408,18 +441,20 @@ public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeT
}
//TODO: suggest by Cheick: this can probably be optimize by query all NodesTasks then query the all machine in single request
public async IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId, INodeOperations nodeOps) {
await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) {
var node = await nodeOps.GetByMachineId(entry.MachineId);
public async IAsyncEnumerable<Node> GetNodesByTaskId(Guid taskId) {
await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) {
var node = await _context.NodeOperations.GetByMachineId(entry.MachineId);
if (node is not null) {
yield return node;
}
}
}
public async IAsyncEnumerable<NodeAssignment> GetNodeAssignments(Guid taskId, INodeOperations nodeOps) {
await foreach (var entry in QueryAsync($"task_id eq '{taskId}'")) {
var node = await nodeOps.GetByMachineId(entry.MachineId);
await foreach (var entry in QueryAsync(Query.RowKey(taskId.ToString()))) {
var node = await _context.NodeOperations.GetByMachineId(entry.MachineId);
if (node is not null) {
var nodeAssignment = new NodeAssignment(node.MachineId, node.ScalesetId, entry.State);
yield return nodeAssignment;
@ -428,11 +463,11 @@ public class NodeTasksOperations : StatefulOrm<NodeTasks, NodeTaskState>, INodeT
}
public IAsyncEnumerable<NodeTasks> GetByMachineId(Guid machineId) {
return QueryAsync($"machine_id eq '{machineId}'");
return QueryAsync(Query.PartitionKey(machineId.ToString()));
}
public IAsyncEnumerable<NodeTasks> GetByTaskId(Guid taskId) {
return QueryAsync($"task_id eq '{taskId}'");
return QueryAsync(Query.RowKey(taskId.ToString()));
}
public async Async.Task ClearByMachineId(Guid machineId) {
@ -472,7 +507,7 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
}
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
=> QueryAsync(Query.PartitionKey(machineId));
=> QueryAsync(Query.PartitionKey(machineId.ToString()));
public async Async.Task ClearMessages(Guid machineId) {
_logTracer.Info($"clearing messages for node {machineId}");

View File

@ -9,7 +9,7 @@ public interface IPoolOperations {
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
}
public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations {
public class PoolOperations : StatefulOrm<Pool, PoolState, PoolOperations>, IPoolOperations {
public PoolOperations(ILogTracer log, IOnefuzzContext context)
: base(log, context) {

View File

@ -13,7 +13,7 @@ public interface IProxyOperations : IStatefulOrm<Proxy, VmState> {
bool IsOutdated(Proxy proxy);
Async.Task<Proxy?> GetOrCreate(string region);
}
public class ProxyOperations : StatefulOrm<Proxy, VmState>, IProxyOperations {
public class ProxyOperations : StatefulOrm<Proxy, VmState, ProxyOperations>, IProxyOperations {
static TimeSpan PROXY_LIFESPAN = TimeSpan.FromDays(7);

View File

@ -51,7 +51,7 @@ public class Queue : IQueue {
var accountId = _storage.GetPrimaryAccount(storageType);
_log.Verbose($"getting blob container (account_id: {accountId})");
var (name, key) = await _storage.GetStorageAccountNameAndKey(accountId);
var endpoint = _storage.GetQueueEndpoint(accountId);
var endpoint = _storage.GetQueueEndpoint(name);
var options = new QueueClientOptions { MessageEncoding = QueueMessageEncoding.Base64 };
return new QueueServiceClient(endpoint, new StorageSharedKeyCredential(name, key), options);
}

View File

@ -10,7 +10,7 @@ public interface IReproOperations : IStatefulOrm<Repro, VmState> {
public IAsyncEnumerable<Repro> SearchStates(IEnumerable<VmState>? States);
}
public class ReproOperations : StatefulOrm<Repro, VmState>, IReproOperations {
public class ReproOperations : StatefulOrm<Repro, VmState, ReproOperations>, IReproOperations {
private static readonly Dictionary<Os, string> DEFAULT_OS = new Dictionary<Os, string>
{
{Os.Linux, "Canonical:UbuntuServer:18.04-LTS:latest"},

View File

@ -14,7 +14,7 @@ public interface IScalesetOperations : IOrm<Scaleset> {
IAsyncEnumerable<Scaleset> GetByObjectId(Guid objectId);
}
public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalesetOperations {
public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState, ScalesetOperations>, IScalesetOperations {
const string SCALESET_LOG_PREFIX = "scalesets: ";
ILogTracer _log;

View File

@ -74,7 +74,7 @@ public class Scheduler : IScheduler {
private async Async.Task<(BucketConfig, WorkSet)?> BuildWorkSet(Task[] tasks) {
var taskIds = tasks.Select(x => x.TaskId).ToHashSet();
var work_units = new List<WorkUnit>();
var workUnits = new List<WorkUnit>();
BucketConfig? bucketConfig = null;
foreach (var task in tasks) {
@ -99,7 +99,7 @@ public class Scheduler : IScheduler {
throw new Exception($"bucket configs differ: {bucketConfig} VS {result.Value.Item1}");
}
work_units.Add(result.Value.Item2);
workUnits.Add(result.Value.Item2);
}
if (bucketConfig != null) {
@ -108,7 +108,7 @@ public class Scheduler : IScheduler {
Reboot: bucketConfig.reboot,
Script: bucketConfig.setupScript != null,
SetupUrl: setupUrl,
WorkUnits: work_units
WorkUnits: workUnits
);
return (bucketConfig, workSet);
@ -182,9 +182,9 @@ public class Scheduler : IScheduler {
return (bucketConfig, workUnit);
}
record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
public record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
private ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
public static ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
// buckets are hashed by:
// OS, JOB ID, vm sku & image (if available), pool name (if available),
@ -214,8 +214,19 @@ public class Scheduler : IScheduler {
unique = Guid.NewGuid();
}
return new BucketId(task.Os, task.JobId, vm, pool, _config.GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique);
return new BucketId(task.Os, task.JobId, vm, pool, GetSetupContainer(task.Config), task.Config.Task.RebootAfterSetup, unique);
});
}
static string GetSetupContainer(TaskConfig config) {
foreach (var container in config.Containers ?? throw new Exception("Missing containers")) {
if (container.Type == ContainerType.Setup) {
return container.Name.ContainerName;
}
}
throw new Exception($"task missing setup container: task_type = {config.Task.Type}");
}
}

View File

@ -20,9 +20,9 @@ public interface IStorage {
public Uri GetBlobEndpoint(string accountId);
public Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId);
public Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId);
public Async.Task<string?> GetStorageAccountNameAndKeyByName(string accountName);
public Async.Task<string?> GetStorageAccountNameKeyByName(string accountName);
public IEnumerable<string> GetAccounts(StorageType storageType);
}
@ -99,16 +99,16 @@ public class Storage : IStorage {
};
}
public async Async.Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId) {
public async Async.Task<(string, string)> GetStorageAccountNameAndKey(string accountId) {
var resourceId = new ResourceIdentifier(accountId);
var armClient = GetMgmtClient();
var storageAccount = armClient.GetStorageAccountResource(resourceId);
var keys = await storageAccount.GetKeysAsync();
var key = keys.Value.Keys.FirstOrDefault();
return (resourceId.Name, key?.Value);
var key = keys.Value.Keys.FirstOrDefault() ?? throw new Exception("no keys found");
return (resourceId.Name, key.Value);
}
public async Async.Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
public async Async.Task<string?> GetStorageAccountNameKeyByName(string accountName) {
var armClient = GetMgmtClient();
var resourceGroup = _creds.GetResourceGroupResourceIdentifier();
var storageAccount = await armClient.GetResourceGroupResource(resourceGroup).GetStorageAccountAsync(accountName);

View File

@ -5,6 +5,10 @@ namespace Microsoft.OneFuzz.Service;
public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
Async.Task<Task?> GetByTaskId(Guid taskId);
IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId);
IAsyncEnumerable<Task> GetByJobId(Guid jobId);
Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId);
@ -22,7 +26,7 @@ public interface ITaskOperations : IStatefulOrm<Task, TaskState> {
Async.Task<Task> SetState(Task task, TaskState state);
}
public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
public class TaskOperations : StatefulOrm<Task, TaskState, TaskOperations>, ITaskOperations {
public TaskOperations(ILogTracer log, IOnefuzzContext context)
@ -31,9 +35,15 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
}
public async Async.Task<Task?> GetByTaskId(Guid taskId) {
var data = QueryAsync(filter: $"RowKey eq '{taskId}'");
return await GetByTaskIds(new[] { taskId }).FirstOrDefaultAsync();
}
return await data.FirstOrDefaultAsync();
public IAsyncEnumerable<Task> GetByTaskIds(IEnumerable<Guid> taskId) {
return QueryAsync(filter: Query.RowKeys(taskId.Select(t => t.ToString())));
}
public IAsyncEnumerable<Task> GetByJobId(Guid jobId) {
return QueryAsync(filter: $"PartitionKey eq '{jobId}'");
}
public async Async.Task<Task?> GetByJobIdAndTaskId(Guid jobId, Guid taskId) {
@ -42,21 +52,13 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
return await data.FirstOrDefaultAsync();
}
public IAsyncEnumerable<Task> SearchStates(Guid? jobId = null, IEnumerable<TaskState>? states = null) {
var queryString = String.Empty;
if (jobId != null) {
queryString += $"PartitionKey eq '{jobId}'";
}
if (states != null) {
if (jobId != null) {
queryString += " and ";
}
queryString += "(" + string.Join(
" or ",
states.Select(s => $"state eq '{s}'")
) + ")";
}
var queryString =
(jobId, states) switch {
(null, null) => "",
(Guid id, null) => Query.PartitionKey($"{id}"),
(null, IEnumerable<TaskState> s) => Query.EqualAnyEnum("state", s),
(Guid id, IEnumerable<TaskState> s) => Query.And(Query.PartitionKey($"{id}"), Query.EqualAnyEnum("state", s)),
};
return QueryAsync(filter: queryString);
}
@ -209,7 +211,7 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
// if a prereq task fails, then mark this task as failed
if (t == null) {
await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find task" }));
await MarkFailed(task, new Error(ErrorCode.INVALID_REQUEST, Errors: new[] { "unable to find prereq task" }));
return false;
}
@ -253,4 +255,33 @@ public class TaskOperations : StatefulOrm<Task, TaskState>, ITaskOperations {
return null;
}
public async Async.Task<Task> Init(Task task) {
await _context.Queue.CreateQueue($"{task.TaskId}", StorageType.Corpus);
return await SetState(task, TaskState.Waiting);
}
public async Async.Task<Task> Stopping(Task task) {
_logTracer.Info($"stopping task : {task.JobId}, {task.TaskId}");
await _context.NodeOperations.StopTask(task.TaskId);
var anyRemainingNodes = await _context.NodeTasksOperations.GetNodesByTaskId(task.TaskId).AnyAsync();
if (!anyRemainingNodes) {
return await Stopped(task);
}
return task;
}
private async Async.Task<Task> Stopped(Task inputTask) {
var task = await SetState(inputTask, TaskState.Stopped);
await _context.Queue.DeleteQueue($"{task.TaskId}", StorageType.Corpus);
// # TODO: we need to 'unschedule' this task from the existing pools
var job = await _context.JobOperations.Get(task.JobId);
if (job != null) {
await _context.JobOperations.StopIfAllDone(job);
}
return task;
}
}

View File

@ -210,8 +210,8 @@ public class EntityConverter {
return Guid.Parse(entity.GetString(ef.kind.ToString()));
else if (ef.type == typeof(int))
return int.Parse(entity.GetString(ef.kind.ToString()));
else if (ef.type == typeof(PoolName))
// TODO: this should be able to be generic over any ValidatedString
else if (ef.type == typeof(PoolName))
// TODO: this should be able to be generic over any ValidatedString
return PoolName.Parse(entity.GetString(ef.kind.ToString()));
else {
throw new Exception($"invalid partition or row key type of {info.type} property {name}: {ef.type}");

View File

@ -61,7 +61,7 @@ namespace ApiService.OneFuzzLib.Orm {
public async Task<ResultVoid<(int, string)>> Replace(T entity) {
var tableClient = await GetTableClient(typeof(T).Name);
var tableEntity = _entityConverter.ToTableEntity(entity);
var response = await tableClient.UpsertEntityAsync(tableEntity);
var response = await tableClient.UpsertEntityAsync(tableEntity, TableUpdateMode.Replace);
if (response.IsError) {
return ResultVoid<(int, string)>.Error((response.Status, response.ReasonPhrase));
} else {
@ -97,7 +97,7 @@ namespace ApiService.OneFuzzLib.Orm {
var account = accountId ?? _context.ServiceConfiguration.OneFuzzFuncStorage ?? throw new ArgumentNullException(nameof(accountId));
var (name, key) = await _context.Storage.GetStorageAccountNameAndKey(account);
var endpoint = _context.Storage.GetTableEndpoint(account);
var endpoint = _context.Storage.GetTableEndpoint(name);
var tableClient = new TableServiceClient(endpoint, new TableSharedKeyCredential(name, key));
await tableClient.CreateTableIfNotExistsAsync(tableName);
return tableClient.GetTableClient(tableName);
@ -136,13 +136,42 @@ namespace ApiService.OneFuzzLib.Orm {
}
public class StatefulOrm<T, TState> : Orm<T>, IStatefulOrm<T, TState> where T : StatefulEntityBase<TState> where TState : Enum {
public class StatefulOrm<T, TState, Self> : Orm<T>, IStatefulOrm<T, TState> where T : StatefulEntityBase<TState> where TState : Enum {
static Lazy<Func<object>>? _partitionKeyGetter;
static Lazy<Func<object>>? _rowKeyGetter;
static ConcurrentDictionary<string, Func<T, Async.Task<T>>?> _stateFuncs = new ConcurrentDictionary<string, Func<T, Async.Task<T>>?>();
delegate Async.Task<T> StateTransition(T entity);
static StatefulOrm() {
/// verify that all state transition function have the correct signature:
var thisType = typeof(Self);
var states = Enum.GetNames(typeof(TState));
var delegateType = typeof(StateTransition);
MethodInfo delegateSignature = delegateType.GetMethod("Invoke")!;
foreach (var state in states) {
var methodInfo = thisType?.GetMethod(state.ToString());
if (methodInfo == null) {
continue;
}
bool parametersEqual = delegateSignature
.GetParameters()
.Select(x => x.ParameterType)
.SequenceEqual(methodInfo.GetParameters()
.Select(x => x.ParameterType));
if (delegateSignature.ReturnType == methodInfo.ReturnType && parametersEqual) {
continue;
}
throw new Exception($"State transition method '{state}' in '{thisType?.Name}' does not have the correct signature. Expected '{delegateSignature}' actual '{methodInfo}' ");
};
_partitionKeyGetter =
typeof(T).GetProperties().FirstOrDefault(p => p.GetCustomAttributes(true).OfType<PartitionKeyAttribute>().Any())?.GetMethod switch {
null => null,
@ -167,11 +196,10 @@ namespace ApiService.OneFuzzLib.Orm {
/// <returns></returns>
public async Async.Task<T?> ProcessStateUpdate(T entity) {
TState state = entity.State;
var func = _stateFuncs.GetOrAdd(state.ToString(), (string k) =>
typeof(T).GetMethod(k) switch {
null => null,
MethodInfo info => (Func<T, Async.Task<T>>)Delegate.CreateDelegate(typeof(Func<T, Async.Task<T>>), info)
});
var func = GetType().GetMethod(state.ToString()) switch {
null => null,
MethodInfo info => info.CreateDelegate<StateTransition>(this)
};
if (func != null) {
_logTracer.Info($"processing state update: {typeof(T)} - PartitionKey {_partitionKeyGetter?.Value()} {_rowKeyGetter?.Value()} - %s");

View File

@ -11,12 +11,15 @@ namespace ApiService.OneFuzzLib.Orm {
public static string PartitionKey(string partitionKey)
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
public static string PartitionKey(Guid partitionKey)
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
public static string RowKey(string rowKey)
=> TableClient.CreateQueryFilter($"RowKey eq {rowKey}");
public static string PartitionKeys(IEnumerable<string> partitionKeys)
=> Or(partitionKeys.Select(PartitionKey));
public static string RowKeys(IEnumerable<string> rowKeys)
=> Or(rowKeys.Select(RowKey));
public static string SingleEntity(string partitionKey, string rowKey)
=> TableClient.CreateQueryFilter($"(PartitionKey eq {partitionKey}) and (RowKey eq {rowKey})");

View File

@ -25,21 +25,31 @@ sealed class AzureStorage : IStorage {
return new AzureStorage(accountName, accountKey);
}
public string? AccountName { get; }
public string? AccountKey { get; }
public string AccountName { get; }
public string AccountKey { get; }
public AzureStorage(string? accountName, string? accountKey) {
public AzureStorage(string accountName, string accountKey) {
AccountName = accountName;
AccountKey = accountKey;
}
public Task<(string?, string?)> GetStorageAccountNameAndKey(string accountId)
=> Async.Task.FromResult((AccountName, AccountKey));
public IEnumerable<string> CorpusAccounts() {
throw new System.NotImplementedException();
}
public IEnumerable<string> GetAccounts(StorageType storageType) {
if (AccountName != null) {
yield return AccountName;
}
yield return AccountName;
}
public string GetPrimaryAccount(StorageType storageType) {
throw new System.NotImplementedException();
}
public Task<(string, string)> GetStorageAccountNameAndKey(string accountId)
=> Async.Task.FromResult((AccountName, AccountKey));
public Task<string?> GetStorageAccountNameKeyByName(string accountName) {
return Async.Task.FromResult(AccountName)!;
}
public Uri GetTableEndpoint(string accountId)
@ -51,15 +61,4 @@ sealed class AzureStorage : IStorage {
public Uri GetBlobEndpoint(string accountId)
=> new($"https://{AccountName}.blob.core.windows.net/");
public IEnumerable<string> CorpusAccounts() {
throw new System.NotImplementedException();
}
public string GetPrimaryAccount(StorageType storageType) {
throw new System.NotImplementedException();
}
public Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
throw new System.NotImplementedException();
}
}

View File

@ -22,15 +22,11 @@ sealed class AzuriteStorage : IStorage {
const string AccountName = "devstoreaccount1";
const string AccountKey = "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
public Task<(string?, string?)> GetStorageAccountNameAndKey(string _accountId)
=> Async.Task.FromResult<(string?, string?)>((AccountName, AccountKey));
public Task<(string, string)> GetStorageAccountNameAndKey(string accountId)
=> Async.Task.FromResult((AccountName, AccountKey));
public IEnumerable<string> GetAccounts(StorageType storageType) {
yield return AccountName;
}
public Task<string?> GetStorageAccountNameAndKeyByName(string accountName) {
throw new System.NotImplementedException();
public Task<string?> GetStorageAccountNameKeyByName(string accountName) {
return Async.Task.FromResult(AccountName)!;
}
public IEnumerable<string> CorpusAccounts() {
@ -40,4 +36,8 @@ sealed class AzuriteStorage : IStorage {
public string GetPrimaryAccount(StorageType storageType) {
throw new System.NotImplementedException();
}
public IEnumerable<string> GetAccounts(StorageType storageType) {
yield return AccountName;
}
}

View File

@ -35,6 +35,8 @@ public class RequestsTests {
var deserialized = (T?)_serializer.Deserialize(stream, typeof(T), CancellationToken.None);
var reserialized = _serializer.Serialize(deserialized);
var result = Encoding.UTF8.GetString(reserialized);
result = result.Replace(System.Environment.NewLine, "\n");
json = json.Replace(System.Environment.NewLine, "\n");
Assert.Equal(json, result);
}

View File

@ -0,0 +1,133 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.OneFuzz.Service;
using Xunit;
namespace Tests;
public class SchedulerTests {
IEnumerable<Task> BuildTasks(int size) {
return Enumerable.Range(0, size).Select(i =>
new Task(
Guid.Empty,
Guid.NewGuid(),
TaskState.Init,
Os.Linux,
new TaskConfig(
Guid.Empty,
null,
new TaskDetails(
Type: TaskType.LibfuzzerFuzz,
Duration: 1,
TargetExe: "fuzz.exe",
TargetEnv: new Dictionary<string, string>(),
TargetOptions: new List<string>()),
Pool: new TaskPool(1, PoolName.Parse("pool")),
Containers: new List<TaskContainers> { new TaskContainers(ContainerType.Setup, new Container("setup")) },
Colocate: true
),
null,
null,
null,
null,
null)
);
}
[Fact]
public void TestAllColocate() {
// all tasks should land in one bucket
var tasks = BuildTasks(10).Select(task => task with { Config = task.Config with { Colocate = true } }
).ToList();
var buckets = Scheduler.BucketTasks(tasks);
var bucket = Assert.Single(buckets);
Assert.True(10 >= bucket.Count());
CheckBuckets(buckets, tasks, 1);
}
[Fact]
public void TestPartialColocate() {
// 2 tasks should land on their own, the rest should be colocated into a
// single bucket.
var tasks = BuildTasks(10).Select((task, i) => {
return i switch {
0 => task with { Config = task.Config with { Colocate = null } },
1 => task with { Config = task.Config with { Colocate = false } },
_ => task
};
}).ToList();
var buckets = Scheduler.BucketTasks(tasks);
var lengths = buckets.Select(b => b.Count()).OrderBy(x => x);
Assert.Equal(new[] { 1, 1, 8 }, lengths);
CheckBuckets(buckets, tasks, 3);
}
[Fact]
public void TestAlluniqueJob() {
// everything has a unique job_id
var tasks = BuildTasks(10).Select(task => {
var jobId = Guid.NewGuid();
return task with { JobId = jobId, Config = task.Config with { JobId = jobId } };
}).ToList();
var buckets = Scheduler.BucketTasks(tasks);
foreach (var bucket in buckets) {
Assert.True(1 >= bucket.Count());
}
CheckBuckets(buckets, tasks, 10);
}
[Fact]
public void TestMultipleJobBuckets() {
// at most 3 tasks per bucket, by job_id
var tasks = BuildTasks(10).Chunk(3).SelectMany(taskChunk => {
var jobId = Guid.NewGuid();
return taskChunk.Select(task => task with { JobId = jobId, Config = task.Config with { JobId = jobId } });
}).ToList();
var buckets = Scheduler.BucketTasks(tasks);
foreach (var bucket in buckets) {
Assert.True(3 >= bucket.Count());
}
CheckBuckets(buckets, tasks, 4);
}
[Fact]
public void TestManyBuckets() {
var jobId = Guid.Parse("00000000-0000-0000-0000-000000000001");
var tasks = BuildTasks(100).Select((task, i) => {
var containers = new List<TaskContainers>(task.Config.Containers!);
if (i % 4 == 0) {
containers[0] = containers[0] with { Name = new Container("setup2") };
}
return task with {
JobId = i % 2 == 0 ? jobId : task.JobId,
Os = i % 3 == 0 ? Os.Windows : task.Os,
Config = task.Config with {
JobId = i % 2 == 0 ? jobId : task.Config.JobId,
Containers = containers,
Pool = i % 5 == 0 ? task.Config.Pool! with { PoolName = PoolName.Parse("alternate-pool") } : task.Config.Pool
}
};
}).ToList();
var buckets = Scheduler.BucketTasks(tasks);
CheckBuckets(buckets, tasks, 12);
}
void CheckBuckets(ILookup<Scheduler.BucketId, Task> buckets, List<Task> tasks, int bucketCount) {
Assert.Equal(buckets.Count, bucketCount);
foreach (var task in tasks) {
Assert.Single(buckets, b => b.Contains(task));
}
}
}

View File

@ -8,12 +8,6 @@ public class ValidatedStringTests {
record ThingContainingPoolName(PoolName PoolName);
[Fact]
public void PoolNameValidatesOnDeserialization() {
var ex = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-not!-a-pool\" }"));
Assert.Equal("unable to parse input as a PoolName", ex.Message);
}
[Fact]
public void PoolNameDeserializesFromString() {
var result = JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-a-pool\" }");