Implement the node C# function (#2072)

1. Ports the `node` function from Python to C#.
2. Adds a missing authentication check.
3. Add validated string type `PoolName` for consistency with Python version.
This commit is contained in:
George Pollard
2022-06-23 13:44:14 +12:00
committed by GitHub
parent 1eeefce85c
commit 4eec0bfc45
30 changed files with 870 additions and 109 deletions

View File

@ -0,0 +1,179 @@
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public class NodeFunction {
private readonly ILogTracer _log;
private readonly IEndpointAuthorization _auth;
private readonly IOnefuzzContext _context;
public NodeFunction(ILogTracer log, IEndpointAuthorization auth, IOnefuzzContext context) {
_log = log;
_auth = auth;
_context = context;
}
private static readonly EntityConverter _entityConverter = new();
// [Function("Node")
public Async.Task<HttpResponseData> Run([HttpTrigger("GET", "PATCH", "POST", "DELETE")] HttpRequestData req) {
return _auth.CallIfUser(req, r => r.Method switch {
"GET" => Get(r),
"PATCH" => Patch(r),
"POST" => Post(r),
"DELETE" => Delete(r),
_ => throw new InvalidOperationException("Unsupported HTTP method"),
});
}
private async Async.Task<HttpResponseData> Get(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeSearch>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(req, request.ErrorV, "pool get");
}
var search = request.OkV;
if (search.MachineId is Guid machineId) {
var node = await _context.NodeOperations.GetByMachineId(machineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
Errors: new string[] { "unable to find node " }),
context: machineId.ToString());
}
var (tasks, messages) = await (
_context.NodeTasksOperations.GetByMachineId(machineId).ToListAsync().AsTask(),
_context.NodeMessageOperations.GetMessage(machineId).ToListAsync().AsTask());
var commands = messages.Select(m => m.Message).ToList();
return await RequestHandling.Ok(req, NodeToNodeSearchResult(node with { Tasks = tasks, Messages = commands }));
}
var nodes = await _context.NodeOperations.SearchStates(
states: search.State,
poolName: search.PoolName,
scaleSetId: search.ScalesetId).ToListAsync();
return await RequestHandling.Ok(req, nodes.Select(NodeToNodeSearchResult));
}
private static NodeSearchResult NodeToNodeSearchResult(Node node) {
return new NodeSearchResult(
PoolId: node.PoolId,
PoolName: node.PoolName,
MachineId: node.MachineId,
Version: node.Version,
Heartbeat: node.Heartbeat,
InitializedAt: node.InitializedAt,
State: node.State,
ScalesetId: node.ScalesetId,
ReimageRequested: node.ReimageRequested,
DeleteRequested: node.DeleteRequested,
DebugKeepNode: node.DebugKeepNode);
}
private async Async.Task<HttpResponseData> Patch(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeGet>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
request.ErrorV,
"NodeReimage");
}
var authCheck = await _auth.CheckRequireAdmins(req);
if (!authCheck.IsOk) {
return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeReimage");
}
var patch = request.OkV;
var node = await _context.NodeOperations.GetByMachineId(patch.MachineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
Errors: new string[] { "unable to find node " }),
context: patch.MachineId.ToString());
}
await _context.NodeOperations.Stop(node, done: true);
if (node.DebugKeepNode) {
await _context.NodeOperations.Replace(node with { DebugKeepNode = false });
}
return await RequestHandling.Ok(req, true);
}
private async Async.Task<HttpResponseData> Post(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeUpdate>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
request.ErrorV,
"NodeUpdate");
}
var authCheck = await _auth.CheckRequireAdmins(req);
if (!authCheck.IsOk) {
return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeUpdate");
}
var post = request.OkV;
var node = await _context.NodeOperations.GetByMachineId(post.MachineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
Errors: new string[] { "unable to find node " }),
context: post.MachineId.ToString());
}
if (post.DebugKeepNode is bool value) {
node = node with { DebugKeepNode = value };
}
await _context.NodeOperations.Replace(node);
return await RequestHandling.Ok(req, true);
}
private async Async.Task<HttpResponseData> Delete(HttpRequestData req) {
var request = await RequestHandling.ParseRequest<NodeGet>(req);
if (!request.IsOk) {
return await _context.RequestHandling.NotOk(
req,
request.ErrorV,
context: "NodeDelete");
}
var authCheck = await _auth.CheckRequireAdmins(req);
if (!authCheck.IsOk) {
return await _context.RequestHandling.NotOk(req, authCheck.ErrorV, "NodeDelete");
}
var delete = request.OkV;
var node = await _context.NodeOperations.GetByMachineId(delete.MachineId);
if (node is null) {
return await _context.RequestHandling.NotOk(
req,
new Error(
Code: ErrorCode.UNABLE_TO_FIND,
new string[] { "unable to find node" }),
context: delete.MachineId.ToString());
}
await _context.NodeOperations.SetHalt(node);
if (node.DebugKeepNode) {
await _context.NodeOperations.Replace(node with { DebugKeepNode = false });
}
return await RequestHandling.Ok(req, true);
}
}

View File

@ -1,7 +1,6 @@
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using PoolName = System.String;
using Region = System.String;
namespace Microsoft.OneFuzz.Service;

View File

