From 4ecd43fc0fb80d0ff5511c2d9a61f702c6c55aed Mon Sep 17 00:00:00 2001 From: George Pollard Date: Fri, 17 Jun 2022 10:28:09 +1200 Subject: [PATCH] Convert `agent_events` to C# & add tests (#2032) ### Function implementation Added implementation of `agent_events` function and some basic tests. Fixed several issues encountered in `TestOperations`. ### Additional JsonConverter The existing Python code for `agent_events` figures out which class to use only based upon the shape of the input, without a discriminator. I have added an additional `JsonConverter` named `SubclassConverter` which will pick a subclass of `T` to deserialize into based upon what properties exist in the JSON. ### Enum helpers Converted some enum helpers to extension methods to make them a bit more readable. --- src/ApiService/ApiService/AgentCanSchedule.cs | 2 +- src/ApiService/ApiService/AgentEvents.cs | 284 +++++++++++++++++ src/ApiService/ApiService/ApiService.csproj | 1 + .../ApiService/OneFuzzTypes/Converters.cs | 101 ++++++ .../ApiService/OneFuzzTypes/Enums.cs | 38 ++- .../ApiService/OneFuzzTypes/Model.cs | 102 +++--- .../ApiService/OneFuzzTypes/Requests.cs | 66 +++- src/ApiService/ApiService/Program.cs | 1 + src/ApiService/ApiService/TimerTasks.cs | 4 +- .../ApiService/onefuzzlib/NodeOperations.cs | 6 +- .../onefuzzlib/NotificationOperations.cs | 2 +- .../ApiService/onefuzzlib/OnefuzzContext.cs | 3 +- .../onefuzzlib/ScalesetOperations.cs | 2 +- .../onefuzzlib/TaskEventOperations.cs | 11 + .../ApiService/onefuzzlib/TaskOperations.cs | 15 +- .../ApiService/onefuzzlib/orm/Queries.cs | 58 ++-- src/ApiService/ApiService/packages.lock.json | 6 + src/ApiService/Tests/Fakes/TestContext.cs | 7 +- .../Tests/Fakes/TestServiceConfiguration.cs | 5 +- .../Tests/Functions/AgentEventsTests.cs | 291 ++++++++++++++++++ .../Tests/Functions/_FunctionTestBase.cs | 15 +- .../Tests/Integration/AzureStorage.cs | 4 +- src/ApiService/Tests/RequestsTests.cs | 192 ++++++++++++ src/ApiService/Tests/packages.lock.json | 8 +- 24 files changed, 1100 insertions(+), 124 deletions(-) create mode 100644 src/ApiService/ApiService/AgentEvents.cs create mode 100644 src/ApiService/ApiService/OneFuzzTypes/Converters.cs create mode 100644 src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs create mode 100644 src/ApiService/Tests/Functions/AgentEventsTests.cs create mode 100644 src/ApiService/Tests/RequestsTests.cs diff --git a/src/ApiService/ApiService/AgentCanSchedule.cs b/src/ApiService/ApiService/AgentCanSchedule.cs index 69799a0ef..577ba8d28 100644 --- a/src/ApiService/ApiService/AgentCanSchedule.cs +++ b/src/ApiService/ApiService/AgentCanSchedule.cs @@ -44,7 +44,7 @@ public class AgentCanSchedule { } var task = await _context.TaskOperations.GetByTaskId(canScheduleRequest.TaskId); - workStopped = task == null || TaskStateHelper.ShuttingDown.Contains(task.State); + workStopped = task == null || task.State.ShuttingDown(); if (allowed) { allowed = (await _context.NodeOperations.AcquireScaleInProtection(node)).IsOk; diff --git a/src/ApiService/ApiService/AgentEvents.cs b/src/ApiService/ApiService/AgentEvents.cs new file mode 100644 index 000000000..2f1dff6a8 --- /dev/null +++ b/src/ApiService/ApiService/AgentEvents.cs @@ -0,0 +1,284 @@ +using System.Threading.Tasks; +using Microsoft.Azure.Functions.Worker; +using Microsoft.Azure.Functions.Worker.Http; +using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; + +namespace Microsoft.OneFuzz.Service; + +public class AgentEvents { + private readonly ILogTracer _log; + + private readonly IOnefuzzContext _context; + + public AgentEvents(ILogTracer log, IOnefuzzContext context) { + _log = log; + _context = context; + } + + private static readonly EntityConverter _entityConverter = new(); + + // [Function("AgentEvents")] + public async Async.Task Run([HttpTrigger("post")] HttpRequestData req) { + var request = await RequestHandling.ParseRequest(req); + if (!request.IsOk || request.OkV == null) { + return await _context.RequestHandling.NotOk(req, request.ErrorV, context: "node event"); + } + + var envelope = request.OkV; + _log.Info($"node event: machine_id: {envelope.MachineId} event: {_entityConverter.ToJsonString(envelope)}"); + + var error = envelope.Event switch { + NodeStateUpdate updateEvent => await OnStateUpdate(envelope.MachineId, updateEvent), + WorkerEvent workerEvent => await OnWorkerEvent(envelope.MachineId, workerEvent), + NodeEvent nodeEvent => await OnNodeEvent(envelope.MachineId, nodeEvent), + _ => new Error(ErrorCode.INVALID_REQUEST, new string[] { $"invalid node event: {envelope.Event.GetType().Name}" }), + }; + + if (error is Error e) { + return await _context.RequestHandling.NotOk(req, e, context: "node event"); + } else { + return await RequestHandling.Ok(req, new BoolResult(true)); + } + } + + private async Async.Task OnNodeEvent(Guid machineId, NodeEvent nodeEvent) { + if (nodeEvent.StateUpdate is not null) { + var result = await OnStateUpdate(machineId, nodeEvent.StateUpdate); + if (result is not null) { + return result; + } + } + + if (nodeEvent.WorkerEvent is not null) { + var result = await OnWorkerEvent(machineId, nodeEvent.WorkerEvent); + if (result is not null) { + return result; + } + } + + return null; + } + + private async Async.Task OnStateUpdate(Guid machineId, NodeStateUpdate ev) { + var node = await _context.NodeOperations.GetByMachineId(machineId); + if (node is null) { + _log.Warning($"unable to process state update event. machine_id:{machineId} state event:{ev}"); + return null; + } + + if (ev.State == NodeState.Free) { + if (node.ReimageRequested || node.DeleteRequested) { + _log.Info($"stopping free node with reset flags: {machineId}"); + await _context.NodeOperations.Stop(node); + return null; + } + + if (_context.NodeOperations.CouldShrinkScaleset(node)) { + _log.Info($"stopping free node to resize scaleset: {machineId}"); + await _context.NodeOperations.SetHalt(node); + return null; + } + } + + if (ev.State == NodeState.Init) { + if (node.DeleteRequested) { + _log.Info($"stopping node (init and delete_requested): {machineId}"); + await _context.NodeOperations.Stop(node); + return null; + } + + // Don’t check reimage_requested, as nodes only send 'init' state once. If + // they send 'init' with reimage_requested, it's because the node was reimaged + // successfully. + node = node with { ReimageRequested = false, InitializedAt = DateTimeOffset.UtcNow }; + await _context.NodeOperations.SetState(node, ev.State); + return null; + } + + _log.Info($"node state update: {machineId} from {node.State} to {ev.State}"); + await _context.NodeOperations.SetState(node, ev.State); + + if (ev.State == NodeState.Free) { + _log.Info($"node now available for work: {machineId}"); + } else if (ev.State == NodeState.SettingUp) { + if (ev.Data is NodeSettingUpEventData settingUpData) { + if (!settingUpData.Tasks.Any()) { + return new Error(ErrorCode.INVALID_REQUEST, Errors: new string[] { + $"setup without tasks. machine_id: {machineId}", + }); + } + + foreach (var taskId in settingUpData.Tasks) { + var task = await _context.TaskOperations.GetByTaskId(taskId); + if (task is null) { + return new Error( + ErrorCode.INVALID_REQUEST, + Errors: new string[] { $"unable to find task: {taskId}" }); + } + + _log.Info($"node starting task. machine_id: {machineId} job_id: {task.JobId} task_id: {task.TaskId}"); + + // The task state may be `running` if it has `vm_count` > 1, and + // another node is concurrently executing the task. If so, leave + // the state as-is, to represent the max progress made. + // + // Other states we would want to preserve are excluded by the + // outermost conditional check. + if (task.State != TaskState.Running && task.State != TaskState.SettingUp) { + await _context.TaskOperations.SetState(task, TaskState.SettingUp); + } + + var nodeTask = new NodeTasks( + MachineId: machineId, + TaskId: task.TaskId, + State: NodeTaskState.SettingUp); + await _context.NodeTasksOperations.Replace(nodeTask); + } + } + } else if (ev.State == NodeState.Done) { + Error? error = null; + if (ev.Data is NodeDoneEventData doneData) { + if (doneData.Error is not null) { + var errorText = _entityConverter.ToJsonString(doneData); + error = new Error(ErrorCode.TASK_FAILED, Errors: new string[] { errorText }); + _log.Error($"node 'done' with error: machine_id:{machineId}, data:{errorText}"); + } + } + + // if tasks are running on the node when it reports as Done + // those are stopped early + await _context.NodeOperations.MarkTasksStoppedEarly(node, error); + await _context.NodeOperations.ToReimage(node, done: true); + } + + return null; + } + + private async Async.Task OnWorkerEvent(Guid machineId, WorkerEvent ev) { + if (ev.Done is not null) { + return await OnWorkerEventDone(machineId, ev.Done); + } + + if (ev.Running is not null) { + return await OnWorkerEventRunning(machineId, ev.Running); + } + + return new Error( + Code: ErrorCode.INVALID_REQUEST, + Errors: new string[] { "WorkerEvent should have either 'done' or 'running' set" }); + } + + private async Async.Task OnWorkerEventRunning(Guid machineId, WorkerRunningEvent running) { + var (task, node) = await ( + _context.TaskOperations.GetByTaskId(running.TaskId), + _context.NodeOperations.GetByMachineId(machineId)); + + if (task is null) { + return new Error( + Code: ErrorCode.INVALID_REQUEST, + Errors: new string[] { $"unable to find task: {running.TaskId}" }); + } + + if (node is null) { + return new Error( + Code: ErrorCode.INVALID_REQUEST, + Errors: new string[] { $"unable to find node: {machineId}" }); + } + + if (!node.State.ReadyForReset()) { + await _context.NodeOperations.SetState(node, NodeState.Busy); + } + + var nodeTask = new NodeTasks( + MachineId: machineId, + TaskId: running.TaskId, + State: NodeTaskState.Running); + await _context.NodeTasksOperations.Replace(nodeTask); + + if (task.State.ShuttingDown()) { + _log.Info($"ignoring task start from node. machine_id:{machineId} job_id:{task.JobId} task_id:{task.TaskId} (state: {task.State})"); + return null; + } + + _log.Info($"task started on node. machine_id:{machineId} job_id:{task.JobId} task_id:{task.TaskId}"); + await _context.TaskOperations.SetState(task, TaskState.Running); + + var taskEvent = new TaskEvent( + TaskId: task.TaskId, + MachineId: machineId, + EventData: new WorkerEvent(Running: running)); + await _context.TaskEventOperations.Replace(taskEvent); + + return null; + } + + private async Async.Task OnWorkerEventDone(Guid machineId, WorkerDoneEvent done) { + var (task, node) = await ( + _context.TaskOperations.GetByTaskId(done.TaskId), + _context.NodeOperations.GetByMachineId(machineId)); + + if (task is null) { + return new Error( + Code: ErrorCode.INVALID_REQUEST, + Errors: new string[] { $"unable to find task: {done.TaskId}" }); + } + + if (node is null) { + return new Error( + Code: ErrorCode.INVALID_REQUEST, + Errors: new string[] { $"unable to find node: {machineId}" }); + } + + // trim stdout/stderr if too long + done = done with { + Stderr = LimitText(done.Stderr), + Stdout = LimitText(done.Stdout), + }; + + if (done.ExitStatus.Success) { + _log.Info($"task done. {task.JobId}:{task.TaskId} status:{done.ExitStatus}"); + await _context.TaskOperations.MarkStopping(task); + + // keep node if keep-on-completion is set + if (task.Config.Debug?.Contains(TaskDebugFlag.KeepNodeOnCompletion) == true) { + node = node with { DebugKeepNode = true }; + await _context.NodeOperations.Replace(node); + } + } else { + await _context.TaskOperations.MarkFailed( + task, + new Error( + Code: ErrorCode.TASK_FAILED, + Errors: new string[] { + $"task failed. exit_status:{done.ExitStatus}", + done.Stdout, + done.Stderr, + })); + + // keep node if any keep options are set + if ((task.Config.Debug?.Contains(TaskDebugFlag.KeepNodeOnFailure) == true) + || (task.Config.Debug?.Contains(TaskDebugFlag.KeepNodeOnCompletion) == true)) { + node = node with { DebugKeepNode = true }; + await _context.NodeOperations.Replace(node); + } + } + + if (!node.DebugKeepNode) { + await _context.NodeTasksOperations.Delete(new NodeTasks(machineId, done.TaskId)); + } + + var taskEvent = new TaskEvent(done.TaskId, machineId, new WorkerEvent { Done = done }); + await _context.TaskEventOperations.Replace(taskEvent); + return null; + } + + private static string LimitText(string str) { + const int MAX_OUTPUT_SIZE = 4096; + + if (str.Length <= MAX_OUTPUT_SIZE) { + return str; + } + + return str[..MAX_OUTPUT_SIZE]; + } +} diff --git a/src/ApiService/ApiService/ApiService.csproj b/src/ApiService/ApiService/ApiService.csproj index cf23b6f12..f3bd578b8 100644 --- a/src/ApiService/ApiService/ApiService.csproj +++ b/src/ApiService/ApiService/ApiService.csproj @@ -34,6 +34,7 @@ + diff --git a/src/ApiService/ApiService/OneFuzzTypes/Converters.cs b/src/ApiService/ApiService/OneFuzzTypes/Converters.cs new file mode 100644 index 000000000..587883748 --- /dev/null +++ b/src/ApiService/ApiService/OneFuzzTypes/Converters.cs @@ -0,0 +1,101 @@ +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Serialization; +using Faithlife.Utility; +using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; + +namespace Microsoft.OneFuzz.Service; + +// SubclassConverter allows serializing and deserializing a set of subclasses +// of the given T abstract base class, as long as all their properties are disjoint. +// +// It identifies which subclass to deserialize based upon the properties provided in the JSON. +public sealed class SubclassConverter : JsonConverter { + private static readonly IReadOnlyList<(HashSet props, Type type)> ChildTypes = FindChildTypes(typeof(T)); + + private static List<(HashSet, Type)> FindChildTypes(Type t) { + if (!t.IsAbstract) { + throw new ArgumentException("SubclassConverter can only be applied to abstract base classes"); + } + + // NB: assumes that the naming converter will always be the same, so we don’t need to regenerate the names each time + var namer = new OnefuzzNamingPolicy(); + + var result = new List<(HashSet props, Type type)>(); + foreach (var type in t.Assembly.ExportedTypes) { + if (type == t) { + // skip the type itself + continue; + } + + if (type.IsAssignableTo(t)) { + var props = type.GetProperties().Select(p => namer.ConvertName(p.Name)).ToHashSet(); + result.Add((props, type)); + } + } + + // ensure that property names are all distinct + for (int i = 0; i < result.Count; ++i) { + for (int j = 0; j < result.Count; ++j) { + if (i == j) { + continue; + } + + var intersection = result[i].props.Intersect(result[j].props); + if (intersection.Any()) { + throw new ArgumentException( + "Cannot use SubclassConverter on types with overlapping property names: " + + $" {result[i].type} and {result[j].type} share properties: {intersection.Join(", ")}"); + } + } + } + + return result; + } + + private static Type FindType(Utf8JsonReader reader) { + // note that this takes the reader by value instead of by 'ref' + // this means it won't affect the reader passed in, which can be + // used to deserialize the whole object + + if (reader.TokenType != JsonTokenType.StartObject) { + throw new JsonException($"Expected to be reading object, not {reader.TokenType}"); + } + + if (!reader.Read() || reader.TokenType != JsonTokenType.PropertyName) { + throw new JsonException("Unable to read object property name"); + } + + var propertyName = reader.GetString(); + if (propertyName is null) { + throw new JsonException("Unable to get property name"); + } + + foreach (var (props, type) in ChildTypes) { + if (props.Contains(propertyName)) { + return type; + } + } + + throw new JsonException($"No subclass found with property '{propertyName}'"); + } + + public override bool CanConvert(Type typeToConvert) { + return typeToConvert == typeof(T) || ChildTypes.Any(x => x.type == typeToConvert); + } + + public override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { + Debug.Assert(options.PropertyNamingPolicy?.GetType() == typeof(OnefuzzNamingPolicy)); // see NB above + + var type = FindType(reader); + return (T?)JsonSerializer.Deserialize(ref reader, type, options); + } + + public override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options) { + Debug.Assert(options.PropertyNamingPolicy?.GetType() == typeof(OnefuzzNamingPolicy)); // see NB above + Debug.Assert(value != null); + + // Note: we invoke GetType to get the derived type to serialize: + JsonSerializer.Serialize(writer, value, value.GetType(), options); + } +} diff --git a/src/ApiService/ApiService/OneFuzzTypes/Enums.cs b/src/ApiService/ApiService/OneFuzzTypes/Enums.cs index f5e03149e..f6fe57237 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Enums.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Enums.cs @@ -173,18 +173,25 @@ public static class VmStateHelper { public static class TaskStateHelper { - private static readonly IReadOnlySet _available = new HashSet { TaskState.Waiting, TaskState.Scheduled, TaskState.SettingUp, TaskState.Running, TaskState.WaitJob }; - private static readonly IReadOnlySet _needsWork = new HashSet { TaskState.Init, TaskState.Stopping }; - private static readonly IReadOnlySet _shuttingDown = new HashSet { TaskState.Stopping, TaskState.Stopped }; - private static readonly IReadOnlySet _hasStarted = new HashSet { TaskState.Running, TaskState.Stopping, TaskState.Stopped }; + public static readonly IReadOnlySet AvailableStates = + new HashSet { TaskState.Waiting, TaskState.Scheduled, TaskState.SettingUp, TaskState.Running, TaskState.WaitJob }; - public static IReadOnlySet Available => _available; + public static readonly IReadOnlySet NeedsWorkStates = + new HashSet { TaskState.Init, TaskState.Stopping }; - public static IReadOnlySet NeedsWork => _needsWork; + public static readonly IReadOnlySet ShuttingDownStates = + new HashSet { TaskState.Stopping, TaskState.Stopped }; - public static IReadOnlySet ShuttingDown => _shuttingDown; + public static readonly IReadOnlySet HasStartedStates = + new HashSet { TaskState.Running, TaskState.Stopping, TaskState.Stopped }; - public static IReadOnlySet HasStarted => _hasStarted; + public static bool Available(this TaskState state) => AvailableStates.Contains(state); + + public static bool NeedsWork(this TaskState state) => NeedsWorkStates.Contains(state); + + public static bool ShuttingDown(this TaskState state) => ShuttingDownStates.Contains(state); + + public static bool HasStarted(this TaskState state) => HasStartedStates.Contains(state); } public enum PoolState { @@ -277,18 +284,21 @@ public enum NodeState { } public static class NodeStateHelper { + private static readonly IReadOnlySet _needsWork = + new HashSet(new[] { NodeState.Done, NodeState.Shutdown, NodeState.Halt }); - private static readonly IReadOnlySet _needsWork = new HashSet(new[] { NodeState.Done, NodeState.Shutdown, NodeState.Halt }); - private static readonly IReadOnlySet _readyForReset = new HashSet(new[] { NodeState.Done, NodeState.Shutdown, NodeState.Halt }); - private static readonly IReadOnlySet _canProcessNewWork = new HashSet(new[] { NodeState.Free }); + private static readonly IReadOnlySet _readyForReset + = new HashSet(new[] { NodeState.Done, NodeState.Shutdown, NodeState.Halt }); + private static readonly IReadOnlySet _canProcessNewWork = + new HashSet(new[] { NodeState.Free }); - public static IReadOnlySet NeedsWork => _needsWork; + public static bool NeedsWork(this NodeState state) => _needsWork.Contains(state); ///If Node is in one of these states, ignore updates from the agent. - public static IReadOnlySet ReadyForReset => _readyForReset; + public static bool ReadyForReset(this NodeState state) => _readyForReset.Contains(state); - public static IReadOnlySet CanProcessNewWork => _canProcessNewWork; + public static bool CanProcessNewWork(this NodeState state) => _canProcessNewWork.Contains(state); } diff --git a/src/ApiService/ApiService/OneFuzzTypes/Model.cs b/src/ApiService/ApiService/OneFuzzTypes/Model.cs index 508983596..7fd86ca08 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Model.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Model.cs @@ -67,8 +67,8 @@ public enum NodeTaskState { public record NodeTasks ( - Guid MachineId, - Guid TaskId, + [PartitionKey] Guid MachineId, + [RowKey] Guid TaskId, NodeTaskState State = NodeTaskState.Init ) : StatefulEntityBase(State); @@ -153,44 +153,40 @@ public record Error(ErrorCode Code, string[]? Errors = null); public record UserInfo(Guid? ApplicationId, Guid? ObjectId, String? Upn); - - public record TaskDetails( - TaskType Type, int Duration, - string? TargetExe, - Dictionary? TargetEnv, - List? TargetOptions, - int? TargetWorkers, - bool? TargetOptionsMerge, - bool? CheckAsanLog, - bool? CheckDebugger, - int? CheckRetryCount, - bool? CheckFuzzerHelp, - bool? ExpectCrashOnFailure, - bool? RenameOutput, - string? SupervisorExe, - Dictionary? SupervisorEnv, - List? SupervisorOptions, - string? SupervisorInputMarker, - string? GeneratorExe, - Dictionary? GeneratorEnv, - List? GeneratorOptions, - string? AnalyzerExe, - Dictionary? AnalyzerEnv, - List AnalyzerOptions, - ContainerType? WaitForFiles, - string? StatsFile, - StatsFormat? StatsFormat, - bool? RebootAfterSetup, - int? TargetTimeout, - int? EnsembleSyncDelay, - bool? PreserveExistingOutputs, - List? ReportList, - int? MinimizedStackDepth, - string? CoverageFilter -); + string? TargetExe = null, + Dictionary? TargetEnv = null, + List? TargetOptions = null, + int? TargetWorkers = null, + bool? TargetOptionsMerge = null, + bool? CheckAsanLog = null, + bool? CheckDebugger = null, + int? CheckRetryCount = null, + bool? CheckFuzzerHelp = null, + bool? ExpectCrashOnFailure = null, + bool? RenameOutput = null, + string? SupervisorExe = null, + Dictionary? SupervisorEnv = null, + List? SupervisorOptions = null, + string? SupervisorInputMarker = null, + string? GeneratorExe = null, + Dictionary? GeneratorEnv = null, + List? GeneratorOptions = null, + string? AnalyzerExe = null, + Dictionary? AnalyzerEnv = null, + List? AnalyzerOptions = null, + ContainerType? WaitForFiles = null, + string? StatsFile = null, + StatsFormat? StatsFormat = null, + bool? RebootAfterSetup = null, + int? TargetTimeout = null, + int? EnsembleSyncDelay = null, + bool? PreserveExistingOutputs = null, + List? ReportList = null, + int? MinimizedStackDepth = null, + string? CoverageFilter = null); public record TaskVm( Region Region, @@ -210,18 +206,17 @@ public record TaskContainers( ContainerType Type, Container Name ); + public record TaskConfig( Guid JobId, List? PrereqTasks, TaskDetails Task, - TaskVm? Vm, - TaskPool? Pool, - List? Containers, - Dictionary? Tags, - List? Debug, - bool? Colocate - ); - + TaskVm? Vm = null, + TaskPool? Pool = null, + List? Containers = null, + Dictionary? Tags = null, + List? Debug = null, + bool? Colocate = null); public record TaskEventSummary( DateTimeOffset? Timestamp, @@ -243,14 +238,21 @@ public record Task( TaskState State, Os Os, TaskConfig Config, - Error? Error, - Authentication? Auth, - DateTimeOffset? Heartbeat, - DateTimeOffset? EndTime, - UserInfo? UserInfo) : StatefulEntityBase(State) { + Error? Error = null, + Authentication? Auth = null, + DateTimeOffset? Heartbeat = null, + DateTimeOffset? EndTime = null, + UserInfo? UserInfo = null) : StatefulEntityBase(State) { List Events { get; set; } = new List(); List Nodes { get; set; } = new List(); } + +public record TaskEvent( + [PartitionKey, RowKey] Guid TaskId, + Guid MachineId, + WorkerEvent EventData +) : EntityBase; + public record AzureSecurityExtensionConfig(); public record GenevaExtensionConfig(); diff --git a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs index 3b88a21b0..7f7432f9f 100644 --- a/src/ApiService/ApiService/OneFuzzTypes/Requests.cs +++ b/src/ApiService/ApiService/OneFuzzTypes/Requests.cs @@ -1,4 +1,6 @@ -namespace Microsoft.OneFuzz.Service; +using System.Text.Json.Serialization; + +namespace Microsoft.OneFuzz.Service; public record BaseRequest(); @@ -15,3 +17,65 @@ public record NodeCommandDelete( Guid MachineId, string MessageId ) : BaseRequest; + +public record NodeStateEnvelope( + NodeEventBase Event, + Guid MachineId +) : BaseRequest; + +// either NodeEvent or WorkerEvent +[JsonConverter(typeof(SubclassConverter))] +public abstract record NodeEventBase; + +public record NodeEvent( + [property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + NodeStateUpdate? StateUpdate, + [property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + WorkerEvent? WorkerEvent +) : NodeEventBase; + +public record WorkerEvent( + [property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + WorkerDoneEvent? Done = null, + [property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + WorkerRunningEvent? Running = null +) : NodeEventBase; + +public record WorkerRunningEvent( + Guid TaskId); + +public record WorkerDoneEvent( + Guid TaskId, + ExitStatus ExitStatus, + string Stderr, + string Stdout); + +public record NodeStateUpdate( + NodeState State, + [property: JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingNull)] + NodeStateData? Data = null +) : NodeEventBase; + +// NodeSettingUpEventData, NodeDoneEventData, or ProcessOutput +[JsonConverter(typeof(SubclassConverter))] +public abstract record NodeStateData; + +public record NodeSettingUpEventData( + List Tasks +) : NodeStateData; + +public record NodeDoneEventData( + string? Error, + ProcessOutput? ScriptOutput +) : NodeStateData; + +public record ProcessOutput( + ExitStatus ExitStatus, + string Stderr, + string Stdout +) : NodeStateData; + +public record ExitStatus( + int? Code, + int? Signal, + bool Success); diff --git a/src/ApiService/ApiService/Program.cs b/src/ApiService/ApiService/Program.cs index 08665aeb9..9eb81208a 100644 --- a/src/ApiService/ApiService/Program.cs +++ b/src/ApiService/ApiService/Program.cs @@ -78,6 +78,7 @@ public class Program { .AddScoped() .AddScoped() .AddScoped() + .AddScoped() .AddScoped() .AddScoped() .AddScoped() diff --git a/src/ApiService/ApiService/TimerTasks.cs b/src/ApiService/ApiService/TimerTasks.cs index ab90c1db9..6f066de26 100644 --- a/src/ApiService/ApiService/TimerTasks.cs +++ b/src/ApiService/ApiService/TimerTasks.cs @@ -44,7 +44,7 @@ public class TimerTasks { await _jobOperations.ProcessStateUpdates(job); } - var tasks = _taskOperations.SearchStates(states: TaskStateHelper.NeedsWork); + var tasks = _taskOperations.SearchStates(states: TaskStateHelper.NeedsWorkStates); await foreach (var task in tasks) { _logger.Info($"update task: {task.TaskId}"); await _taskOperations.ProcessStateUpdate(task); @@ -55,5 +55,3 @@ public class TimerTasks { await _jobOperations.StopNeverStartedJobs(); } } - - diff --git a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs index fe6ab740c..24b40c9c6 100644 --- a/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NodeOperations.cs @@ -98,12 +98,12 @@ public class NodeOperations : StatefulOrm, INodeOperations { return false; } - if (!NodeStateHelper.CanProcessNewWork.Contains(node.State)) { + if (!node.State.CanProcessNewWork()) { _logTracer.Info($"can_process_new_work node not in appropriate state for new work machine_id:{node.MachineId} state:{node.State}"); return false; } - if (NodeStateHelper.ReadyForReset.Contains(node.State)) { + if (node.State.ReadyForReset()) { _logTracer.Info($"can_process_new_work node is set for reset. machine_id:{node.MachineId}"); return false; } @@ -175,7 +175,7 @@ public class NodeOperations : StatefulOrm, INodeOperations { var nodeState = node.State; if (done) { - if (!NodeStateHelper.ReadyForReset.Contains(node.State)) { + if (!node.State.ReadyForReset()) { nodeState = NodeState.Done; } } diff --git a/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs b/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs index 8bfcb3e4e..5f205fac6 100644 --- a/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/NotificationOperations.cs @@ -79,7 +79,7 @@ public class NotificationOperations : Orm, INotificationOperations public IAsyncEnumerable<(Task, IEnumerable)> GetQueueTasks() { // Nullability mismatch: We filter tuples where the containers are null - return _context.TaskOperations.SearchStates(states: TaskStateHelper.Available) + return _context.TaskOperations.SearchStates(states: TaskStateHelper.AvailableStates) .Select(task => (task, _context.TaskOperations.GetInputContainerQueues(task.Config))) .Where(taskTuple => taskTuple.Item2 != null)!; } diff --git a/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs index 9c7f68162..60cbcbec4 100644 --- a/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs +++ b/src/ApiService/ApiService/onefuzzlib/OnefuzzContext.cs @@ -29,6 +29,7 @@ public interface IOnefuzzContext { IServiceConfig ServiceConfiguration { get; } IStorage Storage { get; } ITaskOperations TaskOperations { get; } + ITaskEventOperations TaskEventOperations { get; } IUserCredentials UserCredentials { get; } IVmOperations VmOperations { get; } IVmssOperations VmssOperations { get; } @@ -46,6 +47,7 @@ public class OnefuzzContext : IOnefuzzContext { public IWebhookOperations WebhookOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IWebhookOperations service"); } public IWebhookMessageLogOperations WebhookMessageLogOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IWebhookMessageLogOperations service"); } public ITaskOperations TaskOperations { get => _serviceProvider.GetService() ?? throw new Exception("No ITaskOperations service"); } + public ITaskEventOperations TaskEventOperations => _serviceProvider.GetRequiredService(); public IQueue Queue { get => _serviceProvider.GetService() ?? throw new Exception("No IQueue service"); } public IStorage Storage { get => _serviceProvider.GetService() ?? throw new Exception("No IStorage service"); } public IProxyOperations ProxyOperations { get => _serviceProvider.GetService() ?? throw new Exception("No IProxyOperations service"); } @@ -79,4 +81,3 @@ public class OnefuzzContext : IOnefuzzContext { _serviceProvider = serviceProvider; } } - diff --git a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs index 5626179c4..a4d9b5b9a 100644 --- a/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/ScalesetOperations.cs @@ -191,7 +191,7 @@ public class ScalesetOperations : StatefulOrm, IScalese var nodesToReset = from x in existingNodes - where NodeStateHelper.ReadyForReset.Contains(x.State) + where x.State.ReadyForReset() select x; diff --git a/src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs b/src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs new file mode 100644 index 000000000..539d13d1e --- /dev/null +++ b/src/ApiService/ApiService/onefuzzlib/TaskEventOperations.cs @@ -0,0 +1,11 @@ +using ApiService.OneFuzzLib.Orm; + +namespace Microsoft.OneFuzz.Service; + +public interface ITaskEventOperations : IOrm { +} + +public sealed class TaskEventOperations : Orm, ITaskEventOperations { + public TaskEventOperations(ILogTracer logTracer, IOnefuzzContext context) + : base(logTracer, context) { } +} diff --git a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs index 9bf78c228..b364f1633 100644 --- a/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs +++ b/src/ApiService/ApiService/onefuzzlib/TaskOperations.cs @@ -71,19 +71,20 @@ public class TaskOperations : StatefulOrm, ITaskOperations { } public async Async.Task MarkStopping(Task task) { - if (TaskStateHelper.ShuttingDown.Contains(task.State)) { + if (task.State.ShuttingDown()) { _logTracer.Verbose($"ignoring post - task stop calls to stop {task.JobId}:{task.TaskId}"); return; } - if (TaskStateHelper.HasStarted.Contains(task.State)) { + if (!task.State.HasStarted()) { await MarkFailed(task, new Error(Code: ErrorCode.TASK_FAILED, Errors: new[] { "task never started" })); - + } else { + await SetState(task, TaskState.Stopping); } } public async Async.Task MarkFailed(Task task, Error error, List? taskInJob = null) { - if (TaskStateHelper.ShuttingDown.Contains(task.State)) { + if (task.State.ShuttingDown()) { _logTracer.Verbose( $"ignoring post-task stop failures for {task.JobId}:{task.TaskId}" ); @@ -105,7 +106,7 @@ public class TaskOperations : StatefulOrm, ITaskOperations { } private async Async.Task MarkDependantsFailed(Task task, List? taskInJob = null) { - taskInJob = taskInJob ?? await QueryAsync(filter: $"job_id eq ''{task.JobId}").ToListAsync(); + taskInJob ??= await SearchByPartitionKey(task.JobId.ToString()).ToListAsync(); foreach (var t in taskInJob) { if (t.Config.PrereqTasks != null) { @@ -123,6 +124,8 @@ public class TaskOperations : StatefulOrm, ITaskOperations { if (task.State == TaskState.Running || task.State == TaskState.SettingUp) { task = await OnStart(task with { State = state }); + } else { + task = task with { State = state }; } await this.Replace(task); @@ -210,7 +213,7 @@ public class TaskOperations : StatefulOrm, ITaskOperations { return false; } - if (!TaskStateHelper.HasStarted.Contains(t.State)) { + if (!t.State.HasStarted()) { return false; } } diff --git a/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs b/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs index 7d47f1064..51bf2c9c0 100644 --- a/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs +++ b/src/ApiService/ApiService/onefuzzlib/orm/Queries.cs @@ -1,60 +1,58 @@ using System.Text.Json; +using Azure.Data.Tables; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; namespace ApiService.OneFuzzLib.Orm { public static class Query { - public static string PartitionKey(string partitionKey) { - // TODO: need to escape - return $"PartitionKey eq '{partitionKey}'"; - } + // For all queries below, note that TableClient.CreateQueryFilter takes a FormattableString + // and handles escaping the interpolated values properly. It also handles quoting the values + // where needed, so use {string} and not '{string}'. - public static string RowKey(string rowKey) { - // TODO: need to escape - return $"RowKey eq '{rowKey}'"; - } + public static string PartitionKey(string partitionKey) + => TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}"); - public static string SingleEntity(string partitionKey, string rowKey) { - // TODO: need to escape - return $"(PartitionKey eq '{partitionKey}') and (RowKey eq '{rowKey}')"; - } + public static string RowKey(string rowKey) + => TableClient.CreateQueryFilter($"RowKey eq {rowKey}"); - public static string Or(IEnumerable queries) { - return string.Join(" or ", queries.Select(x => $"({x})")); - } + public static string SingleEntity(string partitionKey, string rowKey) + => TableClient.CreateQueryFilter($"(PartitionKey eq {partitionKey}) and (RowKey eq {rowKey})"); - public static string Or(string q1, string q2) { - return Or(new[] { q1, q2 }); - } + public static string Or(IEnumerable queries) + // subqueries should already be properly escaped + => string.Join(" or ", queries.Select(x => $"({x})")); - public static string And(IEnumerable queries) { - return string.Join(" and ", queries.Select(x => $"({x})")); - } + public static string Or(string q1, string q2) => Or(new[] { q1, q2 }); - public static string And(string q1, string q2) { - return And(new[] { q1, q2 }); - } + public static string And(IEnumerable queries) + // subqueries should already be properly escaped + => string.Join(" and ", queries.Select(x => $"({x})")); + + public static string And(string q1, string q2) => And(new[] { q1, q2 }); public static string EqualAny(string property, IEnumerable values) { - return Or(values.Select(x => $"{property} eq '{x}'")); + // property should not be escaped, but the string should be: + return Or(values.Select(x => $"{property} eq '{EscapeString(x)}'")); } - public static string EqualAnyEnum(string property, IEnumerable enums) where T : Enum { IEnumerable convertedEnums = enums.Select(x => JsonSerializer.Serialize(x, EntityConverter.GetJsonSerializerOptions()).Trim('"')); return EqualAny(property, convertedEnums); } public static string TimeRange(DateTimeOffset min, DateTimeOffset max) { - // NB: this uses the auto-populated Timestamp property, and will result in scanning + // NB: this uses the auto-populated Timestamp property, and will result in a table scan // TODO: should this be inclusive at the endpoints? - return $"Timestamp lt datetime'{max:o}' and Timestamp gt datetime'{min:o}'"; + return TableClient.CreateQueryFilter($"Timestamp lt {max} and Timestamp gt {min}"); } public static string StartsWith(string property, string prefix) { var upperBound = prefix[..(prefix.Length - 1)] + (char)(prefix.Last() + 1); - // TODO: escaping - return $"{property} ge '{prefix}' and {property} lt '{upperBound}'"; + // property name should not be escaped, but strings should be: + return $"{property} ge '{EscapeString(prefix)}' and {property} lt '{EscapeString(upperBound)}'"; } + + // makes a string safe for interpolation between '…' + private static string EscapeString(string s) => s.Replace("'", "''"); } } diff --git a/src/ApiService/ApiService/packages.lock.json b/src/ApiService/ApiService/packages.lock.json index fbc2c2d78..934647070 100644 --- a/src/ApiService/ApiService/packages.lock.json +++ b/src/ApiService/ApiService/packages.lock.json @@ -302,6 +302,12 @@ "Microsoft.Bcl.AsyncInterfaces": "6.0.0" } }, + "TaskTupleAwaiter": { + "type": "Direct", + "requested": "[2.0.0, )", + "resolved": "2.0.0", + "contentHash": "rXkSI9t4vP2EaPhuchsWiD3elcLNth3UOZAlGohGmuckpkiOr57oMHuzM5WDzz7MJd+ZewE27/WfrZhhhFDHzA==" + }, "Azure.Storage.Common": { "type": "Transitive", "resolved": "12.10.0", diff --git a/src/ApiService/Tests/Fakes/TestContext.cs b/src/ApiService/Tests/Fakes/TestContext.cs index 3cdc8e970..c321db13b 100644 --- a/src/ApiService/Tests/Fakes/TestContext.cs +++ b/src/ApiService/Tests/Fakes/TestContext.cs @@ -20,6 +20,8 @@ public sealed class TestContext : IOnefuzzContext { NodeOperations = new NodeOperations(logTracer, this); JobOperations = new JobOperations(logTracer, this); NodeTasksOperations = new NodeTasksOperations(logTracer, this); + TaskEventOperations = new TaskEventOperations(logTracer, this); + NodeMessageOperations = new NodeMessageOperations(logTracer, this); } public TestEvents Events { get; set; } = new(); @@ -49,6 +51,8 @@ public sealed class TestContext : IOnefuzzContext { public IJobOperations JobOperations { get; } public INodeOperations NodeOperations { get; } public INodeTasksOperations NodeTasksOperations { get; } + public ITaskEventOperations TaskEventOperations { get; } + public INodeMessageOperations NodeMessageOperations { get; } // -- Remainder not implemented -- @@ -66,11 +70,8 @@ public sealed class TestContext : IOnefuzzContext { public IIpOperations IpOperations => throw new System.NotImplementedException(); - public ILogAnalytics LogAnalytics => throw new System.NotImplementedException(); - public INodeMessageOperations NodeMessageOperations => throw new System.NotImplementedException(); - public INotificationOperations NotificationOperations => throw new System.NotImplementedException(); public IPoolOperations PoolOperations => throw new System.NotImplementedException(); diff --git a/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs b/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs index 47bb74b5f..ac77f41cd 100644 --- a/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs +++ b/src/ApiService/Tests/Fakes/TestServiceConfiguration.cs @@ -13,6 +13,8 @@ sealed class TestServiceConfiguration : IServiceConfig { public string? OneFuzzFuncStorage { get; } + public string OneFuzzVersion => "9999.0.0"; // very big version to pass any >= checks + // -- Remainder not implemented -- public LogDestination[] LogDestinations { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } @@ -39,7 +41,6 @@ sealed class TestServiceConfiguration : IServiceConfig { public string? OneFuzzDataStorage => throw new System.NotImplementedException(); - public string? OneFuzzInstance => throw new System.NotImplementedException(); public string? OneFuzzInstanceName => throw new System.NotImplementedException(); @@ -55,6 +56,4 @@ sealed class TestServiceConfiguration : IServiceConfig { public string? OneFuzzResourceGroup => throw new System.NotImplementedException(); public string? OneFuzzTelemetry => throw new System.NotImplementedException(); - - public string OneFuzzVersion => throw new System.NotImplementedException(); } diff --git a/src/ApiService/Tests/Functions/AgentEventsTests.cs b/src/ApiService/Tests/Functions/AgentEventsTests.cs new file mode 100644 index 000000000..280025ab9 --- /dev/null +++ b/src/ApiService/Tests/Functions/AgentEventsTests.cs @@ -0,0 +1,291 @@ +using System; +using System.Linq; +using System.Net; +using Microsoft.OneFuzz.Service; +using Tests.Fakes; +using Xunit; +using Xunit.Abstractions; + +using Async = System.Threading.Tasks; + +namespace Tests.Functions; + +[Trait("Category", "Integration")] +public class AzureStorageAgentEventsTest : AgentEventsTestsBase { + public AzureStorageAgentEventsTest(ITestOutputHelper output) + : base(output, Integration.AzureStorage.FromEnvironment(), "UNUSED") { } +} + +public class AzuriteAgentEventsTest : AgentEventsTestsBase { + public AzuriteAgentEventsTest(ITestOutputHelper output) + : base(output, new Integration.AzuriteStorage(), "devstoreaccount1") { } +} + +public abstract class AgentEventsTestsBase : FunctionTestBase { + public AgentEventsTestsBase(ITestOutputHelper output, IStorage storage, string accountId) + : base(output, storage, accountId) { } + + // shared helper variables (per-test) + readonly Guid jobId = Guid.NewGuid(); + readonly Guid taskId = Guid.NewGuid(); + readonly Guid machineId = Guid.NewGuid(); + readonly string poolName = $"pool-{Guid.NewGuid()}"; + readonly Guid poolId = Guid.NewGuid(); + readonly string poolVersion = $"version-{Guid.NewGuid()}"; + + [Fact] + public async Async.Task WorkerEventMustHaveDoneOrRunningSet() { + var func = new AgentEvents(Logger, Context); + + var data = new NodeStateEnvelope( + MachineId: Guid.NewGuid(), + Event: new WorkerEvent(null, null)); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + } + + + [Fact] + public async Async.Task WorkerDone_WithSuccessfulResult_ForRunningTask_MarksTaskAsStopping() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion), + // task state is running + new Task(jobId, taskId, TaskState.Running, Os.Linux, + new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 100)))); + + var func = new AgentEvents(Logger, Context); + + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new WorkerEvent(Done: new WorkerDoneEvent( + TaskId: taskId, + ExitStatus: new ExitStatus(Code: 0, Signal: 0, Success: true), + "stderr", + "stdout"))); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + var task = await Context.TaskOperations.SearchAll().SingleAsync(); + + // should have transitioned into stopping + Assert.Equal(TaskState.Stopping, task.State); + } + + [Fact] + public async Async.Task WorkerDone_WithFailedResult_ForRunningTask_MarksTaskAsStoppingAndErrored() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion), + // task state is running + new Task(jobId, taskId, TaskState.Running, Os.Linux, + new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 100)))); + + + var func = new AgentEvents(Logger, Context); + + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new WorkerEvent(Done: new WorkerDoneEvent( + TaskId: taskId, + ExitStatus: new ExitStatus(Code: 0, Signal: 0, Success: false), // unsuccessful result + "stderr", + "stdout"))); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + var task = await Context.TaskOperations.SearchAll().SingleAsync(); + Assert.Equal(TaskState.Stopping, task.State); // should have transitioned into stopping + Assert.Equal(ErrorCode.TASK_FAILED, task.Error?.Code); // should be an error + } + + [Fact] + public async Async.Task WorkerDone_ForNonStartedTask_MarksTaskAsFailed() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion), + // task state is scheduled, not running + new Task(jobId, taskId, TaskState.Scheduled, Os.Linux, + new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 100)))); + + var func = new AgentEvents(Logger, Context); + + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new WorkerEvent(Done: new WorkerDoneEvent( + TaskId: taskId, + ExitStatus: new ExitStatus(0, 0, true), + "stderr", + "stdout"))); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + var task = await Context.TaskOperations.SearchAll().SingleAsync(); + + // should be failed - it never started running + Assert.Equal(TaskState.Stopping, task.State); + Assert.Equal(ErrorCode.TASK_FAILED, task.Error?.Code); + } + + [Fact] + public async Async.Task WorkerRunning_ForMissingTask_ReturnsError() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion)); + + var func = new AgentEvents(Logger, Context); + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new WorkerEvent(Running: new WorkerRunningEvent(taskId))); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + Assert.Contains("unable to find task", BodyAsString(result)); + } + + [Fact] + public async Async.Task WorkerRunning_ForMissingNode_ReturnsError() { + await Context.InsertAll( + new Task(jobId, taskId, TaskState.Running, Os.Linux, + new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 0)))); + + var func = new AgentEvents(Logger, Context); + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new WorkerEvent(Running: new WorkerRunningEvent(taskId))); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode); + Assert.Contains("unable to find node", BodyAsString(result)); + } + + [Fact] + public async Async.Task WorkerRunning_HappyPath() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion), + new Task(jobId, taskId, TaskState.Running, Os.Linux, + new TaskConfig(jobId, null, new TaskDetails(TaskType.Coverage, 0)))); + + var func = new AgentEvents(Logger, Context); + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new WorkerEvent(Running: new WorkerRunningEvent(taskId))); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + // perform checks in parallel + await Async.Task.WhenAll( + Async.Task.Run(async () => { + // task should be marked running + var task = await Context.TaskOperations.SearchAll().SingleAsync(); + Assert.Equal(TaskState.Running, task.State); + }), + Async.Task.Run(async () => { + // node should now be marked busy + var node = await Context.NodeOperations.SearchAll().SingleAsync(); + Assert.Equal(NodeState.Busy, node.State); + }), + Async.Task.Run(async () => { + // there should be a node-task with correct values + var nodeTask = await Context.NodeTasksOperations.SearchAll().SingleAsync(); + Assert.Equal(machineId, nodeTask.MachineId); + Assert.Equal(taskId, nodeTask.TaskId); + Assert.Equal(NodeTaskState.Running, nodeTask.State); + }), + Async.Task.Run(async () => { + // there should be a task-event with correct values + var taskEvent = await Context.TaskEventOperations.SearchAll().SingleAsync(); + Assert.Equal(taskId, taskEvent.TaskId); + Assert.Equal(machineId, taskEvent.MachineId); + Assert.Equal(new WorkerEvent(Running: new WorkerRunningEvent(taskId)), taskEvent.EventData); + })); + } + + [Fact] + public async Async.Task NodeStateUpdate_ForMissingNode_IgnoresEvent() { + // nothing present in storage + + var func = new AgentEvents(Logger, Context); + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new NodeStateUpdate(NodeState.Init)); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + } + + + [Fact] + public async Async.Task NodeStateUpdate_CanTransitionFromInitToReady() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion, State: NodeState.Init)); + + var func = new AgentEvents(Logger, Context); + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new NodeStateUpdate(NodeState.Ready)); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + var node = await Context.NodeOperations.SearchAll().SingleAsync(); + Assert.Equal(NodeState.Ready, node.State); + } + + [Fact] + public async Async.Task NodeStateUpdate_BecomingFree_StopsNode_IfMarkedForReimage() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion, ReimageRequested: true)); + + var func = new AgentEvents(Logger, Context); + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new NodeStateUpdate(NodeState.Free)); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + await Async.Task.WhenAll( + Async.Task.Run(async () => { + // should still be in init state: + var node = await Context.NodeOperations.SearchAll().SingleAsync(); + Assert.Equal(NodeState.Init, node.State); + }), + Async.Task.Run(async () => { + // the node should be told to stop: + var messages = await Context.NodeMessageOperations.SearchAll().ToListAsync(); + Assert.Contains(messages, msg => + msg.MachineId == machineId && + msg.Message.Stop == new StopNodeCommand()); + })); + } + + [Fact] + public async Async.Task NodeStateUpdate_BecomingFree_StopsNode_IfMarkedForDeletion() { + await Context.InsertAll( + new Node(poolName, machineId, poolId, poolVersion, DeleteRequested: true)); + + var func = new AgentEvents(Logger, Context); + var data = new NodeStateEnvelope( + MachineId: machineId, + Event: new NodeStateUpdate(NodeState.Free)); + + var result = await func.Run(TestHttpRequestData.FromJson("POST", data)); + Assert.Equal(HttpStatusCode.OK, result.StatusCode); + + await Async.Task.WhenAll( + Async.Task.Run(async () => { + // the node should still be in init state: + var node = await Context.NodeOperations.SearchAll().SingleAsync(); + Assert.Equal(NodeState.Init, node.State); + }), + Async.Task.Run(async () => { + // the node should be told to stop: + var messages = await Context.NodeMessageOperations.SearchAll().ToListAsync(); + Assert.Contains(messages, msg => + msg.MachineId == machineId && + msg.Message.Stop == new StopNodeCommand()); + })); + } +} diff --git a/src/ApiService/Tests/Functions/_FunctionTestBase.cs b/src/ApiService/Tests/Functions/_FunctionTestBase.cs index 97670996a..e01b482d0 100644 --- a/src/ApiService/Tests/Functions/_FunctionTestBase.cs +++ b/src/ApiService/Tests/Functions/_FunctionTestBase.cs @@ -1,6 +1,8 @@ using System; +using System.IO; using ApiService.OneFuzzLib.Orm; using Azure.Data.Tables; +using Microsoft.Azure.Functions.Worker.Http; using Microsoft.OneFuzz.Service; using Tests.Fakes; using Xunit.Abstractions; @@ -24,16 +26,21 @@ public abstract class FunctionTestBase : IDisposable { // with each other - generate a prefix like t12345678 (table names must start with letter) private readonly string _tablePrefix = "t" + Guid.NewGuid().ToString()[..8]; - private readonly string _accountId; - protected ILogTracer Logger { get; } - protected TestContext CreateTestContext() => new(Logger, _storage, _tablePrefix, _accountId); + protected TestContext Context { get; } public FunctionTestBase(ITestOutputHelper output, IStorage storage, string accountId) { Logger = new TestLogTracer(output); _storage = storage; - _accountId = accountId; + + Context = new TestContext(Logger, _storage, _tablePrefix, accountId); + } + + protected static string BodyAsString(HttpResponseData data) { + data.Body.Seek(0, SeekOrigin.Begin); + using var sr = new StreamReader(data.Body); + return sr.ReadToEnd(); } public void Dispose() { diff --git a/src/ApiService/Tests/Integration/AzureStorage.cs b/src/ApiService/Tests/Integration/AzureStorage.cs index b69c0fdf9..a7d980e81 100644 --- a/src/ApiService/Tests/Integration/AzureStorage.cs +++ b/src/ApiService/Tests/Integration/AzureStorage.cs @@ -15,11 +15,11 @@ sealed class AzureStorage : IStorage { var accountKey = Environment.GetEnvironmentVariable("AZURE_ACCOUNT_KEY"); if (accountName is null) { - throw new Exception("AZURE_ACCOUNT_NAME must be set in environment to run integration tests"); + throw new Exception("AZURE_ACCOUNT_NAME must be set in environment to run integration tests (use --filter 'Category!=Integration' to skip them)"); } if (accountKey is null) { - throw new Exception("AZURE_ACCOUNT_KEY must be set in environment to run integration tests"); + throw new Exception("AZURE_ACCOUNT_KEY must be set in environment to run integration tests (use --filter 'Category!=Integration' to skip them)"); } return new AzureStorage(accountName, accountKey); diff --git a/src/ApiService/Tests/RequestsTests.cs b/src/ApiService/Tests/RequestsTests.cs new file mode 100644 index 000000000..176ee463f --- /dev/null +++ b/src/ApiService/Tests/RequestsTests.cs @@ -0,0 +1,192 @@ +using System.IO; +using System.Text; +using System.Text.Json; +using System.Threading; +using Azure.Core.Serialization; +using Microsoft.OneFuzz.Service; +using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; +using Xunit; + +namespace Tests; + +// This class contains tests for serialization and +// deserialization of examples generated by the +// onefuzz-agent’s `debug` sub-command. We test each +// example for roundtripping which ensures that no +// data is lost upon deserialization. +// +// We could set this up to run onefuzz-agent itself +// but that seems like additional unnecessary complexity; +// at the moment the Rust code is not built when building C#. +public class RequestsTests { + + private readonly JsonObjectSerializer _serializer = new(serializationOptions()); + + private static JsonSerializerOptions serializationOptions() { + // base on the serialization options used at runtime, but + // also indent to match inputs: + var result = EntityConverter.GetJsonSerializerOptions(); + result.WriteIndented = true; + return result; + } + + private void AssertRoundtrips(string json) { + var stream = new MemoryStream(Encoding.UTF8.GetBytes(json)); + var deserialized = (T?)_serializer.Deserialize(stream, typeof(T), CancellationToken.None); + var reserialized = _serializer.Serialize(deserialized); + var result = Encoding.UTF8.GetString(reserialized); + Assert.Equal(json, result); + } + + [Fact] + public void NodeEvent_WorkerEvent_Done() { + // generated with: onefuzz-agent debug node_event worker_event done + + AssertRoundtrips(@"{ + ""event"": { + ""worker_event"": { + ""done"": { + ""task_id"": ""00e1b131-e2a1-444d-8cc6-841e6cd48f93"", + ""exit_status"": { + ""code"": 0, + ""signal"": null, + ""success"": true + }, + ""stderr"": ""stderr output goes here"", + ""stdout"": ""stdout output goes here"" + } + } + }, + ""machine_id"": ""5ccbe157-a84c-486a-8171-d213fba27247"" +}"); + } + + [Fact] + public void NodeEvent_WorkerEvent_Running() { + // generated with: onefuzz-agent debug node_event worker_event running + + AssertRoundtrips(@"{ + ""event"": { + ""worker_event"": { + ""running"": { + ""task_id"": ""1763e113-02a0-4a3e-b477-92762f030d95"" + } + } + }, + ""machine_id"": ""e819efa5-c43f-46a2-bf9e-cc6a6de86ef9"" +}"); + } + + [Fact] + public void NodeEvent_StateUpdate_Init() { + // generated with: onefuzz-agent debug node_event state_update '"init"' + + AssertRoundtrips(@"{ + ""event"": { + ""state_update"": { + ""state"": ""init"" + } + }, + ""machine_id"": ""38bd035b-fa5b-4cbc-9037-aa4e6550f713"" +}"); + } + + [Fact] + public void NodeEvent_StateUpdate_Free() { + // generated with: onefuzz-agent debug node_event state_update '"free"' + + AssertRoundtrips(@"{ + ""event"": { + ""state_update"": { + ""state"": ""free"" + } + }, + ""machine_id"": ""09a0cd4c-a918-4777-98b6-617e42084eb1"" +}"); + } + + [Fact] + public void NodeEvent_StateUpdate_SettingUp() { + // generated with: onefuzz-agent debug node_event state_update '"setting_up"' + + AssertRoundtrips(@"{ + ""event"": { + ""state_update"": { + ""state"": ""setting_up"", + ""data"": { + ""tasks"": [ + ""163121e2-7df3-4567-9bd8-21b1653fac83"", + ""00604d49-b400-4877-8630-1d6ade31a61d"", + ""719a6316-98c4-4e77-9f3a-324f09505887"" + ] + } + } + }, + ""machine_id"": ""82da6784-fd8c-426a-8baf-643654a060d8"" +}"); + } + + + [Fact] + public void NodeEvent_StateUpdate_Rebooting() { + // generated with: onefuzz-agent debug node_event state_update '"rebooting"' + + AssertRoundtrips(@"{ + ""event"": { + ""state_update"": { + ""state"": ""rebooting"" + } + }, + ""machine_id"": ""8825ca94-11d9-4e83-9df0-c052ee8b77c8"" +}"); + } + + + [Fact] + public void NodeEvent_StateUpdate_Ready() { + // generated with: onefuzz-agent debug node_event state_update '"ready"' + + AssertRoundtrips(@"{ + ""event"": { + ""state_update"": { + ""state"": ""ready"" + } + }, + ""machine_id"": ""a98f9a27-cfb9-426b-a6f2-5b2c04268697"" +}"); + } + + + [Fact] + public void NodeEvent_StateUpdate_Busy() { + // generated with: onefuzz-agent debug node_event state_update '"busy"' + + AssertRoundtrips(@"{ + ""event"": { + ""state_update"": { + ""state"": ""busy"" + } + }, + ""machine_id"": ""e4c70423-bb5c-40a9-9645-942243738240"" +}"); + } + + + [Fact] + public void NodeEvent_StateUpdate_Done() { + // generated with: onefuzz-agent debug node_event state_update '"done"' + + AssertRoundtrips(@"{ + ""event"": { + ""state_update"": { + ""state"": ""done"", + ""data"": { + ""error"": null, + ""script_output"": null + } + } + }, + ""machine_id"": ""5284cba4-aa7a-4285-b2b8-d5123c182bc3"" +}"); + } +} diff --git a/src/ApiService/Tests/packages.lock.json b/src/ApiService/Tests/packages.lock.json index 2a8cfab88..1a25fa420 100644 --- a/src/ApiService/Tests/packages.lock.json +++ b/src/ApiService/Tests/packages.lock.json @@ -2197,6 +2197,11 @@ "System.Xml.ReaderWriter": "4.3.0" } }, + "TaskTupleAwaiter": { + "type": "Transitive", + "resolved": "2.0.0", + "contentHash": "rXkSI9t4vP2EaPhuchsWiD3elcLNth3UOZAlGohGmuckpkiOr57oMHuzM5WDzz7MJd+ZewE27/WfrZhhhFDHzA==" + }, "xunit.abstractions": { "type": "Transitive", "resolved": "2.0.3", @@ -2273,7 +2278,8 @@ "Microsoft.Identity.Web.TokenCache": "1.23.1", "Semver": "2.1.0", "System.IdentityModel.Tokens.Jwt": "6.17.0", - "System.Linq.Async": "6.0.1" + "System.Linq.Async": "6.0.1", + "TaskTupleAwaiter": "2.0.0" } } }