C# Refactor For Webhook Queue (#1818)

* Working on Webhook Queue Changes.

* Initial Push w/ Webhook Work.

* Initial Push. Working on Testing.

* Moving InstanceId

* Removing primitives.

* Formatting.

* Fixing formatting?

* Moving GetInstanceId to containers.

* Moving comments for formatting.

* Removing unused dependency.

* Comments.

* Add WebhookEventGrid test.

* Fixing how tests work.

* Working to resolve conflicts.

* Resolving conflicts.

* Fixing chagnes.

* Tested code.

* Formatting.

* MoreFormatting.

* More formatting.

* Fixing syntax.

* Fixing syntax.

* Removing test and webhookmessagelogoperation class.

* Using config.

* Fixing ProxyOperations.
This commit is contained in:
Noah McGregor Harper
2022-04-26 13:07:10 -07:00
committed by GitHub
parent af7d815e4f
commit 0b1c7aea9c
18 changed files with 1533 additions and 1352 deletions

View File

@ -44,4 +44,4 @@
<CopyToPublishDirectory>Never</CopyToPublishDirectory>
</None>
</ItemGroup>
</Project>
</Project>

View File

@ -47,18 +47,18 @@ public class Request {
public async Task<HttpResponseMessage> Post(Uri url, String json, IDictionary<string, string>? headers = null) {
using var b = new StringContent(json);
b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
return await Send(method: HttpMethod.Post, url: url, headers: headers);
return await Send(method: HttpMethod.Post, content: b, url: url, headers: headers);
}
public async Task<HttpResponseMessage> Put(Uri url, String json, IDictionary<string, string>? headers = null) {
using var b = new StringContent(json);
b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
return await Send(method: HttpMethod.Put, url: url, headers: headers);
return await Send(method: HttpMethod.Put, content: b, url: url, headers: headers);
}
public async Task<HttpResponseMessage> Patch(Uri url, String json, IDictionary<string, string>? headers = null) {
using var b = new StringContent(json);
b.Headers.ContentType = MediaTypeHeaderValue.Parse("application/json");
return await Send(method: HttpMethod.Patch, url: url, headers: headers);
return await Send(method: HttpMethod.Patch, content: b, url: url, headers: headers);
}
}

View File