@ -3,7 +3,6 @@ using System.Text.Json.Serialization;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Endpoint = System.String;
using GroupId = System.Guid;
using PoolName = System.String;
using PrincipalId = System.Guid;
using Region = System.String;
@ -152,7 +151,6 @@ public record Error(ErrorCode Code, string[]? Errors = null);
public record UserInfo(Guid? ApplicationId, Guid? ObjectId, String? Upn);
public record TaskDetails(
TaskType Type,
int Duration,
@ -316,11 +314,11 @@ public record InstanceConfig
NetworkSecurityGroupConfig ProxyNsgConfig,
AzureVmExtensionConfig? Extensions,
string ProxyVmSku,
IDictionary<Endpoint, ApiAccessRule>? ApiAccessRules,
IDictionary<PrincipalId, GroupId[]>? GroupMembership,
IDictionary<string, string>? VmTags,
IDictionary<string, string>? VmssTags
IDictionary<Endpoint, ApiAccessRule>? ApiAccessRules = null,
IDictionary<PrincipalId, GroupId[]>? GroupMembership = null,
IDictionary<string, string>? VmTags = null,
IDictionary<string, string>? VmssTags = null,
bool? RequireAdminPrivileges = null
) : EntityBase() {
public InstanceConfig(string instanceName) : this(
instanceName,
@ -330,12 +328,7 @@ public record InstanceConfig
new NetworkConfig(),
new NetworkSecurityGroupConfig(),
null,
"Standard_B2s",
null,
null,
null,
null) { }
"Standard_B2s") { }
public InstanceConfig() : this(String.Empty) { }
public List<Guid>? CheckAdmins(List<Guid>? value) {
@ -346,7 +339,6 @@ public record InstanceConfig
}
}
//# At the moment, this only checks allowed_aad_tenants, however adding
//# support for 3rd party JWT validation is anticipated in a future release.
public ResultVoid<List<string>> CheckInstanceConfig() {

View File

@ -18,6 +18,22 @@ public record NodeCommandDelete(
string MessageId
) : BaseRequest;
public record NodeGet(
Guid MachineId
) : BaseRequest;
public record NodeUpdate(
Guid MachineId,
bool? DebugKeepNode
) : BaseRequest;
public record NodeSearch(
Guid? MachineId = null,
List<NodeState>? State = null,
Guid? ScalesetId = null,
PoolName? PoolName = null
) : BaseRequest;
public record NodeStateEnvelope(
NodeEventBase Event,
Guid MachineId

View File

@ -4,7 +4,10 @@ using System.Text.Json.Serialization;
namespace Microsoft.OneFuzz.Service;
[JsonConverter(typeof(BaseResponseConverter))]
public abstract record BaseResponse();
public abstract record BaseResponse() {
public static implicit operator BaseResponse(bool value)
=> new BoolResult(value);
};
public record CanSchedule(
bool Allowed,
@ -15,6 +18,23 @@ public record PendingNodeCommand(
NodeCommandEnvelope? Envelope
) : BaseResponse();
// TODO: not sure how much of this is actually
// needed in the search results, so at the moment
// it is a copy of the whole Node type
public record NodeSearchResult(
PoolName PoolName,
Guid MachineId,
Guid? PoolId,
string Version,
DateTimeOffset? Heartbeat,
DateTimeOffset? InitializedAt,
NodeState State,
Guid? ScalesetId,
bool ReimageRequested,
bool DeleteRequested,
bool DebugKeepNode
) : BaseResponse();
public record BoolResult(
bool Result
) : BaseResponse();

View File

@ -54,6 +54,9 @@ namespace Microsoft.OneFuzz.Service {
public static OneFuzzResult<T_Ok> Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error });
public static OneFuzzResult<T_Ok> Error(Error err) => new(err);
// Allow simple conversion of Errors to Results.
public static implicit operator OneFuzzResult<T_Ok>(Error err) => new(err);
}
public struct OneFuzzResultVoid {
@ -69,9 +72,12 @@ namespace Microsoft.OneFuzz.Service {
private OneFuzzResultVoid(Error err) => (ErrorV, IsOk) = (err, false);
public static OneFuzzResultVoid Ok() => new();
public static OneFuzzResultVoid Ok => new();
public static OneFuzzResultVoid Error(ErrorCode errorCode, string[] errors) => new(errorCode, errors);
public static OneFuzzResultVoid Error(ErrorCode errorCode, string error) => new(errorCode, new[] { error });
public static OneFuzzResultVoid Error(Error err) => new(err);
// Allow simple conversion of Errors to Results.
public static implicit operator OneFuzzResultVoid(Error err) => new(err);
}
}

View File

@ -0,0 +1,142 @@

using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Text.RegularExpressions;
namespace Microsoft.OneFuzz.Service;
static class Check {
private static readonly Regex _isAlnum = new(@"\A[a-zA-Z0-9]+\z", RegexOptions.Compiled);
public static bool IsAlnum(string input) => _isAlnum.IsMatch(input);
private static readonly Regex _isAlnumDash = new(@"\A[a-zA-Z0-9\-]+\z", RegexOptions.Compiled);
public static bool IsAlnumDash(string input) => _isAlnumDash.IsMatch(input);
}
// Base class for types that are wrappers around a validated string.
public abstract record ValidatedString(string String) {
public sealed override string ToString() => String;
}
// JSON converter for types that are wrappers around a validated string.
public abstract class ValidatedStringConverter<T> : JsonConverter<T> where T : ValidatedString {
protected abstract bool TryParse(string input, out T? output);
public sealed override bool CanConvert(Type typeToConvert)
=> typeToConvert == typeof(T);
public sealed override T? Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) {
if (reader.TokenType != JsonTokenType.String) {
throw new JsonException("expected a string");
}
var value = reader.GetString();
if (value is null) {
throw new JsonException("expected a string");
}
if (TryParse(value, out var result)) {
return result;
} else {
throw new JsonException($"unable to parse input as a {typeof(T).Name}");
}
}
public sealed override void Write(Utf8JsonWriter writer, T value, JsonSerializerOptions options)
=> writer.WriteStringValue(value.String);
}
[JsonConverter(typeof(Converter))]
public record PoolName : ValidatedString {
private PoolName(string value) : base(value) {
Debug.Assert(Check.IsAlnumDash(value));
}
public static PoolName Parse(string input) {
if (TryParse(input, out var result)) {
return result;
}
throw new ArgumentException("Pool name must have only numbers, letters or dashes");
}
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out PoolName? result) {
if (!Check.IsAlnumDash(input)) {
result = default;
return false;
}
result = new PoolName(input);
return true;
}
public sealed class Converter : ValidatedStringConverter<PoolName> {
protected override bool TryParse(string input, out PoolName? output)
=> PoolName.TryParse(input, out output);
}
}
/* TODO: to be enabled in a separate PR
[JsonConverter(typeof(Converter))]
public record Region : ValidatedString {
private Region(string value) : base(value) {
Debug.Assert(Check.IsAlnum(value));
}
public static Region Parse(string input) {
if (TryParse(input, out var result)) {
return result;
}
throw new ArgumentException("Region name must have only numbers, letters or dashes");
}
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out Region? result) {
if (!Check.IsAlnum(input)) {
result = default;
return false;
}
result = new Region(input);
return true;
}
public sealed class Converter : ValidatedStringConverter<Region> {
protected override bool TryParse(string input, out Region? output)
=> Region.TryParse(input, out output);
}
}
[JsonConverter(typeof(Converter))]
public record Container : ValidatedString {
private Container(string value) : base(value) {
Debug.Assert(Check.IsAlnumDash(value));
}
public static Container Parse(string input) {
if (TryParse(input, out var result)) {
return result;
}
throw new ArgumentException("Container name must have only numbers, letters or dashes");
}
public static bool TryParse(string input, [NotNullWhen(returnValue: true)] out Container? result) {
if (!Check.IsAlnumDash(input)) {
result = default;
return false;
}
result = new Container(input);
return true;
}
public sealed class Converter : ValidatedStringConverter<Container> {
protected override bool TryParse(string input, out Container? output)
=> Container.TryParse(input, out output);
}
}
*/

