C# Refactor For Webhook Queue (#1818)

* Working on Webhook Queue Changes.

* Initial Push w/ Webhook Work.

* Initial Push. Working on Testing.

* Moving InstanceId

* Removing primitives.

* Formatting.

* Fixing formatting?

* Moving GetInstanceId to containers.

* Moving comments for formatting.

* Removing unused dependency.

* Comments.

* Add WebhookEventGrid test.

* Fixing how tests work.

* Working to resolve conflicts.

* Resolving conflicts.

* Fixing chagnes.

* Tested code.

* Formatting.

* MoreFormatting.

* More formatting.

* Fixing syntax.

* Fixing syntax.

* Removing test and webhookmessagelogoperation class.

* Using config.

* Fixing ProxyOperations.
This commit is contained in:
Noah McGregor Harper
2022-04-26 13:07:10 -07:00
committed by GitHub
parent af7d815e4f
commit 0b1c7aea9c
18 changed files with 1533 additions and 1352 deletions

View File

@ -44,4 +44,4 @@
<CopyToPublishDirectory>Never</CopyToPublishDirectory> <CopyToPublishDirectory>Never</CopyToPublishDirectory>
</None> </None>
</ItemGroup> </ItemGroup>
</Project> </Project>

View File

@ -47,18 +47,18 @@ public class Request {
public async Task<HttpResponseMessage> Post(Uri url, String json, IDictionary<string, string>? headers = null) { public async Task<HttpResponseMessage> Post(Uri url, String json, IDictionary<string, string>? headers = null) {
using var b = new StringContent(json); using var b = new StringContent(json);
b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json"); b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
return await Send(method: HttpMethod.Post, url: url, headers: headers); return await Send(method: HttpMethod.Post, content: b, url: url, headers: headers);
} }
public async Task<HttpResponseMessage> Put(Uri url, String json, IDictionary<string, string>? headers = null) { public async Task<HttpResponseMessage> Put(Uri url, String json, IDictionary<string, string>? headers = null) {
using var b = new StringContent(json); using var b = new StringContent(json);
b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json"); b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
return await Send(method: HttpMethod.Put, url: url, headers: headers); return await Send(method: HttpMethod.Put, content: b, url: url, headers: headers);
} }
public async Task<HttpResponseMessage> Patch(Uri url, String json, IDictionary<string, string>? headers = null) { public async Task<HttpResponseMessage> Patch(Uri url, String json, IDictionary<string, string>? headers = null) {
using var b = new StringContent(json); using var b = new StringContent(json);
b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json"); b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
return await Send(method: HttpMethod.Patch, url: url, headers: headers); return await Send(method: HttpMethod.Patch, content: b, url: url, headers: headers);
} }
} }

View File

