Migrating QueueTaskHeartbeat (#1777)

* Migrating QueueTaskHeartbeat

* changing the name of the input queue

* rename type alias Tasks to Async

* Fix property casing

* fixing types

* Removing IStorageProvider

* fix function name

* address PR comments
This commit is contained in:
Cheick Keita
2022-04-13 17:19:44 -07:00
committed by GitHub
parent 31b9163514
commit 9d8d3327d2
14 changed files with 286 additions and 120 deletions

View File

@ -43,4 +43,67 @@ public enum WebhookMessageState
Retrying, Retrying,
Succeeded, Succeeded,
Failed Failed
}
public enum TaskState
{
Init,
Waiting,
Scheduled,
Setting_up,
Running,
Stopping,
Stopped,
WaitJob
}
public enum TaskType
{
Coverage,
LibfuzzerFuzz,
LibfuzzerCoverage,
LibfuzzerCrashReport,
LibfuzzerMerge,
LibfuzzerRegression,
GenericAnalysis,
GenericSupervisor,
GenericMerge,
GenericGenerator,
GenericCrashReport,
GenericRegression
}
public enum Os
{
Windows,
Linux
}
public enum ContainerType
{
Analysis,
Coverage,
Crashes,
Inputs,
NoRepro,
ReadonlyInputs,
Reports,
Setup,
Tools,
UniqueInputs,
UniqueReports,
RegressionReports,
Logs
}
public enum StatsFormat
{
AFL
}
public enum TaskDebugFlag
{
KeepNodeOnFailure,
KeepNodeOnCompletion,
} }

View File

@ -104,11 +104,11 @@ namespace Microsoft.OneFuzz.Service
// ) : BaseEvent(); // ) : BaseEvent();
//record EventTaskHeartbeat( record EventTaskHeartbeat(
// JobId: Guid, Guid JobId,
// TaskId: Guid, Guid TaskId,
// Config: TaskConfig TaskConfig Config
//): BaseEvent(); ) : BaseEvent();
//record EventPing( //record EventPing(

View File