@ -105,7 +105,7 @@ public interface ILogTracer {
void Critical(string message);
void Error(string message);
void Event(string evt, IReadOnlyDictionary<string, double>? metrics);
void Exception(Exception ex, IReadOnlyDictionary<string, double>? metrics);
void Exception(Exception ex, IReadOnlyDictionary<string, double>? metrics = null);
void ForceFlush();
void Info(string message);
void Warning(string message);

View File

@ -45,6 +45,7 @@ public abstract record BaseEvent() {
this switch {
EventNodeHeartbeat _ => EventType.NodeHeartbeat,
EventTaskHeartbeat _ => EventType.TaskHeartbeat,
EventPing _ => EventType.Ping,
EventInstanceConfigUpdated _ => EventType.InstanceConfigUpdated,
EventProxyCreated _ => EventType.ProxyCreated,
EventProxyDeleted _ => EventType.ProxyDeleted,
@ -66,6 +67,7 @@ public abstract record BaseEvent() {
EventType.NodeHeartbeat => typeof(EventNodeHeartbeat),
EventType.InstanceConfigUpdated => typeof(EventInstanceConfigUpdated),
EventType.TaskHeartbeat => typeof(EventTaskHeartbeat),
EventType.Ping => typeof(EventPing),
EventType.ProxyCreated => typeof(EventProxyCreated),
EventType.ProxyDeleted => typeof(EventProxyDeleted),
EventType.ProxyFailed => typeof(EventProxyFailed),
@ -151,11 +153,9 @@ public record EventTaskHeartbeat(
TaskConfig Config
) : BaseEvent();
//record EventPing(
// PingId: Guid
//): BaseEvent();
public record EventPing(
Guid PingId
) : BaseEvent();
//record EventScalesetCreated(
// Guid ScalesetId,
@ -295,7 +295,7 @@ public record EventMessage(
BaseEvent Event,
Guid InstanceId,
String InstanceName
) : EntityBase();
);
public class BaseEventConverter : JsonConverter<BaseEvent> {
public override BaseEvent? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {

View File

@ -245,7 +245,6 @@ public record NodeAssignment(
public record Task(
// Timestamp: Optional[datetime] = Field(alias="Timestamp")
[PartitionKey] Guid JobId,
[RowKey] Guid TaskId,
TaskState State,
@ -307,13 +306,13 @@ public record ApiAccessRule(
Guid[] AllowedGroups
);
//# initial set of admins can only be set during deployment.
//# if admins are set, only admins can update instance configs.
//# if set, only admins can manage pools or scalesets
public record InstanceConfig
(
[PartitionKey, RowKey] string InstanceName,
//# initial set of admins can only be set during deployment.
//# if admins are set, only admins can update instance configs.
Guid[]? Admins,
//# if set, only admins can manage pools or scalesets
bool? AllowPoolManagement,
string[] AllowedAadTenants,
NetworkConfig NetworkConfig,
@ -490,20 +489,20 @@ public record Repro(
UserInfo? UserInfo
) : StatefulEntityBase<VmState>(State);
// TODO: Make this >1 and < 7*24 (more than one hour, less than seven days)
public record ReproConfig(
Container Container,
string Path,
// TODO: Make this >1 and < 7*24 (more than one hour, less than seven days)
int Duration
);
// Skipping AutoScaleConfig because it's not used anymore
public record Pool(
DateTimeOffset Timestamp,
PoolName Name,
Guid PoolId,
Os Os,
bool Managed,
// Skipping AutoScaleConfig because it's not used anymore
Architecture Architecture,
PoolState State,
Guid? ClientId,

View File

@ -9,6 +9,11 @@ public enum WebhookMessageFormat {
EventGrid
}
public record WebhookMessageQueueObj(
Guid WebhookId,
Guid EventId
);
public record WebhookMessage(Guid EventId,
EventType EventType,
BaseEvent Event,
@ -27,24 +32,18 @@ public record WebhookMessageEventGrid(
[property: JsonConverter(typeof(BaseEventConverter))]
BaseEvent Data);
// TODO: This should inherit from Entity Base ? no, since there is
// a table WebhookMessaageLog
public record WebhookMessageLog(
[RowKey] Guid EventId,
EventType EventType,
[property: TypeDiscrimnatorAttribute("EventType", typeof(EventTypeProvider))]
[property: JsonConverter(typeof(BaseEventConverter))]
BaseEvent Event,
Guid InstanceId,
String InstanceName,
[PartitionKey] Guid WebhookId,
WebhookMessageState State = WebhookMessageState.Queued,
int TryCount = 0
) : WebhookMessage(EventId,
EventType,
Event,
InstanceId,
InstanceName,
WebhookId);
long TryCount,
WebhookMessageState State = WebhookMessageState.Queued
) : EntityBase();
public record Webhook(
[PartitionKey] Guid WebhookId,

View File

@ -45,7 +45,8 @@ public class Program {
return loggers;
}
//Move out expensive resources into separate class, and add those as Singleton
// ArmClient, Table Client(s), Queue Client(s), HttpClient, etc.
public static void Main() {
var host = new HostBuilder()
.ConfigureFunctionsWorkerDefaults(
@ -82,7 +83,7 @@ public class Program {
.AddScoped<IConfig, Config>()
//Move out expensive resources into separate class, and add those as Singleton
// ArmClient, Table Client(s), Queue Client(s), HttpClient, etc.
// ArmClient, Table Client(s), Queue Client(s), HttpClient, etc.\
.AddSingleton<ICreds, Creds>()
.AddSingleton<IServiceConfig, ServiceConfiguration>()
.AddHttpClient()

View File

@ -0,0 +1,24 @@
using System.Text.Json;
using Microsoft.Azure.Functions.Worker;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public class QueueWebhooks {
private readonly ILogTracer _log;
private readonly IWebhookMessageLogOperations _webhookMessageLog;
public QueueWebhooks(ILogTracer log, IWebhookMessageLogOperations webhookMessageLog) {
_log = log;
_webhookMessageLog = webhookMessageLog;
}
[Function("QueueWebhooks")]
public async Async.Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg) {
_log.Info($"Webhook Message Queued: {msg}");
var obj = JsonSerializer.Deserialize<WebhookMessageQueueObj>(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}");
await _webhookMessageLog.ProcessFromQueue(obj);
}
}

View File

@ -35,7 +35,7 @@ public interface IServiceConfig {
public string? OneFuzzResourceGroup { get; }
public string? OneFuzzTelemetry { get; }
public string OnefuzzVersion { get; }
public string OneFuzzVersion { get; }
}
public class ServiceConfiguration : IServiceConfig {
@ -77,5 +77,5 @@ public class ServiceConfiguration : IServiceConfig {
public string? OneFuzzOwner { get => Environment.GetEnvironmentVariable("ONEFUZZ_OWNER"); }
public string? OneFuzzResourceGroup { get => Environment.GetEnvironmentVariable("ONEFUZZ_RESOURCE_GROUP"); }
public string? OneFuzzTelemetry { get => Environment.GetEnvironmentVariable("ONEFUZZ_TELEMETRY"); }
public string OnefuzzVersion { get => Environment.GetEnvironmentVariable("ONEFUZZ_VERSION") ?? "0.0.0"; }
public string OneFuzzVersion { get => Environment.GetEnvironmentVariable("ONEFUZZ_VERSION") ?? "0.0.0"; }
}

View File

@ -1,125 +1,124 @@
using System.Threading.Tasks;
using Azure;
using Azure.ResourceManager;
using Azure.Storage;
using Azure.Storage.Blobs;
using Azure.Storage.Sas;
namespace Microsoft.OneFuzz.Service;
public interface IContainers {
public Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType);
public Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType);
public Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null);
Async.Task saveBlob(Container container, string v1, string v2, StorageType config);
Task<Guid> GetInstanceId();
}
public class Containers : IContainers {
private ILogTracer _log;
private IStorage _storage;
private ICreds _creds;
private ArmClient _armClient;
public Containers(ILogTracer log, IStorage storage, ICreds creds) {
_log = log;
_storage = storage;
_creds = creds;
_armClient = creds.ArmClient;
}
public async Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType) {
var client = await FindContainer(container, storageType);
if (client == null) {
return null;
}
try {
return (await client.GetBlobClient(name).DownloadContentAsync())
.Value.Content;
} catch (RequestFailedException) {
return null;
}
}
public async Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType) {
// # check secondary accounts first by searching in reverse.
// #
// # By implementation, the primary account is specified first, followed by
// # any secondary accounts.
// #
// # Secondary accounts, if they exist, are preferred for containers and have
// # increased IOP rates, this should be a slight optimization
return await _storage.GetAccounts(storageType)
.Reverse()
.Select(account => GetBlobService(account)?.GetBlobContainerClient(container.ContainerName))
.ToAsyncEnumerable()
.WhereAwait(async client => client != null && (await client.ExistsAsync()).Value)
.FirstOrDefaultAsync();
}
private BlobServiceClient? GetBlobService(string accountId) {
_log.Info($"getting blob container (account_id: {accountId}");
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(accountId);
if (accountName == null) {
_log.Error("Failed to get storage account name");
return null;
}
var storageKeyCredential = new StorageSharedKeyCredential(accountName, accountKey);
var accountUrl = GetUrl(accountName);
return new BlobServiceClient(accountUrl, storageKeyCredential);
}
private static Uri GetUrl(string accountName) {
return new Uri($"https://{accountName}.blob.core.windows.net/");
}
public async Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null) {
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(client.AccountName);
var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30));
var sasBuilder = new BlobSasBuilder(permissions, endTime) {
StartsOn = startTime,
BlobContainerName = container.ContainerName,
BlobName = name
};
var sasUrl = client.GetBlobClient(name).GenerateSasUri(sasBuilder);
return sasUrl;
}
public (DateTimeOffset, DateTimeOffset) SasTimeWindow(TimeSpan timeSpan) {
// SAS URLs are valid 6 hours earlier, primarily to work around dev
// workstations having out-of-sync time. Additionally, SAS URLs are stopped
// 15 minutes later than requested based on "Be careful with SAS start time"
// guidance.
// Ref: https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
var SAS_START_TIME_DELTA = TimeSpan.FromHours(6);
var SAS_END_TIME_DELTA = TimeSpan.FromMinutes(6);
var now = DateTimeOffset.UtcNow;
var start = now - SAS_START_TIME_DELTA;
var expiry = now + timeSpan + SAS_END_TIME_DELTA;
return (start, expiry);
}
public async System.Threading.Tasks.Task saveBlob(Container container, string name, string data, StorageType storageType) {
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
await client.UploadBlobAsync(name, new BinaryData(data));
}
public async Async.Task<Guid> GetInstanceId() {
var blob = await GetBlob(new Container("base-config"), "instance_id", StorageType.Config);
if (blob == null) {
throw new System.Exception("Blob Not Found");
}
return System.Guid.Parse(blob.ToString());
}
}
using System.Threading.Tasks;
using Azure;
using Azure.ResourceManager;
using Azure.Storage;
using Azure.Storage.Blobs;
using Azure.Storage.Sas;
namespace Microsoft.OneFuzz.Service;
public interface IContainers {
public Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType);
public Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType);
public Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null);
Async.Task saveBlob(Container container, string v1, string v2, StorageType config);
Task<Guid> GetInstanceId();
}
public class Containers : IContainers {
private ILogTracer _log;
private IStorage _storage;
private ICreds _creds;
private ArmClient _armClient;
public Containers(ILogTracer log, IStorage storage, ICreds creds) {
_log = log;
_storage = storage;
_creds = creds;
_armClient = creds.ArmClient;
}
public async Task<BinaryData?> GetBlob(Container container, string name, StorageType storageType) {
var client = await FindContainer(container, storageType);
if (client == null) {
return null;
}
try {
return (await client.GetBlobClient(name).DownloadContentAsync())
.Value.Content;
} catch (RequestFailedException) {
return null;
}
}
public async Async.Task<BlobContainerClient?> FindContainer(Container container, StorageType storageType) {
// # check secondary accounts first by searching in reverse.
// #
// # By implementation, the primary account is specified first, followed by
// # any secondary accounts.
// #
// # Secondary accounts, if they exist, are preferred for containers and have
// # increased IOP rates, this should be a slight optimization
return await _storage.GetAccounts(storageType)
.Reverse()
.Select(account => GetBlobService(account)?.GetBlobContainerClient(container.ContainerName))
.ToAsyncEnumerable()
.WhereAwait(async client => client != null && (await client.ExistsAsync()).Value)
.FirstOrDefaultAsync();
}
private BlobServiceClient? GetBlobService(string accountId) {
_log.Info($"getting blob container (account_id: {accountId}");
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(accountId);
if (accountName == null) {
_log.Error("Failed to get storage account name");
return null;
}
var storageKeyCredential = new StorageSharedKeyCredential(accountName, accountKey);
var accountUrl = GetUrl(accountName);
return new BlobServiceClient(accountUrl, storageKeyCredential);
}
private static Uri GetUrl(string accountName) {
return new Uri($"https://{accountName}.blob.core.windows.net/");
}
public async Async.Task<Uri?> GetFileSasUrl(Container container, string name, StorageType storageType, BlobSasPermissions permissions, TimeSpan? duration = null) {
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey(client.AccountName);
var (startTime, endTime) = SasTimeWindow(duration ?? TimeSpan.FromDays(30));
var sasBuilder = new BlobSasBuilder(permissions, endTime) {
StartsOn = startTime,
BlobContainerName = container.ContainerName,
BlobName = name
};
var sasUrl = client.GetBlobClient(name).GenerateSasUri(sasBuilder);
return sasUrl;
}
public (DateTimeOffset, DateTimeOffset) SasTimeWindow(TimeSpan timeSpan) {
// SAS URLs are valid 6 hours earlier, primarily to work around dev
// workstations having out-of-sync time. Additionally, SAS URLs are stopped
// 15 minutes later than requested based on "Be careful with SAS start time"
// guidance.
// Ref: https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
var SAS_START_TIME_DELTA = TimeSpan.FromHours(6);
var SAS_END_TIME_DELTA = TimeSpan.FromMinutes(6);
var now = DateTimeOffset.UtcNow;
var start = now - SAS_START_TIME_DELTA;
var expiry = now + timeSpan + SAS_END_TIME_DELTA;
return (start, expiry);
}
public async System.Threading.Tasks.Task saveBlob(Container container, string name, string data, StorageType storageType) {
var client = await FindContainer(container, storageType) ?? throw new Exception($"unable to find container: {container.ContainerName} - {storageType}");
await client.UploadBlobAsync(name, new BinaryData(data));
}
public async Async.Task<Guid> GetInstanceId() {
var blob = await GetBlob(new Container("base-config"), "instance_id", StorageType.Config);
if (blob == null) {
throw new System.Exception("Blob Not Found");
}
return System.Guid.Parse(blob.ToString());
}
}

View File

@ -14,6 +14,7 @@ public interface ICreds {
public ResourceIdentifier GetResourceGroupResourceIdentifier();
public string GetInstanceName();
public ArmClient ArmClient { get; }
@ -59,6 +60,13 @@ public class Creds : ICreds {
return new ResourceIdentifier(resourceId);
}
public string GetInstanceName() {
var instanceName = _config.OneFuzzInstanceName
?? throw new System.Exception("Instance Name env var is not present");
return instanceName;
}
public ResourceGroupResource GetResourceGroupResource() {
var resourceId = GetResourceGroupResourceIdentifier();
return ArmClient.GetResourceGroupResource(resourceId);

View File

@ -55,5 +55,4 @@ public partial class TimerProxy {
}
}
}

View File

@ -55,7 +55,7 @@ public class ProxyOperations : StatefulOrm<Proxy, VmState>, IProxyOperations {
}
_logTracer.Info($"creating proxy: region:{region}");
var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, Auth.BuildAuth(), null, null, _config.OnefuzzVersion, null, false);
var newProxy = new Proxy(region, Guid.NewGuid(), DateTimeOffset.UtcNow, VmState.Init, Auth.BuildAuth(), null, null, _config.OneFuzzVersion, null, false);
await Replace(newProxy);
await _events.SendEvent(new EventProxyCreated(region, newProxy.ProxyId));
@ -83,8 +83,8 @@ public class ProxyOperations : StatefulOrm<Proxy, VmState>, IProxyOperations {
return false;
}
if (proxy.Version != _config.OnefuzzVersion) {
_logTracer.Info($"mismatch version: proxy:{proxy.Version} service:{_config.OnefuzzVersion} state:{proxy.State}");
if (proxy.Version != _config.OneFuzzVersion) {
_logTracer.Info($"mismatch version: proxy:{proxy.Version} service:{_config.OneFuzzVersion} state:{proxy.State}");
return true;
}

View File

@ -14,7 +14,6 @@ public interface IStorage {
public IEnumerable<string> CorpusAccounts();
string GetPrimaryAccount(StorageType storageType);
public (string?, string?) GetStorageAccountNameAndKey(string accountId);
public IEnumerable<string> GetAccounts(StorageType storageType);
}
@ -98,6 +97,28 @@ public class Storage : IStorage {
return (resourceId.Name, key?.Value);
}
public string ChooseAccounts(StorageType storageType) {
var accounts = GetAccounts(storageType);
if (!accounts.Any()) {
throw new Exception($"No Storage Accounts for {storageType}");
}
var account_list = accounts.ToList();
if (account_list.Count == 1) {
return account_list[0];
}
// Use a random secondary storage account if any are available. This
// reduces IOP contention for the Storage Queues, which are only available
// on primary accounts
//
// security note: this is not used as a security feature
var random = new Random();
var index = random.Next(account_list.Count);
return account_list[index]; // nosec
}
public IEnumerable<string> GetAccounts(StorageType storageType) {
switch (storageType) {
case StorageType.Corpus:

View File

@ -1,106 +1,253 @@
using ApiService.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public interface IWebhookMessageLogOperations : IOrm<WebhookMessageLog> {
IAsyncEnumerable<WebhookMessageLog> SearchExpired();
}
public class WebhookMessageLogOperations : Orm<WebhookMessageLog>, IWebhookMessageLogOperations {
const int EXPIRE_DAYS = 7;
record WebhookMessageQueueObj(
Guid WebhookId,
Guid EventId
);
private readonly IQueue _queue;
private readonly ILogTracer _log;
public WebhookMessageLogOperations(IStorage storage, IQueue queue, ILogTracer log, IServiceConfig config) : base(storage, log, config) {
_queue = queue;
_log = log;
}
public async Async.Task QueueWebhook(WebhookMessageLog webhookLog) {
var obj = new WebhookMessageQueueObj(webhookLog.WebhookId, webhookLog.EventId);
TimeSpan? visibilityTimeout = webhookLog.State switch {
WebhookMessageState.Queued => TimeSpan.Zero,
WebhookMessageState.Retrying => TimeSpan.FromSeconds(30),
_ => null
};
if (visibilityTimeout == null) {
_log.WithTags(
new[] {
("WebhookId", webhookLog.WebhookId.ToString()),
("EventId", webhookLog.EventId.ToString()) }
).
Error($"invalid WebhookMessage queue state, not queuing. {webhookLog.WebhookId}:{webhookLog.EventId} - {webhookLog.State}");
} else {
await _queue.QueueObject("webhooks", obj, StorageType.Config, visibilityTimeout: visibilityTimeout);
}
}
private void QueueObject(string v, WebhookMessageQueueObj obj, StorageType config, int? visibility_timeout) {
throw new NotImplementedException();
}
public IAsyncEnumerable<WebhookMessageLog> SearchExpired() {
var expireTime = (DateTimeOffset.UtcNow - TimeSpan.FromDays(EXPIRE_DAYS)).ToString("o");
var timeFilter = $"Timestamp lt datetime'{expireTime}'";
return QueryAsync(filter: timeFilter);
}
}
public interface IWebhookOperations {
Async.Task SendEvent(EventMessage eventMessage);
}
public class WebhookOperations : Orm<Webhook>, IWebhookOperations {
private readonly IWebhookMessageLogOperations _webhookMessageLogOperations;
private readonly ILogTracer _log;
public WebhookOperations(IStorage storage, IWebhookMessageLogOperations webhookMessageLogOperations, ILogTracer log, IServiceConfig config)
: base(storage, log, config) {
_webhookMessageLogOperations = webhookMessageLogOperations;
_log = log;
}
async public Async.Task SendEvent(EventMessage eventMessage) {
await foreach (var webhook in GetWebhooksCached()) {
if (!webhook.EventTypes.Contains(eventMessage.EventType)) {
continue;
}
await AddEvent(webhook, eventMessage);
}
}
async private Async.Task AddEvent(Webhook webhook, EventMessage eventMessage) {
var message = new WebhookMessageLog(
EventId: eventMessage.EventId,
EventType: eventMessage.EventType,
Event: eventMessage.Event,
InstanceId: eventMessage.InstanceId,
InstanceName: eventMessage.InstanceName,
WebhookId: webhook.WebhookId
);
var r = await _webhookMessageLogOperations.Replace(message);
if (!r.IsOk) {
var (status, reason) = r.ErrorV;
_log.Error($"Failed to replace webhook message log due to [{status}] {reason}");
}
}
//todo: caching
public IAsyncEnumerable<Webhook> GetWebhooksCached() {
return QueryAsync();
}
}
using System.Net;
using System.Net.Http;
using System.Security.Cryptography;
using System.Text.Json;
using ApiService.OneFuzzLib.Orm;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public interface IWebhookOperations {
Async.Task SendEvent(EventMessage eventMessage);
Async.Task<Webhook?> GetByWebhookId(Guid webhookId);
Async.Task<bool> Send(WebhookMessageLog messageLog);
}
public class WebhookOperations : Orm<Webhook>, IWebhookOperations {
private readonly IWebhookMessageLogOperations _webhookMessageLogOperations;
private readonly ILogTracer _log;
private readonly ICreds _creds;
private readonly IContainers _containers;
private readonly IHttpClientFactory _httpFactory;
public WebhookOperations(IHttpClientFactory httpFactory, ICreds creds, IStorage storage, IWebhookMessageLogOperations webhookMessageLogOperations, IContainers containers, ILogTracer log, IServiceConfig config)
: base(storage, log, config) {
_webhookMessageLogOperations = webhookMessageLogOperations;
_log = log;
_creds = creds;
_containers = containers;
_httpFactory = httpFactory;
}
async public Async.Task SendEvent(EventMessage eventMessage) {
await foreach (var webhook in GetWebhooksCached()) {
if (!webhook.EventTypes.Contains(eventMessage.EventType)) {
continue;
}
await AddEvent(webhook, eventMessage);
}
}
async private Async.Task AddEvent(Webhook webhook, EventMessage eventMessage) {
var message = new WebhookMessageLog(
EventId: eventMessage.EventId,
EventType: eventMessage.EventType,
Event: eventMessage.Event,
InstanceId: eventMessage.InstanceId,
InstanceName: eventMessage.InstanceName,
WebhookId: webhook.WebhookId,
TryCount: 0
);
var r = await _webhookMessageLogOperations.Replace(message);
if (!r.IsOk) {
var (status, reason) = r.ErrorV;
_log.Error($"Failed to replace webhook message log due to [{status}] {reason}");
}
}
public async Async.Task<bool> Send(WebhookMessageLog messageLog) {
var webhook = await GetByWebhookId(messageLog.WebhookId);
if (webhook == null || webhook.Url == null) {
throw new Exception($"Invalid Webhook. Webhook with WebhookId: {messageLog.WebhookId} Not Found");
}
var (data, digest) = await BuildMessage(webhookId: webhook.WebhookId, eventId: messageLog.EventId, eventType: messageLog.EventType, webhookEvent: messageLog.Event, secretToken: webhook.SecretToken, messageFormat: webhook.MessageFormat);
var headers = new Dictionary<string, string> { { "User-Agent", $"onefuzz-webhook {_config.OneFuzzVersion}" } };
if (digest != null) {
headers["X-Onefuzz-Digest"] = digest;
}
var client = new Request(_httpFactory.CreateClient());
_log.Info(data);
var response = client.Post(url: webhook.Url, json: data, headers: headers);
var result = response.Result;
if (result.StatusCode == HttpStatusCode.Accepted) {
return true;
}
return false;
}
// Not converting to bytes, as it's not neccessary in C#. Just keeping as string.
public async Async.Task<Tuple<string, string?>> BuildMessage(Guid webhookId, Guid eventId, EventType eventType, BaseEvent webhookEvent, String? secretToken, WebhookMessageFormat? messageFormat) {
var entityConverter = new EntityConverter();
string data = "";
if (messageFormat != null && messageFormat == WebhookMessageFormat.EventGrid) {
var eventGridMessage = new[] { new WebhookMessageEventGrid(Id: eventId, Data: webhookEvent, DataVersion: "1.0.0", Subject: _creds.GetInstanceName(), EventType: eventType, EventTime: DateTimeOffset.UtcNow) };
data = JsonSerializer.Serialize(eventGridMessage, options: EntityConverter.GetJsonSerializerOptions());
} else {
var instanceId = await _containers.GetInstanceId();
var webhookMessage = new WebhookMessage(WebhookId: webhookId, EventId: eventId, EventType: eventType, Event: webhookEvent, InstanceId: instanceId, InstanceName: _creds.GetInstanceName());
data = JsonSerializer.Serialize(webhookMessage, options: EntityConverter.GetJsonSerializerOptions());
}
string? digest = null;
var hmac = HMAC.Create("HMACSHA512");
if (secretToken != null && hmac != null) {
hmac.Key = System.Text.Encoding.UTF8.GetBytes(secretToken);
digest = Convert.ToHexString(hmac.ComputeHash(System.Text.Encoding.UTF8.GetBytes(data)));
}
return new Tuple<string, string?>(data, digest);
}
public async Async.Task<Webhook?> GetByWebhookId(Guid webhookId) {
var data = QueryAsync(filter: $"PartitionKey eq '{webhookId}'");
return await data.FirstOrDefaultAsync();
}
//todo: caching
public IAsyncEnumerable<Webhook> GetWebhooksCached() {
return QueryAsync();
}
}
public interface IWebhookMessageLogOperations : IOrm<WebhookMessageLog> {
IAsyncEnumerable<WebhookMessageLog> SearchExpired();
public Async.Task ProcessFromQueue(WebhookMessageQueueObj obj);
}
public class WebhookMessageLogOperations : Orm<WebhookMessageLog>, IWebhookMessageLogOperations {
const int EXPIRE_DAYS = 7;
const int MAX_TRIES = 5;
private readonly IQueue _queue;
private readonly ILogTracer _log;
private readonly IWebhookOperations _webhook;
public WebhookMessageLogOperations(IStorage storage, IQueue queue, ILogTracer log, IServiceConfig config, ICreds creds, IHttpClientFactory httpFactory, IContainers containers) : base(storage, log, config) {
_queue = queue;
_log = log;
_webhook = new WebhookOperations(httpFactory: httpFactory, creds: creds, storage: storage, webhookMessageLogOperations: this, containers: containers, log: log, config: config);
}
public async Async.Task QueueWebhook(WebhookMessageLog webhookLog) {
var obj = new WebhookMessageQueueObj(webhookLog.WebhookId, webhookLog.EventId);
TimeSpan? visibilityTimeout = webhookLog.State switch {
WebhookMessageState.Queued => TimeSpan.Zero,
WebhookMessageState.Retrying => TimeSpan.FromSeconds(30),
_ => null
};
if (visibilityTimeout == null) {
_log.WithTags(
new[] {
("WebhookId", webhookLog.WebhookId.ToString()),
("EventId", webhookLog.EventId.ToString()) }
).
Error($"invalid WebhookMessage queue state, not queuing. {webhookLog.WebhookId}:{webhookLog.EventId} - {webhookLog.State}");
} else {
await _queue.QueueObject("webhooks", obj, StorageType.Config, visibilityTimeout: visibilityTimeout);
}
}
public async Async.Task ProcessFromQueue(WebhookMessageQueueObj obj) {
var message = await GetWebhookMessageById(obj.WebhookId, obj.EventId);
if (message == null) {
_log.WithTags(
new[] {
("WebhookId", obj.WebhookId.ToString()),
("EventId", obj.EventId.ToString()) }
).
Error($"webhook message log not found for webhookId: {obj.WebhookId} and eventId: {obj.EventId}");
} else {
await Process(message);
}
}
private async System.Threading.Tasks.Task Process(WebhookMessageLog message) {
if (message.State == WebhookMessageState.Failed || message.State == WebhookMessageState.Succeeded) {
_log.WithTags(
new[] {
("WebhookId", message.WebhookId.ToString()),
("EventId", message.EventId.ToString()) }
).
Error($"webhook message already handled. {message.WebhookId}:{message.EventId}");
return;
}
var newMessage = message with { TryCount = message.TryCount + 1 };
_log.Info($"sending webhook: {message.WebhookId}:{message.EventId}");
var success = await Send(newMessage);
if (success) {
newMessage = newMessage with { State = WebhookMessageState.Succeeded };
await Replace(newMessage);
_log.Info($"sent webhook event {newMessage.WebhookId}:{newMessage.EventId}");
} else if (newMessage.TryCount < MAX_TRIES) {
newMessage = newMessage with { State = WebhookMessageState.Retrying };
await Replace(newMessage);
await QueueWebhook(newMessage);
_log.Warning($"sending webhook event failed, re-queued {newMessage.WebhookId}:{newMessage.EventId}");
} else {
newMessage = newMessage with { State = WebhookMessageState.Failed };
await Replace(newMessage);
_log.Info($"sending webhook: {newMessage.WebhookId} event: {newMessage.EventId} failed {newMessage.TryCount} times.");
}
}
private async Async.Task<bool> Send(WebhookMessageLog message) {
var webhook = await _webhook.GetByWebhookId(message.WebhookId);
if (webhook == null) {
_log.WithTags(
new[] {
("WebhookId", message.WebhookId.ToString()),
}
).
Error($"webhook not found for webhookId: {message.WebhookId}");
return false;
}
try {
return await _webhook.Send(message);
} catch (Exception exc) {
_log.WithTags(
new[] {
("WebhookId", message.WebhookId.ToString())
}
).
Exception(exc);
return false;
}
}
private void QueueObject(string v, WebhookMessageQueueObj obj, StorageType config, int? visibility_timeout) {
throw new NotImplementedException();
}
public IAsyncEnumerable<WebhookMessageLog> SearchExpired() {
var expireTime = (DateTimeOffset.UtcNow - TimeSpan.FromDays(EXPIRE_DAYS)).ToString("o");
var timeFilter = $"Timestamp lt datetime'{expireTime}'";
return QueryAsync(filter: timeFilter);
}
public async Async.Task<WebhookMessageLog?> GetWebhookMessageById(Guid webhookId, Guid eventId) {
var data = QueryAsync(filter: $"PartitionKey eq '{webhookId}' and RowKey eq '{eventId}'");
return await data.FirstOrDefaultAsync();
}
}

File diff suppressed because it is too large Load Diff

View File

@ -244,9 +244,6 @@ namespace Tests {
); ;
}
public static Gen<Report> Report() {
return Arb.Generate<Tuple<string, BlobRef, List<string>, Guid, int>>().Select(
arg =>
@ -373,6 +370,7 @@ namespace Tests {
return Arb.From(OrmGenerators.Notification());
}
public static Arbitrary<WebhookMessageEventGrid> WebhookMessageEventGrid() {
return Arb.From(OrmGenerators.WebhookMessageEventGrid());
}
@ -548,7 +546,6 @@ namespace Tests {
return Test(n);
}
[Property]
public bool Job(Job j) {
return Test(j);

View File

@ -246,19 +246,6 @@ namespace Tests {
Assert.Equal(expected.TheName, actual.TheName);
}
[Fact]
public void TestEventSerialization2() {
var converter = new EntityConverter();
var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), "test Poool"), Guid.NewGuid(), "test") {
ETag = new Azure.ETag("33a64df551425fcc55e4d42a148795d9f25f89d4")
};
var te = converter.ToTableEntity(expectedEvent);
var actualEvent = converter.ToRecord<EventMessage>(te);
Assert.Equal(expectedEvent, actualEvent);
}
record Entity3(
[PartitionKey] int Id,
[RowKey] string TheName,