@ -105,7 +105,7 @@ public interface ILogTracer {
void Critical(string message); void Critical(string message);
void Error(string message); void Error(string message);
void Event(string evt, IReadOnlyDictionary<string, double>? metrics); void Event(string evt, IReadOnlyDictionary<string, double>? metrics);
void Exception(Exception ex, IReadOnlyDictionary<string, double>? metrics); void Exception(Exception ex, IReadOnlyDictionary<string, double>? metrics = null);
void ForceFlush(); void ForceFlush();
void Info(string message); void Info(string message);
void Warning(string message); void Warning(string message);

View File

@ -45,6 +45,7 @@ public abstract record BaseEvent() {
this switch { this switch {
EventNodeHeartbeat _ => EventType.NodeHeartbeat, EventNodeHeartbeat _ => EventType.NodeHeartbeat,
EventTaskHeartbeat _ => EventType.TaskHeartbeat, EventTaskHeartbeat _ => EventType.TaskHeartbeat,
EventPing _ => EventType.Ping,
EventInstanceConfigUpdated _ => EventType.InstanceConfigUpdated, EventInstanceConfigUpdated _ => EventType.InstanceConfigUpdated,
EventProxyCreated _ => EventType.ProxyCreated, EventProxyCreated _ => EventType.ProxyCreated,
EventProxyDeleted _ => EventType.ProxyDeleted, EventProxyDeleted _ => EventType.ProxyDeleted,
@ -66,6 +67,7 @@ public abstract record BaseEvent() {
EventType.NodeHeartbeat => typeof(EventNodeHeartbeat), EventType.NodeHeartbeat => typeof(EventNodeHeartbeat),
EventType.InstanceConfigUpdated => typeof(EventInstanceConfigUpdated), EventType.InstanceConfigUpdated => typeof(EventInstanceConfigUpdated),
EventType.TaskHeartbeat => typeof(EventTaskHeartbeat), EventType.TaskHeartbeat => typeof(EventTaskHeartbeat),
EventType.Ping => typeof(EventPing),
EventType.ProxyCreated => typeof(EventProxyCreated), EventType.ProxyCreated => typeof(EventProxyCreated),
EventType.ProxyDeleted => typeof(EventProxyDeleted), EventType.ProxyDeleted => typeof(EventProxyDeleted),
EventType.ProxyFailed => typeof(EventProxyFailed), EventType.ProxyFailed => typeof(EventProxyFailed),
@ -151,11 +153,9 @@ public record EventTaskHeartbeat(
TaskConfig Config TaskConfig Config
) : BaseEvent(); ) : BaseEvent();
public record EventPing(
//record EventPing( Guid PingId
// PingId: Guid ) : BaseEvent();
//): BaseEvent();
//record EventScalesetCreated( //record EventScalesetCreated(
// Guid ScalesetId, // Guid ScalesetId,
@ -295,7 +295,7 @@ public record EventMessage(
BaseEvent Event, BaseEvent Event,
Guid InstanceId, Guid InstanceId,
String InstanceName String InstanceName
) : EntityBase(); );
public class BaseEventConverter : JsonConverter<BaseEvent> { public class BaseEventConverter : JsonConverter<BaseEvent> {
public override BaseEvent? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { public override BaseEvent? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {

View File

@ -245,7 +245,6 @@ public record NodeAssignment(
public record Task( public record Task(
// Timestamp: Optional[datetime] = Field(alias="Timestamp")
[PartitionKey] Guid JobId, [PartitionKey] Guid JobId,
[RowKey] Guid TaskId, [RowKey] Guid TaskId,
TaskState State, TaskState State,
@ -307,13 +306,13 @@ public record ApiAccessRule(
Guid[] AllowedGroups Guid[] AllowedGroups
); );
//# initial set of admins can only be set during deployment.
//# if admins are set, only admins can update instance configs.
//# if set, only admins can manage pools or scalesets
public record InstanceConfig public record InstanceConfig
( (
[PartitionKey, RowKey] string InstanceName, [PartitionKey, RowKey] string InstanceName,
//# initial set of admins can only be set during deployment.
//# if admins are set, only admins can update instance configs.
Guid[]? Admins, Guid[]? Admins,
//# if set, only admins can manage pools or scalesets
bool? AllowPoolManagement, bool? AllowPoolManagement,
string[] AllowedAadTenants, string[] AllowedAadTenants,
NetworkConfig NetworkConfig, NetworkConfig NetworkConfig,
@ -490,20 +489,20 @@ public record Repro(
UserInfo? UserInfo UserInfo? UserInfo
) : StatefulEntityBase<VmState>(State); ) : StatefulEntityBase<VmState>(State);
// TODO: Make this >1 and < 7*24 (more than one hour, less than seven days)
public record ReproConfig( public record ReproConfig(
Container Container, Container Container,
string Path, string Path,
// TODO: Make this >1 and < 7*24 (more than one hour, less than seven days)
int Duration int Duration
); );
// Skipping AutoScaleConfig because it's not used anymore
public record Pool( public record Pool(
DateTimeOffset Timestamp, DateTimeOffset Timestamp,
PoolName Name, PoolName Name,
Guid PoolId, Guid PoolId,
Os Os, Os Os,
bool Managed, bool Managed,
// Skipping AutoScaleConfig because it's not used anymore
Architecture Architecture, Architecture Architecture,
PoolState State, PoolState State,
Guid? ClientId, Guid? ClientId,

View File

@ -9,6 +9,11 @@ public enum WebhookMessageFormat {
EventGrid EventGrid
} }
public record WebhookMessageQueueObj(
Guid WebhookId,
Guid EventId
);
public record WebhookMessage(Guid EventId, public record WebhookMessage(Guid EventId,
EventType EventType, EventType EventType,
BaseEvent Event, BaseEvent Event,
@ -27,24 +32,18 @@ public record WebhookMessageEventGrid(
[property: JsonConverter(typeof(BaseEventConverter))] [property: JsonConverter(typeof(BaseEventConverter))]
BaseEvent Data); BaseEvent Data);
// TODO: This should inherit from Entity Base ? no, since there is
// a table WebhookMessaageLog
public record WebhookMessageLog( public record WebhookMessageLog(
[RowKey] Guid EventId, [RowKey] Guid EventId,
EventType EventType, EventType EventType,
[property: TypeDiscrimnatorAttribute("EventType", typeof(EventTypeProvider))]
[property: JsonConverter(typeof(BaseEventConverter))]
BaseEvent Event, BaseEvent Event,
Guid InstanceId, Guid InstanceId,
String InstanceName, String InstanceName,
[PartitionKey] Guid WebhookId, [PartitionKey] Guid WebhookId,
WebhookMessageState State = WebhookMessageState.Queued, long TryCount,
int TryCount = 0 WebhookMessageState State = WebhookMessageState.Queued
) : WebhookMessage(EventId, ) : EntityBase();
EventType,
Event,
InstanceId,
InstanceName,
WebhookId);
public record Webhook( public record Webhook(
[PartitionKey] Guid WebhookId, [PartitionKey] Guid WebhookId,

View File

@ -45,7 +45,8 @@ public class Program {
return loggers; return loggers;
} }
//Move out expensive resources into separate class, and add those as Singleton
// ArmClient, Table Client(s), Queue Client(s), HttpClient, etc.
public static void Main() { public static void Main() {
var host = new HostBuilder() var host = new HostBuilder()
.ConfigureFunctionsWorkerDefaults( .ConfigureFunctionsWorkerDefaults(
@ -82,7 +83,7 @@ public class Program {
.AddScoped<IConfig, Config>() .AddScoped<IConfig, Config>()
//Move out expensive resources into separate class, and add those as Singleton //Move out expensive resources into separate class, and add those as Singleton
// ArmClient, Table Client(s), Queue Client(s), HttpClient, etc. // ArmClient, Table Client(s), Queue Client(s), HttpClient, etc.\
.AddSingleton<ICreds, Creds>() .AddSingleton<ICreds, Creds>()
.AddSingleton<IServiceConfig, ServiceConfiguration>() .AddSingleton<IServiceConfig, ServiceConfiguration>()
.AddHttpClient() .AddHttpClient()

View File

@ -0,0 +1,24 @@
using System.Text.Json;
using Microsoft.Azure.Functions.Worker;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public class QueueWebhooks {
private readonly ILogTracer _log;
private readonly IWebhookMessageLogOperations _webhookMessageLog;
public QueueWebhooks(ILogTracer log, IWebhookMessageLogOperations webhookMessageLog) {
_log = log;
_webhookMessageLog = webhookMessageLog;
}
[Function("QueueWebhooks")]
public async Async.Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg) {
_log.Info($"Webhook Message Queued: {msg}");
var obj = JsonSerializer.Deserialize<WebhookMessageQueueObj>(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}");
await _webhookMessageLog.ProcessFromQueue(obj);
}
}

View File

@ -35,7 +35,7 @@ public interface IServiceConfig {
public string? OneFuzzResourceGroup { get; } public string? OneFuzzResourceGroup { get; }
public string? OneFuzzTelemetry { get; } public string? OneFuzzTelemetry { get; }
public string OnefuzzVersion { get; } public string OneFuzzVersion { get; }
} }
public class ServiceConfiguration : IServiceConfig { public class ServiceConfiguration : IServiceConfig {
@ -77,5 +77,5 @@ public class ServiceConfiguration : IServiceConfig {
public string? OneFuzzOwner { get => Environment.GetEnvironmentVariable("ONEFUZZ_OWNER"); } public string? OneFuzzOwner { get => Environment.GetEnvironmentVariable("ONEFUZZ_OWNER"); }
public string? OneFuzzResourceGroup { get => Environment.GetEnvironmentVariable("ONEFUZZ_RESOURCE_GROUP"); } public string? OneFuzzResourceGroup { get => Environment.GetEnvironmentVariable("ONEFUZZ_RESOURCE_GROUP"); }
public string? OneFuzzTelemetry { get => Environment.GetEnvironmentVariable("ONEFUZZ_TELEMETRY"); } public string? OneFuzzTelemetry { get => Environment.GetEnvironmentVariable("ONEFUZZ_TELEMETRY"); }
public string OnefuzzVersion { get => Environment.GetEnvironmentVariable("ONEFUZZ_VERSION") ?? "0.0.0"; } public string OneFuzzVersion { get => Environment.GetEnvironmentVariable("ONEFUZZ_VERSION") ?? "0.0.0"; }
} }

View File

@ -1,125 +1,124 @@
using System.Threading.Tasks; using System.Threading.Tasks;
using Azure; using Azure;
using Azure.ResourceManager; using Azure.ResourceManager;
using Azure.Storage; using Azure.Storage;
using Azure.Storage.Blobs; using Azure.Storage.Blobs;
using Azure.Storage.Sas; using Azure.Storage.Sas;
namespace Microsoft.OneFuzz.Service; namespace Microsoft.OneFuzz.Service;
public interface IContainers { public interface IContainers {
public Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType); public Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType);
public Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType); public Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType);
public Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null); public Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null);
Async.Task saveBlob(Container container, string v1, string v2, StorageType config); Async.Task saveBlob(Container container, string v1, string v2, StorageType config);
Task<Guid> GetInstanceId(); Task<Guid> GetInstanceId();
} }
public class Containers : IContainers { public class Containers : IContainers {
private ILogTracer _log; private ILogTracer _log;
private IStorage _storage; private IStorage _storage;
private ICreds _creds; private ICreds _creds;
private ArmClient _armClient; private ArmClient _armClient;
public Containers(ILogTracer log, IStorage storage, ICreds creds) { public Containers(ILogTracer log, IStorage storage, ICreds creds) {
_log = log; _log = log;
_storage = storage; _storage = storage;
_creds = creds; _creds = creds;
_armClient = creds.ArmClient; _armClient = creds.ArmClient;
} }
public async Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType) { public async Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType) {
var client = await FindContainer(container, storageType); var client = await FindContainer(container, storageType);
if (client == null) { if (client == null) {
return null; return null;
} }
try { try {
return (await client.GetBlobClient(name).DownloadContentAsync()) return (await client.GetBlobClient(name).DownloadContentAsync())
.Value.Content; .Value.Content;
} catch (RequestFailedException) { } catch (RequestFailedException) {
return null; return null;
} }
} }
public async Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType) { public async Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType) {
// # check secondary accounts first by searching in reverse. // # check secondary accounts first by searching in reverse.
// # // #
// # By implementation, the primary account is specified first, followed by // # By implementation, the primary account is specified first, followed by
// # any secondary accounts. // # any secondary accounts.
// # // #
// # Secondary accounts, if they exist, are preferred for containers and have // # Secondary accounts, if they exist, are preferred for containers and have
// # increased IOP rates, this should be a slight optimization // # increased IOP rates, this should be a slight optimization
return await _storage.GetAccounts(storageType) return await _storage.GetAccounts(storageType)
.Reverse() .Reverse()
.Select(account => GetBlobService(account)?.GetBlobContainerClient(container.ContainerName)) .Select(account => GetBlobService(account)?.GetBlobContainerClient(container.ContainerName))
.ToAsyncEnumerable() .ToAsyncEnumerable()
.WhereAwait(async client => client != null && (await client.ExistsAsync()).Value) .WhereAwait(async client => client != null && (await client.ExistsAsync()).Value)
.FirstOrDefaultAsync(); .FirstOrDefaultAsync();
} }
private BlobServiceClient? GetBlobService(string accountId) { private BlobServiceClient? GetBlobService(string accountId) {
_log.Info($"getting blob container (account_id: {accountId}"); _log.Info($"getting blob container (account_id: {accountId}");
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(accountId); var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(accountId);
if (accountName == null) { if (accountName == null) {
_log.Error("Failed to get storage account name"); _log.Error("Failed to get storage account name");
return null; return null;
} }
var storageKeyCredential = new StorageSharedKeyCredential(accountName, accountKey); var storageKeyCredential = new StorageSharedKeyCredential(accountName, accountKey);
var accountUrl = GetUrl(accountName); var accountUrl = GetUrl(accountName);
return new BlobServiceClient(accountUrl, storageKeyCredential); return new BlobServiceClient(accountUrl, storageKeyCredential);
} }
private static Uri GetUrl(string accountName) { private static Uri GetUrl(string accountName) {
return new Uri($"https://{accountName}.blob.core.windows.net/"); return new Uri($"https://{accountName}.blob.core.windows.net/");
} }
public async Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null) { public async Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null) {
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}"); var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(client.AccountName); var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(client.AccountName);
var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30)); var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30));
var sasBuilder = new BlobSasBuilder(permissions, endTime) { var sasBuilder = new BlobSasBuilder(permissions, endTime) {
StartsOn = startTime, StartsOn = startTime,
BlobContainerName = container.ContainerName, BlobContainerName = container.ContainerName,
BlobName = name BlobName = name
}; };
var sasUrl = client.GetBlobClient(name).GenerateSasUri(sasBuilder); var sasUrl = client.GetBlobClient(name).GenerateSasUri(sasBuilder);
return sasUrl; return sasUrl;
} }
public (DateTimeOffset, DateTimeOffset) SasTimeWindow(TimeSpan timeSpan) { public (DateTimeOffset, DateTimeOffset) SasTimeWindow(TimeSpan timeSpan) {
// SAS URLs are valid 6 hours earlier, primarily to work around dev // SAS URLs are valid 6 hours earlier, primarily to work around dev
// workstations having out-of-sync time. Additionally, SAS URLs are stopped // workstations having out-of-sync time. Additionally, SAS URLs are stopped
// 15 minutes later than requested based on "Be careful with SAS start time" // 15 minutes later than requested based on "Be careful with SAS start time"
// guidance. // guidance.
// Ref: https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview // Ref: https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
var SAS_START_TIME_DELTA = TimeSpan.FromHours(6); var SAS_START_TIME_DELTA = TimeSpan.FromHours(6);
var SAS_END_TIME_DELTA = TimeSpan.FromMinutes(6); var SAS_END_TIME_DELTA = TimeSpan.FromMinutes(6);
var now = DateTimeOffset.UtcNow; var now = DateTimeOffset.UtcNow;
var start = now - SAS_START_TIME_DELTA; var start = now - SAS_START_TIME_DELTA;
var expiry = now + timeSpan + SAS_END_TIME_DELTA; var expiry = now + timeSpan + SAS_END_TIME_DELTA;
return (start, expiry); return (start, expiry);
} }
public async System.Threading.Tasks.Task saveBlob(Container container, string name, string data, StorageType storageType) { public async System.Threading.Tasks.Task saveBlob(Container container, string name, string data, StorageType storageType) {
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}"); var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
await client.UploadBlobAsync(name, new BinaryData(data)); await client.UploadBlobAsync(name, new BinaryData(data));
} }
public async Async.Task<Guid> GetInstanceId() { public async Async.Task<Guid> GetInstanceId() {
var blob = await GetBlob(new Container("base-config"), "instance_id", StorageType.Config); var blob = await GetBlob(new Container("base-config"), "instance_id", StorageType.Config);
if (blob == null) { if (blob == null) {
throw new System.Exception("Blob Not Found"); throw new System.Exception("Blob Not Found");
} }
return System.Guid.Parse(blob.ToString()); return System.Guid.Parse(blob.ToString());
} }
} }