@ -2,6 +2,8 @@ using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using PoolName = System.String; using PoolName = System.String;
using Region = System.String;
using Container = System.String;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
@ -29,9 +31,15 @@ public enum HeartbeatType
TaskAlive, TaskAlive,
} }
public record HeartbeatData(HeartbeatType type); public record HeartbeatData(HeartbeatType Type);
public record NodeHeartbeatEntry(Guid NodeId, HeartbeatData[] data); public record TaskHeartbeatEntry(
Guid TaskId,
Guid? JobId,
Guid MachineId,
HeartbeatData[] Data
);
public record NodeHeartbeatEntry(Guid NodeId, HeartbeatData[] Data);
public record NodeCommandStopIfFree(); public record NodeCommandStopIfFree();
@ -79,7 +87,7 @@ public enum NodeState
public record ProxyHeartbeat public record ProxyHeartbeat
( (
string Region, Region Region,
Guid ProxyId, Guid ProxyId,
List<ProxyForward> Forwards, List<ProxyForward> Forwards,
DateTimeOffset TimeStamp DateTimeOffset TimeStamp
@ -102,35 +110,35 @@ public partial record Node
public partial record ProxyForward public partial record ProxyForward
( (
[PartitionKey] string Region, [PartitionKey] Region Region,
[RowKey] int DstPort, [RowKey] int DstPort,
int SrcPort, int SrcPort,
string DstIp string DstIp
) : EntityBase(); ) : EntityBase();
public partial record ProxyConfig public partial record ProxyConfig
( (
Uri Url, Uri Url,
string Notification, string Notification,
string Region, Region Region,
Guid? ProxyId, Guid? ProxyId,
List<ProxyForward> Forwards, List<ProxyForward> Forwards,
string InstanceTelemetryKey, string InstanceTelemetryKey,
string MicrosoftTelemetryKey string MicrosoftTelemetryKey
); );
public partial record Proxy public partial record Proxy
( (
[PartitionKey] string Region, [PartitionKey] Region Region,
[RowKey] Guid ProxyId, [RowKey] Guid ProxyId,
DateTimeOffset? CreatedTimestamp, DateTimeOffset? CreatedTimestamp,
VmState State, VmState State,
Authentication Auth, Authentication Auth,
string? Ip, string? Ip,
Error? Error, Error? Error,
string Version, string Version,
ProxyHeartbeat? heartbeat ProxyHeartbeat? heartbeat
) : EntityBase(); ) : EntityBase();
@ -148,23 +156,102 @@ public record EventMessage(
) : EntityBase(); ) : EntityBase();
//record AnyHttpUrl(AnyUrl): public record TaskDetails(
// allowed_schemes = {'http', 'https
// TaskType Type,
int Duration,
string? TargetExe,
Dictionary<string, string>? TargetEnv,
List<string>? TargetOptions,
int? TargetWorkers,
bool? TargetOptionsMerge,
bool? CheckAsanLog,
bool? CheckDebugger,
int? CheckRetryCount,
bool? CheckFuzzerHelp,
bool? ExpectCrashOnFailure,
bool? RenameOutput,
string? SupervisorExe,
Dictionary<string, string>? SupervisorEnv,
List<string>? SupervisorOptions,
string? SupervisorInputMarker,
string? GeneratorExe,
Dictionary<string, string>? GeneratorEnv,
List<string>? GeneratorOptions,
string? AnalyzerExe,
Dictionary<string, string>? AnalyzerEnv,
List<string> AnalyzerOptions,
ContainerType? WaitForFiles,
string? StatsFile,
StatsFormat? StatsFormat,
bool? RebootAfterSetup,
int? TargetTimeout,
int? EnsembleSyncDelay,
bool? PreserveExistingOutputs,
List<string>? ReportList,
int? MinimizedStackDepth,
string? CoverageFilter
);
public record TaskVm(
Region Region,
string Sku,
string Image,
int Count,
bool SpotInstance,
bool? RebootAfterSetup
);
public record TaskPool(
int Count,
PoolName PoolName
);
public record TaskContainers(
ContainerType Type,
Container Name
);
public record TaskConfig(
Guid JobId,
List<Guid>? PrereqTasks,
TaskDetails Task,
TaskVm? Vm,
TaskPool? Pool,
List<TaskContainers>? Containers,
Dictionary<string, string>? Tags,
List<TaskDebugFlag>? Debug,
bool? Colocate
);
public record TaskEventSummary(
DateTimeOffset? Timestamp,
string EventData,
string EventType
);
public record NodeAssignment(
Guid NodeId,
Guid? ScalesetId,
NodeTaskState State
);
//public record TaskConfig(
// Guid jobId,
// List<Guid> PrereqTasks,
// TaskDetails Task,
// TaskVm? vm,
// TaskPool pool: Optional[]
// containers: List[TaskContainers]
// tags: Dict[str, str]
// debug: Optional[List[TaskDebugFlag]]
// colocate: Optional[bool]
// ): EntityBase();
public record Task(
// Timestamp: Optional[datetime] = Field(alias="Timestamp")
[PartitionKey] Guid JobId,
[RowKey] Guid TaskId,
TaskState State,
Os Os,
TaskConfig Config,
Error? Error,
Authentication? Auth,
DateTimeOffset? Heartbeat,
DateTimeOffset? EndTime,
UserInfo? UserInfo) : EntityBase()
{
List<TaskEventSummary> Events { get; set; } = new List<TaskEventSummary>();
List<NodeAssignment> Nodes { get; set; } = new List<NodeAssignment>();
}

View File