View File

@ -165,7 +165,9 @@ namespace ApiService.TestHooks {
if (query.ContainsKey("states")) {
states = query["states"].Split('-').Select(s => Enum.Parse<NodeState>(s)).ToList();
}
string? poolName = UriExtension.GetString("poolName", query);
string? poolNameString = UriExtension.GetString("poolName", query);
PoolName? poolName = poolNameString is null ? null : PoolName.Parse(poolNameString);
var excludeUpdateScheduled = UriExtension.GetBool("excludeUpdateScheduled", query, false);
int? numResults = UriExtension.GetInt("numResults", query);
@ -209,7 +211,7 @@ namespace ApiService.TestHooks {
var query = UriExtension.GetQueryComponents(req.Url);
Guid poolId = Guid.Parse(query["poolId"]);
string poolName = query["poolName"];
var poolName = PoolName.Parse(query["poolName"]);
Guid machineId = Guid.Parse(query["machineId"]);
Guid? scaleSetId = default;

View File

@ -25,7 +25,7 @@ namespace ApiService.TestHooks {
_log.Info("get pool");
var query = UriExtension.GetQueryComponents(req.Url);
var poolRes = await _poolOps.GetByName(query["name"]);
var poolRes = await _poolOps.GetByName(PoolName.Parse(query["name"]));
if (poolRes.IsOk) {
var resp = req.CreateResponse(HttpStatusCode.OK);

View File

@ -58,7 +58,7 @@ public class UserCredentials : IUserCredentials {
return OneFuzzResult<string[]>.Ok(allowedAddTenantsQuery.ToArray());
}
public async Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) {
public virtual async Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) {
var authToken = GetAuthToken(req);
if (authToken is null) {
return OneFuzzResult<UserInfo>.Error(ErrorCode.INVALID_REQUEST, new[] { "unable to find authorization token" });

View File

@ -19,6 +19,8 @@ public interface IEndpointAuthorization {
Func<HttpRequestData, Async.Task<HttpResponseData>> method,
bool allowUser = false,
bool allowAgent = false);
Async.Task<OneFuzzResultVoid> CheckRequireAdmins(HttpRequestData req);
}
public class EndpointAuthorization : IEndpointAuthorization {
@ -30,7 +32,7 @@ public class EndpointAuthorization : IEndpointAuthorization {
_log = log;
}
public async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method, bool allowUser = false, bool allowAgent = false) {
public virtual async Async.Task<HttpResponseData> CallIf(HttpRequestData req, Func<HttpRequestData, Async.Task<HttpResponseData>> method, bool allowUser = false, bool allowAgent = false) {
var tokenResult = await _context.UserCredentials.ParseJwtToken(req);
if (!tokenResult.IsOk) {
@ -77,6 +79,59 @@ public class EndpointAuthorization : IEndpointAuthorization {
);
}
public async Async.Task<OneFuzzResultVoid> CheckRequireAdmins(HttpRequestData req) {
var tokenResult = await _context.UserCredentials.ParseJwtToken(req);
if (!tokenResult.IsOk) {
return tokenResult.ErrorV;
}
var config = await _context.ConfigOperations.Fetch();
if (config is null) {
return new Error(
Code: ErrorCode.INVALID_CONFIGURATION,
Errors: new string[] { "no instance configuration found " });
}
return CheckRequireAdminsImpl(config, tokenResult.OkV);
}
private static OneFuzzResultVoid CheckRequireAdminsImpl(InstanceConfig config, UserInfo userInfo) {
// When there are no admins in the `admins` list, all users are considered
// admins. However, `require_admin_privileges` is still useful to protect from
// mistakes.
//
// To make changes while still protecting against accidental changes to
// pools, do the following:
//
// 1. set `require_admin_privileges` to `False`
// 2. make the change
// 3. set `require_admin_privileges` to `True`
if (config.RequireAdminPrivileges == false) {
return OneFuzzResultVoid.Ok;
}
if (config.Admins is null) {
return new Error(
Code: ErrorCode.UNAUTHORIZED,
Errors: new string[] { "pool modification disabled " });
}
if (userInfo.ObjectId is Guid objectId) {
if (config.Admins.Contains(objectId)) {
return OneFuzzResultVoid.Ok;
}
return new Error(
Code: ErrorCode.UNAUTHORIZED,
Errors: new string[] { "not authorized to manage pools" });
} else {
return new Error(
Code: ErrorCode.UNAUTHORIZED,
Errors: new string[] { "user had no Object ID" });
}
}
public OneFuzzResultVoid CheckAccess(HttpRequestData req) {
throw new NotImplementedException();
}

View File

@ -22,7 +22,7 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
IAsyncEnumerable<Node> SearchStates(Guid? poolId = default,
Guid? scaleSetId = default,
IEnumerable<NodeState>? states = default,
string? poolName = default,
PoolName? poolName = default,
bool excludeUpdateScheduled = false,
int? numResults = default);
@ -32,7 +32,7 @@ public interface INodeOperations : IStatefulOrm<Node, NodeState> {
Async.Task<Node> Create(
Guid poolId,
string poolName,
PoolName poolName,
Guid machineId,
Guid? scaleSetId,
string version,
@ -67,7 +67,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
_logTracer.Info($"Setting scale-in protection on node {node.MachineId}");
return await _context.VmssOperations.UpdateScaleInProtection((Guid)node.ScalesetId, node.MachineId, protectFromScaleIn: true);
}
return OneFuzzResultVoid.Ok();
return OneFuzzResultVoid.Ok;
}
public async Async.Task<bool> ScalesetNodeExists(Node node) {
@ -207,7 +207,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
public async Async.Task<Node> Create(
Guid poolId,
string poolName,
PoolName poolName,
Guid machineId,
Guid? scaleSetId,
string version,
@ -308,7 +308,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
Guid? poolId = default,
Guid? scaleSetId = default,
IEnumerable<NodeState>? states = default,
string? poolName = default,
PoolName? poolName = default,
bool excludeUpdateScheduled = false,
int? numResults = default) {
@ -318,6 +318,10 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
queryParts.Add($"(pool_id eq '{poolId}')");
}
if (poolName is not null) {
queryParts.Add($"(PartitionKey eq '{poolName}')");
}
if (scaleSetId is not null) {
queryParts.Add($"(scaleset_id eq '{scaleSetId}')");
}
@ -346,7 +350,7 @@ public class NodeOperations : StatefulOrm<Node, NodeState>, INodeOperations {
Guid? poolId = default,
Guid? scaleSetId = default,
IEnumerable<NodeState>? states = default,
string? poolName = default,
PoolName? poolName = default,
bool excludeUpdateScheduled = false,
int? numResults = default) {
var query = NodeOperations.SearchStatesQuery(_context.ServiceConfiguration.OneFuzzVersion, poolId, scaleSetId, states, poolName, excludeUpdateScheduled, numResults);
@ -467,9 +471,8 @@ public class NodeMessageOperations : Orm<NodeMessage>, INodeMessageOperations {
_log = log;
}
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId) {
return QueryAsync($"PartitionKey eq '{machineId}'");
}
public IAsyncEnumerable<NodeMessage> GetMessage(Guid machineId)
=> QueryAsync(Query.PartitionKey(machineId));
public async Async.Task ClearMessages(Guid machineId) {
_logTracer.Info($"clearing messages for node {machineId}");

View File

@ -48,7 +48,7 @@ namespace Microsoft.OneFuzz.Service {
public async Async.Task<OneFuzzResultVoid> DissociateNic(Nsg nsg, NetworkInterfaceResource nic) {
if (nic.Data.NetworkSecurityGroup == null) {
return OneFuzzResultVoid.Ok();
return OneFuzzResultVoid.Ok;
}
var azureNsg = await GetNsg(nsg.Name);
@ -83,7 +83,7 @@ namespace Microsoft.OneFuzz.Service {
err,
)
*/
return OneFuzzResultVoid.Ok();
return OneFuzzResultVoid.Ok;
}
return OneFuzzResultVoid.Error(
ErrorCode.UNABLE_TO_UPDATE,
@ -93,7 +93,7 @@ namespace Microsoft.OneFuzz.Service {
);
}
return OneFuzzResultVoid.Ok();
return OneFuzzResultVoid.Ok;
}
public async Async.Task<NetworkSecurityGroupResource?> GetNsg(string name) {

View File

@ -4,7 +4,7 @@ using ApiService.OneFuzzLib.Orm;
namespace Microsoft.OneFuzz.Service;
public interface IPoolOperations {
public Async.Task<OneFuzzResult<Pool>> GetByName(string poolName);
public Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName);
Task<bool> ScheduleWorkset(Pool pool, WorkSet workSet);
IAsyncEnumerable<Pool> GetByClientId(Guid clientId);
}
@ -16,8 +16,8 @@ public class PoolOperations : StatefulOrm<Pool, PoolState>, IPoolOperations {
}
public async Async.Task<OneFuzzResult<Pool>> GetByName(string poolName) {
var pools = QueryAsync(filter: $"PartitionKey eq '{poolName}'");
public async Async.Task<OneFuzzResult<Pool>> GetByName(PoolName poolName) {
var pools = QueryAsync(filter: $"PartitionKey eq '{poolName.String}'");
if (pools == null || await pools.CountAsync() == 0) {
return OneFuzzResult<Pool>.Error(ErrorCode.INVALID_REQUEST, "unable to find pool");

View File

@ -6,7 +6,7 @@ namespace Microsoft.OneFuzz.Service;
public interface IScalesetOperations : IOrm<Scaleset> {
IAsyncEnumerable<Scaleset> Search();
public IAsyncEnumerable<Scaleset?> SearchByPool(string poolName);
public IAsyncEnumerable<Scaleset?> SearchByPool(PoolName poolName);
public Async.Task UpdateConfigs(Scaleset scaleSet);
@ -29,8 +29,8 @@ public class ScalesetOperations : StatefulOrm<Scaleset, ScalesetState>, IScalese
return QueryAsync();
}
public IAsyncEnumerable<Scaleset> SearchByPool(string poolName) {
return QueryAsync(filter: $"pool_name eq '{poolName}'");
public IAsyncEnumerable<Scaleset> SearchByPool(PoolName poolName) {
return QueryAsync(filter: $"PartitionKey eq '{poolName}'");
}

View File

@ -182,7 +182,7 @@ public class Scheduler : IScheduler {
return (bucketConfig, workUnit);
}
record struct BucketId(Os os, Guid jobId, (string, string)? vm, string? pool, string setupContainer, bool? reboot, Guid? unique);
record struct BucketId(Os os, Guid jobId, (string, string)? vm, PoolName? pool, string setupContainer, bool? reboot, Guid? unique);
private ILookup<BucketId, Task> BucketTasks(IEnumerable<Task> tasks) {
@ -205,7 +205,7 @@ public class Scheduler : IScheduler {
}
// check for multiple VMs for 1.0.0 and later tasks
string? pool = task.Config.Pool?.PoolName;
var pool = task.Config.Pool?.PoolName;
if ((task.Config.Pool?.Count ?? 0) > 1) {
unique = Guid.NewGuid();
}
@ -219,5 +219,3 @@ public class Scheduler : IScheduler {
});
}
}

View File

@ -80,7 +80,7 @@ public class VmssOperations : IVmssOperations {
}
var _ = await res.UpdateAsync(WaitUntil.Started, patch);
_log.Info($"VM extensions updated: {name}");
return OneFuzzResultVoid.Ok();
return OneFuzzResultVoid.Ok;
} else {
return OneFuzzResultVoid.Error(canUpdate.ErrorV);
@ -169,13 +169,13 @@ public class VmssOperations : IVmssOperations {
_log.WithHttpStatus((r.GetRawResponse().Status, r.GetRawResponse().ReasonPhrase)).Error(msg);
return OneFuzzResultVoid.Error(ErrorCode.UNABLE_TO_UPDATE, msg);
} else {
return OneFuzzResultVoid.Ok();
return OneFuzzResultVoid.Ok;
}
} catch (Exception ex) when (ex is RequestFailedException || ex is CloudException) {
if (ex.Message.Contains(INSTANCE_NOT_FOUND) && protectFromScaleIn == false) {
_log.Info($"Tried to remove scale in protection on node {name} {vmId} but instance no longer exists");
return OneFuzzResultVoid.Ok();
return OneFuzzResultVoid.Ok;
} else {
var msg = $"failed to update scale in protection on vm {vmId} for scaleset {name}";
_log.Exception(ex, msg);

View File

@ -143,10 +143,9 @@ public class EntityConverter {
});
}
public string ToJsonString<T>(T typedEntity) {
var serialized = JsonSerializer.Serialize(typedEntity, _options);
return serialized;
}
public string ToJsonString<T>(T typedEntity) => JsonSerializer.Serialize(typedEntity, _options);
public T? FromJsonString<T>(string value) => JsonSerializer.Deserialize<T>(value, _options);
public TableEntity ToTableEntity<T>(T typedEntity) where T : EntityBase {
if (typedEntity == null) {
@ -211,8 +210,11 @@ public class EntityConverter {
return Guid.Parse(entity.GetString(ef.kind.ToString()));
else if (ef.type == typeof(int))
return int.Parse(entity.GetString(ef.kind.ToString()));
else if (ef.type == typeof(PoolName))
// TODO: this should be able to be generic over any ValidatedString
return PoolName.Parse(entity.GetString(ef.kind.ToString()));
else {
throw new Exception("invalid ");
throw new Exception($"invalid partition or row key type of {info.type} property {name}: {ef.type}");
}
}
@ -247,7 +249,6 @@ public class EntityConverter {
outputType = typeProvider.GetTypeInfo(v);
}
if (objType == typeof(string)) {
var value = entity.GetString(fieldName);
if (value.StartsWith('[') || value.StartsWith('{') || value == "null") {
@ -283,6 +284,3 @@ public class EntityConverter {
}
}

View File

@ -11,6 +11,9 @@ namespace ApiService.OneFuzzLib.Orm {
public static string PartitionKey(string partitionKey)
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
public static string PartitionKey(Guid partitionKey)
=> TableClient.CreateQueryFilter($"PartitionKey eq {partitionKey}");
public static string RowKey(string rowKey)
=> TableClient.CreateQueryFilter($"RowKey eq {rowKey}");

View File

@ -24,6 +24,9 @@ public sealed class TestContext : IOnefuzzContext {
NodeTasksOperations = new NodeTasksOperations(logTracer, this);
TaskEventOperations = new TaskEventOperations(logTracer, this);
NodeMessageOperations = new NodeMessageOperations(logTracer, this);
ConfigOperations = new ConfigOperations(logTracer, this);
UserCredentials = new UserCredentials(logTracer, ConfigOperations);
}
public TestEvents Events { get; set; } = new();
@ -36,6 +39,7 @@ public sealed class TestContext : IOnefuzzContext {
Node n => NodeOperations.Insert(n),
Job j => JobOperations.Insert(j),
NodeTasks nt => NodeTasksOperations.Insert(nt),
InstanceConfig ic => ConfigOperations.Insert(ic),
_ => throw new NotImplementedException($"Need to add an TestContext.InsertAll case for {x.GetType()} entities"),
}));
@ -48,6 +52,7 @@ public sealed class TestContext : IOnefuzzContext {
public IStorage Storage { get; }
public ICreds Creds { get; }
public IContainers Containers { get; }
public IUserCredentials UserCredentials { get; set; }
public IRequestHandling RequestHandling { get; }
@ -57,12 +62,12 @@ public sealed class TestContext : IOnefuzzContext {
public INodeTasksOperations NodeTasksOperations { get; }
public ITaskEventOperations TaskEventOperations { get; }
public INodeMessageOperations NodeMessageOperations { get; }
public IConfigOperations ConfigOperations { get; }
// -- Remainder not implemented --
public IConfig Config => throw new System.NotImplementedException();
public IConfigOperations ConfigOperations => throw new System.NotImplementedException();
public IDiskOperations DiskOperations => throw new System.NotImplementedException();
@ -92,8 +97,6 @@ public sealed class TestContext : IOnefuzzContext {
public ISecretsOperations SecretsOperations => throw new System.NotImplementedException();
public IUserCredentials UserCredentials => throw new System.NotImplementedException();
public IVmOperations VmOperations => throw new System.NotImplementedException();
public IVmssOperations VmssOperations => throw new System.NotImplementedException();

View File

@ -11,16 +11,16 @@ enum RequestType {
Agent,
}
sealed class TestEndpointAuthorization : IEndpointAuthorization {
sealed class TestEndpointAuthorization : EndpointAuthorization {
private readonly RequestType _type;
private readonly IOnefuzzContext _context;
public TestEndpointAuthorization(RequestType type, IOnefuzzContext context) {
public TestEndpointAuthorization(RequestType type, ILogTracer log, IOnefuzzContext context) : base(context, log) {
_type = type;
_context = context;
}
public Task<HttpResponseData> CallIf(
public override Task<HttpResponseData> CallIf(
HttpRequestData req,
Func<HttpRequestData, Task<HttpResponseData>> method,
bool allowUser = false,

View File

@ -19,6 +19,8 @@ public sealed class TestServiceConfiguration : IServiceConfig {
public string? ApplicationInsightsInstrumentationKey { get; set; } = "TestAppInsightsInstrumentationKey";
public string? OneFuzzInstanceName => "UnitTestInstance";
// -- Remainder not implemented --
public LogDestination[] LogDestinations { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
@ -42,8 +44,6 @@ public sealed class TestServiceConfiguration : IServiceConfig {
public string? OneFuzzInstance => throw new System.NotImplementedException();
public string? OneFuzzInstanceName => throw new System.NotImplementedException();
public string? OneFuzzKeyvault => throw new System.NotImplementedException();
public string? OneFuzzMonitor => throw new System.NotImplementedException();

View File

@ -0,0 +1,19 @@
using System.Threading.Tasks;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service;
using Async = System.Threading.Tasks;
namespace Tests.Fakes;
sealed class TestUserCredentials : UserCredentials {
private readonly OneFuzzResult<UserInfo> _tokenResult;
public TestUserCredentials(ILogTracer log, IConfigOperations instanceConfig, OneFuzzResult<UserInfo> tokenResult)
: base(log, instanceConfig) {
_tokenResult = tokenResult;
}
public override Task<OneFuzzResult<UserInfo>> ParseJwtToken(HttpRequestData req) => Async.Task.FromResult(_tokenResult);
}

View File

@ -29,7 +29,7 @@ public abstract class AgentEventsTestsBase : FunctionTestBase {
readonly Guid jobId = Guid.NewGuid();
readonly Guid taskId = Guid.NewGuid();
readonly Guid machineId = Guid.NewGuid();
readonly string poolName = $"pool-{Guid.NewGuid()}";
readonly PoolName poolName = PoolName.Parse($"pool-{Guid.NewGuid()}");
readonly Guid poolId = Guid.NewGuid();
readonly string poolVersion = $"version-{Guid.NewGuid()}";

View File

@ -27,7 +27,7 @@ public abstract class InfoTestBase : FunctionTestBase {
[Fact]
public async Async.Task TestInfo_WithoutAuthorization_IsRejected() {
var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Context);
var auth = new TestEndpointAuthorization(RequestType.NoAuthorization, Logger, Context);
var func = new Info(auth, Context);
var result = await func.Run(TestHttpRequestData.Empty("GET"));
@ -36,7 +36,7 @@ public abstract class InfoTestBase : FunctionTestBase {
[Fact]
public async Async.Task TestInfo_WithAgentCredentials_IsRejected() {
var auth = new TestEndpointAuthorization(RequestType.Agent, Context);
var auth = new TestEndpointAuthorization(RequestType.Agent, Logger, Context);
var func = new Info(auth, Context);
var result = await func.Run(TestHttpRequestData.Empty("GET"));
@ -52,7 +52,7 @@ public abstract class InfoTestBase : FunctionTestBase {
await containerClient.CreateAsync();
await containerClient.GetBlobClient("instance_id").UploadAsync(new BinaryData(instanceId));
var auth = new TestEndpointAuthorization(RequestType.User, Context);
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new Info(auth, Context);
var result = await func.Run(TestHttpRequestData.Empty("GET"));

View File

@ -0,0 +1,291 @@

using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using Microsoft.OneFuzz.Service;
using Tests.Fakes;
using Xunit;
using Xunit.Abstractions;
using Async = System.Threading.Tasks;
namespace Tests.Functions;
[Trait("Category", "Integration")]
public class AzureStorageNodeTest : NodeTestBase {
public AzureStorageNodeTest(ITestOutputHelper output)
: base(output, Integration.AzureStorage.FromEnvironment()) { }
}
public class AzuriteNodeTest : NodeTestBase {
public AzuriteNodeTest(ITestOutputHelper output)
: base(output, new Integration.AzuriteStorage()) { }
}
public abstract class NodeTestBase : FunctionTestBase {
public NodeTestBase(ITestOutputHelper output, IStorage storage)
: base(output, storage) { }
private readonly Guid _machineId = Guid.NewGuid();
private readonly Guid _scalesetId = Guid.NewGuid();
private readonly PoolName _poolName = PoolName.Parse($"pool-{Guid.NewGuid()}");
private readonly string _version = Guid.NewGuid().ToString();
[Fact]
public async Async.Task Search_SpecificNode_NotFound_ReturnsNotFound() {
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var req = new NodeSearch(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
}
[Fact]
public async Async.Task Search_SpecificNode_Found_ReturnsOk() {
await Context.InsertAll(
new Node(_poolName, _machineId, null, _version));
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var req = new NodeSearch(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult>(result);
Assert.Equal(_version, deserialized.Version);
}
[Fact]
public async Async.Task Search_MultipleNodes_CanFindNone() {
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var req = new NodeSearch();
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
Assert.Equal(0, result.Body.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByPoolName() {
await Context.InsertAll(
new Node(PoolName.Parse("otherPool"), Guid.NewGuid(), null, _version),
new Node(_poolName, Guid.NewGuid(), null, _version),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(PoolName: _poolName);
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(2, deserialized.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByScalesetId() {
await Context.InsertAll(
new Node(_poolName, Guid.NewGuid(), null, _version, ScalesetId: _scalesetId),
new Node(_poolName, Guid.NewGuid(), null, _version, ScalesetId: _scalesetId),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(ScalesetId: _scalesetId);
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(2, deserialized.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByState() {
await Context.InsertAll(
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(State: new List<NodeState> { NodeState.Busy });
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(2, deserialized.Length);
}
[Fact]
public async Async.Task Search_MultipleNodes_ByMultipleStates() {
await Context.InsertAll(
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Free),
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version, State: NodeState.Busy),
new Node(_poolName, Guid.NewGuid(), null, _version));
var req = new NodeSearch(State: new List<NodeState> { NodeState.Free, NodeState.Busy });
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson("GET", req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
// make sure we got the data from the table
var deserialized = BodyAs<NodeSearchResult[]>(result);
Assert.Equal(3, deserialized.Length);
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task RequiresAdmin(string method) {
// config must be found
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!));
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
var err = BodyAs<Error>(result);
Assert.Equal(ErrorCode.UNAUTHORIZED, err.Code);
Assert.Contains("pool modification disabled", err.Errors?.Single());
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task RequiresAdmin_CanBeDisabled(string method) {
// disable requiring admin privileges
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
RequireAdminPrivileges = false
});
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
// we will fail with BadRequest but due to not being able to find the Node,
// not because of UNAUTHORIZED
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
Assert.Equal(ErrorCode.UNABLE_TO_FIND, BodyAs<Error>(result).Code);
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task UserCanBeAdmin(string method) {
var userObjectId = Guid.NewGuid();
// config specifies that user is admin
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
Admins = new[] { userObjectId }
});
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
// we will fail with BadRequest but due to not being able to find the Node,
// not because of UNAUTHORIZED
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
Assert.Equal(ErrorCode.UNABLE_TO_FIND, BodyAs<Error>(result).Code);
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task EnablingAdminForAnotherUserDoesNotPermitThisUser(string method) {
var userObjectId = Guid.NewGuid();
var otherObjectId = Guid.NewGuid();
// config specifies that a different user is admin
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
Admins = new[] { otherObjectId }
});
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: userObjectId, "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
Assert.Equal(HttpStatusCode.BadRequest, result.StatusCode);
var err = BodyAs<Error>(result);
Assert.Equal(ErrorCode.UNAUTHORIZED, err.Code);
Assert.Contains("not authorized to manage pools", err.Errors?.Single());
}
[Theory]
[InlineData("PATCH")]
[InlineData("POST")]
[InlineData("DELETE")]
public async Async.Task CanPerformOperation(string method) {
// disable requiring admin privileges
await Context.InsertAll(
new InstanceConfig(Context.ServiceConfiguration.OneFuzzInstanceName!) {
RequireAdminPrivileges = false
},
new Node(_poolName, _machineId, null, _version));
// must be a user to auth
var auth = new TestEndpointAuthorization(RequestType.User, Logger, Context);
// override the found user credentials
var userInfo = new UserInfo(ApplicationId: Guid.NewGuid(), ObjectId: Guid.NewGuid(), "upn");
Context.UserCredentials = new TestUserCredentials(Logger, Context.ConfigOperations, OneFuzzResult<UserInfo>.Ok(userInfo));
// all of these operations use NodeGet
var req = new NodeGet(MachineId: _machineId);
var func = new NodeFunction(Logger, auth, Context);
var result = await func.Run(TestHttpRequestData.FromJson(method, req));
Assert.Equal(HttpStatusCode.OK, result.StatusCode);
}
}

View File

@ -7,6 +7,7 @@ using Azure.Storage;
using Azure.Storage.Blobs;
using Microsoft.Azure.Functions.Worker.Http;
using Microsoft.OneFuzz.Service;
using Microsoft.OneFuzz.Service.OneFuzzLib.Orm;
using Tests.Fakes;
using Xunit.Abstractions;
@ -60,6 +61,9 @@ public abstract class FunctionTestBase : IDisposable {
return sr.ReadToEnd();
}
protected static T BodyAs<T>(HttpResponseData data)
=> new EntityConverter().FromJsonString<T>(BodyAsString(data)) ?? throw new Exception($"unable to deserialize body as {typeof(T)}");
public void Dispose() {
var (accountName, accountKey) = _storage.GetStorageAccountNameAndKey("").Result; // sync for test impls
if (accountName is not null && accountKey is not null) {
@ -72,6 +76,7 @@ public abstract class FunctionTestBase : IDisposable {
new StorageSharedKeyCredential(accountName, accountKey));
}
}
private void CleanupBlobs(Uri endpoint, StorageSharedKeyCredential creds) {
var blobClient = new BlobServiceClient(endpoint, creds);

View File

@ -74,21 +74,26 @@ namespace Tests {
);
}
public static Gen<Node> Node() {
return Arb.Generate<Tuple<Tuple<DateTimeOffset?, string, Guid?, Guid, NodeState>, Tuple<Guid?, DateTimeOffset, string, bool, bool, bool>>>().Select(
arg => new Node(
public static Gen<PoolName> PoolNameGen { get; }
= from name in Arb.Generate<NonEmptyString>()
where PoolName.TryParse(name.Get, out _)
select PoolName.Parse(name.Get);
public static Gen<Node> Node { get; }
= from arg in Arb.Generate<Tuple<Tuple<DateTimeOffset?, Guid?, Guid, NodeState>, Tuple<Guid?, DateTimeOffset, string, bool, bool, bool>>>()
from poolName in PoolNameGen
select new Node(
InitializedAt: arg.Item1.Item1,
PoolName: arg.Item1.Item2,
PoolName: poolName,
PoolId: arg.Item1.Item3,
MachineId: arg.Item1.Item4,
State: arg.Item1.Item5,
MachineId: arg.Item1.Item3,
State: arg.Item1.Item4,
ScalesetId: arg.Item2.Item1,
Heartbeat: arg.Item2.Item2,
Version: arg.Item2.Item3,
ReimageRequested: arg.Item2.Item4,
DeleteRequested: arg.Item2.Item5,
DebugKeepNode: arg.Item2.Item6));
}
DebugKeepNode: arg.Item2.Item6);
public static Gen<ProxyForward> ProxyForward() {
return Arb.Generate<Tuple<Tuple<string, long, Guid, Guid, Guid?, long>, Tuple<IPv4Address, DateTimeOffset>>>().Select(
@ -200,20 +205,20 @@ namespace Tests {
)
);
}
public static Gen<Scaleset> Scaleset() {
return Arb.Generate<Tuple<
Tuple<string, Guid, ScalesetState, Authentication?, string, string, string>,
public static Gen<Scaleset> Scaleset { get; }
= from arg in Arb.Generate<Tuple<
Tuple<Guid, ScalesetState, Authentication?, string, string, string>,
Tuple<int, bool, bool, bool, Error?, List<ScalesetNodeState>, Guid?>,
Tuple<Guid?, Dictionary<string, string>>>>().Select(
arg =>
new Scaleset(
PoolName: arg.Item1.Item1,
ScalesetId: arg.Item1.Item2,
State: arg.Item1.Item3,
Auth: arg.Item1.Item4,
VmSku: arg.Item1.Item5,
Image: arg.Item1.Item6,
Region: arg.Item1.Item7,
Tuple<Guid?, Dictionary<string, string>>>>()
from poolName in PoolNameGen
select new Scaleset(
PoolName: poolName,
ScalesetId: arg.Item1.Item1,
State: arg.Item1.Item2,
Auth: arg.Item1.Item3,
VmSku: arg.Item1.Item4,
Image: arg.Item1.Item5,
Region: arg.Item1.Item6,
Size: arg.Item2.Item1,
SpotInstance: arg.Item2.Item2,
@ -224,11 +229,7 @@ namespace Tests {
ClientId: arg.Item2.Item7,
ClientObjectId: arg.Item3.Item1,
Tags: arg.Item3.Item2
)
);
}
Tags: arg.Item3.Item2);
public static Gen<Webhook> Webhook() {
return Arb.Generate<Tuple<Guid, string, Uri?, List<EventType>, string, WebhookMessageFormat>>().Select(
@ -331,7 +332,9 @@ namespace Tests {
public class OrmArb {
public static Arbitrary<Version> Vresion() {
public static Arbitrary<PoolName> PoolName { get; } = OrmGenerators.PoolNameGen.ToArbitrary();
public static Arbitrary<Version> Version() {
return Arb.From(OrmGenerators.Version());
}
@ -348,7 +351,7 @@ namespace Tests {
}
public static Arbitrary<Node> Node() {
return Arb.From(OrmGenerators.Node());
return Arb.From(OrmGenerators.Node);
}
public static Arbitrary<ProxyForward> ProxyForward() {
@ -383,9 +386,8 @@ namespace Tests {
return Arb.From(OrmGenerators.Task());
}
public static Arbitrary<Scaleset> Scaleset() {
return Arb.From(OrmGenerators.Scaleset());
}
public static Arbitrary<Scaleset> Scaleset()
=> Arb.From(OrmGenerators.Scaleset);
public static Arbitrary<Webhook> Webhook() {
return Arb.From(OrmGenerators.Webhook());

View File

@ -234,7 +234,7 @@ namespace Tests {
[Fact]
public void TestEventSerialization() {
var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), "test Poool"), Guid.NewGuid(), "test");
var expectedEvent = new EventMessage(Guid.NewGuid(), EventType.NodeHeartbeat, new EventNodeHeartbeat(Guid.NewGuid(), Guid.NewGuid(), PoolName.Parse("test-Poool")), Guid.NewGuid(), "test");
var serialized = JsonSerializer.Serialize(expectedEvent, EntityConverter.GetJsonSerializerOptions());
var actualEvent = JsonSerializer.Deserialize<EventMessage>(serialized, EntityConverter.GetJsonSerializerOptions());
Assert.Equal(expectedEvent, actualEvent);

View File

@ -0,0 +1,28 @@
using System.Text.Json;
using Microsoft.OneFuzz.Service;
using Xunit;
namespace Tests;
public class ValidatedStringTests {
record ThingContainingPoolName(PoolName PoolName);
[Fact]
public void PoolNameValidatesOnDeserialization() {
var ex = Assert.Throws<JsonException>(() => JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-not!-a-pool\" }"));
Assert.Equal("unable to parse input as a PoolName", ex.Message);
}
[Fact]
public void PoolNameDeserializesFromString() {
var result = JsonSerializer.Deserialize<ThingContainingPoolName>("{ \"PoolName\": \"is-a-pool\" }");
Assert.Equal("is-a-pool", result?.PoolName.String);
}
[Fact]
public void PoolNameSerializesToString() {
var result = JsonSerializer.Serialize(new ThingContainingPoolName(PoolName.Parse("is-a-pool")));
Assert.Equal("{\"PoolName\":\"is-a-pool\"}", result);
}
}