View File

@ -14,6 +14,7 @@ public interface ICreds {
public ResourceIdentifier GetResourceGroupResourceIdentifier(); public ResourceIdentifier GetResourceGroupResourceIdentifier();
public string GetInstanceName();
public ArmClient ArmClient { get; } public ArmClient ArmClient { get; }
@ -59,6 +60,13 @@ public class Creds : ICreds {
return new ResourceIdentifier(resourceId); return new ResourceIdentifier(resourceId);
} }
public string GetInstanceName() {
var instanceName = _config.OneFuzzInstanceName
?? throw new System.Exception("Instance Name env var is not present");
return instanceName;
}
public ResourceGroupResource GetResourceGroupResource() { public ResourceGroupResource GetResourceGroupResource() {
var resourceId = GetResourceGroupResourceIdentifier(); var resourceId = GetResourceGroupResourceIdentifier();
return ArmClient.GetResourceGroupResource(resourceId); return ArmClient.GetResourceGroupResource(resourceId);

View File

@ -55,5 +55,4 @@ public partial class TimerProxy {
} }
} }

View File

@ -55,7 +55,7 @@ public class ProxyOperations : StatefulOrm<Proxy, VmState>, IProxyOperations {
} }
_logTracer.Info($"creating proxy: region:{region}"); _logTracer.Info($"creating proxy: region:{region}");
var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, Auth.BuildAuth(), null, null, _config.OnefuzzVersion, null, false); var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, Auth.BuildAuth(), null, null, _config.OneFuzzVersion, null, false);
await Replace(newProxy); await Replace(newProxy);
await _events.SendEvent(new EventProxyCreated(region, newProxy.ProxyId)); await _events.SendEvent(new EventProxyCreated(region, newProxy.ProxyId));
@ -83,8 +83,8 @@ public class ProxyOperations : StatefulOrm<Proxy, VmState>, IProxyOperations {
return false; return false;
} }
if (proxy.Version != _config.OnefuzzVersion) { if (proxy.Version != _config.OneFuzzVersion) {
_logTracer.Info($"mismatch version: proxy:{proxy.Version} service:{_config.OnefuzzVersion} state:{proxy.State}"); _logTracer.Info($"mismatch version: proxy:{proxy.Version} service:{_config.OneFuzzVersion} state:{proxy.State}");
return true; return true;
} }

