Continue port of QueueNodeHearbeat (#1761)

[WIP ] continue port of  QueueNodeHearbeat
This commit is contained in:
Cheick Keita
2022-04-07 11:55:12 -07:00
committed by GitHub
parent 5004b7902a
commit 627b401c87
16 changed files with 1043 additions and 306 deletions

View File

@ -91,30 +91,49 @@ class Console : ILog {
}
}
public class LogTracer {
public interface ILogTracer
{
IDictionary<string, string> Tags { get; }
void Critical(string message);
void Error(string message);
void Event(string evt, IDictionary<string, double>? metrics);
void Exception(Exception ex, IDictionary<string, double>? metrics);
void ForceFlush();
void Info(string message);
void Warning(string message);
}
public class LogTracer : ILogTracer
{
private List<ILog> loggers;
private IDictionary<string, string> tags = new Dictionary<string, string>();
private Guid correlationId;
public LogTracer(Guid correlationId, List<ILog> loggers) {
public LogTracer(Guid correlationId, List<ILog> loggers)
{
this.correlationId = correlationId;
this.loggers = loggers;
}
public IDictionary<string, string> Tags => tags;
public void Info(string message) {
public void Info(string message)
{
var caller = new StackTrace()?.GetFrame(1)?.GetMethod()?.Name;
foreach (var logger in loggers) {
foreach (var logger in loggers)
{
logger.Log(correlationId, message, SeverityLevel.Information, Tags, caller);
}
}
public void Warning(string message) {
public void Warning(string message)
{
var caller = new StackTrace()?.GetFrame(1)?.GetMethod()?.Name;
foreach (var logger in loggers) {
foreach (var logger in loggers)
{
logger.Log(correlationId, message, SeverityLevel.Warning, Tags, caller);
}
}
@ -137,36 +156,50 @@ public class LogTracer {
}
}
public void Event(string evt, IDictionary<string, double>? metrics) {
public void Event(string evt, IDictionary<string, double>? metrics)
{
var caller = new StackTrace()?.GetFrame(1)?.GetMethod()?.Name;
foreach (var logger in loggers) {
foreach (var logger in loggers)
{
logger.LogEvent(correlationId, evt, Tags, metrics, caller);
}
}
public void Exception(Exception ex, IDictionary<string, double>? metrics) {
public void Exception(Exception ex, IDictionary<string, double>? metrics)
{
var caller = new StackTrace()?.GetFrame(1)?.GetMethod()?.Name;
foreach (var logger in loggers) {
foreach (var logger in loggers)
{
logger.LogException(correlationId, ex, Tags, metrics, caller);
}
}
public void ForceFlush() {
foreach (var logger in loggers) {
public void ForceFlush()
{
foreach (var logger in loggers)
{
logger.Flush();
}
}
}
public class LogTracerFactory {
public interface ILogTracerFactory
{
LogTracer MakeLogTracer(Guid correlationId);
}
public class LogTracerFactory : ILogTracerFactory
{
private List<ILog> loggers;
public LogTracerFactory(List<ILog> loggers) {
public LogTracerFactory(List<ILog> loggers)
{
this.loggers = loggers;
}
public LogTracer MakeLogTracer(Guid correlationId) {
public LogTracer MakeLogTracer(Guid correlationId)
{
return new LogTracer(correlationId, this.loggers);
}

View File

@ -25,4 +25,12 @@ public enum ErrorCode {
UNABLE_TO_UPDATE = 471,
PROXY_FAILED = 472,
INVALID_CONFIGURATION = 473,
}
public enum WebhookMessageState {
Queued,
Retrying,
Succeeded,
Failed
}

View File

@ -0,0 +1,251 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
using PoolName = System.String;
namespace Microsoft.OneFuzz.Service
{
public enum EventType
{
JobCreated,
JobStopped,
NodeCreated,
NodeDeleted,
NodeStateUpdated,
Ping,
PoolCreated,
PoolDeleted,
ProxyCreated,
ProxyDeleted,
ProxyFailed,
ProxyStateUpdated,
ScalesetCreated,
ScalesetDeleted,
ScalesetFailed,
ScalesetStateUpdated,
ScalesetResizeScheduled,
TaskCreated,
TaskFailed,
TaskStateUpdated,
TaskStopped,
CrashReported,
RegressionReported,
FileAdded,
TaskHeartbeat,
NodeHeartbeat,
InstanceConfigUpdated
}
public abstract record BaseEvent() {
public EventType GetEventType() {
return
this switch {
EventNodeHeartbeat _ => EventType.NodeHeartbeat,
_ => throw new NotImplementedException(),
};
}
};
//public record EventTaskStopped(
// Guid JobId,
// Guid TaskId,
// UserInfo? UserInfo,
// TaskConfig Config
//) : BaseEvent();
//record EventTaskFailed(
// Guid JobId,
// Guid TaskId,
// Error Error,
// UserInfo? UserInfo,
// TaskConfig Config
// ) : BaseEvent();
//record EventJobCreated(
// Guid JobId,
// JobConfig Config,
// UserInfo? UserInfo
// ) : BaseEvent();
//record JobTaskStopped(
// Guid TaskId,
// TaskType TaskType,
// Error? Error
// ) : BaseEvent();
//record EventJobStopped(
// Guid JobId: UUId,
// JobConfig Config,
// UserInfo? UserInfo,
// List<JobTaskStopped> TaskInfo
//): BaseEvent();
//record EventTaskCreated(
// Guid JobId,
// Guid TaskId,
// TaskConfig Config,
// UserInfo? UserInfo
// ) : BaseEvent();
//record EventTaskStateUpdated(
// Guid JobId,
// Guid TaskId,
// TaskState State,
// DateTimeOffset? EndTime,
// TaskConfig Config
// ) : BaseEvent();
//record EventTaskHeartbeat(
// JobId: Guid,
// TaskId: Guid,
// Config: TaskConfig
//): BaseEvent();
//record EventPing(
// PingId: Guid
//): BaseEvent();
//record EventScalesetCreated(
// Guid ScalesetId,
// PoolName PoolName,
// string VmSku,
// string Image,
// Region Region,
// int Size) : BaseEvent();
//record EventScalesetFailed(
// Guid ScalesetId,
// PoolName: PoolName,
// Error: Error
//): BaseEvent();
//record EventScalesetDeleted(
// Guid ScalesetId,
// PoolName PoolName,
// ) : BaseEvent();
//record EventScalesetResizeScheduled(
// Guid ScalesetId,
// PoolName PoolName,
// int size
// ) : BaseEvent();
//record EventPoolDeleted(
// PoolName PoolName
// ) : BaseEvent();
//record EventPoolCreated(
// PoolName PoolName,
// Os Os,
// Architecture Arch,
// bool Managed,
// AutoScaleConfig? Autoscale
// ) : BaseEvent();
//record EventProxyCreated(
// Region Region,
// Guid? ProxyId,
// ) : BaseEvent();
//record EventProxyDeleted(
// Region Region,
// Guid? ProxyId
//) : BaseEvent();
//record EventProxyFailed(
// Region Region,
// Guid? ProxyId,
// Error Error
//) : BaseEvent();
//record EventProxyStateUpdated(
// Region Region,
// Guid ProxyId,
// VmState State
// ) : BaseEvent();
//record EventNodeCreated(
// Guid MachineId,
// Guid? ScalesetId,
// PoolName PoolName
// ) : BaseEvent();
public record EventNodeHeartbeat(
Guid MachineId,
Guid? ScalesetId,
PoolName PoolName
) : BaseEvent();
// record EventNodeDeleted(
// Guid MachineId,
// Guid ScalesetId,
// PoolName PoolName
// ) : BaseEvent();
// record EventScalesetStateUpdated(
// Guid ScalesetId,
// PoolName PoolName,
// ScalesetState State
// ) : BaseEvent();
// record EventNodeStateUpdated(
// Guid MachineId,
// Guid? ScalesetId,
// PoolName PoolName,
// NodeState state
// ) : BaseEvent();
// record EventCrashReported(
// Report Report,
// Container Container,
// [property: JsonPropertyName("filename")] String FileName,
// TaskConfig? TaskConfig
// ) : BaseEvent();
// record EventRegressionReported(
// RegressionReport RegressionReport,
// Container Container,
// [property: JsonPropertyName("filename")] String FileName,
// TaskConfig? TaskConfig
// ) : BaseEvent();
// record EventFileAdded(
// Container Container,
// [property: JsonPropertyName("filename")] String FileName
// ) : BaseEvent();
// record EventInstanceConfigUpdated(
// InstanceConfig Config
// ) : BaseEvent();
}

View File

@ -1,6 +1,7 @@
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System;
using System.Collections.Generic;
using PoolName = System.String;
namespace Microsoft.OneFuzz.Service;
@ -74,7 +75,7 @@ public enum NodeState
public partial record Node
(
DateTimeOffset? InitializedAt,
[PartitionKey] string PoolName,
[PartitionKey] PoolName PoolName,
Guid? PoolId,
[RowKey] Guid MachineId,
NodeState State,
@ -90,3 +91,34 @@ public partial record Node
public record Error (ErrorCode Code, string[]? Errors = null);
public record UserInfo (Guid? ApplicationId, Guid? ObjectId, String? Upn);
public record EventMessage(
Guid EventId,
EventType EventType,
BaseEvent Event,
Guid InstanceId,
String InstanceName
): EntityBase();
//record AnyHttpUrl(AnyUrl):
// allowed_schemes = {'http', 'https
//
//public record TaskConfig(
// Guid jobId,
// List<Guid> PrereqTasks,
// TaskDetails Task,
// TaskVm? vm,
// TaskPool pool: Optional[]
// containers: List[TaskContainers]
// tags: Dict[str, str]
// debug: Optional[List[TaskDebugFlag]]
// colocate: Optional[bool]
// ): EntityBase();

View File

@ -0,0 +1,59 @@
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json.Serialization;
using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service;
public enum WebhookMessageFormat
{
Onefuzz,
EventGrid
}
public record WebhookMessage(Guid EventId,
EventType EventType,
BaseEvent Event,
Guid InstanceId,
String InstanceName,
Guid WebhookId): EventMessage(EventId, EventType, Event, InstanceId, InstanceName);
public record WebhookMessageEventGrid(
[property: JsonPropertyName("dataVersion")] string DataVersion,
string Subject,
[property: JsonPropertyName("EventType")] EventType EventType,
[property: JsonPropertyName("eventTime")] DateTimeOffset eventTime,
Guid Id,
BaseEvent data);
public record WebhookMessageLog(
[RowKey] Guid EventId,
EventType EventType,
BaseEvent Event,
Guid InstanceId,
String InstanceName,
[PartitionKey] Guid WebhookId,
WebhookMessageState State = WebhookMessageState.Queued,
int TryCount = 0
) : WebhookMessage(EventId,
EventType,
Event,
InstanceId,
InstanceName,
WebhookId);
public record Webhook(
[PartitionKey] Guid WebhookId,
[RowKey] string Name,
Uri? url,
List<EventType> EventTypes,
string SecretToken, // SecretString??
WebhookMessageFormat? MessageFormat
) : EntityBase();

View File

@ -34,7 +34,9 @@ public class Program
var host = new HostBuilder()
.ConfigureFunctionsWorkerDefaults()
.ConfigureServices((context, services) =>
services.AddSingleton(_ => new LogTracerFactory(GetLoggers()))
services
.AddSingleton<ILogTracerFactory>(_ => new LogTracerFactory(GetLoggers()))
.AddScoped<ILogTracer>(s => s.GetService<LogTracerFactory>()?.MakeLogTracer(Guid.NewGuid()) ?? throw new InvalidOperationException("Unable to create a logger") )
.AddSingleton<IStorageProvider>(_ => new StorageProvider(EnvironmentVariables.OneFuzz.FuncStorage ?? throw new InvalidOperationException("Missing account id") ))
.AddSingleton<ICreds>(_ => new Creds())
.AddSingleton<IStorage, Storage>()

View File

@ -14,20 +14,25 @@ namespace Microsoft.OneFuzz.Service;
public class QueueNodeHearbeat
{
private readonly ILogger _logger;
private readonly IStorageProvider _storageProvider;
public QueueNodeHearbeat(ILoggerFactory loggerFactory, IStorageProvider storageProvider)
private readonly IEvents _events;
private readonly INodeOperations _nodes;
public QueueNodeHearbeat(ILoggerFactory loggerFactory, INodeOperations nodes, IEvents events)
{
_logger = loggerFactory.CreateLogger<QueueNodeHearbeat>();
_storageProvider = storageProvider;
_nodes = nodes;
_events = events;
}
[Function("QueueNodeHearbeat")]
public async Task Run([QueueTrigger("myqueue-items", Connection = "AzureWebJobsStorage")] string msg)
{
_logger.LogInformation($"heartbeat: {msg}");
var hb = JsonSerializer.Deserialize<NodeHeartbeatEntry>(msg, EntityConverter.GetJsonSerializerOptions()).EnsureNotNull($"wrong data {msg}");
var node = await Node.GetByMachineId(_storageProvider, hb.NodeId);
var node = await _nodes.GetByMachineId(hb.NodeId);
if (node == null) {
_logger.LogWarning($"invalid node id: {hb.NodeId}");
@ -36,16 +41,8 @@ public class QueueNodeHearbeat
var newNode = node with { Heartbeat = DateTimeOffset.UtcNow };
await _storageProvider.Replace(newNode);
await _nodes.Replace(newNode);
//send_event(
// EventNodeHeartbeat(
// machine_id = node.machine_id,
// scaleset_id = node.scaleset_id,
// pool_name = node.pool_name,
// )
//)
_logger.LogInformation($"heartbeat: {msg}");
await _events.SendEvent(new EventNodeHeartbeat(node.MachineId, node.ScalesetId, node.PoolName));
}
}

View File

@ -0,0 +1,73 @@
using ApiService.OneFuzzLib;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service
{
public record SignalREvent
(
string Target,
List<EventMessage> arguments
);
public interface IEvents {
public Task SendEvent(BaseEvent anEvent);
public Task QueueSignalrEvent(EventMessage message);
}
public class Events : IEvents
{
private readonly IQueue _queue;
private readonly ILogTracer _logger;
private readonly IWebhookOperations _webhook;
public Events(IQueue queue, ILogTracer logger, IWebhookOperations webhook)
{
_queue = queue;
_logger = logger;
_webhook = webhook;
}
public async Task QueueSignalrEvent(EventMessage eventMessage)
{
var message = new SignalREvent("events", new List<EventMessage>() { eventMessage });
var encodedMessage = Encoding.UTF8.GetBytes(System.Text.Json.JsonSerializer.Serialize(message)) ;
await _queue.SendMessage("signalr-events", encodedMessage, StorageType.Config);
}
public async Task SendEvent(BaseEvent anEvent) {
var eventType = anEvent.GetEventType();
var eventMessage = new EventMessage(
Guid.NewGuid(),
eventType,
anEvent,
Guid.NewGuid(), // todo
"test" //todo
);
await QueueSignalrEvent(eventMessage);
await _webhook.SendEvent(eventMessage);
LogEvent(anEvent, eventType);
}
public void LogEvent(BaseEvent anEvent, EventType eventType)
{
//todo
//var scrubedEvent = FilterEvent(anEvent);
//throw new NotImplementedException();
}
private object FilterEvent(BaseEvent anEvent)
{
throw new NotImplementedException();
}
}
}

View File

@ -0,0 +1,32 @@
using ApiService.OneFuzzLib.Orm;
using Azure.Data.Tables;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service;
public interface INodeOperations : IOrm<Node>
{
Task<Node?> GetByMachineId(Guid machineId);
}
public class NodeOperations : Orm<Node>, INodeOperations
{
public NodeOperations(IStorage storage)
:base(storage)
{
}
public async Task<Node?> GetByMachineId(Guid machineId)
{
var data = QueryAsync(filter: $"RowKey eq '{machineId}'");
return await data.FirstOrDefaultAsync();
}
}

View File

@ -1,17 +0,0 @@
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System;
using System.Linq;
using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service;
public partial record Node
{
public async static Task<Node?> GetByMachineId(IStorageProvider storageProvider, Guid machineId) {
var tableClient = await storageProvider.GetTableClient("Node");
var data = storageProvider.QueryAsync<Node>(filter: $"RowKey eq '{machineId}'");
return await data.FirstOrDefaultAsync();
}
}

View File

@ -0,0 +1,81 @@
using Azure.Storage;
using Azure.Storage.Queues;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service;
public interface IQueue
{
Task SendMessage(string name, byte[] message, StorageType storageType, TimeSpan? visibilityTimeout = null, TimeSpan? timeToLive = null);
Task<bool> QueueObject<T>(string name, T obj, StorageType storageType, TimeSpan? visibilityTimeout);
}
public class Queue : IQueue
{
IStorage _storage;
ILog _logger;
public Queue(IStorage storage, ILog logger)
{
_storage = storage;
_logger = logger;
}
public async Task SendMessage(string name, byte[] message, StorageType storageType, TimeSpan? visibilityTimeout=null, TimeSpan? timeToLive=null ) {
var queue = GetQueue(name, storageType);
if (queue != null) {
try
{
await queue.SendMessageAsync(Convert.ToBase64String(message), visibilityTimeout: visibilityTimeout, timeToLive: timeToLive);
}
catch (Exception) {
}
}
}
public QueueClient? GetQueue(string name, StorageType storageType ) {
var client = GetQueueClient(storageType);
try
{
return client.GetQueueClient(name);
}
catch (Exception) {
return null;
}
}
public QueueServiceClient GetQueueClient(StorageType storageType)
{
var accountId = _storage.GetPrimaryAccount(storageType);
//_logger.LogDEbug("getting blob container (account_id: %s)", account_id)
(var name, var key) = _storage.GetStorageAccountNameAndKey(accountId);
var accountUrl = new Uri($"https://%s.queue.core.windows.net{name}");
var client = new QueueServiceClient(accountUrl, new StorageSharedKeyCredential(name, key));
return client;
}
public async Task<bool> QueueObject<T>(string name, T obj, StorageType storageType, TimeSpan? visibilityTimeout)
{
var queue = GetQueue(name, storageType) ?? throw new Exception($"unable to queue object, no such queue: {name}");
var serialized = JsonSerializer.Serialize(obj, EntityConverter.GetJsonSerializerOptions()) ;
//var encoded = Encoding.UTF8.GetBytes(serialized);
try
{
await queue.SendMessageAsync(serialized, visibilityTimeout: visibilityTimeout);
return true;
} catch (Exception) {
return false;
}
}
}

View File

@ -5,13 +5,23 @@ using Azure.ResourceManager.Storage;
using Azure.Core;
using Microsoft.Extensions.Logging;
using System.Text.Json;
using System.Linq;
using System.Threading.Tasks;
using Azure.Data.Tables;
namespace Microsoft.OneFuzz.Service;
public enum StorageType {
Corpus,
Config
}
public interface IStorage {
public ArmClient GetMgmtClient();
public IEnumerable<string> CorpusAccounts();
string GetPrimaryAccount(StorageType storageType);
public (string?, string?) GetStorageAccountNameAndKey(string accountId);
}
public class Storage : IStorage {
@ -76,4 +86,24 @@ public class Storage : IStorage {
_logger.LogInformation($"corpus accounts: {JsonSerializer.Serialize(results)}");
return results;
}
public string GetPrimaryAccount(StorageType storageType)
{
return
storageType switch
{
StorageType.Corpus => GetFuzzStorage(),
StorageType.Config => GetFuncStorage(),
_ => throw new NotImplementedException(),
};
}
public (string?, string?) GetStorageAccountNameAndKey(string accountId)
{
var resourceId = new ResourceIdentifier(accountId);
var armClient = GetMgmtClient();
var storageAccount = armClient.GetStorageAccount(resourceId);
var key = storageAccount.GetKeys().Value.Keys.FirstOrDefault();
return (resourceId.Name, key?.Value);
}
}

View File

@ -0,0 +1,110 @@
using ApiService.OneFuzzLib.Orm;
using Microsoft.OneFuzz.Service;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace ApiService.OneFuzzLib;
public interface IWebhookMessageLogOperations: IOrm<WebhookMessageLog>
{
}
public class WebhookMessageLogOperations : Orm<WebhookMessageLog>, IWebhookMessageLogOperations
{
record WebhookMessageQueueObj (
Guid WebhookId,
Guid EventId
);
private readonly IQueue _queue;
private readonly ILogTracer _log;
public WebhookMessageLogOperations(IStorage storage, IQueue queue, ILogTracer log) : base(storage)
{
_queue = queue;
_log = log;
}
public async Task QueueWebhook(WebhookMessageLog webhookLog) {
var obj = new WebhookMessageQueueObj(webhookLog.WebhookId, webhookLog.EventId);
TimeSpan? visibilityTimeout = webhookLog.State switch
{
WebhookMessageState.Queued => TimeSpan.Zero,
WebhookMessageState.Retrying => TimeSpan.FromSeconds(30),
_ => null
};
if (visibilityTimeout == null)
{
_log.Error($"invalid WebhookMessage queue state, not queuing. {webhookLog.WebhookId}:{webhookLog.EventId} - {webhookLog.State}");
}
else
{
await _queue.QueueObject("webhooks", obj, StorageType.Config, visibilityTimeout: visibilityTimeout);
}
}
private void QueueObject(string v, WebhookMessageQueueObj obj, StorageType config, int? visibility_timeout)
{
throw new NotImplementedException();
}
}
public interface IWebhookOperations
{
Task SendEvent(EventMessage eventMessage);
}
public class WebhookOperations: Orm<Webhook>, IWebhookOperations
{
private readonly IWebhookMessageLogOperations _webhookMessageLogOperations;
public WebhookOperations(IStorage storage, IWebhookMessageLogOperations webhookMessageLogOperations)
:base(storage)
{
_webhookMessageLogOperations = webhookMessageLogOperations;
}
async public Task SendEvent(EventMessage eventMessage)
{
await foreach (var webhook in GetWebhooksCached())
{
if (!webhook.EventTypes.Contains(eventMessage.EventType))
{
continue;
}
await AddEvent(webhook, eventMessage);
}
}
async private Task AddEvent(Webhook webhook, EventMessage eventMessage)
{
var message = new WebhookMessageLog(
EventId: eventMessage.EventId,
EventType: eventMessage.EventType,
Event: eventMessage.Event,
InstanceId: eventMessage.InstanceId,
InstanceName: eventMessage.InstanceName,
WebhookId: webhook.WebhookId
);
await _webhookMessageLogOperations.Replace(message);
}
//todo: caching
public IAsyncEnumerable<Webhook> GetWebhooksCached()
{
return QueryAsync();
}
}

View File

@ -0,0 +1,261 @@
using Azure.Data.Tables;
using System;
using System.Reflection;
using System.Linq;
using System.Linq.Expressions;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Collections.Concurrent;
using Azure;
namespace Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
public abstract record EntityBase
{
public ETag? ETag { get; set; }
public DateTimeOffset? TimeStamp { get; set; }
//public ApiService.OneFuzzLib.Orm.IOrm<EntityBase>? Orm { get; set; }
}
/// Indicates that the enum cases should no be renamed
[AttributeUsage(AttributeTargets.Enum)]
public class SkipRename : Attribute { }
public class RowKeyAttribute : Attribute { }
public class PartitionKeyAttribute : Attribute { }
public enum EntityPropertyKind
{
PartitionKey,
RowKey,
Column
}
public record EntityProperty(string name, string columnName, Type type, EntityPropertyKind kind);
public record EntityInfo(Type type, EntityProperty[] properties, Func<object?[], object> constructor);
class OnefuzzNamingPolicy : JsonNamingPolicy
{
public override string ConvertName(string name)
{
return CaseConverter.PascalToSnake(name);
}
}
public class EntityConverter
{
private readonly JsonSerializerOptions _options;
private readonly ConcurrentDictionary<Type, EntityInfo> _cache;
public EntityConverter()
{
_options = GetJsonSerializerOptions();
_cache = new ConcurrentDictionary<Type, EntityInfo>();
}
public static JsonSerializerOptions GetJsonSerializerOptions() {
var options = new JsonSerializerOptions()
{
PropertyNamingPolicy = new OnefuzzNamingPolicy(),
};
options.Converters.Add(new CustomEnumConverterFactory());
return options;
}
internal Func<object?[], object> BuildConstructerFrom(ConstructorInfo constructorInfo)
{
var constructorParameters = Expression.Parameter(typeof(object?[]));
var parameterExpressions =
constructorInfo.GetParameters().Select((parameterInfo, i) =>
{
var ithIndex = Expression.Constant(i);
var ithParameter = Expression.ArrayIndex(constructorParameters, ithIndex);
var unboxedIthParameter = Expression.Convert(ithParameter, parameterInfo.ParameterType);
return unboxedIthParameter;
}).ToArray();
NewExpression constructorCall = Expression.New(constructorInfo, parameterExpressions);
Func<object?[], object> ctor = Expression.Lambda<Func<object?[], object>>(constructorCall, constructorParameters).Compile();
return ctor;
}
private EntityInfo GetEntityInfo<T>()
{
return _cache.GetOrAdd(typeof(T), type =>
{
var constructor = type.GetConstructors()[0];
var parameterInfos = constructor.GetParameters();
var parameters =
parameterInfos.Select(f =>
{
var name = f.Name.EnsureNotNull($"Invalid paramter {f}");
var parameterType = f.ParameterType.EnsureNotNull($"Invalid paramter {f}");
var isRowkey = f.GetCustomAttribute(typeof(RowKeyAttribute)) != null;
var isPartitionkey = f.GetCustomAttribute(typeof(PartitionKeyAttribute)) != null;
var (columnName, kind) =
isRowkey
? ("RowKey", EntityPropertyKind.RowKey)
: isPartitionkey
? ("PartitionKey", EntityPropertyKind.PartitionKey)
: (// JsonPropertyNameAttribute can only be applied to properties
typeof(T).GetProperty(name)?.GetCustomAttribute<JsonPropertyNameAttribute>()?.Name
?? CaseConverter.PascalToSnake(name),
EntityPropertyKind.Column
);
return new EntityProperty(name, columnName, parameterType, kind);
}).ToArray();
return new EntityInfo(typeof(T), parameters, BuildConstructerFrom(constructor));
});
}
public TableEntity ToTableEntity<T>(T typedEntity) where T: EntityBase
{
if (typedEntity == null)
{
throw new NullReferenceException();
}
var type = typeof(T)!;
if (type is null)
{
throw new NullReferenceException();
}
var tableEntity = new TableEntity();
var entityInfo = GetEntityInfo<T>();
foreach (var prop in entityInfo.properties)
{
var value = entityInfo.type.GetProperty(prop.name)?.GetValue(typedEntity);
if (prop.type == typeof(Guid) || prop.type == typeof(Guid?))
{
tableEntity.Add(prop.columnName, value?.ToString());
}
else if (prop.type == typeof(bool)
|| prop.type == typeof(bool?)
|| prop.type == typeof(string)
|| prop.type == typeof(DateTime)
|| prop.type == typeof(DateTime?)
|| prop.type == typeof(DateTimeOffset)
|| prop.type == typeof(DateTimeOffset?)
|| prop.type == typeof(int)
|| prop.type == typeof(int?)
|| prop.type == typeof(Int64)
|| prop.type == typeof(Int64?)
|| prop.type == typeof(double)
|| prop.type == typeof(double?)
)
{
tableEntity.Add(prop.columnName, value);
}
else if (prop.type.IsEnum)
{
var values =
(value?.ToString()?.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
.Select(CaseConverter.PascalToSnake)).EnsureNotNull($"Unable to read enum data {value}");
tableEntity.Add(prop.columnName, string.Join(",", values));
}
else
{
var serialized = JsonSerializer.Serialize(value, _options);
tableEntity.Add(prop.columnName, serialized);
}
}
if (typedEntity.ETag.HasValue) {
tableEntity.ETag = typedEntity.ETag.Value;
}
return tableEntity;
}
public T ToRecord<T>(TableEntity entity) where T: EntityBase
{
var entityInfo = GetEntityInfo<T>();
var parameters =
entityInfo.properties.Select(ef =>
{
if (ef.kind == EntityPropertyKind.PartitionKey || ef.kind == EntityPropertyKind.RowKey)
{
if (ef.type == typeof(string))
return entity.GetString(ef.kind.ToString());
else if (ef.type == typeof(Guid))
return Guid.Parse(entity.GetString(ef.kind.ToString()));
else
{
throw new Exception("invalid ");
}
}
var fieldName = ef.columnName;
if (ef.type == typeof(string))
{
return entity.GetString(fieldName);
}
else if (ef.type == typeof(bool))
{
return entity.GetBoolean(fieldName);
}
else if (ef.type == typeof(DateTimeOffset) || ef.type == typeof(DateTimeOffset?))
{
return entity.GetDateTimeOffset(fieldName);
}
else if (ef.type == typeof(DateTime))
{
return entity.GetDateTime(fieldName);
}
else if (ef.type == typeof(double))
{
return entity.GetDouble(fieldName);
}
else if (ef.type == typeof(Guid) || ef.type == typeof(Guid?))
{
return (object?)Guid.Parse(entity.GetString(fieldName));
}
else if (ef.type == typeof(int))
{
return entity.GetInt32(fieldName);
}
else if (ef.type == typeof(Int64))
{
return entity.GetInt64(fieldName);
}
else if (ef.type.IsEnum)
{
var stringValues =
entity.GetString(fieldName).Split(",", StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
.Select(CaseConverter.SnakeToPascal);
return Enum.Parse(ef.type, string.Join(",", stringValues));
}
else
{
var value = entity.GetString(fieldName);
return JsonSerializer.Deserialize(value, ef.type, options: _options); ;
}
}
).ToArray();
var entityRecord = (T)entityInfo.constructor.Invoke(parameters);
entityRecord.ETag = entity.ETag;
entityRecord.TimeStamp = entity.Timestamp;
return entityRecord;
}
}

View File

@ -1,259 +1,60 @@
using Azure.Core;
using Azure.Data.Tables;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using System;
using System.Reflection;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Collections.Concurrent;
using Azure;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
public abstract record EntityBase
namespace ApiService.OneFuzzLib.Orm
{
public ETag? ETag { get; set; }
public DateTimeOffset? TimeStamp { get; set; }
}
/// Indicates that the enum cases should no be renamed
[AttributeUsage(AttributeTargets.Enum)]
public class SkipRename : Attribute { }
public class RowKeyAttribute : Attribute { }
public class PartitionKeyAttribute : Attribute { }
public enum EntityPropertyKind
{
PartitionKey,
RowKey,
Column
}
public record EntityProperty(string name, string columnName, Type type, EntityPropertyKind kind);
public record EntityInfo(Type type, EntityProperty[] properties, Func<object?[], object> constructor);
class OnefuzzNamingPolicy : JsonNamingPolicy
{
public override string ConvertName(string name)
public interface IOrm<T> where T : EntityBase
{
return CaseConverter.PascalToSnake(name);
}
}
public class EntityConverter
{
private readonly JsonSerializerOptions _options;
private readonly ConcurrentDictionary<Type, EntityInfo> _cache;
public EntityConverter()
{
_options = GetJsonSerializerOptions();
_cache = new ConcurrentDictionary<Type, EntityInfo>();
Task<TableClient> GetTableClient(string table, string? accountId = null);
IAsyncEnumerable<T> QueryAsync(string filter);
Task<bool> Replace(T entity);
}
public class Orm<T> : IOrm<T> where T : EntityBase
{
IStorage _storage;
EntityConverter _entityConverter;
public static JsonSerializerOptions GetJsonSerializerOptions() {
var options = new JsonSerializerOptions()
public Orm(IStorage storage)
{
PropertyNamingPolicy = new OnefuzzNamingPolicy(),
};
options.Converters.Add(new CustomEnumConverterFactory());
return options;
}
internal Func<object?[], object> BuildConstructerFrom(ConstructorInfo constructorInfo)
{
var constructorParameters = Expression.Parameter(typeof(object?[]));
var parameterExpressions =
constructorInfo.GetParameters().Select((parameterInfo, i) =>
{
var ithIndex = Expression.Constant(i);
var ithParameter = Expression.ArrayIndex(constructorParameters, ithIndex);
var unboxedIthParameter = Expression.Convert(ithParameter, parameterInfo.ParameterType);
return unboxedIthParameter;
}).ToArray();
NewExpression constructorCall = Expression.New(constructorInfo, parameterExpressions);
Func<object?[], object> ctor = Expression.Lambda<Func<object?[], object>>(constructorCall, constructorParameters).Compile();
return ctor;
}
private EntityInfo GetEntityInfo<T>()
{
return _cache.GetOrAdd(typeof(T), type =>
{
var constructor = type.GetConstructors()[0];
var parameterInfos = constructor.GetParameters();
var parameters =
parameterInfos.Select(f =>
{
var name = f.Name.EnsureNotNull($"Invalid paramter {f}");
var parameterType = f.ParameterType.EnsureNotNull($"Invalid paramter {f}");
var isRowkey = f.GetCustomAttribute(typeof(RowKeyAttribute)) != null;
var isPartitionkey = f.GetCustomAttribute(typeof(PartitionKeyAttribute)) != null;
var (columnName, kind) =
isRowkey
? ("RowKey", EntityPropertyKind.RowKey)
: isPartitionkey
? ("PartitionKey", EntityPropertyKind.PartitionKey)
: (// JsonPropertyNameAttribute can only be applied to properties
typeof(T).GetProperty(name)?.GetCustomAttribute<JsonPropertyNameAttribute>()?.Name
?? CaseConverter.PascalToSnake(name),
EntityPropertyKind.Column
);
return new EntityProperty(name, columnName, parameterType, kind);
}).ToArray();
return new EntityInfo(typeof(T), parameters, BuildConstructerFrom(constructor));
});
}
public TableEntity ToTableEntity<T>(T typedEntity) where T: EntityBase
{
if (typedEntity == null)
{
throw new NullReferenceException();
_storage = storage;
_entityConverter = new EntityConverter();
}
var type = typeof(T)!;
if (type is null)
public async IAsyncEnumerable<T> QueryAsync(string? filter=null)
{
throw new NullReferenceException();
var tableClient = await GetTableClient(typeof(T).Name);
await foreach (var x in tableClient.QueryAsync<TableEntity>(filter).Select(x => _entityConverter.ToRecord<T>(x)))
{
yield return x;
}
}
var tableEntity = new TableEntity();
var entityInfo = GetEntityInfo<T>();
foreach (var prop in entityInfo.properties)
public async Task<bool> Replace(T entity)
{
var value = entityInfo.type.GetProperty(prop.name)?.GetValue(typedEntity);
if (prop.type == typeof(Guid) || prop.type == typeof(Guid?))
{
tableEntity.Add(prop.columnName, value?.ToString());
}
else if (prop.type == typeof(bool)
|| prop.type == typeof(bool?)
|| prop.type == typeof(string)
|| prop.type == typeof(DateTime)
|| prop.type == typeof(DateTime?)
|| prop.type == typeof(DateTimeOffset)
|| prop.type == typeof(DateTimeOffset?)
|| prop.type == typeof(int)
|| prop.type == typeof(int?)
|| prop.type == typeof(Int64)
|| prop.type == typeof(Int64?)
|| prop.type == typeof(double)
|| prop.type == typeof(double?)
)
{
tableEntity.Add(prop.columnName, value);
}
else if (prop.type.IsEnum)
{
var values =
(value?.ToString()?.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
.Select(CaseConverter.PascalToSnake)).EnsureNotNull($"Unable to read enum data {value}");
tableEntity.Add(prop.columnName, string.Join(",", values));
}
else
{
var serialized = JsonSerializer.Serialize(value, _options);
tableEntity.Add(prop.columnName, serialized);
}
var tableClient = await GetTableClient(typeof(T).Name);
var tableEntity = _entityConverter.ToTableEntity(entity);
var response = await tableClient.UpsertEntityAsync(tableEntity);
return !response.IsError;
}
if (typedEntity.ETag.HasValue) {
tableEntity.ETag = typedEntity.ETag.Value;
public async Task<TableClient> GetTableClient(string table, string? accountId = null)
{
var account = accountId ?? EnvironmentVariables.OneFuzz.FuncStorage ?? throw new ArgumentNullException(nameof(accountId));
var (name, key) = _storage.GetStorageAccountNameAndKey(account);
var identifier = new ResourceIdentifier(account);
var tableClient = new TableServiceClient(new Uri($"https://{identifier.Name}.table.core.windows.net"), new TableSharedKeyCredential(name, key));
await tableClient.CreateTableIfNotExistsAsync(table);
return tableClient.GetTableClient(table);
}
return tableEntity;
}
public T ToRecord<T>(TableEntity entity) where T: EntityBase
{
var entityInfo = GetEntityInfo<T>();
var parameters =
entityInfo.properties.Select(ef =>
{
if (ef.kind == EntityPropertyKind.PartitionKey || ef.kind == EntityPropertyKind.RowKey)
{
if (ef.type == typeof(string))
return entity.GetString(ef.kind.ToString());
else if (ef.type == typeof(Guid))
return Guid.Parse(entity.GetString(ef.kind.ToString()));
else
{
throw new Exception("invalid ");
}
}
var fieldName = ef.columnName;
if (ef.type == typeof(string))
{
return entity.GetString(fieldName);
}
else if (ef.type == typeof(bool))
{
return entity.GetBoolean(fieldName);
}
else if (ef.type == typeof(DateTimeOffset) || ef.type == typeof(DateTimeOffset?))
{
return entity.GetDateTimeOffset(fieldName);
}
else if (ef.type == typeof(DateTime))
{
return entity.GetDateTime(fieldName);
}
else if (ef.type == typeof(double))
{
return entity.GetDouble(fieldName);
}
else if (ef.type == typeof(Guid) || ef.type == typeof(Guid?))
{
return (object?)Guid.Parse(entity.GetString(fieldName));
}
else if (ef.type == typeof(int))
{
return entity.GetInt32(fieldName);
}
else if (ef.type == typeof(Int64))
{
return entity.GetInt64(fieldName);
}
else if (ef.type.IsEnum)
{
var stringValues =
entity.GetString(fieldName).Split(",", StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)
.Select(CaseConverter.SnakeToPascal);
return Enum.Parse(ef.type, string.Join(",", stringValues));
}
else
{
var value = entity.GetString(fieldName);
return JsonSerializer.Deserialize(value, ef.type, options: _options); ;
}
}
).ToArray();
var entityRecord = (T)entityInfo.constructor.Invoke(parameters);
entityRecord.ETag = entity.ETag;
entityRecord.TimeStamp = entity.Timestamp;
return entityRecord;
}
}

View File

@ -14,8 +14,8 @@ namespace Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
public interface IStorageProvider
{
Task<TableClient> GetTableClient(string table);
IAsyncEnumerable<T> QueryAsync<T>(string filter) where T : EntityBase;
Task<bool> Replace<T>(T entity) where T : EntityBase;
//IAsyncEnumerable<T> QueryAsync<T>(string filter) where T : EntityBase;
//Task<bool> Replace<T>(T entity) where T : EntityBase;
}
@ -48,21 +48,5 @@ public class StorageProvider : IStorageProvider
return (resourceId.Name, key?.Value);
}
public async IAsyncEnumerable<T> QueryAsync<T>(string filter) where T : EntityBase
{
var tableClient = await GetTableClient(typeof(T).Name);
await foreach (var x in tableClient.QueryAsync<TableEntity>(filter).Select(x => _entityConverter.ToRecord<T>(x))) {
yield return x;
}
}
public async Task<bool> Replace<T>(T entity) where T : EntityBase
{
var tableClient = await GetTableClient(typeof(T).Name);
var tableEntity = _entityConverter.ToTableEntity(entity);
var response = await tableClient.UpsertEntityAsync(tableEntity);
return !response.IsError;
}
}