@ -1,10 +1,14 @@
// to avoid collision with Task in model.cs
global using Async = System.Threading.Tasks;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using ApiService.OneFuzzLib; using ApiService.OneFuzzLib;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
public class Program public class Program
@ -34,11 +38,11 @@ public class Program
.ConfigureServices((context, services) => .ConfigureServices((context, services) =>
services services
.AddSingleton<ILogTracerFactory>(_ => new LogTracerFactory(GetLoggers())) .AddSingleton<ILogTracerFactory>(_ => new LogTracerFactory(GetLoggers()))
.AddSingleton<IStorageProvider>(_ => new StorageProvider(EnvironmentVariables.OneFuzz.FuncStorage ?? throw new InvalidOperationException("Missing account id")))
.AddSingleton<INodeOperations, NodeOperations>() .AddSingleton<INodeOperations, NodeOperations>()
.AddSingleton<IEvents, Events>() .AddSingleton<IEvents, Events>()
.AddSingleton<IWebhookOperations, WebhookOperations>() .AddSingleton<IWebhookOperations, WebhookOperations>()
.AddSingleton<IWebhookMessageLogOperations, WebhookMessageLogOperations>() .AddSingleton<IWebhookMessageLogOperations, WebhookMessageLogOperations>()
.AddSingleton<ITaskOperations, TaskOperations>()
.AddSingleton<IQueue, Queue>() .AddSingleton<IQueue, Queue>()
.AddSingleton<ICreds>(_ => new Creds()) .AddSingleton<ICreds>(_ => new Creds())
.AddSingleton<IStorage, Storage>() .AddSingleton<IStorage, Storage>()

View File

@ -2,7 +2,6 @@ using System;
using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System.Linq; using System.Linq;
@ -15,19 +14,17 @@ public class QueueFileChanges
const int MAX_DEQUEUE_COUNT = 5; const int MAX_DEQUEUE_COUNT = 5;
private readonly ILogTracerFactory _loggerFactory; private readonly ILogTracerFactory _loggerFactory;
private readonly IStorageProvider _storageProvider;
private readonly IStorage _storage; private readonly IStorage _storage;
public QueueFileChanges(ILogTracerFactory loggerFactory, IStorageProvider storageProvider, IStorage storage) public QueueFileChanges(ILogTracerFactory loggerFactory, IStorage storage)
{ {
_loggerFactory = loggerFactory; _loggerFactory = loggerFactory;
_storageProvider = storageProvider;
_storage = storage; _storage = storage;
} }
[Function("QueueFileChanges")] [Function("QueueFileChanges")]
public Task Run( public Async.Task Run(
[QueueTrigger("file-changes-refactored", Connection = "AzureWebJobsStorage")] string msg, [QueueTrigger("file-changes-refactored", Connection = "AzureWebJobsStorage")] string msg,
int dequeueCount) int dequeueCount)
{ {
@ -42,18 +39,18 @@ public class QueueFileChanges
if (!fileChangeEvent.ContainsKey(eventType) if (!fileChangeEvent.ContainsKey(eventType)
|| fileChangeEvent[eventType] != "Microsoft.Storage.BlobCreated") || fileChangeEvent[eventType] != "Microsoft.Storage.BlobCreated")
{ {
return Task.CompletedTask; return Async.Task.CompletedTask;
} }
const string topic = "topic"; const string topic = "topic";
if (!fileChangeEvent.ContainsKey(topic) if (!fileChangeEvent.ContainsKey(topic)
|| !_storage.CorpusAccounts(log).Contains(fileChangeEvent[topic])) || !_storage.CorpusAccounts(log).Contains(fileChangeEvent[topic]))
{ {
return Task.CompletedTask; return Async.Task.CompletedTask;
} }
file_added(log, fileChangeEvent, lastTry); file_added(log, fileChangeEvent, lastTry);
return Task.CompletedTask; return Async.Task.CompletedTask;
} }
private void file_added(ILogTracer log, Dictionary<string, string> fileChangeEvent, bool failTaskOnTransientError) private void file_added(ILogTracer log, Dictionary<string, string> fileChangeEvent, bool failTaskOnTransientError)

View File

@ -1,7 +1,6 @@
using System; using System;
using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
@ -22,7 +21,7 @@ public class QueueNodeHearbeat
} }
[Function("QueueNodeHearbeat")] [Function("QueueNodeHearbeat")]
public async Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg) public async Async.Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg)
{ {
var log = _loggerFactory.MakeLogTracer(Guid.NewGuid()); var log = _loggerFactory.MakeLogTracer(Guid.NewGuid());
log.Info($"heartbeat: {msg}"); log.Info($"heartbeat: {msg}");

View File

@ -1,7 +1,6 @@
using System; using System;
using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker;
using System.Text.Json; using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm; using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
@ -19,7 +18,7 @@ public class QueueProxyHearbeat
} }
[Function("QueueProxyHearbeat")] [Function("QueueProxyHearbeat")]
public async Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg) public async Async.Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg)
{ {
var log = _loggerFactory.MakeLogTracer(Guid.NewGuid()); var log = _loggerFactory.MakeLogTracer(Guid.NewGuid());

View File

@ -0,0 +1,43 @@
using System;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Extensions.Logging;
using System.Text.Json;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public class QueueTaskHearbeat
{
private readonly ILogger _logger;
private readonly IEvents _events;
private readonly ITaskOperations _tasks;
public QueueTaskHearbeat(ILoggerFactory loggerFactory, ITaskOperations tasks, IEvents events)
{
_logger = loggerFactory.CreateLogger<QueueTaskHearbeat>();
_tasks = tasks;
_events = events;
}
[Function("QueueTaskHearbeat")]
public async Async.Task Run([QueueTrigger("myqueue-items2", Connection = "AzureWebJobsStorage")] string msg)
{
_logger.LogInformation($"heartbeat: {msg}");
var hb = JsonSerializer.Deserialize<TaskHeartbeatEntry>(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}");
var task = await _tasks.GetByTaskId(hb.TaskId);
if (task == null)
{
_logger.LogWarning($"invalid task id: {hb.TaskId}");
return;
}
var newTask = task with { Heartbeat = DateTimeOffset.UtcNow };
await _tasks.Replace(newTask);
await _events.SendEvent(new EventTaskHeartbeat(newTask.JobId, newTask.TaskId, newTask.Config));
}
}

View File

@ -5,6 +5,7 @@ using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker.Http; using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.IdentityModel.Tokens; using Microsoft.IdentityModel.Tokens;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
public class UserCredentials public class UserCredentials
@ -53,7 +54,7 @@ public class UserCredentials
static Task<OneFuzzResult<string[]>> GetAllowedTenants() static Task<OneFuzzResult<string[]>> GetAllowedTenants()
{ {
return Task.FromResult(OneFuzzResult<string[]>.Ok(Array.Empty<string>())); return Async.Task.FromResult(OneFuzzResult<string[]>.Ok(Array.Empty<string>()));
} }
/* /*

View File

@ -5,7 +5,6 @@ using System.Collections.Generic;
using System.Text; using System.Text;
using System.Text.Json; using System.Text.Json;
using System.Text.Json.Serialization; using System.Text.Json.Serialization;
using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service namespace Microsoft.OneFuzz.Service
{ {
@ -20,9 +19,9 @@ namespace Microsoft.OneFuzz.Service
public interface IEvents public interface IEvents
{ {
public Task SendEvent(BaseEvent anEvent); public Async.Task SendEvent(BaseEvent anEvent);
public Task QueueSignalrEvent(EventMessage message); public Async.Task QueueSignalrEvent(EventMessage message);
} }
public class Events : IEvents public class Events : IEvents
@ -38,14 +37,14 @@ namespace Microsoft.OneFuzz.Service
_webhook = webhook; _webhook = webhook;
} }
public async Task QueueSignalrEvent(EventMessage eventMessage) public async Async.Task QueueSignalrEvent(EventMessage eventMessage)
{ {
var message = new SignalREvent("events", new List<EventMessage>() { eventMessage }); var message = new SignalREvent("events", new List<EventMessage>() { eventMessage });
var encodedMessage = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(message)); var encodedMessage = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(message));
await _queue.SendMessage("signalr-events", encodedMessage, StorageType.Config); await _queue.SendMessage("signalr-events", encodedMessage, StorageType.Config);
} }
public async Task SendEvent(BaseEvent anEvent) public async Async.Task SendEvent(BaseEvent anEvent)
{ {
var log = _loggerFactory.MakeLogTracer(Guid.NewGuid()); var log = _loggerFactory.MakeLogTracer(Guid.NewGuid());
var eventType = anEvent.GetEventType(); var eventType = anEvent.GetEventType();

View File

@ -8,8 +8,8 @@ using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
public interface IQueue public interface IQueue
{ {
Task SendMessage(string name, byte[] message, StorageType storageType, TimeSpan? visibilityTimeout = null, TimeSpan? timeToLive = null); Async.Task SendMessage(string name, byte[] message, StorageType storageType, TimeSpan? visibilityTimeout = null, TimeSpan? timeToLive = null);
Task<bool> QueueObject<T>(string name, T obj, StorageType storageType, TimeSpan? visibilityTimeout); Async.Task<bool> QueueObject<T>(string name, T obj, StorageType storageType, TimeSpan? visibilityTimeout);
} }
@ -25,7 +25,7 @@ public class Queue : IQueue
} }
public async Task SendMessage(string name, byte[] message, StorageType storageType, TimeSpan? visibilityTimeout = null, TimeSpan? timeToLive = null) public async Async.Task SendMessage(string name, byte[] message, StorageType storageType, TimeSpan? visibilityTimeout = null, TimeSpan? timeToLive = null)
{ {
var queue = GetQueue(name, storageType); var queue = GetQueue(name, storageType);
if (queue != null) if (queue != null)

View File

@ -0,0 +1,28 @@
using ApiService.OneFuzzLib.Orm;
using System;
using System.Linq;
namespace Microsoft.OneFuzz.Service;
public interface ITaskOperations : IOrm<Task>
{
Async.Task<Task?> GetByTaskId(Guid taskId);
}
public class TaskOperations : Orm<Task>, ITaskOperations
{
public TaskOperations(IStorage storage)
: base(storage)
{
}
public async Async.Task<Task?> GetByTaskId(Guid taskId)
{
var data = QueryAsync(filter: $"RowKey eq '{taskId}'");
return await data.FirstOrDefaultAsync();
}
}

View File

@ -2,7 +2,6 @@
using Microsoft.OneFuzz.Service; using Microsoft.OneFuzz.Service;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading.Tasks;
namespace ApiService.OneFuzzLib; namespace ApiService.OneFuzzLib;
@ -29,7 +28,7 @@ public class WebhookMessageLogOperations : Orm<WebhookMessageLog>, IWebhookMessa
} }
public async Task QueueWebhook(WebhookMessageLog webhookLog) public async Async.Task QueueWebhook(WebhookMessageLog webhookLog)
{ {
var log = _loggerFactory.MakeLogTracer(Guid.NewGuid()); var log = _loggerFactory.MakeLogTracer(Guid.NewGuid());
var obj = new WebhookMessageQueueObj(webhookLog.WebhookId, webhookLog.EventId); var obj = new WebhookMessageQueueObj(webhookLog.WebhookId, webhookLog.EventId);
@ -65,7 +64,7 @@ public class WebhookMessageLogOperations : Orm<WebhookMessageLog>, IWebhookMessa
public interface IWebhookOperations public interface IWebhookOperations
{ {
Task SendEvent(EventMessage eventMessage); Async.Task SendEvent(EventMessage eventMessage);
} }
public class WebhookOperations : Orm<Webhook>, IWebhookOperations public class WebhookOperations : Orm<Webhook>, IWebhookOperations
@ -77,7 +76,7 @@ public class WebhookOperations : Orm<Webhook>, IWebhookOperations
_webhookMessageLogOperations = webhookMessageLogOperations; _webhookMessageLogOperations = webhookMessageLogOperations;
} }
async public Task SendEvent(EventMessage eventMessage) async public Async.Task SendEvent(EventMessage eventMessage)
{ {
await foreach (var webhook in GetWebhooksCached()) await foreach (var webhook in GetWebhooksCached())
{ {
@ -89,7 +88,7 @@ public class WebhookOperations : Orm<Webhook>, IWebhookOperations
} }
} }
async private Task AddEvent(Webhook webhook, EventMessage eventMessage) async private Async.Task AddEvent(Webhook webhook, EventMessage eventMessage)
{ {
var message = new WebhookMessageLog( var message = new WebhookMessageLog(
EventId: eventMessage.EventId, EventId: eventMessage.EventId,

View File

@ -1,53 +0,0 @@
using Azure.Data.Tables;
using System;
using System.Linq;
using System.Threading.Tasks;
using Azure.Core;
using Azure.ResourceManager.Storage;
using Azure.ResourceManager;
using Azure.Identity;
namespace Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
public interface IStorageProvider
{
Task<TableClient> GetTableClient(string table);
//IAsyncEnumerable<T> QueryAsync<T>(string filter) where T : EntityBase;
//Task<bool> Replace<T>(T entity) where T : EntityBase;
}
public class StorageProvider : IStorageProvider
{
private readonly string _accountId;
private readonly EntityConverter _entityConverter;
private readonly ArmClient _armClient;
public StorageProvider(string accountId)
{
_accountId = accountId;
_entityConverter = new EntityConverter();
_armClient = new ArmClient(new DefaultAzureCredential());
}
public async Task<TableClient> GetTableClient(string table)
{
var (name, key) = GetStorageAccountNameAndKey(_accountId);
var identifier = new ResourceIdentifier(_accountId);
var tableClient = new TableServiceClient(new Uri($"https://{identifier.Name}.table.core.windows.net"), new TableSharedKeyCredential(name, key));
await tableClient.CreateTableIfNotExistsAsync(table);
return tableClient.GetTableClient(table);
}
public (string?, string?) GetStorageAccountNameAndKey(string accountId)
{
var resourceId = new ResourceIdentifier(accountId);
var storageAccount = _armClient.GetStorageAccountResource(resourceId);
var key = storageAccount.GetKeys().Value.Keys.FirstOrDefault();
return (resourceId.Name, key?.Value);
}
}