View File

@ -14,7 +14,6 @@ public interface IStorage {
public IEnumerable<string> CorpusAccounts(); public IEnumerable<string> CorpusAccounts();
string GetPrimaryAccount(StorageType storageType); string GetPrimaryAccount(StorageType storageType);
public (string?, string?) GetStorageAccountNameAndKey(string accountId); public (string?, string?) GetStorageAccountNameAndKey(string accountId);
public IEnumerable<string> GetAccounts(StorageType storageType); public IEnumerable<string> GetAccounts(StorageType storageType);
} }
@ -98,6 +97,28 @@ public class Storage : IStorage {
return (resourceId.Name, key?.Value); return (resourceId.Name, key?.Value);
} }
public string ChooseAccounts(StorageType storageType) {
var accounts = GetAccounts(storageType);
if (!accounts.Any()) {
throw new Exception($"No Storage Accounts for {storageType}");
}
var account_list = accounts.ToList();
if (account_list.Count == 1) {
return account_list[0];
}
// Use a random secondary storage account if any are available. This
// reduces IOP contention for the Storage Queues, which are only available
// on primary accounts
//
// security note: this is not used as a security feature
var random = new Random();
var index = random.Next(account_list.Count);
return account_list[index]; // nosec
}
public IEnumerable<string> GetAccounts(StorageType storageType) { public IEnumerable<string> GetAccounts(StorageType storageType) {
switch (storageType) { switch (storageType) {
case StorageType.Corpus: case StorageType.Corpus:

View File

@ -1,106 +1,253 @@
using ApiService.OneFuzzLib.Orm; using System.Net;
using System.Net.Http;
namespace Microsoft.OneFuzz.Service; using System.Security.Cryptography;
using System.Text.Json;
using ApiService.OneFuzzLib.Orm;
public interface IWebhookMessageLogOperations : IOrm<WebhookMessageLog> { using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
IAsyncEnumerable<WebhookMessageLog> SearchExpired();
} namespace Microsoft.OneFuzz.Service;
public interface IWebhookOperations {
public class WebhookMessageLogOperations : Orm<WebhookMessageLog>, IWebhookMessageLogOperations { Async.Task SendEvent(EventMessage eventMessage);
const int EXPIRE_DAYS = 7; Async.Task<Webhook?> GetByWebhookId(Guid webhookId);
Async.Task<bool> Send(WebhookMessageLog messageLog);
record WebhookMessageQueueObj( }
Guid WebhookId,
Guid EventId public class WebhookOperations : Orm<Webhook>, IWebhookOperations {
);
private readonly IWebhookMessageLogOperations _webhookMessageLogOperations;
private readonly IQueue _queue; private readonly ILogTracer _log;
private readonly ILogTracer _log; private readonly ICreds _creds;
public WebhookMessageLogOperations(IStorage storage, IQueue queue, ILogTracer log, IServiceConfig config) : base(storage, log, config) { private readonly IContainers _containers;
_queue = queue; private readonly IHttpClientFactory _httpFactory;
_log = log;
} public WebhookOperations(IHttpClientFactory httpFactory, ICreds creds, IStorage storage, IWebhookMessageLogOperations webhookMessageLogOperations, IContainers containers, ILogTracer log, IServiceConfig config)
: base(storage, log, config) {
_webhookMessageLogOperations = webhookMessageLogOperations;
public async Async.Task QueueWebhook(WebhookMessageLog webhookLog) { _log = log;
var obj = new WebhookMessageQueueObj(webhookLog.WebhookId, webhookLog.EventId); _creds = creds;
_containers = containers;
TimeSpan? visibilityTimeout = webhookLog.State switch { _httpFactory = httpFactory;
WebhookMessageState.Queued => TimeSpan.Zero, }
WebhookMessageState.Retrying => TimeSpan.FromSeconds(30),
_ => null async public Async.Task SendEvent(EventMessage eventMessage) {
}; await foreach (var webhook in GetWebhooksCached()) {
if (!webhook.EventTypes.Contains(eventMessage.EventType)) {
if (visibilityTimeout == null) { continue;
_log.WithTags( }
new[] { await AddEvent(webhook, eventMessage);
("WebhookId", webhookLog.WebhookId.ToString()), }
("EventId", webhookLog.EventId.ToString()) } }
).
Error($"invalid WebhookMessage queue state, not queuing. {webhookLog.WebhookId}:{webhookLog.EventId} - {webhookLog.State}"); async private Async.Task AddEvent(Webhook webhook, EventMessage eventMessage) {
} else { var message = new WebhookMessageLog(
await _queue.QueueObject("webhooks", obj, StorageType.Config, visibilityTimeout: visibilityTimeout); EventId: eventMessage.EventId,
} EventType: eventMessage.EventType,
} Event: eventMessage.Event,
InstanceId: eventMessage.InstanceId,
private void QueueObject(string v, WebhookMessageQueueObj obj, StorageType config, int? visibility_timeout) { InstanceName: eventMessage.InstanceName,
throw new NotImplementedException(); WebhookId: webhook.WebhookId,
} TryCount: 0
);
public IAsyncEnumerable<WebhookMessageLog> SearchExpired() {
var expireTime = (DateTimeOffset.UtcNow - TimeSpan.FromDays(EXPIRE_DAYS)).ToString("o"); var r = await _webhookMessageLogOperations.Replace(message);
if (!r.IsOk) {
var timeFilter = $"Timestamp lt datetime'{expireTime}'"; var (status, reason) = r.ErrorV;
return QueryAsync(filter: timeFilter); _log.Error($"Failed to replace webhook message log due to [{status}] {reason}");
} }
} }
public async Async.Task<bool> Send(WebhookMessageLog messageLog) {
public interface IWebhookOperations { var webhook = await GetByWebhookId(messageLog.WebhookId);
Async.Task SendEvent(EventMessage eventMessage); if (webhook == null || webhook.Url == null) {
} throw new Exception($"Invalid Webhook. Webhook with WebhookId: {messageLog.WebhookId} Not Found");
}
public class WebhookOperations : Orm<Webhook>, IWebhookOperations {
private readonly IWebhookMessageLogOperations _webhookMessageLogOperations; var (data, digest) = await BuildMessage(webhookId: webhook.WebhookId, eventId: messageLog.EventId, eventType: messageLog.EventType, webhookEvent: messageLog.Event, secretToken: webhook.SecretToken, messageFormat: webhook.MessageFormat);
private readonly ILogTracer _log;
public WebhookOperations(IStorage storage, IWebhookMessageLogOperations webhookMessageLogOperations, ILogTracer log, IServiceConfig config) var headers = new Dictionary<string, string> { { "User-Agent", $"onefuzz-webhook {_config.OneFuzzVersion}" } };
: base(storage, log, config) {
_webhookMessageLogOperations = webhookMessageLogOperations; if (digest != null) {
_log = log; headers["X-Onefuzz-Digest"] = digest;
} }
async public Async.Task SendEvent(EventMessage eventMessage) { var client = new Request(_httpFactory.CreateClient());
await foreach (var webhook in GetWebhooksCached()) { _log.Info(data);
if (!webhook.EventTypes.Contains(eventMessage.EventType)) { var response = client.Post(url: webhook.Url, json: data, headers: headers);
continue; var result = response.Result;
} if (result.StatusCode == HttpStatusCode.Accepted) {
await AddEvent(webhook, eventMessage); return true;
} }
} return false;
}
async private Async.Task AddEvent(Webhook webhook, EventMessage eventMessage) {
var message = new WebhookMessageLog( // Not converting to bytes, as it's not neccessary in C#. Just keeping as string.
EventId: eventMessage.EventId, public async Async.Task<Tuple<string, string?>> BuildMessage(Guid webhookId, Guid eventId, EventType eventType, BaseEvent webhookEvent, String? secretToken, WebhookMessageFormat? messageFormat) {
EventType: eventMessage.EventType, var entityConverter = new EntityConverter();
Event: eventMessage.Event, string data = "";
InstanceId: eventMessage.InstanceId, if (messageFormat != null && messageFormat == WebhookMessageFormat.EventGrid) {
InstanceName: eventMessage.InstanceName, var eventGridMessage = new[] { new WebhookMessageEventGrid(Id: eventId, Data: webhookEvent, DataVersion: "1.0.0", Subject: _creds.GetInstanceName(), EventType: eventType, EventTime: DateTimeOffset.UtcNow) };
WebhookId: webhook.WebhookId data = JsonSerializer.Serialize(eventGridMessage, options: EntityConverter.GetJsonSerializerOptions());
); } else {
var instanceId = await _containers.GetInstanceId();
var r = await _webhookMessageLogOperations.Replace(message); var webhookMessage = new WebhookMessage(WebhookId: webhookId, EventId: eventId, EventType: eventType, Event: webhookEvent, InstanceId: instanceId, InstanceName: _creds.GetInstanceName());
if (!r.IsOk) {
var (status, reason) = r.ErrorV; data = JsonSerializer.Serialize(webhookMessage, options: EntityConverter.GetJsonSerializerOptions());
_log.Error($"Failed to replace webhook message log due to [{status}] {reason}"); }
}
} string? digest = null;
var hmac = HMAC.Create("HMACSHA512");
if (secretToken != null && hmac != null) {
//todo: caching hmac.Key = System.Text.Encoding.UTF8.GetBytes(secretToken);
public IAsyncEnumerable<Webhook> GetWebhooksCached() { digest = Convert.ToHexString(hmac.ComputeHash(System.Text.Encoding.UTF8.GetBytes(data)));
return QueryAsync(); }
} return new Tuple<string, string?>(data, digest);
} }
public async Async.Task<Webhook?> GetByWebhookId(Guid webhookId) {
var data = QueryAsync(filter: $"PartitionKey eq '{webhookId}'");
return await data.FirstOrDefaultAsync();
}
//todo: caching
public IAsyncEnumerable<Webhook> GetWebhooksCached() {
return QueryAsync();
}
}
public interface IWebhookMessageLogOperations : IOrm<WebhookMessageLog> {
IAsyncEnumerable<WebhookMessageLog> SearchExpired();
public Async.Task ProcessFromQueue(WebhookMessageQueueObj obj);
}
public class WebhookMessageLogOperations : Orm<WebhookMessageLog>, IWebhookMessageLogOperations {
const int EXPIRE_DAYS = 7;
const int MAX_TRIES = 5;
private readonly IQueue _queue;
private readonly ILogTracer _log;
private readonly IWebhookOperations _webhook;
public WebhookMessageLogOperations(IStorage storage, IQueue queue, ILogTracer log, IServiceConfig config, ICreds creds, IHttpClientFactory httpFactory, IContainers containers) : base(storage, log, config) {
_queue = queue;
_log = log;
_webhook = new WebhookOperations(httpFactory: httpFactory, creds: creds, storage: storage, webhookMessageLogOperations: this, containers: containers, log: log, config: config);
}
public async Async.Task QueueWebhook(WebhookMessageLog webhookLog) {
var obj = new WebhookMessageQueueObj(webhookLog.WebhookId, webhookLog.EventId);
TimeSpan? visibilityTimeout = webhookLog.State switch {
WebhookMessageState.Queued => TimeSpan.Zero,
WebhookMessageState.Retrying => TimeSpan.FromSeconds(30),
_ => null
};
if (visibilityTimeout == null) {
_log.WithTags(
new[] {
("WebhookId", webhookLog.WebhookId.ToString()),
("EventId", webhookLog.EventId.ToString()) }
).
Error($"invalid WebhookMessage queue state, not queuing. {webhookLog.WebhookId}:{webhookLog.EventId} - {webhookLog.State}");
} else {
await _queue.QueueObject("webhooks", obj, StorageType.Config, visibilityTimeout: visibilityTimeout);
}
}
public async Async.Task ProcessFromQueue(WebhookMessageQueueObj obj) {
var message = await GetWebhookMessageById(obj.WebhookId, obj.EventId);
if (message == null) {
_log.WithTags(
new[] {
("WebhookId", obj.WebhookId.ToString()),
("EventId", obj.EventId.ToString()) }
).
Error($"webhook message log not found for webhookId: {obj.WebhookId} and eventId: {obj.EventId}");
} else {
await Process(message);
}
}
private async System.Threading.Tasks.Task Process(WebhookMessageLog message) {
if (message.State == WebhookMessageState.Failed || message.State == WebhookMessageState.Succeeded) {
_log.WithTags(
new[] {
("WebhookId", message.WebhookId.ToString()),
("EventId", message.EventId.ToString()) }
).
Error($"webhook message already handled. {message.WebhookId}:{message.EventId}");
return;
}
var newMessage = message with { TryCount = message.TryCount + 1 };
_log.Info($"sending webhook: {message.WebhookId}:{message.EventId}");
var success = await Send(newMessage);
if (success) {
newMessage = newMessage with { State = WebhookMessageState.Succeeded };
await Replace(newMessage);
_log.Info($"sent webhook event {newMessage.WebhookId}:{newMessage.EventId}");
} else if (newMessage.TryCount < MAX_TRIES) {
newMessage = newMessage with { State = WebhookMessageState.Retrying };
await Replace(newMessage);
await QueueWebhook(newMessage);
_log.Warning($"sending webhook event failed, re-queued {newMessage.WebhookId}:{newMessage.EventId}");
} else {
newMessage = newMessage with { State = WebhookMessageState.Failed };
await Replace(newMessage);
_log.Info($"sending webhook: {newMessage.WebhookId} event: {newMessage.EventId} failed {newMessage.TryCount} times.");
}
}
private async Async.Task<bool> Send(WebhookMessageLog message) {
var webhook = await _webhook.GetByWebhookId(message.WebhookId);
if (webhook == null) {
_log.WithTags(
new[] {
("WebhookId", message.WebhookId.ToString()),
}
).
Error($"webhook not found for webhookId: {message.WebhookId}");
return false;
}
try {
return await _webhook.Send(message);
} catch (Exception exc) {
_log.WithTags(
new[] {
("WebhookId", message.WebhookId.ToString())
}
).
Exception(exc);
return false;
}
}
private void QueueObject(string v, WebhookMessageQueueObj obj, StorageType config, int? visibility_timeout) {
throw new NotImplementedException();
}
public IAsyncEnumerable<WebhookMessageLog> SearchExpired() {
var expireTime = (DateTimeOffset.UtcNow - TimeSpan.FromDays(EXPIRE_DAYS)).ToString("o");
var timeFilter = $"Timestamp lt datetime'{expireTime}'";
return QueryAsync(filter: timeFilter);
}
public async Async.Task<WebhookMessageLog?> GetWebhookMessageById(Guid webhookId, Guid eventId) {
var data = QueryAsync(filter: $"PartitionKey eq '{webhookId}' and RowKey eq '{eventId}'");
return await data.FirstOrDefaultAsync();
}
}

File diff suppressed because it is too large Load Diff

View File

@ -244,9 +244,6 @@ namespace Tests {
); ; ); ;
} }
public static Gen<Report> Report() { public static Gen<Report> Report() {
return Arb.Generate<Tuple<string, BlobRef, List<string>, Guid, int>>().Select( return Arb.Generate<Tuple<string, BlobRef, List<string>, Guid, int>>().Select(
arg => arg =>
@ -373,6 +370,7 @@ namespace Tests {
return Arb.From(OrmGenerators.Notification()); return Arb.From(OrmGenerators.Notification());
} }
public static Arbitrary<WebhookMessageEventGrid> WebhookMessageEventGrid() { public static Arbitrary<WebhookMessageEventGrid> WebhookMessageEventGrid() {
return Arb.From(OrmGenerators.WebhookMessageEventGrid()); return Arb.From(OrmGenerators.WebhookMessageEventGrid());
} }
@ -548,7 +546,6 @@ namespace Tests {
return Test(n); return Test(n);
} }
[Property] [Property]
public bool Job(Job j) { public bool Job(Job j) {
return Test(j); return Test(j);

View File

@ -246,19 +246,6 @@ namespace Tests {
Assert.Equal(expected.TheName, actual.TheName); Assert.Equal(expected.TheName, actual.TheName);
} }
[Fact]
public void TestEventSerialization2() {
var converter = new EntityConverter();
var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), "test Poool"), Guid.NewGuid(), "test") {
ETag = new Azure.ETag("33a64df551425fcc55e4d42a148795d9f25f89d4")
};
var te = converter.ToTableEntity(expectedEvent);
var actualEvent = converter.ToRecord<EventMessage>(te);
Assert.Equal(expectedEvent, actualEvent);
}
record Entity3( record Entity3(
[PartitionKey] int Id, [PartitionKey] int Id,
[RowKey] string TheName, [RowKey] string TheName,